Skip to main content

tang_expr/
diff.rs

1//! Symbolic differentiation.
2
3use std::collections::HashMap;
4
5use crate::graph::ExprGraph;
6use crate::node::{ExprId, Node};
7
8impl ExprGraph {
9    /// Differentiate `expr` with respect to `Var(var)`.
10    ///
11    /// Returns a new ExprId in the same graph. Uses memoization to avoid
12    /// recomputing derivatives of shared subexpressions.
13    pub fn diff(&mut self, expr: ExprId, var: u16) -> ExprId {
14        let mut memo = HashMap::new();
15        self.diff_inner(expr, var, &mut memo)
16    }
17
18    fn diff_inner(
19        &mut self,
20        expr: ExprId,
21        var: u16,
22        memo: &mut HashMap<(ExprId, u16), ExprId>,
23    ) -> ExprId {
24        if let Some(&cached) = memo.get(&(expr, var)) {
25            return cached;
26        }
27
28        let result = match self.node(expr) {
29            Node::Var(n) => {
30                if n == var {
31                    ExprId::ONE
32                } else {
33                    ExprId::ZERO
34                }
35            }
36            Node::Lit(_) => ExprId::ZERO,
37
38            Node::Add(a, b) => {
39                // d(a + b) = da + db
40                let da = self.diff_inner(a, var, memo);
41                let db = self.diff_inner(b, var, memo);
42                self.add(da, db)
43            }
44
45            Node::Mul(a, b) => {
46                // d(a * b) = da*b + a*db (product rule)
47                let da = self.diff_inner(a, var, memo);
48                let db = self.diff_inner(b, var, memo);
49                let t1 = self.mul(da, b);
50                let t2 = self.mul(a, db);
51                self.add(t1, t2)
52            }
53
54            Node::Neg(a) => {
55                // d(-a) = -da
56                let da = self.diff_inner(a, var, memo);
57                self.neg(da)
58            }
59
60            Node::Recip(a) => {
61                // d(1/a) = -da / a^2
62                let da = self.diff_inner(a, var, memo);
63                let a_sq = self.mul(a, a);
64                let r = self.recip(a_sq);
65                let t = self.mul(da, r);
66                self.neg(t)
67            }
68
69            Node::Sqrt(a) => {
70                // d(sqrt(a)) = da / (2 * sqrt(a))
71                let da = self.diff_inner(a, var, memo);
72                let sq = self.sqrt(a);
73                let two_sq = self.mul(ExprId::TWO, sq);
74                let r = self.recip(two_sq);
75                self.mul(da, r)
76            }
77
78            Node::Sin(a) => {
79                // d(sin(a)) = cos(a) * da
80                // cos(a) = sin(a + PI/2)
81                let da = self.diff_inner(a, var, memo);
82                let half_pi = self.lit(std::f64::consts::FRAC_PI_2);
83                let shifted = self.add(a, half_pi);
84                let cos_a = self.sin(shifted);
85                self.mul(cos_a, da)
86            }
87
88            Node::Atan2(y, x) => {
89                // d(atan2(y, x)) = (x*dy - y*dx) / (x^2 + y^2)
90                let dy = self.diff_inner(y, var, memo);
91                let dx = self.diff_inner(x, var, memo);
92                let x_dy = self.mul(x, dy);
93                let y_dx = self.mul(y, dx);
94                let neg_y_dx = self.neg(y_dx);
95                let numer = self.add(x_dy, neg_y_dx);
96                let xx = self.mul(x, x);
97                let yy = self.mul(y, y);
98                let denom = self.add(xx, yy);
99                let r = self.recip(denom);
100                self.mul(numer, r)
101            }
102
103            Node::Exp2(a) => {
104                // d(2^a) = ln(2) * 2^a * da
105                let da = self.diff_inner(a, var, memo);
106                let ln2 = self.lit(std::f64::consts::LN_2);
107                let exp2_a = self.exp2(a);
108                let t = self.mul(ln2, exp2_a);
109                self.mul(t, da)
110            }
111
112            Node::Log2(a) => {
113                // d(log2(a)) = da / (ln(2) * a)
114                let da = self.diff_inner(a, var, memo);
115                let ln2 = self.lit(std::f64::consts::LN_2);
116                let ln2_a = self.mul(ln2, a);
117                let r = self.recip(ln2_a);
118                self.mul(da, r)
119            }
120
121            Node::Select(c, a, b) => {
122                // Straight-through: condition doesn't contribute gradient
123                let da = self.diff_inner(a, var, memo);
124                let db = self.diff_inner(b, var, memo);
125                self.select(c, da, db)
126            }
127        };
128
129        memo.insert((expr, var), result);
130        result
131    }
132}
133
134#[cfg(test)]
135mod tests {
136    use crate::graph::ExprGraph;
137    use crate::node::ExprId;
138
139    #[test]
140    fn diff_constant() {
141        let mut g = ExprGraph::new();
142        let c = g.lit(5.0);
143        let dc = g.diff(c, 0);
144        assert_eq!(dc, ExprId::ZERO);
145    }
146
147    #[test]
148    fn diff_var_self() {
149        let mut g = ExprGraph::new();
150        let x = g.var(0);
151        let dx = g.diff(x, 0);
152        assert_eq!(dx, ExprId::ONE);
153    }
154
155    #[test]
156    fn diff_var_other() {
157        let mut g = ExprGraph::new();
158        let x = g.var(0);
159        let dx = g.diff(x, 1);
160        assert_eq!(dx, ExprId::ZERO);
161    }
162
163    #[test]
164    fn diff_add() {
165        // d/dx (x + 3) = 1
166        let mut g = ExprGraph::new();
167        let x = g.var(0);
168        let c = g.lit(3.0);
169        let sum = g.add(x, c);
170        let d = g.diff(sum, 0);
171        // d = Add(ONE, ZERO)
172        let result: f64 = g.eval(d, &[99.0]); // value doesn't matter
173        assert!((result - 1.0).abs() < 1e-10);
174    }
175
176    #[test]
177    fn diff_mul_product_rule() {
178        // d/dx (x * x) = 2x
179        let mut g = ExprGraph::new();
180        let x = g.var(0);
181        let xx = g.mul(x, x);
182        let d = g.diff(xx, 0);
183        // At x=3, d/dx x^2 = 6
184        let result: f64 = g.eval(d, &[3.0]);
185        assert!((result - 6.0).abs() < 1e-10);
186    }
187
188    #[test]
189    fn diff_sin() {
190        // d/dx sin(x) = cos(x)
191        let mut g = ExprGraph::new();
192        let x = g.var(0);
193        let s = g.sin(x);
194        let ds = g.diff(s, 0);
195        // At x=0, cos(0) = 1
196        let result: f64 = g.eval(ds, &[0.0]);
197        assert!((result - 1.0).abs() < 1e-10);
198    }
199
200    #[test]
201    fn diff_chain_rule() {
202        // d/dx sin(x^2) = 2x * cos(x^2)
203        let mut g = ExprGraph::new();
204        let x = g.var(0);
205        let xx = g.mul(x, x);
206        let s = g.sin(xx);
207        let ds = g.diff(s, 0);
208        // At x=1: 2*1*cos(1)
209        let expected = 2.0 * 1.0_f64.cos();
210        let result: f64 = g.eval(ds, &[1.0]);
211        assert!((result - expected).abs() < 1e-10);
212    }
213
214    #[test]
215    fn diff_sqrt() {
216        // d/dx sqrt(x) = 1 / (2*sqrt(x))
217        let mut g = ExprGraph::new();
218        let x = g.var(0);
219        let sq = g.sqrt(x);
220        let d = g.diff(sq, 0);
221        // At x=4: 1/(2*2) = 0.25
222        let result: f64 = g.eval(d, &[4.0]);
223        assert!((result - 0.25).abs() < 1e-10);
224    }
225
226    #[test]
227    fn diff_recip() {
228        // d/dx (1/x) = -1/x^2
229        let mut g = ExprGraph::new();
230        let x = g.var(0);
231        let r = g.recip(x);
232        let d = g.diff(r, 0);
233        // At x=2: -1/4 = -0.25
234        let result: f64 = g.eval(d, &[2.0]);
235        assert!((result - (-0.25)).abs() < 1e-10);
236    }
237
238    #[test]
239    fn diff_memoization() {
240        // Shared subexpression: d/dx (x*x + x*x)
241        // Should reuse derivative of x*x
242        let mut g = ExprGraph::new();
243        let x = g.var(0);
244        let xx = g.mul(x, x);
245        let sum = g.add(xx, xx);
246        let d = g.diff(sum, 0);
247        // d/dx (2x^2) = 4x, at x=3 → 12
248        let result: f64 = g.eval(d, &[3.0]);
249        assert!((result - 12.0).abs() < 1e-10);
250    }
251
252    #[test]
253    fn diff_select() {
254        // d/dx select(x, x*x, x+1) at x=2 (cond>0 → d/dx x² = 2x = 4)
255        let mut g = ExprGraph::new();
256        let x = g.var(0);
257        let xx = g.mul(x, x);
258        let xp1 = g.add(x, ExprId::ONE);
259        let s = g.select(x, xx, xp1);
260        let ds = g.diff(s, 0);
261        let result: f64 = g.eval(ds, &[2.0]);
262        assert!((result - 4.0).abs() < 1e-10);
263
264        // At x=-1 (cond<=0 → d/dx (x+1) = 1)
265        let result2: f64 = g.eval(ds, &[-1.0]);
266        assert!((result2 - 1.0).abs() < 1e-10);
267    }
268
269    #[test]
270    fn diff_dot_product() {
271        // f = x0*x3 + x1*x4 + x2*x5  (dot product)
272        // df/dx0 = x3
273        let mut g = ExprGraph::new();
274        let x0 = g.var(0);
275        let x1 = g.var(1);
276        let x2 = g.var(2);
277        let x3 = g.var(3);
278        let x4 = g.var(4);
279        let x5 = g.var(5);
280
281        let t0 = g.mul(x0, x3);
282        let t1 = g.mul(x1, x4);
283        let t2 = g.mul(x2, x5);
284        let s01 = g.add(t0, t1);
285        let dot = g.add(s01, t2);
286
287        let d0 = g.diff(dot, 0);
288        // df/dx0 = x3, at inputs [1,2,3,4,5,6] → 4
289        let result: f64 = g.eval(d0, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
290        assert!((result - 4.0).abs() < 1e-10);
291    }
292}