Skip to main content

tree_traversal/
tree_traversal.rs

1//! Tree traversal example demonstrating recursive operations with EntityPtr.
2//!
3//! This example shows how to:
4//! - Build a tree structure using EntityHandle references
5//! - Recursively traverse trees to sum values
6//! - Find the root of a tree by following parent links
7//! - Compute tree depth
8//! - Use `WorldExt` to create `EntityPtr` without unsafe blocks
9//!
10//! Run with: `cargo run --example tree_traversal`
11
12use bevy_ecs::prelude::*;
13use bevy_entity_ptr::{EntityHandle, EntityPtr, WorldExt};
14
15// Components for our tree structure
16
17#[derive(Component)]
18struct Name(&'static str);
19
20#[derive(Component)]
21struct Value(i32);
22
23#[derive(Component)]
24struct Parent(EntityHandle);
25
26#[derive(Component)]
27struct Children(Vec<EntityHandle>);
28
29// Recursive function to sum all values in a subtree
30fn sum_tree(node: EntityPtr) -> i32 {
31    let my_value = node.get::<Value>().map(|v| v.0).unwrap_or(0);
32
33    let children_sum: i32 = node
34        .get::<Children>()
35        .map(|c| c.0.iter().map(|h| sum_tree(node.follow_handle(*h))).sum())
36        .unwrap_or(0);
37
38    my_value + children_sum
39}
40
41// Recursive function to find the root by traversing parent links
42fn find_root(node: EntityPtr) -> EntityPtr {
43    match node.follow::<Parent, _>(|p| p.0) {
44        Some(parent) => find_root(parent),
45        None => node,
46    }
47}
48
49// Recursive function to compute tree depth
50fn tree_depth(node: EntityPtr) -> usize {
51    node.get::<Children>()
52        .map(|c| {
53            c.0.iter()
54                .map(|h| tree_depth(node.follow_handle(*h)))
55                .max()
56                .unwrap_or(0)
57                + 1
58        })
59        .unwrap_or(0)
60}
61
62// Recursive function to collect all node names in pre-order
63fn collect_names(node: EntityPtr, names: &mut Vec<&'static str>) {
64    if let Some(name) = node.get::<Name>() {
65        names.push(name.0);
66    }
67
68    if let Some(children) = node.get::<Children>() {
69        for child_handle in &children.0 {
70            collect_names(node.follow_handle(*child_handle), names);
71        }
72    }
73}
74
75fn main() {
76    let mut world = World::new();
77
78    // Build a tree:
79    //
80    //           root (10)
81    //          /    \
82    //       a (5)   b (3)
83    //         |
84    //       c (2)
85    //         |
86    //       d (7)
87    //
88    println!("Building tree structure...\n");
89
90    let d = world.spawn((Name("d"), Value(7))).id();
91    let c = world
92        .spawn((Name("c"), Value(2), Children(vec![EntityHandle::new(d)])))
93        .id();
94    // Add parent link to d
95    world.entity_mut(d).insert(Parent(EntityHandle::new(c)));
96
97    let a = world
98        .spawn((Name("a"), Value(5), Children(vec![EntityHandle::new(c)])))
99        .id();
100    world.entity_mut(c).insert(Parent(EntityHandle::new(a)));
101
102    let b = world.spawn((Name("b"), Value(3))).id();
103
104    let root = world
105        .spawn((
106            Name("root"),
107            Value(10),
108            Children(vec![EntityHandle::new(a), EntityHandle::new(b)]),
109        ))
110        .id();
111    world.entity_mut(a).insert(Parent(EntityHandle::new(root)));
112    world.entity_mut(b).insert(Parent(EntityHandle::new(root)));
113
114    // No unsafe needed! WorldExt provides ergonomic access
115    // Demonstrate tree operations
116    let root_ptr = world.entity_ptr(root);
117
118    // 1. Sum all values
119    let total = sum_tree(root_ptr);
120    println!("Sum of all values in tree: {}", total);
121    println!("  Expected: 10 + 5 + 2 + 7 + 3 = 27");
122    assert_eq!(total, 27);
123
124    // 2. Find root from any node
125    let d_ptr = world.entity_ptr(d);
126    let found_root = find_root(d_ptr);
127    let root_name = found_root.get::<Name>().unwrap().0;
128    println!("\nFinding root from node 'd': {}", root_name);
129    assert_eq!(root_name, "root");
130
131    let b_ptr = world.entity_ptr(b);
132    let found_root = find_root(b_ptr);
133    let root_name = found_root.get::<Name>().unwrap().0;
134    println!("Finding root from node 'b': {}", root_name);
135    assert_eq!(root_name, "root");
136
137    // 3. Compute tree depth (number of edges on longest path)
138    let depth = tree_depth(root_ptr);
139    println!("\nTree depth from root: {}", depth);
140    println!("  Expected: 3 (root->a, a->c, c->d = 3 edges)");
141    assert_eq!(depth, 3);
142
143    let a_ptr = world.entity_ptr(a);
144    let a_depth = tree_depth(a_ptr);
145    println!("Subtree depth from 'a': {}", a_depth);
146    assert_eq!(a_depth, 2);
147
148    // 4. Collect all names in pre-order traversal
149    let mut names = Vec::new();
150    collect_names(root_ptr, &mut names);
151    println!("\nPre-order traversal: {:?}", names);
152    assert_eq!(names, vec!["root", "a", "c", "d", "b"]);
153
154    println!("\nAll assertions passed!");
155}