ghostflow_autograd/
tape.rs

1//! Gradient tape for recording operations
2
3use std::cell::RefCell;
4
5thread_local! {
6    static GRAD_TAPE: RefCell<Option<GradTape>> = const { RefCell::new(None) };
7}
8
9/// Gradient tape that records operations for backward pass
10#[derive(Debug, Default)]
11pub struct GradTape {
12    operations: Vec<RecordedOp>,
13    enabled: bool,
14}
15
16/// Type alias for backward function
17type BackwardFn = Box<dyn Fn(&[f32], &[Vec<f32>]) -> Vec<Vec<f32>> + Send + Sync>;
18
19/// A recorded operation in the tape
20pub struct RecordedOp {
21    pub op_name: &'static str,
22    pub input_ids: Vec<usize>,
23    pub output_id: usize,
24    pub backward_fn: BackwardFn,
25}
26
27impl std::fmt::Debug for RecordedOp {
28    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29        f.debug_struct("RecordedOp")
30            .field("op_name", &self.op_name)
31            .field("input_ids", &self.input_ids)
32            .field("output_id", &self.output_id)
33            .field("backward_fn", &"<closure>")
34            .finish()
35    }
36}
37
38impl GradTape {
39    /// Create a new gradient tape
40    pub fn new() -> Self {
41        GradTape {
42            operations: Vec::new(),
43            enabled: true,
44        }
45    }
46
47    /// Check if tape is recording
48    pub fn is_enabled(&self) -> bool {
49        self.enabled
50    }
51
52    /// Enable recording
53    pub fn enable(&mut self) {
54        self.enabled = true;
55    }
56
57    /// Disable recording
58    pub fn disable(&mut self) {
59        self.enabled = false;
60    }
61
62    /// Record an operation
63    pub fn record(&mut self, op: RecordedOp) {
64        if self.enabled {
65            self.operations.push(op);
66        }
67    }
68
69    /// Get recorded operations
70    pub fn operations(&self) -> &[RecordedOp] {
71        &self.operations
72    }
73
74    /// Clear the tape
75    pub fn clear(&mut self) {
76        self.operations.clear();
77    }
78}
79
80/// Context manager for gradient tape
81pub struct GradTapeContext;
82
83impl GradTapeContext {
84    /// Start recording gradients
85    pub fn new() -> Self {
86        GRAD_TAPE.with(|tape| {
87            *tape.borrow_mut() = Some(GradTape::new());
88        });
89        GradTapeContext
90    }
91}
92
93impl Default for GradTapeContext {
94    fn default() -> Self {
95        Self::new()
96    }
97}
98
99impl Drop for GradTapeContext {
100    fn drop(&mut self) {
101        GRAD_TAPE.with(|tape| {
102            *tape.borrow_mut() = None;
103        });
104    }
105}
106
107/// Check if we're currently recording
108pub fn is_recording() -> bool {
109    GRAD_TAPE.with(|tape| {
110        tape.borrow().as_ref().is_some_and(|t| t.is_enabled())
111    })
112}
113
114/// Record an operation to the current tape
115pub fn record_op(op: RecordedOp) {
116    GRAD_TAPE.with(|tape| {
117        if let Some(ref mut t) = *tape.borrow_mut() {
118            t.record(op);
119        }
120    });
121}