etensor_core/autograd/gradients.rs
1//! The Gradient Map: Accumulates and stores backward pass derivatives.
2
3use std::collections::HashMap;
4use std::sync::Arc;
5use crate::tensor::TensorId;
6use crate::buffer::{Buffer, CpuBuffer};
7use crate::errors::{EtensorError, EtensorResult};
8
9/// A memory store mapping mathematical tensor nodes to their computed derivatives.
10pub struct Gradients {
11 grads: HashMap<TensorId, Buffer>,
12}
13
14impl Gradients {
15 /// Initializes an empty gradient tracker.
16 #[allow(clippy::new_without_default)]
17 pub fn new() -> Self {
18 Self {
19 grads: HashMap::new(),
20 }
21 }
22
23 /// Retrieves a reference to the computed gradient buffer for a specific Tensor.
24 pub fn get(&self, id: &TensorId) -> Option<&Buffer> {
25 self.grads.get(id)
26 }
27
28 /// Removes and takes ownership of a gradient buffer (useful for optimizers).
29 pub fn remove(&mut self, id: &TensorId) -> Option<Buffer> {
30 self.grads.remove(id)
31 }
32
33 /// Pushes a computed derivative into the store.
34 ///
35 /// **The Calculus Accumulation Rule:**
36 /// If a gradient already exists for this `TensorId` (meaning the tensor was used
37 /// multiple times in the forward graph), the new gradient is mathematically
38 /// added to the existing one.
39 pub fn insert(&mut self, id: TensorId, new_grad: Buffer) -> EtensorResult<()> {
40 if let Some(existing_grad) = self.grads.get_mut(&id) {
41 // Collision detected! We must accumulate (Add) the buffers.
42 match (existing_grad, &new_grad) {
43 (Buffer::Cpu(CpuBuffer::F32(existing_arc)), Buffer::Cpu(CpuBuffer::F32(new_arc))) => {
44 // Arc::make_mut ensures we mutate the vector directly in RAM if we are
45 // the only owners, avoiding memory allocation. If another tensor is
46 // sharing this exact buffer, it safely clones it before mutating.
47 let existing_vec = Arc::make_mut(existing_arc);
48
49 if existing_vec.len() != new_arc.len() {
50 return Err(EtensorError::InternalError(
51 "Gradient accumulation failed: Buffer lengths mismatch!".to_string()
52 ));
53 }
54
55 // Vectorized CPU addition
56 for (a, b) in existing_vec.iter_mut().zip(new_arc.iter()) {
57 *a += b;
58 }
59 Ok(())
60 }
61 _ => {
62 // In Phase 4, we will replace this match arm with a call to our `dispatch::add`
63 // kernel so we can support CudaNative and other DTypes automatically.
64 Err(EtensorError::InternalError(
65 "Gradient accumulation for non-F32 or GPU buffers is deferred to Phase 4 Dispatcher.".to_string()
66 ))
67 }
68 }
69 } else {
70 // No collision, just insert it normally.
71 self.grads.insert(id, new_grad);
72 Ok(())
73 }
74 }
75}
76
77// =====================================================================
78// UNIT TESTS
79// =====================================================================
80#[cfg(test)]
81mod tests {
82 use super::*;
83
84 #[test]
85 fn test_gradient_insertion_and_retrieval() {
86 let mut grads = Gradients::new();
87 let id = TensorId::new();
88
89 let grad_buf = Buffer::from_f32_vec(vec![1.5, 2.5]);
90 grads.insert(id, grad_buf).unwrap();
91
92 let retrieved = grads.get(&id).unwrap();
93 let slice = retrieved.as_f32_slice().unwrap();
94
95 assert_eq!(slice, &[1.5, 2.5]);
96 }
97
98 #[test]
99 fn test_gradient_accumulation_on_collision() {
100 let mut grads = Gradients::new();
101 let id = TensorId::new(); // A tensor used twice in the graph
102
103 // First backward pass computes a derivative of [1.0, 2.0]
104 let grad1 = Buffer::from_f32_vec(vec![1.0, 2.0]);
105 grads.insert(id, grad1).unwrap();
106
107 // Second backward pass computes a derivative of [3.0, 4.0]
108 let grad2 = Buffer::from_f32_vec(vec![3.0, 4.0]);
109 grads.insert(id, grad2).unwrap(); // <--- Collision!
110
111 // The Calculus Accumulation Rule dictates the result must be [4.0, 6.0]
112 let accumulated = grads.get(&id).unwrap();
113 let slice = accumulated.as_f32_slice().unwrap();
114
115 assert_eq!(slice, &[4.0, 6.0], "Gradients failed to sum together upon collision!");
116 }
117}