Skip to main content

arael_sym/
simplify.rs

1use std::cmp::Ordering;
2use super::{Expr, E, constant};
3
4fn is_const(e: &Expr, v: f64) -> bool {
5    matches!(e, Expr::Const(c) if *c == v)
6}
7
8fn is_const_int(e: &Expr) -> Option<i64> {
9    if let Expr::Const(v) = e
10        && *v == v.floor() && v.abs() < 1e15 {
11            return Some(*v as i64);
12        }
13    None
14}
15
16// ---------------------------------------------------------------------------
17// Canonical ordering for expressions
18// ---------------------------------------------------------------------------
19
20fn type_priority(e: &Expr) -> u8 {
21    match e {
22        Expr::Const(_) => 100, // constants sort last in Add, first extracted in Mul
23        Expr::Sym(_) => 0,
24        Expr::Pow(base, _) => {
25            // Pow of symbol sorts alongside symbols for degree ordering
26            if matches!(base.as_ref(), Expr::Sym(_)) { 0 } else { 2 }
27        }
28        Expr::Mul(_, _) => 1,
29        Expr::Neg(_) => 3,
30        Expr::Add(_, _) | Expr::Sub(_, _) => 4,
31        _ => 5, // functions
32    }
33}
34
35/// Extract the "leading symbol name" for ordering purposes.
36fn leading_sym(e: &Expr) -> Option<&str> {
37    match e {
38        Expr::Sym(s) => Some(s),
39        Expr::Pow(base, _) => leading_sym(base),
40        Expr::Mul(a, b) => leading_sym(a).or_else(|| leading_sym(b)),
41        Expr::Neg(a) => leading_sym(a),
42        _ => None,
43    }
44}
45
46/// Estimate "degree" for ordering within Add (higher degree first).
47fn degree(e: &Expr) -> i64 {
48    match e {
49        Expr::Sym(_) => 1,
50        Expr::Pow(_, exp) => {
51            if let Expr::Const(v) = exp.as_ref() {
52                *v as i64
53            } else {
54                2 // treat non-const exponent as degree 2
55            }
56        }
57        Expr::Mul(a, b) => degree(a) + degree(b),
58        Expr::Neg(a) => degree(a),
59        Expr::Const(_) => 0,
60        _ => 1,
61    }
62}
63
64/// Canonical comparison for sorting factors in Mul.
65fn mul_factor_cmp(a: &E, b: &E) -> Ordering {
66    let sa = leading_sym(a);
67    let sb = leading_sym(b);
68    match (sa, sb) {
69        (Some(sa), Some(sb)) => {
70            let cmp = sa.cmp(sb);
71            if cmp != Ordering::Equal { return cmp; }
72            // Same leading symbol: compare by degree (lower first in Mul context)
73            degree(a).cmp(&degree(b))
74        }
75        (Some(_), None) => Ordering::Less,
76        (None, Some(_)) => Ordering::Greater,
77        (None, None) => {
78            let cmp = type_priority(a).cmp(&type_priority(b));
79            if cmp != Ordering::Equal { return cmp; }
80            // Tiebreaker: compare by string representation for deterministic ordering
81            format!("{}", a).cmp(&format!("{}", b))
82        }
83    }
84}
85
86/// Canonical comparison for sorting terms in Add.
87/// Higher degree first, then alphabetical, constants last.
88fn add_term_cmp(a: &E, b: &E) -> Ordering {
89    let pa = type_priority(a);
90    let pb = type_priority(b);
91    // Constants always last
92    if pa == 100 && pb != 100 { return Ordering::Greater; }
93    if pa != 100 && pb == 100 { return Ordering::Less; }
94    if pa == 100 && pb == 100 {
95        // Both constants: compare by value
96        if let (Expr::Const(va), Expr::Const(vb)) = (a.as_ref(), b.as_ref()) {
97            return vb.partial_cmp(va).unwrap_or(Ordering::Equal);
98        }
99        return Ordering::Equal;
100    }
101
102    // Higher degree first
103    let da = degree(a);
104    let db = degree(b);
105    if da != db { return db.cmp(&da); }
106
107    // Same degree: alphabetical by leading symbol
108    let sa = leading_sym(a);
109    let sb = leading_sym(b);
110    match (sa, sb) {
111        (Some(sa), Some(sb)) => sa.cmp(sb),
112        (Some(_), None) => Ordering::Less,
113        (None, Some(_)) => Ordering::Greater,
114        (None, None) => Ordering::Equal,
115    }
116}
117
118// ---------------------------------------------------------------------------
119// Mul normalization helpers
120// ---------------------------------------------------------------------------
121
122/// Extract (base, const_exponent) from a factor.
123fn base_and_exp(e: &E) -> (E, f64) {
124    if let Expr::Pow(base, exp) = e.as_ref()
125        && let Expr::Const(n) = exp.as_ref() {
126            return (base.clone(), *n);
127        }
128    (e.clone(), 1.0)
129}
130
131/// Flatten nested Mul into a list of factors, extracting a numeric coefficient.
132/// Also expands (a*b)^n → a^n * b^n for constant exponents.
133fn flatten_mul(e: &E) -> (f64, Vec<E>) {
134    match e.as_ref() {
135        Expr::Mul(a, b) => {
136            let (ca, mut fa) = flatten_mul(a);
137            let (cb, fb) = flatten_mul(b);
138            fa.extend(fb);
139            (ca * cb, fa)
140        }
141        Expr::Neg(inner) => {
142            let (c, f) = flatten_mul(inner);
143            (-c, f)
144        }
145        Expr::Const(v) => (*v, vec![]),
146        // (a * b * ...)^n → a^n * b^n * ... (expand compound-base powers)
147        Expr::Pow(base, exp) if matches!(base.as_ref(), Expr::Mul(..) | Expr::Neg(..)) => {
148            if let Expr::Const(n) = exp.as_ref() {
149                let (c_base, factors) = flatten_mul(base);
150                let coeff = c_base.powf(*n);
151                let powered: Vec<E> = factors
152                    .into_iter()
153                    .map(|f| E::new(Expr::Pow(f, exp.clone())))
154                    .collect();
155                (coeff, powered)
156            } else {
157                (1.0, vec![e.clone()])
158            }
159        }
160        _ => (1.0, vec![e.clone()]),
161    }
162}
163
164/// Combine factors with the same base by summing exponents.
165/// Returns a list of (base, total_exponent) pairs.
166fn combine_powers(factors: Vec<E>) -> Vec<(E, f64)> {
167    let mut groups: Vec<(E, f64)> = Vec::new();
168    for f in factors {
169        let (base, exp) = base_and_exp(&f);
170        if let Some(entry) = groups.iter_mut().find(|(b, _)| *b == base) {
171            entry.1 += exp;
172        } else {
173            groups.push((base, exp));
174        }
175    }
176    groups
177}
178
179/// Build a simplified Mul from coefficient and factor list.
180fn build_product(coeff: f64, mut factors: Vec<E>) -> E {
181    if coeff == 0.0 { return constant(0.0); }
182
183    // Sort factors canonically
184    factors.sort_by(mul_factor_cmp);
185
186    // Build factor chain
187    let factors_expr = if factors.is_empty() {
188        return constant(coeff);
189    } else {
190        let mut iter = factors.into_iter();
191        let first = iter.next().unwrap();
192        iter.fold(first, |acc, f| E::new(Expr::Mul(acc, f)))
193    };
194
195    if coeff == 1.0 {
196        factors_expr
197    } else if coeff == -1.0 {
198        E::new(Expr::Neg(factors_expr))
199    } else {
200        E::new(Expr::Mul(constant(coeff), factors_expr))
201    }
202}
203
204/// Full Mul simplification: flatten, fold constants, combine powers, sort, rebuild.
205fn simplify_product(a: E, b: E) -> E {
206    // Flatten both sides
207    let (ca, fa) = flatten_mul(&a);
208    let (cb, fb) = flatten_mul(&b);
209
210    let coeff = ca * cb;
211    let mut all_factors = fa;
212    all_factors.extend(fb);
213
214    if coeff == 0.0 { return constant(0.0); }
215    if all_factors.is_empty() { return constant(coeff); }
216
217    // If any factor is a Div, combine everything into a single fraction
218    // e.g. Mul(gamma^2, Div(-a, gamma*sigma)) → Div(-a*gamma, sigma)
219    let has_div = all_factors.iter().any(|f| matches!(f.as_ref(), Expr::Div(..)));
220    if has_div {
221        let mut num_factors = Vec::new();
222        let mut den_factors = Vec::new();
223        let mut num_coeff = coeff;
224        for f in all_factors {
225            let (fc, nf, df) = flatten_fraction(&f);
226            num_coeff *= fc;
227            num_factors.extend(nf);
228            den_factors.extend(df);
229        }
230        let num_groups = combine_powers(num_factors);
231        let den_groups = combine_powers(den_factors);
232        let (final_coeff, final_num, final_den) = cancel_common(num_coeff, num_groups, den_groups);
233        let num_expr = build_product_from_groups(final_coeff, final_num);
234        let den_expr = build_product_from_groups(1.0, final_den);
235        if is_const(&den_expr, 1.0) {
236            return num_expr;
237        }
238        return E::new(Expr::Div(num_expr, den_expr));
239    }
240
241    // Combine like bases
242    let groups = combine_powers(all_factors);
243
244    // Rebuild factors from groups
245    let mut factors: Vec<E> = Vec::new();
246    for (base, exp) in groups {
247        if exp == 0.0 {
248            // x^0 = 1, skip
249        } else if exp == 1.0 {
250            factors.push(base);
251        } else {
252            factors.push(E::new(Expr::Pow(base, constant(exp))));
253        }
254    }
255
256    // Sort factors into canonical order so a*b == b*a structurally
257    factors.sort_by(mul_factor_cmp);
258
259    build_product(coeff, factors)
260}
261
262// ---------------------------------------------------------------------------
263// Add normalization helpers
264// ---------------------------------------------------------------------------
265
266/// Flatten Add/Sub/Neg into a list of (coefficient, base) pairs.
267/// The base is the expression without the numeric coefficient.
268fn flatten_additive(e: &E) -> Vec<(f64, E)> {
269    match e.as_ref() {
270        Expr::Add(a, b) => {
271            let mut terms = flatten_additive(a);
272            terms.extend(flatten_additive(b));
273            terms
274        }
275        Expr::Sub(a, b) => {
276            let mut terms = flatten_additive(a);
277            let neg_terms: Vec<(f64, E)> = flatten_additive(b)
278                .into_iter()
279                .map(|(c, base)| (-c, base))
280                .collect();
281            terms.extend(neg_terms);
282            terms
283        }
284        Expr::Neg(inner) => {
285            flatten_additive(inner)
286                .into_iter()
287                .map(|(c, base)| (-c, base))
288                .collect()
289        }
290        _ => {
291            let (coeff, base) = extract_coeff(e);
292            vec![(coeff, base)]
293        }
294    }
295}
296
297/// Extract numeric coefficient and base from a term.
298fn extract_coeff(e: &E) -> (f64, E) {
299    match e.as_ref() {
300        Expr::Const(v) => (*v, constant(1.0)),
301        Expr::Mul(a, b) => {
302            if let Expr::Const(v) = a.as_ref() {
303                // Const * rest — check if rest is also Mul(Const, ...)
304                let (inner_c, inner_b) = extract_coeff(b);
305                return (v * inner_c, inner_b);
306            }
307            if let Expr::Const(v) = b.as_ref() {
308                let (inner_c, inner_b) = extract_coeff(a);
309                return (v * inner_c, inner_b);
310            }
311            (1.0, e.clone())
312        }
313        Expr::Neg(inner) => {
314            let (c, base) = extract_coeff(inner);
315            (-c, base)
316        }
317        _ => (1.0, e.clone()),
318    }
319}
320
321/// Group terms by base, summing numeric coefficients.
322fn combine_like_terms(terms: Vec<(f64, E)>) -> Vec<(f64, E)> {
323    let mut groups: Vec<(f64, E)> = Vec::new();
324    for (coeff, base) in terms {
325        if let Some(entry) = groups.iter_mut().find(|(_, b)| *b == base) {
326            entry.0 += coeff;
327        } else {
328            groups.push((coeff, base));
329        }
330    }
331    groups
332}
333
334/// Build an Add/Sub chain from (coefficient, base) pairs.
335fn build_sum(mut terms: Vec<(f64, E)>) -> E {
336    // Remove zero-coefficient terms
337    terms.retain(|(c, _)| c.abs() > f64::EPSILON);
338
339    if terms.is_empty() {
340        return constant(0.0);
341    }
342
343    // Sort terms: sort bases for consistent output
344    terms.sort_by(|(_, a), (_, b)| add_term_cmp(a, b));
345
346    let make_term = |coeff: f64, base: E| -> E {
347        if is_const(&base, 1.0) {
348            constant(coeff)
349        } else if coeff == 1.0 {
350            base
351        } else if coeff == -1.0 {
352            E::new(Expr::Neg(base))
353        } else {
354            E::new(Expr::Mul(constant(coeff), base))
355        }
356    };
357
358    let mut iter = terms.into_iter();
359    let (first_c, first_b) = iter.next().unwrap();
360    let mut result = make_term(first_c, first_b);
361
362    for (coeff, base) in iter {
363        if coeff > 0.0 {
364            result = E::new(Expr::Add(result, make_term(coeff, base)));
365        } else {
366            result = E::new(Expr::Sub(result, make_term(-coeff, base)));
367        }
368    }
369
370    result
371}
372
373/// Full Add simplification: flatten, combine like terms, sort, rebuild.
374fn simplify_sum(a: E, b: E, negate_b: bool) -> E {
375    let mut terms = flatten_additive(&a);
376    let b_terms = flatten_additive(&b);
377    if negate_b {
378        terms.extend(b_terms.into_iter().map(|(c, base)| (-c, base)));
379    } else {
380        terms.extend(b_terms);
381    }
382
383    let combined = combine_like_terms(terms);
384    build_sum(combined)
385}
386
387// ---------------------------------------------------------------------------
388// Div / fraction helpers
389// ---------------------------------------------------------------------------
390
391/// Flatten an expression into (coeff, numerator_factors, denominator_factors).
392/// Handles Div chains: (a/b)/c → num=[a_factors], den=[b_factors, c_factors]
393/// Handles reciprocals: a/(b/c) → num=[a_factors, c_factors], den=[b_factors]
394fn flatten_fraction(e: &E) -> (f64, Vec<E>, Vec<E>) {
395    match e.as_ref() {
396        Expr::Div(a, b) => {
397            let (ca, na, da) = flatten_fraction(a);
398            let (cb, nb, db) = flatten_fraction(b);
399            // (na/da) / (nb/db) = (na*db) / (da*nb)
400            let mut num = na;
401            num.extend(db);
402            let mut den = da;
403            den.extend(nb);
404            (ca / cb, num, den)
405        }
406        _ => {
407            let (c, factors) = flatten_mul(e);
408            (c, factors, vec![])
409        }
410    }
411}
412
413/// Cancel common bases between numerator and denominator power groups.
414/// gamma^3 in num + gamma^2 in den → gamma^1 in num.
415/// sigma^1 in num + sigma^2 in den → sigma^1 in den.
416fn cancel_common(
417    coeff: f64,
418    mut num: Vec<(E, f64)>,
419    den: Vec<(E, f64)>,
420) -> (f64, Vec<(E, f64)>, Vec<(E, f64)>) {
421    let mut final_den = Vec::new();
422    for (base, den_exp) in den {
423        if let Some(entry) = num.iter_mut().find(|(b, _)| *b == base) {
424            entry.1 -= den_exp;
425        } else {
426            final_den.push((base, den_exp));
427        }
428    }
429    // Move negative-exponent entries from num to den
430    let mut moved = Vec::new();
431    for (i, (_base, exp)) in num.iter().enumerate() {
432        if *exp < 0.0 {
433            moved.push(i);
434        }
435    }
436    for i in moved.into_iter().rev() {
437        let (base, exp) = num.remove(i);
438        final_den.push((base, -exp));
439    }
440    num.retain(|(_, exp)| *exp != 0.0);
441    (coeff, num, final_den)
442}
443
444/// Build a product from (base, exponent) groups.
445fn build_product_from_groups(coeff: f64, groups: Vec<(E, f64)>) -> E {
446    let factors: Vec<E> = groups
447        .into_iter()
448        .map(|(base, exp)| {
449            if exp == 1.0 {
450                base
451            } else {
452                E::new(Expr::Pow(base, constant(exp)))
453            }
454        })
455        .collect();
456    build_product(coeff, factors)
457}
458
459fn simplify_div(a: E, b: E) -> E {
460    // Quick constant cases
461    if let (Expr::Const(va), Expr::Const(vb)) = (a.as_ref(), b.as_ref())
462        && *vb != 0.0 {
463            return constant(va / vb);
464        }
465    if is_const(&a, 0.0) { return constant(0.0); }
466    if is_const(&b, 1.0) { return a; }
467    if a == b { return constant(1.0); }
468
469    // Flatten both sides, collecting all num/den factors across Div chains
470    let (ca, na, da) = flatten_fraction(&a);
471    let (cb, nb, db) = flatten_fraction(&b);
472    // a/b = (na*db) / (da*nb), coeff = ca/cb
473    let mut num_factors = na;
474    num_factors.extend(db);
475    let mut den_factors = da;
476    den_factors.extend(nb);
477    let coeff = ca / cb;
478
479    if coeff == 0.0 { return constant(0.0); }
480
481    // Combine powers within each side
482    let num_groups = combine_powers(num_factors);
483    let den_groups = combine_powers(den_factors);
484
485    // Cancel common bases between num and den
486    let (coeff, final_num, final_den) = cancel_common(coeff, num_groups, den_groups);
487
488    // Rebuild
489    let num_expr = build_product_from_groups(coeff, final_num);
490    let den_expr = build_product_from_groups(1.0, final_den);
491
492    if is_const(&den_expr, 1.0) {
493        num_expr
494    } else {
495        E::new(Expr::Div(num_expr, den_expr))
496    }
497}
498
499// ---------------------------------------------------------------------------
500// Main simplify
501// ---------------------------------------------------------------------------
502
503impl Expr {
504    /// Apply algebraic simplification rules.
505    ///
506    /// Performs constant folding, identity elimination (0+x=x, 1*x=x),
507    /// like-term collection, power combination, fraction cancellation, and
508    /// canonical ordering. Iterates until a fixed point is reached.
509    pub fn simplify(&self) -> E {
510        let mut result = self.simplify_once();
511        for _ in 0..10 {
512            let next = result.simplify_once();
513            if next == result { break; }
514            result = next;
515        }
516        result
517    }
518
519    fn simplify_once(&self) -> E {
520        /// Check if expression is the named constant "pi".
521        fn is_pi(e: &E) -> bool {
522            matches!(e.as_ref(), Expr::NamedConst { name, .. } if name == "pi")
523        }
524
525        /// Check if expression is the named constant "e" (Euler's number).
526        fn is_euler(e: &E) -> bool {
527            matches!(e.as_ref(), Expr::NamedConst { name, .. } if name == "e")
528        }
529
530        /// Extract coefficient k if expression is of the form k*pi.
531        /// Returns Some(k) for: pi->1, 2*pi->2, pi/2->0.5, -pi->-1, etc.
532        fn pi_coeff(e: &E) -> Option<f64> {
533            if is_pi(e) { return Some(1.0); }
534            match e.as_ref() {
535                Expr::Neg(inner) => pi_coeff(inner).map(|c| -c),
536                Expr::Mul(a, b) => {
537                    if let Expr::Const(c) = a.as_ref() && is_pi(b) { return Some(*c); }
538                    if let Expr::Const(c) = b.as_ref() && is_pi(a) { return Some(*c); }
539                    None
540                }
541                Expr::Div(a, b) => {
542                    if let Expr::Const(d) = b.as_ref() { return pi_coeff(a).map(|c| c / d); }
543                    None
544                }
545                _ => None,
546            }
547        }
548
549        /// Try to simplify sin(k*pi) for special values of k.
550        /// Uses twelfths: k*12 mod 24 to cover pi/6, pi/4, pi/3, pi/2, etc.
551        fn sin_pi(k: f64) -> Option<E> {
552            let twelfths = k * 12.0;
553            if (twelfths - twelfths.round()).abs() > 1e-9 { return None; }
554            let idx = ((twelfths.round() as i64) % 24 + 24) % 24;
555            // sin at 0, pi/12, pi/6, pi/4, pi/3, 5pi/12, pi/2, ...
556            match idx {
557                0 | 12 => Some(constant(0.0)),                          // sin(0), sin(pi)
558                6 | 18 => Some(if idx == 6 { constant(1.0) } else { constant(-1.0) }), // sin(pi/2), sin(3pi/2)
559                2 | 10 => Some(constant(0.5)),                          // sin(pi/6), sin(5pi/6)
560                14 | 22 => Some(constant(-0.5)),                        // sin(7pi/6), sin(11pi/6)
561                3 | 9 => Some(crate::sqrt(constant(2.0)) / 2.0),              // sin(pi/4), sin(3pi/4)
562                15 | 21 => Some(-crate::sqrt(constant(2.0)) / 2.0),           // sin(5pi/4), sin(7pi/4)
563                4 | 8 => Some(crate::sqrt(constant(3.0)) / 2.0),              // sin(pi/3), sin(2pi/3)
564                16 | 20 => Some(-crate::sqrt(constant(3.0)) / 2.0),           // sin(4pi/3), sin(5pi/3)
565                _ => None,
566            }
567        }
568
569        /// Try to simplify cos(k*pi) for special values of k.
570        fn cos_pi(k: f64) -> Option<E> {
571            // cos(k*pi) = sin((k + 0.5)*pi)
572            sin_pi(k + 0.5)
573        }
574
575
576        match self {
577            Expr::Sym(_) | Expr::Const(_) | Expr::NamedConst { .. } => E::new(self.clone()),
578
579            Expr::Neg(a) => {
580                let a = a.simplify_once();
581                if let Expr::Neg(inner) = a.as_ref() {
582                    return inner.clone();
583                }
584                if let Expr::Const(v) = a.as_ref() {
585                    return constant(-v);
586                }
587                E::new(Expr::Neg(a))
588            }
589
590            Expr::Add(a, b) => {
591                let a = a.simplify_once();
592                let b = b.simplify_once();
593                simplify_sum(a, b, false)
594            }
595
596            Expr::Sub(a, b) => {
597                let a = a.simplify_once();
598                let b = b.simplify_once();
599                simplify_sum(a, b, true)
600            }
601
602            Expr::Mul(a, b) => {
603                let a = a.simplify_once();
604                let b = b.simplify_once();
605                simplify_product(a, b)
606            }
607
608            Expr::Div(a, b) => {
609                let a = a.simplify_once();
610                let b = b.simplify_once();
611                simplify_div(a, b)
612            }
613
614            Expr::Pow(a, b) => {
615                let a = a.simplify_once();
616                let b = b.simplify_once();
617                if let (Expr::Const(va), Expr::Const(vb)) = (a.as_ref(), b.as_ref()) {
618                    return constant(va.powf(*vb));
619                }
620                if is_const(&b, 0.0) { return constant(1.0); }
621                if is_const(&b, 1.0) { return a; }
622                if is_const(&a, 0.0) { return constant(0.0); }
623                if is_const(&a, 1.0) { return constant(1.0); }
624                E::new(Expr::Pow(a, b))
625            }
626
627            // Inverse function pairs
628            Expr::Ln(a) => {
629                let a = a.simplify_once();
630                if let Expr::Exp(inner) = a.as_ref() { return inner.clone(); }
631                if let Expr::Const(v) = a.as_ref() { return constant(v.ln()); }
632                if is_euler(&a) { return constant(1.0); }
633                // ln(e^n) -> n
634                if let Expr::Pow(base, exp) = a.as_ref()
635                    && is_euler(base) { return exp.clone(); }
636                E::new(Expr::Ln(a))
637            }
638            Expr::Exp(a) => {
639                let a = a.simplify_once();
640                if let Expr::Ln(inner) = a.as_ref() { return inner.clone(); }
641                if let Expr::Const(v) = a.as_ref() { return constant(v.exp()); }
642                E::new(Expr::Exp(a))
643            }
644
645            // Trig functions: constant-fold + pi rules
646            Expr::Sin(a) => {
647                let a = a.simplify_once();
648                if let Expr::Const(v) = a.as_ref() { return constant(v.sin()); }
649                if let Some(k) = pi_coeff(&a) && let Some(v) = sin_pi(k) { return v; }
650                E::new(Expr::Sin(a))
651            }
652            Expr::Cos(a) => {
653                let a = a.simplify_once();
654                if let Expr::Const(v) = a.as_ref() { return constant(v.cos()); }
655                if let Some(k) = pi_coeff(&a) && let Some(v) = cos_pi(k) { return v; }
656                E::new(Expr::Cos(a))
657            }
658            Expr::Tan(a) => {
659                let a = a.simplify_once();
660                if let Expr::Const(v) = a.as_ref() { return constant(v.tan()); }
661                // tan(n*pi) = 0 for integer n
662                if let Some(k) = pi_coeff(&a)
663                    && (k - k.round()).abs() < 1e-9 { return constant(0.0); }
664                E::new(Expr::Tan(a))
665            }
666            Expr::Asin(a) => { let a = a.simplify_once(); if let Expr::Const(v) = a.as_ref() { return constant(v.asin()); } E::new(Expr::Asin(a)) }
667            Expr::Acos(a) => { let a = a.simplify_once(); if let Expr::Const(v) = a.as_ref() { return constant(v.acos()); } E::new(Expr::Acos(a)) }
668            Expr::Atan(a) => { let a = a.simplify_once(); if let Expr::Const(v) = a.as_ref() { return constant(v.atan()); } E::new(Expr::Atan(a)) }
669            Expr::Sinh(a) => { let a = a.simplify_once(); if let Expr::Const(v) = a.as_ref() { return constant(v.sinh()); } E::new(Expr::Sinh(a)) }
670            Expr::Cosh(a) => { let a = a.simplify_once(); if let Expr::Const(v) = a.as_ref() { return constant(v.cosh()); } E::new(Expr::Cosh(a)) }
671            Expr::Tanh(a) => { let a = a.simplify_once(); if let Expr::Const(v) = a.as_ref() { return constant(v.tanh()); } E::new(Expr::Tanh(a)) }
672            Expr::Log2(a) => { let a = a.simplify_once(); if let Expr::Const(v) = a.as_ref() { return constant(v.log2()); } E::new(Expr::Log2(a)) }
673            Expr::Log10(a) => { let a = a.simplify_once(); if let Expr::Const(v) = a.as_ref() { return constant(v.log10()); } E::new(Expr::Log10(a)) }
674            Expr::Sqrt(a) => {
675                let a = a.simplify_once();
676                if let Expr::Const(v) = a.as_ref() { return constant(v.sqrt()); }
677                if let Expr::Pow(base, exp) = a.as_ref()
678                    && is_const(exp, 2.0) {
679                        return E::new(Expr::Abs(base.clone()));
680                    }
681                E::new(Expr::Sqrt(a))
682            }
683            Expr::Abs(a) => { let a = a.simplify_once(); if let Expr::Const(v) = a.as_ref() { return constant(v.abs()); } E::new(Expr::Abs(a)) }
684            Expr::Heaviside(a) => {
685                let a = a.simplify_once();
686                if let Expr::Const(v) = a.as_ref() {
687                    return constant(if *v < 0.0 { 0.0 } else { 1.0 });
688                }
689                E::new(Expr::Heaviside(a))
690            }
691            Expr::Clamp(val, lo, hi) => {
692                let val = val.simplify_once();
693                let lo = lo.simplify_once();
694                let hi = hi.simplify_once();
695                if let (Expr::Const(v), Expr::Const(l), Expr::Const(h)) = (val.as_ref(), lo.as_ref(), hi.as_ref()) {
696                    return constant(v.clamp(*l, *h));
697                }
698                E::new(Expr::Clamp(val, lo, hi))
699            }
700            Expr::Atan2(y, x) => {
701                let y = y.simplify_once();
702                let x = x.simplify_once();
703                if let (Expr::Const(vy), Expr::Const(vx)) = (y.as_ref(), x.as_ref()) {
704                    return constant(vy.atan2(*vx));
705                }
706                E::new(Expr::Atan2(y, x))
707            }
708            Expr::Func { name, params, kind, args } => {
709                let new_args: Vec<E> = args.iter().map(|a| a.simplify_once()).collect();
710                // Constant-fold functions with a symbolic body; Extern stays opaque
711                if let Some(body) = kind.body()
712                    && new_args.iter().all(|a| matches!(a.as_ref(), Expr::Const(_))) {
713                        let expanded = crate::expand_func(params, body, &new_args);
714                        return expanded.simplify_once();
715                    }
716                E::new(Expr::Func {
717                    name: name.clone(), params: params.clone(),
718                    kind: kind.clone(), args: new_args,
719                })
720            }
721        }
722    }
723
724    /// Expand products and integer powers over sums.
725    ///
726    /// Distributes multiplication: `(a + b) * c` becomes `a*c + b*c`.
727    /// Integer powers up to 8 are expanded: `(a + b)^3` becomes the full
728    /// multinomial expansion. The result is simplified afterwards.
729    pub fn expand(&self) -> E {
730        self.expand_inner().simplify()
731    }
732
733    fn expand_inner(&self) -> E {
734        match self {
735            Expr::Sym(_) | Expr::Const(_) | Expr::NamedConst { .. } => E::new(self.clone()),
736            Expr::Neg(a) => E::new(Expr::Neg(a.expand_inner())),
737            Expr::Add(a, b) => E::new(Expr::Add(a.expand_inner(), b.expand_inner())),
738            Expr::Sub(a, b) => E::new(Expr::Sub(a.expand_inner(), b.expand_inner())),
739            Expr::Mul(a, b) => {
740                let a = a.expand_inner();
741                let b = b.expand_inner();
742                if let Expr::Add(b1, b2) = b.as_ref() {
743                    let left = E::new(Expr::Mul(a.clone(), b1.clone()));
744                    let right = E::new(Expr::Mul(a, b2.clone()));
745                    return E::new(Expr::Add(left.expand_inner(), right.expand_inner()));
746                }
747                if let Expr::Sub(b1, b2) = b.as_ref() {
748                    let left = E::new(Expr::Mul(a.clone(), b1.clone()));
749                    let right = E::new(Expr::Mul(a, b2.clone()));
750                    return E::new(Expr::Sub(left.expand_inner(), right.expand_inner()));
751                }
752                if let Expr::Add(a1, a2) = a.as_ref() {
753                    let left = E::new(Expr::Mul(a1.clone(), b.clone()));
754                    let right = E::new(Expr::Mul(a2.clone(), b));
755                    return E::new(Expr::Add(left.expand_inner(), right.expand_inner()));
756                }
757                if let Expr::Sub(a1, a2) = a.as_ref() {
758                    let left = E::new(Expr::Mul(a1.clone(), b.clone()));
759                    let right = E::new(Expr::Mul(a2.clone(), b));
760                    return E::new(Expr::Sub(left.expand_inner(), right.expand_inner()));
761                }
762                E::new(Expr::Mul(a, b))
763            }
764            Expr::Div(a, b) => E::new(Expr::Div(a.expand_inner(), b.expand_inner())),
765            Expr::Pow(base, exp) => {
766                let base = base.expand_inner();
767                let exp = exp.expand_inner();
768                if let Some(n) = is_const_int(&exp)
769                    && (2..=8).contains(&n) {
770                        let mut result = base.clone();
771                        for _ in 1..n {
772                            result = E::new(Expr::Mul(result, base.clone()));
773                        }
774                        return result.expand_inner();
775                    }
776                E::new(Expr::Pow(base, exp))
777            }
778            Expr::Sin(a) => E::new(Expr::Sin(a.expand_inner())),
779            Expr::Cos(a) => E::new(Expr::Cos(a.expand_inner())),
780            Expr::Tan(a) => E::new(Expr::Tan(a.expand_inner())),
781            Expr::Asin(a) => E::new(Expr::Asin(a.expand_inner())),
782            Expr::Acos(a) => E::new(Expr::Acos(a.expand_inner())),
783            Expr::Atan(a) => E::new(Expr::Atan(a.expand_inner())),
784            Expr::Atan2(y, x) => E::new(Expr::Atan2(y.expand_inner(), x.expand_inner())),
785            Expr::Sinh(a) => E::new(Expr::Sinh(a.expand_inner())),
786            Expr::Cosh(a) => E::new(Expr::Cosh(a.expand_inner())),
787            Expr::Tanh(a) => E::new(Expr::Tanh(a.expand_inner())),
788            Expr::Exp(a) => E::new(Expr::Exp(a.expand_inner())),
789            Expr::Ln(a) => E::new(Expr::Ln(a.expand_inner())),
790            Expr::Log2(a) => E::new(Expr::Log2(a.expand_inner())),
791            Expr::Log10(a) => E::new(Expr::Log10(a.expand_inner())),
792            Expr::Sqrt(a) => E::new(Expr::Sqrt(a.expand_inner())),
793            Expr::Abs(a) => E::new(Expr::Abs(a.expand_inner())),
794            Expr::Heaviside(a) => E::new(Expr::Heaviside(a.expand_inner())),
795            Expr::Clamp(val, lo, hi) => E::new(Expr::Clamp(val.expand_inner(), lo.expand_inner(), hi.expand_inner())),
796            Expr::Func { name, params, kind, args } => {
797                let expanded_args: Vec<E> = args.iter().map(|a| a.expand_inner()).collect();
798                if let Some(body) = kind.body() {
799                    crate::expand_func(params, body, &expanded_args).expand_inner()
800                } else {
801                    E::new(Expr::Func {
802                        name: name.clone(), params: params.clone(),
803                        kind: kind.clone(), args: expanded_args,
804                    })
805                }
806            }
807        }
808    }
809
810    /// Collect like terms containing `var` by structural match.
811    ///
812    /// `var` can be any [`AsVarName`] -- a `&str`, a `String`, or an
813    /// [`E`] handle wrapping a `Sym` node. Groups additive terms
814    /// that share `var` as a factor, summing their coefficients.
815    /// For example, `a*x + b*x + c` becomes `(a + b)*x + c`.
816    pub fn collect(&self, var: impl crate::AsVarName) -> E {
817        let var = var.var_expr();
818        let terms = flatten_add_simple(&E::new(self.clone()));
819        let mut with_var: Vec<E> = Vec::new();
820        let mut without_var: Vec<E> = Vec::new();
821
822        for term in &terms {
823            if let Some(coeff) = extract_factor(term, &var) {
824                with_var.push(coeff);
825            } else {
826                without_var.push(term.clone());
827            }
828        }
829
830        let mut result: Option<E> = None;
831
832        if !with_var.is_empty() {
833            let coeff_sum = sum_terms(with_var);
834            let collected = coeff_sum * var;
835            result = Some(collected);
836        }
837
838        for t in without_var {
839            result = Some(match result {
840                Some(acc) => acc + t,
841                None => t,
842            });
843        }
844
845        result.unwrap_or_else(|| constant(0.0))
846    }
847}
848
849fn flatten_add_simple(e: &E) -> Vec<E> {
850    match e.as_ref() {
851        Expr::Add(a, b) => {
852            let mut terms = flatten_add_simple(a);
853            terms.extend(flatten_add_simple(b));
854            terms
855        }
856        _ => vec![e.clone()],
857    }
858}
859
860fn extract_factor(term: &E, var: &E) -> Option<E> {
861    if term == var {
862        return Some(constant(1.0));
863    }
864    if let Expr::Mul(a, b) = term.as_ref() {
865        if b == var { return Some(a.clone()); }
866        if a == var { return Some(b.clone()); }
867    }
868    None
869}
870
871fn sum_terms(terms: Vec<E>) -> E {
872    let mut iter = terms.into_iter();
873    let first = iter.next().unwrap();
874    iter.fold(first, |acc, t| acc + t)
875}