1use std::collections::HashMap;
4
5use crate::graph::ExprGraph;
6use crate::node::{ExprId, Node};
7
8impl ExprGraph {
9 pub fn simplify(&mut self, expr: ExprId) -> ExprId {
14 let mut memo = HashMap::new();
15 self.simplify_inner(expr, &mut memo)
16 }
17
18 fn simplify_inner(&mut self, expr: ExprId, memo: &mut HashMap<ExprId, ExprId>) -> ExprId {
19 if let Some(&cached) = memo.get(&expr) {
20 return cached;
21 }
22
23 let simplified_children = match self.node(expr) {
25 Node::Var(_) | Node::Lit(_) => expr,
26 Node::Add(a, b) => {
27 let sa = self.simplify_inner(a, memo);
28 let sb = self.simplify_inner(b, memo);
29 self.add(sa, sb)
30 }
31 Node::Mul(a, b) => {
32 let sa = self.simplify_inner(a, memo);
33 let sb = self.simplify_inner(b, memo);
34 self.mul(sa, sb)
35 }
36 Node::Neg(a) => {
37 let sa = self.simplify_inner(a, memo);
38 self.neg(sa)
39 }
40 Node::Recip(a) => {
41 let sa = self.simplify_inner(a, memo);
42 self.recip(sa)
43 }
44 Node::Sqrt(a) => {
45 let sa = self.simplify_inner(a, memo);
46 self.sqrt(sa)
47 }
48 Node::Sin(a) => {
49 let sa = self.simplify_inner(a, memo);
50 self.sin(sa)
51 }
52 Node::Atan2(y, x) => {
53 let sy = self.simplify_inner(y, memo);
54 let sx = self.simplify_inner(x, memo);
55 self.atan2(sy, sx)
56 }
57 Node::Exp2(a) => {
58 let sa = self.simplify_inner(a, memo);
59 self.exp2(sa)
60 }
61 Node::Log2(a) => {
62 let sa = self.simplify_inner(a, memo);
63 self.log2(sa)
64 }
65 Node::Select(c, a, b) => {
66 let sc = self.simplify_inner(c, memo);
67 let sa = self.simplify_inner(a, memo);
68 let sb = self.simplify_inner(b, memo);
69 self.select(sc, sa, sb)
70 }
71 };
72
73 let result = self.rewrite(simplified_children);
75
76 let final_result = if result != simplified_children {
78 self.simplify_inner(result, memo)
79 } else {
80 result
81 };
82
83 memo.insert(expr, final_result);
84 final_result
85 }
86
87 fn rewrite(&mut self, expr: ExprId) -> ExprId {
89 match self.node(expr) {
90 Node::Add(a, b) if b == ExprId::ZERO => a,
94 Node::Add(a, b) if a == ExprId::ZERO => b,
96
97 Node::Mul(a, b) if b == ExprId::ONE => a,
99 Node::Mul(a, b) if a == ExprId::ONE => b,
101 Node::Mul(_, b) if b == ExprId::ZERO => ExprId::ZERO,
103 Node::Mul(a, _) if a == ExprId::ZERO => ExprId::ZERO,
105
106 Node::Neg(a) => match self.node(a) {
108 Node::Neg(inner) => inner,
109 _ if a == ExprId::ZERO => ExprId::ZERO,
111 Node::Lit(bits) => {
113 let v = f64::from_bits(bits);
114 self.lit(-v)
115 }
116 _ => expr,
117 },
118
119 Node::Recip(a) => match self.node(a) {
121 Node::Recip(inner) => inner,
122 Node::Lit(bits) => {
124 let v = f64::from_bits(bits);
125 self.lit(1.0 / v)
126 }
127 _ => expr,
128 },
129
130 Node::Add(a, b) => {
134 if let Node::Neg(inner) = self.node(b) {
135 if inner == a {
136 return ExprId::ZERO;
137 }
138 }
139 if let Node::Neg(inner) = self.node(a) {
140 if inner == b {
141 return ExprId::ZERO;
142 }
143 }
144 if let (Some(va), Some(vb)) = (self.node(a).as_f64(), self.node(b).as_f64()) {
146 return self.lit(va + vb);
147 }
148 expr
149 }
150
151 Node::Mul(a, b) => {
152 if let Node::Recip(inner) = self.node(b) {
154 if inner == a {
155 return ExprId::ONE;
156 }
157 }
158 if let Node::Recip(inner) = self.node(a) {
159 if inner == b {
160 return ExprId::ONE;
161 }
162 }
163 if let (Some(va), Some(vb)) = (self.node(a).as_f64(), self.node(b).as_f64()) {
165 return self.lit(va * vb);
166 }
167 expr
168 }
169
170 Node::Sqrt(a) => {
172 if let Some(v) = self.node(a).as_f64() {
173 self.lit(v.sqrt())
174 } else {
175 expr
176 }
177 }
178 Node::Sin(a) => {
179 if let Some(v) = self.node(a).as_f64() {
180 self.lit(v.sin())
181 } else {
182 expr
183 }
184 }
185 Node::Exp2(a) => {
186 if let Some(v) = self.node(a).as_f64() {
187 self.lit(v.exp2())
188 } else {
189 expr
190 }
191 }
192 Node::Log2(a) => {
193 if let Some(v) = self.node(a).as_f64() {
194 self.lit(v.log2())
195 } else {
196 expr
197 }
198 }
199
200 Node::Select(c, a, b) => {
202 if let Some(vc) = self.node(c).as_f64() {
203 if vc > 0.0 { a } else { b }
204 } else {
205 expr
206 }
207 }
208
209 _ => expr,
210 }
211 }
212}
213
214#[cfg(test)]
215mod tests {
216 use crate::graph::ExprGraph;
217 use crate::node::ExprId;
218
219 #[test]
220 fn simplify_add_zero() {
221 let mut g = ExprGraph::new();
222 let x = g.var(0);
223 let sum = g.add(x, ExprId::ZERO);
224 let s = g.simplify(sum);
225 assert_eq!(s, x);
226
227 let sum2 = g.add(ExprId::ZERO, x);
228 let s2 = g.simplify(sum2);
229 assert_eq!(s2, x);
230 }
231
232 #[test]
233 fn simplify_mul_one() {
234 let mut g = ExprGraph::new();
235 let x = g.var(0);
236 let prod = g.mul(x, ExprId::ONE);
237 let s = g.simplify(prod);
238 assert_eq!(s, x);
239 }
240
241 #[test]
242 fn simplify_mul_zero() {
243 let mut g = ExprGraph::new();
244 let x = g.var(0);
245 let prod = g.mul(x, ExprId::ZERO);
246 let s = g.simplify(prod);
247 assert_eq!(s, ExprId::ZERO);
248 }
249
250 #[test]
251 fn simplify_neg_neg() {
252 let mut g = ExprGraph::new();
253 let x = g.var(0);
254 let nn = g.neg(x);
255 let nnn = g.neg(nn);
256 let s = g.simplify(nnn);
257 assert_eq!(s, x);
258 }
259
260 #[test]
261 fn simplify_recip_recip() {
262 let mut g = ExprGraph::new();
263 let x = g.var(0);
264 let r = g.recip(x);
265 let rr = g.recip(r);
266 let s = g.simplify(rr);
267 assert_eq!(s, x);
268 }
269
270 #[test]
271 fn simplify_cancel_add_neg() {
272 let mut g = ExprGraph::new();
273 let x = g.var(0);
274 let nx = g.neg(x);
275 let sum = g.add(x, nx);
276 let s = g.simplify(sum);
277 assert_eq!(s, ExprId::ZERO);
278 }
279
280 #[test]
281 fn simplify_cancel_mul_recip() {
282 let mut g = ExprGraph::new();
283 let x = g.var(0);
284 let rx = g.recip(x);
285 let prod = g.mul(x, rx);
286 let s = g.simplify(prod);
287 assert_eq!(s, ExprId::ONE);
288 }
289
290 #[test]
291 fn simplify_constant_fold_add() {
292 let mut g = ExprGraph::new();
293 let a = g.lit(3.0);
294 let b = g.lit(4.0);
295 let sum = g.add(a, b);
296 let s = g.simplify(sum);
297 let result: f64 = g.eval(s, &[]);
298 assert!((result - 7.0).abs() < 1e-10);
299 }
300
301 #[test]
302 fn simplify_constant_fold_mul() {
303 let mut g = ExprGraph::new();
304 let a = g.lit(3.0);
305 let b = g.lit(4.0);
306 let prod = g.mul(a, b);
307 let s = g.simplify(prod);
308 let result: f64 = g.eval(s, &[]);
309 assert!((result - 12.0).abs() < 1e-10);
310 }
311
312 #[test]
313 fn simplify_neg_zero() {
314 let mut g = ExprGraph::new();
315 let nz = g.neg(ExprId::ZERO);
316 let s = g.simplify(nz);
317 let result: f64 = g.eval(s, &[]);
321 assert!(result == 0.0);
322 }
323
324 #[test]
325 fn simplify_derivative() {
326 let mut g = ExprGraph::new();
328 let x = g.var(0);
329 let xx = g.mul(x, x);
330 let d = g.diff(xx, 0);
331
332 let s = g.simplify(d);
335 let result: f64 = g.eval(s, &[5.0]);
336 assert!((result - 10.0).abs() < 1e-10);
337 }
338}