1use super::{AsVarName, Expr, E, constant, sin, cos, cosh, sinh, tanh, exp, ln, sqrt, abs, pow};
2
3impl Expr {
4 pub fn diff(&self, var: impl AsVarName) -> E {
9 let var = var.var_name();
10 let zero = || constant(0.0);
11 let one = || constant(1.0);
12 let two = || constant(2.0);
13
14 match self {
15 Expr::Sym(name) => {
16 if name == var { one() } else { zero() }
17 }
18 Expr::Const(_) | Expr::NamedConst { .. } => zero(),
19 Expr::Neg(a) => {
20 -a.diff(var)
21 }
22 Expr::Add(a, b) => {
23 a.diff(var) + b.diff(var)
24 }
25 Expr::Sub(a, b) => {
26 a.diff(var) - b.diff(var)
27 }
28 Expr::Mul(a, b) => {
29 let da = a.diff(var);
31 let db = b.diff(var);
32 da * b.clone() + a.clone() * db
33 }
34 Expr::Div(a, b) => {
35 let da = a.diff(var);
37 let db = b.diff(var);
38 (da * b.clone() - a.clone() * db) / pow(b.clone(), two())
39 }
40 Expr::Pow(a, b) => {
41 let da = a.diff(var);
42 let db = b.diff(var);
43 if matches!(b.as_ref(), Expr::Const(_)) {
44 b.clone() * pow(a.clone(), b.clone() - constant(1.0)) * da
46 } else if matches!(a.as_ref(), Expr::Const(_)) {
47 pow(a.clone(), b.clone()) * ln(a.clone()) * db
49 } else {
50 let base = pow(a.clone(), b.clone());
52 base * (db * ln(a.clone()) + b.clone() * da / a.clone())
53 }
54 }
55 Expr::Sin(a) => {
56 cos(a.clone()) * a.diff(var)
57 }
58 Expr::Cos(a) => {
59 -(sin(a.clone()) * a.diff(var))
60 }
61 Expr::Tan(a) => {
62 a.diff(var) / pow(cos(a.clone()), two())
64 }
65 Expr::Asin(a) => {
66 a.diff(var) / sqrt(one() - pow(a.clone(), two()))
68 }
69 Expr::Acos(a) => {
70 -(a.diff(var) / sqrt(one() - pow(a.clone(), two())))
72 }
73 Expr::Atan(a) => {
74 a.diff(var) / (one() + pow(a.clone(), two()))
76 }
77 Expr::Atan2(y, x) => {
78 let dy = y.diff(var);
80 let dx = x.diff(var);
81 (x.clone() * dy - y.clone() * dx) / (pow(x.clone(), two()) + pow(y.clone(), two()))
82 }
83 Expr::Sinh(a) => {
84 cosh(a.clone()) * a.diff(var)
85 }
86 Expr::Cosh(a) => {
87 sinh(a.clone()) * a.diff(var)
88 }
89 Expr::Tanh(a) => {
90 a.diff(var) * (one() - pow(tanh(a.clone()), two()))
92 }
93 Expr::Exp(a) => {
94 exp(a.clone()) * a.diff(var)
95 }
96 Expr::Ln(a) => {
97 a.diff(var) / a.clone()
99 }
100 Expr::Log2(a) => {
101 a.diff(var) / (a.clone() * ln(constant(2.0)))
103 }
104 Expr::Log10(a) => {
105 a.diff(var) / (a.clone() * ln(constant(10.0)))
107 }
108 Expr::Sqrt(a) => {
109 a.diff(var) / (two() * sqrt(a.clone()))
111 }
112 Expr::Abs(a) => {
113 a.clone() * a.diff(var) / abs(a.clone())
115 }
116 Expr::Heaviside(_) => {
117 zero()
119 }
120 Expr::Clamp(val, _, _) => {
121 val.diff(var)
123 }
124 Expr::Func { params, kind, args, .. } => {
125 if let Some(body) = kind.auto_diff_body() {
126 super::expand_func(params, body, args).diff(var)
128 } else {
129 let derivs = kind.derivs().unwrap();
131 let mut result = zero();
132 for (d, a) in derivs.iter().zip(args.iter()) {
133 let da = a.diff(var);
134 if !matches!(da.as_ref(), Expr::Const(v) if *v == 0.0) {
135 result = result + super::expand_func(params, d, args) * da;
136 }
137 }
138 result
139 }
140 }
141 }
142 }
143}