use crate::traits::{fp::FPVector, math::Vector, stable::StableFn, sugar::VecOps};
use peroxide_num::{ExpLogOps, PowOps, TrigOps};
use std::ops::{Add, Div, Index, IndexMut, Mul, Neg, Sub};
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct Jet<const N: usize> {
value: f64,
deriv: [f64; N],
}
impl<const N: usize> Jet<N> {
pub fn new(value: f64, deriv: [f64; N]) -> Self {
Self { value, deriv }
}
pub fn var(x: f64) -> Self {
let mut deriv = [0.0f64; N];
if N >= 1 {
deriv[0] = 1.0;
}
Self { value: x, deriv }
}
pub fn constant(x: f64) -> Self {
Self {
value: x,
deriv: [0.0f64; N],
}
}
#[inline]
pub fn value(&self) -> f64 {
self.value
}
#[inline]
pub fn x(&self) -> f64 {
self.value
}
#[inline]
pub fn dx(&self) -> f64 {
if N >= 1 {
self.deriv[0]
} else {
0.0
}
}
#[inline]
pub fn ddx(&self) -> f64 {
if N >= 2 {
self.deriv[1] * 2.0
} else {
0.0
}
}
pub fn derivative(&self, order: usize) -> f64 {
if order == 0 {
self.value
} else if order <= N {
self.deriv[order - 1] * factorial(order) as f64
} else {
0.0
}
}
pub fn taylor_coeff(&self, k: usize) -> f64 {
self.coeff(k)
}
#[inline]
fn coeff(&self, k: usize) -> f64 {
if k == 0 {
self.value
} else if k <= N {
self.deriv[k - 1]
} else {
0.0
}
}
#[inline]
fn set_coeff(&mut self, k: usize, v: f64) {
if k == 0 {
self.value = v;
} else if k <= N {
self.deriv[k - 1] = v;
}
}
#[inline]
fn zero() -> Self {
Self {
value: 0.0,
deriv: [0.0f64; N],
}
}
}
pub type Dual = Jet<1>;
pub type HyperDual = Jet<2>;
#[inline]
pub fn ad0(x: f64) -> Jet<0> {
Jet { value: x, deriv: [] }
}
#[inline]
pub fn ad1(x: f64, dx: f64) -> Jet<1> {
Jet {
value: x,
deriv: [dx],
}
}
#[inline]
pub fn ad2(x: f64, dx: f64, ddx: f64) -> Jet<2> {
Jet {
value: x,
deriv: [dx, ddx / 2.0],
}
}
#[inline]
fn factorial(n: usize) -> u64 {
let mut result = 1u64;
for i in 2..=(n as u64) {
result *= i;
}
result
}
impl<const N: usize> std::fmt::Display for Jet<N> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Jet({}", self.value)?;
if N > 0 {
write!(f, "; ")?;
for (i, d) in self.deriv.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{}", d)?;
}
}
write!(f, ")")
}
}
impl<const N: usize> PartialOrd for Jet<N> {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
self.value.partial_cmp(&other.value)
}
}
impl<const N: usize> From<f64> for Jet<N> {
fn from(v: f64) -> Self {
Self::constant(v)
}
}
impl<const N: usize> From<Jet<N>> for f64 {
fn from(j: Jet<N>) -> f64 {
j.value
}
}
impl<const N: usize> Index<usize> for Jet<N> {
type Output = f64;
fn index(&self, index: usize) -> &Self::Output {
if index == 0 {
&self.value
} else if index <= N {
&self.deriv[index - 1]
} else {
panic!("Jet<{}> index {} out of bounds (max index = {})", N, index, N)
}
}
}
impl<const N: usize> IndexMut<usize> for Jet<N> {
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
if index == 0 {
&mut self.value
} else if index <= N {
&mut self.deriv[index - 1]
} else {
panic!("Jet<{}> index {} out of bounds (max index = {})", N, index, N)
}
}
}
impl<const N: usize> Neg for Jet<N> {
type Output = Self;
fn neg(self) -> Self::Output {
let mut z = self;
z.value = -z.value;
for d in z.deriv.iter_mut() {
*d = -*d;
}
z
}
}
impl<const N: usize> Add<Jet<N>> for Jet<N> {
type Output = Self;
fn add(self, rhs: Jet<N>) -> Self::Output {
let mut z = self;
z.value += rhs.value;
for i in 0..N {
z.deriv[i] += rhs.deriv[i];
}
z
}
}
impl<const N: usize> Sub<Jet<N>> for Jet<N> {
type Output = Self;
fn sub(self, rhs: Jet<N>) -> Self::Output {
let mut z = self;
z.value -= rhs.value;
for i in 0..N {
z.deriv[i] -= rhs.deriv[i];
}
z
}
}
impl<const N: usize> Mul<Jet<N>> for Jet<N> {
type Output = Self;
fn mul(self, rhs: Jet<N>) -> Self::Output {
let mut z = Self::zero();
for n in 0..=N {
let mut s = 0.0f64;
for k in 0..=n {
s += self.coeff(k) * rhs.coeff(n - k);
}
z.set_coeff(n, s);
}
z
}
}
impl<const N: usize> Div<Jet<N>> for Jet<N> {
type Output = Self;
fn div(self, rhs: Jet<N>) -> Self::Output {
let b0 = rhs.coeff(0);
let inv_b0 = 1.0 / b0;
let mut z = Self::zero();
z.set_coeff(0, self.coeff(0) * inv_b0);
for n in 1..=N {
let mut s = 0.0f64;
for k in 1..=n {
s += rhs.coeff(k) * z.coeff(n - k);
}
z.set_coeff(n, inv_b0 * (self.coeff(n) - s));
}
z
}
}
impl<const N: usize> Add<f64> for Jet<N> {
type Output = Self;
fn add(self, rhs: f64) -> Self::Output {
let mut z = self;
z.value += rhs;
z
}
}
impl<const N: usize> Sub<f64> for Jet<N> {
type Output = Self;
fn sub(self, rhs: f64) -> Self::Output {
let mut z = self;
z.value -= rhs;
z
}
}
impl<const N: usize> Mul<f64> for Jet<N> {
type Output = Self;
fn mul(self, rhs: f64) -> Self::Output {
let mut z = self;
z.value *= rhs;
for d in z.deriv.iter_mut() {
*d *= rhs;
}
z
}
}
impl<const N: usize> Div<f64> for Jet<N> {
type Output = Self;
fn div(self, rhs: f64) -> Self::Output {
let inv = 1.0 / rhs;
let mut z = self;
z.value *= inv;
for d in z.deriv.iter_mut() {
*d *= inv;
}
z
}
}
impl<const N: usize> Add<Jet<N>> for f64 {
type Output = Jet<N>;
fn add(self, rhs: Jet<N>) -> Self::Output {
let mut z = rhs;
z.value += self;
z
}
}
impl<const N: usize> Sub<Jet<N>> for f64 {
type Output = Jet<N>;
fn sub(self, rhs: Jet<N>) -> Self::Output {
let mut z = -rhs;
z.value += self;
z
}
}
impl<const N: usize> Mul<Jet<N>> for f64 {
type Output = Jet<N>;
fn mul(self, rhs: Jet<N>) -> Self::Output {
rhs * self
}
}
impl<const N: usize> Div<Jet<N>> for f64 {
type Output = Jet<N>;
fn div(self, rhs: Jet<N>) -> Self::Output {
Jet::<N>::constant(self) / rhs
}
}
impl<const N: usize> ExpLogOps for Jet<N> {
type Float = f64;
fn exp(&self) -> Self {
let mut z = Self::zero();
z.set_coeff(0, self.coeff(0).exp());
for n in 1..=N {
let mut s = 0.0f64;
for k in 1..=n {
s += (k as f64) * self.coeff(k) * z.coeff(n - k);
}
z.set_coeff(n, s / (n as f64));
}
z
}
fn ln(&self) -> Self {
let a0 = self.coeff(0);
let inv_a0 = 1.0 / a0;
let mut z = Self::zero();
z.set_coeff(0, a0.ln());
for n in 1..=N {
let mut s = 0.0f64;
for k in 1..n {
s += (k as f64) * z.coeff(k) * self.coeff(n - k);
}
z.set_coeff(n, inv_a0 * (self.coeff(n) - s / (n as f64)));
}
z
}
fn log(&self, base: f64) -> Self {
let ln_base = base.ln();
let z = self.ln();
let mut result = Self::zero();
result.set_coeff(0, z.coeff(0) / ln_base);
for k in 1..=N {
result.set_coeff(k, z.coeff(k) / ln_base);
}
result
}
fn log2(&self) -> Self {
self.log(2.0)
}
fn log10(&self) -> Self {
self.log(10.0)
}
}
impl<const N: usize> PowOps for Jet<N> {
type Float = f64;
fn powi(&self, n: i32) -> Self {
if n == 0 {
return Self::constant(1.0);
}
let abs_n = n.unsigned_abs() as usize;
let mut result = *self;
for _ in 1..abs_n {
result = result * *self;
}
if n < 0 {
Self::constant(1.0) / result
} else {
result
}
}
fn powf(&self, f: f64) -> Self {
(self.ln() * f).exp()
}
fn pow(&self, rhs: Self) -> Self {
(self.ln() * rhs).exp()
}
fn sqrt(&self) -> Self {
let a0 = self.coeff(0);
let z0 = a0.sqrt();
let inv_2z0 = 1.0 / (2.0 * z0);
let mut z = Self::zero();
z.set_coeff(0, z0);
for n in 1..=N {
let mut s = 0.0f64;
for k in 1..n {
s += z.coeff(k) * z.coeff(n - k);
}
z.set_coeff(n, inv_2z0 * (self.coeff(n) - s));
}
z
}
}
impl<const N: usize> TrigOps for Jet<N> {
fn sin_cos(&self) -> (Self, Self) {
let mut s = Self::zero();
let mut c = Self::zero();
s.set_coeff(0, self.coeff(0).sin());
c.set_coeff(0, self.coeff(0).cos());
for n in 1..=N {
let mut ss = 0.0f64;
let mut cs = 0.0f64;
for k in 1..=n {
let ka = (k as f64) * self.coeff(k);
ss += ka * c.coeff(n - k);
cs += ka * s.coeff(n - k);
}
s.set_coeff(n, ss / (n as f64));
c.set_coeff(n, -cs / (n as f64));
}
(s, c)
}
fn sin(&self) -> Self {
self.sin_cos().0
}
fn cos(&self) -> Self {
self.sin_cos().1
}
fn tan(&self) -> Self {
let (s, c) = self.sin_cos();
s / c
}
fn sinh(&self) -> Self {
self.sinh_cosh().0
}
fn cosh(&self) -> Self {
self.sinh_cosh().1
}
fn tanh(&self) -> Self {
let (s, c) = self.sinh_cosh();
s / c
}
fn asin(&self) -> Self {
let one = Self::constant(1.0);
let q = (one - self.powi(2)).sqrt();
let q_inv = one / q;
self.integrate_derivative(self.coeff(0).asin(), &q_inv)
}
fn acos(&self) -> Self {
let one = Self::constant(1.0);
let q = (one - self.powi(2)).sqrt();
let q_inv = -(one / q);
self.integrate_derivative(self.coeff(0).acos(), &q_inv)
}
fn atan(&self) -> Self {
let one = Self::constant(1.0);
let q = one / (one + self.powi(2));
self.integrate_derivative(self.coeff(0).atan(), &q)
}
fn asinh(&self) -> Self {
let one = Self::constant(1.0);
let q_inv = (one + self.powi(2)).sqrt();
let q = one / q_inv;
self.integrate_derivative(self.coeff(0).asinh(), &q)
}
fn acosh(&self) -> Self {
let one = Self::constant(1.0);
let q_inv = (self.powi(2) - one).sqrt();
let q = one / q_inv;
self.integrate_derivative(self.coeff(0).acosh(), &q)
}
fn atanh(&self) -> Self {
let one = Self::constant(1.0);
let q = one / (one - self.powi(2));
self.integrate_derivative(self.coeff(0).atanh(), &q)
}
}
impl<const N: usize> Jet<N> {
pub fn sinh_cosh(&self) -> (Self, Self) {
let mut s = Self::zero();
let mut c = Self::zero();
s.set_coeff(0, self.coeff(0).sinh());
c.set_coeff(0, self.coeff(0).cosh());
for n in 1..=N {
let mut ss = 0.0f64;
let mut cs = 0.0f64;
for k in 1..=n {
let ka = (k as f64) * self.coeff(k);
ss += ka * c.coeff(n - k);
cs += ka * s.coeff(n - k);
}
s.set_coeff(n, ss / (n as f64));
c.set_coeff(n, cs / (n as f64)); }
(s, c)
}
fn integrate_derivative(&self, z0: f64, q: &Self) -> Self {
let mut z = Self::zero();
z.set_coeff(0, z0);
for n in 1..=N {
let mut s = 0.0f64;
for k in 1..=n {
s += (k as f64) * self.coeff(k) * q.coeff(n - k);
}
z.set_coeff(n, s / (n as f64));
}
z
}
}
pub struct ADFn<F> {
f: Box<F>,
grad_level: usize,
}
impl<F: Clone> ADFn<F> {
pub fn new(f: F) -> Self {
Self {
f: Box::new(f),
grad_level: 0,
}
}
pub fn grad(&self) -> Self {
assert!(self.grad_level < 2, "Higher order AD is not allowed");
ADFn {
f: self.f.clone(),
grad_level: self.grad_level + 1,
}
}
}
impl<F: Fn(Jet<2>) -> Jet<2>> StableFn<f64> for ADFn<F> {
type Output = f64;
fn call_stable(&self, target: f64) -> f64 {
match self.grad_level {
0 => (self.f)(Jet::<2>::constant(target)).value(),
1 => (self.f)(Jet::<2>::new(target, [1.0, 0.0])).dx(),
2 => (self.f)(Jet::<2>::new(target, [1.0, 0.0])).ddx(),
_ => unreachable!("grad_level > 2 is not allowed"),
}
}
}
impl<F: Fn(Jet<2>) -> Jet<2>> StableFn<Jet<2>> for ADFn<F> {
type Output = Jet<2>;
fn call_stable(&self, target: Jet<2>) -> Jet<2> {
(self.f)(target)
}
}
impl<F: Fn(Vec<Jet<1>>) -> Vec<Jet<1>>> StableFn<Vec<f64>> for ADFn<F> {
type Output = Vec<f64>;
fn call_stable(&self, target: Vec<f64>) -> Vec<f64> {
(self.f)(target.into_iter().map(Jet::<1>::constant).collect())
.into_iter()
.map(|j| j.value())
.collect()
}
}
impl<F: Fn(Vec<Jet<1>>) -> Vec<Jet<1>>> StableFn<Vec<Jet<1>>> for ADFn<F> {
type Output = Vec<Jet<1>>;
fn call_stable(&self, target: Vec<Jet<1>>) -> Vec<Jet<1>> {
(self.f)(target)
}
}
impl<'a, F: Fn(&Vec<Jet<1>>) -> Vec<Jet<1>>> StableFn<&'a Vec<f64>> for ADFn<F> {
type Output = Vec<f64>;
fn call_stable(&self, target: &'a Vec<f64>) -> Vec<f64> {
let jet_target: Vec<Jet<1>> = target.iter().map(|&x| Jet::<1>::constant(x)).collect();
(self.f)(&jet_target)
.into_iter()
.map(|j| j.value())
.collect()
}
}
impl<'a, F: Fn(&Vec<Jet<1>>) -> Vec<Jet<1>>> StableFn<&'a Vec<Jet<1>>> for ADFn<F> {
type Output = Vec<Jet<1>>;
fn call_stable(&self, target: &'a Vec<Jet<1>>) -> Vec<Jet<1>> {
(self.f)(target)
}
}
pub trait JetVec {
fn to_jet_vec(&self) -> Vec<Jet<1>>;
fn to_f64_vec(&self) -> Vec<f64>;
}
impl JetVec for Vec<f64> {
fn to_jet_vec(&self) -> Vec<Jet<1>> {
self.iter().map(|&x| Jet::<1>::constant(x)).collect()
}
fn to_f64_vec(&self) -> Vec<f64> {
self.clone()
}
}
impl JetVec for Vec<Jet<1>> {
fn to_jet_vec(&self) -> Vec<Jet<1>> {
self.clone()
}
fn to_f64_vec(&self) -> Vec<f64> {
self.iter().map(|j| j.value()).collect()
}
}
impl FPVector for Vec<Jet<1>> {
type Scalar = Jet<1>;
fn fmap<F>(&self, f: F) -> Self
where
F: Fn(Self::Scalar) -> Self::Scalar,
{
self.iter().map(|&x| f(x)).collect()
}
fn reduce<F, T>(&self, init: T, f: F) -> Self::Scalar
where
F: Fn(Self::Scalar, Self::Scalar) -> Self::Scalar,
T: Into<Self::Scalar>,
{
self.iter().fold(init.into(), |acc, &x| f(acc, x))
}
fn zip_with<F>(&self, f: F, other: &Self) -> Self
where
F: Fn(Self::Scalar, Self::Scalar) -> Self::Scalar,
{
self.iter()
.zip(other.iter())
.map(|(&x, &y)| f(x, y))
.collect()
}
fn filter<F>(&self, f: F) -> Self
where
F: Fn(Self::Scalar) -> bool,
{
self.iter().filter(|&&x| f(x)).cloned().collect()
}
fn take(&self, n: usize) -> Self {
self.iter().take(n).cloned().collect()
}
fn skip(&self, n: usize) -> Self {
self.iter().skip(n).cloned().collect()
}
fn sum(&self) -> Self::Scalar {
if self.is_empty() {
return Jet::<1>::constant(0.0);
}
let s = self[0];
self.reduce(s, |x, y| x + y)
}
fn prod(&self) -> Self::Scalar {
if self.is_empty() {
return Jet::<1>::constant(1.0);
}
let s = self[0];
self.reduce(s, |x, y| x * y)
}
}
impl Vector for Vec<Jet<1>> {
type Scalar = Jet<1>;
fn add_vec(&self, rhs: &Self) -> Self {
self.add_v(rhs)
}
fn sub_vec(&self, rhs: &Self) -> Self {
self.sub_v(rhs)
}
fn mul_scalar(&self, rhs: Self::Scalar) -> Self {
self.mul_s(rhs)
}
}
impl VecOps for Vec<Jet<1>> {}
pub type AD = Jet<2>;
#[inline]
#[allow(non_snake_case)]
pub fn AD0(x: f64) -> Jet<2> {
Jet::<2>::constant(x)
}
#[inline]
#[allow(non_snake_case)]
pub fn AD1(x: f64, dx: f64) -> Jet<2> {
Jet::<2>::new(x, [dx, 0.0])
}
#[inline]
#[allow(non_snake_case)]
pub fn AD2(x: f64, dx: f64, ddx: f64) -> Jet<2> {
Jet::<2>::new(x, [dx, ddx / 2.0])
}
pub trait ADVec: JetVec {
fn to_ad_vec(&self) -> Vec<AD>;
fn to_f64_vec_compat(&self) -> Vec<f64> {
self.to_f64_vec()
}
}
impl ADVec for Vec<f64> {
fn to_ad_vec(&self) -> Vec<AD> {
self.iter().map(|&x| Jet::<2>::constant(x)).collect()
}
}
impl ADVec for Vec<AD> {
fn to_ad_vec(&self) -> Vec<AD> {
self.clone()
}
}
impl JetVec for Vec<AD> {
fn to_jet_vec(&self) -> Vec<Jet<1>> {
self.iter()
.map(|j| Jet::<1>::new(j.value(), [j.dx()]))
.collect()
}
fn to_f64_vec(&self) -> Vec<f64> {
self.iter().map(|j| j.value()).collect()
}
}
impl FPVector for Vec<AD> {
type Scalar = AD;
fn fmap<F>(&self, f: F) -> Self
where
F: Fn(Self::Scalar) -> Self::Scalar,
{
self.iter().map(|&x| f(x)).collect()
}
fn reduce<F, T>(&self, init: T, f: F) -> Self::Scalar
where
F: Fn(Self::Scalar, Self::Scalar) -> Self::Scalar,
T: Into<Self::Scalar>,
{
self.iter().fold(init.into(), |acc, &x| f(acc, x))
}
fn zip_with<F>(&self, f: F, other: &Self) -> Self
where
F: Fn(Self::Scalar, Self::Scalar) -> Self::Scalar,
{
self.iter()
.zip(other.iter())
.map(|(&x, &y)| f(x, y))
.collect()
}
fn filter<F>(&self, f: F) -> Self
where
F: Fn(Self::Scalar) -> bool,
{
self.iter().filter(|&&x| f(x)).cloned().collect()
}
fn take(&self, n: usize) -> Self {
self.iter().take(n).cloned().collect()
}
fn skip(&self, n: usize) -> Self {
self.iter().skip(n).cloned().collect()
}
fn sum(&self) -> Self::Scalar {
if self.is_empty() {
return Jet::<2>::constant(0.0);
}
let s = self[0];
self.reduce(s, |x, y| x + y)
}
fn prod(&self) -> Self::Scalar {
if self.is_empty() {
return Jet::<2>::constant(1.0);
}
let s = self[0];
self.reduce(s, |x, y| x * y)
}
}
impl Vector for Vec<AD> {
type Scalar = AD;
fn add_vec(&self, rhs: &Self) -> Self {
self.add_v(rhs)
}
fn sub_vec(&self, rhs: &Self) -> Self {
self.sub_v(rhs)
}
fn mul_scalar(&self, rhs: Self::Scalar) -> Self {
self.mul_s(rhs)
}
}
impl VecOps for Vec<AD> {}