etensor_core/autograd/
engine.rs1use crate::tensor::Tensor;
4use crate::buffer::Buffer;
5use crate::dtypes::DType;
6use crate::errors::{EtensorError, EtensorResult};
7use crate::autograd::tape;
8use crate::autograd::gradients::Gradients;
9
10pub fn backward(root: &Tensor) -> EtensorResult<Gradients> {
18 if !root.requires_grad {
20 return Err(EtensorError::AutogradError(
21 "Cannot call backward() on a tensor where requires_grad is false.".to_string(),
22 ));
23 }
24
25 let mut grads = Gradients::new();
27
28 if root.dtype != DType::F32 || !root.device.is_cpu() {
32 return Err(EtensorError::AutogradError(
33 "Phase 3 backward engine currently only supports CPU Float32 seeding.".to_string(),
34 ));
35 }
36
37 let num_elements = root.shape.num_elements();
39 let seed_buffer = Buffer::from_f32_vec(vec![1.0; num_elements]);
40
41 grads.insert(root.id, seed_buffer)?;
43
44 let actions = tape::take();
46
47 for action in actions.into_iter().rev() {
50 action.backward(&mut grads)?;
53 }
54
55 Ok(grads)
56}
57
58#[cfg(test)]
62mod tests {
63 use super::*;
64 use crate::tensor::{Tensor, TensorId};
65 use crate::shape::Shape;
66 use crate::device::Device;
67 use crate::autograd::tape::{TapeAction, record};
68
69 struct MockPassThroughBackward {
71 input_id: TensorId,
72 output_id: TensorId,
73 }
74
75 impl TapeAction for MockPassThroughBackward {
76 fn backward(&self, grads: &mut Gradients) -> EtensorResult<()> {
77 let dy_buffer = grads.get(&self.output_id).unwrap().clone();
79
80 grads.insert(self.input_id, dy_buffer)?;
83
84 Ok(())
85 }
86
87 fn name(&self) -> String {
88 "MockPassThrough".to_string()
89 }
90 }
91
92 #[test]
93 fn test_requires_grad_guard() {
94 let shape = Shape::new(vec![1]);
95 let data = Buffer::from_f32_vec(vec![42.0]);
96 let t = Tensor::new(data, shape, Device::Cpu, DType::F32, false);
98
99 let result = backward(&t);
100 assert!(result.is_err());
101
102 if let Err(EtensorError::AutogradError(msg)) = result {
103 assert!(msg.contains("requires_grad is false"));
104 } else {
105 panic!("Engine bypassed the requires_grad safety guard!");
106 }
107 }
108
109 #[test]
110 fn test_engine_seeding_and_execution() {
111 let _ = tape::take();
113
114 let shape = Shape::new(vec![1]);
116 let input_tensor = Tensor::new(
117 Buffer::from_f32_vec(vec![5.0]), shape.clone(), Device::Cpu, DType::F32, true
118 );
119 let output_tensor = Tensor::new(
120 Buffer::from_f32_vec(vec![5.0]), shape, Device::Cpu, DType::F32, true
121 );
122
123 record(Box::new(MockPassThroughBackward {
125 input_id: input_tensor.id,
126 output_id: output_tensor.id,
127 }));
128
129 let grads = backward(&output_tensor).unwrap();
131
132 let dy = grads.get(&output_tensor.id).unwrap().as_f32_slice().unwrap();
135 assert_eq!(dy[0], 1.0);
136
137 let dx = grads.get(&input_tensor.id).unwrap().as_f32_slice().unwrap();
139 assert_eq!(dx[0], 1.0);
140
141 assert_eq!(tape::take().len(), 0);
143 }
144}