Skip to main content

oximo_expr/
simplify.rs

1use crate::arena::{ExprArena, ExprId, ExprNode};
2
3/// Apply local algebraic simplifications to the subtree rooted at `id`,
4/// returning a (possibly fresh) `ExprId` that is observationally equivalent.
5///
6/// Current rules: constant folding for unary nodes and `Pow`. The linear
7/// fast-path is already canonical, so we leave `Linear` and n-ary `Add`/`Mul`
8/// alone.
9///
10/// TODO: Extend this once we add a CSE pass.
11pub fn simplify(arena: &mut ExprArena, id: ExprId) -> ExprId {
12    let folded = match arena.get(id).clone() {
13        ExprNode::Neg(inner) => match arena.get(inner) {
14            ExprNode::Const(c) => Some(ExprNode::Const(-*c)),
15            _ => None,
16        },
17        ExprNode::Pow(base, exp) => match (arena.get(base), arena.get(exp)) {
18            (ExprNode::Const(b), ExprNode::Const(e)) => Some(ExprNode::Const(b.powf(*e))),
19            _ => None,
20        },
21        ExprNode::Div(num, den) => match (arena.get(num), arena.get(den)) {
22            (ExprNode::Const(n), ExprNode::Const(d)) => Some(ExprNode::Const(n / d)),
23            _ => None,
24        },
25        ExprNode::Sin(inner)
26        | ExprNode::Cos(inner)
27        | ExprNode::Exp(inner)
28        | ExprNode::Log(inner)
29        | ExprNode::Abs(inner) => {
30            let node = arena.get(id).clone();
31            if let ExprNode::Const(c) = arena.get(inner) {
32                Some(ExprNode::Const(match node {
33                    ExprNode::Sin(_) => c.sin(),
34                    ExprNode::Cos(_) => c.cos(),
35                    ExprNode::Exp(_) => c.exp(),
36                    ExprNode::Log(_) => c.ln(),
37                    ExprNode::Abs(_) => c.abs(),
38                    _ => unreachable!(),
39                }))
40            } else {
41                None
42            }
43        }
44        _ => None,
45    };
46    match folded {
47        Some(node) => arena.push(node),
48        None => id,
49    }
50}
51
52#[cfg(test)]
53mod tests {
54    use super::*;
55    use crate::arena::{ExprArena, ExprNode};
56
57    #[test]
58    fn folds_abs_of_const() {
59        let mut a = ExprArena::new();
60        let c = a.push(ExprNode::Const(-5.0));
61        let abs = a.push(ExprNode::Abs(c));
62        let folded = simplify(&mut a, abs);
63        assert!(matches!(a.get(folded), ExprNode::Const(v) if (*v - 5.0).abs() < f64::EPSILON));
64    }
65}