1use super::expr::Expr;
4
5pub fn diff(expr: &Expr, var: &str) -> Expr {
7 match expr {
8 Expr::Var(name) => {
9 if name == var { Expr::one() } else { Expr::zero() }
10 }
11 Expr::Const(_) => Expr::zero(),
12
13 Expr::Neg(a) => Expr::Neg(Box::new(diff(a, var))),
15
16 Expr::Add(a, b) => Expr::Add(Box::new(diff(a, var)), Box::new(diff(b, var))),
18
19 Expr::Sub(a, b) => Expr::Sub(Box::new(diff(a, var)), Box::new(diff(b, var))),
21
22 Expr::Mul(a, b) => {
24 let left = Expr::Mul(a.clone(), Box::new(diff(b, var)));
25 let right = Expr::Mul(Box::new(diff(a, var)), b.clone());
26 Expr::Add(Box::new(left), Box::new(right))
27 }
28
29 Expr::Div(a, b) => {
31 let num_left = Expr::Mul(Box::new(diff(a, var)), b.clone());
32 let num_right = Expr::Mul(a.clone(), Box::new(diff(b, var)));
33 let numerator = Expr::Sub(Box::new(num_left), Box::new(num_right));
34 let denominator = Expr::Pow(b.clone(), Box::new(Expr::c(2.0)));
35 Expr::Div(Box::new(numerator), Box::new(denominator))
36 }
37
38 Expr::Pow(base, exp) => {
40 let base_has_var = base.contains_var(var);
41 let exp_has_var = exp.contains_var(var);
42
43 if !base_has_var && !exp_has_var {
44 Expr::zero()
45 } else if base_has_var && !exp_has_var {
46 let n_minus_1 = Expr::Sub(exp.clone(), Box::new(Expr::one()));
48 let term = Expr::Mul(
49 exp.clone(),
50 Box::new(Expr::Pow(base.clone(), Box::new(n_minus_1))),
51 );
52 Expr::Mul(Box::new(term), Box::new(diff(base, var)))
53 } else if !base_has_var && exp_has_var {
54 let term = Expr::Mul(
56 Box::new(expr.clone()),
57 Box::new(Expr::Ln(base.clone())),
58 );
59 Expr::Mul(Box::new(term), Box::new(diff(exp, var)))
60 } else {
61 let ln_f = Expr::Ln(base.clone());
63 let term1 = Expr::Mul(Box::new(diff(exp, var)), Box::new(ln_f));
64 let term2 = Expr::Mul(
65 exp.clone(),
66 Box::new(Expr::Div(Box::new(diff(base, var)), base.clone())),
67 );
68 Expr::Mul(Box::new(expr.clone()), Box::new(Expr::Add(Box::new(term1), Box::new(term2))))
69 }
70 }
71
72 Expr::Sin(a) => {
74 Expr::Mul(Box::new(Expr::Cos(a.clone())), Box::new(diff(a, var)))
76 }
77 Expr::Cos(a) => {
78 Expr::Mul(
80 Box::new(Expr::Neg(Box::new(Expr::Sin(a.clone())))),
81 Box::new(diff(a, var)),
82 )
83 }
84 Expr::Tan(a) => {
85 let sec_sq = Expr::Add(
87 Box::new(Expr::one()),
88 Box::new(Expr::Pow(Box::new(Expr::Tan(a.clone())), Box::new(Expr::c(2.0)))),
89 );
90 Expr::Mul(Box::new(sec_sq), Box::new(diff(a, var)))
91 }
92 Expr::Ln(a) => {
93 Expr::Div(Box::new(diff(a, var)), a.clone())
95 }
96 Expr::Exp(a) => {
97 Expr::Mul(Box::new(Expr::Exp(a.clone())), Box::new(diff(a, var)))
99 }
100 Expr::Sqrt(a) => {
101 Expr::Div(
103 Box::new(diff(a, var)),
104 Box::new(Expr::Mul(Box::new(Expr::c(2.0)), Box::new(Expr::Sqrt(a.clone())))),
105 )
106 }
107
108 _ => Expr::Derivative { body: Box::new(expr.clone()), var: var.to_string() },
110 }
111}
112
113#[cfg(test)]
114mod tests {
115 use super::*;
116 use std::collections::HashMap;
117
118 fn eval_at(expr: &Expr, x: f64) -> f64 {
119 let mut vars = HashMap::new();
120 vars.insert("x".to_string(), x);
121 expr.eval(&vars)
122 }
123
124 #[test]
125 fn diff_constant_is_zero() {
126 let d = diff(&Expr::c(5.0), "x");
127 assert_eq!(eval_at(&d, 1.0), 0.0);
128 }
129
130 #[test]
131 fn diff_x_is_one() {
132 let d = diff(&Expr::var("x"), "x");
133 assert_eq!(eval_at(&d, 42.0), 1.0);
134 }
135
136 #[test]
137 fn diff_x_squared() {
138 let expr = Expr::var("x").pow(Expr::c(2.0));
140 let d = diff(&expr, "x");
141 let result = eval_at(&d, 3.0);
142 assert!((result - 6.0).abs() < 0.01, "d/dx(x²) at x=3 should be 6, got {result}");
143 }
144
145 #[test]
146 fn diff_sin_x() {
147 let expr = Expr::var("x").sin();
149 let d = diff(&expr, "x");
150 let result = eval_at(&d, 0.0);
151 assert!((result - 1.0).abs() < 0.01, "cos(0) should be 1, got {result}");
152 }
153
154 #[test]
155 fn diff_exp_x() {
156 let expr = Expr::var("x").exp();
158 let d = diff(&expr, "x");
159 let result = eval_at(&d, 1.0);
160 let expected = std::f64::consts::E;
161 assert!((result - expected).abs() < 0.01);
162 }
163
164 #[test]
165 fn diff_product_rule() {
166 let expr = Expr::var("x").mul(Expr::var("x").sin());
168 let d = diff(&expr, "x");
169 let x = 1.0;
170 let expected = x.sin() + x * x.cos();
171 let result = eval_at(&d, x);
172 assert!((result - expected).abs() < 0.01, "got {result}, expected {expected}");
173 }
174
175 #[test]
176 fn diff_ln_x() {
177 let expr = Expr::var("x").ln();
179 let d = diff(&expr, "x");
180 let result = eval_at(&d, 2.0);
181 assert!((result - 0.5).abs() < 0.01);
182 }
183}