1use std::collections::HashMap;
7
8use egg::{rewrite, CostFunction, Id, Language, RecExpr, Rewrite, Runner, Symbol};
9
10use crate::expr::{ExprLang, Expression};
11
12pub fn expand(expr: &Expression) -> Expression {
14 let expanded_pow = expand_powers(expr);
16
17 distribute_fully(&expanded_pow)
19}
20
21fn distribute_fully(expr: &Expression) -> Expression {
24 if expr.is_mul() {
26 let operands = expr.as_mul().expect("is_mul() was true");
28 let left = distribute_fully(&operands[0]);
29 let right = distribute_fully(&operands[1]);
30
31 distribute_product(&left, &right)
33 } else if expr.is_add() {
34 let operands = expr.as_add().expect("is_add() was true");
36 let left = distribute_fully(&operands[0]);
37 let right = distribute_fully(&operands[1]);
38 left + right
39 } else if expr.is_neg() {
40 let inner = expr.as_neg().expect("is_neg() was true");
42 -distribute_fully(&inner)
43 } else if expr.is_pow() {
44 let (base, exp) = expr.as_pow().expect("is_pow() was true");
46 let expanded_base = distribute_fully(&base);
47 expanded_base.pow(&exp)
48 } else {
49 expr.clone()
51 }
52}
53
54fn distribute_product(left: &Expression, right: &Expression) -> Expression {
56 let left_terms = collect_addends(left);
58 let right_terms = collect_addends(right);
60
61 let mut result_terms: Vec<Expression> = Vec::new();
63 for l in &left_terms {
64 for r in &right_terms {
65 let product = multiply_terms(l, r);
66 result_terms.push(product);
67 }
68 }
69
70 if result_terms.is_empty() {
72 Expression::zero()
73 } else {
74 let mut result = result_terms.remove(0);
75 for term in result_terms {
76 result = result + term;
77 }
78 result
79 }
80}
81
82fn collect_addends(expr: &Expression) -> Vec<Expression> {
84 if expr.is_add() {
85 let operands = expr.as_add().expect("is_add() was true");
87 let mut terms = collect_addends(&operands[0]);
88 terms.extend(collect_addends(&operands[1]));
89 terms
90 } else {
91 vec![expr.clone()]
92 }
93}
94
95fn multiply_terms(a: &Expression, b: &Expression) -> Expression {
97 let (a_neg, a_inner) = unwrap_neg(a);
99 let (b_neg, b_inner) = unwrap_neg(b);
100
101 let product = a_inner * b_inner;
102
103 if a_neg ^ b_neg {
105 -product
106 } else {
107 product
108 }
109}
110
111fn unwrap_neg(expr: &Expression) -> (bool, Expression) {
113 if expr.is_neg() {
114 let inner = expr.as_neg().expect("is_neg() was true");
116 let (inner_neg, inner_expr) = unwrap_neg(&inner);
117 (!inner_neg, inner_expr)
119 } else {
120 (false, expr.clone())
121 }
122}
123
124fn expand_powers(expr: &Expression) -> Expression {
126 if expr.is_pow() {
128 let (base, exp) = expr.as_pow().expect("is_pow() was true");
130
131 let expanded_base = expand_powers(&base);
133
134 if exp.is_number() {
136 if let Some(exp_val) = exp.to_f64() {
137 if (exp_val - 2.0).abs() < 1e-10 {
138 return expanded_base.clone() * expanded_base;
140 }
141 }
142 }
143
144 return expanded_base.pow(&exp);
146 }
147
148 if expr.is_add() {
150 let operands = expr.as_add().expect("is_add() was true");
152 let left = expand_powers(&operands[0]);
153 let right = expand_powers(&operands[1]);
154 return left + right;
155 }
156
157 if expr.is_mul() {
159 let operands = expr.as_mul().expect("is_mul() was true");
161 let left = expand_powers(&operands[0]);
162 let right = expand_powers(&operands[1]);
163 return left * right;
164 }
165
166 if expr.is_neg() {
168 let inner = expr.as_neg().expect("is_neg() was true");
170 return -expand_powers(&inner);
171 }
172
173 expr.clone()
175}
176
177pub fn simplify(expr: &Expression) -> Expression {
179 let rules = get_simplification_rules();
180
181 let runner = Runner::default()
182 .with_expr(expr.as_rec_expr())
183 .with_iter_limit(20)
184 .run(&rules);
185
186 let root = runner.roots[0];
187 let extractor = egg::Extractor::new(&runner.egraph, AstSize);
188 let (_, best) = extractor.find_best(root);
189
190 Expression::from_rec_expr(best)
191}
192
193pub fn substitute(expr: &Expression, var: &Expression, value: &Expression) -> Expression {
195 let var_name = match var.as_symbol() {
196 Some(name) => name.to_string(),
197 None => return expr.clone(), };
199
200 let rec_expr = expr.as_rec_expr();
201 let value_expr = value.as_rec_expr();
202
203 let mut new_expr = RecExpr::default();
205 let mut id_map: HashMap<usize, Id> = HashMap::new();
206
207 substitute_rec(
208 rec_expr,
209 rec_expr.as_ref().len() - 1,
210 &var_name,
211 value_expr,
212 &mut new_expr,
213 &mut id_map,
214 );
215
216 Expression::from_rec_expr(new_expr)
217}
218
219fn substitute_rec(
221 expr: &RecExpr<ExprLang>,
222 idx: usize,
223 var_name: &str,
224 value: &RecExpr<ExprLang>,
225 new_expr: &mut RecExpr<ExprLang>,
226 id_map: &mut HashMap<usize, Id>,
227) -> Id {
228 if let Some(&new_id) = id_map.get(&idx) {
229 return new_id;
230 }
231
232 let node = &expr[Id::from(idx)];
233
234 if let ExprLang::Num(s) = node {
236 if s.as_str() == var_name {
237 let offset = new_expr.as_ref().len();
239 for (i, n) in value.as_ref().iter().enumerate() {
240 let mapped_node = n
241 .clone()
242 .map_children(|child_id| Id::from(usize::from(child_id) + offset));
243 new_expr.add(mapped_node);
244 }
245 let new_id = Id::from(new_expr.as_ref().len() - 1);
246 id_map.insert(idx, new_id);
247 return new_id;
248 }
249 }
250
251 let new_node = node.clone().map_children(|child_id| {
253 substitute_rec(
254 expr,
255 usize::from(child_id),
256 var_name,
257 value,
258 new_expr,
259 id_map,
260 )
261 });
262 let new_id = new_expr.add(new_node);
263 id_map.insert(idx, new_id);
264 new_id
265}
266
267struct AstSize;
269
270impl CostFunction<ExprLang> for AstSize {
271 type Cost = usize;
272
273 fn cost<C>(&mut self, node: &ExprLang, mut costs: C) -> Self::Cost
274 where
275 C: FnMut(Id) -> Self::Cost,
276 {
277 let node_cost = match node {
278 ExprLang::Num(_) => 1,
280 _ => 3,
281 };
282
283 node.fold(node_cost, |sum, id| sum + costs(id))
284 }
285}
286
287#[allow(dead_code)]
290struct ExpandedSize;
291
292impl CostFunction<ExprLang> for ExpandedSize {
293 type Cost = usize;
294
295 fn cost<C>(&mut self, node: &ExprLang, mut costs: C) -> Self::Cost
296 where
297 C: FnMut(Id) -> Self::Cost,
298 {
299 let node_cost = match node {
300 ExprLang::Num(_) => 1,
301 ExprLang::Add(_) => 2,
303 ExprLang::Mul(_) => 4,
304 _ => 3,
305 };
306
307 node.fold(node_cost, |sum, id| sum + costs(id))
308 }
309}
310
311#[allow(dead_code)]
315fn get_distribution_rules() -> Vec<Rewrite<ExprLang, ()>> {
316 vec![
317 rewrite!("distrib-left"; "(* ?a (+ ?b ?c))" => "(+ (* ?a ?b) (* ?a ?c))"),
319 rewrite!("distrib-right"; "(* (+ ?a ?b) ?c)" => "(+ (* ?a ?c) (* ?b ?c))"),
320 rewrite!("neg-mul-left"; "(* (neg ?a) ?b)" => "(neg (* ?a ?b))"),
325 rewrite!("neg-mul-right"; "(* ?a (neg ?b))" => "(neg (* ?a ?b))"),
327 rewrite!("neg-neg-mul"; "(* (neg ?a) (neg ?b))" => "(* ?a ?b)"),
329 rewrite!("neg-add"; "(neg (+ ?a ?b))" => "(+ (neg ?a) (neg ?b))"),
331 rewrite!("neg-neg"; "(neg (neg ?a))" => "?a"),
333 rewrite!("mul-assoc"; "(* ?a (* ?b ?c))" => "(* (* ?a ?b) ?c)"),
335 rewrite!("mul-assoc-rev"; "(* (* ?a ?b) ?c)" => "(* ?a (* ?b ?c))"),
336 rewrite!("add-assoc"; "(+ ?a (+ ?b ?c))" => "(+ (+ ?a ?b) ?c)"),
338 rewrite!("add-assoc-rev"; "(+ (+ ?a ?b) ?c)" => "(+ ?a (+ ?b ?c))"),
339 rewrite!("mul-comm"; "(* ?a ?b)" => "(* ?b ?a)"),
341 rewrite!("add-comm"; "(+ ?a ?b)" => "(+ ?b ?a)"),
342 rewrite!("add-zero"; "(+ ?a 0)" => "?a"),
344 rewrite!("zero-add"; "(+ 0 ?a)" => "?a"),
345 rewrite!("mul-one"; "(* ?a 1)" => "?a"),
346 rewrite!("one-mul"; "(* 1 ?a)" => "?a"),
347 rewrite!("mul-zero"; "(* ?a 0)" => "0"),
348 rewrite!("zero-mul"; "(* 0 ?a)" => "0"),
349 rewrite!("neg-zero"; "(neg 0)" => "0"),
351 ]
352}
353
354fn get_simplification_rules() -> Vec<Rewrite<ExprLang, ()>> {
356 vec![
357 rewrite!("add-zero"; "(+ ?a 0)" => "?a"),
359 rewrite!("zero-add"; "(+ 0 ?a)" => "?a"),
360 rewrite!("mul-one"; "(* ?a 1)" => "?a"),
362 rewrite!("one-mul"; "(* 1 ?a)" => "?a"),
363 rewrite!("mul-zero"; "(* ?a 0)" => "0"),
365 rewrite!("zero-mul"; "(* 0 ?a)" => "0"),
366 rewrite!("neg-neg"; "(neg (neg ?a))" => "?a"),
368 rewrite!("pow-zero"; "(^ ?a 0)" => "1"),
370 rewrite!("pow-one"; "(^ ?a 1)" => "?a"),
371 rewrite!("add-comm"; "(+ ?a ?b)" => "(+ ?b ?a)"),
373 rewrite!("mul-comm"; "(* ?a ?b)" => "(* ?b ?a)"),
374 rewrite!("add-assoc"; "(+ ?a (+ ?b ?c))" => "(+ (+ ?a ?b) ?c)"),
376 rewrite!("mul-assoc"; "(* ?a (* ?b ?c))" => "(* (* ?a ?b) ?c)"),
377 rewrite!("distrib"; "(* ?a (+ ?b ?c))" => "(+ (* ?a ?b) (* ?a ?c))"),
379 rewrite!("exp-log"; "(exp (log ?a))" => "?a"),
384 rewrite!("log-exp"; "(log (exp ?a))" => "?a"),
385 rewrite!("sqrt-sq"; "(sqrt (^ ?a 2))" => "(abs ?a)"),
387 ]
388}
389
390pub fn get_quantum_rules() -> Vec<Rewrite<ExprLang, ()>> {
392 vec![
393 rewrite!("comm-self"; "(comm ?a ?a)" => "0"),
396 rewrite!("comm-antisym"; "(comm ?a ?b)" => "(neg (comm ?b ?a))"),
398 rewrite!("comm-zero-left"; "(comm 0 ?a)" => "0"),
400 rewrite!("comm-zero-right"; "(comm ?a 0)" => "0"),
401 rewrite!("anticomm-self"; "(anticomm ?a ?a)" => "(* 2 ?a)"),
404 rewrite!("anticomm-sym"; "(anticomm ?a ?b)" => "(anticomm ?b ?a)"),
406 rewrite!("anticomm-zero"; "(anticomm 0 ?a)" => "?a"),
408 rewrite!("dagger-dagger"; "(dagger (dagger ?a))" => "?a"),
411 rewrite!("dagger-mul"; "(dagger (* ?a ?b))" => "(* (dagger ?b) (dagger ?a))"),
413 rewrite!("dagger-add"; "(dagger (+ ?a ?b))" => "(+ (dagger ?a) (dagger ?b))"),
415 rewrite!("dagger-zero"; "(dagger 0)" => "0"),
420 rewrite!("dagger-one"; "(dagger 1)" => "1"),
422 rewrite!("trace-add"; "(trace (+ ?a ?b))" => "(+ (trace ?a) (trace ?b))"),
425 rewrite!("trace-scale"; "(trace (* ?c ?a))" => "(* ?c (trace ?a))"),
427 rewrite!("trace-zero"; "(trace 0)" => "0"),
429 rewrite!("tensor-mul"; "(* (tensor ?a ?b) (tensor ?c ?d))" => "(tensor (* ?a ?c) (* ?b ?d))"),
432 rewrite!("tensor-one-right"; "(tensor ?a 1)" => "?a"),
434 rewrite!("tensor-one-left"; "(tensor 1 ?a)" => "?a"),
435 rewrite!("tensor-zero"; "(tensor ?a 0)" => "0"),
437 rewrite!("tensor-zero-left"; "(tensor 0 ?a)" => "0"),
438 rewrite!("det-one"; "(det 1)" => "1"),
442 rewrite!("transpose-transpose"; "(transpose (transpose ?a))" => "?a"),
445 rewrite!("transpose-mul"; "(transpose (* ?a ?b))" => "(* (transpose ?b) (transpose ?a))"),
447 rewrite!("transpose-add"; "(transpose (+ ?a ?b))" => "(+ (transpose ?a) (transpose ?b))"),
449 ]
450}
451
452pub fn simplify_quantum(expr: &Expression) -> Expression {
454 let mut rules = get_simplification_rules();
455 rules.extend(get_quantum_rules());
456
457 let runner = Runner::default()
458 .with_expr(expr.as_rec_expr())
459 .with_iter_limit(30)
460 .run(&rules);
461
462 let root = runner.roots[0];
463 let extractor = egg::Extractor::new(&runner.egraph, AstSize);
464 let (_, best) = extractor.find_best(root);
465
466 Expression::from_rec_expr(best)
467}
468
469pub fn get_trig_rules() -> Vec<Rewrite<ExprLang, ()>> {
471 vec![
472 rewrite!("sin-zero"; "(sin 0)" => "0"),
477 rewrite!("cos-zero"; "(cos 0)" => "1"),
479 rewrite!("tan-zero"; "(tan 0)" => "0"),
481 rewrite!("exp-zero"; "(exp 0)" => "1"),
483 rewrite!("log-one"; "(log 1)" => "0"),
485 rewrite!("sin-neg"; "(sin (neg ?x))" => "(neg (sin ?x))"),
487 rewrite!("cos-neg"; "(cos (neg ?x))" => "(cos ?x)"),
489 rewrite!("tan-neg"; "(tan (neg ?x))" => "(neg (tan ?x))"),
491 rewrite!("exp-add"; "(exp (+ ?a ?b))" => "(* (exp ?a) (exp ?b))"),
493 rewrite!("log-mul"; "(log (* ?a ?b))" => "(+ (log ?a) (log ?b))"),
495 rewrite!("exp-log"; "(exp (log ?x))" => "?x"),
497 rewrite!("log-exp"; "(log (exp ?x))" => "?x"),
499 rewrite!("sqrt-sq"; "(^ (sqrt ?x) 2)" => "?x"),
501 rewrite!("sq-sqrt"; "(sqrt (^ ?x 2))" => "(abs ?x)"),
503 ]
504}
505
506pub fn simplify_trig(expr: &Expression) -> Expression {
508 let mut rules = get_simplification_rules();
509 rules.extend(get_trig_rules());
510
511 let runner = Runner::default()
512 .with_expr(expr.as_rec_expr())
513 .with_iter_limit(30)
514 .run(&rules);
515
516 let root = runner.roots[0];
517 let extractor = egg::Extractor::new(&runner.egraph, AstSize);
518 let (_, best) = extractor.find_best(root);
519
520 Expression::from_rec_expr(best)
521}
522
523pub fn collect(expr: &Expression, var: &Expression) -> Expression {
528 let expanded = expand(expr);
530
531 simplify(&expanded)
534}
535
536pub fn factor(expr: &Expression) -> Expression {
540 let factor_rules = vec![
541 rewrite!("factor-left"; "(+ (* ?a ?b) (* ?a ?c))" => "(* ?a (+ ?b ?c))"),
543 rewrite!("factor-right"; "(+ (* ?a ?c) (* ?b ?c))" => "(* (+ ?a ?b) ?c)"),
544 rewrite!("add-same"; "(+ ?a ?a)" => "(* 2 ?a)"),
546 rewrite!("mul-one"; "(* ?a 1)" => "?a"),
548 rewrite!("mul-zero"; "(* ?a 0)" => "0"),
549 ];
550
551 let runner: Runner<ExprLang, ()> = Runner::default()
552 .with_expr(expr.as_rec_expr())
553 .with_iter_limit(20)
554 .run(&factor_rules);
555
556 let root = runner.roots[0];
557
558 let extractor = egg::Extractor::new(&runner.egraph, FactoredSize);
560 let (_, best) = extractor.find_best(root);
561
562 Expression::from_rec_expr(best)
563}
564
565struct FactoredSize;
567
568impl CostFunction<ExprLang> for FactoredSize {
569 type Cost = usize;
570
571 fn cost<C>(&mut self, node: &ExprLang, mut costs: C) -> Self::Cost
572 where
573 C: FnMut(Id) -> Self::Cost,
574 {
575 let node_cost = match node {
576 ExprLang::Num(_) => 1,
577 ExprLang::Mul(_) => 2,
579 ExprLang::Add(_) => 4,
580 _ => 3,
581 };
582
583 node.fold(node_cost, |sum, id| sum + costs(id))
584 }
585}
586
587#[cfg(test)]
588#[allow(clippy::redundant_clone)]
589mod tests {
590 use super::*;
591
592 #[test]
593 fn test_simplify_add_zero() {
594 let x = Expression::symbol("x");
595 let zero = Expression::zero();
596 let expr = x + zero;
597
598 let simplified = simplify(&expr);
599 assert!(simplified.as_symbol().is_some());
601 }
602
603 #[test]
604 fn test_simplify_mul_one() {
605 let x = Expression::symbol("x");
606 let one = Expression::one();
607 let expr = x * one;
608
609 let simplified = simplify(&expr);
610 assert!(simplified.as_symbol().is_some());
611 }
612
613 #[test]
614 fn test_simplify_mul_zero() {
615 let x = Expression::symbol("x");
616 let zero = Expression::zero();
617 let expr = x * zero;
618
619 let simplified = simplify(&expr);
620 assert!(simplified.is_zero());
621 }
622
623 #[test]
624 fn test_substitute_simple() {
625 let x = Expression::symbol("x");
626 let y = Expression::symbol("y");
627 let two = Expression::int(2);
628
629 let expr = x.clone() + y;
631 let result = substitute(&expr, &x, &two);
632
633 let mut values = std::collections::HashMap::new();
635 values.insert("y".to_string(), 3.0);
636 let eval_result = result.eval(&values);
637 assert!(eval_result.is_ok());
638 assert!((eval_result.expect("eval") - 5.0).abs() < 1e-10);
639 }
640
641 #[test]
642 fn test_substitute_nested() {
643 let x = Expression::symbol("x");
644 let y = Expression::symbol("y");
645
646 let expr = x.clone() * x.clone();
648 let result = substitute(&expr, &x, &y);
649
650 let mut values = std::collections::HashMap::new();
652 values.insert("y".to_string(), 3.0);
653 let eval_result = result.eval(&values);
654 assert!(eval_result.is_ok());
655 assert!((eval_result.expect("eval") - 9.0).abs() < 1e-10);
656 }
657
658 #[test]
659 fn test_expand_distribution() {
660 let x = Expression::symbol("x");
661 let y = Expression::symbol("y");
662 let z = Expression::symbol("z");
663
664 let expr = x * (y + z);
666 let expanded = expand(&expr);
667
668 let mut values = std::collections::HashMap::new();
670 values.insert("x".to_string(), 2.0);
671 values.insert("y".to_string(), 3.0);
672 values.insert("z".to_string(), 4.0);
673
674 let orig_val = expr.eval(&values).expect("eval original");
675 let exp_val = expanded.eval(&values).expect("eval expanded");
676
677 assert!((orig_val - exp_val).abs() < 1e-10);
678 assert!((exp_val - 14.0).abs() < 1e-10); }
680
681 #[test]
682 fn test_factor_common_terms() {
683 let a = Expression::symbol("a");
684 let x = Expression::symbol("x");
685 let y = Expression::symbol("y");
686
687 let expr = a.clone() * x.clone() + a.clone() * y.clone();
689 let factored = factor(&expr);
690
691 let mut values = std::collections::HashMap::new();
693 values.insert("a".to_string(), 2.0);
694 values.insert("x".to_string(), 3.0);
695 values.insert("y".to_string(), 4.0);
696
697 let orig_val = expr.eval(&values).expect("eval original");
698 let fact_val = factored.eval(&values).expect("eval factored");
699
700 assert!((orig_val - fact_val).abs() < 1e-10);
701 assert!((fact_val - 14.0).abs() < 1e-10); }
703
704 #[test]
705 fn test_simplify_trig() {
706 let zero = Expression::zero();
708 let sin_zero = crate::ops::trig::sin(&zero);
709 let simplified = simplify_trig(&sin_zero);
710
711 let result = simplified.eval(&std::collections::HashMap::new());
714 assert!(result.is_ok());
715 assert!(result.expect("eval").abs() < 1e-10);
716 }
717
718 #[test]
719 fn test_simplify_quantum_dagger() {
720 let rules = get_quantum_rules();
724 assert!(!rules.is_empty());
725
726 assert!(rules.len() >= 15);
729 }
730
731 #[test]
732 fn test_collect() {
733 let x = Expression::symbol("x");
734
735 let expr = x.clone() + x.clone();
737 let collected = collect(&expr, &x);
738
739 let mut values = std::collections::HashMap::new();
741 values.insert("x".to_string(), 5.0);
742
743 let orig_val = expr.eval(&values).expect("eval original");
744 let coll_val = collected.eval(&values).expect("eval collected");
745
746 assert!((orig_val - coll_val).abs() < 1e-10);
747 assert!((coll_val - 10.0).abs() < 1e-10); }
749
750 #[test]
751 fn test_expand_simple_pow2() {
752 let a = Expression::symbol("a");
754 let two = Expression::from(2);
755
756 let expr = a.clone().pow(&two);
757 let expanded = expand(&expr);
758
759 let mut values = std::collections::HashMap::new();
761 values.insert("a".to_string(), 3.0);
762 let exp_val = expanded.eval(&values).expect("eval");
763 assert!((exp_val - 9.0).abs() < 1e-10);
764 }
765
766 #[test]
767 fn test_expand_binomial_squared() {
768 let a = Expression::symbol("a");
770 let b = Expression::symbol("b");
771 let two = Expression::from(2);
772
773 let expr = (a.clone() + b.clone()).pow(&two);
774 let expanded = expand(&expr);
775
776 for (a_val, b_val) in [(2.0, 3.0), (1.0, 1.0), (0.0, 5.0)] {
778 let mut values = std::collections::HashMap::new();
779 values.insert("a".to_string(), a_val);
780 values.insert("b".to_string(), b_val);
781
782 let orig_val = expr.eval(&values).expect("eval original");
783 let exp_val = expanded.eval(&values).expect("eval expanded");
784
785 assert!(
787 (orig_val - exp_val).abs() < 1e-10,
788 "Mismatch at a={a_val}, b={b_val}: orig={orig_val}, expanded={exp_val}"
789 );
790 let expected = (a_val + b_val).powi(2);
792 assert!(
793 (exp_val - expected).abs() < 1e-10,
794 "Unexpected value at a={a_val}, b={b_val}: got {exp_val}, expected {expected}"
795 );
796 }
797 }
798
799 #[test]
800 fn test_expand_polynomial_constraint() {
801 let x = Expression::symbol("x");
804 let y = Expression::symbol("y");
805 let z = Expression::symbol("z");
806 let one = Expression::from(1);
807 let two = Expression::from(2);
808
809 let expr = (x.clone() + y.clone() + z.clone() - one).pow(&two);
810 let expanded = expand(&expr);
811
812 for (x_val, y_val, z_val) in [
814 (0.0, 0.0, 0.0),
815 (1.0, 0.0, 0.0),
816 (1.0, 1.0, 0.0),
817 (0.0, 1.0, 1.0),
818 (1.0, 1.0, 1.0),
819 (0.5, 0.5, 0.0),
820 ] {
821 let mut values = std::collections::HashMap::new();
822 values.insert("x".to_string(), x_val);
823 values.insert("y".to_string(), y_val);
824 values.insert("z".to_string(), z_val);
825
826 let orig_val = expr.eval(&values).expect("eval original");
827 let exp_val = expanded.eval(&values).expect("eval expanded");
828
829 assert!(
831 (orig_val - exp_val).abs() < 1e-10,
832 "Mismatch at x={x_val}, y={y_val}, z={z_val}: orig={orig_val}, expanded={exp_val}"
833 );
834
835 let expected = (x_val + y_val + z_val - 1.0).powi(2);
837 assert!(
838 (exp_val - expected).abs() < 1e-10,
839 "Unexpected value at x={x_val}, y={y_val}, z={z_val}: got {exp_val}, expected {expected}"
840 );
841 }
842 }
843}