Skip to main content

etensor_core/autograd/
gradients.rs

1//! The Gradient Map: Accumulates and stores backward pass derivatives.
2
3use std::collections::HashMap;
4use std::sync::Arc;
5use crate::tensor::TensorId;
6use crate::buffer::{Buffer, CpuBuffer};
7use crate::errors::{EtensorError, EtensorResult};
8
9/// A memory store mapping mathematical tensor nodes to their computed derivatives.
10pub struct Gradients {
11    grads: HashMap<TensorId, Buffer>,
12}
13
14impl Gradients {
15    /// Initializes an empty gradient tracker.
16    #[allow(clippy::new_without_default)]
17    pub fn new() -> Self {
18        Self {
19            grads: HashMap::new(),
20        }
21    }
22
23    /// Retrieves a reference to the computed gradient buffer for a specific Tensor.
24    pub fn get(&self, id: &TensorId) -> Option<&Buffer> {
25        self.grads.get(id)
26    }
27
28    /// Removes and takes ownership of a gradient buffer (useful for optimizers).
29    pub fn remove(&mut self, id: &TensorId) -> Option<Buffer> {
30        self.grads.remove(id)
31    }
32
33    /// Pushes a computed derivative into the store.
34    /// 
35    /// **The Calculus Accumulation Rule:** 
36    /// If a gradient already exists for this `TensorId` (meaning the tensor was used 
37    /// multiple times in the forward graph), the new gradient is mathematically 
38    /// added to the existing one.
39    pub fn insert(&mut self, id: TensorId, new_grad: Buffer) -> EtensorResult<()> {
40        if let Some(existing_grad) = self.grads.get_mut(&id) {
41            // Collision detected! We must accumulate (Add) the buffers.
42            match (existing_grad, &new_grad) {
43                (Buffer::Cpu(CpuBuffer::F32(existing_arc)), Buffer::Cpu(CpuBuffer::F32(new_arc))) => {
44                    // Arc::make_mut ensures we mutate the vector directly in RAM if we are 
45                    // the only owners, avoiding memory allocation. If another tensor is 
46                    // sharing this exact buffer, it safely clones it before mutating.
47                    let existing_vec = Arc::make_mut(existing_arc);
48                    
49                    if existing_vec.len() != new_arc.len() {
50                        return Err(EtensorError::InternalError(
51                            "Gradient accumulation failed: Buffer lengths mismatch!".to_string()
52                        ));
53                    }
54
55                    // Vectorized CPU addition
56                    for (a, b) in existing_vec.iter_mut().zip(new_arc.iter()) {
57                        *a += b;
58                    }
59                    Ok(())
60                }
61                _ => {
62                    // In Phase 4, we will replace this match arm with a call to our `dispatch::add` 
63                    // kernel so we can support CudaNative and other DTypes automatically.
64                    Err(EtensorError::InternalError(
65                        "Gradient accumulation for non-F32 or GPU buffers is deferred to Phase 4 Dispatcher.".to_string()
66                    ))
67                }
68            }
69        } else {
70            // No collision, just insert it normally.
71            self.grads.insert(id, new_grad);
72            Ok(())
73        }
74    }
75}
76
77// =====================================================================
78// UNIT TESTS
79// =====================================================================
80#[cfg(test)]
81mod tests {
82    use super::*;
83
84    #[test]
85    fn test_gradient_insertion_and_retrieval() {
86        let mut grads = Gradients::new();
87        let id = TensorId::new();
88        
89        let grad_buf = Buffer::from_f32_vec(vec![1.5, 2.5]);
90        grads.insert(id, grad_buf).unwrap();
91
92        let retrieved = grads.get(&id).unwrap();
93        let slice = retrieved.as_f32_slice().unwrap();
94        
95        assert_eq!(slice, &[1.5, 2.5]);
96    }
97
98    #[test]
99    fn test_gradient_accumulation_on_collision() {
100        let mut grads = Gradients::new();
101        let id = TensorId::new(); // A tensor used twice in the graph
102
103        // First backward pass computes a derivative of [1.0, 2.0]
104        let grad1 = Buffer::from_f32_vec(vec![1.0, 2.0]);
105        grads.insert(id, grad1).unwrap();
106
107        // Second backward pass computes a derivative of [3.0, 4.0]
108        let grad2 = Buffer::from_f32_vec(vec![3.0, 4.0]);
109        grads.insert(id, grad2).unwrap(); // <--- Collision!
110
111        // The Calculus Accumulation Rule dictates the result must be [4.0, 6.0]
112        let accumulated = grads.get(&id).unwrap();
113        let slice = accumulated.as_f32_slice().unwrap();
114        
115        assert_eq!(slice, &[4.0, 6.0], "Gradients failed to sum together upon collision!");
116    }
117}