use crate::tree::{EmlNode, EmlTree};
use std::collections::HashMap;
use std::sync::Arc;
pub fn simplify(tree: &EmlTree) -> EmlTree {
let mut cache: HashMap<EmlNodeKey, Arc<EmlNode>> = HashMap::new();
let simplified = simplify_node(&tree.root, &mut cache);
EmlTree::from_node(simplified)
}
#[derive(Clone, Hash, PartialEq, Eq)]
enum EmlNodeKey {
One,
Var(usize),
Eml(EmlNodeKey2, EmlNodeKey2),
}
#[derive(Clone, Hash, PartialEq, Eq)]
struct EmlNodeKey2(Box<EmlNodeKey>);
fn make_key(node: &EmlNode) -> EmlNodeKey {
match node {
EmlNode::One => EmlNodeKey::One,
EmlNode::Var(i) => EmlNodeKey::Var(*i),
EmlNode::Eml { left, right } => EmlNodeKey::Eml(
EmlNodeKey2(Box::new(make_key(left))),
EmlNodeKey2(Box::new(make_key(right))),
),
}
}
fn simplify_node(node: &EmlNode, cache: &mut HashMap<EmlNodeKey, Arc<EmlNode>>) -> Arc<EmlNode> {
let key = make_key(node);
if let Some(cached) = cache.get(&key) {
return Arc::clone(cached);
}
let result = match node {
EmlNode::One | EmlNode::Var(_) => Arc::new(node.clone()),
EmlNode::Eml { left, right } => {
let left_s = simplify_node(left, cache);
let right_s = simplify_node(right, cache);
if matches!(left_s.as_ref(), EmlNode::One) {
if let Some(inner) = match_ln_of_exp(&right_s) {
let inner_simplified = simplify_node(&inner, cache);
cache.insert(key, Arc::clone(&inner_simplified));
return inner_simplified;
}
}
if matches!(right_s.as_ref(), EmlNode::One) {
if let Some(inner) = match_exp_of_ln(&left_s) {
let inner_simplified = simplify_node(&inner, cache);
cache.insert(key, Arc::clone(&inner_simplified));
return inner_simplified;
}
}
if matches!(right_s.as_ref(), EmlNode::One) {
if let Some(inner) = match_ln_pattern(&left_s) {
let inner_simplified = simplify_node(&inner, cache);
cache.insert(key, Arc::clone(&inner_simplified));
return inner_simplified;
}
}
Arc::new(EmlNode::Eml {
left: left_s,
right: right_s,
})
}
};
cache.insert(key, Arc::clone(&result));
result
}
fn match_ln_of_exp(right: &EmlNode) -> Option<EmlNode> {
if let EmlNode::Eml {
left: mid_l,
right: mid_r,
} = right
{
if !matches!(mid_r.as_ref(), EmlNode::One) {
return None;
}
if let EmlNode::Eml {
left: inner_l,
right: inner_r,
} = mid_l.as_ref()
{
if !matches!(inner_l.as_ref(), EmlNode::One) {
return None;
}
if let EmlNode::Eml {
left: x_node,
right: one_node,
} = inner_r.as_ref()
{
if matches!(one_node.as_ref(), EmlNode::One) {
return Some(x_node.as_ref().clone());
}
}
}
}
None
}
fn match_exp_of_ln(left: &EmlNode) -> Option<EmlNode> {
if let Some(inner) = match_ln_pattern(left) {
return Some(inner);
}
None
}
fn match_ln_pattern(node: &EmlNode) -> Option<EmlNode> {
if let EmlNode::Eml { left, right } = node {
if !matches!(left.as_ref(), EmlNode::One) {
return None;
}
if let EmlNode::Eml {
left: mid_l,
right: mid_r,
} = right.as_ref()
{
if !matches!(mid_r.as_ref(), EmlNode::One) {
return None;
}
if let EmlNode::Eml {
left: inner_l,
right: inner_r,
} = mid_l.as_ref()
{
if matches!(inner_l.as_ref(), EmlNode::One) {
return Some(inner_r.as_ref().clone());
}
}
}
}
None
}
pub fn count_shared_subtrees(tree: &EmlTree) -> (usize, usize) {
let mut all_nodes = Vec::new();
collect_arcs(&tree.root, &mut all_nodes);
let total = all_nodes.len();
all_nodes.sort_by_key(|a| Arc::as_ptr(a) as usize);
all_nodes.dedup_by(|a, b| Arc::ptr_eq(a, b));
let unique = all_nodes.len();
(total, unique)
}
fn collect_arcs(node: &EmlNode, out: &mut Vec<Arc<EmlNode>>) {
if let EmlNode::Eml { left, right } = node {
out.push(Arc::clone(left));
out.push(Arc::clone(right));
collect_arcs(left, out);
collect_arcs(right, out);
}
}
pub fn normalize(tree: &EmlTree) -> EmlTree {
simplify(tree)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::canonical::Canonical;
use crate::eval::EvalCtx;
#[test]
fn test_simplify_leaf() {
let one = EmlTree::one();
let simplified = simplify(&one);
assert_eq!(simplified, one);
}
#[test]
fn test_simplify_preserves_eml() {
let one = EmlTree::one();
let euler = EmlTree::eml(&one, &one);
let simplified = simplify(&euler);
assert_eq!(simplified.depth(), 1);
assert_eq!(simplified.size(), 3);
}
#[test]
fn test_simplify_ln_of_exp() {
let x = EmlTree::var(0);
let exp_x = Canonical::exp(&x); let ln_exp_x = Canonical::ln(&exp_x); let simplified = simplify(&ln_exp_x);
assert_eq!(simplified.size(), 1);
assert_eq!(*simplified.root, EmlNode::Var(0));
}
#[test]
fn test_simplify_exp_of_ln() {
let x = EmlTree::var(0);
let ln_x = Canonical::ln(&x); let exp_ln_x = Canonical::exp(&ln_x); let simplified = simplify(&exp_ln_x);
assert_eq!(simplified.size(), 1);
assert_eq!(*simplified.root, EmlNode::Var(0));
}
#[test]
fn test_simplify_preserves_semantics() {
let x = EmlTree::var(0);
let exp_x = Canonical::exp(&x);
let ln_exp_x = Canonical::ln(&exp_x);
let ctx = EvalCtx::new(&[2.5]);
let before = ln_exp_x
.eval_real(&ctx)
.expect("ln(exp(x)) eval before simplify should succeed");
let simplified = simplify(&ln_exp_x);
let after = simplified
.eval_real(&ctx)
.expect("x eval after simplify should succeed");
assert!((before - after).abs() < 1e-10);
}
#[test]
fn test_common_subexpression_sharing() {
let one = EmlTree::one();
let x = EmlTree::var(0);
let sub1 = EmlTree::eml(&x, &one);
let sub2 = EmlTree::eml(&x, &one); let tree = EmlTree::eml(&sub1, &sub2);
let simplified = simplify(&tree);
if let EmlNode::Eml { left, right } = simplified.root.as_ref() {
assert!(Arc::ptr_eq(left, right));
} else {
panic!("expected Eml node");
}
}
#[test]
fn test_shared_subtree_counting() {
let one = EmlTree::one();
let x = EmlTree::var(0);
let t1 = EmlTree::eml(&x, &one);
let t2 = EmlTree::eml(&t1, &t1);
let (total, unique) = count_shared_subtrees(&t2);
assert!(total >= 2);
assert!(unique <= total);
}
#[test]
fn test_normalize_idempotent() {
let one = EmlTree::one();
let x = EmlTree::var(0);
let t = EmlTree::eml(&x, &one);
let n1 = normalize(&t);
let n2 = normalize(&n1);
assert_eq!(n1, n2);
}
}