radiate_gp/collections/trees/
eval.rs

1use super::Tree;
2use crate::{Eval, TreeNode, node::Node};
3
4/// Implements the `Reduce` trait for `Vec<Tree<T>>`. This is a wrapper around a `Vec<Tree<T>>`
5/// and allows for the evaluation of each `Tree` in the `Vec` with a single input.
6/// This is useful for things like `Ensemble` models where multiple models are used to make a prediction.
7///
8/// This is a simple implementation that just maps over the `Vec` and calls `eval` on each `Tree`.
9impl<T, V> Eval<[V], Vec<V>> for Vec<Tree<T>>
10where
11    T: Eval<[V], V>,
12    V: Clone,
13{
14    #[inline]
15    fn eval(&self, inputs: &[V]) -> Vec<V> {
16        self.iter().map(|tree| tree.eval(inputs)).collect()
17    }
18}
19
20/// Implements the `Reduce` trait for `Tree<T>` where `T` is `Eval<[V], V>`. All this really does is
21/// call the `reduce` method on the root node of the `Tree`. The real work is
22/// done in the `TreeNode` implementation below.
23impl<T, V> Eval<[V], V> for Tree<T>
24where
25    T: Eval<[V], V>,
26    V: Clone,
27{
28    #[inline]
29    fn eval(&self, input: &[V]) -> V {
30        self.root()
31            .map(|root| root.eval(input))
32            .unwrap_or_else(|| panic!("Tree has no root node."))
33    }
34}
35
36/// Implements the `Reduce` trait for `TreeNode<T>` where `T` is `Eval<[V], V>`. This is where the real work is done.
37/// It recursively evaluates the `TreeNode` and its children until it reaches a leaf node,
38/// at which point it applies the `T`'s eval fn to the input.
39///
40/// Because a `Tree` has only a single root node, this can only be used to return a single value.
41/// We assume here that each leaf can eval the incoming input - this is a safe and the
42/// only real logical assumption we can make.
43impl<T, V> Eval<[V], V> for TreeNode<T>
44where
45    T: Eval<[V], V>,
46    V: Clone,
47{
48    #[inline]
49    fn eval(&self, input: &[V]) -> V {
50        if self.is_leaf() {
51            self.value().eval(input)
52        } else {
53            if let Some(children) = self.children() {
54                let mut inputs = Vec::with_capacity(children.len());
55
56                for child in children {
57                    inputs.push(child.eval(input));
58                }
59
60                return self.value().eval(&inputs);
61            }
62
63            panic!("Node is not a leaf and has no children - this should never happen.");
64        }
65    }
66}
67
68#[cfg(test)]
69mod tests {
70    use super::*;
71    use crate::{Op, TreeNode};
72
73    #[test]
74    fn test_tree_reduce_simple() {
75        let mut root = TreeNode::new(Op::add());
76
77        root.add_child(TreeNode::new(Op::constant(1.0)));
78        root.add_child(TreeNode::new(Op::constant(2.0)));
79
80        let result = root.eval(&vec![]);
81
82        assert_eq!(result, 3.0);
83    }
84
85    #[test]
86    fn test_tree_reduce_complex() {
87        let tree = Tree::new(
88            TreeNode::new(Op::add())
89                .attach(
90                    TreeNode::new(Op::mul())
91                        .attach(TreeNode::new(Op::constant(2.0)))
92                        .attach(TreeNode::new(Op::constant(3.0))),
93                )
94                .attach(
95                    TreeNode::new(Op::add())
96                        .attach(TreeNode::new(Op::constant(2.0)))
97                        .attach(TreeNode::new(Op::var(0))),
98                ),
99        );
100
101        let nine = tree.eval(&vec![1_f32]);
102        let ten = tree.eval(&vec![2_f32]);
103        let eleven = tree.eval(&vec![3_f32]);
104
105        assert_eq!(nine, 9.0);
106        assert_eq!(ten, 10.0);
107        assert_eq!(eleven, 11.0);
108    }
109}