use std::collections::HashMap;
use super::error::TensorError;
use super::tensor::{GpuTensor, SavedTensor, TensorId};
#[derive(Debug, Clone)]
pub enum GradFn {
Add {
lhs: TensorId,
rhs: TensorId,
},
Sub {
lhs: TensorId,
rhs: TensorId,
},
Mul {
lhs: TensorId,
rhs: TensorId,
lhs_data: SavedTensor,
rhs_data: SavedTensor,
},
Div {
lhs: TensorId,
rhs: TensorId,
lhs_data: SavedTensor,
rhs_data: SavedTensor,
},
Neg {
input: TensorId,
},
Abs {
input: TensorId,
input_data: SavedTensor,
},
Relu {
input: TensorId,
mask: SavedTensor,
},
LeakyRelu {
input: TensorId,
alpha: f64,
input_data: SavedTensor,
},
Sigmoid {
input: TensorId,
output: SavedTensor,
},
Tanh {
input: TensorId,
output: SavedTensor,
},
Gelu {
input: TensorId,
input_data: SavedTensor,
},
Silu {
input: TensorId,
input_data: SavedTensor,
},
Exp {
input: TensorId,
output: SavedTensor,
},
Log {
input: TensorId,
input_data: SavedTensor,
},
Sqrt {
input: TensorId,
output: SavedTensor,
},
Pow {
input: TensorId,
exponent: f64,
input_data: SavedTensor,
},
MatMul {
lhs: TensorId,
rhs: TensorId,
lhs_data: SavedTensor,
rhs_data: SavedTensor,
},
Softmax {
input: TensorId,
output: SavedTensor,
dim: usize,
},
LogSoftmax {
input: TensorId,
output: SavedTensor,
dim: usize,
},
Sum {
input: TensorId,
input_shape: Vec<usize>,
},
Mean {
input: TensorId,
input_shape: Vec<usize>,
},
Max {
input: TensorId,
indices: Vec<usize>,
input_shape: Vec<usize>,
},
Min {
input: TensorId,
indices: Vec<usize>,
input_shape: Vec<usize>,
},
BatchNorm {
input: TensorId,
normalized: SavedTensor,
std_inv: Vec<f64>,
gamma: Vec<f64>,
num_channels: usize,
},
LayerNorm {
input: TensorId,
normalized: SavedTensor,
std_inv: Vec<f64>,
gamma: Vec<f64>,
norm_shape: Vec<usize>,
},
CrossEntropy {
input: TensorId,
probs: SavedTensor,
targets: Vec<usize>,
},
MseLoss {
input: TensorId,
target: TensorId,
diff: SavedTensor,
},
L1Loss {
input: TensorId,
sign_diff: SavedTensor,
},
SmoothL1Loss {
input: TensorId,
beta: f64,
diff: SavedTensor,
},
NllLoss {
input: TensorId,
targets: Vec<usize>,
batch_size: usize,
},
Conv2d {
input: TensorId,
col_data: SavedTensor,
weight_data: SavedTensor,
input_shape: Vec<usize>,
kernel_size: (usize, usize),
stride: (usize, usize),
padding: (usize, usize),
},
MaxPool2d {
input: TensorId,
indices: Vec<usize>,
input_shape: Vec<usize>,
},
AvgPool2d {
input: TensorId,
input_shape: Vec<usize>,
kernel_size: (usize, usize),
stride: (usize, usize),
padding: (usize, usize),
},
GroupNorm {
input: TensorId,
normalized: SavedTensor,
std_inv: Vec<f64>,
gamma: Vec<f64>,
num_groups: usize,
},
}
use std::cell::Cell;
thread_local! {
static GRAD_ENABLED: Cell<bool> = const { Cell::new(true) };
}
#[must_use]
pub fn is_grad_enabled() -> bool {
GRAD_ENABLED.with(|c| c.get())
}
pub struct NoGradGuard {
prev: bool,
}
#[must_use]
pub fn no_grad() -> NoGradGuard {
let prev = GRAD_ENABLED.with(|c| c.replace(false));
NoGradGuard { prev }
}
impl Drop for NoGradGuard {
fn drop(&mut self) {
GRAD_ENABLED.with(|c| c.set(self.prev));
}
}
#[derive(Debug, Clone)]
pub struct TapeEntry {
pub output_id: TensorId,
pub grad_fn: GradFn,
}
#[derive(Debug, Clone, Default)]
pub struct AutogradTape {
entries: Vec<TapeEntry>,
}
impl AutogradTape {
#[must_use]
pub fn new() -> Self {
Self {
entries: Vec::new(),
}
}
pub fn record(&mut self, output_id: TensorId, grad_fn: GradFn) {
self.entries.push(TapeEntry { output_id, grad_fn });
}
#[must_use]
pub fn len(&self) -> usize {
self.entries.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
pub fn clear(&mut self) {
self.entries.clear();
}
pub fn backward(
&self,
loss_id: TensorId,
tensors: &mut HashMap<TensorId, GpuTensor>,
) -> Result<(), TensorError> {
let loss_shape = tensors
.get(&loss_id)
.ok_or_else(|| TensorError::AutogradError("loss tensor not found on tape".into()))?
.shape()
.to_vec();
let loss_numel = tensors
.get(&loss_id)
.ok_or_else(|| TensorError::AutogradError("loss missing".into()))?
.numel();
let mut grads: HashMap<TensorId, Vec<f64>> = HashMap::new();
grads.insert(loss_id, vec![1.0; loss_numel]);
for entry in self.entries.iter().rev() {
let out_grad = match grads.get(&entry.output_id) {
Some(g) => g.clone(),
None => continue, };
backward_one(&entry.grad_fn, &out_grad, &loss_shape, &mut grads)?;
}
for (tid, grad_data) in &grads {
if let Some(tensor) = tensors.get_mut(tid) {
if tensor.requires_grad() {
let grad_tensor = GpuTensor::from_parts(
tensor.shape().to_vec(),
tensor.dtype(),
tensor.device_id(),
grad_data.clone(),
false,
None,
);
tensor.accumulate_grad(&grad_tensor)?;
}
}
}
Ok(())
}
}
#[allow(clippy::too_many_lines)]
fn backward_one(
grad_fn: &GradFn,
out_grad: &[f64],
_loss_shape: &[usize],
grads: &mut HashMap<TensorId, Vec<f64>>,
) -> Result<(), TensorError> {
match grad_fn {
GradFn::Add { lhs, rhs } => {
accumulate(grads, *lhs, out_grad);
accumulate(grads, *rhs, out_grad);
}
GradFn::Sub { lhs, rhs } => {
accumulate(grads, *lhs, out_grad);
let neg: Vec<f64> = out_grad.iter().map(|&g| -g).collect();
accumulate(grads, *rhs, &neg);
}
GradFn::Mul {
lhs,
rhs,
lhs_data,
rhs_data,
} => {
let da: Vec<f64> = out_grad
.iter()
.zip(rhs_data.data.iter())
.map(|(&g, &b)| g * b)
.collect();
accumulate(grads, *lhs, &da);
let db: Vec<f64> = out_grad
.iter()
.zip(lhs_data.data.iter())
.map(|(&g, &a)| g * a)
.collect();
accumulate(grads, *rhs, &db);
}
GradFn::Div {
lhs,
rhs,
lhs_data,
rhs_data,
} => {
let da: Vec<f64> = out_grad
.iter()
.zip(rhs_data.data.iter())
.map(|(&g, &b)| if b.abs() > 1e-30 { g / b } else { 0.0 })
.collect();
accumulate(grads, *lhs, &da);
let db: Vec<f64> = out_grad
.iter()
.zip(lhs_data.data.iter())
.zip(rhs_data.data.iter())
.map(|((&g, &a), &b)| {
if b.abs() > 1e-30 {
-g * a / (b * b)
} else {
0.0
}
})
.collect();
accumulate(grads, *rhs, &db);
}
GradFn::Neg { input } => {
let neg: Vec<f64> = out_grad.iter().map(|&g| -g).collect();
accumulate(grads, *input, &neg);
}
GradFn::Abs { input, input_data } => {
let da: Vec<f64> = out_grad
.iter()
.zip(input_data.data.iter())
.map(|(&g, &x)| {
if x > 0.0 {
g
} else if x < 0.0 {
-g
} else {
0.0
}
})
.collect();
accumulate(grads, *input, &da);
}
GradFn::Relu { input, mask } => {
let da: Vec<f64> = out_grad
.iter()
.zip(mask.data.iter())
.map(|(&g, &m)| g * m)
.collect();
accumulate(grads, *input, &da);
}
GradFn::LeakyRelu {
input,
alpha,
input_data,
} => {
let da: Vec<f64> = out_grad
.iter()
.zip(input_data.data.iter())
.map(|(&g, &x)| if x > 0.0 { g } else { g * alpha })
.collect();
accumulate(grads, *input, &da);
}
GradFn::Sigmoid { input, output } => {
let da: Vec<f64> = out_grad
.iter()
.zip(output.data.iter())
.map(|(&g, &z)| g * z * (1.0 - z))
.collect();
accumulate(grads, *input, &da);
}
GradFn::Tanh { input, output } => {
let da: Vec<f64> = out_grad
.iter()
.zip(output.data.iter())
.map(|(&g, &z)| g * (1.0 - z * z))
.collect();
accumulate(grads, *input, &da);
}
GradFn::Gelu { input, input_data } => {
let sqrt_2_over_pi = (2.0_f64 / std::f64::consts::PI).sqrt();
let da: Vec<f64> = out_grad
.iter()
.zip(input_data.data.iter())
.map(|(&g, &x)| {
let inner = sqrt_2_over_pi * (x + 0.044715 * x * x * x);
let tanh_inner = inner.tanh();
let cdf = 0.5 * (1.0 + tanh_inner);
let pdf = sqrt_2_over_pi
* (1.0 + 3.0 * 0.044715 * x * x)
* (1.0 - tanh_inner * tanh_inner);
g * (cdf + 0.5 * x * pdf)
})
.collect();
accumulate(grads, *input, &da);
}
GradFn::Silu { input, input_data } => {
let da: Vec<f64> = out_grad
.iter()
.zip(input_data.data.iter())
.map(|(&g, &x)| {
let s = 1.0 / (1.0 + (-x).exp());
g * (s * (1.0 + x * (1.0 - s)))
})
.collect();
accumulate(grads, *input, &da);
}
GradFn::Exp { input, output } => {
let da: Vec<f64> = out_grad
.iter()
.zip(output.data.iter())
.map(|(&g, &z)| g * z)
.collect();
accumulate(grads, *input, &da);
}
GradFn::Log { input, input_data } => {
let da: Vec<f64> = out_grad
.iter()
.zip(input_data.data.iter())
.map(|(&g, &x)| if x.abs() > 1e-30 { g / x } else { 0.0 })
.collect();
accumulate(grads, *input, &da);
}
GradFn::Sqrt { input, output } => {
let da: Vec<f64> = out_grad
.iter()
.zip(output.data.iter())
.map(|(&g, &z)| if z.abs() > 1e-30 { g / (2.0 * z) } else { 0.0 })
.collect();
accumulate(grads, *input, &da);
}
GradFn::Pow {
input,
exponent,
input_data,
} => {
let da: Vec<f64> = out_grad
.iter()
.zip(input_data.data.iter())
.map(|(&g, &x)| g * exponent * x.powf(exponent - 1.0))
.collect();
accumulate(grads, *input, &da);
}
GradFn::MatMul {
lhs,
rhs,
lhs_data,
rhs_data,
} => {
backward_matmul(out_grad, lhs_data, rhs_data, *lhs, *rhs, grads)?;
}
GradFn::Softmax {
input,
output,
dim: _,
} => {
let n = output.data.len();
let dot: f64 = out_grad
.iter()
.zip(output.data.iter())
.map(|(&g, &s)| g * s)
.sum();
let da: Vec<f64> = (0..n)
.map(|i| output.data[i] * (out_grad[i] - dot))
.collect();
accumulate(grads, *input, &da);
}
GradFn::LogSoftmax {
input,
output,
dim: _,
} => {
let sum_dz: f64 = out_grad.iter().sum();
let da: Vec<f64> = output
.data
.iter()
.zip(out_grad.iter())
.map(|(&log_s, &g)| g - log_s.exp() * sum_dz)
.collect();
accumulate(grads, *input, &da);
}
GradFn::Sum { input, input_shape } => {
let numel: usize = input_shape.iter().product();
let expanded = vec![out_grad.first().copied().unwrap_or(0.0); numel];
accumulate(grads, *input, &expanded);
}
GradFn::Mean { input, input_shape } => {
let numel: usize = input_shape.iter().product();
let scale = if numel > 0 { 1.0 / numel as f64 } else { 0.0 };
let expanded = vec![out_grad.first().copied().unwrap_or(0.0) * scale; numel];
accumulate(grads, *input, &expanded);
}
GradFn::Max {
input,
indices,
input_shape,
} => {
let numel: usize = input_shape.iter().product();
let mut da = vec![0.0; numel];
for (i, &idx) in indices.iter().enumerate() {
if idx < numel && i < out_grad.len() {
da[idx] += out_grad[i];
}
}
accumulate(grads, *input, &da);
}
GradFn::Min {
input,
indices,
input_shape,
} => {
let numel: usize = input_shape.iter().product();
let mut da = vec![0.0; numel];
for (i, &idx) in indices.iter().enumerate() {
if idx < numel && i < out_grad.len() {
da[idx] += out_grad[i];
}
}
accumulate(grads, *input, &da);
}
GradFn::BatchNorm {
input,
normalized,
std_inv,
gamma,
num_channels,
} => {
backward_batch_norm(
out_grad,
normalized,
std_inv,
gamma,
*num_channels,
*input,
grads,
);
}
GradFn::LayerNorm {
input,
normalized,
std_inv,
gamma,
norm_shape,
} => {
backward_layer_norm(
out_grad, normalized, std_inv, gamma, norm_shape, *input, grads,
);
}
GradFn::CrossEntropy {
input,
probs,
targets,
} => {
backward_cross_entropy(out_grad, probs, targets, *input, grads);
}
GradFn::MseLoss {
input,
target: _,
diff,
} => {
let n = diff.data.len();
let scale = if n > 0 { 2.0 / n as f64 } else { 0.0 };
let da: Vec<f64> = diff
.data
.iter()
.zip(out_grad.iter().cycle())
.map(|(&d, &g)| g * scale * d)
.collect();
accumulate(grads, *input, &da);
}
GradFn::L1Loss { input, sign_diff } => {
let n = sign_diff.data.len();
let scale = if n > 0 { 1.0 / n as f64 } else { 0.0 };
let da: Vec<f64> = sign_diff
.data
.iter()
.zip(out_grad.iter().cycle())
.map(|(&s, &g)| g * scale * s)
.collect();
accumulate(grads, *input, &da);
}
GradFn::SmoothL1Loss { input, beta, diff } => {
let n = diff.data.len();
let scale = if n > 0 { 1.0 / n as f64 } else { 0.0 };
let da: Vec<f64> = diff
.data
.iter()
.zip(out_grad.iter().cycle())
.map(|(&d, &g)| {
if d.abs() < *beta {
g * scale * d / beta
} else {
g * scale * d.signum()
}
})
.collect();
accumulate(grads, *input, &da);
}
GradFn::NllLoss {
input,
targets,
batch_size,
} => {
backward_nll_loss(out_grad, targets, *batch_size, *input, grads);
}
GradFn::Conv2d {
input,
col_data,
weight_data,
input_shape,
kernel_size,
stride,
padding,
} => {
backward_conv2d(
out_grad,
col_data,
weight_data,
input_shape,
(*kernel_size, *stride, *padding),
*input,
grads,
)?;
}
GradFn::MaxPool2d {
input,
indices,
input_shape,
} => {
let numel: usize = input_shape.iter().product();
let mut da = vec![0.0; numel];
for (i, &idx) in indices.iter().enumerate() {
if idx < numel && i < out_grad.len() {
da[idx] += out_grad[i];
}
}
accumulate(grads, *input, &da);
}
GradFn::AvgPool2d {
input,
input_shape,
kernel_size,
stride,
padding: _,
} => {
backward_avg_pool2d(out_grad, input_shape, *kernel_size, *stride, *input, grads);
}
GradFn::GroupNorm {
input,
normalized,
std_inv,
gamma,
num_groups,
} => {
backward_group_norm(
out_grad,
normalized,
std_inv,
gamma,
*num_groups,
*input,
grads,
);
}
}
Ok(())
}
fn accumulate(grads: &mut HashMap<TensorId, Vec<f64>>, id: TensorId, grad: &[f64]) {
let entry = grads.entry(id).or_insert_with(|| vec![0.0; grad.len()]);
if entry.len() == grad.len() {
for (a, b) in entry.iter_mut().zip(grad.iter()) {
*a += b;
}
}
}
fn backward_matmul(
out_grad: &[f64],
lhs_data: &SavedTensor,
rhs_data: &SavedTensor,
lhs_id: TensorId,
rhs_id: TensorId,
grads: &mut HashMap<TensorId, Vec<f64>>,
) -> Result<(), TensorError> {
let lhs_shape = &lhs_data.shape;
let rhs_shape = &rhs_data.shape;
if lhs_shape.len() != 2 || rhs_shape.len() != 2 {
return Err(TensorError::AutogradError(
"matmul backward requires 2D tensors".into(),
));
}
let m = lhs_shape[0];
let k = lhs_shape[1];
let n = rhs_shape[1];
let mut da = vec![0.0; m * k];
for i in 0..m {
for j in 0..k {
let mut sum = 0.0;
for l in 0..n {
sum += out_grad[i * n + l] * rhs_data.data[j * n + l];
}
da[i * k + j] = sum;
}
}
accumulate(grads, lhs_id, &da);
let mut db = vec![0.0; k * n];
for i in 0..k {
for j in 0..n {
let mut sum = 0.0;
for l in 0..m {
sum += lhs_data.data[l * k + i] * out_grad[l * n + j];
}
db[i * n + j] = sum;
}
}
accumulate(grads, rhs_id, &db);
Ok(())
}
fn backward_cross_entropy(
out_grad: &[f64],
probs: &SavedTensor,
targets: &[usize],
input_id: TensorId,
grads: &mut HashMap<TensorId, Vec<f64>>,
) {
let batch_size = targets.len();
let num_classes = match probs.data.len().checked_div(batch_size) {
Some(n) => n,
None => return,
};
let mut da = probs.data.clone();
for (b, &t) in targets.iter().enumerate() {
if t < num_classes {
da[b * num_classes + t] -= 1.0;
}
}
let scale = if batch_size > 0 {
out_grad.first().copied().unwrap_or(1.0) / batch_size as f64
} else {
0.0
};
for v in &mut da {
*v *= scale;
}
accumulate(grads, input_id, &da);
}
fn backward_nll_loss(
out_grad: &[f64],
targets: &[usize],
batch_size: usize,
input_id: TensorId,
grads: &mut HashMap<TensorId, Vec<f64>>,
) {
let num_classes = if batch_size > 0 && !targets.is_empty() {
targets.iter().copied().max().unwrap_or(0) + 1
} else {
return;
};
let mut da = vec![0.0; batch_size * num_classes];
let scale = if batch_size > 0 {
out_grad.first().copied().unwrap_or(1.0) / batch_size as f64
} else {
0.0
};
for (b, &t) in targets.iter().enumerate() {
if t < num_classes {
da[b * num_classes + t] = -scale;
}
}
accumulate(grads, input_id, &da);
}
#[allow(clippy::too_many_arguments)]
fn backward_batch_norm(
out_grad: &[f64],
normalized: &SavedTensor,
std_inv: &[f64],
gamma: &[f64],
num_channels: usize,
input_id: TensorId,
grads: &mut HashMap<TensorId, Vec<f64>>,
) {
let total = normalized.data.len();
let per_channel = match total.checked_div(num_channels) {
Some(n) => n,
None => return,
};
let mut da = vec![0.0; total];
for c in 0..num_channels {
let g = gamma.get(c).copied().unwrap_or(1.0);
let inv = std_inv.get(c).copied().unwrap_or(1.0);
let start = c * per_channel;
let end = start + per_channel;
let mean_dz: f64 = out_grad[start..end].iter().sum::<f64>() / per_channel as f64;
let mean_dz_xhat: f64 = out_grad[start..end]
.iter()
.zip(normalized.data[start..end].iter())
.map(|(&dz, &xh)| dz * xh)
.sum::<f64>()
/ per_channel as f64;
for i in start..end {
da[i] = g * inv * (out_grad[i] - mean_dz - normalized.data[i] * mean_dz_xhat);
}
}
accumulate(grads, input_id, &da);
}
#[allow(clippy::too_many_arguments)]
fn backward_layer_norm(
out_grad: &[f64],
normalized: &SavedTensor,
std_inv: &[f64],
gamma: &[f64],
norm_shape: &[usize],
input_id: TensorId,
grads: &mut HashMap<TensorId, Vec<f64>>,
) {
let total = normalized.data.len();
let norm_size: usize = norm_shape.iter().product();
if norm_size == 0 {
return;
}
let num_instances = total / norm_size;
let mut da = vec![0.0; total];
for inst in 0..num_instances {
let start = inst * norm_size;
let end = start + norm_size;
let inv = std_inv.get(inst).copied().unwrap_or(1.0);
let mean_dz: f64 = out_grad[start..end]
.iter()
.zip(gamma.iter().cycle())
.map(|(&dz, &g)| dz * g)
.sum::<f64>()
/ norm_size as f64;
let mean_dz_xhat: f64 = out_grad[start..end]
.iter()
.zip(gamma.iter().cycle())
.zip(normalized.data[start..end].iter())
.map(|((&dz, &g), &xh)| dz * g * xh)
.sum::<f64>()
/ norm_size as f64;
for (i, idx) in (start..end).enumerate() {
let g = gamma.get(i % gamma.len()).copied().unwrap_or(1.0);
da[idx] = inv * (g * out_grad[idx] - mean_dz - normalized.data[idx] * mean_dz_xhat);
}
}
accumulate(grads, input_id, &da);
}
fn backward_avg_pool2d(
out_grad: &[f64],
input_shape: &[usize],
kernel_size: (usize, usize),
stride: (usize, usize),
input_id: TensorId,
grads: &mut HashMap<TensorId, Vec<f64>>,
) {
if input_shape.len() != 4 {
return;
}
let (n, c, h, _w) = (
input_shape[0],
input_shape[1],
input_shape[2],
input_shape[3],
);
let numel: usize = input_shape.iter().product();
let out_h = (h - kernel_size.0) / stride.0 + 1;
let out_w = (input_shape[3] - kernel_size.1) / stride.1 + 1;
let pool_size = (kernel_size.0 * kernel_size.1) as f64;
let mut da = vec![0.0; numel];
for batch in 0..n {
for ch in 0..c {
for oh in 0..out_h {
for ow in 0..out_w {
let out_idx = ((batch * c + ch) * out_h + oh) * out_w + ow;
let g = if out_idx < out_grad.len() {
out_grad[out_idx]
} else {
0.0
};
let val = g / pool_size;
for kh in 0..kernel_size.0 {
for kw in 0..kernel_size.1 {
let ih = oh * stride.0 + kh;
let iw = ow * stride.1 + kw;
let in_idx =
((batch * c + ch) * input_shape[2] + ih) * input_shape[3] + iw;
if in_idx < numel {
da[in_idx] += val;
}
}
}
}
}
}
}
accumulate(grads, input_id, &da);
}
fn backward_conv2d(
out_grad: &[f64],
col_data: &SavedTensor,
weight_data: &SavedTensor,
input_shape: &[usize],
conv_params: ((usize, usize), (usize, usize), (usize, usize)),
input_id: TensorId,
grads: &mut HashMap<TensorId, Vec<f64>>,
) -> Result<(), TensorError> {
if input_shape.len() != 4 {
return Err(TensorError::AutogradError(
"conv2d backward requires 4D input".into(),
));
}
let (kernel_size, stride, padding) = conv_params;
let (k_h, k_w) = kernel_size;
let (stride_h, stride_w) = stride;
let (pad_h, pad_w) = padding;
if stride_h == 0 || stride_w == 0 {
return Err(TensorError::AutogradError(
"conv2d backward requires non-zero stride".into(),
));
}
let n = input_shape[0];
let c_in = input_shape[1];
let h = input_shape[2];
let w = input_shape[3];
let col_rows = c_in
.checked_mul(k_h)
.and_then(|v| v.checked_mul(k_w))
.ok_or_else(|| TensorError::AutogradError("conv2d backward shape overflow".into()))?;
let padded_h =
h.checked_add(pad_h.checked_mul(2).ok_or_else(|| {
TensorError::AutogradError("conv2d backward padding overflow".into())
})?)
.ok_or_else(|| TensorError::AutogradError("conv2d backward padding overflow".into()))?;
let padded_w =
w.checked_add(pad_w.checked_mul(2).ok_or_else(|| {
TensorError::AutogradError("conv2d backward padding overflow".into())
})?)
.ok_or_else(|| TensorError::AutogradError("conv2d backward padding overflow".into()))?;
if padded_h < k_h || padded_w < k_w {
return Err(TensorError::AutogradError(
"conv2d backward kernel larger than padded input".into(),
));
}
let out_h = (padded_h - k_h) / stride_h + 1;
let out_w = (padded_w - k_w) / stride_w + 1;
let col_cols = out_h
.checked_mul(out_w)
.ok_or_else(|| TensorError::AutogradError("conv2d backward shape overflow".into()))?;
let patch_count = n
.checked_mul(col_cols)
.ok_or_else(|| TensorError::AutogradError("conv2d backward shape overflow".into()))?;
if col_data.shape.len() != 3
|| col_data.shape[0] != n
|| col_data.shape[1] != col_rows
|| col_data.shape[2] != col_cols
{
return Err(TensorError::AutogradError(
"conv2d backward saved im2col shape mismatch".into(),
));
}
let expected_col_len = patch_count
.checked_mul(col_rows)
.ok_or_else(|| TensorError::AutogradError("conv2d backward shape overflow".into()))?;
if col_data.data.len() != expected_col_len {
return Err(TensorError::AutogradError(
"conv2d backward saved im2col data mismatch".into(),
));
}
let c_out = weight_data
.data
.len()
.checked_div(col_rows)
.ok_or_else(|| TensorError::AutogradError("conv2d backward invalid weight shape".into()))?;
if c_out == 0 || weight_data.data.len() != c_out * col_rows {
return Err(TensorError::AutogradError(
"conv2d backward invalid weight shape".into(),
));
}
let expected_out_grad = patch_count
.checked_mul(c_out)
.ok_or_else(|| TensorError::AutogradError("conv2d backward shape overflow".into()))?;
if out_grad.len() != expected_out_grad {
return Err(TensorError::AutogradError(
"conv2d backward output gradient shape mismatch".into(),
));
}
let numel = n
.checked_mul(c_in)
.and_then(|v| v.checked_mul(h))
.and_then(|v| v.checked_mul(w))
.ok_or_else(|| TensorError::AutogradError("conv2d backward shape overflow".into()))?;
let max_h = pad_h
.checked_add(h)
.ok_or_else(|| TensorError::AutogradError("conv2d backward shape overflow".into()))?;
let max_w = pad_w
.checked_add(w)
.ok_or_else(|| TensorError::AutogradError("conv2d backward shape overflow".into()))?;
let mut d_col = vec![0.0; expected_col_len];
for batch in 0..n {
for patch in 0..col_cols {
let patch_idx = batch * col_cols + patch;
for kc in 0..col_rows {
let mut sum = 0.0;
for co in 0..c_out {
let out_idx = (batch * c_out + co) * col_cols + patch;
sum += out_grad[out_idx] * weight_data.data[co * col_rows + kc];
}
d_col[patch_idx * col_rows + kc] = sum;
}
}
}
let mut da = vec![0.0; numel];
for patch_idx in 0..patch_count {
let batch = patch_idx / col_cols;
let patch = patch_idx % col_cols;
let oh = patch / out_w;
let ow = patch % out_w;
for c in 0..c_in {
for kh_idx in 0..k_h {
for kw_idx in 0..k_w {
let ih_padded = oh * stride_h + kh_idx;
let iw_padded = ow * stride_w + kw_idx;
if ih_padded < pad_h
|| ih_padded >= max_h
|| iw_padded < pad_w
|| iw_padded >= max_w
{
continue;
}
let ih = ih_padded - pad_h;
let iw = iw_padded - pad_w;
let kc = c * k_h * k_w + kh_idx * k_w + kw_idx;
let input_idx = ((batch * c_in + c) * h + ih) * w + iw;
da[input_idx] += d_col[patch_idx * col_rows + kc];
}
}
}
}
accumulate(grads, input_id, &da);
Ok(())
}
#[allow(clippy::too_many_arguments)]
fn backward_group_norm(
out_grad: &[f64],
normalized: &SavedTensor,
std_inv: &[f64],
gamma: &[f64],
num_groups: usize,
input_id: TensorId,
grads: &mut HashMap<TensorId, Vec<f64>>,
) {
let total = normalized.data.len();
if num_groups == 0 {
return;
}
let group_size = total / num_groups;
let mut da = vec![0.0; total];
for g in 0..num_groups {
let start = g * group_size;
let end = start + group_size;
let inv = std_inv.get(g).copied().unwrap_or(1.0);
let mean_dz: f64 = out_grad[start..end]
.iter()
.zip(gamma.iter().cycle())
.map(|(&dz, &gm)| dz * gm)
.sum::<f64>()
/ group_size as f64;
let mean_dz_xhat: f64 = out_grad[start..end]
.iter()
.zip(gamma.iter().cycle())
.zip(normalized.data[start..end].iter())
.map(|((&dz, &gm), &xh)| dz * gm * xh)
.sum::<f64>()
/ group_size as f64;
for (i, idx) in (start..end).enumerate() {
let gm = gamma.get(i % gamma.len()).copied().unwrap_or(1.0);
da[idx] = inv * (gm * out_grad[idx] - mean_dz - normalized.data[idx] * mean_dz_xhat);
}
}
accumulate(grads, input_id, &da);
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum CheckpointStrategy {
#[default]
StoreAll,
RecomputeAll,
EveryN(usize),
}
#[derive(Debug, Clone)]
pub struct GradientCheckpointing {
pub strategy: CheckpointStrategy,
pub enabled: bool,
}
impl GradientCheckpointing {
#[must_use]
pub fn new(strategy: CheckpointStrategy) -> Self {
Self {
strategy,
enabled: true,
}
}
#[must_use]
pub fn should_checkpoint(&self, layer_idx: usize) -> bool {
if !self.enabled {
return false;
}
match self.strategy {
CheckpointStrategy::StoreAll => false,
CheckpointStrategy::RecomputeAll => true,
CheckpointStrategy::EveryN(n) => n > 0 && layer_idx % n == 0,
}
}
}
impl Default for GradientCheckpointing {
fn default() -> Self {
Self {
strategy: CheckpointStrategy::StoreAll,
enabled: false,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor_backend::dtype::TensorDtype;
use crate::tensor_backend::ops::conv2d;
#[test]
fn test_no_grad_context() {
assert!(is_grad_enabled());
{
let _guard = no_grad();
assert!(!is_grad_enabled());
}
assert!(is_grad_enabled());
}
#[test]
fn test_tape_record() {
let mut tape = AutogradTape::new();
assert!(tape.is_empty());
let id_a = TensorId::new();
let id_b = TensorId::new();
let id_c = TensorId::new();
tape.record(
id_c,
GradFn::Add {
lhs: id_a,
rhs: id_b,
},
);
assert_eq!(tape.len(), 1);
tape.clear();
assert!(tape.is_empty());
}
#[test]
fn test_backward_add() {
let mut tape = AutogradTape::new();
let mut a = GpuTensor::from_host_f64(&[2.0, 3.0], &[2], 0).unwrap();
a.set_requires_grad(true);
let mut b = GpuTensor::from_host_f64(&[4.0, 5.0], &[2], 0).unwrap();
b.set_requires_grad(true);
let c_data: Vec<f64> = a
.host_data()
.iter()
.zip(b.host_data().iter())
.map(|(&x, &y)| x + y)
.collect();
let c = GpuTensor::from_parts(vec![2], TensorDtype::Float32, 0, c_data, false, None);
tape.record(
c.id(),
GradFn::Add {
lhs: a.id(),
rhs: b.id(),
},
);
let loss_val: f64 = c.host_data().iter().sum();
let loss = GpuTensor::from_parts(
vec![1],
TensorDtype::Float32,
0,
vec![loss_val],
false,
None,
);
tape.record(
loss.id(),
GradFn::Sum {
input: c.id(),
input_shape: vec![2],
},
);
let mut tensors = HashMap::new();
tensors.insert(a.id(), a);
tensors.insert(b.id(), b);
tensors.insert(c.id(), c);
tensors.insert(loss.id(), loss.clone());
tape.backward(loss.id(), &mut tensors).unwrap();
let a_grad = tensors.get(&TensorId(1)).map(|t| t.grad());
if let Some(Some(g)) = a_grad {
assert!((g.host_data()[0] - 1.0).abs() < 1e-10);
assert!((g.host_data()[1] - 1.0).abs() < 1e-10);
}
}
#[test]
fn test_backward_mul() {
let mut tape = AutogradTape::new();
let mut a = GpuTensor::from_host_f64(&[3.0], &[1], 0).unwrap();
a.set_requires_grad(true);
let mut b = GpuTensor::from_host_f64(&[5.0], &[1], 0).unwrap();
b.set_requires_grad(true);
let c_val = a.host_data()[0] * b.host_data()[0];
let c = GpuTensor::from_parts(vec![1], TensorDtype::Float32, 0, vec![c_val], false, None);
tape.record(
c.id(),
GradFn::Mul {
lhs: a.id(),
rhs: b.id(),
lhs_data: SavedTensor::from_tensor(&a),
rhs_data: SavedTensor::from_tensor(&b),
},
);
let mut tensors = HashMap::new();
let a_id = a.id();
let b_id = b.id();
let c_id = c.id();
tensors.insert(a_id, a);
tensors.insert(b_id, b);
tensors.insert(c_id, c);
tape.backward(c_id, &mut tensors).unwrap();
let a_grad = tensors.get(&a_id).and_then(|t| t.grad());
let b_grad = tensors.get(&b_id).and_then(|t| t.grad());
if let Some(g) = a_grad {
assert!((g.host_data()[0] - 5.0).abs() < 1e-10);
}
if let Some(g) = b_grad {
assert!((g.host_data()[0] - 3.0).abs() < 1e-10);
}
}
#[test]
fn test_checkpoint_strategy() {
let cp = GradientCheckpointing::new(CheckpointStrategy::EveryN(3));
assert!(cp.should_checkpoint(0));
assert!(!cp.should_checkpoint(1));
assert!(!cp.should_checkpoint(2));
assert!(cp.should_checkpoint(3));
let cp_all = GradientCheckpointing::new(CheckpointStrategy::RecomputeAll);
assert!(cp_all.should_checkpoint(0));
assert!(cp_all.should_checkpoint(99));
let cp_none = GradientCheckpointing::new(CheckpointStrategy::StoreAll);
assert!(!cp_none.should_checkpoint(0));
}
#[test]
fn test_backward_sigmoid() {
let mut tape = AutogradTape::new();
let mut a = GpuTensor::from_host_f64(&[0.0], &[1], 0).unwrap();
a.set_requires_grad(true);
let sig_val = 1.0 / (1.0 + (-a.host_data()[0]).exp());
let c = GpuTensor::from_parts(vec![1], TensorDtype::Float32, 0, vec![sig_val], false, None);
tape.record(
c.id(),
GradFn::Sigmoid {
input: a.id(),
output: SavedTensor::from_tensor(&c),
},
);
let mut tensors = HashMap::new();
let a_id = a.id();
let c_id = c.id();
tensors.insert(a_id, a);
tensors.insert(c_id, c);
tape.backward(c_id, &mut tensors).unwrap();
let a_grad = tensors.get(&a_id).and_then(|t| t.grad());
if let Some(g) = a_grad {
assert!((g.host_data()[0] - 0.25).abs() < 1e-10);
}
}
#[test]
fn test_backward_relu() {
let mut tape = AutogradTape::new();
let mut a = GpuTensor::from_host_f64(&[-1.0, 2.0, 0.0], &[3], 0).unwrap();
a.set_requires_grad(true);
let relu_data: Vec<f64> = a.host_data().iter().map(|&x| x.max(0.0)).collect();
let mask_data: Vec<f64> = a
.host_data()
.iter()
.map(|&x| if x > 0.0 { 1.0 } else { 0.0 })
.collect();
let c = GpuTensor::from_parts(vec![3], TensorDtype::Float32, 0, relu_data, false, None);
let mask = SavedTensor {
id: TensorId::new(),
shape: vec![3],
dtype: TensorDtype::Float32,
data: mask_data,
};
tape.record(
c.id(),
GradFn::Relu {
input: a.id(),
mask,
},
);
let loss_val: f64 = c.host_data().iter().sum();
let loss = GpuTensor::from_parts(
vec![1],
TensorDtype::Float32,
0,
vec![loss_val],
false,
None,
);
tape.record(
loss.id(),
GradFn::Sum {
input: c.id(),
input_shape: vec![3],
},
);
let mut tensors = HashMap::new();
let a_id = a.id();
tensors.insert(a_id, a);
tensors.insert(c.id(), c);
tensors.insert(loss.id(), loss.clone());
tape.backward(loss.id(), &mut tensors).unwrap();
let a_grad = tensors.get(&a_id).and_then(|t| t.grad());
if let Some(g) = a_grad {
assert!((g.host_data()[0] - 0.0).abs() < 1e-10); assert!((g.host_data()[1] - 1.0).abs() < 1e-10); assert!((g.host_data()[2] - 0.0).abs() < 1e-10); }
}
#[test]
fn conv2d_backward_col2im_correctness() {
let mut tape = AutogradTape::new();
let mut input = GpuTensor::from_host_f64(
&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0],
&[1, 1, 3, 3],
0,
)
.unwrap();
input.set_requires_grad(true);
let weight = GpuTensor::from_host_f64(&[1.0, 1.0, 1.0, 1.0], &[1, 1, 2, 2], 0).unwrap();
let output = conv2d(&input, &weight, None, (1, 1), (0, 0), Some(&mut tape)).unwrap();
let input_id = input.id();
let output_id = output.id();
let mut tensors = HashMap::new();
tensors.insert(input_id, input);
tensors.insert(output_id, output);
tape.backward(output_id, &mut tensors).unwrap();
let input_grad = tensors
.get(&input_id)
.and_then(GpuTensor::grad)
.map(|grad| grad.host_data().to_vec())
.unwrap();
assert_eq!(
input_grad,
vec![1.0, 2.0, 1.0, 2.0, 4.0, 2.0, 1.0, 2.0, 1.0]
);
}
}