mathhook_core/simplify/arithmetic/
addition.rs

1//! Addition simplification operations
2
3use super::helpers::{expression_order, extract_arithmetic_coefficient_and_base};
4use super::multiplication::simplify_multiplication;
5use super::power::simplify_power;
6use super::Simplify;
7use crate::core::commutativity::Commutativity;
8use crate::core::constants::EPSILON;
9use crate::core::{Expression, Number};
10use num_bigint::BigInt;
11use num_rational::BigRational;
12use num_traits::{ToPrimitive, Zero};
13use std::collections::VecDeque;
14
15fn extract_trig_squared(expr: &Expression, func: &str) -> Option<Expression> {
16    if let Expression::Pow(base, exp) = expr {
17        if let Expression::Number(Number::Integer(2)) = exp.as_ref() {
18            if let Expression::Function { name, args } = base.as_ref() {
19                if name == func && args.len() == 1 {
20                    return Some(args[0].clone());
21                }
22            }
23        }
24    }
25    None
26}
27
28fn check_pythagorean(terms: &[Expression]) -> Option<Vec<Expression>> {
29    for (i, t1) in terms.iter().enumerate() {
30        for (j, t2) in terms.iter().enumerate() {
31            if i >= j {
32                continue;
33            }
34            if let (Some(arg1), Some(arg2)) = (
35                extract_trig_squared(t1, "sin"),
36                extract_trig_squared(t2, "cos"),
37            ) {
38                if arg1 == arg2 {
39                    let mut remaining: Vec<_> = terms
40                        .iter()
41                        .enumerate()
42                        .filter(|(k, _)| *k != i && *k != j)
43                        .map(|(_, e)| e.clone())
44                        .collect();
45                    remaining.push(Expression::integer(1));
46                    return Some(remaining);
47                }
48            }
49            if let (Some(arg1), Some(arg2)) = (
50                extract_trig_squared(t1, "cos"),
51                extract_trig_squared(t2, "sin"),
52            ) {
53                if arg1 == arg2 {
54                    let mut remaining: Vec<_> = terms
55                        .iter()
56                        .enumerate()
57                        .filter(|(k, _)| *k != i && *k != j)
58                        .map(|(_, e)| e.clone())
59                        .collect();
60                    remaining.push(Expression::integer(1));
61                    return Some(remaining);
62                }
63            }
64        }
65    }
66    None
67}
68
69/// Simplify addition expressions with minimal overhead
70pub fn simplify_addition(terms: &[Expression]) -> Expression {
71    if terms.is_empty() {
72        return Expression::integer(0);
73    }
74
75    // Iteratively flatten nested Add expressions and distribute Mul over Add
76    let mut flattened_terms: Vec<Expression> = Vec::new();
77    let mut to_process: VecDeque<&Expression> = terms.iter().collect();
78
79    while let Some(term) = to_process.pop_front() {
80        match term {
81            Expression::Add(nested_terms) => {
82                for nested_term in nested_terms.iter().rev() {
83                    to_process.push_front(nested_term);
84                }
85            }
86            // c * (a + b) → c*a + c*b (distribute numeric constants over addition)
87            Expression::Mul(factors) if factors.len() == 2 => {
88                if let (Expression::Number(coeff), Expression::Add(add_terms)) =
89                    (&factors[0], &factors[1])
90                {
91                    for add_term in add_terms.iter() {
92                        let distributed = Expression::mul(vec![
93                            Expression::Number(coeff.clone()),
94                            add_term.clone(),
95                        ]);
96                        flattened_terms.push(distributed);
97                    }
98                } else if let (Expression::Add(add_terms), Expression::Number(coeff)) =
99                    (&factors[0], &factors[1])
100                {
101                    for add_term in add_terms.iter() {
102                        let distributed = Expression::mul(vec![
103                            Expression::Number(coeff.clone()),
104                            add_term.clone(),
105                        ]);
106                        flattened_terms.push(distributed);
107                    }
108                } else {
109                    flattened_terms.push(term.clone());
110                }
111            }
112            _ => flattened_terms.push(term.clone()),
113        }
114    }
115
116    // Use flattened terms for the rest of the function
117    let terms = &flattened_terms;
118
119    // Matrix fast-path: try direct matrix addition for 2-term case
120    // Note: During simplification, we only apply the fast-path if it succeeds.
121    // If dimensions are incompatible (Some(Err(_))), we fall through to symbolic form.
122    // Domain errors will be caught during evaluation, not simplification.
123    if terms.len() == 2 {
124        if let Some(Ok(result)) = super::matrix_ops::try_matrix_add(&terms[0], &terms[1]) {
125            return result;
126        }
127    }
128
129    // Ultra-fast path for numeric addition
130    let mut int_sum = 0i64;
131    let mut float_sum = 0.0;
132    let mut has_float = false;
133    let mut rational_sum: Option<BigRational> = None;
134    let mut non_numeric_count = 0;
135    let mut first_non_numeric: Option<Expression> = None;
136    let mut numeric_result = None;
137
138    for term in terms {
139        // Simplify the term, but avoid recursive calls for Add expressions (already flattened)
140        let simplified_term = match term {
141            Expression::Add(_) => {
142                // Add expressions should already be flattened, so this shouldn't happen
143                // But if it does, just use the term as-is to avoid recursion
144                term.clone()
145            }
146            Expression::Mul(factors) => simplify_multiplication(factors),
147            Expression::Pow(base, exp) => simplify_power(base, exp),
148            _ => term.simplify(),
149        };
150        match simplified_term {
151            Expression::Number(Number::Integer(n)) => {
152                int_sum = int_sum.saturating_add(n);
153            }
154            Expression::Number(Number::Float(f)) => {
155                float_sum += f;
156                has_float = true;
157            }
158            Expression::Number(Number::Rational(r)) => {
159                if let Some(ref mut current_sum) = rational_sum {
160                    *current_sum += r.as_ref();
161                } else {
162                    rational_sum = Some(r.as_ref().clone());
163                }
164            }
165            _ => {
166                non_numeric_count += 1;
167                if first_non_numeric.is_none() {
168                    first_non_numeric = Some(simplified_term);
169                }
170            }
171        }
172    }
173
174    // Determine numeric result
175    if let Some(rational) = rational_sum {
176        // Combine rational with integer and float sums
177        let mut final_rational = rational;
178        if int_sum != 0 {
179            final_rational += BigRational::from(BigInt::from(int_sum));
180        }
181        if has_float {
182            // Convert to float if we have float terms
183            let float_val = final_rational.to_f64().unwrap_or(0.0) + float_sum;
184            if float_val.abs() >= EPSILON {
185                numeric_result = Some(Expression::Number(Number::float(float_val)));
186            }
187        } else {
188            // Keep as rational if it's not zero
189            if !final_rational.is_zero() {
190                numeric_result = Some(Expression::Number(Number::rational(final_rational)));
191            }
192        }
193    } else if has_float {
194        let total = int_sum as f64 + float_sum;
195        if total.abs() >= EPSILON {
196            numeric_result = Some(Expression::Number(Number::float(total)));
197        }
198    } else if int_sum != 0 {
199        numeric_result = Some(Expression::integer(int_sum));
200    }
201
202    match (numeric_result.as_ref(), non_numeric_count) {
203        (None, 0) => Expression::integer(0),
204        (Some(num), 0) => num.clone(),
205        (None, 1) => {
206            // Return the single remaining term (already simplified)
207            first_non_numeric.expect("BUG: non_numeric_count is 1 but first_non_numeric is None")
208        }
209        (Some(num), 1) => {
210            // Use the already simplified non-numeric term
211            let simplified_non_numeric = first_non_numeric
212                .expect("BUG: non_numeric_count is 1 but first_non_numeric is None");
213            // If numeric part is zero, just return the non-numeric part
214            match num {
215                Expression::Number(Number::Integer(0)) => simplified_non_numeric,
216                Expression::Number(Number::Float(f)) if f.abs() < EPSILON => simplified_non_numeric,
217                _ => Expression::Add(Box::new(vec![num.clone(), simplified_non_numeric])),
218            }
219        }
220        _ => {
221            // Multiple non-numeric terms - collect like terms and build result efficiently
222            let mut result_terms = Vec::with_capacity(non_numeric_count + 1);
223            if let Some(num) = numeric_result {
224                // Only include numeric result if it's not zero
225                match num {
226                    Expression::Number(Number::Integer(0)) => {}
227                    Expression::Number(Number::Float(0.0)) => {}
228                    _ => result_terms.push(num),
229                }
230            }
231
232            // Collect like terms using an order-preserving approach
233            // For noncommutative terms, only combine if structurally identical
234            let mut like_terms: Vec<(String, Expression, Vec<Expression>)> = Vec::new();
235
236            for term in terms {
237                if !matches!(term, Expression::Number(_)) {
238                    // Each non-numeric term - use controlled simplification to avoid recursion
239                    let simplified_term = match term {
240                        Expression::Add(_) => term.clone(), // Already flattened
241                        Expression::Mul(factors) => simplify_multiplication(factors),
242                        Expression::Pow(base, exp) => simplify_power(base, exp),
243                        _ => term.simplify(),
244                    };
245                    match simplified_term {
246                        Expression::Number(Number::Integer(0)) => {}
247                        Expression::Number(Number::Float(0.0)) => {}
248                        _ => {
249                            // Extract coefficient and base term
250                            let (coeff, base) =
251                                extract_arithmetic_coefficient_and_base(&simplified_term);
252
253                            let base_key = format!("{:?}", base);
254
255                            // Find existing entry or create new one
256                            if let Some(entry) =
257                                like_terms.iter_mut().find(|(key, _, _)| key == &base_key)
258                            {
259                                entry.2.push(coeff);
260                            } else {
261                                like_terms.push((base_key, base.clone(), vec![coeff]));
262                            }
263                        }
264                    }
265                }
266            }
267
268            // Combine like terms
269            for (_, base, coeffs) in like_terms {
270                if coeffs.len() == 1 {
271                    // Single term, reconstruct if coefficient is not 1
272                    let coeff = &coeffs[0];
273                    match coeff {
274                        Expression::Number(Number::Integer(1)) => {
275                            // Just add the base term (coefficient is 1)
276                            result_terms.push(base);
277                        }
278                        _ => {
279                            result_terms.push(Expression::Mul(Box::new(vec![coeff.clone(), base])));
280                        }
281                    }
282                } else {
283                    // Multiple coefficients for the same base - sum them
284                    let coeff_sum = simplify_addition(&coeffs);
285                    match coeff_sum {
286                        Expression::Number(Number::Integer(0)) => {}
287                        Expression::Number(Number::Float(0.0)) => {}
288                        Expression::Number(Number::Integer(1)) => {
289                            // Coefficient sum is 1, just add the base
290                            result_terms.push(base);
291                        }
292                        _ => {
293                            result_terms.push(Expression::Mul(Box::new(vec![coeff_sum, base])));
294                        }
295                    }
296                }
297            }
298
299            // sin²(x) + cos²(x) = 1
300            if let Some(pythagorean_terms) = check_pythagorean(&result_terms) {
301                return simplify_addition(&pythagorean_terms);
302            }
303
304            match result_terms.len() {
305                0 => Expression::integer(0),
306                1 => result_terms
307                    .into_iter()
308                    .next()
309                    .expect("BUG: result_terms has length 1 but iterator is empty"),
310                _ => {
311                    // Check commutativity BEFORE sorting
312                    // Only sort if all terms are commutative (safe to reorder)
313                    let commutativity =
314                        Commutativity::combine(result_terms.iter().map(|t| t.commutativity()));
315
316                    if commutativity.can_sort() {
317                        // Safe to sort - all terms commutative
318                        result_terms.sort_by(expression_order);
319                    }
320                    // Else: preserve order for noncommutative terms
321
322                    Expression::Add(Box::new(result_terms))
323                }
324            }
325        }
326    }
327}
328
329#[cfg(test)]
330mod tests {
331    use super::*;
332    use crate::simplify::Simplify;
333    use crate::{expr, symbol, Expression};
334
335    #[test]
336    fn test_addition_simplification() {
337        // Simple integer addition
338        let expr = simplify_addition(&[Expression::integer(2), Expression::integer(3)]);
339        assert_eq!(expr, Expression::integer(5));
340
341        // Addition with zero
342        let expr = simplify_addition(&[Expression::integer(5), Expression::integer(0)]);
343        assert_eq!(expr, Expression::integer(5));
344
345        // Mixed numeric and symbolic
346        let x = symbol!(x);
347        let expr = simplify_addition(&[Expression::integer(2), Expression::symbol(x.clone())]);
348        assert_eq!(
349            expr,
350            Expression::add(vec![Expression::integer(2), Expression::symbol(x)])
351        );
352    }
353
354    #[test]
355    fn test_scalar_terms_combine() {
356        let x = symbol!(x);
357        let y = symbol!(y);
358
359        // x*y + y*x should combine to 2*x*y (commutative)
360        let xy = Expression::mul(vec![
361            Expression::symbol(x.clone()),
362            Expression::symbol(y.clone()),
363        ]);
364        let yx = Expression::mul(vec![
365            Expression::symbol(y.clone()),
366            Expression::symbol(x.clone()),
367        ]);
368        let expr = Expression::add(vec![xy.clone(), yx.clone()]);
369
370        let simplified = expr.simplify();
371
372        match simplified {
373            Expression::Mul(factors) => {
374                assert_eq!(factors.len(), 3);
375                assert_eq!(factors[0], Expression::integer(2));
376            }
377            _ => panic!("Expected Mul, got {:?}", simplified),
378        }
379    }
380
381    #[test]
382    fn test_matrix_terms_not_combined() {
383        let mat_a = symbol!(A; matrix);
384        let mat_b = symbol!(B; matrix);
385
386        // A*B + B*A should NOT combine (noncommutative)
387        let ab = Expression::mul(vec![
388            Expression::symbol(mat_a.clone()),
389            Expression::symbol(mat_b.clone()),
390        ]);
391        let ba = Expression::mul(vec![
392            Expression::symbol(mat_b.clone()),
393            Expression::symbol(mat_a.clone()),
394        ]);
395        let expr = Expression::add(vec![ab.clone(), ba.clone()]);
396
397        let simplified = expr.simplify();
398
399        match simplified {
400            Expression::Add(terms) => {
401                assert_eq!(terms.len(), 2);
402            }
403            _ => panic!("Expected Add with 2 terms, got {:?}", simplified),
404        }
405    }
406
407    #[test]
408    fn test_identical_matrix_terms_combine() {
409        let mat_a = symbol!(A; matrix);
410        let mat_b = symbol!(B; matrix);
411
412        // A*B + A*B should combine to 2*A*B (same term)
413        let ab1 = Expression::mul(vec![
414            Expression::symbol(mat_a.clone()),
415            Expression::symbol(mat_b.clone()),
416        ]);
417        let ab2 = Expression::mul(vec![
418            Expression::symbol(mat_a.clone()),
419            Expression::symbol(mat_b.clone()),
420        ]);
421        let expr = Expression::add(vec![ab1, ab2]);
422
423        let simplified = expr.simplify();
424
425        match simplified {
426            Expression::Mul(factors) => {
427                assert_eq!(factors.len(), 3);
428                assert_eq!(factors[0], Expression::integer(2));
429            }
430            _ => panic!("Expected Mul, got {:?}", simplified),
431        }
432    }
433
434    #[test]
435    fn test_operator_terms_not_combined() {
436        let operator_p = symbol!(P; operator);
437        let operator_q = symbol!(Q; operator);
438
439        // P*Q + Q*P should NOT combine (noncommutative)
440        let pq = Expression::mul(vec![
441            Expression::symbol(operator_p.clone()),
442            Expression::symbol(operator_q.clone()),
443        ]);
444        let qp = Expression::mul(vec![
445            Expression::symbol(operator_q.clone()),
446            Expression::symbol(operator_p.clone()),
447        ]);
448        let expr = Expression::add(vec![pq, qp]);
449
450        let simplified = expr.simplify();
451
452        match simplified {
453            Expression::Add(terms) => {
454                assert_eq!(terms.len(), 2);
455            }
456            _ => panic!("Expected Add with 2 terms, got {:?}", simplified),
457        }
458    }
459
460    #[test]
461    fn test_quaternion_terms_not_combined() {
462        let i = symbol!(i; quaternion);
463        let j = symbol!(j; quaternion);
464
465        // i*j + j*i should NOT combine (noncommutative)
466        let ij = Expression::mul(vec![
467            Expression::symbol(i.clone()),
468            Expression::symbol(j.clone()),
469        ]);
470        let ji = Expression::mul(vec![
471            Expression::symbol(j.clone()),
472            Expression::symbol(i.clone()),
473        ]);
474        let expr = Expression::add(vec![ij, ji]);
475
476        let simplified = expr.simplify();
477
478        match simplified {
479            Expression::Add(terms) => {
480                assert_eq!(terms.len(), 2);
481            }
482            _ => panic!("Expected Add with 2 terms, got {:?}", simplified),
483        }
484    }
485
486    #[test]
487    fn test_scalar_addition_sorts() {
488        let y = symbol!(y);
489        let x = symbol!(x);
490        let expr = Expression::add(vec![
491            Expression::symbol(y.clone()),
492            Expression::symbol(x.clone()),
493        ]);
494        let simplified = expr.simplify();
495
496        match simplified {
497            Expression::Add(terms) => {
498                assert_eq!(terms.len(), 2);
499                assert_eq!(terms[0], Expression::symbol(symbol!(x)));
500                assert_eq!(terms[1], Expression::symbol(symbol!(y)));
501            }
502            _ => panic!("Expected Add, got {:?}", simplified),
503        }
504    }
505
506    #[test]
507    fn test_matrix_addition_preserves_order() {
508        let mat_b = symbol!(B; matrix);
509        let mat_a = symbol!(A; matrix);
510        let expr = Expression::add(vec![
511            Expression::symbol(mat_b.clone()),
512            Expression::symbol(mat_a.clone()),
513        ]);
514        let simplified = expr.simplify();
515
516        match simplified {
517            Expression::Add(terms) => {
518                assert_eq!(terms.len(), 2);
519                assert_eq!(terms[0], Expression::symbol(symbol!(B; matrix)));
520                assert_eq!(terms[1], Expression::symbol(symbol!(A; matrix)));
521            }
522            _ => panic!("Expected Add, got {:?}", simplified),
523        }
524    }
525
526    #[test]
527    fn test_mixed_scalar_matrix_addition_preserves_order() {
528        let x = symbol!(x);
529        let mat_a = symbol!(A; matrix);
530        let expr = Expression::add(vec![
531            Expression::symbol(x.clone()),
532            Expression::symbol(mat_a.clone()),
533        ]);
534        let simplified = expr.simplify();
535
536        match simplified {
537            Expression::Add(terms) => {
538                assert_eq!(terms.len(), 2);
539                assert_eq!(terms[0], expr!(x));
540                assert_eq!(terms[1], Expression::symbol(symbol!(A; matrix)));
541            }
542            _ => panic!("Expected Add, got {:?}", simplified),
543        }
544    }
545
546    #[test]
547    fn test_three_scalar_like_terms_combine() {
548        let x = symbol!(x);
549        let expr = Expression::add(vec![
550            Expression::symbol(x.clone()),
551            Expression::symbol(x.clone()),
552            Expression::symbol(x.clone()),
553        ]);
554        let simplified = expr.simplify();
555
556        match simplified {
557            Expression::Mul(factors) => {
558                assert_eq!(factors.len(), 2);
559                assert_eq!(factors[0], Expression::integer(3));
560                assert_eq!(factors[1], expr!(x));
561            }
562            _ => panic!("Expected Mul, got {:?}", simplified),
563        }
564    }
565
566    #[test]
567    fn test_three_matrix_like_terms_combine() {
568        let mat_a = symbol!(A; matrix);
569        let expr = Expression::add(vec![
570            Expression::symbol(mat_a.clone()),
571            Expression::symbol(mat_a.clone()),
572            Expression::symbol(mat_a.clone()),
573        ]);
574        let simplified = expr.simplify();
575
576        match simplified {
577            Expression::Mul(factors) => {
578                assert_eq!(factors.len(), 2);
579                assert_eq!(factors[0], Expression::integer(3));
580                assert_eq!(factors[1], Expression::symbol(symbol!(A; matrix)));
581            }
582            _ => panic!("Expected Mul, got {:?}", simplified),
583        }
584    }
585
586    #[test]
587    fn test_incompatible_matrix_addition_during_simplification() {
588        let a = Expression::matrix(vec![vec![expr!(1), expr!(2)], vec![expr!(3), expr!(4)]]);
589        let b = Expression::matrix(vec![vec![expr!(5), expr!(6), expr!(7)]]);
590
591        let expr = Expression::add(vec![a.clone(), b.clone()]);
592        let simplified = expr.simplify();
593
594        // During simplification, incompatible matrices are NOT simplified
595        // They remain in symbolic Add form
596        // The error will be caught during evaluation, not simplification
597        match simplified {
598            Expression::Add(terms) => {
599                assert_eq!(terms.len(), 2);
600            }
601            _ => panic!(
602                "Expected Add with 2 terms for incompatible matrices during simplification, got {:?}",
603                simplified
604            ),
605        }
606    }
607
608    #[test]
609    fn test_pythagorean_identity_sin_cos() {
610        let x = symbol!(x);
611        let sin_x = Expression::function("sin", vec![Expression::symbol(x.clone())]);
612        let cos_x = Expression::function("cos", vec![Expression::symbol(x.clone())]);
613        let sin_squared = Expression::pow(sin_x, Expression::integer(2));
614        let cos_squared = Expression::pow(cos_x, Expression::integer(2));
615
616        let expr = Expression::add(vec![sin_squared, cos_squared]);
617        let simplified = expr.simplify();
618
619        assert_eq!(simplified, Expression::integer(1));
620    }
621
622    #[test]
623    fn test_pythagorean_identity_cos_sin() {
624        let x = symbol!(x);
625        let sin_x = Expression::function("sin", vec![Expression::symbol(x.clone())]);
626        let cos_x = Expression::function("cos", vec![Expression::symbol(x.clone())]);
627        let sin_squared = Expression::pow(sin_x, Expression::integer(2));
628        let cos_squared = Expression::pow(cos_x, Expression::integer(2));
629
630        let expr = Expression::add(vec![cos_squared, sin_squared]);
631        let simplified = expr.simplify();
632
633        assert_eq!(simplified, Expression::integer(1));
634    }
635
636    #[test]
637    fn test_pythagorean_identity_different_args() {
638        let x = symbol!(x);
639        let y = symbol!(y);
640        let sin_x = Expression::function("sin", vec![Expression::symbol(x.clone())]);
641        let cos_y = Expression::function("cos", vec![Expression::symbol(y.clone())]);
642        let sin_squared = Expression::pow(sin_x, Expression::integer(2));
643        let cos_squared = Expression::pow(cos_y, Expression::integer(2));
644
645        let expr = Expression::add(vec![sin_squared, cos_squared]);
646        let simplified = expr.simplify();
647
648        match simplified {
649            Expression::Add(_) => {}
650            _ => panic!("Expected Add (unchanged), got {:?}", simplified),
651        }
652    }
653
654    #[test]
655    fn test_pythagorean_identity_with_additional_terms() {
656        let x = symbol!(x);
657        let y = symbol!(y);
658        let sin_x = Expression::function("sin", vec![Expression::symbol(x.clone())]);
659        let cos_x = Expression::function("cos", vec![Expression::symbol(x.clone())]);
660        let sin_squared = Expression::pow(sin_x, Expression::integer(2));
661        let cos_squared = Expression::pow(cos_x, Expression::integer(2));
662
663        let expr = Expression::add(vec![
664            sin_squared,
665            cos_squared,
666            Expression::symbol(y.clone()),
667        ]);
668        let simplified = expr.simplify();
669
670        assert_eq!(
671            simplified,
672            Expression::add(vec![Expression::integer(1), Expression::symbol(y)])
673        );
674    }
675
676    #[test]
677    fn test_pythagorean_identity_not_squared() {
678        let x = symbol!(x);
679        let sin_x = Expression::function("sin", vec![Expression::symbol(x.clone())]);
680        let cos_x = Expression::function("cos", vec![Expression::symbol(x.clone())]);
681
682        let expr = Expression::add(vec![sin_x, cos_x]);
683        let simplified = expr.simplify();
684
685        match simplified {
686            Expression::Add(_) => {}
687            _ => panic!("Expected Add (unchanged), got {:?}", simplified),
688        }
689    }
690
691    #[test]
692    fn test_distribute_numeric_over_addition() {
693        let x = symbol!(x);
694
695        // -1 * (x + 1) should distribute to -x - 1
696        let expr = Expression::add(vec![Expression::mul(vec![
697            Expression::integer(-1),
698            Expression::add(vec![Expression::symbol(x.clone()), Expression::integer(1)]),
699        ])]);
700
701        let simplified = expr.simplify();
702
703        // Should be Add([Mul([-1, x]), -1]) which simplifies to Add([-1, Mul([-1, x])])
704        match &simplified {
705            Expression::Add(terms) => {
706                assert_eq!(terms.len(), 2);
707                // Check that we have both -1 and -1*x in some form
708                let has_neg_one = terms
709                    .iter()
710                    .any(|t| matches!(t, Expression::Number(Number::Integer(-1))));
711                let has_neg_x = terms.iter().any(|t| {
712                    matches!(t, Expression::Mul(factors)
713                        if factors.len() == 2
714                        && matches!(factors[0], Expression::Number(Number::Integer(-1)))
715                    )
716                });
717                assert!(
718                    has_neg_one || has_neg_x,
719                    "Expected distributed terms, got {:?}",
720                    simplified
721                );
722            }
723            _ => panic!("Expected Add with distributed terms, got {:?}", simplified),
724        }
725    }
726}