extol_sprite_layer/
lib.rs

1#![doc = include_str!("../README.md")]
2use std::cmp::Reverse;
3use std::fmt::Debug;
4use std::hash::Hash;
5use std::marker::PhantomData;
6
7use bevy::ecs::entity::EntityHashMap; // noticeably faster than std's
8use bevy::prelude::*;
9use ordered_float::OrderedFloat;
10use tap::Tap;
11
12/// This plugin adjusts your entities' transforms so that their z-coordinates are sorted in the
13/// proper order, where the order is specified by the `Layer` component. Layers propagate to
14/// children (including through entities with no )
15///
16/// Layers propagate to children, including 'through' entities with no [`GlobalTransform`].
17///
18/// If you need to know the z-coordinate, you can read it out of the [`GlobalTransform`] after the
19/// [`SpriteLayer::SetZCoordinates`] set has run.
20///
21/// In general you should only instantiate this plugin with a single type you use throughout your
22/// program.
23///
24/// By default your sprites will also be y-sorted. If you don't need this, replace the
25/// [`SpriteLayerOptions`] like so:
26///
27/// ```
28/// # use bevy::prelude::*;
29/// # use extol_sprite_layer::SpriteLayerOptions;
30/// # let mut app = App::new();
31/// app.insert_resource(SpriteLayerOptions { y_sort: false });
32/// ```
33pub struct SpriteLayerPlugin<Layer> {
34    phantom: PhantomData<Layer>,
35}
36
37impl<Layer> Default for SpriteLayerPlugin<Layer> {
38    fn default() -> Self {
39        Self {
40            phantom: Default::default(),
41        }
42    }
43}
44
45impl<Layer: LayerIndex> Plugin for SpriteLayerPlugin<Layer> {
46    fn build(&self, app: &mut App) {
47        app.init_resource::<SpriteLayerOptions>()
48            .add_systems(
49                First,
50                clear_z_coordinates.in_set(SpriteLayerSet::ClearZCoordinates),
51            )
52            .add_systems(
53                Last,
54                // We need to run these systems *after* the transform's systems because they need the
55                // proper y-coordinate to be set for y-sorting.
56                (propagate_layers::<Layer>.pipe(set_z_coordinates::<Layer>),)
57                    .chain()
58                    .in_set(SpriteLayerSet::SetZCoordinates),
59            )
60            .register_type::<RenderZCoordinate>();
61    }
62}
63
64/// Configure how the sprite layer
65#[derive(Debug, Resource, Reflect)]
66pub struct SpriteLayerOptions {
67    pub y_sort: bool,
68}
69
70impl Default for SpriteLayerOptions {
71    fn default() -> Self {
72        Self { y_sort: true }
73    }
74}
75
76/// Set for all systems related to [`SpriteLayerPlugin`]. This is run in the
77/// render app's [`ExtractSchedule`], *not* the main app.
78#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq, SystemSet)]
79pub enum SpriteLayerSet {
80    ClearZCoordinates,
81    SetZCoordinates,
82}
83
84/// Trait for the type you use to indicate your sprites' layers. Add this as a
85/// component to any entity you want to treat as a sprite. Note that this does
86/// *not* propagate.
87pub trait LayerIndex: Eq + Hash + Component + Clone + Debug {
88    /// The actual numeric z-value that the layer index corresponds to.  Note
89    /// that the z-value for an entity can be any value in the range
90    /// `layer.as_z_coordinate() <= z < layer.as_z_coordinate() + 1.0`, and the
91    /// exact values are an implementation detail!
92    ///
93    /// With the default Bevy camera settings, your return values from this
94    /// function should be between 0 and 999.0, since the camera is at z =
95    /// 1000.0. Prefer smaller z-values since that gives more precision.
96    fn as_z_coordinate(&self) -> f32;
97}
98
99/// Clears the z-coordinate of everything with a `RenderZCoordinate` component.
100pub fn clear_z_coordinates(mut query: Query<&mut Transform, With<RenderZCoordinate>>) {
101    for mut transform in query.iter_mut() {
102        transform.bypass_change_detection().translation.z = 0.0;
103    }
104}
105
106/// Propagates the `Layer` of each entity to the `InheritedLayer` of itself and all of its
107/// descendants.
108pub fn propagate_layers<Layer: LayerIndex>(
109    recursive_query: Query<(Option<&Children>, Option<&Layer>)>,
110    root_query: Query<(Entity, &Layer), Without<Parent>>,
111    mut size: Local<usize>,
112) -> EntityHashMap<Layer> {
113    let mut layer_map = EntityHashMap::default();
114    layer_map.reserve(*size);
115    for (entity, layer) in &root_query {
116        propagate_layers_impl(entity, layer, &recursive_query, &mut layer_map);
117    }
118    *size = size.max(layer_map.len());
119    layer_map
120}
121
122/// Recursive impl for [`inherited_layers`].
123fn propagate_layers_impl<Layer: LayerIndex>(
124    entity: Entity,
125    propagated_layer: &Layer,
126    query: &Query<(Option<&Children>, Option<&Layer>)>,
127    layer_map: &mut EntityHashMap<Layer>,
128) {
129    let (children, layer) = query.get(entity).expect("query shouldn't ever fail");
130    let layer = layer.unwrap_or(propagated_layer);
131    layer_map.insert(entity, layer.clone());
132
133    let Some(children) = children else {
134        return;
135    };
136
137    for child in children {
138        propagate_layers_impl(*child, layer, query, layer_map);
139    }
140}
141
142/// Compute the z-coordinate that each entity should have. This is equal to its layer's equivalent
143/// z-coordinate, plus an offset in the range [0, 1) corresponding to its y-sorted position
144/// (if y-sorting is enabled).
145pub fn set_z_coordinates<Layer: LayerIndex>(
146    In(layers): In<EntityHashMap<Layer>>,
147    mut transform_query: Query<&mut GlobalTransform>,
148    options: Res<SpriteLayerOptions>,
149) {
150    if options.y_sort {
151        // We y-sort everything because this avoids the overhead of grouping
152        // entities by their layer.
153        let key_fn = |entity: &Entity| {
154            transform_query
155                .get(*entity)
156                .map(ZIndexSortKey::new)
157                .unwrap_or_else(|_| ZIndexSortKey::new(&Default::default()))
158        };
159        // note: parallelizing with rayon is slower(!) here. I'm not sure why. maybe it has to do
160        // with some kind of inter-thread overhead or L1/L2 cache not being shared?
161        let y_sorted = layers
162            .keys()
163            .cloned()
164            .collect::<Vec<_>>()
165            .tap_mut(|v| v.sort_by_cached_key(key_fn));
166
167        let scale_factor = 1.0 / y_sorted.len() as f32;
168        for (i, entity) in y_sorted.into_iter().enumerate() {
169            let z = layers[&entity].as_z_coordinate() + (i as f32) * scale_factor;
170            set_transform_z(&mut transform_query, entity, z);
171        }
172    } else {
173        for (entity, layer) in layers {
174            set_transform_z(&mut transform_query, entity, layer.as_z_coordinate());
175        }
176    }
177}
178
179/// Sets the given entity's global transform z. Does nothing if it doesn't have one.
180fn set_transform_z(query: &mut Query<&mut GlobalTransform>, entity: Entity, z: f32) {
181    // hacky hacky; I can't find a way to directly mutate the GlobalTransform.
182    let Some(mut transform) = query.get_mut(entity).ok() else {
183        return;
184    };
185    let transform = transform.bypass_change_detection();
186    let mut affine = transform.affine();
187    affine.translation.z = z;
188    *transform = GlobalTransform::from(affine);
189}
190
191/// Used to sort the entities within a sprite layer.
192#[derive(Debug, PartialEq, Eq, PartialOrd, Ord)]
193pub struct ZIndexSortKey(Reverse<OrderedFloat<f32>>);
194
195impl ZIndexSortKey {
196    // This is reversed because bevy uses +y pointing upwards, which is the
197    // opposite of what you generally want.
198    fn new(transform: &GlobalTransform) -> Self {
199        Self(Reverse(OrderedFloat(transform.translation().y)))
200    }
201}
202
203/// Stores the z-coordinate that will be used at render time. Don't modify this yourself.
204#[derive(Copy, Clone, Debug, PartialEq, PartialOrd, Component, Reflect)]
205pub struct RenderZCoordinate(pub f32);
206
207#[cfg(test)]
208mod tests {
209    use bevy::ecs::system::RunSystemOnce;
210
211    use super::*;
212
213    #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug, Component)]
214    enum Layer {
215        Top,
216        Middle,
217        Bottom,
218    }
219
220    impl LayerIndex for Layer {
221        fn as_z_coordinate(&self) -> f32 {
222            use Layer::*;
223            match self {
224                Bottom => 0.0,
225                Middle => 1.0,
226                Top => 2.0,
227            }
228        }
229    }
230
231    fn test_app() -> App {
232        let mut app = App::new();
233        app.add_plugins(MinimalPlugins)
234            .add_plugins(TransformPlugin)
235            .add_plugins(SpriteLayerPlugin::<Layer>::default());
236
237        app
238    }
239
240    /// Just verify that adding the plugin doesn't somehow blow everything up.
241    #[test]
242    fn plugin_add_smoke_check() {
243        let _ = test_app();
244    }
245
246    fn transform_at(x: f32, y: f32) -> TransformBundle {
247        TransformBundle::from_transform(Transform::from_xyz(x, y, 0.0))
248    }
249
250    fn get_z(world: &World, entity: Entity) -> f32 {
251        world
252            .get::<GlobalTransform>(entity)
253            .unwrap()
254            .translation()
255            .z
256    }
257
258    #[test]
259    fn simple() {
260        let mut app = test_app();
261        let top = app
262            .world_mut()
263            .spawn((transform_at(1.0, 1.0), Layer::Top))
264            .id();
265        let middle = app
266            .world_mut()
267            .spawn((transform_at(1.0, 1.0), Layer::Middle))
268            .id();
269        let bottom = app
270            .world_mut()
271            .spawn((transform_at(1.0, 1.0), Layer::Bottom))
272            .id();
273        app.update();
274
275        assert!(get_z(app.world(), bottom) < get_z(app.world(), middle));
276        assert!(get_z(app.world(), middle) < get_z(app.world(), top));
277    }
278
279    fn layer_bundle(layer: Layer) -> impl Bundle {
280        (transform_at(0.0, 0.0), layer)
281    }
282
283    #[test]
284    fn inherited() {
285        let mut app = test_app();
286        let top = app.world_mut().spawn(layer_bundle(Layer::Top)).id();
287        let child_with_layer = app
288            .world_mut()
289            .spawn(layer_bundle(Layer::Middle))
290            .set_parent(top)
291            .id();
292        let child_without_layer = app
293            .world_mut()
294            .spawn(transform_at(0.0, 0.0))
295            .set_parent(top)
296            .id();
297        app.update();
298
299        // we use .floor() here since y-sorting can add a fractional amount to the coordinates
300        assert_eq!(
301            get_z(app.world(), child_with_layer).floor(),
302            Layer::Middle.as_z_coordinate()
303        );
304        assert_eq!(
305            get_z(app.world(), child_without_layer).floor(),
306            get_z(app.world(), top).floor()
307        );
308    }
309
310    #[test]
311    fn y_sorting() {
312        let mut app = test_app();
313        for _ in 0..10 {
314            app.world_mut()
315                .spawn((transform_at(0.0, fastrand::f32()), Layer::Top));
316        }
317        app.update();
318        let positions =
319            app.world_mut()
320                .run_system_once(|query: Query<&GlobalTransform>| -> Vec<Vec3> {
321                    query
322                        .into_iter()
323                        .map(|transform| transform.translation())
324                        .collect()
325                });
326        let sorted_by_z = positions
327            .clone()
328            .tap_mut(|positions| positions.sort_by_key(|vec| OrderedFloat(vec.z)));
329        let sorted_by_y = positions
330            .tap_mut(|positions| positions.sort_by_key(|vec| Reverse(OrderedFloat(vec.y))));
331        assert_eq!(sorted_by_z, sorted_by_y);
332    }
333
334    #[test]
335    fn child_with_no_transform() {
336        let mut app = test_app();
337        let entity = app.world_mut().spawn(layer_bundle(Layer::Top)).id();
338        let child = app.world_mut().spawn_empty().set_parent(entity).id();
339        let grandchild = app
340            .world_mut()
341            .spawn(transform_at(0.0, 0.0))
342            .set_parent(child)
343            .id();
344        app.update();
345        assert_eq!(
346            get_z(app.world(), grandchild).floor(),
347            Layer::Top.as_z_coordinate()
348        );
349    }
350}