use std::cell::RefCell;
use crate::errors::EtensorResult;
use super::gradients::Gradients;
pub trait TapeAction {
fn backward(&self, grads: &mut Gradients) -> EtensorResult<()>;
fn name(&self) -> String;
}
pub struct Tape {
pub actions: Vec<Box<dyn TapeAction>>,
}
impl Tape {
pub fn new() -> Self {
Self { actions: Vec::new() }
}
pub fn push(&mut self, action: Box<dyn TapeAction>) {
self.actions.push(action);
}
pub fn take_all(&mut self) -> Vec<Box<dyn TapeAction>> {
std::mem::take(&mut self.actions)
}
}
impl Default for Tape {
fn default() -> Self {
Self::new()
}
}
thread_local! {
static TAPE: RefCell<Tape> = RefCell::new(Tape::new());
}
pub fn record(action: Box<dyn TapeAction>) {
TAPE.with(|t| t.borrow_mut().push(action));
}
pub fn take() -> Vec<Box<dyn TapeAction>> {
TAPE.with(|t| t.borrow_mut().take_all())
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
struct MockAction {
id: usize
}
impl TapeAction for MockAction {
fn backward(&self, _grads: &mut Gradients) -> EtensorResult<()> {
Ok(())
}
fn name(&self) -> String {
format!("MockAction({})", self.id)
}
}
#[test]
fn test_tape_record_and_take() {
let _ = take();
record(Box::new(MockAction { id: 1 }));
record(Box::new(MockAction { id: 2 }));
let actions = take();
assert_eq!(actions.len(), 2);
assert_eq!(actions[0].name(), "MockAction(1)");
assert_eq!(actions[1].name(), "MockAction(2)");
assert_eq!(take().len(), 0);
}
#[test]
fn test_thread_local_isolation() {
let _ = take();
record(Box::new(MockAction { id: 99 }));
let handle = thread::spawn(|| {
let initial = take();
assert_eq!(initial.len(), 0, "Thread tape bled over! Isolation failed.");
record(Box::new(MockAction { id: 42 }));
let modified = take();
assert_eq!(modified.len(), 1);
assert_eq!(modified[0].name(), "MockAction(42)");
});
handle.join().unwrap();
let final_actions = take();
assert_eq!(final_actions.len(), 1);
assert_eq!(final_actions[0].name(), "MockAction(99)");
}
}