1mod coefficients;
5mod terms;
6
7use crate::core::{Expression, Symbol};
8
9pub trait Collect {
11 fn collect(&self, var: &Symbol) -> Self;
12 fn collect_terms(&self) -> Self;
13 fn combine_like_terms(&self) -> Self;
14}
15
16impl Collect for Expression {
17 fn collect(&self, var: &Symbol) -> Self {
19 match self {
20 Expression::Add(terms) => self.collect_addition_terms(terms, var),
21 _ => self.clone(),
22 }
23 }
24
25 fn collect_terms(&self) -> Self {
27 match self {
28 Expression::Add(terms) => self.collect_all_like_terms(terms),
29 Expression::Mul(factors) => self.collect_multiplication_terms(factors),
30 _ => self.clone(),
31 }
32 }
33
34 fn combine_like_terms(&self) -> Self {
36 self.collect_terms()
37 }
38}
39
40#[cfg(test)]
41mod tests {
42 use super::*;
43 use crate::{expr, symbol};
44
45 #[test]
46 fn test_collect_like_terms() {
47 let x = symbol!(x);
48
49 let expr = Expression::add(vec![
50 Expression::mul(vec![expr!(2), Expression::symbol(x.clone())]),
51 Expression::mul(vec![Expression::integer(3), Expression::symbol(x.clone())]),
52 ]);
53
54 let result = expr.collect(&x);
55 println!("2x + 3x collected = {}", result);
56
57 assert!(!result.is_zero());
58 }
59
60 #[test]
61 fn test_collect_different_powers() {
62 let x = symbol!(x);
63
64 let expr = Expression::add(vec![
65 Expression::pow(Expression::symbol(x.clone()), expr!(2)),
66 Expression::mul(vec![Expression::integer(2), Expression::symbol(x.clone())]),
67 Expression::pow(Expression::symbol(x.clone()), expr!(2)),
68 ]);
69
70 let result = expr.collect(&x);
71 println!("x^2 + 2x + x^2 collected = {}", result);
72
73 match result {
74 Expression::Add(terms) => {
75 assert_eq!(terms.len(), 2);
76 }
77 _ => println!("Collection result: {}", result),
78 }
79 }
80
81 #[test]
82 fn test_combine_like_terms() {
83 let x = symbol!(x);
84 let y = symbol!(y);
85
86 let expr = Expression::add(vec![
87 Expression::mul(vec![expr!(3), Expression::symbol(x.clone())]),
88 Expression::mul(vec![expr!(2), Expression::symbol(y.clone())]),
89 Expression::symbol(x.clone()),
90 Expression::symbol(y.clone()),
91 ]);
92
93 let result = expr.combine_like_terms();
94 println!("3x + 2y + x + y combined = {}", result);
95
96 assert!(!result.is_zero());
97 }
98
99 #[test]
100 fn test_collect_constants() {
101 let x = symbol!(x);
102
103 let expr = Expression::add(vec![
104 Expression::integer(5),
105 Expression::mul(vec![expr!(3), Expression::symbol(x.clone())]),
106 expr!(2),
107 ]);
108
109 let result = expr.collect(&x);
110 println!("5 + 3x + 2 collected = {}", result);
111
112 assert!(!result.is_zero());
113 }
114
115 #[test]
116 fn test_separate_constants() {
117 let x = symbol!(x);
118
119 let expr = Expression::add(vec![expr!(5), Expression::symbol(x.clone()), expr!(3)]);
120
121 let (constants, variables) = expr.separate_constants();
122
123 println!("Constants: {}, Variables: {}", constants, variables);
124
125 assert!(!constants.is_zero());
126 assert!(!variables.is_zero());
127 }
128
129 #[test]
130 fn test_collect_multiplication_powers() {
131 let x = symbol!(x);
132
133 let expr = Expression::mul(vec![
134 Expression::pow(Expression::symbol(x.clone()), expr!(2)),
135 Expression::pow(Expression::symbol(x.clone()), expr!(3)),
136 ]);
137
138 let result = expr.collect_terms();
139 println!("x^2 * x^3 collected = {}", result);
140
141 assert!(!result.is_zero());
142 }
143
144 #[test]
145 fn test_commutative_collection() {
146 let x = symbol!(x);
147 let y = symbol!(y);
148
149 let expr = Expression::add(vec![
150 Expression::mul(vec![
151 Expression::integer(2),
152 Expression::symbol(x.clone()),
153 Expression::symbol(y.clone()),
154 ]),
155 Expression::mul(vec![
156 Expression::integer(3),
157 Expression::symbol(x.clone()),
158 Expression::symbol(y.clone()),
159 ]),
160 ]);
161
162 let result = expr.combine_like_terms();
163 println!("2xy + 3xy = {}", result);
164
165 match result {
166 Expression::Mul(_) => {
167 println!("Successfully combined like terms");
168 }
169 _ => println!("Result: {}", result),
170 }
171 }
172
173 #[test]
174 fn test_noncommutative_no_collection_different_order() {
175 let a = symbol!(A; matrix);
176 let b = symbol!(B; matrix);
177
178 let expr = Expression::add(vec![
179 Expression::mul(vec![
180 Expression::integer(2),
181 Expression::symbol(a.clone()),
182 Expression::symbol(b.clone()),
183 ]),
184 Expression::mul(vec![
185 Expression::integer(3),
186 Expression::symbol(b.clone()),
187 Expression::symbol(a.clone()),
188 ]),
189 ]);
190
191 let result = expr.combine_like_terms();
192 println!("2AB + 3BA = {}", result);
193
194 match result {
195 Expression::Add(terms) => {
196 assert_eq!(
197 terms.len(),
198 2,
199 "AB and BA should NOT combine (different order)"
200 );
201 }
202 _ => panic!("Expected addition of 2 separate terms"),
203 }
204 }
205
206 #[test]
207 fn test_noncommutative_collection_same_order() {
208 let a = symbol!(A; matrix);
209 let b = symbol!(B; matrix);
210
211 let expr = Expression::add(vec![
212 Expression::mul(vec![
213 Expression::integer(2),
214 Expression::symbol(a.clone()),
215 Expression::symbol(b.clone()),
216 ]),
217 Expression::mul(vec![
218 Expression::integer(3),
219 Expression::symbol(a.clone()),
220 Expression::symbol(b.clone()),
221 ]),
222 ]);
223
224 let result = expr.combine_like_terms();
225 println!("2AB + 3AB = {}", result);
226
227 match result {
228 Expression::Mul(_) => {
229 println!("Successfully combined like terms with same order");
230 }
231 Expression::Add(terms) if terms.len() == 1 => {
232 println!("Single term result (acceptable)");
233 }
234 _ => println!("Result: {}", result),
235 }
236 }
237
238 #[test]
239 fn test_operator_collection() {
240 let p = symbol!(p; operator);
241 let x = symbol!(x; operator);
242
243 let expr = Expression::add(vec![
244 Expression::mul(vec![
245 Expression::integer(2),
246 Expression::symbol(p.clone()),
247 Expression::symbol(x.clone()),
248 ]),
249 Expression::mul(vec![
250 Expression::integer(3),
251 Expression::symbol(x.clone()),
252 Expression::symbol(p.clone()),
253 ]),
254 ]);
255
256 let result = expr.combine_like_terms();
257 println!("2px + 3xp = {}", result);
258
259 match result {
260 Expression::Add(terms) => {
261 assert_eq!(terms.len(), 2, "px and xp should NOT combine");
262 }
263 _ => panic!("Expected addition of 2 separate terms"),
264 }
265 }
266
267 #[test]
268 fn test_quaternion_collection() {
269 let i = symbol!(i; quaternion);
270 let j = symbol!(j; quaternion);
271
272 let expr = Expression::add(vec![
273 Expression::mul(vec![
274 Expression::integer(2),
275 Expression::symbol(i.clone()),
276 Expression::symbol(j.clone()),
277 ]),
278 Expression::mul(vec![
279 Expression::integer(3),
280 Expression::symbol(j.clone()),
281 Expression::symbol(i.clone()),
282 ]),
283 ]);
284
285 let result = expr.combine_like_terms();
286 println!("2ij + 3ji = {}", result);
287
288 match result {
289 Expression::Add(terms) => {
290 assert_eq!(terms.len(), 2, "ij and ji should NOT combine");
291 }
292 _ => panic!("Expected addition of 2 separate terms"),
293 }
294 }
295
296 #[test]
297 fn test_mixed_commutative_noncommutative() {
298 let x = symbol!(x);
299 let a = symbol!(A; matrix);
300 let b = symbol!(B; matrix);
301
302 let expr = Expression::add(vec![
303 Expression::mul(vec![
304 Expression::integer(2),
305 Expression::symbol(x.clone()),
306 Expression::symbol(a.clone()),
307 Expression::symbol(b.clone()),
308 ]),
309 Expression::mul(vec![
310 Expression::integer(3),
311 Expression::symbol(x.clone()),
312 Expression::symbol(a.clone()),
313 Expression::symbol(b.clone()),
314 ]),
315 ]);
316
317 let result = expr.combine_like_terms();
318 println!("2xAB + 3xAB = {}", result);
319
320 assert!(!result.is_zero());
321 }
322}