use num_traits::{real::Real, One, Zero};
use std::{
fmt::{Debug, Display, Formatter, LowerExp, Result},
ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign},
};
pub trait Value: Real + AddAssign + MulAssign + SubAssign + Debug {}
impl<R> Value for R where R: Real + AddAssign + MulAssign + SubAssign + Debug {}
pub trait Grad<V: Value>
where
Self: Clone
+ AddAssign
+ Neg<Output = Self>
+ MulAssign<V>
+ Mul<V, Output = Self>
+ PartialEq
+ Zero,
{
}
impl<V: Value, G> Grad<V> for G where
G: Clone
+ AddAssign
+ MulAssign<V>
+ Mul<V, Output = Self>
+ Neg<Output = Self>
+ PartialEq
+ PartialOrd
+ Zero
{
}
pub trait Dual
where
Self: Sized
+ Clone
+ PartialEq
+ Add<Output = Self>
+ Mul<Output = Self>
+ Sub<Output = Self>
+ Div<Output = Self>
+ AddAssign
+ DivAssign
+ MulAssign
+ SubAssign
+ Neg,
{
type Value: Value;
fn value(&self) -> &Self::Value;
fn value_mut(&mut self) -> &mut Self::Value;
type Grad: Grad<Self::Value>;
fn decompose(self) -> (Self::Value, Self::Grad);
fn dual(&self) -> &Self::Grad;
fn dual_mut(&mut self) -> &mut Self::Grad;
fn new(value: Self::Value, grad: Self::Grad) -> Self;
fn parameter(value: Self::Value) -> Self {
Self::new(value, Self::Grad::zero())
}
#[must_use]
fn chain(&self, func: impl Fn(&Self::Value) -> (Self::Value, Self::Value)) -> Self {
let (f, df) = func(self.value());
let dual_new = self.dual().clone() * df;
Self::new(f, dual_new)
}
#[must_use]
#[inline]
fn powf(&self, exp: Self::Value) -> Self {
self.chain(|x: &Self::Value| (x.powf(exp), x.powf(exp - Self::Value::one()) * exp))
}
#[must_use]
fn sin(&self) -> Self {
self.sin_cos().0 }
#[must_use]
fn cos(&self) -> Self {
self.sin_cos().1 }
fn sin_cos(&self) -> (Self, Self) {
let (sin, cos) = self.value().sin_cos();
(self.chain(|_| (sin, cos)), self.chain(|_| (cos, -sin)))
}
#[must_use]
fn exp(&self) -> Self {
let real = self.value().exp();
self.chain(|_| (real, real))
}
#[must_use]
fn ln(&self) -> Self {
self.chain(|x| (x.ln(), x.recip()))
}
#[must_use]
fn recip(&self) -> Self {
self.powf(-Self::Value::one())
}
#[must_use]
fn abs(&self) -> Self {
self.chain(|x| (x.abs(), x.signum()))
}
#[must_use]
fn signum(&self) -> Self {
self.chain(|x| (x.signum(), Self::Value::zero()))
}
#[must_use]
fn add_impl(&self, rhs: &Self) -> Self {
let mut output = self.clone();
let _ = output.add_assign_impl(rhs);
output
}
#[must_use]
fn mul_impl(&self, rhs: &Self) -> Self {
let mut output = self.clone();
let _ = output.mul_assign_impl(rhs);
output
}
#[must_use]
fn sub_impl(&self, rhs: &Self) -> Self {
let mut output = self.clone();
let _ = output.sub_assign_impl(rhs);
output
}
#[must_use]
fn div_impl(&self, rhs: &Self) -> Self {
let mut output = self.clone();
let _ = output.div_assign_impl(rhs);
output
}
fn add_assign_impl(&mut self, rhs: &Self) -> &mut Self {
*self.value_mut() += *rhs.value();
*self.dual_mut() += rhs.dual().clone();
self
}
fn mul_assign_impl(&mut self, rhs: &Self) -> &mut Self {
let value_local = *self.value(); *self.value_mut() *= *rhs.value();
*self.dual_mut() *= *rhs.value();
*self.dual_mut() += rhs.dual().clone() * value_local;
self
}
fn sub_assign_impl(&mut self, rhs: &Self) -> &mut Self {
self.add_assign_impl(&rhs.neg_impl())
}
fn div_assign_impl(&mut self, rhs: &Self) -> &mut Self {
self.mul_assign_impl(&rhs.recip())
}
#[must_use]
fn neg_impl(&self) -> Self {
Self::new(self.value().neg(), self.dual().clone().neg())
}
fn map<Output>(self, func: impl Fn(Self) -> Output) -> Output {
func(self)
}
}
pub(crate) fn display_impl<V, G, D>(dual_number: &D, f: &mut Formatter<'_>) -> Result
where
V: Value + Display,
G: Grad<V> + Display,
D: Dual<Value = V, Grad = G>,
{
write!(f, "{}{:+}∆", dual_number.value(), dual_number.dual())
}
pub(crate) fn lower_exp_impl<V, G, D>(dual_number: &D, f: &mut Formatter<'_>) -> Result
where
V: Value + LowerExp,
G: Grad<V> + LowerExp,
D: Dual<Value = V, Grad = G>,
{
write!(f, "{:e}{:e}∆", dual_number.value(), dual_number.dual())
}