1use crate::core::commutativity::Commutativity;
5use crate::core::{Expression, Number};
6
7pub trait Expand {
9 fn expand(&self) -> Self;
10}
11
12impl Expand for Expression {
13 fn expand(&self) -> Self {
15 match self {
16 Expression::Number(_) | Expression::Symbol(_) => self.clone(),
17
18 Expression::Add(terms) => {
19 let expanded_terms: Vec<Expression> =
20 terms.iter().map(|term| term.expand()).collect();
21 Expression::add(expanded_terms)
22 }
23
24 Expression::Mul(factors) => self.expand_multiplication(factors),
25
26 Expression::Pow(base, exp) => self.expand_power(base, exp),
27
28 Expression::Function { name, args } => {
29 let expanded_args: Vec<Expression> = args.iter().map(|arg| arg.expand()).collect();
30 Expression::function(name.clone(), expanded_args)
31 }
32 _ => self.clone(),
33 }
34 }
35}
36
37impl Expression {
38 fn expand_multiplication(&self, factors: &[Expression]) -> Expression {
40 if factors.is_empty() {
41 return Expression::integer(1);
42 }
43
44 if factors.len() == 1 {
45 return factors[0].expand();
46 }
47
48 let mut result = factors[0].expand();
49
50 for factor in &factors[1..] {
51 result = result.distribute_multiply(&factor.expand());
52 }
53
54 result
55 }
56
57 fn distribute_multiply(&self, right: &Expression) -> Expression {
59 match (self, right) {
60 (Expression::Add(left_terms), _) => {
61 let distributed_terms: Vec<Expression> = left_terms
62 .iter()
63 .map(|term| term.distribute_multiply(right))
64 .collect();
65 Expression::add(distributed_terms)
66 }
67
68 (_, Expression::Add(right_terms)) => {
69 let distributed_terms: Vec<Expression> = right_terms
70 .iter()
71 .map(|term| self.distribute_multiply(term))
72 .collect();
73 Expression::add(distributed_terms)
74 }
75
76 _ => Expression::mul(vec![self.clone(), right.clone()]),
77 }
78 }
79
80 fn expand_power(&self, base: &Expression, exp: &Expression) -> Expression {
82 if let Expression::Number(Number::Integer(n)) = exp {
83 let exp_val = *n;
84 if (0..=10).contains(&exp_val) {
85 return self.expand_integer_power(base, exp_val as u32);
86 }
87 }
88
89 Expression::pow(base.clone(), exp.clone())
90 }
91
92 fn expand_integer_power(&self, base: &Expression, exp: u32) -> Expression {
98 match exp {
99 0 => Expression::integer(1),
100 1 => base.expand(),
101 2 => match base {
102 Expression::Add(terms) if terms.len() == 2 => {
103 let a = &terms[0];
104 let b = &terms[1];
105
106 let commutativity =
107 Commutativity::combine(terms.iter().map(|t| t.commutativity()));
108
109 if commutativity.can_sort() {
110 Expression::add(vec![
111 Expression::pow(a.clone(), Expression::integer(2)).expand(),
112 Expression::mul(vec![Expression::integer(2), a.clone(), b.clone()])
113 .expand(),
114 Expression::pow(b.clone(), Expression::integer(2)).expand(),
115 ])
116 } else {
117 Expression::add(vec![
118 Expression::pow(a.clone(), Expression::integer(2)).expand(),
119 Expression::mul(vec![a.clone(), b.clone()]).expand(),
120 Expression::mul(vec![b.clone(), a.clone()]).expand(),
121 Expression::pow(b.clone(), Expression::integer(2)).expand(),
122 ])
123 }
124 }
125 _ => {
126 let expanded_base = base.expand();
127 expanded_base.distribute_multiply(&expanded_base)
128 }
129 },
130 _ => {
131 let expanded_base = base.expand();
132 let mut result = expanded_base.clone();
133
134 for _ in 1..exp {
135 result = result.distribute_multiply(&expanded_base);
136 }
137
138 result
139 }
140 }
141 }
142
143 pub fn expand_binomial(&self, a: &Expression, b: &Expression, n: u32) -> Expression {
148 if n == 0 {
149 return Expression::integer(1);
150 }
151
152 if n == 1 {
153 return Expression::add(vec![a.clone(), b.clone()]);
154 }
155
156 let commutativity = Commutativity::combine(vec![a.commutativity(), b.commutativity()]);
157
158 if !commutativity.can_sort() {
159 let base = Expression::add(vec![a.clone(), b.clone()]);
160 let mut result = base.clone();
161 for _ in 1..n {
162 result = result.distribute_multiply(&base);
163 }
164 return result;
165 }
166
167 if n <= 5 {
168 let mut terms = Vec::new();
169
170 for k in 0..=n {
171 let coeff = self.binomial_coefficient(n, k);
172 let a_power = if k == 0 {
173 Expression::integer(1)
174 } else {
175 Expression::pow(a.clone(), Expression::integer(k as i64))
176 };
177 let b_power = if n - k == 0 {
178 Expression::integer(1)
179 } else {
180 Expression::pow(b.clone(), Expression::integer((n - k) as i64))
181 };
182
183 let term = Expression::mul(vec![Expression::integer(coeff), a_power, b_power]);
184
185 terms.push(term);
186 }
187
188 Expression::add(terms)
189 } else {
190 Expression::pow(
191 Expression::add(vec![a.clone(), b.clone()]),
192 Expression::integer(n as i64),
193 )
194 }
195 }
196
197 fn binomial_coefficient(&self, n: u32, k: u32) -> i64 {
199 if k > n {
200 return 0;
201 }
202
203 if k == 0 || k == n {
204 return 1;
205 }
206
207 let mut result = 1i64;
208 let k = k.min(n - k); for i in 0..k {
211 if let Some(new_result) = result.checked_mul((n - i) as i64) {
212 if let Some(final_result) = new_result.checked_div((i + 1) as i64) {
213 result = final_result;
214 } else {
215 return 1; }
217 } else {
218 return 1; }
220 }
221
222 result
223 }
224}
225
226#[cfg(test)]
227mod tests {
228 use super::*;
229 use crate::symbol;
230
231 #[test]
232 fn test_basic_expansion() {
233 let x = symbol!(x);
234 let y = symbol!(y);
235
236 let expr = Expression::mul(vec![
237 Expression::add(vec![
238 Expression::symbol(x.clone()),
239 Expression::symbol(y.clone()),
240 ]),
241 Expression::integer(2),
242 ]);
243
244 let result = expr.expand();
245
246 match result {
247 Expression::Add(terms) => {
248 assert_eq!(terms.len(), 2);
249 }
250 _ => println!("Expansion result: {}", result),
251 }
252 }
253
254 #[test]
255 fn test_square_expansion() {
256 let x = symbol!(x);
257 let y = symbol!(y);
258
259 let expr = Expression::pow(
260 Expression::add(vec![
261 Expression::symbol(x.clone()),
262 Expression::symbol(y.clone()),
263 ]),
264 Expression::integer(2),
265 );
266
267 let result = expr.expand();
268
269 match result {
270 Expression::Add(terms) => {
271 assert_eq!(terms.len(), 3);
272 }
273 _ => println!("Square expansion result: {}", result),
274 }
275 }
276
277 #[test]
278 fn test_binomial_coefficients() {
279 let expr = Expression::integer(1); assert_eq!(expr.binomial_coefficient(5, 0), 1);
282 assert_eq!(expr.binomial_coefficient(5, 1), 5);
283 assert_eq!(expr.binomial_coefficient(5, 2), 10);
284 assert_eq!(expr.binomial_coefficient(5, 3), 10);
285 assert_eq!(expr.binomial_coefficient(5, 4), 5);
286 assert_eq!(expr.binomial_coefficient(5, 5), 1);
287 }
288
289 #[test]
290 fn test_nested_expansion() {
291 let x = symbol!(x);
292
293 let expr = Expression::mul(vec![
294 Expression::add(vec![Expression::symbol(x.clone()), Expression::integer(1)]),
295 Expression::add(vec![Expression::symbol(x.clone()), Expression::integer(2)]),
296 ]);
297
298 let result = expr.expand();
299
300 assert!(!result.is_zero());
301 }
302
303 #[test]
304 fn test_expansion_with_numbers() {
305 let expr = Expression::mul(vec![
306 Expression::integer(3),
307 Expression::add(vec![Expression::integer(2), Expression::integer(4)]),
308 ]);
309
310 let result = expr.expand();
311
312 assert!(!result.is_zero());
313 }
314
315 #[test]
316 fn test_commutative_square_expansion() {
317 let x = symbol!(x);
318 let y = symbol!(y);
319
320 let expr = Expression::pow(
321 Expression::add(vec![
322 Expression::symbol(x.clone()),
323 Expression::symbol(y.clone()),
324 ]),
325 Expression::integer(2),
326 );
327
328 let result = expr.expand();
329
330 match result {
331 Expression::Add(terms) => {
332 assert_eq!(terms.len(), 3, "Expected 3 terms for commutative square");
333 }
334 _ => panic!("Expected addition of 3 terms"),
335 }
336 }
337
338 #[test]
339 fn test_noncommutative_matrix_square_expansion() {
340 let a = symbol!(A; matrix);
341 let b = symbol!(B; matrix);
342
343 let expr = Expression::pow(
344 Expression::add(vec![
345 Expression::symbol(a.clone()),
346 Expression::symbol(b.clone()),
347 ]),
348 Expression::integer(2),
349 );
350
351 let result = expr.expand();
352
353 match result {
354 Expression::Add(terms) => {
355 assert_eq!(terms.len(), 4, "Expected 4 terms for noncommutative square");
356 }
357 _ => panic!("Expected addition of 4 terms"),
358 }
359 }
360
361 #[test]
362 fn test_noncommutative_operator_square_expansion() {
363 let p = symbol!(p; operator);
364 let x = symbol!(x; operator);
365
366 let expr = Expression::pow(
367 Expression::add(vec![
368 Expression::symbol(p.clone()),
369 Expression::symbol(x.clone()),
370 ]),
371 Expression::integer(2),
372 );
373
374 let result = expr.expand();
375
376 match result {
377 Expression::Add(terms) => {
378 assert_eq!(terms.len(), 4, "Expected 4 terms for operator square");
379 }
380 _ => panic!("Expected addition of 4 terms"),
381 }
382 }
383
384 #[test]
385 fn test_noncommutative_quaternion_square_expansion() {
386 let i = symbol!(i; quaternion);
387 let j = symbol!(j; quaternion);
388
389 let expr = Expression::pow(
390 Expression::add(vec![
391 Expression::symbol(i.clone()),
392 Expression::symbol(j.clone()),
393 ]),
394 Expression::integer(2),
395 );
396
397 let result = expr.expand();
398
399 match result {
400 Expression::Add(terms) => {
401 assert_eq!(terms.len(), 4, "Expected 4 terms for quaternion square");
402 }
403 _ => panic!("Expected addition of 4 terms"),
404 }
405 }
406
407 #[test]
408 fn test_mixed_commutative_noncommutative_expansion() {
409 let x = symbol!(x);
410 let a = symbol!(A; matrix);
411
412 let expr = Expression::pow(
413 Expression::add(vec![
414 Expression::symbol(x.clone()),
415 Expression::symbol(a.clone()),
416 ]),
417 Expression::integer(2),
418 );
419
420 let result = expr.expand();
421
422 match result {
423 Expression::Add(terms) => {
424 assert_eq!(
425 terms.len(),
426 4,
427 "Expected 4 terms when ANY term is noncommutative"
428 );
429 }
430 _ => panic!("Expected addition of 4 terms"),
431 }
432 }
433
434 #[test]
435 fn test_distribution_preserves_order_for_matrices() {
436 let a = symbol!(A; matrix);
437 let b = symbol!(B; matrix);
438 let c = symbol!(C; matrix);
439
440 let expr = Expression::mul(vec![
441 Expression::add(vec![
442 Expression::symbol(a.clone()),
443 Expression::symbol(b.clone()),
444 ]),
445 Expression::symbol(c.clone()),
446 ]);
447
448 let result = expr.expand();
449
450 match result {
451 Expression::Add(terms) => {
452 assert_eq!(terms.len(), 2, "Expected AC + BC");
453 }
454 _ => panic!("Expected addition"),
455 }
456 }
457
458 #[test]
459 fn test_binomial_theorem_not_used_for_noncommutative() {
460 let a = symbol!(A; matrix);
461 let b = symbol!(B; matrix);
462
463 let result = Expression::integer(1).expand_binomial(
464 &Expression::symbol(a.clone()),
465 &Expression::symbol(b.clone()),
466 3,
467 );
468
469 assert!(!result.is_zero());
471 }
472}