proof_engine/symbolic/
simplify.rs1use super::expr::Expr;
4
5pub fn simplify(expr: &Expr) -> Expr {
7 let result = simplify_once(expr);
8 let result2 = simplify_once(&result);
10 if format!("{result}") == format!("{result2}") { result } else { simplify(&result2) }
11}
12
13fn simplify_once(expr: &Expr) -> Expr {
14 match expr {
15 Expr::Neg(a) => {
17 let a = simplify_once(a);
18 match a {
19 Expr::Const(v) => Expr::Const(-v),
20 Expr::Neg(inner) => *inner, _ => Expr::Neg(Box::new(a)),
22 }
23 }
24 Expr::Add(a, b) => {
25 let a = simplify_once(a);
26 let b = simplify_once(b);
27 match (&a, &b) {
28 (Expr::Const(x), Expr::Const(y)) => Expr::Const(x + y),
29 (Expr::Const(x), _) if *x == 0.0 => b, (_, Expr::Const(y)) if *y == 0.0 => a, _ => Expr::Add(Box::new(a), Box::new(b)),
32 }
33 }
34 Expr::Sub(a, b) => {
35 let a = simplify_once(a);
36 let b = simplify_once(b);
37 match (&a, &b) {
38 (Expr::Const(x), Expr::Const(y)) => Expr::Const(x - y),
39 (_, Expr::Const(y)) if *y == 0.0 => a,
40 _ if format!("{a}") == format!("{b}") => Expr::zero(), _ => Expr::Sub(Box::new(a), Box::new(b)),
42 }
43 }
44 Expr::Mul(a, b) => {
45 let a = simplify_once(a);
46 let b = simplify_once(b);
47 match (&a, &b) {
48 (Expr::Const(x), Expr::Const(y)) => Expr::Const(x * y),
49 (Expr::Const(x), _) if *x == 0.0 => Expr::zero(),
50 (_, Expr::Const(y)) if *y == 0.0 => Expr::zero(),
51 (Expr::Const(x), _) if *x == 1.0 => b,
52 (_, Expr::Const(y)) if *y == 1.0 => a,
53 (Expr::Const(x), _) if *x == -1.0 => Expr::Neg(Box::new(b)),
54 (_, Expr::Const(y)) if *y == -1.0 => Expr::Neg(Box::new(a)),
55 _ => Expr::Mul(Box::new(a), Box::new(b)),
56 }
57 }
58 Expr::Div(a, b) => {
59 let a = simplify_once(a);
60 let b = simplify_once(b);
61 match (&a, &b) {
62 (Expr::Const(x), Expr::Const(y)) if *y != 0.0 => Expr::Const(x / y),
63 (Expr::Const(x), _) if *x == 0.0 => Expr::zero(),
64 (_, Expr::Const(y)) if *y == 1.0 => a,
65 _ if format!("{a}") == format!("{b}") => Expr::one(), _ => Expr::Div(Box::new(a), Box::new(b)),
67 }
68 }
69 Expr::Pow(a, b) => {
70 let a = simplify_once(a);
71 let b = simplify_once(b);
72 match (&a, &b) {
73 (_, Expr::Const(y)) if *y == 0.0 => Expr::one(), (_, Expr::Const(y)) if *y == 1.0 => a, (Expr::Const(x), _) if *x == 0.0 => Expr::zero(), (Expr::Const(x), _) if *x == 1.0 => Expr::one(), (Expr::Const(x), Expr::Const(y)) => Expr::Const(x.powf(*y)),
78 _ => Expr::Pow(Box::new(a), Box::new(b)),
79 }
80 }
81 Expr::Sin(a) => {
82 let a = simplify_once(a);
83 if let Expr::Const(v) = a { Expr::Const(v.sin()) }
84 else { Expr::Sin(Box::new(a)) }
85 }
86 Expr::Cos(a) => {
87 let a = simplify_once(a);
88 if let Expr::Const(v) = a { Expr::Const(v.cos()) }
89 else { Expr::Cos(Box::new(a)) }
90 }
91 Expr::Ln(a) => {
92 let a = simplify_once(a);
93 match a {
94 Expr::Const(v) if (v - 1.0).abs() < 1e-15 => Expr::zero(), Expr::Exp(inner) => *inner, Expr::Const(v) => Expr::Const(v.ln()),
97 _ => Expr::Ln(Box::new(a)),
98 }
99 }
100 Expr::Exp(a) => {
101 let a = simplify_once(a);
102 match a {
103 Expr::Const(v) if v == 0.0 => Expr::one(), Expr::Ln(inner) => *inner, Expr::Const(v) => Expr::Const(v.exp()),
106 _ => Expr::Exp(Box::new(a)),
107 }
108 }
109 Expr::Sqrt(a) => {
110 let a = simplify_once(a);
111 if let Expr::Const(v) = a { Expr::Const(v.sqrt()) }
112 else { Expr::Sqrt(Box::new(a)) }
113 }
114 _ => expr.clone(),
116 }
117}
118
119#[cfg(test)]
120mod tests {
121 use super::*;
122
123 #[test]
124 fn simplify_zero_add() {
125 let e = Expr::var("x").add(Expr::zero());
126 let s = simplify(&e);
127 assert!(matches!(s, Expr::Var(ref n) if n == "x"));
128 }
129
130 #[test]
131 fn simplify_multiply_by_one() {
132 let e = Expr::var("x").mul(Expr::one());
133 let s = simplify(&e);
134 assert!(matches!(s, Expr::Var(ref n) if n == "x"));
135 }
136
137 #[test]
138 fn simplify_multiply_by_zero() {
139 let e = Expr::var("x").mul(Expr::zero());
140 let s = simplify(&e);
141 assert!(matches!(s, Expr::Const(v) if v == 0.0));
142 }
143
144 #[test]
145 fn simplify_constant_folding() {
146 let e = Expr::c(3.0).add(Expr::c(4.0));
147 let s = simplify(&e);
148 assert!(matches!(s, Expr::Const(v) if (v - 7.0).abs() < 1e-10));
149 }
150
151 #[test]
152 fn simplify_x_minus_x() {
153 let e = Expr::var("x").sub(Expr::var("x"));
154 let s = simplify(&e);
155 assert!(matches!(s, Expr::Const(v) if v == 0.0));
156 }
157
158 #[test]
159 fn simplify_x_div_x() {
160 let e = Expr::var("x").div(Expr::var("x"));
161 let s = simplify(&e);
162 assert!(matches!(s, Expr::Const(v) if (v - 1.0).abs() < 1e-10));
163 }
164
165 #[test]
166 fn simplify_power_zero() {
167 let e = Expr::var("x").pow(Expr::zero());
168 let s = simplify(&e);
169 assert!(matches!(s, Expr::Const(v) if (v - 1.0).abs() < 1e-10));
170 }
171
172 #[test]
173 fn simplify_ln_exp() {
174 let e = Expr::Ln(Box::new(Expr::Exp(Box::new(Expr::var("x")))));
175 let s = simplify(&e);
176 assert!(matches!(s, Expr::Var(ref n) if n == "x"));
177 }
178
179 #[test]
180 fn simplify_double_negation() {
181 let e = Expr::var("x").neg().neg();
182 let s = simplify(&e);
183 assert!(matches!(s, Expr::Var(ref n) if n == "x"));
184 }
185}