etensor-core 0.0.1

The pure Rust tensor math and autograd engine
Documentation
//! Mathematical calculus rules for Tape operations.

use crate::tensor::TensorId;
use crate::buffer::Buffer;
use crate::shape::Shape;
use crate::errors::{EtensorError, EtensorResult};
use crate::autograd::tape::TapeAction;
use crate::autograd::gradients::Gradients;

// =====================================================================
// ADDITION: y = a + b
// Calculus: dy/da = 1 * dy, dy/db = 1 * dy
// =====================================================================

/// The backward operation for element-wise addition.
pub struct AddBackward {
    pub output_id: TensorId,
    pub lhs_id: Option<TensorId>, // Option allows us to skip tracking if requires_grad=false
    pub rhs_id: Option<TensorId>,
}

impl TapeAction for AddBackward {
    fn backward(&self, grads: &mut Gradients) -> EtensorResult<()> {
        let dy = grads.get(&self.output_id)
            .ok_or_else(|| EtensorError::AutogradError(
                format!("Gradient missing for Output ID {:?}", self.output_id)
            ))?
            .clone();

        if let Some(id) = self.lhs_id {
            grads.insert(id, dy.clone())?;
        }
        if let Some(id) = self.rhs_id {
            grads.insert(id, dy)?;
        }
        Ok(())
    }
    fn name(&self) -> String { "AddBackward".to_string() }
}

// =====================================================================
// MULTIPLICATION: y = a * b
// Calculus: dy/da = b * dy, dy/db = a * dy
// =====================================================================

/// The backward operation for element-wise multiplication.
pub struct MulBackward {
    pub output_id: TensorId,
    pub lhs_id: Option<TensorId>,
    pub rhs_id: Option<TensorId>,
    pub lhs_data: Buffer,
    pub rhs_data: Buffer, 
}

impl TapeAction for MulBackward {
    fn backward(&self, grads: &mut Gradients) -> EtensorResult<()> {
        let dy_buf = grads.get(&self.output_id)
            .ok_or_else(|| EtensorError::AutogradError("Gradient missing".to_string()))?
            .clone();
        
        let dy = dy_buf.as_f32_slice()?;
        let a = self.lhs_data.as_f32_slice()?;
        let b = self.rhs_data.as_f32_slice()?;

        if let Some(id) = self.lhs_id {
            let mut da = vec![0.0; dy.len()];
            for i in 0..dy.len() { da[i] = dy[i] * b[i]; }
            grads.insert(id, Buffer::from_f32_vec(da))?;
        }

        if let Some(id) = self.rhs_id {
            let mut db = vec![0.0; dy.len()];
            for i in 0..dy.len() { db[i] = dy[i] * a[i]; }
            grads.insert(id, Buffer::from_f32_vec(db))?;
        }
        Ok(())
    }
    fn name(&self) -> String { "MulBackward".to_string() }
}

// =====================================================================
// MATRIX MULTIPLICATION: C = A @ B
// Calculus: dA = dC @ B^T,  dB = A^T @ dC
// =====================================================================

/// The backward operation for matrix multiplication.
pub struct MatMulBackward {
    pub output_id: TensorId,
    pub lhs_id: Option<TensorId>,
    pub rhs_id: Option<TensorId>,
    pub lhs_data: Buffer,
    pub rhs_data: Buffer,
    pub lhs_shape: Shape,
    pub rhs_shape: Shape,
}

impl TapeAction for MatMulBackward {
    fn backward(&self, grads: &mut Gradients) -> EtensorResult<()> {
        let dc_buf = grads.get(&self.output_id)
            .ok_or_else(|| EtensorError::AutogradError("Gradient missing for MatMul Output".to_string()))?
            .clone();
            
        let dc = dc_buf.as_f32_slice()?;
        let a = self.lhs_data.as_f32_slice()?;
        let b = self.rhs_data.as_f32_slice()?;

        let m = self.lhs_shape.dims[0];
        let k = self.lhs_shape.dims[1];
        let n = self.rhs_shape.dims[1];

        let stride_a0 = self.lhs_shape.strides[0];
        let stride_a1 = self.lhs_shape.strides[1];
        let stride_b0 = self.rhs_shape.strides[0];
        let stride_b1 = self.rhs_shape.strides[1];

        if let Some(id) = self.lhs_id {
            let mut da = vec![0.0; m * k];
            for i in 0..m {
                for j in 0..k {
                    let mut sum = 0.0;
                    for p in 0..n {
                        let idx_dc = i * n + p;
                        let idx_b = j * stride_b0 + p * stride_b1;
                        sum += dc[idx_dc] * b[idx_b];
                    }
                    da[i * k + j] = sum;
                }
            }
            grads.insert(id, Buffer::from_f32_vec(da))?;
        }

        if let Some(id) = self.rhs_id {
            let mut db = vec![0.0; k * n];
            for i in 0..k {
                for j in 0..n {
                    let mut sum = 0.0;
                    for p in 0..m {
                        let idx_a = p * stride_a0 + i * stride_a1;
                        let idx_dc = p * n + j;
                        sum += a[idx_a] * dc[idx_dc];
                    }
                    db[i * n + j] = sum;
                }
            }
            grads.insert(id, Buffer::from_f32_vec(db))?;
        }
        Ok(())
    }
    fn name(&self) -> String { "MatMulBackward".to_string() }
}

