Skip to main content

oxiphysics_core/
symbolic_math.rs

1// Copyright 2026 COOLJAPAN OU (Team KitaSan)
2// SPDX-License-Identifier: Apache-2.0
3
4//! Symbolic mathematics: expression trees, evaluation, differentiation, and simplification.
5//!
6//! # Quick start
7//!
8//! ```no_run
9//! use oxiphysics_core::symbolic_math::{cst, var, diff, simplify, eval};
10//! use std::collections::HashMap;
11//!
12//! // f(x) = x^2 + 1
13//! let x = var("x");
14//! let expr = x.pow(cst(2.0)).add_expr(cst(1.0));
15//! let df = simplify(&diff(&expr, "x")); // 2*x
16//! let mut vars = HashMap::new();
17//! vars.insert("x".to_string(), 3.0);
18//! assert!((eval(&df, &vars).unwrap() - 6.0).abs() < 1e-12);
19//! ```
20
21use std::collections::HashMap;
22
23// ---------------------------------------------------------------------------
24// Expression tree
25// ---------------------------------------------------------------------------
26
27/// A symbolic mathematical expression tree.
28///
29/// Each variant represents one kind of mathematical operation or atom.
30#[derive(Debug, Clone, PartialEq)]
31pub enum Expr {
32    /// A numeric constant.
33    Const(f64),
34    /// A named variable.
35    Var(String),
36    /// Addition of two sub-expressions.
37    Add(Box<Expr>, Box<Expr>),
38    /// Multiplication of two sub-expressions.
39    Mul(Box<Expr>, Box<Expr>),
40    /// Exponentiation: base raised to exponent.
41    Pow(Box<Expr>, Box<Expr>),
42    /// Unary negation.
43    Neg(Box<Expr>),
44    /// Sine function.
45    Sin(Box<Expr>),
46    /// Cosine function.
47    Cos(Box<Expr>),
48    /// Natural exponential function.
49    Exp(Box<Expr>),
50    /// Natural logarithm.
51    Ln(Box<Expr>),
52}
53
54// ---------------------------------------------------------------------------
55// Convenience constructors (free functions)
56// ---------------------------------------------------------------------------
57
58/// Create a named-variable expression.
59pub fn var(name: &str) -> Expr {
60    Expr::Var(name.to_string())
61}
62
63/// Create a numeric constant expression.
64pub fn cst(v: f64) -> Expr {
65    Expr::Const(v)
66}
67
68// ---------------------------------------------------------------------------
69// Expr: helper methods
70// ---------------------------------------------------------------------------
71
72impl Expr {
73    /// Add another expression to `self`: `self + rhs`.
74    pub fn add_expr(self, rhs: Expr) -> Expr {
75        Expr::Add(Box::new(self), Box::new(rhs))
76    }
77
78    /// Subtract another expression from `self`: `self + (-rhs)`.
79    pub fn sub_expr(self, rhs: Expr) -> Expr {
80        Expr::Add(Box::new(self), Box::new(Expr::Neg(Box::new(rhs))))
81    }
82
83    /// Multiply `self` by another expression: `self * rhs`.
84    pub fn mul_expr(self, rhs: Expr) -> Expr {
85        Expr::Mul(Box::new(self), Box::new(rhs))
86    }
87
88    /// Raise `self` to the power of `exp`.
89    pub fn pow(self, exp: Expr) -> Expr {
90        Expr::Pow(Box::new(self), Box::new(exp))
91    }
92
93    /// Negate `self`: `-self`.
94    #[allow(clippy::should_implement_trait)]
95    pub fn neg(self) -> Expr {
96        Expr::Neg(Box::new(self))
97    }
98
99    /// Apply sine to `self`.
100    pub fn sin(self) -> Expr {
101        Expr::Sin(Box::new(self))
102    }
103
104    /// Apply cosine to `self`.
105    pub fn cos(self) -> Expr {
106        Expr::Cos(Box::new(self))
107    }
108
109    /// Apply the natural exponential to `self`.
110    pub fn exp(self) -> Expr {
111        Expr::Exp(Box::new(self))
112    }
113
114    /// Apply the natural logarithm to `self`.
115    pub fn ln(self) -> Expr {
116        Expr::Ln(Box::new(self))
117    }
118}
119
120// ---------------------------------------------------------------------------
121// Evaluation
122// ---------------------------------------------------------------------------
123
124/// Evaluate `expr` by substituting variable values from `vars`.
125///
126/// Returns `Err` if a variable is encountered that is not present in `vars`,
127/// or if the expression is undefined (e.g., `ln(0)`).
128pub fn eval(expr: &Expr, vars: &HashMap<String, f64>) -> Result<f64, String> {
129    match expr {
130        Expr::Const(c) => Ok(*c),
131        Expr::Var(name) => vars
132            .get(name)
133            .copied()
134            .ok_or_else(|| format!("undefined variable: {name}")),
135        Expr::Add(a, b) => Ok(eval(a, vars)? + eval(b, vars)?),
136        Expr::Mul(a, b) => Ok(eval(a, vars)? * eval(b, vars)?),
137        Expr::Pow(base, exp) => Ok(eval(base, vars)?.powf(eval(exp, vars)?)),
138        Expr::Neg(inner) => Ok(-eval(inner, vars)?),
139        Expr::Sin(inner) => Ok(eval(inner, vars)?.sin()),
140        Expr::Cos(inner) => Ok(eval(inner, vars)?.cos()),
141        Expr::Exp(inner) => Ok(eval(inner, vars)?.exp()),
142        Expr::Ln(inner) => {
143            let v = eval(inner, vars)?;
144            if v <= 0.0 {
145                Err(format!("ln of non-positive value: {v}"))
146            } else {
147                Ok(v.ln())
148            }
149        }
150    }
151}
152
153// ---------------------------------------------------------------------------
154// Symbolic differentiation
155// ---------------------------------------------------------------------------
156
157/// Symbolically differentiate `expr` with respect to the variable named `var`.
158///
159/// Returns a new expression for the derivative (not yet simplified).
160pub fn diff(expr: &Expr, var: &str) -> Expr {
161    match expr {
162        Expr::Const(_) => cst(0.0),
163        Expr::Var(name) => {
164            if name == var {
165                cst(1.0)
166            } else {
167                cst(0.0)
168            }
169        }
170        // (f + g)' = f' + g'
171        Expr::Add(f, g) => Expr::Add(Box::new(diff(f, var)), Box::new(diff(g, var))),
172        // (f * g)' = f' * g + f * g'
173        Expr::Mul(f, g) => Expr::Add(
174            Box::new(Expr::Mul(Box::new(diff(f, var)), g.clone())),
175            Box::new(Expr::Mul(f.clone(), Box::new(diff(g, var)))),
176        ),
177        // (f^n)' = n * f^(n-1) * f'   — general power rule via chain rule
178        Expr::Pow(base, exp) => {
179            // d/dx [f^g] = f^g * (g' * ln(f) + g * f'/f)
180            // For the common case where exp is a Const, simplify:
181            if let Expr::Const(n) = exp.as_ref() {
182                let n = *n;
183                // n * base^(n-1) * base'
184                Expr::Mul(
185                    Box::new(Expr::Mul(
186                        Box::new(cst(n)),
187                        Box::new(Expr::Pow(base.clone(), Box::new(cst(n - 1.0)))),
188                    )),
189                    Box::new(diff(base, var)),
190                )
191            } else {
192                // General: (f^g)' = f^g * (g' * ln(f) + g * f'/f)
193                let f = base.as_ref();
194                let g = exp.as_ref();
195                let fg = Expr::Pow(base.clone(), exp.clone());
196                let term1 = Expr::Mul(Box::new(diff(g, var)), Box::new(Expr::Ln(base.clone())));
197                let term2 = Expr::Mul(
198                    g.clone().into(),
199                    Box::new(Expr::Mul(
200                        Box::new(diff(f, var)),
201                        Box::new(Expr::Pow(base.clone(), Box::new(cst(-1.0)))),
202                    )),
203                );
204                Expr::Mul(
205                    Box::new(fg),
206                    Box::new(Expr::Add(Box::new(term1), Box::new(term2))),
207                )
208            }
209        }
210        // (-f)' = -(f')
211        Expr::Neg(f) => Expr::Neg(Box::new(diff(f, var))),
212        // sin(f)' = cos(f) * f'
213        Expr::Sin(f) => Expr::Mul(Box::new(Expr::Cos(f.clone())), Box::new(diff(f, var))),
214        // cos(f)' = -sin(f) * f'
215        Expr::Cos(f) => Expr::Neg(Box::new(Expr::Mul(
216            Box::new(Expr::Sin(f.clone())),
217            Box::new(diff(f, var)),
218        ))),
219        // exp(f)' = exp(f) * f'
220        Expr::Exp(f) => Expr::Mul(Box::new(Expr::Exp(f.clone())), Box::new(diff(f, var))),
221        // ln(f)' = f' / f  = f' * f^(-1)
222        Expr::Ln(f) => Expr::Mul(
223            Box::new(diff(f, var)),
224            Box::new(Expr::Pow(f.clone(), Box::new(cst(-1.0)))),
225        ),
226    }
227}
228
229// ---------------------------------------------------------------------------
230// Simplification
231// ---------------------------------------------------------------------------
232
233/// Apply basic algebraic simplifications to `expr`.
234///
235/// Rules applied (recursively, bottom-up):
236/// - `0 + x = x`, `x + 0 = x`
237/// - `0 * x = 0`, `x * 0 = 0`, `1 * x = x`, `x * 1 = x`
238/// - `x ^ 0 = 1`, `x ^ 1 = x`
239/// - `--x = x` (double negation)
240/// - `-0 = 0`
241/// - Fold constant sub-expressions to a single `Const`.
242pub fn simplify(expr: &Expr) -> Expr {
243    match expr {
244        // Atoms are already in simplest form.
245        Expr::Const(_) | Expr::Var(_) => expr.clone(),
246
247        Expr::Add(a, b) => {
248            let a = simplify(a);
249            let b = simplify(b);
250            // Constant folding
251            if let (Expr::Const(x), Expr::Const(y)) = (&a, &b) {
252                return cst(x + y);
253            }
254            // 0 + b = b
255            if matches!(a, Expr::Const(x) if x == 0.0) {
256                return b;
257            }
258            // a + 0 = a
259            if matches!(b, Expr::Const(x) if x == 0.0) {
260                return a;
261            }
262            Expr::Add(Box::new(a), Box::new(b))
263        }
264
265        Expr::Mul(a, b) => {
266            let a = simplify(a);
267            let b = simplify(b);
268            // Constant folding
269            if let (Expr::Const(x), Expr::Const(y)) = (&a, &b) {
270                return cst(x * y);
271            }
272            // 0 * b = 0  or  a * 0 = 0
273            if matches!(a, Expr::Const(x) if x == 0.0) {
274                return cst(0.0);
275            }
276            if matches!(b, Expr::Const(x) if x == 0.0) {
277                return cst(0.0);
278            }
279            // 1 * b = b
280            if matches!(a, Expr::Const(x) if x == 1.0) {
281                return b;
282            }
283            // a * 1 = a
284            if matches!(b, Expr::Const(x) if x == 1.0) {
285                return a;
286            }
287            Expr::Mul(Box::new(a), Box::new(b))
288        }
289
290        Expr::Pow(base, exp) => {
291            let base = simplify(base);
292            let exp = simplify(exp);
293            // Constant folding
294            if let (Expr::Const(b), Expr::Const(e)) = (&base, &exp) {
295                return cst(b.powf(*e));
296            }
297            // x^0 = 1
298            if matches!(exp, Expr::Const(e) if e == 0.0) {
299                return cst(1.0);
300            }
301            // x^1 = x
302            if matches!(exp, Expr::Const(e) if e == 1.0) {
303                return base;
304            }
305            Expr::Pow(Box::new(base), Box::new(exp))
306        }
307
308        Expr::Neg(inner) => {
309            let inner = simplify(inner);
310            // Constant folding
311            if let Expr::Const(c) = &inner {
312                return cst(-c);
313            }
314            // Double negation: -(-x) = x
315            if let Expr::Neg(x) = inner {
316                return *x;
317            }
318            Expr::Neg(Box::new(inner))
319        }
320
321        Expr::Sin(inner) => Expr::Sin(Box::new(simplify(inner))),
322        Expr::Cos(inner) => Expr::Cos(Box::new(simplify(inner))),
323        Expr::Exp(inner) => {
324            let inner = simplify(inner);
325            if let Expr::Const(c) = &inner {
326                return cst(c.exp());
327            }
328            Expr::Exp(Box::new(inner))
329        }
330        Expr::Ln(inner) => {
331            let inner = simplify(inner);
332            if let Expr::Const(c) = &inner
333                && *c > 0.0
334            {
335                return cst(c.ln());
336            }
337            Expr::Ln(Box::new(inner))
338        }
339    }
340}
341
342// ---------------------------------------------------------------------------
343// Pretty-print
344// ---------------------------------------------------------------------------
345
346/// Convert `expr` to a human-readable infix string.
347pub fn to_string(expr: &Expr) -> String {
348    expr_to_str(expr)
349}
350
351fn expr_to_str(expr: &Expr) -> String {
352    match expr {
353        Expr::Const(c) => {
354            // Print integers without a decimal point for readability
355            if c.fract() == 0.0 && c.abs() < 1e15 {
356                format!("{}", *c as i64)
357            } else {
358                format!("{c}")
359            }
360        }
361        Expr::Var(name) => name.clone(),
362        Expr::Add(a, b) => {
363            let bs = expr_to_str(b);
364            // If b is a negation, display as subtraction
365            if let Expr::Neg(inner) = b.as_ref() {
366                format!("({} - {})", expr_to_str(a), expr_to_str(inner))
367            } else if let Some(bs_stripped) = bs.strip_prefix('-') {
368                format!("({} - {})", expr_to_str(a), bs_stripped)
369            } else {
370                format!("({} + {})", expr_to_str(a), bs)
371            }
372        }
373        Expr::Mul(a, b) => format!("({} * {})", expr_to_str(a), expr_to_str(b)),
374        Expr::Pow(base, exp) => format!("({}^{})", expr_to_str(base), expr_to_str(exp)),
375        Expr::Neg(inner) => format!("(-{})", expr_to_str(inner)),
376        Expr::Sin(inner) => format!("sin({})", expr_to_str(inner)),
377        Expr::Cos(inner) => format!("cos({})", expr_to_str(inner)),
378        Expr::Exp(inner) => format!("exp({})", expr_to_str(inner)),
379        Expr::Ln(inner) => format!("ln({})", expr_to_str(inner)),
380    }
381}
382
383// ---------------------------------------------------------------------------
384// Tests
385// ---------------------------------------------------------------------------
386
387#[cfg(test)]
388mod tests {
389    use super::*;
390
391    fn vars(bindings: &[(&str, f64)]) -> HashMap<String, f64> {
392        bindings.iter().map(|(k, v)| (k.to_string(), *v)).collect()
393    }
394
395    // --- eval ---
396
397    #[test]
398    fn eval_const() {
399        assert_eq!(eval(&cst(3.125), &HashMap::new()).unwrap(), 3.125);
400    }
401
402    #[test]
403    fn eval_var_found() {
404        let e = var("x");
405        assert_eq!(eval(&e, &vars(&[("x", 5.0)])).unwrap(), 5.0);
406    }
407
408    #[test]
409    fn eval_var_missing_returns_err() {
410        let e = var("y");
411        assert!(eval(&e, &HashMap::new()).is_err());
412    }
413
414    #[test]
415    fn eval_add() {
416        let e = var("x").add_expr(cst(1.0));
417        assert_eq!(eval(&e, &vars(&[("x", 4.0)])).unwrap(), 5.0);
418    }
419
420    #[test]
421    fn eval_mul() {
422        let e = var("x").mul_expr(cst(3.0));
423        assert_eq!(eval(&e, &vars(&[("x", 2.0)])).unwrap(), 6.0);
424    }
425
426    #[test]
427    fn eval_pow() {
428        let e = var("x").pow(cst(3.0));
429        assert!((eval(&e, &vars(&[("x", 2.0)])).unwrap() - 8.0).abs() < 1e-12);
430    }
431
432    #[test]
433    fn eval_neg() {
434        let e = var("x").neg();
435        assert_eq!(eval(&e, &vars(&[("x", 7.0)])).unwrap(), -7.0);
436    }
437
438    #[test]
439    fn eval_sin() {
440        let e = var("x").sin();
441        let got = eval(&e, &vars(&[("x", 0.0)])).unwrap();
442        assert!(got.abs() < 1e-12);
443    }
444
445    #[test]
446    fn eval_cos() {
447        let e = var("x").cos();
448        let got = eval(&e, &vars(&[("x", 0.0)])).unwrap();
449        assert!((got - 1.0).abs() < 1e-12);
450    }
451
452    #[test]
453    fn eval_exp() {
454        let e = var("x").exp();
455        let got = eval(&e, &vars(&[("x", 0.0)])).unwrap();
456        assert!((got - 1.0).abs() < 1e-12);
457    }
458
459    #[test]
460    fn eval_ln() {
461        let e = var("x").ln();
462        let got = eval(&e, &vars(&[("x", 1.0)])).unwrap();
463        assert!(got.abs() < 1e-12);
464    }
465
466    #[test]
467    fn eval_ln_nonpositive_returns_err() {
468        let e = var("x").ln();
469        assert!(eval(&e, &vars(&[("x", 0.0)])).is_err());
470        assert!(eval(&e, &vars(&[("x", -1.0)])).is_err());
471    }
472
473    #[test]
474    fn eval_complex_poly() {
475        // 3x^2 + 2x + 1 at x=2 → 3*4 + 4 + 1 = 17
476        let x = var("x");
477        let e = cst(3.0)
478            .mul_expr(x.clone().pow(cst(2.0)))
479            .add_expr(cst(2.0).mul_expr(x.clone()))
480            .add_expr(cst(1.0));
481        let got = eval(&e, &vars(&[("x", 2.0)])).unwrap();
482        assert!((got - 17.0).abs() < 1e-12);
483    }
484
485    // --- diff ---
486
487    #[test]
488    fn diff_const_is_zero() {
489        let e = diff(&cst(42.0), "x");
490        assert_eq!(simplify(&e), cst(0.0));
491    }
492
493    #[test]
494    fn diff_var_self_is_one() {
495        let e = diff(&var("x"), "x");
496        assert_eq!(simplify(&e), cst(1.0));
497    }
498
499    #[test]
500    fn diff_var_other_is_zero() {
501        let e = diff(&var("y"), "x");
502        assert_eq!(simplify(&e), cst(0.0));
503    }
504
505    #[test]
506    fn diff_linear() {
507        // d/dx (3x) = 3
508        let e = cst(3.0).mul_expr(var("x"));
509        let d = simplify(&diff(&e, "x"));
510        let got = eval(&d, &HashMap::new()).unwrap();
511        assert!((got - 3.0).abs() < 1e-12);
512    }
513
514    #[test]
515    fn diff_quadratic() {
516        // d/dx (x^2) = 2x  → at x=3: 6
517        let e = var("x").pow(cst(2.0));
518        let d = simplify(&diff(&e, "x"));
519        let got = eval(&d, &vars(&[("x", 3.0)])).unwrap();
520        assert!((got - 6.0).abs() < 1e-12);
521    }
522
523    #[test]
524    fn diff_cubic() {
525        // d/dx (x^3) = 3x^2 → at x=2: 12
526        let e = var("x").pow(cst(3.0));
527        let d = simplify(&diff(&e, "x"));
528        let got = eval(&d, &vars(&[("x", 2.0)])).unwrap();
529        assert!((got - 12.0).abs() < 1e-12);
530    }
531
532    #[test]
533    fn diff_sin() {
534        // d/dx sin(x) = cos(x) → at x=0: 1
535        let e = var("x").sin();
536        let d = simplify(&diff(&e, "x"));
537        let got = eval(&d, &vars(&[("x", 0.0)])).unwrap();
538        assert!((got - 1.0).abs() < 1e-12);
539    }
540
541    #[test]
542    fn diff_cos() {
543        // d/dx cos(x) = -sin(x) → at x=0: 0
544        let e = var("x").cos();
545        let d = simplify(&diff(&e, "x"));
546        let got = eval(&d, &vars(&[("x", 0.0)])).unwrap();
547        assert!(got.abs() < 1e-12);
548    }
549
550    #[test]
551    fn diff_exp() {
552        // d/dx exp(x) = exp(x) → at x=0: 1
553        let e = var("x").exp();
554        let d = simplify(&diff(&e, "x"));
555        let got = eval(&d, &vars(&[("x", 0.0)])).unwrap();
556        assert!((got - 1.0).abs() < 1e-12);
557    }
558
559    #[test]
560    fn diff_ln() {
561        // d/dx ln(x) = 1/x → at x=2: 0.5
562        let e = var("x").ln();
563        let d = simplify(&diff(&e, "x"));
564        let got = eval(&d, &vars(&[("x", 2.0)])).unwrap();
565        assert!((got - 0.5).abs() < 1e-12);
566    }
567
568    #[test]
569    fn diff_product_rule() {
570        // d/dx (x * sin(x)) = sin(x) + x*cos(x) → at x=0: 0
571        let e = var("x").mul_expr(var("x").sin());
572        let d = simplify(&diff(&e, "x"));
573        let got = eval(&d, &vars(&[("x", 0.0)])).unwrap();
574        assert!(got.abs() < 1e-12);
575    }
576
577    #[test]
578    fn diff_chain_sin_of_poly() {
579        // d/dx sin(x^2) = cos(x^2) * 2x → at x=0: 0
580        let e = var("x").pow(cst(2.0)).sin();
581        let d = simplify(&diff(&e, "x"));
582        let got = eval(&d, &vars(&[("x", 0.0)])).unwrap();
583        assert!(got.abs() < 1e-12);
584    }
585
586    #[test]
587    fn diff_neg() {
588        // d/dx (-x) = -1
589        let e = var("x").neg();
590        let d = simplify(&diff(&e, "x"));
591        let got = eval(&d, &HashMap::new()).unwrap();
592        assert!((got + 1.0).abs() < 1e-12);
593    }
594
595    // --- simplify ---
596
597    #[test]
598    fn simplify_zero_plus_x() {
599        let e = cst(0.0).add_expr(var("x"));
600        assert_eq!(simplify(&e), var("x"));
601    }
602
603    #[test]
604    fn simplify_x_plus_zero() {
605        let e = var("x").add_expr(cst(0.0));
606        assert_eq!(simplify(&e), var("x"));
607    }
608
609    #[test]
610    fn simplify_zero_times_x() {
611        let e = cst(0.0).mul_expr(var("x"));
612        assert_eq!(simplify(&e), cst(0.0));
613    }
614
615    #[test]
616    fn simplify_x_times_zero() {
617        let e = var("x").mul_expr(cst(0.0));
618        assert_eq!(simplify(&e), cst(0.0));
619    }
620
621    #[test]
622    fn simplify_one_times_x() {
623        let e = cst(1.0).mul_expr(var("x"));
624        assert_eq!(simplify(&e), var("x"));
625    }
626
627    #[test]
628    fn simplify_x_times_one() {
629        let e = var("x").mul_expr(cst(1.0));
630        assert_eq!(simplify(&e), var("x"));
631    }
632
633    #[test]
634    fn simplify_x_pow_zero() {
635        let e = var("x").pow(cst(0.0));
636        assert_eq!(simplify(&e), cst(1.0));
637    }
638
639    #[test]
640    fn simplify_x_pow_one() {
641        let e = var("x").pow(cst(1.0));
642        assert_eq!(simplify(&e), var("x"));
643    }
644
645    #[test]
646    fn simplify_double_neg() {
647        let e = var("x").neg().neg();
648        assert_eq!(simplify(&e), var("x"));
649    }
650
651    #[test]
652    fn simplify_const_fold_add() {
653        let e = cst(3.0).add_expr(cst(4.0));
654        assert_eq!(simplify(&e), cst(7.0));
655    }
656
657    #[test]
658    fn simplify_const_fold_mul() {
659        let e = cst(3.0).mul_expr(cst(4.0));
660        assert_eq!(simplify(&e), cst(12.0));
661    }
662
663    #[test]
664    fn simplify_const_fold_pow() {
665        let e = cst(2.0).pow(cst(10.0));
666        assert_eq!(simplify(&e), cst(1024.0));
667    }
668
669    // --- to_string ---
670
671    #[test]
672    fn to_string_const() {
673        assert_eq!(to_string(&cst(3.0)), "3");
674    }
675
676    #[test]
677    fn to_string_var() {
678        assert_eq!(to_string(&var("theta")), "theta");
679    }
680
681    #[test]
682    fn to_string_add() {
683        let e = var("x").add_expr(cst(1.0));
684        let s = to_string(&e);
685        assert!(s.contains("x") && s.contains("1") && s.contains("+"));
686    }
687
688    #[test]
689    fn to_string_mul() {
690        let e = var("a").mul_expr(var("b"));
691        let s = to_string(&e);
692        assert!(s.contains("a") && s.contains("b") && s.contains("*"));
693    }
694
695    #[test]
696    fn to_string_pow() {
697        let e = var("x").pow(cst(2.0));
698        let s = to_string(&e);
699        assert!(s.contains("x") && s.contains("2") && s.contains("^"));
700    }
701
702    #[test]
703    fn to_string_sin() {
704        let s = to_string(&var("x").sin());
705        assert!(s.starts_with("sin("));
706    }
707
708    #[test]
709    fn to_string_cos() {
710        let s = to_string(&var("x").cos());
711        assert!(s.starts_with("cos("));
712    }
713
714    #[test]
715    fn to_string_exp() {
716        let s = to_string(&var("x").exp());
717        assert!(s.starts_with("exp("));
718    }
719
720    #[test]
721    fn to_string_ln() {
722        let s = to_string(&var("x").ln());
723        assert!(s.starts_with("ln("));
724    }
725
726    #[test]
727    fn to_string_neg() {
728        let s = to_string(&var("x").neg());
729        assert!(s.contains("x") && s.contains('-'));
730    }
731
732    // --- combined ---
733
734    #[test]
735    fn diff_poly_numeric_check() {
736        // d/dx (x^4 - 3x^2 + 2) at x=1: 4*1 - 6*1 = -2
737        let x = var("x");
738        let poly = x
739            .clone()
740            .pow(cst(4.0))
741            .sub_expr(cst(3.0).mul_expr(x.clone().pow(cst(2.0))))
742            .add_expr(cst(2.0));
743        let d = simplify(&diff(&poly, "x"));
744        let got = eval(&d, &vars(&[("x", 1.0)])).unwrap();
745        assert!((got - (-2.0)).abs() < 1e-10);
746    }
747
748    #[test]
749    fn diff_exp_of_linear() {
750        // d/dx exp(3x) = 3*exp(3x) → at x=0: 3
751        let e = cst(3.0).mul_expr(var("x")).exp();
752        let d = simplify(&diff(&e, "x"));
753        let got = eval(&d, &vars(&[("x", 0.0)])).unwrap();
754        assert!((got - 3.0).abs() < 1e-12);
755    }
756}