use std::fmt;
use crate::dual2::Dual2;
use crate::traits::Scalar;
#[derive(Clone)]
pub struct LabeledDual2<T: Scalar> {
pub(super) inner: Dual2<T>,
pub(super) seeded: Option<usize>,
#[cfg(debug_assertions)]
pub(super) gen_id: u64,
}
impl<T: Scalar> LabeledDual2<T> {
#[inline]
pub(crate) fn __from_parts(inner: Dual2<T>, seeded: Option<usize>) -> Self {
Self {
inner,
seeded,
#[cfg(debug_assertions)]
gen_id: crate::labeled::forward_tape::current_gen(),
}
}
#[inline]
pub fn value(&self) -> T {
self.inner.value()
}
pub fn first_derivative(&self, name: &str) -> T {
let idx = crate::labeled::forward_tape::with_active_registry(|r| {
let r = r.expect(
"LabeledDual2::first_derivative called outside a frozen LabeledForwardTape scope",
);
r.index_of(name).unwrap_or_else(|| {
panic!(
"LabeledDual2::first_derivative: name {:?} not present in registry",
name
)
})
});
if self.seeded == Some(idx) {
self.inner.first_derivative()
} else {
T::zero()
}
}
pub fn second_derivative(&self, name: &str) -> T {
let idx = crate::labeled::forward_tape::with_active_registry(|r| {
let r = r.expect(
"LabeledDual2::second_derivative called outside a frozen LabeledForwardTape scope",
);
r.index_of(name).unwrap_or_else(|| {
panic!(
"LabeledDual2::second_derivative: name {:?} not present in registry",
name
)
})
});
if self.seeded == Some(idx) {
self.inner.second_derivative()
} else {
T::zero()
}
}
#[inline]
pub fn inner(&self) -> &Dual2<T> {
&self.inner
}
#[inline]
pub fn exp(&self) -> Self {
Self {
inner: self.inner.exp(),
seeded: self.seeded,
#[cfg(debug_assertions)]
gen_id: self.gen_id,
}
}
#[inline]
pub fn ln(&self) -> Self {
Self {
inner: self.inner.ln(),
seeded: self.seeded,
#[cfg(debug_assertions)]
gen_id: self.gen_id,
}
}
#[inline]
pub fn sqrt(&self) -> Self {
Self {
inner: self.inner.sqrt(),
seeded: self.seeded,
#[cfg(debug_assertions)]
gen_id: self.gen_id,
}
}
#[inline]
pub fn sin(&self) -> Self {
Self {
inner: self.inner.sin(),
seeded: self.seeded,
#[cfg(debug_assertions)]
gen_id: self.gen_id,
}
}
#[inline]
pub fn cos(&self) -> Self {
Self {
inner: self.inner.cos(),
seeded: self.seeded,
#[cfg(debug_assertions)]
gen_id: self.gen_id,
}
}
}
impl<T: Scalar> fmt::Debug for LabeledDual2<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("LabeledDual2")
.field("value", &self.inner.value())
.field("first", &self.inner.first_derivative())
.field("second", &self.inner.second_derivative())
.field("seeded", &self.seeded)
.finish()
}
}
impl<T: Scalar> fmt::Display for LabeledDual2<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "LabeledDual2({})", self.inner.value())
}
}
#[inline]
pub(super) fn merge_seeded(a: Option<usize>, b: Option<usize>) -> Option<usize> {
match (a, b) {
(None, None) => None,
(Some(x), None) | (None, Some(x)) => Some(x),
(Some(x), Some(y)) if x == y => Some(x),
(Some(_), Some(_)) => {
#[cfg(debug_assertions)]
panic!(
"LabeledDual2: operation between two differently-seeded variables; \
seeded Dual2 supports only one active direction"
);
#[cfg(not(debug_assertions))]
a
}
}
}
macro_rules! __lbld2_binop {
($trait_:ident, $method:ident, $op:tt) => {
impl<T: Scalar> ::core::ops::$trait_<LabeledDual2<T>> for LabeledDual2<T> {
type Output = LabeledDual2<T>;
#[inline]
fn $method(self, rhs: LabeledDual2<T>) -> LabeledDual2<T> {
#[cfg(debug_assertions)]
crate::labeled::forward_tape::check_gen(self.gen_id, rhs.gen_id);
LabeledDual2 {
inner: self.inner $op rhs.inner,
seeded: merge_seeded(self.seeded, rhs.seeded),
#[cfg(debug_assertions)]
gen_id: self.gen_id,
}
}
}
impl<T: Scalar> ::core::ops::$trait_<&LabeledDual2<T>> for &LabeledDual2<T> {
type Output = LabeledDual2<T>;
#[inline]
fn $method(self, rhs: &LabeledDual2<T>) -> LabeledDual2<T> {
#[cfg(debug_assertions)]
crate::labeled::forward_tape::check_gen(self.gen_id, rhs.gen_id);
LabeledDual2 {
inner: self.inner $op rhs.inner,
seeded: merge_seeded(self.seeded, rhs.seeded),
#[cfg(debug_assertions)]
gen_id: self.gen_id,
}
}
}
impl<T: Scalar> ::core::ops::$trait_<&LabeledDual2<T>> for LabeledDual2<T> {
type Output = LabeledDual2<T>;
#[inline]
fn $method(self, rhs: &LabeledDual2<T>) -> LabeledDual2<T> {
#[cfg(debug_assertions)]
crate::labeled::forward_tape::check_gen(self.gen_id, rhs.gen_id);
LabeledDual2 {
inner: self.inner $op rhs.inner,
seeded: merge_seeded(self.seeded, rhs.seeded),
#[cfg(debug_assertions)]
gen_id: self.gen_id,
}
}
}
impl<T: Scalar> ::core::ops::$trait_<LabeledDual2<T>> for &LabeledDual2<T> {
type Output = LabeledDual2<T>;
#[inline]
fn $method(self, rhs: LabeledDual2<T>) -> LabeledDual2<T> {
#[cfg(debug_assertions)]
crate::labeled::forward_tape::check_gen(self.gen_id, rhs.gen_id);
LabeledDual2 {
inner: self.inner $op rhs.inner,
seeded: merge_seeded(self.seeded, rhs.seeded),
#[cfg(debug_assertions)]
gen_id: self.gen_id,
}
}
}
impl<T: Scalar> ::core::ops::$trait_<T> for LabeledDual2<T> {
type Output = LabeledDual2<T>;
#[inline]
fn $method(self, rhs: T) -> LabeledDual2<T> {
LabeledDual2 {
inner: self.inner $op rhs,
seeded: self.seeded,
#[cfg(debug_assertions)]
gen_id: self.gen_id,
}
}
}
impl<T: Scalar> ::core::ops::$trait_<T> for &LabeledDual2<T> {
type Output = LabeledDual2<T>;
#[inline]
fn $method(self, rhs: T) -> LabeledDual2<T> {
LabeledDual2 {
inner: self.inner $op rhs,
seeded: self.seeded,
#[cfg(debug_assertions)]
gen_id: self.gen_id,
}
}
}
};
}
__lbld2_binop!(Add, add, +);
__lbld2_binop!(Sub, sub, -);
__lbld2_binop!(Mul, mul, *);
__lbld2_binop!(Div, div, /);
impl<T: Scalar> ::core::ops::Neg for LabeledDual2<T> {
type Output = LabeledDual2<T>;
#[inline]
fn neg(self) -> LabeledDual2<T> {
LabeledDual2 {
inner: -self.inner,
seeded: self.seeded,
#[cfg(debug_assertions)]
gen_id: self.gen_id,
}
}
}
impl<T: Scalar> ::core::ops::Neg for &LabeledDual2<T> {
type Output = LabeledDual2<T>;
#[inline]
fn neg(self) -> LabeledDual2<T> {
LabeledDual2 {
inner: -self.inner,
seeded: self.seeded,
#[cfg(debug_assertions)]
gen_id: self.gen_id,
}
}
}
macro_rules! __lbld2_scalar_lhs {
($scalar:ty) => {
impl ::core::ops::Add<LabeledDual2<$scalar>> for $scalar {
type Output = LabeledDual2<$scalar>;
#[inline]
fn add(self, rhs: LabeledDual2<$scalar>) -> LabeledDual2<$scalar> {
LabeledDual2 {
inner: self + rhs.inner,
seeded: rhs.seeded,
#[cfg(debug_assertions)]
gen_id: rhs.gen_id,
}
}
}
impl ::core::ops::Add<&LabeledDual2<$scalar>> for $scalar {
type Output = LabeledDual2<$scalar>;
#[inline]
fn add(self, rhs: &LabeledDual2<$scalar>) -> LabeledDual2<$scalar> {
LabeledDual2 {
inner: self + rhs.inner,
seeded: rhs.seeded,
#[cfg(debug_assertions)]
gen_id: rhs.gen_id,
}
}
}
impl ::core::ops::Sub<LabeledDual2<$scalar>> for $scalar {
type Output = LabeledDual2<$scalar>;
#[inline]
fn sub(self, rhs: LabeledDual2<$scalar>) -> LabeledDual2<$scalar> {
LabeledDual2 {
inner: self - rhs.inner,
seeded: rhs.seeded,
#[cfg(debug_assertions)]
gen_id: rhs.gen_id,
}
}
}
impl ::core::ops::Sub<&LabeledDual2<$scalar>> for $scalar {
type Output = LabeledDual2<$scalar>;
#[inline]
fn sub(self, rhs: &LabeledDual2<$scalar>) -> LabeledDual2<$scalar> {
LabeledDual2 {
inner: self - rhs.inner,
seeded: rhs.seeded,
#[cfg(debug_assertions)]
gen_id: rhs.gen_id,
}
}
}
impl ::core::ops::Mul<LabeledDual2<$scalar>> for $scalar {
type Output = LabeledDual2<$scalar>;
#[inline]
fn mul(self, rhs: LabeledDual2<$scalar>) -> LabeledDual2<$scalar> {
LabeledDual2 {
inner: self * rhs.inner,
seeded: rhs.seeded,
#[cfg(debug_assertions)]
gen_id: rhs.gen_id,
}
}
}
impl ::core::ops::Mul<&LabeledDual2<$scalar>> for $scalar {
type Output = LabeledDual2<$scalar>;
#[inline]
fn mul(self, rhs: &LabeledDual2<$scalar>) -> LabeledDual2<$scalar> {
LabeledDual2 {
inner: self * rhs.inner,
seeded: rhs.seeded,
#[cfg(debug_assertions)]
gen_id: rhs.gen_id,
}
}
}
impl ::core::ops::Div<LabeledDual2<$scalar>> for $scalar {
type Output = LabeledDual2<$scalar>;
#[inline]
fn div(self, rhs: LabeledDual2<$scalar>) -> LabeledDual2<$scalar> {
LabeledDual2 {
inner: self / rhs.inner,
seeded: rhs.seeded,
#[cfg(debug_assertions)]
gen_id: rhs.gen_id,
}
}
}
impl ::core::ops::Div<&LabeledDual2<$scalar>> for $scalar {
type Output = LabeledDual2<$scalar>;
#[inline]
fn div(self, rhs: &LabeledDual2<$scalar>) -> LabeledDual2<$scalar> {
LabeledDual2 {
inner: self / rhs.inner,
seeded: rhs.seeded,
#[cfg(debug_assertions)]
gen_id: rhs.gen_id,
}
}
}
};
}
__lbld2_scalar_lhs!(f64);
__lbld2_scalar_lhs!(f32);