ghostflow_autograd/
tape.rs1use std::cell::RefCell;
4
5thread_local! {
6 static GRAD_TAPE: RefCell<Option<GradTape>> = const { RefCell::new(None) };
7}
8
9#[derive(Debug, Default)]
11pub struct GradTape {
12 operations: Vec<RecordedOp>,
13 enabled: bool,
14}
15
16type BackwardFn = Box<dyn Fn(&[f32], &[Vec<f32>]) -> Vec<Vec<f32>> + Send + Sync>;
18
19pub struct RecordedOp {
21 pub op_name: &'static str,
22 pub input_ids: Vec<usize>,
23 pub output_id: usize,
24 pub backward_fn: BackwardFn,
25}
26
27impl std::fmt::Debug for RecordedOp {
28 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29 f.debug_struct("RecordedOp")
30 .field("op_name", &self.op_name)
31 .field("input_ids", &self.input_ids)
32 .field("output_id", &self.output_id)
33 .field("backward_fn", &"<closure>")
34 .finish()
35 }
36}
37
38impl GradTape {
39 pub fn new() -> Self {
41 GradTape {
42 operations: Vec::new(),
43 enabled: true,
44 }
45 }
46
47 pub fn is_enabled(&self) -> bool {
49 self.enabled
50 }
51
52 pub fn enable(&mut self) {
54 self.enabled = true;
55 }
56
57 pub fn disable(&mut self) {
59 self.enabled = false;
60 }
61
62 pub fn record(&mut self, op: RecordedOp) {
64 if self.enabled {
65 self.operations.push(op);
66 }
67 }
68
69 pub fn operations(&self) -> &[RecordedOp] {
71 &self.operations
72 }
73
74 pub fn clear(&mut self) {
76 self.operations.clear();
77 }
78}
79
80pub struct GradTapeContext;
82
83impl GradTapeContext {
84 pub fn new() -> Self {
86 GRAD_TAPE.with(|tape| {
87 *tape.borrow_mut() = Some(GradTape::new());
88 });
89 GradTapeContext
90 }
91}
92
93impl Default for GradTapeContext {
94 fn default() -> Self {
95 Self::new()
96 }
97}
98
99impl Drop for GradTapeContext {
100 fn drop(&mut self) {
101 GRAD_TAPE.with(|tape| {
102 *tape.borrow_mut() = None;
103 });
104 }
105}
106
107pub fn is_recording() -> bool {
109 GRAD_TAPE.with(|tape| {
110 tape.borrow().as_ref().is_some_and(|t| t.is_enabled())
111 })
112}
113
114pub fn record_op(op: RecordedOp) {
116 GRAD_TAPE.with(|tape| {
117 if let Some(ref mut t) = *tape.borrow_mut() {
118 t.record(op);
119 }
120 });
121}