mathhook_core/algebra/
factor.rs

1//! Factorization operations for expressions
2//! Handles polynomial factorization, common factor extraction, and algebraic factoring
3
4mod common;
5mod noncommutative;
6mod quadratic;
7
8use crate::core::commutativity::Commutativity;
9use crate::core::Expression;
10// num_traits imports removed
11
12/// Trait for factoring expressions
13pub trait Factor {
14    fn factor(&self) -> Self;
15    fn factor_out_gcd(&self) -> Self;
16    fn factor_common(&self) -> Self;
17}
18
19impl Factor for Expression {
20    /// Factor the expression by extracting common factors
21    fn factor(&self) -> Self {
22        match self {
23            Expression::Number(_) | Expression::Symbol(_) => self.clone(),
24
25            Expression::Add(terms) => self.factor_addition(terms),
26
27            Expression::Mul(factors) => {
28                let factored_factors: Vec<Expression> =
29                    factors.iter().map(|f| f.factor()).collect();
30                Expression::mul(factored_factors)
31            }
32
33            Expression::Pow(base, exp) => Expression::pow(base.factor(), exp.factor()),
34
35            Expression::Function { name, args } => {
36                let factored_args: Vec<Expression> = args.iter().map(|arg| arg.factor()).collect();
37                Expression::function(name.clone(), factored_args)
38            }
39            _ => self.clone(),
40        }
41    }
42
43    /// Factor out the GCD from an expression
44    fn factor_out_gcd(&self) -> Self {
45        match self {
46            Expression::Add(terms) => {
47                if terms.len() < 2 {
48                    return self.clone();
49                }
50
51                let mut common_factor = terms[0].clone();
52                for term in &terms[1..] {
53                    common_factor = common_factor.gcd(term);
54                    if common_factor.is_one() {
55                        return self.clone();
56                    }
57                }
58
59                if !common_factor.is_one() {
60                    let factored_terms: Vec<Expression> = terms
61                        .iter()
62                        .map(|term| self.divide_by_factor(term, &common_factor))
63                        .collect();
64
65                    Expression::mul(vec![common_factor, Expression::add(factored_terms)])
66                } else {
67                    self.clone()
68                }
69            }
70            _ => self.clone(),
71        }
72    }
73
74    /// Factor common elements
75    fn factor_common(&self) -> Self {
76        self.factor_out_gcd()
77    }
78}
79
80impl Expression {
81    /// Factor addition expressions by finding common factors
82    ///
83    /// For commutative terms: AB + AC = A(B+C)
84    /// For noncommutative terms: Try left factoring first, then right factoring
85    ///   Left: AB + AC = A(B+C)
86    ///   Right: BA + CA = (B+C)A
87    fn factor_addition(&self, terms: &[Expression]) -> Expression {
88        if terms.len() < 2 {
89            return Expression::add(terms.to_vec());
90        }
91
92        let commutativity = Commutativity::combine(terms.iter().map(|t| t.commutativity()));
93
94        if commutativity.can_sort() {
95            let common_factor = self.find_common_factor_in_terms(terms);
96
97            if !common_factor.is_one() {
98                let factored_terms: Vec<Expression> = terms
99                    .iter()
100                    .map(|term| self.divide_by_factor(term, &common_factor))
101                    .collect();
102
103                Expression::mul(vec![common_factor, Expression::add(factored_terms)])
104            } else {
105                self.try_quadratic_factoring(terms)
106                    .unwrap_or_else(|| Expression::add(terms.to_vec()))
107            }
108        } else {
109            if let Some(left_factored) = self.try_left_factor(terms) {
110                return left_factored;
111            }
112
113            if let Some(right_factored) = self.try_right_factor(terms) {
114                return right_factored;
115            }
116
117            Expression::add(terms.to_vec())
118        }
119    }
120}
121
122#[cfg(test)]
123mod tests {
124    use super::*;
125    use crate::symbol;
126    use num_bigint::BigInt;
127
128    #[test]
129    fn test_basic_factoring() {
130        let x = symbol!(x);
131
132        let expr = Expression::add(vec![
133            Expression::mul(vec![Expression::integer(2), Expression::symbol(x.clone())]),
134            Expression::integer(4),
135        ]);
136
137        let result = expr.factor();
138        println!("2x + 4 factored = {}", result);
139
140        match result {
141            Expression::Mul(_) => println!("Successfully factored"),
142            _ => println!("Factoring result: {}", result),
143        }
144    }
145
146    #[test]
147    fn test_gcd_factoring() {
148        let x = symbol!(x);
149
150        let expr = Expression::add(vec![
151            Expression::mul(vec![Expression::integer(6), Expression::symbol(x.clone())]),
152            Expression::integer(9),
153        ]);
154
155        let result = expr.factor_out_gcd();
156        println!("6x + 9 GCD factored = {}", result);
157
158        assert!(!result.is_zero());
159    }
160
161    #[test]
162    fn test_numeric_coefficient_extraction() {
163        let x = symbol!(x);
164
165        let expr = Expression::mul(vec![
166            Expression::integer(12),
167            Expression::symbol(x.clone()),
168            Expression::integer(5),
169        ]);
170
171        let (coeff, remaining) = expr.factor_numeric_coefficient();
172
173        println!("Coefficient: {}, Remaining: {}", coeff, remaining);
174        assert_eq!(coeff, BigInt::from(60));
175        assert_eq!(remaining, Expression::symbol(x));
176    }
177
178    #[test]
179    fn test_difference_of_squares() {
180        let x = symbol!(x);
181        let y = symbol!(y);
182
183        let result = Expression::integer(1).factor_difference_of_squares(
184            &Expression::symbol(x.clone()),
185            &Expression::symbol(y.clone()),
186        );
187
188        println!("x^2 - y^2 factored = {}", result);
189
190        match result {
191            Expression::Mul(factors) => assert_eq!(factors.len(), 2),
192            _ => panic!("Expected multiplication"),
193        }
194    }
195
196    #[test]
197    fn test_common_factor_extraction() {
198        let x = symbol!(x);
199        let y = symbol!(y);
200
201        let expr = Expression::add(vec![
202            Expression::mul(vec![
203                Expression::symbol(x.clone()),
204                Expression::symbol(y.clone()),
205            ]),
206            Expression::symbol(x.clone()),
207        ]);
208
209        let result = expr.factor_common();
210        println!("xy + x factored = {}", result);
211
212        assert!(!result.is_zero());
213    }
214
215    #[test]
216    fn test_no_common_factor() {
217        let x = symbol!(x);
218        let y = symbol!(y);
219
220        let expr = Expression::add(vec![
221            Expression::symbol(x.clone()),
222            Expression::symbol(y.clone()),
223        ]);
224
225        let result = expr.factor();
226
227        assert_eq!(result, expr);
228    }
229
230    #[test]
231    fn test_left_factoring_matrices() {
232        let a = symbol!(A; matrix);
233        let b = symbol!(B; matrix);
234        let c = symbol!(C; matrix);
235
236        let expr = Expression::add(vec![
237            Expression::mul(vec![
238                Expression::symbol(a.clone()),
239                Expression::symbol(b.clone()),
240            ]),
241            Expression::mul(vec![
242                Expression::symbol(a.clone()),
243                Expression::symbol(c.clone()),
244            ]),
245        ]);
246
247        let result = expr.factor();
248        println!("AB + AC factored = {}", result);
249
250        match result {
251            Expression::Mul(factors) => {
252                assert_eq!(factors.len(), 2, "Expected factored form A(B+C) or (B+C)A");
253                let has_a = factors.iter().any(|f| f == &Expression::symbol(a.clone()));
254                let has_sum = factors.iter().any(|f| matches!(f, Expression::Add(_)));
255                assert!(has_a, "Should contain factor A");
256                assert!(has_sum, "Should contain sum (B+C)");
257            }
258            _ => panic!("Expected multiplication after factoring, got: {}", result),
259        }
260    }
261
262    #[test]
263    fn test_right_factoring_matrices() {
264        let a = symbol!(A; matrix);
265        let b = symbol!(B; matrix);
266        let c = symbol!(C; matrix);
267
268        let expr = Expression::add(vec![
269            Expression::mul(vec![
270                Expression::symbol(b.clone()),
271                Expression::symbol(a.clone()),
272            ]),
273            Expression::mul(vec![
274                Expression::symbol(c.clone()),
275                Expression::symbol(a.clone()),
276            ]),
277        ]);
278
279        let result = expr.factor();
280        println!("BA + CA factored = {}", result);
281
282        match result {
283            Expression::Mul(factors) => {
284                assert_eq!(factors.len(), 2, "Expected factored form (B+C)A or A(B+C)");
285                let has_a = factors.iter().any(|f| f == &Expression::symbol(a.clone()));
286                let has_sum = factors.iter().any(|f| matches!(f, Expression::Add(_)));
287                assert!(has_a, "Should contain factor A");
288                assert!(has_sum, "Should contain sum (B+C)");
289            }
290            _ => panic!("Expected multiplication after factoring, got: {}", result),
291        }
292    }
293
294    #[test]
295    fn test_cannot_cross_factor_noncommutative() {
296        let a = symbol!(A; matrix);
297        let b = symbol!(B; matrix);
298        let c = symbol!(C; matrix);
299        let d = symbol!(D; matrix);
300
301        let expr = Expression::add(vec![
302            Expression::mul(vec![
303                Expression::symbol(a.clone()),
304                Expression::symbol(b.clone()),
305            ]),
306            Expression::mul(vec![
307                Expression::symbol(c.clone()),
308                Expression::symbol(d.clone()),
309            ]),
310        ]);
311
312        let result = expr.factor();
313        println!("AB + CD factored = {}", result);
314
315        match result {
316            Expression::Add(_) => (),
317            _ => panic!("Expected no factoring for AB + CD"),
318        }
319    }
320
321    #[test]
322    fn test_operator_left_factoring() {
323        let p = symbol!(p; operator);
324        let x = symbol!(x; operator);
325        let h = symbol!(h; operator);
326
327        let expr = Expression::add(vec![
328            Expression::mul(vec![
329                Expression::symbol(p.clone()),
330                Expression::symbol(x.clone()),
331            ]),
332            Expression::mul(vec![
333                Expression::symbol(p.clone()),
334                Expression::symbol(h.clone()),
335            ]),
336        ]);
337
338        let result = expr.factor();
339        println!("px + ph factored = {}", result);
340
341        match result {
342            Expression::Mul(factors) => {
343                assert_eq!(factors.len(), 2, "Expected factored form p(x+h) or (x+h)p");
344                let has_p = factors.iter().any(|f| f == &Expression::symbol(p.clone()));
345                let has_sum = factors.iter().any(|f| matches!(f, Expression::Add(_)));
346                assert!(has_p, "Should contain factor p");
347                assert!(has_sum, "Should contain sum (x+h)");
348            }
349            _ => panic!("Expected multiplication after factoring, got: {}", result),
350        }
351    }
352
353    #[test]
354    fn test_commutative_factoring_unchanged() {
355        let x = symbol!(x);
356        let y = symbol!(y);
357
358        let expr = Expression::add(vec![
359            Expression::mul(vec![
360                Expression::symbol(x.clone()),
361                Expression::symbol(y.clone()),
362            ]),
363            Expression::mul(vec![
364                Expression::symbol(x.clone()),
365                Expression::symbol(y.clone()),
366            ]),
367        ]);
368
369        let result = expr.factor();
370        println!("Commutative xy + xz factored = {}", result);
371
372        assert!(!result.is_zero());
373    }
374    #[test]
375    fn test_matrix_same_position_factoring() {
376        let a = symbol!(A; matrix);
377        let b = symbol!(B; matrix);
378
379        // Test 2AB + 3AB = 5AB (same order, should combine)
380        let expr = Expression::add(vec![
381            Expression::mul(vec![
382                Expression::integer(2),
383                Expression::symbol(a.clone()),
384                Expression::symbol(b.clone()),
385            ]),
386            Expression::mul(vec![
387                Expression::integer(3),
388                Expression::symbol(a.clone()),
389                Expression::symbol(b.clone()),
390            ]),
391        ]);
392
393        let result = expr.factor();
394
395        // Should be able to factor out AB
396        assert!(!result.is_zero());
397    }
398}