#[cfg(test)]
mod ast_tests {
use rssn_advanced::ast::convert::{ast_to_dag, dag_to_ast};
use rssn_advanced::dag::builder::DagBuilder;
use rssn_advanced::dag::symbol::{OpKind, SymbolKind};
#[test]
fn test_complex_ast_projection_traversal() {
let mut builder = DagBuilder::new();
let x = builder.variable("x");
let y = builder.variable("y");
let z = builder.variable("z");
let two = builder.constant(2.0);
let mul = builder.mul(x, y);
let power = builder.pow(z, two);
let root = builder.add(mul, power);
let ast = dag_to_ast(builder.arena(), root);
assert_eq!(
ast.len(),
7,
"AST projection should contain exactly 7 nodes"
);
let root_node = ast.root().unwrap();
assert_eq!(root_node.kind, SymbolKind::Operator(OpKind::Add));
if let rssn_advanced::ast::projection::AstChildList::Two([left_ptr, right_ptr]) =
root_node.children
{
let left_child = ast.resolve(0, left_ptr).unwrap();
assert_eq!(left_child.kind, SymbolKind::Operator(OpKind::Mul));
if let rssn_advanced::ast::projection::AstChildList::Two([ll_ptr, _]) =
left_child.children
{
let left_idx = left_ptr.resolve(0).unwrap();
let ll_child = ast.resolve(left_idx, ll_ptr).unwrap();
if let SymbolKind::Variable(sym_id) = ll_child.kind {
assert_eq!(builder.registry().name(sym_id), Some("x"));
} else {
panic!("Expected variable x");
}
}
let right_child = ast.resolve(0, right_ptr).unwrap();
assert_eq!(right_child.kind, SymbolKind::Operator(OpKind::Pow));
} else {
panic!("Expected addition to have exactly 2 children");
}
}
#[test]
fn test_ast_projection_mutability_and_dedup_writeback() {
let mut builder = DagBuilder::new();
let a = builder.variable("a");
let b = builder.variable("b");
let sum = builder.add(a, b);
let root = builder.add(sum, sum);
assert_eq!(
builder.arena().len(),
4,
"DAG size should be exactly 4 due to perfect sharing"
);
let mut ast = dag_to_ast(builder.arena(), root);
assert_eq!(
ast.len(),
7,
"Unshared AST tree projection should have 7 nodes"
);
let two_const = rssn_advanced::ast::projection::AstNode {
kind: SymbolKind::Constant(2.0),
value: 2.0,
dag_id: rssn_advanced::dag::node::DagNodeId::NONE,
children: rssn_advanced::ast::projection::AstChildList::Empty,
};
ast.nodes.push(two_const);
let const_idx = ast.len() - 1;
let left_ptr = rssn_advanced::ast::pointer::RelPtr::<_, i32>::from_indices(0, 1);
let right_ptr = rssn_advanced::ast::pointer::RelPtr::<_, i32>::from_indices(0, const_idx);
ast.nodes[0].kind = SymbolKind::Operator(OpKind::Mul);
ast.nodes[0].children =
rssn_advanced::ast::projection::AstChildList::Two([left_ptr, right_ptr]);
let new_root = ast_to_dag(&ast, &mut builder);
let new_root_node = builder.arena().get(new_root).unwrap();
assert_eq!(new_root_node.kind, SymbolKind::Operator(OpKind::Mul));
assert_eq!(
builder.arena().len(),
6,
"Writeback failed to deduplicate subexpressions correctly"
);
}
}