// =====================================================================
// GLOBAL SUM REDUCTION: y = sum(x)
// Calculus: dx = 1 * dy (Broadcast upstream scalar to all elements)
// =====================================================================

pub struct SumAllBackward {
    pub output_id: TensorId,
    pub input_id: TensorId,
    pub input_shape: Shape,
}

impl TapeAction for SumAllBackward {
    fn backward(&self, grads: &mut Gradients) -> EtensorResult<()> {
        let dy_buf = grads.get(&self.output_id)
            .ok_or_else(|| EtensorError::AutogradError("Gradient missing for Sum Output".to_string()))?
            .clone();
            
        // dy is a single scalar for a global sum
        let dy_scalar = dy_buf.as_f32_slice()?[0];
        
        let num_elements = self.input_shape.num_elements();
        let dx = vec![dy_scalar; num_elements];
        
        grads.insert(self.input_id, Buffer::from_f32_vec(dx))?;
        Ok(())
    }
    fn name(&self) -> String { "SumAllBackward".to_string() }
}

// =====================================================================
// RELU: y = max(0, x)
// Calculus: dx = dy if x > 0 else 0
// =====================================================================

pub struct ReluBackward {
    pub output_id: TensorId,
    pub input_id: TensorId,
    pub input_data: Buffer,
}

impl TapeAction for ReluBackward {
    fn backward(&self, grads: &mut Gradients) -> EtensorResult<()> {
        let dy_buf = grads.get(&self.output_id)
            .ok_or_else(|| EtensorError::AutogradError("Gradient missing for ReLU Output".to_string()))?
            .clone();
            
        let dy = dy_buf.as_f32_slice()?;
        let x = self.input_data.as_f32_slice()?;
        
        let mut dx = vec![0.0; dy.len()];
        for i in 0..dy.len() {
            dx[i] = if x[i] > 0.0 { dy[i] } else { 0.0 };
        }
        
        grads.insert(self.input_id, Buffer::from_f32_vec(dx))?;
        Ok(())
    }
    fn name(&self) -> String { "ReluBackward".to_string() }
}

// =====================================================================
// SIGMOID: y = 1 / (1 + exp(-x))
// Calculus: dx = dy * y * (1 - y)
// =====================================================================

pub struct SigmoidBackward {
    pub output_id: TensorId,
    pub input_id: TensorId,
    pub output_data: Buffer, // We save y, not x, because the math is faster!
}

impl TapeAction for SigmoidBackward {
    fn backward(&self, grads: &mut Gradients) -> EtensorResult<()> {
        let dy_buf = grads.get(&self.output_id)
            .ok_or_else(|| EtensorError::AutogradError("Gradient missing for Sigmoid Output".to_string()))?
            .clone();
            
        let dy = dy_buf.as_f32_slice()?;
        let y = self.output_data.as_f32_slice()?;
        
        let mut dx = vec![0.0; dy.len()];
        for i in 0..dy.len() {
            dx[i] = dy[i] * y[i] * (1.0 - y[i]);
        }
        
        grads.insert(self.input_id, Buffer::from_f32_vec(dx))?;
        Ok(())
    }
    fn name(&self) -> String { "SigmoidBackward".to_string() }
}

