use super::generated::types::Dual;
use crate::scalar::Float;
use std::ops::Div;
impl<T: Float> Dual<T> {
#[inline]
pub fn one() -> Self {
Self::new(T::one(), T::zero())
}
#[inline]
pub fn from_real(real: T) -> Self {
Self::new(real, T::zero())
}
#[inline]
pub fn epsilon() -> Self {
Self::new(T::zero(), T::one())
}
#[inline]
pub fn variable(x: T) -> Self {
Self::new(x, T::one())
}
#[inline]
pub fn differentiate<F>(x: T, f: F) -> (T, T)
where
F: FnOnce(Dual<T>) -> Dual<T>,
{
let result = f(Self::variable(x));
(result.real(), result.dual())
}
#[inline]
pub fn exp(&self) -> Self {
let exp_real = self.real().exp();
Self::new(exp_real, self.dual() * exp_real)
}
#[inline]
pub fn ln(&self) -> Self {
Self::new(self.real().ln(), self.dual() / self.real())
}
#[inline]
pub fn sin(&self) -> Self {
Self::new(self.real().sin(), self.dual() * self.real().cos())
}
#[inline]
pub fn cos(&self) -> Self {
Self::new(self.real().cos(), -self.dual() * self.real().sin())
}
#[inline]
pub fn tan(&self) -> Self {
let cos_a = self.real().cos();
let sec_sq = T::one() / (cos_a * cos_a);
Self::new(self.real().tan(), self.dual() * sec_sq)
}
#[inline]
pub fn sqrt(&self) -> Self {
let sqrt_real = self.real().sqrt();
Self::new(sqrt_real, self.dual() / (T::TWO * sqrt_real))
}
#[inline]
pub fn powi(&self, n: i32) -> Self {
if n == 0 {
return Self::one();
}
let a_pow_n = self.real().powi(n);
let a_pow_nm1 = self.real().powi(n - 1);
Self::new(a_pow_n, T::from_f64(f64::from(n)) * a_pow_nm1 * self.dual())
}
#[inline]
pub fn powf(&self, p: T) -> Self {
let a_pow_p = self.real().powf(p);
let a_pow_pm1 = self.real().powf(p - T::one());
Self::new(a_pow_p, p * a_pow_pm1 * self.dual())
}
#[inline]
pub fn abs(&self) -> Self {
if self.real() >= T::zero() {
*self
} else {
Self::new(-self.real(), -self.dual())
}
}
}
impl<T: Float> Div for Dual<T> {
type Output = Self;
#[inline]
fn div(self, other: Self) -> Self::Output {
let c = other.real();
let c_sq = c * c;
Self::new(
self.real() / c,
(self.dual() * c - self.real() * other.dual()) / c_sq,
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use std::f64::consts::{E, FRAC_PI_4};
#[test]
fn test_one() {
let one = Dual::<f64>::one();
assert_eq!(one.real(), 1.0);
assert_eq!(one.dual(), 0.0);
}
#[test]
fn test_variable() {
let d = Dual::variable(3.0);
assert_eq!(d.real(), 3.0);
assert_eq!(d.dual(), 1.0);
}
#[test]
fn test_differentiate_square() {
let (value, derivative) = Dual::differentiate(3.0, |d| d * d);
assert_eq!(value, 9.0);
assert_eq!(derivative, 6.0);
}
#[test]
fn test_differentiate_cubic() {
let (value, derivative) = Dual::differentiate(2.0, |d| d * d * d);
assert_eq!(value, 8.0);
assert_eq!(derivative, 12.0);
}
#[test]
fn test_exp() {
let d = Dual::variable(0.0);
let result = d.exp();
assert_relative_eq!(result.real(), 1.0, epsilon = 1e-10);
assert_relative_eq!(result.dual(), 1.0, epsilon = 1e-10);
let d2 = Dual::variable(1.0);
let result2 = d2.exp();
assert_relative_eq!(result2.real(), E, epsilon = 1e-10);
assert_relative_eq!(result2.dual(), E, epsilon = 1e-10);
}
#[test]
fn test_ln() {
let d = Dual::variable(E);
let result = d.ln();
assert_relative_eq!(result.real(), 1.0, epsilon = 1e-10);
assert_relative_eq!(result.dual(), 1.0 / E, epsilon = 1e-10);
}
#[test]
fn test_sin() {
let d = Dual::variable(0.0);
let result = d.sin();
assert_relative_eq!(result.real(), 0.0, epsilon = 1e-10);
assert_relative_eq!(result.dual(), 1.0, epsilon = 1e-10);
}
#[test]
fn test_cos() {
let d = Dual::variable(0.0);
let result = d.cos();
assert_relative_eq!(result.real(), 1.0, epsilon = 1e-10);
assert_relative_eq!(result.dual(), 0.0, epsilon = 1e-10);
}
#[test]
fn test_tan() {
let d = Dual::variable(FRAC_PI_4);
let result = d.tan();
assert_relative_eq!(result.real(), 1.0, epsilon = 1e-10);
assert_relative_eq!(result.dual(), 2.0, epsilon = 1e-10); }
#[test]
fn test_sqrt() {
let d = Dual::variable(4.0);
let result = d.sqrt();
assert_relative_eq!(result.real(), 2.0, epsilon = 1e-10);
assert_relative_eq!(result.dual(), 0.25, epsilon = 1e-10);
}
#[test]
fn test_powi() {
let d = Dual::variable(2.0);
let result = d.powi(3);
assert_relative_eq!(result.real(), 8.0, epsilon = 1e-10);
assert_relative_eq!(result.dual(), 12.0, epsilon = 1e-10);
}
#[test]
fn test_powf() {
let d = Dual::variable(4.0);
let result = d.powf(1.5);
assert_relative_eq!(result.real(), 8.0, epsilon = 1e-10);
assert_relative_eq!(result.dual(), 3.0, epsilon = 1e-10);
}
#[test]
fn test_division() {
let d1 = Dual::variable(4.0);
let d2 = Dual::from_real(2.0);
let result = d1 / d2;
assert_relative_eq!(result.real(), 2.0, epsilon = 1e-10);
assert_relative_eq!(result.dual(), 0.5, epsilon = 1e-10);
}
#[test]
fn test_chain_rule() {
let (_, derivative) = Dual::differentiate(1.0, |d| (d * d).sin());
assert_relative_eq!(derivative, 2.0 * 1.0_f64.cos(), epsilon = 1e-10);
}
#[test]
fn test_abs() {
let d = Dual::variable(-3.0);
let result = d.abs();
assert_relative_eq!(result.real(), 3.0, epsilon = 1e-10);
assert_relative_eq!(result.dual(), -1.0, epsilon = 1e-10);
}
}