use std::cell::Cell;
use std::marker::PhantomData;
use crate::Float;
pub const CONSTANT: u32 = u32::MAX;
#[derive(Clone, Copy, Debug)]
struct Statement {
lhs_index: u32,
end_plus_one: u32,
}
pub struct Tape<F: Float> {
statements: Vec<Statement>,
multipliers: Vec<F>,
indices: Vec<u32>,
num_variables: u32,
}
impl<F: Float> Default for Tape<F> {
fn default() -> Self {
Self::new()
}
}
impl<F: Float> Tape<F> {
#[must_use]
pub fn new() -> Self {
let mut tape = Tape {
statements: Vec::new(),
multipliers: Vec::new(),
indices: Vec::new(),
num_variables: 0,
};
tape.statements.push(Statement {
lhs_index: 0,
end_plus_one: 0,
});
tape
}
#[must_use]
pub fn with_capacity(est_ops: usize) -> Self {
let mut tape = Tape {
statements: Vec::with_capacity(est_ops + 1),
multipliers: Vec::with_capacity(est_ops * 2),
indices: Vec::with_capacity(est_ops * 2),
num_variables: 0,
};
tape.statements.push(Statement {
lhs_index: 0,
end_plus_one: 0,
});
tape
}
pub fn clear(&mut self) {
self.statements.clear();
self.multipliers.clear();
self.indices.clear();
self.num_variables = 0;
self.statements.push(Statement {
lhs_index: 0,
end_plus_one: 0,
});
}
#[inline]
pub fn new_variable(&mut self, value: F) -> (u32, F) {
debug_assert!(
self.num_variables < u32::MAX - 1,
"tape variable count overflow: exceeded u32::MAX"
);
let idx = self.num_variables;
self.num_variables += 1;
(idx, value)
}
#[inline]
pub fn push_unary(&mut self, operand_idx: u32, multiplier: F) -> u32 {
debug_assert!(
self.num_variables < u32::MAX - 1,
"tape variable count overflow: exceeded u32::MAX"
);
let result_idx = self.num_variables;
self.num_variables += 1;
if operand_idx != CONSTANT {
self.multipliers.push(multiplier);
self.indices.push(operand_idx);
}
self.statements.push(Statement {
lhs_index: result_idx,
end_plus_one: self.multipliers.len() as u32,
});
result_idx
}
#[inline]
pub fn push_binary(&mut self, lhs_idx: u32, lhs_mult: F, rhs_idx: u32, rhs_mult: F) -> u32 {
debug_assert!(
self.num_variables < u32::MAX - 1,
"tape variable count overflow: exceeded u32::MAX"
);
let result_idx = self.num_variables;
self.num_variables += 1;
if lhs_idx != CONSTANT {
self.multipliers.push(lhs_mult);
self.indices.push(lhs_idx);
}
if rhs_idx != CONSTANT {
self.multipliers.push(rhs_mult);
self.indices.push(rhs_idx);
}
self.statements.push(Statement {
lhs_index: result_idx,
end_plus_one: self.multipliers.len() as u32,
});
result_idx
}
#[must_use]
pub fn reverse(&self, seed_index: u32) -> Vec<F> {
let mut adjoints = vec![F::zero(); self.num_variables as usize];
adjoints[seed_index as usize] = F::one();
for i in (1..self.statements.len()).rev() {
let stmt = self.statements[i];
let a = adjoints[stmt.lhs_index as usize];
if a != F::zero() {
adjoints[stmt.lhs_index as usize] = F::zero();
let start = self.statements[i - 1].end_plus_one as usize;
let end = stmt.end_plus_one as usize;
for j in start..end {
adjoints[self.indices[j] as usize] =
adjoints[self.indices[j] as usize] + self.multipliers[j] * a;
}
}
}
adjoints
}
pub fn reverse_seeded(&self, seeds: &[(u32, F)]) -> Vec<F> {
let mut adjoints = vec![F::zero(); self.num_variables as usize];
for &(idx, seed) in seeds {
adjoints[idx as usize] = adjoints[idx as usize] + seed;
}
for i in (1..self.statements.len()).rev() {
let stmt = self.statements[i];
let a = adjoints[stmt.lhs_index as usize];
if a != F::zero() {
adjoints[stmt.lhs_index as usize] = F::zero();
let start = self.statements[i - 1].end_plus_one as usize;
let end = stmt.end_plus_one as usize;
for j in start..end {
adjoints[self.indices[j] as usize] =
adjoints[self.indices[j] as usize] + self.multipliers[j] * a;
}
}
}
adjoints
}
}
thread_local! {
static TAPE_F32: Cell<*mut Tape<f32>> = const { Cell::new(std::ptr::null_mut()) };
static TAPE_F64: Cell<*mut Tape<f64>> = const { Cell::new(std::ptr::null_mut()) };
}
thread_local! {
static POOL_F32: Cell<Option<Tape<f32>>> = const { Cell::new(None) };
static POOL_F64: Cell<Option<Tape<f64>>> = const { Cell::new(None) };
}
pub trait TapeThreadLocal: Float {
fn cell() -> &'static std::thread::LocalKey<Cell<*mut Tape<Self>>>;
fn pool_cell() -> &'static std::thread::LocalKey<Cell<Option<Tape<Self>>>>;
fn borrow_cell() -> &'static std::thread::LocalKey<Cell<bool>>;
}
impl TapeThreadLocal for f32 {
fn cell() -> &'static std::thread::LocalKey<Cell<*mut Tape<Self>>> {
&TAPE_F32
}
fn pool_cell() -> &'static std::thread::LocalKey<Cell<Option<Tape<Self>>>> {
&POOL_F32
}
fn borrow_cell() -> &'static std::thread::LocalKey<Cell<bool>> {
&TAPE_BORROWED_F32
}
}
impl TapeThreadLocal for f64 {
fn cell() -> &'static std::thread::LocalKey<Cell<*mut Tape<Self>>> {
&TAPE_F64
}
fn pool_cell() -> &'static std::thread::LocalKey<Cell<Option<Tape<Self>>>> {
&POOL_F64
}
fn borrow_cell() -> &'static std::thread::LocalKey<Cell<bool>> {
&TAPE_BORROWED_F64
}
}
impl<F: TapeThreadLocal> Tape<F> {
pub(crate) fn take_pooled(capacity: usize) -> Self {
F::pool_cell().with(|cell| match cell.take() {
Some(mut tape) => {
tape.clear();
tape
}
None => Tape::with_capacity(capacity),
})
}
pub(crate) fn return_to_pool(self) {
F::pool_cell().with(|cell| cell.set(Some(self)));
}
}
thread_local! {
static TAPE_BORROWED_F32: Cell<bool> = const { Cell::new(false) };
static TAPE_BORROWED_F64: Cell<bool> = const { Cell::new(false) };
}
struct TapeBorrowGuard {
cell: &'static std::thread::LocalKey<Cell<bool>>,
}
impl TapeBorrowGuard {
fn new<F: TapeThreadLocal>() -> Self {
let cell = F::borrow_cell();
cell.with(|b| {
assert!(
!b.get(),
"reentrant with_active_tape call detected — this would create aliased &mut references"
);
b.set(true);
});
TapeBorrowGuard { cell }
}
}
impl Drop for TapeBorrowGuard {
fn drop(&mut self) {
self.cell.with(|b| b.set(false));
}
}
#[inline]
pub fn with_active_tape<F: TapeThreadLocal, R>(f: impl FnOnce(&mut Tape<F>) -> R) -> R {
let _guard = TapeBorrowGuard::new::<F>();
F::cell().with(|cell| {
let ptr = cell.get();
assert!(
!ptr.is_null(),
"No active tape. Use echidna::grad() or similar API."
);
let tape = unsafe { &mut *ptr };
f(tape)
})
}
pub struct TapeGuard<'a, F: TapeThreadLocal> {
prev: *mut Tape<F>,
_borrow: PhantomData<&'a mut Tape<F>>,
}
impl<'a, F: TapeThreadLocal> TapeGuard<'a, F> {
#[must_use = "dropping the guard immediately deactivates the tape; bind it to extend the recording scope"]
pub fn new(tape: &'a mut Tape<F>) -> Self {
let prev = F::cell().with(|cell| {
let prev = cell.get();
cell.set(tape as *mut Tape<F>);
prev
});
TapeGuard {
prev,
_borrow: PhantomData,
}
}
}
impl<'a, F: TapeThreadLocal> Drop for TapeGuard<'a, F> {
fn drop(&mut self) {
F::cell().with(|cell| {
cell.set(self.prev);
});
}
}