mathhook_core/algebra/
expand.rs

1//! Expression expansion operations
2//! Handles polynomial expansion, distribution, and algebraic expansion
3
4use crate::core::commutativity::Commutativity;
5use crate::core::{Expression, Number};
6
7/// Trait for expanding expressions
8pub trait Expand {
9    fn expand(&self) -> Self;
10}
11
12impl Expand for Expression {
13    /// Expand the expression by distributing multiplication over addition
14    fn expand(&self) -> Self {
15        match self {
16            Expression::Number(_) | Expression::Symbol(_) => self.clone(),
17
18            Expression::Add(terms) => {
19                let expanded_terms: Vec<Expression> =
20                    terms.iter().map(|term| term.expand()).collect();
21                Expression::add(expanded_terms)
22            }
23
24            Expression::Mul(factors) => self.expand_multiplication(factors),
25
26            Expression::Pow(base, exp) => self.expand_power(base, exp),
27
28            Expression::Function { name, args } => {
29                let expanded_args: Vec<Expression> = args.iter().map(|arg| arg.expand()).collect();
30                Expression::function(name.clone(), expanded_args)
31            }
32            _ => self.clone(),
33        }
34    }
35}
36
37impl Expression {
38    /// Expand multiplication by distributing over addition
39    fn expand_multiplication(&self, factors: &[Expression]) -> Expression {
40        if factors.is_empty() {
41            return Expression::integer(1);
42        }
43
44        if factors.len() == 1 {
45            return factors[0].expand();
46        }
47
48        let mut result = factors[0].expand();
49
50        for factor in &factors[1..] {
51            result = result.distribute_multiply(&factor.expand());
52        }
53
54        result
55    }
56
57    /// Distribute multiplication: (a + b) * c = a*c + b*c
58    fn distribute_multiply(&self, right: &Expression) -> Expression {
59        match (self, right) {
60            (Expression::Add(left_terms), _) => {
61                let distributed_terms: Vec<Expression> = left_terms
62                    .iter()
63                    .map(|term| term.distribute_multiply(right))
64                    .collect();
65                Expression::add(distributed_terms)
66            }
67
68            (_, Expression::Add(right_terms)) => {
69                let distributed_terms: Vec<Expression> = right_terms
70                    .iter()
71                    .map(|term| self.distribute_multiply(term))
72                    .collect();
73                Expression::add(distributed_terms)
74            }
75
76            _ => Expression::mul(vec![self.clone(), right.clone()]),
77        }
78    }
79
80    /// Expand power expressions
81    fn expand_power(&self, base: &Expression, exp: &Expression) -> Expression {
82        if let Expression::Number(Number::Integer(n)) = exp {
83            let exp_val = *n;
84            if (0..=10).contains(&exp_val) {
85                return self.expand_integer_power(base, exp_val as u32);
86            }
87        }
88
89        Expression::pow(base.clone(), exp.clone())
90    }
91
92    /// Expand integer powers: (a + b)^n
93    ///
94    /// For noncommutative terms, preserves order:
95    /// (A+B)^2 = A^2 + AB + BA + B^2 (4 terms for noncommutative)
96    /// (x+y)^2 = x^2 + 2xy + y^2 (3 terms for commutative)
97    fn expand_integer_power(&self, base: &Expression, exp: u32) -> Expression {
98        match exp {
99            0 => Expression::integer(1),
100            1 => base.expand(),
101            2 => match base {
102                Expression::Add(terms) if terms.len() == 2 => {
103                    let a = &terms[0];
104                    let b = &terms[1];
105
106                    let commutativity =
107                        Commutativity::combine(terms.iter().map(|t| t.commutativity()));
108
109                    if commutativity.can_sort() {
110                        Expression::add(vec![
111                            Expression::pow(a.clone(), Expression::integer(2)).expand(),
112                            Expression::mul(vec![Expression::integer(2), a.clone(), b.clone()])
113                                .expand(),
114                            Expression::pow(b.clone(), Expression::integer(2)).expand(),
115                        ])
116                    } else {
117                        Expression::add(vec![
118                            Expression::pow(a.clone(), Expression::integer(2)).expand(),
119                            Expression::mul(vec![a.clone(), b.clone()]).expand(),
120                            Expression::mul(vec![b.clone(), a.clone()]).expand(),
121                            Expression::pow(b.clone(), Expression::integer(2)).expand(),
122                        ])
123                    }
124                }
125                _ => {
126                    let expanded_base = base.expand();
127                    expanded_base.distribute_multiply(&expanded_base)
128                }
129            },
130            _ => {
131                let expanded_base = base.expand();
132                let mut result = expanded_base.clone();
133
134                for _ in 1..exp {
135                    result = result.distribute_multiply(&expanded_base);
136                }
137
138                result
139            }
140        }
141    }
142
143    /// Expand binomial expressions: (a + b)^n using binomial theorem
144    ///
145    /// For commutative terms, uses binomial theorem: C(n,k) * a^k * b^(n-k)
146    /// For noncommutative terms, uses direct multiplication to preserve order
147    pub fn expand_binomial(&self, a: &Expression, b: &Expression, n: u32) -> Expression {
148        if n == 0 {
149            return Expression::integer(1);
150        }
151
152        if n == 1 {
153            return Expression::add(vec![a.clone(), b.clone()]);
154        }
155
156        let commutativity = Commutativity::combine(vec![a.commutativity(), b.commutativity()]);
157
158        if !commutativity.can_sort() {
159            let base = Expression::add(vec![a.clone(), b.clone()]);
160            let mut result = base.clone();
161            for _ in 1..n {
162                result = result.distribute_multiply(&base);
163            }
164            return result;
165        }
166
167        if n <= 5 {
168            let mut terms = Vec::new();
169
170            for k in 0..=n {
171                let coeff = self.binomial_coefficient(n, k);
172                let a_power = if k == 0 {
173                    Expression::integer(1)
174                } else {
175                    Expression::pow(a.clone(), Expression::integer(k as i64))
176                };
177                let b_power = if n - k == 0 {
178                    Expression::integer(1)
179                } else {
180                    Expression::pow(b.clone(), Expression::integer((n - k) as i64))
181                };
182
183                let term = Expression::mul(vec![Expression::integer(coeff), a_power, b_power]);
184
185                terms.push(term);
186            }
187
188            Expression::add(terms)
189        } else {
190            Expression::pow(
191                Expression::add(vec![a.clone(), b.clone()]),
192                Expression::integer(n as i64),
193            )
194        }
195    }
196
197    /// Calculate binomial coefficient C(n, k)
198    fn binomial_coefficient(&self, n: u32, k: u32) -> i64 {
199        if k > n {
200            return 0;
201        }
202
203        if k == 0 || k == n {
204            return 1;
205        }
206
207        let mut result = 1i64;
208        let k = k.min(n - k); // Take advantage of symmetry
209
210        for i in 0..k {
211            if let Some(new_result) = result.checked_mul((n - i) as i64) {
212                if let Some(final_result) = new_result.checked_div((i + 1) as i64) {
213                    result = final_result;
214                } else {
215                    return 1; // Fallback on division error
216                }
217            } else {
218                return 1; // Fallback on overflow
219            }
220        }
221
222        result
223    }
224}
225
226#[cfg(test)]
227mod tests {
228    use super::*;
229    use crate::symbol;
230
231    #[test]
232    fn test_basic_expansion() {
233        let x = symbol!(x);
234        let y = symbol!(y);
235
236        let expr = Expression::mul(vec![
237            Expression::add(vec![
238                Expression::symbol(x.clone()),
239                Expression::symbol(y.clone()),
240            ]),
241            Expression::integer(2),
242        ]);
243
244        let result = expr.expand();
245
246        match result {
247            Expression::Add(terms) => {
248                assert_eq!(terms.len(), 2);
249            }
250            _ => println!("Expansion result: {}", result),
251        }
252    }
253
254    #[test]
255    fn test_square_expansion() {
256        let x = symbol!(x);
257        let y = symbol!(y);
258
259        let expr = Expression::pow(
260            Expression::add(vec![
261                Expression::symbol(x.clone()),
262                Expression::symbol(y.clone()),
263            ]),
264            Expression::integer(2),
265        );
266
267        let result = expr.expand();
268
269        match result {
270            Expression::Add(terms) => {
271                assert_eq!(terms.len(), 3);
272            }
273            _ => println!("Square expansion result: {}", result),
274        }
275    }
276
277    #[test]
278    fn test_binomial_coefficients() {
279        let expr = Expression::integer(1); // Dummy expression for method access
280
281        assert_eq!(expr.binomial_coefficient(5, 0), 1);
282        assert_eq!(expr.binomial_coefficient(5, 1), 5);
283        assert_eq!(expr.binomial_coefficient(5, 2), 10);
284        assert_eq!(expr.binomial_coefficient(5, 3), 10);
285        assert_eq!(expr.binomial_coefficient(5, 4), 5);
286        assert_eq!(expr.binomial_coefficient(5, 5), 1);
287    }
288
289    #[test]
290    fn test_nested_expansion() {
291        let x = symbol!(x);
292
293        let expr = Expression::mul(vec![
294            Expression::add(vec![Expression::symbol(x.clone()), Expression::integer(1)]),
295            Expression::add(vec![Expression::symbol(x.clone()), Expression::integer(2)]),
296        ]);
297
298        let result = expr.expand();
299
300        assert!(!result.is_zero());
301    }
302
303    #[test]
304    fn test_expansion_with_numbers() {
305        let expr = Expression::mul(vec![
306            Expression::integer(3),
307            Expression::add(vec![Expression::integer(2), Expression::integer(4)]),
308        ]);
309
310        let result = expr.expand();
311
312        assert!(!result.is_zero());
313    }
314
315    #[test]
316    fn test_commutative_square_expansion() {
317        let x = symbol!(x);
318        let y = symbol!(y);
319
320        let expr = Expression::pow(
321            Expression::add(vec![
322                Expression::symbol(x.clone()),
323                Expression::symbol(y.clone()),
324            ]),
325            Expression::integer(2),
326        );
327
328        let result = expr.expand();
329
330        match result {
331            Expression::Add(terms) => {
332                assert_eq!(terms.len(), 3, "Expected 3 terms for commutative square");
333            }
334            _ => panic!("Expected addition of 3 terms"),
335        }
336    }
337
338    #[test]
339    fn test_noncommutative_matrix_square_expansion() {
340        let a = symbol!(A; matrix);
341        let b = symbol!(B; matrix);
342
343        let expr = Expression::pow(
344            Expression::add(vec![
345                Expression::symbol(a.clone()),
346                Expression::symbol(b.clone()),
347            ]),
348            Expression::integer(2),
349        );
350
351        let result = expr.expand();
352
353        match result {
354            Expression::Add(terms) => {
355                assert_eq!(terms.len(), 4, "Expected 4 terms for noncommutative square");
356            }
357            _ => panic!("Expected addition of 4 terms"),
358        }
359    }
360
361    #[test]
362    fn test_noncommutative_operator_square_expansion() {
363        let p = symbol!(p; operator);
364        let x = symbol!(x; operator);
365
366        let expr = Expression::pow(
367            Expression::add(vec![
368                Expression::symbol(p.clone()),
369                Expression::symbol(x.clone()),
370            ]),
371            Expression::integer(2),
372        );
373
374        let result = expr.expand();
375
376        match result {
377            Expression::Add(terms) => {
378                assert_eq!(terms.len(), 4, "Expected 4 terms for operator square");
379            }
380            _ => panic!("Expected addition of 4 terms"),
381        }
382    }
383
384    #[test]
385    fn test_noncommutative_quaternion_square_expansion() {
386        let i = symbol!(i; quaternion);
387        let j = symbol!(j; quaternion);
388
389        let expr = Expression::pow(
390            Expression::add(vec![
391                Expression::symbol(i.clone()),
392                Expression::symbol(j.clone()),
393            ]),
394            Expression::integer(2),
395        );
396
397        let result = expr.expand();
398
399        match result {
400            Expression::Add(terms) => {
401                assert_eq!(terms.len(), 4, "Expected 4 terms for quaternion square");
402            }
403            _ => panic!("Expected addition of 4 terms"),
404        }
405    }
406
407    #[test]
408    fn test_mixed_commutative_noncommutative_expansion() {
409        let x = symbol!(x);
410        let a = symbol!(A; matrix);
411
412        let expr = Expression::pow(
413            Expression::add(vec![
414                Expression::symbol(x.clone()),
415                Expression::symbol(a.clone()),
416            ]),
417            Expression::integer(2),
418        );
419
420        let result = expr.expand();
421
422        match result {
423            Expression::Add(terms) => {
424                assert_eq!(
425                    terms.len(),
426                    4,
427                    "Expected 4 terms when ANY term is noncommutative"
428                );
429            }
430            _ => panic!("Expected addition of 4 terms"),
431        }
432    }
433
434    #[test]
435    fn test_distribution_preserves_order_for_matrices() {
436        let a = symbol!(A; matrix);
437        let b = symbol!(B; matrix);
438        let c = symbol!(C; matrix);
439
440        let expr = Expression::mul(vec![
441            Expression::add(vec![
442                Expression::symbol(a.clone()),
443                Expression::symbol(b.clone()),
444            ]),
445            Expression::symbol(c.clone()),
446        ]);
447
448        let result = expr.expand();
449
450        match result {
451            Expression::Add(terms) => {
452                assert_eq!(terms.len(), 2, "Expected AC + BC");
453            }
454            _ => panic!("Expected addition"),
455        }
456    }
457
458    #[test]
459    fn test_binomial_theorem_not_used_for_noncommutative() {
460        let a = symbol!(A; matrix);
461        let b = symbol!(B; matrix);
462
463        let result = Expression::integer(1).expand_binomial(
464            &Expression::symbol(a.clone()),
465            &Expression::symbol(b.clone()),
466            3,
467        );
468
469        // (via repeated multiplication, not binomial theorem)
470        assert!(!result.is_zero());
471    }
472}