use crate::tensor::TensorId;
use crate::buffer::Buffer;
use crate::shape::Shape;
use crate::errors::{EtensorError, EtensorResult};
use crate::autograd::tape::TapeAction;
use crate::autograd::gradients::Gradients;
pub struct AddBackward {
pub output_id: TensorId,
pub lhs_id: Option<TensorId>, pub rhs_id: Option<TensorId>,
}
impl TapeAction for AddBackward {
fn backward(&self, grads: &mut Gradients) -> EtensorResult<()> {
let dy = grads.get(&self.output_id)
.ok_or_else(|| EtensorError::AutogradError(
format!("Gradient missing for Output ID {:?}", self.output_id)
))?
.clone();
if let Some(id) = self.lhs_id {
grads.insert(id, dy.clone())?;
}
if let Some(id) = self.rhs_id {
grads.insert(id, dy)?;
}
Ok(())
}
fn name(&self) -> String { "AddBackward".to_string() }
}
pub struct MulBackward {
pub output_id: TensorId,
pub lhs_id: Option<TensorId>,
pub rhs_id: Option<TensorId>,
pub lhs_data: Buffer,
pub rhs_data: Buffer,
}
impl TapeAction for MulBackward {
fn backward(&self, grads: &mut Gradients) -> EtensorResult<()> {
let dy_buf = grads.get(&self.output_id)
.ok_or_else(|| EtensorError::AutogradError("Gradient missing".to_string()))?
.clone();
let dy = dy_buf.as_f32_slice()?;
let a = self.lhs_data.as_f32_slice()?;
let b = self.rhs_data.as_f32_slice()?;
if let Some(id) = self.lhs_id {
let mut da = vec![0.0; dy.len()];
for i in 0..dy.len() { da[i] = dy[i] * b[i]; }
grads.insert(id, Buffer::from_f32_vec(da))?;
}
if let Some(id) = self.rhs_id {
let mut db = vec![0.0; dy.len()];
for i in 0..dy.len() { db[i] = dy[i] * a[i]; }
grads.insert(id, Buffer::from_f32_vec(db))?;
}
Ok(())
}
fn name(&self) -> String { "MulBackward".to_string() }
}
pub struct MatMulBackward {
pub output_id: TensorId,
pub lhs_id: Option<TensorId>,
pub rhs_id: Option<TensorId>,
pub lhs_data: Buffer,
pub rhs_data: Buffer,
pub lhs_shape: Shape,
pub rhs_shape: Shape,
}
impl TapeAction for MatMulBackward {
fn backward(&self, grads: &mut Gradients) -> EtensorResult<()> {
let dc_buf = grads.get(&self.output_id)
.ok_or_else(|| EtensorError::AutogradError("Gradient missing for MatMul Output".to_string()))?
.clone();
let dc = dc_buf.as_f32_slice()?;
let a = self.lhs_data.as_f32_slice()?;
let b = self.rhs_data.as_f32_slice()?;
let m = self.lhs_shape.dims[0];
let k = self.lhs_shape.dims[1];
let n = self.rhs_shape.dims[1];
let stride_a0 = self.lhs_shape.strides[0];
let stride_a1 = self.lhs_shape.strides[1];
let stride_b0 = self.rhs_shape.strides[0];
let stride_b1 = self.rhs_shape.strides[1];
if let Some(id) = self.lhs_id {
let mut da = vec![0.0; m * k];
for i in 0..m {
for j in 0..k {
let mut sum = 0.0;
for p in 0..n {
let idx_dc = i * n + p;
let idx_b = j * stride_b0 + p * stride_b1;
sum += dc[idx_dc] * b[idx_b];
}
da[i * k + j] = sum;
}
}
grads.insert(id, Buffer::from_f32_vec(da))?;
}
if let Some(id) = self.rhs_id {
let mut db = vec![0.0; k * n];
for i in 0..k {
for j in 0..n {
let mut sum = 0.0;
for p in 0..m {
let idx_a = p * stride_a0 + i * stride_a1;
let idx_dc = p * n + j;
sum += a[idx_a] * dc[idx_dc];
}
db[i * n + j] = sum;
}
}
grads.insert(id, Buffer::from_f32_vec(db))?;
}
Ok(())
}
fn name(&self) -> String { "MatMulBackward".to_string() }
}
pub struct SumAllBackward {
pub output_id: TensorId,
pub input_id: TensorId,
pub input_shape: Shape,
}
impl TapeAction for SumAllBackward {
fn backward(&self, grads: &mut Gradients) -> EtensorResult<()> {
let dy_buf = grads.get(&self.output_id)
.ok_or_else(|| EtensorError::AutogradError("Gradient missing for Sum Output".to_string()))?
.clone();
let dy_scalar = dy_buf.as_f32_slice()?[0];
let num_elements = self.input_shape.num_elements();
let dx = vec![dy_scalar; num_elements];
grads.insert(self.input_id, Buffer::from_f32_vec(dx))?;
Ok(())
}
fn name(&self) -> String { "SumAllBackward".to_string() }
}
pub struct ReluBackward {
pub output_id: TensorId,
pub input_id: TensorId,
pub input_data: Buffer,
}
impl TapeAction for ReluBackward {
fn backward(&self, grads: &mut Gradients) -> EtensorResult<()> {
let dy_buf = grads.get(&self.output_id)
.ok_or_else(|| EtensorError::AutogradError("Gradient missing for ReLU Output".to_string()))?
.clone();
let dy = dy_buf.as_f32_slice()?;
let x = self.input_data.as_f32_slice()?;
let mut dx = vec![0.0; dy.len()];
for i in 0..dy.len() {
dx[i] = if x[i] > 0.0 { dy[i] } else { 0.0 };
}
grads.insert(self.input_id, Buffer::from_f32_vec(dx))?;
Ok(())
}
fn name(&self) -> String { "ReluBackward".to_string() }
}
pub struct SigmoidBackward {
pub output_id: TensorId,
pub input_id: TensorId,
pub output_data: Buffer, }
impl TapeAction for SigmoidBackward {
fn backward(&self, grads: &mut Gradients) -> EtensorResult<()> {
let dy_buf = grads.get(&self.output_id)
.ok_or_else(|| EtensorError::AutogradError("Gradient missing for Sigmoid Output".to_string()))?
.clone();
let dy = dy_buf.as_f32_slice()?;
let y = self.output_data.as_f32_slice()?;
let mut dx = vec![0.0; dy.len()];
for i in 0..dy.len() {
dx[i] = dy[i] * y[i] * (1.0 - y[i]);
}
grads.insert(self.input_id, Buffer::from_f32_vec(dx))?;
Ok(())
}
fn name(&self) -> String { "SigmoidBackward".to_string() }
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_add_backward_logic() {
let mut grads = Gradients::new();
let out_id = TensorId::new();
let lhs_id = TensorId::new();
let rhs_id = TensorId::new();
grads.insert(out_id, Buffer::from_f32_vec(vec![5.0, 5.0])).unwrap();
let node = AddBackward {
output_id: out_id, lhs_id: Some(lhs_id), rhs_id: Some(rhs_id),
};
node.backward(&mut grads).unwrap();
assert_eq!(grads.get(&lhs_id).unwrap().as_f32_slice().unwrap(), &[5.0, 5.0]);
assert_eq!(grads.get(&rhs_id).unwrap().as_f32_slice().unwrap(), &[5.0, 5.0]);
}
#[test]
fn test_mul_backward_logic() {
let mut grads = Gradients::new();
let out_id = TensorId::new();
let lhs_id = TensorId::new();
let rhs_id = TensorId::new();
grads.insert(out_id, Buffer::from_f32_vec(vec![2.0, 2.0])).unwrap();
let node = MulBackward {
output_id: out_id, lhs_id: Some(lhs_id), rhs_id: Some(rhs_id),
lhs_data: Buffer::from_f32_vec(vec![3.0, 4.0]),
rhs_data: Buffer::from_f32_vec(vec![10.0, 20.0]),
};
node.backward(&mut grads).unwrap();
assert_eq!(grads.get(&lhs_id).unwrap().as_f32_slice().unwrap(), &[20.0, 40.0]);
assert_eq!(grads.get(&rhs_id).unwrap().as_f32_slice().unwrap(), &[6.0, 8.0]);
}
#[test]
fn test_matmul_backward_logic() {
let mut grads = Gradients::new();
let out_id = TensorId::new();
let lhs_id = TensorId::new();
let rhs_id = TensorId::new();
grads.insert(out_id, Buffer::from_f32_vec(vec![1.0, 1.0, 1.0, 1.0])).unwrap();
let a_shape = Shape::new(vec![2, 3]);
let a_data = Buffer::from_f32_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
let b_shape = Shape::new(vec![3, 2]);
let b_data = Buffer::from_f32_vec(vec![7.0, 8.0, 9.0, 1.0, 2.0, 3.0]);
let node = MatMulBackward {
output_id: out_id, lhs_id: Some(lhs_id), rhs_id: Some(rhs_id),
lhs_data: a_data, rhs_data: b_data,
lhs_shape: a_shape, rhs_shape: b_shape,
};
node.backward(&mut grads).unwrap();
assert_eq!(
grads.get(&lhs_id).unwrap().as_f32_slice().unwrap(),
&[15.0, 10.0, 5.0, 15.0, 10.0, 5.0]
);
assert_eq!(
grads.get(&rhs_id).unwrap().as_f32_slice().unwrap(),
&[5.0, 5.0, 7.0, 7.0, 9.0, 9.0]
);
}
#[test]
fn test_sum_all_backward_logic() {
let mut grads = Gradients::new();
let out_id = TensorId::new();
let in_id = TensorId::new();
grads.insert(out_id, Buffer::from_f32_vec(vec![42.0])).unwrap();
let node = SumAllBackward {
output_id: out_id, input_id: in_id, input_shape: Shape::new(vec![2, 2]),
};
node.backward(&mut grads).unwrap();
assert_eq!(grads.get(&in_id).unwrap().as_f32_slice().unwrap(), &[42.0, 42.0, 42.0, 42.0]);
}
#[test]
fn test_relu_backward_logic() {
let mut grads = Gradients::new();
let out_id = TensorId::new();
let in_id = TensorId::new();
grads.insert(out_id, Buffer::from_f32_vec(vec![2.0, 2.0, 2.0])).unwrap();
let node = ReluBackward {
output_id: out_id, input_id: in_id, input_data: Buffer::from_f32_vec(vec![-5.0, 0.0, 10.0]),
};
node.backward(&mut grads).unwrap();
assert_eq!(grads.get(&in_id).unwrap().as_f32_slice().unwrap(), &[0.0, 0.0, 2.0]);
}
#[test]
fn test_sigmoid_backward_logic() {
let mut grads = Gradients::new();
let out_id = TensorId::new();
let in_id = TensorId::new();
grads.insert(out_id, Buffer::from_f32_vec(vec![2.0])).unwrap();
let node = SigmoidBackward {
output_id: out_id, input_id: in_id, output_data: Buffer::from_f32_vec(vec![0.5]),
};
node.backward(&mut grads).unwrap();
assert_eq!(grads.get(&in_id).unwrap().as_f32_slice().unwrap(), &[0.5]);
}
}