etensor-core 0.0.1

The pure Rust tensor math and autograd engine
Documentation
//! The Gradient Map: Accumulates and stores backward pass derivatives.

use std::collections::HashMap;
use std::sync::Arc;
use crate::tensor::TensorId;
use crate::buffer::{Buffer, CpuBuffer};
use crate::errors::{EtensorError, EtensorResult};

/// A memory store mapping mathematical tensor nodes to their computed derivatives.
pub struct Gradients {
    grads: HashMap<TensorId, Buffer>,
}

impl Gradients {
    /// Initializes an empty gradient tracker.
    #[allow(clippy::new_without_default)]
    pub fn new() -> Self {
        Self {
            grads: HashMap::new(),
        }
    }

    /// Retrieves a reference to the computed gradient buffer for a specific Tensor.
    pub fn get(&self, id: &TensorId) -> Option<&Buffer> {
        self.grads.get(id)
    }

    /// Removes and takes ownership of a gradient buffer (useful for optimizers).
    pub fn remove(&mut self, id: &TensorId) -> Option<Buffer> {
        self.grads.remove(id)
    }

    /// Pushes a computed derivative into the store.
    /// 
    /// **The Calculus Accumulation Rule:** 
    /// If a gradient already exists for this `TensorId` (meaning the tensor was used 
    /// multiple times in the forward graph), the new gradient is mathematically 
    /// added to the existing one.
    pub fn insert(&mut self, id: TensorId, new_grad: Buffer) -> EtensorResult<()> {
        if let Some(existing_grad) = self.grads.get_mut(&id) {
            // Collision detected! We must accumulate (Add) the buffers.
            match (existing_grad, &new_grad) {
                (Buffer::Cpu(CpuBuffer::F32(existing_arc)), Buffer::Cpu(CpuBuffer::F32(new_arc))) => {
                    // Arc::make_mut ensures we mutate the vector directly in RAM if we are 
                    // the only owners, avoiding memory allocation. If another tensor is 
                    // sharing this exact buffer, it safely clones it before mutating.
                    let existing_vec = Arc::make_mut(existing_arc);
                    
                    if existing_vec.len() != new_arc.len() {
                        return Err(EtensorError::InternalError(
                            "Gradient accumulation failed: Buffer lengths mismatch!".to_string()
                        ));
                    }

                    // Vectorized CPU addition
                    for (a, b) in existing_vec.iter_mut().zip(new_arc.iter()) {
                        *a += b;
                    }
                    Ok(())
                }
                _ => {
                    // In Phase 4, we will replace this match arm with a call to our `dispatch::add` 
                    // kernel so we can support CudaNative and other DTypes automatically.
                    Err(EtensorError::InternalError(
                        "Gradient accumulation for non-F32 or GPU buffers is deferred to Phase 4 Dispatcher.".to_string()
                    ))
                }
            }
        } else {
            // No collision, just insert it normally.
            self.grads.insert(id, new_grad);
            Ok(())
        }
    }
}

// =====================================================================
// UNIT TESTS
// =====================================================================
#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_gradient_insertion_and_retrieval() {
        let mut grads = Gradients::new();
        let id = TensorId::new();
        
        let grad_buf = Buffer::from_f32_vec(vec![1.5, 2.5]);
        grads.insert(id, grad_buf).unwrap();

        let retrieved = grads.get(&id).unwrap();
        let slice = retrieved.as_f32_slice().unwrap();
        
        assert_eq!(slice, &[1.5, 2.5]);
    }

    #[test]
    fn test_gradient_accumulation_on_collision() {
        let mut grads = Gradients::new();
        let id = TensorId::new(); // A tensor used twice in the graph

        // First backward pass computes a derivative of [1.0, 2.0]
        let grad1 = Buffer::from_f32_vec(vec![1.0, 2.0]);
        grads.insert(id, grad1).unwrap();

        // Second backward pass computes a derivative of [3.0, 4.0]
        let grad2 = Buffer::from_f32_vec(vec![3.0, 4.0]);
        grads.insert(id, grad2).unwrap(); // <--- Collision!

        // The Calculus Accumulation Rule dictates the result must be [4.0, 6.0]
        let accumulated = grads.get(&id).unwrap();
        let slice = accumulated.as_f32_slice().unwrap();
        
        assert_eq!(slice, &[4.0, 6.0], "Gradients failed to sum together upon collision!");
    }
}