use crate::{DualNum, DualNumFloat};
use num_traits::{Float, FloatConst, FromPrimitive, Inv, Num, One, Signed, Zero};
use std::convert::Infallible;
use std::fmt;
use std::iter::{Product, Sum};
use std::marker::PhantomData;
use std::ops::{
Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Rem, RemAssign, Sub, SubAssign,
};
#[derive(PartialEq, Eq, Copy, Clone, Debug)]
pub struct HyperDual<T: DualNum<F>, F> {
pub re: T,
pub eps1: T,
pub eps2: T,
pub eps1eps2: T,
f: PhantomData<F>,
}
pub type HyperDual32 = HyperDual<f32, f32>;
pub type HyperDual64 = HyperDual<f64, f64>;
impl<T: DualNum<F>, F> HyperDual<T, F> {
#[inline]
pub fn new(re: T, eps1: T, eps2: T, eps1eps2: T) -> Self {
Self {
re,
eps1,
eps2,
eps1eps2,
f: PhantomData,
}
}
}
impl<T: DualNum<F>, F> HyperDual<T, F> {
#[inline]
pub fn derivative1(mut self) -> Self {
self.eps1 = T::one();
self
}
#[inline]
pub fn derivative2(mut self) -> Self {
self.eps2 = T::one();
self
}
}
impl<T: DualNum<F>, F> HyperDual<T, F> {
#[inline]
pub fn from_re(re: T) -> Self {
Self::new(re, T::zero(), T::zero(), T::zero())
}
}
pub fn second_partial_derivative<G, T: DualNum<F>, F>(g: G, x: T, y: T) -> (T, T, T, T)
where
G: FnOnce(HyperDual<T, F>, HyperDual<T, F>) -> HyperDual<T, F>,
{
try_second_partial_derivative(|x, y| Ok::<_, Infallible>(g(x, y)), x, y).unwrap()
}
pub fn try_second_partial_derivative<G, T: DualNum<F>, F, E>(
g: G,
x: T,
y: T,
) -> Result<(T, T, T, T), E>
where
G: FnOnce(HyperDual<T, F>, HyperDual<T, F>) -> Result<HyperDual<T, F>, E>,
{
let x = HyperDual::from_re(x).derivative1();
let y = HyperDual::from_re(y).derivative2();
g(x, y).map(|r| (r.re, r.eps1, r.eps2, r.eps1eps2))
}
impl<T: DualNum<F>, F: Float> HyperDual<T, F> {
#[inline]
fn chain_rule(&self, f0: T, f1: T, f2: T) -> Self {
Self::new(
f0,
self.eps1.clone() * f1.clone(),
self.eps2.clone() * f1.clone(),
self.eps1eps2.clone() * f1 + self.eps1.clone() * self.eps2.clone() * f2,
)
}
}
impl<'a, 'b, T: DualNum<F>, F: Float> Mul<&'a HyperDual<T, F>> for &'b HyperDual<T, F> {
type Output = HyperDual<T, F>;
#[inline]
fn mul(self, other: &HyperDual<T, F>) -> HyperDual<T, F> {
HyperDual::new(
self.re.clone() * other.re.clone(),
other.eps1.clone() * self.re.clone() + self.eps1.clone() * other.re.clone(),
other.eps2.clone() * self.re.clone() + self.eps2.clone() * other.re.clone(),
other.eps1eps2.clone() * self.re.clone()
+ self.eps1.clone() * other.eps2.clone()
+ other.eps1.clone() * self.eps2.clone()
+ self.eps1eps2.clone() * other.re.clone(),
)
}
}
impl<'a, 'b, T: DualNum<F>, F: Float> Div<&'a HyperDual<T, F>> for &'b HyperDual<T, F> {
type Output = HyperDual<T, F>;
#[inline]
fn div(self, other: &HyperDual<T, F>) -> HyperDual<T, F> {
let inv = other.re.recip();
let inv2 = inv.clone() * &inv;
HyperDual::new(
self.re.clone() * &inv,
(self.eps1.clone() * other.re.clone() - other.eps1.clone() * self.re.clone())
* inv2.clone(),
(self.eps2.clone() * other.re.clone() - other.eps2.clone() * self.re.clone())
* inv2.clone(),
self.eps1eps2.clone() * inv.clone()
- (other.eps1eps2.clone() * self.re.clone()
+ self.eps1.clone() * other.eps2.clone()
+ other.eps1.clone() * self.eps2.clone())
* inv2.clone()
+ other.eps1.clone()
* other.eps2.clone()
* ((T::one() + T::one()) * self.re.clone() * inv2 * inv),
)
}
}
impl<T: DualNum<F>, F: fmt::Display> fmt::Display for HyperDual<T, F> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"{} + {}ε1 + {}ε2 + {}ε1ε2",
self.re, self.eps1, self.eps2, self.eps1eps2
)
}
}
impl_second_derivatives!(HyperDual, [eps1, eps2, eps1eps2]);
impl_dual!(HyperDual, [eps1, eps2, eps1eps2]);