ghostflow_autograd/
tape.rs1use std::cell::RefCell;
4
5thread_local! {
6 static GRAD_TAPE: RefCell<Option<GradTape>> = RefCell::new(None);
7}
8
9#[derive(Debug, Default)]
11pub struct GradTape {
12 operations: Vec<RecordedOp>,
13 enabled: bool,
14}
15
16pub struct RecordedOp {
18 pub op_name: &'static str,
19 pub input_ids: Vec<usize>,
20 pub output_id: usize,
21 pub backward_fn: Box<dyn Fn(&[f32], &[Vec<f32>]) -> Vec<Vec<f32>> + Send + Sync>,
22}
23
24impl std::fmt::Debug for RecordedOp {
25 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26 f.debug_struct("RecordedOp")
27 .field("op_name", &self.op_name)
28 .field("input_ids", &self.input_ids)
29 .field("output_id", &self.output_id)
30 .field("backward_fn", &"<closure>")
31 .finish()
32 }
33}
34
35impl GradTape {
36 pub fn new() -> Self {
38 GradTape {
39 operations: Vec::new(),
40 enabled: true,
41 }
42 }
43
44 pub fn is_enabled(&self) -> bool {
46 self.enabled
47 }
48
49 pub fn enable(&mut self) {
51 self.enabled = true;
52 }
53
54 pub fn disable(&mut self) {
56 self.enabled = false;
57 }
58
59 pub fn record(&mut self, op: RecordedOp) {
61 if self.enabled {
62 self.operations.push(op);
63 }
64 }
65
66 pub fn operations(&self) -> &[RecordedOp] {
68 &self.operations
69 }
70
71 pub fn clear(&mut self) {
73 self.operations.clear();
74 }
75}
76
77pub struct GradTapeContext;
79
80impl GradTapeContext {
81 pub fn new() -> Self {
83 GRAD_TAPE.with(|tape| {
84 *tape.borrow_mut() = Some(GradTape::new());
85 });
86 GradTapeContext
87 }
88}
89
90impl Default for GradTapeContext {
91 fn default() -> Self {
92 Self::new()
93 }
94}
95
96impl Drop for GradTapeContext {
97 fn drop(&mut self) {
98 GRAD_TAPE.with(|tape| {
99 *tape.borrow_mut() = None;
100 });
101 }
102}
103
104pub fn is_recording() -> bool {
106 GRAD_TAPE.with(|tape| {
107 tape.borrow().as_ref().map_or(false, |t| t.is_enabled())
108 })
109}
110
111pub fn record_op(op: RecordedOp) {
113 GRAD_TAPE.with(|tape| {
114 if let Some(ref mut t) = *tape.borrow_mut() {
115 t.record(op);
116 }
117 });
118}