use core::fmt;
use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign};
use numra_core::Scalar;
#[derive(Copy, Clone)]
pub struct Dual<S: Scalar> {
val: S,
eps: S,
}
impl<S: Scalar> Dual<S> {
#[inline]
pub fn new(value: S, derivative: S) -> Self {
Self {
val: value,
eps: derivative,
}
}
#[inline]
pub fn constant(value: S) -> Self {
Self {
val: value,
eps: S::ZERO,
}
}
#[inline]
pub fn variable(value: S) -> Self {
Self {
val: value,
eps: S::ONE,
}
}
#[inline]
pub fn value(&self) -> S {
self.val
}
#[inline]
pub fn deriv(&self) -> S {
self.eps
}
#[inline]
pub fn sin(self) -> Self {
Self {
val: self.val.sin(),
eps: self.eps * self.val.cos(),
}
}
#[inline]
pub fn cos(self) -> Self {
Self {
val: self.val.cos(),
eps: -self.eps * self.val.sin(),
}
}
#[inline]
pub fn tan(self) -> Self {
let c = self.val.cos();
Self {
val: self.val.tan(),
eps: self.eps / (c * c),
}
}
#[inline]
pub fn exp(self) -> Self {
let e = self.val.exp();
Self {
val: e,
eps: self.eps * e,
}
}
#[inline]
pub fn ln(self) -> Self {
Self {
val: self.val.ln(),
eps: self.eps / self.val,
}
}
#[inline]
pub fn sqrt(self) -> Self {
let s = self.val.sqrt();
Self {
val: s,
eps: self.eps / (S::TWO * s),
}
}
#[inline]
pub fn abs(self) -> Self {
Self {
val: self.val.abs(),
eps: self.eps * self.val.signum(),
}
}
#[inline]
pub fn powf(self, n: S) -> Self {
Self {
val: self.val.powf(n),
eps: self.eps * n * self.val.powf(n - S::ONE),
}
}
#[inline]
pub fn powf_dual(self, n: Self) -> Self {
let val = self.val.powf(n.val);
let eps = val * (n.eps * self.val.ln() + n.val * self.eps / self.val);
Self { val, eps }
}
#[inline]
pub fn powi(self, n: i32) -> Self {
let nf = S::from_i32(n);
Self {
val: self.val.powi(n),
eps: self.eps * nf * self.val.powi(n - 1),
}
}
#[inline]
pub fn asin(self) -> Self {
Self {
val: self.val.asin(),
eps: self.eps / (S::ONE - self.val * self.val).sqrt(),
}
}
#[inline]
pub fn acos(self) -> Self {
Self {
val: self.val.acos(),
eps: -self.eps / (S::ONE - self.val * self.val).sqrt(),
}
}
#[inline]
pub fn atan(self) -> Self {
Self {
val: self.val.atan(),
eps: self.eps / (S::ONE + self.val * self.val),
}
}
#[inline]
pub fn sinh(self) -> Self {
Self {
val: self.val.sinh(),
eps: self.eps * self.val.cosh(),
}
}
#[inline]
pub fn cosh(self) -> Self {
Self {
val: self.val.cosh(),
eps: self.eps * self.val.sinh(),
}
}
#[inline]
pub fn tanh(self) -> Self {
let t = self.val.tanh();
Self {
val: t,
eps: self.eps * (S::ONE - t * t),
}
}
}
impl<S: Scalar> Add for Dual<S> {
type Output = Self;
#[inline]
fn add(self, rhs: Self) -> Self {
Self {
val: self.val + rhs.val,
eps: self.eps + rhs.eps,
}
}
}
impl<S: Scalar> Sub for Dual<S> {
type Output = Self;
#[inline]
fn sub(self, rhs: Self) -> Self {
Self {
val: self.val - rhs.val,
eps: self.eps - rhs.eps,
}
}
}
impl<S: Scalar> Mul for Dual<S> {
type Output = Self;
#[inline]
fn mul(self, rhs: Self) -> Self {
Self {
val: self.val * rhs.val,
eps: self.eps * rhs.val + self.val * rhs.eps,
}
}
}
impl<S: Scalar> Div for Dual<S> {
type Output = Self;
#[inline]
fn div(self, rhs: Self) -> Self {
Self {
val: self.val / rhs.val,
eps: (self.eps * rhs.val - self.val * rhs.eps) / (rhs.val * rhs.val),
}
}
}
impl<S: Scalar> Neg for Dual<S> {
type Output = Self;
#[inline]
fn neg(self) -> Self {
Self {
val: -self.val,
eps: -self.eps,
}
}
}
impl<S: Scalar> AddAssign for Dual<S> {
#[inline]
fn add_assign(&mut self, rhs: Self) {
*self = *self + rhs;
}
}
impl<S: Scalar> SubAssign for Dual<S> {
#[inline]
fn sub_assign(&mut self, rhs: Self) {
*self = *self - rhs;
}
}
impl<S: Scalar> MulAssign for Dual<S> {
#[inline]
fn mul_assign(&mut self, rhs: Self) {
*self = *self * rhs;
}
}
impl<S: Scalar> DivAssign for Dual<S> {
#[inline]
fn div_assign(&mut self, rhs: Self) {
*self = *self / rhs;
}
}
impl<S: Scalar> fmt::Debug for Dual<S> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Dual")
.field("val", &self.val)
.field("eps", &self.eps)
.finish()
}
}
impl<S: Scalar> fmt::Display for Dual<S> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{} + {}e", self.val, self.eps)
}
}
impl<S: Scalar> PartialEq for Dual<S> {
#[inline]
fn eq(&self, other: &Self) -> bool {
self.val == other.val
}
}
impl<S: Scalar> PartialOrd for Dual<S> {
#[inline]
fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
self.val.partial_cmp(&other.val)
}
}
#[cfg(test)]
mod tests {
use super::*;
const TOL: f64 = 1e-12;
#[test]
fn test_dual_arithmetic() {
let a = Dual::new(2.0_f64, 1.0);
let b = Dual::new(3.0, 4.0);
let sum = a + b;
assert!((sum.value() - 5.0).abs() < TOL);
assert!((sum.deriv() - 5.0).abs() < TOL);
let prod = a * b;
assert!((prod.value() - 6.0).abs() < TOL);
assert!((prod.deriv() - 11.0).abs() < TOL);
let c = Dual::new(6.0, 1.0);
let d = Dual::constant(3.0);
let quot = c / d;
assert!((quot.value() - 2.0).abs() < TOL);
assert!((quot.deriv() - 1.0 / 3.0).abs() < TOL);
}
#[test]
fn test_dual_transcendental() {
let x = Dual::variable(0.0_f64);
let s = x.sin();
assert!((s.value() - 0.0).abs() < TOL);
assert!((s.deriv() - 1.0).abs() < TOL);
let e = x.exp();
assert!((e.value() - 1.0).abs() < TOL);
assert!((e.deriv() - 1.0).abs() < TOL);
}
#[test]
fn test_dual_chain_rule() {
let x = Dual::variable(1.0_f64);
let y = (x * x).sin();
let expected_val = 1.0_f64.sin();
let expected_deriv = 2.0 * 1.0_f64.cos();
assert!((y.value() - expected_val).abs() < TOL);
assert!((y.deriv() - expected_deriv).abs() < TOL);
}
#[test]
fn test_dual_negation() {
let x = Dual::new(3.0_f64, 2.0);
let neg = -x;
assert!((neg.value() - (-3.0)).abs() < TOL);
assert!((neg.deriv() - (-2.0)).abs() < TOL);
}
#[test]
fn test_dual_constant() {
let c = Dual::<f64>::constant(5.0);
assert!((c.value() - 5.0).abs() < TOL);
assert!((c.deriv() - 0.0).abs() < TOL);
let x = Dual::variable(3.0_f64);
let y = x + c;
assert!((y.value() - 8.0).abs() < TOL);
assert!((y.deriv() - 1.0).abs() < TOL); }
#[test]
fn test_dual_powf() {
let x = Dual::variable(4.0_f64);
let y = x.powf(2.5);
assert!((y.value() - 32.0).abs() < 1e-10);
assert!((y.deriv() - 20.0).abs() < 1e-10);
}
#[test]
fn test_dual_sqrt() {
let x = Dual::variable(4.0_f64);
let y = x.sqrt();
assert!((y.value() - 2.0).abs() < TOL);
assert!((y.deriv() - 0.25).abs() < TOL);
}
#[test]
fn test_dual_ln() {
let x = Dual::variable(2.0_f64);
let y = x.ln();
assert!((y.value() - 2.0_f64.ln()).abs() < TOL);
assert!((y.deriv() - 0.5).abs() < TOL);
}
#[test]
fn test_dual_complex_expression() {
let x = Dual::variable(1.0_f64);
let half = Dual::constant(0.5_f64);
let y = (-(x * x) * half).exp();
let expected_val = (-0.5_f64).exp();
let expected_deriv = -(-0.5_f64).exp();
assert!((y.value() - expected_val).abs() < TOL);
assert!((y.deriv() - expected_deriv).abs() < TOL);
}
#[test]
fn test_dual_powi() {
let x = Dual::variable(2.0_f64);
let y = x.powi(3);
assert!((y.value() - 8.0).abs() < TOL);
assert!((y.deriv() - 12.0).abs() < TOL);
}
#[test]
fn test_dual_cos() {
let x = Dual::variable(core::f64::consts::PI / 3.0);
let y = x.cos();
assert!((y.value() - 0.5).abs() < 1e-10);
assert!((y.deriv() - (-(core::f64::consts::PI / 3.0).sin())).abs() < 1e-10);
}
#[test]
fn test_dual_tan() {
let x = Dual::variable(core::f64::consts::PI / 4.0);
let y = x.tan();
assert!((y.value() - 1.0).abs() < 1e-10);
assert!((y.deriv() - 2.0).abs() < 1e-10);
}
#[test]
fn test_dual_display_debug() {
let x = Dual::new(1.0_f64, 2.0);
let display = format!("{}", x);
assert!(display.contains("1"));
assert!(display.contains("2"));
let debug = format!("{:?}", x);
assert!(debug.contains("Dual"));
}
#[test]
fn test_dual_partial_eq_ord() {
let a = Dual::new(2.0_f64, 1.0);
let b = Dual::new(2.0, 99.0); assert_eq!(a, b);
let c = Dual::new(3.0_f64, 0.0);
assert!(a < c);
assert!(c > a);
}
#[test]
fn test_dual_compound_assignment() {
let mut x = Dual::new(2.0_f64, 1.0);
let y = Dual::new(3.0, 4.0);
x += y;
assert!((x.value() - 5.0).abs() < TOL);
assert!((x.deriv() - 5.0).abs() < TOL);
x -= Dual::new(1.0, 1.0);
assert!((x.value() - 4.0).abs() < TOL);
assert!((x.deriv() - 4.0).abs() < TOL);
x *= Dual::constant(2.0);
assert!((x.value() - 8.0).abs() < TOL);
assert!((x.deriv() - 8.0).abs() < TOL);
x /= Dual::constant(4.0);
assert!((x.value() - 2.0).abs() < TOL);
assert!((x.deriv() - 2.0).abs() < TOL);
}
#[test]
fn test_dual_asin_acos_atan() {
let x = Dual::variable(0.5_f64);
let y = x.asin();
assert!((y.value() - 0.5_f64.asin()).abs() < TOL);
assert!((y.deriv() - 1.0 / (0.75_f64).sqrt()).abs() < 1e-10);
let y2 = x.acos();
assert!((y2.value() - 0.5_f64.acos()).abs() < TOL);
assert!((y2.deriv() - (-1.0 / (0.75_f64).sqrt())).abs() < 1e-10);
let z = Dual::variable(1.0_f64);
let w = z.atan();
assert!((w.value() - 1.0_f64.atan()).abs() < TOL);
assert!((w.deriv() - 0.5).abs() < TOL);
}
#[test]
fn test_dual_sinh_cosh_tanh() {
let x = Dual::variable(1.0_f64);
let y = x.sinh();
assert!((y.value() - 1.0_f64.sinh()).abs() < TOL);
assert!((y.deriv() - 1.0_f64.cosh()).abs() < TOL);
let y2 = x.cosh();
assert!((y2.value() - 1.0_f64.cosh()).abs() < TOL);
assert!((y2.deriv() - 1.0_f64.sinh()).abs() < TOL);
let y3 = x.tanh();
let t = 1.0_f64.tanh();
assert!((y3.value() - t).abs() < TOL);
assert!((y3.deriv() - (1.0 - t * t)).abs() < TOL);
}
#[test]
fn test_dual_abs() {
let x = Dual::variable(-3.0_f64);
let y = x.abs();
assert!((y.value() - 3.0).abs() < TOL);
assert!((y.deriv() - (-1.0)).abs() < TOL);
let x2 = Dual::variable(5.0_f64);
let y2 = x2.abs();
assert!((y2.value() - 5.0).abs() < TOL);
assert!((y2.deriv() - 1.0).abs() < TOL);
}
#[test]
fn test_dual_powf_dual() {
let x = Dual::variable(2.0_f64);
let y = Dual::constant(3.0_f64);
let z = x.powf_dual(y);
assert!((z.value() - 8.0).abs() < 1e-10);
assert!((z.deriv() - 12.0).abs() < 1e-10);
let x = Dual::constant(2.0_f64);
let y = Dual::variable(3.0_f64);
let z = x.powf_dual(y);
assert!((z.value() - 8.0).abs() < 1e-10);
assert!((z.deriv() - 8.0 * 2.0_f64.ln()).abs() < 1e-10);
}
#[test]
fn test_dual_powf_dual_both_variable() {
let t = Dual::variable(1.0_f64);
let base = Dual::constant(2.0) * t; let z = base.powf_dual(t);
assert!((z.value() - 2.0).abs() < 1e-10);
let expected = 2.0 * (2.0_f64.ln() + 1.0);
assert!((z.deriv() - expected).abs() < 1e-10);
}
}