mathhook_core/simplify/arithmetic/
multiplication.rs1mod binary_numeric;
4mod power_combining;
5
6pub use binary_numeric::try_simplify_binary;
7pub use power_combining::combine_like_powers;
8
9use super::addition::simplify_addition;
10use super::helpers::expression_order;
11use super::power::simplify_power;
12use super::Simplify;
13use crate::core::commutativity::Commutativity;
14use crate::core::constants::EPSILON;
15use crate::core::{Expression, Number};
16use num_bigint::BigInt;
17use num_rational::BigRational;
18use num_traits::{One, ToPrimitive, Zero};
19
20pub fn simplify_multiplication(factors: &[Expression]) -> Expression {
22 if factors.is_empty() {
23 return Expression::integer(1);
24 }
25 if factors.len() == 1 {
26 return factors[0].clone();
27 }
28
29 let mut flattened_factors = Vec::new();
30 let mut to_process: Vec<&Expression> = factors.iter().collect();
31
32 while !to_process.is_empty() {
33 let factor = to_process.remove(0);
34 match factor {
35 Expression::Mul(nested_factors) => {
36 for (i, nested) in nested_factors.iter().enumerate() {
37 to_process.insert(i, nested);
38 }
39 }
40 _ => {
41 let simplified = match factor {
42 Expression::Add(terms) => simplify_addition(terms),
43 Expression::Pow(base, exp) => simplify_power(base, exp),
44 _ => factor.simplify(),
45 };
46 flattened_factors.push(simplified);
47 }
48 }
49 }
50
51 let factors = &flattened_factors;
52
53 if factors.len() == 2 {
54 if let Some(result) = try_simplify_binary(&factors[0], &factors[1]) {
55 return result;
56 }
57
58 if let Some(Ok(result)) = super::matrix_ops::try_matrix_multiply(&factors[0], &factors[1]) {
62 return result;
63 }
64
65 match (&factors[0], &factors[1]) {
67 (a, Expression::Add(terms)) => {
68 let simplified_add = simplify_addition(terms);
69 if !matches!(simplified_add, Expression::Add(_)) {
70 return simplify_multiplication(&[a.clone(), simplified_add]);
71 }
72 }
73 (Expression::Add(terms), b) => {
74 let simplified_add = simplify_addition(terms);
75 if !matches!(simplified_add, Expression::Add(_)) {
76 return simplify_multiplication(&[simplified_add, b.clone()]);
77 }
78 }
79 _ => {}
80 }
81 }
82
83 let mut all_integers = true;
84 let mut integer_product = 1i64;
85 for factor in factors {
86 match factor {
87 Expression::Number(Number::Integer(n)) => {
88 integer_product = integer_product.saturating_mul(*n);
89 }
90 _ => {
91 all_integers = false;
92 break;
93 }
94 }
95 }
96 if all_integers && factors.len() > 2 {
97 return Expression::integer(integer_product);
98 }
99
100 let mut int_product = 1i64;
101 let mut float_product = 1.0;
102 let mut has_float = false;
103 let mut non_numeric_count = 0;
104 let mut first_non_numeric = None;
105 let mut numeric_result = None;
106
107 let mut rational_product: Option<BigRational> = None;
108
109 let has_undefined = factors
110 .iter()
111 .any(|f| matches!(f, Expression::Function { name, .. } if name == "undefined"));
112
113 for factor in factors {
114 match factor {
115 Expression::Number(Number::Integer(n)) => {
116 int_product = int_product.saturating_mul(*n);
117 if int_product == 0 && !has_undefined {
118 return Expression::integer(0);
119 }
120 }
121 Expression::Number(Number::Float(f)) => {
122 float_product *= f;
123 has_float = true;
124 if float_product.abs() < EPSILON && !has_undefined {
125 return Expression::integer(0);
126 }
127 }
128 Expression::Number(Number::Rational(r)) => {
129 if let Some(ref mut current_rational) = rational_product {
130 *current_rational *= r.as_ref();
131 } else {
132 rational_product = Some(r.as_ref().clone());
133 }
134 if rational_product
135 .as_ref()
136 .expect("BUG: rational_product should be Some at this point")
137 .is_zero()
138 && !has_undefined
139 {
140 return Expression::integer(0);
141 }
142 }
143 _ => {
144 non_numeric_count += 1;
145 if first_non_numeric.is_none() {
146 first_non_numeric = Some(factor);
147 }
148 }
149 }
150 }
151
152 if let Some(rational) = rational_product {
153 let mut final_rational = rational;
154 if int_product != 1 {
155 final_rational *= BigRational::from(BigInt::from(int_product));
156 }
157 if has_float {
158 let float_val = final_rational.to_f64().unwrap_or(0.0) * float_product;
159 if (float_val - 1.0).abs() >= EPSILON {
160 numeric_result = Some(Expression::Number(Number::float(float_val)));
161 }
162 } else if final_rational.denom() == &BigInt::from(1) {
163 if let Some(int_val) = final_rational.numer().to_i64() {
164 if int_val != 1 {
165 numeric_result = Some(Expression::integer(int_val));
166 }
167 } else if !final_rational.is_one() {
168 numeric_result = Some(Expression::Number(Number::rational(final_rational)));
169 }
170 } else if !final_rational.is_one() {
171 numeric_result = Some(Expression::Number(Number::rational(final_rational)));
172 }
173 } else if has_float {
174 let total = int_product as f64 * float_product;
175 if (total - 1.0).abs() >= EPSILON {
176 numeric_result = Some(Expression::Number(Number::float(total)));
177 }
178 } else if int_product != 1 {
179 numeric_result = Some(Expression::integer(int_product));
180 }
181
182 match (numeric_result.as_ref(), non_numeric_count) {
183 (None, 0) => Expression::integer(1),
184 (Some(num), 0) => num.clone(),
185 (None, 1) => {
186 let factor = first_non_numeric
187 .expect("BUG: non_numeric_count is 1 but first_non_numeric is None");
188 match factor {
189 Expression::Add(terms) => simplify_addition(terms),
190 Expression::Pow(base, exp) => simplify_power(base, exp),
191 _ => factor.simplify(),
192 }
193 }
194 (Some(num), 1) => {
195 let factor = first_non_numeric
196 .expect("BUG: non_numeric_count is 1 but first_non_numeric is None");
197 let simplified_non_numeric = match factor {
198 Expression::Add(terms) => simplify_addition(terms),
199 Expression::Pow(base, exp) => simplify_power(base, exp),
200 _ => factor.simplify(),
201 };
202 match num {
203 Expression::Number(Number::Integer(1)) => simplified_non_numeric,
204 Expression::Number(Number::Float(f)) if (f - 1.0).abs() < EPSILON => {
205 simplified_non_numeric
206 }
207 _ => Expression::Mul(Box::new(vec![num.clone(), simplified_non_numeric])),
208 }
209 }
210 _ => {
211 let mut result_factors = Vec::with_capacity(non_numeric_count + 1);
212 if let Some(num) = numeric_result {
213 match num {
214 Expression::Number(Number::Integer(1)) => {}
215 Expression::Number(Number::Float(1.0)) => {}
216 _ => result_factors.push(num),
217 }
218 }
219 for factor in factors {
220 if !matches!(factor, Expression::Number(_)) {
221 let simplified_factor = match factor {
222 Expression::Add(terms) => simplify_addition(terms),
223 Expression::Pow(base, exp) => simplify_power(base, exp),
224 _ => factor.simplify(),
225 };
226 result_factors.push(simplified_factor);
227 }
228 }
229 match result_factors.len() {
230 0 => Expression::integer(1),
231 1 => result_factors
232 .into_iter()
233 .next()
234 .expect("BUG: result_factors has length 1 but iterator is empty"),
235 _ => {
236 let commutativity =
237 Commutativity::combine(result_factors.iter().map(|f| f.commutativity()));
238
239 if commutativity.can_sort() {
240 result_factors = combine_like_powers(result_factors);
241 result_factors.sort_by(expression_order);
242 }
243
244 match result_factors.len() {
245 0 => Expression::integer(1),
246 1 => result_factors.into_iter().next().unwrap(),
247 _ => Expression::Mul(Box::new(result_factors)),
248 }
249 }
250 }
251 }
252 }
253}
254
255#[cfg(test)]
256mod tests {
257 use super::*;
258 use crate::simplify::Simplify;
259 use crate::symbol;
260 use crate::Expression;
261
262 #[test]
263 fn test_multiplication_simplification() {
264 let expr = simplify_multiplication(&[Expression::integer(2), Expression::integer(3)]);
265 assert_eq!(expr, Expression::integer(6));
266
267 let expr = simplify_multiplication(&[Expression::integer(5), Expression::integer(1)]);
268 assert_eq!(expr, Expression::integer(5));
269
270 let expr = simplify_multiplication(&[Expression::integer(5), Expression::integer(0)]);
271 assert_eq!(expr, Expression::integer(0));
272 }
273
274 #[test]
275 fn test_nested_multiplication_flattening() {
276 let nested = Expression::mul(vec![Expression::integer(3), Expression::integer(4)]);
277 let expr = simplify_multiplication(&[Expression::integer(2), nested]);
278 assert_eq!(expr, Expression::integer(24));
279 }
280
281 #[test]
282 fn test_scalar_multiplication_sorts() {
283 let y = symbol!(y);
284 let x = symbol!(x);
285 let expr = Expression::mul(vec![
286 Expression::symbol(y.clone()),
287 Expression::symbol(x.clone()),
288 ]);
289 let simplified = expr.simplify();
290
291 match simplified {
292 Expression::Mul(factors) => {
293 assert_eq!(factors.len(), 2);
294 assert_eq!(factors[0], Expression::symbol(symbol!(x)));
295 assert_eq!(factors[1], Expression::symbol(symbol!(y)));
296 }
297 _ => panic!("Expected Mul, got {:?}", simplified),
298 }
299 }
300
301 #[test]
302 fn test_matrix_multiplication_preserves_order() {
303 let mat_a = symbol!(A; matrix);
304 let mat_b = symbol!(B; matrix);
305 let expr = Expression::mul(vec![
306 Expression::symbol(mat_b.clone()),
307 Expression::symbol(mat_a.clone()),
308 ]);
309 let simplified = expr.simplify();
310
311 match simplified {
312 Expression::Mul(factors) => {
313 assert_eq!(factors.len(), 2);
314 assert_eq!(factors[0], Expression::symbol(symbol!(B; matrix)));
315 assert_eq!(factors[1], Expression::symbol(symbol!(A; matrix)));
316 }
317 _ => panic!("Expected Mul, got {:?}", simplified),
318 }
319 }
320
321 #[test]
322 fn test_mixed_scalar_matrix_preserves_order() {
323 let x = symbol!(x);
324 let mat_a = symbol!(A; matrix);
325 let expr = Expression::mul(vec![
326 Expression::symbol(x.clone()),
327 Expression::symbol(mat_a.clone()),
328 ]);
329 let simplified = expr.simplify();
330
331 match simplified {
332 Expression::Mul(factors) => {
333 assert_eq!(factors.len(), 2);
334 assert_eq!(factors[0], Expression::symbol(symbol!(x)));
335 assert_eq!(factors[1], Expression::symbol(symbol!(A; matrix)));
336 }
337 _ => panic!("Expected Mul, got {:?}", simplified),
338 }
339 }
340
341 #[test]
342 fn test_operator_multiplication_preserves_order() {
343 let mat_p = symbol!(P; operator);
344 let mat_q = symbol!(Q; operator);
345 let expr = Expression::mul(vec![
346 Expression::symbol(mat_q.clone()),
347 Expression::symbol(mat_p.clone()),
348 ]);
349 let simplified = expr.simplify();
350
351 match simplified {
352 Expression::Mul(factors) => {
353 assert_eq!(factors.len(), 2);
354 assert_eq!(factors[0], Expression::symbol(symbol!(Q; operator)));
355 assert_eq!(factors[1], Expression::symbol(symbol!(P; operator)));
356 }
357 _ => panic!("Expected Mul, got {:?}", simplified),
358 }
359 }
360
361 #[test]
362 fn test_quaternion_multiplication_preserves_order() {
363 let i = symbol!(i; quaternion);
364 let j = symbol!(j; quaternion);
365 let expr = Expression::mul(vec![
366 Expression::symbol(j.clone()),
367 Expression::symbol(i.clone()),
368 ]);
369 let simplified = expr.simplify();
370
371 match simplified {
372 Expression::Mul(factors) => {
373 assert_eq!(factors.len(), 2);
374 assert_eq!(factors[0], Expression::symbol(symbol!(j; quaternion)));
375 assert_eq!(factors[1], Expression::symbol(symbol!(i; quaternion)));
376 }
377 _ => panic!("Expected Mul, got {:?}", simplified),
378 }
379 }
380
381 #[test]
382 fn test_three_scalar_factors_sort() {
383 let z = symbol!(z);
384 let x = symbol!(x);
385 let y = symbol!(y);
386 let expr = Expression::mul(vec![
387 Expression::symbol(z.clone()),
388 Expression::symbol(x.clone()),
389 Expression::symbol(y.clone()),
390 ]);
391 let simplified = expr.simplify();
392
393 match simplified {
394 Expression::Mul(factors) => {
395 assert_eq!(factors.len(), 3);
396 assert_eq!(factors[0], Expression::symbol(symbol!(x)));
397 assert_eq!(factors[1], Expression::symbol(symbol!(y)));
398 assert_eq!(factors[2], Expression::symbol(symbol!(z)));
399 }
400 _ => panic!("Expected Mul, got {:?}", simplified),
401 }
402 }
403
404 #[test]
405 fn test_three_matrix_factors_preserve_order() {
406 let mat_c = symbol!(C; matrix);
407 let mat_a = symbol!(A; matrix);
408 let mat_b = symbol!(B; matrix);
409 let expr = Expression::mul(vec![
410 Expression::symbol(mat_c.clone()),
411 Expression::symbol(mat_a.clone()),
412 Expression::symbol(mat_b.clone()),
413 ]);
414 let simplified = expr.simplify();
415
416 match simplified {
417 Expression::Mul(factors) => {
418 assert_eq!(factors.len(), 3);
419 assert_eq!(factors[0], Expression::symbol(symbol!(C; matrix)));
420 assert_eq!(factors[1], Expression::symbol(symbol!(A; matrix)));
421 assert_eq!(factors[2], Expression::symbol(symbol!(B; matrix)));
422 }
423 _ => panic!("Expected Mul, got {:?}", simplified),
424 }
425 }
426
427 #[test]
428 fn test_numeric_coefficient_with_scalars_sorts() {
429 let y = symbol!(y);
430 let x = symbol!(x);
431 let expr = Expression::mul(vec![
432 Expression::integer(2),
433 Expression::symbol(y.clone()),
434 Expression::symbol(x.clone()),
435 ]);
436 let simplified = expr.simplify();
437
438 match simplified {
439 Expression::Mul(factors) => {
440 assert_eq!(factors.len(), 3);
441 assert_eq!(factors[0], Expression::integer(2));
442 assert_eq!(factors[1], Expression::symbol(symbol!(x)));
443 assert_eq!(factors[2], Expression::symbol(symbol!(y)));
444 }
445 _ => panic!("Expected Mul, got {:?}", simplified),
446 }
447 }
448
449 #[test]
450 fn test_numeric_coefficient_with_matrices_preserves_order() {
451 let mat_b = symbol!(B; matrix);
452 let mat_a = symbol!(A; matrix);
453 let expr = Expression::mul(vec![
454 Expression::integer(2),
455 Expression::symbol(mat_b.clone()),
456 Expression::symbol(mat_a.clone()),
457 ]);
458 let simplified = expr.simplify();
459
460 match simplified {
461 Expression::Mul(factors) => {
462 assert_eq!(factors.len(), 3);
463 assert_eq!(factors[0], Expression::integer(2));
464 assert_eq!(factors[1], Expression::symbol(symbol!(B; matrix)));
465 assert_eq!(factors[2], Expression::symbol(symbol!(A; matrix)));
466 }
467 _ => panic!("Expected Mul, got {:?}", simplified),
468 }
469 }
470}