mathhook_core/algebra/
collect.rs

1//! Term collection and combination operations
2//! Handles collecting like terms, combining coefficients, and organizing expressions
3
4mod coefficients;
5mod terms;
6
7use crate::core::{Expression, Symbol};
8
9/// Trait for collecting terms in expressions
10pub trait Collect {
11    fn collect(&self, var: &Symbol) -> Self;
12    fn collect_terms(&self) -> Self;
13    fn combine_like_terms(&self) -> Self;
14}
15
16impl Collect for Expression {
17    /// Collect terms with respect to a specific variable
18    fn collect(&self, var: &Symbol) -> Self {
19        match self {
20            Expression::Add(terms) => self.collect_addition_terms(terms, var),
21            _ => self.clone(),
22        }
23    }
24
25    /// Collect and combine all like terms
26    fn collect_terms(&self) -> Self {
27        match self {
28            Expression::Add(terms) => self.collect_all_like_terms(terms),
29            Expression::Mul(factors) => self.collect_multiplication_terms(factors),
30            _ => self.clone(),
31        }
32    }
33
34    /// Combine like terms in the expression
35    fn combine_like_terms(&self) -> Self {
36        self.collect_terms()
37    }
38}
39
40#[cfg(test)]
41mod tests {
42    use super::*;
43    use crate::{expr, symbol};
44
45    #[test]
46    fn test_collect_like_terms() {
47        let x = symbol!(x);
48
49        let expr = Expression::add(vec![
50            Expression::mul(vec![expr!(2), Expression::symbol(x.clone())]),
51            Expression::mul(vec![Expression::integer(3), Expression::symbol(x.clone())]),
52        ]);
53
54        let result = expr.collect(&x);
55        println!("2x + 3x collected = {}", result);
56
57        assert!(!result.is_zero());
58    }
59
60    #[test]
61    fn test_collect_different_powers() {
62        let x = symbol!(x);
63
64        let expr = Expression::add(vec![
65            Expression::pow(Expression::symbol(x.clone()), expr!(2)),
66            Expression::mul(vec![Expression::integer(2), Expression::symbol(x.clone())]),
67            Expression::pow(Expression::symbol(x.clone()), expr!(2)),
68        ]);
69
70        let result = expr.collect(&x);
71        println!("x^2 + 2x + x^2 collected = {}", result);
72
73        match result {
74            Expression::Add(terms) => {
75                assert_eq!(terms.len(), 2);
76            }
77            _ => println!("Collection result: {}", result),
78        }
79    }
80
81    #[test]
82    fn test_combine_like_terms() {
83        let x = symbol!(x);
84        let y = symbol!(y);
85
86        let expr = Expression::add(vec![
87            Expression::mul(vec![expr!(3), Expression::symbol(x.clone())]),
88            Expression::mul(vec![expr!(2), Expression::symbol(y.clone())]),
89            Expression::symbol(x.clone()),
90            Expression::symbol(y.clone()),
91        ]);
92
93        let result = expr.combine_like_terms();
94        println!("3x + 2y + x + y combined = {}", result);
95
96        assert!(!result.is_zero());
97    }
98
99    #[test]
100    fn test_collect_constants() {
101        let x = symbol!(x);
102
103        let expr = Expression::add(vec![
104            Expression::integer(5),
105            Expression::mul(vec![expr!(3), Expression::symbol(x.clone())]),
106            expr!(2),
107        ]);
108
109        let result = expr.collect(&x);
110        println!("5 + 3x + 2 collected = {}", result);
111
112        assert!(!result.is_zero());
113    }
114
115    #[test]
116    fn test_separate_constants() {
117        let x = symbol!(x);
118
119        let expr = Expression::add(vec![expr!(5), Expression::symbol(x.clone()), expr!(3)]);
120
121        let (constants, variables) = expr.separate_constants();
122
123        println!("Constants: {}, Variables: {}", constants, variables);
124
125        assert!(!constants.is_zero());
126        assert!(!variables.is_zero());
127    }
128
129    #[test]
130    fn test_collect_multiplication_powers() {
131        let x = symbol!(x);
132
133        let expr = Expression::mul(vec![
134            Expression::pow(Expression::symbol(x.clone()), expr!(2)),
135            Expression::pow(Expression::symbol(x.clone()), expr!(3)),
136        ]);
137
138        let result = expr.collect_terms();
139        println!("x^2 * x^3 collected = {}", result);
140
141        assert!(!result.is_zero());
142    }
143
144    #[test]
145    fn test_commutative_collection() {
146        let x = symbol!(x);
147        let y = symbol!(y);
148
149        let expr = Expression::add(vec![
150            Expression::mul(vec![
151                Expression::integer(2),
152                Expression::symbol(x.clone()),
153                Expression::symbol(y.clone()),
154            ]),
155            Expression::mul(vec![
156                Expression::integer(3),
157                Expression::symbol(x.clone()),
158                Expression::symbol(y.clone()),
159            ]),
160        ]);
161
162        let result = expr.combine_like_terms();
163        println!("2xy + 3xy = {}", result);
164
165        match result {
166            Expression::Mul(_) => {
167                println!("Successfully combined like terms");
168            }
169            _ => println!("Result: {}", result),
170        }
171    }
172
173    #[test]
174    fn test_noncommutative_no_collection_different_order() {
175        let a = symbol!(A; matrix);
176        let b = symbol!(B; matrix);
177
178        let expr = Expression::add(vec![
179            Expression::mul(vec![
180                Expression::integer(2),
181                Expression::symbol(a.clone()),
182                Expression::symbol(b.clone()),
183            ]),
184            Expression::mul(vec![
185                Expression::integer(3),
186                Expression::symbol(b.clone()),
187                Expression::symbol(a.clone()),
188            ]),
189        ]);
190
191        let result = expr.combine_like_terms();
192        println!("2AB + 3BA = {}", result);
193
194        match result {
195            Expression::Add(terms) => {
196                assert_eq!(
197                    terms.len(),
198                    2,
199                    "AB and BA should NOT combine (different order)"
200                );
201            }
202            _ => panic!("Expected addition of 2 separate terms"),
203        }
204    }
205
206    #[test]
207    fn test_noncommutative_collection_same_order() {
208        let a = symbol!(A; matrix);
209        let b = symbol!(B; matrix);
210
211        let expr = Expression::add(vec![
212            Expression::mul(vec![
213                Expression::integer(2),
214                Expression::symbol(a.clone()),
215                Expression::symbol(b.clone()),
216            ]),
217            Expression::mul(vec![
218                Expression::integer(3),
219                Expression::symbol(a.clone()),
220                Expression::symbol(b.clone()),
221            ]),
222        ]);
223
224        let result = expr.combine_like_terms();
225        println!("2AB + 3AB = {}", result);
226
227        match result {
228            Expression::Mul(_) => {
229                println!("Successfully combined like terms with same order");
230            }
231            Expression::Add(terms) if terms.len() == 1 => {
232                println!("Single term result (acceptable)");
233            }
234            _ => println!("Result: {}", result),
235        }
236    }
237
238    #[test]
239    fn test_operator_collection() {
240        let p = symbol!(p; operator);
241        let x = symbol!(x; operator);
242
243        let expr = Expression::add(vec![
244            Expression::mul(vec![
245                Expression::integer(2),
246                Expression::symbol(p.clone()),
247                Expression::symbol(x.clone()),
248            ]),
249            Expression::mul(vec![
250                Expression::integer(3),
251                Expression::symbol(x.clone()),
252                Expression::symbol(p.clone()),
253            ]),
254        ]);
255
256        let result = expr.combine_like_terms();
257        println!("2px + 3xp = {}", result);
258
259        match result {
260            Expression::Add(terms) => {
261                assert_eq!(terms.len(), 2, "px and xp should NOT combine");
262            }
263            _ => panic!("Expected addition of 2 separate terms"),
264        }
265    }
266
267    #[test]
268    fn test_quaternion_collection() {
269        let i = symbol!(i; quaternion);
270        let j = symbol!(j; quaternion);
271
272        let expr = Expression::add(vec![
273            Expression::mul(vec![
274                Expression::integer(2),
275                Expression::symbol(i.clone()),
276                Expression::symbol(j.clone()),
277            ]),
278            Expression::mul(vec![
279                Expression::integer(3),
280                Expression::symbol(j.clone()),
281                Expression::symbol(i.clone()),
282            ]),
283        ]);
284
285        let result = expr.combine_like_terms();
286        println!("2ij + 3ji = {}", result);
287
288        match result {
289            Expression::Add(terms) => {
290                assert_eq!(terms.len(), 2, "ij and ji should NOT combine");
291            }
292            _ => panic!("Expected addition of 2 separate terms"),
293        }
294    }
295
296    #[test]
297    fn test_mixed_commutative_noncommutative() {
298        let x = symbol!(x);
299        let a = symbol!(A; matrix);
300        let b = symbol!(B; matrix);
301
302        let expr = Expression::add(vec![
303            Expression::mul(vec![
304                Expression::integer(2),
305                Expression::symbol(x.clone()),
306                Expression::symbol(a.clone()),
307                Expression::symbol(b.clone()),
308            ]),
309            Expression::mul(vec![
310                Expression::integer(3),
311                Expression::symbol(x.clone()),
312                Expression::symbol(a.clone()),
313                Expression::symbol(b.clone()),
314            ]),
315        ]);
316
317        let result = expr.combine_like_terms();
318        println!("2xAB + 3xAB = {}", result);
319
320        assert!(!result.is_zero());
321    }
322}