use super::tensor::Tensor;
pub trait GradFn: Send + Sync {
fn backward(&self, grad_output: &Tensor) -> Vec<Tensor>;
fn name(&self) -> &'static str;
}
pub(crate) struct AddBackward {
pub(crate) x_shape: Vec<usize>,
pub(crate) y_shape: Vec<usize>,
}
impl GradFn for AddBackward {
fn backward(&self, grad_output: &Tensor) -> Vec<Tensor> {
let grad_x = maybe_reduce_grad(grad_output, &self.x_shape);
let grad_y = maybe_reduce_grad(grad_output, &self.y_shape);
vec![grad_x, grad_y]
}
fn name(&self) -> &'static str {
"AddBackward"
}
}
pub(crate) struct SubBackward {
pub(crate) x_shape: Vec<usize>,
pub(crate) y_shape: Vec<usize>,
}
impl GradFn for SubBackward {
fn backward(&self, grad_output: &Tensor) -> Vec<Tensor> {
let grad_x = maybe_reduce_grad(grad_output, &self.x_shape);
let grad_y_data: Vec<f32> = grad_output.data().iter().map(|&g| -g).collect();
let grad_y_full = Tensor::new(&grad_y_data, grad_output.shape());
let grad_y = maybe_reduce_grad(&grad_y_full, &self.y_shape);
vec![grad_x, grad_y]
}
fn name(&self) -> &'static str {
"SubBackward"
}
}
pub(crate) struct MulBackward {
pub(crate) x: Tensor,
pub(crate) y: Tensor,
}
impl GradFn for MulBackward {
fn backward(&self, grad_output: &Tensor) -> Vec<Tensor> {
let grad_x_data: Vec<f32> = grad_output
.data()
.iter()
.zip(self.y.data().iter())
.map(|(&g, &y)| g * y)
.collect();
let grad_y_data: Vec<f32> = grad_output
.data()
.iter()
.zip(self.x.data().iter())
.map(|(&g, &x)| g * x)
.collect();
let grad_x = maybe_reduce_grad(
&Tensor::new(&grad_x_data, grad_output.shape()),
self.x.shape(),
);
let grad_y = maybe_reduce_grad(
&Tensor::new(&grad_y_data, grad_output.shape()),
self.y.shape(),
);
vec![grad_x, grad_y]
}
fn name(&self) -> &'static str {
"MulBackward"
}
}
pub(crate) struct DivBackward {
pub(crate) x: Tensor,
pub(crate) y: Tensor,
}
impl GradFn for DivBackward {
fn backward(&self, grad_output: &Tensor) -> Vec<Tensor> {
let grad_x_data: Vec<f32> = grad_output
.data()
.iter()
.zip(self.y.data().iter())
.map(|(&g, &y)| g / y)
.collect();
let grad_y_data: Vec<f32> = grad_output
.data()
.iter()
.zip(self.x.data().iter())
.zip(self.y.data().iter())
.map(|((&g, &x), &y)| -g * x / (y * y))
.collect();
let grad_x = maybe_reduce_grad(
&Tensor::new(&grad_x_data, grad_output.shape()),
self.x.shape(),
);
let grad_y = maybe_reduce_grad(
&Tensor::new(&grad_y_data, grad_output.shape()),
self.y.shape(),
);
vec![grad_x, grad_y]
}
fn name(&self) -> &'static str {
"DivBackward"
}
}
pub(crate) struct NegBackward;
impl GradFn for NegBackward {
fn backward(&self, grad_output: &Tensor) -> Vec<Tensor> {
let grad_data: Vec<f32> = grad_output.data().iter().map(|&g| -g).collect();
vec![Tensor::new(&grad_data, grad_output.shape())]
}
fn name(&self) -> &'static str {
"NegBackward"
}
}
pub(crate) struct ExpBackward {
pub(crate) output: Tensor, }
impl GradFn for ExpBackward {
fn backward(&self, grad_output: &Tensor) -> Vec<Tensor> {
let grad_data: Vec<f32> = grad_output
.data()
.iter()
.zip(self.output.data().iter())
.map(|(&g, &exp_x)| g * exp_x)
.collect();
vec![Tensor::new(&grad_data, grad_output.shape())]
}
fn name(&self) -> &'static str {
"ExpBackward"
}
}
pub(crate) struct LogBackward {
pub(crate) x: Tensor,
}
impl GradFn for LogBackward {
fn backward(&self, grad_output: &Tensor) -> Vec<Tensor> {
let grad_data: Vec<f32> = grad_output
.data()
.iter()
.zip(self.x.data().iter())
.map(|(&g, &x)| g / x)
.collect();
vec![Tensor::new(&grad_data, grad_output.shape())]
}
fn name(&self) -> &'static str {
"LogBackward"
}
}
pub(crate) struct PowBackward {
pub(crate) x: Tensor,
pub(crate) n: f32,
}
impl GradFn for PowBackward {
fn backward(&self, grad_output: &Tensor) -> Vec<Tensor> {
let grad_data: Vec<f32> = grad_output
.data()
.iter()
.zip(self.x.data().iter())
.map(|(&g, &x)| g * self.n * x.powf(self.n - 1.0))
.collect();
vec![Tensor::new(&grad_data, grad_output.shape())]
}
fn name(&self) -> &'static str {
"PowBackward"
}
}
pub(crate) struct SqrtBackward {
pub(crate) output: Tensor, }
impl GradFn for SqrtBackward {
fn backward(&self, grad_output: &Tensor) -> Vec<Tensor> {
let grad_data: Vec<f32> = grad_output
.data()
.iter()
.zip(self.output.data().iter())
.map(|(&g, &sqrt_x)| g * 0.5 / sqrt_x)
.collect();
vec![Tensor::new(&grad_data, grad_output.shape())]
}
fn name(&self) -> &'static str {
"SqrtBackward"
}
}
pub(crate) struct SumBackward {
pub(crate) input_shape: Vec<usize>,
}
impl GradFn for SumBackward {
fn backward(&self, grad_output: &Tensor) -> Vec<Tensor> {
let g = grad_output.item();
let numel: usize = self.input_shape.iter().product();
vec![Tensor::new(&vec![g; numel], &self.input_shape)]
}
fn name(&self) -> &'static str {
"SumBackward"
}
}
pub(crate) struct MeanBackward {
pub(crate) input_shape: Vec<usize>,
}
impl GradFn for MeanBackward {
fn backward(&self, grad_output: &Tensor) -> Vec<Tensor> {
let g = grad_output.item();
let numel: usize = self.input_shape.iter().product();
let grad_val = g / numel as f32;
vec![Tensor::new(&vec![grad_val; numel], &self.input_shape)]
}
fn name(&self) -> &'static str {
"MeanBackward"
}
}
pub(crate) struct ReluBackward {
pub(crate) x: Tensor,
}
impl GradFn for ReluBackward {
fn backward(&self, grad_output: &Tensor) -> Vec<Tensor> {
let grad_data: Vec<f32> = grad_output
.data()
.iter()
.zip(self.x.data().iter())
.map(|(&g, &x)| if x > 0.0 { g } else { 0.0 })
.collect();
vec![Tensor::new(&grad_data, grad_output.shape())]
}
fn name(&self) -> &'static str {
"ReluBackward"
}
}
pub(crate) struct LeakyReluBackward {
pub(crate) x: Tensor,
pub(crate) negative_slope: f32,
}
impl GradFn for LeakyReluBackward {
fn backward(&self, grad_output: &Tensor) -> Vec<Tensor> {
let grad_data: Vec<f32> = grad_output
.data()
.iter()
.zip(self.x.data().iter())
.map(|(&g, &x)| if x > 0.0 { g } else { g * self.negative_slope })
.collect();
vec![Tensor::new(&grad_data, grad_output.shape())]
}
fn name(&self) -> &'static str {
"LeakyReluBackward"
}
}
pub(crate) struct GeluBackward {
pub(crate) x: Tensor,
}
impl GradFn for GeluBackward {
fn backward(&self, grad_output: &Tensor) -> Vec<Tensor> {
let sqrt_2_over_pi = (2.0_f32 / std::f32::consts::PI).sqrt();
let grad_data: Vec<f32> = grad_output
.data()
.iter()
.zip(self.x.data().iter())
.map(|(&g, &x)| {
let inner = sqrt_2_over_pi * (x + 0.044715 * x.powi(3));
let tanh_inner = inner.tanh();
let inner_deriv = sqrt_2_over_pi * (1.0 + 3.0 * 0.044715 * x.powi(2));
let gelu_deriv =
0.5 * (1.0 + tanh_inner) + 0.5 * x * (1.0 - tanh_inner.powi(2)) * inner_deriv;
g * gelu_deriv
})
.collect();
vec![Tensor::new(&grad_data, grad_output.shape())]
}
fn name(&self) -> &'static str {
"GeluBackward"
}
}
pub(crate) struct SoftmaxBackward {
pub(crate) output: Tensor, }
impl GradFn for SoftmaxBackward {
fn backward(&self, grad_output: &Tensor) -> Vec<Tensor> {
assert_eq!(self.output.ndim(), 2, "SoftmaxBackward expects 2D tensor");
let (batch, features) = (self.output.shape()[0], self.output.shape()[1]);
let mut grad_input = vec![0.0; batch * features];
let out_data = self.output.data();
let grad_data = grad_output.data();
for b in 0..batch {
let row_start = b * features;
let mut dot_product = 0.0;
for j in 0..features {
dot_product += grad_data[row_start + j] * out_data[row_start + j];
}
for j in 0..features {
let idx = row_start + j;
grad_input[idx] = out_data[idx] * (grad_data[idx] - dot_product);
}
}
vec![Tensor::new(&grad_input, grad_output.shape())]
}
fn name(&self) -> &'static str {
"SoftmaxBackward"
}
}
pub(crate) struct CrossEntropyBackward {
pub(crate) softmax_output: Tensor, pub(crate) targets: Vec<usize>, }
include!("gradient.rs");
include!("grad_fn_tests.rs");