mathhook_core/simplify/arithmetic/
multiplication.rs

1//! Multiplication simplification operations
2
3mod binary_numeric;
4mod power_combining;
5
6pub use binary_numeric::try_simplify_binary;
7pub use power_combining::combine_like_powers;
8
9use super::addition::simplify_addition;
10use super::helpers::expression_order;
11use super::power::simplify_power;
12use super::Simplify;
13use crate::core::commutativity::Commutativity;
14use crate::core::constants::EPSILON;
15use crate::core::{Expression, Number};
16use num_bigint::BigInt;
17use num_rational::BigRational;
18use num_traits::{One, ToPrimitive, Zero};
19use std::sync::Arc;
20
21/// Simplify multiplication with minimal overhead and flattening
22pub fn simplify_multiplication(factors: &[Expression]) -> Expression {
23    if factors.is_empty() {
24        return Expression::integer(1);
25    }
26    if factors.len() == 1 {
27        return factors[0].clone();
28    }
29
30    let mut flattened_factors = Vec::new();
31    let mut to_process: Vec<&Expression> = factors.iter().collect();
32
33    while !to_process.is_empty() {
34        let factor = to_process.remove(0);
35        match factor {
36            Expression::Mul(nested_factors) => {
37                for (i, nested) in nested_factors.iter().enumerate() {
38                    to_process.insert(i, nested);
39                }
40            }
41            _ => {
42                let simplified = match factor {
43                    Expression::Add(terms) => simplify_addition(terms),
44                    Expression::Pow(base, exp) => simplify_power(base.as_ref(), exp.as_ref()),
45                    _ => factor.simplify(),
46                };
47                flattened_factors.push(simplified);
48            }
49        }
50    }
51
52    let factors = &flattened_factors;
53
54    if factors.len() == 2 {
55        if let Some(result) = try_simplify_binary(&factors[0], &factors[1]) {
56            return result;
57        }
58
59        if let Some(Ok(result)) = super::matrix_ops::try_matrix_multiply(&factors[0], &factors[1]) {
60            return result;
61        }
62
63        match (&factors[0], &factors[1]) {
64            (a, Expression::Add(terms)) => {
65                let simplified_add = simplify_addition(terms);
66                if !matches!(simplified_add, Expression::Add(_)) {
67                    return simplify_multiplication(&[a.clone(), simplified_add]);
68                }
69            }
70            (Expression::Add(terms), b) => {
71                let simplified_add = simplify_addition(terms);
72                if !matches!(simplified_add, Expression::Add(_)) {
73                    return simplify_multiplication(&[simplified_add, b.clone()]);
74                }
75            }
76            _ => {}
77        }
78    }
79
80    let mut all_integers = true;
81    let mut integer_product = 1i64;
82    for factor in factors {
83        match factor {
84            Expression::Number(Number::Integer(n)) => {
85                integer_product = integer_product.saturating_mul(*n);
86            }
87            _ => {
88                all_integers = false;
89                break;
90            }
91        }
92    }
93    if all_integers && factors.len() > 2 {
94        return Expression::integer(integer_product);
95    }
96
97    let mut int_product = 1i64;
98    let mut float_product = 1.0;
99    let mut has_float = false;
100    let mut non_numeric_count = 0;
101    let mut first_non_numeric = None;
102    let mut numeric_result = None;
103
104    let mut rational_product: Option<BigRational> = None;
105
106    let has_undefined = factors
107        .iter()
108        .any(|f| matches!(f, Expression::Function { name, .. } if name.as_ref() == "undefined"));
109
110    for factor in factors {
111        match factor {
112            Expression::Number(Number::Integer(n)) => {
113                int_product = int_product.saturating_mul(*n);
114                if int_product == 0 && !has_undefined {
115                    return Expression::integer(0);
116                }
117            }
118            Expression::Number(Number::Float(f)) => {
119                float_product *= f;
120                has_float = true;
121                if float_product.abs() < EPSILON && !has_undefined {
122                    return Expression::integer(0);
123                }
124            }
125            Expression::Number(Number::Rational(r)) => {
126                if let Some(ref mut current_rational) = rational_product {
127                    *current_rational *= r.as_ref();
128                } else {
129                    rational_product = Some(r.as_ref().clone());
130                }
131                if rational_product
132                    .as_ref()
133                    .expect("BUG: rational_product should be Some at this point")
134                    .is_zero()
135                    && !has_undefined
136                {
137                    return Expression::integer(0);
138                }
139            }
140            _ => {
141                non_numeric_count += 1;
142                if first_non_numeric.is_none() {
143                    first_non_numeric = Some(factor);
144                }
145            }
146        }
147    }
148
149    if let Some(rational) = rational_product {
150        let mut final_rational = rational;
151        if int_product != 1 {
152            final_rational *= BigRational::from(BigInt::from(int_product));
153        }
154        if has_float {
155            let float_val = final_rational.to_f64().unwrap_or(0.0) * float_product;
156            if (float_val - 1.0).abs() >= EPSILON {
157                numeric_result = Some(Expression::Number(Number::float(float_val)));
158            }
159        } else if final_rational.denom() == &BigInt::from(1) {
160            if let Some(int_val) = final_rational.numer().to_i64() {
161                if int_val != 1 {
162                    numeric_result = Some(Expression::integer(int_val));
163                }
164            } else if !final_rational.is_one() {
165                numeric_result = Some(Expression::Number(Number::rational(final_rational)));
166            }
167        } else if !final_rational.is_one() {
168            numeric_result = Some(Expression::Number(Number::rational(final_rational)));
169        }
170    } else if has_float {
171        let total = int_product as f64 * float_product;
172        if (total - 1.0).abs() >= EPSILON {
173            numeric_result = Some(Expression::Number(Number::float(total)));
174        }
175    } else if int_product != 1 {
176        numeric_result = Some(Expression::integer(int_product));
177    }
178
179    match (numeric_result.as_ref(), non_numeric_count) {
180        (None, 0) => Expression::integer(1),
181        (Some(num), 0) => num.clone(),
182        (None, 1) => {
183            let factor = first_non_numeric
184                .expect("BUG: non_numeric_count is 1 but first_non_numeric is None");
185            match factor {
186                Expression::Add(terms) => simplify_addition(terms),
187                Expression::Pow(base, exp) => simplify_power(base.as_ref(), exp.as_ref()),
188                _ => factor.simplify(),
189            }
190        }
191        (Some(num), 1) => {
192            let factor = first_non_numeric
193                .expect("BUG: non_numeric_count is 1 but first_non_numeric is None");
194            let simplified_non_numeric = match factor {
195                Expression::Add(terms) => simplify_addition(terms),
196                Expression::Pow(base, exp) => simplify_power(base.as_ref(), exp.as_ref()),
197                _ => factor.simplify(),
198            };
199            match num {
200                Expression::Number(Number::Integer(1)) => simplified_non_numeric,
201                Expression::Number(Number::Float(f)) if (f - 1.0).abs() < EPSILON => {
202                    simplified_non_numeric
203                }
204                _ => Expression::Mul(Arc::new(vec![num.clone(), simplified_non_numeric])),
205            }
206        }
207        _ => {
208            let mut result_factors = Vec::with_capacity(non_numeric_count + 1);
209            if let Some(num) = numeric_result {
210                match num {
211                    Expression::Number(Number::Integer(1)) => {}
212                    Expression::Number(Number::Float(1.0)) => {}
213                    _ => result_factors.push(num),
214                }
215            }
216            for factor in factors {
217                if !matches!(factor, Expression::Number(_)) {
218                    let simplified_factor = match factor {
219                        Expression::Add(terms) => simplify_addition(terms),
220                        Expression::Pow(base, exp) => simplify_power(base.as_ref(), exp.as_ref()),
221                        _ => factor.simplify(),
222                    };
223                    result_factors.push(simplified_factor);
224                }
225            }
226            match result_factors.len() {
227                0 => Expression::integer(1),
228                1 => result_factors
229                    .into_iter()
230                    .next()
231                    .expect("BUG: result_factors has length 1 but iterator is empty"),
232                _ => {
233                    let commutativity =
234                        Commutativity::combine(result_factors.iter().map(|f| f.commutativity()));
235
236                    if commutativity.can_sort() {
237                        result_factors = combine_like_powers(result_factors);
238                        result_factors.sort_by(expression_order);
239                    }
240
241                    match result_factors.len() {
242                        0 => Expression::integer(1),
243                        1 => result_factors.into_iter().next().unwrap(),
244                        _ => Expression::Mul(Arc::new(result_factors)),
245                    }
246                }
247            }
248        }
249    }
250}
251
252#[cfg(test)]
253mod tests {
254    use super::*;
255    use crate::simplify::Simplify;
256    use crate::symbol;
257    use crate::Expression;
258
259    #[test]
260    fn test_multiplication_simplification() {
261        let expr = simplify_multiplication(&[Expression::integer(2), Expression::integer(3)]);
262        assert_eq!(expr, Expression::integer(6));
263
264        let expr = simplify_multiplication(&[Expression::integer(5), Expression::integer(1)]);
265        assert_eq!(expr, Expression::integer(5));
266
267        let expr = simplify_multiplication(&[Expression::integer(5), Expression::integer(0)]);
268        assert_eq!(expr, Expression::integer(0));
269    }
270
271    #[test]
272    fn test_nested_multiplication_flattening() {
273        let nested = Expression::mul(vec![Expression::integer(3), Expression::integer(4)]);
274        let expr = simplify_multiplication(&[Expression::integer(2), nested]);
275        assert_eq!(expr, Expression::integer(24));
276    }
277
278    #[test]
279    fn test_scalar_multiplication_sorts() {
280        let y = symbol!(y);
281        let x = symbol!(x);
282        let expr = Expression::mul(vec![
283            Expression::symbol(y.clone()),
284            Expression::symbol(x.clone()),
285        ]);
286        let simplified = expr.simplify();
287
288        match simplified {
289            Expression::Mul(factors) => {
290                assert_eq!(factors.len(), 2);
291                assert_eq!(factors[0], Expression::symbol(symbol!(x)));
292                assert_eq!(factors[1], Expression::symbol(symbol!(y)));
293            }
294            _ => panic!("Expected Mul, got {:?}", simplified),
295        }
296    }
297
298    #[test]
299    fn test_matrix_multiplication_preserves_order() {
300        let mat_a = symbol!(A; matrix);
301        let mat_b = symbol!(B; matrix);
302        let expr = Expression::mul(vec![
303            Expression::symbol(mat_b.clone()),
304            Expression::symbol(mat_a.clone()),
305        ]);
306        let simplified = expr.simplify();
307
308        match simplified {
309            Expression::Mul(factors) => {
310                assert_eq!(factors.len(), 2);
311                assert_eq!(factors[0], Expression::symbol(symbol!(B; matrix)));
312                assert_eq!(factors[1], Expression::symbol(symbol!(A; matrix)));
313            }
314            _ => panic!("Expected Mul, got {:?}", simplified),
315        }
316    }
317
318    #[test]
319    fn test_mixed_scalar_matrix_preserves_order() {
320        let x = symbol!(x);
321        let mat_a = symbol!(A; matrix);
322        let expr = Expression::mul(vec![
323            Expression::symbol(x.clone()),
324            Expression::symbol(mat_a.clone()),
325        ]);
326        let simplified = expr.simplify();
327
328        match simplified {
329            Expression::Mul(factors) => {
330                assert_eq!(factors.len(), 2);
331                assert_eq!(factors[0], Expression::symbol(symbol!(x)));
332                assert_eq!(factors[1], Expression::symbol(symbol!(A; matrix)));
333            }
334            _ => panic!("Expected Mul, got {:?}", simplified),
335        }
336    }
337
338    #[test]
339    fn test_operator_multiplication_preserves_order() {
340        let mat_p = symbol!(P; operator);
341        let mat_q = symbol!(Q; operator);
342        let expr = Expression::mul(vec![
343            Expression::symbol(mat_q.clone()),
344            Expression::symbol(mat_p.clone()),
345        ]);
346        let simplified = expr.simplify();
347
348        match simplified {
349            Expression::Mul(factors) => {
350                assert_eq!(factors.len(), 2);
351                assert_eq!(factors[0], Expression::symbol(symbol!(Q; operator)));
352                assert_eq!(factors[1], Expression::symbol(symbol!(P; operator)));
353            }
354            _ => panic!("Expected Mul, got {:?}", simplified),
355        }
356    }
357
358    #[test]
359    fn test_quaternion_multiplication_preserves_order() {
360        let i = symbol!(i; quaternion);
361        let j = symbol!(j; quaternion);
362        let expr = Expression::mul(vec![
363            Expression::symbol(j.clone()),
364            Expression::symbol(i.clone()),
365        ]);
366        let simplified = expr.simplify();
367
368        match simplified {
369            Expression::Mul(factors) => {
370                assert_eq!(factors.len(), 2);
371                assert_eq!(factors[0], Expression::symbol(symbol!(j; quaternion)));
372                assert_eq!(factors[1], Expression::symbol(symbol!(i; quaternion)));
373            }
374            _ => panic!("Expected Mul, got {:?}", simplified),
375        }
376    }
377
378    #[test]
379    fn test_three_scalar_factors_sort() {
380        let z = symbol!(z);
381        let x = symbol!(x);
382        let y = symbol!(y);
383        let expr = Expression::mul(vec![
384            Expression::symbol(z.clone()),
385            Expression::symbol(x.clone()),
386            Expression::symbol(y.clone()),
387        ]);
388        let simplified = expr.simplify();
389
390        match simplified {
391            Expression::Mul(factors) => {
392                assert_eq!(factors.len(), 3);
393                assert_eq!(factors[0], Expression::symbol(symbol!(x)));
394                assert_eq!(factors[1], Expression::symbol(symbol!(y)));
395                assert_eq!(factors[2], Expression::symbol(symbol!(z)));
396            }
397            _ => panic!("Expected Mul, got {:?}", simplified),
398        }
399    }
400
401    #[test]
402    fn test_three_matrix_factors_preserve_order() {
403        let mat_c = symbol!(C; matrix);
404        let mat_a = symbol!(A; matrix);
405        let mat_b = symbol!(B; matrix);
406        let expr = Expression::mul(vec![
407            Expression::symbol(mat_c.clone()),
408            Expression::symbol(mat_a.clone()),
409            Expression::symbol(mat_b.clone()),
410        ]);
411        let simplified = expr.simplify();
412
413        match simplified {
414            Expression::Mul(factors) => {
415                assert_eq!(factors.len(), 3);
416                assert_eq!(factors[0], Expression::symbol(symbol!(C; matrix)));
417                assert_eq!(factors[1], Expression::symbol(symbol!(A; matrix)));
418                assert_eq!(factors[2], Expression::symbol(symbol!(B; matrix)));
419            }
420            _ => panic!("Expected Mul, got {:?}", simplified),
421        }
422    }
423
424    #[test]
425    fn test_numeric_coefficient_with_scalars_sorts() {
426        let y = symbol!(y);
427        let x = symbol!(x);
428        let expr = Expression::mul(vec![
429            Expression::integer(2),
430            Expression::symbol(y.clone()),
431            Expression::symbol(x.clone()),
432        ]);
433        let simplified = expr.simplify();
434
435        match simplified {
436            Expression::Mul(factors) => {
437                assert_eq!(factors.len(), 3);
438                assert_eq!(factors[0], Expression::integer(2));
439                assert_eq!(factors[1], Expression::symbol(symbol!(x)));
440                assert_eq!(factors[2], Expression::symbol(symbol!(y)));
441            }
442            _ => panic!("Expected Mul, got {:?}", simplified),
443        }
444    }
445
446    #[test]
447    fn test_numeric_coefficient_with_matrices_preserves_order() {
448        let mat_b = symbol!(B; matrix);
449        let mat_a = symbol!(A; matrix);
450        let expr = Expression::mul(vec![
451            Expression::integer(2),
452            Expression::symbol(mat_b.clone()),
453            Expression::symbol(mat_a.clone()),
454        ]);
455        let simplified = expr.simplify();
456
457        match simplified {
458            Expression::Mul(factors) => {
459                assert_eq!(factors.len(), 3);
460                assert_eq!(factors[0], Expression::integer(2));
461                assert_eq!(factors[1], Expression::symbol(symbol!(B; matrix)));
462                assert_eq!(factors[2], Expression::symbol(symbol!(A; matrix)));
463            }
464            _ => panic!("Expected Mul, got {:?}", simplified),
465        }
466    }
467}