use std::{
array,
ops::{Add, Div, Mul, Neg, Sub},
};
use super::Scalarish;
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct Dual<const N: usize> {
pub v: f64,
pub d: [f64; N],
}
impl<const N: usize> Dual<N> {
pub fn cst(v: f64) -> Self {
Self { v, d: [0.0; N] }
}
pub fn var(v: f64, slot: usize) -> Self {
let mut d = [0.0; N];
if let Some(seed) = d.get_mut(slot) {
*seed = 1.0;
}
Self { v, d }
}
}
impl<const N: usize> Add for Dual<N> {
type Output = Self;
fn add(self, rhs: Self) -> Self::Output {
Self {
v: self.v + rhs.v,
d: array::from_fn(|index| self.d[index] + rhs.d[index]),
}
}
}
impl<const N: usize> Sub for Dual<N> {
type Output = Self;
fn sub(self, rhs: Self) -> Self::Output {
Self {
v: self.v - rhs.v,
d: array::from_fn(|index| self.d[index] - rhs.d[index]),
}
}
}
impl<const N: usize> Mul for Dual<N> {
type Output = Self;
fn mul(self, rhs: Self) -> Self::Output {
Self {
v: self.v * rhs.v,
d: array::from_fn(|index| self.d[index].mul_add(rhs.v, rhs.d[index] * self.v)),
}
}
}
impl<const N: usize> Div for Dual<N> {
type Output = Self;
fn div(self, rhs: Self) -> Self::Output {
let denom = rhs.v * rhs.v;
Self {
v: self.v / rhs.v,
d: array::from_fn(|index| (self.d[index] * rhs.v - self.v * rhs.d[index]) / denom),
}
}
}
impl<const N: usize> Neg for Dual<N> {
type Output = Self;
fn neg(self) -> Self::Output {
Self {
v: -self.v,
d: array::from_fn(|index| -self.d[index]),
}
}
}
impl<const N: usize> Scalarish for Dual<N> {
fn from_f64(x: f64) -> Self {
Self::cst(x)
}
fn sin(self) -> Self {
let cos_v = self.v.cos();
Self {
v: self.v.sin(),
d: array::from_fn(|index| self.d[index] * cos_v),
}
}
fn cos(self) -> Self {
let sin_v = self.v.sin();
Self {
v: self.v.cos(),
d: array::from_fn(|index| -self.d[index] * sin_v),
}
}
fn exp(self) -> Self {
let exp_v = self.v.exp();
Self {
v: exp_v,
d: array::from_fn(|index| self.d[index] * exp_v),
}
}
fn ln(self) -> Self {
Self {
v: self.v.ln(),
d: array::from_fn(|index| self.d[index] / self.v),
}
}
fn sqrt(self) -> Self {
let sqrt_v = self.v.sqrt();
Self {
v: sqrt_v,
d: array::from_fn(|index| self.d[index] / (2.0 * sqrt_v)),
}
}
fn recip(self) -> Self {
let denom = self.v * self.v;
Self {
v: self.v.recip(),
d: array::from_fn(|index| -self.d[index] / denom),
}
}
}