Skip to main content

proof_engine/symbolic/
differentiate.rs

1//! Symbolic differentiation — d/dx of any expression tree.
2
3use super::expr::Expr;
4
5/// Symbolically differentiate an expression with respect to a variable.
6pub fn diff(expr: &Expr, var: &str) -> Expr {
7    match expr {
8        Expr::Var(name) => {
9            if name == var { Expr::one() } else { Expr::zero() }
10        }
11        Expr::Const(_) => Expr::zero(),
12
13        // d/dx(-a) = -da
14        Expr::Neg(a) => Expr::Neg(Box::new(diff(a, var))),
15
16        // d/dx(a + b) = da + db
17        Expr::Add(a, b) => Expr::Add(Box::new(diff(a, var)), Box::new(diff(b, var))),
18
19        // d/dx(a - b) = da - db
20        Expr::Sub(a, b) => Expr::Sub(Box::new(diff(a, var)), Box::new(diff(b, var))),
21
22        // Product rule: d/dx(a * b) = a*db + da*b
23        Expr::Mul(a, b) => {
24            let left = Expr::Mul(a.clone(), Box::new(diff(b, var)));
25            let right = Expr::Mul(Box::new(diff(a, var)), b.clone());
26            Expr::Add(Box::new(left), Box::new(right))
27        }
28
29        // Quotient rule: d/dx(a/b) = (da*b - a*db) / b²
30        Expr::Div(a, b) => {
31            let num_left = Expr::Mul(Box::new(diff(a, var)), b.clone());
32            let num_right = Expr::Mul(a.clone(), Box::new(diff(b, var)));
33            let numerator = Expr::Sub(Box::new(num_left), Box::new(num_right));
34            let denominator = Expr::Pow(b.clone(), Box::new(Expr::c(2.0)));
35            Expr::Div(Box::new(numerator), Box::new(denominator))
36        }
37
38        // Power rule with chain rule
39        Expr::Pow(base, exp) => {
40            let base_has_var = base.contains_var(var);
41            let exp_has_var = exp.contains_var(var);
42
43            if !base_has_var && !exp_has_var {
44                Expr::zero()
45            } else if base_has_var && !exp_has_var {
46                // d/dx(f^n) = n * f^(n-1) * f'
47                let n_minus_1 = Expr::Sub(exp.clone(), Box::new(Expr::one()));
48                let term = Expr::Mul(
49                    exp.clone(),
50                    Box::new(Expr::Pow(base.clone(), Box::new(n_minus_1))),
51                );
52                Expr::Mul(Box::new(term), Box::new(diff(base, var)))
53            } else if !base_has_var && exp_has_var {
54                // d/dx(a^g) = a^g * ln(a) * g'
55                let term = Expr::Mul(
56                    Box::new(expr.clone()),
57                    Box::new(Expr::Ln(base.clone())),
58                );
59                Expr::Mul(Box::new(term), Box::new(diff(exp, var)))
60            } else {
61                // General: d/dx(f^g) = f^g * (g'*ln(f) + g*f'/f)
62                let ln_f = Expr::Ln(base.clone());
63                let term1 = Expr::Mul(Box::new(diff(exp, var)), Box::new(ln_f));
64                let term2 = Expr::Mul(
65                    exp.clone(),
66                    Box::new(Expr::Div(Box::new(diff(base, var)), base.clone())),
67                );
68                Expr::Mul(Box::new(expr.clone()), Box::new(Expr::Add(Box::new(term1), Box::new(term2))))
69            }
70        }
71
72        // Chain rule for trig/transcendental
73        Expr::Sin(a) => {
74            // d/dx sin(f) = cos(f) * f'
75            Expr::Mul(Box::new(Expr::Cos(a.clone())), Box::new(diff(a, var)))
76        }
77        Expr::Cos(a) => {
78            // d/dx cos(f) = -sin(f) * f'
79            Expr::Mul(
80                Box::new(Expr::Neg(Box::new(Expr::Sin(a.clone())))),
81                Box::new(diff(a, var)),
82            )
83        }
84        Expr::Tan(a) => {
85            // d/dx tan(f) = (1 + tan²(f)) * f' = sec²(f) * f'
86            let sec_sq = Expr::Add(
87                Box::new(Expr::one()),
88                Box::new(Expr::Pow(Box::new(Expr::Tan(a.clone())), Box::new(Expr::c(2.0)))),
89            );
90            Expr::Mul(Box::new(sec_sq), Box::new(diff(a, var)))
91        }
92        Expr::Ln(a) => {
93            // d/dx ln(f) = f'/f
94            Expr::Div(Box::new(diff(a, var)), a.clone())
95        }
96        Expr::Exp(a) => {
97            // d/dx exp(f) = exp(f) * f'
98            Expr::Mul(Box::new(Expr::Exp(a.clone())), Box::new(diff(a, var)))
99        }
100        Expr::Sqrt(a) => {
101            // d/dx √f = f' / (2√f)
102            Expr::Div(
103                Box::new(diff(a, var)),
104                Box::new(Expr::Mul(Box::new(Expr::c(2.0)), Box::new(Expr::Sqrt(a.clone())))),
105            )
106        }
107
108        // Default: return symbolic derivative node
109        _ => Expr::Derivative { body: Box::new(expr.clone()), var: var.to_string() },
110    }
111}
112
113#[cfg(test)]
114mod tests {
115    use super::*;
116    use std::collections::HashMap;
117
118    fn eval_at(expr: &Expr, x: f64) -> f64 {
119        let mut vars = HashMap::new();
120        vars.insert("x".to_string(), x);
121        expr.eval(&vars)
122    }
123
124    #[test]
125    fn diff_constant_is_zero() {
126        let d = diff(&Expr::c(5.0), "x");
127        assert_eq!(eval_at(&d, 1.0), 0.0);
128    }
129
130    #[test]
131    fn diff_x_is_one() {
132        let d = diff(&Expr::var("x"), "x");
133        assert_eq!(eval_at(&d, 42.0), 1.0);
134    }
135
136    #[test]
137    fn diff_x_squared() {
138        // d/dx(x²) = 2x
139        let expr = Expr::var("x").pow(Expr::c(2.0));
140        let d = diff(&expr, "x");
141        let result = eval_at(&d, 3.0);
142        assert!((result - 6.0).abs() < 0.01, "d/dx(x²) at x=3 should be 6, got {result}");
143    }
144
145    #[test]
146    fn diff_sin_x() {
147        // d/dx(sin(x)) = cos(x)
148        let expr = Expr::var("x").sin();
149        let d = diff(&expr, "x");
150        let result = eval_at(&d, 0.0);
151        assert!((result - 1.0).abs() < 0.01, "cos(0) should be 1, got {result}");
152    }
153
154    #[test]
155    fn diff_exp_x() {
156        // d/dx(e^x) = e^x
157        let expr = Expr::var("x").exp();
158        let d = diff(&expr, "x");
159        let result = eval_at(&d, 1.0);
160        let expected = std::f64::consts::E;
161        assert!((result - expected).abs() < 0.01);
162    }
163
164    #[test]
165    fn diff_product_rule() {
166        // d/dx(x * sin(x)) = sin(x) + x*cos(x)
167        let expr = Expr::var("x").mul(Expr::var("x").sin());
168        let d = diff(&expr, "x");
169        let x = 1.0;
170        let expected = x.sin() + x * x.cos();
171        let result = eval_at(&d, x);
172        assert!((result - expected).abs() < 0.01, "got {result}, expected {expected}");
173    }
174
175    #[test]
176    fn diff_ln_x() {
177        // d/dx(ln(x)) = 1/x
178        let expr = Expr::var("x").ln();
179        let d = diff(&expr, "x");
180        let result = eval_at(&d, 2.0);
181        assert!((result - 0.5).abs() < 0.01);
182    }
183}