use std::collections::HashMap;
use super::error::TensorError;
use super::tensor::{GpuTensor, SavedTensor, TensorId};
#[derive(Debug, Clone)]
pub enum GradFn {
Add {
lhs: TensorId,
rhs: TensorId,
},
Sub {
lhs: TensorId,
rhs: TensorId,
},
Mul {
lhs: TensorId,
rhs: TensorId,
lhs_data: SavedTensor,
rhs_data: SavedTensor,
},
Div {
lhs: TensorId,
rhs: TensorId,
lhs_data: SavedTensor,
rhs_data: SavedTensor,
},
Neg {
input: TensorId,
},
Abs {
input: TensorId,
input_data: SavedTensor,
},
Relu {
input: TensorId,
mask: SavedTensor,
},
LeakyRelu {
input: TensorId,
alpha: f64,
input_data: SavedTensor,
},
Sigmoid {
input: TensorId,
output: SavedTensor,
},
Tanh {
input: TensorId,
output: SavedTensor,
},
Gelu {
input: TensorId,
input_data: SavedTensor,
},
Silu {
input: TensorId,
input_data: SavedTensor,
},
Exp {
input: TensorId,
output: SavedTensor,
},
Log {
input: TensorId,
input_data: SavedTensor,
},
Sqrt {
input: TensorId,
output: SavedTensor,
},
Pow {
input: TensorId,
exponent: f64,
input_data: SavedTensor,
},
MatMul {
lhs: TensorId,
rhs: TensorId,
lhs_data: SavedTensor,
rhs_data: SavedTensor,
},
Softmax {
input: TensorId,
output: SavedTensor,
dim: usize,
},
LogSoftmax {
input: TensorId,
output: SavedTensor,
dim: usize,
},
Sum {
input: TensorId,
input_shape: Vec<usize>,
},
Mean {
input: TensorId,
input_shape: Vec<usize>,
},
Max {
input: TensorId,
indices: Vec<usize>,
input_shape: Vec<usize>,
},
Min {
input: TensorId,
indices: Vec<usize>,
input_shape: Vec<usize>,
},
BatchNorm {
input: TensorId,
normalized: SavedTensor,
std_inv: Vec<f64>,
gamma: Vec<f64>,
num_channels: usize,
},
LayerNorm {
input: TensorId,
normalized: SavedTensor,
std_inv: Vec<f64>,
gamma: Vec<f64>,
norm_shape: Vec<usize>,
},
CrossEntropy {
input: TensorId,
probs: SavedTensor,
targets: Vec<usize>,
},
MseLoss {
input: TensorId,
target: TensorId,
diff: SavedTensor,
},
L1Loss {
input: TensorId,
sign_diff: SavedTensor,
},
SmoothL1Loss {
input: TensorId,
beta: f64,
diff: SavedTensor,
},
NllLoss {
input: TensorId,
targets: Vec<usize>,
batch_size: usize,
},
Conv2d {
input: TensorId,
col_data: SavedTensor,
weight_data: SavedTensor,
input_shape: Vec<usize>,
kernel_size: (usize, usize),
stride: (usize, usize),
padding: (usize, usize),
},
MaxPool2d {
input: TensorId,
indices: Vec<usize>,
input_shape: Vec<usize>,
},
AvgPool2d {
input: TensorId,
input_shape: Vec<usize>,
kernel_size: (usize, usize),
stride: (usize, usize),
padding: (usize, usize),
},
GroupNorm {
input: TensorId,
normalized: SavedTensor,
std_inv: Vec<f64>,
gamma: Vec<f64>,
num_groups: usize,
},
}
use std::cell::Cell;
thread_local! {
static GRAD_ENABLED: Cell<bool> = const { Cell::new(true) };
}
#[must_use]
pub fn is_grad_enabled() -> bool {
GRAD_ENABLED.with(|c| c.get())
}
pub struct NoGradGuard {
prev: bool,
}
#[must_use]
pub fn no_grad() -> NoGradGuard {
let prev = GRAD_ENABLED.with(|c| c.replace(false));
NoGradGuard { prev }
}
impl Drop for NoGradGuard {
fn drop(&mut self) {
GRAD_ENABLED.with(|c| c.set(self.prev));
}
}
#[derive(Debug, Clone)]
pub struct TapeEntry {
pub output_id: TensorId,
pub grad_fn: GradFn,
}
#[derive(Debug, Clone, Default)]
pub struct AutogradTape {
entries: Vec<TapeEntry>,
}
impl AutogradTape {
#[must_use]
pub fn new() -> Self {
Self {
entries: Vec::new(),
}
}
pub fn record(&mut self, output_id: TensorId, grad_fn: GradFn) {
self.entries.push(TapeEntry { output_id, grad_fn });
}
#[must_use]
pub fn len(&self) -> usize {
self.entries.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
pub fn clear(&mut self) {
self.entries.clear();
}
pub fn backward(
&self,
loss_id: TensorId,
tensors: &mut HashMap<TensorId, GpuTensor>,
) -> Result<(), TensorError> {
let loss_shape = tensors
.get(&loss_id)
.ok_or_else(|| TensorError::AutogradError("loss tensor not found on tape".into()))?
.shape()
.to_vec();
let loss_numel = tensors
.get(&loss_id)
.ok_or_else(|| TensorError::AutogradError("loss missing".into()))?
.numel();
let mut grads: HashMap<TensorId, Vec<f64>> = HashMap::new();
grads.insert(loss_id, vec![1.0; loss_numel]);
for entry in self.entries.iter().rev() {
let out_grad = match grads.get(&entry.output_id) {
Some(g) => g.clone(),
None => continue, };
backward_one(&entry.grad_fn, &out_grad, &loss_shape, &mut grads)?;
}
for (tid, grad_data) in &grads {
if let Some(tensor) = tensors.get_mut(tid) {
if tensor.requires_grad() {
let grad_tensor = GpuTensor::from_parts(
tensor.shape().to_vec(),
tensor.dtype(),
tensor.device_id(),
grad_data.clone(),
false,
None,
);
tensor.accumulate_grad(&grad_tensor)?;
}
}
}
Ok(())
}
}
#[allow(clippy::too_many_lines)]
fn backward_one(
grad_fn: &GradFn,
out_grad: &[f64],
_loss_shape: &[usize],
grads: &mut HashMap<TensorId, Vec<f64>>,
) -> Result<(), TensorError> {
match grad_fn {
GradFn::Add { lhs, rhs } => {
accumulate(grads, *lhs, out_grad);
accumulate(grads, *rhs, out_grad);
}
GradFn::Sub { lhs, rhs } => {
accumulate(grads, *lhs, out_grad);
let neg: Vec<f64> = out_grad.iter().map(|&g| -g).collect();
accumulate(grads, *rhs, &neg);
}
GradFn::Mul {
lhs,
rhs,
lhs_data,
rhs_data,
} => {
let da: Vec<f64> = out_grad
.iter()
.zip(rhs_data.data.iter())
.map(|(&g, &b)| g * b)
.collect();
accumulate(grads, *lhs, &da);
let db: Vec<f64> = out_grad
.iter()
.zip(lhs_data.data.iter())
.map(|(&g, &a)| g * a)
.collect();
accumulate(grads, *rhs, &db);
}
GradFn::Div {
lhs,
rhs,
lhs_data,
rhs_data,
} => {
let da: Vec<f64> = out_grad
.iter()
.zip(rhs_data.data.iter())
.map(|(&g, &b)| if b.abs() > 1e-30 { g / b } else { 0.0 })
.collect();
accumulate(grads, *lhs, &da);
let db: Vec<f64> = out_grad
.iter()
.zip(lhs_data.data.iter())
.zip(rhs_data.data.iter())
.map(|((&g, &a), &b)| {
if b.abs() > 1e-30 {
-g * a / (b * b)
} else {
0.0
}
})
.collect();
accumulate(grads, *rhs, &db);
}
GradFn::Neg { input } => {
let neg: Vec<f64> = out_grad.iter().map(|&g| -g).collect();
accumulate(grads, *input, &neg);
}
GradFn::Abs { input, input_data } => {
let da: Vec<f64> = out_grad
.iter()
.zip(input_data.data.iter())
.map(|(&g, &x)| {
if x > 0.0 {
g
} else if x < 0.0 {
-g
} else {
0.0
}
})
.collect();
accumulate(grads, *input, &da);
}
GradFn::Relu { input, mask } => {
let da: Vec<f64> = out_grad
.iter()
.zip(mask.data.iter())
.map(|(&g, &m)| g * m)
.collect();
accumulate(grads, *input, &da);
}
GradFn::LeakyRelu {
input,
alpha,
input_data,
} => {
let da: Vec<f64> = out_grad
.iter()
.zip(input_data.data.iter())
.map(|(&g, &x)| if x > 0.0 { g } else { g * alpha })
.collect();
accumulate(grads, *input, &da);
}
GradFn::Sigmoid { input, output } => {
let da: Vec<f64> = out_grad
.iter()
.zip(output.data.iter())
.map(|(&g, &z)| g * z * (1.0 - z))
.collect();
accumulate(grads, *input, &da);
}
GradFn::Tanh { input, output } => {
let da: Vec<f64> = out_grad
.iter()
.zip(output.data.iter())
.map(|(&g, &z)| g * (1.0 - z * z))
.collect();
accumulate(grads, *input, &da);
}
GradFn::Gelu { input, input_data } => {
let sqrt_2_over_pi = (2.0_f64 / std::f64::consts::PI).sqrt();
let da: Vec<f64> = out_grad
.iter()
.zip(input_data.data.iter())
.map(|(&g, &x)| {
let inner = sqrt_2_over_pi * (x + 0.044715 * x * x * x);
let tanh_inner = inner.tanh();
let cdf = 0.5 * (1.0 + tanh_inner);
let pdf = sqrt_2_over_pi
* (1.0 + 3.0 * 0.044715 * x * x)
* (1.0 - tanh_inner * tanh_inner);
g * (cdf + 0.5 * x * pdf)
})
.collect();
accumulate(grads, *input, &da);
}
GradFn::Silu { input, input_data } => {
let da: Vec<f64> = out_grad
.iter()
.zip(input_data.data.iter())
.map(|(&g, &x)| {
let s = 1.0 / (1.0 + (-x).exp());
g * (s * (1.0 + x * (1.0 - s)))
})
.collect();
accumulate(grads, *input, &da);
}
GradFn::Exp { input, output } => {
let da: Vec<f64> = out_grad
.iter()
.zip(output.data.iter())
.map(|(&g, &z)| g * z)
.collect();
accumulate(grads, *input, &da);
}
GradFn::Log { input, input_data } => {
let da: Vec<f64> = out_grad
.iter()
.zip(input_data.data.iter())
.map(|(&g, &x)| if x.abs() > 1e-30 { g / x } else { 0.0 })
.collect();
accumulate(grads, *input, &da);
}
GradFn::Sqrt { input, output } => {
let da: Vec<f64> = out_grad
.iter()
.zip(output.data.iter())
.map(|(&g, &z)| if z.abs() > 1e-30 { g / (2.0 * z) } else { 0.0 })
.collect();
accumulate(grads, *input, &da);
}
GradFn::Pow {
input,
exponent,
input_data,
} => {
let da: Vec<f64> = out_grad
.iter()
.zip(input_data.data.iter())
.map(|(&g, &x)| g * exponent * x.powf(exponent - 1.0))
.collect();
accumulate(grads, *input, &da);
}
GradFn::MatMul {
lhs,
rhs,
lhs_data,
rhs_data,
} => {
backward_matmul(out_grad, lhs_data, rhs_data, *lhs, *rhs, grads)?;
}
GradFn::Softmax {
input,
output,
dim: _,
} => {
let n = output.data.len();
let dot: f64 = out_grad
.iter()
.zip(output.data.iter())
.map(|(&g, &s)| g * s)
.sum();
let da: Vec<f64> = (0..n)
.map(|i| output.data[i] * (out_grad[i] - dot))
.collect();
accumulate(grads, *input, &da);
}
GradFn::LogSoftmax {
input,
output,
dim: _,
} => {
let sum_dz: f64 = out_grad.iter().sum();
let da: Vec<f64> = output
.data
.iter()
.zip(out_grad.iter())
.map(|(&log_s, &g)| g - log_s.exp() * sum_dz)
.collect();
accumulate(grads, *input, &da);
}
GradFn::Sum { input, input_shape } => {
let numel: usize = input_shape.iter().product();
let expanded = vec![out_grad.first().copied().unwrap_or(0.0); numel];
accumulate(grads, *input, &expanded);
}
GradFn::Mean { input, input_shape } => {
let numel: usize = input_shape.iter().product();
let scale = if numel > 0 { 1.0 / numel as f64 } else { 0.0 };
let expanded = vec![out_grad.first().copied().unwrap_or(0.0) * scale; numel];
accumulate(grads, *input, &expanded);
}
GradFn::Max {
input,
indices,
input_shape,
} => {
let numel: usize = input_shape.iter().product();
let mut da = vec![0.0; numel];
for (i, &idx) in indices.iter().enumerate() {
if idx < numel && i < out_grad.len() {
da[idx] += out_grad[i];
}
}
accumulate(grads, *input, &da);
}
GradFn::Min {
input,
indices,
input_shape,
} => {
let numel: usize = input_shape.iter().product();
let mut da = vec![0.0; numel];
for (i, &idx) in indices.iter().enumerate() {
if idx < numel && i < out_grad.len() {
da[idx] += out_grad[i];
}
}
accumulate(grads, *input, &da);
}
GradFn::BatchNorm {
input,
normalized,
std_inv,
gamma,
num_channels,
} => {
backward_batch_norm(
out_grad,
normalized,
std_inv,
gamma,
*num_channels,
*input,
grads,
);
}
GradFn::LayerNorm {
input,
normalized,
std_inv,
gamma,
norm_shape,
} => {
backward_layer_norm(
out_grad, normalized, std_inv, gamma, norm_shape, *input, grads,
);
}
GradFn::CrossEntropy {
input,
probs,
targets,
} => {
backward_cross_entropy(out_grad, probs, targets, *input, grads);
}
GradFn::MseLoss {
input,
target: _,
diff,
} => {
let n = diff.data.len();
let scale = if n > 0 { 2.0 / n as f64 } else { 0.0 };
let da: Vec<f64> = diff
.data
.iter()
.zip(out_grad.iter().cycle())
.map(|(&d, &g)| g * scale * d)
.collect();
accumulate(grads, *input, &da);
}
GradFn::L1Loss { input, sign_diff } => {
let n = sign_diff.data.len();
let scale = if n > 0 { 1.0 / n as f64 } else { 0.0 };
let da: Vec<f64> = sign_diff
.data
.iter()
.zip(out_grad.iter().cycle())
.map(|(&s, &g)| g * scale * s)
.collect();
accumulate(grads, *input, &da);
}
GradFn::SmoothL1Loss { input, beta, diff } => {
let n = diff.data.len();
let scale = if n > 0 { 1.0 / n as f64 } else { 0.0 };
let da: Vec<f64> = diff
.data
.iter()
.zip(out_grad.iter().cycle())
.map(|(&d, &g)| {
if d.abs() < *beta {
g * scale * d / beta
} else {
g * scale * d.signum()
}
})
.collect();
accumulate(grads, *input, &da);
}
GradFn::NllLoss {
input,
targets,
batch_size,
} => {
backward_nll_loss(out_grad, targets, *batch_size, *input, grads);
}
GradFn::Conv2d {
input,
col_data: _,
weight_data: _,
input_shape,
kernel_size: _,
stride: _,
padding: _,
} => {
let numel: usize = input_shape.iter().product();
let da = vec![0.0; numel]; accumulate(grads, *input, &da);
}
GradFn::MaxPool2d {
input,
indices,
input_shape,
} => {
let numel: usize = input_shape.iter().product();
let mut da = vec![0.0; numel];
for (i, &idx) in indices.iter().enumerate() {
if idx < numel && i < out_grad.len() {
da[idx] += out_grad[i];
}
}
accumulate(grads, *input, &da);
}
GradFn::AvgPool2d {
input,
input_shape,
kernel_size,
stride,
padding: _,
} => {
backward_avg_pool2d(out_grad, input_shape, *kernel_size, *stride, *input, grads);
}
GradFn::GroupNorm {
input,
normalized,
std_inv,
gamma,
num_groups,
} => {
backward_group_norm(
out_grad,
normalized,
std_inv,
gamma,
*num_groups,
*input,
grads,
);
}
}
Ok(())
}
fn accumulate(grads: &mut HashMap<TensorId, Vec<f64>>, id: TensorId, grad: &[f64]) {
let entry = grads.entry(id).or_insert_with(|| vec![0.0; grad.len()]);
if entry.len() == grad.len() {
for (a, b) in entry.iter_mut().zip(grad.iter()) {
*a += b;
}
}
}
fn backward_matmul(
out_grad: &[f64],
lhs_data: &SavedTensor,
rhs_data: &SavedTensor,
lhs_id: TensorId,
rhs_id: TensorId,
grads: &mut HashMap<TensorId, Vec<f64>>,
) -> Result<(), TensorError> {
let lhs_shape = &lhs_data.shape;
let rhs_shape = &rhs_data.shape;
if lhs_shape.len() != 2 || rhs_shape.len() != 2 {
return Err(TensorError::AutogradError(
"matmul backward requires 2D tensors".into(),
));
}
let m = lhs_shape[0];
let k = lhs_shape[1];
let n = rhs_shape[1];
let mut da = vec![0.0; m * k];
for i in 0..m {
for j in 0..k {
let mut sum = 0.0;
for l in 0..n {
sum += out_grad[i * n + l] * rhs_data.data[j * n + l];
}
da[i * k + j] = sum;
}
}
accumulate(grads, lhs_id, &da);
let mut db = vec![0.0; k * n];
for i in 0..k {
for j in 0..n {
let mut sum = 0.0;
for l in 0..m {
sum += lhs_data.data[l * k + i] * out_grad[l * n + j];
}
db[i * n + j] = sum;
}
}
accumulate(grads, rhs_id, &db);
Ok(())
}
fn backward_cross_entropy(
out_grad: &[f64],
probs: &SavedTensor,
targets: &[usize],
input_id: TensorId,
grads: &mut HashMap<TensorId, Vec<f64>>,
) {
let batch_size = targets.len();
let num_classes = match probs.data.len().checked_div(batch_size) {
Some(n) => n,
None => return,
};
let mut da = probs.data.clone();
for (b, &t) in targets.iter().enumerate() {
if t < num_classes {
da[b * num_classes + t] -= 1.0;
}
}
let scale = if batch_size > 0 {
out_grad.first().copied().unwrap_or(1.0) / batch_size as f64
} else {
0.0
};
for v in &mut da {
*v *= scale;
}
accumulate(grads, input_id, &da);
}
fn backward_nll_loss(
out_grad: &[f64],
targets: &[usize],
batch_size: usize,
input_id: TensorId,
grads: &mut HashMap<TensorId, Vec<f64>>,
) {
let num_classes = if batch_size > 0 && !targets.is_empty() {
targets.iter().copied().max().unwrap_or(0) + 1
} else {
return;
};
let mut da = vec![0.0; batch_size * num_classes];
let scale = if batch_size > 0 {
out_grad.first().copied().unwrap_or(1.0) / batch_size as f64
} else {
0.0
};
for (b, &t) in targets.iter().enumerate() {
if t < num_classes {
da[b * num_classes + t] = -scale;
}
}
accumulate(grads, input_id, &da);
}
#[allow(clippy::too_many_arguments)]
fn backward_batch_norm(
out_grad: &[f64],
normalized: &SavedTensor,
std_inv: &[f64],
gamma: &[f64],
num_channels: usize,
input_id: TensorId,
grads: &mut HashMap<TensorId, Vec<f64>>,
) {
let total = normalized.data.len();
let per_channel = match total.checked_div(num_channels) {
Some(n) => n,
None => return,
};
let mut da = vec![0.0; total];
for c in 0..num_channels {
let g = gamma.get(c).copied().unwrap_or(1.0);
let inv = std_inv.get(c).copied().unwrap_or(1.0);
let start = c * per_channel;
let end = start + per_channel;
let mean_dz: f64 = out_grad[start..end].iter().sum::<f64>() / per_channel as f64;
let mean_dz_xhat: f64 = out_grad[start..end]
.iter()
.zip(normalized.data[start..end].iter())
.map(|(&dz, &xh)| dz * xh)
.sum::<f64>()
/ per_channel as f64;
for i in start..end {
da[i] = g * inv * (out_grad[i] - mean_dz - normalized.data[i] * mean_dz_xhat);
}
}
accumulate(grads, input_id, &da);
}
#[allow(clippy::too_many_arguments)]
fn backward_layer_norm(
out_grad: &[f64],
normalized: &SavedTensor,
std_inv: &[f64],
gamma: &[f64],
norm_shape: &[usize],
input_id: TensorId,
grads: &mut HashMap<TensorId, Vec<f64>>,
) {
let total = normalized.data.len();
let norm_size: usize = norm_shape.iter().product();
if norm_size == 0 {
return;
}
let num_instances = total / norm_size;
let mut da = vec![0.0; total];
for inst in 0..num_instances {
let start = inst * norm_size;
let end = start + norm_size;
let inv = std_inv.get(inst).copied().unwrap_or(1.0);
let mean_dz: f64 = out_grad[start..end]
.iter()
.zip(gamma.iter().cycle())
.map(|(&dz, &g)| dz * g)
.sum::<f64>()
/ norm_size as f64;
let mean_dz_xhat: f64 = out_grad[start..end]
.iter()
.zip(gamma.iter().cycle())
.zip(normalized.data[start..end].iter())
.map(|((&dz, &g), &xh)| dz * g * xh)
.sum::<f64>()
/ norm_size as f64;
for (i, idx) in (start..end).enumerate() {
let g = gamma.get(i % gamma.len()).copied().unwrap_or(1.0);
da[idx] = inv * (g * out_grad[idx] - mean_dz - normalized.data[idx] * mean_dz_xhat);
}
}
accumulate(grads, input_id, &da);
}
fn backward_avg_pool2d(
out_grad: &[f64],
input_shape: &[usize],
kernel_size: (usize, usize),
stride: (usize, usize),
input_id: TensorId,
grads: &mut HashMap<TensorId, Vec<f64>>,
) {
if input_shape.len() != 4 {
return;
}
let (n, c, h, _w) = (
input_shape[0],
input_shape[1],
input_shape[2],
input_shape[3],
);
let numel: usize = input_shape.iter().product();
let out_h = (h - kernel_size.0) / stride.0 + 1;
let out_w = (input_shape[3] - kernel_size.1) / stride.1 + 1;
let pool_size = (kernel_size.0 * kernel_size.1) as f64;
let mut da = vec![0.0; numel];
for batch in 0..n {
for ch in 0..c {
for oh in 0..out_h {
for ow in 0..out_w {
let out_idx = ((batch * c + ch) * out_h + oh) * out_w + ow;
let g = if out_idx < out_grad.len() {
out_grad[out_idx]
} else {
0.0
};
let val = g / pool_size;
for kh in 0..kernel_size.0 {
for kw in 0..kernel_size.1 {
let ih = oh * stride.0 + kh;
let iw = ow * stride.1 + kw;
let in_idx =
((batch * c + ch) * input_shape[2] + ih) * input_shape[3] + iw;
if in_idx < numel {
da[in_idx] += val;
}
}
}
}
}
}
}
accumulate(grads, input_id, &da);
}
#[allow(clippy::too_many_arguments)]
fn backward_group_norm(
out_grad: &[f64],
normalized: &SavedTensor,
std_inv: &[f64],
gamma: &[f64],
num_groups: usize,
input_id: TensorId,
grads: &mut HashMap<TensorId, Vec<f64>>,
) {
let total = normalized.data.len();
if num_groups == 0 {
return;
}
let group_size = total / num_groups;
let mut da = vec![0.0; total];
for g in 0..num_groups {
let start = g * group_size;
let end = start + group_size;
let inv = std_inv.get(g).copied().unwrap_or(1.0);
let mean_dz: f64 = out_grad[start..end]
.iter()
.zip(gamma.iter().cycle())
.map(|(&dz, &gm)| dz * gm)
.sum::<f64>()
/ group_size as f64;
let mean_dz_xhat: f64 = out_grad[start..end]
.iter()
.zip(gamma.iter().cycle())
.zip(normalized.data[start..end].iter())
.map(|((&dz, &gm), &xh)| dz * gm * xh)
.sum::<f64>()
/ group_size as f64;
for (i, idx) in (start..end).enumerate() {
let gm = gamma.get(i % gamma.len()).copied().unwrap_or(1.0);
da[idx] = inv * (gm * out_grad[idx] - mean_dz - normalized.data[idx] * mean_dz_xhat);
}
}
accumulate(grads, input_id, &da);
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum CheckpointStrategy {
#[default]
StoreAll,
RecomputeAll,
EveryN(usize),
}
#[derive(Debug, Clone)]
pub struct GradientCheckpointing {
pub strategy: CheckpointStrategy,
pub enabled: bool,
}
impl GradientCheckpointing {
#[must_use]
pub fn new(strategy: CheckpointStrategy) -> Self {
Self {
strategy,
enabled: true,
}
}
#[must_use]
pub fn should_checkpoint(&self, layer_idx: usize) -> bool {
if !self.enabled {
return false;
}
match self.strategy {
CheckpointStrategy::StoreAll => false,
CheckpointStrategy::RecomputeAll => true,
CheckpointStrategy::EveryN(n) => n > 0 && layer_idx % n == 0,
}
}
}
impl Default for GradientCheckpointing {
fn default() -> Self {
Self {
strategy: CheckpointStrategy::StoreAll,
enabled: false,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor_backend::dtype::TensorDtype;
#[test]
fn test_no_grad_context() {
assert!(is_grad_enabled());
{
let _guard = no_grad();
assert!(!is_grad_enabled());
}
assert!(is_grad_enabled());
}
#[test]
fn test_tape_record() {
let mut tape = AutogradTape::new();
assert!(tape.is_empty());
let id_a = TensorId::new();
let id_b = TensorId::new();
let id_c = TensorId::new();
tape.record(
id_c,
GradFn::Add {
lhs: id_a,
rhs: id_b,
},
);
assert_eq!(tape.len(), 1);
tape.clear();
assert!(tape.is_empty());
}
#[test]
fn test_backward_add() {
let mut tape = AutogradTape::new();
let mut a = GpuTensor::from_host_f64(&[2.0, 3.0], &[2], 0).unwrap();
a.set_requires_grad(true);
let mut b = GpuTensor::from_host_f64(&[4.0, 5.0], &[2], 0).unwrap();
b.set_requires_grad(true);
let c_data: Vec<f64> = a
.host_data()
.iter()
.zip(b.host_data().iter())
.map(|(&x, &y)| x + y)
.collect();
let c = GpuTensor::from_parts(vec![2], TensorDtype::Float32, 0, c_data, false, None);
tape.record(
c.id(),
GradFn::Add {
lhs: a.id(),
rhs: b.id(),
},
);
let loss_val: f64 = c.host_data().iter().sum();
let loss = GpuTensor::from_parts(
vec![1],
TensorDtype::Float32,
0,
vec![loss_val],
false,
None,
);
tape.record(
loss.id(),
GradFn::Sum {
input: c.id(),
input_shape: vec![2],
},
);
let mut tensors = HashMap::new();
tensors.insert(a.id(), a);
tensors.insert(b.id(), b);
tensors.insert(c.id(), c);
tensors.insert(loss.id(), loss.clone());
tape.backward(loss.id(), &mut tensors).unwrap();
let a_grad = tensors.get(&TensorId(1)).map(|t| t.grad());
if let Some(Some(g)) = a_grad {
assert!((g.host_data()[0] - 1.0).abs() < 1e-10);
assert!((g.host_data()[1] - 1.0).abs() < 1e-10);
}
}
#[test]
fn test_backward_mul() {
let mut tape = AutogradTape::new();
let mut a = GpuTensor::from_host_f64(&[3.0], &[1], 0).unwrap();
a.set_requires_grad(true);
let mut b = GpuTensor::from_host_f64(&[5.0], &[1], 0).unwrap();
b.set_requires_grad(true);
let c_val = a.host_data()[0] * b.host_data()[0];
let c = GpuTensor::from_parts(vec![1], TensorDtype::Float32, 0, vec![c_val], false, None);
tape.record(
c.id(),
GradFn::Mul {
lhs: a.id(),
rhs: b.id(),
lhs_data: SavedTensor::from_tensor(&a),
rhs_data: SavedTensor::from_tensor(&b),
},
);
let mut tensors = HashMap::new();
let a_id = a.id();
let b_id = b.id();
let c_id = c.id();
tensors.insert(a_id, a);
tensors.insert(b_id, b);
tensors.insert(c_id, c);
tape.backward(c_id, &mut tensors).unwrap();
let a_grad = tensors.get(&a_id).and_then(|t| t.grad());
let b_grad = tensors.get(&b_id).and_then(|t| t.grad());
if let Some(g) = a_grad {
assert!((g.host_data()[0] - 5.0).abs() < 1e-10);
}
if let Some(g) = b_grad {
assert!((g.host_data()[0] - 3.0).abs() < 1e-10);
}
}
#[test]
fn test_checkpoint_strategy() {
let cp = GradientCheckpointing::new(CheckpointStrategy::EveryN(3));
assert!(cp.should_checkpoint(0));
assert!(!cp.should_checkpoint(1));
assert!(!cp.should_checkpoint(2));
assert!(cp.should_checkpoint(3));
let cp_all = GradientCheckpointing::new(CheckpointStrategy::RecomputeAll);
assert!(cp_all.should_checkpoint(0));
assert!(cp_all.should_checkpoint(99));
let cp_none = GradientCheckpointing::new(CheckpointStrategy::StoreAll);
assert!(!cp_none.should_checkpoint(0));
}
#[test]
fn test_backward_sigmoid() {
let mut tape = AutogradTape::new();
let mut a = GpuTensor::from_host_f64(&[0.0], &[1], 0).unwrap();
a.set_requires_grad(true);
let sig_val = 1.0 / (1.0 + (-a.host_data()[0]).exp());
let c = GpuTensor::from_parts(vec![1], TensorDtype::Float32, 0, vec![sig_val], false, None);
tape.record(
c.id(),
GradFn::Sigmoid {
input: a.id(),
output: SavedTensor::from_tensor(&c),
},
);
let mut tensors = HashMap::new();
let a_id = a.id();
let c_id = c.id();
tensors.insert(a_id, a);
tensors.insert(c_id, c);
tape.backward(c_id, &mut tensors).unwrap();
let a_grad = tensors.get(&a_id).and_then(|t| t.grad());
if let Some(g) = a_grad {
assert!((g.host_data()[0] - 0.25).abs() < 1e-10);
}
}
#[test]
fn test_backward_relu() {
let mut tape = AutogradTape::new();
let mut a = GpuTensor::from_host_f64(&[-1.0, 2.0, 0.0], &[3], 0).unwrap();
a.set_requires_grad(true);
let relu_data: Vec<f64> = a.host_data().iter().map(|&x| x.max(0.0)).collect();
let mask_data: Vec<f64> = a
.host_data()
.iter()
.map(|&x| if x > 0.0 { 1.0 } else { 0.0 })
.collect();
let c = GpuTensor::from_parts(vec![3], TensorDtype::Float32, 0, relu_data, false, None);
let mask = SavedTensor {
id: TensorId::new(),
shape: vec![3],
dtype: TensorDtype::Float32,
data: mask_data,
};
tape.record(
c.id(),
GradFn::Relu {
input: a.id(),
mask,
},
);
let loss_val: f64 = c.host_data().iter().sum();
let loss = GpuTensor::from_parts(
vec![1],
TensorDtype::Float32,
0,
vec![loss_val],
false,
None,
);
tape.record(
loss.id(),
GradFn::Sum {
input: c.id(),
input_shape: vec![3],
},
);
let mut tensors = HashMap::new();
let a_id = a.id();
tensors.insert(a_id, a);
tensors.insert(c.id(), c);
tensors.insert(loss.id(), loss.clone());
tape.backward(loss.id(), &mut tensors).unwrap();
let a_grad = tensors.get(&a_id).and_then(|t| t.grad());
if let Some(g) = a_grad {
assert!((g.host_data()[0] - 0.0).abs() < 1e-10); assert!((g.host_data()[1] - 1.0).abs() < 1e-10); assert!((g.host_data()[2] - 0.0).abs() < 1e-10); }
}
}