Skip to main content

any_gpu/
autograd.rs

1// Unlicense — cochranblock.org
2// Contributors: GotEmCoach, KOVA, Claude Opus 4.6
3//
4// Autograd: reverse-mode automatic differentiation.
5// Flat tape, enum ops, no trait objects. The tape owns all tensors.
6
7use crate::device::{GpuBuffer, GpuDevice};
8use anyhow::{Result, ensure};
9
10/// Tensor ID — index into the tape's tensor storage.
11#[derive(Copy, Clone, Debug, PartialEq, Eq)]
12pub struct TensorId(pub u32);
13
14/// Recorded operation for backward pass.
15#[derive(Copy, Clone, Debug)]
16pub enum Op {
17    /// Leaf tensor (parameter or input). No backward.
18    Leaf,
19    Add { a: TensorId, b: TensorId },
20    Sub { a: TensorId, b: TensorId },
21    Mul { a: TensorId, b: TensorId },
22    Scale { a: TensorId, s: f32 },
23    Relu { a: TensorId },
24    Sigmoid { a: TensorId },
25    Swish { a: TensorId },
26    Tanh { a: TensorId },
27    Matmul { a: TensorId, b: TensorId, m: u32, n: u32, k: u32 },
28    MseLoss { pred: TensorId, target: TensorId },
29    Conv2d {
30        input: TensorId,
31        weight: TensorId,
32        bias: Option<TensorId>,
33        batch: u32, in_c: u32, in_h: u32, in_w: u32,
34        out_c: u32, out_h: u32, out_w: u32,
35        kh: u32, kw: u32,
36        stride_h: u32, stride_w: u32,
37        pad_h: u32, pad_w: u32,
38        dil_h: u32, dil_w: u32,
39        groups: u32,
40    },
41}
42
43/// Tape entry: one recorded operation.
44struct TapeEntry {
45    op: Op,
46    output: TensorId,
47}
48
49/// Autograd tape. Records forward operations, runs backward to compute gradients.
50pub struct Tape<'d> {
51    dev: &'d GpuDevice,
52    entries: Vec<TapeEntry>,
53    bufs: Vec<GpuBuffer>,
54    grads: Vec<Option<GpuBuffer>>,
55}
56
57impl<'d> Tape<'d> {
58    pub fn new(dev: &'d GpuDevice) -> Self {
59        Self {
60            dev,
61            entries: Vec::new(),
62            bufs: Vec::new(),
63            grads: Vec::new(),
64        }
65    }
66
67    /// Register a leaf tensor (parameter or input data). No backward through this.
68    pub fn leaf(&mut self, data: &[f32]) -> TensorId {
69        let buf = self.dev.upload(data);
70        let id = TensorId(self.bufs.len() as u32);
71        self.bufs.push(buf);
72        self.grads.push(None);
73        self.entries.push(TapeEntry { op: Op::Leaf, output: id });
74        id
75    }
76
77    /// Read tensor data back to CPU.
78    pub fn read(&self, id: TensorId) -> Result<Vec<f32>> {
79        self.dev.read(&self.bufs[id.0 as usize])
80    }
81
82    /// Read gradient data back to CPU. Returns None if no gradient computed.
83    pub fn read_grad(&self, id: TensorId) -> Result<Option<Vec<f32>>> {
84        match &self.grads[id.0 as usize] {
85            Some(buf) => Ok(Some(self.dev.read(buf)?)),
86            None => Ok(None),
87        }
88    }
89
90    fn push_result(&mut self, buf: GpuBuffer, op: Op) -> TensorId {
91        let id = TensorId(self.bufs.len() as u32);
92        self.bufs.push(buf);
93        self.grads.push(None);
94        self.entries.push(TapeEntry { op, output: id });
95        id
96    }
97
98    fn buf(&self, id: TensorId) -> &GpuBuffer {
99        &self.bufs[id.0 as usize]
100    }
101
102    // --- Forward ops (recorded on tape) ---
103
104    pub fn add(&mut self, a: TensorId, b: TensorId) -> Result<TensorId> {
105        let out = self.dev.add(self.buf(a), self.buf(b))?;
106        Ok(self.push_result(out, Op::Add { a, b }))
107    }
108
109    pub fn sub(&mut self, a: TensorId, b: TensorId) -> Result<TensorId> {
110        let out = self.dev.sub(self.buf(a), self.buf(b))?;
111        Ok(self.push_result(out, Op::Sub { a, b }))
112    }
113
114    pub fn mul(&mut self, a: TensorId, b: TensorId) -> Result<TensorId> {
115        let out = self.dev.mul(self.buf(a), self.buf(b))?;
116        Ok(self.push_result(out, Op::Mul { a, b }))
117    }
118
119    pub fn scale(&mut self, a: TensorId, s: f32) -> Result<TensorId> {
120        let out = self.dev.scale(self.buf(a), s)?;
121        Ok(self.push_result(out, Op::Scale { a, s }))
122    }
123
124    pub fn relu(&mut self, a: TensorId) -> Result<TensorId> {
125        let out = self.dev.relu(self.buf(a))?;
126        Ok(self.push_result(out, Op::Relu { a }))
127    }
128
129    pub fn sigmoid(&mut self, a: TensorId) -> Result<TensorId> {
130        let out = self.dev.sigmoid(self.buf(a))?;
131        Ok(self.push_result(out, Op::Sigmoid { a }))
132    }
133
134    pub fn swish(&mut self, a: TensorId) -> Result<TensorId> {
135        let out = self.dev.swish(self.buf(a))?;
136        Ok(self.push_result(out, Op::Swish { a }))
137    }
138
139    pub fn tanh_act(&mut self, a: TensorId) -> Result<TensorId> {
140        let out = self.dev.tanh_act(self.buf(a))?;
141        Ok(self.push_result(out, Op::Tanh { a }))
142    }
143
144    pub fn matmul(&mut self, a: TensorId, b: TensorId, m: u32, n: u32, k: u32) -> Result<TensorId> {
145        let out = self.dev.matmul(self.buf(a), self.buf(b), m, n, k)?;
146        Ok(self.push_result(out, Op::Matmul { a, b, m, n, k }))
147    }
148
149    pub fn mse_loss(&mut self, pred: TensorId, target: TensorId) -> Result<TensorId> {
150        let out = self.dev.mse_loss(self.buf(pred), self.buf(target))?;
151        Ok(self.push_result(out, Op::MseLoss { pred, target }))
152    }
153
154    pub fn conv2d(
155        &mut self,
156        input: TensorId,
157        weight: TensorId,
158        bias: Option<TensorId>,
159        batch: u32, in_c: u32, in_h: u32, in_w: u32,
160        out_c: u32, kh: u32, kw: u32,
161        stride: (u32, u32), padding: (u32, u32),
162        dilation: (u32, u32), groups: u32,
163    ) -> Result<TensorId> {
164        let out_h = (in_h + 2 * padding.0 - dilation.0 * (kh - 1) - 1) / stride.0 + 1;
165        let out_w = (in_w + 2 * padding.1 - dilation.1 * (kw - 1) - 1) / stride.1 + 1;
166        let out = self.dev.conv2d(
167            self.buf(input), self.buf(weight),
168            bias.map(|id| &self.bufs[id.0 as usize]).as_deref(),
169            batch, in_c, in_h, in_w, out_c, kh, kw, stride, padding, dilation, groups,
170        )?;
171        Ok(self.push_result(out, Op::Conv2d {
172            input, weight, bias,
173            batch, in_c, in_h, in_w,
174            out_c, out_h, out_w,
175            kh, kw,
176            stride_h: stride.0, stride_w: stride.1,
177            pad_h: padding.0, pad_w: padding.1,
178            dil_h: dilation.0, dil_w: dilation.1,
179            groups,
180        }))
181    }
182
183    // --- Backward ---
184
185    /// Accumulate gradient into a tensor's grad buffer.
186    fn accum_grad(&mut self, id: TensorId, grad: GpuBuffer) -> Result<()> {
187        match &self.grads[id.0 as usize] {
188            Some(existing) => {
189                let summed = self.dev.add(existing, &grad)?;
190                self.grads[id.0 as usize] = Some(summed);
191            }
192            None => {
193                self.grads[id.0 as usize] = Some(grad);
194            }
195        }
196        Ok(())
197    }
198
199    /// Run backward pass from a loss tensor. Computes gradients for all tensors on the tape.
200    pub fn backward(&mut self, loss: TensorId) -> Result<()> {
201        ensure!(self.bufs[loss.0 as usize].len == 1, "backward: loss must be a scalar (1 element)");
202
203        // Seed: d(loss)/d(loss) = 1.0
204        self.grads[loss.0 as usize] = Some(self.dev.upload(&[1.0]));
205
206        // Walk tape in reverse
207        for i in (0..self.entries.len()).rev() {
208            let entry = &self.entries[i];
209            let out_id = entry.output;
210
211            // Skip if no gradient flows to this node
212            let grad_out = match &self.grads[out_id.0 as usize] {
213                Some(g) => g,
214                None => continue,
215            };
216
217            // Clone the grad_out reference data we need before mutating self
218            // We need to read grad_out's buffer info before calling accum_grad
219            match entry.op {
220                Op::Leaf => {} // no backward for leaves
221
222                Op::Add { a, b } => {
223                    // grad_a = grad_out, grad_b = grad_out
224                    let ga = self.dev.scale(grad_out, 1.0)?; // copy
225                    let gb = self.dev.scale(grad_out, 1.0)?;
226                    self.accum_grad(a, ga)?;
227                    self.accum_grad(b, gb)?;
228                }
229
230                Op::Sub { a, b } => {
231                    // grad_a = grad_out, grad_b = -grad_out
232                    let ga = self.dev.scale(grad_out, 1.0)?;
233                    let gb = self.dev.scale(grad_out, -1.0)?;
234                    self.accum_grad(a, ga)?;
235                    self.accum_grad(b, gb)?;
236                }
237
238                Op::Mul { a, b } => {
239                    // grad_a = grad_out * b, grad_b = grad_out * a
240                    let ga = self.dev.mul(grad_out, &self.bufs[b.0 as usize])?;
241                    let gb = self.dev.mul(grad_out, &self.bufs[a.0 as usize])?;
242                    self.accum_grad(a, ga)?;
243                    self.accum_grad(b, gb)?;
244                }
245
246                Op::Scale { a, s } => {
247                    // grad_a = grad_out * s
248                    let ga = self.dev.scale(grad_out, s)?;
249                    self.accum_grad(a, ga)?;
250                }
251
252                Op::Relu { a } => {
253                    // grad_a = grad_out * (input > 0)
254                    let ga = self.dev.relu_backward(grad_out, &self.bufs[a.0 as usize])?;
255                    self.accum_grad(a, ga)?;
256                }
257
258                Op::Sigmoid { a } => {
259                    // grad_a = grad_out * sig * (1 - sig) where sig = output
260                    let ga = self.dev.sigmoid_backward(grad_out, &self.bufs[out_id.0 as usize])?;
261                    self.accum_grad(a, ga)?;
262                }
263
264                Op::Swish { a } => {
265                    // grad_a = grad_out * (sig + x * sig * (1 - sig)) where sig = sigmoid(x)
266                    let ga = self.dev.swish_backward(grad_out, &self.bufs[a.0 as usize])?;
267                    self.accum_grad(a, ga)?;
268                }
269
270                Op::Tanh { a } => {
271                    // grad_a = grad_out * (1 - tanh(x)^2) where tanh(x) = output
272                    let ga = self.dev.tanh_backward(grad_out, &self.bufs[out_id.0 as usize])?;
273                    self.accum_grad(a, ga)?;
274                }
275
276                Op::Matmul { a, b, m, n, k } => {
277                    // grad_a = grad_out @ B^T  (grad_out is m x n, B is k x n, B^T is n x k -> grad_a is m x k)
278                    let bt = self.dev.transpose(&self.bufs[b.0 as usize], 1, k, n, 1)?;
279                    let ga = self.dev.matmul(grad_out, &bt, m, k, n)?;
280                    // grad_b = A^T @ grad_out  (A is m x k, A^T is k x m, grad_out is m x n -> grad_b is k x n)
281                    let at = self.dev.transpose(&self.bufs[a.0 as usize], 1, m, k, 1)?;
282                    let gb = self.dev.matmul(&at, grad_out, k, n, m)?;
283                    self.accum_grad(a, ga)?;
284                    self.accum_grad(b, gb)?;
285                }
286
287                Op::MseLoss { pred, target } => {
288                    // grad_pred = 2 * (pred - target) / n
289                    let n = self.bufs[pred.0 as usize].len as f32;
290                    let diff = self.dev.sub(&self.bufs[pred.0 as usize], &self.bufs[target.0 as usize])?;
291                    let ga = self.dev.scale(&diff, 2.0 / n)?;
292                    self.accum_grad(pred, ga)?;
293                }
294
295                Op::Conv2d { input, weight, bias, batch, in_c, in_h, in_w, out_c, out_h, out_w, kh, kw, stride_h, stride_w, pad_h, pad_w, dil_h, dil_w, groups } => {
296                    // grad_input via conv_transpose2d
297                    let ga = self.dev.conv_transpose2d(
298                        grad_out,
299                        &self.bufs[weight.0 as usize],
300                        None,
301                        batch, out_c, out_h, out_w,
302                        in_c, kh, kw,
303                        (stride_h, stride_w),
304                        (pad_h, pad_w),
305                        (0, 0),
306                        (dil_h, dil_w),
307                        groups,
308                    )?;
309                    // grad_weight
310                    let gw = self.dev.conv2d_grad_weight(
311                        &self.bufs[input.0 as usize],
312                        grad_out,
313                        batch, in_c, in_h, in_w,
314                        out_c, out_h, out_w, kh, kw,
315                        stride_h, stride_w, pad_h, pad_w,
316                        dil_h, dil_w, groups,
317                    )?;
318                    // grad_bias
319                    let gb = if bias.is_some() {
320                        Some(self.dev.conv2d_grad_bias(grad_out, batch, out_c, out_h, out_w)?)
321                    } else {
322                        None
323                    };
324                    self.accum_grad(input, ga)?;
325                    self.accum_grad(weight, gw)?;
326                    if let (Some(bias_id), Some(gb_buf)) = (bias, gb) {
327                        self.accum_grad(bias_id, gb_buf)?;
328                    }
329                }
330            }
331        }
332        Ok(())
333    }
334}
335
336#[cfg(test)]
337mod tests {
338    use super::*;
339    use crate::ops::assert_approx;
340
341    fn dev() -> &'static GpuDevice { &crate::ops::TEST_DEV }
342
343    #[test]
344    fn test_backward_add() {
345        let mut tape = Tape::new(dev());
346        let a = tape.leaf(&[1.0, 2.0, 3.0]);
347        let b = tape.leaf(&[4.0, 5.0, 6.0]);
348        let c = tape.add(a, b).unwrap();
349        // sum c to scalar for loss
350        // c = [5, 7, 9], loss = mean(c^2) - but let's use a simpler path
351        // Just test: loss = sum(c) via scale trick: loss_val = c[0]+c[1]+c[2]
352        // Actually, let's test with mse against zero
353        let target = tape.leaf(&[0.0, 0.0, 0.0]);
354        let loss = tape.mse_loss(c, target).unwrap();
355        tape.backward(loss).unwrap();
356
357        // MSE = (5^2 + 7^2 + 9^2)/3 = (25+49+81)/3 = 155/3
358        let loss_val = tape.read(loss).unwrap();
359        assert_approx(&loss_val, &[155.0 / 3.0], 1e-3);
360
361        // d(MSE)/d(pred) = 2*(pred-target)/n = 2*[5,7,9]/3
362        // d(pred)/d(a) = 1, d(pred)/d(b) = 1
363        // So grad_a = grad_b = 2*[5,7,9]/3
364        let ga = tape.read_grad(a).unwrap().unwrap();
365        let gb = tape.read_grad(b).unwrap().unwrap();
366        assert_approx(&ga, &[10.0/3.0, 14.0/3.0, 18.0/3.0], 1e-3);
367        assert_approx(&gb, &[10.0/3.0, 14.0/3.0, 18.0/3.0], 1e-3);
368    }
369
370    #[test]
371    fn test_backward_mul() {
372        let mut tape = Tape::new(dev());
373        let a = tape.leaf(&[2.0, 3.0]);
374        let b = tape.leaf(&[4.0, 5.0]);
375        let c = tape.mul(a, b).unwrap(); // c = [8, 15]
376        let target = tape.leaf(&[0.0, 0.0]);
377        let loss = tape.mse_loss(c, target).unwrap();
378        tape.backward(loss).unwrap();
379
380        // MSE = (64 + 225)/2 = 144.5
381        let loss_val = tape.read(loss).unwrap();
382        assert_approx(&loss_val, &[144.5], 1e-3);
383
384        // d(MSE)/d(c) = 2*[8,15]/2 = [8, 15]
385        // d(c)/d(a) = b = [4, 5], d(c)/d(b) = a = [2, 3]
386        // grad_a = [8,15] * [4,5] = [32, 75]
387        // grad_b = [8,15] * [2,3] = [16, 45]
388        let ga = tape.read_grad(a).unwrap().unwrap();
389        let gb = tape.read_grad(b).unwrap().unwrap();
390        assert_approx(&ga, &[32.0, 75.0], 1e-3);
391        assert_approx(&gb, &[16.0, 45.0], 1e-3);
392    }
393
394    #[test]
395    fn test_backward_matmul() {
396        let mut tape = Tape::new(dev());
397        // A = [[1, 2]], B = [[3], [4]] -> C = [[11]]
398        let a = tape.leaf(&[1.0, 2.0]); // 1x2
399        let b = tape.leaf(&[3.0, 4.0]); // 2x1
400        let c = tape.matmul(a, b, 1, 1, 2).unwrap(); // 1x1 = [[11]]
401        let target = tape.leaf(&[0.0]);
402        let loss = tape.mse_loss(c, target).unwrap();
403        tape.backward(loss).unwrap();
404
405        // MSE = 121/1 = 121
406        let loss_val = tape.read(loss).unwrap();
407        assert_approx(&loss_val, &[121.0], 1e-3);
408
409        // d(MSE)/d(c) = 2*11/1 = 22
410        // grad_a = grad_out @ B^T = [22] @ [3, 4] = [66, 88]
411        // grad_b = A^T @ grad_out = [[1],[2]] @ [22] = [22, 44]
412        let ga = tape.read_grad(a).unwrap().unwrap();
413        let gb = tape.read_grad(b).unwrap().unwrap();
414        assert_approx(&ga, &[66.0, 88.0], 1e-3);
415        assert_approx(&gb, &[22.0, 44.0], 1e-3);
416    }
417
418    #[test]
419    fn test_backward_relu() {
420        let mut tape = Tape::new(dev());
421        let a = tape.leaf(&[-1.0, 2.0, -3.0, 4.0]);
422        let b = tape.relu(a).unwrap(); // [0, 2, 0, 4]
423        let target = tape.leaf(&[0.0, 0.0, 0.0, 0.0]);
424        let loss = tape.mse_loss(b, target).unwrap();
425        tape.backward(loss).unwrap();
426
427        // MSE = (0 + 4 + 0 + 16)/4 = 5
428        let loss_val = tape.read(loss).unwrap();
429        assert_approx(&loss_val, &[5.0], 1e-3);
430
431        // d(MSE)/d(b) = 2*[0,2,0,4]/4 = [0, 1, 0, 2]
432        // d(relu)/d(a) = [0, 1, 0, 1] (mask where a > 0)
433        // grad_a = [0, 1, 0, 2] * [0, 1, 0, 1] = [0, 1, 0, 2]
434        let ga = tape.read_grad(a).unwrap().unwrap();
435        assert_approx(&ga, &[0.0, 1.0, 0.0, 2.0], 1e-3);
436    }
437
438    #[test]
439    fn test_backward_scale() {
440        let mut tape = Tape::new(dev());
441        let a = tape.leaf(&[1.0, 2.0, 3.0]);
442        let b = tape.scale(a, 3.0).unwrap();
443        let target = tape.leaf(&[0.0, 0.0, 0.0]);
444        let loss = tape.mse_loss(b, target).unwrap();
445        tape.backward(loss).unwrap();
446        let ga = tape.read_grad(a).unwrap().unwrap();
447        assert_approx(&ga, &[6.0, 12.0, 18.0], 1e-3);
448    }
449
450    #[test]
451    fn test_backward_sub() {
452        let mut tape = Tape::new(dev());
453        let a = tape.leaf(&[5.0, 10.0]);
454        let b = tape.leaf(&[1.0, 2.0]);
455        let c = tape.sub(a, b).unwrap(); // [4, 8]
456        let target = tape.leaf(&[0.0, 0.0]);
457        let loss = tape.mse_loss(c, target).unwrap();
458        tape.backward(loss).unwrap();
459        // d(MSE)/d(c) = 2*[4,8]/2 = [4, 8]
460        // grad_a = [4, 8] * 1 = [4, 8], grad_b = [4, 8] * (-1) = [-4, -8]
461        let ga = tape.read_grad(a).unwrap().unwrap();
462        let gb = tape.read_grad(b).unwrap().unwrap();
463        assert_approx(&ga, &[4.0, 8.0], 1e-3);
464        assert_approx(&gb, &[-4.0, -8.0], 1e-3);
465    }
466
467    #[test]
468    fn test_backward_sigmoid() {
469        let mut tape = Tape::new(dev());
470        let a = tape.leaf(&[0.0, 1.0, -1.0]);
471        let b = tape.sigmoid(a).unwrap();
472        let target = tape.leaf(&[0.0, 0.0, 0.0]);
473        let loss = tape.mse_loss(b, target).unwrap();
474        tape.backward(loss).unwrap();
475
476        // sig(0)=0.5, sig(1)=0.7311, sig(-1)=0.2689
477        // d(MSE)/d(b) = 2*[0.5, 0.7311, 0.2689]/3
478        // d(sig)/d(a) = sig*(1-sig) = [0.25, 0.1966, 0.1966]
479        // grad_a = d(MSE)/d(b) * d(sig)/d(a)
480        let s = [0.5f32, 0.7311, 0.2689];
481        let expected: Vec<f32> = (0..3).map(|i| 2.0 * s[i] / 3.0 * s[i] * (1.0 - s[i])).collect();
482        let ga = tape.read_grad(a).unwrap().unwrap();
483        assert_approx(&ga, &expected, 1e-3);
484    }
485
486    #[test]
487    fn test_backward_tanh() {
488        let mut tape = Tape::new(dev());
489        let a = tape.leaf(&[0.0, 1.0, -1.0]);
490        let b = tape.tanh_act(a).unwrap();
491        let target = tape.leaf(&[0.0, 0.0, 0.0]);
492        let loss = tape.mse_loss(b, target).unwrap();
493        tape.backward(loss).unwrap();
494
495        // tanh(0)=0, tanh(1)=0.7616, tanh(-1)=-0.7616
496        // d(MSE)/d(b) = 2*[0, 0.7616, -0.7616]/3
497        // d(tanh)/d(a) = 1-tanh^2 = [1, 0.4200, 0.4200]
498        let t = [0.0f32, 0.7616, -0.7616];
499        let expected: Vec<f32> = (0..3).map(|i| 2.0 * t[i] / 3.0 * (1.0 - t[i] * t[i])).collect();
500        let ga = tape.read_grad(a).unwrap().unwrap();
501        assert_approx(&ga, &expected, 1e-2);
502    }
503
504    #[test]
505    fn test_backward_swish() {
506        let mut tape = Tape::new(dev());
507        let a = tape.leaf(&[0.0, 1.0, -1.0]);
508        let b = tape.swish(a).unwrap();
509        let target = tape.leaf(&[0.0, 0.0, 0.0]);
510        let loss = tape.mse_loss(b, target).unwrap();
511        tape.backward(loss).unwrap();
512
513        // swish(x) = x*sig(x), d(swish)/d(x) = sig(x) + x*sig(x)*(1-sig(x))
514        let x = [0.0f32, 1.0, -1.0];
515        let sw: Vec<f32> = x.iter().map(|&v| v / (1.0 + (-v).exp())).collect();
516        let expected: Vec<f32> = (0..3).map(|i| {
517            let s = 1.0 / (1.0 + (-x[i]).exp());
518            let d_swish = s + x[i] * s * (1.0 - s);
519            2.0 * sw[i] / 3.0 * d_swish
520        }).collect();
521        let ga = tape.read_grad(a).unwrap().unwrap();
522        assert_approx(&ga, &expected, 1e-2);
523    }
524
525    #[test]
526    fn test_read_grad_before_backward() {
527        let mut tape = Tape::new(dev());
528        let a = tape.leaf(&[1.0, 2.0]);
529        assert!(tape.read_grad(a).unwrap().is_none());
530    }
531
532    #[test]
533    fn test_backward_non_scalar_loss() {
534        let mut tape = Tape::new(dev());
535        let a = tape.leaf(&[1.0, 2.0]);
536        // Try backward on a non-scalar — should error
537        assert!(tape.backward(a).is_err());
538    }
539
540    #[test]
541    fn test_backward_diamond_graph() {
542        // a -> b = a*2, a -> c = a*3, d = b+c, loss = mse(d, target)
543        // Tests gradient accumulation: a receives grad from both b and c paths
544        let mut tape = Tape::new(dev());
545        let a = tape.leaf(&[1.0]); // scalar
546        let b = tape.scale(a, 2.0).unwrap(); // 2
547        let c = tape.scale(a, 3.0).unwrap(); // 3
548        let d = tape.add(b, c).unwrap(); // 5
549        let target = tape.leaf(&[0.0]);
550        let loss = tape.mse_loss(d, target).unwrap();
551        tape.backward(loss).unwrap();
552
553        // d=5, MSE=25, d(MSE)/d(d)=10
554        // grad_b = 10, grad_c = 10
555        // grad_a from b path: 10*2 = 20
556        // grad_a from c path: 10*3 = 30
557        // total grad_a = 50
558        let ga = tape.read_grad(a).unwrap().unwrap();
559        assert_approx(&ga, &[50.0], 1e-3);
560    }
561
562    #[test]
563    fn test_tape_leaf_data_roundtrip() {
564        let mut tape = Tape::new(dev());
565        let data = vec![1.5, -2.7, 0.0, 99.9];
566        let a = tape.leaf(&data);
567        assert_eq!(tape.read(a).unwrap(), data);
568    }
569
570    #[test]
571    fn test_tape_conv2d_forward() {
572        // 1x1x3x3 input, 1x1x1x1 weight=1, bias=0 -> output == input
573        let mut tape = Tape::new(dev());
574        let input_data: Vec<f32> = (1..=9).map(|x| x as f32).collect();
575        let inp = tape.leaf(&input_data);
576        let w = tape.leaf(&[1.0f32]);
577        let b = tape.leaf(&[0.0f32]);
578        let out = tape.conv2d(inp, w, Some(b), 1, 1, 3, 3, 1, 1, 1, (1,1), (0,0), (1,1), 1).unwrap();
579        let result = tape.read(out).unwrap();
580        assert_approx(&result, &input_data, 1e-5);
581    }
582
583    #[test]
584    fn test_tape_conv2d_backward_weight_grad() {
585        let eps = 1e-3f32;
586        let input_data: Vec<f32> = (1..=9).map(|x| x as f32 * 0.1).collect();
587        let weight_data = vec![0.5f32];
588
589        let run = |w_val: f32| -> f32 {
590            let mut tape = Tape::new(dev());
591            let inp = tape.leaf(&input_data);
592            let w = tape.leaf(&[w_val]);
593            let out = tape.conv2d(inp, w, None, 1, 1, 3, 3, 1, 1, 1, (1,1), (0,0), (1,1), 1).unwrap();
594            let target = tape.leaf(&vec![0.0f32; 9]);
595            let loss = tape.mse_loss(out, target).unwrap();
596            tape.read(loss).unwrap()[0]
597        };
598
599        let mut tape = Tape::new(dev());
600        let inp = tape.leaf(&input_data);
601        let w = tape.leaf(&weight_data);
602        let out = tape.conv2d(inp, w, None, 1, 1, 3, 3, 1, 1, 1, (1,1), (0,0), (1,1), 1).unwrap();
603        let target = tape.leaf(&vec![0.0f32; 9]);
604        let loss = tape.mse_loss(out, target).unwrap();
605        tape.backward(loss).unwrap();
606        let gw = tape.read_grad(w).unwrap().unwrap();
607
608        let numeric = (run(weight_data[0] + eps) - run(weight_data[0] - eps)) / (2.0 * eps);
609        assert!((gw[0] - numeric).abs() < 1e-2,
610            "weight grad: analytical={}, numeric={}", gw[0], numeric);
611    }
612
613    #[test]
614    fn test_tape_conv2d_backward_input_grad() {
615        let eps = 1e-3f32;
616        let input_data: Vec<f32> = (1..=9).map(|x| x as f32 * 0.1).collect();
617        let weight_data = vec![0.5f32];
618
619        let run = |x_val: f32, idx: usize| -> f32 {
620            let mut inp_data = input_data.clone();
621            inp_data[idx] = x_val;
622            let mut tape = Tape::new(dev());
623            let inp = tape.leaf(&inp_data);
624            let w = tape.leaf(&weight_data);
625            let out = tape.conv2d(inp, w, None, 1, 1, 3, 3, 1, 1, 1, (1,1), (0,0), (1,1), 1).unwrap();
626            let target = tape.leaf(&vec![0.0f32; 9]);
627            let loss = tape.mse_loss(out, target).unwrap();
628            tape.read(loss).unwrap()[0]
629        };
630
631        let mut tape = Tape::new(dev());
632        let inp = tape.leaf(&input_data);
633        let w = tape.leaf(&weight_data);
634        let out = tape.conv2d(inp, w, None, 1, 1, 3, 3, 1, 1, 1, (1,1), (0,0), (1,1), 1).unwrap();
635        let target = tape.leaf(&vec![0.0f32; 9]);
636        let loss = tape.mse_loss(out, target).unwrap();
637        tape.backward(loss).unwrap();
638        let gi = tape.read_grad(inp).unwrap().unwrap();
639
640        for i in 0..9 {
641            let numeric = (run(input_data[i] + eps, i) - run(input_data[i] - eps, i)) / (2.0 * eps);
642            assert!((gi[i] - numeric).abs() < 1e-2,
643                "input grad[{i}]: analytical={}, numeric={}", gi[i], numeric);
644        }
645    }
646
647    #[test]
648    fn test_tape_conv2d_backward_bias_grad() {
649        // 1x1x2x2 input, 1x1x1x1 kernel, with bias
650        // out is 2x2, grad_bias = sum of grad_out over spatial
651        let mut tape = Tape::new(dev());
652        let inp = tape.leaf(&[1.0f32, 2.0, 3.0, 4.0]);
653        let w = tape.leaf(&[1.0f32]);
654        let b = tape.leaf(&[0.0f32]);
655        let out = tape.conv2d(inp, w, Some(b), 1, 1, 2, 2, 1, 1, 1, (1,1), (0,0), (1,1), 1).unwrap();
656        let target = tape.leaf(&[0.0f32; 4]);
657        let loss = tape.mse_loss(out, target).unwrap();
658        tape.backward(loss).unwrap();
659
660        // output = input (1x1 kernel=1, bias=0), target=0
661        // MSE grad = 2*output/4 = output/2 = [0.5, 1.0, 1.5, 2.0]
662        // grad_bias = sum = 0.5 + 1.0 + 1.5 + 2.0 = 5.0
663        let gb = tape.read_grad(b).unwrap().unwrap();
664        assert_approx(&gb, &[5.0], 1e-3);
665    }
666}