1use super::helpers::{expression_order, extract_arithmetic_coefficient_and_base};
4use super::multiplication::simplify_multiplication;
5use super::power::simplify_power;
6use super::Simplify;
7use crate::core::commutativity::Commutativity;
8use crate::core::constants::EPSILON;
9use crate::core::{Expression, Number};
10use num_bigint::BigInt;
11use num_rational::BigRational;
12use num_traits::{ToPrimitive, Zero};
13use std::collections::VecDeque;
14use std::sync::Arc;
15
16fn extract_trig_squared(expr: &Expression, func: &str) -> Option<Expression> {
17 if let Expression::Pow(base, exp) = expr {
18 if let Expression::Number(Number::Integer(2)) = exp.as_ref() {
19 if let Expression::Function { name, args } = base.as_ref() {
20 if name.as_ref() == func && args.len() == 1 {
21 return Some(args[0].clone());
22 }
23 }
24 }
25 }
26 None
27}
28
29fn check_pythagorean(terms: &[Expression]) -> Option<Vec<Expression>> {
30 for (i, t1) in terms.iter().enumerate() {
31 for (j, t2) in terms.iter().enumerate() {
32 if i >= j {
33 continue;
34 }
35 if let (Some(arg1), Some(arg2)) = (
36 extract_trig_squared(t1, "sin"),
37 extract_trig_squared(t2, "cos"),
38 ) {
39 if arg1 == arg2 {
40 let mut remaining: Vec<_> = terms
41 .iter()
42 .enumerate()
43 .filter(|(k, _)| *k != i && *k != j)
44 .map(|(_, e)| e.clone())
45 .collect();
46 remaining.push(Expression::integer(1));
47 return Some(remaining);
48 }
49 }
50 if let (Some(arg1), Some(arg2)) = (
51 extract_trig_squared(t1, "cos"),
52 extract_trig_squared(t2, "sin"),
53 ) {
54 if arg1 == arg2 {
55 let mut remaining: Vec<_> = terms
56 .iter()
57 .enumerate()
58 .filter(|(k, _)| *k != i && *k != j)
59 .map(|(_, e)| e.clone())
60 .collect();
61 remaining.push(Expression::integer(1));
62 return Some(remaining);
63 }
64 }
65 }
66 }
67 None
68}
69
70pub fn simplify_addition(terms: &[Expression]) -> Expression {
72 if terms.is_empty() {
73 return Expression::integer(0);
74 }
75
76 let mut flattened_terms: Vec<Expression> = Vec::new();
77 let mut to_process: VecDeque<&Expression> = terms.iter().collect();
78
79 while let Some(term) = to_process.pop_front() {
80 match term {
81 Expression::Add(nested_terms) => {
82 for nested_term in nested_terms.iter().rev() {
83 to_process.push_front(nested_term);
84 }
85 }
86 Expression::Mul(factors) if factors.len() == 2 => {
87 if let (Expression::Number(coeff), Expression::Add(add_terms)) =
88 (&factors[0], &factors[1])
89 {
90 for add_term in add_terms.iter() {
91 let distributed = Expression::mul(vec![
92 Expression::Number(coeff.clone()),
93 add_term.clone(),
94 ]);
95 flattened_terms.push(distributed);
96 }
97 } else if let (Expression::Add(add_terms), Expression::Number(coeff)) =
98 (&factors[0], &factors[1])
99 {
100 for add_term in add_terms.iter() {
101 let distributed = Expression::mul(vec![
102 Expression::Number(coeff.clone()),
103 add_term.clone(),
104 ]);
105 flattened_terms.push(distributed);
106 }
107 } else {
108 flattened_terms.push(term.clone());
109 }
110 }
111 _ => flattened_terms.push(term.clone()),
112 }
113 }
114
115 let terms = &flattened_terms;
116
117 if terms.len() == 2 {
118 if let Some(Ok(result)) = super::matrix_ops::try_matrix_add(&terms[0], &terms[1]) {
119 return result;
120 }
121 }
122
123 let mut int_sum = 0i64;
124 let mut float_sum = 0.0;
125 let mut has_float = false;
126 let mut rational_sum: Option<BigRational> = None;
127 let mut non_numeric_count = 0;
128 let mut first_non_numeric: Option<Expression> = None;
129 let mut numeric_result = None;
130
131 for term in terms {
132 let simplified_term = match term {
133 Expression::Add(_) => term.clone(),
134 Expression::Mul(factors) => simplify_multiplication(factors),
135 Expression::Pow(base, exp) => simplify_power(base.as_ref(), exp.as_ref()),
136 _ => term.simplify(),
137 };
138 match simplified_term {
139 Expression::Number(Number::Integer(n)) => {
140 int_sum = int_sum.saturating_add(n);
141 }
142 Expression::Number(Number::Float(f)) => {
143 float_sum += f;
144 has_float = true;
145 }
146 Expression::Number(Number::Rational(r)) => {
147 if let Some(ref mut current_sum) = rational_sum {
148 *current_sum += r.as_ref();
149 } else {
150 rational_sum = Some(r.as_ref().clone());
151 }
152 }
153 _ => {
154 non_numeric_count += 1;
155 if first_non_numeric.is_none() {
156 first_non_numeric = Some(simplified_term);
157 }
158 }
159 }
160 }
161
162 if let Some(rational) = rational_sum {
163 let mut final_rational = rational;
164 if int_sum != 0 {
165 final_rational += BigRational::from(BigInt::from(int_sum));
166 }
167 if has_float {
168 let float_val = final_rational.to_f64().unwrap_or(0.0) + float_sum;
169 if float_val.abs() >= EPSILON {
170 numeric_result = Some(Expression::Number(Number::float(float_val)));
171 }
172 } else if !final_rational.is_zero() {
173 numeric_result = Some(Expression::Number(Number::rational(final_rational)));
174 }
175 } else if has_float {
176 let total = int_sum as f64 + float_sum;
177 if total.abs() >= EPSILON {
178 numeric_result = Some(Expression::Number(Number::float(total)));
179 }
180 } else if int_sum != 0 {
181 numeric_result = Some(Expression::integer(int_sum));
182 }
183
184 match (numeric_result.as_ref(), non_numeric_count) {
185 (None, 0) => Expression::integer(0),
186 (Some(num), 0) => num.clone(),
187 (None, 1) => {
188 first_non_numeric.expect("BUG: non_numeric_count is 1 but first_non_numeric is None")
189 }
190 (Some(num), 1) => {
191 let simplified_non_numeric = first_non_numeric
192 .expect("BUG: non_numeric_count is 1 but first_non_numeric is None");
193 match num {
194 Expression::Number(Number::Integer(0)) => simplified_non_numeric,
195 Expression::Number(Number::Float(f)) if f.abs() < EPSILON => simplified_non_numeric,
196 _ => Expression::Add(Arc::new(vec![num.clone(), simplified_non_numeric])),
197 }
198 }
199 _ => {
200 let mut result_terms = Vec::with_capacity(non_numeric_count + 1);
201 if let Some(num) = numeric_result {
202 match num {
203 Expression::Number(Number::Integer(0)) => {}
204 Expression::Number(Number::Float(0.0)) => {}
205 _ => result_terms.push(num),
206 }
207 }
208
209 let mut like_terms: Vec<(String, Expression, Vec<Expression>)> = Vec::new();
210
211 for term in terms {
212 if !matches!(term, Expression::Number(_)) {
213 let simplified_term = match term {
214 Expression::Add(_) => term.clone(),
215 Expression::Mul(factors) => simplify_multiplication(factors),
216 Expression::Pow(base, exp) => simplify_power(base.as_ref(), exp.as_ref()),
217 _ => term.simplify(),
218 };
219 match simplified_term {
220 Expression::Number(Number::Integer(0)) => {}
221 Expression::Number(Number::Float(0.0)) => {}
222 _ => {
223 let (coeff, base) =
224 extract_arithmetic_coefficient_and_base(&simplified_term);
225
226 let base_key = format!("{:?}", base);
227
228 if let Some(entry) =
229 like_terms.iter_mut().find(|(key, _, _)| key == &base_key)
230 {
231 entry.2.push(coeff);
232 } else {
233 like_terms.push((base_key, base.clone(), vec![coeff]));
234 }
235 }
236 }
237 }
238 }
239
240 for (_, base, coeffs) in like_terms {
241 if coeffs.len() == 1 {
242 let coeff = &coeffs[0];
243 match coeff {
244 Expression::Number(Number::Integer(1)) => {
245 result_terms.push(base);
246 }
247 _ => {
248 result_terms.push(Expression::Mul(Arc::new(vec![coeff.clone(), base])));
249 }
250 }
251 } else {
252 let coeff_sum = simplify_addition(&coeffs);
253 match coeff_sum {
254 Expression::Number(Number::Integer(0)) => {}
255 Expression::Number(Number::Float(0.0)) => {}
256 Expression::Number(Number::Integer(1)) => {
257 result_terms.push(base);
258 }
259 _ => {
260 result_terms.push(Expression::Mul(Arc::new(vec![coeff_sum, base])));
261 }
262 }
263 }
264 }
265
266 if let Some(pythagorean_terms) = check_pythagorean(&result_terms) {
267 return simplify_addition(&pythagorean_terms);
268 }
269
270 match result_terms.len() {
271 0 => Expression::integer(0),
272 1 => result_terms
273 .into_iter()
274 .next()
275 .expect("BUG: result_terms has length 1 but iterator is empty"),
276 _ => {
277 let commutativity =
278 Commutativity::combine(result_terms.iter().map(|t| t.commutativity()));
279
280 if commutativity.can_sort() {
281 result_terms.sort_by(expression_order);
282 }
283
284 Expression::Add(Arc::new(result_terms))
285 }
286 }
287 }
288 }
289}
290
291#[cfg(test)]
292mod tests {
293 use super::*;
294 use crate::simplify::Simplify;
295 use crate::{expr, symbol, Expression};
296
297 #[test]
298 fn test_addition_simplification() {
299 let expr = simplify_addition(&[Expression::integer(2), Expression::integer(3)]);
300 assert_eq!(expr, Expression::integer(5));
301
302 let expr = simplify_addition(&[Expression::integer(5), Expression::integer(0)]);
303 assert_eq!(expr, Expression::integer(5));
304
305 let x = symbol!(x);
306 let expr = simplify_addition(&[Expression::integer(2), Expression::symbol(x.clone())]);
307 assert_eq!(
308 expr,
309 Expression::add(vec![Expression::integer(2), Expression::symbol(x)])
310 );
311 }
312
313 #[test]
314 fn test_scalar_terms_combine() {
315 let x = symbol!(x);
316 let y = symbol!(y);
317
318 let xy = Expression::mul(vec![
319 Expression::symbol(x.clone()),
320 Expression::symbol(y.clone()),
321 ]);
322 let yx = Expression::mul(vec![
323 Expression::symbol(y.clone()),
324 Expression::symbol(x.clone()),
325 ]);
326 let expr = Expression::add(vec![xy.clone(), yx.clone()]);
327
328 let simplified = expr.simplify();
329
330 match simplified {
331 Expression::Mul(factors) => {
332 assert_eq!(factors.len(), 3);
333 assert_eq!(factors[0], Expression::integer(2));
334 }
335 _ => panic!("Expected Mul, got {:?}", simplified),
336 }
337 }
338
339 #[test]
340 fn test_matrix_terms_not_combined() {
341 let mat_a = symbol!(A; matrix);
342 let mat_b = symbol!(B; matrix);
343
344 let ab = Expression::mul(vec![
345 Expression::symbol(mat_a.clone()),
346 Expression::symbol(mat_b.clone()),
347 ]);
348 let ba = Expression::mul(vec![
349 Expression::symbol(mat_b.clone()),
350 Expression::symbol(mat_a.clone()),
351 ]);
352 let expr = Expression::add(vec![ab.clone(), ba.clone()]);
353
354 let simplified = expr.simplify();
355
356 match simplified {
357 Expression::Add(terms) => {
358 assert_eq!(terms.len(), 2);
359 }
360 _ => panic!("Expected Add with 2 terms, got {:?}", simplified),
361 }
362 }
363
364 #[test]
365 fn test_identical_matrix_terms_combine() {
366 let mat_a = symbol!(A; matrix);
367 let mat_b = symbol!(B; matrix);
368
369 let ab1 = Expression::mul(vec![
370 Expression::symbol(mat_a.clone()),
371 Expression::symbol(mat_b.clone()),
372 ]);
373 let ab2 = Expression::mul(vec![
374 Expression::symbol(mat_a.clone()),
375 Expression::symbol(mat_b.clone()),
376 ]);
377 let expr = Expression::add(vec![ab1, ab2]);
378
379 let simplified = expr.simplify();
380
381 match simplified {
382 Expression::Mul(factors) => {
383 assert_eq!(factors.len(), 3);
384 assert_eq!(factors[0], Expression::integer(2));
385 }
386 _ => panic!("Expected Mul, got {:?}", simplified),
387 }
388 }
389
390 #[test]
391 fn test_operator_terms_not_combined() {
392 let operator_p = symbol!(P; operator);
393 let operator_q = symbol!(Q; operator);
394
395 let pq = Expression::mul(vec![
396 Expression::symbol(operator_p.clone()),
397 Expression::symbol(operator_q.clone()),
398 ]);
399 let qp = Expression::mul(vec![
400 Expression::symbol(operator_q.clone()),
401 Expression::symbol(operator_p.clone()),
402 ]);
403 let expr = Expression::add(vec![pq, qp]);
404
405 let simplified = expr.simplify();
406
407 match simplified {
408 Expression::Add(terms) => {
409 assert_eq!(terms.len(), 2);
410 }
411 _ => panic!("Expected Add with 2 terms, got {:?}", simplified),
412 }
413 }
414
415 #[test]
416 fn test_quaternion_terms_not_combined() {
417 let i = symbol!(i; quaternion);
418 let j = symbol!(j; quaternion);
419
420 let ij = Expression::mul(vec![
421 Expression::symbol(i.clone()),
422 Expression::symbol(j.clone()),
423 ]);
424 let ji = Expression::mul(vec![
425 Expression::symbol(j.clone()),
426 Expression::symbol(i.clone()),
427 ]);
428 let expr = Expression::add(vec![ij, ji]);
429
430 let simplified = expr.simplify();
431
432 match simplified {
433 Expression::Add(terms) => {
434 assert_eq!(terms.len(), 2);
435 }
436 _ => panic!("Expected Add with 2 terms, got {:?}", simplified),
437 }
438 }
439
440 #[test]
441 fn test_scalar_addition_sorts() {
442 let y = symbol!(y);
443 let x = symbol!(x);
444 let expr = Expression::add(vec![
445 Expression::symbol(y.clone()),
446 Expression::symbol(x.clone()),
447 ]);
448 let simplified = expr.simplify();
449
450 match simplified {
451 Expression::Add(terms) => {
452 assert_eq!(terms.len(), 2);
453 assert_eq!(terms[0], Expression::symbol(symbol!(x)));
454 assert_eq!(terms[1], Expression::symbol(symbol!(y)));
455 }
456 _ => panic!("Expected Add, got {:?}", simplified),
457 }
458 }
459
460 #[test]
461 fn test_matrix_addition_preserves_order() {
462 let mat_b = symbol!(B; matrix);
463 let mat_a = symbol!(A; matrix);
464 let expr = Expression::add(vec![
465 Expression::symbol(mat_b.clone()),
466 Expression::symbol(mat_a.clone()),
467 ]);
468 let simplified = expr.simplify();
469
470 match simplified {
471 Expression::Add(terms) => {
472 assert_eq!(terms.len(), 2);
473 assert_eq!(terms[0], Expression::symbol(symbol!(B; matrix)));
474 assert_eq!(terms[1], Expression::symbol(symbol!(A; matrix)));
475 }
476 _ => panic!("Expected Add, got {:?}", simplified),
477 }
478 }
479
480 #[test]
481 fn test_mixed_scalar_matrix_addition_preserves_order() {
482 let x = symbol!(x);
483 let mat_a = symbol!(A; matrix);
484 let expr = Expression::add(vec![
485 Expression::symbol(x.clone()),
486 Expression::symbol(mat_a.clone()),
487 ]);
488 let simplified = expr.simplify();
489
490 match simplified {
491 Expression::Add(terms) => {
492 assert_eq!(terms.len(), 2);
493 assert_eq!(terms[0], expr!(x));
494 assert_eq!(terms[1], Expression::symbol(symbol!(A; matrix)));
495 }
496 _ => panic!("Expected Add, got {:?}", simplified),
497 }
498 }
499
500 #[test]
501 fn test_three_scalar_like_terms_combine() {
502 let x = symbol!(x);
503 let expr = Expression::add(vec![
504 Expression::symbol(x.clone()),
505 Expression::symbol(x.clone()),
506 Expression::symbol(x.clone()),
507 ]);
508 let simplified = expr.simplify();
509
510 match simplified {
511 Expression::Mul(factors) => {
512 assert_eq!(factors.len(), 2);
513 assert_eq!(factors[0], Expression::integer(3));
514 assert_eq!(factors[1], expr!(x));
515 }
516 _ => panic!("Expected Mul, got {:?}", simplified),
517 }
518 }
519
520 #[test]
521 fn test_three_matrix_like_terms_combine() {
522 let mat_a = symbol!(A; matrix);
523 let expr = Expression::add(vec![
524 Expression::symbol(mat_a.clone()),
525 Expression::symbol(mat_a.clone()),
526 Expression::symbol(mat_a.clone()),
527 ]);
528 let simplified = expr.simplify();
529
530 match simplified {
531 Expression::Mul(factors) => {
532 assert_eq!(factors.len(), 2);
533 assert_eq!(factors[0], Expression::integer(3));
534 assert_eq!(factors[1], Expression::symbol(symbol!(A; matrix)));
535 }
536 _ => panic!("Expected Mul, got {:?}", simplified),
537 }
538 }
539
540 #[test]
541 fn test_incompatible_matrix_addition_during_simplification() {
542 let a = Expression::matrix(vec![vec![expr!(1), expr!(2)], vec![expr!(3), expr!(4)]]);
543 let b = Expression::matrix(vec![vec![expr!(5), expr!(6), expr!(7)]]);
544
545 let expr = Expression::add(vec![a.clone(), b.clone()]);
546 let simplified = expr.simplify();
547
548 match simplified {
549 Expression::Add(terms) => {
550 assert_eq!(terms.len(), 2);
551 }
552 _ => panic!(
553 "Expected Add with 2 terms for incompatible matrices during simplification, got {:?}",
554 simplified
555 ),
556 }
557 }
558
559 #[test]
560 fn test_pythagorean_identity_sin_cos() {
561 let x = symbol!(x);
562 let sin_x = Expression::function("sin", vec![Expression::symbol(x.clone())]);
563 let cos_x = Expression::function("cos", vec![Expression::symbol(x.clone())]);
564 let sin_squared = Expression::pow(sin_x, Expression::integer(2));
565 let cos_squared = Expression::pow(cos_x, Expression::integer(2));
566
567 let expr = Expression::add(vec![sin_squared, cos_squared]);
568 let simplified = expr.simplify();
569
570 assert_eq!(simplified, Expression::integer(1));
571 }
572
573 #[test]
574 fn test_pythagorean_identity_cos_sin() {
575 let x = symbol!(x);
576 let sin_x = Expression::function("sin", vec![Expression::symbol(x.clone())]);
577 let cos_x = Expression::function("cos", vec![Expression::symbol(x.clone())]);
578 let sin_squared = Expression::pow(sin_x, Expression::integer(2));
579 let cos_squared = Expression::pow(cos_x, Expression::integer(2));
580
581 let expr = Expression::add(vec![cos_squared, sin_squared]);
582 let simplified = expr.simplify();
583
584 assert_eq!(simplified, Expression::integer(1));
585 }
586
587 #[test]
588 fn test_pythagorean_identity_different_args() {
589 let x = symbol!(x);
590 let y = symbol!(y);
591 let sin_x = Expression::function("sin", vec![Expression::symbol(x.clone())]);
592 let cos_y = Expression::function("cos", vec![Expression::symbol(y.clone())]);
593 let sin_squared = Expression::pow(sin_x, Expression::integer(2));
594 let cos_squared = Expression::pow(cos_y, Expression::integer(2));
595
596 let expr = Expression::add(vec![sin_squared, cos_squared]);
597 let simplified = expr.simplify();
598
599 match simplified {
600 Expression::Add(_) => {}
601 _ => panic!("Expected Add (unchanged), got {:?}", simplified),
602 }
603 }
604
605 #[test]
606 fn test_pythagorean_identity_with_additional_terms() {
607 let x = symbol!(x);
608 let y = symbol!(y);
609 let sin_x = Expression::function("sin", vec![Expression::symbol(x.clone())]);
610 let cos_x = Expression::function("cos", vec![Expression::symbol(x.clone())]);
611 let sin_squared = Expression::pow(sin_x, Expression::integer(2));
612 let cos_squared = Expression::pow(cos_x, Expression::integer(2));
613
614 let expr = Expression::add(vec![
615 sin_squared,
616 cos_squared,
617 Expression::symbol(y.clone()),
618 ]);
619 let simplified = expr.simplify();
620
621 assert_eq!(
622 simplified,
623 Expression::add(vec![Expression::integer(1), Expression::symbol(y)])
624 );
625 }
626
627 #[test]
628 fn test_pythagorean_identity_not_squared() {
629 let x = symbol!(x);
630 let sin_x = Expression::function("sin", vec![Expression::symbol(x.clone())]);
631 let cos_x = Expression::function("cos", vec![Expression::symbol(x.clone())]);
632
633 let expr = Expression::add(vec![sin_x, cos_x]);
634 let simplified = expr.simplify();
635
636 match simplified {
637 Expression::Add(_) => {}
638 _ => panic!("Expected Add (unchanged), got {:?}", simplified),
639 }
640 }
641
642 #[test]
643 fn test_distribute_numeric_over_addition() {
644 let x = symbol!(x);
645
646 let expr = Expression::add(vec![Expression::mul(vec![
647 Expression::integer(-1),
648 Expression::add(vec![Expression::symbol(x.clone()), Expression::integer(1)]),
649 ])]);
650
651 let simplified = expr.simplify();
652
653 match &simplified {
654 Expression::Add(terms) => {
655 assert_eq!(terms.len(), 2);
656 let has_neg_one = terms
657 .iter()
658 .any(|t| matches!(t, Expression::Number(Number::Integer(-1))));
659 let has_neg_x = terms.iter().any(|t| {
660 matches!(t, Expression::Mul(factors)
661 if factors.len() == 2
662 && matches!(factors[0], Expression::Number(Number::Integer(-1)))
663 )
664 });
665 assert!(
666 has_neg_one || has_neg_x,
667 "Expected distributed terms, got {:?}",
668 simplified
669 );
670 }
671 _ => panic!("Expected Add with distributed terms, got {:?}", simplified),
672 }
673 }
674}