xad-rs 0.2.0

Automatic differentiation library for Rust — forward/reverse mode AD, a Rust port of the C++ XAD library (https://github.com/auto-differentiation/xad)
Documentation
//! `LabeledTape` + `LabeledAReal` — string-keyed reverse-mode AD.
//!
//! The labeled reverse-mode hero type. `LabeledTape` owns a `Tape<f64>` plus
//! a per-instance name -> slot map; the user calls `input()` for each named
//! input, then `freeze()` to lock the registry and activate the tape, then
//! runs the forward closure normally and reads the gradient back by name via
//! `gradient()`.
//!
//! # The two-phase contract
//!
//! 1. **Setup phase** (`new()` -> `input()` calls): registers each named
//!    input on the inner `Tape<f64>` via `AReal::register_input`, which does
//!    NOT require an active tape — slots are assigned eagerly through the
//!    `&mut Tape` reference. The tape stays inactive throughout the setup
//!    phase, so you can construct multiple `LabeledTape`s on one thread
//!    sequentially without conflict (just not concurrently).
//! 2. **Forward + readback phase** (`freeze()` onward): `freeze()` builds the
//!    final `Arc<VarRegistry>` and calls `Tape::activate`. From this point
//!    forward, all arithmetic on the `LabeledAReal` handles is recorded on
//!    this tape. After the forward closure produces its output, call
//!    `gradient(&output)` to read the per-name adjoints as an
//!    `IndexMap<String, f64>` in registry insertion order.
//!
//! # Thread-local discipline (`!Send`)
//!
//! `Tape<f64>` uses a thread-local active-tape pointer (see `src/tape.rs`).
//! `LabeledTape` is structurally `!Send` via `PhantomData<*const ()>` so the
//! compiler refuses to let you move a `LabeledTape` across threads — doing
//! so would corrupt the TLS contract on either side. Two `LabeledTape`s on
//! two threads work fine (each thread has its own TLS pointer); two
//! `LabeledTape`s on **one** thread cannot both be `freeze()`d at the same
//! time (the second `freeze()` panics with `"A tape is already active on
//! this thread"`).
//!
//! # `std::mem::forget` and panic-during-forward hazards
//!
//! - `std::mem::forget(labeled_tape)` skips `Drop`, leaving the TLS pointer
//!   dangling. **Recovery:** call [`LabeledTape::deactivate_all`] before
//!   constructing the next `LabeledTape`.
//! - A panic inside the user's forward closure unwinds normally and runs
//!   `LabeledTape::Drop`, which deactivates the tape — panic safety is
//!   preserved unless the panic itself happens inside another `Drop` (the
//!   standard double-panic abort rule).
//!
//! # `!Send` compile-fail assertion
//!
//! ```compile_fail,E0277
//! use xad_rs::labeled::LabeledTape;
//! fn assert_send<T: Send>(_: T) {}
//! assert_send(LabeledTape::new());
//! ```

use std::fmt;
use std::marker::PhantomData;
use std::sync::Arc;

use indexmap::{IndexMap, IndexSet};

use crate::areal::AReal;
use crate::labeled::VarRegistry;
use crate::math;
use crate::tape::Tape;

/// Labeled wrapper around a positional [`AReal<f64>`].
///
/// **Shape A (minimal):** does NOT carry an `Arc<VarRegistry>` field. The
/// only way to construct a `LabeledAReal` is via [`LabeledTape::input`], so
/// every `LabeledAReal` is structurally tied to exactly one tape via the
/// thread-local active pointer; cross-tape mixing is structurally
/// impossible (the second `LabeledTape::freeze()` on one thread panics).
/// Skipping the `Arc` field saves one atomic increment per operator and is
/// the gate-binding choice for the labeled reverse-mode bench in Phase 2.
#[derive(Clone)]
pub struct LabeledAReal {
    inner: AReal<f64>,
}

impl LabeledAReal {
    /// Internal constructor used by `LabeledTape::input` and the operator
    /// impls. Not part of the public API.
    #[inline]
    pub(crate) fn from_inner(inner: AReal<f64>) -> Self {
        Self { inner }
    }

    /// Underlying value (no derivative information).
    #[inline]
    pub fn value(&self) -> f64 {
        self.inner.value()
    }

    /// Escape hatch: read-only access to the inner `AReal<f64>`.
    #[inline]
    pub fn inner(&self) -> &AReal<f64> {
        &self.inner
    }

    // ============ Elementary math delegations (mirrors Phase 1 LabeledFReal) ============
    #[inline]
    pub fn sin(&self) -> Self {
        Self {
            inner: math::ad::sin(&self.inner),
        }
    }
    #[inline]
    pub fn cos(&self) -> Self {
        Self {
            inner: math::ad::cos(&self.inner),
        }
    }
    #[inline]
    pub fn tan(&self) -> Self {
        Self {
            inner: math::ad::tan(&self.inner),
        }
    }
    #[inline]
    pub fn exp(&self) -> Self {
        Self {
            inner: math::ad::exp(&self.inner),
        }
    }
    #[inline]
    pub fn ln(&self) -> Self {
        Self {
            inner: math::ad::ln(&self.inner),
        }
    }
    #[inline]
    pub fn sqrt(&self) -> Self {
        Self {
            inner: math::ad::sqrt(&self.inner),
        }
    }
    #[inline]
    pub fn tanh(&self) -> Self {
        Self {
            inner: math::ad::tanh(&self.inner),
        }
    }
}

impl fmt::Debug for LabeledAReal {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_struct("LabeledAReal")
            .field("value", &self.inner.value())
            .field("slot", &self.inner.slot())
            .finish()
    }
}

