use crate::error::{NumRs2Error, Result};
use num_traits::Float;
use std::fmt;
use std::ops::{Add, Div, Mul, Neg, Sub};
pub(super) fn float_const<T: Float>(val: f64) -> Result<T> {
T::from(val).ok_or_else(|| {
NumRs2Error::NumericalError(format!("Cannot represent {} in target float type", val))
})
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct HyperDual<T> {
real: T,
eps1: T,
eps2: T,
eps1eps2: T,
}
impl<T: Float> HyperDual<T> {
#[inline]
pub fn new(real: T, eps1: T, eps2: T, eps1eps2: T) -> Self {
Self {
real,
eps1,
eps2,
eps1eps2,
}
}
#[inline]
pub fn constant(value: T) -> Self {
Self::new(value, T::zero(), T::zero(), T::zero())
}
#[inline]
pub fn make_variable(value: T, is_dir_i: bool, is_dir_j: bool) -> Self {
Self::new(
value,
if is_dir_i { T::one() } else { T::zero() },
if is_dir_j { T::one() } else { T::zero() },
T::zero(),
)
}
#[inline]
pub fn real(&self) -> T {
self.real
}
#[inline]
pub fn eps1(&self) -> T {
self.eps1
}
#[inline]
pub fn eps2(&self) -> T {
self.eps2
}
#[inline]
pub fn eps1eps2(&self) -> T {
self.eps1eps2
}
#[inline]
pub fn scale(self, s: T) -> Self {
Self::new(
self.real * s,
self.eps1 * s,
self.eps2 * s,
self.eps1eps2 * s,
)
}
pub fn powf(self, n: T) -> Self {
let a = self.real;
let gp = n * a.powf(n - T::one());
let gpp = n * (n - T::one()) * a.powf(n - T::one() - T::one());
Self::new(
a.powf(n),
gp * self.eps1,
gp * self.eps2,
gpp * self.eps1 * self.eps2 + gp * self.eps1eps2,
)
}
pub fn powi(self, n: i32) -> Self {
match n {
0 => Self::constant(T::one()),
1 => self,
_ if n < 0 => Self::constant(T::one()) / self.powi(-n),
_ => {
let half = self.powi(n / 2);
if n % 2 == 0 {
half * half
} else {
half * half * self
}
}
}
}
pub fn exp(self) -> Self {
let ea = self.real.exp();
Self::new(
ea,
ea * self.eps1,
ea * self.eps2,
ea * (self.eps1 * self.eps2 + self.eps1eps2),
)
}
pub fn ln(self) -> Self {
let a = self.real;
let gp = T::one() / a;
let gpp = -T::one() / (a * a);
Self::new(
a.ln(),
gp * self.eps1,
gp * self.eps2,
gpp * self.eps1 * self.eps2 + gp * self.eps1eps2,
)
}
pub fn sin(self) -> Self {
let sin_a = self.real.sin();
let cos_a = self.real.cos();
Self::new(
sin_a,
cos_a * self.eps1,
cos_a * self.eps2,
-sin_a * self.eps1 * self.eps2 + cos_a * self.eps1eps2,
)
}
pub fn cos(self) -> Self {
let sin_a = self.real.sin();
let cos_a = self.real.cos();
Self::new(
cos_a,
-sin_a * self.eps1,
-sin_a * self.eps2,
-cos_a * self.eps1 * self.eps2 - sin_a * self.eps1eps2,
)
}
pub fn tan(self) -> Self {
let tan_a = self.real.tan();
let cos_a = self.real.cos();
let sec2 = T::one() / (cos_a * cos_a);
let two = T::one() + T::one();
let gp = sec2;
let gpp = two * tan_a * sec2;
Self::new(
tan_a,
gp * self.eps1,
gp * self.eps2,
gpp * self.eps1 * self.eps2 + gp * self.eps1eps2,
)
}
pub fn sinh(self) -> Self {
let sinh_a = self.real.sinh();
let cosh_a = self.real.cosh();
Self::new(
sinh_a,
cosh_a * self.eps1,
cosh_a * self.eps2,
sinh_a * self.eps1 * self.eps2 + cosh_a * self.eps1eps2,
)
}
pub fn cosh(self) -> Self {
let sinh_a = self.real.sinh();
let cosh_a = self.real.cosh();
Self::new(
cosh_a,
sinh_a * self.eps1,
sinh_a * self.eps2,
cosh_a * self.eps1 * self.eps2 + sinh_a * self.eps1eps2,
)
}
pub fn tanh(self) -> Self {
let tanh_a = self.real.tanh();
let sech2 = T::one() - tanh_a * tanh_a;
let two = T::one() + T::one();
let gp = sech2;
let gpp = -two * tanh_a * sech2;
Self::new(
tanh_a,
gp * self.eps1,
gp * self.eps2,
gpp * self.eps1 * self.eps2 + gp * self.eps1eps2,
)
}
pub fn sqrt(self) -> Self {
let a = self.real;
let sqrt_a = a.sqrt();
let two = T::one() + T::one();
let four = two * two;
let gp = T::one() / (two * sqrt_a);
let gpp = -T::one() / (four * a * sqrt_a);
Self::new(
sqrt_a,
gp * self.eps1,
gp * self.eps2,
gpp * self.eps1 * self.eps2 + gp * self.eps1eps2,
)
}
pub fn abs(self) -> Self {
if self.real >= T::zero() {
self
} else {
-self
}
}
pub fn sigmoid(self) -> Self {
let a = self.real;
let s = T::one() / (T::one() + (-a).exp());
let two = T::one() + T::one();
let gp = s * (T::one() - s);
let gpp = gp * (T::one() - two * s);
Self::new(
s,
gp * self.eps1,
gp * self.eps2,
gpp * self.eps1 * self.eps2 + gp * self.eps1eps2,
)
}
pub fn relu(self) -> Self {
if self.real > T::zero() {
self
} else {
Self::constant(T::zero())
}
}
}
impl<T: Float> Add for HyperDual<T> {
type Output = Self;
#[inline]
fn add(self, rhs: Self) -> Self::Output {
Self::new(
self.real + rhs.real,
self.eps1 + rhs.eps1,
self.eps2 + rhs.eps2,
self.eps1eps2 + rhs.eps1eps2,
)
}
}
impl<T: Float> Sub for HyperDual<T> {
type Output = Self;
#[inline]
fn sub(self, rhs: Self) -> Self::Output {
Self::new(
self.real - rhs.real,
self.eps1 - rhs.eps1,
self.eps2 - rhs.eps2,
self.eps1eps2 - rhs.eps1eps2,
)
}
}
impl<T: Float> Mul for HyperDual<T> {
type Output = Self;
#[inline]
fn mul(self, rhs: Self) -> Self::Output {
Self::new(
self.real * rhs.real,
self.real * rhs.eps1 + self.eps1 * rhs.real,
self.real * rhs.eps2 + self.eps2 * rhs.real,
self.real * rhs.eps1eps2
+ self.eps1 * rhs.eps2
+ self.eps2 * rhs.eps1
+ self.eps1eps2 * rhs.real,
)
}
}
impl<T: Float> Div for HyperDual<T> {
type Output = Self;
#[inline]
fn div(self, rhs: Self) -> Self::Output {
let (a, b, c, d) = (self.real, self.eps1, self.eps2, self.eps1eps2);
let (e, f, g, h) = (rhs.real, rhs.eps1, rhs.eps2, rhs.eps1eps2);
let e2 = e * e;
let e3 = e2 * e;
let two = T::one() + T::one();
Self::new(
a / e,
(b * e - a * f) / e2,
(c * e - a * g) / e2,
(d * e2 - a * h * e - b * g * e - c * f * e + two * a * f * g) / e3,
)
}
}
impl<T: Float> Neg for HyperDual<T> {
type Output = Self;
#[inline]
fn neg(self) -> Self::Output {
Self::new(-self.real, -self.eps1, -self.eps2, -self.eps1eps2)
}
}
impl<T: Float + fmt::Display> fmt::Display for HyperDual<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"{} + {}*e1 + {}*e2 + {}*e1e2",
self.real, self.eps1, self.eps2, self.eps1eps2
)
}
}