Skip to main content

proof_engine/symbolic/
simplify.rs

1//! Expression simplification — combine like terms, cancel factors, reduce fractions.
2
3use super::expr::Expr;
4
5/// Simplify an expression by applying algebraic identities.
6pub fn simplify(expr: &Expr) -> Expr {
7    let result = simplify_once(expr);
8    // Iterate until fixed point
9    let result2 = simplify_once(&result);
10    if format!("{result}") == format!("{result2}") { result } else { simplify(&result2) }
11}
12
13fn simplify_once(expr: &Expr) -> Expr {
14    match expr {
15        // Recurse into children first
16        Expr::Neg(a) => {
17            let a = simplify_once(a);
18            match a {
19                Expr::Const(v) => Expr::Const(-v),
20                Expr::Neg(inner) => *inner, // --a = a
21                _ => Expr::Neg(Box::new(a)),
22            }
23        }
24        Expr::Add(a, b) => {
25            let a = simplify_once(a);
26            let b = simplify_once(b);
27            match (&a, &b) {
28                (Expr::Const(x), Expr::Const(y)) => Expr::Const(x + y),
29                (Expr::Const(x), _) if *x == 0.0 => b,  // 0 + b = b
30                (_, Expr::Const(y)) if *y == 0.0 => a,    // a + 0 = a
31                _ => Expr::Add(Box::new(a), Box::new(b)),
32            }
33        }
34        Expr::Sub(a, b) => {
35            let a = simplify_once(a);
36            let b = simplify_once(b);
37            match (&a, &b) {
38                (Expr::Const(x), Expr::Const(y)) => Expr::Const(x - y),
39                (_, Expr::Const(y)) if *y == 0.0 => a,
40                _ if format!("{a}") == format!("{b}") => Expr::zero(), // a - a = 0
41                _ => Expr::Sub(Box::new(a), Box::new(b)),
42            }
43        }
44        Expr::Mul(a, b) => {
45            let a = simplify_once(a);
46            let b = simplify_once(b);
47            match (&a, &b) {
48                (Expr::Const(x), Expr::Const(y)) => Expr::Const(x * y),
49                (Expr::Const(x), _) if *x == 0.0 => Expr::zero(),
50                (_, Expr::Const(y)) if *y == 0.0 => Expr::zero(),
51                (Expr::Const(x), _) if *x == 1.0 => b,
52                (_, Expr::Const(y)) if *y == 1.0 => a,
53                (Expr::Const(x), _) if *x == -1.0 => Expr::Neg(Box::new(b)),
54                (_, Expr::Const(y)) if *y == -1.0 => Expr::Neg(Box::new(a)),
55                _ => Expr::Mul(Box::new(a), Box::new(b)),
56            }
57        }
58        Expr::Div(a, b) => {
59            let a = simplify_once(a);
60            let b = simplify_once(b);
61            match (&a, &b) {
62                (Expr::Const(x), Expr::Const(y)) if *y != 0.0 => Expr::Const(x / y),
63                (Expr::Const(x), _) if *x == 0.0 => Expr::zero(),
64                (_, Expr::Const(y)) if *y == 1.0 => a,
65                _ if format!("{a}") == format!("{b}") => Expr::one(), // a/a = 1
66                _ => Expr::Div(Box::new(a), Box::new(b)),
67            }
68        }
69        Expr::Pow(a, b) => {
70            let a = simplify_once(a);
71            let b = simplify_once(b);
72            match (&a, &b) {
73                (_, Expr::Const(y)) if *y == 0.0 => Expr::one(),  // a^0 = 1
74                (_, Expr::Const(y)) if *y == 1.0 => a,             // a^1 = a
75                (Expr::Const(x), _) if *x == 0.0 => Expr::zero(), // 0^b = 0
76                (Expr::Const(x), _) if *x == 1.0 => Expr::one(),  // 1^b = 1
77                (Expr::Const(x), Expr::Const(y)) => Expr::Const(x.powf(*y)),
78                _ => Expr::Pow(Box::new(a), Box::new(b)),
79            }
80        }
81        Expr::Sin(a) => {
82            let a = simplify_once(a);
83            if let Expr::Const(v) = a { Expr::Const(v.sin()) }
84            else { Expr::Sin(Box::new(a)) }
85        }
86        Expr::Cos(a) => {
87            let a = simplify_once(a);
88            if let Expr::Const(v) = a { Expr::Const(v.cos()) }
89            else { Expr::Cos(Box::new(a)) }
90        }
91        Expr::Ln(a) => {
92            let a = simplify_once(a);
93            match a {
94                Expr::Const(v) if (v - 1.0).abs() < 1e-15 => Expr::zero(), // ln(1) = 0
95                Expr::Exp(inner) => *inner, // ln(e^x) = x
96                Expr::Const(v) => Expr::Const(v.ln()),
97                _ => Expr::Ln(Box::new(a)),
98            }
99        }
100        Expr::Exp(a) => {
101            let a = simplify_once(a);
102            match a {
103                Expr::Const(v) if v == 0.0 => Expr::one(), // e^0 = 1
104                Expr::Ln(inner) => *inner, // e^(ln(x)) = x
105                Expr::Const(v) => Expr::Const(v.exp()),
106                _ => Expr::Exp(Box::new(a)),
107            }
108        }
109        Expr::Sqrt(a) => {
110            let a = simplify_once(a);
111            if let Expr::Const(v) = a { Expr::Const(v.sqrt()) }
112            else { Expr::Sqrt(Box::new(a)) }
113        }
114        // Pass through unhandled cases
115        _ => expr.clone(),
116    }
117}
118
119#[cfg(test)]
120mod tests {
121    use super::*;
122
123    #[test]
124    fn simplify_zero_add() {
125        let e = Expr::var("x").add(Expr::zero());
126        let s = simplify(&e);
127        assert!(matches!(s, Expr::Var(ref n) if n == "x"));
128    }
129
130    #[test]
131    fn simplify_multiply_by_one() {
132        let e = Expr::var("x").mul(Expr::one());
133        let s = simplify(&e);
134        assert!(matches!(s, Expr::Var(ref n) if n == "x"));
135    }
136
137    #[test]
138    fn simplify_multiply_by_zero() {
139        let e = Expr::var("x").mul(Expr::zero());
140        let s = simplify(&e);
141        assert!(matches!(s, Expr::Const(v) if v == 0.0));
142    }
143
144    #[test]
145    fn simplify_constant_folding() {
146        let e = Expr::c(3.0).add(Expr::c(4.0));
147        let s = simplify(&e);
148        assert!(matches!(s, Expr::Const(v) if (v - 7.0).abs() < 1e-10));
149    }
150
151    #[test]
152    fn simplify_x_minus_x() {
153        let e = Expr::var("x").sub(Expr::var("x"));
154        let s = simplify(&e);
155        assert!(matches!(s, Expr::Const(v) if v == 0.0));
156    }
157
158    #[test]
159    fn simplify_x_div_x() {
160        let e = Expr::var("x").div(Expr::var("x"));
161        let s = simplify(&e);
162        assert!(matches!(s, Expr::Const(v) if (v - 1.0).abs() < 1e-10));
163    }
164
165    #[test]
166    fn simplify_power_zero() {
167        let e = Expr::var("x").pow(Expr::zero());
168        let s = simplify(&e);
169        assert!(matches!(s, Expr::Const(v) if (v - 1.0).abs() < 1e-10));
170    }
171
172    #[test]
173    fn simplify_ln_exp() {
174        let e = Expr::Ln(Box::new(Expr::Exp(Box::new(Expr::var("x")))));
175        let s = simplify(&e);
176        assert!(matches!(s, Expr::Var(ref n) if n == "x"));
177    }
178
179    #[test]
180    fn simplify_double_negation() {
181        let e = Expr::var("x").neg().neg();
182        let s = simplify(&e);
183        assert!(matches!(s, Expr::Var(ref n) if n == "x"));
184    }
185}