1mod common;
5mod noncommutative;
6mod quadratic;
7
8use crate::core::commutativity::Commutativity;
9use crate::core::Expression;
10pub trait Factor {
14 fn factor(&self) -> Self;
15 fn factor_out_gcd(&self) -> Self;
16 fn factor_common(&self) -> Self;
17}
18
19impl Factor for Expression {
20 fn factor(&self) -> Self {
22 match self {
23 Expression::Number(_) | Expression::Symbol(_) => self.clone(),
24
25 Expression::Add(terms) => self.factor_addition(terms),
26
27 Expression::Mul(factors) => {
28 let factored_factors: Vec<Expression> =
29 factors.iter().map(|f| f.factor()).collect();
30 Expression::mul(factored_factors)
31 }
32
33 Expression::Pow(base, exp) => Expression::pow(base.factor(), exp.factor()),
34
35 Expression::Function { name, args } => {
36 let factored_args: Vec<Expression> = args.iter().map(|arg| arg.factor()).collect();
37 Expression::function(name.clone(), factored_args)
38 }
39 _ => self.clone(),
40 }
41 }
42
43 fn factor_out_gcd(&self) -> Self {
45 match self {
46 Expression::Add(terms) => {
47 if terms.len() < 2 {
48 return self.clone();
49 }
50
51 let mut common_factor = terms[0].clone();
52 for term in &terms[1..] {
53 common_factor = common_factor.gcd(term);
54 if common_factor.is_one() {
55 return self.clone();
56 }
57 }
58
59 if !common_factor.is_one() {
60 let factored_terms: Vec<Expression> = terms
61 .iter()
62 .map(|term| self.divide_by_factor(term, &common_factor))
63 .collect();
64
65 Expression::mul(vec![common_factor, Expression::add(factored_terms)])
66 } else {
67 self.clone()
68 }
69 }
70 _ => self.clone(),
71 }
72 }
73
74 fn factor_common(&self) -> Self {
76 self.factor_out_gcd()
77 }
78}
79
80impl Expression {
81 fn factor_addition(&self, terms: &[Expression]) -> Expression {
88 if terms.len() < 2 {
89 return Expression::add(terms.to_vec());
90 }
91
92 let commutativity = Commutativity::combine(terms.iter().map(|t| t.commutativity()));
93
94 if commutativity.can_sort() {
95 let common_factor = self.find_common_factor_in_terms(terms);
96
97 if !common_factor.is_one() {
98 let factored_terms: Vec<Expression> = terms
99 .iter()
100 .map(|term| self.divide_by_factor(term, &common_factor))
101 .collect();
102
103 Expression::mul(vec![common_factor, Expression::add(factored_terms)])
104 } else {
105 self.try_quadratic_factoring(terms)
106 .unwrap_or_else(|| Expression::add(terms.to_vec()))
107 }
108 } else {
109 if let Some(left_factored) = self.try_left_factor(terms) {
110 return left_factored;
111 }
112
113 if let Some(right_factored) = self.try_right_factor(terms) {
114 return right_factored;
115 }
116
117 Expression::add(terms.to_vec())
118 }
119 }
120}
121
122#[cfg(test)]
123mod tests {
124 use super::*;
125 use crate::symbol;
126 use num_bigint::BigInt;
127
128 #[test]
129 fn test_basic_factoring() {
130 let x = symbol!(x);
131
132 let expr = Expression::add(vec![
133 Expression::mul(vec![Expression::integer(2), Expression::symbol(x.clone())]),
134 Expression::integer(4),
135 ]);
136
137 let result = expr.factor();
138 println!("2x + 4 factored = {}", result);
139
140 match result {
141 Expression::Mul(_) => println!("Successfully factored"),
142 _ => println!("Factoring result: {}", result),
143 }
144 }
145
146 #[test]
147 fn test_gcd_factoring() {
148 let x = symbol!(x);
149
150 let expr = Expression::add(vec![
151 Expression::mul(vec![Expression::integer(6), Expression::symbol(x.clone())]),
152 Expression::integer(9),
153 ]);
154
155 let result = expr.factor_out_gcd();
156 println!("6x + 9 GCD factored = {}", result);
157
158 assert!(!result.is_zero());
159 }
160
161 #[test]
162 fn test_numeric_coefficient_extraction() {
163 let x = symbol!(x);
164
165 let expr = Expression::mul(vec![
166 Expression::integer(12),
167 Expression::symbol(x.clone()),
168 Expression::integer(5),
169 ]);
170
171 let (coeff, remaining) = expr.factor_numeric_coefficient();
172
173 println!("Coefficient: {}, Remaining: {}", coeff, remaining);
174 assert_eq!(coeff, BigInt::from(60));
175 assert_eq!(remaining, Expression::symbol(x));
176 }
177
178 #[test]
179 fn test_difference_of_squares() {
180 let x = symbol!(x);
181 let y = symbol!(y);
182
183 let result = Expression::integer(1).factor_difference_of_squares(
184 &Expression::symbol(x.clone()),
185 &Expression::symbol(y.clone()),
186 );
187
188 println!("x^2 - y^2 factored = {}", result);
189
190 match result {
191 Expression::Mul(factors) => assert_eq!(factors.len(), 2),
192 _ => panic!("Expected multiplication"),
193 }
194 }
195
196 #[test]
197 fn test_common_factor_extraction() {
198 let x = symbol!(x);
199 let y = symbol!(y);
200
201 let expr = Expression::add(vec![
202 Expression::mul(vec![
203 Expression::symbol(x.clone()),
204 Expression::symbol(y.clone()),
205 ]),
206 Expression::symbol(x.clone()),
207 ]);
208
209 let result = expr.factor_common();
210 println!("xy + x factored = {}", result);
211
212 assert!(!result.is_zero());
213 }
214
215 #[test]
216 fn test_no_common_factor() {
217 let x = symbol!(x);
218 let y = symbol!(y);
219
220 let expr = Expression::add(vec![
221 Expression::symbol(x.clone()),
222 Expression::symbol(y.clone()),
223 ]);
224
225 let result = expr.factor();
226
227 assert_eq!(result, expr);
228 }
229
230 #[test]
231 fn test_left_factoring_matrices() {
232 let a = symbol!(A; matrix);
233 let b = symbol!(B; matrix);
234 let c = symbol!(C; matrix);
235
236 let expr = Expression::add(vec![
237 Expression::mul(vec![
238 Expression::symbol(a.clone()),
239 Expression::symbol(b.clone()),
240 ]),
241 Expression::mul(vec![
242 Expression::symbol(a.clone()),
243 Expression::symbol(c.clone()),
244 ]),
245 ]);
246
247 let result = expr.factor();
248 println!("AB + AC factored = {}", result);
249
250 match result {
251 Expression::Mul(factors) => {
252 assert_eq!(factors.len(), 2, "Expected factored form A(B+C) or (B+C)A");
253 let has_a = factors.iter().any(|f| f == &Expression::symbol(a.clone()));
254 let has_sum = factors.iter().any(|f| matches!(f, Expression::Add(_)));
255 assert!(has_a, "Should contain factor A");
256 assert!(has_sum, "Should contain sum (B+C)");
257 }
258 _ => panic!("Expected multiplication after factoring, got: {}", result),
259 }
260 }
261
262 #[test]
263 fn test_right_factoring_matrices() {
264 let a = symbol!(A; matrix);
265 let b = symbol!(B; matrix);
266 let c = symbol!(C; matrix);
267
268 let expr = Expression::add(vec![
269 Expression::mul(vec![
270 Expression::symbol(b.clone()),
271 Expression::symbol(a.clone()),
272 ]),
273 Expression::mul(vec![
274 Expression::symbol(c.clone()),
275 Expression::symbol(a.clone()),
276 ]),
277 ]);
278
279 let result = expr.factor();
280 println!("BA + CA factored = {}", result);
281
282 match result {
283 Expression::Mul(factors) => {
284 assert_eq!(factors.len(), 2, "Expected factored form (B+C)A or A(B+C)");
285 let has_a = factors.iter().any(|f| f == &Expression::symbol(a.clone()));
286 let has_sum = factors.iter().any(|f| matches!(f, Expression::Add(_)));
287 assert!(has_a, "Should contain factor A");
288 assert!(has_sum, "Should contain sum (B+C)");
289 }
290 _ => panic!("Expected multiplication after factoring, got: {}", result),
291 }
292 }
293
294 #[test]
295 fn test_cannot_cross_factor_noncommutative() {
296 let a = symbol!(A; matrix);
297 let b = symbol!(B; matrix);
298 let c = symbol!(C; matrix);
299 let d = symbol!(D; matrix);
300
301 let expr = Expression::add(vec![
302 Expression::mul(vec![
303 Expression::symbol(a.clone()),
304 Expression::symbol(b.clone()),
305 ]),
306 Expression::mul(vec![
307 Expression::symbol(c.clone()),
308 Expression::symbol(d.clone()),
309 ]),
310 ]);
311
312 let result = expr.factor();
313 println!("AB + CD factored = {}", result);
314
315 match result {
316 Expression::Add(_) => (),
317 _ => panic!("Expected no factoring for AB + CD"),
318 }
319 }
320
321 #[test]
322 fn test_operator_left_factoring() {
323 let p = symbol!(p; operator);
324 let x = symbol!(x; operator);
325 let h = symbol!(h; operator);
326
327 let expr = Expression::add(vec![
328 Expression::mul(vec![
329 Expression::symbol(p.clone()),
330 Expression::symbol(x.clone()),
331 ]),
332 Expression::mul(vec![
333 Expression::symbol(p.clone()),
334 Expression::symbol(h.clone()),
335 ]),
336 ]);
337
338 let result = expr.factor();
339 println!("px + ph factored = {}", result);
340
341 match result {
342 Expression::Mul(factors) => {
343 assert_eq!(factors.len(), 2, "Expected factored form p(x+h) or (x+h)p");
344 let has_p = factors.iter().any(|f| f == &Expression::symbol(p.clone()));
345 let has_sum = factors.iter().any(|f| matches!(f, Expression::Add(_)));
346 assert!(has_p, "Should contain factor p");
347 assert!(has_sum, "Should contain sum (x+h)");
348 }
349 _ => panic!("Expected multiplication after factoring, got: {}", result),
350 }
351 }
352
353 #[test]
354 fn test_commutative_factoring_unchanged() {
355 let x = symbol!(x);
356 let y = symbol!(y);
357
358 let expr = Expression::add(vec![
359 Expression::mul(vec![
360 Expression::symbol(x.clone()),
361 Expression::symbol(y.clone()),
362 ]),
363 Expression::mul(vec![
364 Expression::symbol(x.clone()),
365 Expression::symbol(y.clone()),
366 ]),
367 ]);
368
369 let result = expr.factor();
370 println!("Commutative xy + xz factored = {}", result);
371
372 assert!(!result.is_zero());
373 }
374 #[test]
375 fn test_matrix_same_position_factoring() {
376 let a = symbol!(A; matrix);
377 let b = symbol!(B; matrix);
378
379 let expr = Expression::add(vec![
381 Expression::mul(vec![
382 Expression::integer(2),
383 Expression::symbol(a.clone()),
384 Expression::symbol(b.clone()),
385 ]),
386 Expression::mul(vec![
387 Expression::integer(3),
388 Expression::symbol(a.clone()),
389 Expression::symbol(b.clone()),
390 ]),
391 ]);
392
393 let result = expr.factor();
394
395 assert!(!result.is_zero());
397 }
398}