mathhook_core/simplify/arithmetic/
addition.rs

1//! Addition simplification operations
2
3use super::helpers::{expression_order, extract_arithmetic_coefficient_and_base};
4use super::multiplication::simplify_multiplication;
5use super::power::simplify_power;
6use super::Simplify;
7use crate::core::commutativity::Commutativity;
8use crate::core::constants::EPSILON;
9use crate::core::{Expression, Number};
10use num_bigint::BigInt;
11use num_rational::BigRational;
12use num_traits::{ToPrimitive, Zero};
13use std::collections::VecDeque;
14use std::sync::Arc;
15
16fn extract_trig_squared(expr: &Expression, func: &str) -> Option<Expression> {
17    if let Expression::Pow(base, exp) = expr {
18        if let Expression::Number(Number::Integer(2)) = exp.as_ref() {
19            if let Expression::Function { name, args } = base.as_ref() {
20                if name.as_ref() == func && args.len() == 1 {
21                    return Some(args[0].clone());
22                }
23            }
24        }
25    }
26    None
27}
28
29fn check_pythagorean(terms: &[Expression]) -> Option<Vec<Expression>> {
30    for (i, t1) in terms.iter().enumerate() {
31        for (j, t2) in terms.iter().enumerate() {
32            if i >= j {
33                continue;
34            }
35            if let (Some(arg1), Some(arg2)) = (
36                extract_trig_squared(t1, "sin"),
37                extract_trig_squared(t2, "cos"),
38            ) {
39                if arg1 == arg2 {
40                    let mut remaining: Vec<_> = terms
41                        .iter()
42                        .enumerate()
43                        .filter(|(k, _)| *k != i && *k != j)
44                        .map(|(_, e)| e.clone())
45                        .collect();
46                    remaining.push(Expression::integer(1));
47                    return Some(remaining);
48                }
49            }
50            if let (Some(arg1), Some(arg2)) = (
51                extract_trig_squared(t1, "cos"),
52                extract_trig_squared(t2, "sin"),
53            ) {
54                if arg1 == arg2 {
55                    let mut remaining: Vec<_> = terms
56                        .iter()
57                        .enumerate()
58                        .filter(|(k, _)| *k != i && *k != j)
59                        .map(|(_, e)| e.clone())
60                        .collect();
61                    remaining.push(Expression::integer(1));
62                    return Some(remaining);
63                }
64            }
65        }
66    }
67    None
68}
69
70/// Simplify addition expressions with minimal overhead
71pub fn simplify_addition(terms: &[Expression]) -> Expression {
72    if terms.is_empty() {
73        return Expression::integer(0);
74    }
75
76    let mut flattened_terms: Vec<Expression> = Vec::new();
77    let mut to_process: VecDeque<&Expression> = terms.iter().collect();
78
79    while let Some(term) = to_process.pop_front() {
80        match term {
81            Expression::Add(nested_terms) => {
82                for nested_term in nested_terms.iter().rev() {
83                    to_process.push_front(nested_term);
84                }
85            }
86            Expression::Mul(factors) if factors.len() == 2 => {
87                if let (Expression::Number(coeff), Expression::Add(add_terms)) =
88                    (&factors[0], &factors[1])
89                {
90                    for add_term in add_terms.iter() {
91                        let distributed = Expression::mul(vec![
92                            Expression::Number(coeff.clone()),
93                            add_term.clone(),
94                        ]);
95                        flattened_terms.push(distributed);
96                    }
97                } else if let (Expression::Add(add_terms), Expression::Number(coeff)) =
98                    (&factors[0], &factors[1])
99                {
100                    for add_term in add_terms.iter() {
101                        let distributed = Expression::mul(vec![
102                            Expression::Number(coeff.clone()),
103                            add_term.clone(),
104                        ]);
105                        flattened_terms.push(distributed);
106                    }
107                } else {
108                    flattened_terms.push(term.clone());
109                }
110            }
111            _ => flattened_terms.push(term.clone()),
112        }
113    }
114
115    let terms = &flattened_terms;
116
117    if terms.len() == 2 {
118        if let Some(Ok(result)) = super::matrix_ops::try_matrix_add(&terms[0], &terms[1]) {
119            return result;
120        }
121    }
122
123    let mut int_sum = 0i64;
124    let mut float_sum = 0.0;
125    let mut has_float = false;
126    let mut rational_sum: Option<BigRational> = None;
127    let mut non_numeric_count = 0;
128    let mut first_non_numeric: Option<Expression> = None;
129    let mut numeric_result = None;
130
131    for term in terms {
132        let simplified_term = match term {
133            Expression::Add(_) => term.clone(),
134            Expression::Mul(factors) => simplify_multiplication(factors),
135            Expression::Pow(base, exp) => simplify_power(base.as_ref(), exp.as_ref()),
136            _ => term.simplify(),
137        };
138        match simplified_term {
139            Expression::Number(Number::Integer(n)) => {
140                int_sum = int_sum.saturating_add(n);
141            }
142            Expression::Number(Number::Float(f)) => {
143                float_sum += f;
144                has_float = true;
145            }
146            Expression::Number(Number::Rational(r)) => {
147                if let Some(ref mut current_sum) = rational_sum {
148                    *current_sum += r.as_ref();
149                } else {
150                    rational_sum = Some(r.as_ref().clone());
151                }
152            }
153            _ => {
154                non_numeric_count += 1;
155                if first_non_numeric.is_none() {
156                    first_non_numeric = Some(simplified_term);
157                }
158            }
159        }
160    }
161
162    if let Some(rational) = rational_sum {
163        let mut final_rational = rational;
164        if int_sum != 0 {
165            final_rational += BigRational::from(BigInt::from(int_sum));
166        }
167        if has_float {
168            let float_val = final_rational.to_f64().unwrap_or(0.0) + float_sum;
169            if float_val.abs() >= EPSILON {
170                numeric_result = Some(Expression::Number(Number::float(float_val)));
171            }
172        } else if !final_rational.is_zero() {
173            numeric_result = Some(Expression::Number(Number::rational(final_rational)));
174        }
175    } else if has_float {
176        let total = int_sum as f64 + float_sum;
177        if total.abs() >= EPSILON {
178            numeric_result = Some(Expression::Number(Number::float(total)));
179        }
180    } else if int_sum != 0 {
181        numeric_result = Some(Expression::integer(int_sum));
182    }
183
184    match (numeric_result.as_ref(), non_numeric_count) {
185        (None, 0) => Expression::integer(0),
186        (Some(num), 0) => num.clone(),
187        (None, 1) => {
188            first_non_numeric.expect("BUG: non_numeric_count is 1 but first_non_numeric is None")
189        }
190        (Some(num), 1) => {
191            let simplified_non_numeric = first_non_numeric
192                .expect("BUG: non_numeric_count is 1 but first_non_numeric is None");
193            match num {
194                Expression::Number(Number::Integer(0)) => simplified_non_numeric,
195                Expression::Number(Number::Float(f)) if f.abs() < EPSILON => simplified_non_numeric,
196                _ => Expression::Add(Arc::new(vec![num.clone(), simplified_non_numeric])),
197            }
198        }
199        _ => {
200            let mut result_terms = Vec::with_capacity(non_numeric_count + 1);
201            if let Some(num) = numeric_result {
202                match num {
203                    Expression::Number(Number::Integer(0)) => {}
204                    Expression::Number(Number::Float(0.0)) => {}
205                    _ => result_terms.push(num),
206                }
207            }
208
209            let mut like_terms: Vec<(String, Expression, Vec<Expression>)> = Vec::new();
210
211            for term in terms {
212                if !matches!(term, Expression::Number(_)) {
213                    let simplified_term = match term {
214                        Expression::Add(_) => term.clone(),
215                        Expression::Mul(factors) => simplify_multiplication(factors),
216                        Expression::Pow(base, exp) => simplify_power(base.as_ref(), exp.as_ref()),
217                        _ => term.simplify(),
218                    };
219                    match simplified_term {
220                        Expression::Number(Number::Integer(0)) => {}
221                        Expression::Number(Number::Float(0.0)) => {}
222                        _ => {
223                            let (coeff, base) =
224                                extract_arithmetic_coefficient_and_base(&simplified_term);
225
226                            let base_key = format!("{:?}", base);
227
228                            if let Some(entry) =
229                                like_terms.iter_mut().find(|(key, _, _)| key == &base_key)
230                            {
231                                entry.2.push(coeff);
232                            } else {
233                                like_terms.push((base_key, base.clone(), vec![coeff]));
234                            }
235                        }
236                    }
237                }
238            }
239
240            for (_, base, coeffs) in like_terms {
241                if coeffs.len() == 1 {
242                    let coeff = &coeffs[0];
243                    match coeff {
244                        Expression::Number(Number::Integer(1)) => {
245                            result_terms.push(base);
246                        }
247                        _ => {
248                            result_terms.push(Expression::Mul(Arc::new(vec![coeff.clone(), base])));
249                        }
250                    }
251                } else {
252                    let coeff_sum = simplify_addition(&coeffs);
253                    match coeff_sum {
254                        Expression::Number(Number::Integer(0)) => {}
255                        Expression::Number(Number::Float(0.0)) => {}
256                        Expression::Number(Number::Integer(1)) => {
257                            result_terms.push(base);
258                        }
259                        _ => {
260                            result_terms.push(Expression::Mul(Arc::new(vec![coeff_sum, base])));
261                        }
262                    }
263                }
264            }
265
266            if let Some(pythagorean_terms) = check_pythagorean(&result_terms) {
267                return simplify_addition(&pythagorean_terms);
268            }
269
270            match result_terms.len() {
271                0 => Expression::integer(0),
272                1 => result_terms
273                    .into_iter()
274                    .next()
275                    .expect("BUG: result_terms has length 1 but iterator is empty"),
276                _ => {
277                    let commutativity =
278                        Commutativity::combine(result_terms.iter().map(|t| t.commutativity()));
279
280                    if commutativity.can_sort() {
281                        result_terms.sort_by(expression_order);
282                    }
283
284                    Expression::Add(Arc::new(result_terms))
285                }
286            }
287        }
288    }
289}
290
291#[cfg(test)]
292mod tests {
293    use super::*;
294    use crate::simplify::Simplify;
295    use crate::{expr, symbol, Expression};
296
297    #[test]
298    fn test_addition_simplification() {
299        let expr = simplify_addition(&[Expression::integer(2), Expression::integer(3)]);
300        assert_eq!(expr, Expression::integer(5));
301
302        let expr = simplify_addition(&[Expression::integer(5), Expression::integer(0)]);
303        assert_eq!(expr, Expression::integer(5));
304
305        let x = symbol!(x);
306        let expr = simplify_addition(&[Expression::integer(2), Expression::symbol(x.clone())]);
307        assert_eq!(
308            expr,
309            Expression::add(vec![Expression::integer(2), Expression::symbol(x)])
310        );
311    }
312
313    #[test]
314    fn test_scalar_terms_combine() {
315        let x = symbol!(x);
316        let y = symbol!(y);
317
318        let xy = Expression::mul(vec![
319            Expression::symbol(x.clone()),
320            Expression::symbol(y.clone()),
321        ]);
322        let yx = Expression::mul(vec![
323            Expression::symbol(y.clone()),
324            Expression::symbol(x.clone()),
325        ]);
326        let expr = Expression::add(vec![xy.clone(), yx.clone()]);
327
328        let simplified = expr.simplify();
329
330        match simplified {
331            Expression::Mul(factors) => {
332                assert_eq!(factors.len(), 3);
333                assert_eq!(factors[0], Expression::integer(2));
334            }
335            _ => panic!("Expected Mul, got {:?}", simplified),
336        }
337    }
338
339    #[test]
340    fn test_matrix_terms_not_combined() {
341        let mat_a = symbol!(A; matrix);
342        let mat_b = symbol!(B; matrix);
343
344        let ab = Expression::mul(vec![
345            Expression::symbol(mat_a.clone()),
346            Expression::symbol(mat_b.clone()),
347        ]);
348        let ba = Expression::mul(vec![
349            Expression::symbol(mat_b.clone()),
350            Expression::symbol(mat_a.clone()),
351        ]);
352        let expr = Expression::add(vec![ab.clone(), ba.clone()]);
353
354        let simplified = expr.simplify();
355
356        match simplified {
357            Expression::Add(terms) => {
358                assert_eq!(terms.len(), 2);
359            }
360            _ => panic!("Expected Add with 2 terms, got {:?}", simplified),
361        }
362    }
363
364    #[test]
365    fn test_identical_matrix_terms_combine() {
366        let mat_a = symbol!(A; matrix);
367        let mat_b = symbol!(B; matrix);
368
369        let ab1 = Expression::mul(vec![
370            Expression::symbol(mat_a.clone()),
371            Expression::symbol(mat_b.clone()),
372        ]);
373        let ab2 = Expression::mul(vec![
374            Expression::symbol(mat_a.clone()),
375            Expression::symbol(mat_b.clone()),
376        ]);
377        let expr = Expression::add(vec![ab1, ab2]);
378
379        let simplified = expr.simplify();
380
381        match simplified {
382            Expression::Mul(factors) => {
383                assert_eq!(factors.len(), 3);
384                assert_eq!(factors[0], Expression::integer(2));
385            }
386            _ => panic!("Expected Mul, got {:?}", simplified),
387        }
388    }
389
390    #[test]
391    fn test_operator_terms_not_combined() {
392        let operator_p = symbol!(P; operator);
393        let operator_q = symbol!(Q; operator);
394
395        let pq = Expression::mul(vec![
396            Expression::symbol(operator_p.clone()),
397            Expression::symbol(operator_q.clone()),
398        ]);
399        let qp = Expression::mul(vec![
400            Expression::symbol(operator_q.clone()),
401            Expression::symbol(operator_p.clone()),
402        ]);
403        let expr = Expression::add(vec![pq, qp]);
404
405        let simplified = expr.simplify();
406
407        match simplified {
408            Expression::Add(terms) => {
409                assert_eq!(terms.len(), 2);
410            }
411            _ => panic!("Expected Add with 2 terms, got {:?}", simplified),
412        }
413    }
414
415    #[test]
416    fn test_quaternion_terms_not_combined() {
417        let i = symbol!(i; quaternion);
418        let j = symbol!(j; quaternion);
419
420        let ij = Expression::mul(vec![
421            Expression::symbol(i.clone()),
422            Expression::symbol(j.clone()),
423        ]);
424        let ji = Expression::mul(vec![
425            Expression::symbol(j.clone()),
426            Expression::symbol(i.clone()),
427        ]);
428        let expr = Expression::add(vec![ij, ji]);
429
430        let simplified = expr.simplify();
431
432        match simplified {
433            Expression::Add(terms) => {
434                assert_eq!(terms.len(), 2);
435            }
436            _ => panic!("Expected Add with 2 terms, got {:?}", simplified),
437        }
438    }
439
440    #[test]
441    fn test_scalar_addition_sorts() {
442        let y = symbol!(y);
443        let x = symbol!(x);
444        let expr = Expression::add(vec![
445            Expression::symbol(y.clone()),
446            Expression::symbol(x.clone()),
447        ]);
448        let simplified = expr.simplify();
449
450        match simplified {
451            Expression::Add(terms) => {
452                assert_eq!(terms.len(), 2);
453                assert_eq!(terms[0], Expression::symbol(symbol!(x)));
454                assert_eq!(terms[1], Expression::symbol(symbol!(y)));
455            }
456            _ => panic!("Expected Add, got {:?}", simplified),
457        }
458    }
459
460    #[test]
461    fn test_matrix_addition_preserves_order() {
462        let mat_b = symbol!(B; matrix);
463        let mat_a = symbol!(A; matrix);
464        let expr = Expression::add(vec![
465            Expression::symbol(mat_b.clone()),
466            Expression::symbol(mat_a.clone()),
467        ]);
468        let simplified = expr.simplify();
469
470        match simplified {
471            Expression::Add(terms) => {
472                assert_eq!(terms.len(), 2);
473                assert_eq!(terms[0], Expression::symbol(symbol!(B; matrix)));
474                assert_eq!(terms[1], Expression::symbol(symbol!(A; matrix)));
475            }
476            _ => panic!("Expected Add, got {:?}", simplified),
477        }
478    }
479
480    #[test]
481    fn test_mixed_scalar_matrix_addition_preserves_order() {
482        let x = symbol!(x);
483        let mat_a = symbol!(A; matrix);
484        let expr = Expression::add(vec![
485            Expression::symbol(x.clone()),
486            Expression::symbol(mat_a.clone()),
487        ]);
488        let simplified = expr.simplify();
489
490        match simplified {
491            Expression::Add(terms) => {
492                assert_eq!(terms.len(), 2);
493                assert_eq!(terms[0], expr!(x));
494                assert_eq!(terms[1], Expression::symbol(symbol!(A; matrix)));
495            }
496            _ => panic!("Expected Add, got {:?}", simplified),
497        }
498    }
499
500    #[test]
501    fn test_three_scalar_like_terms_combine() {
502        let x = symbol!(x);
503        let expr = Expression::add(vec![
504            Expression::symbol(x.clone()),
505            Expression::symbol(x.clone()),
506            Expression::symbol(x.clone()),
507        ]);
508        let simplified = expr.simplify();
509
510        match simplified {
511            Expression::Mul(factors) => {
512                assert_eq!(factors.len(), 2);
513                assert_eq!(factors[0], Expression::integer(3));
514                assert_eq!(factors[1], expr!(x));
515            }
516            _ => panic!("Expected Mul, got {:?}", simplified),
517        }
518    }
519
520    #[test]
521    fn test_three_matrix_like_terms_combine() {
522        let mat_a = symbol!(A; matrix);
523        let expr = Expression::add(vec![
524            Expression::symbol(mat_a.clone()),
525            Expression::symbol(mat_a.clone()),
526            Expression::symbol(mat_a.clone()),
527        ]);
528        let simplified = expr.simplify();
529
530        match simplified {
531            Expression::Mul(factors) => {
532                assert_eq!(factors.len(), 2);
533                assert_eq!(factors[0], Expression::integer(3));
534                assert_eq!(factors[1], Expression::symbol(symbol!(A; matrix)));
535            }
536            _ => panic!("Expected Mul, got {:?}", simplified),
537        }
538    }
539
540    #[test]
541    fn test_incompatible_matrix_addition_during_simplification() {
542        let a = Expression::matrix(vec![vec![expr!(1), expr!(2)], vec![expr!(3), expr!(4)]]);
543        let b = Expression::matrix(vec![vec![expr!(5), expr!(6), expr!(7)]]);
544
545        let expr = Expression::add(vec![a.clone(), b.clone()]);
546        let simplified = expr.simplify();
547
548        match simplified {
549            Expression::Add(terms) => {
550                assert_eq!(terms.len(), 2);
551            }
552            _ => panic!(
553                "Expected Add with 2 terms for incompatible matrices during simplification, got {:?}",
554                simplified
555            ),
556        }
557    }
558
559    #[test]
560    fn test_pythagorean_identity_sin_cos() {
561        let x = symbol!(x);
562        let sin_x = Expression::function("sin", vec![Expression::symbol(x.clone())]);
563        let cos_x = Expression::function("cos", vec![Expression::symbol(x.clone())]);
564        let sin_squared = Expression::pow(sin_x, Expression::integer(2));
565        let cos_squared = Expression::pow(cos_x, Expression::integer(2));
566
567        let expr = Expression::add(vec![sin_squared, cos_squared]);
568        let simplified = expr.simplify();
569
570        assert_eq!(simplified, Expression::integer(1));
571    }
572
573    #[test]
574    fn test_pythagorean_identity_cos_sin() {
575        let x = symbol!(x);
576        let sin_x = Expression::function("sin", vec![Expression::symbol(x.clone())]);
577        let cos_x = Expression::function("cos", vec![Expression::symbol(x.clone())]);
578        let sin_squared = Expression::pow(sin_x, Expression::integer(2));
579        let cos_squared = Expression::pow(cos_x, Expression::integer(2));
580
581        let expr = Expression::add(vec![cos_squared, sin_squared]);
582        let simplified = expr.simplify();
583
584        assert_eq!(simplified, Expression::integer(1));
585    }
586
587    #[test]
588    fn test_pythagorean_identity_different_args() {
589        let x = symbol!(x);
590        let y = symbol!(y);
591        let sin_x = Expression::function("sin", vec![Expression::symbol(x.clone())]);
592        let cos_y = Expression::function("cos", vec![Expression::symbol(y.clone())]);
593        let sin_squared = Expression::pow(sin_x, Expression::integer(2));
594        let cos_squared = Expression::pow(cos_y, Expression::integer(2));
595
596        let expr = Expression::add(vec![sin_squared, cos_squared]);
597        let simplified = expr.simplify();
598
599        match simplified {
600            Expression::Add(_) => {}
601            _ => panic!("Expected Add (unchanged), got {:?}", simplified),
602        }
603    }
604
605    #[test]
606    fn test_pythagorean_identity_with_additional_terms() {
607        let x = symbol!(x);
608        let y = symbol!(y);
609        let sin_x = Expression::function("sin", vec![Expression::symbol(x.clone())]);
610        let cos_x = Expression::function("cos", vec![Expression::symbol(x.clone())]);
611        let sin_squared = Expression::pow(sin_x, Expression::integer(2));
612        let cos_squared = Expression::pow(cos_x, Expression::integer(2));
613
614        let expr = Expression::add(vec![
615            sin_squared,
616            cos_squared,
617            Expression::symbol(y.clone()),
618        ]);
619        let simplified = expr.simplify();
620
621        assert_eq!(
622            simplified,
623            Expression::add(vec![Expression::integer(1), Expression::symbol(y)])
624        );
625    }
626
627    #[test]
628    fn test_pythagorean_identity_not_squared() {
629        let x = symbol!(x);
630        let sin_x = Expression::function("sin", vec![Expression::symbol(x.clone())]);
631        let cos_x = Expression::function("cos", vec![Expression::symbol(x.clone())]);
632
633        let expr = Expression::add(vec![sin_x, cos_x]);
634        let simplified = expr.simplify();
635
636        match simplified {
637            Expression::Add(_) => {}
638            _ => panic!("Expected Add (unchanged), got {:?}", simplified),
639        }
640    }
641
642    #[test]
643    fn test_distribute_numeric_over_addition() {
644        let x = symbol!(x);
645
646        let expr = Expression::add(vec![Expression::mul(vec![
647            Expression::integer(-1),
648            Expression::add(vec![Expression::symbol(x.clone()), Expression::integer(1)]),
649        ])]);
650
651        let simplified = expr.simplify();
652
653        match &simplified {
654            Expression::Add(terms) => {
655                assert_eq!(terms.len(), 2);
656                let has_neg_one = terms
657                    .iter()
658                    .any(|t| matches!(t, Expression::Number(Number::Integer(-1))));
659                let has_neg_x = terms.iter().any(|t| {
660                    matches!(t, Expression::Mul(factors)
661                        if factors.len() == 2
662                        && matches!(factors[0], Expression::Number(Number::Integer(-1)))
663                    )
664                });
665                assert!(
666                    has_neg_one || has_neg_x,
667                    "Expected distributed terms, got {:?}",
668                    simplified
669                );
670            }
671            _ => panic!("Expected Add with distributed terms, got {:?}", simplified),
672        }
673    }
674}