common_tree/
node.rs

1
2use std::collections::VecDeque;
3
4/// A tree node
5pub struct Node<T> {
6    data: T,
7    children: Vec<Node<T>>,
8}
9
10impl<T> Node<T> {
11
12    ///
13    /// Create a Node
14    ///
15    pub fn new(data: T) -> Self {
16        Node {
17            data,
18            children: Vec::new(),
19        }
20    }
21
22    ///
23    /// Add a node to self
24    ///
25    pub fn add(&mut self, node: Node<T>) {
26        self.children.push(node);
27    }
28
29    /// Get mut data
30    pub fn data_mut(&mut self) -> &mut T {
31        &mut self.data
32    }
33
34    /// Get data
35    pub fn data(&self) -> &T {
36        &self.data
37    }
38
39    /// Get child by index
40    /// 
41    pub fn child(&self, index: usize) -> Option<&Node<T>> {
42        self.children.get(index)
43    }
44
45    /// Get last child
46    /// 
47    pub fn last_child(&self) -> Option<&Node<T>> {
48        self.children.last()
49    }
50
51    /// Get mut child by index
52    /// 
53    pub fn child_mut(&mut self, index: usize) -> Option<&mut Node<T>> {
54        self.children.get_mut(index)
55    }
56
57    /// Get last mut child
58    /// 
59    pub fn last_child_mut(&mut self) -> Option<&mut Node<T>> {
60        self.children.last_mut()
61    }
62
63    /// Get children
64    pub fn children(&self) -> &Vec<Node<T>> {
65        &self.children
66    }
67
68    /// Get mut children
69    pub fn children_mut(&mut self) -> &mut Vec<Node<T>> {
70        &mut self.children
71    }
72
73    ///
74    /// Get child by a vec path
75    /// path: [0, 1, 3]
76    ///
77    pub fn child_by_path(&self, path: &Vec<usize>) -> Option<&Node<T>> {
78        let mut node: Option<&Node<T>> = Some(self);
79        let level = path.len();
80        for i in 1..level {
81            if node.is_some() {
82                node = node.unwrap().child(*path.get(i).unwrap());
83            }
84        }
85        return node;
86    }
87
88    ///
89    /// Get mut child by a vec path
90    /// path: [0, 1, 3]
91    ///
92    pub fn child_mut_by_path(&mut self, path: &Vec<usize>) -> Option<&mut Node<T>> {
93        let mut node: Option<&mut Node<T>> = Some(self);
94        let level = path.len();
95        for i in 1..level {
96            if node.is_some() {
97                node = node.unwrap().child_mut(*path.get(i).unwrap());
98            }
99        }
100        return node;
101    }
102
103    ///
104    /// get right child by deepth
105    /// 
106    pub fn last_child_by_level(&self, level: usize) -> Option<&Node<T>> {
107        let mut node: Option<&Node<T>> = Some(self);
108        for _ in 0..level {
109            if node.is_some() {
110                node = node.unwrap().last_child()
111            } else {
112                node = None;
113                break;
114            }
115        }
116        node
117    }
118
119    ///
120    /// get right mut child by deepth
121    /// 
122    pub fn last_child_mut_by_level(&mut self, level: usize) -> Option<&mut Node<T>> {
123        let mut node: Option<&mut Node<T>> = Some(self);
124        for _ in 0..level {
125            if node.is_some() {
126                node = node.unwrap().last_child_mut()
127            } else {
128                node = None;
129                break;
130            }
131        }
132        node
133    }
134
135    ///  Deepth first traversal (Preorder) of a tree
136    ///  
137    pub fn deepth_first_search<F: FnMut(&T)>(&self, mut f: F) {
138        self.dfs_helper(&mut f);
139    }
140
141    fn dfs_helper<F: FnMut(&T)>(&self, f: &mut F) {
142        f(self.data());
143        for child in self.children() {
144            child.dfs_helper(f);
145        }
146    }
147    /// Breadth first traversal (Level Order) of a tree
148    ///  
149    pub fn breadth_first_search<F: FnMut(&T)>(&self, mut f: F) {
150        let mut queue: VecDeque<&Node<T>> = VecDeque::new();
151        queue.push_back(self);
152        while let Some(node) = queue.pop_front() {
153            f(node.data());
154            for child in node.children() {
155                queue.push_back(child);
156            }
157        }
158    }
159}
160
161#[cfg(test)]
162mod tests {
163    use super::*;
164    #[test]
165    fn data_works() {
166        let s = format!("hello");
167        let mut node = Node::new(s.clone());
168        assert_eq!(node.data(), &s);
169        let s2 = format!("world");
170        let data = node.data_mut();
171        *data = s2.clone();
172        assert_eq!(node.data(), &s2);
173    }
174
175    #[test]
176    fn child_works() {
177        let s = format!("root");
178        let mut root = Node::new(s);
179        let level1_data = String::from("level1_1");
180        let level1_2_data = String::from("level1_2");
181        let level2_1_data = String::from("level2_1");
182        let level1 = Node::new(level1_data.clone());
183        let mut level1_2 = Node::new(level1_2_data.clone());
184        let level2_1 = Node::new(level2_1_data.clone());
185
186        level1_2.add(level2_1);
187        root.add(level1);
188        root.add(level1_2);
189        assert_eq!(root.children.len(), 2);
190        assert_eq!(root.child(0).unwrap().data(), &level1_data);
191        assert_eq!(root.child(1).unwrap().data(), &level1_2_data);
192        assert_eq!(root.child(0).unwrap().children.len(), 0);
193        assert_eq!(root.child(1).unwrap().children.len(), 1);
194        assert_eq!(
195            root.child(1).unwrap().child(0).unwrap().data(),
196            &level2_1_data
197        );
198
199        let level2_2_data = String::from("level2_2");
200        let level2_2 = Node::new(level2_2_data.clone());
201
202        let level1_2_mut_opt = root.child_mut(1);
203        assert!(level1_2_mut_opt.is_some());
204        let level1_2_new_data = String::from("level1_2_new_data");
205        let level1_2_mut = level1_2_mut_opt.unwrap();
206        *(level1_2_mut.data_mut()) = level1_2_new_data.clone();
207        assert_eq!(level1_2_mut.data(), &level1_2_new_data);
208    }
209
210    //    ---o(root)--
211    //   /            \
212    //  o(level1_1)    o(level1_2)
213    //                 /
214    //                o(level2_1)
215    fn get_tree() -> Node<String> {
216        let s = format!("root");
217        let mut root = Node::new(s);
218        let level1_data = String::from("level1_1");
219        let level1_2_data = String::from("level1_2");
220        let level2_1_data = String::from("level2_1");
221        let level1 = Node::new(level1_data.clone());
222        let mut level1_2 = Node::new(level1_2_data.clone());
223        let level2_1 = Node::new(level2_1_data.clone());
224        level1_2.add(level2_1);
225        root.add(level1);
226        root.add(level1_2);
227        return root;
228    }
229
230    #[test]
231    fn child_path_works() {
232        let root = get_tree();
233        let root1: Option<&Node<String>> = root.child_by_path(&vec![0]);
234        assert!(root1.is_some());
235        assert_eq!(root1.unwrap().data(), &String::from("root"));
236        let level2_1 = root.child_by_path(&vec![0, 1, 0]);
237        assert!(level2_1.is_some());
238        assert_eq!(level2_1.unwrap().data(), &String::from("level2_1"));
239    }
240
241    #[test]
242    fn last_child_level_works() {
243        let root = get_tree();
244        assert_eq!(root.last_child_by_level(0).unwrap().data(), &String::from("root"));
245        assert_eq!(root.last_child_by_level(1).unwrap().data(), &String::from("level1_2"));
246        assert_eq!(root.last_child_by_level(2).unwrap().data(), &String::from("level2_1"));
247        assert!(root.last_child_by_level(3).is_none());
248        assert!(root.last_child_by_level(4).is_none());
249    }
250    //
251    //       ------------------o(root)---------------
252    //      /                                        \
253    //     o(level1_1)                               o(level1_2)
254    //    /           \                              /
255    //   o(level2_2)   o(level2_3)                  o(level2_1)
256    //
257    fn get_tree2() -> Node<String> {
258        let s = format!("root");
259        let mut root = Node::new(s);
260        let level1_data = String::from("level1_1");
261        let level1_2_data = String::from("level1_2");
262        let level2_1_data = String::from("level2_1");
263        let level2_2_data = String::from("level2_2");
264        let level2_3_data = String::from("level2_3");
265        let mut level1 = Node::new(level1_data.clone());
266        let mut level1_2 = Node::new(level1_2_data.clone());
267        let level2_1 = Node::new(level2_1_data.clone());
268        let level2_2 = Node::new(level2_2_data.clone());
269        let level2_3 = Node::new(level2_3_data.clone());
270        level1_2.add(level2_1);
271        level1.add(level2_2);
272        level1.add(level2_3);
273        root.add(level1);
274        root.add(level1_2);
275        return root;
276    }
277
278    #[test]
279    fn traversal_works() {
280        let root = get_tree2();
281        let mut dfs_str = String::new();
282        let mut bfs_str = String::new();
283        root.deepth_first_search(|d| {
284            dfs_str = format!("{}-{}", dfs_str, d);
285            // println!("{}", d);
286        });
287        assert_eq!(
288            dfs_str,
289            String::from("-root-level1_1-level2_2-level2_3-level1_2-level2_1")
290        );
291        root.breadth_first_search(|d| {
292            bfs_str = format!("{}-{}", bfs_str, d);
293        });
294
295        assert_eq!(
296            bfs_str,
297            String::from("-root-level1_1-level1_2-level2_2-level2_3-level2_1")
298        );
299    }
300}