use crate::scalar::Scalar;
use std::cell::Cell;
use std::ptr;
thread_local! {
static ACTIVE_TAPE_F32: Cell<*mut Tape<f32>> = const { Cell::new(ptr::null_mut()) };
static ACTIVE_TAPE_F64: Cell<*mut Tape<f64>> = const { Cell::new(ptr::null_mut()) };
}
pub trait TapeStorage: Scalar {
fn get_active_ptr() -> Option<*mut Tape<Self>>;
fn set_active_ptr(ptr: Option<*mut Tape<Self>>);
}
impl TapeStorage for f32 {
#[inline]
fn get_active_ptr() -> Option<*mut Tape<f32>> {
let p = ACTIVE_TAPE_F32.with(|c| c.get());
if p.is_null() { None } else { Some(p) }
}
#[inline]
fn set_active_ptr(ptr: Option<*mut Tape<f32>>) {
ACTIVE_TAPE_F32.with(|c| c.set(ptr.unwrap_or(std::ptr::null_mut())));
}
}
impl TapeStorage for f64 {
#[inline]
fn get_active_ptr() -> Option<*mut Tape<f64>> {
let p = ACTIVE_TAPE_F64.with(|c| c.get());
if p.is_null() { None } else { Some(p) }
}
#[inline]
fn set_active_ptr(ptr: Option<*mut Tape<f64>>) {
ACTIVE_TAPE_F64.with(|c| c.set(ptr.unwrap_or(std::ptr::null_mut())));
}
}
#[derive(Debug, Clone)]
struct Statement {
op_end: u32,
slot: u32,
}
#[derive(Debug, Clone, Copy)]
struct Operation<T: TapeStorage> {
multiplier: T,
slot: u32,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct TapePosition {
pub(crate) statement_pos: u32,
pub(crate) operation_pos: u32,
}
pub struct Tape<T: TapeStorage> {
statements: Vec<Statement>,
operations: Vec<Operation<T>>,
derivatives: Vec<T>,
next_slot: u32,
num_variables: u32,
}
unsafe impl<T: TapeStorage> Send for Tape<T> {}
impl<T: TapeStorage> Tape<T> {
pub fn new(_activate: bool) -> Self {
Tape {
statements: vec![Statement { op_end: 0, slot: u32::MAX }],
operations: Vec::new(),
derivatives: Vec::new(),
next_slot: 0,
num_variables: 0,
}
}
pub fn activate(&mut self) {
if T::get_active_ptr().is_some() {
panic!("A tape is already active on this thread");
}
T::set_active_ptr(Some(self as *mut Tape<T>));
}
pub fn deactivate(&mut self) {
let ptr = self as *mut Tape<T>;
if T::get_active_ptr() == Some(ptr) {
T::set_active_ptr(None);
}
}
pub fn is_active(&self) -> bool {
let ptr = self as *const Tape<T> as *mut Tape<T>;
T::get_active_ptr() == Some(ptr)
}
pub(crate) fn get_active() -> Option<*mut Tape<T>> {
T::get_active_ptr()
}
pub fn deactivate_all() {
T::set_active_ptr(None);
}
#[inline]
pub fn register_variable(&mut self) -> u32 {
let slot = self.next_slot;
debug_assert_eq!(
self.derivatives.len(),
slot as usize,
"tape derivatives invariant: len == next_slot"
);
self.next_slot += 1;
self.num_variables += 1;
self.derivatives.push(T::zero());
slot
}
#[inline]
pub fn push_statement(&mut self, lhs_slot: u32, operands: &[(T, u32)]) {
for &(multiplier, slot) in operands {
self.operations.push(Operation { multiplier, slot });
}
self.statements.push(Statement {
op_end: self.operations.len() as u32,
slot: lhs_slot,
});
}
#[inline]
pub fn push_nullary(&mut self, lhs_slot: u32) {
self.statements.push(Statement {
op_end: self.operations.len() as u32,
slot: lhs_slot,
});
}
#[inline]
pub fn push_unary(&mut self, lhs_slot: u32, multiplier: T, operand_slot: u32) {
if operand_slot != u32::MAX {
self.operations.push(Operation { multiplier, slot: operand_slot });
}
self.statements.push(Statement {
op_end: self.operations.len() as u32,
slot: lhs_slot,
});
}
#[inline]
pub fn push_binary(
&mut self,
lhs_slot: u32,
m1: T,
s1: u32,
m2: T,
s2: u32,
) {
if s1 != u32::MAX {
self.operations.push(Operation { multiplier: m1, slot: s1 });
}
if s2 != u32::MAX {
self.operations.push(Operation { multiplier: m2, slot: s2 });
}
self.statements.push(Statement {
op_end: self.operations.len() as u32,
slot: lhs_slot,
});
}
pub fn new_recording(&mut self) {
self.statements.clear();
self.statements.push(Statement { op_end: 0, slot: u32::MAX });
self.operations.clear();
self.derivatives.clear();
self.next_slot = 0;
self.num_variables = 0;
}
pub fn compute_adjoints(&mut self) {
let end = self.statements.len() as u32;
self.compute_adjoints_to_impl(0, end);
}
pub fn compute_adjoints_to(&mut self, pos: TapePosition) {
let end = self.statements.len() as u32;
self.compute_adjoints_to_impl(pos.statement_pos, end);
}
fn compute_adjoints_to_impl(&mut self, target_pos: u32, start: u32) {
let stmts = self.statements.as_slice();
let ops = self.operations.as_slice();
let derivs = self.derivatives.as_mut_slice();
debug_assert_eq!(derivs.len(), self.num_variables as usize);
let mut i = start as usize;
let stop = target_pos as usize + 1;
while i > stop {
i -= 1;
let stmt = unsafe { stmts.get_unchecked(i) };
let lhs_slot = stmt.slot as usize;
let adjoint = unsafe { *derivs.get_unchecked(lhs_slot) };
if adjoint == T::zero() {
continue;
}
let op_end = stmt.op_end as usize;
let op_start = unsafe { stmts.get_unchecked(i - 1).op_end as usize };
for j in op_start..op_end {
let op = unsafe { *ops.get_unchecked(j) };
unsafe {
*derivs.get_unchecked_mut(op.slot as usize) +=
op.multiplier * adjoint;
}
}
}
}
pub fn clear_derivatives(&mut self) {
for d in self.derivatives.iter_mut() {
*d = T::zero();
}
}
pub fn derivative(&self, slot: u32) -> T {
self.derivatives.get(slot as usize).copied().unwrap_or_else(T::zero)
}
pub fn set_derivative(&mut self, slot: u32, value: T) {
if slot as usize >= self.derivatives.len() {
self.derivatives.resize(slot as usize + 1, T::zero());
}
self.derivatives[slot as usize] = value;
}
pub fn increment_adjoint(&mut self, slot: u32, value: T) {
if slot as usize >= self.derivatives.len() {
self.derivatives.resize(slot as usize + 1, T::zero());
}
self.derivatives[slot as usize] += value;
}
pub fn get_position(&self) -> TapePosition {
TapePosition {
statement_pos: self.statements.len() as u32,
operation_pos: self.operations.len() as u32,
}
}
pub fn clear_derivatives_after(&mut self, pos: TapePosition) {
for i in (pos.statement_pos as usize)..self.statements.len() {
let slot = self.statements[i].slot;
if slot != u32::MAX && (slot as usize) < self.derivatives.len() {
self.derivatives[slot as usize] = T::zero();
}
}
}
pub fn reset_to(&mut self, pos: TapePosition) {
self.statements.truncate(pos.statement_pos as usize);
self.operations.truncate(pos.operation_pos as usize);
}
pub fn num_variables(&self) -> u32 {
self.num_variables
}
pub fn num_operations(&self) -> usize {
self.operations.len()
}
pub fn num_statements(&self) -> usize {
self.statements.len().saturating_sub(1)
}
pub fn memory(&self) -> usize {
self.statements.capacity() * std::mem::size_of::<Statement>()
+ self.operations.capacity() * std::mem::size_of::<Operation<T>>()
+ self.derivatives.capacity() * std::mem::size_of::<T>()
}
}
impl<T: TapeStorage> Drop for Tape<T> {
fn drop(&mut self) {
self.deactivate();
}
}