parentree/
lib.rs

1use serde::{Deserialize, Serialize};
2use slotmap::{DefaultKey, DenseSlotMap};
3
4#[derive(Debug, Clone, Serialize, Deserialize)]
5pub struct Node<T> {
6    parent_key: Option<DefaultKey>,
7    children: Vec<DefaultKey>,
8    val: T,
9}
10
11pub struct Tree<T>(DenseSlotMap<DefaultKey, Node<T>>);
12
13impl<T> Tree<T> {
14    pub fn root(val: T) -> (Self, DefaultKey) {
15        let mut this = Self(DenseSlotMap::new());
16        let root = this.0.insert(Node {
17            parent_key: None,
18            children: vec![],
19            val,
20        });
21        (this, root)
22    }
23
24    pub fn add_child(&mut self, parent_key: DefaultKey, val: T) -> DefaultKey {
25        let new = self.0.insert(Node {
26            parent_key: Some(parent_key),
27            children: vec![],
28            val,
29        });
30        let parent = self.0.get_mut(parent_key).expect("key does not exist");
31        parent.children.push(new);
32        new
33    }
34
35    pub fn modify_recursive<F, Arg>(&mut self, key: DefaultKey, func: F, arg: Arg) -> Arg
36    where
37        F: Fn(&mut T, Arg) -> Arg,
38    {
39        let my_value = self.0.get_mut(key).expect("key does not exist");
40        let arg = func(&mut my_value.val, arg);
41
42        if let Some(parent_key) = my_value.parent_key {
43            return self.modify_recursive(parent_key, func, arg);
44        }
45        return arg;
46    }
47}
48impl<T> std::ops::Deref for Tree<T> {
49    type Target = DenseSlotMap<DefaultKey, Node<T>>;
50    fn deref(&self) -> &Self::Target {
51        &self.0
52    }
53}
54
55#[cfg(test)]
56mod test {
57    use crate::Tree;
58
59    #[test]
60    fn test_tree() {
61        fn reducer(val: &mut i32, arg: i32) -> i32 {
62            *val += arg;
63            arg
64        }
65
66        let (mut tree, root) = Tree::root(2);
67
68        let child1 = tree.add_child(root, 2);
69        let child2 = tree.add_child(child1, 2);
70        tree.modify_recursive(child2, reducer, 1);
71
72        assert_eq!(tree[root].val, 3);
73        assert_eq!(tree[child1].val, 3);
74        assert_eq!(tree[child2].val, 3);
75    }
76}