use std::fmt;
use std::ops::{Add, Div, Mul, Neg, Sub};
#[derive(Clone, Copy, PartialEq)]
pub struct Dual {
pub value: f64,
pub derivative: f64,
}
impl fmt::Debug for Dual {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Dual {{ value: {}, derivative: {} }}", self.value, self.derivative)
}
}
impl fmt::Display for Dual {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.derivative >= 0.0 {
write!(f, "{} + {}ε", self.value, self.derivative)
} else {
write!(f, "{} - {}ε", self.value, -self.derivative)
}
}
}
impl Dual {
#[inline]
pub fn new(value: f64, derivative: f64) -> Self {
Self { value, derivative }
}
#[inline]
pub fn constant(value: f64) -> Self {
Self { value, derivative: 0.0 }
}
#[inline]
pub fn variable(value: f64) -> Self {
Self { value, derivative: 1.0 }
}
#[inline]
pub fn zero() -> Self {
Self { value: 0.0, derivative: 0.0 }
}
#[inline]
pub fn one() -> Self {
Self { value: 1.0, derivative: 0.0 }
}
#[inline]
pub fn sin(self) -> Self {
Self {
value: self.value.sin(),
derivative: self.derivative * self.value.cos(),
}
}
#[inline]
pub fn cos(self) -> Self {
Self {
value: self.value.cos(),
derivative: -self.derivative * self.value.sin(),
}
}
#[inline]
pub fn tan(self) -> Self {
let c = self.value.cos();
Self {
value: self.value.tan(),
derivative: self.derivative / (c * c),
}
}
#[inline]
pub fn exp(self) -> Self {
let e = self.value.exp();
Self {
value: e,
derivative: self.derivative * e,
}
}
#[inline]
pub fn ln(self) -> Self {
Self {
value: self.value.ln(),
derivative: self.derivative / self.value,
}
}
#[inline]
pub fn log2(self) -> Self {
Self {
value: self.value.log2(),
derivative: self.derivative / (self.value * std::f64::consts::LN_2),
}
}
#[inline]
pub fn log10(self) -> Self {
Self {
value: self.value.log10(),
derivative: self.derivative / (self.value * std::f64::consts::LN_10),
}
}
#[inline]
pub fn sqrt(self) -> Self {
let s = self.value.sqrt();
Self {
value: s,
derivative: self.derivative / (2.0 * s),
}
}
#[inline]
pub fn cbrt(self) -> Self {
let c = self.value.cbrt();
Self {
value: c,
derivative: self.derivative / (3.0 * c * c),
}
}
#[inline]
pub fn abs(self) -> Self {
let sign = if self.value > 0.0 {
1.0
} else if self.value < 0.0 {
-1.0
} else {
0.0
};
Self {
value: self.value.abs(),
derivative: self.derivative * sign,
}
}
#[inline]
pub fn powi(self, n: i32) -> Self {
Self {
value: self.value.powi(n),
derivative: self.derivative * f64::from(n) * self.value.powi(n - 1),
}
}
#[inline]
pub fn powf(self, p: f64) -> Self {
Self {
value: self.value.powf(p),
derivative: self.derivative * p * self.value.powf(p - 1.0),
}
}
pub fn pow_dual(self, exponent: Dual) -> Self {
let base_pow = self.value.powf(exponent.value);
let d_value = base_pow
* (exponent.value * self.derivative / self.value
+ self.value.ln() * exponent.derivative);
Self {
value: base_pow,
derivative: d_value,
}
}
#[inline]
pub fn asin(self) -> Self {
Self {
value: self.value.asin(),
derivative: self.derivative / (1.0 - self.value * self.value).sqrt(),
}
}
#[inline]
pub fn acos(self) -> Self {
Self {
value: self.value.acos(),
derivative: -self.derivative / (1.0 - self.value * self.value).sqrt(),
}
}
#[inline]
pub fn atan(self) -> Self {
Self {
value: self.value.atan(),
derivative: self.derivative / (1.0 + self.value * self.value),
}
}
pub fn atan2(self, other: Dual) -> Self {
let denom = other.value * other.value + self.value * self.value;
Self {
value: self.value.atan2(other.value),
derivative: (other.value * self.derivative - self.value * other.derivative) / denom,
}
}
#[inline]
pub fn sinh(self) -> Self {
Self {
value: self.value.sinh(),
derivative: self.derivative * self.value.cosh(),
}
}
#[inline]
pub fn cosh(self) -> Self {
Self {
value: self.value.cosh(),
derivative: self.derivative * self.value.sinh(),
}
}
#[inline]
pub fn tanh(self) -> Self {
let t = self.value.tanh();
Self {
value: t,
derivative: self.derivative * (1.0 - t * t),
}
}
#[inline]
pub fn max(self, other: Dual) -> Self {
if self.value > other.value {
self
} else if other.value > self.value {
other
} else {
Self {
value: self.value,
derivative: 0.5 * (self.derivative + other.derivative),
}
}
}
#[inline]
pub fn min(self, other: Dual) -> Self {
if self.value < other.value {
self
} else if other.value < self.value {
other
} else {
Self {
value: self.value,
derivative: 0.5 * (self.derivative + other.derivative),
}
}
}
#[inline]
pub fn relu(self) -> Self {
if self.value > 0.0 {
self
} else if self.value < 0.0 {
Self::zero()
} else {
Self::zero() }
}
pub fn sigmoid(self) -> Self {
let s = 1.0 / (1.0 + (-self.value).exp());
Self {
value: s,
derivative: self.derivative * s * (1.0 - s),
}
}
pub fn softplus(self) -> Self {
let ep = self.value.exp();
Self {
value: (1.0 + ep).ln(),
derivative: self.derivative * ep / (1.0 + ep),
}
}
#[inline]
pub fn is_nan(self) -> bool {
self.value.is_nan()
}
#[inline]
pub fn is_infinite(self) -> bool {
self.value.is_infinite()
}
#[inline]
pub fn is_finite(self) -> bool {
self.value.is_finite()
}
#[inline]
pub fn clamp(self, min: f64, max: f64) -> Self {
if self.value < min {
Self::constant(min)
} else if self.value > max {
Self::constant(max)
} else {
self
}
}
}
impl Neg for Dual {
type Output = Self;
#[inline]
fn neg(self) -> Self {
Self { value: -self.value, derivative: -self.derivative }
}
}
impl Add for Dual {
type Output = Self;
#[inline]
fn add(self, rhs: Self) -> Self {
Self {
value: self.value + rhs.value,
derivative: self.derivative + rhs.derivative,
}
}
}
impl Add<f64> for Dual {
type Output = Self;
#[inline]
fn add(self, rhs: f64) -> Self {
Self { value: self.value + rhs, derivative: self.derivative }
}
}
impl Add<Dual> for f64 {
type Output = Dual;
#[inline]
fn add(self, rhs: Dual) -> Dual {
Dual { value: self + rhs.value, derivative: rhs.derivative }
}
}
impl Sub for Dual {
type Output = Self;
#[inline]
fn sub(self, rhs: Self) -> Self {
Self {
value: self.value - rhs.value,
derivative: self.derivative - rhs.derivative,
}
}
}
impl Sub<f64> for Dual {
type Output = Self;
#[inline]
fn sub(self, rhs: f64) -> Self {
Self { value: self.value - rhs, derivative: self.derivative }
}
}
impl Sub<Dual> for f64 {
type Output = Dual;
#[inline]
fn sub(self, rhs: Dual) -> Dual {
Dual { value: self - rhs.value, derivative: -rhs.derivative }
}
}
impl Mul for Dual {
type Output = Self;
#[inline]
fn mul(self, rhs: Self) -> Self {
Self {
value: self.value * rhs.value,
derivative: self.derivative * rhs.value + self.value * rhs.derivative,
}
}
}
impl Mul<f64> for Dual {
type Output = Self;
#[inline]
fn mul(self, rhs: f64) -> Self {
Self { value: self.value * rhs, derivative: self.derivative * rhs }
}
}
impl Mul<Dual> for f64 {
type Output = Dual;
#[inline]
fn mul(self, rhs: Dual) -> Dual {
Dual { value: self * rhs.value, derivative: self * rhs.derivative }
}
}
impl Div for Dual {
type Output = Self;
#[inline]
fn div(self, rhs: Self) -> Self {
let b2 = rhs.value * rhs.value;
Self {
value: self.value / rhs.value,
derivative: (self.derivative * rhs.value - self.value * rhs.derivative) / b2,
}
}
}
impl Div<f64> for Dual {
type Output = Self;
#[inline]
fn div(self, rhs: f64) -> Self {
Self { value: self.value / rhs, derivative: self.derivative / rhs }
}
}
impl Div<Dual> for f64 {
type Output = Dual;
#[inline]
fn div(self, rhs: Dual) -> Dual {
let b2 = rhs.value * rhs.value;
Dual {
value: self / rhs.value,
derivative: -(self * rhs.derivative) / b2,
}
}
}
impl From<f64> for Dual {
#[inline]
fn from(v: f64) -> Self {
Dual::constant(v)
}
}
impl From<f32> for Dual {
#[inline]
fn from(v: f32) -> Self {
Dual::constant(f64::from(v))
}
}
impl From<i32> for Dual {
#[inline]
fn from(v: i32) -> Self {
Dual::constant(f64::from(v))
}
}
impl From<i64> for Dual {
#[inline]
fn from(v: i64) -> Self {
Dual::constant(v as f64)
}
}
impl PartialOrd for Dual {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
self.value.partial_cmp(&other.value)
}
}
impl std::iter::Sum for Dual {
fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
iter.fold(Dual::zero(), |acc, x| acc + x)
}
}
impl std::iter::Product for Dual {
fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
iter.fold(Dual::one(), |acc, x| acc * x)
}
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct HyperDual {
pub value: f64,
pub d1: f64,
pub d2: f64,
pub d12: f64,
}
impl fmt::Display for HyperDual {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "HyperDual({}, {}, {}, {})", self.value, self.d1, self.d2, self.d12)
}
}
impl HyperDual {
#[inline]
pub fn new(value: f64, d1: f64, d2: f64, d12: f64) -> Self {
Self { value, d1, d2, d12 }
}
#[inline]
pub fn constant(value: f64) -> Self {
Self { value, d1: 0.0, d2: 0.0, d12: 0.0 }
}
#[inline]
pub fn variable(value: f64) -> Self {
Self { value, d1: 1.0, d2: 1.0, d12: 0.0 }
}
#[inline]
pub fn variable1(value: f64) -> Self {
Self { value, d1: 1.0, d2: 0.0, d12: 0.0 }
}
#[inline]
pub fn variable2(value: f64) -> Self {
Self { value, d1: 0.0, d2: 1.0, d12: 0.0 }
}
#[inline]
pub fn zero() -> Self {
Self { value: 0.0, d1: 0.0, d2: 0.0, d12: 0.0 }
}
#[inline]
pub fn one() -> Self {
Self::constant(1.0)
}
pub fn exp(self) -> Self {
let e = self.value.exp();
Self {
value: e,
d1: e * self.d1,
d2: e * self.d2,
d12: e * (self.d12 + self.d1 * self.d2),
}
}
pub fn ln(self) -> Self {
let inv = 1.0 / self.value;
let inv2 = -inv * inv;
Self {
value: self.value.ln(),
d1: inv * self.d1,
d2: inv * self.d2,
d12: inv * self.d12 + inv2 * self.d1 * self.d2,
}
}
pub fn sqrt(self) -> Self {
let s = self.value.sqrt();
let inv2s = 0.5 / s;
let neg_inv4s3 = -0.25 / (s * s * s);
Self {
value: s,
d1: inv2s * self.d1,
d2: inv2s * self.d2,
d12: inv2s * self.d12 + neg_inv4s3 * self.d1 * self.d2,
}
}
pub fn sin(self) -> Self {
let (sin_v, cos_v) = (self.value.sin(), self.value.cos());
Self {
value: sin_v,
d1: cos_v * self.d1,
d2: cos_v * self.d2,
d12: cos_v * self.d12 - sin_v * self.d1 * self.d2,
}
}
pub fn cos(self) -> Self {
let (sin_v, cos_v) = (self.value.sin(), self.value.cos());
Self {
value: cos_v,
d1: -sin_v * self.d1,
d2: -sin_v * self.d2,
d12: -sin_v * self.d12 - cos_v * self.d1 * self.d2,
}
}
pub fn tanh(self) -> Self {
let t = self.value.tanh();
let sech2 = 1.0 - t * t;
let neg2_tanh_sech2 = -2.0 * t * sech2;
Self {
value: t,
d1: sech2 * self.d1,
d2: sech2 * self.d2,
d12: sech2 * self.d12 + neg2_tanh_sech2 * self.d1 * self.d2,
}
}
pub fn powi(self, n: i32) -> Self {
let nf = f64::from(n);
let val_n = self.value.powi(n);
let val_n1 = if n == 0 { 0.0 } else { self.value.powi(n - 1) };
let val_n2 = if n <= 1 { 0.0 } else { self.value.powi(n - 2) };
Self {
value: val_n,
d1: nf * val_n1 * self.d1,
d2: nf * val_n1 * self.d2,
d12: nf * val_n1 * self.d12 + nf * (nf - 1.0) * val_n2 * self.d1 * self.d2,
}
}
pub fn powf(self, p: f64) -> Self {
let val_p = self.value.powf(p);
let val_p1 = self.value.powf(p - 1.0);
let val_p2 = self.value.powf(p - 2.0);
Self {
value: val_p,
d1: p * val_p1 * self.d1,
d2: p * val_p1 * self.d2,
d12: p * val_p1 * self.d12 + p * (p - 1.0) * val_p2 * self.d1 * self.d2,
}
}
}
impl Neg for HyperDual {
type Output = Self;
fn neg(self) -> Self {
Self { value: -self.value, d1: -self.d1, d2: -self.d2, d12: -self.d12 }
}
}
impl Add for HyperDual {
type Output = Self;
fn add(self, rhs: Self) -> Self {
Self {
value: self.value + rhs.value,
d1: self.d1 + rhs.d1,
d2: self.d2 + rhs.d2,
d12: self.d12 + rhs.d12,
}
}
}
impl Add<f64> for HyperDual {
type Output = Self;
fn add(self, rhs: f64) -> Self {
Self { value: self.value + rhs, ..self }
}
}
impl Sub for HyperDual {
type Output = Self;
fn sub(self, rhs: Self) -> Self {
Self {
value: self.value - rhs.value,
d1: self.d1 - rhs.d1,
d2: self.d2 - rhs.d2,
d12: self.d12 - rhs.d12,
}
}
}
impl Mul for HyperDual {
type Output = Self;
fn mul(self, rhs: Self) -> Self {
Self {
value: self.value * rhs.value,
d1: self.d1 * rhs.value + self.value * rhs.d1,
d2: self.d2 * rhs.value + self.value * rhs.d2,
d12: self.d12 * rhs.value
+ self.d1 * rhs.d2
+ self.d2 * rhs.d1
+ self.value * rhs.d12,
}
}
}
impl Mul<f64> for HyperDual {
type Output = Self;
fn mul(self, rhs: f64) -> Self {
Self {
value: self.value * rhs,
d1: self.d1 * rhs,
d2: self.d2 * rhs,
d12: self.d12 * rhs,
}
}
}
impl Div for HyperDual {
type Output = Self;
fn div(self, rhs: Self) -> Self {
let g = rhs.value;
let g2 = g * g;
let g3 = g2 * g;
Self {
value: self.value / g,
d1: (self.d1 * g - self.value * rhs.d1) / g2,
d2: (self.d2 * g - self.value * rhs.d2) / g2,
d12: (self.d12 * g - self.value * rhs.d12) / g2
- (self.d1 * rhs.d2 + self.d2 * rhs.d1) / g2
+ 2.0 * self.value * rhs.d1 * rhs.d2 / g3,
}
}
}
impl Div<f64> for HyperDual {
type Output = Self;
fn div(self, rhs: f64) -> Self {
Self {
value: self.value / rhs,
d1: self.d1 / rhs,
d2: self.d2 / rhs,
d12: self.d12 / rhs,
}
}
}
impl From<f64> for HyperDual {
fn from(v: f64) -> Self {
HyperDual::constant(v)
}
}
impl std::iter::Sum for HyperDual {
fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
iter.fold(HyperDual::zero(), |acc, x| acc + x)
}
}
impl std::iter::Product for HyperDual {
fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
iter.fold(HyperDual::one(), |acc, x| acc * x)
}
}
pub fn eval_gradient<F>(f: F, x: &[f64]) -> (f64, Vec<f64>)
where
F: Fn(&[Dual]) -> Dual,
{
let n = x.len();
let mut grad = vec![0.0f64; n];
let mut value = 0.0f64;
for i in 0..n {
let xs: Vec<Dual> = x
.iter()
.enumerate()
.map(|(j, &xj)| {
if j == i {
Dual::variable(xj)
} else {
Dual::constant(xj)
}
})
.collect();
let out = f(&xs);
if i == 0 {
value = out.value;
}
grad[i] = out.derivative;
}
(value, grad)
}
pub fn eval_hessian<F>(f: F, x: &[f64]) -> Vec<Vec<f64>>
where
F: Fn(&[HyperDual]) -> HyperDual,
{
let n = x.len();
let mut h = vec![vec![0.0f64; n]; n];
for i in 0..n {
for j in i..n {
let xs: Vec<HyperDual> = x
.iter()
.enumerate()
.map(|(k, &xk)| {
if k == i && k == j {
HyperDual::variable(xk) } else if k == i {
HyperDual::variable1(xk)
} else if k == j {
HyperDual::variable2(xk)
} else {
HyperDual::constant(xk)
}
})
.collect();
let out = f(&xs);
h[i][j] = out.d12;
h[j][i] = out.d12; }
}
h
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dual_basic_ops() {
let a = Dual::new(3.0, 1.0);
let b = Dual::new(2.0, 0.0);
let c = a + b;
assert!((c.value - 5.0).abs() < 1e-12);
assert!((c.derivative - 1.0).abs() < 1e-12);
let d = a * b;
assert!((d.value - 6.0).abs() < 1e-12);
assert!((d.derivative - 2.0).abs() < 1e-12);
let e = a / b;
assert!((e.value - 1.5).abs() < 1e-12);
assert!((e.derivative - 0.5).abs() < 1e-12); }
#[test]
fn test_dual_transcendentals() {
let x = Dual::variable(0.0_f64);
let y = x.exp();
assert!((y.value - 1.0).abs() < 1e-12);
assert!((y.derivative - 1.0).abs() < 1e-12);
let z = Dual::variable(1.0_f64).ln();
assert!((z.value - 0.0).abs() < 1e-12);
assert!((z.derivative - 1.0).abs() < 1e-12);
let s = Dual::variable(std::f64::consts::FRAC_PI_2).sin();
assert!((s.value - 1.0).abs() < 1e-12);
assert!(s.derivative.abs() < 1e-12); }
#[test]
fn test_eval_gradient_quadratic() {
let (val, grad) = eval_gradient(|xs| xs[0] * xs[0] + xs[1] * xs[1], &[3.0, 4.0]);
assert!((val - 25.0).abs() < 1e-12);
assert!((grad[0] - 6.0).abs() < 1e-12);
assert!((grad[1] - 8.0).abs() < 1e-12);
}
#[test]
fn test_eval_gradient_mixed() {
let (val, grad) = eval_gradient(|xs| xs[0] * xs[1], &[3.0, 2.0]);
assert!((val - 6.0).abs() < 1e-12);
assert!((grad[0] - 2.0).abs() < 1e-12);
assert!((grad[1] - 3.0).abs() < 1e-12);
}
#[test]
fn test_eval_hessian_quadratic() {
let h = eval_hessian(
|xs| xs[0].powi(2) + HyperDual::constant(3.0) * xs[0] * xs[1]
+ HyperDual::constant(2.0) * xs[1].powi(2),
&[1.0, 1.0],
);
assert!((h[0][0] - 2.0).abs() < 1e-10, "H[0][0]={}", h[0][0]);
assert!((h[0][1] - 3.0).abs() < 1e-10, "H[0][1]={}", h[0][1]);
assert!((h[1][0] - 3.0).abs() < 1e-10, "H[1][0]={}", h[1][0]);
assert!((h[1][1] - 4.0).abs() < 1e-10, "H[1][1]={}", h[1][1]);
}
#[test]
fn test_hyper_dual_exp() {
let x = HyperDual::variable(1.0);
let y = x.exp();
let e = std::f64::consts::E;
assert!((y.value - e).abs() < 1e-12);
assert!((y.d1 - e).abs() < 1e-12);
assert!((y.d2 - e).abs() < 1e-12);
assert!((y.d12 - e).abs() < 1e-12);
}
#[test]
fn test_dual_relu() {
assert_eq!(Dual::new(2.0, 1.0).relu(), Dual::new(2.0, 1.0));
assert_eq!(Dual::new(-1.0, 1.0).relu(), Dual::zero());
}
#[test]
fn test_dual_sigmoid_derivative() {
let x = Dual::variable(0.0);
let s = x.sigmoid();
assert!((s.value - 0.5).abs() < 1e-12);
assert!((s.derivative - 0.25).abs() < 1e-12); }
#[test]
fn test_dual_display() {
let d = Dual::new(3.0, -1.5);
let s = format!("{}", d);
assert!(s.contains("3") && s.contains("1.5"), "display: {}", s);
}
#[test]
fn test_dual_from_primitives() {
let d: Dual = 3.14_f64.into();
assert!((d.value - 3.14).abs() < 1e-12);
assert_eq!(d.derivative, 0.0);
let di: Dual = 5_i32.into();
assert!((di.value - 5.0).abs() < 1e-12);
}
#[test]
fn test_dual_clamp() {
let d = Dual::new(0.5, 2.0).clamp(0.0, 1.0);
assert!((d.value - 0.5).abs() < 1e-12);
assert!((d.derivative - 2.0).abs() < 1e-12);
let clamped = Dual::new(1.5, 2.0).clamp(0.0, 1.0);
assert!((clamped.value - 1.0).abs() < 1e-12);
assert_eq!(clamped.derivative, 0.0); }
}