use crate::scalar::Scalar;
use std::fmt;
use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign};
#[derive(Clone, Copy)]
pub struct Dual2<T: Scalar> {
value: T,
d1: T,
d2: T,
}
impl<T: Scalar> Dual2<T> {
#[inline]
pub fn new(value: T, d1: T, d2: T) -> Self {
Dual2 { value, d1, d2 }
}
#[inline]
pub fn constant(value: T) -> Self {
Dual2 {
value,
d1: T::zero(),
d2: T::zero(),
}
}
#[inline]
pub fn variable(value: T) -> Self {
Dual2 {
value,
d1: T::one(),
d2: T::zero(),
}
}
#[inline]
pub fn value(&self) -> T {
self.value
}
#[inline]
pub fn first_derivative(&self) -> T {
self.d1
}
#[inline]
pub fn second_derivative(&self) -> T {
self.d2
}
pub fn powf(self, n: T) -> Dual2<T> {
let v = self.value;
let two = T::from(2.0).unwrap();
let vn = v.powf(n);
let gp = n * v.powf(n - T::one());
let gpp = n * (n - T::one()) * v.powf(n - two);
Dual2 {
value: vn,
d1: gp * self.d1,
d2: gpp * self.d1 * self.d1 + gp * self.d2,
}
}
pub fn exp(self) -> Dual2<T> {
let e = self.value.exp();
Dual2 {
value: e,
d1: e * self.d1,
d2: e * self.d1 * self.d1 + e * self.d2,
}
}
pub fn ln(self) -> Dual2<T> {
let v = self.value;
let inv = T::one() / v;
let inv2 = inv * inv;
Dual2 {
value: v.ln(),
d1: inv * self.d1,
d2: -inv2 * self.d1 * self.d1 + inv * self.d2,
}
}
pub fn sqrt(self) -> Dual2<T> {
let v = self.value;
let s = v.sqrt();
let half = T::from(0.5).unwrap();
let quarter = T::from(0.25).unwrap();
let gp = half / s;
let gpp = -quarter / (s * v);
Dual2 {
value: s,
d1: gp * self.d1,
d2: gpp * self.d1 * self.d1 + gp * self.d2,
}
}
pub fn sin(self) -> Dual2<T> {
let v = self.value;
let (s, c) = (v.sin(), v.cos());
Dual2 {
value: s,
d1: c * self.d1,
d2: -s * self.d1 * self.d1 + c * self.d2,
}
}
pub fn cos(self) -> Dual2<T> {
let v = self.value;
let (s, c) = (v.sin(), v.cos());
Dual2 {
value: c,
d1: -s * self.d1,
d2: -c * self.d1 * self.d1 - s * self.d2,
}
}
pub fn erf(self) -> Dual2<T> {
let v = self.value;
let r = crate::math::erf(v);
let two = T::from(2.0).unwrap();
let two_over_sqrt_pi =
T::from(std::f64::consts::FRAC_2_SQRT_PI).unwrap();
let gp = two_over_sqrt_pi * (-v * v).exp();
let gpp = -two * v * gp;
Dual2 {
value: r,
d1: gp * self.d1,
d2: gpp * self.d1 * self.d1 + gp * self.d2,
}
}
pub fn norm_cdf(self) -> Dual2<T> {
let v = self.value;
let r = crate::math::norm_cdf(v);
let gp = crate::math::norm_pdf(v);
let gpp = -v * gp;
Dual2 {
value: r,
d1: gp * self.d1,
d2: gpp * self.d1 * self.d1 + gp * self.d2,
}
}
pub fn inv_norm_cdf(self) -> Dual2<T> {
let v = self.value;
let r = crate::math::inv_norm_cdf(v);
let gp = T::one() / crate::math::norm_pdf(r);
let gpp = r * gp * gp;
Dual2 {
value: r,
d1: gp * self.d1,
d2: gpp * self.d1 * self.d1 + gp * self.d2,
}
}
}
impl<T: Scalar> Add for Dual2<T> {
type Output = Dual2<T>;
#[inline]
fn add(self, rhs: Self) -> Self {
Dual2 {
value: self.value + rhs.value,
d1: self.d1 + rhs.d1,
d2: self.d2 + rhs.d2,
}
}
}
impl<T: Scalar> Add<T> for Dual2<T> {
type Output = Dual2<T>;
#[inline]
fn add(self, rhs: T) -> Self {
Dual2 {
value: self.value + rhs,
d1: self.d1,
d2: self.d2,
}
}
}
impl Add<Dual2<f64>> for f64 {
type Output = Dual2<f64>;
#[inline]
fn add(self, rhs: Dual2<f64>) -> Dual2<f64> {
Dual2 {
value: self + rhs.value,
d1: rhs.d1,
d2: rhs.d2,
}
}
}
impl Add<Dual2<f32>> for f32 {
type Output = Dual2<f32>;
#[inline]
fn add(self, rhs: Dual2<f32>) -> Dual2<f32> {
Dual2 {
value: self + rhs.value,
d1: rhs.d1,
d2: rhs.d2,
}
}
}
impl<T: Scalar> Sub for Dual2<T> {
type Output = Dual2<T>;
#[inline]
fn sub(self, rhs: Self) -> Self {
Dual2 {
value: self.value - rhs.value,
d1: self.d1 - rhs.d1,
d2: self.d2 - rhs.d2,
}
}
}
impl<T: Scalar> Sub<T> for Dual2<T> {
type Output = Dual2<T>;
#[inline]
fn sub(self, rhs: T) -> Self {
Dual2 {
value: self.value - rhs,
d1: self.d1,
d2: self.d2,
}
}
}
impl Sub<Dual2<f64>> for f64 {
type Output = Dual2<f64>;
#[inline]
fn sub(self, rhs: Dual2<f64>) -> Dual2<f64> {
Dual2 {
value: self - rhs.value,
d1: -rhs.d1,
d2: -rhs.d2,
}
}
}
impl Sub<Dual2<f32>> for f32 {
type Output = Dual2<f32>;
#[inline]
fn sub(self, rhs: Dual2<f32>) -> Dual2<f32> {
Dual2 {
value: self - rhs.value,
d1: -rhs.d1,
d2: -rhs.d2,
}
}
}
impl<T: Scalar> Mul for Dual2<T> {
type Output = Dual2<T>;
#[inline]
fn mul(self, rhs: Self) -> Self {
let two = T::from(2.0).unwrap();
Dual2 {
value: self.value * rhs.value,
d1: self.d1 * rhs.value + self.value * rhs.d1,
d2: self.d2 * rhs.value + two * self.d1 * rhs.d1 + self.value * rhs.d2,
}
}
}
impl<T: Scalar> Mul<T> for Dual2<T> {
type Output = Dual2<T>;
#[inline]
fn mul(self, rhs: T) -> Self {
Dual2 {
value: self.value * rhs,
d1: self.d1 * rhs,
d2: self.d2 * rhs,
}
}
}
impl Mul<Dual2<f64>> for f64 {
type Output = Dual2<f64>;
#[inline]
fn mul(self, rhs: Dual2<f64>) -> Dual2<f64> {
Dual2 {
value: self * rhs.value,
d1: self * rhs.d1,
d2: self * rhs.d2,
}
}
}
impl Mul<Dual2<f32>> for f32 {
type Output = Dual2<f32>;
#[inline]
fn mul(self, rhs: Dual2<f32>) -> Dual2<f32> {
Dual2 {
value: self * rhs.value,
d1: self * rhs.d1,
d2: self * rhs.d2,
}
}
}
impl<T: Scalar> Div for Dual2<T> {
type Output = Dual2<T>;
#[inline]
fn div(self, rhs: Self) -> Self {
let two = T::from(2.0).unwrap();
let inv_b = T::one() / rhs.value;
let inv_b2 = inv_b * inv_b;
let inv_b3 = inv_b2 * inv_b;
let recip = Dual2 {
value: inv_b,
d1: -rhs.d1 * inv_b2,
d2: two * rhs.d1 * rhs.d1 * inv_b3 - rhs.d2 * inv_b2,
};
self * recip
}
}
impl<T: Scalar> Div<T> for Dual2<T> {
type Output = Dual2<T>;
#[inline]
fn div(self, rhs: T) -> Self {
let inv = T::one() / rhs;
Dual2 {
value: self.value * inv,
d1: self.d1 * inv,
d2: self.d2 * inv,
}
}
}
impl Div<Dual2<f64>> for f64 {
type Output = Dual2<f64>;
#[inline]
fn div(self, rhs: Dual2<f64>) -> Dual2<f64> {
Dual2::constant(self) / rhs
}
}
impl Div<Dual2<f32>> for f32 {
type Output = Dual2<f32>;
#[inline]
fn div(self, rhs: Dual2<f32>) -> Dual2<f32> {
Dual2::constant(self) / rhs
}
}
impl<T: Scalar> Neg for Dual2<T> {
type Output = Dual2<T>;
#[inline]
fn neg(self) -> Self {
Dual2 {
value: -self.value,
d1: -self.d1,
d2: -self.d2,
}
}
}
impl<T: Scalar> AddAssign for Dual2<T> {
#[inline]
fn add_assign(&mut self, rhs: Self) {
*self = *self + rhs;
}
}
impl<T: Scalar> AddAssign<T> for Dual2<T> {
#[inline]
fn add_assign(&mut self, rhs: T) {
self.value = self.value + rhs;
}
}
impl<T: Scalar> SubAssign for Dual2<T> {
#[inline]
fn sub_assign(&mut self, rhs: Self) {
*self = *self - rhs;
}
}
impl<T: Scalar> SubAssign<T> for Dual2<T> {
#[inline]
fn sub_assign(&mut self, rhs: T) {
self.value = self.value - rhs;
}
}
impl<T: Scalar> MulAssign for Dual2<T> {
#[inline]
fn mul_assign(&mut self, rhs: Self) {
*self = *self * rhs;
}
}
impl<T: Scalar> MulAssign<T> for Dual2<T> {
#[inline]
fn mul_assign(&mut self, rhs: T) {
*self = *self * rhs;
}
}
impl<T: Scalar> DivAssign for Dual2<T> {
#[inline]
fn div_assign(&mut self, rhs: Self) {
*self = *self / rhs;
}
}
impl<T: Scalar> DivAssign<T> for Dual2<T> {
#[inline]
fn div_assign(&mut self, rhs: T) {
*self = *self / rhs;
}
}
impl<T: Scalar> fmt::Display for Dual2<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.value)
}
}
impl<T: Scalar> fmt::Debug for Dual2<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"Dual2(v={}, d1={}, d2={})",
self.value, self.d1, self.d2
)
}
}
impl<T: Scalar> Default for Dual2<T> {
fn default() -> Self {
Dual2::constant(T::zero())
}
}
impl<T: Scalar> From<T> for Dual2<T> {
fn from(value: T) -> Self {
Dual2::constant(value)
}
}
#[derive(Clone)]
pub struct NamedDual2<T: Scalar> {
pub(crate) inner: Dual2<T>,
pub(crate) seeded: Option<usize>,
#[cfg(debug_assertions)]
pub(crate) gen_id: u64,
}
impl<T: Scalar> NamedDual2<T> {
#[inline]
pub(crate) fn __from_parts(inner: Dual2<T>, seeded: Option<usize>) -> Self {
Self {
inner,
seeded,
#[cfg(debug_assertions)]
gen_id: crate::forward_tape::current_gen(),
}
}
#[inline]
pub fn value(&self) -> T {
self.inner.value()
}
pub fn first_derivative(&self, name: &str) -> T {
let idx = crate::forward_tape::with_active_registry(|r| {
let r = r.expect(
"NamedDual2::first_derivative called outside a frozen NamedForwardTape scope",
);
r.index_of(name).unwrap_or_else(|| {
panic!(
"NamedDual2::first_derivative: name {:?} not present in registry",
name
)
})
});
if self.seeded == Some(idx) {
self.inner.first_derivative()
} else {
T::zero()
}
}
pub fn second_derivative(&self, name: &str) -> T {
let idx = crate::forward_tape::with_active_registry(|r| {
let r = r.expect(
"NamedDual2::second_derivative called outside a frozen NamedForwardTape scope",
);
r.index_of(name).unwrap_or_else(|| {
panic!(
"NamedDual2::second_derivative: name {:?} not present in registry",
name
)
})
});
if self.seeded == Some(idx) {
self.inner.second_derivative()
} else {
T::zero()
}
}
#[inline]
pub fn inner(&self) -> &Dual2<T> {
&self.inner
}
#[inline]
pub fn exp(&self) -> Self {
Self {
inner: self.inner.exp(),
seeded: self.seeded,
#[cfg(debug_assertions)]
gen_id: self.gen_id,
}
}
#[inline]
pub fn ln(&self) -> Self {
Self {
inner: self.inner.ln(),
seeded: self.seeded,
#[cfg(debug_assertions)]
gen_id: self.gen_id,
}
}
#[inline]
pub fn sqrt(&self) -> Self {
Self {
inner: self.inner.sqrt(),
seeded: self.seeded,
#[cfg(debug_assertions)]
gen_id: self.gen_id,
}
}
#[inline]
pub fn sin(&self) -> Self {
Self {
inner: self.inner.sin(),
seeded: self.seeded,
#[cfg(debug_assertions)]
gen_id: self.gen_id,
}
}
#[inline]
pub fn cos(&self) -> Self {
Self {
inner: self.inner.cos(),
seeded: self.seeded,
#[cfg(debug_assertions)]
gen_id: self.gen_id,
}
}
#[inline]
pub fn norm_cdf(&self) -> Self {
Self {
inner: self.inner.norm_cdf(),
seeded: self.seeded,
#[cfg(debug_assertions)]
gen_id: self.gen_id,
}
}
#[inline]
pub fn inv_norm_cdf(&self) -> Self {
Self {
inner: self.inner.inv_norm_cdf(),
seeded: self.seeded,
#[cfg(debug_assertions)]
gen_id: self.gen_id,
}
}
}
impl<T: Scalar> fmt::Debug for NamedDual2<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("NamedDual2")
.field("value", &self.inner.value())
.field("first", &self.inner.first_derivative())
.field("second", &self.inner.second_derivative())
.field("seeded", &self.seeded)
.finish()
}
}
impl<T: Scalar> fmt::Display for NamedDual2<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "NamedDual2({})", self.inner.value())
}
}
#[inline]
pub(crate) fn merge_seeded(a: Option<usize>, b: Option<usize>) -> Option<usize> {
match (a, b) {
(None, None) => None,
(Some(x), None) | (None, Some(x)) => Some(x),
(Some(x), Some(y)) if x == y => Some(x),
(Some(_), Some(_)) => {
#[cfg(debug_assertions)]
panic!(
"NamedDual2: operation between two differently-seeded variables; \
seeded Dual2 supports only one active direction"
);
#[cfg(not(debug_assertions))]
a
}
}
}
macro_rules! __named_d2_binop {
($trait_:ident, $method:ident, $op:tt) => {
impl<T: Scalar> ::core::ops::$trait_<NamedDual2<T>> for NamedDual2<T> {
type Output = NamedDual2<T>;
#[inline]
fn $method(self, rhs: NamedDual2<T>) -> NamedDual2<T> {
#[cfg(debug_assertions)]
crate::forward_tape::check_gen(self.gen_id, rhs.gen_id);
NamedDual2 {
inner: self.inner $op rhs.inner,
seeded: merge_seeded(self.seeded, rhs.seeded),
#[cfg(debug_assertions)]
gen_id: self.gen_id,
}
}
}
impl<T: Scalar> ::core::ops::$trait_<&NamedDual2<T>> for &NamedDual2<T> {
type Output = NamedDual2<T>;
#[inline]
fn $method(self, rhs: &NamedDual2<T>) -> NamedDual2<T> {
#[cfg(debug_assertions)]
crate::forward_tape::check_gen(self.gen_id, rhs.gen_id);
NamedDual2 {
inner: self.inner $op rhs.inner,
seeded: merge_seeded(self.seeded, rhs.seeded),
#[cfg(debug_assertions)]
gen_id: self.gen_id,
}
}
}
impl<T: Scalar> ::core::ops::$trait_<&NamedDual2<T>> for NamedDual2<T> {
type Output = NamedDual2<T>;
#[inline]
fn $method(self, rhs: &NamedDual2<T>) -> NamedDual2<T> {
#[cfg(debug_assertions)]
crate::forward_tape::check_gen(self.gen_id, rhs.gen_id);
NamedDual2 {
inner: self.inner $op rhs.inner,
seeded: merge_seeded(self.seeded, rhs.seeded),
#[cfg(debug_assertions)]
gen_id: self.gen_id,
}
}
}
impl<T: Scalar> ::core::ops::$trait_<NamedDual2<T>> for &NamedDual2<T> {
type Output = NamedDual2<T>;
#[inline]
fn $method(self, rhs: NamedDual2<T>) -> NamedDual2<T> {
#[cfg(debug_assertions)]
crate::forward_tape::check_gen(self.gen_id, rhs.gen_id);
NamedDual2 {
inner: self.inner $op rhs.inner,
seeded: merge_seeded(self.seeded, rhs.seeded),
#[cfg(debug_assertions)]
gen_id: self.gen_id,
}
}
}
impl<T: Scalar> ::core::ops::$trait_<T> for NamedDual2<T> {
type Output = NamedDual2<T>;
#[inline]
fn $method(self, rhs: T) -> NamedDual2<T> {
NamedDual2 {
inner: self.inner $op rhs,
seeded: self.seeded,
#[cfg(debug_assertions)]
gen_id: self.gen_id,
}
}
}
impl<T: Scalar> ::core::ops::$trait_<T> for &NamedDual2<T> {
type Output = NamedDual2<T>;
#[inline]
fn $method(self, rhs: T) -> NamedDual2<T> {
NamedDual2 {
inner: self.inner $op rhs,
seeded: self.seeded,
#[cfg(debug_assertions)]
gen_id: self.gen_id,
}
}
}
};
}
__named_d2_binop!(Add, add, +);
__named_d2_binop!(Sub, sub, -);
__named_d2_binop!(Mul, mul, *);
__named_d2_binop!(Div, div, /);
impl<T: Scalar> ::core::ops::Neg for NamedDual2<T> {
type Output = NamedDual2<T>;
#[inline]
fn neg(self) -> NamedDual2<T> {
NamedDual2 {
inner: -self.inner,
seeded: self.seeded,
#[cfg(debug_assertions)]
gen_id: self.gen_id,
}
}
}
impl<T: Scalar> ::core::ops::Neg for &NamedDual2<T> {
type Output = NamedDual2<T>;
#[inline]
fn neg(self) -> NamedDual2<T> {
NamedDual2 {
inner: -self.inner,
seeded: self.seeded,
#[cfg(debug_assertions)]
gen_id: self.gen_id,
}
}
}
macro_rules! __named_d2_scalar_lhs {
($scalar:ty) => {
impl ::core::ops::Add<NamedDual2<$scalar>> for $scalar {
type Output = NamedDual2<$scalar>;
#[inline]
fn add(self, rhs: NamedDual2<$scalar>) -> NamedDual2<$scalar> {
NamedDual2 {
inner: self + rhs.inner,
seeded: rhs.seeded,
#[cfg(debug_assertions)]
gen_id: rhs.gen_id,
}
}
}
impl ::core::ops::Add<&NamedDual2<$scalar>> for $scalar {
type Output = NamedDual2<$scalar>;
#[inline]
fn add(self, rhs: &NamedDual2<$scalar>) -> NamedDual2<$scalar> {
NamedDual2 {
inner: self + rhs.inner,
seeded: rhs.seeded,
#[cfg(debug_assertions)]
gen_id: rhs.gen_id,
}
}
}
impl ::core::ops::Sub<NamedDual2<$scalar>> for $scalar {
type Output = NamedDual2<$scalar>;
#[inline]
fn sub(self, rhs: NamedDual2<$scalar>) -> NamedDual2<$scalar> {
NamedDual2 {
inner: self - rhs.inner,
seeded: rhs.seeded,
#[cfg(debug_assertions)]
gen_id: rhs.gen_id,
}
}
}
impl ::core::ops::Sub<&NamedDual2<$scalar>> for $scalar {
type Output = NamedDual2<$scalar>;
#[inline]
fn sub(self, rhs: &NamedDual2<$scalar>) -> NamedDual2<$scalar> {
NamedDual2 {
inner: self - rhs.inner,
seeded: rhs.seeded,
#[cfg(debug_assertions)]
gen_id: rhs.gen_id,
}
}
}
impl ::core::ops::Mul<NamedDual2<$scalar>> for $scalar {
type Output = NamedDual2<$scalar>;
#[inline]
fn mul(self, rhs: NamedDual2<$scalar>) -> NamedDual2<$scalar> {
NamedDual2 {
inner: self * rhs.inner,
seeded: rhs.seeded,
#[cfg(debug_assertions)]
gen_id: rhs.gen_id,
}
}
}
impl ::core::ops::Mul<&NamedDual2<$scalar>> for $scalar {
type Output = NamedDual2<$scalar>;
#[inline]
fn mul(self, rhs: &NamedDual2<$scalar>) -> NamedDual2<$scalar> {
NamedDual2 {
inner: self * rhs.inner,
seeded: rhs.seeded,
#[cfg(debug_assertions)]
gen_id: rhs.gen_id,
}
}
}
impl ::core::ops::Div<NamedDual2<$scalar>> for $scalar {
type Output = NamedDual2<$scalar>;
#[inline]
fn div(self, rhs: NamedDual2<$scalar>) -> NamedDual2<$scalar> {
NamedDual2 {
inner: self / rhs.inner,
seeded: rhs.seeded,
#[cfg(debug_assertions)]
gen_id: rhs.gen_id,
}
}
}
impl ::core::ops::Div<&NamedDual2<$scalar>> for $scalar {
type Output = NamedDual2<$scalar>;
#[inline]
fn div(self, rhs: &NamedDual2<$scalar>) -> NamedDual2<$scalar> {
NamedDual2 {
inner: self / rhs.inner,
seeded: rhs.seeded,
#[cfg(debug_assertions)]
gen_id: rhs.gen_id,
}
}
}
};
}
__named_d2_scalar_lhs!(f64);
__named_d2_scalar_lhs!(f32);