mathhook_core/simplify/arithmetic/
multiplication.rs1mod binary_numeric;
4mod power_combining;
5
6pub use binary_numeric::try_simplify_binary;
7pub use power_combining::combine_like_powers;
8
9use super::addition::simplify_addition;
10use super::helpers::expression_order;
11use super::power::simplify_power;
12use super::Simplify;
13use crate::core::commutativity::Commutativity;
14use crate::core::constants::EPSILON;
15use crate::core::{Expression, Number};
16use num_bigint::BigInt;
17use num_rational::BigRational;
18use num_traits::{One, ToPrimitive, Zero};
19use std::sync::Arc;
20
21pub fn simplify_multiplication(factors: &[Expression]) -> Expression {
23 if factors.is_empty() {
24 return Expression::integer(1);
25 }
26 if factors.len() == 1 {
27 return factors[0].clone();
28 }
29
30 let mut flattened_factors = Vec::new();
31 let mut to_process: Vec<&Expression> = factors.iter().collect();
32
33 while !to_process.is_empty() {
34 let factor = to_process.remove(0);
35 match factor {
36 Expression::Mul(nested_factors) => {
37 for (i, nested) in nested_factors.iter().enumerate() {
38 to_process.insert(i, nested);
39 }
40 }
41 _ => {
42 let simplified = match factor {
43 Expression::Add(terms) => simplify_addition(terms),
44 Expression::Pow(base, exp) => simplify_power(base.as_ref(), exp.as_ref()),
45 _ => factor.simplify(),
46 };
47 flattened_factors.push(simplified);
48 }
49 }
50 }
51
52 let factors = &flattened_factors;
53
54 if factors.len() == 2 {
55 if let Some(result) = try_simplify_binary(&factors[0], &factors[1]) {
56 return result;
57 }
58
59 if let Some(Ok(result)) = super::matrix_ops::try_matrix_multiply(&factors[0], &factors[1]) {
60 return result;
61 }
62
63 match (&factors[0], &factors[1]) {
64 (a, Expression::Add(terms)) => {
65 let simplified_add = simplify_addition(terms);
66 if !matches!(simplified_add, Expression::Add(_)) {
67 return simplify_multiplication(&[a.clone(), simplified_add]);
68 }
69 }
70 (Expression::Add(terms), b) => {
71 let simplified_add = simplify_addition(terms);
72 if !matches!(simplified_add, Expression::Add(_)) {
73 return simplify_multiplication(&[simplified_add, b.clone()]);
74 }
75 }
76 _ => {}
77 }
78 }
79
80 let mut all_integers = true;
81 let mut integer_product = 1i64;
82 for factor in factors {
83 match factor {
84 Expression::Number(Number::Integer(n)) => {
85 integer_product = integer_product.saturating_mul(*n);
86 }
87 _ => {
88 all_integers = false;
89 break;
90 }
91 }
92 }
93 if all_integers && factors.len() > 2 {
94 return Expression::integer(integer_product);
95 }
96
97 let mut int_product = 1i64;
98 let mut float_product = 1.0;
99 let mut has_float = false;
100 let mut non_numeric_count = 0;
101 let mut first_non_numeric = None;
102 let mut numeric_result = None;
103
104 let mut rational_product: Option<BigRational> = None;
105
106 let has_undefined = factors
107 .iter()
108 .any(|f| matches!(f, Expression::Function { name, .. } if name.as_ref() == "undefined"));
109
110 for factor in factors {
111 match factor {
112 Expression::Number(Number::Integer(n)) => {
113 int_product = int_product.saturating_mul(*n);
114 if int_product == 0 && !has_undefined {
115 return Expression::integer(0);
116 }
117 }
118 Expression::Number(Number::Float(f)) => {
119 float_product *= f;
120 has_float = true;
121 if float_product.abs() < EPSILON && !has_undefined {
122 return Expression::integer(0);
123 }
124 }
125 Expression::Number(Number::Rational(r)) => {
126 if let Some(ref mut current_rational) = rational_product {
127 *current_rational *= r.as_ref();
128 } else {
129 rational_product = Some(r.as_ref().clone());
130 }
131 if rational_product
132 .as_ref()
133 .expect("BUG: rational_product should be Some at this point")
134 .is_zero()
135 && !has_undefined
136 {
137 return Expression::integer(0);
138 }
139 }
140 _ => {
141 non_numeric_count += 1;
142 if first_non_numeric.is_none() {
143 first_non_numeric = Some(factor);
144 }
145 }
146 }
147 }
148
149 if let Some(rational) = rational_product {
150 let mut final_rational = rational;
151 if int_product != 1 {
152 final_rational *= BigRational::from(BigInt::from(int_product));
153 }
154 if has_float {
155 let float_val = final_rational.to_f64().unwrap_or(0.0) * float_product;
156 if (float_val - 1.0).abs() >= EPSILON {
157 numeric_result = Some(Expression::Number(Number::float(float_val)));
158 }
159 } else if final_rational.denom() == &BigInt::from(1) {
160 if let Some(int_val) = final_rational.numer().to_i64() {
161 if int_val != 1 {
162 numeric_result = Some(Expression::integer(int_val));
163 }
164 } else if !final_rational.is_one() {
165 numeric_result = Some(Expression::Number(Number::rational(final_rational)));
166 }
167 } else if !final_rational.is_one() {
168 numeric_result = Some(Expression::Number(Number::rational(final_rational)));
169 }
170 } else if has_float {
171 let total = int_product as f64 * float_product;
172 if (total - 1.0).abs() >= EPSILON {
173 numeric_result = Some(Expression::Number(Number::float(total)));
174 }
175 } else if int_product != 1 {
176 numeric_result = Some(Expression::integer(int_product));
177 }
178
179 match (numeric_result.as_ref(), non_numeric_count) {
180 (None, 0) => Expression::integer(1),
181 (Some(num), 0) => num.clone(),
182 (None, 1) => {
183 let factor = first_non_numeric
184 .expect("BUG: non_numeric_count is 1 but first_non_numeric is None");
185 match factor {
186 Expression::Add(terms) => simplify_addition(terms),
187 Expression::Pow(base, exp) => simplify_power(base.as_ref(), exp.as_ref()),
188 _ => factor.simplify(),
189 }
190 }
191 (Some(num), 1) => {
192 let factor = first_non_numeric
193 .expect("BUG: non_numeric_count is 1 but first_non_numeric is None");
194 let simplified_non_numeric = match factor {
195 Expression::Add(terms) => simplify_addition(terms),
196 Expression::Pow(base, exp) => simplify_power(base.as_ref(), exp.as_ref()),
197 _ => factor.simplify(),
198 };
199 match num {
200 Expression::Number(Number::Integer(1)) => simplified_non_numeric,
201 Expression::Number(Number::Float(f)) if (f - 1.0).abs() < EPSILON => {
202 simplified_non_numeric
203 }
204 _ => Expression::Mul(Arc::new(vec![num.clone(), simplified_non_numeric])),
205 }
206 }
207 _ => {
208 let mut result_factors = Vec::with_capacity(non_numeric_count + 1);
209 if let Some(num) = numeric_result {
210 match num {
211 Expression::Number(Number::Integer(1)) => {}
212 Expression::Number(Number::Float(1.0)) => {}
213 _ => result_factors.push(num),
214 }
215 }
216 for factor in factors {
217 if !matches!(factor, Expression::Number(_)) {
218 let simplified_factor = match factor {
219 Expression::Add(terms) => simplify_addition(terms),
220 Expression::Pow(base, exp) => simplify_power(base.as_ref(), exp.as_ref()),
221 _ => factor.simplify(),
222 };
223 result_factors.push(simplified_factor);
224 }
225 }
226 match result_factors.len() {
227 0 => Expression::integer(1),
228 1 => result_factors
229 .into_iter()
230 .next()
231 .expect("BUG: result_factors has length 1 but iterator is empty"),
232 _ => {
233 let commutativity =
234 Commutativity::combine(result_factors.iter().map(|f| f.commutativity()));
235
236 if commutativity.can_sort() {
237 result_factors = combine_like_powers(result_factors);
238 result_factors.sort_by(expression_order);
239 }
240
241 match result_factors.len() {
242 0 => Expression::integer(1),
243 1 => result_factors.into_iter().next().unwrap(),
244 _ => Expression::Mul(Arc::new(result_factors)),
245 }
246 }
247 }
248 }
249 }
250}
251
252#[cfg(test)]
253mod tests {
254 use super::*;
255 use crate::simplify::Simplify;
256 use crate::symbol;
257 use crate::Expression;
258
259 #[test]
260 fn test_multiplication_simplification() {
261 let expr = simplify_multiplication(&[Expression::integer(2), Expression::integer(3)]);
262 assert_eq!(expr, Expression::integer(6));
263
264 let expr = simplify_multiplication(&[Expression::integer(5), Expression::integer(1)]);
265 assert_eq!(expr, Expression::integer(5));
266
267 let expr = simplify_multiplication(&[Expression::integer(5), Expression::integer(0)]);
268 assert_eq!(expr, Expression::integer(0));
269 }
270
271 #[test]
272 fn test_nested_multiplication_flattening() {
273 let nested = Expression::mul(vec![Expression::integer(3), Expression::integer(4)]);
274 let expr = simplify_multiplication(&[Expression::integer(2), nested]);
275 assert_eq!(expr, Expression::integer(24));
276 }
277
278 #[test]
279 fn test_scalar_multiplication_sorts() {
280 let y = symbol!(y);
281 let x = symbol!(x);
282 let expr = Expression::mul(vec![
283 Expression::symbol(y.clone()),
284 Expression::symbol(x.clone()),
285 ]);
286 let simplified = expr.simplify();
287
288 match simplified {
289 Expression::Mul(factors) => {
290 assert_eq!(factors.len(), 2);
291 assert_eq!(factors[0], Expression::symbol(symbol!(x)));
292 assert_eq!(factors[1], Expression::symbol(symbol!(y)));
293 }
294 _ => panic!("Expected Mul, got {:?}", simplified),
295 }
296 }
297
298 #[test]
299 fn test_matrix_multiplication_preserves_order() {
300 let mat_a = symbol!(A; matrix);
301 let mat_b = symbol!(B; matrix);
302 let expr = Expression::mul(vec![
303 Expression::symbol(mat_b.clone()),
304 Expression::symbol(mat_a.clone()),
305 ]);
306 let simplified = expr.simplify();
307
308 match simplified {
309 Expression::Mul(factors) => {
310 assert_eq!(factors.len(), 2);
311 assert_eq!(factors[0], Expression::symbol(symbol!(B; matrix)));
312 assert_eq!(factors[1], Expression::symbol(symbol!(A; matrix)));
313 }
314 _ => panic!("Expected Mul, got {:?}", simplified),
315 }
316 }
317
318 #[test]
319 fn test_mixed_scalar_matrix_preserves_order() {
320 let x = symbol!(x);
321 let mat_a = symbol!(A; matrix);
322 let expr = Expression::mul(vec![
323 Expression::symbol(x.clone()),
324 Expression::symbol(mat_a.clone()),
325 ]);
326 let simplified = expr.simplify();
327
328 match simplified {
329 Expression::Mul(factors) => {
330 assert_eq!(factors.len(), 2);
331 assert_eq!(factors[0], Expression::symbol(symbol!(x)));
332 assert_eq!(factors[1], Expression::symbol(symbol!(A; matrix)));
333 }
334 _ => panic!("Expected Mul, got {:?}", simplified),
335 }
336 }
337
338 #[test]
339 fn test_operator_multiplication_preserves_order() {
340 let mat_p = symbol!(P; operator);
341 let mat_q = symbol!(Q; operator);
342 let expr = Expression::mul(vec![
343 Expression::symbol(mat_q.clone()),
344 Expression::symbol(mat_p.clone()),
345 ]);
346 let simplified = expr.simplify();
347
348 match simplified {
349 Expression::Mul(factors) => {
350 assert_eq!(factors.len(), 2);
351 assert_eq!(factors[0], Expression::symbol(symbol!(Q; operator)));
352 assert_eq!(factors[1], Expression::symbol(symbol!(P; operator)));
353 }
354 _ => panic!("Expected Mul, got {:?}", simplified),
355 }
356 }
357
358 #[test]
359 fn test_quaternion_multiplication_preserves_order() {
360 let i = symbol!(i; quaternion);
361 let j = symbol!(j; quaternion);
362 let expr = Expression::mul(vec![
363 Expression::symbol(j.clone()),
364 Expression::symbol(i.clone()),
365 ]);
366 let simplified = expr.simplify();
367
368 match simplified {
369 Expression::Mul(factors) => {
370 assert_eq!(factors.len(), 2);
371 assert_eq!(factors[0], Expression::symbol(symbol!(j; quaternion)));
372 assert_eq!(factors[1], Expression::symbol(symbol!(i; quaternion)));
373 }
374 _ => panic!("Expected Mul, got {:?}", simplified),
375 }
376 }
377
378 #[test]
379 fn test_three_scalar_factors_sort() {
380 let z = symbol!(z);
381 let x = symbol!(x);
382 let y = symbol!(y);
383 let expr = Expression::mul(vec![
384 Expression::symbol(z.clone()),
385 Expression::symbol(x.clone()),
386 Expression::symbol(y.clone()),
387 ]);
388 let simplified = expr.simplify();
389
390 match simplified {
391 Expression::Mul(factors) => {
392 assert_eq!(factors.len(), 3);
393 assert_eq!(factors[0], Expression::symbol(symbol!(x)));
394 assert_eq!(factors[1], Expression::symbol(symbol!(y)));
395 assert_eq!(factors[2], Expression::symbol(symbol!(z)));
396 }
397 _ => panic!("Expected Mul, got {:?}", simplified),
398 }
399 }
400
401 #[test]
402 fn test_three_matrix_factors_preserve_order() {
403 let mat_c = symbol!(C; matrix);
404 let mat_a = symbol!(A; matrix);
405 let mat_b = symbol!(B; matrix);
406 let expr = Expression::mul(vec![
407 Expression::symbol(mat_c.clone()),
408 Expression::symbol(mat_a.clone()),
409 Expression::symbol(mat_b.clone()),
410 ]);
411 let simplified = expr.simplify();
412
413 match simplified {
414 Expression::Mul(factors) => {
415 assert_eq!(factors.len(), 3);
416 assert_eq!(factors[0], Expression::symbol(symbol!(C; matrix)));
417 assert_eq!(factors[1], Expression::symbol(symbol!(A; matrix)));
418 assert_eq!(factors[2], Expression::symbol(symbol!(B; matrix)));
419 }
420 _ => panic!("Expected Mul, got {:?}", simplified),
421 }
422 }
423
424 #[test]
425 fn test_numeric_coefficient_with_scalars_sorts() {
426 let y = symbol!(y);
427 let x = symbol!(x);
428 let expr = Expression::mul(vec![
429 Expression::integer(2),
430 Expression::symbol(y.clone()),
431 Expression::symbol(x.clone()),
432 ]);
433 let simplified = expr.simplify();
434
435 match simplified {
436 Expression::Mul(factors) => {
437 assert_eq!(factors.len(), 3);
438 assert_eq!(factors[0], Expression::integer(2));
439 assert_eq!(factors[1], Expression::symbol(symbol!(x)));
440 assert_eq!(factors[2], Expression::symbol(symbol!(y)));
441 }
442 _ => panic!("Expected Mul, got {:?}", simplified),
443 }
444 }
445
446 #[test]
447 fn test_numeric_coefficient_with_matrices_preserves_order() {
448 let mat_b = symbol!(B; matrix);
449 let mat_a = symbol!(A; matrix);
450 let expr = Expression::mul(vec![
451 Expression::integer(2),
452 Expression::symbol(mat_b.clone()),
453 Expression::symbol(mat_a.clone()),
454 ]);
455 let simplified = expr.simplify();
456
457 match simplified {
458 Expression::Mul(factors) => {
459 assert_eq!(factors.len(), 3);
460 assert_eq!(factors[0], Expression::integer(2));
461 assert_eq!(factors[1], Expression::symbol(symbol!(B; matrix)));
462 assert_eq!(factors[2], Expression::symbol(symbol!(A; matrix)));
463 }
464 _ => panic!("Expected Mul, got {:?}", simplified),
465 }
466 }
467}