ghostflow_autograd/
tape.rs

1//! Gradient tape for recording operations
2
3use std::cell::RefCell;
4
5thread_local! {
6    static GRAD_TAPE: RefCell<Option<GradTape>> = RefCell::new(None);
7}
8
9/// Gradient tape that records operations for backward pass
10#[derive(Debug, Default)]
11pub struct GradTape {
12    operations: Vec<RecordedOp>,
13    enabled: bool,
14}
15
16/// A recorded operation in the tape
17pub struct RecordedOp {
18    pub op_name: &'static str,
19    pub input_ids: Vec<usize>,
20    pub output_id: usize,
21    pub backward_fn: Box<dyn Fn(&[f32], &[Vec<f32>]) -> Vec<Vec<f32>> + Send + Sync>,
22}
23
24impl std::fmt::Debug for RecordedOp {
25    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26        f.debug_struct("RecordedOp")
27            .field("op_name", &self.op_name)
28            .field("input_ids", &self.input_ids)
29            .field("output_id", &self.output_id)
30            .field("backward_fn", &"<closure>")
31            .finish()
32    }
33}
34
35impl GradTape {
36    /// Create a new gradient tape
37    pub fn new() -> Self {
38        GradTape {
39            operations: Vec::new(),
40            enabled: true,
41        }
42    }
43
44    /// Check if tape is recording
45    pub fn is_enabled(&self) -> bool {
46        self.enabled
47    }
48
49    /// Enable recording
50    pub fn enable(&mut self) {
51        self.enabled = true;
52    }
53
54    /// Disable recording
55    pub fn disable(&mut self) {
56        self.enabled = false;
57    }
58
59    /// Record an operation
60    pub fn record(&mut self, op: RecordedOp) {
61        if self.enabled {
62            self.operations.push(op);
63        }
64    }
65
66    /// Get recorded operations
67    pub fn operations(&self) -> &[RecordedOp] {
68        &self.operations
69    }
70
71    /// Clear the tape
72    pub fn clear(&mut self) {
73        self.operations.clear();
74    }
75}
76
77/// Context manager for gradient tape
78pub struct GradTapeContext;
79
80impl GradTapeContext {
81    /// Start recording gradients
82    pub fn new() -> Self {
83        GRAD_TAPE.with(|tape| {
84            *tape.borrow_mut() = Some(GradTape::new());
85        });
86        GradTapeContext
87    }
88}
89
90impl Default for GradTapeContext {
91    fn default() -> Self {
92        Self::new()
93    }
94}
95
96impl Drop for GradTapeContext {
97    fn drop(&mut self) {
98        GRAD_TAPE.with(|tape| {
99            *tape.borrow_mut() = None;
100        });
101    }
102}
103
104/// Check if we're currently recording
105pub fn is_recording() -> bool {
106    GRAD_TAPE.with(|tape| {
107        tape.borrow().as_ref().map_or(false, |t| t.is_enabled())
108    })
109}
110
111/// Record an operation to the current tape
112pub fn record_op(op: RecordedOp) {
113    GRAD_TAPE.with(|tape| {
114        if let Some(ref mut t) = *tape.borrow_mut() {
115            t.record(op);
116        }
117    });
118}