#![doc = include_str!("../README.md")]
use std::cmp::Reverse;
use std::fmt::Debug;
use std::hash::Hash;
use std::marker::PhantomData;
use bevy::ecs::entity::EntityHashMap; use bevy::prelude::*;
use ordered_float::OrderedFloat;
use tap::Tap;
pub struct SpriteLayerPlugin<Layer> {
phantom: PhantomData<Layer>,
}
impl<Layer> Default for SpriteLayerPlugin<Layer> {
fn default() -> Self {
Self {
phantom: Default::default(),
}
}
}
impl<Layer: LayerIndex> Plugin for SpriteLayerPlugin<Layer> {
fn build(&self, app: &mut App) {
app.init_resource::<SpriteLayerOptions>()
.add_systems(
First,
clear_z_coordinates.in_set(SpriteLayerSet::ClearZCoordinates),
)
.add_systems(
Last,
(propagate_layers::<Layer>.pipe(set_z_coordinates::<Layer>),)
.chain()
.in_set(SpriteLayerSet::SetZCoordinates),
)
.register_type::<RenderZCoordinate>();
}
}
#[derive(Debug, Resource, Reflect)]
pub struct SpriteLayerOptions {
pub y_sort: bool,
}
impl Default for SpriteLayerOptions {
fn default() -> Self {
Self { y_sort: true }
}
}
#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq, SystemSet)]
pub enum SpriteLayerSet {
ClearZCoordinates,
SetZCoordinates,
}
pub trait LayerIndex: Eq + Hash + Component + Clone + Debug {
fn as_z_coordinate(&self) -> f32;
}
pub fn clear_z_coordinates(mut query: Query<&mut Transform, With<RenderZCoordinate>>) {
for mut transform in query.iter_mut() {
transform.bypass_change_detection().translation.z = 0.0;
}
}
pub fn propagate_layers<Layer: LayerIndex>(
recursive_query: Query<(Option<&Children>, Option<&Layer>)>,
root_query: Query<(Entity, &Layer), Without<Parent>>,
mut size: Local<usize>,
) -> EntityHashMap<Layer> {
let mut layer_map = EntityHashMap::default();
layer_map.reserve(*size);
for (entity, layer) in &root_query {
propagate_layers_impl(entity, layer, &recursive_query, &mut layer_map);
}
*size = size.max(layer_map.len());
layer_map
}
fn propagate_layers_impl<Layer: LayerIndex>(
entity: Entity,
propagated_layer: &Layer,
query: &Query<(Option<&Children>, Option<&Layer>)>,
layer_map: &mut EntityHashMap<Layer>,
) {
let (children, layer) = query.get(entity).expect("query shouldn't ever fail");
let layer = layer.unwrap_or(propagated_layer);
layer_map.insert(entity, layer.clone());
let Some(children) = children else {
return;
};
for child in children {
propagate_layers_impl(*child, layer, query, layer_map);
}
}
pub fn set_z_coordinates<Layer: LayerIndex>(
In(layers): In<EntityHashMap<Layer>>,
mut transform_query: Query<&mut GlobalTransform>,
options: Res<SpriteLayerOptions>,
) {
if options.y_sort {
let key_fn = |entity: &Entity| {
transform_query
.get(*entity)
.map(ZIndexSortKey::new)
.unwrap_or_else(|_| ZIndexSortKey::new(&Default::default()))
};
let y_sorted = layers
.keys()
.cloned()
.collect::<Vec<_>>()
.tap_mut(|v| v.sort_by_cached_key(key_fn));
let scale_factor = 1.0 / y_sorted.len() as f32;
for (i, entity) in y_sorted.into_iter().enumerate() {
let z = layers[&entity].as_z_coordinate() + (i as f32) * scale_factor;
set_transform_z(&mut transform_query, entity, z);
}
} else {
for (entity, layer) in layers {
set_transform_z(&mut transform_query, entity, layer.as_z_coordinate());
}
}
}
fn set_transform_z(query: &mut Query<&mut GlobalTransform>, entity: Entity, z: f32) {
let Some(mut transform) = query.get_mut(entity).ok() else {
return;
};
let transform = transform.bypass_change_detection();
let mut affine = transform.affine();
affine.translation.z = z;
*transform = GlobalTransform::from(affine);
}
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord)]
pub struct ZIndexSortKey(Reverse<OrderedFloat<f32>>);
impl ZIndexSortKey {
fn new(transform: &GlobalTransform) -> Self {
Self(Reverse(OrderedFloat(transform.translation().y)))
}
}
#[derive(Copy, Clone, Debug, PartialEq, PartialOrd, Component, Reflect)]
pub struct RenderZCoordinate(pub f32);
#[cfg(test)]
mod tests {
use bevy::ecs::system::RunSystemOnce;
use super::*;
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug, Component)]
enum Layer {
Top,
Middle,
Bottom,
}
impl LayerIndex for Layer {
fn as_z_coordinate(&self) -> f32 {
use Layer::*;
match self {
Bottom => 0.0,
Middle => 1.0,
Top => 2.0,
}
}
}
fn test_app() -> App {
let mut app = App::new();
app.add_plugins(MinimalPlugins)
.add_plugins(TransformPlugin)
.add_plugins(SpriteLayerPlugin::<Layer>::default());
app
}
#[test]
fn plugin_add_smoke_check() {
let _ = test_app();
}
fn transform_at(x: f32, y: f32) -> TransformBundle {
TransformBundle::from_transform(Transform::from_xyz(x, y, 0.0))
}
fn get_z(world: &World, entity: Entity) -> f32 {
world
.get::<GlobalTransform>(entity)
.unwrap()
.translation()
.z
}
#[test]
fn simple() {
let mut app = test_app();
let top = app
.world_mut()
.spawn((transform_at(1.0, 1.0), Layer::Top))
.id();
let middle = app
.world_mut()
.spawn((transform_at(1.0, 1.0), Layer::Middle))
.id();
let bottom = app
.world_mut()
.spawn((transform_at(1.0, 1.0), Layer::Bottom))
.id();
app.update();
assert!(get_z(app.world(), bottom) < get_z(app.world(), middle));
assert!(get_z(app.world(), middle) < get_z(app.world(), top));
}
fn layer_bundle(layer: Layer) -> impl Bundle {
(transform_at(0.0, 0.0), layer)
}
#[test]
fn inherited() {
let mut app = test_app();
let top = app.world_mut().spawn(layer_bundle(Layer::Top)).id();
let child_with_layer = app
.world_mut()
.spawn(layer_bundle(Layer::Middle))
.set_parent(top)
.id();
let child_without_layer = app
.world_mut()
.spawn(transform_at(0.0, 0.0))
.set_parent(top)
.id();
app.update();
assert_eq!(
get_z(app.world(), child_with_layer).floor(),
Layer::Middle.as_z_coordinate()
);
assert_eq!(
get_z(app.world(), child_without_layer).floor(),
get_z(app.world(), top).floor()
);
}
#[test]
fn y_sorting() {
let mut app = test_app();
for _ in 0..10 {
app.world_mut()
.spawn((transform_at(0.0, fastrand::f32()), Layer::Top));
}
app.update();
let positions =
app.world_mut()
.run_system_once(|query: Query<&GlobalTransform>| -> Vec<Vec3> {
query
.into_iter()
.map(|transform| transform.translation())
.collect()
});
let sorted_by_z = positions
.clone()
.tap_mut(|positions| positions.sort_by_key(|vec| OrderedFloat(vec.z)));
let sorted_by_y = positions
.tap_mut(|positions| positions.sort_by_key(|vec| Reverse(OrderedFloat(vec.y))));
assert_eq!(sorted_by_z, sorted_by_y);
}
#[test]
fn child_with_no_transform() {
let mut app = test_app();
let entity = app.world_mut().spawn(layer_bundle(Layer::Top)).id();
let child = app.world_mut().spawn_empty().set_parent(entity).id();
let grandchild = app
.world_mut()
.spawn(transform_at(0.0, 0.0))
.set_parent(child)
.id();
app.update();
assert_eq!(
get_z(app.world(), grandchild).floor(),
Layer::Top.as_z_coordinate()
);
}
}