1use crate::tensor::TensorId;
4use crate::buffer::Buffer;
5use crate::shape::Shape;
6use crate::errors::{EtensorError, EtensorResult};
7use crate::autograd::tape::TapeAction;
8use crate::autograd::gradients::Gradients;
9
10pub struct AddBackward {
17 pub output_id: TensorId,
18 pub lhs_id: Option<TensorId>, pub rhs_id: Option<TensorId>,
20}
21
22impl TapeAction for AddBackward {
23 fn backward(&self, grads: &mut Gradients) -> EtensorResult<()> {
24 let dy = grads.get(&self.output_id)
25 .ok_or_else(|| EtensorError::AutogradError(
26 format!("Gradient missing for Output ID {:?}", self.output_id)
27 ))?
28 .clone();
29
30 if let Some(id) = self.lhs_id {
31 grads.insert(id, dy.clone())?;
32 }
33 if let Some(id) = self.rhs_id {
34 grads.insert(id, dy)?;
35 }
36 Ok(())
37 }
38 fn name(&self) -> String { "AddBackward".to_string() }
39}
40
41pub struct MulBackward {
48 pub output_id: TensorId,
49 pub lhs_id: Option<TensorId>,
50 pub rhs_id: Option<TensorId>,
51 pub lhs_data: Buffer,
52 pub rhs_data: Buffer,
53}
54
55impl TapeAction for MulBackward {
56 fn backward(&self, grads: &mut Gradients) -> EtensorResult<()> {
57 let dy_buf = grads.get(&self.output_id)
58 .ok_or_else(|| EtensorError::AutogradError("Gradient missing".to_string()))?
59 .clone();
60
61 let dy = dy_buf.as_f32_slice()?;
62 let a = self.lhs_data.as_f32_slice()?;
63 let b = self.rhs_data.as_f32_slice()?;
64
65 if let Some(id) = self.lhs_id {
66 let mut da = vec![0.0; dy.len()];
67 for i in 0..dy.len() { da[i] = dy[i] * b[i]; }
68 grads.insert(id, Buffer::from_f32_vec(da))?;
69 }
70
71 if let Some(id) = self.rhs_id {
72 let mut db = vec![0.0; dy.len()];
73 for i in 0..dy.len() { db[i] = dy[i] * a[i]; }
74 grads.insert(id, Buffer::from_f32_vec(db))?;
75 }
76 Ok(())
77 }
78 fn name(&self) -> String { "MulBackward".to_string() }
79}
80
81pub struct MatMulBackward {
88 pub output_id: TensorId,
89 pub lhs_id: Option<TensorId>,
90 pub rhs_id: Option<TensorId>,
91 pub lhs_data: Buffer,
92 pub rhs_data: Buffer,
93 pub lhs_shape: Shape,
94 pub rhs_shape: Shape,
95}
96
97impl TapeAction for MatMulBackward {
98 fn backward(&self, grads: &mut Gradients) -> EtensorResult<()> {
99 let dc_buf = grads.get(&self.output_id)
100 .ok_or_else(|| EtensorError::AutogradError("Gradient missing for MatMul Output".to_string()))?
101 .clone();
102
103 let dc = dc_buf.as_f32_slice()?;
104 let a = self.lhs_data.as_f32_slice()?;
105 let b = self.rhs_data.as_f32_slice()?;
106
107 let m = self.lhs_shape.dims[0];
108 let k = self.lhs_shape.dims[1];
109 let n = self.rhs_shape.dims[1];
110
111 let stride_a0 = self.lhs_shape.strides[0];
112 let stride_a1 = self.lhs_shape.strides[1];
113 let stride_b0 = self.rhs_shape.strides[0];
114 let stride_b1 = self.rhs_shape.strides[1];
115
116 if let Some(id) = self.lhs_id {
117 let mut da = vec![0.0; m * k];
118 for i in 0..m {
119 for j in 0..k {
120 let mut sum = 0.0;
121 for p in 0..n {
122 let idx_dc = i * n + p;
123 let idx_b = j * stride_b0 + p * stride_b1;
124 sum += dc[idx_dc] * b[idx_b];
125 }
126 da[i * k + j] = sum;
127 }
128 }
129 grads.insert(id, Buffer::from_f32_vec(da))?;
130 }
131
132 if let Some(id) = self.rhs_id {
133 let mut db = vec![0.0; k * n];
134 for i in 0..k {
135 for j in 0..n {
136 let mut sum = 0.0;
137 for p in 0..m {
138 let idx_a = p * stride_a0 + i * stride_a1;
139 let idx_dc = p * n + j;
140 sum += a[idx_a] * dc[idx_dc];
141 }
142 db[i * n + j] = sum;
143 }
144 }
145 grads.insert(id, Buffer::from_f32_vec(db))?;
146 }
147 Ok(())
148 }
149 fn name(&self) -> String { "MatMulBackward".to_string() }
150}
151
152pub struct SumAllBackward {
158 pub output_id: TensorId,
159 pub input_id: TensorId,
160 pub input_shape: Shape,
161}
162
163impl TapeAction for SumAllBackward {
164 fn backward(&self, grads: &mut Gradients) -> EtensorResult<()> {
165 let dy_buf = grads.get(&self.output_id)
166 .ok_or_else(|| EtensorError::AutogradError("Gradient missing for Sum Output".to_string()))?
167 .clone();
168
169 let dy_scalar = dy_buf.as_f32_slice()?[0];
171
172 let num_elements = self.input_shape.num_elements();
173 let dx = vec![dy_scalar; num_elements];
174
175 grads.insert(self.input_id, Buffer::from_f32_vec(dx))?;
176 Ok(())
177 }
178 fn name(&self) -> String { "SumAllBackward".to_string() }
179}
180
181pub struct ReluBackward {
187 pub output_id: TensorId,
188 pub input_id: TensorId,
189 pub input_data: Buffer,
190}
191
192impl TapeAction for ReluBackward {
193 fn backward(&self, grads: &mut Gradients) -> EtensorResult<()> {
194 let dy_buf = grads.get(&self.output_id)
195 .ok_or_else(|| EtensorError::AutogradError("Gradient missing for ReLU Output".to_string()))?
196 .clone();
197
198 let dy = dy_buf.as_f32_slice()?;
199 let x = self.input_data.as_f32_slice()?;
200
201 let mut dx = vec![0.0; dy.len()];
202 for i in 0..dy.len() {
203 dx[i] = if x[i] > 0.0 { dy[i] } else { 0.0 };
204 }
205
206 grads.insert(self.input_id, Buffer::from_f32_vec(dx))?;
207 Ok(())
208 }
209 fn name(&self) -> String { "ReluBackward".to_string() }
210}
211
212pub struct SigmoidBackward {
218 pub output_id: TensorId,
219 pub input_id: TensorId,
220 pub output_data: Buffer, }
222
223impl TapeAction for SigmoidBackward {
224 fn backward(&self, grads: &mut Gradients) -> EtensorResult<()> {
225 let dy_buf = grads.get(&self.output_id)
226 .ok_or_else(|| EtensorError::AutogradError("Gradient missing for Sigmoid Output".to_string()))?
227 .clone();
228
229 let dy = dy_buf.as_f32_slice()?;
230 let y = self.output_data.as_f32_slice()?;
231
232 let mut dx = vec![0.0; dy.len()];
233 for i in 0..dy.len() {
234 dx[i] = dy[i] * y[i] * (1.0 - y[i]);
235 }
236
237 grads.insert(self.input_id, Buffer::from_f32_vec(dx))?;
238 Ok(())
239 }
240 fn name(&self) -> String { "SigmoidBackward".to_string() }
241}
242
243#[cfg(test)]
247mod tests {
248 use super::*;
249
250 #[test]
251 fn test_add_backward_logic() {
252 let mut grads = Gradients::new();
253 let out_id = TensorId::new();
254 let lhs_id = TensorId::new();
255 let rhs_id = TensorId::new();
256
257 grads.insert(out_id, Buffer::from_f32_vec(vec![5.0, 5.0])).unwrap();
258
259 let node = AddBackward {
260 output_id: out_id, lhs_id: Some(lhs_id), rhs_id: Some(rhs_id),
261 };
262 node.backward(&mut grads).unwrap();
263
264 assert_eq!(grads.get(&lhs_id).unwrap().as_f32_slice().unwrap(), &[5.0, 5.0]);
265 assert_eq!(grads.get(&rhs_id).unwrap().as_f32_slice().unwrap(), &[5.0, 5.0]);
266 }
267
268 #[test]
269 fn test_mul_backward_logic() {
270 let mut grads = Gradients::new();
271 let out_id = TensorId::new();
272 let lhs_id = TensorId::new();
273 let rhs_id = TensorId::new();
274
275 grads.insert(out_id, Buffer::from_f32_vec(vec![2.0, 2.0])).unwrap();
276
277 let node = MulBackward {
278 output_id: out_id, lhs_id: Some(lhs_id), rhs_id: Some(rhs_id),
279 lhs_data: Buffer::from_f32_vec(vec![3.0, 4.0]),
280 rhs_data: Buffer::from_f32_vec(vec![10.0, 20.0]),
281 };
282 node.backward(&mut grads).unwrap();
283
284 assert_eq!(grads.get(&lhs_id).unwrap().as_f32_slice().unwrap(), &[20.0, 40.0]);
285 assert_eq!(grads.get(&rhs_id).unwrap().as_f32_slice().unwrap(), &[6.0, 8.0]);
286 }
287
288 #[test]
289 fn test_matmul_backward_logic() {
290 let mut grads = Gradients::new();
291 let out_id = TensorId::new();
292 let lhs_id = TensorId::new();
293 let rhs_id = TensorId::new();
294
295 grads.insert(out_id, Buffer::from_f32_vec(vec![1.0, 1.0, 1.0, 1.0])).unwrap();
296
297 let a_shape = Shape::new(vec![2, 3]);
298 let a_data = Buffer::from_f32_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
299 let b_shape = Shape::new(vec![3, 2]);
300 let b_data = Buffer::from_f32_vec(vec![7.0, 8.0, 9.0, 1.0, 2.0, 3.0]);
301
302 let node = MatMulBackward {
303 output_id: out_id, lhs_id: Some(lhs_id), rhs_id: Some(rhs_id),
304 lhs_data: a_data, rhs_data: b_data,
305 lhs_shape: a_shape, rhs_shape: b_shape,
306 };
307 node.backward(&mut grads).unwrap();
308
309 assert_eq!(
310 grads.get(&lhs_id).unwrap().as_f32_slice().unwrap(),
311 &[15.0, 10.0, 5.0, 15.0, 10.0, 5.0]
312 );
313 assert_eq!(
314 grads.get(&rhs_id).unwrap().as_f32_slice().unwrap(),
315 &[5.0, 5.0, 7.0, 7.0, 9.0, 9.0]
316 );
317 }
318
319 #[test]
320 fn test_sum_all_backward_logic() {
321 let mut grads = Gradients::new();
322 let out_id = TensorId::new();
323 let in_id = TensorId::new();
324
325 grads.insert(out_id, Buffer::from_f32_vec(vec![42.0])).unwrap();
327
328 let node = SumAllBackward {
329 output_id: out_id, input_id: in_id, input_shape: Shape::new(vec![2, 2]),
330 };
331 node.backward(&mut grads).unwrap();
332
333 assert_eq!(grads.get(&in_id).unwrap().as_f32_slice().unwrap(), &[42.0, 42.0, 42.0, 42.0]);
335 }
336
337 #[test]
338 fn test_relu_backward_logic() {
339 let mut grads = Gradients::new();
340 let out_id = TensorId::new();
341 let in_id = TensorId::new();
342
343 grads.insert(out_id, Buffer::from_f32_vec(vec![2.0, 2.0, 2.0])).unwrap();
345
346 let node = ReluBackward {
348 output_id: out_id, input_id: in_id, input_data: Buffer::from_f32_vec(vec![-5.0, 0.0, 10.0]),
349 };
350 node.backward(&mut grads).unwrap();
351
352 assert_eq!(grads.get(&in_id).unwrap().as_f32_slice().unwrap(), &[0.0, 0.0, 2.0]);
354 }
355
356 #[test]
357 fn test_sigmoid_backward_logic() {
358 let mut grads = Gradients::new();
359 let out_id = TensorId::new();
360 let in_id = TensorId::new();
361
362 grads.insert(out_id, Buffer::from_f32_vec(vec![2.0])).unwrap();
364
365 let node = SigmoidBackward {
367 output_id: out_id, input_id: in_id, output_data: Buffer::from_f32_vec(vec![0.5]),
368 };
369 node.backward(&mut grads).unwrap();
370
371 assert_eq!(grads.get(&in_id).unwrap().as_f32_slice().unwrap(), &[0.5]);
373 }
374}