1use std::cmp::Ordering;
2use super::{Expr, E, constant};
3
4fn is_const(e: &Expr, v: f64) -> bool {
5 matches!(e, Expr::Const(c) if *c == v)
6}
7
8fn is_const_int(e: &Expr) -> Option<i64> {
9 if let Expr::Const(v) = e
10 && *v == v.floor() && v.abs() < 1e15 {
11 return Some(*v as i64);
12 }
13 None
14}
15
16fn type_priority(e: &Expr) -> u8 {
21 match e {
22 Expr::Const(_) => 100, Expr::Sym(_) => 0,
24 Expr::Pow(base, _) => {
25 if matches!(base.as_ref(), Expr::Sym(_)) { 0 } else { 2 }
27 }
28 Expr::Mul(_, _) => 1,
29 Expr::Neg(_) => 3,
30 Expr::Add(_, _) | Expr::Sub(_, _) => 4,
31 _ => 5, }
33}
34
35fn leading_sym(e: &Expr) -> Option<&str> {
37 match e {
38 Expr::Sym(s) => Some(s),
39 Expr::Pow(base, _) => leading_sym(base),
40 Expr::Mul(a, b) => leading_sym(a).or_else(|| leading_sym(b)),
41 Expr::Neg(a) => leading_sym(a),
42 _ => None,
43 }
44}
45
46fn degree(e: &Expr) -> i64 {
48 match e {
49 Expr::Sym(_) => 1,
50 Expr::Pow(_, exp) => {
51 if let Expr::Const(v) = exp.as_ref() {
52 *v as i64
53 } else {
54 2 }
56 }
57 Expr::Mul(a, b) => degree(a) + degree(b),
58 Expr::Neg(a) => degree(a),
59 Expr::Const(_) => 0,
60 _ => 1,
61 }
62}
63
64fn mul_factor_cmp(a: &E, b: &E) -> Ordering {
66 let sa = leading_sym(a);
67 let sb = leading_sym(b);
68 match (sa, sb) {
69 (Some(sa), Some(sb)) => {
70 let cmp = sa.cmp(sb);
71 if cmp != Ordering::Equal { return cmp; }
72 degree(a).cmp(°ree(b))
74 }
75 (Some(_), None) => Ordering::Less,
76 (None, Some(_)) => Ordering::Greater,
77 (None, None) => {
78 let cmp = type_priority(a).cmp(&type_priority(b));
79 if cmp != Ordering::Equal { return cmp; }
80 format!("{}", a).cmp(&format!("{}", b))
82 }
83 }
84}
85
86fn add_term_cmp(a: &E, b: &E) -> Ordering {
89 let pa = type_priority(a);
90 let pb = type_priority(b);
91 if pa == 100 && pb != 100 { return Ordering::Greater; }
93 if pa != 100 && pb == 100 { return Ordering::Less; }
94 if pa == 100 && pb == 100 {
95 if let (Expr::Const(va), Expr::Const(vb)) = (a.as_ref(), b.as_ref()) {
97 return vb.partial_cmp(va).unwrap_or(Ordering::Equal);
98 }
99 return Ordering::Equal;
100 }
101
102 let da = degree(a);
104 let db = degree(b);
105 if da != db { return db.cmp(&da); }
106
107 let sa = leading_sym(a);
109 let sb = leading_sym(b);
110 match (sa, sb) {
111 (Some(sa), Some(sb)) => sa.cmp(sb),
112 (Some(_), None) => Ordering::Less,
113 (None, Some(_)) => Ordering::Greater,
114 (None, None) => Ordering::Equal,
115 }
116}
117
118fn base_and_exp(e: &E) -> (E, f64) {
124 if let Expr::Pow(base, exp) = e.as_ref()
125 && let Expr::Const(n) = exp.as_ref() {
126 return (base.clone(), *n);
127 }
128 (e.clone(), 1.0)
129}
130
131fn flatten_mul(e: &E) -> (f64, Vec<E>) {
134 match e.as_ref() {
135 Expr::Mul(a, b) => {
136 let (ca, mut fa) = flatten_mul(a);
137 let (cb, fb) = flatten_mul(b);
138 fa.extend(fb);
139 (ca * cb, fa)
140 }
141 Expr::Neg(inner) => {
142 let (c, f) = flatten_mul(inner);
143 (-c, f)
144 }
145 Expr::Const(v) => (*v, vec![]),
146 Expr::Pow(base, exp) if matches!(base.as_ref(), Expr::Mul(..) | Expr::Neg(..)) => {
148 if let Expr::Const(n) = exp.as_ref() {
149 let (c_base, factors) = flatten_mul(base);
150 let coeff = c_base.powf(*n);
151 let powered: Vec<E> = factors
152 .into_iter()
153 .map(|f| E::new(Expr::Pow(f, exp.clone())))
154 .collect();
155 (coeff, powered)
156 } else {
157 (1.0, vec![e.clone()])
158 }
159 }
160 _ => (1.0, vec![e.clone()]),
161 }
162}
163
164fn combine_powers(factors: Vec<E>) -> Vec<(E, f64)> {
167 let mut groups: Vec<(E, f64)> = Vec::new();
168 for f in factors {
169 let (base, exp) = base_and_exp(&f);
170 if let Some(entry) = groups.iter_mut().find(|(b, _)| *b == base) {
171 entry.1 += exp;
172 } else {
173 groups.push((base, exp));
174 }
175 }
176 groups
177}
178
179fn build_product(coeff: f64, mut factors: Vec<E>) -> E {
181 if coeff == 0.0 { return constant(0.0); }
182
183 factors.sort_by(mul_factor_cmp);
185
186 let factors_expr = if factors.is_empty() {
188 return constant(coeff);
189 } else {
190 let mut iter = factors.into_iter();
191 let first = iter.next().unwrap();
192 iter.fold(first, |acc, f| E::new(Expr::Mul(acc, f)))
193 };
194
195 if coeff == 1.0 {
196 factors_expr
197 } else if coeff == -1.0 {
198 E::new(Expr::Neg(factors_expr))
199 } else {
200 E::new(Expr::Mul(constant(coeff), factors_expr))
201 }
202}
203
204fn simplify_product(a: E, b: E) -> E {
206 let (ca, fa) = flatten_mul(&a);
208 let (cb, fb) = flatten_mul(&b);
209
210 let coeff = ca * cb;
211 let mut all_factors = fa;
212 all_factors.extend(fb);
213
214 if coeff == 0.0 { return constant(0.0); }
215 if all_factors.is_empty() { return constant(coeff); }
216
217 let has_div = all_factors.iter().any(|f| matches!(f.as_ref(), Expr::Div(..)));
220 if has_div {
221 let mut num_factors = Vec::new();
222 let mut den_factors = Vec::new();
223 let mut num_coeff = coeff;
224 for f in all_factors {
225 let (fc, nf, df) = flatten_fraction(&f);
226 num_coeff *= fc;
227 num_factors.extend(nf);
228 den_factors.extend(df);
229 }
230 let num_groups = combine_powers(num_factors);
231 let den_groups = combine_powers(den_factors);
232 let (final_coeff, final_num, final_den) = cancel_common(num_coeff, num_groups, den_groups);
233 let num_expr = build_product_from_groups(final_coeff, final_num);
234 let den_expr = build_product_from_groups(1.0, final_den);
235 if is_const(&den_expr, 1.0) {
236 return num_expr;
237 }
238 return E::new(Expr::Div(num_expr, den_expr));
239 }
240
241 let groups = combine_powers(all_factors);
243
244 let mut factors: Vec<E> = Vec::new();
246 for (base, exp) in groups {
247 if exp == 0.0 {
248 } else if exp == 1.0 {
250 factors.push(base);
251 } else {
252 factors.push(E::new(Expr::Pow(base, constant(exp))));
253 }
254 }
255
256 factors.sort_by(mul_factor_cmp);
258
259 build_product(coeff, factors)
260}
261
262fn flatten_additive(e: &E) -> Vec<(f64, E)> {
269 match e.as_ref() {
270 Expr::Add(a, b) => {
271 let mut terms = flatten_additive(a);
272 terms.extend(flatten_additive(b));
273 terms
274 }
275 Expr::Sub(a, b) => {
276 let mut terms = flatten_additive(a);
277 let neg_terms: Vec<(f64, E)> = flatten_additive(b)
278 .into_iter()
279 .map(|(c, base)| (-c, base))
280 .collect();
281 terms.extend(neg_terms);
282 terms
283 }
284 Expr::Neg(inner) => {
285 flatten_additive(inner)
286 .into_iter()
287 .map(|(c, base)| (-c, base))
288 .collect()
289 }
290 _ => {
291 let (coeff, base) = extract_coeff(e);
292 vec![(coeff, base)]
293 }
294 }
295}
296
297fn extract_coeff(e: &E) -> (f64, E) {
299 match e.as_ref() {
300 Expr::Const(v) => (*v, constant(1.0)),
301 Expr::Mul(a, b) => {
302 if let Expr::Const(v) = a.as_ref() {
303 let (inner_c, inner_b) = extract_coeff(b);
305 return (v * inner_c, inner_b);
306 }
307 if let Expr::Const(v) = b.as_ref() {
308 let (inner_c, inner_b) = extract_coeff(a);
309 return (v * inner_c, inner_b);
310 }
311 (1.0, e.clone())
312 }
313 Expr::Neg(inner) => {
314 let (c, base) = extract_coeff(inner);
315 (-c, base)
316 }
317 _ => (1.0, e.clone()),
318 }
319}
320
321fn combine_like_terms(terms: Vec<(f64, E)>) -> Vec<(f64, E)> {
323 let mut groups: Vec<(f64, E)> = Vec::new();
324 for (coeff, base) in terms {
325 if let Some(entry) = groups.iter_mut().find(|(_, b)| *b == base) {
326 entry.0 += coeff;
327 } else {
328 groups.push((coeff, base));
329 }
330 }
331 groups
332}
333
334fn build_sum(mut terms: Vec<(f64, E)>) -> E {
336 terms.retain(|(c, _)| c.abs() > f64::EPSILON);
338
339 if terms.is_empty() {
340 return constant(0.0);
341 }
342
343 terms.sort_by(|(_, a), (_, b)| add_term_cmp(a, b));
345
346 let make_term = |coeff: f64, base: E| -> E {
347 if is_const(&base, 1.0) {
348 constant(coeff)
349 } else if coeff == 1.0 {
350 base
351 } else if coeff == -1.0 {
352 E::new(Expr::Neg(base))
353 } else {
354 E::new(Expr::Mul(constant(coeff), base))
355 }
356 };
357
358 let mut iter = terms.into_iter();
359 let (first_c, first_b) = iter.next().unwrap();
360 let mut result = make_term(first_c, first_b);
361
362 for (coeff, base) in iter {
363 if coeff > 0.0 {
364 result = E::new(Expr::Add(result, make_term(coeff, base)));
365 } else {
366 result = E::new(Expr::Sub(result, make_term(-coeff, base)));
367 }
368 }
369
370 result
371}
372
373fn simplify_sum(a: E, b: E, negate_b: bool) -> E {
375 let mut terms = flatten_additive(&a);
376 let b_terms = flatten_additive(&b);
377 if negate_b {
378 terms.extend(b_terms.into_iter().map(|(c, base)| (-c, base)));
379 } else {
380 terms.extend(b_terms);
381 }
382
383 let combined = combine_like_terms(terms);
384 build_sum(combined)
385}
386
387fn flatten_fraction(e: &E) -> (f64, Vec<E>, Vec<E>) {
395 match e.as_ref() {
396 Expr::Div(a, b) => {
397 let (ca, na, da) = flatten_fraction(a);
398 let (cb, nb, db) = flatten_fraction(b);
399 let mut num = na;
401 num.extend(db);
402 let mut den = da;
403 den.extend(nb);
404 (ca / cb, num, den)
405 }
406 _ => {
407 let (c, factors) = flatten_mul(e);
408 (c, factors, vec![])
409 }
410 }
411}
412
413fn cancel_common(
417 coeff: f64,
418 mut num: Vec<(E, f64)>,
419 den: Vec<(E, f64)>,
420) -> (f64, Vec<(E, f64)>, Vec<(E, f64)>) {
421 let mut final_den = Vec::new();
422 for (base, den_exp) in den {
423 if let Some(entry) = num.iter_mut().find(|(b, _)| *b == base) {
424 entry.1 -= den_exp;
425 } else {
426 final_den.push((base, den_exp));
427 }
428 }
429 let mut moved = Vec::new();
431 for (i, (_base, exp)) in num.iter().enumerate() {
432 if *exp < 0.0 {
433 moved.push(i);
434 }
435 }
436 for i in moved.into_iter().rev() {
437 let (base, exp) = num.remove(i);
438 final_den.push((base, -exp));
439 }
440 num.retain(|(_, exp)| *exp != 0.0);
441 (coeff, num, final_den)
442}
443
444fn build_product_from_groups(coeff: f64, groups: Vec<(E, f64)>) -> E {
446 let factors: Vec<E> = groups
447 .into_iter()
448 .map(|(base, exp)| {
449 if exp == 1.0 {
450 base
451 } else {
452 E::new(Expr::Pow(base, constant(exp)))
453 }
454 })
455 .collect();
456 build_product(coeff, factors)
457}
458
459fn simplify_div(a: E, b: E) -> E {
460 if let (Expr::Const(va), Expr::Const(vb)) = (a.as_ref(), b.as_ref())
462 && *vb != 0.0 {
463 return constant(va / vb);
464 }
465 if is_const(&a, 0.0) { return constant(0.0); }
466 if is_const(&b, 1.0) { return a; }
467 if a == b { return constant(1.0); }
468
469 let (ca, na, da) = flatten_fraction(&a);
471 let (cb, nb, db) = flatten_fraction(&b);
472 let mut num_factors = na;
474 num_factors.extend(db);
475 let mut den_factors = da;
476 den_factors.extend(nb);
477 let coeff = ca / cb;
478
479 if coeff == 0.0 { return constant(0.0); }
480
481 let num_groups = combine_powers(num_factors);
483 let den_groups = combine_powers(den_factors);
484
485 let (coeff, final_num, final_den) = cancel_common(coeff, num_groups, den_groups);
487
488 let num_expr = build_product_from_groups(coeff, final_num);
490 let den_expr = build_product_from_groups(1.0, final_den);
491
492 if is_const(&den_expr, 1.0) {
493 num_expr
494 } else {
495 E::new(Expr::Div(num_expr, den_expr))
496 }
497}
498
499impl Expr {
504 pub fn simplify(&self) -> E {
510 let mut result = self.simplify_once();
511 for _ in 0..10 {
512 let next = result.simplify_once();
513 if next == result { break; }
514 result = next;
515 }
516 result
517 }
518
519 fn simplify_once(&self) -> E {
520 fn is_pi(e: &E) -> bool {
522 matches!(e.as_ref(), Expr::NamedConst { name, .. } if name == "pi")
523 }
524
525 fn is_euler(e: &E) -> bool {
527 matches!(e.as_ref(), Expr::NamedConst { name, .. } if name == "e")
528 }
529
530 fn pi_coeff(e: &E) -> Option<f64> {
533 if is_pi(e) { return Some(1.0); }
534 match e.as_ref() {
535 Expr::Neg(inner) => pi_coeff(inner).map(|c| -c),
536 Expr::Mul(a, b) => {
537 if let Expr::Const(c) = a.as_ref() && is_pi(b) { return Some(*c); }
538 if let Expr::Const(c) = b.as_ref() && is_pi(a) { return Some(*c); }
539 None
540 }
541 Expr::Div(a, b) => {
542 if let Expr::Const(d) = b.as_ref() { return pi_coeff(a).map(|c| c / d); }
543 None
544 }
545 _ => None,
546 }
547 }
548
549 fn sin_pi(k: f64) -> Option<E> {
552 let twelfths = k * 12.0;
553 if (twelfths - twelfths.round()).abs() > 1e-9 { return None; }
554 let idx = ((twelfths.round() as i64) % 24 + 24) % 24;
555 match idx {
557 0 | 12 => Some(constant(0.0)), 6 | 18 => Some(if idx == 6 { constant(1.0) } else { constant(-1.0) }), 2 | 10 => Some(constant(0.5)), 14 | 22 => Some(constant(-0.5)), 3 | 9 => Some(crate::sqrt(constant(2.0)) / 2.0), 15 | 21 => Some(-crate::sqrt(constant(2.0)) / 2.0), 4 | 8 => Some(crate::sqrt(constant(3.0)) / 2.0), 16 | 20 => Some(-crate::sqrt(constant(3.0)) / 2.0), _ => None,
566 }
567 }
568
569 fn cos_pi(k: f64) -> Option<E> {
571 sin_pi(k + 0.5)
573 }
574
575
576 match self {
577 Expr::Sym(_) | Expr::Const(_) | Expr::NamedConst { .. } => E::new(self.clone()),
578
579 Expr::Neg(a) => {
580 let a = a.simplify_once();
581 if let Expr::Neg(inner) = a.as_ref() {
582 return inner.clone();
583 }
584 if let Expr::Const(v) = a.as_ref() {
585 return constant(-v);
586 }
587 E::new(Expr::Neg(a))
588 }
589
590 Expr::Add(a, b) => {
591 let a = a.simplify_once();
592 let b = b.simplify_once();
593 simplify_sum(a, b, false)
594 }
595
596 Expr::Sub(a, b) => {
597 let a = a.simplify_once();
598 let b = b.simplify_once();
599 simplify_sum(a, b, true)
600 }
601
602 Expr::Mul(a, b) => {
603 let a = a.simplify_once();
604 let b = b.simplify_once();
605 simplify_product(a, b)
606 }
607
608 Expr::Div(a, b) => {
609 let a = a.simplify_once();
610 let b = b.simplify_once();
611 simplify_div(a, b)
612 }
613
614 Expr::Pow(a, b) => {
615 let a = a.simplify_once();
616 let b = b.simplify_once();
617 if let (Expr::Const(va), Expr::Const(vb)) = (a.as_ref(), b.as_ref()) {
618 return constant(va.powf(*vb));
619 }
620 if is_const(&b, 0.0) { return constant(1.0); }
621 if is_const(&b, 1.0) { return a; }
622 if is_const(&a, 0.0) { return constant(0.0); }
623 if is_const(&a, 1.0) { return constant(1.0); }
624 E::new(Expr::Pow(a, b))
625 }
626
627 Expr::Ln(a) => {
629 let a = a.simplify_once();
630 if let Expr::Exp(inner) = a.as_ref() { return inner.clone(); }
631 if let Expr::Const(v) = a.as_ref() { return constant(v.ln()); }
632 if is_euler(&a) { return constant(1.0); }
633 if let Expr::Pow(base, exp) = a.as_ref()
635 && is_euler(base) { return exp.clone(); }
636 E::new(Expr::Ln(a))
637 }
638 Expr::Exp(a) => {
639 let a = a.simplify_once();
640 if let Expr::Ln(inner) = a.as_ref() { return inner.clone(); }
641 if let Expr::Const(v) = a.as_ref() { return constant(v.exp()); }
642 E::new(Expr::Exp(a))
643 }
644
645 Expr::Sin(a) => {
647 let a = a.simplify_once();
648 if let Expr::Const(v) = a.as_ref() { return constant(v.sin()); }
649 if let Some(k) = pi_coeff(&a) && let Some(v) = sin_pi(k) { return v; }
650 E::new(Expr::Sin(a))
651 }
652 Expr::Cos(a) => {
653 let a = a.simplify_once();
654 if let Expr::Const(v) = a.as_ref() { return constant(v.cos()); }
655 if let Some(k) = pi_coeff(&a) && let Some(v) = cos_pi(k) { return v; }
656 E::new(Expr::Cos(a))
657 }
658 Expr::Tan(a) => {
659 let a = a.simplify_once();
660 if let Expr::Const(v) = a.as_ref() { return constant(v.tan()); }
661 if let Some(k) = pi_coeff(&a)
663 && (k - k.round()).abs() < 1e-9 { return constant(0.0); }
664 E::new(Expr::Tan(a))
665 }
666 Expr::Asin(a) => { let a = a.simplify_once(); if let Expr::Const(v) = a.as_ref() { return constant(v.asin()); } E::new(Expr::Asin(a)) }
667 Expr::Acos(a) => { let a = a.simplify_once(); if let Expr::Const(v) = a.as_ref() { return constant(v.acos()); } E::new(Expr::Acos(a)) }
668 Expr::Atan(a) => { let a = a.simplify_once(); if let Expr::Const(v) = a.as_ref() { return constant(v.atan()); } E::new(Expr::Atan(a)) }
669 Expr::Sinh(a) => { let a = a.simplify_once(); if let Expr::Const(v) = a.as_ref() { return constant(v.sinh()); } E::new(Expr::Sinh(a)) }
670 Expr::Cosh(a) => { let a = a.simplify_once(); if let Expr::Const(v) = a.as_ref() { return constant(v.cosh()); } E::new(Expr::Cosh(a)) }
671 Expr::Tanh(a) => { let a = a.simplify_once(); if let Expr::Const(v) = a.as_ref() { return constant(v.tanh()); } E::new(Expr::Tanh(a)) }
672 Expr::Log2(a) => { let a = a.simplify_once(); if let Expr::Const(v) = a.as_ref() { return constant(v.log2()); } E::new(Expr::Log2(a)) }
673 Expr::Log10(a) => { let a = a.simplify_once(); if let Expr::Const(v) = a.as_ref() { return constant(v.log10()); } E::new(Expr::Log10(a)) }
674 Expr::Sqrt(a) => {
675 let a = a.simplify_once();
676 if let Expr::Const(v) = a.as_ref() { return constant(v.sqrt()); }
677 if let Expr::Pow(base, exp) = a.as_ref()
678 && is_const(exp, 2.0) {
679 return E::new(Expr::Abs(base.clone()));
680 }
681 E::new(Expr::Sqrt(a))
682 }
683 Expr::Abs(a) => { let a = a.simplify_once(); if let Expr::Const(v) = a.as_ref() { return constant(v.abs()); } E::new(Expr::Abs(a)) }
684 Expr::Heaviside(a) => {
685 let a = a.simplify_once();
686 if let Expr::Const(v) = a.as_ref() {
687 return constant(if *v < 0.0 { 0.0 } else { 1.0 });
688 }
689 E::new(Expr::Heaviside(a))
690 }
691 Expr::Clamp(val, lo, hi) => {
692 let val = val.simplify_once();
693 let lo = lo.simplify_once();
694 let hi = hi.simplify_once();
695 if let (Expr::Const(v), Expr::Const(l), Expr::Const(h)) = (val.as_ref(), lo.as_ref(), hi.as_ref()) {
696 return constant(v.clamp(*l, *h));
697 }
698 E::new(Expr::Clamp(val, lo, hi))
699 }
700 Expr::Atan2(y, x) => {
701 let y = y.simplify_once();
702 let x = x.simplify_once();
703 if let (Expr::Const(vy), Expr::Const(vx)) = (y.as_ref(), x.as_ref()) {
704 return constant(vy.atan2(*vx));
705 }
706 E::new(Expr::Atan2(y, x))
707 }
708 Expr::Func { name, params, kind, args } => {
709 let new_args: Vec<E> = args.iter().map(|a| a.simplify_once()).collect();
710 if let Some(body) = kind.body()
712 && new_args.iter().all(|a| matches!(a.as_ref(), Expr::Const(_))) {
713 let expanded = crate::expand_func(params, body, &new_args);
714 return expanded.simplify_once();
715 }
716 E::new(Expr::Func {
717 name: name.clone(), params: params.clone(),
718 kind: kind.clone(), args: new_args,
719 })
720 }
721 }
722 }
723
724 pub fn expand(&self) -> E {
730 self.expand_inner().simplify()
731 }
732
733 fn expand_inner(&self) -> E {
734 match self {
735 Expr::Sym(_) | Expr::Const(_) | Expr::NamedConst { .. } => E::new(self.clone()),
736 Expr::Neg(a) => E::new(Expr::Neg(a.expand_inner())),
737 Expr::Add(a, b) => E::new(Expr::Add(a.expand_inner(), b.expand_inner())),
738 Expr::Sub(a, b) => E::new(Expr::Sub(a.expand_inner(), b.expand_inner())),
739 Expr::Mul(a, b) => {
740 let a = a.expand_inner();
741 let b = b.expand_inner();
742 if let Expr::Add(b1, b2) = b.as_ref() {
743 let left = E::new(Expr::Mul(a.clone(), b1.clone()));
744 let right = E::new(Expr::Mul(a, b2.clone()));
745 return E::new(Expr::Add(left.expand_inner(), right.expand_inner()));
746 }
747 if let Expr::Sub(b1, b2) = b.as_ref() {
748 let left = E::new(Expr::Mul(a.clone(), b1.clone()));
749 let right = E::new(Expr::Mul(a, b2.clone()));
750 return E::new(Expr::Sub(left.expand_inner(), right.expand_inner()));
751 }
752 if let Expr::Add(a1, a2) = a.as_ref() {
753 let left = E::new(Expr::Mul(a1.clone(), b.clone()));
754 let right = E::new(Expr::Mul(a2.clone(), b));
755 return E::new(Expr::Add(left.expand_inner(), right.expand_inner()));
756 }
757 if let Expr::Sub(a1, a2) = a.as_ref() {
758 let left = E::new(Expr::Mul(a1.clone(), b.clone()));
759 let right = E::new(Expr::Mul(a2.clone(), b));
760 return E::new(Expr::Sub(left.expand_inner(), right.expand_inner()));
761 }
762 E::new(Expr::Mul(a, b))
763 }
764 Expr::Div(a, b) => E::new(Expr::Div(a.expand_inner(), b.expand_inner())),
765 Expr::Pow(base, exp) => {
766 let base = base.expand_inner();
767 let exp = exp.expand_inner();
768 if let Some(n) = is_const_int(&exp)
769 && (2..=8).contains(&n) {
770 let mut result = base.clone();
771 for _ in 1..n {
772 result = E::new(Expr::Mul(result, base.clone()));
773 }
774 return result.expand_inner();
775 }
776 E::new(Expr::Pow(base, exp))
777 }
778 Expr::Sin(a) => E::new(Expr::Sin(a.expand_inner())),
779 Expr::Cos(a) => E::new(Expr::Cos(a.expand_inner())),
780 Expr::Tan(a) => E::new(Expr::Tan(a.expand_inner())),
781 Expr::Asin(a) => E::new(Expr::Asin(a.expand_inner())),
782 Expr::Acos(a) => E::new(Expr::Acos(a.expand_inner())),
783 Expr::Atan(a) => E::new(Expr::Atan(a.expand_inner())),
784 Expr::Atan2(y, x) => E::new(Expr::Atan2(y.expand_inner(), x.expand_inner())),
785 Expr::Sinh(a) => E::new(Expr::Sinh(a.expand_inner())),
786 Expr::Cosh(a) => E::new(Expr::Cosh(a.expand_inner())),
787 Expr::Tanh(a) => E::new(Expr::Tanh(a.expand_inner())),
788 Expr::Exp(a) => E::new(Expr::Exp(a.expand_inner())),
789 Expr::Ln(a) => E::new(Expr::Ln(a.expand_inner())),
790 Expr::Log2(a) => E::new(Expr::Log2(a.expand_inner())),
791 Expr::Log10(a) => E::new(Expr::Log10(a.expand_inner())),
792 Expr::Sqrt(a) => E::new(Expr::Sqrt(a.expand_inner())),
793 Expr::Abs(a) => E::new(Expr::Abs(a.expand_inner())),
794 Expr::Heaviside(a) => E::new(Expr::Heaviside(a.expand_inner())),
795 Expr::Clamp(val, lo, hi) => E::new(Expr::Clamp(val.expand_inner(), lo.expand_inner(), hi.expand_inner())),
796 Expr::Func { name, params, kind, args } => {
797 let expanded_args: Vec<E> = args.iter().map(|a| a.expand_inner()).collect();
798 if let Some(body) = kind.body() {
799 crate::expand_func(params, body, &expanded_args).expand_inner()
800 } else {
801 E::new(Expr::Func {
802 name: name.clone(), params: params.clone(),
803 kind: kind.clone(), args: expanded_args,
804 })
805 }
806 }
807 }
808 }
809
810 pub fn collect(&self, var: impl crate::AsVarName) -> E {
817 let var = var.var_expr();
818 let terms = flatten_add_simple(&E::new(self.clone()));
819 let mut with_var: Vec<E> = Vec::new();
820 let mut without_var: Vec<E> = Vec::new();
821
822 for term in &terms {
823 if let Some(coeff) = extract_factor(term, &var) {
824 with_var.push(coeff);
825 } else {
826 without_var.push(term.clone());
827 }
828 }
829
830 let mut result: Option<E> = None;
831
832 if !with_var.is_empty() {
833 let coeff_sum = sum_terms(with_var);
834 let collected = coeff_sum * var;
835 result = Some(collected);
836 }
837
838 for t in without_var {
839 result = Some(match result {
840 Some(acc) => acc + t,
841 None => t,
842 });
843 }
844
845 result.unwrap_or_else(|| constant(0.0))
846 }
847}
848
849fn flatten_add_simple(e: &E) -> Vec<E> {
850 match e.as_ref() {
851 Expr::Add(a, b) => {
852 let mut terms = flatten_add_simple(a);
853 terms.extend(flatten_add_simple(b));
854 terms
855 }
856 _ => vec![e.clone()],
857 }
858}
859
860fn extract_factor(term: &E, var: &E) -> Option<E> {
861 if term == var {
862 return Some(constant(1.0));
863 }
864 if let Expr::Mul(a, b) = term.as_ref() {
865 if b == var { return Some(a.clone()); }
866 if a == var { return Some(b.clone()); }
867 }
868 None
869}
870
871fn sum_terms(terms: Vec<E>) -> E {
872 let mut iter = terms.into_iter();
873 let first = iter.next().unwrap();
874 iter.fold(first, |acc, t| acc + t)
875}