use super::Tree;
use crate::{Eval, TreeNode, node::Node};
impl<T, V> Eval<[V], Vec<V>> for Vec<Tree<T>>
where
T: Eval<[V], V>,
V: Clone,
{
#[inline]
fn eval(&self, inputs: &[V]) -> Vec<V> {
self.iter().map(|tree| tree.eval(inputs)).collect()
}
}
impl<T, V> Eval<[V], Vec<V>> for Vec<&TreeNode<T>>
where
T: Eval<[V], V>,
V: Clone,
{
#[inline]
fn eval(&self, inputs: &[V]) -> Vec<V> {
self.iter().map(|node| node.eval(inputs)).collect()
}
}
impl<T, V> Eval<[V], V> for Tree<T>
where
T: Eval<[V], V>,
V: Clone,
{
#[inline]
fn eval(&self, input: &[V]) -> V {
self.root()
.map(|root| root.eval(input))
.unwrap_or_else(|| panic!("Tree has no root node."))
}
}
impl<T, V> Eval<[V], V> for TreeNode<T>
where
T: Eval<[V], V>,
V: Clone,
{
#[inline]
fn eval(&self, input: &[V]) -> V {
if let Some(children) = self.children() {
let mut inputs = Vec::with_capacity(children.len());
for child in children {
inputs.push(child.eval(input));
}
return self.value().eval(&inputs);
}
self.value().eval(input)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{Op, TreeNode};
#[test]
fn test_tree_reduce_simple() {
let mut root = TreeNode::new(Op::add());
root.add_child(TreeNode::new(Op::constant(1.0)));
root.add_child(TreeNode::new(Op::constant(2.0)));
let result = root.eval(&[]);
assert_eq!(result, 3.0);
}
#[test]
fn test_tree_reduce_complex() {
let node = TreeNode::new(Op::add())
.attach(
TreeNode::new(Op::mul())
.attach(TreeNode::new(Op::constant(2.0)))
.attach(TreeNode::new(Op::constant(3.0))),
)
.attach(
TreeNode::new(Op::add())
.attach(TreeNode::new(Op::constant(2.0)))
.attach(TreeNode::new(Op::var(0))),
);
let nine = node.eval(&[1_f32]);
let ten = node.eval(&[2_f32]);
let eleven = node.eval(&[3_f32]);
assert_eq!(nine, 9.0);
assert_eq!(ten, 10.0);
assert_eq!(eleven, 11.0);
}
}