use std::fmt;
use std::marker::PhantomData;
use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign};
use std::sync::Arc;
use indexmap::{IndexMap, IndexSet};
use crate::math;
use crate::registry::VarRegistry;
use crate::tape::{Tape, TapeStorage};
#[derive(Clone)]
pub struct AReal<T: TapeStorage> {
value: T,
pub(crate) slot: u32,
}
const INVALID_SLOT: u32 = u32::MAX;
impl<T: TapeStorage> AReal<T> {
pub fn new(value: T) -> Self {
AReal {
value,
slot: INVALID_SLOT,
}
}
#[inline]
pub fn value(&self) -> T {
self.value
}
#[inline]
pub fn set_value(&mut self, v: T) {
self.value = v;
}
#[inline]
pub fn slot(&self) -> u32 {
self.slot
}
#[inline]
pub fn should_record(&self) -> bool {
self.slot != INVALID_SLOT
}
#[inline]
pub fn adjoint(&self, tape: &Tape<T>) -> T {
if self.slot == INVALID_SLOT {
T::zero()
} else {
tape.derivative(self.slot)
}
}
#[inline]
pub fn set_adjoint(&self, tape: &mut Tape<T>, value: T) {
if self.slot != INVALID_SLOT {
tape.set_derivative(self.slot, value);
}
}
pub fn register_input(vars: &mut [AReal<T>], tape: &mut Tape<T>) {
for v in vars.iter_mut() {
if !v.should_record() {
v.slot = tape.register_variable();
}
}
}
pub fn register_output(vars: &mut [AReal<T>], tape: &mut Tape<T>) {
for v in vars.iter_mut() {
if !v.should_record() {
v.slot = tape.register_variable();
}
}
}
}
#[inline]
fn record_binary<T: TapeStorage>(
result_value: T,
lhs_slot: u32,
lhs_mul: T,
rhs_slot: u32,
rhs_mul: T,
) -> AReal<T> {
let tape_ptr = Tape::<T>::get_active();
if let Some(ptr) = tape_ptr {
let tape = unsafe { &mut *ptr };
let slot = tape.register_variable();
tape.push_binary(slot, lhs_mul, lhs_slot, rhs_mul, rhs_slot);
AReal { value: result_value, slot }
} else {
AReal::new(result_value)
}
}
#[inline]
fn record_unary<T: TapeStorage>(result_value: T, input_slot: u32, multiplier: T) -> AReal<T> {
let tape_ptr = Tape::<T>::get_active();
if let Some(ptr) = tape_ptr {
let tape = unsafe { &mut *ptr };
let slot = tape.register_variable();
tape.push_unary(slot, multiplier, input_slot);
AReal { value: result_value, slot }
} else {
AReal::new(result_value)
}
}
pub(crate) fn record_unary_op<T: TapeStorage>(
result_value: T,
input_slot: u32,
multiplier: T,
) -> AReal<T> {
record_unary(result_value, input_slot, multiplier)
}
pub(crate) fn record_binary_op<T: TapeStorage>(
result_value: T,
a_slot: u32,
a_mul: T,
b_slot: u32,
b_mul: T,
) -> AReal<T> {
record_binary(result_value, a_slot, a_mul, b_slot, b_mul)
}
impl<T: TapeStorage> From<T> for AReal<T> {
fn from(value: T) -> Self {
AReal::new(value)
}
}
impl From<i32> for AReal<f64> {
fn from(value: i32) -> Self {
AReal::new(value as f64)
}
}
impl From<i32> for AReal<f32> {
fn from(value: i32) -> Self {
AReal::new(value as f32)
}
}
impl<T: TapeStorage> Add for AReal<T> {
type Output = AReal<T>;
#[inline]
fn add(self, rhs: AReal<T>) -> AReal<T> {
record_binary(
self.value + rhs.value,
self.slot,
T::one(),
rhs.slot,
T::one(),
)
}
}
impl<T: TapeStorage> Add for &AReal<T> {
type Output = AReal<T>;
#[inline]
fn add(self, rhs: &AReal<T>) -> AReal<T> {
record_binary(
self.value + rhs.value,
self.slot,
T::one(),
rhs.slot,
T::one(),
)
}
}
impl<T: TapeStorage> Add<&AReal<T>> for AReal<T> {
type Output = AReal<T>;
#[inline]
fn add(self, rhs: &AReal<T>) -> AReal<T> {
record_binary(
self.value + rhs.value,
self.slot,
T::one(),
rhs.slot,
T::one(),
)
}
}
impl<T: TapeStorage> Add<AReal<T>> for &AReal<T> {
type Output = AReal<T>;
#[inline]
fn add(self, rhs: AReal<T>) -> AReal<T> {
record_binary(
self.value + rhs.value,
self.slot,
T::one(),
rhs.slot,
T::one(),
)
}
}
impl<T: TapeStorage> Add<T> for AReal<T> {
type Output = AReal<T>;
#[inline]
fn add(self, rhs: T) -> AReal<T> {
record_unary(self.value + rhs, self.slot, T::one())
}
}
impl<T: TapeStorage> Add<T> for &AReal<T> {
type Output = AReal<T>;
#[inline]
fn add(self, rhs: T) -> AReal<T> {
record_unary(self.value + rhs, self.slot, T::one())
}
}
impl Add<AReal<f64>> for f64 {
type Output = AReal<f64>;
#[inline]
fn add(self, rhs: AReal<f64>) -> AReal<f64> {
record_unary(self + rhs.value, rhs.slot, 1.0)
}
}
impl Add<&AReal<f64>> for f64 {
type Output = AReal<f64>;
#[inline]
fn add(self, rhs: &AReal<f64>) -> AReal<f64> {
record_unary(self + rhs.value, rhs.slot, 1.0)
}
}
impl Add<AReal<f32>> for f32 {
type Output = AReal<f32>;
#[inline]
fn add(self, rhs: AReal<f32>) -> AReal<f32> {
record_unary(self + rhs.value, rhs.slot, 1.0)
}
}
impl Add<&AReal<f32>> for f32 {
type Output = AReal<f32>;
#[inline]
fn add(self, rhs: &AReal<f32>) -> AReal<f32> {
record_unary(self + rhs.value, rhs.slot, 1.0)
}
}
impl<T: TapeStorage> Sub for AReal<T> {
type Output = AReal<T>;
#[inline]
fn sub(self, rhs: AReal<T>) -> AReal<T> {
record_binary(
self.value - rhs.value,
self.slot,
T::one(),
rhs.slot,
-T::one(),
)
}
}
impl<T: TapeStorage> Sub for &AReal<T> {
type Output = AReal<T>;
#[inline]
fn sub(self, rhs: &AReal<T>) -> AReal<T> {
record_binary(
self.value - rhs.value,
self.slot,
T::one(),
rhs.slot,
-T::one(),
)
}
}
impl<T: TapeStorage> Sub<&AReal<T>> for AReal<T> {
type Output = AReal<T>;
#[inline]
fn sub(self, rhs: &AReal<T>) -> AReal<T> {
record_binary(
self.value - rhs.value,
self.slot,
T::one(),
rhs.slot,
-T::one(),
)
}
}
impl<T: TapeStorage> Sub<AReal<T>> for &AReal<T> {
type Output = AReal<T>;
#[inline]
fn sub(self, rhs: AReal<T>) -> AReal<T> {
record_binary(
self.value - rhs.value,
self.slot,
T::one(),
rhs.slot,
-T::one(),
)
}
}
impl<T: TapeStorage> Sub<T> for AReal<T> {
type Output = AReal<T>;
#[inline]
fn sub(self, rhs: T) -> AReal<T> {
record_unary(self.value - rhs, self.slot, T::one())
}
}
impl<T: TapeStorage> Sub<T> for &AReal<T> {
type Output = AReal<T>;
#[inline]
fn sub(self, rhs: T) -> AReal<T> {
record_unary(self.value - rhs, self.slot, T::one())
}
}
impl Sub<AReal<f64>> for f64 {
type Output = AReal<f64>;
#[inline]
fn sub(self, rhs: AReal<f64>) -> AReal<f64> {
record_unary(self - rhs.value, rhs.slot, -1.0)
}
}
impl Sub<&AReal<f64>> for f64 {
type Output = AReal<f64>;
#[inline]
fn sub(self, rhs: &AReal<f64>) -> AReal<f64> {
record_unary(self - rhs.value, rhs.slot, -1.0)
}
}
impl Sub<AReal<f32>> for f32 {
type Output = AReal<f32>;
#[inline]
fn sub(self, rhs: AReal<f32>) -> AReal<f32> {
record_unary(self - rhs.value, rhs.slot, -1.0)
}
}
impl Sub<&AReal<f32>> for f32 {
type Output = AReal<f32>;
#[inline]
fn sub(self, rhs: &AReal<f32>) -> AReal<f32> {
record_unary(self - rhs.value, rhs.slot, -1.0)
}
}
impl<T: TapeStorage> Mul for AReal<T> {
type Output = AReal<T>;
#[inline]
fn mul(self, rhs: AReal<T>) -> AReal<T> {
record_binary(
self.value * rhs.value,
self.slot,
rhs.value,
rhs.slot,
self.value,
)
}
}
impl<T: TapeStorage> Mul for &AReal<T> {
type Output = AReal<T>;
#[inline]
fn mul(self, rhs: &AReal<T>) -> AReal<T> {
record_binary(
self.value * rhs.value,
self.slot,
rhs.value,
rhs.slot,
self.value,
)
}
}
impl<T: TapeStorage> Mul<&AReal<T>> for AReal<T> {
type Output = AReal<T>;
#[inline]
fn mul(self, rhs: &AReal<T>) -> AReal<T> {
record_binary(
self.value * rhs.value,
self.slot,
rhs.value,
rhs.slot,
self.value,
)
}
}
impl<T: TapeStorage> Mul<AReal<T>> for &AReal<T> {
type Output = AReal<T>;
#[inline]
fn mul(self, rhs: AReal<T>) -> AReal<T> {
record_binary(
self.value * rhs.value,
self.slot,
rhs.value,
rhs.slot,
self.value,
)
}
}
impl<T: TapeStorage> Mul<T> for AReal<T> {
type Output = AReal<T>;
#[inline]
fn mul(self, rhs: T) -> AReal<T> {
record_unary(self.value * rhs, self.slot, rhs)
}
}
impl<T: TapeStorage> Mul<T> for &AReal<T> {
type Output = AReal<T>;
#[inline]
fn mul(self, rhs: T) -> AReal<T> {
record_unary(self.value * rhs, self.slot, rhs)
}
}
impl Mul<AReal<f64>> for f64 {
type Output = AReal<f64>;
#[inline]
fn mul(self, rhs: AReal<f64>) -> AReal<f64> {
record_unary(self * rhs.value, rhs.slot, self)
}
}
impl Mul<&AReal<f64>> for f64 {
type Output = AReal<f64>;
#[inline]
fn mul(self, rhs: &AReal<f64>) -> AReal<f64> {
record_unary(self * rhs.value, rhs.slot, self)
}
}
impl Mul<AReal<f32>> for f32 {
type Output = AReal<f32>;
#[inline]
fn mul(self, rhs: AReal<f32>) -> AReal<f32> {
record_unary(self * rhs.value, rhs.slot, self)
}
}
impl Mul<&AReal<f32>> for f32 {
type Output = AReal<f32>;
#[inline]
fn mul(self, rhs: &AReal<f32>) -> AReal<f32> {
record_unary(self * rhs.value, rhs.slot, self)
}
}
impl<T: TapeStorage> Div for AReal<T> {
type Output = AReal<T>;
#[inline]
fn div(self, rhs: AReal<T>) -> AReal<T> {
let inv_b = T::one() / rhs.value;
record_binary(
self.value * inv_b,
self.slot,
inv_b,
rhs.slot,
-self.value * inv_b * inv_b,
)
}
}
impl<T: TapeStorage> Div for &AReal<T> {
type Output = AReal<T>;
#[inline]
fn div(self, rhs: &AReal<T>) -> AReal<T> {
let inv_b = T::one() / rhs.value;
record_binary(
self.value * inv_b,
self.slot,
inv_b,
rhs.slot,
-self.value * inv_b * inv_b,
)
}
}
impl<T: TapeStorage> Div<&AReal<T>> for AReal<T> {
type Output = AReal<T>;
#[inline]
fn div(self, rhs: &AReal<T>) -> AReal<T> {
let inv_b = T::one() / rhs.value;
record_binary(
self.value * inv_b,
self.slot,
inv_b,
rhs.slot,
-self.value * inv_b * inv_b,
)
}
}
impl<T: TapeStorage> Div<AReal<T>> for &AReal<T> {
type Output = AReal<T>;
#[inline]
fn div(self, rhs: AReal<T>) -> AReal<T> {
let inv_b = T::one() / rhs.value;
record_binary(
self.value * inv_b,
self.slot,
inv_b,
rhs.slot,
-self.value * inv_b * inv_b,
)
}
}
impl<T: TapeStorage> Div<T> for AReal<T> {
type Output = AReal<T>;
#[inline]
fn div(self, rhs: T) -> AReal<T> {
let inv = T::one() / rhs;
record_unary(self.value * inv, self.slot, inv)
}
}
impl<T: TapeStorage> Div<T> for &AReal<T> {
type Output = AReal<T>;
#[inline]
fn div(self, rhs: T) -> AReal<T> {
let inv = T::one() / rhs;
record_unary(self.value * inv, self.slot, inv)
}
}
impl Div<AReal<f64>> for f64 {
type Output = AReal<f64>;
#[inline]
fn div(self, rhs: AReal<f64>) -> AReal<f64> {
let inv = 1.0 / rhs.value;
record_unary(self * inv, rhs.slot, -self * inv * inv)
}
}
impl Div<&AReal<f64>> for f64 {
type Output = AReal<f64>;
#[inline]
fn div(self, rhs: &AReal<f64>) -> AReal<f64> {
let inv = 1.0 / rhs.value;
record_unary(self * inv, rhs.slot, -self * inv * inv)
}
}
impl Div<AReal<f32>> for f32 {
type Output = AReal<f32>;
#[inline]
fn div(self, rhs: AReal<f32>) -> AReal<f32> {
let inv = 1.0 / rhs.value;
record_unary(self * inv, rhs.slot, -self * inv * inv)
}
}
impl Div<&AReal<f32>> for f32 {
type Output = AReal<f32>;
#[inline]
fn div(self, rhs: &AReal<f32>) -> AReal<f32> {
let inv = 1.0 / rhs.value;
record_unary(self * inv, rhs.slot, -self * inv * inv)
}
}
impl<T: TapeStorage> Neg for AReal<T> {
type Output = AReal<T>;
#[inline]
fn neg(self) -> AReal<T> {
record_unary(-self.value, self.slot, -T::one())
}
}
impl<T: TapeStorage> Neg for &AReal<T> {
type Output = AReal<T>;
#[inline]
fn neg(self) -> AReal<T> {
record_unary(-self.value, self.slot, -T::one())
}
}
impl<T: TapeStorage> AddAssign for AReal<T> {
fn add_assign(&mut self, rhs: AReal<T>) {
*self = self.clone() + rhs;
}
}
impl<T: TapeStorage> AddAssign<&AReal<T>> for AReal<T> {
fn add_assign(&mut self, rhs: &AReal<T>) {
*self = self.clone() + rhs;
}
}
impl<T: TapeStorage> AddAssign<T> for AReal<T> {
fn add_assign(&mut self, rhs: T) {
*self = self.clone() + rhs;
}
}
impl<T: TapeStorage> SubAssign for AReal<T> {
fn sub_assign(&mut self, rhs: AReal<T>) {
*self = self.clone() - rhs;
}
}
impl<T: TapeStorage> SubAssign<&AReal<T>> for AReal<T> {
fn sub_assign(&mut self, rhs: &AReal<T>) {
*self = self.clone() - rhs;
}
}
impl<T: TapeStorage> SubAssign<T> for AReal<T> {
fn sub_assign(&mut self, rhs: T) {
*self = self.clone() - rhs;
}
}
impl<T: TapeStorage> MulAssign for AReal<T> {
fn mul_assign(&mut self, rhs: AReal<T>) {
*self = self.clone() * rhs;
}
}
impl<T: TapeStorage> MulAssign<&AReal<T>> for AReal<T> {
fn mul_assign(&mut self, rhs: &AReal<T>) {
*self = self.clone() * rhs;
}
}
impl<T: TapeStorage> MulAssign<T> for AReal<T> {
fn mul_assign(&mut self, rhs: T) {
*self = self.clone() * rhs;
}
}
impl<T: TapeStorage> DivAssign for AReal<T> {
fn div_assign(&mut self, rhs: AReal<T>) {
*self = self.clone() / rhs;
}
}
impl<T: TapeStorage> DivAssign<&AReal<T>> for AReal<T> {
fn div_assign(&mut self, rhs: &AReal<T>) {
*self = self.clone() / rhs;
}
}
impl<T: TapeStorage> DivAssign<T> for AReal<T> {
fn div_assign(&mut self, rhs: T) {
*self = self.clone() / rhs;
}
}
impl<T: TapeStorage> PartialEq for AReal<T> {
fn eq(&self, other: &Self) -> bool {
self.value == other.value
}
}
impl<T: TapeStorage> PartialEq<T> for AReal<T> {
fn eq(&self, other: &T) -> bool {
self.value == *other
}
}
impl<T: TapeStorage> PartialOrd for AReal<T> {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
self.value.partial_cmp(&other.value)
}
}
impl<T: TapeStorage> PartialOrd<T> for AReal<T> {
fn partial_cmp(&self, other: &T) -> Option<std::cmp::Ordering> {
self.value.partial_cmp(other)
}
}
impl<T: TapeStorage> fmt::Display for AReal<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.value)
}
}
impl<T: TapeStorage> fmt::Debug for AReal<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "AReal({}, slot={})", self.value, self.slot)
}
}
impl<T: TapeStorage> Default for AReal<T> {
fn default() -> Self {
AReal::new(T::zero())
}
}
#[derive(Clone)]
pub struct NamedAReal {
inner: AReal<f64>,
}
impl NamedAReal {
#[inline]
pub(crate) fn from_inner(inner: AReal<f64>) -> Self {
Self { inner }
}
#[inline]
pub fn value(&self) -> f64 {
self.inner.value()
}
#[inline]
pub fn inner(&self) -> &AReal<f64> {
&self.inner
}
#[inline]
pub fn sin(&self) -> Self {
Self {
inner: math::ad::sin(&self.inner),
}
}
#[inline]
pub fn cos(&self) -> Self {
Self {
inner: math::ad::cos(&self.inner),
}
}
#[inline]
pub fn tan(&self) -> Self {
Self {
inner: math::ad::tan(&self.inner),
}
}
#[inline]
pub fn exp(&self) -> Self {
Self {
inner: math::ad::exp(&self.inner),
}
}
#[inline]
pub fn ln(&self) -> Self {
Self {
inner: math::ad::ln(&self.inner),
}
}
#[inline]
pub fn sqrt(&self) -> Self {
Self {
inner: math::ad::sqrt(&self.inner),
}
}
#[inline]
pub fn tanh(&self) -> Self {
Self {
inner: math::ad::tanh(&self.inner),
}
}
#[inline]
pub fn norm_cdf(&self) -> Self {
Self {
inner: math::ad::norm_cdf(&self.inner),
}
}
#[inline]
pub fn inv_norm_cdf(&self) -> Self {
Self {
inner: math::ad::inv_norm_cdf(&self.inner),
}
}
}
impl fmt::Debug for NamedAReal {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("NamedAReal")
.field("value", &self.inner.value())
.field("slot", &self.inner.slot())
.finish()
}
}
macro_rules! __named_areal_binop {
($trait:ident, $method:ident, $op:tt) => {
impl ::core::ops::$trait<NamedAReal> for NamedAReal {
type Output = NamedAReal;
#[inline]
fn $method(self, rhs: NamedAReal) -> NamedAReal {
NamedAReal { inner: self.inner $op rhs.inner }
}
}
impl ::core::ops::$trait<&NamedAReal> for &NamedAReal {
type Output = NamedAReal;
#[inline]
fn $method(self, rhs: &NamedAReal) -> NamedAReal {
NamedAReal { inner: &self.inner $op &rhs.inner }
}
}
impl ::core::ops::$trait<&NamedAReal> for NamedAReal {
type Output = NamedAReal;
#[inline]
fn $method(self, rhs: &NamedAReal) -> NamedAReal {
NamedAReal { inner: self.inner $op &rhs.inner }
}
}
impl ::core::ops::$trait<NamedAReal> for &NamedAReal {
type Output = NamedAReal;
#[inline]
fn $method(self, rhs: NamedAReal) -> NamedAReal {
NamedAReal { inner: &self.inner $op rhs.inner }
}
}
impl ::core::ops::$trait<f64> for NamedAReal {
type Output = NamedAReal;
#[inline]
fn $method(self, rhs: f64) -> NamedAReal {
NamedAReal { inner: self.inner $op rhs }
}
}
impl ::core::ops::$trait<f64> for &NamedAReal {
type Output = NamedAReal;
#[inline]
fn $method(self, rhs: f64) -> NamedAReal {
NamedAReal { inner: &self.inner $op rhs }
}
}
};
}
__named_areal_binop!(Add, add, +);
__named_areal_binop!(Sub, sub, -);
__named_areal_binop!(Mul, mul, *);
__named_areal_binop!(Div, div, /);
impl ::core::ops::Neg for NamedAReal {
type Output = NamedAReal;
#[inline]
fn neg(self) -> NamedAReal {
NamedAReal { inner: -self.inner }
}
}
impl ::core::ops::Neg for &NamedAReal {
type Output = NamedAReal;
#[inline]
fn neg(self) -> NamedAReal {
NamedAReal {
inner: -&self.inner,
}
}
}
pub struct NamedTape {
tape: Tape<f64>,
builder: IndexSet<String>,
inputs: Vec<(String, u32)>,
registry: Option<Arc<VarRegistry>>,
frozen: bool,
_not_send: PhantomData<*const ()>,
}
impl NamedTape {
pub fn new() -> Self {
Self {
tape: Tape::<f64>::new(true),
builder: IndexSet::new(),
inputs: Vec::new(),
registry: None,
frozen: false,
_not_send: PhantomData,
}
}
pub fn input(&mut self, name: &str, value: f64) -> NamedAReal {
assert!(
!self.frozen,
"NamedTape::input({:?}) called after freeze(); add all inputs before running the forward pass",
name
);
if !self.builder.contains(name) {
self.builder.insert(name.to_string());
}
let mut ar = AReal::<f64>::new(value);
AReal::register_input(std::slice::from_mut(&mut ar), &mut self.tape);
self.inputs.push((name.to_string(), ar.slot()));
NamedAReal::from_inner(ar)
}
pub fn freeze(&mut self) -> Arc<VarRegistry> {
assert!(
!self.frozen,
"NamedTape::freeze() called twice on the same tape"
);
let reg = Arc::new(VarRegistry::from_names(self.builder.iter().cloned()));
self.registry = Some(Arc::clone(®));
self.tape.activate();
self.frozen = true;
reg
}
#[inline]
pub fn is_frozen(&self) -> bool {
self.frozen
}
#[inline]
pub fn registry(&self) -> Option<&Arc<VarRegistry>> {
self.registry.as_ref()
}
pub fn gradient(&mut self, output: &NamedAReal) -> IndexMap<String, f64> {
assert!(
self.frozen,
"NamedTape::gradient() called before freeze()"
);
self.tape.clear_derivatives();
output.inner.set_adjoint(&mut self.tape, 1.0);
self.tape.compute_adjoints();
let mut grad = IndexMap::with_capacity(self.inputs.len());
for (name, slot) in &self.inputs {
grad.insert(name.clone(), self.tape.derivative(*slot));
}
grad
}
pub fn deactivate_all() {
Tape::<f64>::deactivate_all();
}
}
impl Default for NamedTape {
fn default() -> Self {
Self::new()
}
}
impl fmt::Debug for NamedTape {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("NamedTape")
.field("frozen", &self.frozen)
.field("inputs", &self.inputs.len())
.field("registry_len", &self.registry.as_ref().map(|r| r.len()))
.finish()
}
}
impl Drop for NamedTape {
fn drop(&mut self) {
if self.frozen {
self.tape.deactivate();
}
}
}