use orx_tree::*;
use std::fmt::Display;
#[derive(Debug, Clone, Copy)]
enum Instruction {
Input(usize),
Add,
AddI { val: f32 },
}
impl Display for Instruction {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Input(x) => write!(f, "Input({x})"),
Self::Add => write!(f, "Add"),
Self::AddI { val } => write!(f, "AddI({val})"),
}
}
}
#[derive(Debug)]
struct InstructionNode {
instruction: Instruction,
value: f32,
}
impl InstructionNode {
fn new(instruction: Instruction, value: f32) -> Self {
Self { instruction, value }
}
}
impl Display for InstructionNode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match &self.instruction {
Instruction::Input(x) => write!(f, "Input({}) => {}", x, self.value),
Instruction::Add => write!(f, "Add => {}", self.value),
Instruction::AddI { val } => write!(f, "AddI({}) => {}", val, self.value),
}
}
}
#[derive(Debug)]
struct Instructions {
tree: DynTree<InstructionNode>,
}
impl Instructions {
fn example() -> Self {
let mut tree = DynTree::new(InstructionNode::new(Instruction::AddI { val: 100.0 }, 0.0));
let mut n0 = tree.root_mut();
let [n1, n2] = n0.push_children([
InstructionNode::new(Instruction::Input(1), 0.0),
InstructionNode::new(Instruction::AddI { val: 2.0 }, 0.0),
]);
let _n3 = tree
.node_mut(n1)
.push_child(InstructionNode::new(Instruction::Input(0), 0.0));
let [_n4, _n5] = tree.node_mut(n2).push_children([
InstructionNode::new(Instruction::Add, 0.0),
InstructionNode::new(Instruction::AddI { val: 5.0 }, 0.0),
]);
Self { tree }
}
}
fn recursive_traversal_over_nodes<'a>(
inputs: &[f32],
mut node: NodeMut<'a, Dyn<InstructionNode>>,
) -> (NodeMut<'a, Dyn<InstructionNode>>, f32) {
let num_children = node.num_children();
let mut children_sum = 0.0;
for i in 0..num_children {
let child = node.into_child_mut(i).unwrap();
let (child, child_value) = recursive_traversal_over_nodes(inputs, child);
children_sum += child_value;
node = child.into_parent_mut().unwrap();
}
let new_value = match node.data().instruction {
Instruction::Input(i) => inputs[i],
Instruction::Add => children_sum,
Instruction::AddI { val } => val + children_sum,
};
node.data_mut().value = new_value;
(node, new_value)
}
fn recursive_traversal_over_indices(
tree: &mut DynTree<InstructionNode>,
inputs: &[f32],
node_idx: NodeIdx<Dyn<InstructionNode>>,
) -> f32 {
let node = tree.node(node_idx);
let children_ids: Vec<_> = node.children().map(|child| child.idx()).collect();
let children: Vec<_> = children_ids
.into_iter()
.map(|node| recursive_traversal_over_indices(tree, inputs, node))
.collect();
let mut node = tree.node_mut(node_idx);
let new_value = match node.data().instruction {
Instruction::Input(i) => inputs[i],
Instruction::Add => children.into_iter().sum(),
Instruction::AddI { val } => children.into_iter().sum::<f32>() + val,
};
node.data_mut().value = new_value;
new_value
}
fn recursive_set(inputs: &[f32], mut node: NodeMut<Dyn<InstructionNode>>) {
node.recursive_set(|node_data, children_data| {
let instruction = node_data.instruction;
let children_sum: f32 = children_data.iter().map(|x| x.value).sum();
let value = match node_data.instruction {
Instruction::Input(i) => inputs[i],
Instruction::Add => children_sum,
Instruction::AddI { val } => val + children_sum,
};
InstructionNode { instruction, value }
});
}
fn main() {
fn test_implementation(method: &str, f: impl FnOnce(&[f32], &mut Instructions)) {
let inputs = [10.0, 20.0];
let mut instructions = Instructions::example();
println!("\n\n### {method}");
f(&inputs, &mut instructions);
println!("\n{}\n", &instructions.tree);
}
test_implementation(
"recursive_traversal_over_indices",
|inputs, instructions| {
let root_idx = instructions.tree.root().idx();
recursive_traversal_over_indices(&mut instructions.tree, inputs, root_idx);
},
);
test_implementation("recursive_traversal_over_nodes", |inputs, instructions| {
recursive_traversal_over_nodes(inputs, instructions.tree.root_mut());
});
test_implementation("recursive_set", |inputs, instructions| {
recursive_set(inputs, instructions.tree.root_mut());
});
}