1#![doc = include_str!("../README.md")]
2use std::cmp::Reverse;
3use std::fmt::Debug;
4use std::hash::Hash;
5use std::marker::PhantomData;
6
7use bevy::ecs::entity::EntityHashMap; use bevy::prelude::*;
9use ordered_float::OrderedFloat;
10use tap::Tap;
11
12pub struct SpriteLayerPlugin<Layer> {
34 phantom: PhantomData<Layer>,
35}
36
37impl<Layer> Default for SpriteLayerPlugin<Layer> {
38 fn default() -> Self {
39 Self {
40 phantom: Default::default(),
41 }
42 }
43}
44
45impl<Layer: LayerIndex> Plugin for SpriteLayerPlugin<Layer> {
46 fn build(&self, app: &mut App) {
47 app.init_resource::<SpriteLayerOptions>()
48 .add_systems(
49 First,
50 clear_z_coordinates.in_set(SpriteLayerSet::ClearZCoordinates),
51 )
52 .add_systems(
53 Last,
54 (propagate_layers::<Layer>.pipe(set_z_coordinates::<Layer>),)
57 .chain()
58 .in_set(SpriteLayerSet::SetZCoordinates),
59 )
60 .register_type::<RenderZCoordinate>();
61 }
62}
63
64#[derive(Debug, Resource, Reflect)]
66pub struct SpriteLayerOptions {
67 pub y_sort: bool,
68}
69
70impl Default for SpriteLayerOptions {
71 fn default() -> Self {
72 Self { y_sort: true }
73 }
74}
75
76#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq, SystemSet)]
79pub enum SpriteLayerSet {
80 ClearZCoordinates,
81 SetZCoordinates,
82}
83
84pub trait LayerIndex: Eq + Hash + Component + Clone + Debug {
88 fn as_z_coordinate(&self) -> f32;
97}
98
99pub fn clear_z_coordinates(mut query: Query<&mut Transform, With<RenderZCoordinate>>) {
101 for mut transform in query.iter_mut() {
102 transform.bypass_change_detection().translation.z = 0.0;
103 }
104}
105
106pub fn propagate_layers<Layer: LayerIndex>(
109 recursive_query: Query<(Option<&Children>, Option<&Layer>)>,
110 root_query: Query<(Entity, &Layer), Without<Parent>>,
111 mut size: Local<usize>,
112) -> EntityHashMap<Layer> {
113 let mut layer_map = EntityHashMap::default();
114 layer_map.reserve(*size);
115 for (entity, layer) in &root_query {
116 propagate_layers_impl(entity, layer, &recursive_query, &mut layer_map);
117 }
118 *size = size.max(layer_map.len());
119 layer_map
120}
121
122fn propagate_layers_impl<Layer: LayerIndex>(
124 entity: Entity,
125 propagated_layer: &Layer,
126 query: &Query<(Option<&Children>, Option<&Layer>)>,
127 layer_map: &mut EntityHashMap<Layer>,
128) {
129 let (children, layer) = query.get(entity).expect("query shouldn't ever fail");
130 let layer = layer.unwrap_or(propagated_layer);
131 layer_map.insert(entity, layer.clone());
132
133 let Some(children) = children else {
134 return;
135 };
136
137 for child in children {
138 propagate_layers_impl(*child, layer, query, layer_map);
139 }
140}
141
142pub fn set_z_coordinates<Layer: LayerIndex>(
146 In(layers): In<EntityHashMap<Layer>>,
147 mut transform_query: Query<&mut GlobalTransform>,
148 options: Res<SpriteLayerOptions>,
149) {
150 if options.y_sort {
151 let key_fn = |entity: &Entity| {
154 transform_query
155 .get(*entity)
156 .map(ZIndexSortKey::new)
157 .unwrap_or_else(|_| ZIndexSortKey::new(&Default::default()))
158 };
159 let y_sorted = layers
162 .keys()
163 .cloned()
164 .collect::<Vec<_>>()
165 .tap_mut(|v| v.sort_by_cached_key(key_fn));
166
167 let scale_factor = 1.0 / y_sorted.len() as f32;
168 for (i, entity) in y_sorted.into_iter().enumerate() {
169 let z = layers[&entity].as_z_coordinate() + (i as f32) * scale_factor;
170 set_transform_z(&mut transform_query, entity, z);
171 }
172 } else {
173 for (entity, layer) in layers {
174 set_transform_z(&mut transform_query, entity, layer.as_z_coordinate());
175 }
176 }
177}
178
179fn set_transform_z(query: &mut Query<&mut GlobalTransform>, entity: Entity, z: f32) {
181 let Some(mut transform) = query.get_mut(entity).ok() else {
183 return;
184 };
185 let transform = transform.bypass_change_detection();
186 let mut affine = transform.affine();
187 affine.translation.z = z;
188 *transform = GlobalTransform::from(affine);
189}
190
191#[derive(Debug, PartialEq, Eq, PartialOrd, Ord)]
193pub struct ZIndexSortKey(Reverse<OrderedFloat<f32>>);
194
195impl ZIndexSortKey {
196 fn new(transform: &GlobalTransform) -> Self {
199 Self(Reverse(OrderedFloat(transform.translation().y)))
200 }
201}
202
203#[derive(Copy, Clone, Debug, PartialEq, PartialOrd, Component, Reflect)]
205pub struct RenderZCoordinate(pub f32);
206
207#[cfg(test)]
208mod tests {
209 use bevy::ecs::system::RunSystemOnce;
210
211 use super::*;
212
213 #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug, Component)]
214 enum Layer {
215 Top,
216 Middle,
217 Bottom,
218 }
219
220 impl LayerIndex for Layer {
221 fn as_z_coordinate(&self) -> f32 {
222 use Layer::*;
223 match self {
224 Bottom => 0.0,
225 Middle => 1.0,
226 Top => 2.0,
227 }
228 }
229 }
230
231 fn test_app() -> App {
232 let mut app = App::new();
233 app.add_plugins(MinimalPlugins)
234 .add_plugins(TransformPlugin)
235 .add_plugins(SpriteLayerPlugin::<Layer>::default());
236
237 app
238 }
239
240 #[test]
242 fn plugin_add_smoke_check() {
243 let _ = test_app();
244 }
245
246 fn transform_at(x: f32, y: f32) -> TransformBundle {
247 TransformBundle::from_transform(Transform::from_xyz(x, y, 0.0))
248 }
249
250 fn get_z(world: &World, entity: Entity) -> f32 {
251 world
252 .get::<GlobalTransform>(entity)
253 .unwrap()
254 .translation()
255 .z
256 }
257
258 #[test]
259 fn simple() {
260 let mut app = test_app();
261 let top = app
262 .world_mut()
263 .spawn((transform_at(1.0, 1.0), Layer::Top))
264 .id();
265 let middle = app
266 .world_mut()
267 .spawn((transform_at(1.0, 1.0), Layer::Middle))
268 .id();
269 let bottom = app
270 .world_mut()
271 .spawn((transform_at(1.0, 1.0), Layer::Bottom))
272 .id();
273 app.update();
274
275 assert!(get_z(app.world(), bottom) < get_z(app.world(), middle));
276 assert!(get_z(app.world(), middle) < get_z(app.world(), top));
277 }
278
279 fn layer_bundle(layer: Layer) -> impl Bundle {
280 (transform_at(0.0, 0.0), layer)
281 }
282
283 #[test]
284 fn inherited() {
285 let mut app = test_app();
286 let top = app.world_mut().spawn(layer_bundle(Layer::Top)).id();
287 let child_with_layer = app
288 .world_mut()
289 .spawn(layer_bundle(Layer::Middle))
290 .set_parent(top)
291 .id();
292 let child_without_layer = app
293 .world_mut()
294 .spawn(transform_at(0.0, 0.0))
295 .set_parent(top)
296 .id();
297 app.update();
298
299 assert_eq!(
301 get_z(app.world(), child_with_layer).floor(),
302 Layer::Middle.as_z_coordinate()
303 );
304 assert_eq!(
305 get_z(app.world(), child_without_layer).floor(),
306 get_z(app.world(), top).floor()
307 );
308 }
309
310 #[test]
311 fn y_sorting() {
312 let mut app = test_app();
313 for _ in 0..10 {
314 app.world_mut()
315 .spawn((transform_at(0.0, fastrand::f32()), Layer::Top));
316 }
317 app.update();
318 let positions =
319 app.world_mut()
320 .run_system_once(|query: Query<&GlobalTransform>| -> Vec<Vec3> {
321 query
322 .into_iter()
323 .map(|transform| transform.translation())
324 .collect()
325 });
326 let sorted_by_z = positions
327 .clone()
328 .tap_mut(|positions| positions.sort_by_key(|vec| OrderedFloat(vec.z)));
329 let sorted_by_y = positions
330 .tap_mut(|positions| positions.sort_by_key(|vec| Reverse(OrderedFloat(vec.y))));
331 assert_eq!(sorted_by_z, sorted_by_y);
332 }
333
334 #[test]
335 fn child_with_no_transform() {
336 let mut app = test_app();
337 let entity = app.world_mut().spawn(layer_bundle(Layer::Top)).id();
338 let child = app.world_mut().spawn_empty().set_parent(entity).id();
339 let grandchild = app
340 .world_mut()
341 .spawn(transform_at(0.0, 0.0))
342 .set_parent(child)
343 .id();
344 app.update();
345 assert_eq!(
346 get_z(app.world(), grandchild).floor(),
347 Layer::Top.as_z_coordinate()
348 );
349 }
350}