1use crate::device::{GpuBuffer, GpuDevice};
8use anyhow::{Result, ensure};
9
10#[derive(Copy, Clone, Debug, PartialEq, Eq)]
12pub struct TensorId(pub u32);
13
14#[derive(Copy, Clone, Debug)]
16pub enum Op {
17 Leaf,
19 Add { a: TensorId, b: TensorId },
20 Sub { a: TensorId, b: TensorId },
21 Mul { a: TensorId, b: TensorId },
22 Scale { a: TensorId, s: f32 },
23 Relu { a: TensorId },
24 Sigmoid { a: TensorId },
25 Swish { a: TensorId },
26 Tanh { a: TensorId },
27 Matmul { a: TensorId, b: TensorId, m: u32, n: u32, k: u32 },
28 MseLoss { pred: TensorId, target: TensorId },
29 Conv2d {
30 input: TensorId,
31 weight: TensorId,
32 bias: Option<TensorId>,
33 batch: u32, in_c: u32, in_h: u32, in_w: u32,
34 out_c: u32, out_h: u32, out_w: u32,
35 kh: u32, kw: u32,
36 stride_h: u32, stride_w: u32,
37 pad_h: u32, pad_w: u32,
38 dil_h: u32, dil_w: u32,
39 groups: u32,
40 },
41}
42
43struct TapeEntry {
45 op: Op,
46 output: TensorId,
47}
48
49pub struct Tape<'d> {
51 dev: &'d GpuDevice,
52 entries: Vec<TapeEntry>,
53 bufs: Vec<GpuBuffer>,
54 grads: Vec<Option<GpuBuffer>>,
55}
56
57impl<'d> Tape<'d> {
58 pub fn new(dev: &'d GpuDevice) -> Self {
59 Self {
60 dev,
61 entries: Vec::new(),
62 bufs: Vec::new(),
63 grads: Vec::new(),
64 }
65 }
66
67 pub fn leaf(&mut self, data: &[f32]) -> TensorId {
69 let buf = self.dev.upload(data);
70 let id = TensorId(self.bufs.len() as u32);
71 self.bufs.push(buf);
72 self.grads.push(None);
73 self.entries.push(TapeEntry { op: Op::Leaf, output: id });
74 id
75 }
76
77 pub fn read(&self, id: TensorId) -> Result<Vec<f32>> {
79 self.dev.read(&self.bufs[id.0 as usize])
80 }
81
82 pub fn read_grad(&self, id: TensorId) -> Result<Option<Vec<f32>>> {
84 match &self.grads[id.0 as usize] {
85 Some(buf) => Ok(Some(self.dev.read(buf)?)),
86 None => Ok(None),
87 }
88 }
89
90 fn push_result(&mut self, buf: GpuBuffer, op: Op) -> TensorId {
91 let id = TensorId(self.bufs.len() as u32);
92 self.bufs.push(buf);
93 self.grads.push(None);
94 self.entries.push(TapeEntry { op, output: id });
95 id
96 }
97
98 fn buf(&self, id: TensorId) -> &GpuBuffer {
99 &self.bufs[id.0 as usize]
100 }
101
102 pub fn add(&mut self, a: TensorId, b: TensorId) -> Result<TensorId> {
105 let out = self.dev.add(self.buf(a), self.buf(b))?;
106 Ok(self.push_result(out, Op::Add { a, b }))
107 }
108
109 pub fn sub(&mut self, a: TensorId, b: TensorId) -> Result<TensorId> {
110 let out = self.dev.sub(self.buf(a), self.buf(b))?;
111 Ok(self.push_result(out, Op::Sub { a, b }))
112 }
113
114 pub fn mul(&mut self, a: TensorId, b: TensorId) -> Result<TensorId> {
115 let out = self.dev.mul(self.buf(a), self.buf(b))?;
116 Ok(self.push_result(out, Op::Mul { a, b }))
117 }
118
119 pub fn scale(&mut self, a: TensorId, s: f32) -> Result<TensorId> {
120 let out = self.dev.scale(self.buf(a), s)?;
121 Ok(self.push_result(out, Op::Scale { a, s }))
122 }
123
124 pub fn relu(&mut self, a: TensorId) -> Result<TensorId> {
125 let out = self.dev.relu(self.buf(a))?;
126 Ok(self.push_result(out, Op::Relu { a }))
127 }
128
129 pub fn sigmoid(&mut self, a: TensorId) -> Result<TensorId> {
130 let out = self.dev.sigmoid(self.buf(a))?;
131 Ok(self.push_result(out, Op::Sigmoid { a }))
132 }
133
134 pub fn swish(&mut self, a: TensorId) -> Result<TensorId> {
135 let out = self.dev.swish(self.buf(a))?;
136 Ok(self.push_result(out, Op::Swish { a }))
137 }
138
139 pub fn tanh_act(&mut self, a: TensorId) -> Result<TensorId> {
140 let out = self.dev.tanh_act(self.buf(a))?;
141 Ok(self.push_result(out, Op::Tanh { a }))
142 }
143
144 pub fn matmul(&mut self, a: TensorId, b: TensorId, m: u32, n: u32, k: u32) -> Result<TensorId> {
145 let out = self.dev.matmul(self.buf(a), self.buf(b), m, n, k)?;
146 Ok(self.push_result(out, Op::Matmul { a, b, m, n, k }))
147 }
148
149 pub fn mse_loss(&mut self, pred: TensorId, target: TensorId) -> Result<TensorId> {
150 let out = self.dev.mse_loss(self.buf(pred), self.buf(target))?;
151 Ok(self.push_result(out, Op::MseLoss { pred, target }))
152 }
153
154 pub fn conv2d(
155 &mut self,
156 input: TensorId,
157 weight: TensorId,
158 bias: Option<TensorId>,
159 batch: u32, in_c: u32, in_h: u32, in_w: u32,
160 out_c: u32, kh: u32, kw: u32,
161 stride: (u32, u32), padding: (u32, u32),
162 dilation: (u32, u32), groups: u32,
163 ) -> Result<TensorId> {
164 let out_h = (in_h + 2 * padding.0 - dilation.0 * (kh - 1) - 1) / stride.0 + 1;
165 let out_w = (in_w + 2 * padding.1 - dilation.1 * (kw - 1) - 1) / stride.1 + 1;
166 let out = self.dev.conv2d(
167 self.buf(input), self.buf(weight),
168 bias.map(|id| &self.bufs[id.0 as usize]).as_deref(),
169 batch, in_c, in_h, in_w, out_c, kh, kw, stride, padding, dilation, groups,
170 )?;
171 Ok(self.push_result(out, Op::Conv2d {
172 input, weight, bias,
173 batch, in_c, in_h, in_w,
174 out_c, out_h, out_w,
175 kh, kw,
176 stride_h: stride.0, stride_w: stride.1,
177 pad_h: padding.0, pad_w: padding.1,
178 dil_h: dilation.0, dil_w: dilation.1,
179 groups,
180 }))
181 }
182
183 fn accum_grad(&mut self, id: TensorId, grad: GpuBuffer) -> Result<()> {
187 match &self.grads[id.0 as usize] {
188 Some(existing) => {
189 let summed = self.dev.add(existing, &grad)?;
190 self.grads[id.0 as usize] = Some(summed);
191 }
192 None => {
193 self.grads[id.0 as usize] = Some(grad);
194 }
195 }
196 Ok(())
197 }
198
199 pub fn backward(&mut self, loss: TensorId) -> Result<()> {
201 ensure!(self.bufs[loss.0 as usize].len == 1, "backward: loss must be a scalar (1 element)");
202
203 self.grads[loss.0 as usize] = Some(self.dev.upload(&[1.0]));
205
206 for i in (0..self.entries.len()).rev() {
208 let entry = &self.entries[i];
209 let out_id = entry.output;
210
211 let grad_out = match &self.grads[out_id.0 as usize] {
213 Some(g) => g,
214 None => continue,
215 };
216
217 match entry.op {
220 Op::Leaf => {} Op::Add { a, b } => {
223 let ga = self.dev.scale(grad_out, 1.0)?; let gb = self.dev.scale(grad_out, 1.0)?;
226 self.accum_grad(a, ga)?;
227 self.accum_grad(b, gb)?;
228 }
229
230 Op::Sub { a, b } => {
231 let ga = self.dev.scale(grad_out, 1.0)?;
233 let gb = self.dev.scale(grad_out, -1.0)?;
234 self.accum_grad(a, ga)?;
235 self.accum_grad(b, gb)?;
236 }
237
238 Op::Mul { a, b } => {
239 let ga = self.dev.mul(grad_out, &self.bufs[b.0 as usize])?;
241 let gb = self.dev.mul(grad_out, &self.bufs[a.0 as usize])?;
242 self.accum_grad(a, ga)?;
243 self.accum_grad(b, gb)?;
244 }
245
246 Op::Scale { a, s } => {
247 let ga = self.dev.scale(grad_out, s)?;
249 self.accum_grad(a, ga)?;
250 }
251
252 Op::Relu { a } => {
253 let ga = self.dev.relu_backward(grad_out, &self.bufs[a.0 as usize])?;
255 self.accum_grad(a, ga)?;
256 }
257
258 Op::Sigmoid { a } => {
259 let ga = self.dev.sigmoid_backward(grad_out, &self.bufs[out_id.0 as usize])?;
261 self.accum_grad(a, ga)?;
262 }
263
264 Op::Swish { a } => {
265 let ga = self.dev.swish_backward(grad_out, &self.bufs[a.0 as usize])?;
267 self.accum_grad(a, ga)?;
268 }
269
270 Op::Tanh { a } => {
271 let ga = self.dev.tanh_backward(grad_out, &self.bufs[out_id.0 as usize])?;
273 self.accum_grad(a, ga)?;
274 }
275
276 Op::Matmul { a, b, m, n, k } => {
277 let bt = self.dev.transpose(&self.bufs[b.0 as usize], 1, k, n, 1)?;
279 let ga = self.dev.matmul(grad_out, &bt, m, k, n)?;
280 let at = self.dev.transpose(&self.bufs[a.0 as usize], 1, m, k, 1)?;
282 let gb = self.dev.matmul(&at, grad_out, k, n, m)?;
283 self.accum_grad(a, ga)?;
284 self.accum_grad(b, gb)?;
285 }
286
287 Op::MseLoss { pred, target } => {
288 let n = self.bufs[pred.0 as usize].len as f32;
290 let diff = self.dev.sub(&self.bufs[pred.0 as usize], &self.bufs[target.0 as usize])?;
291 let ga = self.dev.scale(&diff, 2.0 / n)?;
292 self.accum_grad(pred, ga)?;
293 }
294
295 Op::Conv2d { input, weight, bias, batch, in_c, in_h, in_w, out_c, out_h, out_w, kh, kw, stride_h, stride_w, pad_h, pad_w, dil_h, dil_w, groups } => {
296 let ga = self.dev.conv_transpose2d(
298 grad_out,
299 &self.bufs[weight.0 as usize],
300 None,
301 batch, out_c, out_h, out_w,
302 in_c, kh, kw,
303 (stride_h, stride_w),
304 (pad_h, pad_w),
305 (0, 0),
306 (dil_h, dil_w),
307 groups,
308 )?;
309 let gw = self.dev.conv2d_grad_weight(
311 &self.bufs[input.0 as usize],
312 grad_out,
313 batch, in_c, in_h, in_w,
314 out_c, out_h, out_w, kh, kw,
315 stride_h, stride_w, pad_h, pad_w,
316 dil_h, dil_w, groups,
317 )?;
318 let gb = if bias.is_some() {
320 Some(self.dev.conv2d_grad_bias(grad_out, batch, out_c, out_h, out_w)?)
321 } else {
322 None
323 };
324 self.accum_grad(input, ga)?;
325 self.accum_grad(weight, gw)?;
326 if let (Some(bias_id), Some(gb_buf)) = (bias, gb) {
327 self.accum_grad(bias_id, gb_buf)?;
328 }
329 }
330 }
331 }
332 Ok(())
333 }
334}
335
336#[cfg(test)]
337mod tests {
338 use super::*;
339 use crate::ops::assert_approx;
340
341 fn dev() -> &'static GpuDevice { &crate::ops::TEST_DEV }
342
343 #[test]
344 fn test_backward_add() {
345 let mut tape = Tape::new(dev());
346 let a = tape.leaf(&[1.0, 2.0, 3.0]);
347 let b = tape.leaf(&[4.0, 5.0, 6.0]);
348 let c = tape.add(a, b).unwrap();
349 let target = tape.leaf(&[0.0, 0.0, 0.0]);
354 let loss = tape.mse_loss(c, target).unwrap();
355 tape.backward(loss).unwrap();
356
357 let loss_val = tape.read(loss).unwrap();
359 assert_approx(&loss_val, &[155.0 / 3.0], 1e-3);
360
361 let ga = tape.read_grad(a).unwrap().unwrap();
365 let gb = tape.read_grad(b).unwrap().unwrap();
366 assert_approx(&ga, &[10.0/3.0, 14.0/3.0, 18.0/3.0], 1e-3);
367 assert_approx(&gb, &[10.0/3.0, 14.0/3.0, 18.0/3.0], 1e-3);
368 }
369
370 #[test]
371 fn test_backward_mul() {
372 let mut tape = Tape::new(dev());
373 let a = tape.leaf(&[2.0, 3.0]);
374 let b = tape.leaf(&[4.0, 5.0]);
375 let c = tape.mul(a, b).unwrap(); let target = tape.leaf(&[0.0, 0.0]);
377 let loss = tape.mse_loss(c, target).unwrap();
378 tape.backward(loss).unwrap();
379
380 let loss_val = tape.read(loss).unwrap();
382 assert_approx(&loss_val, &[144.5], 1e-3);
383
384 let ga = tape.read_grad(a).unwrap().unwrap();
389 let gb = tape.read_grad(b).unwrap().unwrap();
390 assert_approx(&ga, &[32.0, 75.0], 1e-3);
391 assert_approx(&gb, &[16.0, 45.0], 1e-3);
392 }
393
394 #[test]
395 fn test_backward_matmul() {
396 let mut tape = Tape::new(dev());
397 let a = tape.leaf(&[1.0, 2.0]); let b = tape.leaf(&[3.0, 4.0]); let c = tape.matmul(a, b, 1, 1, 2).unwrap(); let target = tape.leaf(&[0.0]);
402 let loss = tape.mse_loss(c, target).unwrap();
403 tape.backward(loss).unwrap();
404
405 let loss_val = tape.read(loss).unwrap();
407 assert_approx(&loss_val, &[121.0], 1e-3);
408
409 let ga = tape.read_grad(a).unwrap().unwrap();
413 let gb = tape.read_grad(b).unwrap().unwrap();
414 assert_approx(&ga, &[66.0, 88.0], 1e-3);
415 assert_approx(&gb, &[22.0, 44.0], 1e-3);
416 }
417
418 #[test]
419 fn test_backward_relu() {
420 let mut tape = Tape::new(dev());
421 let a = tape.leaf(&[-1.0, 2.0, -3.0, 4.0]);
422 let b = tape.relu(a).unwrap(); let target = tape.leaf(&[0.0, 0.0, 0.0, 0.0]);
424 let loss = tape.mse_loss(b, target).unwrap();
425 tape.backward(loss).unwrap();
426
427 let loss_val = tape.read(loss).unwrap();
429 assert_approx(&loss_val, &[5.0], 1e-3);
430
431 let ga = tape.read_grad(a).unwrap().unwrap();
435 assert_approx(&ga, &[0.0, 1.0, 0.0, 2.0], 1e-3);
436 }
437
438 #[test]
439 fn test_backward_scale() {
440 let mut tape = Tape::new(dev());
441 let a = tape.leaf(&[1.0, 2.0, 3.0]);
442 let b = tape.scale(a, 3.0).unwrap();
443 let target = tape.leaf(&[0.0, 0.0, 0.0]);
444 let loss = tape.mse_loss(b, target).unwrap();
445 tape.backward(loss).unwrap();
446 let ga = tape.read_grad(a).unwrap().unwrap();
447 assert_approx(&ga, &[6.0, 12.0, 18.0], 1e-3);
448 }
449
450 #[test]
451 fn test_backward_sub() {
452 let mut tape = Tape::new(dev());
453 let a = tape.leaf(&[5.0, 10.0]);
454 let b = tape.leaf(&[1.0, 2.0]);
455 let c = tape.sub(a, b).unwrap(); let target = tape.leaf(&[0.0, 0.0]);
457 let loss = tape.mse_loss(c, target).unwrap();
458 tape.backward(loss).unwrap();
459 let ga = tape.read_grad(a).unwrap().unwrap();
462 let gb = tape.read_grad(b).unwrap().unwrap();
463 assert_approx(&ga, &[4.0, 8.0], 1e-3);
464 assert_approx(&gb, &[-4.0, -8.0], 1e-3);
465 }
466
467 #[test]
468 fn test_backward_sigmoid() {
469 let mut tape = Tape::new(dev());
470 let a = tape.leaf(&[0.0, 1.0, -1.0]);
471 let b = tape.sigmoid(a).unwrap();
472 let target = tape.leaf(&[0.0, 0.0, 0.0]);
473 let loss = tape.mse_loss(b, target).unwrap();
474 tape.backward(loss).unwrap();
475
476 let s = [0.5f32, 0.7311, 0.2689];
481 let expected: Vec<f32> = (0..3).map(|i| 2.0 * s[i] / 3.0 * s[i] * (1.0 - s[i])).collect();
482 let ga = tape.read_grad(a).unwrap().unwrap();
483 assert_approx(&ga, &expected, 1e-3);
484 }
485
486 #[test]
487 fn test_backward_tanh() {
488 let mut tape = Tape::new(dev());
489 let a = tape.leaf(&[0.0, 1.0, -1.0]);
490 let b = tape.tanh_act(a).unwrap();
491 let target = tape.leaf(&[0.0, 0.0, 0.0]);
492 let loss = tape.mse_loss(b, target).unwrap();
493 tape.backward(loss).unwrap();
494
495 let t = [0.0f32, 0.7616, -0.7616];
499 let expected: Vec<f32> = (0..3).map(|i| 2.0 * t[i] / 3.0 * (1.0 - t[i] * t[i])).collect();
500 let ga = tape.read_grad(a).unwrap().unwrap();
501 assert_approx(&ga, &expected, 1e-2);
502 }
503
504 #[test]
505 fn test_backward_swish() {
506 let mut tape = Tape::new(dev());
507 let a = tape.leaf(&[0.0, 1.0, -1.0]);
508 let b = tape.swish(a).unwrap();
509 let target = tape.leaf(&[0.0, 0.0, 0.0]);
510 let loss = tape.mse_loss(b, target).unwrap();
511 tape.backward(loss).unwrap();
512
513 let x = [0.0f32, 1.0, -1.0];
515 let sw: Vec<f32> = x.iter().map(|&v| v / (1.0 + (-v).exp())).collect();
516 let expected: Vec<f32> = (0..3).map(|i| {
517 let s = 1.0 / (1.0 + (-x[i]).exp());
518 let d_swish = s + x[i] * s * (1.0 - s);
519 2.0 * sw[i] / 3.0 * d_swish
520 }).collect();
521 let ga = tape.read_grad(a).unwrap().unwrap();
522 assert_approx(&ga, &expected, 1e-2);
523 }
524
525 #[test]
526 fn test_read_grad_before_backward() {
527 let mut tape = Tape::new(dev());
528 let a = tape.leaf(&[1.0, 2.0]);
529 assert!(tape.read_grad(a).unwrap().is_none());
530 }
531
532 #[test]
533 fn test_backward_non_scalar_loss() {
534 let mut tape = Tape::new(dev());
535 let a = tape.leaf(&[1.0, 2.0]);
536 assert!(tape.backward(a).is_err());
538 }
539
540 #[test]
541 fn test_backward_diamond_graph() {
542 let mut tape = Tape::new(dev());
545 let a = tape.leaf(&[1.0]); let b = tape.scale(a, 2.0).unwrap(); let c = tape.scale(a, 3.0).unwrap(); let d = tape.add(b, c).unwrap(); let target = tape.leaf(&[0.0]);
550 let loss = tape.mse_loss(d, target).unwrap();
551 tape.backward(loss).unwrap();
552
553 let ga = tape.read_grad(a).unwrap().unwrap();
559 assert_approx(&ga, &[50.0], 1e-3);
560 }
561
562 #[test]
563 fn test_tape_leaf_data_roundtrip() {
564 let mut tape = Tape::new(dev());
565 let data = vec![1.5, -2.7, 0.0, 99.9];
566 let a = tape.leaf(&data);
567 assert_eq!(tape.read(a).unwrap(), data);
568 }
569
570 #[test]
571 fn test_tape_conv2d_forward() {
572 let mut tape = Tape::new(dev());
574 let input_data: Vec<f32> = (1..=9).map(|x| x as f32).collect();
575 let inp = tape.leaf(&input_data);
576 let w = tape.leaf(&[1.0f32]);
577 let b = tape.leaf(&[0.0f32]);
578 let out = tape.conv2d(inp, w, Some(b), 1, 1, 3, 3, 1, 1, 1, (1,1), (0,0), (1,1), 1).unwrap();
579 let result = tape.read(out).unwrap();
580 assert_approx(&result, &input_data, 1e-5);
581 }
582
583 #[test]
584 fn test_tape_conv2d_backward_weight_grad() {
585 let eps = 1e-3f32;
586 let input_data: Vec<f32> = (1..=9).map(|x| x as f32 * 0.1).collect();
587 let weight_data = vec![0.5f32];
588
589 let run = |w_val: f32| -> f32 {
590 let mut tape = Tape::new(dev());
591 let inp = tape.leaf(&input_data);
592 let w = tape.leaf(&[w_val]);
593 let out = tape.conv2d(inp, w, None, 1, 1, 3, 3, 1, 1, 1, (1,1), (0,0), (1,1), 1).unwrap();
594 let target = tape.leaf(&vec![0.0f32; 9]);
595 let loss = tape.mse_loss(out, target).unwrap();
596 tape.read(loss).unwrap()[0]
597 };
598
599 let mut tape = Tape::new(dev());
600 let inp = tape.leaf(&input_data);
601 let w = tape.leaf(&weight_data);
602 let out = tape.conv2d(inp, w, None, 1, 1, 3, 3, 1, 1, 1, (1,1), (0,0), (1,1), 1).unwrap();
603 let target = tape.leaf(&vec![0.0f32; 9]);
604 let loss = tape.mse_loss(out, target).unwrap();
605 tape.backward(loss).unwrap();
606 let gw = tape.read_grad(w).unwrap().unwrap();
607
608 let numeric = (run(weight_data[0] + eps) - run(weight_data[0] - eps)) / (2.0 * eps);
609 assert!((gw[0] - numeric).abs() < 1e-2,
610 "weight grad: analytical={}, numeric={}", gw[0], numeric);
611 }
612
613 #[test]
614 fn test_tape_conv2d_backward_input_grad() {
615 let eps = 1e-3f32;
616 let input_data: Vec<f32> = (1..=9).map(|x| x as f32 * 0.1).collect();
617 let weight_data = vec![0.5f32];
618
619 let run = |x_val: f32, idx: usize| -> f32 {
620 let mut inp_data = input_data.clone();
621 inp_data[idx] = x_val;
622 let mut tape = Tape::new(dev());
623 let inp = tape.leaf(&inp_data);
624 let w = tape.leaf(&weight_data);
625 let out = tape.conv2d(inp, w, None, 1, 1, 3, 3, 1, 1, 1, (1,1), (0,0), (1,1), 1).unwrap();
626 let target = tape.leaf(&vec![0.0f32; 9]);
627 let loss = tape.mse_loss(out, target).unwrap();
628 tape.read(loss).unwrap()[0]
629 };
630
631 let mut tape = Tape::new(dev());
632 let inp = tape.leaf(&input_data);
633 let w = tape.leaf(&weight_data);
634 let out = tape.conv2d(inp, w, None, 1, 1, 3, 3, 1, 1, 1, (1,1), (0,0), (1,1), 1).unwrap();
635 let target = tape.leaf(&vec![0.0f32; 9]);
636 let loss = tape.mse_loss(out, target).unwrap();
637 tape.backward(loss).unwrap();
638 let gi = tape.read_grad(inp).unwrap().unwrap();
639
640 for i in 0..9 {
641 let numeric = (run(input_data[i] + eps, i) - run(input_data[i] - eps, i)) / (2.0 * eps);
642 assert!((gi[i] - numeric).abs() < 1e-2,
643 "input grad[{i}]: analytical={}, numeric={}", gi[i], numeric);
644 }
645 }
646
647 #[test]
648 fn test_tape_conv2d_backward_bias_grad() {
649 let mut tape = Tape::new(dev());
652 let inp = tape.leaf(&[1.0f32, 2.0, 3.0, 4.0]);
653 let w = tape.leaf(&[1.0f32]);
654 let b = tape.leaf(&[0.0f32]);
655 let out = tape.conv2d(inp, w, Some(b), 1, 1, 2, 2, 1, 1, 1, (1,1), (0,0), (1,1), 1).unwrap();
656 let target = tape.leaf(&[0.0f32; 4]);
657 let loss = tape.mse_loss(out, target).unwrap();
658 tape.backward(loss).unwrap();
659
660 let gb = tape.read_grad(b).unwrap().unwrap();
664 assert_approx(&gb, &[5.0], 1e-3);
665 }
666}