1use crate::arena::{ExprArena, ExprId, ExprNode};
2
3pub fn simplify(arena: &mut ExprArena, id: ExprId) -> ExprId {
12 let folded = match arena.get(id).clone() {
13 ExprNode::Neg(inner) => match arena.get(inner) {
14 ExprNode::Const(c) => Some(ExprNode::Const(-*c)),
15 _ => None,
16 },
17 ExprNode::Pow(base, exp) => match (arena.get(base), arena.get(exp)) {
18 (ExprNode::Const(b), ExprNode::Const(e)) => Some(ExprNode::Const(b.powf(*e))),
19 _ => None,
20 },
21 ExprNode::Div(num, den) => match (arena.get(num), arena.get(den)) {
22 (ExprNode::Const(n), ExprNode::Const(d)) => Some(ExprNode::Const(n / d)),
23 _ => None,
24 },
25 ExprNode::Sin(inner)
26 | ExprNode::Cos(inner)
27 | ExprNode::Exp(inner)
28 | ExprNode::Log(inner)
29 | ExprNode::Abs(inner) => {
30 let node = arena.get(id).clone();
31 if let ExprNode::Const(c) = arena.get(inner) {
32 Some(ExprNode::Const(match node {
33 ExprNode::Sin(_) => c.sin(),
34 ExprNode::Cos(_) => c.cos(),
35 ExprNode::Exp(_) => c.exp(),
36 ExprNode::Log(_) => c.ln(),
37 ExprNode::Abs(_) => c.abs(),
38 _ => unreachable!(),
39 }))
40 } else {
41 None
42 }
43 }
44 _ => None,
45 };
46 match folded {
47 Some(node) => arena.push(node),
48 None => id,
49 }
50}
51
52#[cfg(test)]
53mod tests {
54 use super::*;
55 use crate::arena::{ExprArena, ExprNode};
56
57 #[test]
58 fn folds_abs_of_const() {
59 let mut a = ExprArena::new();
60 let c = a.push(ExprNode::Const(-5.0));
61 let abs = a.push(ExprNode::Abs(c));
62 let folded = simplify(&mut a, abs);
63 assert!(matches!(a.get(folded), ExprNode::Const(v) if (*v - 5.0).abs() < f64::EPSILON));
64 }
65}