use std::fmt;
use std::marker::PhantomData;
use std::sync::Arc;
use indexmap::{IndexMap, IndexSet};
use crate::areal::AReal;
use crate::labeled::VarRegistry;
use crate::math;
use crate::tape::Tape;
#[derive(Clone)]
pub struct LabeledAReal {
inner: AReal<f64>,
}
impl LabeledAReal {
#[inline]
pub(crate) fn from_inner(inner: AReal<f64>) -> Self {
Self { inner }
}
#[inline]
pub fn value(&self) -> f64 {
self.inner.value()
}
#[inline]
pub fn inner(&self) -> &AReal<f64> {
&self.inner
}
#[inline]
pub fn sin(&self) -> Self {
Self {
inner: math::ad::sin(&self.inner),
}
}
#[inline]
pub fn cos(&self) -> Self {
Self {
inner: math::ad::cos(&self.inner),
}
}
#[inline]
pub fn tan(&self) -> Self {
Self {
inner: math::ad::tan(&self.inner),
}
}
#[inline]
pub fn exp(&self) -> Self {
Self {
inner: math::ad::exp(&self.inner),
}
}
#[inline]
pub fn ln(&self) -> Self {
Self {
inner: math::ad::ln(&self.inner),
}
}
#[inline]
pub fn sqrt(&self) -> Self {
Self {
inner: math::ad::sqrt(&self.inner),
}
}
#[inline]
pub fn tanh(&self) -> Self {
Self {
inner: math::ad::tanh(&self.inner),
}
}
}
impl fmt::Debug for LabeledAReal {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("LabeledAReal")
.field("value", &self.inner.value())
.field("slot", &self.inner.slot())
.finish()
}
}
macro_rules! __lbl_areal_binop {
($trait:ident, $method:ident, $op:tt) => {
impl ::core::ops::$trait<LabeledAReal> for LabeledAReal {
type Output = LabeledAReal;
#[inline]
fn $method(self, rhs: LabeledAReal) -> LabeledAReal {
LabeledAReal { inner: self.inner $op rhs.inner }
}
}
impl ::core::ops::$trait<&LabeledAReal> for &LabeledAReal {
type Output = LabeledAReal;
#[inline]
fn $method(self, rhs: &LabeledAReal) -> LabeledAReal {
LabeledAReal { inner: &self.inner $op &rhs.inner }
}
}
impl ::core::ops::$trait<&LabeledAReal> for LabeledAReal {
type Output = LabeledAReal;
#[inline]
fn $method(self, rhs: &LabeledAReal) -> LabeledAReal {
LabeledAReal { inner: self.inner $op &rhs.inner }
}
}
impl ::core::ops::$trait<LabeledAReal> for &LabeledAReal {
type Output = LabeledAReal;
#[inline]
fn $method(self, rhs: LabeledAReal) -> LabeledAReal {
LabeledAReal { inner: &self.inner $op rhs.inner }
}
}
impl ::core::ops::$trait<f64> for LabeledAReal {
type Output = LabeledAReal;
#[inline]
fn $method(self, rhs: f64) -> LabeledAReal {
LabeledAReal { inner: self.inner $op rhs }
}
}
impl ::core::ops::$trait<f64> for &LabeledAReal {
type Output = LabeledAReal;
#[inline]
fn $method(self, rhs: f64) -> LabeledAReal {
LabeledAReal { inner: &self.inner $op rhs }
}
}
};
}
__lbl_areal_binop!(Add, add, +);
__lbl_areal_binop!(Sub, sub, -);
__lbl_areal_binop!(Mul, mul, *);
__lbl_areal_binop!(Div, div, /);
impl ::core::ops::Neg for LabeledAReal {
type Output = LabeledAReal;
#[inline]
fn neg(self) -> LabeledAReal {
LabeledAReal { inner: -self.inner }
}
}
impl ::core::ops::Neg for &LabeledAReal {
type Output = LabeledAReal;
#[inline]
fn neg(self) -> LabeledAReal {
LabeledAReal {
inner: -&self.inner,
}
}
}
pub struct LabeledTape {
tape: Tape<f64>,
builder: IndexSet<String>,
inputs: Vec<(String, u32)>,
registry: Option<Arc<VarRegistry>>,
frozen: bool,
_not_send: PhantomData<*const ()>,
}
impl LabeledTape {
pub fn new() -> Self {
Self {
tape: Tape::<f64>::new(true),
builder: IndexSet::new(),
inputs: Vec::new(),
registry: None,
frozen: false,
_not_send: PhantomData,
}
}
pub fn input(&mut self, name: &str, value: f64) -> LabeledAReal {
assert!(
!self.frozen,
"LabeledTape::input({:?}) called after freeze(); add all inputs before running the forward pass",
name
);
if !self.builder.contains(name) {
self.builder.insert(name.to_string());
}
let mut ar = AReal::<f64>::new(value);
AReal::register_input(std::slice::from_mut(&mut ar), &mut self.tape);
self.inputs.push((name.to_string(), ar.slot()));
LabeledAReal::from_inner(ar)
}
pub fn freeze(&mut self) -> Arc<VarRegistry> {
assert!(
!self.frozen,
"LabeledTape::freeze() called twice on the same tape"
);
let reg = Arc::new(VarRegistry::from_names(self.builder.iter().cloned()));
self.registry = Some(Arc::clone(®));
self.tape.activate();
self.frozen = true;
reg
}
#[inline]
pub fn is_frozen(&self) -> bool {
self.frozen
}
#[inline]
pub fn registry(&self) -> Option<&Arc<VarRegistry>> {
self.registry.as_ref()
}
pub fn gradient(&mut self, output: &LabeledAReal) -> IndexMap<String, f64> {
assert!(
self.frozen,
"LabeledTape::gradient() called before freeze()"
);
self.tape.clear_derivatives();
output.inner.set_adjoint(&mut self.tape, 1.0);
self.tape.compute_adjoints();
let mut grad = IndexMap::with_capacity(self.inputs.len());
for (name, slot) in &self.inputs {
grad.insert(name.clone(), self.tape.derivative(*slot));
}
grad
}
pub fn deactivate_all() {
Tape::<f64>::deactivate_all();
}
}
impl Default for LabeledTape {
fn default() -> Self {
Self::new()
}
}
impl fmt::Debug for LabeledTape {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("LabeledTape")
.field("frozen", &self.frozen)
.field("inputs", &self.inputs.len())
.field("registry_len", &self.registry.as_ref().map(|r| r.len()))
.finish()
}
}
impl Drop for LabeledTape {
fn drop(&mut self) {
if self.frozen {
self.tape.deactivate();
}
}
}