use crate::{TreeIterator, collections::TreeNode};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use std::fmt::Debug;
#[derive(Clone, PartialEq, Default)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct Tree<T> {
root: Option<TreeNode<T>>,
}
impl<T> Tree<T> {
pub fn new(root: impl Into<TreeNode<T>>) -> Self {
Tree {
root: Some(root.into()),
}
}
pub fn root(&self) -> Option<&TreeNode<T>> {
self.root.as_ref()
}
pub fn root_mut(&mut self) -> Option<&mut TreeNode<T>> {
self.root.as_mut()
}
pub fn take_root(self) -> Option<TreeNode<T>> {
self.root
}
pub fn size(&self) -> usize {
self.root.as_ref().map_or(0, |node| node.size())
}
pub fn height(&self) -> usize {
self.root.as_ref().map_or(0, |node| node.height())
}
}
impl<T> AsRef<TreeNode<T>> for Tree<T> {
fn as_ref(&self) -> &TreeNode<T> {
self.root.as_ref().unwrap()
}
}
impl<T> AsMut<TreeNode<T>> for Tree<T> {
fn as_mut(&mut self) -> &mut TreeNode<T> {
self.root.as_mut().unwrap()
}
}
impl<T: Debug> Debug for Tree<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Tree {{\n")?;
for node in self.iter_breadth_first() {
write!(f, " {:?}\n", node)?;
}
write!(f, "}}")
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::{Arity, Node, NodeType, Op, TreeIterator};
#[test]
fn test_swap_subtrees() {
let mut tree_one = Tree::new(
TreeNode::new(Op::add())
.attach(TreeNode::new(Op::constant(1.0)))
.attach(TreeNode::new(Op::constant(2.0))),
);
let mut tree_two = Tree::new(
TreeNode::new(Op::mul())
.attach(TreeNode::new(Op::constant(3.0)))
.attach(TreeNode::new(Op::constant(4.0))),
);
tree_one.as_mut().swap_subtrees(tree_two.as_mut(), 1, 1);
let values_one = tree_one
.iter_breadth_first()
.filter_map(|n| match &n.value() {
Op::Const(_, v) => Some(*v),
_ => None,
})
.collect::<Vec<f32>>();
assert_eq!(values_one, vec![3.0, 2.0]);
}
#[test]
fn test_size() {
let tree = Tree::new(
TreeNode::new(Op::add())
.attach(TreeNode::from(Op::constant(1.0)))
.attach(TreeNode::from(Op::constant(2.0))),
);
assert_eq!(tree.size(), 3);
}
#[test]
fn test_depth() {
let store = vec![
(NodeType::Vertex, vec![Op::add(), Op::sub(), Op::mul()]),
(NodeType::Leaf, vec![Op::constant(1.0), Op::constant(2.0)]),
];
let tree = Tree::with_depth(5, store);
assert_eq!(tree.height(), 5);
}
#[test]
fn test_tree_with_mixed_arity() {
let store = vec![
(
NodeType::Vertex,
vec![
Op::add(), Op::constant(1.0), Op::sigmoid(), ],
),
(NodeType::Leaf, vec![Op::constant(2.0)]),
];
let tree = Tree::with_depth(3, store);
for node in tree.iter_breadth_first() {
match node.value() {
Op::Fn(name, arity, _) if *name == "add" || *name == "sub" || *name == "mul" => {
assert_eq!(**arity, 2, "Binary operator should have arity 2")
}
Op::Const(_, _) => assert_eq!(*node.arity(), 0, "Constant should have arity 0"),
Op::Fn(name, arity, _) if *name == "sigmoid" => {
assert!(
vec![0, 1, 2].contains(&**arity),
"Unary operator should have arity 0 or 1 or 2"
)
}
_ => (), }
}
}
#[test]
fn test_tree_with_zero_arity() {
let store = vec![
(NodeType::Vertex, vec![Op::constant(1.0)]), (NodeType::Leaf, vec![Op::constant(2.0)]),
];
let tree = Tree::with_depth(2, store);
for node in tree.iter_breadth_first() {
println!("Node: {:?}", node);
assert_eq!(*node.arity(), 0, "Vertex node should have zero arity");
assert!(
node.children().is_none(),
"Vertex node should have no children"
);
}
}
#[test]
fn test_tree_with_exact_arity() {
let store = vec![
(NodeType::Vertex, vec![Op::add(), Op::sub()]), (NodeType::Leaf, vec![Op::constant(1.0), Op::constant(2.0)]),
];
let tree = Tree::with_depth(2, store);
for node in tree.iter_breadth_first() {
if node.node_type() == NodeType::Vertex {
assert_eq!(node.arity(), Arity::Exact(2));
assert_eq!(node.children().unwrap().len(), 2);
}
}
}
#[test]
fn test_tree_with_only_leaf_nodes() {
let store = vec![(NodeType::Leaf, vec![Op::constant(1.0), Op::constant(2.0)])];
let tree = Tree::with_depth(3, store);
assert!(tree.root().is_some());
assert_eq!(tree.root().unwrap().node_type(), NodeType::Leaf);
assert_eq!(tree.size(), 1);
assert_eq!(tree.height(), 0);
}
#[test]
fn test_tree_with_empty_store() {
let empty_store: Vec<(NodeType, Vec<Op<f32>>)> = vec![];
let tree = Tree::with_depth(3, empty_store);
assert!(tree.root().is_some());
assert_eq!(tree.size(), 1);
assert_eq!(tree.height(), 0);
}
#[test]
fn test_tree_debug() {
let tree = Tree::new(
TreeNode::new(Op::add())
.attach(TreeNode::new(Op::constant(1.0)))
.attach(TreeNode::new(Op::constant(2.0))),
);
let debug_str = format!("{:?}", tree);
assert!(debug_str.contains("Tree {"));
assert!(debug_str.contains("add"));
assert!(debug_str.contains("C"));
}
#[test]
fn test_tree_as_ref_as_mut() {
let mut tree = Tree::new(
TreeNode::new(Op::add())
.attach(TreeNode::new(Op::constant(1.0)))
.attach(TreeNode::new(Op::constant(2.0))),
);
let root_ref: &TreeNode<Op<f32>> = tree.as_ref();
assert_eq!(root_ref.value(), &Op::add());
assert_eq!(root_ref.children().unwrap().len(), 2);
let root_mut: &mut TreeNode<Op<f32>> = tree.as_mut();
assert_eq!(root_mut.value(), &Op::add());
root_mut
.children_mut()
.unwrap()
.push(TreeNode::new(Op::constant(3.0))); assert_eq!(root_mut.children().unwrap().len(), 3); }
#[test]
fn test_tree_root_operations() {
let mut empty_tree = Tree::<Op<f32>>::default();
assert!(empty_tree.root().is_none());
assert!(empty_tree.root_mut().is_none());
assert!(empty_tree.take_root().is_none());
let tree = Tree::new(
TreeNode::new(Op::add())
.attach(TreeNode::new(Op::constant(1.0)))
.attach(TreeNode::new(Op::constant(2.0))),
);
let root = tree.root().unwrap();
assert_eq!(root.value(), &Op::add());
assert_eq!(root.children().unwrap().len(), 2);
let root = tree.take_root().unwrap();
assert_eq!(root.value(), &Op::add());
}
#[test]
#[cfg(feature = "serde")]
fn test_tree_can_serde() {
use crate::Eval;
let store = vec![
(
NodeType::Vertex,
vec![
Op::add(),
Op::sub(),
Op::mul(),
Op::div(),
Op::sigmoid(),
Op::tanh(),
],
),
(
NodeType::Leaf,
vec![Op::constant(1.0), Op::constant(2.0), Op::var(0)],
),
];
let tree = Tree::with_depth(5, store);
let eval_before = tree.eval(&[3.0]);
let serialized = serde_json::to_string(&tree).expect("Failed to serialize tree");
let deserialized: Tree<Op<f32>> =
serde_json::from_str(&serialized).expect("Failed to deserialize tree");
let eval_after = deserialized.eval(&[3.0]);
assert_eq!(
eval_before, eval_after,
"Tree evaluation should match before and after serialization"
);
assert_eq!(tree, deserialized);
}
}