1use super::helpers::{expression_order, extract_arithmetic_coefficient_and_base};
4use super::multiplication::simplify_multiplication;
5use super::power::simplify_power;
6use super::Simplify;
7use crate::core::commutativity::Commutativity;
8use crate::core::constants::EPSILON;
9use crate::core::{Expression, Number};
10use num_bigint::BigInt;
11use num_rational::BigRational;
12use num_traits::{ToPrimitive, Zero};
13use std::collections::VecDeque;
14
15fn extract_trig_squared(expr: &Expression, func: &str) -> Option<Expression> {
16 if let Expression::Pow(base, exp) = expr {
17 if let Expression::Number(Number::Integer(2)) = exp.as_ref() {
18 if let Expression::Function { name, args } = base.as_ref() {
19 if name == func && args.len() == 1 {
20 return Some(args[0].clone());
21 }
22 }
23 }
24 }
25 None
26}
27
28fn check_pythagorean(terms: &[Expression]) -> Option<Vec<Expression>> {
29 for (i, t1) in terms.iter().enumerate() {
30 for (j, t2) in terms.iter().enumerate() {
31 if i >= j {
32 continue;
33 }
34 if let (Some(arg1), Some(arg2)) = (
35 extract_trig_squared(t1, "sin"),
36 extract_trig_squared(t2, "cos"),
37 ) {
38 if arg1 == arg2 {
39 let mut remaining: Vec<_> = terms
40 .iter()
41 .enumerate()
42 .filter(|(k, _)| *k != i && *k != j)
43 .map(|(_, e)| e.clone())
44 .collect();
45 remaining.push(Expression::integer(1));
46 return Some(remaining);
47 }
48 }
49 if let (Some(arg1), Some(arg2)) = (
50 extract_trig_squared(t1, "cos"),
51 extract_trig_squared(t2, "sin"),
52 ) {
53 if arg1 == arg2 {
54 let mut remaining: Vec<_> = terms
55 .iter()
56 .enumerate()
57 .filter(|(k, _)| *k != i && *k != j)
58 .map(|(_, e)| e.clone())
59 .collect();
60 remaining.push(Expression::integer(1));
61 return Some(remaining);
62 }
63 }
64 }
65 }
66 None
67}
68
69pub fn simplify_addition(terms: &[Expression]) -> Expression {
71 if terms.is_empty() {
72 return Expression::integer(0);
73 }
74
75 let mut flattened_terms: Vec<Expression> = Vec::new();
77 let mut to_process: VecDeque<&Expression> = terms.iter().collect();
78
79 while let Some(term) = to_process.pop_front() {
80 match term {
81 Expression::Add(nested_terms) => {
82 for nested_term in nested_terms.iter().rev() {
83 to_process.push_front(nested_term);
84 }
85 }
86 Expression::Mul(factors) if factors.len() == 2 => {
88 if let (Expression::Number(coeff), Expression::Add(add_terms)) =
89 (&factors[0], &factors[1])
90 {
91 for add_term in add_terms.iter() {
92 let distributed = Expression::mul(vec![
93 Expression::Number(coeff.clone()),
94 add_term.clone(),
95 ]);
96 flattened_terms.push(distributed);
97 }
98 } else if let (Expression::Add(add_terms), Expression::Number(coeff)) =
99 (&factors[0], &factors[1])
100 {
101 for add_term in add_terms.iter() {
102 let distributed = Expression::mul(vec![
103 Expression::Number(coeff.clone()),
104 add_term.clone(),
105 ]);
106 flattened_terms.push(distributed);
107 }
108 } else {
109 flattened_terms.push(term.clone());
110 }
111 }
112 _ => flattened_terms.push(term.clone()),
113 }
114 }
115
116 let terms = &flattened_terms;
118
119 if terms.len() == 2 {
124 if let Some(Ok(result)) = super::matrix_ops::try_matrix_add(&terms[0], &terms[1]) {
125 return result;
126 }
127 }
128
129 let mut int_sum = 0i64;
131 let mut float_sum = 0.0;
132 let mut has_float = false;
133 let mut rational_sum: Option<BigRational> = None;
134 let mut non_numeric_count = 0;
135 let mut first_non_numeric: Option<Expression> = None;
136 let mut numeric_result = None;
137
138 for term in terms {
139 let simplified_term = match term {
141 Expression::Add(_) => {
142 term.clone()
145 }
146 Expression::Mul(factors) => simplify_multiplication(factors),
147 Expression::Pow(base, exp) => simplify_power(base, exp),
148 _ => term.simplify(),
149 };
150 match simplified_term {
151 Expression::Number(Number::Integer(n)) => {
152 int_sum = int_sum.saturating_add(n);
153 }
154 Expression::Number(Number::Float(f)) => {
155 float_sum += f;
156 has_float = true;
157 }
158 Expression::Number(Number::Rational(r)) => {
159 if let Some(ref mut current_sum) = rational_sum {
160 *current_sum += r.as_ref();
161 } else {
162 rational_sum = Some(r.as_ref().clone());
163 }
164 }
165 _ => {
166 non_numeric_count += 1;
167 if first_non_numeric.is_none() {
168 first_non_numeric = Some(simplified_term);
169 }
170 }
171 }
172 }
173
174 if let Some(rational) = rational_sum {
176 let mut final_rational = rational;
178 if int_sum != 0 {
179 final_rational += BigRational::from(BigInt::from(int_sum));
180 }
181 if has_float {
182 let float_val = final_rational.to_f64().unwrap_or(0.0) + float_sum;
184 if float_val.abs() >= EPSILON {
185 numeric_result = Some(Expression::Number(Number::float(float_val)));
186 }
187 } else {
188 if !final_rational.is_zero() {
190 numeric_result = Some(Expression::Number(Number::rational(final_rational)));
191 }
192 }
193 } else if has_float {
194 let total = int_sum as f64 + float_sum;
195 if total.abs() >= EPSILON {
196 numeric_result = Some(Expression::Number(Number::float(total)));
197 }
198 } else if int_sum != 0 {
199 numeric_result = Some(Expression::integer(int_sum));
200 }
201
202 match (numeric_result.as_ref(), non_numeric_count) {
203 (None, 0) => Expression::integer(0),
204 (Some(num), 0) => num.clone(),
205 (None, 1) => {
206 first_non_numeric.expect("BUG: non_numeric_count is 1 but first_non_numeric is None")
208 }
209 (Some(num), 1) => {
210 let simplified_non_numeric = first_non_numeric
212 .expect("BUG: non_numeric_count is 1 but first_non_numeric is None");
213 match num {
215 Expression::Number(Number::Integer(0)) => simplified_non_numeric,
216 Expression::Number(Number::Float(f)) if f.abs() < EPSILON => simplified_non_numeric,
217 _ => Expression::Add(Box::new(vec![num.clone(), simplified_non_numeric])),
218 }
219 }
220 _ => {
221 let mut result_terms = Vec::with_capacity(non_numeric_count + 1);
223 if let Some(num) = numeric_result {
224 match num {
226 Expression::Number(Number::Integer(0)) => {}
227 Expression::Number(Number::Float(0.0)) => {}
228 _ => result_terms.push(num),
229 }
230 }
231
232 let mut like_terms: Vec<(String, Expression, Vec<Expression>)> = Vec::new();
235
236 for term in terms {
237 if !matches!(term, Expression::Number(_)) {
238 let simplified_term = match term {
240 Expression::Add(_) => term.clone(), Expression::Mul(factors) => simplify_multiplication(factors),
242 Expression::Pow(base, exp) => simplify_power(base, exp),
243 _ => term.simplify(),
244 };
245 match simplified_term {
246 Expression::Number(Number::Integer(0)) => {}
247 Expression::Number(Number::Float(0.0)) => {}
248 _ => {
249 let (coeff, base) =
251 extract_arithmetic_coefficient_and_base(&simplified_term);
252
253 let base_key = format!("{:?}", base);
254
255 if let Some(entry) =
257 like_terms.iter_mut().find(|(key, _, _)| key == &base_key)
258 {
259 entry.2.push(coeff);
260 } else {
261 like_terms.push((base_key, base.clone(), vec![coeff]));
262 }
263 }
264 }
265 }
266 }
267
268 for (_, base, coeffs) in like_terms {
270 if coeffs.len() == 1 {
271 let coeff = &coeffs[0];
273 match coeff {
274 Expression::Number(Number::Integer(1)) => {
275 result_terms.push(base);
277 }
278 _ => {
279 result_terms.push(Expression::Mul(Box::new(vec![coeff.clone(), base])));
280 }
281 }
282 } else {
283 let coeff_sum = simplify_addition(&coeffs);
285 match coeff_sum {
286 Expression::Number(Number::Integer(0)) => {}
287 Expression::Number(Number::Float(0.0)) => {}
288 Expression::Number(Number::Integer(1)) => {
289 result_terms.push(base);
291 }
292 _ => {
293 result_terms.push(Expression::Mul(Box::new(vec![coeff_sum, base])));
294 }
295 }
296 }
297 }
298
299 if let Some(pythagorean_terms) = check_pythagorean(&result_terms) {
301 return simplify_addition(&pythagorean_terms);
302 }
303
304 match result_terms.len() {
305 0 => Expression::integer(0),
306 1 => result_terms
307 .into_iter()
308 .next()
309 .expect("BUG: result_terms has length 1 but iterator is empty"),
310 _ => {
311 let commutativity =
314 Commutativity::combine(result_terms.iter().map(|t| t.commutativity()));
315
316 if commutativity.can_sort() {
317 result_terms.sort_by(expression_order);
319 }
320 Expression::Add(Box::new(result_terms))
323 }
324 }
325 }
326 }
327}
328
329#[cfg(test)]
330mod tests {
331 use super::*;
332 use crate::simplify::Simplify;
333 use crate::{expr, symbol, Expression};
334
335 #[test]
336 fn test_addition_simplification() {
337 let expr = simplify_addition(&[Expression::integer(2), Expression::integer(3)]);
339 assert_eq!(expr, Expression::integer(5));
340
341 let expr = simplify_addition(&[Expression::integer(5), Expression::integer(0)]);
343 assert_eq!(expr, Expression::integer(5));
344
345 let x = symbol!(x);
347 let expr = simplify_addition(&[Expression::integer(2), Expression::symbol(x.clone())]);
348 assert_eq!(
349 expr,
350 Expression::add(vec![Expression::integer(2), Expression::symbol(x)])
351 );
352 }
353
354 #[test]
355 fn test_scalar_terms_combine() {
356 let x = symbol!(x);
357 let y = symbol!(y);
358
359 let xy = Expression::mul(vec![
361 Expression::symbol(x.clone()),
362 Expression::symbol(y.clone()),
363 ]);
364 let yx = Expression::mul(vec![
365 Expression::symbol(y.clone()),
366 Expression::symbol(x.clone()),
367 ]);
368 let expr = Expression::add(vec![xy.clone(), yx.clone()]);
369
370 let simplified = expr.simplify();
371
372 match simplified {
373 Expression::Mul(factors) => {
374 assert_eq!(factors.len(), 3);
375 assert_eq!(factors[0], Expression::integer(2));
376 }
377 _ => panic!("Expected Mul, got {:?}", simplified),
378 }
379 }
380
381 #[test]
382 fn test_matrix_terms_not_combined() {
383 let mat_a = symbol!(A; matrix);
384 let mat_b = symbol!(B; matrix);
385
386 let ab = Expression::mul(vec![
388 Expression::symbol(mat_a.clone()),
389 Expression::symbol(mat_b.clone()),
390 ]);
391 let ba = Expression::mul(vec![
392 Expression::symbol(mat_b.clone()),
393 Expression::symbol(mat_a.clone()),
394 ]);
395 let expr = Expression::add(vec![ab.clone(), ba.clone()]);
396
397 let simplified = expr.simplify();
398
399 match simplified {
400 Expression::Add(terms) => {
401 assert_eq!(terms.len(), 2);
402 }
403 _ => panic!("Expected Add with 2 terms, got {:?}", simplified),
404 }
405 }
406
407 #[test]
408 fn test_identical_matrix_terms_combine() {
409 let mat_a = symbol!(A; matrix);
410 let mat_b = symbol!(B; matrix);
411
412 let ab1 = Expression::mul(vec![
414 Expression::symbol(mat_a.clone()),
415 Expression::symbol(mat_b.clone()),
416 ]);
417 let ab2 = Expression::mul(vec![
418 Expression::symbol(mat_a.clone()),
419 Expression::symbol(mat_b.clone()),
420 ]);
421 let expr = Expression::add(vec![ab1, ab2]);
422
423 let simplified = expr.simplify();
424
425 match simplified {
426 Expression::Mul(factors) => {
427 assert_eq!(factors.len(), 3);
428 assert_eq!(factors[0], Expression::integer(2));
429 }
430 _ => panic!("Expected Mul, got {:?}", simplified),
431 }
432 }
433
434 #[test]
435 fn test_operator_terms_not_combined() {
436 let operator_p = symbol!(P; operator);
437 let operator_q = symbol!(Q; operator);
438
439 let pq = Expression::mul(vec![
441 Expression::symbol(operator_p.clone()),
442 Expression::symbol(operator_q.clone()),
443 ]);
444 let qp = Expression::mul(vec![
445 Expression::symbol(operator_q.clone()),
446 Expression::symbol(operator_p.clone()),
447 ]);
448 let expr = Expression::add(vec![pq, qp]);
449
450 let simplified = expr.simplify();
451
452 match simplified {
453 Expression::Add(terms) => {
454 assert_eq!(terms.len(), 2);
455 }
456 _ => panic!("Expected Add with 2 terms, got {:?}", simplified),
457 }
458 }
459
460 #[test]
461 fn test_quaternion_terms_not_combined() {
462 let i = symbol!(i; quaternion);
463 let j = symbol!(j; quaternion);
464
465 let ij = Expression::mul(vec![
467 Expression::symbol(i.clone()),
468 Expression::symbol(j.clone()),
469 ]);
470 let ji = Expression::mul(vec![
471 Expression::symbol(j.clone()),
472 Expression::symbol(i.clone()),
473 ]);
474 let expr = Expression::add(vec![ij, ji]);
475
476 let simplified = expr.simplify();
477
478 match simplified {
479 Expression::Add(terms) => {
480 assert_eq!(terms.len(), 2);
481 }
482 _ => panic!("Expected Add with 2 terms, got {:?}", simplified),
483 }
484 }
485
486 #[test]
487 fn test_scalar_addition_sorts() {
488 let y = symbol!(y);
489 let x = symbol!(x);
490 let expr = Expression::add(vec![
491 Expression::symbol(y.clone()),
492 Expression::symbol(x.clone()),
493 ]);
494 let simplified = expr.simplify();
495
496 match simplified {
497 Expression::Add(terms) => {
498 assert_eq!(terms.len(), 2);
499 assert_eq!(terms[0], Expression::symbol(symbol!(x)));
500 assert_eq!(terms[1], Expression::symbol(symbol!(y)));
501 }
502 _ => panic!("Expected Add, got {:?}", simplified),
503 }
504 }
505
506 #[test]
507 fn test_matrix_addition_preserves_order() {
508 let mat_b = symbol!(B; matrix);
509 let mat_a = symbol!(A; matrix);
510 let expr = Expression::add(vec![
511 Expression::symbol(mat_b.clone()),
512 Expression::symbol(mat_a.clone()),
513 ]);
514 let simplified = expr.simplify();
515
516 match simplified {
517 Expression::Add(terms) => {
518 assert_eq!(terms.len(), 2);
519 assert_eq!(terms[0], Expression::symbol(symbol!(B; matrix)));
520 assert_eq!(terms[1], Expression::symbol(symbol!(A; matrix)));
521 }
522 _ => panic!("Expected Add, got {:?}", simplified),
523 }
524 }
525
526 #[test]
527 fn test_mixed_scalar_matrix_addition_preserves_order() {
528 let x = symbol!(x);
529 let mat_a = symbol!(A; matrix);
530 let expr = Expression::add(vec![
531 Expression::symbol(x.clone()),
532 Expression::symbol(mat_a.clone()),
533 ]);
534 let simplified = expr.simplify();
535
536 match simplified {
537 Expression::Add(terms) => {
538 assert_eq!(terms.len(), 2);
539 assert_eq!(terms[0], expr!(x));
540 assert_eq!(terms[1], Expression::symbol(symbol!(A; matrix)));
541 }
542 _ => panic!("Expected Add, got {:?}", simplified),
543 }
544 }
545
546 #[test]
547 fn test_three_scalar_like_terms_combine() {
548 let x = symbol!(x);
549 let expr = Expression::add(vec![
550 Expression::symbol(x.clone()),
551 Expression::symbol(x.clone()),
552 Expression::symbol(x.clone()),
553 ]);
554 let simplified = expr.simplify();
555
556 match simplified {
557 Expression::Mul(factors) => {
558 assert_eq!(factors.len(), 2);
559 assert_eq!(factors[0], Expression::integer(3));
560 assert_eq!(factors[1], expr!(x));
561 }
562 _ => panic!("Expected Mul, got {:?}", simplified),
563 }
564 }
565
566 #[test]
567 fn test_three_matrix_like_terms_combine() {
568 let mat_a = symbol!(A; matrix);
569 let expr = Expression::add(vec![
570 Expression::symbol(mat_a.clone()),
571 Expression::symbol(mat_a.clone()),
572 Expression::symbol(mat_a.clone()),
573 ]);
574 let simplified = expr.simplify();
575
576 match simplified {
577 Expression::Mul(factors) => {
578 assert_eq!(factors.len(), 2);
579 assert_eq!(factors[0], Expression::integer(3));
580 assert_eq!(factors[1], Expression::symbol(symbol!(A; matrix)));
581 }
582 _ => panic!("Expected Mul, got {:?}", simplified),
583 }
584 }
585
586 #[test]
587 fn test_incompatible_matrix_addition_during_simplification() {
588 let a = Expression::matrix(vec![vec![expr!(1), expr!(2)], vec![expr!(3), expr!(4)]]);
589 let b = Expression::matrix(vec![vec![expr!(5), expr!(6), expr!(7)]]);
590
591 let expr = Expression::add(vec![a.clone(), b.clone()]);
592 let simplified = expr.simplify();
593
594 match simplified {
598 Expression::Add(terms) => {
599 assert_eq!(terms.len(), 2);
600 }
601 _ => panic!(
602 "Expected Add with 2 terms for incompatible matrices during simplification, got {:?}",
603 simplified
604 ),
605 }
606 }
607
608 #[test]
609 fn test_pythagorean_identity_sin_cos() {
610 let x = symbol!(x);
611 let sin_x = Expression::function("sin", vec![Expression::symbol(x.clone())]);
612 let cos_x = Expression::function("cos", vec![Expression::symbol(x.clone())]);
613 let sin_squared = Expression::pow(sin_x, Expression::integer(2));
614 let cos_squared = Expression::pow(cos_x, Expression::integer(2));
615
616 let expr = Expression::add(vec![sin_squared, cos_squared]);
617 let simplified = expr.simplify();
618
619 assert_eq!(simplified, Expression::integer(1));
620 }
621
622 #[test]
623 fn test_pythagorean_identity_cos_sin() {
624 let x = symbol!(x);
625 let sin_x = Expression::function("sin", vec![Expression::symbol(x.clone())]);
626 let cos_x = Expression::function("cos", vec![Expression::symbol(x.clone())]);
627 let sin_squared = Expression::pow(sin_x, Expression::integer(2));
628 let cos_squared = Expression::pow(cos_x, Expression::integer(2));
629
630 let expr = Expression::add(vec![cos_squared, sin_squared]);
631 let simplified = expr.simplify();
632
633 assert_eq!(simplified, Expression::integer(1));
634 }
635
636 #[test]
637 fn test_pythagorean_identity_different_args() {
638 let x = symbol!(x);
639 let y = symbol!(y);
640 let sin_x = Expression::function("sin", vec![Expression::symbol(x.clone())]);
641 let cos_y = Expression::function("cos", vec![Expression::symbol(y.clone())]);
642 let sin_squared = Expression::pow(sin_x, Expression::integer(2));
643 let cos_squared = Expression::pow(cos_y, Expression::integer(2));
644
645 let expr = Expression::add(vec![sin_squared, cos_squared]);
646 let simplified = expr.simplify();
647
648 match simplified {
649 Expression::Add(_) => {}
650 _ => panic!("Expected Add (unchanged), got {:?}", simplified),
651 }
652 }
653
654 #[test]
655 fn test_pythagorean_identity_with_additional_terms() {
656 let x = symbol!(x);
657 let y = symbol!(y);
658 let sin_x = Expression::function("sin", vec![Expression::symbol(x.clone())]);
659 let cos_x = Expression::function("cos", vec![Expression::symbol(x.clone())]);
660 let sin_squared = Expression::pow(sin_x, Expression::integer(2));
661 let cos_squared = Expression::pow(cos_x, Expression::integer(2));
662
663 let expr = Expression::add(vec![
664 sin_squared,
665 cos_squared,
666 Expression::symbol(y.clone()),
667 ]);
668 let simplified = expr.simplify();
669
670 assert_eq!(
671 simplified,
672 Expression::add(vec![Expression::integer(1), Expression::symbol(y)])
673 );
674 }
675
676 #[test]
677 fn test_pythagorean_identity_not_squared() {
678 let x = symbol!(x);
679 let sin_x = Expression::function("sin", vec![Expression::symbol(x.clone())]);
680 let cos_x = Expression::function("cos", vec![Expression::symbol(x.clone())]);
681
682 let expr = Expression::add(vec![sin_x, cos_x]);
683 let simplified = expr.simplify();
684
685 match simplified {
686 Expression::Add(_) => {}
687 _ => panic!("Expected Add (unchanged), got {:?}", simplified),
688 }
689 }
690
691 #[test]
692 fn test_distribute_numeric_over_addition() {
693 let x = symbol!(x);
694
695 let expr = Expression::add(vec![Expression::mul(vec![
697 Expression::integer(-1),
698 Expression::add(vec![Expression::symbol(x.clone()), Expression::integer(1)]),
699 ])]);
700
701 let simplified = expr.simplify();
702
703 match &simplified {
705 Expression::Add(terms) => {
706 assert_eq!(terms.len(), 2);
707 let has_neg_one = terms
709 .iter()
710 .any(|t| matches!(t, Expression::Number(Number::Integer(-1))));
711 let has_neg_x = terms.iter().any(|t| {
712 matches!(t, Expression::Mul(factors)
713 if factors.len() == 2
714 && matches!(factors[0], Expression::Number(Number::Integer(-1)))
715 )
716 });
717 assert!(
718 has_neg_one || has_neg_x,
719 "Expected distributed terms, got {:?}",
720 simplified
721 );
722 }
723 _ => panic!("Expected Add with distributed terms, got {:?}", simplified),
724 }
725 }
726}