use std::fmt;
use crate::dual::Dual;
#[derive(Clone, Debug)]
pub struct LabeledDual {
pub(super) inner: Dual,
#[cfg(debug_assertions)]
pub(super) gen_id: u64,
}
impl LabeledDual {
#[inline]
pub(crate) fn __from_inner(inner: Dual) -> Self {
Self {
inner,
#[cfg(debug_assertions)]
gen_id: crate::labeled::forward_tape::current_gen(),
}
}
#[inline]
pub fn real(&self) -> f64 {
self.inner.real
}
pub fn partial(&self, name: &str) -> f64 {
let idx = crate::labeled::forward_tape::with_active_registry(|r| {
let r =
r.expect("LabeledDual::partial called outside a frozen LabeledForwardTape scope");
r.index_of(name).unwrap_or_else(|| {
panic!(
"LabeledDual::partial: name {:?} not present in registry",
name
)
})
});
self.inner.partial(idx)
}
pub fn gradient(&self) -> Vec<(String, f64)> {
crate::labeled::forward_tape::with_active_registry(|r| {
let r =
r.expect("LabeledDual::gradient called outside a frozen LabeledForwardTape scope");
let n = r.len();
let mut out = Vec::with_capacity(n);
for (i, name) in r.iter().enumerate() {
out.push((name.to_string(), self.inner.partial(i)));
}
out
})
}
#[inline]
pub fn inner(&self) -> &Dual {
&self.inner
}
#[inline]
pub fn exp(&self) -> Self {
Self {
inner: self.inner.exp(),
#[cfg(debug_assertions)]
gen_id: self.gen_id,
}
}
#[inline]
pub fn ln(&self) -> Self {
Self {
inner: self.inner.ln(),
#[cfg(debug_assertions)]
gen_id: self.gen_id,
}
}
#[inline]
pub fn sqrt(&self) -> Self {
Self {
inner: self.inner.sqrt(),
#[cfg(debug_assertions)]
gen_id: self.gen_id,
}
}
#[inline]
pub fn sin(&self) -> Self {
Self {
inner: self.inner.sin(),
#[cfg(debug_assertions)]
gen_id: self.gen_id,
}
}
#[inline]
pub fn cos(&self) -> Self {
Self {
inner: self.inner.cos(),
#[cfg(debug_assertions)]
gen_id: self.gen_id,
}
}
#[inline]
pub fn tan(&self) -> Self {
Self {
inner: self.inner.tan(),
#[cfg(debug_assertions)]
gen_id: self.gen_id,
}
}
}
impl fmt::Display for LabeledDual {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "LabeledDual({})", self.inner.real)
}
}
macro_rules! __lbl_dual_binop {
($trait:ident, $method:ident, $op:tt) => {
impl ::core::ops::$trait<LabeledDual> for LabeledDual {
type Output = LabeledDual;
#[inline]
fn $method(self, rhs: LabeledDual) -> LabeledDual {
#[cfg(debug_assertions)]
crate::labeled::forward_tape::check_gen(self.gen_id, rhs.gen_id);
LabeledDual {
inner: self.inner $op rhs.inner,
#[cfg(debug_assertions)]
gen_id: self.gen_id,
}
}
}
impl ::core::ops::$trait<&LabeledDual> for &LabeledDual {
type Output = LabeledDual;
#[inline]
fn $method(self, rhs: &LabeledDual) -> LabeledDual {
#[cfg(debug_assertions)]
crate::labeled::forward_tape::check_gen(self.gen_id, rhs.gen_id);
LabeledDual {
inner: &self.inner $op &rhs.inner,
#[cfg(debug_assertions)]
gen_id: self.gen_id,
}
}
}
impl ::core::ops::$trait<&LabeledDual> for LabeledDual {
type Output = LabeledDual;
#[inline]
fn $method(self, rhs: &LabeledDual) -> LabeledDual {
#[cfg(debug_assertions)]
crate::labeled::forward_tape::check_gen(self.gen_id, rhs.gen_id);
LabeledDual {
inner: self.inner $op &rhs.inner,
#[cfg(debug_assertions)]
gen_id: self.gen_id,
}
}
}
impl ::core::ops::$trait<LabeledDual> for &LabeledDual {
type Output = LabeledDual;
#[inline]
fn $method(self, rhs: LabeledDual) -> LabeledDual {
#[cfg(debug_assertions)]
crate::labeled::forward_tape::check_gen(self.gen_id, rhs.gen_id);
LabeledDual {
inner: &self.inner $op rhs.inner,
#[cfg(debug_assertions)]
gen_id: self.gen_id,
}
}
}
impl ::core::ops::$trait<f64> for LabeledDual {
type Output = LabeledDual;
#[inline]
fn $method(self, rhs: f64) -> LabeledDual {
LabeledDual {
inner: self.inner $op rhs,
#[cfg(debug_assertions)]
gen_id: self.gen_id,
}
}
}
impl ::core::ops::$trait<f64> for &LabeledDual {
type Output = LabeledDual;
#[inline]
fn $method(self, rhs: f64) -> LabeledDual {
LabeledDual {
inner: &self.inner $op rhs,
#[cfg(debug_assertions)]
gen_id: self.gen_id,
}
}
}
};
}
__lbl_dual_binop!(Add, add, +);
__lbl_dual_binop!(Sub, sub, -);
__lbl_dual_binop!(Mul, mul, *);
__lbl_dual_binop!(Div, div, /);
impl ::core::ops::Neg for LabeledDual {
type Output = LabeledDual;
#[inline]
fn neg(self) -> LabeledDual {
LabeledDual {
inner: -self.inner,
#[cfg(debug_assertions)]
gen_id: self.gen_id,
}
}
}
impl ::core::ops::Neg for &LabeledDual {
type Output = LabeledDual;
#[inline]
fn neg(self) -> LabeledDual {
LabeledDual {
inner: -&self.inner,
#[cfg(debug_assertions)]
gen_id: self.gen_id,
}
}
}
impl ::core::ops::Add<LabeledDual> for f64 {
type Output = LabeledDual;
#[inline]
fn add(self, rhs: LabeledDual) -> LabeledDual {
LabeledDual {
inner: self + rhs.inner,
#[cfg(debug_assertions)]
gen_id: rhs.gen_id,
}
}
}
impl ::core::ops::Add<&LabeledDual> for f64 {
type Output = LabeledDual;
#[inline]
fn add(self, rhs: &LabeledDual) -> LabeledDual {
LabeledDual {
inner: self + &rhs.inner,
#[cfg(debug_assertions)]
gen_id: rhs.gen_id,
}
}
}
impl ::core::ops::Sub<LabeledDual> for f64 {
type Output = LabeledDual;
#[inline]
fn sub(self, rhs: LabeledDual) -> LabeledDual {
LabeledDual {
inner: self - rhs.inner,
#[cfg(debug_assertions)]
gen_id: rhs.gen_id,
}
}
}
impl ::core::ops::Sub<&LabeledDual> for f64 {
type Output = LabeledDual;
#[inline]
fn sub(self, rhs: &LabeledDual) -> LabeledDual {
LabeledDual {
inner: self - &rhs.inner,
#[cfg(debug_assertions)]
gen_id: rhs.gen_id,
}
}
}
impl ::core::ops::Mul<LabeledDual> for f64 {
type Output = LabeledDual;
#[inline]
fn mul(self, rhs: LabeledDual) -> LabeledDual {
LabeledDual {
inner: self * rhs.inner,
#[cfg(debug_assertions)]
gen_id: rhs.gen_id,
}
}
}
impl ::core::ops::Mul<&LabeledDual> for f64 {
type Output = LabeledDual;
#[inline]
fn mul(self, rhs: &LabeledDual) -> LabeledDual {
LabeledDual {
inner: self * &rhs.inner,
#[cfg(debug_assertions)]
gen_id: rhs.gen_id,
}
}
}
impl ::core::ops::Div<LabeledDual> for f64 {
type Output = LabeledDual;
#[inline]
fn div(self, rhs: LabeledDual) -> LabeledDual {
LabeledDual {
inner: self / rhs.inner,
#[cfg(debug_assertions)]
gen_id: rhs.gen_id,
}
}
}
impl ::core::ops::Div<&LabeledDual> for f64 {
type Output = LabeledDual;
#[inline]
fn div(self, rhs: &LabeledDual) -> LabeledDual {
LabeledDual {
inner: self / &rhs.inner,
#[cfg(debug_assertions)]
gen_id: rhs.gen_id,
}
}
}