etensor-core 0.0.1

The pure Rust tensor math and autograd engine
Documentation
//! The Flight Recorder: Thread-local calculus tracking.

use std::cell::RefCell;
use crate::errors::EtensorResult;
use super::gradients::Gradients;

// =====================================================================
// THE ACTION CONTRACT
// =====================================================================

/// The fundamental contract for any mathematical operation that supports backpropagation.
pub trait TapeAction {
    /// Executes the backward mathematical derivative, accumulating results into the `Gradients` map.
    fn backward(&self, grads: &mut Gradients) -> EtensorResult<()>;
    
    /// Returns a human-readable identifier for explainability and graph visualization.
    /// (e.g., "Add(TensorId(1), TensorId(2))")
    fn name(&self) -> String;
}

// =====================================================================
// THE TAPE
// =====================================================================

/// The linear log of operations required to compute gradients.
pub struct Tape {
    pub actions: Vec<Box<dyn TapeAction>>,
}

impl Tape {
    pub fn new() -> Self {
        Self { actions: Vec::new() }
    }

    /// Appends a new operation to the history log.
    pub fn push(&mut self, action: Box<dyn TapeAction>) {
        self.actions.push(action);
    }

    /// Extracts the entire history log, leaving the Tape completely empty.
    /// This is called right before executing the backward pass.
    pub fn take_all(&mut self) -> Vec<Box<dyn TapeAction>> {
        std::mem::take(&mut self.actions)
    }
}

impl Default for Tape {
    fn default() -> Self {
        Self::new()
    }
}

// =====================================================================
// GLOBAL THREAD-LOCAL STATE
// =====================================================================

// The core architectural decision: A thread-local global state.
// Every native OS thread gets its own isolated Tape. There are zero locks, 
// zero deadlocks, and zero waiting.
thread_local! {
    static TAPE: RefCell<Tape> = RefCell::new(Tape::new());
}

/// Pushes a new operation onto the current thread's Tape.
pub fn record(action: Box<dyn TapeAction>) {
    TAPE.with(|t| t.borrow_mut().push(action));
}

/// Extracts and clears the current thread's Tape for backward execution.
pub fn take() -> Vec<Box<dyn TapeAction>> {
    TAPE.with(|t| t.borrow_mut().take_all())
}

// =====================================================================
// UNIT TESTS
// =====================================================================
#[cfg(test)]
mod tests {
    use super::*;
    use std::thread;

    // A mock action strictly for testing the Tape tracking logic
    struct MockAction { 
        id: usize 
    }
    
    impl TapeAction for MockAction {
        fn backward(&self, _grads: &mut Gradients) -> EtensorResult<()> { 
            Ok(()) 
        }
        fn name(&self) -> String { 
            format!("MockAction({})", self.id) 
        }
    }

    #[test]
    fn test_tape_record_and_take() {
        // 1. Ensure tape is perfectly clean
        let _ = take();

        // 2. Record two operations
        record(Box::new(MockAction { id: 1 }));
        record(Box::new(MockAction { id: 2 }));

        // 3. Extract the tape
        let actions = take();
        assert_eq!(actions.len(), 2);
        assert_eq!(actions[0].name(), "MockAction(1)");
        assert_eq!(actions[1].name(), "MockAction(2)");

        // 4. The tape should be completely empty after taking
        assert_eq!(take().len(), 0);
    }

    #[test]
    fn test_thread_local_isolation() {
        // Clear main thread tape and add an operation
        let _ = take();
        record(Box::new(MockAction { id: 99 }));

        // Spawn a brand new OS thread
        let handle = thread::spawn(|| {
            // 1. This thread should NOT see MockAction(99). Its tape is blank.
            let initial = take();
            assert_eq!(initial.len(), 0, "Thread tape bled over! Isolation failed.");

            // 2. Add an action specific to this thread
            record(Box::new(MockAction { id: 42 }));
            let modified = take();
            
            assert_eq!(modified.len(), 1);
            assert_eq!(modified[0].name(), "MockAction(42)");
        });

        // Wait for the background thread to finish
        handle.join().unwrap();

        // Main thread tape should STILL only have 99. The other thread didn't touch it.
        let final_actions = take();
        assert_eq!(final_actions.len(), 1);
        assert_eq!(final_actions[0].name(), "MockAction(99)");
    }
}