// ============ Operator overloads — hand-written, Shape A ============
// No shared op-stamping macro is used: Shape A does not carry a
// `registry: Arc<VarRegistry>` field, and the historical LBLF-07
// stamping scaffold has been deleted in Plan 02.2-02. The four
// reference variants
// (owned/owned, ref/ref, owned/ref, ref/owned) plus scalar variants are
// stamped explicitly below. The inner `AReal` operators read TLS via
// `record_binary` / `record_unary`, so the labeled wrapper is a pure
// pass-through with zero atomic increments.

macro_rules! __lbl_areal_binop {
    ($trait:ident, $method:ident, $op:tt) => {
        impl ::core::ops::$trait<LabeledAReal> for LabeledAReal {
            type Output = LabeledAReal;
            #[inline]
            fn $method(self, rhs: LabeledAReal) -> LabeledAReal {
                LabeledAReal { inner: self.inner $op rhs.inner }
            }
        }
        impl ::core::ops::$trait<&LabeledAReal> for &LabeledAReal {
            type Output = LabeledAReal;
            #[inline]
            fn $method(self, rhs: &LabeledAReal) -> LabeledAReal {
                LabeledAReal { inner: &self.inner $op &rhs.inner }
            }
        }
        impl ::core::ops::$trait<&LabeledAReal> for LabeledAReal {
            type Output = LabeledAReal;
            #[inline]
            fn $method(self, rhs: &LabeledAReal) -> LabeledAReal {
                LabeledAReal { inner: self.inner $op &rhs.inner }
            }
        }
        impl ::core::ops::$trait<LabeledAReal> for &LabeledAReal {
            type Output = LabeledAReal;
            #[inline]
            fn $method(self, rhs: LabeledAReal) -> LabeledAReal {
                LabeledAReal { inner: &self.inner $op rhs.inner }
            }
        }
        impl ::core::ops::$trait<f64> for LabeledAReal {
            type Output = LabeledAReal;
            #[inline]
            fn $method(self, rhs: f64) -> LabeledAReal {
                LabeledAReal { inner: self.inner $op rhs }
            }
        }
        impl ::core::ops::$trait<f64> for &LabeledAReal {
            type Output = LabeledAReal;
            #[inline]
            fn $method(self, rhs: f64) -> LabeledAReal {
                LabeledAReal { inner: &self.inner $op rhs }
            }
        }
    };
}

__lbl_areal_binop!(Add, add, +);
__lbl_areal_binop!(Sub, sub, -);
__lbl_areal_binop!(Mul, mul, *);
__lbl_areal_binop!(Div, div, /);

impl ::core::ops::Neg for LabeledAReal {
    type Output = LabeledAReal;
    #[inline]
    fn neg(self) -> LabeledAReal {
        LabeledAReal { inner: -self.inner }
    }
}
impl ::core::ops::Neg for &LabeledAReal {
    type Output = LabeledAReal;
    #[inline]
    fn neg(self) -> LabeledAReal {
        LabeledAReal {
            inner: -&self.inner,
        }
    }
}

// ============ LabeledTape ============

/// Labeled reverse-mode tape. Owns a `Tape<f64>`, a name -> slot map, and
/// (after `freeze()`) an `Arc<VarRegistry>`.
///
/// Two-phase usage: see the [module-level docs](crate::labeled::areal).
///
/// `!Send` is enforced structurally via [`PhantomData<*const ()>`]; see the
/// `compile_fail` doctest in the module-level docs.
pub struct LabeledTape {
    tape: Tape<f64>,
    builder: IndexSet<String>,
    inputs: Vec<(String, u32)>,
    registry: Option<Arc<VarRegistry>>,
    frozen: bool,
    _not_send: PhantomData<*const ()>,
}

