1use crate::core::{Expression, Number};
5use crate::simplify::Simplify;
6use num_bigint::BigInt;
7use num_traits::{One, Zero};
8
9pub trait ZeroDetection {
11 fn is_algebraic_zero(&self) -> bool;
12 fn detect_zero_patterns(&self) -> bool;
13 fn simplify_to_zero(&self) -> Option<Expression>;
14}
15
16impl ZeroDetection for Expression {
17 fn is_algebraic_zero(&self) -> bool {
19 if self.is_zero() {
21 return true;
22 }
23
24 let simplified = self.simplify();
26 if simplified.is_zero() {
27 return true;
28 }
29
30 self.detect_zero_patterns()
32 }
33
34 fn detect_zero_patterns(&self) -> bool {
36 match self {
37 Expression::Add(terms) => self.detect_additive_zero_patterns(terms),
38
39 Expression::Mul(factors) => {
40 factors.iter().any(|f| f.is_zero() || f.is_algebraic_zero())
42 }
43
44 _ => false,
45 }
46 }
47
48 fn simplify_to_zero(&self) -> Option<Expression> {
50 if self.is_algebraic_zero() {
51 Some(Expression::integer(0))
52 } else {
53 None
54 }
55 }
56}
57
58impl Expression {
59 fn detect_additive_zero_patterns(&self, terms: &[Expression]) -> bool {
61 if self.has_additive_inverses(terms) {
63 return true;
64 }
65
66 if self.terms_cancel_out(terms) {
68 return true;
69 }
70
71 if self.detect_complex_zero_identities(terms) {
73 return true;
74 }
75
76 false
77 }
78
79 fn has_additive_inverses(&self, terms: &[Expression]) -> bool {
81 for (i, term1) in terms.iter().enumerate() {
82 for (j, term2) in terms.iter().enumerate() {
83 if i != j && self.are_additive_inverses(term1, term2) {
84 let remaining_terms: Vec<&Expression> = terms
86 .iter()
87 .enumerate()
88 .filter(|(k, _)| *k != i && *k != j)
89 .map(|(_, t)| t)
90 .collect();
91
92 if remaining_terms.is_empty() {
93 return true; }
95
96 let remaining_expr =
98 Expression::add(remaining_terms.into_iter().cloned().collect());
99 if remaining_expr.is_algebraic_zero() {
100 return true;
101 }
102 }
103 }
104 }
105 false
106 }
107
108 fn are_additive_inverses(&self, expr1: &Expression, expr2: &Expression) -> bool {
110 match (expr1, expr2) {
111 (Expression::Number(Number::Integer(a)), Expression::Number(Number::Integer(b))) => {
113 *a + *b == 0
114 }
115
116 (Expression::Symbol(s1), Expression::Mul(factors)) => {
118 if factors.len() == 2 {
119 if let (Expression::Number(Number::Integer(n)), Expression::Symbol(s2)) =
120 (&factors[0], &factors[1])
121 {
122 *n == -1 && s1 == s2
123 } else if let (Expression::Symbol(s2), Expression::Number(Number::Integer(n))) =
124 (&factors[0], &factors[1])
125 {
126 *n == -1 && s1 == s2
127 } else {
128 false
129 }
130 } else {
131 false
132 }
133 }
134
135 (Expression::Mul(_factors), Expression::Symbol(_s1)) => {
137 self.are_additive_inverses(expr2, expr1)
138 }
139
140 (Expression::Mul(factors1), Expression::Mul(factors2)) => {
142 self.are_multiplicative_inverses(factors1, factors2)
143 }
144
145 _ => false,
146 }
147 }
148
149 fn are_multiplicative_inverses(
151 &self,
152 factors1: &[Expression],
153 factors2: &[Expression],
154 ) -> bool {
155 let (neg_factors, pos_factors) = if self.has_negative_one_factor(factors1) {
157 (factors1, factors2)
158 } else if self.has_negative_one_factor(factors2) {
159 (factors2, factors1)
160 } else {
161 return false;
162 };
163
164 let neg_without_minus_one: Vec<Expression> = neg_factors
166 .iter()
167 .filter(|f| !matches!(f, Expression::Number(Number::Integer(n)) if *n == -1))
168 .cloned()
169 .collect();
170
171 self.are_factor_sets_equal(&neg_without_minus_one, pos_factors)
173 }
174
175 fn has_negative_one_factor(&self, factors: &[Expression]) -> bool {
177 factors
178 .iter()
179 .any(|f| matches!(f, Expression::Number(Number::Integer(n)) if *n == -1))
180 }
181
182 fn are_factor_sets_equal(&self, factors1: &[Expression], factors2: &[Expression]) -> bool {
184 if factors1.len() != factors2.len() {
185 return false;
186 }
187
188 for factor1 in factors1 {
190 if !factors2.contains(factor1) {
191 return false;
192 }
193 }
194
195 true
196 }
197
198 fn terms_cancel_out(&self, terms: &[Expression]) -> bool {
200 let mut term_coefficients: Vec<(Expression, BigInt)> = Vec::new();
203
204 for term in terms {
205 let (coeff, base) = self.extract_coefficient_and_base_term(term);
206 let mut found = false;
208 for (existing_expr, existing_coeff) in term_coefficients.iter_mut() {
209 if *existing_expr == base {
210 *existing_coeff += &coeff;
211 found = true;
212 break;
213 }
214 }
215 if !found {
216 term_coefficients.push((base, coeff));
217 }
218 }
219
220 term_coefficients.iter().all(|(_, coeff)| coeff.is_zero())
222 }
223
224 fn extract_coefficient_and_base_term(&self, term: &Expression) -> (BigInt, Expression) {
226 match term {
227 Expression::Number(Number::Integer(n)) => (BigInt::from(*n), Expression::integer(1)),
228 Expression::Symbol(_) => (BigInt::one(), term.clone()),
229 Expression::Mul(factors) => {
230 let mut coefficient = BigInt::one();
231 let mut base_factors = Vec::new();
232
233 for factor in factors.iter() {
234 if let Expression::Number(Number::Integer(n)) = factor {
235 coefficient *= BigInt::from(*n);
236 } else {
237 base_factors.push(factor.clone());
238 }
239 }
240
241 let base = if base_factors.is_empty() {
242 Expression::integer(1)
243 } else if base_factors.len() == 1 {
244 base_factors[0].clone()
245 } else {
246 Expression::mul(base_factors)
247 };
248
249 (coefficient, base)
250 }
251 _ => (BigInt::one(), term.clone()),
252 }
253 }
254
255 fn detect_complex_zero_identities(&self, terms: &[Expression]) -> bool {
257 if terms.len() >= 3 {
262 if let Some(expanded) = self.try_expand_and_simplify(terms) {
264 return expanded.is_zero();
265 }
266 }
267
268 false
269 }
270
271 fn try_expand_and_simplify(&self, terms: &[Expression]) -> Option<Expression> {
273 let mut simplified_terms = Vec::new();
277
278 for term in terms {
279 match term {
280 Expression::Mul(factors) if factors.len() >= 2 => {
282 if let Some(expanded) = self.try_expand_multiplication(factors) {
283 if let Expression::Add(expanded_terms) = expanded {
284 simplified_terms.extend(expanded_terms.into_iter());
285 } else {
286 simplified_terms.push(expanded);
287 }
288 } else {
289 simplified_terms.push(term.clone());
290 }
291 }
292 _ => simplified_terms.push(term.clone()),
293 }
294 }
295
296 let result = Expression::add(simplified_terms).simplify();
297 Some(result)
298 }
299
300 fn try_expand_multiplication(&self, factors: &[Expression]) -> Option<Expression> {
302 if factors.len() == 2 {
304 match (&factors[0], &factors[1]) {
305 (Expression::Number(Number::Integer(coeff)), Expression::Add(terms)) => {
306 let distributed_terms: Vec<Expression> = terms
308 .iter()
309 .map(|term| {
310 Expression::mul(vec![Expression::integer(*coeff), term.clone()])
311 })
312 .collect();
313 Some(Expression::add(distributed_terms))
314 }
315 (Expression::Add(terms), Expression::Number(Number::Integer(coeff))) => {
316 let distributed_terms: Vec<Expression> = terms
318 .iter()
319 .map(|term| {
320 Expression::mul(vec![term.clone(), Expression::integer(*coeff)])
321 })
322 .collect();
323 Some(Expression::add(distributed_terms))
324 }
325 _ => None,
326 }
327 } else {
328 None
329 }
330 }
331
332 pub fn detect_advanced_zero_patterns(&self) -> bool {
334 match self {
335 Expression::Add(terms) if terms.len() == 2 => {
337 if self.are_additive_inverses(&terms[0], &terms[1]) {
338 return true;
339 }
340 false
341 }
342
343 Expression::Add(terms) => {
345 if let Some(factored) = self.try_factor_for_zero_detection(terms) {
347 return factored.is_zero();
348 }
349 false
350 }
351
352 _ => false,
353 }
354 }
355
356 fn try_factor_for_zero_detection(&self, _terms: &[Expression]) -> Option<Expression> {
358 None
361 }
362}
363
364#[cfg(test)]
365mod tests {
366 use super::*;
367 use crate::symbol;
368
369 #[test]
370 fn test_zero_detection_basic() {
371 let zero = Expression::integer(0);
373 assert!(zero.is_algebraic_zero());
374
375 let x = symbol!(x);
377 let expr = Expression::add(vec![
378 Expression::symbol(x.clone()),
379 Expression::mul(vec![Expression::integer(-1), Expression::symbol(x.clone())]),
380 ]);
381
382 println!("x + (-x) = {}", expr);
384 assert!(expr.is_algebraic_zero());
385 }
386
387 #[test]
388 fn test_additive_inverse_detection() {
389 let x = symbol!(x);
390
391 let term1 = Expression::symbol(x.clone());
392 let term2 = Expression::mul(vec![Expression::integer(-1), Expression::symbol(x.clone())]);
393
394 let expr = Expression::integer(1); assert!(expr.are_additive_inverses(&term1, &term2));
396 }
397
398 #[test]
399 fn test_numeric_zero_detection() {
400 let expr = Expression::add(vec![Expression::integer(5), Expression::integer(-5)]);
402
403 assert!(expr.is_algebraic_zero());
404 }
405
406 #[test]
407 fn test_complex_zero_pattern() {
408 let x = symbol!(x);
409
410 let expr = Expression::add(vec![
412 Expression::integer(4),
413 Expression::mul(vec![Expression::integer(4), Expression::symbol(x.clone())]),
414 Expression::mul(vec![
415 Expression::integer(-2),
416 Expression::add(vec![
417 Expression::integer(2),
418 Expression::mul(vec![Expression::integer(2), Expression::symbol(x.clone())]),
419 ]),
420 ]),
421 ]);
422
423 println!("Complex zero pattern: {}", expr);
424
425 let is_zero = expr.is_algebraic_zero();
428 println!("Is algebraic zero: {}", is_zero);
429 }
430
431 #[test]
432 fn test_zero_simplification() {
433 let x = symbol!(x);
434
435 let expr = Expression::add(vec![
436 Expression::symbol(x.clone()),
437 Expression::mul(vec![Expression::integer(-1), Expression::symbol(x.clone())]),
438 ]);
439
440 if let Some(simplified) = expr.simplify_to_zero() {
441 assert_eq!(simplified, Expression::integer(0));
442 }
443 }
444
445 #[test]
446 fn test_multiplication_zero_detection() {
447 let x = symbol!(x);
448
449 let expr = Expression::mul(vec![Expression::integer(0), Expression::symbol(x.clone())]);
451
452 assert!(expr.is_algebraic_zero());
453 }
454}