Skip to main content

etensor_core/autograd/
engine.rs

1//! The Backward Pass Runner.
2
3use crate::tensor::Tensor;
4use crate::buffer::Buffer;
5use crate::dtypes::DType;
6use crate::errors::{EtensorError, EtensorResult};
7use crate::autograd::tape;
8use crate::autograd::gradients::Gradients;
9
10/// Triggers the backpropagation engine starting from the given root tensor (usually the Loss).
11/// 
12/// This function:
13/// 1. Seeds the root tensor's gradient with `1.0`.
14/// 2. Extracts the `thread_local!` Tape.
15/// 3. Executes all recorded `TapeActions` in strict reverse order.
16/// 4. Returns the fully populated `Gradients` map containing all derivatives.
17pub fn backward(root: &Tensor) -> EtensorResult<Gradients> {
18    // 1. Gatekeeper: Ensure the tensor is part of the graph
19    if !root.requires_grad {
20        return Err(EtensorError::AutogradError(
21            "Cannot call backward() on a tensor where requires_grad is false.".to_string(),
22        ));
23    }
24
25    // 2. Initialize the accumulation store
26    let mut grads = Gradients::new();
27
28    // 3. Seed the Root Gradient
29    // In calculus, the derivative of a variable with respect to itself (dx/dx) is exactly 1.0.
30    // For Phase 3, we enforce CPU F32 execution. (Hardware routing logic happens in Phase 4).
31    if root.dtype != DType::F32 || !root.device.is_cpu() {
32        return Err(EtensorError::AutogradError(
33            "Phase 3 backward engine currently only supports CPU Float32 seeding.".to_string(),
34        ));
35    }
36
37    // Allocate a buffer of 1.0s matching the exact size of the root tensor
38    let num_elements = root.shape.num_elements();
39    let seed_buffer = Buffer::from_f32_vec(vec![1.0; num_elements]);
40    
41    // Insert the seed into the gradients map
42    grads.insert(root.id, seed_buffer)?;
43
44    // 4. Extract the computation history
45    let actions = tape::take();
46
47    // 5. Execute Reverse-Mode Autodifferentiation
48    // We must read the tape backwards to correctly apply the Chain Rule from output to input.
49    for action in actions.into_iter().rev() {
50        // If an operation fails internally, we catch the error, stop the loop, 
51        // and bubble it up to Python safely.
52        action.backward(&mut grads)?;
53    }
54
55    Ok(grads)
56}
57
58// =====================================================================
59// UNIT TESTS
60// =====================================================================
61#[cfg(test)]
62mod tests {
63    use super::*;
64    use crate::tensor::{Tensor, TensorId};
65    use crate::shape::Shape;
66    use crate::device::Device;
67    use crate::autograd::tape::{TapeAction, record};
68
69    // A mock mathematical operation simulating: y = x + 0
70    struct MockPassThroughBackward {
71        input_id: TensorId,
72        output_id: TensorId,
73    }
74
75    impl TapeAction for MockPassThroughBackward {
76        fn backward(&self, grads: &mut Gradients) -> EtensorResult<()> {
77            // 1. Retrieve the gradient of the output (dy)
78            let dy_buffer = grads.get(&self.output_id).unwrap().clone();
79            
80            // 2. The derivative of (x + 0) is 1. So dx = dy * 1. 
81            // We just pass the gradient directly back to the input ID.
82            grads.insert(self.input_id, dy_buffer)?;
83            
84            Ok(())
85        }
86
87        fn name(&self) -> String {
88            "MockPassThrough".to_string()
89        }
90    }
91
92    #[test]
93    fn test_requires_grad_guard() {
94        let shape = Shape::new(vec![1]);
95        let data = Buffer::from_f32_vec(vec![42.0]);
96        // requires_grad = false
97        let t = Tensor::new(data, shape, Device::Cpu, DType::F32, false);
98
99        let result = backward(&t);
100        assert!(result.is_err());
101        
102        if let Err(EtensorError::AutogradError(msg)) = result {
103            assert!(msg.contains("requires_grad is false"));
104        } else {
105            panic!("Engine bypassed the requires_grad safety guard!");
106        }
107    }
108
109    #[test]
110    fn test_engine_seeding_and_execution() {
111        // Clean tape just in case
112        let _ = tape::take();
113
114        // 1. Create mock Input and Output tensors
115        let shape = Shape::new(vec![1]);
116        let input_tensor = Tensor::new(
117            Buffer::from_f32_vec(vec![5.0]), shape.clone(), Device::Cpu, DType::F32, true
118        );
119        let output_tensor = Tensor::new(
120            Buffer::from_f32_vec(vec![5.0]), shape, Device::Cpu, DType::F32, true
121        );
122
123        // 2. Push our mock operation to the Tape
124        record(Box::new(MockPassThroughBackward {
125            input_id: input_tensor.id,
126            output_id: output_tensor.id,
127        }));
128
129        // 3. Trigger the engine from the output!
130        let grads = backward(&output_tensor).unwrap();
131
132        // 4. Verify the results
133        // The output should be seeded with 1.0
134        let dy = grads.get(&output_tensor.id).unwrap().as_f32_slice().unwrap();
135        assert_eq!(dy[0], 1.0);
136
137        // The engine should have executed the Tape backwards, passing the 1.0 to the input
138        let dx = grads.get(&input_tensor.id).unwrap().as_f32_slice().unwrap();
139        assert_eq!(dx[0], 1.0);
140        
141        // Ensure tape is completely consumed
142        assert_eq!(tape::take().len(), 0);
143    }
144}