mathhook_core/simplify/arithmetic/
multiplication.rs

1//! Multiplication simplification operations
2
3mod binary_numeric;
4mod power_combining;
5
6pub use binary_numeric::try_simplify_binary;
7pub use power_combining::combine_like_powers;
8
9use super::addition::simplify_addition;
10use super::helpers::expression_order;
11use super::power::simplify_power;
12use super::Simplify;
13use crate::core::commutativity::Commutativity;
14use crate::core::constants::EPSILON;
15use crate::core::{Expression, Number};
16use num_bigint::BigInt;
17use num_rational::BigRational;
18use num_traits::{One, ToPrimitive, Zero};
19
20/// Simplify multiplication with minimal overhead and flattening
21pub fn simplify_multiplication(factors: &[Expression]) -> Expression {
22    if factors.is_empty() {
23        return Expression::integer(1);
24    }
25    if factors.len() == 1 {
26        return factors[0].clone();
27    }
28
29    let mut flattened_factors = Vec::new();
30    let mut to_process: Vec<&Expression> = factors.iter().collect();
31
32    while !to_process.is_empty() {
33        let factor = to_process.remove(0);
34        match factor {
35            Expression::Mul(nested_factors) => {
36                for (i, nested) in nested_factors.iter().enumerate() {
37                    to_process.insert(i, nested);
38                }
39            }
40            _ => {
41                let simplified = match factor {
42                    Expression::Add(terms) => simplify_addition(terms),
43                    Expression::Pow(base, exp) => simplify_power(base, exp),
44                    _ => factor.simplify(),
45                };
46                flattened_factors.push(simplified);
47            }
48        }
49    }
50
51    let factors = &flattened_factors;
52
53    if factors.len() == 2 {
54        if let Some(result) = try_simplify_binary(&factors[0], &factors[1]) {
55            return result;
56        }
57
58        // Matrix fast-path: try direct matrix multiplication
59        // Note: During simplification, we only apply if it succeeds.
60        // Dimension errors will be caught during evaluation, not simplification.
61        if let Some(Ok(result)) = super::matrix_ops::try_matrix_multiply(&factors[0], &factors[1]) {
62            return result;
63        }
64
65        // Handle Add simplification special cases
66        match (&factors[0], &factors[1]) {
67            (a, Expression::Add(terms)) => {
68                let simplified_add = simplify_addition(terms);
69                if !matches!(simplified_add, Expression::Add(_)) {
70                    return simplify_multiplication(&[a.clone(), simplified_add]);
71                }
72            }
73            (Expression::Add(terms), b) => {
74                let simplified_add = simplify_addition(terms);
75                if !matches!(simplified_add, Expression::Add(_)) {
76                    return simplify_multiplication(&[simplified_add, b.clone()]);
77                }
78            }
79            _ => {}
80        }
81    }
82
83    let mut all_integers = true;
84    let mut integer_product = 1i64;
85    for factor in factors {
86        match factor {
87            Expression::Number(Number::Integer(n)) => {
88                integer_product = integer_product.saturating_mul(*n);
89            }
90            _ => {
91                all_integers = false;
92                break;
93            }
94        }
95    }
96    if all_integers && factors.len() > 2 {
97        return Expression::integer(integer_product);
98    }
99
100    let mut int_product = 1i64;
101    let mut float_product = 1.0;
102    let mut has_float = false;
103    let mut non_numeric_count = 0;
104    let mut first_non_numeric = None;
105    let mut numeric_result = None;
106
107    let mut rational_product: Option<BigRational> = None;
108
109    let has_undefined = factors
110        .iter()
111        .any(|f| matches!(f, Expression::Function { name, .. } if name == "undefined"));
112
113    for factor in factors {
114        match factor {
115            Expression::Number(Number::Integer(n)) => {
116                int_product = int_product.saturating_mul(*n);
117                if int_product == 0 && !has_undefined {
118                    return Expression::integer(0);
119                }
120            }
121            Expression::Number(Number::Float(f)) => {
122                float_product *= f;
123                has_float = true;
124                if float_product.abs() < EPSILON && !has_undefined {
125                    return Expression::integer(0);
126                }
127            }
128            Expression::Number(Number::Rational(r)) => {
129                if let Some(ref mut current_rational) = rational_product {
130                    *current_rational *= r.as_ref();
131                } else {
132                    rational_product = Some(r.as_ref().clone());
133                }
134                if rational_product
135                    .as_ref()
136                    .expect("BUG: rational_product should be Some at this point")
137                    .is_zero()
138                    && !has_undefined
139                {
140                    return Expression::integer(0);
141                }
142            }
143            _ => {
144                non_numeric_count += 1;
145                if first_non_numeric.is_none() {
146                    first_non_numeric = Some(factor);
147                }
148            }
149        }
150    }
151
152    if let Some(rational) = rational_product {
153        let mut final_rational = rational;
154        if int_product != 1 {
155            final_rational *= BigRational::from(BigInt::from(int_product));
156        }
157        if has_float {
158            let float_val = final_rational.to_f64().unwrap_or(0.0) * float_product;
159            if (float_val - 1.0).abs() >= EPSILON {
160                numeric_result = Some(Expression::Number(Number::float(float_val)));
161            }
162        } else if final_rational.denom() == &BigInt::from(1) {
163            if let Some(int_val) = final_rational.numer().to_i64() {
164                if int_val != 1 {
165                    numeric_result = Some(Expression::integer(int_val));
166                }
167            } else if !final_rational.is_one() {
168                numeric_result = Some(Expression::Number(Number::rational(final_rational)));
169            }
170        } else if !final_rational.is_one() {
171            numeric_result = Some(Expression::Number(Number::rational(final_rational)));
172        }
173    } else if has_float {
174        let total = int_product as f64 * float_product;
175        if (total - 1.0).abs() >= EPSILON {
176            numeric_result = Some(Expression::Number(Number::float(total)));
177        }
178    } else if int_product != 1 {
179        numeric_result = Some(Expression::integer(int_product));
180    }
181
182    match (numeric_result.as_ref(), non_numeric_count) {
183        (None, 0) => Expression::integer(1),
184        (Some(num), 0) => num.clone(),
185        (None, 1) => {
186            let factor = first_non_numeric
187                .expect("BUG: non_numeric_count is 1 but first_non_numeric is None");
188            match factor {
189                Expression::Add(terms) => simplify_addition(terms),
190                Expression::Pow(base, exp) => simplify_power(base, exp),
191                _ => factor.simplify(),
192            }
193        }
194        (Some(num), 1) => {
195            let factor = first_non_numeric
196                .expect("BUG: non_numeric_count is 1 but first_non_numeric is None");
197            let simplified_non_numeric = match factor {
198                Expression::Add(terms) => simplify_addition(terms),
199                Expression::Pow(base, exp) => simplify_power(base, exp),
200                _ => factor.simplify(),
201            };
202            match num {
203                Expression::Number(Number::Integer(1)) => simplified_non_numeric,
204                Expression::Number(Number::Float(f)) if (f - 1.0).abs() < EPSILON => {
205                    simplified_non_numeric
206                }
207                _ => Expression::Mul(Box::new(vec![num.clone(), simplified_non_numeric])),
208            }
209        }
210        _ => {
211            let mut result_factors = Vec::with_capacity(non_numeric_count + 1);
212            if let Some(num) = numeric_result {
213                match num {
214                    Expression::Number(Number::Integer(1)) => {}
215                    Expression::Number(Number::Float(1.0)) => {}
216                    _ => result_factors.push(num),
217                }
218            }
219            for factor in factors {
220                if !matches!(factor, Expression::Number(_)) {
221                    let simplified_factor = match factor {
222                        Expression::Add(terms) => simplify_addition(terms),
223                        Expression::Pow(base, exp) => simplify_power(base, exp),
224                        _ => factor.simplify(),
225                    };
226                    result_factors.push(simplified_factor);
227                }
228            }
229            match result_factors.len() {
230                0 => Expression::integer(1),
231                1 => result_factors
232                    .into_iter()
233                    .next()
234                    .expect("BUG: result_factors has length 1 but iterator is empty"),
235                _ => {
236                    let commutativity =
237                        Commutativity::combine(result_factors.iter().map(|f| f.commutativity()));
238
239                    if commutativity.can_sort() {
240                        result_factors = combine_like_powers(result_factors);
241                        result_factors.sort_by(expression_order);
242                    }
243
244                    match result_factors.len() {
245                        0 => Expression::integer(1),
246                        1 => result_factors.into_iter().next().unwrap(),
247                        _ => Expression::Mul(Box::new(result_factors)),
248                    }
249                }
250            }
251        }
252    }
253}
254
255#[cfg(test)]
256mod tests {
257    use super::*;
258    use crate::simplify::Simplify;
259    use crate::symbol;
260    use crate::Expression;
261
262    #[test]
263    fn test_multiplication_simplification() {
264        let expr = simplify_multiplication(&[Expression::integer(2), Expression::integer(3)]);
265        assert_eq!(expr, Expression::integer(6));
266
267        let expr = simplify_multiplication(&[Expression::integer(5), Expression::integer(1)]);
268        assert_eq!(expr, Expression::integer(5));
269
270        let expr = simplify_multiplication(&[Expression::integer(5), Expression::integer(0)]);
271        assert_eq!(expr, Expression::integer(0));
272    }
273
274    #[test]
275    fn test_nested_multiplication_flattening() {
276        let nested = Expression::mul(vec![Expression::integer(3), Expression::integer(4)]);
277        let expr = simplify_multiplication(&[Expression::integer(2), nested]);
278        assert_eq!(expr, Expression::integer(24));
279    }
280
281    #[test]
282    fn test_scalar_multiplication_sorts() {
283        let y = symbol!(y);
284        let x = symbol!(x);
285        let expr = Expression::mul(vec![
286            Expression::symbol(y.clone()),
287            Expression::symbol(x.clone()),
288        ]);
289        let simplified = expr.simplify();
290
291        match simplified {
292            Expression::Mul(factors) => {
293                assert_eq!(factors.len(), 2);
294                assert_eq!(factors[0], Expression::symbol(symbol!(x)));
295                assert_eq!(factors[1], Expression::symbol(symbol!(y)));
296            }
297            _ => panic!("Expected Mul, got {:?}", simplified),
298        }
299    }
300
301    #[test]
302    fn test_matrix_multiplication_preserves_order() {
303        let mat_a = symbol!(A; matrix);
304        let mat_b = symbol!(B; matrix);
305        let expr = Expression::mul(vec![
306            Expression::symbol(mat_b.clone()),
307            Expression::symbol(mat_a.clone()),
308        ]);
309        let simplified = expr.simplify();
310
311        match simplified {
312            Expression::Mul(factors) => {
313                assert_eq!(factors.len(), 2);
314                assert_eq!(factors[0], Expression::symbol(symbol!(B; matrix)));
315                assert_eq!(factors[1], Expression::symbol(symbol!(A; matrix)));
316            }
317            _ => panic!("Expected Mul, got {:?}", simplified),
318        }
319    }
320
321    #[test]
322    fn test_mixed_scalar_matrix_preserves_order() {
323        let x = symbol!(x);
324        let mat_a = symbol!(A; matrix);
325        let expr = Expression::mul(vec![
326            Expression::symbol(x.clone()),
327            Expression::symbol(mat_a.clone()),
328        ]);
329        let simplified = expr.simplify();
330
331        match simplified {
332            Expression::Mul(factors) => {
333                assert_eq!(factors.len(), 2);
334                assert_eq!(factors[0], Expression::symbol(symbol!(x)));
335                assert_eq!(factors[1], Expression::symbol(symbol!(A; matrix)));
336            }
337            _ => panic!("Expected Mul, got {:?}", simplified),
338        }
339    }
340
341    #[test]
342    fn test_operator_multiplication_preserves_order() {
343        let mat_p = symbol!(P; operator);
344        let mat_q = symbol!(Q; operator);
345        let expr = Expression::mul(vec![
346            Expression::symbol(mat_q.clone()),
347            Expression::symbol(mat_p.clone()),
348        ]);
349        let simplified = expr.simplify();
350
351        match simplified {
352            Expression::Mul(factors) => {
353                assert_eq!(factors.len(), 2);
354                assert_eq!(factors[0], Expression::symbol(symbol!(Q; operator)));
355                assert_eq!(factors[1], Expression::symbol(symbol!(P; operator)));
356            }
357            _ => panic!("Expected Mul, got {:?}", simplified),
358        }
359    }
360
361    #[test]
362    fn test_quaternion_multiplication_preserves_order() {
363        let i = symbol!(i; quaternion);
364        let j = symbol!(j; quaternion);
365        let expr = Expression::mul(vec![
366            Expression::symbol(j.clone()),
367            Expression::symbol(i.clone()),
368        ]);
369        let simplified = expr.simplify();
370
371        match simplified {
372            Expression::Mul(factors) => {
373                assert_eq!(factors.len(), 2);
374                assert_eq!(factors[0], Expression::symbol(symbol!(j; quaternion)));
375                assert_eq!(factors[1], Expression::symbol(symbol!(i; quaternion)));
376            }
377            _ => panic!("Expected Mul, got {:?}", simplified),
378        }
379    }
380
381    #[test]
382    fn test_three_scalar_factors_sort() {
383        let z = symbol!(z);
384        let x = symbol!(x);
385        let y = symbol!(y);
386        let expr = Expression::mul(vec![
387            Expression::symbol(z.clone()),
388            Expression::symbol(x.clone()),
389            Expression::symbol(y.clone()),
390        ]);
391        let simplified = expr.simplify();
392
393        match simplified {
394            Expression::Mul(factors) => {
395                assert_eq!(factors.len(), 3);
396                assert_eq!(factors[0], Expression::symbol(symbol!(x)));
397                assert_eq!(factors[1], Expression::symbol(symbol!(y)));
398                assert_eq!(factors[2], Expression::symbol(symbol!(z)));
399            }
400            _ => panic!("Expected Mul, got {:?}", simplified),
401        }
402    }
403
404    #[test]
405    fn test_three_matrix_factors_preserve_order() {
406        let mat_c = symbol!(C; matrix);
407        let mat_a = symbol!(A; matrix);
408        let mat_b = symbol!(B; matrix);
409        let expr = Expression::mul(vec![
410            Expression::symbol(mat_c.clone()),
411            Expression::symbol(mat_a.clone()),
412            Expression::symbol(mat_b.clone()),
413        ]);
414        let simplified = expr.simplify();
415
416        match simplified {
417            Expression::Mul(factors) => {
418                assert_eq!(factors.len(), 3);
419                assert_eq!(factors[0], Expression::symbol(symbol!(C; matrix)));
420                assert_eq!(factors[1], Expression::symbol(symbol!(A; matrix)));
421                assert_eq!(factors[2], Expression::symbol(symbol!(B; matrix)));
422            }
423            _ => panic!("Expected Mul, got {:?}", simplified),
424        }
425    }
426
427    #[test]
428    fn test_numeric_coefficient_with_scalars_sorts() {
429        let y = symbol!(y);
430        let x = symbol!(x);
431        let expr = Expression::mul(vec![
432            Expression::integer(2),
433            Expression::symbol(y.clone()),
434            Expression::symbol(x.clone()),
435        ]);
436        let simplified = expr.simplify();
437
438        match simplified {
439            Expression::Mul(factors) => {
440                assert_eq!(factors.len(), 3);
441                assert_eq!(factors[0], Expression::integer(2));
442                assert_eq!(factors[1], Expression::symbol(symbol!(x)));
443                assert_eq!(factors[2], Expression::symbol(symbol!(y)));
444            }
445            _ => panic!("Expected Mul, got {:?}", simplified),
446        }
447    }
448
449    #[test]
450    fn test_numeric_coefficient_with_matrices_preserves_order() {
451        let mat_b = symbol!(B; matrix);
452        let mat_a = symbol!(A; matrix);
453        let expr = Expression::mul(vec![
454            Expression::integer(2),
455            Expression::symbol(mat_b.clone()),
456            Expression::symbol(mat_a.clone()),
457        ]);
458        let simplified = expr.simplify();
459
460        match simplified {
461            Expression::Mul(factors) => {
462                assert_eq!(factors.len(), 3);
463                assert_eq!(factors[0], Expression::integer(2));
464                assert_eq!(factors[1], Expression::symbol(symbol!(B; matrix)));
465                assert_eq!(factors[2], Expression::symbol(symbol!(A; matrix)));
466            }
467            _ => panic!("Expected Mul, got {:?}", simplified),
468        }
469    }
470}