radiate_gp/collections/trees/
tree.rs

1use crate::collections::TreeIterator;
2use crate::collections::TreeNode;
3use crate::node::Node;
4
5use std::fmt::Debug;
6
7#[derive(Clone, PartialEq, Default)]
8pub struct Tree<T> {
9    root: Option<TreeNode<T>>,
10}
11
12impl<T> Tree<T> {
13    pub fn new(root: impl Into<TreeNode<T>>) -> Self {
14        Tree {
15            root: Some(root.into()),
16        }
17    }
18
19    pub fn root(&self) -> Option<&TreeNode<T>> {
20        self.root.as_ref()
21    }
22
23    pub fn root_mut(&mut self) -> Option<&mut TreeNode<T>> {
24        self.root.as_mut()
25    }
26
27    pub fn take_root(self) -> Option<TreeNode<T>> {
28        self.root
29    }
30
31    pub fn size(&self) -> usize {
32        self.root.as_ref().map_or(0, |node| node.size())
33    }
34
35    pub fn height(&self) -> usize {
36        self.root.as_ref().map_or(0, |node| node.height())
37    }
38}
39
40impl<T> AsRef<TreeNode<T>> for Tree<T> {
41    fn as_ref(&self) -> &TreeNode<T> {
42        self.root.as_ref().unwrap()
43    }
44}
45
46impl<T> AsMut<TreeNode<T>> for Tree<T> {
47    fn as_mut(&mut self) -> &mut TreeNode<T> {
48        self.root.as_mut().unwrap()
49    }
50}
51
52impl<T: Debug> Debug for Tree<T> {
53    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54        write!(f, "Tree {{\n")?;
55        for node in self.iter_breadth_first() {
56            write!(f, "  {:?},\n", node.value())?;
57        }
58        write!(f, "}}")
59    }
60}
61
62#[cfg(test)]
63mod test {
64    use super::*;
65
66    use crate::Op;
67
68    #[test]
69    fn test_tree() {
70        let mut tree_one = Tree::new(TreeNode::with_children(
71            Op::add(),
72            vec![Op::constant(1.0), Op::constant(2.0)],
73        ));
74
75        let mut tree_two = Tree::new(TreeNode::with_children(
76            Op::mul(),
77            vec![Op::constant(3.0), Op::constant(4.0)],
78        ));
79
80        // Swap the first child of each tree
81        tree_one.as_mut().swap_subtrees(tree_two.as_mut(), 1, 1);
82
83        // Verify swap using breadth-first traversal
84        let values_one = tree_one
85            .iter_breadth_first()
86            .filter_map(|n| match &n.value() {
87                Op::Const(_, v) => Some(*v),
88                _ => None,
89            })
90            .collect::<Vec<f32>>();
91
92        assert_eq!(values_one, vec![3.0, 2.0]);
93    }
94}