Skip to main content

arael_sym/
diff.rs

1use super::{AsVarName, Expr, E, constant, sin, cos, cosh, sinh, tanh, exp, ln, sqrt, abs, pow};
2
3impl Expr {
4    /// Symbolically differentiate this expression with respect to a variable.
5    ///
6    /// Applies the chain rule, product rule, and quotient rule automatically.
7    /// The result is simplified.
8    pub fn diff(&self, var: impl AsVarName) -> E {
9        let var = var.var_name();
10        let zero = || constant(0.0);
11        let one = || constant(1.0);
12        let two = || constant(2.0);
13
14        match self {
15            Expr::Sym(name) => {
16                if name == var { one() } else { zero() }
17            }
18            Expr::Const(_) | Expr::NamedConst { .. } => zero(),
19            Expr::Neg(a) => {
20                -a.diff(var)
21            }
22            Expr::Add(a, b) => {
23                a.diff(var) + b.diff(var)
24            }
25            Expr::Sub(a, b) => {
26                a.diff(var) - b.diff(var)
27            }
28            Expr::Mul(a, b) => {
29                // product rule: a'*b + a*b'
30                let da = a.diff(var);
31                let db = b.diff(var);
32                da * b.clone() + a.clone() * db
33            }
34            Expr::Div(a, b) => {
35                // quotient rule: (a'*b - a*b') / b^2
36                let da = a.diff(var);
37                let db = b.diff(var);
38                (da * b.clone() - a.clone() * db) / pow(b.clone(), two())
39            }
40            Expr::Pow(a, b) => {
41                let da = a.diff(var);
42                let db = b.diff(var);
43                if matches!(b.as_ref(), Expr::Const(_)) {
44                    // Power rule: n * a^(n-1) * a'
45                    b.clone() * pow(a.clone(), b.clone() - constant(1.0)) * da
46                } else if matches!(a.as_ref(), Expr::Const(_)) {
47                    // Constant base: a^b * ln(a) * b'
48                    pow(a.clone(), b.clone()) * ln(a.clone()) * db
49                } else {
50                    // General: a^b * (b' * ln(a) + b * a' / a)
51                    let base = pow(a.clone(), b.clone());
52                    base * (db * ln(a.clone()) + b.clone() * da / a.clone())
53                }
54            }
55            Expr::Sin(a) => {
56                cos(a.clone()) * a.diff(var)
57            }
58            Expr::Cos(a) => {
59                -(sin(a.clone()) * a.diff(var))
60            }
61            Expr::Tan(a) => {
62                // a' / cos(a)^2
63                a.diff(var) / pow(cos(a.clone()), two())
64            }
65            Expr::Asin(a) => {
66                // a' / sqrt(1 - a^2)
67                a.diff(var) / sqrt(one() - pow(a.clone(), two()))
68            }
69            Expr::Acos(a) => {
70                // -a' / sqrt(1 - a^2)
71                -(a.diff(var) / sqrt(one() - pow(a.clone(), two())))
72            }
73            Expr::Atan(a) => {
74                // a' / (1 + a^2)
75                a.diff(var) / (one() + pow(a.clone(), two()))
76            }
77            Expr::Atan2(y, x) => {
78                // (x*dy - y*dx) / (x^2 + y^2)
79                let dy = y.diff(var);
80                let dx = x.diff(var);
81                (x.clone() * dy - y.clone() * dx) / (pow(x.clone(), two()) + pow(y.clone(), two()))
82            }
83            Expr::Sinh(a) => {
84                cosh(a.clone()) * a.diff(var)
85            }
86            Expr::Cosh(a) => {
87                sinh(a.clone()) * a.diff(var)
88            }
89            Expr::Tanh(a) => {
90                // a' * (1 - tanh(a)^2)
91                a.diff(var) * (one() - pow(tanh(a.clone()), two()))
92            }
93            Expr::Exp(a) => {
94                exp(a.clone()) * a.diff(var)
95            }
96            Expr::Ln(a) => {
97                // a' / a
98                a.diff(var) / a.clone()
99            }
100            Expr::Log2(a) => {
101                // a' / (a * ln(2))
102                a.diff(var) / (a.clone() * ln(constant(2.0)))
103            }
104            Expr::Log10(a) => {
105                // a' / (a * ln(10))
106                a.diff(var) / (a.clone() * ln(constant(10.0)))
107            }
108            Expr::Sqrt(a) => {
109                // a' / (2 * sqrt(a))
110                a.diff(var) / (two() * sqrt(a.clone()))
111            }
112            Expr::Abs(a) => {
113                // a * a' / |a|  (i.e., sign(a) * a')
114                a.clone() * a.diff(var) / abs(a.clone())
115            }
116            Expr::Heaviside(_) => {
117                // Derivative is 0 everywhere (pragmatic, not Dirac delta)
118                zero()
119            }
120            Expr::Clamp(val, _, _) => {
121                // Pass-through: derivative ignores the clamping
122                val.diff(var)
123            }
124            Expr::Func { params, kind, args, .. } => {
125                if let Some(body) = kind.auto_diff_body() {
126                    // Auto-diff: expand body, differentiate
127                    super::expand_func(params, body, args).diff(var)
128                } else {
129                    // Explicit derivs: chain rule df/dvar = sum_i(df/dp_i * dp_i/dvar)
130                    let derivs = kind.derivs().unwrap();
131                    let mut result = zero();
132                    for (d, a) in derivs.iter().zip(args.iter()) {
133                        let da = a.diff(var);
134                        if !matches!(da.as_ref(), Expr::Const(v) if *v == 0.0) {
135                            result = result + super::expand_func(params, d, args) * da;
136                        }
137                    }
138                    result
139                }
140            }
141        }
142    }
143}