use crate::jet_algebra;
#[derive(Clone, Copy, Debug)]
pub struct Tower4<const K: usize> {
pub v: f64,
pub g: [f64; K],
pub h: [[f64; K]; K],
pub t3: [[[f64; K]; K]; K],
pub t4: [[[[f64; K]; K]; K]; K],
}
impl<const K: usize> Tower4<K> {
pub fn zero() -> Self {
Self {
v: 0.0,
g: [0.0; K],
h: [[0.0; K]; K],
t3: [[[0.0; K]; K]; K],
t4: [[[[0.0; K]; K]; K]; K],
}
}
pub fn constant(c: f64) -> Self {
let mut out = Self::zero();
out.v = c;
out
}
pub fn variable(value: f64, idx: usize) -> Self {
let mut out = Self::constant(value);
out.g[idx] = 1.0;
out
}
#[inline]
fn deriv(&self, labels: &[usize]) -> f64 {
assert!(
labels.len() <= 4,
"Tower4 carries at most fourth-order derivatives"
);
match labels.len() {
0 => self.v,
1 => self.g[labels[0]],
2 => self.h[labels[0]][labels[1]],
3 => self.t3[labels[0]][labels[1]][labels[2]],
_ => self.t4[labels[0]][labels[1]][labels[2]][labels[3]],
}
}
pub fn mul(&self, o: &Self) -> Self {
let a = self;
let b = o;
let mut out = Self::zero();
out.v = a.v * b.v;
for i in 0..K {
let mut acc = 0.0;
acc += a.v * b.g[i];
acc += a.g[i] * b.v;
out.g[i] = acc;
}
for i in 0..K {
for j in 0..K {
let mut acc = 0.0;
acc += a.v * b.h[i][j];
acc += a.g[i] * b.g[j];
acc += a.g[j] * b.g[i];
acc += a.h[i][j] * b.v;
out.h[i][j] = acc;
}
}
for i in 0..K {
for j in 0..K {
for k in 0..K {
let mut acc = 0.0;
acc += a.v * b.t3[i][j][k];
acc += a.g[i] * b.h[j][k];
acc += a.g[j] * b.h[i][k];
acc += a.h[i][j] * b.g[k];
acc += a.g[k] * b.h[i][j];
acc += a.h[i][k] * b.g[j];
acc += a.h[j][k] * b.g[i];
acc += a.t3[i][j][k] * b.v;
out.t3[i][j][k] = acc;
}
}
}
for i in 0..K {
for j in 0..K {
for k in 0..K {
for l in 0..K {
let mut acc = 0.0;
acc += a.v * b.t4[i][j][k][l];
acc += a.g[i] * b.t3[j][k][l];
acc += a.g[j] * b.t3[i][k][l];
acc += a.h[i][j] * b.h[k][l];
acc += a.g[k] * b.t3[i][j][l];
acc += a.h[i][k] * b.h[j][l];
acc += a.h[j][k] * b.h[i][l];
acc += a.t3[i][j][k] * b.g[l];
acc += a.g[l] * b.t3[i][j][k];
acc += a.h[i][l] * b.h[j][k];
acc += a.h[j][l] * b.h[i][k];
acc += a.t3[i][j][l] * b.g[k];
acc += a.h[k][l] * b.h[i][j];
acc += a.t3[i][k][l] * b.g[j];
acc += a.t3[j][k][l] * b.g[i];
acc += a.t4[i][j][k][l] * b.v;
out.t4[i][j][k][l] = acc;
}
}
}
}
out
}
pub fn add(&self, o: &Self) -> Self {
*self + *o
}
pub fn sub(&self, o: &Self) -> Self {
*self + o.scale(-1.0)
}
pub fn compose_unary(&self, d: [f64; 5]) -> Self {
let mut out = Self::zero();
out.v = d[0];
for i in 0..K {
let mut acc = 0.0;
acc += d[1] * self.g[i];
out.g[i] = acc;
}
for i in 0..K {
for j in 0..K {
let mut acc = 0.0;
acc += d[1] * self.h[i][j];
acc += d[2] * self.g[i] * self.g[j];
out.h[i][j] = acc;
}
}
for i in 0..K {
for j in 0..K {
for k in 0..K {
let mut acc = 0.0;
acc += d[1] * self.t3[i][j][k];
acc += d[2] * self.h[i][j] * self.g[k];
acc += d[2] * self.h[i][k] * self.g[j];
acc += d[2] * self.g[i] * self.h[j][k];
acc += d[3] * self.g[i] * self.g[j] * self.g[k];
out.t3[i][j][k] = acc;
}
}
}
for i in 0..K {
for j in 0..K {
for k in 0..K {
for l in 0..K {
let mut acc = 0.0;
acc += d[1] * self.t4[i][j][k][l];
acc += d[2] * self.t3[i][j][k] * self.g[l];
acc += d[2] * self.t3[i][j][l] * self.g[k];
acc += d[2] * self.h[i][j] * self.h[k][l];
acc += d[3] * self.h[i][j] * self.g[k] * self.g[l];
acc += d[2] * self.t3[i][k][l] * self.g[j];
acc += d[2] * self.h[i][k] * self.h[j][l];
acc += d[3] * self.h[i][k] * self.g[j] * self.g[l];
acc += d[2] * self.h[i][l] * self.h[j][k];
acc += d[2] * self.g[i] * self.t3[j][k][l];
acc += d[3] * self.g[i] * self.h[j][k] * self.g[l];
acc += d[3] * self.h[i][l] * self.g[j] * self.g[k];
acc += d[3] * self.g[i] * self.h[j][l] * self.g[k];
acc += d[3] * self.g[i] * self.g[j] * self.h[k][l];
acc += d[4] * self.g[i] * self.g[j] * self.g[k] * self.g[l];
out.t4[i][j][k][l] = acc;
}
}
}
}
out
}
#[inline]
pub fn compose_unary_with(&self, stack_fn: impl Fn(f64) -> [f64; 5]) -> Self {
self.compose_unary(stack_fn(self.v))
}
#[inline]
pub fn compose_unary_single_slot(&self, d: [f64; 5], slot: usize) -> Self {
let mut out = Self::zero();
let s = slot;
let g = self.g[s];
let h = self.h[s][s];
let t3 = self.t3[s][s][s];
let t4 = self.t4[s][s][s][s];
out.v = d[0];
out.g[s] = {
let mut acc = 0.0;
acc += d[1] * g;
acc
};
out.h[s][s] = {
let mut acc = 0.0;
acc += d[1] * h;
acc += d[2] * g * g;
acc
};
out.t3[s][s][s] = {
let mut acc = 0.0;
acc += d[1] * t3;
acc += d[2] * h * g;
acc += d[2] * h * g;
acc += d[2] * g * h;
acc += d[3] * g * g * g;
acc
};
out.t4[s][s][s][s] = {
let mut acc = 0.0;
acc += d[1] * t4;
acc += d[2] * t3 * g;
acc += d[2] * t3 * g;
acc += d[2] * h * h;
acc += d[3] * h * g * g;
acc += d[2] * t3 * g;
acc += d[2] * h * h;
acc += d[3] * h * g * g;
acc += d[2] * h * h;
acc += d[2] * g * t3;
acc += d[3] * g * h * g;
acc += d[3] * h * g * g;
acc += d[3] * g * h * g;
acc += d[3] * g * g * h;
acc += d[4] * g * g * g * g;
acc
};
out
}
pub fn scale(&self, s: f64) -> Self {
let mut out = *self;
out.v *= s;
for i in 0..K {
out.g[i] *= s;
for j in 0..K {
out.h[i][j] *= s;
for k in 0..K {
out.t3[i][j][k] *= s;
for l in 0..K {
out.t4[i][j][k][l] *= s;
}
}
}
}
out
}
pub fn exp(&self) -> Self {
let e = self.v.exp();
self.compose_unary([e, e, e, e, e])
}
pub fn ln(&self) -> Self {
let u = self.v;
let r = 1.0 / u;
self.compose_unary([u.ln(), r, -r * r, 2.0 * r * r * r, -6.0 * r * r * r * r])
}
pub fn recip(&self) -> Self {
let r = 1.0 / self.v;
let r2 = r * r;
self.compose_unary([r, -r2, 2.0 * r2 * r, -6.0 * r2 * r2, 24.0 * r2 * r2 * r])
}
pub fn sqrt(&self) -> Self {
let u = self.v;
let s = u.sqrt();
self.compose_unary([
s,
0.5 / s,
-0.25 / (u * s),
0.375 / (u * u * s),
-0.9375 / (u * u * u * s),
])
}
pub fn powf(&self, a: f64) -> Self {
let u = self.v;
let f0 = u.powf(a);
let f1 = a * u.powf(a - 1.0);
let f2 = a * (a - 1.0) * u.powf(a - 2.0);
let f3 = a * (a - 1.0) * (a - 2.0) * u.powf(a - 3.0);
let f4 = a * (a - 1.0) * (a - 2.0) * (a - 3.0) * u.powf(a - 4.0);
self.compose_unary([f0, f1, f2, f3, f4])
}
pub fn ln_gamma(&self) -> Self {
self.compose_unary(ln_gamma_derivative_stack(self.v))
}
pub fn digamma(&self) -> Self {
self.compose_unary(digamma_derivative_stack(self.v))
}
pub fn trigamma(&self) -> Self {
self.compose_unary(trigamma_derivative_stack(self.v))
}
pub fn third_contracted(&self, dir: &[f64; K]) -> [[f64; K]; K] {
let mut out = [[0.0; K]; K];
for a in 0..K {
for b in 0..K {
let mut acc = 0.0;
for c in 0..K {
acc += self.t3[a][b][c] * dir[c];
}
out[a][b] = acc;
}
}
out
}
pub fn fourth_contracted(&self, u: &[f64; K], w: &[f64; K]) -> [[f64; K]; K] {
let mut out = [[0.0; K]; K];
for i in 0..K {
for j in 0..K {
let mut acc = 0.0;
for k in 0..K {
for l in 0..K {
acc += self.t4[i][j][k][l] * u[k] * w[l];
}
}
out[i][j] = acc;
}
}
out
}
}
impl<const K: usize> jet_algebra::JetAlgebra<5> for Tower4<K> {
#[inline]
fn derivative(&self, labels: &[usize]) -> f64 {
self.deriv(labels)
}
fn map_derivatives<F>(&self, mut f: F) -> Self
where
F: FnMut(&[usize]) -> f64,
{
let mut out = Self::zero();
out.v = f(&[]);
for i in 0..K {
let labels = [i];
out.g[i] = f(&labels);
}
for i in 0..K {
for j in 0..K {
let labels = [i, j];
out.h[i][j] = f(&labels);
}
}
for i in 0..K {
for j in 0..K {
for k in 0..K {
let labels = [i, j, k];
out.t3[i][j][k] = f(&labels);
}
}
}
for i in 0..K {
for j in 0..K {
for k in 0..K {
for l in 0..K {
let labels = [i, j, k, l];
out.t4[i][j][k][l] = f(&labels);
}
}
}
}
out
}
}
#[derive(Clone, Copy, Debug)]
pub struct Tower2<const K: usize> {
pub v: f64,
pub g: [f64; K],
pub h: [[f64; K]; K],
}
impl<const K: usize> Tower2<K> {
pub fn zero() -> Self {
Self {
v: 0.0,
g: [0.0; K],
h: [[0.0; K]; K],
}
}
pub fn constant(c: f64) -> Self {
let mut out = Self::zero();
out.v = c;
out
}
pub fn variable(value: f64, idx: usize) -> Self {
let mut out = Self::constant(value);
out.g[idx] = 1.0;
out
}
#[inline]
fn deriv(&self, labels: &[usize]) -> f64 {
assert!(
labels.len() <= 2,
"Tower2 carries at most second-order derivatives"
);
match labels.len() {
0 => self.v,
1 => self.g[labels[0]],
_ => self.h[labels[0]][labels[1]],
}
}
pub fn mul(&self, o: &Self) -> Self {
let a = self;
let b = o;
let mut out = Self::zero();
out.v = a.v * b.v;
for i in 0..K {
out.g[i] = a.v * b.g[i] + a.g[i] * b.v;
}
for i in 0..K {
for j in 0..K {
out.h[i][j] = a.v * b.h[i][j] + a.g[i] * b.g[j] + a.g[j] * b.g[i] + a.h[i][j] * b.v;
}
}
out
}
pub fn compose_unary(&self, d: [f64; 3]) -> Self {
let mut out = Self::zero();
out.v = d[0];
for i in 0..K {
let mut acc = 0.0;
acc += d[1] * self.g[i];
out.g[i] = acc;
}
for i in 0..K {
for j in 0..K {
let mut acc = 0.0;
acc += d[1] * self.h[i][j];
acc += d[2] * self.g[i] * self.g[j];
out.h[i][j] = acc;
}
}
out
}
pub fn scale(&self, s: f64) -> Self {
let mut out = *self;
out.v *= s;
for i in 0..K {
out.g[i] *= s;
for j in 0..K {
out.h[i][j] *= s;
}
}
out
}
pub fn exp(&self) -> Self {
let e = self.v.exp();
self.compose_unary([e, e, e])
}
pub fn sqrt(&self) -> Self {
let u = self.v;
let s = u.sqrt();
self.compose_unary([s, 0.5 / s, -0.25 / (u * s)])
}
}
impl<const K: usize> jet_algebra::JetAlgebra<3> for Tower2<K> {
#[inline]
fn derivative(&self, labels: &[usize]) -> f64 {
self.deriv(labels)
}
fn map_derivatives<F>(&self, mut f: F) -> Self
where
F: FnMut(&[usize]) -> f64,
{
let mut out = Self::zero();
out.v = f(&[]);
for i in 0..K {
let labels = [i];
out.g[i] = f(&labels);
}
for i in 0..K {
for j in 0..K {
let labels = [i, j];
out.h[i][j] = f(&labels);
}
}
out
}
}
impl<const K: usize> std::ops::Add for Tower2<K> {
type Output = Self;
fn add(self, o: Self) -> Self {
let mut out = self;
out.v += o.v;
for i in 0..K {
out.g[i] += o.g[i];
for j in 0..K {
out.h[i][j] += o.h[i][j];
}
}
out
}
}
impl<const K: usize> std::ops::Mul for Tower2<K> {
type Output = Self;
fn mul(self, o: Self) -> Self {
Tower2::mul(&self, &o)
}
}
impl<const K: usize> std::ops::Add<f64> for Tower2<K> {
type Output = Self;
fn add(self, c: f64) -> Self {
let mut out = self;
out.v += c;
out
}
}
impl<const K: usize> std::ops::Mul<f64> for Tower2<K> {
type Output = Self;
fn mul(self, c: f64) -> Self {
self.scale(c)
}
}
#[derive(Clone, Copy, Debug)]
pub struct Tower3<const K: usize> {
pub v: f64,
pub g: [f64; K],
pub h: [[f64; K]; K],
pub t3: [[[f64; K]; K]; K],
}
impl<const K: usize> Tower3<K> {
pub fn zero() -> Self {
Self {
v: 0.0,
g: [0.0; K],
h: [[0.0; K]; K],
t3: [[[0.0; K]; K]; K],
}
}
pub fn constant(c: f64) -> Self {
let mut out = Self::zero();
out.v = c;
out
}
pub fn variable(value: f64, idx: usize) -> Self {
let mut out = Self::constant(value);
out.g[idx] = 1.0;
out
}
#[inline]
fn deriv(&self, labels: &[usize]) -> f64 {
assert!(
labels.len() <= 3,
"Tower3 carries at most third-order derivatives"
);
match labels.len() {
0 => self.v,
1 => self.g[labels[0]],
2 => self.h[labels[0]][labels[1]],
_ => self.t3[labels[0]][labels[1]][labels[2]],
}
}
pub fn mul(&self, o: &Self) -> Self {
let a = self;
let b = o;
let mut out = Self::zero();
out.v = a.v * b.v;
for i in 0..K {
let mut acc = 0.0;
acc += a.v * b.g[i];
acc += a.g[i] * b.v;
out.g[i] = acc;
}
for i in 0..K {
for j in 0..K {
let mut acc = 0.0;
acc += a.v * b.h[i][j];
acc += a.g[i] * b.g[j];
acc += a.g[j] * b.g[i];
acc += a.h[i][j] * b.v;
out.h[i][j] = acc;
}
}
for i in 0..K {
for j in 0..K {
for k in 0..K {
let mut acc = 0.0;
acc += a.v * b.t3[i][j][k];
acc += a.g[i] * b.h[j][k];
acc += a.g[j] * b.h[i][k];
acc += a.h[i][j] * b.g[k];
acc += a.g[k] * b.h[i][j];
acc += a.h[i][k] * b.g[j];
acc += a.h[j][k] * b.g[i];
acc += a.t3[i][j][k] * b.v;
out.t3[i][j][k] = acc;
}
}
}
out
}
pub fn add(&self, o: &Self) -> Self {
*self + *o
}
pub fn sub(&self, o: &Self) -> Self {
*self + o.scale(-1.0)
}
pub fn compose_unary(&self, d: [f64; 4]) -> Self {
let mut out = Self::zero();
out.v = d[0];
for i in 0..K {
let mut acc = 0.0;
acc += d[1] * self.g[i];
out.g[i] = acc;
}
for i in 0..K {
for j in 0..K {
let mut acc = 0.0;
acc += d[1] * self.h[i][j];
acc += d[2] * self.g[i] * self.g[j];
out.h[i][j] = acc;
}
}
for i in 0..K {
for j in 0..K {
for k in 0..K {
let mut acc = 0.0;
acc += d[1] * self.t3[i][j][k];
acc += d[2] * self.h[i][j] * self.g[k];
acc += d[2] * self.h[i][k] * self.g[j];
acc += d[2] * self.g[i] * self.h[j][k];
acc += d[3] * self.g[i] * self.g[j] * self.g[k];
out.t3[i][j][k] = acc;
}
}
}
out
}
#[inline]
pub fn compose_unary_with(&self, stack_fn: impl Fn(f64) -> [f64; 4]) -> Self {
self.compose_unary(stack_fn(self.v))
}
#[inline]
pub fn compose_unary_single_slot(&self, d: [f64; 4], slot: usize) -> Self {
let mut out = Self::zero();
let s = slot;
let g = self.g[s];
let h = self.h[s][s];
let t3 = self.t3[s][s][s];
out.v = d[0];
out.g[s] = {
let mut acc = 0.0;
acc += d[1] * g;
acc
};
out.h[s][s] = {
let mut acc = 0.0;
acc += d[1] * h;
acc += d[2] * g * g;
acc
};
out.t3[s][s][s] = {
let mut acc = 0.0;
acc += d[1] * t3;
acc += d[2] * h * g;
acc += d[2] * h * g;
acc += d[2] * g * h;
acc += d[3] * g * g * g;
acc
};
out
}
pub fn scale(&self, s: f64) -> Self {
let mut out = *self;
out.v *= s;
for i in 0..K {
out.g[i] *= s;
for j in 0..K {
out.h[i][j] *= s;
for k in 0..K {
out.t3[i][j][k] *= s;
}
}
}
out
}
}
impl<const K: usize> jet_algebra::JetAlgebra<4> for Tower3<K> {
#[inline]
fn derivative(&self, labels: &[usize]) -> f64 {
self.deriv(labels)
}
fn map_derivatives<F>(&self, mut f: F) -> Self
where
F: FnMut(&[usize]) -> f64,
{
let mut out = Self::zero();
out.v = f(&[]);
for i in 0..K {
let labels = [i];
out.g[i] = f(&labels);
}
for i in 0..K {
for j in 0..K {
let labels = [i, j];
out.h[i][j] = f(&labels);
}
}
for i in 0..K {
for j in 0..K {
for k in 0..K {
let labels = [i, j, k];
out.t3[i][j][k] = f(&labels);
}
}
}
out
}
}
impl<const K: usize> std::ops::Add for Tower3<K> {
type Output = Self;
fn add(self, o: Self) -> Self {
let mut out = self;
out.v += o.v;
for i in 0..K {
out.g[i] += o.g[i];
for j in 0..K {
out.h[i][j] += o.h[i][j];
for k in 0..K {
out.t3[i][j][k] += o.t3[i][j][k];
}
}
}
out
}
}
pub fn ln_gamma_derivative_stack(x: f64) -> [f64; 5] {
[
statrs::function::gamma::ln_gamma(x),
digamma_positive(x),
polygamma_positive(1, x),
polygamma_positive(2, x),
polygamma_positive(3, x),
]
}
pub fn ln_gamma_derivative_stack_order2(x: f64) -> [f64; 3] {
[
statrs::function::gamma::ln_gamma(x),
digamma_positive(x),
polygamma_positive(1, x),
]
}
pub fn digamma_derivative_stack(x: f64) -> [f64; 5] {
[
digamma_positive(x),
polygamma_positive(1, x),
polygamma_positive(2, x),
polygamma_positive(3, x),
polygamma_positive(4, x),
]
}
pub fn trigamma_derivative_stack(x: f64) -> [f64; 5] {
[
polygamma_positive(1, x),
polygamma_positive(2, x),
polygamma_positive(3, x),
polygamma_positive(4, x),
polygamma_positive(5, x),
]
}
fn digamma_positive(mut x: f64) -> f64 {
if !(x.is_finite() && x > 0.0) {
return f64::NAN;
}
let mut acc = 0.0;
while x < POLYGAMMA_ASYMPTOTIC_MIN_X {
acc -= 1.0 / x;
x += 1.0;
}
acc + digamma_asymptotic(x)
}
fn polygamma_positive(order: usize, mut x: f64) -> f64 {
if !(x.is_finite() && x > 0.0) {
return f64::NAN;
}
let mut acc = 0.0;
while x < POLYGAMMA_ASYMPTOTIC_MIN_X {
acc += polygamma_recurrence_term(order, x);
x += 1.0;
}
acc + polygamma_asymptotic(order, x)
}
const POLYGAMMA_ASYMPTOTIC_MIN_X: f64 = 20.0;
const BERNOULLI_EVEN: [(usize, f64); 10] = [
(2, 1.0 / 6.0),
(4, -1.0 / 30.0),
(6, 1.0 / 42.0),
(8, -1.0 / 30.0),
(10, 5.0 / 66.0),
(12, -691.0 / 2730.0),
(14, 7.0 / 6.0),
(16, -3617.0 / 510.0),
(18, 43867.0 / 798.0),
(20, -174611.0 / 330.0),
];
fn polygamma_recurrence_term(order: usize, x: f64) -> f64 {
let sign = if order % 2 == 1 { 1.0 } else { -1.0 };
sign * factorial(order) / x.powi((order + 1) as i32)
}
fn digamma_asymptotic(x: f64) -> f64 {
let mut out = x.ln() - 0.5 / x;
for (bernoulli_order, bernoulli) in BERNOULLI_EVEN {
out -= bernoulli / (bernoulli_order as f64 * x.powi(bernoulli_order as i32));
}
out
}
fn polygamma_asymptotic(order: usize, x: f64) -> f64 {
if !(1..=5).contains(&order) {
return f64::NAN;
}
let order_factorial = factorial(order);
let leading_sign = if order % 2 == 1 { 1.0 } else { -1.0 };
let mut out = leading_sign * factorial(order - 1) / x.powi(order as i32)
+ leading_sign * order_factorial / (2.0 * x.powi((order + 1) as i32));
let bernoulli_sign = if order % 2 == 1 { 1.0 } else { -1.0 };
for (bernoulli_order, bernoulli) in BERNOULLI_EVEN {
let rising = rising_factorial(bernoulli_order, order);
out += bernoulli_sign * bernoulli * rising
/ bernoulli_order as f64
/ x.powi((bernoulli_order + order) as i32);
}
out
}
fn factorial(n: usize) -> f64 {
(1..=n).fold(1.0, |acc, k| acc * k as f64)
}
fn rising_factorial(start: usize, len: usize) -> f64 {
(start..start + len).fold(1.0, |acc, k| acc * k as f64)
}
impl<const K: usize> std::ops::Add for Tower4<K> {
type Output = Self;
fn add(self, o: Self) -> Self {
let mut out = self;
out.v += o.v;
for i in 0..K {
out.g[i] += o.g[i];
for j in 0..K {
out.h[i][j] += o.h[i][j];
for k in 0..K {
out.t3[i][j][k] += o.t3[i][j][k];
for l in 0..K {
out.t4[i][j][k][l] += o.t4[i][j][k][l];
}
}
}
}
out
}
}
impl<const K: usize> std::ops::Sub for Tower4<K> {
type Output = Self;
fn sub(self, o: Self) -> Self {
self + o.scale(-1.0)
}
}
impl<const K: usize> std::ops::Neg for Tower4<K> {
type Output = Self;
fn neg(self) -> Self {
self.scale(-1.0)
}
}
impl<const K: usize> std::ops::Mul for Tower4<K> {
type Output = Self;
fn mul(self, o: Self) -> Self {
Tower4::mul(&self, &o)
}
}
impl<const K: usize> std::ops::Div for Tower4<K> {
type Output = Self;
fn div(self, o: Self) -> Self {
Tower4::mul(&self, &o.recip())
}
}
impl<const K: usize> std::ops::Add<f64> for Tower4<K> {
type Output = Self;
fn add(self, c: f64) -> Self {
let mut out = self;
out.v += c;
out
}
}
impl<const K: usize> std::ops::Sub<f64> for Tower4<K> {
type Output = Self;
fn sub(self, c: f64) -> Self {
self + (-c)
}
}
impl<const K: usize> std::ops::Mul<f64> for Tower4<K> {
type Output = Self;
fn mul(self, c: f64) -> Self {
self.scale(c)
}
}
pub fn implicit_solve<const K1: usize, const K: usize>(
f: &Tower4<K1>,
a0: f64,
) -> Result<Tower4<K>, String> {
assert_eq!(K1, K + 1, "implicit_solve: constraint must carry K+1 vars");
let f_a = f.g[0];
if f_a == 0.0 || !f_a.is_finite() {
return Err(format!(
"implicit_solve: ∂F/∂a = {f_a:+.3e} is not invertible"
));
}
let root_tol = 1e-9;
if !f.v.is_finite() {
return Err(format!(
"implicit_solve: F(a0, θ0) = {:+.3e} is not finite",
f.v
));
}
let newton_step = f.v.abs() / f_a.abs();
if newton_step > root_tol * (1.0 + a0.abs()) {
return Err(format!(
"implicit_solve: expansion point a0 = {a0:+.6e} is not a root of F: \
F(a0, θ0) = {:+.3e}, Newton correction {newton_step:+.3e} exceeds \
root_tol {root_tol:.1e} · (1 + |a0|)",
f.v
));
}
let mut a = Tower4::<K>::constant(a0);
for order in 1..=4 {
let g = substitute_intercept(f, &a);
match order {
1 => {
for i in 0..K {
a.g[i] -= g.g[i] / f_a;
}
}
2 => {
for i in 0..K {
for j in 0..K {
a.h[i][j] -= g.h[i][j] / f_a;
}
}
}
3 => {
for i in 0..K {
for j in 0..K {
for k in 0..K {
a.t3[i][j][k] -= g.t3[i][j][k] / f_a;
}
}
}
}
_ => {
for i in 0..K {
for j in 0..K {
for k in 0..K {
for l in 0..K {
a.t4[i][j][k][l] -= g.t4[i][j][k][l] / f_a;
}
}
}
}
}
}
}
let g = substitute_intercept(f, &a);
let resid_tol = 1e-7 * (1.0 + f_a.abs());
let mut worst = g.v.abs();
for i in 0..K {
worst = worst.max(g.g[i].abs());
for j in 0..K {
worst = worst.max(g.h[i][j].abs());
for k in 0..K {
worst = worst.max(g.t3[i][j][k].abs());
for l in 0..K {
worst = worst.max(g.t4[i][j][k][l].abs());
}
}
}
}
if !worst.is_finite() || worst > resid_tol {
return Err(format!(
"implicit_solve: composed residual G = F∘a does not vanish: \
worst channel magnitude {worst:+.3e} exceeds tol {resid_tol:.1e}"
));
}
Ok(a)
}
pub fn substitute_intercept<const K1: usize, const K: usize>(
f: &Tower4<K1>,
a: &Tower4<K>,
) -> Tower4<K> {
assert_eq!(K1, K + 1);
let inp: [Tower4<K>; K1] = std::array::from_fn(|slot| {
if slot == 0 {
let mut d = *a;
d.v = 0.0;
d
} else {
let mut d = Tower4::<K>::zero();
d.g[slot - 1] = 1.0;
d
}
});
let mut out = Tower4::<K>::constant(f.v);
for a_idx in 0..K1 {
out = out + inp[a_idx].scale(f.g[a_idx]);
}
for a_idx in 0..K1 {
for b_idx in 0..K1 {
let prod = inp[a_idx].mul(&inp[b_idx]);
out = out + prod.scale(0.5 * f.h[a_idx][b_idx]);
}
}
for a_idx in 0..K1 {
for b_idx in 0..K1 {
for c_idx in 0..K1 {
let prod = inp[a_idx].mul(&inp[b_idx]).mul(&inp[c_idx]);
out = out + prod.scale(f.t3[a_idx][b_idx][c_idx] / 6.0);
}
}
}
for a_idx in 0..K1 {
for b_idx in 0..K1 {
for c_idx in 0..K1 {
for d_idx in 0..K1 {
let prod = inp[a_idx]
.mul(&inp[b_idx])
.mul(&inp[c_idx])
.mul(&inp[d_idx]);
out = out + prod.scale(f.t4[a_idx][b_idx][c_idx][d_idx] / 24.0);
}
}
}
}
out
}
pub fn moving_limit_boundary_tower<const K: usize>(
z_edge: &Tower4<K>,
b_stack: [f64; 4],
) -> Tower4<K> {
z_edge.compose_unary([0.0, b_stack[0], b_stack[1], b_stack[2], b_stack[3]])
}
pub fn cell_moving_boundary_flux_tower<const K: usize>(
z_right: &Tower4<K>,
b_stack_right: [f64; 4],
z_left: &Tower4<K>,
b_stack_left: [f64; 4],
) -> Tower4<K> {
moving_limit_boundary_tower(z_right, b_stack_right)
- moving_limit_boundary_tower(z_left, b_stack_left)
}
pub fn moving_limit_boundary_tower_theta_integrand<const K1: usize, const K: usize>(
phi_jet: &Tower4<K1>,
z_edge: &Tower4<K>,
) -> Tower4<K> {
assert_eq!(
K1,
K + 1,
"moving_limit_boundary_tower_theta_integrand: Φ jet must carry z + K θ-vars"
);
let frozen_edge = Tower4::<K>::constant(z_edge.v);
let full = substitute_intercept(phi_jet, z_edge);
let interior = substitute_intercept(phi_jet, &frozen_edge);
full - interior
}
pub fn cell_moving_boundary_flux_tower_theta_integrand<const K1: usize, const K: usize>(
phi_jet_right: &Tower4<K1>,
z_right: &Tower4<K>,
phi_jet_left: &Tower4<K1>,
z_left: &Tower4<K>,
) -> Tower4<K> {
moving_limit_boundary_tower_theta_integrand(phi_jet_right, z_right)
- moving_limit_boundary_tower_theta_integrand(phi_jet_left, z_left)
}
pub trait RowNllProgram<const K: usize>: Send + Sync {
fn n_rows(&self) -> usize;
fn primaries(&self, row: usize) -> Result<[f64; K], String>;
fn row_nll(&self, row: usize, p: &[Tower4<K>; K]) -> Result<Tower4<K>, String>;
}
pub fn evaluate_program<const K: usize, P: RowNllProgram<K> + ?Sized>(
prog: &P,
row: usize,
) -> Result<Tower4<K>, String> {
let p = prog.primaries(row)?;
let vars: [Tower4<K>; K] = std::array::from_fn(|a| Tower4::variable(p[a], a));
prog.row_nll(row, &vars)
}
pub fn derived_row_kernel<const K: usize, P: RowNllProgram<K> + ?Sized>(
prog: &P,
row: usize,
) -> Result<(f64, [f64; K], [[f64; K]; K]), String> {
let t = evaluate_program(prog, row)?;
Ok((t.v, t.g, t.h))
}
pub fn derived_third_contracted<const K: usize, P: RowNllProgram<K> + ?Sized>(
prog: &P,
row: usize,
dir: &[f64; K],
) -> Result<[[f64; K]; K], String> {
Ok(evaluate_program(prog, row)?.third_contracted(dir))
}
pub fn derived_fourth_contracted<const K: usize, P: RowNllProgram<K> + ?Sized>(
prog: &P,
row: usize,
dir_u: &[f64; K],
dir_v: &[f64; K],
) -> Result<[[f64; K]; K], String> {
Ok(evaluate_program(prog, row)?.fourth_contracted(dir_u, dir_v))
}
pub trait RowNllProgramGeneric<const K: usize>: Send + Sync {
fn n_rows(&self) -> usize;
fn primaries(&self, row: usize) -> Result<[f64; K], String>;
fn row_nll_generic<S: crate::jet_scalar::JetScalar<K>>(
&self,
row: usize,
p: &[S; K],
) -> Result<S, String>;
}
pub fn generic_row_kernel<const K: usize, P: RowNllProgramGeneric<K> + ?Sized>(
prog: &P,
row: usize,
) -> Result<(f64, [f64; K], [[f64; K]; K]), String> {
let base = prog.primaries(row)?;
let vars: [crate::jet_scalar::Order2<K>; K] = std::array::from_fn(|a| {
<crate::jet_scalar::Order2<K> as crate::jet_scalar::JetScalar<K>>::variable(base[a], a)
});
let s = prog.row_nll_generic(row, &vars)?;
Ok((crate::jet_scalar::JetScalar::value(&s), s.g(), s.h()))
}
pub fn generic_third_contracted<const K: usize, P: RowNllProgramGeneric<K> + ?Sized>(
prog: &P,
row: usize,
dir: &[f64; K],
) -> Result<[[f64; K]; K], String> {
let base = prog.primaries(row)?;
let vars: [crate::jet_scalar::OneSeed<K>; K] =
std::array::from_fn(|a| crate::jet_scalar::OneSeed::seed_direction(base[a], a, dir[a]));
let s = prog.row_nll_generic(row, &vars)?;
Ok(s.contracted_third())
}
pub fn generic_fourth_contracted<const K: usize, P: RowNllProgramGeneric<K> + ?Sized>(
prog: &P,
row: usize,
dir_u: &[f64; K],
dir_v: &[f64; K],
) -> Result<[[f64; K]; K], String> {
let base = prog.primaries(row)?;
let vars: [crate::jet_scalar::TwoSeed<K>; K] =
std::array::from_fn(|a| crate::jet_scalar::TwoSeed::seed(base[a], a, dir_u[a], dir_v[a]));
let s = prog.row_nll_generic(row, &vars)?;
Ok(s.contracted_fourth())
}
pub fn generic_full_tower<const K: usize, P: RowNllProgramGeneric<K> + ?Sized>(
prog: &P,
row: usize,
) -> Result<Tower4<K>, String> {
let base = prog.primaries(row)?;
let vars: [Tower4<K>; K] = std::array::from_fn(|a| Tower4::variable(base[a], a));
prog.row_nll_generic(row, &vars)
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct GuardVerdict {
lanes: u8,
failed_mask: u8,
}
impl GuardVerdict {
#[inline]
pub fn scalar(pass: bool) -> Self {
Self { lanes: 1, failed_mask: if pass { 0 } else { 1 } }
}
#[inline]
pub fn lanes4(failed_mask: u8) -> Self {
Self { lanes: 4, failed_mask: failed_mask & 0x0f }
}
#[inline]
pub fn lanes(self) -> usize {
self.lanes as usize
}
#[inline]
pub fn all_pass(self) -> bool {
self.failed_mask == 0
}
#[inline]
pub fn any_failed(self) -> bool {
self.failed_mask != 0
}
#[inline]
pub fn lane_failed(self, i: usize) -> bool {
(self.failed_mask >> i) & 1 == 1
}
#[inline]
pub fn failed_mask(self) -> u8 {
self.failed_mask
}
}
#[inline]
fn resize_stack<const N: usize, const M: usize>(s: [f64; N]) -> [f64; M] {
let mut out = [0.0_f64; M];
let m = N.min(M);
out[..m].copy_from_slice(&s[..m]);
out
}
pub trait RowJet<const K: usize>: Copy {
type Value: Copy;
fn constant(c: f64) -> Self;
fn variable(x: f64, slot: usize) -> Self;
fn values(&self) -> Self::Value;
fn add(&self, o: &Self) -> Self;
fn sub(&self, o: &Self) -> Self;
fn mul(&self, o: &Self) -> Self;
fn scale(&self, s: f64) -> Self;
fn neg(&self) -> Self {
self.scale(-1.0)
}
fn compose_unary_with<const N: usize>(&self, stack_fn: impl Fn(f64) -> [f64; N]) -> Self;
fn guard(&self, pred: impl Fn(f64) -> bool) -> GuardVerdict;
fn scale_rows(&self, s: Self::Value) -> Self;
fn pack_rows(rows: &[usize], value_of: impl Fn(usize) -> f64) -> Self::Value;
fn exp(&self) -> Self {
self.compose_unary_with(|u| {
let e = u.exp();
[e, e, e, e, e]
})
}
fn ln(&self) -> Self {
self.compose_unary_with(|u| {
let r = 1.0 / u;
[u.ln(), r, -r * r, 2.0 * r * r * r, -6.0 * r * r * r * r]
})
}
fn sqrt(&self) -> Self {
self.compose_unary_with(|u| {
let s = u.sqrt();
[s, 0.5 / s, -0.25 / (u * s), 0.375 / (u * u * s), -0.9375 / (u * u * u * s)]
})
}
fn recip(&self) -> Self {
self.compose_unary_with(|u| {
let r = 1.0 / u;
let r2 = r * r;
[r, -r2, 2.0 * r2 * r, -6.0 * r2 * r2, 24.0 * r2 * r2 * r]
})
}
fn powf(&self, a: f64) -> Self {
self.compose_unary_with(move |u| {
[
u.powf(a),
a * u.powf(a - 1.0),
a * (a - 1.0) * u.powf(a - 2.0),
a * (a - 1.0) * (a - 2.0) * u.powf(a - 3.0),
a * (a - 1.0) * (a - 2.0) * (a - 3.0) * u.powf(a - 4.0),
]
})
}
fn ln_gamma(&self) -> Self {
self.compose_unary_with(ln_gamma_derivative_stack)
}
fn digamma(&self) -> Self {
self.compose_unary_with(digamma_derivative_stack)
}
}
impl<const K: usize, S: crate::jet_scalar::JetScalar<K>> RowJet<K> for S {
type Value = f64;
#[inline]
fn constant(c: f64) -> Self {
<S as crate::jet_scalar::JetScalar<K>>::constant(c)
}
#[inline]
fn variable(x: f64, slot: usize) -> Self {
<S as crate::jet_scalar::JetScalar<K>>::variable(x, slot)
}
#[inline]
fn values(&self) -> f64 {
crate::jet_scalar::JetScalar::value(self)
}
#[inline]
fn add(&self, o: &Self) -> Self {
crate::jet_scalar::JetScalar::add(self, o)
}
#[inline]
fn sub(&self, o: &Self) -> Self {
crate::jet_scalar::JetScalar::sub(self, o)
}
#[inline]
fn mul(&self, o: &Self) -> Self {
crate::jet_scalar::JetScalar::mul(self, o)
}
#[inline]
fn scale(&self, s: f64) -> Self {
crate::jet_scalar::JetScalar::scale(self, s)
}
#[inline]
fn neg(&self) -> Self {
crate::jet_scalar::JetScalar::neg(self)
}
#[inline]
fn compose_unary_with<const N: usize>(&self, stack_fn: impl Fn(f64) -> [f64; N]) -> Self {
crate::jet_scalar::JetScalar::compose_unary_with(self, |u| resize_stack::<N, 5>(stack_fn(u)))
}
#[inline]
fn guard(&self, pred: impl Fn(f64) -> bool) -> GuardVerdict {
GuardVerdict::scalar(pred(crate::jet_scalar::JetScalar::value(self)))
}
#[inline]
fn scale_rows(&self, s: f64) -> Self {
crate::jet_scalar::JetScalar::scale(self, s)
}
#[inline]
fn pack_rows(rows: &[usize], value_of: impl Fn(usize) -> f64) -> f64 {
value_of(rows[0])
}
}
impl<const K: usize> RowJet<K> for Tower4Lane<wide::f64x4, K> {
type Value = [f64; 4];
#[inline]
fn constant(c: f64) -> Self {
Tower4Lane::constant(<wide::f64x4 as crate::jet_scalar::Lane>::splat(c))
}
#[inline]
fn variable(x: f64, slot: usize) -> Self {
Tower4Lane::variable(<wide::f64x4 as crate::jet_scalar::Lane>::splat(x), slot)
}
#[inline]
fn values(&self) -> [f64; 4] {
self.v.to_array()
}
#[inline]
fn add(&self, o: &Self) -> Self {
Tower4Lane::add(self, o)
}
#[inline]
fn sub(&self, o: &Self) -> Self {
Tower4Lane::sub(self, o)
}
#[inline]
fn mul(&self, o: &Self) -> Self {
Tower4Lane::mul(self, o)
}
#[inline]
fn scale(&self, s: f64) -> Self {
Tower4Lane::scale(self, s)
}
#[inline]
fn compose_unary_with<const N: usize>(&self, stack_fn: impl Fn(f64) -> [f64; N]) -> Self {
Tower4Lane::compose_unary_with(self, |u| resize_stack::<N, 5>(stack_fn(u)))
}
#[inline]
fn guard(&self, pred: impl Fn(f64) -> bool) -> GuardVerdict {
let vals = self.v.to_array();
let mut mask = 0u8;
for (i, &v) in vals.iter().enumerate() {
if !pred(v) {
mask |= 1 << i;
}
}
GuardVerdict::lanes4(mask)
}
#[inline]
fn scale_rows(&self, s: [f64; 4]) -> Self {
let sl = wide::f64x4::new(s);
let mut out = *self;
out.v = self.v * sl;
for i in 0..K {
out.g[i] = self.g[i] * sl;
for j in 0..K {
out.h[i][j] = self.h[i][j] * sl;
for k in 0..K {
out.t3[i][j][k] = self.t3[i][j][k] * sl;
for l in 0..K {
out.t4[i][j][k][l] = self.t4[i][j][k][l] * sl;
}
}
}
}
out
}
#[inline]
fn pack_rows(rows: &[usize], value_of: impl Fn(usize) -> f64) -> [f64; 4] {
[value_of(rows[0]), value_of(rows[1]), value_of(rows[2]), value_of(rows[3])]
}
}
impl<const K: usize> RowJet<K> for Tower3Lane<wide::f64x4, K> {
type Value = [f64; 4];
#[inline]
fn constant(c: f64) -> Self {
Tower3Lane::constant(<wide::f64x4 as crate::jet_scalar::Lane>::splat(c))
}
#[inline]
fn variable(x: f64, slot: usize) -> Self {
Tower3Lane::variable(<wide::f64x4 as crate::jet_scalar::Lane>::splat(x), slot)
}
#[inline]
fn values(&self) -> [f64; 4] {
self.v.to_array()
}
#[inline]
fn add(&self, o: &Self) -> Self {
Tower3Lane::add(self, o)
}
#[inline]
fn sub(&self, o: &Self) -> Self {
Tower3Lane::sub(self, o)
}
#[inline]
fn mul(&self, o: &Self) -> Self {
Tower3Lane::mul(self, o)
}
#[inline]
fn scale(&self, s: f64) -> Self {
Tower3Lane::scale(self, s)
}
#[inline]
fn compose_unary_with<const N: usize>(&self, stack_fn: impl Fn(f64) -> [f64; N]) -> Self {
Tower3Lane::compose_unary_with(self, |u| resize_stack::<N, 4>(stack_fn(u)))
}
#[inline]
fn guard(&self, pred: impl Fn(f64) -> bool) -> GuardVerdict {
let vals = self.v.to_array();
let mut mask = 0u8;
for (i, &v) in vals.iter().enumerate() {
if !pred(v) {
mask |= 1 << i;
}
}
GuardVerdict::lanes4(mask)
}
#[inline]
fn scale_rows(&self, s: [f64; 4]) -> Self {
let sl = wide::f64x4::new(s);
let mut out = *self;
out.v = self.v * sl;
for i in 0..K {
out.g[i] = self.g[i] * sl;
for j in 0..K {
out.h[i][j] = self.h[i][j] * sl;
for k in 0..K {
out.t3[i][j][k] = self.t3[i][j][k] * sl;
}
}
}
out
}
#[inline]
fn pack_rows(rows: &[usize], value_of: impl Fn(usize) -> f64) -> [f64; 4] {
[value_of(rows[0]), value_of(rows[1]), value_of(rows[2]), value_of(rows[3])]
}
}
pub trait RowNllProgramRowJet<const K: usize>: Send + Sync {
fn n_rows(&self) -> usize;
fn primaries(&self, row: usize) -> Result<[f64; K], String>;
fn row_nll<R: RowJet<K>>(&self, rows: &[usize], p: &[R; K]) -> Result<R, String>;
}
pub fn rowjet_row_kernel<const K: usize, P: RowNllProgramRowJet<K> + ?Sized>(
prog: &P,
row: usize,
) -> Result<(f64, [f64; K], [[f64; K]; K]), String> {
let base = prog.primaries(row)?;
let vars: [crate::jet_scalar::Order2<K>; K] =
std::array::from_fn(|a| <crate::jet_scalar::Order2<K> as RowJet<K>>::variable(base[a], a));
let s = prog.row_nll(&[row], &vars)?;
Ok((crate::jet_scalar::JetScalar::value(&s), s.g(), s.h()))
}
pub fn rowjet_third_contracted<const K: usize, P: RowNllProgramRowJet<K> + ?Sized>(
prog: &P,
row: usize,
dir: &[f64; K],
) -> Result<[[f64; K]; K], String> {
let base = prog.primaries(row)?;
let vars: [crate::jet_scalar::OneSeed<K>; K] =
std::array::from_fn(|a| crate::jet_scalar::OneSeed::seed_direction(base[a], a, dir[a]));
let s = prog.row_nll(&[row], &vars)?;
Ok(s.contracted_third())
}
pub fn rowjet_fourth_contracted<const K: usize, P: RowNllProgramRowJet<K> + ?Sized>(
prog: &P,
row: usize,
dir_u: &[f64; K],
dir_v: &[f64; K],
) -> Result<[[f64; K]; K], String> {
let base = prog.primaries(row)?;
let vars: [crate::jet_scalar::TwoSeed<K>; K] =
std::array::from_fn(|a| crate::jet_scalar::TwoSeed::seed(base[a], a, dir_u[a], dir_v[a]));
let s = prog.row_nll(&[row], &vars)?;
Ok(s.contracted_fourth())
}
pub fn generic_batched_fourth_tower<const K: usize, P: RowNllProgramRowJet<K> + ?Sized>(
prog: &P,
rows: [usize; 4],
) -> Result<Tower4Batch<K>, String> {
let bases: [[f64; K]; 4] = [
prog.primaries(rows[0])?,
prog.primaries(rows[1])?,
prog.primaries(rows[2])?,
prog.primaries(rows[3])?,
];
let vars: [Tower4Batch<K>; K] = std::array::from_fn(|a| {
let lane_vals = wide::f64x4::new([bases[0][a], bases[1][a], bases[2][a], bases[3][a]]);
Tower4Batch::variable(lane_vals, a)
});
prog.row_nll(&rows, &vars)
}
pub fn generic_batched_third_tower<const K: usize, P: RowNllProgramRowJet<K> + ?Sized>(
prog: &P,
rows: [usize; 4],
) -> Result<Tower3Batch<K>, String> {
let bases: [[f64; K]; 4] = [
prog.primaries(rows[0])?,
prog.primaries(rows[1])?,
prog.primaries(rows[2])?,
prog.primaries(rows[3])?,
];
let vars: [Tower3Batch<K>; K] = std::array::from_fn(|a| {
let lane_vals = wide::f64x4::new([bases[0][a], bases[1][a], bases[2][a], bases[3][a]]);
Tower3Batch::variable(lane_vals, a)
});
prog.row_nll(&rows, &vars)
}
pub struct KernelChannels<const K: usize> {
pub value: f64,
pub gradient: [f64; K],
pub hessian: [[f64; K]; K],
pub third: Vec<([f64; K], [[f64; K]; K])>,
pub fourth: Vec<([f64; K], [f64; K], [[f64; K]; K])>,
}
pub fn verify_kernel_channels<const K: usize>(
tower: &Tower4<K>,
claims: &KernelChannels<K>,
rel_tol: f64,
) -> Result<(), String> {
let atol = rel_tol;
let check = |label: &str, claim: f64, truth: f64| -> Result<(), String> {
if !claim.is_finite() || !truth.is_finite() {
let agree = claim.is_infinite()
&& truth.is_infinite()
&& claim.is_sign_positive() == truth.is_sign_positive();
if agree {
return Ok(());
}
return Err(format!(
"row-kernel oracle: {label} non-finite mismatch: claimed {claim:+.12e}, tower {truth:+.12e}"
));
}
let band = atol + rel_tol * claim.abs().max(truth.abs());
if (claim - truth).abs() > band {
return Err(format!(
"row-kernel oracle: {label} disagrees: claimed {claim:+.12e}, tower {truth:+.12e} (rel_tol {rel_tol:.1e}, atol {atol:.1e}, band {band:.3e})"
));
}
Ok(())
};
check("value", claims.value, tower.v)?;
for a in 0..K {
check(&format!("gradient[{a}]"), claims.gradient[a], tower.g[a])?;
}
for a in 0..K {
for b in 0..K {
check(
&format!("hessian[{a}][{b}]"),
claims.hessian[a][b],
tower.h[a][b],
)?;
}
}
for (t_idx, (dir, claim)) in claims.third.iter().enumerate() {
let truth = tower.third_contracted(dir);
for a in 0..K {
for b in 0..K {
check(
&format!("third[{t_idx}][{a}][{b}]"),
claim[a][b],
truth[a][b],
)?;
}
}
}
for (f_idx, (u, w, claim)) in claims.fourth.iter().enumerate() {
let truth = tower.fourth_contracted(u, w);
for a in 0..K {
for b in 0..K {
check(
&format!("fourth[{f_idx}][{a}][{b}]"),
claim[a][b],
truth[a][b],
)?;
}
}
}
Ok(())
}
use crate::jet_scalar::Lane;
#[derive(Clone, Copy)]
pub struct Tower4Lane<L: Lane, const K: usize> {
pub v: L,
pub g: [L; K],
pub h: [[L; K]; K],
pub t3: [[[L; K]; K]; K],
pub t4: [[[[L; K]; K]; K]; K],
}
pub type Tower4Batch<const K: usize> = Tower4Lane<wide::f64x4, K>;
impl<L: Lane, const K: usize> Tower4Lane<L, K> {
#[inline]
pub fn zero() -> Self {
let z = L::splat(0.0);
Self { v: z, g: [z; K], h: [[z; K]; K], t3: [[[z; K]; K]; K], t4: [[[[z; K]; K]; K]; K] }
}
#[inline]
pub fn constant(c: L) -> Self {
let mut o = Self::zero();
o.v = c;
o
}
#[inline]
pub fn variable(value: L, idx: usize) -> Self {
let mut o = Self::constant(value);
o.g[idx] = L::splat(1.0);
o
}
#[inline]
pub fn lane(&self, i: usize) -> Tower4<K> {
let mut out = Tower4::<K>::zero();
out.v = self.v.lane(i);
for a in 0..K {
out.g[a] = self.g[a].lane(i);
for b in 0..K {
out.h[a][b] = self.h[a][b].lane(i);
for c in 0..K {
out.t3[a][b][c] = self.t3[a][b][c].lane(i);
for d in 0..K {
out.t4[a][b][c][d] = self.t4[a][b][c][d].lane(i);
}
}
}
}
out
}
#[inline]
pub fn add(&self, o: &Self) -> Self {
let mut out = *self;
out.v = self.v.add(o.v);
for i in 0..K {
out.g[i] = self.g[i].add(o.g[i]);
for j in 0..K {
out.h[i][j] = self.h[i][j].add(o.h[i][j]);
for k in 0..K {
out.t3[i][j][k] = self.t3[i][j][k].add(o.t3[i][j][k]);
for l in 0..K {
out.t4[i][j][k][l] = self.t4[i][j][k][l].add(o.t4[i][j][k][l]);
}
}
}
}
out
}
#[inline]
pub fn sub(&self, o: &Self) -> Self {
let mut out = *self;
out.v = self.v.sub(o.v);
for i in 0..K {
out.g[i] = self.g[i].sub(o.g[i]);
for j in 0..K {
out.h[i][j] = self.h[i][j].sub(o.h[i][j]);
for k in 0..K {
out.t3[i][j][k] = self.t3[i][j][k].sub(o.t3[i][j][k]);
for l in 0..K {
out.t4[i][j][k][l] = self.t4[i][j][k][l].sub(o.t4[i][j][k][l]);
}
}
}
}
out
}
#[inline]
pub fn scale(&self, s: f64) -> Self {
let sl = L::splat(s);
let mut out = *self;
out.v = self.v.mul(sl);
for i in 0..K {
out.g[i] = self.g[i].mul(sl);
for j in 0..K {
out.h[i][j] = self.h[i][j].mul(sl);
for k in 0..K {
out.t3[i][j][k] = self.t3[i][j][k].mul(sl);
for l in 0..K {
out.t4[i][j][k][l] = self.t4[i][j][k][l].mul(sl);
}
}
}
}
out
}
#[inline]
pub fn mul(&self, o: &Self) -> Self {
let a = self;
let b = o;
let mut out = Self::zero();
out.v = a.v.mul(b.v);
for i in 0..K {
let mut acc = L::splat(0.0);
acc = acc.add(a.v.mul(b.g[i]));
acc = acc.add(a.g[i].mul(b.v));
out.g[i] = acc;
}
for i in 0..K {
for j in 0..K {
let mut acc = L::splat(0.0);
acc = acc.add(a.v.mul(b.h[i][j]));
acc = acc.add(a.g[i].mul(b.g[j]));
acc = acc.add(a.g[j].mul(b.g[i]));
acc = acc.add(a.h[i][j].mul(b.v));
out.h[i][j] = acc;
}
}
for i in 0..K {
for j in 0..K {
for k in 0..K {
let mut acc = L::splat(0.0);
acc = acc.add(a.v.mul(b.t3[i][j][k]));
acc = acc.add(a.g[i].mul(b.h[j][k]));
acc = acc.add(a.g[j].mul(b.h[i][k]));
acc = acc.add(a.h[i][j].mul(b.g[k]));
acc = acc.add(a.g[k].mul(b.h[i][j]));
acc = acc.add(a.h[i][k].mul(b.g[j]));
acc = acc.add(a.h[j][k].mul(b.g[i]));
acc = acc.add(a.t3[i][j][k].mul(b.v));
out.t3[i][j][k] = acc;
}
}
}
for i in 0..K {
for j in 0..K {
for k in 0..K {
for l in 0..K {
let mut acc = L::splat(0.0);
acc = acc.add(a.v.mul(b.t4[i][j][k][l]));
acc = acc.add(a.g[i].mul(b.t3[j][k][l]));
acc = acc.add(a.g[j].mul(b.t3[i][k][l]));
acc = acc.add(a.h[i][j].mul(b.h[k][l]));
acc = acc.add(a.g[k].mul(b.t3[i][j][l]));
acc = acc.add(a.h[i][k].mul(b.h[j][l]));
acc = acc.add(a.h[j][k].mul(b.h[i][l]));
acc = acc.add(a.t3[i][j][k].mul(b.g[l]));
acc = acc.add(a.g[l].mul(b.t3[i][j][k]));
acc = acc.add(a.h[i][l].mul(b.h[j][k]));
acc = acc.add(a.h[j][l].mul(b.h[i][k]));
acc = acc.add(a.t3[i][j][l].mul(b.g[k]));
acc = acc.add(a.h[k][l].mul(b.h[i][j]));
acc = acc.add(a.t3[i][k][l].mul(b.g[j]));
acc = acc.add(a.t3[j][k][l].mul(b.g[i]));
acc = acc.add(a.t4[i][j][k][l].mul(b.v));
out.t4[i][j][k][l] = acc;
}
}
}
}
out
}
#[inline]
pub fn compose_unary(&self, d: [L; 5]) -> Self {
let mut out = Self::zero();
out.v = d[0];
for i in 0..K {
let mut acc = L::splat(0.0);
acc = acc.add(d[1].mul(self.g[i]));
out.g[i] = acc;
}
for i in 0..K {
for j in 0..K {
let mut acc = L::splat(0.0);
acc = acc.add(d[1].mul(self.h[i][j]));
acc = acc.add(d[2].mul(self.g[i]).mul(self.g[j]));
out.h[i][j] = acc;
}
}
for i in 0..K {
for j in 0..K {
for k in 0..K {
let mut acc = L::splat(0.0);
acc = acc.add(d[1].mul(self.t3[i][j][k]));
acc = acc.add(d[2].mul(self.h[i][j]).mul(self.g[k]));
acc = acc.add(d[2].mul(self.h[i][k]).mul(self.g[j]));
acc = acc.add(d[2].mul(self.g[i]).mul(self.h[j][k]));
acc = acc.add(d[3].mul(self.g[i]).mul(self.g[j]).mul(self.g[k]));
out.t3[i][j][k] = acc;
}
}
}
for i in 0..K {
for j in 0..K {
for k in 0..K {
for l in 0..K {
let mut acc = L::splat(0.0);
acc = acc.add(d[1].mul(self.t4[i][j][k][l]));
acc = acc.add(d[2].mul(self.t3[i][j][k]).mul(self.g[l]));
acc = acc.add(d[2].mul(self.t3[i][j][l]).mul(self.g[k]));
acc = acc.add(d[2].mul(self.h[i][j]).mul(self.h[k][l]));
acc = acc.add(d[3].mul(self.h[i][j]).mul(self.g[k]).mul(self.g[l]));
acc = acc.add(d[2].mul(self.t3[i][k][l]).mul(self.g[j]));
acc = acc.add(d[2].mul(self.h[i][k]).mul(self.h[j][l]));
acc = acc.add(d[3].mul(self.h[i][k]).mul(self.g[j]).mul(self.g[l]));
acc = acc.add(d[2].mul(self.h[i][l]).mul(self.h[j][k]));
acc = acc.add(d[2].mul(self.g[i]).mul(self.t3[j][k][l]));
acc = acc.add(d[3].mul(self.g[i]).mul(self.h[j][k]).mul(self.g[l]));
acc = acc.add(d[3].mul(self.h[i][l]).mul(self.g[j]).mul(self.g[k]));
acc = acc.add(d[3].mul(self.g[i]).mul(self.h[j][l]).mul(self.g[k]));
acc = acc.add(d[3].mul(self.g[i]).mul(self.g[j]).mul(self.h[k][l]));
acc = acc.add(d[4].mul(self.g[i]).mul(self.g[j]).mul(self.g[k]).mul(self.g[l]));
out.t4[i][j][k][l] = acc;
}
}
}
}
out
}
#[inline]
pub fn compose_unary_with(&self, stack_fn: impl Fn(f64) -> [f64; 5]) -> Self {
self.compose_unary(self.v.unary_with(stack_fn))
}
#[inline]
pub fn compose_unary_single_slot(&self, d: [L; 5], slot: usize) -> Self {
let mut out = Self::zero();
let s = slot;
let g = self.g[s];
let h = self.h[s][s];
let t3 = self.t3[s][s][s];
let t4 = self.t4[s][s][s][s];
out.v = d[0];
out.g[s] = {
let mut acc = L::splat(0.0);
acc = acc.add(d[1].mul(g));
acc
};
out.h[s][s] = {
let mut acc = L::splat(0.0);
acc = acc.add(d[1].mul(h));
acc = acc.add(d[2].mul(g).mul(g));
acc
};
out.t3[s][s][s] = {
let mut acc = L::splat(0.0);
acc = acc.add(d[1].mul(t3));
acc = acc.add(d[2].mul(h).mul(g));
acc = acc.add(d[2].mul(h).mul(g));
acc = acc.add(d[2].mul(g).mul(h));
acc = acc.add(d[3].mul(g).mul(g).mul(g));
acc
};
out.t4[s][s][s][s] = {
let mut acc = L::splat(0.0);
acc = acc.add(d[1].mul(t4));
acc = acc.add(d[2].mul(t3).mul(g));
acc = acc.add(d[2].mul(t3).mul(g));
acc = acc.add(d[2].mul(h).mul(h));
acc = acc.add(d[3].mul(h).mul(g).mul(g));
acc = acc.add(d[2].mul(t3).mul(g));
acc = acc.add(d[2].mul(h).mul(h));
acc = acc.add(d[3].mul(h).mul(g).mul(g));
acc = acc.add(d[2].mul(h).mul(h));
acc = acc.add(d[2].mul(g).mul(t3));
acc = acc.add(d[3].mul(g).mul(h).mul(g));
acc = acc.add(d[3].mul(h).mul(g).mul(g));
acc = acc.add(d[3].mul(g).mul(h).mul(g));
acc = acc.add(d[3].mul(g).mul(g).mul(h));
acc = acc.add(d[4].mul(g).mul(g).mul(g).mul(g));
acc
};
out
}
#[inline]
pub fn third_contracted(&self, dir: &[L; K]) -> [[L; K]; K] {
let mut out = [[L::splat(0.0); K]; K];
for a in 0..K {
for b in 0..K {
let mut acc = L::splat(0.0);
for c in 0..K {
acc = acc.add(self.t3[a][b][c].mul(dir[c]));
}
out[a][b] = acc;
}
}
out
}
#[inline]
pub fn fourth_contracted(&self, u: &[L; K], w: &[L; K]) -> [[L; K]; K] {
let mut out = [[L::splat(0.0); K]; K];
for i in 0..K {
for j in 0..K {
let mut acc = L::splat(0.0);
for k in 0..K {
for l in 0..K {
acc = acc.add(self.t4[i][j][k][l].mul(u[k]).mul(w[l]));
}
}
out[i][j] = acc;
}
}
out
}
}
#[derive(Clone, Copy)]
pub struct Tower3Lane<L: Lane, const K: usize> {
pub v: L,
pub g: [L; K],
pub h: [[L; K]; K],
pub t3: [[[L; K]; K]; K],
}
pub type Tower3Batch<const K: usize> = Tower3Lane<wide::f64x4, K>;
impl<L: Lane, const K: usize> Tower3Lane<L, K> {
#[inline]
pub fn zero() -> Self {
let z = L::splat(0.0);
Self { v: z, g: [z; K], h: [[z; K]; K], t3: [[[z; K]; K]; K] }
}
#[inline]
pub fn constant(c: L) -> Self {
let mut o = Self::zero();
o.v = c;
o
}
#[inline]
pub fn variable(value: L, idx: usize) -> Self {
let mut o = Self::constant(value);
o.g[idx] = L::splat(1.0);
o
}
#[inline]
pub fn lane(&self, i: usize) -> Tower3<K> {
let mut out = Tower3::<K>::zero();
out.v = self.v.lane(i);
for a in 0..K {
out.g[a] = self.g[a].lane(i);
for b in 0..K {
out.h[a][b] = self.h[a][b].lane(i);
for c in 0..K {
out.t3[a][b][c] = self.t3[a][b][c].lane(i);
}
}
}
out
}
#[inline]
pub fn add(&self, o: &Self) -> Self {
let mut out = *self;
out.v = self.v.add(o.v);
for i in 0..K {
out.g[i] = self.g[i].add(o.g[i]);
for j in 0..K {
out.h[i][j] = self.h[i][j].add(o.h[i][j]);
for k in 0..K {
out.t3[i][j][k] = self.t3[i][j][k].add(o.t3[i][j][k]);
}
}
}
out
}
#[inline]
pub fn sub(&self, o: &Self) -> Self {
let mut out = *self;
out.v = self.v.sub(o.v);
for i in 0..K {
out.g[i] = self.g[i].sub(o.g[i]);
for j in 0..K {
out.h[i][j] = self.h[i][j].sub(o.h[i][j]);
for k in 0..K {
out.t3[i][j][k] = self.t3[i][j][k].sub(o.t3[i][j][k]);
}
}
}
out
}
#[inline]
pub fn scale(&self, s: f64) -> Self {
let sl = L::splat(s);
let mut out = *self;
out.v = self.v.mul(sl);
for i in 0..K {
out.g[i] = self.g[i].mul(sl);
for j in 0..K {
out.h[i][j] = self.h[i][j].mul(sl);
for k in 0..K {
out.t3[i][j][k] = self.t3[i][j][k].mul(sl);
}
}
}
out
}
#[inline]
pub fn mul(&self, o: &Self) -> Self {
let a = self;
let b = o;
let mut out = Self::zero();
out.v = a.v.mul(b.v);
for i in 0..K {
let mut acc = L::splat(0.0);
acc = acc.add(a.v.mul(b.g[i]));
acc = acc.add(a.g[i].mul(b.v));
out.g[i] = acc;
}
for i in 0..K {
for j in 0..K {
let mut acc = L::splat(0.0);
acc = acc.add(a.v.mul(b.h[i][j]));
acc = acc.add(a.g[i].mul(b.g[j]));
acc = acc.add(a.g[j].mul(b.g[i]));
acc = acc.add(a.h[i][j].mul(b.v));
out.h[i][j] = acc;
}
}
for i in 0..K {
for j in 0..K {
for k in 0..K {
let mut acc = L::splat(0.0);
acc = acc.add(a.v.mul(b.t3[i][j][k]));
acc = acc.add(a.g[i].mul(b.h[j][k]));
acc = acc.add(a.g[j].mul(b.h[i][k]));
acc = acc.add(a.h[i][j].mul(b.g[k]));
acc = acc.add(a.g[k].mul(b.h[i][j]));
acc = acc.add(a.h[i][k].mul(b.g[j]));
acc = acc.add(a.h[j][k].mul(b.g[i]));
acc = acc.add(a.t3[i][j][k].mul(b.v));
out.t3[i][j][k] = acc;
}
}
}
out
}
#[inline]
pub fn compose_unary(&self, d: [L; 4]) -> Self {
let mut out = Self::zero();
out.v = d[0];
for i in 0..K {
let mut acc = L::splat(0.0);
acc = acc.add(d[1].mul(self.g[i]));
out.g[i] = acc;
}
for i in 0..K {
for j in 0..K {
let mut acc = L::splat(0.0);
acc = acc.add(d[1].mul(self.h[i][j]));
acc = acc.add(d[2].mul(self.g[i]).mul(self.g[j]));
out.h[i][j] = acc;
}
}
for i in 0..K {
for j in 0..K {
for k in 0..K {
let mut acc = L::splat(0.0);
acc = acc.add(d[1].mul(self.t3[i][j][k]));
acc = acc.add(d[2].mul(self.h[i][j]).mul(self.g[k]));
acc = acc.add(d[2].mul(self.h[i][k]).mul(self.g[j]));
acc = acc.add(d[2].mul(self.g[i]).mul(self.h[j][k]));
acc = acc.add(d[3].mul(self.g[i]).mul(self.g[j]).mul(self.g[k]));
out.t3[i][j][k] = acc;
}
}
}
out
}
#[inline]
pub fn compose_unary_with(&self, stack_fn: impl Fn(f64) -> [f64; 4]) -> Self {
self.compose_unary(self.v.unary_with(stack_fn))
}
#[inline]
pub fn compose_unary_single_slot(&self, d: [L; 4], slot: usize) -> Self {
let mut out = Self::zero();
let s = slot;
let g = self.g[s];
let h = self.h[s][s];
let t3 = self.t3[s][s][s];
out.v = d[0];
out.g[s] = {
let mut acc = L::splat(0.0);
acc = acc.add(d[1].mul(g));
acc
};
out.h[s][s] = {
let mut acc = L::splat(0.0);
acc = acc.add(d[1].mul(h));
acc = acc.add(d[2].mul(g).mul(g));
acc
};
out.t3[s][s][s] = {
let mut acc = L::splat(0.0);
acc = acc.add(d[1].mul(t3));
acc = acc.add(d[2].mul(h).mul(g));
acc = acc.add(d[2].mul(h).mul(g));
acc = acc.add(d[2].mul(g).mul(h));
acc = acc.add(d[3].mul(g).mul(g).mul(g));
acc
};
out
}
}
#[cfg(test)]
mod batch_tests {
use super::*;
struct Rng(u64);
impl Rng {
fn f(&mut self) -> f64 {
self.0 = self.0.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
((self.0 >> 11) as f64 / (1u64 << 53) as f64) * 4.0 - 2.0
}
}
fn rand_t4<const K: usize>(r: &mut Rng) -> Tower4<K> {
let mut t = Tower4::<K>::zero();
t.v = r.f();
for i in 0..K {
t.g[i] = r.f();
for j in 0..K {
t.h[i][j] = r.f();
for k in 0..K {
t.t3[i][j][k] = r.f();
for l in 0..K {
t.t4[i][j][k][l] = r.f();
}
}
}
}
t
}
fn rand_t3<const K: usize>(r: &mut Rng) -> Tower3<K> {
let mut t = Tower3::<K>::zero();
t.v = r.f();
for i in 0..K {
t.g[i] = r.f();
for j in 0..K {
t.h[i][j] = r.f();
for k in 0..K {
t.t3[i][j][k] = r.f();
}
}
}
t
}
fn pack4_t4<const K: usize>(rows: &[Tower4<K>; 4]) -> Tower4Batch<K> {
let mut b = Tower4Batch::<K>::zero();
let lane = |f: &dyn Fn(&Tower4<K>) -> f64| {
wide::f64x4::new([f(&rows[0]), f(&rows[1]), f(&rows[2]), f(&rows[3])])
};
b.v = lane(&|t| t.v);
for i in 0..K {
b.g[i] = lane(&|t| t.g[i]);
for j in 0..K {
b.h[i][j] = lane(&|t| t.h[i][j]);
for k in 0..K {
b.t3[i][j][k] = lane(&|t| t.t3[i][j][k]);
for l in 0..K {
b.t4[i][j][k][l] = lane(&|t| t.t4[i][j][k][l]);
}
}
}
}
b
}
fn pack4_t3<const K: usize>(rows: &[Tower3<K>; 4]) -> Tower3Batch<K> {
let mut b = Tower3Batch::<K>::zero();
let lane = |f: &dyn Fn(&Tower3<K>) -> f64| {
wide::f64x4::new([f(&rows[0]), f(&rows[1]), f(&rows[2]), f(&rows[3])])
};
b.v = lane(&|t| t.v);
for i in 0..K {
b.g[i] = lane(&|t| t.g[i]);
for j in 0..K {
b.h[i][j] = lane(&|t| t.h[i][j]);
for k in 0..K {
b.t3[i][j][k] = lane(&|t| t.t3[i][j][k]);
}
}
}
b
}
fn assert_t4_eq<const K: usize>(b: &Tower4<K>, s: &Tower4<K>, ctx: &str) {
assert_eq!(b.v.to_bits(), s.v.to_bits(), "v {ctx}");
for i in 0..K {
assert_eq!(b.g[i].to_bits(), s.g[i].to_bits(), "g {ctx}");
for j in 0..K {
assert_eq!(b.h[i][j].to_bits(), s.h[i][j].to_bits(), "h {ctx}");
for k in 0..K {
assert_eq!(b.t3[i][j][k].to_bits(), s.t3[i][j][k].to_bits(), "t3 {ctx}");
for l in 0..K {
assert_eq!(b.t4[i][j][k][l].to_bits(), s.t4[i][j][k][l].to_bits(), "t4 {ctx}");
}
}
}
}
}
fn assert_t3_eq<const K: usize>(b: &Tower3<K>, s: &Tower3<K>, ctx: &str) {
assert_eq!(b.v.to_bits(), s.v.to_bits(), "v {ctx}");
for i in 0..K {
assert_eq!(b.g[i].to_bits(), s.g[i].to_bits(), "g {ctx}");
for j in 0..K {
assert_eq!(b.h[i][j].to_bits(), s.h[i][j].to_bits(), "h {ctx}");
for k in 0..K {
assert_eq!(b.t3[i][j][k].to_bits(), s.t3[i][j][k].to_bits(), "t3 {ctx}");
}
}
}
}
fn run4<const K: usize>(seed: u64, batches: usize) -> usize {
let mut r = Rng(seed);
let mut rows_checked = 0;
for _ in 0..batches {
let a: [Tower4<K>; 4] = std::array::from_fn(|_| rand_t4::<K>(&mut r));
let b: [Tower4<K>; 4] = std::array::from_fn(|_| rand_t4::<K>(&mut r));
let d: [[f64; 5]; 4] = std::array::from_fn(|_| std::array::from_fn(|_| r.f()));
let dir: [[f64; K]; 4] = std::array::from_fn(|_| std::array::from_fn(|_| r.f()));
let dir2: [[f64; K]; 4] = std::array::from_fn(|_| std::array::from_fn(|_| r.f()));
let s = r.f();
let scal: [Tower4<K>; 4] = std::array::from_fn(|rw| {
let prod = a[rw].mul(&b[rw]);
let comp = prod.compose_unary(d[rw]);
let summed = comp.add(&a[rw]).sub(&b[rw]).scale(s);
summed.compose_unary_single_slot(d[rw], 0)
});
let third: [[[f64; K]; K]; 4] =
std::array::from_fn(|rw| a[rw].third_contracted(&dir[rw]));
let fourth: [[[f64; K]; K]; 4] =
std::array::from_fn(|rw| a[rw].fourth_contracted(&dir[rw], &dir2[rw]));
let ab = pack4_t4(&a);
let bb = pack4_t4(&b);
let db: [wide::f64x4; 5] = std::array::from_fn(|c| {
wide::f64x4::new([d[0][c], d[1][c], d[2][c], d[3][c]])
});
let dirb: [wide::f64x4; K] = std::array::from_fn(|c| {
wide::f64x4::new([dir[0][c], dir[1][c], dir[2][c], dir[3][c]])
});
let dir2b: [wide::f64x4; K] = std::array::from_fn(|c| {
wide::f64x4::new([dir2[0][c], dir2[1][c], dir2[2][c], dir2[3][c]])
});
let prodb = ab.mul(&bb);
let compb = prodb.compose_unary(db);
let summedb = compb.add(&ab).sub(&bb).scale(s);
let finalb = summedb.compose_unary_single_slot(db, 0);
let thirdb = ab.third_contracted(&dirb);
let fourthb = ab.fourth_contracted(&dirb, &dir2b);
for rw in 0..4 {
assert_t4_eq(&finalb.lane(rw), &scal[rw], "t4-chain");
for i in 0..K {
for j in 0..K {
assert_eq!(thirdb[i][j].lane(rw).to_bits(), third[rw][i][j].to_bits(), "third");
assert_eq!(fourthb[i][j].lane(rw).to_bits(), fourth[rw][i][j].to_bits(), "fourth");
}
}
rows_checked += 1;
}
}
rows_checked
}
fn run3<const K: usize>(seed: u64, batches: usize) -> usize {
let mut r = Rng(seed);
let mut rows_checked = 0;
for _ in 0..batches {
let a: [Tower3<K>; 4] = std::array::from_fn(|_| rand_t3::<K>(&mut r));
let b: [Tower3<K>; 4] = std::array::from_fn(|_| rand_t3::<K>(&mut r));
let d: [[f64; 4]; 4] = std::array::from_fn(|_| std::array::from_fn(|_| r.f()));
let s = r.f();
let scal: [Tower3<K>; 4] = std::array::from_fn(|rw| {
let prod = a[rw].mul(&b[rw]);
let comp = prod.compose_unary(d[rw]);
let summed = comp.add(&a[rw]).sub(&b[rw]).scale(s);
summed.compose_unary_single_slot(d[rw], 0)
});
let ab = pack4_t3(&a);
let bb = pack4_t3(&b);
let db: [wide::f64x4; 4] = std::array::from_fn(|c| {
wide::f64x4::new([d[0][c], d[1][c], d[2][c], d[3][c]])
});
let prodb = ab.mul(&bb);
let compb = prodb.compose_unary(db);
let summedb = compb.add(&ab).sub(&bb).scale(s);
let finalb = summedb.compose_unary_single_slot(db, 0);
for rw in 0..4 {
assert_t3_eq(&finalb.lane(rw), &scal[rw], "t3-chain");
rows_checked += 1;
}
}
rows_checked
}
fn big_stack<R: Send + 'static, F: FnOnce() -> R + Send + 'static>(f: F) -> R {
std::thread::Builder::new()
.stack_size(512 << 20)
.spawn(f)
.unwrap()
.join()
.unwrap()
}
#[test]
fn tower4_batch_lane_bit_identical() {
let batches = 2000;
let rows_checked = big_stack(move || run4::<2>(0x1111_2222_3333_4444, batches))
+ big_stack(move || run4::<3>(0x5555_6666_7777_8888, batches))
+ big_stack(move || run4::<4>(0x9999_aaaa_bbbb_cccc, batches))
+ big_stack(move || run4::<9>(0xdddd_eeee_ffff_0000, batches));
assert_eq!(rows_checked, 4 * batches * 4);
}
#[test]
fn tower3_batch_lane_bit_identical() {
let batches = 2000;
let rows_checked = big_stack(move || run3::<2>(0x0f0f_1e1e_2d2d_3c3c, batches))
+ big_stack(move || run3::<3>(0x4b4b_5a5a_6969_7878, batches))
+ big_stack(move || run3::<4>(0x8787_9696_a5a5_b4b4, batches))
+ big_stack(move || run3::<9>(0xc3c3_d2d2_e1e1_f0f0, batches));
assert_eq!(rows_checked, 4 * batches * 4);
}
fn seam_stack5(u: f64) -> [f64; 5] {
[u.sin(), u.cos(), (2.0 * u).sin(), (0.5 * u).cos(), u * u - 0.3]
}
fn seam_stack4(u: f64) -> [f64; 4] {
let s = seam_stack5(u);
[s[0], s[1], s[2], s[3]]
}
fn seam_edge_base(r: &mut Rng, which: usize) -> f64 {
match which {
0 => -0.0,
1 => 0.0,
2 => r.f(),
_ => r.f() + 3.0,
}
}
fn scalar_seam_t4<const K: usize>(seed: u64, n: usize) -> usize {
let mut r = Rng(seed);
for _ in 0..n {
let mut t = rand_t4::<K>(&mut r);
t.v = seam_edge_base(&mut r, (t.v.to_bits() % 4) as usize);
assert_t4_eq(
&t.compose_unary_with(seam_stack5),
&t.compose_unary(seam_stack5(t.v)),
"scalar t4 seam",
);
}
n
}
fn scalar_seam_t3<const K: usize>(seed: u64, n: usize) -> usize {
let mut r = Rng(seed);
for _ in 0..n {
let mut t = rand_t3::<K>(&mut r);
t.v = seam_edge_base(&mut r, (t.v.to_bits() % 4) as usize);
assert_t3_eq(
&t.compose_unary_with(seam_stack4),
&t.compose_unary(seam_stack4(t.v)),
"scalar t3 seam",
);
}
n
}
fn lane_seam_t4<const K: usize>(seed: u64, batches: usize) -> usize {
let mut r = Rng(seed);
let mut verified = 0usize;
for _ in 0..batches {
let mut rows: [Tower4<K>; 4] = std::array::from_fn(|_| rand_t4::<K>(&mut r));
for (rw, row) in rows.iter_mut().enumerate() {
row.v = seam_edge_base(&mut r, rw);
}
let batch_out = pack4_t4(&rows).compose_unary_with(seam_stack5);
for (rw, row) in rows.iter().enumerate() {
assert_t4_eq(&batch_out.lane(rw), &row.compose_unary_with(seam_stack5), "lane t4 seam");
verified += 1;
}
}
verified
}
fn lane_seam_t3<const K: usize>(seed: u64, batches: usize) -> usize {
let mut r = Rng(seed);
let mut verified = 0usize;
for _ in 0..batches {
let mut rows: [Tower3<K>; 4] = std::array::from_fn(|_| rand_t3::<K>(&mut r));
for (rw, row) in rows.iter_mut().enumerate() {
row.v = seam_edge_base(&mut r, rw);
}
let batch_out = pack4_t3(&rows).compose_unary_with(seam_stack4);
for (rw, row) in rows.iter().enumerate() {
assert_t3_eq(&batch_out.lane(rw), &row.compose_unary_with(seam_stack4), "lane t3 seam");
verified += 1;
}
}
verified
}
#[test]
fn compose_unary_with_scalar_bit_identical() {
let n = 1100;
let total = scalar_seam_t4::<2>(0x2200_0001, n)
+ scalar_seam_t4::<3>(0x2200_0002, n)
+ scalar_seam_t4::<4>(0x2200_0003, n)
+ big_stack(move || scalar_seam_t4::<9>(0x2200_0004, n))
+ scalar_seam_t3::<2>(0x3300_0001, n)
+ scalar_seam_t3::<3>(0x3300_0002, n)
+ scalar_seam_t3::<4>(0x3300_0003, n)
+ big_stack(move || scalar_seam_t3::<9>(0x3300_0004, n));
assert_eq!(total, 8 * n);
}
#[test]
fn compose_unary_with_lane_matches_scalar() {
let b = 600;
let total = lane_seam_t4::<2>(0x4400_0001, b)
+ lane_seam_t4::<3>(0x4400_0002, b)
+ lane_seam_t4::<4>(0x4400_0003, b)
+ big_stack(move || lane_seam_t4::<9>(0x4400_0004, b))
+ lane_seam_t3::<2>(0x5500_0001, b)
+ lane_seam_t3::<3>(0x5500_0002, b)
+ lane_seam_t3::<4>(0x5500_0003, b)
+ big_stack(move || lane_seam_t3::<9>(0x5500_0004, b));
assert_eq!(total, 8 * b * 4);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn tower3_matches_tower4_through_third_order() {
let s_a: [f64; 5] = [
0.3_f64.sin(),
0.3_f64.cos(),
-0.3_f64.sin(),
-0.3_f64.cos(),
0.3_f64.sin(),
];
let s_b: [f64; 5] = [1.1, -0.4, 0.8, -0.2, 0.05];
let s4 = |s: [f64; 5]| [s[0], s[1], s[2], s[3]];
let a4 = Tower4::<3>::variable(0.4, 0);
let b4 = Tower4::<3>::variable(-0.7, 1);
let c4 = Tower4::<3>::variable(0.9, 2);
let prog4 = (a4.mul(&b4) + c4).compose_unary(s_a).scale(1.3)
+ a4.mul(&c4).scale(-0.7)
+ b4.compose_unary(s_b).scale(0.25);
let a3 = Tower3::<3>::variable(0.4, 0);
let b3 = Tower3::<3>::variable(-0.7, 1);
let c3 = Tower3::<3>::variable(0.9, 2);
let prog3 = (a3.mul(&b3) + c3).compose_unary(s4(s_a)).scale(1.3)
+ a3.mul(&c3).scale(-0.7)
+ b3.compose_unary(s4(s_b)).scale(0.25);
assert_eq!(prog3.v.to_bits(), prog4.v.to_bits(), "value mismatch");
for i in 0..3 {
assert_eq!(
prog3.g[i].to_bits(),
prog4.g[i].to_bits(),
"g[{i}] mismatch"
);
for j in 0..3 {
assert_eq!(
prog3.h[i][j].to_bits(),
prog4.h[i][j].to_bits(),
"h[{i}][{j}] mismatch"
);
for k in 0..3 {
assert_eq!(
prog3.t3[i][j][k].to_bits(),
prog4.t3[i][j][k].to_bits(),
"t3[{i}][{j}][{k}] mismatch"
);
}
}
}
}
struct LogitProgram {
eta: Vec<f64>,
y: Vec<f64>,
}
impl RowNllProgram<1> for LogitProgram {
fn n_rows(&self) -> usize {
self.eta.len()
}
fn primaries(&self, row: usize) -> Result<[f64; 1], String> {
Ok([self.eta[row]])
}
fn row_nll(&self, row: usize, p: &[Tower4<1>; 1]) -> Result<Tower4<1>, String> {
let eta = p[0];
Ok((eta.exp() + 1.0).ln() - eta * self.y[row])
}
}
#[test]
fn logit_tower_matches_closed_forms() {
let prog = LogitProgram {
eta: vec![-2.3, -0.4, 0.0, 0.9, 3.1],
y: vec![1.0, 0.0, 1.0, 0.0, 1.0],
};
for row in 0..prog.n_rows() {
let t = evaluate_program(&prog, row).expect("logit program");
let eta = prog.eta[row];
let y = prog.y[row];
let mu = 1.0 / (1.0 + (-eta).exp());
let w = mu * (1.0 - mu);
let expect = [
(t.v, (1.0 + eta.exp()).ln() - y * eta, "value"),
(t.g[0], mu - y, "grad"),
(t.h[0][0], w, "hess"),
(t.t3[0][0][0], w * (1.0 - 2.0 * mu), "third"),
(
t.t4[0][0][0][0],
w * (1.0 - 6.0 * mu + 6.0 * mu * mu),
"fourth",
),
];
for (got, want, label) in expect {
assert!(
(got - want).abs() <= 1e-12 * want.abs().max(1.0),
"row {row} {label}: got {got:+.15e} want {want:+.15e}"
);
}
}
}
fn assert_close(label: &str, got: f64, want: f64, rel_tol: f64) {
let diff = (got - want).abs();
assert!(
diff <= rel_tol * want.abs().max(1.0),
"{label}: got {got:+.17e} want {want:+.17e} diff {diff:.3e}"
);
}
#[test]
fn gamma_special_function_stacks_match_reference_values() {
const EULER_GAMMA: f64 = 0.577_215_664_901_532_9;
let pi_sq = std::f64::consts::PI * std::f64::consts::PI;
let cases = [
(
"x=0.1",
0.1,
-10.423_754_940_411_076,
101.433_299_150_792_75,
),
(
"x=0.5",
0.5,
-EULER_GAMMA - 2.0 * std::f64::consts::LN_2,
pi_sq / 2.0,
),
("x=1", 1.0, -EULER_GAMMA, pi_sq / 6.0),
(
"x=2.5",
2.5,
-EULER_GAMMA - 2.0 * std::f64::consts::LN_2 + 2.0 + 2.0 / 3.0,
pi_sq / 2.0 - 4.0 - 4.0 / 9.0,
),
(
"x=50",
50.0,
3.901_989_673_427_892,
0.020_201_333_226_697_128,
),
];
for (label, x, digamma_ref, trigamma_ref) in cases {
let ln_gamma_stack = ln_gamma_derivative_stack(x);
let digamma_stack = digamma_derivative_stack(x);
let trigamma_stack = trigamma_derivative_stack(x);
assert_close(
&format!("{label} ln_gamma_stack digamma"),
ln_gamma_stack[1],
digamma_ref,
1e-13,
);
assert_close(
&format!("{label} digamma value"),
digamma_stack[0],
digamma_ref,
1e-13,
);
assert_close(
&format!("{label} ln_gamma_stack trigamma"),
ln_gamma_stack[2],
trigamma_ref,
1e-13,
);
assert_close(
&format!("{label} digamma_stack trigamma"),
digamma_stack[1],
trigamma_ref,
1e-13,
);
assert_close(
&format!("{label} trigamma value"),
trigamma_stack[0],
trigamma_ref,
1e-13,
);
}
}
#[test]
fn gamma_special_function_stacks_obey_recurrences() {
for x in [0.1, 0.5, 1.0, 2.5, 50.0] {
let digamma_x = digamma_derivative_stack(x)[0];
let digamma_next = digamma_derivative_stack(x + 1.0)[0];
let trigamma_x = trigamma_derivative_stack(x)[0];
let trigamma_next = trigamma_derivative_stack(x + 1.0)[0];
assert_close(
&format!("digamma recurrence x={x}"),
digamma_next,
digamma_x + 1.0 / x,
1e-13,
);
assert_close(
&format!("trigamma recurrence x={x}"),
trigamma_next,
trigamma_x - 1.0 / (x * x),
1e-13,
);
}
}
struct LocScaleProgram {
eta: Vec<f64>,
s: Vec<f64>,
y: Vec<f64>,
}
impl RowNllProgram<2> for LocScaleProgram {
fn n_rows(&self) -> usize {
self.eta.len()
}
fn primaries(&self, row: usize) -> Result<[f64; 2], String> {
Ok([self.eta[row], self.s[row]])
}
fn row_nll(&self, row: usize, p: &[Tower4<2>; 2]) -> Result<Tower4<2>, String> {
let r = -(p[0] - self.y[row]);
Ok(p[1] + (p[1] * (-2.0)).exp() * r * r * 0.5)
}
}
#[test]
fn locscale_tower_matches_closed_forms_including_cross_blocks() {
let prog = LocScaleProgram {
eta: vec![0.3, -1.1, 2.0],
s: vec![-0.5, 0.2, 0.8],
y: vec![1.0, -2.0, 2.5],
};
let tol = 1e-12;
for row in 0..prog.n_rows() {
let t = evaluate_program(&prog, row).expect("locscale program");
let r = prog.y[row] - prog.eta[row];
let w = (-2.0 * prog.s[row]).exp();
let truth_g = [-w * r, 1.0 - w * r * r];
let truth_h = [[w, 2.0 * w * r], [2.0 * w * r, 2.0 * w * r * r]];
let t3_truth = |a: usize, b: usize, c: usize| -> f64 {
match a + b + c {
0 => 0.0,
1 => -2.0 * w,
2 => -4.0 * w * r,
_ => -4.0 * w * r * r,
}
};
let t4_truth = |a: usize, b: usize, c: usize, d: usize| -> f64 {
match a + b + c + d {
0 | 1 => 0.0,
2 => 4.0 * w,
3 => 8.0 * w * r,
_ => 8.0 * w * r * r,
}
};
for a in 0..2 {
assert!(
(t.g[a] - truth_g[a]).abs() <= tol * truth_g[a].abs().max(1.0),
"row {row} grad[{a}]"
);
for b in 0..2 {
assert!(
(t.h[a][b] - truth_h[a][b]).abs() <= tol * w.max(1.0) * (1.0 + r.abs()),
"row {row} hess[{a}][{b}]: got {} want {}",
t.h[a][b],
truth_h[a][b]
);
for c in 0..2 {
assert!(
(t.t3[a][b][c] - t3_truth(a, b, c)).abs()
<= tol * 8.0 * w.max(1.0) * (1.0 + r.abs() + r * r),
"row {row} t3[{a}][{b}][{c}]: got {} want {}",
t.t3[a][b][c],
t3_truth(a, b, c)
);
for d in 0..2 {
assert!(
(t.t4[a][b][c][d] - t4_truth(a, b, c, d)).abs()
<= tol * 16.0 * w.max(1.0) * (1.0 + r.abs() + r * r),
"row {row} t4[{a}][{b}][{c}][{d}]: got {} want {}",
t.t4[a][b][c][d],
t4_truth(a, b, c, d)
);
}
}
}
}
let dir = [0.7, -1.3];
let third = derived_third_contracted(&prog, row, &dir).expect("third");
for a in 0..2 {
for b in 0..2 {
let want = t.t3[a][b][0] * dir[0] + t.t3[a][b][1] * dir[1];
assert!((third[a][b] - want).abs() <= 1e-13 * want.abs().max(1.0));
}
}
}
}
struct GnarlyProgram {
primaries: Vec<[f64; 3]>,
tau: Vec<f64>,
}
impl GnarlyProgram {
fn fixture() -> Self {
Self {
primaries: vec![[0.4, -0.7, 1.2], [-0.9, 0.6, 0.3], [1.1, -0.2, -0.8]],
tau: vec![0.15, -0.35, 0.5],
}
}
}
impl RowNllProgram<3> for GnarlyProgram {
fn n_rows(&self) -> usize {
self.primaries.len()
}
fn primaries(&self, row: usize) -> Result<[f64; 3], String> {
self.primaries
.get(row)
.copied()
.ok_or_else(|| format!("gnarly: row {row} out of range"))
}
fn row_nll(&self, row: usize, p: &[Tower4<3>; 3]) -> Result<Tower4<3>, String> {
let tau = *self
.tau
.get(row)
.ok_or_else(|| format!("gnarly: tau row {row} out of range"))?;
let a = (p[0] * p[1]).exp();
let b = (p[2] * p[2] + 1.0).sqrt();
let c = (a + b + tau).ln();
let d = (p[1] * 0.5 + 2.0).powf(1.7);
Ok(c / d + (p[0] - p[2]) * (p[0] - p[2]) * 0.25)
}
}
fn gnarly_tower_at(prog: &GnarlyProgram, row: usize, p: [f64; 3]) -> Tower4<3> {
struct At<'a> {
base: &'a GnarlyProgram,
row: usize,
p: [f64; 3],
}
impl RowNllProgram<3> for At<'_> {
fn n_rows(&self) -> usize {
1
}
fn primaries(&self, row: usize) -> Result<[f64; 3], String> {
if row != 0 {
return Err(format!("gnarly-at: row {row} out of range"));
}
Ok(self.p)
}
fn row_nll(&self, eval_row: usize, vars: &[Tower4<3>; 3]) -> Result<Tower4<3>, String> {
if eval_row != 0 {
return Err(format!("gnarly-at: eval row {eval_row} out of range"));
}
self.base.row_nll(self.row, vars)
}
}
evaluate_program(&At { base: prog, row, p }, 0).expect("gnarly tower")
}
#[test]
fn gnarly_tower_is_fd_consistent_order_by_order() {
let prog = GnarlyProgram::fixture();
for row in 0..prog.n_rows() {
let base = prog.primaries(row).expect("primaries");
let t = gnarly_tower_at(&prog, row, base);
let h_step = 1e-5;
let tol = 1e-6;
for c in 0..3 {
let mut up = base;
let mut dn = base;
up[c] += h_step;
dn[c] -= h_step;
let t_up = gnarly_tower_at(&prog, row, up);
let t_dn = gnarly_tower_at(&prog, row, dn);
let fd_g = (t_up.v - t_dn.v) / (2.0 * h_step);
assert!(
(t.g[c] - fd_g).abs() <= tol * fd_g.abs().max(1.0),
"grad[{c}]: analytic {} fd {}",
t.g[c],
fd_g
);
for a in 0..3 {
let fd_h = (t_up.g[a] - t_dn.g[a]) / (2.0 * h_step);
assert!(
(t.h[a][c] - fd_h).abs() <= tol * fd_h.abs().max(1.0),
"hess[{a}][{c}]: analytic {} fd {}",
t.h[a][c],
fd_h
);
for b in 0..3 {
let fd_t3 = (t_up.h[a][b] - t_dn.h[a][b]) / (2.0 * h_step);
assert!(
(t.t3[a][b][c] - fd_t3).abs() <= tol * fd_t3.abs().max(1.0),
"t3[{a}][{b}][{c}]: analytic {} fd {}",
t.t3[a][b][c],
fd_t3
);
for d in 0..3 {
let fd_t4 = (t_up.t3[a][b][d] - t_dn.t3[a][b][d]) / (2.0 * h_step);
assert!(
(t.t4[a][b][d][c] - fd_t4).abs() <= tol * fd_t4.abs().max(1.0),
"t4[{a}][{b}][{d}][{c}]: analytic {} fd {}",
t.t4[a][b][d][c],
fd_t4
);
}
}
}
}
}
}
#[test]
fn implicit_solve_matches_scalar_resolve_to_fourth_order() {
const C: f64 = 1.7;
let f_scalar = |a: f64, th: [f64; 2]| a + th[0] * a * a + th[1] * a.exp() - C;
let f_da = |a: f64, th: [f64; 2]| 1.0 + 2.0 * th[0] * a + th[1] * a.exp();
let solve = |th: [f64; 2]| -> f64 {
let mut a = 0.0_f64;
for _ in 0..100 {
let r = f_scalar(a, th);
if r.abs() < 1e-14 {
break;
}
a -= r / f_da(a, th);
}
a
};
let f_tower = |a0: f64, th: [f64; 2]| -> Tower4<3> {
let a = Tower4::<3>::variable(a0, 0);
let t0 = Tower4::<3>::variable(th[0], 1);
let t1 = Tower4::<3>::variable(th[1], 2);
a + t0 * a.mul(&a) + t1 * a.exp() - C
};
let th0 = [0.35, 0.2];
let a0 = solve(th0);
let f = f_tower(a0, th0);
assert!(f.v.abs() < 1e-12, "constraint residual {:+.3e}", f.v);
let a_tower: Tower4<2> = implicit_solve::<3, 2>(&f, a0).expect("implicit solve");
let h = 1e-4;
let tol = 1e-5;
let re = |th: [f64; 2]| solve(th);
for i in 0..2 {
let mut up = th0;
let mut dn = th0;
up[i] += h;
dn[i] -= h;
let fd_g = (re(up) - re(dn)) / (2.0 * h);
assert!(
(a_tower.g[i] - fd_g).abs() <= tol * fd_g.abs().max(1.0),
"a_θ[{i}]: analytic {:+.6e} fd {:+.6e}",
a_tower.g[i],
fd_g
);
let grad_at = |th: [f64; 2], j: usize| -> f64 {
let mut up = th;
let mut dn = th;
up[j] += h;
dn[j] -= h;
(re(up) - re(dn)) / (2.0 * h)
};
for j in 0..2 {
let fd_h = (grad_at(up, j) - grad_at(dn, j)) / (2.0 * h);
assert!(
(a_tower.h[i][j] - fd_h).abs() <= 1e-3 * fd_h.abs().max(1.0),
"a_θθ[{i}][{j}]: analytic {:+.6e} fd {:+.6e}",
a_tower.h[i][j],
fd_h
);
}
}
}
#[test]
fn implicit_solve_matches_textbook_ift_recursion() {
let a0 = 0.4_f64;
let th = [0.25_f64, -0.15_f64];
let f = {
let a = Tower4::<3>::variable(a0, 0);
let t0 = Tower4::<3>::variable(th[0], 1);
let t1 = Tower4::<3>::variable(th[1], 2);
a * (t0 + 1.0) + t1 * a.mul(&a) + t0 * t1 - 0.4385
};
let a_t = implicit_solve::<3, 2>(&f, a0).expect("solve");
let f_a = f.g[0];
for u in 0..2 {
let want = -f.g[u + 1] / f_a;
assert!(
(a_t.g[u] - want).abs() < 1e-12,
"a_u[{u}] {:+.6e} vs −F_u/F_a {:+.6e}",
a_t.g[u],
want
);
}
for u in 0..2 {
for v in 0..2 {
let f_uv = f.h[u + 1][v + 1];
let f_au = f.h[0][u + 1];
let f_av = f.h[0][v + 1];
let f_aa = f.h[0][0];
let want =
-(f_uv + f_au * a_t.g[v] + f_av * a_t.g[u] + f_aa * a_t.g[u] * a_t.g[v]) / f_a;
assert!(
(a_t.h[u][v] - want).abs() < 1e-12,
"a_uv[{u}][{v}] {:+.6e} vs IFT {:+.6e}",
a_t.h[u][v],
want
);
}
}
}
#[test]
fn moving_boundary_flux_carries_b_zuv_term() {
use std::f64::consts::PI;
let b = |z: f64| (-0.5 * z * z).exp(); let integral = |z_r: f64| (PI / 2.0).sqrt() * libm::erf(z_r / 2.0_f64.sqrt());
let z_r = |th: [f64; 2]| th[0] + th[1] * th[1];
let th0 = [0.7_f64, 0.5_f64];
let mut z_edge = Tower4::<2>::constant(z_r(th0));
z_edge.g[0] = 1.0; z_edge.g[1] = 2.0 * th0[1]; z_edge.h[1][1] = 2.0;
let z0 = z_edge.v;
let b0 = b(z0);
let stack = [
b0,
-z0 * b0,
(z0 * z0 - 1.0) * b0,
(3.0 * z0 - z0 * z0 * z0) * b0,
];
let flux = moving_limit_boundary_tower(&z_edge, stack);
let h = 1e-4;
let tol = 1e-6;
for i in 0..2 {
let mut up = th0;
let mut dn = th0;
up[i] += h;
dn[i] -= h;
let fd_g = (integral(z_r(up)) - integral(z_r(dn))) / (2.0 * h);
assert!(
(flux.g[i] - fd_g).abs() <= tol * fd_g.abs().max(1.0),
"flux_g[{i}]: analytic {:+.8e} fd {:+.8e}",
flux.g[i],
fd_g
);
}
let grad1_at = |th: [f64; 2]| -> f64 {
let mut up = th;
let mut dn = th;
up[1] += h;
dn[1] -= h;
(integral(z_r(up)) - integral(z_r(dn))) / (2.0 * h)
};
let mut up = th0;
let mut dn = th0;
up[1] += h;
dn[1] -= h;
let fd_h11 = (grad1_at(up) - grad1_at(dn)) / (2.0 * h);
assert!(
(flux.h[1][1] - fd_h11).abs() <= 1e-3 * fd_h11.abs().max(1.0),
"flux_h[1][1] (carries B·z_uv): analytic {:+.8e} fd {:+.8e}",
flux.h[1][1],
fd_h11
);
let pure_zu2 = stack[1] * z_edge.g[1] * z_edge.g[1];
let b_zuv = flux.h[1][1] - pure_zu2;
assert!(
(b_zuv - b0 * 2.0).abs() < 1e-10,
"B·z_uv term {:+.8e} != B₀·z_uv {:+.8e}",
b_zuv,
b0 * 2.0
);
}
#[test]
fn moving_boundary_theta_integrand_matches_handpath_and_closed_form() {
let g = |z: f64, t0: f64| (z * t0).exp();
let phi = |z: f64, t0: f64| ((z * t0).exp() - 1.0) / t0;
let z_r = |th: [f64; 2]| 0.6 + th[0] + th[1] * th[1];
let th0 = [0.4_f64, 0.5_f64];
let z0 = z_r(th0);
let mut z_edge = Tower4::<2>::constant(z0);
z_edge.g[0] = 1.0; z_edge.g[1] = 2.0 * th0[1]; z_edge.h[1][1] = 2.0;
let z_var = Tower4::<3>::variable(z0, 0);
let t0_var = Tower4::<3>::variable(th0[0], 1);
let _t1_var = Tower4::<3>::variable(th0[1], 2);
let phi_jet = ((z_var * t0_var).exp() - 1.0) / t0_var;
assert!(
(phi_jet.g[0] - g(z0, th0[0])).abs() < 1e-12,
"Φ_z {:+.8e} != G {:+.8e}",
phi_jet.g[0],
g(z0, th0[0])
);
let flux = moving_limit_boundary_tower_theta_integrand::<3, 2>(&phi_jet, &z_edge);
assert!(
flux.v.abs() < 1e-12,
"boundary value channel {:+.3e}",
flux.v
);
let bnd = |th: [f64; 2]| phi(z_r(th), th[0]) - phi(z0, th[0]);
let h = 1e-4;
let tol = 1e-6;
for i in 0..2 {
let mut up = th0;
let mut dn = th0;
up[i] += h;
dn[i] -= h;
let fd_g = (bnd(up) - bnd(dn)) / (2.0 * h);
assert!(
(flux.g[i] - fd_g).abs() <= tol * fd_g.abs().max(1.0),
"boundary_g[{i}] analytic {:+.8e} fd {:+.8e}",
flux.g[i],
fd_g
);
}
let grad_at = |th: [f64; 2], j: usize| -> f64 {
let mut up = th;
let mut dn = th;
up[j] += h;
dn[j] -= h;
(bnd(up) - bnd(dn)) / (2.0 * h)
};
for i in 0..2 {
for j in 0..2 {
let mut up = th0;
let mut dn = th0;
up[i] += h;
dn[i] -= h;
let fd_h = (grad_at(up, j) - grad_at(dn, j)) / (2.0 * h);
assert!(
(flux.h[i][j] - fd_h).abs() <= 1e-3 * fd_h.abs().max(1.0),
"boundary_h[{i}][{j}] analytic {:+.8e} fd {:+.8e}",
flux.h[i][j],
fd_h
);
}
}
let gg = g(z0, th0[0]);
let g_z = th0[0] * gg;
let g_theta = [z0 * gg, 0.0]; for i in 0..2 {
for j in 0..2 {
let z_u = z_edge.g[i];
let z_v = z_edge.g[j];
let z_uv = z_edge.h[i][j];
let hand = gg * z_uv + g_z * z_u * z_v + g_theta[i] * z_v + g_theta[j] * z_u;
assert!(
(flux.h[i][j] - hand).abs() < 1e-9,
"boundary_h[{i}][{j}] {:+.8e} != hand closure {:+.8e}",
flux.h[i][j],
hand
);
}
}
let pure_no_zuv = g_z * z_edge.g[1] * z_edge.g[1] + 2.0 * g_theta[1] * z_edge.g[1];
let g_zuv = flux.h[1][1] - pure_no_zuv;
assert!(
(g_zuv - gg * 2.0).abs() < 1e-9,
"G·z_uv term {:+.8e} != G₀·z_uv {:+.8e}",
g_zuv,
gg * 2.0
);
}
#[test]
fn crossing_edge_tower_matches_handpath_velocity_formulas() {
const TAU: f64 = 1.3; let g_idx = 1usize;
let g0 = 0.85_f64; let mut a = Tower4::<3>::constant(0.45);
a.g[0] = 0.7;
a.g[1] = -0.3;
a.h[0][0] = 0.25;
a.h[0][1] = 0.11;
a.h[1][0] = 0.11;
a.h[1][1] = -0.08;
let b = Tower4::<3>::variable(g0, g_idx);
let z_edge = (Tower4::<3>::constant(TAU) - a) / b;
let bv = g0;
let z0 = z_edge.v;
assert!((z0 - (TAU - 0.45) / bv).abs() < 1e-12);
for u in 0..2 {
let direct = if u == g_idx { z0 } else { 0.0 };
let want = -(a.g[u] + direct) / bv;
assert!(
(z_edge.g[u] - want).abs() < 1e-10,
"z_u[{u}] {:+.8e} vs hand formula {:+.8e}",
z_edge.g[u],
want
);
}
for u in 0..2 {
for v in 0..2 {
let cross = if u == g_idx { z_edge.g[v] } else { 0.0 }
+ if v == g_idx { z_edge.g[u] } else { 0.0 };
let want = -(a.h[u][v] + cross) / bv;
assert!(
(z_edge.h[u][v] - want).abs() < 1e-10,
"z_uv[{u}][{v}] {:+.8e} vs hand formula {:+.8e}",
z_edge.h[u][v],
want
);
}
}
}
#[test]
fn crossing_edge_constraint_frame_matches_bare_velocity_constants() {
const TAU: f64 = 1.3;
let a0 = 0.45_f64;
let b0 = 0.85_f64;
let a = Tower4::<2>::variable(a0, 0);
let b = Tower4::<2>::variable(b0, 1);
let z = (Tower4::<2>::constant(TAU) - a) / b;
assert!((z.v - (TAU - a0) / b0).abs() < 1e-12);
assert!((z.g[0] - (-1.0 / b0)).abs() < 1e-12, "z_a {:+.10e}", z.g[0]);
assert!(
(z.h[0][1] - 1.0 / (b0 * b0)).abs() < 1e-12,
"z_ab {:+.10e} vs +1/b² {:+.10e}",
z.h[0][1],
1.0 / (b0 * b0)
);
assert!(
z.h[0][0].abs() < 1e-12,
"z_aa must vanish, got {:+.10e}",
z.h[0][0]
);
let want_zbb = 2.0 * (TAU - a0) / (b0 * b0 * b0);
assert!(
(z.h[1][1] - want_zbb).abs() < 1e-12,
"z_bb {:+.10e} vs 2(τ−a)/b³ {:+.10e}",
z.h[1][1],
want_zbb
);
}
#[test]
fn oracle_catches_planted_cross_block_sign_flip() {
let prog = LocScaleProgram {
eta: vec![0.3],
s: vec![-0.5],
y: vec![1.0],
};
let t = evaluate_program(&prog, 0).expect("tower");
let dir = [0.6, -0.2];
let mut third = t.third_contracted(&dir);
let honest = KernelChannels {
value: t.v,
gradient: t.g,
hessian: t.h,
third: vec![(dir, third)],
fourth: vec![(dir, [1.0, 0.5], t.fourth_contracted(&dir, &[1.0, 0.5]))],
};
verify_kernel_channels(&t, &honest, 1e-10).expect("honest kernel must pass");
third[0][1] = -third[0][1];
let flipped = KernelChannels {
value: t.v,
gradient: t.g,
hessian: t.h,
third: vec![(dir, third)],
fourth: vec![],
};
let err = verify_kernel_channels(&t, &flipped, 1e-10)
.expect_err("planted sign flip must be caught");
assert!(
err.contains("third[0][0][1]"),
"oracle must name the flipped channel, got: {err}"
);
}
#[test]
fn t3_t4_are_fully_index_symmetric() {
let prog = GnarlyProgram::fixture();
let perms3: [[usize; 3]; 6] = [
[0, 1, 2],
[0, 2, 1],
[1, 0, 2],
[1, 2, 0],
[2, 0, 1],
[2, 1, 0],
];
let perms4: [[usize; 4]; 24] = [
[0, 1, 2, 3],
[0, 1, 3, 2],
[0, 2, 1, 3],
[0, 2, 3, 1],
[0, 3, 1, 2],
[0, 3, 2, 1],
[1, 0, 2, 3],
[1, 0, 3, 2],
[1, 2, 0, 3],
[1, 2, 3, 0],
[1, 3, 0, 2],
[1, 3, 2, 0],
[2, 0, 1, 3],
[2, 0, 3, 1],
[2, 1, 0, 3],
[2, 1, 3, 0],
[2, 3, 0, 1],
[2, 3, 1, 0],
[3, 0, 1, 2],
[3, 0, 2, 1],
[3, 1, 0, 2],
[3, 1, 2, 0],
[3, 2, 0, 1],
[3, 2, 1, 0],
];
for row in 0..prog.n_rows() {
let t = evaluate_program(&prog, row).expect("gnarly tower");
let scale_t3 =
t.t3.iter()
.flatten()
.flatten()
.fold(0.0_f64, |m, x| m.max(x.abs()))
.max(1.0);
let scale_t4 =
t.t4.iter()
.flatten()
.flatten()
.flatten()
.fold(0.0_f64, |m, x| m.max(x.abs()))
.max(1.0);
for i in 0..3 {
for j in 0..3 {
for k in 0..3 {
let base = t.t3[i][j][k];
let idx = [i, j, k];
for p in &perms3 {
let permed = t.t3[idx[p[0]]][idx[p[1]]][idx[p[2]]];
assert!(
(base - permed).abs() <= 1e-12 * scale_t3,
"row {row}: t3[{i}][{j}][{k}]={base:+.15e} != \
permuted {permed:+.15e} under {p:?}"
);
}
for l in 0..3 {
let base4 = t.t4[i][j][k][l];
let idx4 = [i, j, k, l];
for p in &perms4 {
let permed = t.t4[idx4[p[0]]][idx4[p[1]]][idx4[p[2]]][idx4[p[3]]];
assert!(
(base4 - permed).abs() <= 1e-12 * scale_t4,
"row {row}: t4[{i}][{j}][{k}][{l}]={base4:+.15e} != \
permuted {permed:+.15e} under {p:?}"
);
}
}
}
}
}
}
}
}
#[inline]
fn erfcx_nonnegative(x: f64) -> f64 {
if !x.is_finite() {
return if x.is_sign_positive() {
0.0
} else {
f64::INFINITY
};
}
if x <= 0.0 {
return 1.0;
}
if x < 26.0 {
((x * x).min(700.0)).exp() * statrs::function::erf::erfc(x)
} else {
let inv = 1.0 / x;
let inv2 = inv * inv;
let poly = 1.0 - 0.5 * inv2 + 0.75 * inv2 * inv2 - 1.875 * inv2 * inv2 * inv2
+ 6.5625 * inv2 * inv2 * inv2 * inv2;
inv * poly / std::f64::consts::PI.sqrt()
}
}
#[inline]
fn log1mexp_positive(a: f64) -> f64 {
assert!(a >= 0.0, "log1mexp_positive requires a >= 0: a={a}");
if a > core::f64::consts::LN_2 {
(-(-a).exp()).ln_1p()
} else if a > 0.0 {
(-(-a).exp_m1()).ln()
} else {
f64::NEG_INFINITY
}
}
#[inline]
fn signed_probit_logcdf_and_mills_ratio(x: f64) -> (f64, f64) {
if x == f64::INFINITY {
return (0.0, 0.0);
}
if x == f64::NEG_INFINITY {
return (f64::NEG_INFINITY, f64::INFINITY);
}
if x.is_nan() {
return (f64::NAN, f64::NAN);
}
if x < 0.0 {
let u = -x / std::f64::consts::SQRT_2;
let ex = erfcx_nonnegative(u).max(1e-300);
let log_cdf = -u * u + (0.5 * ex).ln();
let lambda = (2.0 / std::f64::consts::PI).sqrt() / ex;
(log_cdf, lambda)
} else {
let cdf = crate::probability::normal_cdf(x).clamp(1e-300, 1.0);
let lambda = crate::probability::normal_pdf(x) / cdf;
(cdf.ln(), lambda)
}
}
#[inline]
pub fn unary_derivatives_normal_logcdf(x: f64) -> [f64; 5] {
let (log_cdf, lambda) = signed_probit_logcdf_and_mills_ratio(x);
let lambda2 = lambda * lambda;
let lambda3 = lambda2 * lambda;
let x2 = x * x;
[
log_cdf,
lambda,
-lambda * (x + lambda),
lambda * (x2 - 1.0 + 3.0 * x * lambda + 2.0 * lambda2),
-lambda
* ((x * x2 - 3.0 * x) + (7.0 * x2 - 4.0) * lambda + 12.0 * x * lambda2 + 6.0 * lambda3),
]
}
#[inline]
pub fn unary_derivatives_log1mexp_positive(x: f64) -> [f64; 5] {
let r = 1.0 / x.exp_m1();
[
log1mexp_positive(x),
r,
-r * (1.0 + r),
r * (1.0 + r) * (1.0 + 2.0 * r),
-r * (1.0 + r) * (1.0 + 6.0 * r + 6.0 * r * r),
]
}
#[cfg(test)]
mod rowjet_bridge_tests {
use super::*;
use crate::jet_scalar::{JetScalar, Order2};
struct ToyProgram {
primaries: Vec<[f64; 2]>,
aux: Vec<[f64; 3]>,
}
impl ToyProgram {
fn body<R: RowJet<2>>(&self, rows: &[usize], p: &[R; 2]) -> R {
let cov = R::pack_rows(rows, |r| self.aux[r][0]);
let z = R::pack_rows(rows, |r| self.aux[r][1]);
let wi = R::pack_rows(rows, |r| self.aux[r][2]);
let a = p[0].mul(&p[1]).scale_rows(cov);
let b = a.add(&R::constant(0.5)).sub(&p[0].scale(0.25));
let c = b
.compose_unary_with(|u| {
let e = u.exp();
[e, e, e, e, e]
})
.scale_rows(z);
let d = c.neg().add(&p[0]);
let e = d
.compose_unary_with(|u| {
let s = (1.0 + u * u).sqrt();
let s3 = s * s * s;
let s5 = s3 * s * s;
let s7 = s5 * s * s;
[s, u / s, 1.0 / s3, -3.0 * u / s5, (12.0 * u * u - 3.0) / s7]
})
.scale_rows(wi);
e.mul(&p[1]).add(&e)
}
}
impl RowNllProgramRowJet<2> for ToyProgram {
fn n_rows(&self) -> usize {
self.primaries.len()
}
fn primaries(&self, row: usize) -> Result<[f64; 2], String> {
Ok(self.primaries[row])
}
fn row_nll<R: RowJet<2>>(&self, rows: &[usize], p: &[R; 2]) -> Result<R, String> {
assert!(rows.len() == 1 || rows.len() == 4, "lane→row map is 1 or 4 wide");
Ok(self.body(rows, p))
}
}
fn assert_t4_bits_eq(a: &Tower4<2>, b: &Tower4<2>, ctx: &str) {
assert_eq!(a.v.to_bits(), b.v.to_bits(), "{ctx}: v");
for i in 0..2 {
assert_eq!(a.g[i].to_bits(), b.g[i].to_bits(), "{ctx}: g[{i}]");
for j in 0..2 {
assert_eq!(a.h[i][j].to_bits(), b.h[i][j].to_bits(), "{ctx}: h[{i}][{j}]");
for k in 0..2 {
assert_eq!(
a.t3[i][j][k].to_bits(),
b.t3[i][j][k].to_bits(),
"{ctx}: t3[{i}][{j}][{k}]"
);
for l in 0..2 {
assert_eq!(
a.t4[i][j][k][l].to_bits(),
b.t4[i][j][k][l].to_bits(),
"{ctx}: t4[{i}][{j}][{k}][{l}]"
);
}
}
}
}
}
fn assert_t3_bits_eq(a: &Tower3<2>, b: &Tower3<2>, ctx: &str) {
assert_eq!(a.v.to_bits(), b.v.to_bits(), "{ctx}: v");
for i in 0..2 {
assert_eq!(a.g[i].to_bits(), b.g[i].to_bits(), "{ctx}: g[{i}]");
for j in 0..2 {
assert_eq!(a.h[i][j].to_bits(), b.h[i][j].to_bits(), "{ctx}: h[{i}][{j}]");
for k in 0..2 {
assert_eq!(
a.t3[i][j][k].to_bits(),
b.t3[i][j][k].to_bits(),
"{ctx}: t3[{i}][{j}][{k}]"
);
}
}
}
}
struct Lcg(u64);
impl Lcg {
fn next(&mut self) -> f64 {
self.0 = self
.0
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
(self.0 >> 11) as f64 / (1u64 << 53) as f64
}
fn val(&mut self) -> f64 {
let u = self.next();
if u < 0.04 {
return 0.0;
}
if u < 0.08 {
return -0.0;
}
(self.next() - 0.5) * 5.0
}
}
#[test]
fn batched_lane_i_matches_scalar_row_i_bit_identical() {
let mut rng = Lcg(0xA5A5_1234_DEAD_BEEF);
let mut batches = 0usize;
for _ in 0..2500 {
let bases: [[f64; 2]; 4] = std::array::from_fn(|_| std::array::from_fn(|_| rng.val()));
let aux: [[f64; 3]; 4] = std::array::from_fn(|_| std::array::from_fn(|_| rng.val()));
let prog = ToyProgram { primaries: bases.to_vec(), aux: aux.to_vec() };
let rows = [0usize, 1, 2, 3];
let batch4 = generic_batched_fourth_tower(&prog, rows).expect("batch4");
for (row, base) in bases.iter().enumerate() {
let vars: [Tower4<2>; 2] =
std::array::from_fn(|a| <Tower4<2> as RowJet<2>>::variable(base[a], a));
let scal = prog.row_nll(&[row], &vars).expect("scalar tower4");
assert_t4_bits_eq(&batch4.lane(row), &scal, "batched_fourth");
}
let batch3 = generic_batched_third_tower(&prog, rows).expect("batch3");
for (row, base) in bases.iter().enumerate() {
let vars: [Tower3<2>; 2] =
std::array::from_fn(|a| <Tower3<2> as RowJet<2>>::variable(base[a], a));
let scal = prog.row_nll(&[row], &vars).expect("scalar tower3");
assert_t3_bits_eq(&batch3.lane(row), &scal, "batched_third");
}
batches += 1;
}
assert_eq!(batches, 2500);
}
#[test]
fn blanket_scalar_path_is_unchanged_and_consistent() {
let mut rng = Lcg(0x0BAD_F00D_1357_2468);
for _ in 0..3000 {
let base: [f64; 2] = std::array::from_fn(|_| rng.val());
let aux0: [f64; 3] = std::array::from_fn(|_| rng.val());
let prog = ToyProgram { primaries: vec![base], aux: vec![aux0] };
let via_rowjet: Tower4<2> = {
let vars: [Tower4<2>; 2] =
std::array::from_fn(|a| <Tower4<2> as RowJet<2>>::variable(base[a], a));
prog.row_nll(&[0], &vars).expect("rowjet")
};
let via_jetscalar: Tower4<2> = {
let vars: [Tower4<2>; 2] = std::array::from_fn(|a| {
<Tower4<2> as JetScalar<2>>::variable(base[a], a)
});
let (cov, z, wi) = (aux0[0], aux0[1], aux0[2]);
let a = vars[0].mul(&vars[1]).scale(cov);
let b = a.add(&Tower4::constant(0.5)).sub(&vars[0].scale(0.25));
let c = b
.compose_unary_with(|u| {
let e = u.exp();
[e, e, e, e, e]
})
.scale(z);
let d = JetScalar::neg(&c).add(&vars[0]);
let e = d
.compose_unary_with(|u| {
let s = (1.0 + u * u).sqrt();
let s3 = s * s * s;
let s5 = s3 * s * s;
let s7 = s5 * s * s;
[s, u / s, 1.0 / s3, -3.0 * u / s5, (12.0 * u * u - 3.0) / s7]
})
.scale(wi);
e.mul(&vars[1]).add(&e)
};
assert_t4_bits_eq(&via_rowjet, &via_jetscalar, "blanket_vs_direct");
let (v, g, h) = rowjet_row_kernel(&prog, 0).expect("kernel");
assert_eq!(v.to_bits(), via_rowjet.v.to_bits(), "kernel v");
for i in 0..2 {
assert!(g[i] == via_rowjet.g[i], "kernel g[{i}]: {} vs {}", g[i], via_rowjet.g[i]);
for j in 0..2 {
assert!(
h[i][j] == via_rowjet.h[i][j],
"kernel h[{i}][{j}]: {} vs {}",
h[i][j],
via_rowjet.h[i][j]
);
}
}
let o2: [Order2<2>; 2] =
std::array::from_fn(|a| <Order2<2> as RowJet<2>>::variable(base[a], a));
let _ = prog.body(&[0], &o2);
}
}
#[test]
fn scale_rows_scalar_is_bit_identical_to_scale() {
let mut rng = Lcg(0xFEED_FACE_0042_1001);
for _ in 0..3000 {
let base: [f64; 2] = std::array::from_fn(|_| rng.val());
let s = rng.val();
let vars: [Tower4<2>; 2] =
std::array::from_fn(|a| <Tower4<2> as RowJet<2>>::variable(base[a], a));
let jet = vars[0].mul(&vars[1]).compose_unary_with(|u| {
let e = u.exp();
[e, e, e, e, e]
});
let via_scale = RowJet::scale(&jet, s);
let via_scale_rows = RowJet::scale_rows(&jet, s);
assert_t4_bits_eq(&via_scale_rows, &via_scale, "scale_rows==scale");
}
}
#[test]
fn batched_scale_rows_matches_per_row_scalar_scale() {
let mut rng = Lcg(0x1357_9BDF_2468_ACE0);
for _ in 0..2500 {
let bases: [[f64; 2]; 4] = std::array::from_fn(|_| std::array::from_fn(|_| rng.val()));
let s: [f64; 4] = std::array::from_fn(|_| rng.val());
let batch: [Tower4Batch<2>; 2] = std::array::from_fn(|a| {
Tower4Batch::variable(
wide::f64x4::new([bases[0][a], bases[1][a], bases[2][a], bases[3][a]]),
a,
)
});
let prod = batch[0].mul(&batch[1]).compose_unary_with(|u| {
let e = u.exp();
[e, e, e, e, e]
});
let scaled = prod.scale_rows(s);
for (row, base) in bases.iter().enumerate() {
let v: [Tower4<2>; 2] =
std::array::from_fn(|a| <Tower4<2> as RowJet<2>>::variable(base[a], a));
let prod_s = v[0].mul(&v[1]).compose_unary_with(|u| {
let e = u.exp();
[e, e, e, e, e]
});
let ref_s = RowJet::scale(&prod_s, s[row]);
assert_t4_bits_eq(&scaled.lane(row), &ref_s, "batched_scale_rows");
}
}
}
#[test]
fn guard_reports_per_lane_failures() {
let cols: [[f64; 2]; 4] = [[1.0, 0.5], [-2.0, 0.5], [3.0, 0.5], [-0.0, 0.5]];
let vars: [Tower4Batch<2>; 2] = std::array::from_fn(|a| {
Tower4Batch::variable(
wide::f64x4::new([cols[0][a], cols[1][a], cols[2][a], cols[3][a]]),
a,
)
});
let verdict = vars[0].guard(|v| v > 0.0);
assert_eq!(verdict.lanes(), 4);
assert!(verdict.any_failed());
assert!(!verdict.all_pass());
assert!(!verdict.lane_failed(0));
assert!(verdict.lane_failed(1));
assert!(!verdict.lane_failed(2));
assert!(verdict.lane_failed(3));
assert_eq!(verdict.failed_mask(), 0b1010);
let s_ok = <Tower4<2> as RowJet<2>>::variable(1.0, 0);
let s_bad = <Tower4<2> as RowJet<2>>::variable(-1.0, 0);
assert!(RowJet::guard(&s_ok, |v| v > 0.0).all_pass());
assert!(RowJet::guard(&s_bad, |v| v > 0.0).any_failed());
assert_eq!(RowJet::guard(&s_ok, |v| v > 0.0).lanes(), 1);
}
}