etensor_core/autograd/tape.rs
1//! The Flight Recorder: Thread-local calculus tracking.
2
3use std::cell::RefCell;
4use crate::errors::EtensorResult;
5use super::gradients::Gradients;
6
7// =====================================================================
8// THE ACTION CONTRACT
9// =====================================================================
10
11/// The fundamental contract for any mathematical operation that supports backpropagation.
12pub trait TapeAction {
13 /// Executes the backward mathematical derivative, accumulating results into the `Gradients` map.
14 fn backward(&self, grads: &mut Gradients) -> EtensorResult<()>;
15
16 /// Returns a human-readable identifier for explainability and graph visualization.
17 /// (e.g., "Add(TensorId(1), TensorId(2))")
18 fn name(&self) -> String;
19}
20
21// =====================================================================
22// THE TAPE
23// =====================================================================
24
25/// The linear log of operations required to compute gradients.
26pub struct Tape {
27 pub actions: Vec<Box<dyn TapeAction>>,
28}
29
30impl Tape {
31 pub fn new() -> Self {
32 Self { actions: Vec::new() }
33 }
34
35 /// Appends a new operation to the history log.
36 pub fn push(&mut self, action: Box<dyn TapeAction>) {
37 self.actions.push(action);
38 }
39
40 /// Extracts the entire history log, leaving the Tape completely empty.
41 /// This is called right before executing the backward pass.
42 pub fn take_all(&mut self) -> Vec<Box<dyn TapeAction>> {
43 std::mem::take(&mut self.actions)
44 }
45}
46
47impl Default for Tape {
48 fn default() -> Self {
49 Self::new()
50 }
51}
52
53// =====================================================================
54// GLOBAL THREAD-LOCAL STATE
55// =====================================================================
56
57// The core architectural decision: A thread-local global state.
58// Every native OS thread gets its own isolated Tape. There are zero locks,
59// zero deadlocks, and zero waiting.
60thread_local! {
61 static TAPE: RefCell<Tape> = RefCell::new(Tape::new());
62}
63
64/// Pushes a new operation onto the current thread's Tape.
65pub fn record(action: Box<dyn TapeAction>) {
66 TAPE.with(|t| t.borrow_mut().push(action));
67}
68
69/// Extracts and clears the current thread's Tape for backward execution.
70pub fn take() -> Vec<Box<dyn TapeAction>> {
71 TAPE.with(|t| t.borrow_mut().take_all())
72}
73
74// =====================================================================
75// UNIT TESTS
76// =====================================================================
77#[cfg(test)]
78mod tests {
79 use super::*;
80 use std::thread;
81
82 // A mock action strictly for testing the Tape tracking logic
83 struct MockAction {
84 id: usize
85 }
86
87 impl TapeAction for MockAction {
88 fn backward(&self, _grads: &mut Gradients) -> EtensorResult<()> {
89 Ok(())
90 }
91 fn name(&self) -> String {
92 format!("MockAction({})", self.id)
93 }
94 }
95
96 #[test]
97 fn test_tape_record_and_take() {
98 // 1. Ensure tape is perfectly clean
99 let _ = take();
100
101 // 2. Record two operations
102 record(Box::new(MockAction { id: 1 }));
103 record(Box::new(MockAction { id: 2 }));
104
105 // 3. Extract the tape
106 let actions = take();
107 assert_eq!(actions.len(), 2);
108 assert_eq!(actions[0].name(), "MockAction(1)");
109 assert_eq!(actions[1].name(), "MockAction(2)");
110
111 // 4. The tape should be completely empty after taking
112 assert_eq!(take().len(), 0);
113 }
114
115 #[test]
116 fn test_thread_local_isolation() {
117 // Clear main thread tape and add an operation
118 let _ = take();
119 record(Box::new(MockAction { id: 99 }));
120
121 // Spawn a brand new OS thread
122 let handle = thread::spawn(|| {
123 // 1. This thread should NOT see MockAction(99). Its tape is blank.
124 let initial = take();
125 assert_eq!(initial.len(), 0, "Thread tape bled over! Isolation failed.");
126
127 // 2. Add an action specific to this thread
128 record(Box::new(MockAction { id: 42 }));
129 let modified = take();
130
131 assert_eq!(modified.len(), 1);
132 assert_eq!(modified[0].name(), "MockAction(42)");
133 });
134
135 // Wait for the background thread to finish
136 handle.join().unwrap();
137
138 // Main thread tape should STILL only have 99. The other thread didn't touch it.
139 let final_actions = take();
140 assert_eq!(final_actions.len(), 1);
141 assert_eq!(final_actions[0].name(), "MockAction(99)");
142 }
143}