use crate::common::IntegrateFloat;
use scirs2_core::ndarray::{Array1, ArrayView1};
use std::fmt;
use std::ops::{Add, Div, Mul, Neg, Sub};
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct Dual<F: IntegrateFloat> {
pub val: F,
pub der: F,
}
impl<F: IntegrateFloat> Dual<F> {
pub fn new(val: F, der: F) -> Self {
Dual { val, der }
}
pub fn constant(val: F) -> Self {
Dual {
val,
der: F::zero(),
}
}
pub fn variable(val: F) -> Self {
Dual { val, der: F::one() }
}
pub fn value(&self) -> F {
self.val
}
pub fn derivative(&self) -> F {
self.der
}
pub fn sin(&self) -> Self {
Dual {
val: self.val.sin(),
der: self.der * self.val.cos(),
}
}
pub fn cos(&self) -> Self {
Dual {
val: self.val.cos(),
der: -self.der * self.val.sin(),
}
}
pub fn exp(&self) -> Self {
let exp_val = self.val.exp();
Dual {
val: exp_val,
der: self.der * exp_val,
}
}
pub fn ln(&self) -> Self {
Dual {
val: self.val.ln(),
der: self.der / self.val,
}
}
pub fn sqrt(&self) -> Self {
let sqrt_val = self.val.sqrt();
Dual {
val: sqrt_val,
der: self.der / (F::from(2.0).expect("Failed to convert constant to float") * sqrt_val),
}
}
pub fn powf(&self, n: F) -> Self {
Dual {
val: self.val.powf(n),
der: self.der * n * self.val.powf(n - F::one()),
}
}
pub fn abs(&self) -> Self {
if self.val >= F::zero() {
*self
} else {
-*self
}
}
pub fn tan(&self) -> Self {
let cos_val = self.val.cos();
Dual {
val: self.val.tan(),
der: self.der / (cos_val * cos_val),
}
}
pub fn tanh(&self) -> Self {
let tanh_val = self.val.tanh();
Dual {
val: tanh_val,
der: self.der * (F::one() - tanh_val * tanh_val),
}
}
pub fn sinh(&self) -> Self {
Dual {
val: self.val.sinh(),
der: self.der * self.val.cosh(),
}
}
pub fn cosh(&self) -> Self {
Dual {
val: self.val.cosh(),
der: self.der * self.val.sinh(),
}
}
pub fn atan(&self) -> Self {
Dual {
val: self.val.atan(),
der: self.der / (F::one() + self.val * self.val),
}
}
pub fn asin(&self) -> Self {
Dual {
val: self.val.asin(),
der: self.der / (F::one() - self.val * self.val).sqrt(),
}
}
pub fn acos(&self) -> Self {
Dual {
val: self.val.acos(),
der: -self.der / (F::one() - self.val * self.val).sqrt(),
}
}
pub fn atan2(&self, x: Self) -> Self {
let r2 = self.val * self.val + x.val * x.val;
Dual {
val: self.val.atan2(x.val),
der: (self.der * x.val - self.val * x.der) / r2,
}
}
pub fn max(&self, other: Self) -> Self {
if self.val > other.val {
*self
} else if self.val < other.val {
other
} else {
Dual {
val: self.val,
der: (self.der + other.der)
/ F::from(2.0).expect("Failed to convert constant to float"),
}
}
}
pub fn min(&self, other: Self) -> Self {
if self.val < other.val {
*self
} else if self.val > other.val {
other
} else {
Dual {
val: self.val,
der: (self.der + other.der)
/ F::from(2.0).expect("Failed to convert constant to float"),
}
}
}
pub fn pow(&self, other: Self) -> Self {
let val = self.val.powf(other.val);
let der = if self.val > F::zero() {
val * (other.der * self.val.ln() + other.val * self.der / self.val)
} else {
F::zero() };
Dual { val, der }
}
}
impl<F: IntegrateFloat> Add for Dual<F> {
type Output = Self;
fn add(self, other: Self) -> Self {
Dual {
val: self.val + other.val,
der: self.der + other.der,
}
}
}
impl<F: IntegrateFloat> Sub for Dual<F> {
type Output = Self;
fn sub(self, other: Self) -> Self {
Dual {
val: self.val - other.val,
der: self.der - other.der,
}
}
}
impl<F: IntegrateFloat> Mul for Dual<F> {
type Output = Self;
fn mul(self, other: Self) -> Self {
Dual {
val: self.val * other.val,
der: self.der * other.val + self.val * other.der,
}
}
}
impl<F: IntegrateFloat> Div for Dual<F> {
type Output = Self;
fn div(self, other: Self) -> Self {
let inv_val = F::one() / other.val;
Dual {
val: self.val * inv_val,
der: (self.der * other.val - self.val * other.der) * inv_val * inv_val,
}
}
}
impl<F: IntegrateFloat> Neg for Dual<F> {
type Output = Self;
fn neg(self) -> Self {
Dual {
val: -self.val,
der: -self.der,
}
}
}
impl<F: IntegrateFloat> Add<F> for Dual<F> {
type Output = Self;
fn add(self, scalar: F) -> Self {
Dual {
val: self.val + scalar,
der: self.der,
}
}
}
impl<F: IntegrateFloat> Sub<F> for Dual<F> {
type Output = Self;
fn sub(self, scalar: F) -> Self {
Dual {
val: self.val - scalar,
der: self.der,
}
}
}
impl<F: IntegrateFloat> Mul<F> for Dual<F> {
type Output = Self;
fn mul(self, scalar: F) -> Self {
Dual {
val: self.val * scalar,
der: self.der * scalar,
}
}
}
impl<F: IntegrateFloat> Div<F> for Dual<F> {
type Output = Self;
fn div(self, scalar: F) -> Self {
Dual {
val: self.val / scalar,
der: self.der / scalar,
}
}
}
impl<F: IntegrateFloat> fmt::Display for Dual<F> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{} + {}ε", self.val, self.der)
}
}
pub struct DualVector<F: IntegrateFloat> {
pub values: Array1<F>,
pub jacobian: Array1<Array1<F>>,
}
impl<F: IntegrateFloat> DualVector<F> {
pub fn new(values: Array1<F>, jacobian: Array1<Array1<F>>) -> Self {
DualVector { values, jacobian }
}
pub fn from_vector(_values: ArrayView1<F>, activevar: usize) -> Self {
let n = _values.len();
let mut jacobian = Array1::from_elem(n, Array1::zeros(n));
jacobian[activevar][activevar] = F::one();
DualVector {
values: _values.to_owned(),
jacobian,
}
}
pub fn constant(values: Array1<F>) -> Self {
let n = values.len();
let jacobian = Array1::from_elem(n, Array1::zeros(n));
DualVector { values, jacobian }
}
pub fn dim(&self) -> usize {
self.values.len()
}
pub fn values(&self) -> &Array1<F> {
&self.values
}
pub fn jacobian(&self) -> &Array1<Array1<F>> {
&self.jacobian
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dual_arithmetic() {
let x = Dual::new(2.0, 1.0);
let y = Dual::new(3.0, 0.0);
let sum = x + y;
assert_eq!(sum.val, 5.0);
assert_eq!(sum.der, 1.0);
let prod = x * y;
assert_eq!(prod.val, 6.0);
assert_eq!(prod.der, 3.0);
let square = x * x;
assert_eq!(square.val, 4.0);
assert_eq!(square.der, 4.0);
}
#[test]
fn test_dual_functions() {
let x = Dual::variable(0.0);
let sin_x = x.sin();
assert!((sin_x.val - 0.0_f64).abs() < 1e-10_f64);
assert!((sin_x.der - 1.0_f64).abs() < 1e-10_f64);
let cos_x = x.cos();
assert!((cos_x.val - 1.0).abs() < 1e-10);
assert!((cos_x.der - 0.0).abs() < 1e-10);
let exp_x = x.exp();
assert!((exp_x.val - 1.0).abs() < 1e-10);
assert!((exp_x.der - 1.0).abs() < 1e-10); }
}