Skip to main content

etensor_core/autograd/
tape.rs

1//! The Flight Recorder: Thread-local calculus tracking.
2
3use std::cell::RefCell;
4use crate::errors::EtensorResult;
5use super::gradients::Gradients;
6
7// =====================================================================
8// THE ACTION CONTRACT
9// =====================================================================
10
11/// The fundamental contract for any mathematical operation that supports backpropagation.
12pub trait TapeAction {
13    /// Executes the backward mathematical derivative, accumulating results into the `Gradients` map.
14    fn backward(&self, grads: &mut Gradients) -> EtensorResult<()>;
15    
16    /// Returns a human-readable identifier for explainability and graph visualization.
17    /// (e.g., "Add(TensorId(1), TensorId(2))")
18    fn name(&self) -> String;
19}
20
21// =====================================================================
22// THE TAPE
23// =====================================================================
24
25/// The linear log of operations required to compute gradients.
26pub struct Tape {
27    pub actions: Vec<Box<dyn TapeAction>>,
28}
29
30impl Tape {
31    pub fn new() -> Self {
32        Self { actions: Vec::new() }
33    }
34
35    /// Appends a new operation to the history log.
36    pub fn push(&mut self, action: Box<dyn TapeAction>) {
37        self.actions.push(action);
38    }
39
40    /// Extracts the entire history log, leaving the Tape completely empty.
41    /// This is called right before executing the backward pass.
42    pub fn take_all(&mut self) -> Vec<Box<dyn TapeAction>> {
43        std::mem::take(&mut self.actions)
44    }
45}
46
47impl Default for Tape {
48    fn default() -> Self {
49        Self::new()
50    }
51}
52
53// =====================================================================
54// GLOBAL THREAD-LOCAL STATE
55// =====================================================================
56
57// The core architectural decision: A thread-local global state.
58// Every native OS thread gets its own isolated Tape. There are zero locks, 
59// zero deadlocks, and zero waiting.
60thread_local! {
61    static TAPE: RefCell<Tape> = RefCell::new(Tape::new());
62}
63
64/// Pushes a new operation onto the current thread's Tape.
65pub fn record(action: Box<dyn TapeAction>) {
66    TAPE.with(|t| t.borrow_mut().push(action));
67}
68
69/// Extracts and clears the current thread's Tape for backward execution.
70pub fn take() -> Vec<Box<dyn TapeAction>> {
71    TAPE.with(|t| t.borrow_mut().take_all())
72}
73
74// =====================================================================
75// UNIT TESTS
76// =====================================================================
77#[cfg(test)]
78mod tests {
79    use super::*;
80    use std::thread;
81
82    // A mock action strictly for testing the Tape tracking logic
83    struct MockAction { 
84        id: usize 
85    }
86    
87    impl TapeAction for MockAction {
88        fn backward(&self, _grads: &mut Gradients) -> EtensorResult<()> { 
89            Ok(()) 
90        }
91        fn name(&self) -> String { 
92            format!("MockAction({})", self.id) 
93        }
94    }
95
96    #[test]
97    fn test_tape_record_and_take() {
98        // 1. Ensure tape is perfectly clean
99        let _ = take();
100
101        // 2. Record two operations
102        record(Box::new(MockAction { id: 1 }));
103        record(Box::new(MockAction { id: 2 }));
104
105        // 3. Extract the tape
106        let actions = take();
107        assert_eq!(actions.len(), 2);
108        assert_eq!(actions[0].name(), "MockAction(1)");
109        assert_eq!(actions[1].name(), "MockAction(2)");
110
111        // 4. The tape should be completely empty after taking
112        assert_eq!(take().len(), 0);
113    }
114
115    #[test]
116    fn test_thread_local_isolation() {
117        // Clear main thread tape and add an operation
118        let _ = take();
119        record(Box::new(MockAction { id: 99 }));
120
121        // Spawn a brand new OS thread
122        let handle = thread::spawn(|| {
123            // 1. This thread should NOT see MockAction(99). Its tape is blank.
124            let initial = take();
125            assert_eq!(initial.len(), 0, "Thread tape bled over! Isolation failed.");
126
127            // 2. Add an action specific to this thread
128            record(Box::new(MockAction { id: 42 }));
129            let modified = take();
130            
131            assert_eq!(modified.len(), 1);
132            assert_eq!(modified[0].name(), "MockAction(42)");
133        });
134
135        // Wait for the background thread to finish
136        handle.join().unwrap();
137
138        // Main thread tape should STILL only have 99. The other thread didn't touch it.
139        let final_actions = take();
140        assert_eq!(final_actions.len(), 1);
141        assert_eq!(final_actions[0].name(), "MockAction(99)");
142    }
143}