Skip to main content

oxilean_codegen/opt_algebraic/
functions.rs

1//! Functions for the Algebraic Simplification optimisation pass.
2
3use std::collections::HashMap;
4
5use super::types::{AlgExpr, AlgSimplConfig, SimplResult, SimplStats};
6
7// ── Expression utilities ──────────────────────────────────────────────────────
8
9/// Count the number of nodes in an expression tree.
10pub fn expr_size(expr: &AlgExpr) -> usize {
11    match expr {
12        AlgExpr::Const(_) | AlgExpr::Var(_) => 1,
13        AlgExpr::Neg(e) => 1 + expr_size(e),
14        AlgExpr::Add(l, r)
15        | AlgExpr::Sub(l, r)
16        | AlgExpr::Mul(l, r)
17        | AlgExpr::Div(l, r)
18        | AlgExpr::Pow(l, r)
19        | AlgExpr::Mod(l, r) => 1 + expr_size(l) + expr_size(r),
20    }
21}
22
23/// Count the occurrences of each variable name in `expr`.
24pub fn count_vars(expr: &AlgExpr) -> HashMap<String, usize> {
25    let mut map: HashMap<String, usize> = HashMap::new();
26    count_vars_impl(expr, &mut map);
27    map
28}
29
30fn count_vars_impl(expr: &AlgExpr, map: &mut HashMap<String, usize>) {
31    match expr {
32        AlgExpr::Const(_) => {}
33        AlgExpr::Var(v) => *map.entry(v.clone()).or_insert(0) += 1,
34        AlgExpr::Neg(e) => count_vars_impl(e, map),
35        AlgExpr::Add(l, r)
36        | AlgExpr::Sub(l, r)
37        | AlgExpr::Mul(l, r)
38        | AlgExpr::Div(l, r)
39        | AlgExpr::Pow(l, r)
40        | AlgExpr::Mod(l, r) => {
41            count_vars_impl(l, map);
42            count_vars_impl(r, map);
43        }
44    }
45}
46
47/// Substitute variables in `expr` using `subs`.  Variables not present in
48/// `subs` are left unchanged.
49pub fn substitute_vars(expr: &AlgExpr, subs: &HashMap<String, AlgExpr>) -> AlgExpr {
50    match expr {
51        AlgExpr::Const(n) => AlgExpr::Const(*n),
52        AlgExpr::Var(v) => subs
53            .get(v)
54            .cloned()
55            .unwrap_or_else(|| AlgExpr::Var(v.clone())),
56        AlgExpr::Neg(e) => AlgExpr::Neg(Box::new(substitute_vars(e, subs))),
57        AlgExpr::Add(l, r) => AlgExpr::Add(
58            Box::new(substitute_vars(l, subs)),
59            Box::new(substitute_vars(r, subs)),
60        ),
61        AlgExpr::Sub(l, r) => AlgExpr::Sub(
62            Box::new(substitute_vars(l, subs)),
63            Box::new(substitute_vars(r, subs)),
64        ),
65        AlgExpr::Mul(l, r) => AlgExpr::Mul(
66            Box::new(substitute_vars(l, subs)),
67            Box::new(substitute_vars(r, subs)),
68        ),
69        AlgExpr::Div(l, r) => AlgExpr::Div(
70            Box::new(substitute_vars(l, subs)),
71            Box::new(substitute_vars(r, subs)),
72        ),
73        AlgExpr::Pow(l, r) => AlgExpr::Pow(
74            Box::new(substitute_vars(l, subs)),
75            Box::new(substitute_vars(r, subs)),
76        ),
77        AlgExpr::Mod(l, r) => AlgExpr::Mod(
78            Box::new(substitute_vars(l, subs)),
79            Box::new(substitute_vars(r, subs)),
80        ),
81    }
82}
83
84/// Format an expression as a human-readable string.
85pub fn alg_expr_to_string(expr: &AlgExpr) -> String {
86    match expr {
87        AlgExpr::Const(n) => n.to_string(),
88        AlgExpr::Var(v) => v.clone(),
89        AlgExpr::Neg(e) => format!("(-{})", alg_expr_to_string(e)),
90        AlgExpr::Add(l, r) => format!("({} + {})", alg_expr_to_string(l), alg_expr_to_string(r)),
91        AlgExpr::Sub(l, r) => format!("({} - {})", alg_expr_to_string(l), alg_expr_to_string(r)),
92        AlgExpr::Mul(l, r) => format!("({} * {})", alg_expr_to_string(l), alg_expr_to_string(r)),
93        AlgExpr::Div(l, r) => format!("({} / {})", alg_expr_to_string(l), alg_expr_to_string(r)),
94        AlgExpr::Pow(l, r) => format!("({} ^ {})", alg_expr_to_string(l), alg_expr_to_string(r)),
95        AlgExpr::Mod(l, r) => format!("({} % {})", alg_expr_to_string(l), alg_expr_to_string(r)),
96    }
97}
98
99// ── Core simplification steps ─────────────────────────────────────────────────
100
101/// Try to evaluate `expr` to a constant if both children are constants.
102///
103/// Returns `Some(simplified)` if folding was possible, `None` otherwise.
104pub fn fold_constants(expr: &AlgExpr) -> Option<AlgExpr> {
105    match expr {
106        AlgExpr::Add(l, r) => {
107            if let (AlgExpr::Const(a), AlgExpr::Const(b)) = (l.as_ref(), r.as_ref()) {
108                return a.checked_add(*b).map(AlgExpr::Const);
109            }
110        }
111        AlgExpr::Sub(l, r) => {
112            if let (AlgExpr::Const(a), AlgExpr::Const(b)) = (l.as_ref(), r.as_ref()) {
113                return a.checked_sub(*b).map(AlgExpr::Const);
114            }
115        }
116        AlgExpr::Mul(l, r) => {
117            if let (AlgExpr::Const(a), AlgExpr::Const(b)) = (l.as_ref(), r.as_ref()) {
118                return a.checked_mul(*b).map(AlgExpr::Const);
119            }
120        }
121        AlgExpr::Div(l, r) => {
122            if let (AlgExpr::Const(a), AlgExpr::Const(b)) = (l.as_ref(), r.as_ref()) {
123                if *b != 0 {
124                    return a.checked_div(*b).map(AlgExpr::Const);
125                }
126            }
127        }
128        AlgExpr::Mod(l, r) => {
129            if let (AlgExpr::Const(a), AlgExpr::Const(b)) = (l.as_ref(), r.as_ref()) {
130                if *b != 0 {
131                    return a.checked_rem(*b).map(AlgExpr::Const);
132                }
133            }
134        }
135        AlgExpr::Neg(e) => {
136            if let AlgExpr::Const(n) = e.as_ref() {
137                return n.checked_neg().map(AlgExpr::Const);
138            }
139        }
140        AlgExpr::Pow(base, exp) => {
141            if let (AlgExpr::Const(b), AlgExpr::Const(e)) = (base.as_ref(), exp.as_ref()) {
142                if *e >= 0 {
143                    let e_u32 = *e as u32;
144                    return b.checked_pow(e_u32).map(AlgExpr::Const);
145                }
146            }
147        }
148        _ => {}
149    }
150    None
151}
152
153/// Try to apply a single algebraic identity to the top-level node of `expr`.
154///
155/// Returns `Some((simplified_expr, rule_name))` when a rule fires.
156///
157/// Identities applied:
158/// - `x + 0 = x`, `0 + x = x`
159/// - `x * 1 = x`, `1 * x = x`
160/// - `x * 0 = 0`, `0 * x = 0`
161/// - `x - 0 = x`
162/// - `x + x = 2 * x`
163/// - `x - x = 0`
164/// - `x / x = 1` (when x is a non-zero constant or a variable)
165/// - `0 - x = -x`
166/// - `-(-x) = x`
167/// - `x ^ 0 = 1`
168/// - `x ^ 1 = x`
169/// - `0 ^ x = 0` (x must be a positive constant)
170pub fn apply_identity(expr: &AlgExpr) -> Option<(AlgExpr, String)> {
171    match expr {
172        // x + 0 = x
173        AlgExpr::Add(_, r) if *r.as_ref() == AlgExpr::Const(0) => {
174            Some((expr_left(expr).clone(), "add_zero_right".to_string()))
175        }
176        // 0 + x = x
177        AlgExpr::Add(l, _) if *l.as_ref() == AlgExpr::Const(0) => {
178            Some((expr_right(expr).clone(), "add_zero_left".to_string()))
179        }
180        // x - 0 = x
181        AlgExpr::Sub(_, r) if *r.as_ref() == AlgExpr::Const(0) => {
182            Some((expr_left(expr).clone(), "sub_zero".to_string()))
183        }
184        // 0 - x = -x
185        AlgExpr::Sub(l, _) if *l.as_ref() == AlgExpr::Const(0) => Some((
186            AlgExpr::Neg(Box::new(expr_right(expr).clone())),
187            "neg_zero".to_string(),
188        )),
189        // x * 1 = x
190        AlgExpr::Mul(_, r) if *r.as_ref() == AlgExpr::Const(1) => {
191            Some((expr_left(expr).clone(), "mul_one_right".to_string()))
192        }
193        // 1 * x = x
194        AlgExpr::Mul(l, _) if *l.as_ref() == AlgExpr::Const(1) => {
195            Some((expr_right(expr).clone(), "mul_one_left".to_string()))
196        }
197        // x * 0 = 0
198        AlgExpr::Mul(_, r) if *r.as_ref() == AlgExpr::Const(0) => {
199            Some((AlgExpr::Const(0), "mul_zero_right".to_string()))
200        }
201        // 0 * x = 0
202        AlgExpr::Mul(l, _) if *l.as_ref() == AlgExpr::Const(0) => {
203            Some((AlgExpr::Const(0), "mul_zero_left".to_string()))
204        }
205        // x + x = 2 * x
206        AlgExpr::Add(l, r) if l == r => Some((
207            AlgExpr::Mul(Box::new(AlgExpr::Const(2)), Box::new(l.as_ref().clone())),
208            "add_self".to_string(),
209        )),
210        // x - x = 0
211        AlgExpr::Sub(l, r) if l == r => Some((AlgExpr::Const(0), "sub_self".to_string())),
212        // x / x = 1 (when x ≠ 0; we check constants and variables)
213        AlgExpr::Div(l, r) if l == r => match l.as_ref() {
214            AlgExpr::Const(n) if *n != 0 => Some((AlgExpr::Const(1), "div_self".to_string())),
215            AlgExpr::Var(_) => Some((AlgExpr::Const(1), "div_self".to_string())),
216            _ => None,
217        },
218        // -(-x) = x
219        AlgExpr::Neg(inner) => match inner.as_ref() {
220            AlgExpr::Neg(inner2) => Some((inner2.as_ref().clone(), "double_neg".to_string())),
221            _ => None,
222        },
223        // x ^ 0 = 1
224        AlgExpr::Pow(_, r) if *r.as_ref() == AlgExpr::Const(0) => {
225            Some((AlgExpr::Const(1), "pow_zero".to_string()))
226        }
227        // x ^ 1 = x
228        AlgExpr::Pow(_, r) if *r.as_ref() == AlgExpr::Const(1) => {
229            Some((expr_left(expr).clone(), "pow_one".to_string()))
230        }
231        // 0 ^ x = 0 (x > 0)
232        AlgExpr::Pow(l, r) if *l.as_ref() == AlgExpr::Const(0) => match r.as_ref() {
233            AlgExpr::Const(e) if *e > 0 => Some((AlgExpr::Const(0), "zero_pow".to_string())),
234            _ => None,
235        },
236        _ => None,
237    }
238}
239
240/// Helper: extract left child of a binary expression.
241fn expr_left(expr: &AlgExpr) -> &AlgExpr {
242    match expr {
243        AlgExpr::Add(l, _)
244        | AlgExpr::Sub(l, _)
245        | AlgExpr::Mul(l, _)
246        | AlgExpr::Div(l, _)
247        | AlgExpr::Pow(l, _)
248        | AlgExpr::Mod(l, _) => l,
249        _ => expr,
250    }
251}
252
253/// Helper: extract right child of a binary expression.
254fn expr_right(expr: &AlgExpr) -> &AlgExpr {
255    match expr {
256        AlgExpr::Add(_, r)
257        | AlgExpr::Sub(_, r)
258        | AlgExpr::Mul(_, r)
259        | AlgExpr::Div(_, r)
260        | AlgExpr::Pow(_, r)
261        | AlgExpr::Mod(_, r) => r,
262        _ => expr,
263    }
264}
265
266/// Produce a canonical form for `expr`:
267/// - For commutative `Add` and `Mul`, sort the operand string representations
268///   so that `a + b` and `b + a` both become the same canonical form.
269/// - Flatten nested `Add(Add(a,b),c)` → `Add(a, Add(b,c))`.
270/// - Flatten nested `Mul(Mul(a,b),c)` → `Mul(a, Mul(b,c))`.
271pub fn normalize(expr: &AlgExpr) -> AlgExpr {
272    match expr {
273        AlgExpr::Const(_) | AlgExpr::Var(_) => expr.clone(),
274        AlgExpr::Neg(e) => AlgExpr::Neg(Box::new(normalize(e))),
275        AlgExpr::Add(l, r) => {
276            let nl = normalize(l);
277            let nr = normalize(r);
278            // Flatten left-nested add: (a+b)+c => a+(b+c)
279            if let AlgExpr::Add(ll, lr) = nl.clone() {
280                return normalize(&AlgExpr::Add(ll, Box::new(AlgExpr::Add(lr, Box::new(nr)))));
281            }
282            // Sort for commutativity
283            let ls = alg_expr_to_string(&nl);
284            let rs = alg_expr_to_string(&nr);
285            if ls <= rs {
286                AlgExpr::Add(Box::new(nl), Box::new(nr))
287            } else {
288                AlgExpr::Add(Box::new(nr), Box::new(nl))
289            }
290        }
291        AlgExpr::Mul(l, r) => {
292            let nl = normalize(l);
293            let nr = normalize(r);
294            // Flatten left-nested mul: (a*b)*c => a*(b*c)
295            if let AlgExpr::Mul(ll, lr) = nl.clone() {
296                return normalize(&AlgExpr::Mul(ll, Box::new(AlgExpr::Mul(lr, Box::new(nr)))));
297            }
298            // Sort for commutativity
299            let ls = alg_expr_to_string(&nl);
300            let rs = alg_expr_to_string(&nr);
301            if ls <= rs {
302                AlgExpr::Mul(Box::new(nl), Box::new(nr))
303            } else {
304                AlgExpr::Mul(Box::new(nr), Box::new(nl))
305            }
306        }
307        AlgExpr::Sub(l, r) => AlgExpr::Sub(Box::new(normalize(l)), Box::new(normalize(r))),
308        AlgExpr::Div(l, r) => AlgExpr::Div(Box::new(normalize(l)), Box::new(normalize(r))),
309        AlgExpr::Pow(l, r) => AlgExpr::Pow(Box::new(normalize(l)), Box::new(normalize(r))),
310        AlgExpr::Mod(l, r) => AlgExpr::Mod(Box::new(normalize(l)), Box::new(normalize(r))),
311    }
312}
313
314/// Apply one simplification step (fold_constants or apply_identity) to the
315/// top-level node of `expr`.  Recurse into children first.
316///
317/// Returns `(simplified, changed, rule_name)`.
318fn simplify_step(expr: AlgExpr, fold: bool) -> (AlgExpr, bool, Option<String>) {
319    // Recurse into children first.
320    let expr = match expr {
321        AlgExpr::Neg(e) => {
322            let (se, _, _) = simplify_step(*e, fold);
323            AlgExpr::Neg(Box::new(se))
324        }
325        AlgExpr::Add(l, r) => {
326            let (sl, _, _) = simplify_step(*l, fold);
327            let (sr, _, _) = simplify_step(*r, fold);
328            AlgExpr::Add(Box::new(sl), Box::new(sr))
329        }
330        AlgExpr::Sub(l, r) => {
331            let (sl, _, _) = simplify_step(*l, fold);
332            let (sr, _, _) = simplify_step(*r, fold);
333            AlgExpr::Sub(Box::new(sl), Box::new(sr))
334        }
335        AlgExpr::Mul(l, r) => {
336            let (sl, _, _) = simplify_step(*l, fold);
337            let (sr, _, _) = simplify_step(*r, fold);
338            AlgExpr::Mul(Box::new(sl), Box::new(sr))
339        }
340        AlgExpr::Div(l, r) => {
341            let (sl, _, _) = simplify_step(*l, fold);
342            let (sr, _, _) = simplify_step(*r, fold);
343            AlgExpr::Div(Box::new(sl), Box::new(sr))
344        }
345        AlgExpr::Pow(l, r) => {
346            let (sl, _, _) = simplify_step(*l, fold);
347            let (sr, _, _) = simplify_step(*r, fold);
348            AlgExpr::Pow(Box::new(sl), Box::new(sr))
349        }
350        AlgExpr::Mod(l, r) => {
351            let (sl, _, _) = simplify_step(*l, fold);
352            let (sr, _, _) = simplify_step(*r, fold);
353            AlgExpr::Mod(Box::new(sl), Box::new(sr))
354        }
355        other => other,
356    };
357
358    // Try constant folding.
359    if fold {
360        if let Some(folded) = fold_constants(&expr) {
361            return (folded, true, Some("fold_constants".to_string()));
362        }
363    }
364
365    // Try identity rules.
366    if let Some((simplified, rule)) = apply_identity(&expr) {
367        return (simplified, true, Some(rule));
368    }
369
370    (expr, false, None)
371}
372
373/// Simplify `expr` according to `cfg`, returning a `SimplResult` with the
374/// simplified expression, a trace of applied rules, and statistics.
375pub fn simplify(expr: AlgExpr, cfg: &AlgSimplConfig) -> SimplResult {
376    let size_before = expr_size(&expr);
377    let mut current = expr;
378    let mut steps: Vec<String> = Vec::new();
379    let mut passes = 0usize;
380    for _pass in 0..cfg.max_passes {
381        passes += 1;
382        let (next, changed, rule) = simplify_step(current.clone(), cfg.fold_constants);
383        if changed {
384            if let Some(r) = rule {
385                steps.push(format!(
386                    "pass {}: {} => {}",
387                    passes,
388                    r,
389                    alg_expr_to_string(&next)
390                ));
391            }
392            current = normalize(&next);
393        } else {
394            break;
395        }
396    }
397
398    let size_after = expr_size(&current);
399    let reduced = size_after < size_before || !steps.is_empty();
400    SimplResult {
401        expr: current,
402        steps,
403        reduced,
404    }
405}
406
407/// Return a `SimplStats` for simplifying `expr` with `cfg`.
408pub fn simplify_with_stats(expr: AlgExpr, cfg: &AlgSimplConfig) -> (SimplResult, SimplStats) {
409    let size_before = expr_size(&expr);
410    let mut current = expr;
411    let mut steps: Vec<String> = Vec::new();
412    let mut passes_completed = 0usize;
413    let mut rules_applied = 0usize;
414
415    for _pass in 0..cfg.max_passes {
416        passes_completed += 1;
417        let (next, changed, rule) = simplify_step(current.clone(), cfg.fold_constants);
418        if changed {
419            if let Some(r) = rule {
420                steps.push(format!(
421                    "pass {}: {} => {}",
422                    passes_completed,
423                    r,
424                    alg_expr_to_string(&next)
425                ));
426            }
427            rules_applied += 1;
428            current = normalize(&next);
429        } else {
430            break;
431        }
432    }
433
434    let size_after = expr_size(&current);
435    let reduced = size_after < size_before || !steps.is_empty();
436    let result = SimplResult {
437        expr: current,
438        steps,
439        reduced,
440    };
441    let stats = SimplStats {
442        rules_applied,
443        passes_completed,
444        size_before,
445        size_after,
446    };
447    (result, stats)
448}
449
450// ── Tests ─────────────────────────────────────────────────────────────────────
451
452#[cfg(test)]
453mod tests {
454    use super::super::types::{AlgExpr, AlgSimplConfig};
455    use super::*;
456
457    fn c(n: i64) -> AlgExpr {
458        AlgExpr::Const(n)
459    }
460    fn v(s: &str) -> AlgExpr {
461        AlgExpr::Var(s.to_string())
462    }
463    fn add(l: AlgExpr, r: AlgExpr) -> AlgExpr {
464        AlgExpr::Add(Box::new(l), Box::new(r))
465    }
466    fn sub(l: AlgExpr, r: AlgExpr) -> AlgExpr {
467        AlgExpr::Sub(Box::new(l), Box::new(r))
468    }
469    fn mul(l: AlgExpr, r: AlgExpr) -> AlgExpr {
470        AlgExpr::Mul(Box::new(l), Box::new(r))
471    }
472    fn div(l: AlgExpr, r: AlgExpr) -> AlgExpr {
473        AlgExpr::Div(Box::new(l), Box::new(r))
474    }
475    fn neg(e: AlgExpr) -> AlgExpr {
476        AlgExpr::Neg(Box::new(e))
477    }
478    fn pow(b: AlgExpr, e: AlgExpr) -> AlgExpr {
479        AlgExpr::Pow(Box::new(b), Box::new(e))
480    }
481    fn cfg_default() -> AlgSimplConfig {
482        AlgSimplConfig::default()
483    }
484
485    // ── fold_constants ────────────────────────────────────────────────────────
486
487    #[test]
488    fn test_fold_add_constants() {
489        assert_eq!(fold_constants(&add(c(3), c(4))), Some(c(7)));
490    }
491
492    #[test]
493    fn test_fold_sub_constants() {
494        assert_eq!(fold_constants(&sub(c(10), c(3))), Some(c(7)));
495    }
496
497    #[test]
498    fn test_fold_mul_constants() {
499        assert_eq!(fold_constants(&mul(c(6), c(7))), Some(c(42)));
500    }
501
502    #[test]
503    fn test_fold_div_constants() {
504        assert_eq!(fold_constants(&div(c(12), c(4))), Some(c(3)));
505    }
506
507    #[test]
508    fn test_fold_div_by_zero_returns_none() {
509        assert_eq!(fold_constants(&div(c(5), c(0))), None);
510    }
511
512    #[test]
513    fn test_fold_neg_constant() {
514        assert_eq!(fold_constants(&neg(c(7))), Some(c(-7)));
515    }
516
517    #[test]
518    fn test_fold_pow_constants() {
519        assert_eq!(fold_constants(&pow(c(2), c(10))), Some(c(1024)));
520    }
521
522    #[test]
523    fn test_fold_not_applicable_for_vars() {
524        assert_eq!(fold_constants(&add(v("x"), c(1))), None);
525    }
526
527    // ── apply_identity ────────────────────────────────────────────────────────
528
529    #[test]
530    fn test_identity_add_zero_right() {
531        let (e, rule) = apply_identity(&add(v("x"), c(0))).unwrap();
532        assert_eq!(e, v("x"));
533        assert_eq!(rule, "add_zero_right");
534    }
535
536    #[test]
537    fn test_identity_add_zero_left() {
538        let (e, rule) = apply_identity(&add(c(0), v("x"))).unwrap();
539        assert_eq!(e, v("x"));
540        assert_eq!(rule, "add_zero_left");
541    }
542
543    #[test]
544    fn test_identity_mul_one_right() {
545        let (e, rule) = apply_identity(&mul(v("x"), c(1))).unwrap();
546        assert_eq!(e, v("x"));
547        assert_eq!(rule, "mul_one_right");
548    }
549
550    #[test]
551    fn test_identity_mul_zero_right() {
552        let (e, rule) = apply_identity(&mul(v("x"), c(0))).unwrap();
553        assert_eq!(e, c(0));
554        assert_eq!(rule, "mul_zero_right");
555    }
556
557    #[test]
558    fn test_identity_sub_zero() {
559        let (e, rule) = apply_identity(&sub(v("x"), c(0))).unwrap();
560        assert_eq!(e, v("x"));
561        assert_eq!(rule, "sub_zero");
562    }
563
564    #[test]
565    fn test_identity_add_self() {
566        let (e, rule) = apply_identity(&add(v("x"), v("x"))).unwrap();
567        assert_eq!(e, mul(c(2), v("x")));
568        assert_eq!(rule, "add_self");
569    }
570
571    #[test]
572    fn test_identity_sub_self() {
573        let (e, rule) = apply_identity(&sub(v("x"), v("x"))).unwrap();
574        assert_eq!(e, c(0));
575        assert_eq!(rule, "sub_self");
576    }
577
578    #[test]
579    fn test_identity_div_self_var() {
580        let (e, rule) = apply_identity(&div(v("x"), v("x"))).unwrap();
581        assert_eq!(e, c(1));
582        assert_eq!(rule, "div_self");
583    }
584
585    #[test]
586    fn test_identity_div_self_nonzero_const() {
587        let (e, rule) = apply_identity(&div(c(5), c(5))).unwrap();
588        assert_eq!(e, c(1));
589        assert_eq!(rule, "div_self");
590    }
591
592    #[test]
593    fn test_identity_neg_zero() {
594        let (e, rule) = apply_identity(&sub(c(0), v("x"))).unwrap();
595        assert_eq!(e, neg(v("x")));
596        assert_eq!(rule, "neg_zero");
597    }
598
599    #[test]
600    fn test_identity_double_neg() {
601        let (e, rule) = apply_identity(&neg(neg(v("x")))).unwrap();
602        assert_eq!(e, v("x"));
603        assert_eq!(rule, "double_neg");
604    }
605
606    #[test]
607    fn test_identity_pow_zero() {
608        let (e, rule) = apply_identity(&pow(v("x"), c(0))).unwrap();
609        assert_eq!(e, c(1));
610        assert_eq!(rule, "pow_zero");
611    }
612
613    #[test]
614    fn test_identity_pow_one() {
615        let (e, rule) = apply_identity(&pow(v("x"), c(1))).unwrap();
616        assert_eq!(e, v("x"));
617        assert_eq!(rule, "pow_one");
618    }
619
620    #[test]
621    fn test_identity_zero_pow() {
622        let (e, rule) = apply_identity(&pow(c(0), c(3))).unwrap();
623        assert_eq!(e, c(0));
624        assert_eq!(rule, "zero_pow");
625    }
626
627    #[test]
628    fn test_identity_no_match() {
629        assert!(apply_identity(&add(v("x"), v("y"))).is_none());
630    }
631
632    // ── simplify ──────────────────────────────────────────────────────────────
633
634    #[test]
635    fn test_simplify_constant_fold_through() {
636        let expr = add(c(3), c(4));
637        let result = simplify(expr, &cfg_default());
638        assert_eq!(result.expr, c(7));
639        assert!(result.reduced);
640    }
641
642    #[test]
643    fn test_simplify_add_zero() {
644        let expr = add(v("x"), c(0));
645        let result = simplify(expr, &cfg_default());
646        assert_eq!(result.expr, v("x"));
647    }
648
649    #[test]
650    fn test_simplify_mul_one() {
651        let expr = mul(v("x"), c(1));
652        let result = simplify(expr, &cfg_default());
653        assert_eq!(result.expr, v("x"));
654    }
655
656    #[test]
657    fn test_simplify_no_change() {
658        let expr = add(v("x"), v("y"));
659        let result = simplify(expr.clone(), &cfg_default());
660        // normalize may reorder but no identity fires
661        assert!(!result.steps.iter().any(|s| s.contains("fold")));
662    }
663
664    #[test]
665    fn test_simplify_nested() {
666        // (x + 0) * 1  =>  x
667        let expr = mul(add(v("x"), c(0)), c(1));
668        let result = simplify(expr, &cfg_default());
669        assert_eq!(result.expr, v("x"));
670    }
671
672    // ── expr_size ─────────────────────────────────────────────────────────────
673
674    #[test]
675    fn test_expr_size_leaf() {
676        assert_eq!(expr_size(&c(5)), 1);
677        assert_eq!(expr_size(&v("x")), 1);
678    }
679
680    #[test]
681    fn test_expr_size_add() {
682        assert_eq!(expr_size(&add(c(1), c(2))), 3);
683    }
684
685    #[test]
686    fn test_expr_size_nested() {
687        assert_eq!(expr_size(&add(mul(v("a"), v("b")), c(1))), 5);
688    }
689
690    // ── count_vars ────────────────────────────────────────────────────────────
691
692    #[test]
693    fn test_count_vars_single() {
694        let m = count_vars(&v("x"));
695        assert_eq!(m.get("x"), Some(&1));
696    }
697
698    #[test]
699    fn test_count_vars_repeated() {
700        let m = count_vars(&add(v("x"), v("x")));
701        assert_eq!(m.get("x"), Some(&2));
702    }
703
704    #[test]
705    fn test_count_vars_multiple() {
706        let m = count_vars(&add(v("x"), v("y")));
707        assert_eq!(m.get("x"), Some(&1));
708        assert_eq!(m.get("y"), Some(&1));
709    }
710
711    // ── substitute_vars ───────────────────────────────────────────────────────
712
713    #[test]
714    fn test_substitute_simple() {
715        let mut subs = HashMap::new();
716        subs.insert("x".to_string(), c(5));
717        let result = substitute_vars(&v("x"), &subs);
718        assert_eq!(result, c(5));
719    }
720
721    #[test]
722    fn test_substitute_partial() {
723        let mut subs = HashMap::new();
724        subs.insert("x".to_string(), c(3));
725        let result = substitute_vars(&add(v("x"), v("y")), &subs);
726        assert_eq!(result, add(c(3), v("y")));
727    }
728
729    #[test]
730    fn test_substitute_then_simplify() {
731        let mut subs = HashMap::new();
732        subs.insert("x".to_string(), c(0));
733        let expr = add(v("x"), v("y"));
734        let subbed = substitute_vars(&expr, &subs);
735        let result = simplify(subbed, &cfg_default());
736        assert_eq!(result.expr, v("y"));
737    }
738
739    // ── alg_expr_to_string ────────────────────────────────────────────────────
740
741    #[test]
742    fn test_to_string_const() {
743        assert_eq!(alg_expr_to_string(&c(42)), "42");
744    }
745
746    #[test]
747    fn test_to_string_var() {
748        assert_eq!(alg_expr_to_string(&v("x")), "x");
749    }
750
751    #[test]
752    fn test_to_string_add() {
753        assert_eq!(alg_expr_to_string(&add(v("a"), c(1))), "(a + 1)");
754    }
755
756    // ── normalize ─────────────────────────────────────────────────────────────
757
758    #[test]
759    fn test_normalize_commutes_add() {
760        let e1 = normalize(&add(v("z"), v("a")));
761        let e2 = normalize(&add(v("a"), v("z")));
762        assert_eq!(e1, e2, "normalize should produce same form for a+z and z+a");
763    }
764
765    #[test]
766    fn test_normalize_commutes_mul() {
767        let e1 = normalize(&mul(v("z"), v("a")));
768        let e2 = normalize(&mul(v("a"), v("z")));
769        assert_eq!(e1, e2);
770    }
771}