Skip to main content

etensor_core/autograd/
nodes.rs

1//! Mathematical calculus rules for Tape operations.
2
3use crate::tensor::TensorId;
4use crate::buffer::Buffer;
5use crate::shape::Shape;
6use crate::errors::{EtensorError, EtensorResult};
7use crate::autograd::tape::TapeAction;
8use crate::autograd::gradients::Gradients;
9
10// =====================================================================
11// ADDITION: y = a + b
12// Calculus: dy/da = 1 * dy, dy/db = 1 * dy
13// =====================================================================
14
15/// The backward operation for element-wise addition.
16pub struct AddBackward {
17    pub output_id: TensorId,
18    pub lhs_id: Option<TensorId>, // Option allows us to skip tracking if requires_grad=false
19    pub rhs_id: Option<TensorId>,
20}
21
22impl TapeAction for AddBackward {
23    fn backward(&self, grads: &mut Gradients) -> EtensorResult<()> {
24        let dy = grads.get(&self.output_id)
25            .ok_or_else(|| EtensorError::AutogradError(
26                format!("Gradient missing for Output ID {:?}", self.output_id)
27            ))?
28            .clone();
29
30        if let Some(id) = self.lhs_id {
31            grads.insert(id, dy.clone())?;
32        }
33        if let Some(id) = self.rhs_id {
34            grads.insert(id, dy)?;
35        }
36        Ok(())
37    }
38    fn name(&self) -> String { "AddBackward".to_string() }
39}
40
41// =====================================================================
42// MULTIPLICATION: y = a * b
43// Calculus: dy/da = b * dy, dy/db = a * dy
44// =====================================================================
45
46/// The backward operation for element-wise multiplication.
47pub struct MulBackward {
48    pub output_id: TensorId,
49    pub lhs_id: Option<TensorId>,
50    pub rhs_id: Option<TensorId>,
51    pub lhs_data: Buffer,
52    pub rhs_data: Buffer, 
53}
54
55impl TapeAction for MulBackward {
56    fn backward(&self, grads: &mut Gradients) -> EtensorResult<()> {
57        let dy_buf = grads.get(&self.output_id)
58            .ok_or_else(|| EtensorError::AutogradError("Gradient missing".to_string()))?
59            .clone();
60        
61        let dy = dy_buf.as_f32_slice()?;
62        let a = self.lhs_data.as_f32_slice()?;
63        let b = self.rhs_data.as_f32_slice()?;
64
65        if let Some(id) = self.lhs_id {
66            let mut da = vec![0.0; dy.len()];
67            for i in 0..dy.len() { da[i] = dy[i] * b[i]; }
68            grads.insert(id, Buffer::from_f32_vec(da))?;
69        }
70
71        if let Some(id) = self.rhs_id {
72            let mut db = vec![0.0; dy.len()];
73            for i in 0..dy.len() { db[i] = dy[i] * a[i]; }
74            grads.insert(id, Buffer::from_f32_vec(db))?;
75        }
76        Ok(())
77    }
78    fn name(&self) -> String { "MulBackward".to_string() }
79}
80
81// =====================================================================
82// MATRIX MULTIPLICATION: C = A @ B
83// Calculus: dA = dC @ B^T,  dB = A^T @ dC
84// =====================================================================
85
86/// The backward operation for matrix multiplication.
87pub struct MatMulBackward {
88    pub output_id: TensorId,
89    pub lhs_id: Option<TensorId>,
90    pub rhs_id: Option<TensorId>,
91    pub lhs_data: Buffer,
92    pub rhs_data: Buffer,
93    pub lhs_shape: Shape,
94    pub rhs_shape: Shape,
95}
96
97impl TapeAction for MatMulBackward {
98    fn backward(&self, grads: &mut Gradients) -> EtensorResult<()> {
99        let dc_buf = grads.get(&self.output_id)
100            .ok_or_else(|| EtensorError::AutogradError("Gradient missing for MatMul Output".to_string()))?
101            .clone();
102            
103        let dc = dc_buf.as_f32_slice()?;
104        let a = self.lhs_data.as_f32_slice()?;
105        let b = self.rhs_data.as_f32_slice()?;
106
107        let m = self.lhs_shape.dims[0];
108        let k = self.lhs_shape.dims[1];
109        let n = self.rhs_shape.dims[1];
110
111        let stride_a0 = self.lhs_shape.strides[0];
112        let stride_a1 = self.lhs_shape.strides[1];
113        let stride_b0 = self.rhs_shape.strides[0];
114        let stride_b1 = self.rhs_shape.strides[1];
115
116        if let Some(id) = self.lhs_id {
117            let mut da = vec![0.0; m * k];
118            for i in 0..m {
119                for j in 0..k {
120                    let mut sum = 0.0;
121                    for p in 0..n {
122                        let idx_dc = i * n + p;
123                        let idx_b = j * stride_b0 + p * stride_b1;
124                        sum += dc[idx_dc] * b[idx_b];
125                    }
126                    da[i * k + j] = sum;
127                }
128            }
129            grads.insert(id, Buffer::from_f32_vec(da))?;
130        }
131
132        if let Some(id) = self.rhs_id {
133            let mut db = vec![0.0; k * n];
134            for i in 0..k {
135                for j in 0..n {
136                    let mut sum = 0.0;
137                    for p in 0..m {
138                        let idx_a = p * stride_a0 + i * stride_a1;
139                        let idx_dc = p * n + j;
140                        sum += a[idx_a] * dc[idx_dc];
141                    }
142                    db[i * n + j] = sum;
143                }
144            }
145            grads.insert(id, Buffer::from_f32_vec(db))?;
146        }
147        Ok(())
148    }
149    fn name(&self) -> String { "MatMulBackward".to_string() }
150}
151
152// =====================================================================
153// GLOBAL SUM REDUCTION: y = sum(x)
154// Calculus: dx = 1 * dy (Broadcast upstream scalar to all elements)
155// =====================================================================
156
157pub struct SumAllBackward {
158    pub output_id: TensorId,
159    pub input_id: TensorId,
160    pub input_shape: Shape,
161}
162
163impl TapeAction for SumAllBackward {
164    fn backward(&self, grads: &mut Gradients) -> EtensorResult<()> {
165        let dy_buf = grads.get(&self.output_id)
166            .ok_or_else(|| EtensorError::AutogradError("Gradient missing for Sum Output".to_string()))?
167            .clone();
168            
169        // dy is a single scalar for a global sum
170        let dy_scalar = dy_buf.as_f32_slice()?[0];
171        
172        let num_elements = self.input_shape.num_elements();
173        let dx = vec![dy_scalar; num_elements];
174        
175        grads.insert(self.input_id, Buffer::from_f32_vec(dx))?;
176        Ok(())
177    }
178    fn name(&self) -> String { "SumAllBackward".to_string() }
179}
180
181// =====================================================================
182// RELU: y = max(0, x)
183// Calculus: dx = dy if x > 0 else 0
184// =====================================================================
185
186pub struct ReluBackward {
187    pub output_id: TensorId,
188    pub input_id: TensorId,
189    pub input_data: Buffer,
190}
191
192impl TapeAction for ReluBackward {
193    fn backward(&self, grads: &mut Gradients) -> EtensorResult<()> {
194        let dy_buf = grads.get(&self.output_id)
195            .ok_or_else(|| EtensorError::AutogradError("Gradient missing for ReLU Output".to_string()))?
196            .clone();
197            
198        let dy = dy_buf.as_f32_slice()?;
199        let x = self.input_data.as_f32_slice()?;
200        
201        let mut dx = vec![0.0; dy.len()];
202        for i in 0..dy.len() {
203            dx[i] = if x[i] > 0.0 { dy[i] } else { 0.0 };
204        }
205        
206        grads.insert(self.input_id, Buffer::from_f32_vec(dx))?;
207        Ok(())
208    }
209    fn name(&self) -> String { "ReluBackward".to_string() }
210}
211
212// =====================================================================
213// SIGMOID: y = 1 / (1 + exp(-x))
214// Calculus: dx = dy * y * (1 - y)
215// =====================================================================
216
217pub struct SigmoidBackward {
218    pub output_id: TensorId,
219    pub input_id: TensorId,
220    pub output_data: Buffer, // We save y, not x, because the math is faster!
221}
222
223impl TapeAction for SigmoidBackward {
224    fn backward(&self, grads: &mut Gradients) -> EtensorResult<()> {
225        let dy_buf = grads.get(&self.output_id)
226            .ok_or_else(|| EtensorError::AutogradError("Gradient missing for Sigmoid Output".to_string()))?
227            .clone();
228            
229        let dy = dy_buf.as_f32_slice()?;
230        let y = self.output_data.as_f32_slice()?;
231        
232        let mut dx = vec![0.0; dy.len()];
233        for i in 0..dy.len() {
234            dx[i] = dy[i] * y[i] * (1.0 - y[i]);
235        }
236        
237        grads.insert(self.input_id, Buffer::from_f32_vec(dx))?;
238        Ok(())
239    }
240    fn name(&self) -> String { "SigmoidBackward".to_string() }
241}
242
243// =====================================================================
244// UNIT TESTS
245// =====================================================================
246#[cfg(test)]
247mod tests {
248    use super::*;
249
250    #[test]
251    fn test_add_backward_logic() {
252        let mut grads = Gradients::new();
253        let out_id = TensorId::new();
254        let lhs_id = TensorId::new();
255        let rhs_id = TensorId::new();
256
257        grads.insert(out_id, Buffer::from_f32_vec(vec![5.0, 5.0])).unwrap();
258
259        let node = AddBackward {
260            output_id: out_id, lhs_id: Some(lhs_id), rhs_id: Some(rhs_id),
261        };
262        node.backward(&mut grads).unwrap();
263
264        assert_eq!(grads.get(&lhs_id).unwrap().as_f32_slice().unwrap(), &[5.0, 5.0]);
265        assert_eq!(grads.get(&rhs_id).unwrap().as_f32_slice().unwrap(), &[5.0, 5.0]);
266    }
267
268    #[test]
269    fn test_mul_backward_logic() {
270        let mut grads = Gradients::new();
271        let out_id = TensorId::new();
272        let lhs_id = TensorId::new();
273        let rhs_id = TensorId::new();
274
275        grads.insert(out_id, Buffer::from_f32_vec(vec![2.0, 2.0])).unwrap();
276
277        let node = MulBackward {
278            output_id: out_id, lhs_id: Some(lhs_id), rhs_id: Some(rhs_id),
279            lhs_data: Buffer::from_f32_vec(vec![3.0, 4.0]),
280            rhs_data: Buffer::from_f32_vec(vec![10.0, 20.0]),
281        };
282        node.backward(&mut grads).unwrap();
283
284        assert_eq!(grads.get(&lhs_id).unwrap().as_f32_slice().unwrap(), &[20.0, 40.0]);
285        assert_eq!(grads.get(&rhs_id).unwrap().as_f32_slice().unwrap(), &[6.0, 8.0]);
286    }
287
288    #[test]
289    fn test_matmul_backward_logic() {
290        let mut grads = Gradients::new();
291        let out_id = TensorId::new();
292        let lhs_id = TensorId::new();
293        let rhs_id = TensorId::new();
294
295        grads.insert(out_id, Buffer::from_f32_vec(vec![1.0, 1.0, 1.0, 1.0])).unwrap();
296
297        let a_shape = Shape::new(vec![2, 3]);
298        let a_data = Buffer::from_f32_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
299        let b_shape = Shape::new(vec![3, 2]);
300        let b_data = Buffer::from_f32_vec(vec![7.0, 8.0, 9.0, 1.0, 2.0, 3.0]);
301
302        let node = MatMulBackward {
303            output_id: out_id, lhs_id: Some(lhs_id), rhs_id: Some(rhs_id),
304            lhs_data: a_data, rhs_data: b_data,
305            lhs_shape: a_shape, rhs_shape: b_shape,
306        };
307        node.backward(&mut grads).unwrap();
308
309        assert_eq!(
310            grads.get(&lhs_id).unwrap().as_f32_slice().unwrap(), 
311            &[15.0, 10.0, 5.0, 15.0, 10.0, 5.0]
312        );
313        assert_eq!(
314            grads.get(&rhs_id).unwrap().as_f32_slice().unwrap(), 
315            &[5.0, 5.0, 7.0, 7.0, 9.0, 9.0]
316        );
317    }
318
319    #[test]
320    fn test_sum_all_backward_logic() {
321        let mut grads = Gradients::new();
322        let out_id = TensorId::new();
323        let in_id = TensorId::new();
324
325        // dy is 42.0
326        grads.insert(out_id, Buffer::from_f32_vec(vec![42.0])).unwrap();
327
328        let node = SumAllBackward {
329            output_id: out_id, input_id: in_id, input_shape: Shape::new(vec![2, 2]),
330        };
331        node.backward(&mut grads).unwrap();
332
333        // 42.0 should be broadcast to all 4 elements
334        assert_eq!(grads.get(&in_id).unwrap().as_f32_slice().unwrap(), &[42.0, 42.0, 42.0, 42.0]);
335    }
336
337    #[test]
338    fn test_relu_backward_logic() {
339        let mut grads = Gradients::new();
340        let out_id = TensorId::new();
341        let in_id = TensorId::new();
342
343        // dy is [2.0, 2.0, 2.0]
344        grads.insert(out_id, Buffer::from_f32_vec(vec![2.0, 2.0, 2.0])).unwrap();
345
346        // x is [-5.0, 0.0, 10.0]
347        let node = ReluBackward {
348            output_id: out_id, input_id: in_id, input_data: Buffer::from_f32_vec(vec![-5.0, 0.0, 10.0]),
349        };
350        node.backward(&mut grads).unwrap();
351
352        // dx should be [0.0, 0.0, 2.0]
353        assert_eq!(grads.get(&in_id).unwrap().as_f32_slice().unwrap(), &[0.0, 0.0, 2.0]);
354    }
355
356    #[test]
357    fn test_sigmoid_backward_logic() {
358        let mut grads = Gradients::new();
359        let out_id = TensorId::new();
360        let in_id = TensorId::new();
361
362        // dy is [2.0]
363        grads.insert(out_id, Buffer::from_f32_vec(vec![2.0])).unwrap();
364
365        // y is 0.5 (which means x was 0.0)
366        let node = SigmoidBackward {
367            output_id: out_id, input_id: in_id, output_data: Buffer::from_f32_vec(vec![0.5]),
368        };
369        node.backward(&mut grads).unwrap();
370
371        // dx = dy * y * (1 - y) -> 2.0 * 0.5 * 0.5 = 0.5
372        assert_eq!(grads.get(&in_id).unwrap().as_f32_slice().unwrap(), &[0.5]);
373    }
374}