impl LabeledTape {
    /// Construct a new labeled tape. Does NOT activate the inner `Tape<f64>`
    /// — call [`freeze`](Self::freeze) to lock the registry and activate.
    pub fn new() -> Self {
        Self {
            tape: Tape::<f64>::new(true),
            builder: IndexSet::new(),
            inputs: Vec::new(),
            registry: None,
            frozen: false,
            _not_send: PhantomData,
        }
    }

    /// Register a named input and return a [`LabeledAReal`] handle.
    ///
    /// Eagerly assigns a tape slot via `AReal::register_input(&mut [v], &mut self.tape)`
    /// — `register_input` does NOT require the tape to be active, so this
    /// works during the setup phase before [`freeze`](Self::freeze) runs.
    ///
    /// Panics if called after [`freeze`](Self::freeze).
    pub fn input(&mut self, name: &str, value: f64) -> LabeledAReal {
        assert!(
            !self.frozen,
            "LabeledTape::input({:?}) called after freeze(); add all inputs before running the forward pass",
            name
        );
        // Idempotent insertion: first wins (matches VarRegistry::from_names semantics).
        if !self.builder.contains(name) {
            self.builder.insert(name.to_string());
        }
        let mut ar = AReal::<f64>::new(value);
        AReal::register_input(std::slice::from_mut(&mut ar), &mut self.tape);
        self.inputs.push((name.to_string(), ar.slot()));
        LabeledAReal::from_inner(ar)
    }

    /// Lock the registry, activate the tape, and return the shared
    /// `Arc<VarRegistry>`. Panics if already frozen, or if another tape is
    /// already active on this thread (panic message: `"A tape is already
    /// active on this thread"`).
    pub fn freeze(&mut self) -> Arc<VarRegistry> {
        assert!(
            !self.frozen,
            "LabeledTape::freeze() called twice on the same tape"
        );
        let reg = Arc::new(VarRegistry::from_names(self.builder.iter().cloned()));
        self.registry = Some(Arc::clone(&reg));
        self.tape.activate();
        self.frozen = true;
        reg
    }

    /// True if [`freeze`](Self::freeze) has been called.
    #[inline]
    pub fn is_frozen(&self) -> bool {
        self.frozen
    }

    /// Access the frozen registry, if any. Returns `None` until
    /// [`freeze`](Self::freeze) has been called.
    #[inline]
    pub fn registry(&self) -> Option<&Arc<VarRegistry>> {
        self.registry.as_ref()
    }

    /// Compute the gradient of `output` with respect to every registered
    /// input, returning an `IndexMap<String, f64>` in registry insertion
    /// order (i.e. the order of the [`input`](Self::input) calls).
    ///
    /// Panics if called before [`freeze`](Self::freeze).
    pub fn gradient(&mut self, output: &LabeledAReal) -> IndexMap<String, f64> {
        assert!(
            self.frozen,
            "LabeledTape::gradient() called before freeze()"
        );
        self.tape.clear_derivatives();
        output.inner.set_adjoint(&mut self.tape, 1.0);
        self.tape.compute_adjoints();
        let mut grad = IndexMap::with_capacity(self.inputs.len());
        for (name, slot) in &self.inputs {
            grad.insert(name.clone(), self.tape.derivative(*slot));
        }
        grad
    }

    /// Static escape hatch for the `std::mem::forget` recovery path.
    ///
    /// Wraps `Tape::<f64>::deactivate_all()`. Call this before constructing
    /// a new `LabeledTape` if a previous `LabeledTape` was leaked via
    /// `std::mem::forget` (which would otherwise leave the thread-local
    /// pointer dangling).
    pub fn deactivate_all() {
        Tape::<f64>::deactivate_all();
    }
}

impl Default for LabeledTape {
    fn default() -> Self {
        Self::new()
    }
}

impl fmt::Debug for LabeledTape {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_struct("LabeledTape")
            .field("frozen", &self.frozen)
            .field("inputs", &self.inputs.len())
            .field("registry_len", &self.registry.as_ref().map(|r| r.len()))
            .finish()
    }
}

impl Drop for LabeledTape {
    fn drop(&mut self) {
        // Belt-and-suspenders: explicitly deactivate. The inner `Tape::Drop`
        // also calls `deactivate()`, but doing it here makes the lifecycle
        // contract obvious in the source.
        if self.frozen {
            self.tape.deactivate();
        }
    }
}