use crate::device::{GpuBuffer, GpuDevice};
use anyhow::{Result, ensure};
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub struct TensorId(pub u32);
#[derive(Copy, Clone, Debug)]
pub enum Op {
Leaf,
Add { a: TensorId, b: TensorId },
Sub { a: TensorId, b: TensorId },
Mul { a: TensorId, b: TensorId },
Scale { a: TensorId, s: f32 },
Relu { a: TensorId },
Sigmoid { a: TensorId },
Swish { a: TensorId },
Tanh { a: TensorId },
Matmul { a: TensorId, b: TensorId, m: u32, n: u32, k: u32 },
MseLoss { pred: TensorId, target: TensorId },
Conv2d {
input: TensorId,
weight: TensorId,
bias: Option<TensorId>,
batch: u32, in_c: u32, in_h: u32, in_w: u32,
out_c: u32, out_h: u32, out_w: u32,
kh: u32, kw: u32,
stride_h: u32, stride_w: u32,
pad_h: u32, pad_w: u32,
dil_h: u32, dil_w: u32,
groups: u32,
},
}
struct TapeEntry {
op: Op,
output: TensorId,
}
pub struct Tape<'d> {
dev: &'d GpuDevice,
entries: Vec<TapeEntry>,
bufs: Vec<GpuBuffer>,
grads: Vec<Option<GpuBuffer>>,
}
impl<'d> Tape<'d> {
pub fn new(dev: &'d GpuDevice) -> Self {
Self {
dev,
entries: Vec::new(),
bufs: Vec::new(),
grads: Vec::new(),
}
}
pub fn leaf(&mut self, data: &[f32]) -> TensorId {
let buf = self.dev.upload(data);
let id = TensorId(self.bufs.len() as u32);
self.bufs.push(buf);
self.grads.push(None);
self.entries.push(TapeEntry { op: Op::Leaf, output: id });
id
}
pub fn read(&self, id: TensorId) -> Result<Vec<f32>> {
self.dev.read(&self.bufs[id.0 as usize])
}
pub fn read_grad(&self, id: TensorId) -> Result<Option<Vec<f32>>> {
match &self.grads[id.0 as usize] {
Some(buf) => Ok(Some(self.dev.read(buf)?)),
None => Ok(None),
}
}
fn push_result(&mut self, buf: GpuBuffer, op: Op) -> TensorId {
let id = TensorId(self.bufs.len() as u32);
self.bufs.push(buf);
self.grads.push(None);
self.entries.push(TapeEntry { op, output: id });
id
}
fn buf(&self, id: TensorId) -> &GpuBuffer {
&self.bufs[id.0 as usize]
}
pub fn add(&mut self, a: TensorId, b: TensorId) -> Result<TensorId> {
let out = self.dev.add(self.buf(a), self.buf(b))?;
Ok(self.push_result(out, Op::Add { a, b }))
}
pub fn sub(&mut self, a: TensorId, b: TensorId) -> Result<TensorId> {
let out = self.dev.sub(self.buf(a), self.buf(b))?;
Ok(self.push_result(out, Op::Sub { a, b }))
}
pub fn mul(&mut self, a: TensorId, b: TensorId) -> Result<TensorId> {
let out = self.dev.mul(self.buf(a), self.buf(b))?;
Ok(self.push_result(out, Op::Mul { a, b }))
}
pub fn scale(&mut self, a: TensorId, s: f32) -> Result<TensorId> {
let out = self.dev.scale(self.buf(a), s)?;
Ok(self.push_result(out, Op::Scale { a, s }))
}
pub fn relu(&mut self, a: TensorId) -> Result<TensorId> {
let out = self.dev.relu(self.buf(a))?;
Ok(self.push_result(out, Op::Relu { a }))
}
pub fn sigmoid(&mut self, a: TensorId) -> Result<TensorId> {
let out = self.dev.sigmoid(self.buf(a))?;
Ok(self.push_result(out, Op::Sigmoid { a }))
}
pub fn swish(&mut self, a: TensorId) -> Result<TensorId> {
let out = self.dev.swish(self.buf(a))?;
Ok(self.push_result(out, Op::Swish { a }))
}
pub fn tanh_act(&mut self, a: TensorId) -> Result<TensorId> {
let out = self.dev.tanh_act(self.buf(a))?;
Ok(self.push_result(out, Op::Tanh { a }))
}
pub fn matmul(&mut self, a: TensorId, b: TensorId, m: u32, n: u32, k: u32) -> Result<TensorId> {
let out = self.dev.matmul(self.buf(a), self.buf(b), m, n, k)?;
Ok(self.push_result(out, Op::Matmul { a, b, m, n, k }))
}
pub fn mse_loss(&mut self, pred: TensorId, target: TensorId) -> Result<TensorId> {
let out = self.dev.mse_loss(self.buf(pred), self.buf(target))?;
Ok(self.push_result(out, Op::MseLoss { pred, target }))
}
pub fn conv2d(
&mut self,
input: TensorId,
weight: TensorId,
bias: Option<TensorId>,
batch: u32, in_c: u32, in_h: u32, in_w: u32,
out_c: u32, kh: u32, kw: u32,
stride: (u32, u32), padding: (u32, u32),
dilation: (u32, u32), groups: u32,
) -> Result<TensorId> {
let out_h = (in_h + 2 * padding.0 - dilation.0 * (kh - 1) - 1) / stride.0 + 1;
let out_w = (in_w + 2 * padding.1 - dilation.1 * (kw - 1) - 1) / stride.1 + 1;
let out = self.dev.conv2d(
self.buf(input), self.buf(weight),
bias.map(|id| &self.bufs[id.0 as usize]).as_deref(),
batch, in_c, in_h, in_w, out_c, kh, kw, stride, padding, dilation, groups,
)?;
Ok(self.push_result(out, Op::Conv2d {
input, weight, bias,
batch, in_c, in_h, in_w,
out_c, out_h, out_w,
kh, kw,
stride_h: stride.0, stride_w: stride.1,
pad_h: padding.0, pad_w: padding.1,
dil_h: dilation.0, dil_w: dilation.1,
groups,
}))
}
fn accum_grad(&mut self, id: TensorId, grad: GpuBuffer) -> Result<()> {
match &self.grads[id.0 as usize] {
Some(existing) => {
let summed = self.dev.add(existing, &grad)?;
self.grads[id.0 as usize] = Some(summed);
}
None => {
self.grads[id.0 as usize] = Some(grad);
}
}
Ok(())
}
pub fn backward(&mut self, loss: TensorId) -> Result<()> {
ensure!(self.bufs[loss.0 as usize].len == 1, "backward: loss must be a scalar (1 element)");
self.grads[loss.0 as usize] = Some(self.dev.upload(&[1.0]));
for i in (0..self.entries.len()).rev() {
let entry = &self.entries[i];
let out_id = entry.output;
let grad_out = match &self.grads[out_id.0 as usize] {
Some(g) => g,
None => continue,
};
match entry.op {
Op::Leaf => {}
Op::Add { a, b } => {
let ga = self.dev.scale(grad_out, 1.0)?; let gb = self.dev.scale(grad_out, 1.0)?;
self.accum_grad(a, ga)?;
self.accum_grad(b, gb)?;
}
Op::Sub { a, b } => {
let ga = self.dev.scale(grad_out, 1.0)?;
let gb = self.dev.scale(grad_out, -1.0)?;
self.accum_grad(a, ga)?;
self.accum_grad(b, gb)?;
}
Op::Mul { a, b } => {
let ga = self.dev.mul(grad_out, &self.bufs[b.0 as usize])?;
let gb = self.dev.mul(grad_out, &self.bufs[a.0 as usize])?;
self.accum_grad(a, ga)?;
self.accum_grad(b, gb)?;
}
Op::Scale { a, s } => {
let ga = self.dev.scale(grad_out, s)?;
self.accum_grad(a, ga)?;
}
Op::Relu { a } => {
let ga = self.dev.relu_backward(grad_out, &self.bufs[a.0 as usize])?;
self.accum_grad(a, ga)?;
}
Op::Sigmoid { a } => {
let ga = self.dev.sigmoid_backward(grad_out, &self.bufs[out_id.0 as usize])?;
self.accum_grad(a, ga)?;
}
Op::Swish { a } => {
let ga = self.dev.swish_backward(grad_out, &self.bufs[a.0 as usize])?;
self.accum_grad(a, ga)?;
}
Op::Tanh { a } => {
let ga = self.dev.tanh_backward(grad_out, &self.bufs[out_id.0 as usize])?;
self.accum_grad(a, ga)?;
}
Op::Matmul { a, b, m, n, k } => {
let bt = self.dev.transpose(&self.bufs[b.0 as usize], 1, k, n, 1)?;
let ga = self.dev.matmul(grad_out, &bt, m, k, n)?;
let at = self.dev.transpose(&self.bufs[a.0 as usize], 1, m, k, 1)?;
let gb = self.dev.matmul(&at, grad_out, k, n, m)?;
self.accum_grad(a, ga)?;
self.accum_grad(b, gb)?;
}
Op::MseLoss { pred, target } => {
let n = self.bufs[pred.0 as usize].len as f32;
let diff = self.dev.sub(&self.bufs[pred.0 as usize], &self.bufs[target.0 as usize])?;
let ga = self.dev.scale(&diff, 2.0 / n)?;
self.accum_grad(pred, ga)?;
}
Op::Conv2d { input, weight, bias, batch, in_c, in_h, in_w, out_c, out_h, out_w, kh, kw, stride_h, stride_w, pad_h, pad_w, dil_h, dil_w, groups } => {
let ga = self.dev.conv_transpose2d(
grad_out,
&self.bufs[weight.0 as usize],
None,
batch, out_c, out_h, out_w,
in_c, kh, kw,
(stride_h, stride_w),
(pad_h, pad_w),
(0, 0),
(dil_h, dil_w),
groups,
)?;
let gw = self.dev.conv2d_grad_weight(
&self.bufs[input.0 as usize],
grad_out,
batch, in_c, in_h, in_w,
out_c, out_h, out_w, kh, kw,
stride_h, stride_w, pad_h, pad_w,
dil_h, dil_w, groups,
)?;
let gb = if bias.is_some() {
Some(self.dev.conv2d_grad_bias(grad_out, batch, out_c, out_h, out_w)?)
} else {
None
};
self.accum_grad(input, ga)?;
self.accum_grad(weight, gw)?;
if let (Some(bias_id), Some(gb_buf)) = (bias, gb) {
self.accum_grad(bias_id, gb_buf)?;
}
}
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ops::assert_approx;
fn dev() -> &'static GpuDevice { &crate::ops::TEST_DEV }
#[test]
fn test_backward_add() {
let mut tape = Tape::new(dev());
let a = tape.leaf(&[1.0, 2.0, 3.0]);
let b = tape.leaf(&[4.0, 5.0, 6.0]);
let c = tape.add(a, b).unwrap();
let target = tape.leaf(&[0.0, 0.0, 0.0]);
let loss = tape.mse_loss(c, target).unwrap();
tape.backward(loss).unwrap();
let loss_val = tape.read(loss).unwrap();
assert_approx(&loss_val, &[155.0 / 3.0], 1e-3);
let ga = tape.read_grad(a).unwrap().unwrap();
let gb = tape.read_grad(b).unwrap().unwrap();
assert_approx(&ga, &[10.0/3.0, 14.0/3.0, 18.0/3.0], 1e-3);
assert_approx(&gb, &[10.0/3.0, 14.0/3.0, 18.0/3.0], 1e-3);
}
#[test]
fn test_backward_mul() {
let mut tape = Tape::new(dev());
let a = tape.leaf(&[2.0, 3.0]);
let b = tape.leaf(&[4.0, 5.0]);
let c = tape.mul(a, b).unwrap(); let target = tape.leaf(&[0.0, 0.0]);
let loss = tape.mse_loss(c, target).unwrap();
tape.backward(loss).unwrap();
let loss_val = tape.read(loss).unwrap();
assert_approx(&loss_val, &[144.5], 1e-3);
let ga = tape.read_grad(a).unwrap().unwrap();
let gb = tape.read_grad(b).unwrap().unwrap();
assert_approx(&ga, &[32.0, 75.0], 1e-3);
assert_approx(&gb, &[16.0, 45.0], 1e-3);
}
#[test]
fn test_backward_matmul() {
let mut tape = Tape::new(dev());
let a = tape.leaf(&[1.0, 2.0]); let b = tape.leaf(&[3.0, 4.0]); let c = tape.matmul(a, b, 1, 1, 2).unwrap(); let target = tape.leaf(&[0.0]);
let loss = tape.mse_loss(c, target).unwrap();
tape.backward(loss).unwrap();
let loss_val = tape.read(loss).unwrap();
assert_approx(&loss_val, &[121.0], 1e-3);
let ga = tape.read_grad(a).unwrap().unwrap();
let gb = tape.read_grad(b).unwrap().unwrap();
assert_approx(&ga, &[66.0, 88.0], 1e-3);
assert_approx(&gb, &[22.0, 44.0], 1e-3);
}
#[test]
fn test_backward_relu() {
let mut tape = Tape::new(dev());
let a = tape.leaf(&[-1.0, 2.0, -3.0, 4.0]);
let b = tape.relu(a).unwrap(); let target = tape.leaf(&[0.0, 0.0, 0.0, 0.0]);
let loss = tape.mse_loss(b, target).unwrap();
tape.backward(loss).unwrap();
let loss_val = tape.read(loss).unwrap();
assert_approx(&loss_val, &[5.0], 1e-3);
let ga = tape.read_grad(a).unwrap().unwrap();
assert_approx(&ga, &[0.0, 1.0, 0.0, 2.0], 1e-3);
}
#[test]
fn test_backward_scale() {
let mut tape = Tape::new(dev());
let a = tape.leaf(&[1.0, 2.0, 3.0]);
let b = tape.scale(a, 3.0).unwrap();
let target = tape.leaf(&[0.0, 0.0, 0.0]);
let loss = tape.mse_loss(b, target).unwrap();
tape.backward(loss).unwrap();
let ga = tape.read_grad(a).unwrap().unwrap();
assert_approx(&ga, &[6.0, 12.0, 18.0], 1e-3);
}
#[test]
fn test_backward_sub() {
let mut tape = Tape::new(dev());
let a = tape.leaf(&[5.0, 10.0]);
let b = tape.leaf(&[1.0, 2.0]);
let c = tape.sub(a, b).unwrap(); let target = tape.leaf(&[0.0, 0.0]);
let loss = tape.mse_loss(c, target).unwrap();
tape.backward(loss).unwrap();
let ga = tape.read_grad(a).unwrap().unwrap();
let gb = tape.read_grad(b).unwrap().unwrap();
assert_approx(&ga, &[4.0, 8.0], 1e-3);
assert_approx(&gb, &[-4.0, -8.0], 1e-3);
}
#[test]
fn test_backward_sigmoid() {
let mut tape = Tape::new(dev());
let a = tape.leaf(&[0.0, 1.0, -1.0]);
let b = tape.sigmoid(a).unwrap();
let target = tape.leaf(&[0.0, 0.0, 0.0]);
let loss = tape.mse_loss(b, target).unwrap();
tape.backward(loss).unwrap();
let s = [0.5f32, 0.7311, 0.2689];
let expected: Vec<f32> = (0..3).map(|i| 2.0 * s[i] / 3.0 * s[i] * (1.0 - s[i])).collect();
let ga = tape.read_grad(a).unwrap().unwrap();
assert_approx(&ga, &expected, 1e-3);
}
#[test]
fn test_backward_tanh() {
let mut tape = Tape::new(dev());
let a = tape.leaf(&[0.0, 1.0, -1.0]);
let b = tape.tanh_act(a).unwrap();
let target = tape.leaf(&[0.0, 0.0, 0.0]);
let loss = tape.mse_loss(b, target).unwrap();
tape.backward(loss).unwrap();
let t = [0.0f32, 0.7616, -0.7616];
let expected: Vec<f32> = (0..3).map(|i| 2.0 * t[i] / 3.0 * (1.0 - t[i] * t[i])).collect();
let ga = tape.read_grad(a).unwrap().unwrap();
assert_approx(&ga, &expected, 1e-2);
}
#[test]
fn test_backward_swish() {
let mut tape = Tape::new(dev());
let a = tape.leaf(&[0.0, 1.0, -1.0]);
let b = tape.swish(a).unwrap();
let target = tape.leaf(&[0.0, 0.0, 0.0]);
let loss = tape.mse_loss(b, target).unwrap();
tape.backward(loss).unwrap();
let x = [0.0f32, 1.0, -1.0];
let sw: Vec<f32> = x.iter().map(|&v| v / (1.0 + (-v).exp())).collect();
let expected: Vec<f32> = (0..3).map(|i| {
let s = 1.0 / (1.0 + (-x[i]).exp());
let d_swish = s + x[i] * s * (1.0 - s);
2.0 * sw[i] / 3.0 * d_swish
}).collect();
let ga = tape.read_grad(a).unwrap().unwrap();
assert_approx(&ga, &expected, 1e-2);
}
#[test]
fn test_read_grad_before_backward() {
let mut tape = Tape::new(dev());
let a = tape.leaf(&[1.0, 2.0]);
assert!(tape.read_grad(a).unwrap().is_none());
}
#[test]
fn test_backward_non_scalar_loss() {
let mut tape = Tape::new(dev());
let a = tape.leaf(&[1.0, 2.0]);
assert!(tape.backward(a).is_err());
}
#[test]
fn test_backward_diamond_graph() {
let mut tape = Tape::new(dev());
let a = tape.leaf(&[1.0]); let b = tape.scale(a, 2.0).unwrap(); let c = tape.scale(a, 3.0).unwrap(); let d = tape.add(b, c).unwrap(); let target = tape.leaf(&[0.0]);
let loss = tape.mse_loss(d, target).unwrap();
tape.backward(loss).unwrap();
let ga = tape.read_grad(a).unwrap().unwrap();
assert_approx(&ga, &[50.0], 1e-3);
}
#[test]
fn test_tape_leaf_data_roundtrip() {
let mut tape = Tape::new(dev());
let data = vec![1.5, -2.7, 0.0, 99.9];
let a = tape.leaf(&data);
assert_eq!(tape.read(a).unwrap(), data);
}
#[test]
fn test_tape_conv2d_forward() {
let mut tape = Tape::new(dev());
let input_data: Vec<f32> = (1..=9).map(|x| x as f32).collect();
let inp = tape.leaf(&input_data);
let w = tape.leaf(&[1.0f32]);
let b = tape.leaf(&[0.0f32]);
let out = tape.conv2d(inp, w, Some(b), 1, 1, 3, 3, 1, 1, 1, (1,1), (0,0), (1,1), 1).unwrap();
let result = tape.read(out).unwrap();
assert_approx(&result, &input_data, 1e-5);
}
#[test]
fn test_tape_conv2d_backward_weight_grad() {
let eps = 1e-3f32;
let input_data: Vec<f32> = (1..=9).map(|x| x as f32 * 0.1).collect();
let weight_data = vec![0.5f32];
let run = |w_val: f32| -> f32 {
let mut tape = Tape::new(dev());
let inp = tape.leaf(&input_data);
let w = tape.leaf(&[w_val]);
let out = tape.conv2d(inp, w, None, 1, 1, 3, 3, 1, 1, 1, (1,1), (0,0), (1,1), 1).unwrap();
let target = tape.leaf(&vec![0.0f32; 9]);
let loss = tape.mse_loss(out, target).unwrap();
tape.read(loss).unwrap()[0]
};
let mut tape = Tape::new(dev());
let inp = tape.leaf(&input_data);
let w = tape.leaf(&weight_data);
let out = tape.conv2d(inp, w, None, 1, 1, 3, 3, 1, 1, 1, (1,1), (0,0), (1,1), 1).unwrap();
let target = tape.leaf(&vec![0.0f32; 9]);
let loss = tape.mse_loss(out, target).unwrap();
tape.backward(loss).unwrap();
let gw = tape.read_grad(w).unwrap().unwrap();
let numeric = (run(weight_data[0] + eps) - run(weight_data[0] - eps)) / (2.0 * eps);
assert!((gw[0] - numeric).abs() < 1e-2,
"weight grad: analytical={}, numeric={}", gw[0], numeric);
}
#[test]
fn test_tape_conv2d_backward_input_grad() {
let eps = 1e-3f32;
let input_data: Vec<f32> = (1..=9).map(|x| x as f32 * 0.1).collect();
let weight_data = vec![0.5f32];
let run = |x_val: f32, idx: usize| -> f32 {
let mut inp_data = input_data.clone();
inp_data[idx] = x_val;
let mut tape = Tape::new(dev());
let inp = tape.leaf(&inp_data);
let w = tape.leaf(&weight_data);
let out = tape.conv2d(inp, w, None, 1, 1, 3, 3, 1, 1, 1, (1,1), (0,0), (1,1), 1).unwrap();
let target = tape.leaf(&vec![0.0f32; 9]);
let loss = tape.mse_loss(out, target).unwrap();
tape.read(loss).unwrap()[0]
};
let mut tape = Tape::new(dev());
let inp = tape.leaf(&input_data);
let w = tape.leaf(&weight_data);
let out = tape.conv2d(inp, w, None, 1, 1, 3, 3, 1, 1, 1, (1,1), (0,0), (1,1), 1).unwrap();
let target = tape.leaf(&vec![0.0f32; 9]);
let loss = tape.mse_loss(out, target).unwrap();
tape.backward(loss).unwrap();
let gi = tape.read_grad(inp).unwrap().unwrap();
for i in 0..9 {
let numeric = (run(input_data[i] + eps, i) - run(input_data[i] - eps, i)) / (2.0 * eps);
assert!((gi[i] - numeric).abs() < 1e-2,
"input grad[{i}]: analytical={}, numeric={}", gi[i], numeric);
}
}
#[test]
fn test_tape_conv2d_backward_bias_grad() {
let mut tape = Tape::new(dev());
let inp = tape.leaf(&[1.0f32, 2.0, 3.0, 4.0]);
let w = tape.leaf(&[1.0f32]);
let b = tape.leaf(&[0.0f32]);
let out = tape.conv2d(inp, w, Some(b), 1, 1, 2, 2, 1, 1, 1, (1,1), (0,0), (1,1), 1).unwrap();
let target = tape.leaf(&[0.0f32; 4]);
let loss = tape.mse_loss(out, target).unwrap();
tape.backward(loss).unwrap();
let gb = tape.read_grad(b).unwrap().unwrap();
assert_approx(&gb, &[5.0], 1e-3);
}
}