quantrs2_symengine_pure/simplify/
mod.rs

1//! Expression simplification using e-graph equality saturation.
2//!
3//! This module uses the egg library to perform term rewriting and
4//! simplification via equality saturation.
5
6use std::collections::HashMap;
7
8use egg::{rewrite, CostFunction, Id, Language, RecExpr, Rewrite, Runner, Symbol};
9
10use crate::expr::{ExprLang, Expression};
11
12/// Expand an expression (distribute products over sums)
13pub fn expand(expr: &Expression) -> Expression {
14    // First, manually expand any power-2 expressions
15    let expanded_pow = expand_powers(expr);
16
17    // Then fully distribute all multiplications over additions
18    distribute_fully(&expanded_pow)
19}
20
21/// Fully distribute multiplications over additions
22/// This implements FOIL-like expansion for all product-of-sums
23fn distribute_fully(expr: &Expression) -> Expression {
24    // Recursively process the expression
25    if expr.is_mul() {
26        // SAFETY: is_mul() check guarantees as_mul() will succeed
27        let operands = expr.as_mul().expect("is_mul() was true");
28        let left = distribute_fully(&operands[0]);
29        let right = distribute_fully(&operands[1]);
30
31        // Distribute multiplication over additions
32        distribute_product(&left, &right)
33    } else if expr.is_add() {
34        // SAFETY: is_add() check guarantees as_add() will succeed
35        let operands = expr.as_add().expect("is_add() was true");
36        let left = distribute_fully(&operands[0]);
37        let right = distribute_fully(&operands[1]);
38        left + right
39    } else if expr.is_neg() {
40        // SAFETY: is_neg() check guarantees as_neg() will succeed
41        let inner = expr.as_neg().expect("is_neg() was true");
42        -distribute_fully(&inner)
43    } else if expr.is_pow() {
44        // SAFETY: is_pow() check guarantees as_pow() will succeed
45        let (base, exp) = expr.as_pow().expect("is_pow() was true");
46        let expanded_base = distribute_fully(&base);
47        expanded_base.pow(&exp)
48    } else {
49        // Symbols, numbers, etc. - return as-is
50        expr.clone()
51    }
52}
53
54/// Distribute a product: (a + b) * (c + d) = a*c + a*d + b*c + b*d
55fn distribute_product(left: &Expression, right: &Expression) -> Expression {
56    // Get all addends from left
57    let left_terms = collect_addends(left);
58    // Get all addends from right
59    let right_terms = collect_addends(right);
60
61    // Multiply each pair
62    let mut result_terms: Vec<Expression> = Vec::new();
63    for l in &left_terms {
64        for r in &right_terms {
65            let product = multiply_terms(l, r);
66            result_terms.push(product);
67        }
68    }
69
70    // Build sum
71    if result_terms.is_empty() {
72        Expression::zero()
73    } else {
74        let mut result = result_terms.remove(0);
75        for term in result_terms {
76            result = result + term;
77        }
78        result
79    }
80}
81
82/// Collect all addends from an expression (handles nested additions)
83fn collect_addends(expr: &Expression) -> Vec<Expression> {
84    if expr.is_add() {
85        // SAFETY: is_add() check guarantees as_add() will succeed
86        let operands = expr.as_add().expect("is_add() was true");
87        let mut terms = collect_addends(&operands[0]);
88        terms.extend(collect_addends(&operands[1]));
89        terms
90    } else {
91        vec![expr.clone()]
92    }
93}
94
95/// Multiply two terms, handling negations
96fn multiply_terms(a: &Expression, b: &Expression) -> Expression {
97    // Handle negations to keep things clean
98    let (a_neg, a_inner) = unwrap_neg(a);
99    let (b_neg, b_inner) = unwrap_neg(b);
100
101    let product = a_inner * b_inner;
102
103    // XOR the negations
104    if a_neg ^ b_neg {
105        -product
106    } else {
107        product
108    }
109}
110
111/// Unwrap negation: returns (is_negated, inner_expression)
112fn unwrap_neg(expr: &Expression) -> (bool, Expression) {
113    if expr.is_neg() {
114        // SAFETY: is_neg() check guarantees as_neg() will succeed
115        let inner = expr.as_neg().expect("is_neg() was true");
116        let (inner_neg, inner_expr) = unwrap_neg(&inner);
117        // Double negation cancels out
118        (!inner_neg, inner_expr)
119    } else {
120        (false, expr.clone())
121    }
122}
123
124/// Recursively expand power expressions with exponent 2
125fn expand_powers(expr: &Expression) -> Expression {
126    // Check if this is a power expression
127    if expr.is_pow() {
128        // SAFETY: is_pow() check guarantees as_pow() will succeed
129        let (base, exp) = expr.as_pow().expect("is_pow() was true");
130
131        // First recursively expand powers in the base
132        let expanded_base = expand_powers(&base);
133
134        // Check if exponent is 2
135        if exp.is_number() {
136            if let Some(exp_val) = exp.to_f64() {
137                if (exp_val - 2.0).abs() < 1e-10 {
138                    // a^2 => a * a
139                    return expanded_base.clone() * expanded_base;
140                }
141            }
142        }
143
144        // For other exponents, return base^exp with expanded base
145        return expanded_base.pow(&exp);
146    }
147
148    // Check if this is an addition - recursively expand
149    if expr.is_add() {
150        // SAFETY: is_add() check guarantees as_add() will succeed
151        let operands = expr.as_add().expect("is_add() was true");
152        let left = expand_powers(&operands[0]);
153        let right = expand_powers(&operands[1]);
154        return left + right;
155    }
156
157    // Check if this is a multiplication - recursively expand
158    if expr.is_mul() {
159        // SAFETY: is_mul() check guarantees as_mul() will succeed
160        let operands = expr.as_mul().expect("is_mul() was true");
161        let left = expand_powers(&operands[0]);
162        let right = expand_powers(&operands[1]);
163        return left * right;
164    }
165
166    // Check if this is a negation - recursively expand
167    if expr.is_neg() {
168        // SAFETY: is_neg() check guarantees as_neg() will succeed
169        let inner = expr.as_neg().expect("is_neg() was true");
170        return -expand_powers(&inner);
171    }
172
173    // For all other expressions (symbols, numbers), return as-is
174    expr.clone()
175}
176
177/// Simplify an expression using e-graph equality saturation
178pub fn simplify(expr: &Expression) -> Expression {
179    let rules = get_simplification_rules();
180
181    let runner = Runner::default()
182        .with_expr(expr.as_rec_expr())
183        .with_iter_limit(20)
184        .run(&rules);
185
186    let root = runner.roots[0];
187    let extractor = egg::Extractor::new(&runner.egraph, AstSize);
188    let (_, best) = extractor.find_best(root);
189
190    Expression::from_rec_expr(best)
191}
192
193/// Substitute a variable with an expression
194pub fn substitute(expr: &Expression, var: &Expression, value: &Expression) -> Expression {
195    let var_name = match var.as_symbol() {
196        Some(name) => name.to_string(),
197        None => return expr.clone(), // Can only substitute symbols
198    };
199
200    let rec_expr = expr.as_rec_expr();
201    let value_expr = value.as_rec_expr();
202
203    // Build a new expression with substitution
204    let mut new_expr = RecExpr::default();
205    let mut id_map: HashMap<usize, Id> = HashMap::new();
206
207    substitute_rec(
208        rec_expr,
209        rec_expr.as_ref().len() - 1,
210        &var_name,
211        value_expr,
212        &mut new_expr,
213        &mut id_map,
214    );
215
216    Expression::from_rec_expr(new_expr)
217}
218
219/// Recursive substitution helper
220fn substitute_rec(
221    expr: &RecExpr<ExprLang>,
222    idx: usize,
223    var_name: &str,
224    value: &RecExpr<ExprLang>,
225    new_expr: &mut RecExpr<ExprLang>,
226    id_map: &mut HashMap<usize, Id>,
227) -> Id {
228    if let Some(&new_id) = id_map.get(&idx) {
229        return new_id;
230    }
231
232    let node = &expr[Id::from(idx)];
233
234    // Check if this is the variable to substitute
235    if let ExprLang::Num(s) = node {
236        if s.as_str() == var_name {
237            // Insert the value expression
238            let offset = new_expr.as_ref().len();
239            for (i, n) in value.as_ref().iter().enumerate() {
240                let mapped_node = n
241                    .clone()
242                    .map_children(|child_id| Id::from(usize::from(child_id) + offset));
243                new_expr.add(mapped_node);
244            }
245            let new_id = Id::from(new_expr.as_ref().len() - 1);
246            id_map.insert(idx, new_id);
247            return new_id;
248        }
249    }
250
251    // Otherwise, recursively process children
252    let new_node = node.clone().map_children(|child_id| {
253        substitute_rec(
254            expr,
255            usize::from(child_id),
256            var_name,
257            value,
258            new_expr,
259            id_map,
260        )
261    });
262    let new_id = new_expr.add(new_node);
263    id_map.insert(idx, new_id);
264    new_id
265}
266
267/// Cost function for extracting the simplest expression
268struct AstSize;
269
270impl CostFunction<ExprLang> for AstSize {
271    type Cost = usize;
272
273    fn cost<C>(&mut self, node: &ExprLang, mut costs: C) -> Self::Cost
274    where
275        C: FnMut(Id) -> Self::Cost,
276    {
277        let node_cost = match node {
278            // Prefer simpler nodes (Num covers both constants and symbols)
279            ExprLang::Num(_) => 1,
280            _ => 3,
281        };
282
283        node.fold(node_cost, |sum, id| sum + costs(id))
284    }
285}
286
287/// Cost function that prefers expanded (distributed) forms
288/// Note: Currently unused as expand() uses direct distribute_fully() instead.
289#[allow(dead_code)]
290struct ExpandedSize;
291
292impl CostFunction<ExprLang> for ExpandedSize {
293    type Cost = usize;
294
295    fn cost<C>(&mut self, node: &ExprLang, mut costs: C) -> Self::Cost
296    where
297        C: FnMut(Id) -> Self::Cost,
298    {
299        let node_cost = match node {
300            ExprLang::Num(_) => 1,
301            // Prefer additions over multiplications (expanded form)
302            ExprLang::Add(_) => 2,
303            ExprLang::Mul(_) => 4,
304            _ => 3,
305        };
306
307        node.fold(node_cost, |sum, id| sum + costs(id))
308    }
309}
310
311/// Get distribution rewrite rules (for expanding products over sums)
312/// Note: Currently unused as expand() uses direct distribute_fully() instead,
313/// but kept for potential future e-graph based expansion use cases.
314#[allow(dead_code)]
315fn get_distribution_rules() -> Vec<Rewrite<ExprLang, ()>> {
316    vec![
317        // Distributivity (left and right)
318        rewrite!("distrib-left"; "(* ?a (+ ?b ?c))" => "(+ (* ?a ?b) (* ?a ?c))"),
319        rewrite!("distrib-right"; "(* (+ ?a ?b) ?c)" => "(+ (* ?a ?c) (* ?b ?c))"),
320        // Note: pow-2 rule doesn't work due to Num(Symbol) not matching literal "2"
321        // Power expansion is now handled by expand_powers() function
322        // Negation handling for expansion
323        // (neg a) * b = neg(a * b)
324        rewrite!("neg-mul-left"; "(* (neg ?a) ?b)" => "(neg (* ?a ?b))"),
325        // a * (neg b) = neg(a * b)
326        rewrite!("neg-mul-right"; "(* ?a (neg ?b))" => "(neg (* ?a ?b))"),
327        // (neg a) * (neg b) = a * b (double negation in multiplication)
328        rewrite!("neg-neg-mul"; "(* (neg ?a) (neg ?b))" => "(* ?a ?b)"),
329        // Distribute negation over addition
330        rewrite!("neg-add"; "(neg (+ ?a ?b))" => "(+ (neg ?a) (neg ?b))"),
331        // Double negation elimination
332        rewrite!("neg-neg"; "(neg (neg ?a))" => "?a"),
333        // Multiplication associativity (helps with nested distributions)
334        rewrite!("mul-assoc"; "(* ?a (* ?b ?c))" => "(* (* ?a ?b) ?c)"),
335        rewrite!("mul-assoc-rev"; "(* (* ?a ?b) ?c)" => "(* ?a (* ?b ?c))"),
336        // Addition associativity (helps flatten sums)
337        rewrite!("add-assoc"; "(+ ?a (+ ?b ?c))" => "(+ (+ ?a ?b) ?c)"),
338        rewrite!("add-assoc-rev"; "(+ (+ ?a ?b) ?c)" => "(+ ?a (+ ?b ?c))"),
339        // Commutativity (needed for proper expansion)
340        rewrite!("mul-comm"; "(* ?a ?b)" => "(* ?b ?a)"),
341        rewrite!("add-comm"; "(+ ?a ?b)" => "(+ ?b ?a)"),
342        // Basic simplifications needed during expansion
343        rewrite!("add-zero"; "(+ ?a 0)" => "?a"),
344        rewrite!("zero-add"; "(+ 0 ?a)" => "?a"),
345        rewrite!("mul-one"; "(* ?a 1)" => "?a"),
346        rewrite!("one-mul"; "(* 1 ?a)" => "?a"),
347        rewrite!("mul-zero"; "(* ?a 0)" => "0"),
348        rewrite!("zero-mul"; "(* 0 ?a)" => "0"),
349        // Handle negation with constants
350        rewrite!("neg-zero"; "(neg 0)" => "0"),
351    ]
352}
353
354/// Get the simplification rewrite rules
355fn get_simplification_rules() -> Vec<Rewrite<ExprLang, ()>> {
356    vec![
357        // Additive identity: a + 0 = a
358        rewrite!("add-zero"; "(+ ?a 0)" => "?a"),
359        rewrite!("zero-add"; "(+ 0 ?a)" => "?a"),
360        // Multiplicative identity: a * 1 = a
361        rewrite!("mul-one"; "(* ?a 1)" => "?a"),
362        rewrite!("one-mul"; "(* 1 ?a)" => "?a"),
363        // Multiplicative zero: a * 0 = 0
364        rewrite!("mul-zero"; "(* ?a 0)" => "0"),
365        rewrite!("zero-mul"; "(* 0 ?a)" => "0"),
366        // Double negation: -(-a) = a
367        rewrite!("neg-neg"; "(neg (neg ?a))" => "?a"),
368        // Power rules
369        rewrite!("pow-zero"; "(^ ?a 0)" => "1"),
370        rewrite!("pow-one"; "(^ ?a 1)" => "?a"),
371        // Commutativity
372        rewrite!("add-comm"; "(+ ?a ?b)" => "(+ ?b ?a)"),
373        rewrite!("mul-comm"; "(* ?a ?b)" => "(* ?b ?a)"),
374        // Associativity
375        rewrite!("add-assoc"; "(+ ?a (+ ?b ?c))" => "(+ (+ ?a ?b) ?c)"),
376        rewrite!("mul-assoc"; "(* ?a (* ?b ?c))" => "(* (* ?a ?b) ?c)"),
377        // Distributivity
378        rewrite!("distrib"; "(* ?a (+ ?b ?c))" => "(+ (* ?a ?b) (* ?a ?c))"),
379        // Trigonometric identities
380        // sin^2 + cos^2 = 1 (this is tricky to express as a rewrite)
381
382        // Exponential/logarithm identities
383        rewrite!("exp-log"; "(exp (log ?a))" => "?a"),
384        rewrite!("log-exp"; "(log (exp ?a))" => "?a"),
385        // sqrt(x^2) = |x|
386        rewrite!("sqrt-sq"; "(sqrt (^ ?a 2))" => "(abs ?a)"),
387    ]
388}
389
390/// Get quantum-specific simplification rules
391pub fn get_quantum_rules() -> Vec<Rewrite<ExprLang, ()>> {
392    vec![
393        // Commutator identities
394        // [A, A] = 0
395        rewrite!("comm-self"; "(comm ?a ?a)" => "0"),
396        // [A, B] = -[B, A] (antisymmetry)
397        rewrite!("comm-antisym"; "(comm ?a ?b)" => "(neg (comm ?b ?a))"),
398        // [0, A] = 0, [A, 0] = 0
399        rewrite!("comm-zero-left"; "(comm 0 ?a)" => "0"),
400        rewrite!("comm-zero-right"; "(comm ?a 0)" => "0"),
401        // Anticommutator identities
402        // {A, A} = 2A
403        rewrite!("anticomm-self"; "(anticomm ?a ?a)" => "(* 2 ?a)"),
404        // {A, B} = {B, A} (symmetry)
405        rewrite!("anticomm-sym"; "(anticomm ?a ?b)" => "(anticomm ?b ?a)"),
406        // {0, A} = A, {A, 0} = A
407        rewrite!("anticomm-zero"; "(anticomm 0 ?a)" => "?a"),
408        // Hermitian conjugate (dagger) identities
409        // (A†)† = A
410        rewrite!("dagger-dagger"; "(dagger (dagger ?a))" => "?a"),
411        // (AB)† = B†A† (reversal for products)
412        rewrite!("dagger-mul"; "(dagger (* ?a ?b))" => "(* (dagger ?b) (dagger ?a))"),
413        // (A + B)† = A† + B†
414        rewrite!("dagger-add"; "(dagger (+ ?a ?b))" => "(+ (dagger ?a) (dagger ?b))"),
415        // (cA)† = c*A† (for complex scalars, * denotes conjugate)
416        // This is handled via (conj c) * (dagger A)
417
418        // 0† = 0
419        rewrite!("dagger-zero"; "(dagger 0)" => "0"),
420        // 1† = 1
421        rewrite!("dagger-one"; "(dagger 1)" => "1"),
422        // Trace identities
423        // tr(A + B) = tr(A) + tr(B)
424        rewrite!("trace-add"; "(trace (+ ?a ?b))" => "(+ (trace ?a) (trace ?b))"),
425        // tr(cA) = c * tr(A)
426        rewrite!("trace-scale"; "(trace (* ?c ?a))" => "(* ?c (trace ?a))"),
427        // tr(0) = 0
428        rewrite!("trace-zero"; "(trace 0)" => "0"),
429        // Tensor product identities
430        // (A ⊗ B)(C ⊗ D) = (AC) ⊗ (BD) (this is a simplification hint)
431        rewrite!("tensor-mul"; "(* (tensor ?a ?b) (tensor ?c ?d))" => "(tensor (* ?a ?c) (* ?b ?d))"),
432        // A ⊗ 1 = A (for identity operator 1)
433        rewrite!("tensor-one-right"; "(tensor ?a 1)" => "?a"),
434        rewrite!("tensor-one-left"; "(tensor 1 ?a)" => "?a"),
435        // A ⊗ 0 = 0
436        rewrite!("tensor-zero"; "(tensor ?a 0)" => "0"),
437        rewrite!("tensor-zero-left"; "(tensor 0 ?a)" => "0"),
438        // Determinant identities
439        // det(AB) = det(A) * det(B) - only true for square matrices
440        // det(I) = 1
441        rewrite!("det-one"; "(det 1)" => "1"),
442        // Transpose identities
443        // (A^T)^T = A
444        rewrite!("transpose-transpose"; "(transpose (transpose ?a))" => "?a"),
445        // (AB)^T = B^T A^T
446        rewrite!("transpose-mul"; "(transpose (* ?a ?b))" => "(* (transpose ?b) (transpose ?a))"),
447        // (A + B)^T = A^T + B^T
448        rewrite!("transpose-add"; "(transpose (+ ?a ?b))" => "(+ (transpose ?a) (transpose ?b))"),
449    ]
450}
451
452/// Simplify an expression with quantum-specific rules
453pub fn simplify_quantum(expr: &Expression) -> Expression {
454    let mut rules = get_simplification_rules();
455    rules.extend(get_quantum_rules());
456
457    let runner = Runner::default()
458        .with_expr(expr.as_rec_expr())
459        .with_iter_limit(30)
460        .run(&rules);
461
462    let root = runner.roots[0];
463    let extractor = egg::Extractor::new(&runner.egraph, AstSize);
464    let (_, best) = extractor.find_best(root);
465
466    Expression::from_rec_expr(best)
467}
468
469/// Get trigonometric identities useful in quantum computing
470pub fn get_trig_rules() -> Vec<Rewrite<ExprLang, ()>> {
471    vec![
472        // Pythagorean identity: sin²(x) + cos²(x) = 1
473        // This is hard to express directly, but we can express some related rules
474
475        // sin(0) = 0
476        rewrite!("sin-zero"; "(sin 0)" => "0"),
477        // cos(0) = 1
478        rewrite!("cos-zero"; "(cos 0)" => "1"),
479        // tan(0) = 0
480        rewrite!("tan-zero"; "(tan 0)" => "0"),
481        // exp(0) = 1
482        rewrite!("exp-zero"; "(exp 0)" => "1"),
483        // log(1) = 0
484        rewrite!("log-one"; "(log 1)" => "0"),
485        // sin(-x) = -sin(x) (odd function)
486        rewrite!("sin-neg"; "(sin (neg ?x))" => "(neg (sin ?x))"),
487        // cos(-x) = cos(x) (even function)
488        rewrite!("cos-neg"; "(cos (neg ?x))" => "(cos ?x)"),
489        // tan(-x) = -tan(x) (odd function)
490        rewrite!("tan-neg"; "(tan (neg ?x))" => "(neg (tan ?x))"),
491        // exp(a + b) = exp(a) * exp(b)
492        rewrite!("exp-add"; "(exp (+ ?a ?b))" => "(* (exp ?a) (exp ?b))"),
493        // log(a * b) = log(a) + log(b)
494        rewrite!("log-mul"; "(log (* ?a ?b))" => "(+ (log ?a) (log ?b))"),
495        // exp(log(x)) = x
496        rewrite!("exp-log"; "(exp (log ?x))" => "?x"),
497        // log(exp(x)) = x
498        rewrite!("log-exp"; "(log (exp ?x))" => "?x"),
499        // sqrt(x)^2 = x
500        rewrite!("sqrt-sq"; "(^ (sqrt ?x) 2)" => "?x"),
501        // sqrt(x^2) = |x|
502        rewrite!("sq-sqrt"; "(sqrt (^ ?x 2))" => "(abs ?x)"),
503    ]
504}
505
506/// Simplify with trigonometric rules
507pub fn simplify_trig(expr: &Expression) -> Expression {
508    let mut rules = get_simplification_rules();
509    rules.extend(get_trig_rules());
510
511    let runner = Runner::default()
512        .with_expr(expr.as_rec_expr())
513        .with_iter_limit(30)
514        .run(&rules);
515
516    let root = runner.roots[0];
517    let extractor = egg::Extractor::new(&runner.egraph, AstSize);
518    let (_, best) = extractor.find_best(root);
519
520    Expression::from_rec_expr(best)
521}
522
523/// Collect like terms in a polynomial expression
524///
525/// This is a more aggressive simplification that tries to collect
526/// terms with the same variable factors.
527pub fn collect(expr: &Expression, var: &Expression) -> Expression {
528    // First expand, then simplify
529    let expanded = expand(expr);
530
531    // For now, just return simplified form
532    // Full polynomial collection would require more sophisticated analysis
533    simplify(&expanded)
534}
535
536/// Factor common terms out of a sum
537///
538/// For example: ax + ay -> a(x + y)
539pub fn factor(expr: &Expression) -> Expression {
540    let factor_rules = vec![
541        // Reverse distributivity: common factor extraction
542        rewrite!("factor-left"; "(+ (* ?a ?b) (* ?a ?c))" => "(* ?a (+ ?b ?c))"),
543        rewrite!("factor-right"; "(+ (* ?a ?c) (* ?b ?c))" => "(* (+ ?a ?b) ?c)"),
544        // a + a = 2a
545        rewrite!("add-same"; "(+ ?a ?a)" => "(* 2 ?a)"),
546        // Basic simplifications
547        rewrite!("mul-one"; "(* ?a 1)" => "?a"),
548        rewrite!("mul-zero"; "(* ?a 0)" => "0"),
549    ];
550
551    let runner: Runner<ExprLang, ()> = Runner::default()
552        .with_expr(expr.as_rec_expr())
553        .with_iter_limit(20)
554        .run(&factor_rules);
555
556    let root = runner.roots[0];
557
558    // Use a cost function that prefers factored forms
559    let extractor = egg::Extractor::new(&runner.egraph, FactoredSize);
560    let (_, best) = extractor.find_best(root);
561
562    Expression::from_rec_expr(best)
563}
564
565/// Cost function that prefers factored (shorter) forms
566struct FactoredSize;
567
568impl CostFunction<ExprLang> for FactoredSize {
569    type Cost = usize;
570
571    fn cost<C>(&mut self, node: &ExprLang, mut costs: C) -> Self::Cost
572    where
573        C: FnMut(Id) -> Self::Cost,
574    {
575        let node_cost = match node {
576            ExprLang::Num(_) => 1,
577            // Prefer multiplications over additions for factored form
578            ExprLang::Mul(_) => 2,
579            ExprLang::Add(_) => 4,
580            _ => 3,
581        };
582
583        node.fold(node_cost, |sum, id| sum + costs(id))
584    }
585}
586
587#[cfg(test)]
588#[allow(clippy::redundant_clone)]
589mod tests {
590    use super::*;
591
592    #[test]
593    fn test_simplify_add_zero() {
594        let x = Expression::symbol("x");
595        let zero = Expression::zero();
596        let expr = x + zero;
597
598        let simplified = simplify(&expr);
599        // The simplification should reduce x + 0 to x
600        assert!(simplified.as_symbol().is_some());
601    }
602
603    #[test]
604    fn test_simplify_mul_one() {
605        let x = Expression::symbol("x");
606        let one = Expression::one();
607        let expr = x * one;
608
609        let simplified = simplify(&expr);
610        assert!(simplified.as_symbol().is_some());
611    }
612
613    #[test]
614    fn test_simplify_mul_zero() {
615        let x = Expression::symbol("x");
616        let zero = Expression::zero();
617        let expr = x * zero;
618
619        let simplified = simplify(&expr);
620        assert!(simplified.is_zero());
621    }
622
623    #[test]
624    fn test_substitute_simple() {
625        let x = Expression::symbol("x");
626        let y = Expression::symbol("y");
627        let two = Expression::int(2);
628
629        // x + y, substitute x -> 2
630        let expr = x.clone() + y;
631        let result = substitute(&expr, &x, &two);
632
633        // The result should be 2 + y
634        let mut values = std::collections::HashMap::new();
635        values.insert("y".to_string(), 3.0);
636        let eval_result = result.eval(&values);
637        assert!(eval_result.is_ok());
638        assert!((eval_result.expect("eval") - 5.0).abs() < 1e-10);
639    }
640
641    #[test]
642    fn test_substitute_nested() {
643        let x = Expression::symbol("x");
644        let y = Expression::symbol("y");
645
646        // x * x, substitute x -> y
647        let expr = x.clone() * x.clone();
648        let result = substitute(&expr, &x, &y);
649
650        // The result should be y * y
651        let mut values = std::collections::HashMap::new();
652        values.insert("y".to_string(), 3.0);
653        let eval_result = result.eval(&values);
654        assert!(eval_result.is_ok());
655        assert!((eval_result.expect("eval") - 9.0).abs() < 1e-10);
656    }
657
658    #[test]
659    fn test_expand_distribution() {
660        let x = Expression::symbol("x");
661        let y = Expression::symbol("y");
662        let z = Expression::symbol("z");
663
664        // x * (y + z) should expand to x*y + x*z
665        let expr = x * (y + z);
666        let expanded = expand(&expr);
667
668        // Verify by evaluation
669        let mut values = std::collections::HashMap::new();
670        values.insert("x".to_string(), 2.0);
671        values.insert("y".to_string(), 3.0);
672        values.insert("z".to_string(), 4.0);
673
674        let orig_val = expr.eval(&values).expect("eval original");
675        let exp_val = expanded.eval(&values).expect("eval expanded");
676
677        assert!((orig_val - exp_val).abs() < 1e-10);
678        assert!((exp_val - 14.0).abs() < 1e-10); // 2*(3+4) = 14
679    }
680
681    #[test]
682    fn test_factor_common_terms() {
683        let a = Expression::symbol("a");
684        let x = Expression::symbol("x");
685        let y = Expression::symbol("y");
686
687        // a*x + a*y should factor to a*(x+y)
688        let expr = a.clone() * x.clone() + a.clone() * y.clone();
689        let factored = factor(&expr);
690
691        // Verify by evaluation - both should give same result
692        let mut values = std::collections::HashMap::new();
693        values.insert("a".to_string(), 2.0);
694        values.insert("x".to_string(), 3.0);
695        values.insert("y".to_string(), 4.0);
696
697        let orig_val = expr.eval(&values).expect("eval original");
698        let fact_val = factored.eval(&values).expect("eval factored");
699
700        assert!((orig_val - fact_val).abs() < 1e-10);
701        assert!((fact_val - 14.0).abs() < 1e-10); // 2*3 + 2*4 = 14
702    }
703
704    #[test]
705    fn test_simplify_trig() {
706        // Test that sin(0) = 0
707        let zero = Expression::zero();
708        let sin_zero = crate::ops::trig::sin(&zero);
709        let simplified = simplify_trig(&sin_zero);
710
711        // After simplification, sin(0) should be 0
712        // Verify by evaluation at a point
713        let result = simplified.eval(&std::collections::HashMap::new());
714        assert!(result.is_ok());
715        assert!(result.expect("eval").abs() < 1e-10);
716    }
717
718    #[test]
719    fn test_simplify_quantum_dagger() {
720        // Test that (A†)† = A
721        // We can't directly test this with the current DSL since dagger is symbolic
722        // But we can verify the rules are in place
723        let rules = get_quantum_rules();
724        assert!(!rules.is_empty());
725
726        // Verify specific rules exist by checking the count
727        // We have many quantum rules defined
728        assert!(rules.len() >= 15);
729    }
730
731    #[test]
732    fn test_collect() {
733        let x = Expression::symbol("x");
734
735        // x + x should become 2x after collect
736        let expr = x.clone() + x.clone();
737        let collected = collect(&expr, &x);
738
739        // Verify by evaluation
740        let mut values = std::collections::HashMap::new();
741        values.insert("x".to_string(), 5.0);
742
743        let orig_val = expr.eval(&values).expect("eval original");
744        let coll_val = collected.eval(&values).expect("eval collected");
745
746        assert!((orig_val - coll_val).abs() < 1e-10);
747        assert!((coll_val - 10.0).abs() < 1e-10); // 5 + 5 = 10
748    }
749
750    #[test]
751    fn test_expand_simple_pow2() {
752        // Test simple a^2 = a*a
753        let a = Expression::symbol("a");
754        let two = Expression::from(2);
755
756        let expr = a.clone().pow(&two);
757        let expanded = expand(&expr);
758
759        // Should expand to a*a
760        let mut values = std::collections::HashMap::new();
761        values.insert("a".to_string(), 3.0);
762        let exp_val = expanded.eval(&values).expect("eval");
763        assert!((exp_val - 9.0).abs() < 1e-10);
764    }
765
766    #[test]
767    fn test_expand_binomial_squared() {
768        // Test (a+b)^2 = a^2 + 2ab + b^2
769        let a = Expression::symbol("a");
770        let b = Expression::symbol("b");
771        let two = Expression::from(2);
772
773        let expr = (a.clone() + b.clone()).pow(&two);
774        let expanded = expand(&expr);
775
776        // Verify by evaluation at multiple points
777        for (a_val, b_val) in [(2.0, 3.0), (1.0, 1.0), (0.0, 5.0)] {
778            let mut values = std::collections::HashMap::new();
779            values.insert("a".to_string(), a_val);
780            values.insert("b".to_string(), b_val);
781
782            let orig_val = expr.eval(&values).expect("eval original");
783            let exp_val = expanded.eval(&values).expect("eval expanded");
784
785            // (a+b)^2 should equal expanded form
786            assert!(
787                (orig_val - exp_val).abs() < 1e-10,
788                "Mismatch at a={a_val}, b={b_val}: orig={orig_val}, expanded={exp_val}"
789            );
790            // Expected: (a+b)^2
791            let expected = (a_val + b_val).powi(2);
792            assert!(
793                (exp_val - expected).abs() < 1e-10,
794                "Unexpected value at a={a_val}, b={b_val}: got {exp_val}, expected {expected}"
795            );
796        }
797    }
798
799    #[test]
800    fn test_expand_polynomial_constraint() {
801        // Test (x+y+z-1)^2 expansion
802        // This is used in QUBO constraint expressions
803        let x = Expression::symbol("x");
804        let y = Expression::symbol("y");
805        let z = Expression::symbol("z");
806        let one = Expression::from(1);
807        let two = Expression::from(2);
808
809        let expr = (x.clone() + y.clone() + z.clone() - one).pow(&two);
810        let expanded = expand(&expr);
811
812        // Verify by evaluation at multiple test points
813        for (x_val, y_val, z_val) in [
814            (0.0, 0.0, 0.0),
815            (1.0, 0.0, 0.0),
816            (1.0, 1.0, 0.0),
817            (0.0, 1.0, 1.0),
818            (1.0, 1.0, 1.0),
819            (0.5, 0.5, 0.0),
820        ] {
821            let mut values = std::collections::HashMap::new();
822            values.insert("x".to_string(), x_val);
823            values.insert("y".to_string(), y_val);
824            values.insert("z".to_string(), z_val);
825
826            let orig_val = expr.eval(&values).expect("eval original");
827            let exp_val = expanded.eval(&values).expect("eval expanded");
828
829            // Both should give same result
830            assert!(
831                (orig_val - exp_val).abs() < 1e-10,
832                "Mismatch at x={x_val}, y={y_val}, z={z_val}: orig={orig_val}, expanded={exp_val}"
833            );
834
835            // Expected: (x+y+z-1)^2
836            let expected = (x_val + y_val + z_val - 1.0).powi(2);
837            assert!(
838                (exp_val - expected).abs() < 1e-10,
839                "Unexpected value at x={x_val}, y={y_val}, z={z_val}: got {exp_val}, expected {expected}"
840            );
841        }
842    }
843}