// =====================================================================
// UNIT TESTS
// =====================================================================
#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_add_backward_logic() {
        let mut grads = Gradients::new();
        let out_id = TensorId::new();
        let lhs_id = TensorId::new();
        let rhs_id = TensorId::new();

        grads.insert(out_id, Buffer::from_f32_vec(vec![5.0, 5.0])).unwrap();

        let node = AddBackward {
            output_id: out_id, lhs_id: Some(lhs_id), rhs_id: Some(rhs_id),
        };
        node.backward(&mut grads).unwrap();

        assert_eq!(grads.get(&lhs_id).unwrap().as_f32_slice().unwrap(), &[5.0, 5.0]);
        assert_eq!(grads.get(&rhs_id).unwrap().as_f32_slice().unwrap(), &[5.0, 5.0]);
    }

    #[test]
    fn test_mul_backward_logic() {
        let mut grads = Gradients::new();
        let out_id = TensorId::new();
        let lhs_id = TensorId::new();
        let rhs_id = TensorId::new();

        grads.insert(out_id, Buffer::from_f32_vec(vec![2.0, 2.0])).unwrap();

        let node = MulBackward {
            output_id: out_id, lhs_id: Some(lhs_id), rhs_id: Some(rhs_id),
            lhs_data: Buffer::from_f32_vec(vec![3.0, 4.0]),
            rhs_data: Buffer::from_f32_vec(vec![10.0, 20.0]),
        };
        node.backward(&mut grads).unwrap();

        assert_eq!(grads.get(&lhs_id).unwrap().as_f32_slice().unwrap(), &[20.0, 40.0]);
        assert_eq!(grads.get(&rhs_id).unwrap().as_f32_slice().unwrap(), &[6.0, 8.0]);
    }

    #[test]
    fn test_matmul_backward_logic() {
        let mut grads = Gradients::new();
        let out_id = TensorId::new();
        let lhs_id = TensorId::new();
        let rhs_id = TensorId::new();

        grads.insert(out_id, Buffer::from_f32_vec(vec![1.0, 1.0, 1.0, 1.0])).unwrap();

        let a_shape = Shape::new(vec![2, 3]);
        let a_data = Buffer::from_f32_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
        let b_shape = Shape::new(vec![3, 2]);
        let b_data = Buffer::from_f32_vec(vec![7.0, 8.0, 9.0, 1.0, 2.0, 3.0]);

        let node = MatMulBackward {
            output_id: out_id, lhs_id: Some(lhs_id), rhs_id: Some(rhs_id),
            lhs_data: a_data, rhs_data: b_data,
            lhs_shape: a_shape, rhs_shape: b_shape,
        };
        node.backward(&mut grads).unwrap();

        assert_eq!(
            grads.get(&lhs_id).unwrap().as_f32_slice().unwrap(), 
            &[15.0, 10.0, 5.0, 15.0, 10.0, 5.0]
        );
        assert_eq!(
            grads.get(&rhs_id).unwrap().as_f32_slice().unwrap(), 
            &[5.0, 5.0, 7.0, 7.0, 9.0, 9.0]
        );
    }

    #[test]
    fn test_sum_all_backward_logic() {
        let mut grads = Gradients::new();
        let out_id = TensorId::new();
        let in_id = TensorId::new();

        // dy is 42.0
        grads.insert(out_id, Buffer::from_f32_vec(vec![42.0])).unwrap();

        let node = SumAllBackward {
            output_id: out_id, input_id: in_id, input_shape: Shape::new(vec![2, 2]),
        };
        node.backward(&mut grads).unwrap();

        // 42.0 should be broadcast to all 4 elements
        assert_eq!(grads.get(&in_id).unwrap().as_f32_slice().unwrap(), &[42.0, 42.0, 42.0, 42.0]);
    }

    #[test]
    fn test_relu_backward_logic() {
        let mut grads = Gradients::new();
        let out_id = TensorId::new();
        let in_id = TensorId::new();

        // dy is [2.0, 2.0, 2.0]
        grads.insert(out_id, Buffer::from_f32_vec(vec![2.0, 2.0, 2.0])).unwrap();

        // x is [-5.0, 0.0, 10.0]
        let node = ReluBackward {
            output_id: out_id, input_id: in_id, input_data: Buffer::from_f32_vec(vec![-5.0, 0.0, 10.0]),
        };
        node.backward(&mut grads).unwrap();

        // dx should be [0.0, 0.0, 2.0]
        assert_eq!(grads.get(&in_id).unwrap().as_f32_slice().unwrap(), &[0.0, 0.0, 2.0]);
    }

    #[test]
    fn test_sigmoid_backward_logic() {
        let mut grads = Gradients::new();
        let out_id = TensorId::new();
        let in_id = TensorId::new();

        // dy is [2.0]
        grads.insert(out_id, Buffer::from_f32_vec(vec![2.0])).unwrap();

        // y is 0.5 (which means x was 0.0)
        let node = SigmoidBackward {
            output_id: out_id, input_id: in_id, output_data: Buffer::from_f32_vec(vec![0.5]),
        };
        node.backward(&mut grads).unwrap();

        // dx = dy * y * (1 - y) -> 2.0 * 0.5 * 0.5 = 0.5
        assert_eq!(grads.get(&in_id).unwrap().as_f32_slice().unwrap(), &[0.5]);
    }
}