pub trait JetScalar<const K: usize>: Copy {
fn constant(c: f64) -> Self;
fn variable(x: f64, axis: usize) -> Self;
fn value(&self) -> f64;
fn add(&self, o: &Self) -> Self;
fn sub(&self, o: &Self) -> Self;
fn mul(&self, o: &Self) -> Self;
fn neg(&self) -> Self;
fn scale(&self, s: f64) -> Self;
fn compose_unary(&self, d: [f64; 5]) -> Self;
fn compose_unary_with(&self, stack_fn: impl Fn(f64) -> [f64; 5]) -> Self {
self.compose_unary(stack_fn(self.value()))
}
fn exp(&self) -> Self {
let e = self.value().exp();
self.compose_unary([e, e, e, e, e])
}
fn sqrt(&self) -> Self {
let u = self.value();
let s = u.sqrt();
self.compose_unary([
s,
0.5 / s,
-0.25 / (u * s),
0.375 / (u * u * s),
-0.9375 / (u * u * u * s),
])
}
fn ln(&self) -> Self {
let u = self.value();
let r = 1.0 / u;
self.compose_unary([u.ln(), r, -r * r, 2.0 * r * r * r, -6.0 * r * r * r * r])
}
fn recip(&self) -> Self {
let r = 1.0 / self.value();
let r2 = r * r;
self.compose_unary([r, -r2, 2.0 * r2 * r, -6.0 * r2 * r2, 24.0 * r2 * r2 * r])
}
fn powf(&self, a: f64) -> Self {
let u = self.value();
self.compose_unary([
u.powf(a),
a * u.powf(a - 1.0),
a * (a - 1.0) * u.powf(a - 2.0),
a * (a - 1.0) * (a - 2.0) * u.powf(a - 3.0),
a * (a - 1.0) * (a - 2.0) * (a - 3.0) * u.powf(a - 4.0),
])
}
fn ln_gamma(&self) -> Self {
self.compose_unary(crate::jet_tower::ln_gamma_derivative_stack(self.value()))
}
fn digamma(&self) -> Self {
self.compose_unary(crate::jet_tower::digamma_derivative_stack(self.value()))
}
}
impl<const K: usize> std::ops::Add for Order2<K> {
type Output = Self;
#[inline]
fn add(self, o: Self) -> Self {
Order2(self.0 + o.0)
}
}
impl<const K: usize> std::ops::Add<f64> for Order2<K> {
type Output = Self;
#[inline]
fn add(self, c: f64) -> Self {
Order2(self.0 + c)
}
}
impl<const K: usize> std::ops::Sub for Order2<K> {
type Output = Self;
#[inline]
fn sub(self, o: Self) -> Self {
Order2(self.0 + o.0.scale(-1.0))
}
}
impl<const K: usize> std::ops::Sub<f64> for Order2<K> {
type Output = Self;
#[inline]
fn sub(self, c: f64) -> Self {
Order2(self.0 + (-c))
}
}
impl<const K: usize> std::ops::Mul for Order2<K> {
type Output = Self;
#[inline]
fn mul(self, o: Self) -> Self {
Order2(crate::jet_tower::Tower2::mul(&self.0, &o.0))
}
}
impl<const K: usize> std::ops::Mul<f64> for Order2<K> {
type Output = Self;
#[inline]
fn mul(self, c: f64) -> Self {
Order2(self.0.scale(c))
}
}
impl<const K: usize> std::ops::Neg for Order2<K> {
type Output = Self;
#[inline]
fn neg(self) -> Self {
Order2(self.0.scale(-1.0))
}
}
pub fn filtered_implicit_solve_scalar<const K: usize, S: JetScalar<K>>(
a0: f64,
inv_fa: f64,
iters: usize,
f: impl Fn(&S) -> S,
) -> S {
let mut a = S::constant(a0);
for _ in 0..iters {
let residual = f(&a);
a = a.sub(&residual.scale(inv_fa));
}
a
}
#[derive(Clone, Copy, Debug)]
pub struct Order2<const K: usize>(pub crate::jet_tower::Tower2<K>);
impl<const K: usize> Order2<K> {
#[inline]
pub fn g(&self) -> [f64; K] {
self.0.g
}
#[inline]
pub fn h(&self) -> [[f64; K]; K] {
self.0.h
}
}
impl<const K: usize> JetScalar<K> for Order2<K> {
fn constant(c: f64) -> Self {
Order2(crate::jet_tower::Tower2::constant(c))
}
fn variable(x: f64, axis: usize) -> Self {
Order2(crate::jet_tower::Tower2::variable(x, axis))
}
fn value(&self) -> f64 {
self.0.v
}
fn add(&self, o: &Self) -> Self {
Order2(self.0 + o.0)
}
fn sub(&self, o: &Self) -> Self {
Order2(self.0 + o.0.scale(-1.0))
}
fn mul(&self, o: &Self) -> Self {
Order2(crate::jet_tower::Tower2::mul(&self.0, &o.0))
}
fn neg(&self) -> Self {
Order2(self.0.scale(-1.0))
}
fn scale(&self, s: f64) -> Self {
Order2(self.0.scale(s))
}
fn compose_unary(&self, d: [f64; 5]) -> Self {
Order2(self.0.compose_unary([d[0], d[1], d[2]]))
}
}
pub trait Lane: Copy {
fn splat(x: f64) -> Self;
fn add(self, o: Self) -> Self;
fn sub(self, o: Self) -> Self;
fn mul(self, o: Self) -> Self;
fn lane(self, i: usize) -> f64;
fn unary3(self, stack: impl Fn(f64) -> [f64; 3]) -> [Self; 3];
fn unary5(self, stack: impl Fn(f64) -> [f64; 5]) -> [Self; 5];
fn unary_with<const N: usize>(self, stack: impl Fn(f64) -> [f64; N]) -> [Self; N];
}
impl Lane for f64 {
#[inline]
fn splat(x: f64) -> Self {
x
}
#[inline]
fn add(self, o: Self) -> Self {
self + o
}
#[inline]
fn sub(self, o: Self) -> Self {
self - o
}
#[inline]
fn mul(self, o: Self) -> Self {
self * o
}
#[inline]
fn lane(self, _: usize) -> f64 {
self
}
#[inline]
fn unary3(self, stack: impl Fn(f64) -> [f64; 3]) -> [Self; 3] {
stack(self)
}
#[inline]
fn unary5(self, stack: impl Fn(f64) -> [f64; 5]) -> [Self; 5] {
stack(self)
}
#[inline]
fn unary_with<const N: usize>(self, stack: impl Fn(f64) -> [f64; N]) -> [Self; N] {
stack(self)
}
}
impl Lane for wide::f64x4 {
#[inline]
fn splat(x: f64) -> Self {
wide::f64x4::splat(x)
}
#[inline]
fn add(self, o: Self) -> Self {
self + o
}
#[inline]
fn sub(self, o: Self) -> Self {
self - o
}
#[inline]
fn mul(self, o: Self) -> Self {
self * o
}
#[inline]
fn lane(self, i: usize) -> f64 {
self.to_array()[i]
}
#[inline]
fn unary3(self, stack: impl Fn(f64) -> [f64; 3]) -> [Self; 3] {
let a = self.to_array();
let mut d0 = [0.0_f64; 4];
let mut d1 = [0.0_f64; 4];
let mut d2 = [0.0_f64; 4];
for i in 0..4 {
let s = stack(a[i]);
d0[i] = s[0];
d1[i] = s[1];
d2[i] = s[2];
}
[
wide::f64x4::new(d0),
wide::f64x4::new(d1),
wide::f64x4::new(d2),
]
}
#[inline]
fn unary5(self, stack: impl Fn(f64) -> [f64; 5]) -> [Self; 5] {
let a = self.to_array();
let mut d = [[0.0_f64; 4]; 5];
for i in 0..4 {
let s = stack(a[i]);
for (k, dk) in d.iter_mut().enumerate() {
dk[i] = s[k];
}
}
[
wide::f64x4::new(d[0]),
wide::f64x4::new(d[1]),
wide::f64x4::new(d[2]),
wide::f64x4::new(d[3]),
wide::f64x4::new(d[4]),
]
}
#[inline]
fn unary_with<const N: usize>(self, stack: impl Fn(f64) -> [f64; N]) -> [Self; N] {
let a = self.to_array();
let mut cols = [[0.0_f64; 4]; N];
for (i, &base) in a.iter().enumerate() {
let s = stack(base);
for (k, sk) in s.iter().enumerate() {
cols[k][i] = *sk;
}
}
std::array::from_fn(|k| wide::f64x4::new(cols[k]))
}
}
#[derive(Clone, Copy, Debug)]
pub struct Order2Lane<L: Lane, const K: usize> {
pub v: L,
pub g: [L; K],
pub h: [[L; K]; K],
}
pub type Order2Batch<const K: usize> = Order2Lane<wide::f64x4, K>;
impl<L: Lane, const K: usize> Order2Lane<L, K> {
#[inline]
pub fn constant(c: L) -> Self {
Order2Lane {
v: c,
g: [L::splat(0.0); K],
h: [[L::splat(0.0); K]; K],
}
}
#[inline]
pub fn variable(value: L, axis: usize) -> Self {
let mut out = Self::constant(value);
out.g[axis] = L::splat(1.0);
out
}
#[inline]
pub fn add(&self, o: &Self) -> Self {
let mut out = *self;
out.v = self.v.add(o.v);
for i in 0..K {
out.g[i] = self.g[i].add(o.g[i]);
for j in 0..K {
out.h[i][j] = self.h[i][j].add(o.h[i][j]);
}
}
out
}
#[inline]
pub fn scale(&self, s: f64) -> Self {
let sl = L::splat(s);
let mut out = *self;
out.v = self.v.mul(sl);
for i in 0..K {
out.g[i] = self.g[i].mul(sl);
for j in 0..K {
out.h[i][j] = self.h[i][j].mul(sl);
}
}
out
}
#[inline]
pub fn sub(&self, o: &Self) -> Self {
self.add(&o.scale(-1.0))
}
#[inline]
pub fn neg(&self) -> Self {
self.scale(-1.0)
}
#[inline]
pub fn mul(&self, o: &Self) -> Self {
let a = self;
let b = o;
let mut out = Self::constant(a.v.mul(b.v));
for i in 0..K {
out.g[i] = a.v.mul(b.g[i]).add(a.g[i].mul(b.v));
}
for i in 0..K {
for j in 0..K {
out.h[i][j] = a
.v
.mul(b.h[i][j])
.add(a.g[i].mul(b.g[j]))
.add(a.g[j].mul(b.g[i]))
.add(a.h[i][j].mul(b.v));
}
}
out
}
#[inline]
pub fn compose_unary(&self, d: [L; 3]) -> Self {
let mut out = Self::constant(d[0]);
for i in 0..K {
let mut acc = L::splat(0.0);
acc = acc.add(d[1].mul(self.g[i]));
out.g[i] = acc;
}
for i in 0..K {
for j in 0..K {
let mut acc = L::splat(0.0);
acc = acc.add(d[1].mul(self.h[i][j]));
acc = acc.add(d[2].mul(self.g[i]).mul(self.g[j]));
out.h[i][j] = acc;
}
}
out
}
#[inline]
pub fn exp(&self) -> Self {
let d = self.v.unary3(|u| {
let e = u.exp();
[e, e, e]
});
self.compose_unary(d)
}
#[inline]
pub fn ln(&self) -> Self {
let d = self.v.unary3(|u| {
let r = 1.0 / u;
[u.ln(), r, -r * r]
});
self.compose_unary(d)
}
#[inline]
pub fn sqrt(&self) -> Self {
let d = self.v.unary3(|u| {
let s = u.sqrt();
[s, 0.5 / s, -0.25 / (u * s)]
});
self.compose_unary(d)
}
#[inline]
pub fn recip(&self) -> Self {
let d = self.v.unary3(|u| {
let r = 1.0 / u;
let r2 = r * r;
[r, -r2, 2.0 * r2 * r]
});
self.compose_unary(d)
}
#[inline]
pub fn powf(&self, a: f64) -> Self {
let d = self.v.unary3(|u| {
[
u.powf(a),
a * u.powf(a - 1.0),
a * (a - 1.0) * u.powf(a - 2.0),
]
});
self.compose_unary(d)
}
}
impl<const K: usize> Order2Batch<K> {
#[inline]
#[must_use]
pub fn lane(&self, i: usize) -> Order2<K> {
let mut t = crate::jet_tower::Tower2::<K>::constant(self.v.lane(i));
for a in 0..K {
t.g[a] = self.g[a].lane(i);
for b in 0..K {
t.h[a][b] = self.h[a][b].lane(i);
}
}
Order2(t)
}
}
#[derive(Clone, Copy, Debug)]
pub struct Order1<const K: usize> {
pub v: f64,
pub g: [f64; K],
}
impl<const K: usize> Order1<K> {
#[inline]
pub fn g(&self) -> [f64; K] {
self.g
}
}
impl<const K: usize> JetScalar<K> for Order1<K> {
fn constant(c: f64) -> Self {
Order1 { v: c, g: [0.0; K] }
}
fn variable(x: f64, axis: usize) -> Self {
let mut g = [0.0; K];
g[axis] = 1.0;
Order1 { v: x, g }
}
fn value(&self) -> f64 {
self.v
}
fn add(&self, o: &Self) -> Self {
let mut g = self.g;
for i in 0..K {
g[i] += o.g[i];
}
Order1 { v: self.v + o.v, g }
}
fn sub(&self, o: &Self) -> Self {
self.add(&o.scale(-1.0))
}
fn mul(&self, o: &Self) -> Self {
let a = self;
let b = o;
let mut g = [0.0; K];
for i in 0..K {
g[i] = a.v * b.g[i] + a.g[i] * b.v;
}
Order1 { v: a.v * b.v, g }
}
fn neg(&self) -> Self {
self.scale(-1.0)
}
fn scale(&self, s: f64) -> Self {
let mut g = self.g;
for i in 0..K {
g[i] *= s;
}
Order1 { v: self.v * s, g }
}
fn compose_unary(&self, d: [f64; 5]) -> Self {
let mut g = [0.0; K];
for i in 0..K {
g[i] = d[1] * self.g[i];
}
Order1 { v: d[0], g }
}
}
#[derive(Clone, Copy, Debug)]
pub struct OneSeed<const K: usize> {
pub base: Order2<K>,
pub eps: Order2<K>,
}
impl<const K: usize> OneSeed<K> {
pub fn seed_direction(x: f64, axis: usize, u_axis: f64) -> Self {
OneSeed {
base: Order2::variable(x, axis),
eps: Order2::constant(u_axis),
}
}
pub fn contracted_third(&self) -> [[f64; K]; K] {
self.eps.h()
}
}
impl<const K: usize> JetScalar<K> for OneSeed<K> {
fn constant(c: f64) -> Self {
OneSeed {
base: Order2::constant(c),
eps: Order2::constant(0.0),
}
}
fn variable(x: f64, axis: usize) -> Self {
OneSeed {
base: Order2::variable(x, axis),
eps: Order2::constant(0.0),
}
}
fn value(&self) -> f64 {
self.base.value()
}
fn add(&self, o: &Self) -> Self {
OneSeed {
base: self.base.add(&o.base),
eps: self.eps.add(&o.eps),
}
}
fn sub(&self, o: &Self) -> Self {
OneSeed {
base: self.base.sub(&o.base),
eps: self.eps.sub(&o.eps),
}
}
fn mul(&self, o: &Self) -> Self {
OneSeed {
base: self.base.mul(&o.base),
eps: self.base.mul(&o.eps).add(&self.eps.mul(&o.base)),
}
}
fn neg(&self) -> Self {
OneSeed {
base: self.base.neg(),
eps: self.eps.neg(),
}
}
fn scale(&self, s: f64) -> Self {
OneSeed {
base: self.base.scale(s),
eps: self.eps.scale(s),
}
}
fn compose_unary(&self, d: [f64; 5]) -> Self {
let base = self.base.compose_unary([d[0], d[1], d[2], d[3], d[4]]);
let fprime = self.base.compose_unary([d[1], d[2], d[3], d[4], d[4]]);
let eps = fprime.mul(&self.eps);
OneSeed { base, eps }
}
}
#[derive(Clone, Copy, Debug)]
pub struct OneSeedLane<L: Lane, const K: usize> {
pub base: Order2Lane<L, K>,
pub eps: Order2Lane<L, K>,
}
pub type OneSeedBatch<const K: usize> = OneSeedLane<wide::f64x4, K>;
impl<L: Lane, const K: usize> OneSeedLane<L, K> {
#[inline]
pub fn constant(c: L) -> Self {
OneSeedLane {
base: Order2Lane::constant(c),
eps: Order2Lane::constant(L::splat(0.0)),
}
}
#[inline]
pub fn variable(value: L, axis: usize) -> Self {
OneSeedLane {
base: Order2Lane::variable(value, axis),
eps: Order2Lane::constant(L::splat(0.0)),
}
}
#[inline]
pub fn seed_direction(value: L, axis: usize, u_axis: L) -> Self {
OneSeedLane {
base: Order2Lane::variable(value, axis),
eps: Order2Lane::constant(u_axis),
}
}
#[inline]
#[must_use]
pub fn contracted_third(&self) -> [[L; K]; K] {
self.eps.h
}
#[inline]
pub fn add(&self, o: &Self) -> Self {
OneSeedLane {
base: self.base.add(&o.base),
eps: self.eps.add(&o.eps),
}
}
#[inline]
pub fn sub(&self, o: &Self) -> Self {
OneSeedLane {
base: self.base.sub(&o.base),
eps: self.eps.sub(&o.eps),
}
}
#[inline]
pub fn mul(&self, o: &Self) -> Self {
OneSeedLane {
base: self.base.mul(&o.base),
eps: self.base.mul(&o.eps).add(&self.eps.mul(&o.base)),
}
}
#[inline]
pub fn neg(&self) -> Self {
OneSeedLane {
base: self.base.neg(),
eps: self.eps.neg(),
}
}
#[inline]
pub fn scale(&self, s: f64) -> Self {
OneSeedLane {
base: self.base.scale(s),
eps: self.eps.scale(s),
}
}
#[inline]
pub fn compose_unary(&self, d: [L; 5]) -> Self {
let base = self.base.compose_unary([d[0], d[1], d[2]]);
let fprime = self.base.compose_unary([d[1], d[2], d[3]]);
let eps = fprime.mul(&self.eps);
OneSeedLane { base, eps }
}
#[inline]
pub fn exp(&self) -> Self {
let d = self.base.v.unary5(|u| {
let e = u.exp();
[e, e, e, e, e]
});
self.compose_unary(d)
}
#[inline]
pub fn ln(&self) -> Self {
let d = self.base.v.unary5(|u| {
let r = 1.0 / u;
[u.ln(), r, -r * r, 2.0 * r * r * r, -6.0 * r * r * r * r]
});
self.compose_unary(d)
}
#[inline]
pub fn sqrt(&self) -> Self {
let d = self.base.v.unary5(|u| {
let s = u.sqrt();
[
s,
0.5 / s,
-0.25 / (u * s),
0.375 / (u * u * s),
-0.9375 / (u * u * u * s),
]
});
self.compose_unary(d)
}
#[inline]
pub fn recip(&self) -> Self {
let d = self.base.v.unary5(|u| {
let r = 1.0 / u;
let r2 = r * r;
[r, -r2, 2.0 * r2 * r, -6.0 * r2 * r2, 24.0 * r2 * r2 * r]
});
self.compose_unary(d)
}
#[inline]
pub fn powf(&self, a: f64) -> Self {
let d = self.base.v.unary5(|u| {
[
u.powf(a),
a * u.powf(a - 1.0),
a * (a - 1.0) * u.powf(a - 2.0),
a * (a - 1.0) * (a - 2.0) * u.powf(a - 3.0),
a * (a - 1.0) * (a - 2.0) * (a - 3.0) * u.powf(a - 4.0),
]
});
self.compose_unary(d)
}
#[inline]
pub fn ln_gamma(&self) -> Self {
let d = self
.base
.v
.unary5(crate::jet_tower::ln_gamma_derivative_stack);
self.compose_unary(d)
}
#[inline]
pub fn digamma(&self) -> Self {
let d = self
.base
.v
.unary5(crate::jet_tower::digamma_derivative_stack);
self.compose_unary(d)
}
}
impl<const K: usize> OneSeedBatch<K> {
#[inline]
#[must_use]
pub fn lane(&self, i: usize) -> OneSeed<K> {
OneSeed {
base: self.base.lane(i),
eps: self.eps.lane(i),
}
}
}
#[derive(Clone, Copy, Debug)]
pub struct TwoSeed<const K: usize> {
pub base: Order2<K>,
pub eps: Order2<K>,
pub del: Order2<K>,
pub eps_del: Order2<K>,
}
impl<const K: usize> TwoSeed<K> {
pub fn seed(x: f64, axis: usize, u_axis: f64, v_axis: f64) -> Self {
TwoSeed {
base: Order2::variable(x, axis),
eps: Order2::constant(u_axis),
del: Order2::constant(v_axis),
eps_del: Order2::constant(0.0),
}
}
pub fn contracted_fourth(&self) -> [[f64; K]; K] {
self.eps_del.h()
}
}
impl<const K: usize> JetScalar<K> for TwoSeed<K> {
fn constant(c: f64) -> Self {
TwoSeed {
base: Order2::constant(c),
eps: Order2::constant(0.0),
del: Order2::constant(0.0),
eps_del: Order2::constant(0.0),
}
}
fn variable(x: f64, axis: usize) -> Self {
TwoSeed {
base: Order2::variable(x, axis),
eps: Order2::constant(0.0),
del: Order2::constant(0.0),
eps_del: Order2::constant(0.0),
}
}
fn value(&self) -> f64 {
self.base.value()
}
fn add(&self, o: &Self) -> Self {
TwoSeed {
base: self.base.add(&o.base),
eps: self.eps.add(&o.eps),
del: self.del.add(&o.del),
eps_del: self.eps_del.add(&o.eps_del),
}
}
fn sub(&self, o: &Self) -> Self {
TwoSeed {
base: self.base.sub(&o.base),
eps: self.eps.sub(&o.eps),
del: self.del.sub(&o.del),
eps_del: self.eps_del.sub(&o.eps_del),
}
}
fn mul(&self, o: &Self) -> Self {
let a = self;
let b = o;
let base = a.base.mul(&b.base);
let eps = a.base.mul(&b.eps).add(&a.eps.mul(&b.base));
let del = a.base.mul(&b.del).add(&a.del.mul(&b.base));
let eps_del = a
.base
.mul(&b.eps_del)
.add(&a.eps.mul(&b.del))
.add(&a.del.mul(&b.eps))
.add(&a.eps_del.mul(&b.base));
TwoSeed {
base,
eps,
del,
eps_del,
}
}
fn neg(&self) -> Self {
TwoSeed {
base: self.base.neg(),
eps: self.eps.neg(),
del: self.del.neg(),
eps_del: self.eps_del.neg(),
}
}
fn scale(&self, s: f64) -> Self {
TwoSeed {
base: self.base.scale(s),
eps: self.eps.scale(s),
del: self.del.scale(s),
eps_del: self.eps_del.scale(s),
}
}
fn compose_unary(&self, d: [f64; 5]) -> Self {
let base = self.base.compose_unary([d[0], d[1], d[2], d[3], d[4]]);
let fprime = self.base.compose_unary([d[1], d[2], d[3], d[4], d[4]]); let fsecond = self.base.compose_unary([d[2], d[3], d[4], d[4], d[4]]); let eps = fprime.mul(&self.eps);
let del = fprime.mul(&self.del);
let eps_del = fsecond
.mul(&self.eps)
.mul(&self.del)
.add(&fprime.mul(&self.eps_del));
TwoSeed {
base,
eps,
del,
eps_del,
}
}
}
#[derive(Clone, Copy, Debug)]
pub struct TwoSeedLane<L: Lane, const K: usize> {
pub base: Order2Lane<L, K>,
pub eps: Order2Lane<L, K>,
pub del: Order2Lane<L, K>,
pub eps_del: Order2Lane<L, K>,
}
pub type TwoSeedBatch<const K: usize> = TwoSeedLane<wide::f64x4, K>;
impl<L: Lane, const K: usize> TwoSeedLane<L, K> {
#[inline]
pub fn constant(c: L) -> Self {
let z = Order2Lane::constant(L::splat(0.0));
TwoSeedLane {
base: Order2Lane::constant(c),
eps: z,
del: z,
eps_del: z,
}
}
#[inline]
pub fn variable(value: L, axis: usize) -> Self {
let z = Order2Lane::constant(L::splat(0.0));
TwoSeedLane {
base: Order2Lane::variable(value, axis),
eps: z,
del: z,
eps_del: z,
}
}
#[inline]
pub fn seed(value: L, axis: usize, u_axis: L, v_axis: L) -> Self {
TwoSeedLane {
base: Order2Lane::variable(value, axis),
eps: Order2Lane::constant(u_axis),
del: Order2Lane::constant(v_axis),
eps_del: Order2Lane::constant(L::splat(0.0)),
}
}
#[inline]
#[must_use]
pub fn contracted_fourth(&self) -> [[L; K]; K] {
self.eps_del.h
}
#[inline]
pub fn add(&self, o: &Self) -> Self {
TwoSeedLane {
base: self.base.add(&o.base),
eps: self.eps.add(&o.eps),
del: self.del.add(&o.del),
eps_del: self.eps_del.add(&o.eps_del),
}
}
#[inline]
pub fn sub(&self, o: &Self) -> Self {
TwoSeedLane {
base: self.base.sub(&o.base),
eps: self.eps.sub(&o.eps),
del: self.del.sub(&o.del),
eps_del: self.eps_del.sub(&o.eps_del),
}
}
#[inline]
pub fn mul(&self, o: &Self) -> Self {
let a = self;
let b = o;
let base = a.base.mul(&b.base);
let eps = a.base.mul(&b.eps).add(&a.eps.mul(&b.base));
let del = a.base.mul(&b.del).add(&a.del.mul(&b.base));
let eps_del = a
.base
.mul(&b.eps_del)
.add(&a.eps.mul(&b.del))
.add(&a.del.mul(&b.eps))
.add(&a.eps_del.mul(&b.base));
TwoSeedLane {
base,
eps,
del,
eps_del,
}
}
#[inline]
pub fn neg(&self) -> Self {
TwoSeedLane {
base: self.base.neg(),
eps: self.eps.neg(),
del: self.del.neg(),
eps_del: self.eps_del.neg(),
}
}
#[inline]
pub fn scale(&self, s: f64) -> Self {
TwoSeedLane {
base: self.base.scale(s),
eps: self.eps.scale(s),
del: self.del.scale(s),
eps_del: self.eps_del.scale(s),
}
}
#[inline]
pub fn compose_unary(&self, d: [L; 5]) -> Self {
let base = self.base.compose_unary([d[0], d[1], d[2]]);
let fprime = self.base.compose_unary([d[1], d[2], d[3]]);
let fsecond = self.base.compose_unary([d[2], d[3], d[4]]);
let eps = fprime.mul(&self.eps);
let del = fprime.mul(&self.del);
let eps_del = fsecond
.mul(&self.eps)
.mul(&self.del)
.add(&fprime.mul(&self.eps_del));
TwoSeedLane {
base,
eps,
del,
eps_del,
}
}
#[inline]
pub fn exp(&self) -> Self {
let d = self.base.v.unary5(|u| {
let e = u.exp();
[e, e, e, e, e]
});
self.compose_unary(d)
}
#[inline]
pub fn ln(&self) -> Self {
let d = self.base.v.unary5(|u| {
let r = 1.0 / u;
[u.ln(), r, -r * r, 2.0 * r * r * r, -6.0 * r * r * r * r]
});
self.compose_unary(d)
}
#[inline]
pub fn sqrt(&self) -> Self {
let d = self.base.v.unary5(|u| {
let s = u.sqrt();
[
s,
0.5 / s,
-0.25 / (u * s),
0.375 / (u * u * s),
-0.9375 / (u * u * u * s),
]
});
self.compose_unary(d)
}
#[inline]
pub fn recip(&self) -> Self {
let d = self.base.v.unary5(|u| {
let r = 1.0 / u;
let r2 = r * r;
[r, -r2, 2.0 * r2 * r, -6.0 * r2 * r2, 24.0 * r2 * r2 * r]
});
self.compose_unary(d)
}
#[inline]
pub fn powf(&self, a: f64) -> Self {
let d = self.base.v.unary5(|u| {
[
u.powf(a),
a * u.powf(a - 1.0),
a * (a - 1.0) * u.powf(a - 2.0),
a * (a - 1.0) * (a - 2.0) * u.powf(a - 3.0),
a * (a - 1.0) * (a - 2.0) * (a - 3.0) * u.powf(a - 4.0),
]
});
self.compose_unary(d)
}
#[inline]
pub fn ln_gamma(&self) -> Self {
let d = self
.base
.v
.unary5(crate::jet_tower::ln_gamma_derivative_stack);
self.compose_unary(d)
}
#[inline]
pub fn digamma(&self) -> Self {
let d = self
.base
.v
.unary5(crate::jet_tower::digamma_derivative_stack);
self.compose_unary(d)
}
}
impl<const K: usize> TwoSeedBatch<K> {
#[inline]
#[must_use]
pub fn lane(&self, i: usize) -> TwoSeed<K> {
TwoSeed {
base: self.base.lane(i),
eps: self.eps.lane(i),
del: self.del.lane(i),
eps_del: self.eps_del.lane(i),
}
}
}
impl<const K: usize> JetScalar<K> for crate::jet_tower::Tower3<K> {
fn constant(c: f64) -> Self {
crate::jet_tower::Tower3::constant(c)
}
fn variable(x: f64, axis: usize) -> Self {
crate::jet_tower::Tower3::variable(x, axis)
}
fn value(&self) -> f64 {
self.v
}
fn add(&self, o: &Self) -> Self {
*self + *o
}
fn sub(&self, o: &Self) -> Self {
*self + o.scale(-1.0)
}
fn mul(&self, o: &Self) -> Self {
crate::jet_tower::Tower3::mul(self, o)
}
fn neg(&self) -> Self {
self.scale(-1.0)
}
fn scale(&self, s: f64) -> Self {
crate::jet_tower::Tower3::scale(self, s)
}
fn compose_unary(&self, d: [f64; 5]) -> Self {
crate::jet_tower::Tower3::compose_unary(self, [d[0], d[1], d[2], d[3]])
}
}
impl<const K: usize> JetScalar<K> for crate::jet_tower::Tower4<K> {
fn constant(c: f64) -> Self {
crate::jet_tower::Tower4::constant(c)
}
fn variable(x: f64, axis: usize) -> Self {
crate::jet_tower::Tower4::variable(x, axis)
}
fn value(&self) -> f64 {
self.v
}
fn add(&self, o: &Self) -> Self {
*self + *o
}
fn sub(&self, o: &Self) -> Self {
*self - *o
}
fn mul(&self, o: &Self) -> Self {
crate::jet_tower::Tower4::mul(self, o)
}
fn neg(&self) -> Self {
self.scale(-1.0)
}
fn scale(&self, s: f64) -> Self {
crate::jet_tower::Tower4::scale(self, s)
}
fn compose_unary(&self, d: [f64; 5]) -> Self {
crate::jet_tower::Tower4::compose_unary(self, d)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::jet_tower::{RowNllProgram, Tower4, evaluate_program};
fn row_expr<S: JetScalar<2>>(p: &[S; 2]) -> S {
let g = p[0].mul(&p[1]).exp();
let inner = g.add(&S::constant(2.0));
let radic = p[0].mul(&p[0]).add(&S::constant(1.0)).sqrt();
inner.mul(&radic).sub(&p[1].mul(&p[1]).scale(0.5))
}
struct ExprProgram {
p: [f64; 2],
}
impl RowNllProgram<2> for ExprProgram {
fn n_rows(&self) -> usize {
1
}
fn primaries(&self, row: usize) -> Result<[f64; 2], String> {
if row >= self.n_rows() {
return Err(format!("ExprProgram: row {row} out of range"));
}
Ok(self.p)
}
fn row_nll(&self, row: usize, p: &[Tower4<2>; 2]) -> Result<Tower4<2>, String> {
if row >= self.n_rows() {
return Err(format!("ExprProgram: row {row} out of range"));
}
Ok(row_expr(p))
}
}
const SEED: [f64; 2] = [0.37, -0.81];
const U: [f64; 2] = [0.6, -0.2];
const V: [f64; 2] = [-0.4, 1.1];
const TOL: f64 = 1e-10;
fn close(a: f64, b: f64, label: &str) {
let band = TOL + TOL * a.abs().max(b.abs());
assert!(
(a - b).abs() <= band,
"{label}: {a:+.15e} vs {b:+.15e} (band {band:.3e})"
);
}
fn tower() -> Tower4<2> {
evaluate_program(&ExprProgram { p: SEED }, 0).expect("tower")
}
#[test]
fn order2_matches_tower_value_grad_hessian() {
let t = tower();
let vars: [Order2<2>; 2] = std::array::from_fn(|a| Order2::variable(SEED[a], a));
let s = row_expr(&vars);
close(s.value(), t.v, "value");
for a in 0..2 {
close(s.0.g[a], t.g[a], &format!("grad[{a}]"));
for b in 0..2 {
close(s.h()[a][b], t.h[a][b], &format!("hess[{a}][{b}]"));
}
}
}
#[test]
fn compose_unary_with_scalar_seam_bit_identical() {
fn rand_unit(state: &mut u64) -> f64 {
let mut x = *state;
x ^= x << 13;
x ^= x >> 7;
x ^= x << 17;
*state = x;
2.0 * ((x >> 11) as f64 / ((1u64 << 53) as f64)) - 1.0
}
fn stack(u: f64) -> [f64; 5] {
[u.sin(), u.cos(), (2.0 * u).sin(), (0.5 * u).cos(), u * u - 0.3]
}
fn run<const K: usize>(state: &mut u64, n: usize) -> usize {
for _ in 0..n {
let base = rand_unit(state);
let mut s = Order2::<K>::variable(base, 0);
for a in 1..K {
s = JetScalar::mul(&s, &Order2::<K>::variable(rand_unit(state), a));
}
let with = s.compose_unary_with(stack);
let explicit = s.compose_unary(stack(s.value()));
assert_eq!(with.value().to_bits(), explicit.value().to_bits(), "value");
for a in 0..K {
assert_eq!(with.g()[a].to_bits(), explicit.g()[a].to_bits(), "g[{a}]");
for b in 0..K {
assert_eq!(
with.h()[a][b].to_bits(),
explicit.h()[a][b].to_bits(),
"h[{a}][{b}]"
);
}
}
}
n
}
let mut st = 0x9e37_79b9_7f4a_7c15u64;
let total =
run::<2>(&mut st, 1100) + run::<3>(&mut st, 1100) + run::<4>(&mut st, 1100) + run::<9>(&mut st, 1100);
assert_eq!(total, 4400);
}
#[test]
fn one_seed_matches_tower_third_contracted() {
let t = tower();
let truth = t.third_contracted(&U);
let vars: [OneSeed<2>; 2] =
std::array::from_fn(|a| OneSeed::seed_direction(SEED[a], a, U[a]));
let s = row_expr(&vars);
close(s.value(), t.v, "value");
for a in 0..2 {
for b in 0..2 {
close(s.base.h()[a][b], t.h[a][b], &format!("base hess[{a}][{b}]"));
}
}
let third = s.contracted_third();
for a in 0..2 {
for b in 0..2 {
close(third[a][b], truth[a][b], &format!("third[{a}][{b}]"));
}
}
}
#[test]
fn two_seed_matches_tower_fourth_contracted() {
let t = tower();
let truth4 = t.fourth_contracted(&U, &V);
let truth3_u = t.third_contracted(&U);
let truth3_v = t.third_contracted(&V);
let vars: [TwoSeed<2>; 2] = std::array::from_fn(|a| TwoSeed::seed(SEED[a], a, U[a], V[a]));
let s = row_expr(&vars);
close(s.value(), t.v, "value");
for a in 0..2 {
close(s.base.0.g[a], t.g[a], &format!("grad[{a}]"));
for b in 0..2 {
close(s.base.h()[a][b], t.h[a][b], &format!("base hess[{a}][{b}]"));
close(
s.eps.h()[a][b],
truth3_u[a][b],
&format!("eps third_u[{a}][{b}]"),
);
close(
s.del.h()[a][b],
truth3_v[a][b],
&format!("del third_v[{a}][{b}]"),
);
}
}
let fourth = s.contracted_fourth();
for a in 0..2 {
for b in 0..2 {
close(fourth[a][b], truth4[a][b], &format!("fourth[{a}][{b}]"));
}
}
}
#[test]
fn generic_program_seam_matches_tower_for_every_channel() {
let t = tower();
let o2: [Order2<2>; 2] = std::array::from_fn(|a| Order2::variable(SEED[a], a));
let so2 = row_expr(&o2);
close(so2.value(), t.v, "seam order2 value");
let os: [OneSeed<2>; 2] =
std::array::from_fn(|a| OneSeed::seed_direction(SEED[a], a, U[a]));
let third = row_expr(&os).contracted_third();
let truth3 = t.third_contracted(&U);
for a in 0..2 {
for b in 0..2 {
close(third[a][b], truth3[a][b], &format!("seam third[{a}][{b}]"));
}
}
let ts: [TwoSeed<2>; 2] = std::array::from_fn(|a| TwoSeed::seed(SEED[a], a, U[a], V[a]));
let fourth = row_expr(&ts).contracted_fourth();
let truth4 = t.fourth_contracted(&U, &V);
for a in 0..2 {
for b in 0..2 {
close(
fourth[a][b],
truth4[a][b],
&format!("seam fourth[{a}][{b}]"),
);
}
}
}
#[test]
fn tower4_as_jetscalar_matches_program_tower_all_channels() {
let t = tower();
let vars: [Tower4<2>; 2] = std::array::from_fn(|a| Tower4::variable(SEED[a], a));
let s = row_expr(&vars);
close(s.v, t.v, "tower-jetscalar value");
for a in 0..2 {
close(s.g[a], t.g[a], &format!("tower-jetscalar grad[{a}]"));
for b in 0..2 {
close(
s.h[a][b],
t.h[a][b],
&format!("tower-jetscalar hess[{a}][{b}]"),
);
for c in 0..2 {
close(
s.t3[a][b][c],
t.t3[a][b][c],
&format!("tower-jetscalar t3[{a}][{b}][{c}]"),
);
for d in 0..2 {
close(
s.t4[a][b][c][d],
t.t4[a][b][c][d],
&format!("tower-jetscalar t4[{a}][{b}][{c}][{d}]"),
);
}
}
}
}
}
}
#[cfg(test)]
mod batch_tests {
use super::{
JetScalar, Lane, OneSeed, OneSeedBatch, OneSeedLane, Order2, Order2Batch, Order2Lane,
TwoSeed, TwoSeedBatch, TwoSeedLane,
};
trait RowAlg<const K: usize>: Copy {
fn constant(c: f64) -> Self;
fn add(&self, o: &Self) -> Self;
fn sub(&self, o: &Self) -> Self;
fn mul(&self, o: &Self) -> Self;
fn scale(&self, s: f64) -> Self;
fn exp(&self) -> Self;
fn sqrt(&self) -> Self;
fn recip(&self) -> Self;
}
impl<const K: usize> RowAlg<K> for Order2<K> {
fn constant(c: f64) -> Self {
<Self as JetScalar<K>>::constant(c)
}
fn add(&self, o: &Self) -> Self {
JetScalar::add(self, o)
}
fn sub(&self, o: &Self) -> Self {
JetScalar::sub(self, o)
}
fn mul(&self, o: &Self) -> Self {
JetScalar::mul(self, o)
}
fn scale(&self, s: f64) -> Self {
JetScalar::scale(self, s)
}
fn exp(&self) -> Self {
JetScalar::exp(self)
}
fn sqrt(&self) -> Self {
JetScalar::sqrt(self)
}
fn recip(&self) -> Self {
JetScalar::recip(self)
}
}
impl<L: Lane, const K: usize> RowAlg<K> for Order2Lane<L, K> {
fn constant(c: f64) -> Self {
Order2Lane::constant(L::splat(c))
}
fn add(&self, o: &Self) -> Self {
Order2Lane::add(self, o)
}
fn sub(&self, o: &Self) -> Self {
Order2Lane::sub(self, o)
}
fn mul(&self, o: &Self) -> Self {
Order2Lane::mul(self, o)
}
fn scale(&self, s: f64) -> Self {
Order2Lane::scale(self, s)
}
fn exp(&self) -> Self {
Order2Lane::exp(self)
}
fn sqrt(&self) -> Self {
Order2Lane::sqrt(self)
}
fn recip(&self) -> Self {
Order2Lane::recip(self)
}
}
fn row_expr<const K: usize, A: RowAlg<K>>(p: &[A; K]) -> A {
let mut s = A::constant(0.3);
for a in 0..K {
let b = (a + 1) % K;
s = s.add(&p[a].mul(&p[b]).scale(0.1 + 0.05 * a as f64));
}
let e = s.exp();
let r = s.mul(&s).add(&A::constant(1.0)).sqrt();
let denom = e.add(&A::constant(2.0));
e.mul(&r).sub(&s.scale(0.5)).mul(&denom.recip())
}
fn rand_unit(state: &mut u64) -> f64 {
let mut x = *state;
x ^= x << 13;
x ^= x >> 7;
x ^= x << 17;
*state = x;
let u = (x >> 11) as f64 / ((1u64 << 53) as f64); 2.0 * u - 1.0
}
fn check_k<const K: usize>(state: &mut u64, batches: usize) -> usize {
let mut verified_rows = 0usize;
for _ in 0..batches {
let rows: [[f64; K]; 4] =
std::array::from_fn(|_| std::array::from_fn(|_| rand_unit(state)));
let prod: [Order2<K>; 4] = std::array::from_fn(|r| {
let p: [Order2<K>; K] = std::array::from_fn(|a| Order2::variable(rows[r][a], a));
row_expr(&p)
});
let scal: [Order2Lane<f64, K>; 4] = std::array::from_fn(|r| {
let p: [Order2Lane<f64, K>; K] =
std::array::from_fn(|a| Order2Lane::variable(rows[r][a], a));
row_expr(&p)
});
let pbatch: [Order2Batch<K>; K] = std::array::from_fn(|a| {
let packed =
wide::f64x4::new([rows[0][a], rows[1][a], rows[2][a], rows[3][a]]);
Order2Batch::variable(packed, a)
});
let batch = row_expr(&pbatch);
for r in 0..4 {
let g = prod[r].0;
assert_eq!(scal[r].v.to_bits(), g.v.to_bits(), "K={K} scalar v");
let lr = batch.lane(r).0;
assert_eq!(lr.v.to_bits(), g.v.to_bits(), "K={K} batch lane {r} v");
for a in 0..K {
assert_eq!(
scal[r].g[a].to_bits(),
g.g[a].to_bits(),
"K={K} scalar g[{a}]"
);
assert_eq!(
lr.g[a].to_bits(),
g.g[a].to_bits(),
"K={K} batch lane {r} g[{a}]"
);
for b in 0..K {
assert_eq!(
scal[r].h[a][b].to_bits(),
g.h[a][b].to_bits(),
"K={K} scalar h[{a}][{b}]"
);
assert_eq!(
lr.h[a][b].to_bits(),
g.h[a][b].to_bits(),
"K={K} batch lane {r} h[{a}][{b}]"
);
}
}
verified_rows += 1;
}
}
verified_rows
}
#[test]
fn batch_lanes_bit_identical_to_scalar_per_row() {
let mut state = 0x9E37_79B9_7F4A_7C15_u64;
let mut verified = 0usize;
verified += check_k::<2>(&mut state, 2000);
verified += check_k::<3>(&mut state, 2000);
verified += check_k::<4>(&mut state, 2000);
verified += check_k::<9>(&mut state, 2000);
assert_eq!(verified, 4 * 2000 * 4, "every batch row must be verified");
}
impl<const K: usize> RowAlg<K> for OneSeed<K> {
fn constant(c: f64) -> Self {
<Self as JetScalar<K>>::constant(c)
}
fn add(&self, o: &Self) -> Self {
JetScalar::add(self, o)
}
fn sub(&self, o: &Self) -> Self {
JetScalar::sub(self, o)
}
fn mul(&self, o: &Self) -> Self {
JetScalar::mul(self, o)
}
fn scale(&self, s: f64) -> Self {
JetScalar::scale(self, s)
}
fn exp(&self) -> Self {
JetScalar::exp(self)
}
fn sqrt(&self) -> Self {
JetScalar::sqrt(self)
}
fn recip(&self) -> Self {
JetScalar::recip(self)
}
}
impl<L: Lane, const K: usize> RowAlg<K> for OneSeedLane<L, K> {
fn constant(c: f64) -> Self {
OneSeedLane::constant(L::splat(c))
}
fn add(&self, o: &Self) -> Self {
OneSeedLane::add(self, o)
}
fn sub(&self, o: &Self) -> Self {
OneSeedLane::sub(self, o)
}
fn mul(&self, o: &Self) -> Self {
OneSeedLane::mul(self, o)
}
fn scale(&self, s: f64) -> Self {
OneSeedLane::scale(self, s)
}
fn exp(&self) -> Self {
OneSeedLane::exp(self)
}
fn sqrt(&self) -> Self {
OneSeedLane::sqrt(self)
}
fn recip(&self) -> Self {
OneSeedLane::recip(self)
}
}
impl<const K: usize> RowAlg<K> for TwoSeed<K> {
fn constant(c: f64) -> Self {
<Self as JetScalar<K>>::constant(c)
}
fn add(&self, o: &Self) -> Self {
JetScalar::add(self, o)
}
fn sub(&self, o: &Self) -> Self {
JetScalar::sub(self, o)
}
fn mul(&self, o: &Self) -> Self {
JetScalar::mul(self, o)
}
fn scale(&self, s: f64) -> Self {
JetScalar::scale(self, s)
}
fn exp(&self) -> Self {
JetScalar::exp(self)
}
fn sqrt(&self) -> Self {
JetScalar::sqrt(self)
}
fn recip(&self) -> Self {
JetScalar::recip(self)
}
}
impl<L: Lane, const K: usize> RowAlg<K> for TwoSeedLane<L, K> {
fn constant(c: f64) -> Self {
TwoSeedLane::constant(L::splat(c))
}
fn add(&self, o: &Self) -> Self {
TwoSeedLane::add(self, o)
}
fn sub(&self, o: &Self) -> Self {
TwoSeedLane::sub(self, o)
}
fn mul(&self, o: &Self) -> Self {
TwoSeedLane::mul(self, o)
}
fn scale(&self, s: f64) -> Self {
TwoSeedLane::scale(self, s)
}
fn exp(&self) -> Self {
TwoSeedLane::exp(self)
}
fn sqrt(&self) -> Self {
TwoSeedLane::sqrt(self)
}
fn recip(&self) -> Self {
TwoSeedLane::recip(self)
}
}
fn check_oneseed<const K: usize>(state: &mut u64, batches: usize) -> usize {
let mut rows_checked = 0;
for _ in 0..batches {
let rows: [[f64; K]; 4] =
std::array::from_fn(|_| std::array::from_fn(|_| rand_unit(state)));
let u: [[f64; K]; 4] =
std::array::from_fn(|_| std::array::from_fn(|_| rand_unit(state)));
let prod: [OneSeed<K>; 4] = std::array::from_fn(|r| {
let p: [OneSeed<K>; K] =
std::array::from_fn(|a| OneSeed::seed_direction(rows[r][a], a, u[r][a]));
row_expr(&p)
});
let scal: [OneSeedLane<f64, K>; 4] = std::array::from_fn(|r| {
let p: [OneSeedLane<f64, K>; K] =
std::array::from_fn(|a| OneSeedLane::seed_direction(rows[r][a], a, u[r][a]));
row_expr(&p)
});
let pbatch: [OneSeedBatch<K>; K] = std::array::from_fn(|a| {
let val = wide::f64x4::new([rows[0][a], rows[1][a], rows[2][a], rows[3][a]]);
let uu = wide::f64x4::new([u[0][a], u[1][a], u[2][a], u[3][a]]);
OneSeedBatch::seed_direction(val, a, uu)
});
let batch = row_expr(&pbatch);
for r in 0..4 {
let want = prod[r].contracted_third();
let got_scal = scal[r].contracted_third();
let got_batch = batch.lane(r).contracted_third();
assert_eq!(
scal[r].base.v.to_bits(),
prod[r].base.value().to_bits(),
"OneSeed K={K} scalar value"
);
assert_eq!(
batch.lane(r).base.value().to_bits(),
prod[r].base.value().to_bits(),
"OneSeed K={K} batch lane {r} value"
);
for a in 0..K {
for b in 0..K {
assert_eq!(
got_scal[a][b].to_bits(),
want[a][b].to_bits(),
"OneSeed K={K} scalar third[{a}][{b}]"
);
assert_eq!(
got_batch[a][b].to_bits(),
want[a][b].to_bits(),
"OneSeed K={K} batch lane {r} third[{a}][{b}]"
);
}
}
rows_checked += 1;
}
}
rows_checked
}
fn check_twoseed<const K: usize>(state: &mut u64, batches: usize) -> usize {
let mut rows_checked = 0;
for _ in 0..batches {
let rows: [[f64; K]; 4] =
std::array::from_fn(|_| std::array::from_fn(|_| rand_unit(state)));
let u: [[f64; K]; 4] =
std::array::from_fn(|_| std::array::from_fn(|_| rand_unit(state)));
let v: [[f64; K]; 4] =
std::array::from_fn(|_| std::array::from_fn(|_| rand_unit(state)));
let prod: [TwoSeed<K>; 4] = std::array::from_fn(|r| {
let p: [TwoSeed<K>; K] =
std::array::from_fn(|a| TwoSeed::seed(rows[r][a], a, u[r][a], v[r][a]));
row_expr(&p)
});
let scal: [TwoSeedLane<f64, K>; 4] = std::array::from_fn(|r| {
let p: [TwoSeedLane<f64, K>; K] =
std::array::from_fn(|a| TwoSeedLane::seed(rows[r][a], a, u[r][a], v[r][a]));
row_expr(&p)
});
let pbatch: [TwoSeedBatch<K>; K] = std::array::from_fn(|a| {
let val = wide::f64x4::new([rows[0][a], rows[1][a], rows[2][a], rows[3][a]]);
let uu = wide::f64x4::new([u[0][a], u[1][a], u[2][a], u[3][a]]);
let vv = wide::f64x4::new([v[0][a], v[1][a], v[2][a], v[3][a]]);
TwoSeedBatch::seed(val, a, uu, vv)
});
let batch = row_expr(&pbatch);
for r in 0..4 {
let want = prod[r].contracted_fourth();
let got_scal = scal[r].contracted_fourth();
let got_batch = batch.lane(r).contracted_fourth();
assert_eq!(
scal[r].base.v.to_bits(),
prod[r].base.value().to_bits(),
"TwoSeed K={K} scalar value"
);
assert_eq!(
batch.lane(r).base.value().to_bits(),
prod[r].base.value().to_bits(),
"TwoSeed K={K} batch lane {r} value"
);
for a in 0..K {
for b in 0..K {
assert_eq!(
got_scal[a][b].to_bits(),
want[a][b].to_bits(),
"TwoSeed K={K} scalar fourth[{a}][{b}]"
);
assert_eq!(
got_batch[a][b].to_bits(),
want[a][b].to_bits(),
"TwoSeed K={K} batch lane {r} fourth[{a}][{b}]"
);
}
}
rows_checked += 1;
}
}
rows_checked
}
#[test]
fn oneseed_lanes_contracted_third_bit_identical() {
let mut state = 0x1234_5678_9ABC_DEF0_u64;
let batches = 2000;
let rows_checked = check_oneseed::<2>(&mut state, batches)
+ check_oneseed::<3>(&mut state, batches)
+ check_oneseed::<4>(&mut state, batches)
+ check_oneseed::<9>(&mut state, batches);
assert_eq!(rows_checked, 4 * batches * 4);
}
#[test]
fn twoseed_lanes_contracted_fourth_bit_identical() {
let mut state = 0x0FED_CBA9_8765_4321_u64;
let batches = 2000;
let rows_checked = check_twoseed::<2>(&mut state, batches)
+ check_twoseed::<3>(&mut state, batches)
+ check_twoseed::<4>(&mut state, batches)
+ check_twoseed::<9>(&mut state, batches);
assert_eq!(rows_checked, 4 * batches * 4);
}
}
#[cfg(test)]
mod unit_tests {
use super::{JetScalar, Order1, Order2, filtered_implicit_solve_scalar};
#[test]
fn order2_constant_has_zero_derivatives() {
let s = Order2::<3>::constant(7.5);
assert_eq!(s.value(), 7.5);
for a in 0..3 {
assert_eq!(s.g()[a], 0.0, "grad[{a}] should be zero");
for b in 0..3 {
assert_eq!(s.h()[a][b], 0.0, "hess[{a}][{b}] should be zero");
}
}
}
#[test]
fn order2_variable_has_unit_gradient_in_seeded_slot() {
let x = -2.5_f64;
let s = Order2::<4>::variable(x, 2);
assert_eq!(s.value(), x);
for a in 0..4 {
let expected_g = if a == 2 { 1.0 } else { 0.0 };
assert_eq!(s.g()[a], expected_g, "grad[{a}]");
for b in 0..4 {
assert_eq!(s.h()[a][b], 0.0, "hess[{a}][{b}] should be zero");
}
}
}
#[test]
fn order2_add_sub_roundtrip() {
let p = Order2::<2>::variable(3.0, 0);
let q = Order2::<2>::variable(2.0, 1);
let pq = JetScalar::add(&p, &q);
assert_eq!(pq.value(), 5.0, "add value");
let back = JetScalar::sub(&pq, &q);
for a in 0..2 {
assert_eq!(back.g()[a], p.g()[a], "grad[{a}] roundtrip");
}
}
#[test]
fn order2_mul_satisfies_leibniz_rule() {
let pv = 3.0_f64;
let qv = -2.0_f64;
let p = Order2::<2>::variable(pv, 0);
let q = Order2::<2>::variable(qv, 1);
let pq = JetScalar::mul(&p, &q);
assert_eq!(pq.value(), pv * qv, "value = p·q");
assert_eq!(pq.g()[0], qv, "∂(p·q)/∂p = q");
assert_eq!(pq.g()[1], pv, "∂(p·q)/∂q = p");
assert_eq!(pq.h()[0][1], 1.0, "∂²(p·q)/∂p∂q = 1");
assert_eq!(pq.h()[1][0], 1.0, "∂²(p·q)/∂q∂p = 1 (symmetric)");
assert_eq!(pq.h()[0][0], 0.0, "∂²(p·q)/∂p² = 0");
assert_eq!(pq.h()[1][1], 0.0, "∂²(p·q)/∂q² = 0");
}
#[test]
fn order2_scale_multiplies_all_channels() {
let p = Order2::<2>::variable(4.0, 0);
let s = 2.5_f64;
let ps = JetScalar::scale(&p, s);
assert_eq!(ps.value(), 4.0 * s);
assert_eq!(ps.g()[0], 1.0 * s);
assert_eq!(ps.g()[1], 0.0);
}
#[test]
fn order2_exp_derivative_stack_correct() {
let p0 = 1.0_f64;
let p = Order2::<1>::variable(p0, 0);
let ep = JetScalar::exp(&p);
let e = p0.exp();
assert!((ep.value() - e).abs() < 1e-15, "exp value");
assert!((ep.g()[0] - e).abs() < 1e-15, "d/dp exp(p) = exp(p)");
assert!((ep.h()[0][0] - e).abs() < 1e-15, "d²/dp² exp(p) = exp(p)");
}
#[test]
fn order2_ln_derivative_stack_correct() {
let p0 = 2.0_f64;
let p = Order2::<1>::variable(p0, 0);
let lnp = JetScalar::ln(&p);
assert!((lnp.value() - p0.ln()).abs() < 1e-15, "ln value");
assert!((lnp.g()[0] - 1.0 / p0).abs() < 1e-15, "d/dp ln(p) = 1/p");
assert!((lnp.h()[0][0] - (-1.0 / (p0 * p0))).abs() < 1e-15, "d²/dp² ln(p) = -1/p²");
}
#[test]
fn order2_exp_ln_roundtrip_at_value() {
let p0 = 0.8_f64;
let p = Order2::<1>::variable(p0, 0);
let roundtrip = JetScalar::ln(&JetScalar::exp(&p));
assert!((roundtrip.value() - p0).abs() < 1e-14, "ln(exp(p)) ≈ p");
}
#[test]
fn order1_constant_has_zero_gradient() {
let s = Order1::<3>::constant(-5.0);
assert_eq!(s.value(), -5.0);
for a in 0..3 {
assert_eq!(s.g()[a], 0.0, "g[{a}] should be zero");
}
}
#[test]
fn order1_variable_has_unit_gradient_in_seeded_slot() {
let s = Order1::<3>::variable(2.0, 1);
assert_eq!(s.value(), 2.0);
assert_eq!(s.g()[0], 0.0);
assert_eq!(s.g()[1], 1.0);
assert_eq!(s.g()[2], 0.0);
}
#[test]
fn order1_mul_satisfies_product_rule() {
let pv = 3.0_f64;
let qv = -2.0_f64;
let p = Order1::<2>::variable(pv, 0);
let q = Order1::<2>::variable(qv, 1);
let pq = JetScalar::mul(&p, &q);
assert_eq!(pq.value(), pv * qv);
assert_eq!(pq.g()[0], qv, "∂(p·q)/∂p = q");
assert_eq!(pq.g()[1], pv, "∂(p·q)/∂q = p");
}
#[test]
fn order1_exp_has_correct_value_and_gradient() {
let p0 = 0.5_f64;
let p = Order1::<2>::variable(p0, 0);
let ep = JetScalar::exp(&p);
let e = p0.exp();
assert!((ep.value() - e).abs() < 1e-15, "exp value");
assert!((ep.g()[0] - e).abs() < 1e-15, "d/dp exp(p)");
assert_eq!(ep.g()[1], 0.0, "irrelevant gradient slot is zero");
}
#[test]
fn order1_and_order2_agree_on_value_and_gradient() {
let p0 = 1.3_f64;
let q0 = -0.7_f64;
let p1 = Order1::<2>::variable(p0, 0);
let q1 = Order1::<2>::variable(q0, 1);
let expr1 = JetScalar::exp(&JetScalar::add(&JetScalar::mul(&p1, &q1), &p1));
let p2 = Order2::<2>::variable(p0, 0);
let q2 = Order2::<2>::variable(q0, 1);
let expr2 = JetScalar::exp(&JetScalar::add(&JetScalar::mul(&p2, &q2), &p2));
assert!((expr1.value() - expr2.value()).abs() < 1e-14, "value mismatch");
for a in 0..2 {
assert!(
(expr1.g()[a] - expr2.g()[a]).abs() < 1e-14,
"gradient[{a}] mismatch"
);
}
}
#[test]
fn filtered_implicit_solve_linear_constraint_gives_exact_jet() {
let theta0 = 3.0_f64;
let theta = Order2::<1>::variable(theta0, 0);
let a = filtered_implicit_solve_scalar::<1, Order2<1>>(
theta0,
1.0,
2,
|a_jet| JetScalar::sub(a_jet, &theta),
);
assert!((a.value() - theta0).abs() < 1e-14, "value = theta0");
assert!((a.g()[0] - 1.0).abs() < 1e-14, "gradient = 1");
assert!(a.h()[0][0].abs() < 1e-14, "hessian = 0");
}
#[test]
fn filtered_implicit_solve_quadratic_constraint_matches_analytic_derivatives() {
let theta0 = 4.0_f64;
let a0 = theta0.sqrt();
let inv_fa = 1.0 / (2.0 * a0);
let theta = Order2::<1>::variable(theta0, 0);
let a = filtered_implicit_solve_scalar::<1, Order2<1>>(a0, inv_fa, 2, |a_jet| {
let aa = JetScalar::mul(a_jet, a_jet);
JetScalar::sub(&aa, &theta)
});
let tol = 1e-12;
assert!((a.value() - a0).abs() < tol, "value = sqrt(theta0)");
let expected_g = 0.5 / a0;
assert!((a.g()[0] - expected_g).abs() < tol, "da/dtheta = 1/(2*sqrt)");
let expected_h = -0.25 / (theta0 * a0);
assert!((a.h()[0][0] - expected_h).abs() < tol, "d2a/dtheta2 = -1/(4*theta^1.5)");
}
}