1use crate::component::{Children, Parent};
2use crate::entity::Entity;
3use crate::world::World;
4
5pub trait HierarchyExt {
7 fn despawn_recursive(&mut self, entity: Entity);
9
10 fn add_child(&mut self, parent: Entity, child: Entity);
12
13 fn remove_child(&mut self, parent: Entity, child: Entity);
15}
16
17impl HierarchyExt for World {
18 fn despawn_recursive(&mut self, entity: Entity) {
19 let mut children_to_despawn = Vec::new();
20
21 if let Some(children_ptr) = self.get_component_ptr(entity, std::any::TypeId::of::<Children>()) {
22 let children = unsafe { &*(children_ptr as *const Children) };
23 for &child_id in &children.0 {
24 if let Some(child_entity) = self.reconstruct_entity(child_id) {
25 children_to_despawn.push(child_entity);
26 }
27 }
28 }
29
30 if let Some(parent_ptr) = self.get_component_ptr(entity, std::any::TypeId::of::<Parent>()) {
32 let parent_id = unsafe { (*(parent_ptr as *const Parent)).0 };
33 if let Some(parent_entity) = self.reconstruct_entity(parent_id) {
34 self.remove_child(parent_entity, entity);
35 }
36 }
37
38 for child in children_to_despawn {
40 self.despawn_recursive(child);
41 }
42
43 self.despawn(entity);
44 }
45
46 fn add_child(&mut self, parent: Entity, child: Entity) {
47 if let Some(parent_ptr) = self.get_component_ptr(child, std::any::TypeId::of::<Parent>()) {
49 let old_parent_id = unsafe { (*(parent_ptr as *const Parent)).0 };
50 if old_parent_id != parent.id() {
51 if let Some(old_parent) = self.reconstruct_entity(old_parent_id) {
52 self.remove_child(old_parent, child);
53 }
54 }
55 }
56
57 self.add_component(child, Parent(parent.id()));
59
60 if let Some(children_ptr) = self.get_component_mut_ptr(parent, std::any::TypeId::of::<Children>()) {
62 let children = unsafe { &mut *(children_ptr as *mut Children) };
63 if !children.0.contains(&child.id()) {
64 children.0.push(child.id());
65 }
66 } else {
67 self.add_component(parent, Children(vec![child.id()]));
68 }
69 }
70
71 fn remove_child(&mut self, parent: Entity, child: Entity) {
72 self.remove_component::<Parent>(child);
73
74 if let Some(children_ptr) = self.get_component_mut_ptr(parent, std::any::TypeId::of::<Children>()) {
75 let children = unsafe { &mut *(children_ptr as *mut Children) };
76 children.0.retain(|&id| id != child.id());
77 }
78 }
79}
80
81#[cfg(test)]
82mod tests {
83 use super::*;
84 use crate::world::World;
85
86 #[test]
87 fn test_hierarchy_add_remove() {
88 let mut world = World::new();
89 let parent = world.spawn();
90 let child = world.spawn();
91
92 world.add_child(parent, child);
93
94 if let Some(parent_ptr) = world.get_component_ptr(child, std::any::TypeId::of::<Parent>()) {
96 let parent_id = unsafe { (*(parent_ptr as *const Parent)).0 };
97 assert_eq!(parent_id, parent.id());
98 } else {
99 panic!("Child missing Parent component");
100 }
101
102 if let Some(children_ptr) = world.get_component_ptr(parent, std::any::TypeId::of::<Children>()) {
104 let children = unsafe { &*(children_ptr as *const Children) };
105 assert_eq!(children.0.len(), 1);
106 assert_eq!(children.0[0], child.id());
107 } else {
108 panic!("Parent missing Children component");
109 }
110
111 world.remove_child(parent, child);
113
114 assert!(world.get_component_ptr(child, std::any::TypeId::of::<Parent>()).is_none());
116
117 if let Some(children_ptr) = world.get_component_ptr(parent, std::any::TypeId::of::<Children>()) {
119 let children = unsafe { &*(children_ptr as *const Children) };
120 assert_eq!(children.0.len(), 0);
121 }
122 }
123
124 #[test]
125 fn test_despawn_recursive() {
126 let mut world = World::new();
127 let p1 = world.spawn();
128 let c1 = world.spawn();
129 let c2 = world.spawn();
130 let gc1 = world.spawn();
131
132 world.add_child(p1, c1);
133 world.add_child(p1, c2);
134 world.add_child(c1, gc1);
135
136 assert_eq!(world.entity_count(), 4);
137
138 world.despawn_recursive(p1);
140
141 assert_eq!(world.entity_count(), 0);
144 }
145}