use super::jet_algebra;
#[derive(Clone, Copy, Debug)]
pub struct Tower4<const K: usize> {
pub v: f64,
pub g: [f64; K],
pub h: [[f64; K]; K],
pub t3: [[[f64; K]; K]; K],
pub t4: [[[[f64; K]; K]; K]; K],
}
impl<const K: usize> Tower4<K> {
pub fn zero() -> Self {
Self {
v: 0.0,
g: [0.0; K],
h: [[0.0; K]; K],
t3: [[[0.0; K]; K]; K],
t4: [[[[0.0; K]; K]; K]; K],
}
}
pub fn constant(c: f64) -> Self {
let mut out = Self::zero();
out.v = c;
out
}
pub fn variable(value: f64, idx: usize) -> Self {
let mut out = Self::constant(value);
out.g[idx] = 1.0;
out
}
#[inline]
fn deriv(&self, labels: &[usize]) -> f64 {
assert!(
labels.len() <= 4,
"Tower4 carries at most fourth-order derivatives"
);
match labels.len() {
0 => self.v,
1 => self.g[labels[0]],
2 => self.h[labels[0]][labels[1]],
3 => self.t3[labels[0]][labels[1]][labels[2]],
_ => self.t4[labels[0]][labels[1]][labels[2]][labels[3]],
}
}
pub fn mul(&self, o: &Self) -> Self {
let a = self;
let b = o;
let mut out = Self::zero();
out.v = a.v * b.v;
for i in 0..K {
let labels = [i];
out.g[i] = jet_algebra::leibniz_product(&labels, |t| a.deriv(t), |c| b.deriv(c));
}
for i in 0..K {
for j in 0..K {
let labels = [i, j];
out.h[i][j] = jet_algebra::leibniz_product(&labels, |t| a.deriv(t), |c| b.deriv(c));
}
}
for i in 0..K {
for j in 0..K {
for k in 0..K {
let labels = [i, j, k];
out.t3[i][j][k] =
jet_algebra::leibniz_product(&labels, |t| a.deriv(t), |c| b.deriv(c));
}
}
}
for i in 0..K {
for j in 0..K {
for k in 0..K {
for l in 0..K {
let labels = [i, j, k, l];
out.t4[i][j][k][l] =
jet_algebra::leibniz_product(&labels, |t| a.deriv(t), |c| b.deriv(c));
}
}
}
}
out
}
pub fn compose_unary(&self, d: [f64; 5]) -> Self {
<Self as jet_algebra::JetAlgebra<5>>::compose_unary(self, d)
}
pub fn scale(&self, s: f64) -> Self {
let mut out = *self;
out.v *= s;
for i in 0..K {
out.g[i] *= s;
for j in 0..K {
out.h[i][j] *= s;
for k in 0..K {
out.t3[i][j][k] *= s;
for l in 0..K {
out.t4[i][j][k][l] *= s;
}
}
}
}
out
}
pub fn exp(&self) -> Self {
let e = self.v.exp();
self.compose_unary([e, e, e, e, e])
}
pub fn ln(&self) -> Self {
let u = self.v;
let r = 1.0 / u;
self.compose_unary([u.ln(), r, -r * r, 2.0 * r * r * r, -6.0 * r * r * r * r])
}
pub fn recip(&self) -> Self {
let r = 1.0 / self.v;
let r2 = r * r;
self.compose_unary([r, -r2, 2.0 * r2 * r, -6.0 * r2 * r2, 24.0 * r2 * r2 * r])
}
pub fn sqrt(&self) -> Self {
let u = self.v;
let s = u.sqrt();
self.compose_unary([
s,
0.5 / s,
-0.25 / (u * s),
0.375 / (u * u * s),
-0.9375 / (u * u * u * s),
])
}
pub fn powf(&self, a: f64) -> Self {
let u = self.v;
let f0 = u.powf(a);
let f1 = a * u.powf(a - 1.0);
let f2 = a * (a - 1.0) * u.powf(a - 2.0);
let f3 = a * (a - 1.0) * (a - 2.0) * u.powf(a - 3.0);
let f4 = a * (a - 1.0) * (a - 2.0) * (a - 3.0) * u.powf(a - 4.0);
self.compose_unary([f0, f1, f2, f3, f4])
}
pub fn ln_gamma(&self) -> Self {
self.compose_unary(ln_gamma_derivative_stack(self.v))
}
pub fn digamma(&self) -> Self {
self.compose_unary(digamma_derivative_stack(self.v))
}
pub fn trigamma(&self) -> Self {
self.compose_unary(trigamma_derivative_stack(self.v))
}
pub fn third_contracted(&self, dir: &[f64; K]) -> [[f64; K]; K] {
let mut out = [[0.0; K]; K];
for a in 0..K {
for b in 0..K {
let mut acc = 0.0;
for c in 0..K {
acc += self.t3[a][b][c] * dir[c];
}
out[a][b] = acc;
}
}
out
}
pub fn fourth_contracted(&self, u: &[f64; K], w: &[f64; K]) -> [[f64; K]; K] {
let mut out = [[0.0; K]; K];
for i in 0..K {
for j in 0..K {
let mut acc = 0.0;
for k in 0..K {
for l in 0..K {
acc += self.t4[i][j][k][l] * u[k] * w[l];
}
}
out[i][j] = acc;
}
}
out
}
}
impl<const K: usize> jet_algebra::JetAlgebra<5> for Tower4<K> {
#[inline]
fn derivative(&self, labels: &[usize]) -> f64 {
self.deriv(labels)
}
fn map_derivatives<F>(&self, mut f: F) -> Self
where
F: FnMut(&[usize]) -> f64,
{
let mut out = Self::zero();
out.v = f(&[]);
for i in 0..K {
let labels = [i];
out.g[i] = f(&labels);
}
for i in 0..K {
for j in 0..K {
let labels = [i, j];
out.h[i][j] = f(&labels);
}
}
for i in 0..K {
for j in 0..K {
for k in 0..K {
let labels = [i, j, k];
out.t3[i][j][k] = f(&labels);
}
}
}
for i in 0..K {
for j in 0..K {
for k in 0..K {
for l in 0..K {
let labels = [i, j, k, l];
out.t4[i][j][k][l] = f(&labels);
}
}
}
}
out
}
}
#[derive(Clone, Copy, Debug)]
pub struct Tower2<const K: usize> {
pub v: f64,
pub g: [f64; K],
pub h: [[f64; K]; K],
}
impl<const K: usize> Tower2<K> {
pub fn zero() -> Self {
Self {
v: 0.0,
g: [0.0; K],
h: [[0.0; K]; K],
}
}
pub fn constant(c: f64) -> Self {
let mut out = Self::zero();
out.v = c;
out
}
pub fn variable(value: f64, idx: usize) -> Self {
let mut out = Self::constant(value);
out.g[idx] = 1.0;
out
}
#[inline]
fn deriv(&self, labels: &[usize]) -> f64 {
assert!(
labels.len() <= 2,
"Tower2 carries at most second-order derivatives"
);
match labels.len() {
0 => self.v,
1 => self.g[labels[0]],
_ => self.h[labels[0]][labels[1]],
}
}
pub fn mul(&self, o: &Self) -> Self {
let a = self;
let b = o;
let mut out = Self::zero();
out.v = a.v * b.v;
for i in 0..K {
out.g[i] = a.v * b.g[i] + a.g[i] * b.v;
}
for i in 0..K {
for j in 0..K {
out.h[i][j] = a.v * b.h[i][j] + a.g[i] * b.g[j] + a.g[j] * b.g[i] + a.h[i][j] * b.v;
}
}
out
}
pub fn compose_unary(&self, d: [f64; 3]) -> Self {
<Self as jet_algebra::JetAlgebra<3>>::compose_unary(self, d)
}
pub fn scale(&self, s: f64) -> Self {
let mut out = *self;
out.v *= s;
for i in 0..K {
out.g[i] *= s;
for j in 0..K {
out.h[i][j] *= s;
}
}
out
}
pub fn exp(&self) -> Self {
let e = self.v.exp();
self.compose_unary([e, e, e])
}
pub fn sqrt(&self) -> Self {
let u = self.v;
let s = u.sqrt();
self.compose_unary([s, 0.5 / s, -0.25 / (u * s)])
}
}
impl<const K: usize> jet_algebra::JetAlgebra<3> for Tower2<K> {
#[inline]
fn derivative(&self, labels: &[usize]) -> f64 {
self.deriv(labels)
}
fn map_derivatives<F>(&self, mut f: F) -> Self
where
F: FnMut(&[usize]) -> f64,
{
let mut out = Self::zero();
out.v = f(&[]);
for i in 0..K {
let labels = [i];
out.g[i] = f(&labels);
}
for i in 0..K {
for j in 0..K {
let labels = [i, j];
out.h[i][j] = f(&labels);
}
}
out
}
}
impl<const K: usize> std::ops::Add for Tower2<K> {
type Output = Self;
fn add(self, o: Self) -> Self {
let mut out = self;
out.v += o.v;
for i in 0..K {
out.g[i] += o.g[i];
for j in 0..K {
out.h[i][j] += o.h[i][j];
}
}
out
}
}
impl<const K: usize> std::ops::Mul for Tower2<K> {
type Output = Self;
fn mul(self, o: Self) -> Self {
Tower2::mul(&self, &o)
}
}
impl<const K: usize> std::ops::Add<f64> for Tower2<K> {
type Output = Self;
fn add(self, c: f64) -> Self {
let mut out = self;
out.v += c;
out
}
}
impl<const K: usize> std::ops::Mul<f64> for Tower2<K> {
type Output = Self;
fn mul(self, c: f64) -> Self {
self.scale(c)
}
}
pub fn ln_gamma_derivative_stack(x: f64) -> [f64; 5] {
[
statrs::function::gamma::ln_gamma(x),
digamma_positive(x),
polygamma_positive(1, x),
polygamma_positive(2, x),
polygamma_positive(3, x),
]
}
pub fn digamma_derivative_stack(x: f64) -> [f64; 5] {
[
digamma_positive(x),
polygamma_positive(1, x),
polygamma_positive(2, x),
polygamma_positive(3, x),
polygamma_positive(4, x),
]
}
pub fn trigamma_derivative_stack(x: f64) -> [f64; 5] {
[
polygamma_positive(1, x),
polygamma_positive(2, x),
polygamma_positive(3, x),
polygamma_positive(4, x),
polygamma_positive(5, x),
]
}
fn digamma_positive(mut x: f64) -> f64 {
if !(x.is_finite() && x > 0.0) {
return f64::NAN;
}
let mut acc = 0.0;
while x < POLYGAMMA_ASYMPTOTIC_MIN_X {
acc -= 1.0 / x;
x += 1.0;
}
acc + digamma_asymptotic(x)
}
fn polygamma_positive(order: usize, mut x: f64) -> f64 {
if !(x.is_finite() && x > 0.0) {
return f64::NAN;
}
let mut acc = 0.0;
while x < POLYGAMMA_ASYMPTOTIC_MIN_X {
acc += polygamma_recurrence_term(order, x);
x += 1.0;
}
acc + polygamma_asymptotic(order, x)
}
const POLYGAMMA_ASYMPTOTIC_MIN_X: f64 = 20.0;
const BERNOULLI_EVEN: [(usize, f64); 10] = [
(2, 1.0 / 6.0),
(4, -1.0 / 30.0),
(6, 1.0 / 42.0),
(8, -1.0 / 30.0),
(10, 5.0 / 66.0),
(12, -691.0 / 2730.0),
(14, 7.0 / 6.0),
(16, -3617.0 / 510.0),
(18, 43867.0 / 798.0),
(20, -174611.0 / 330.0),
];
fn polygamma_recurrence_term(order: usize, x: f64) -> f64 {
let sign = if order % 2 == 1 { 1.0 } else { -1.0 };
sign * factorial(order) / x.powi((order + 1) as i32)
}
fn digamma_asymptotic(x: f64) -> f64 {
let mut out = x.ln() - 0.5 / x;
for (bernoulli_order, bernoulli) in BERNOULLI_EVEN {
out -= bernoulli / (bernoulli_order as f64 * x.powi(bernoulli_order as i32));
}
out
}
fn polygamma_asymptotic(order: usize, x: f64) -> f64 {
if !(1..=5).contains(&order) {
return f64::NAN;
}
let order_factorial = factorial(order);
let leading_sign = if order % 2 == 1 { 1.0 } else { -1.0 };
let mut out = leading_sign * factorial(order - 1) / x.powi(order as i32)
+ leading_sign * order_factorial / (2.0 * x.powi((order + 1) as i32));
let bernoulli_sign = if order % 2 == 1 { 1.0 } else { -1.0 };
for (bernoulli_order, bernoulli) in BERNOULLI_EVEN {
let rising = rising_factorial(bernoulli_order, order);
out += bernoulli_sign * bernoulli * rising
/ bernoulli_order as f64
/ x.powi((bernoulli_order + order) as i32);
}
out
}
fn factorial(n: usize) -> f64 {
(1..=n).fold(1.0, |acc, k| acc * k as f64)
}
fn rising_factorial(start: usize, len: usize) -> f64 {
(start..start + len).fold(1.0, |acc, k| acc * k as f64)
}
impl<const K: usize> std::ops::Add for Tower4<K> {
type Output = Self;
fn add(self, o: Self) -> Self {
let mut out = self;
out.v += o.v;
for i in 0..K {
out.g[i] += o.g[i];
for j in 0..K {
out.h[i][j] += o.h[i][j];
for k in 0..K {
out.t3[i][j][k] += o.t3[i][j][k];
for l in 0..K {
out.t4[i][j][k][l] += o.t4[i][j][k][l];
}
}
}
}
out
}
}
impl<const K: usize> std::ops::Sub for Tower4<K> {
type Output = Self;
fn sub(self, o: Self) -> Self {
self + o.scale(-1.0)
}
}
impl<const K: usize> std::ops::Neg for Tower4<K> {
type Output = Self;
fn neg(self) -> Self {
self.scale(-1.0)
}
}
impl<const K: usize> std::ops::Mul for Tower4<K> {
type Output = Self;
fn mul(self, o: Self) -> Self {
Tower4::mul(&self, &o)
}
}
impl<const K: usize> std::ops::Div for Tower4<K> {
type Output = Self;
fn div(self, o: Self) -> Self {
Tower4::mul(&self, &o.recip())
}
}
impl<const K: usize> std::ops::Add<f64> for Tower4<K> {
type Output = Self;
fn add(self, c: f64) -> Self {
let mut out = self;
out.v += c;
out
}
}
impl<const K: usize> std::ops::Sub<f64> for Tower4<K> {
type Output = Self;
fn sub(self, c: f64) -> Self {
self + (-c)
}
}
impl<const K: usize> std::ops::Mul<f64> for Tower4<K> {
type Output = Self;
fn mul(self, c: f64) -> Self {
self.scale(c)
}
}
pub fn implicit_solve<const K1: usize, const K: usize>(
f: &Tower4<K1>,
a0: f64,
) -> Result<Tower4<K>, String> {
assert_eq!(K1, K + 1, "implicit_solve: constraint must carry K+1 vars");
let f_a = f.g[0];
if f_a == 0.0 || !f_a.is_finite() {
return Err(format!(
"implicit_solve: ∂F/∂a = {f_a:+.3e} is not invertible"
));
}
let root_tol = 1e-9;
if !f.v.is_finite() {
return Err(format!(
"implicit_solve: F(a0, θ0) = {:+.3e} is not finite",
f.v
));
}
let newton_step = f.v.abs() / f_a.abs();
if newton_step > root_tol * (1.0 + a0.abs()) {
return Err(format!(
"implicit_solve: expansion point a0 = {a0:+.6e} is not a root of F: \
F(a0, θ0) = {:+.3e}, Newton correction {newton_step:+.3e} exceeds \
root_tol {root_tol:.1e} · (1 + |a0|)",
f.v
));
}
let mut a = Tower4::<K>::constant(a0);
for order in 1..=4 {
let g = substitute_intercept(f, &a);
match order {
1 => {
for i in 0..K {
a.g[i] -= g.g[i] / f_a;
}
}
2 => {
for i in 0..K {
for j in 0..K {
a.h[i][j] -= g.h[i][j] / f_a;
}
}
}
3 => {
for i in 0..K {
for j in 0..K {
for k in 0..K {
a.t3[i][j][k] -= g.t3[i][j][k] / f_a;
}
}
}
}
_ => {
for i in 0..K {
for j in 0..K {
for k in 0..K {
for l in 0..K {
a.t4[i][j][k][l] -= g.t4[i][j][k][l] / f_a;
}
}
}
}
}
}
}
let g = substitute_intercept(f, &a);
let resid_tol = 1e-7 * (1.0 + f_a.abs());
let mut worst = g.v.abs();
for i in 0..K {
worst = worst.max(g.g[i].abs());
for j in 0..K {
worst = worst.max(g.h[i][j].abs());
for k in 0..K {
worst = worst.max(g.t3[i][j][k].abs());
for l in 0..K {
worst = worst.max(g.t4[i][j][k][l].abs());
}
}
}
}
if !worst.is_finite() || worst > resid_tol {
return Err(format!(
"implicit_solve: composed residual G = F∘a does not vanish: \
worst channel magnitude {worst:+.3e} exceeds tol {resid_tol:.1e}"
));
}
Ok(a)
}
pub fn substitute_intercept<const K1: usize, const K: usize>(
f: &Tower4<K1>,
a: &Tower4<K>,
) -> Tower4<K> {
assert_eq!(K1, K + 1);
let inp: [Tower4<K>; K1] = std::array::from_fn(|slot| {
if slot == 0 {
let mut d = *a;
d.v = 0.0;
d
} else {
let mut d = Tower4::<K>::zero();
d.g[slot - 1] = 1.0;
d
}
});
let mut out = Tower4::<K>::constant(f.v);
for a_idx in 0..K1 {
out = out + inp[a_idx].scale(f.g[a_idx]);
}
for a_idx in 0..K1 {
for b_idx in 0..K1 {
let prod = inp[a_idx].mul(&inp[b_idx]);
out = out + prod.scale(0.5 * f.h[a_idx][b_idx]);
}
}
for a_idx in 0..K1 {
for b_idx in 0..K1 {
for c_idx in 0..K1 {
let prod = inp[a_idx].mul(&inp[b_idx]).mul(&inp[c_idx]);
out = out + prod.scale(f.t3[a_idx][b_idx][c_idx] / 6.0);
}
}
}
for a_idx in 0..K1 {
for b_idx in 0..K1 {
for c_idx in 0..K1 {
for d_idx in 0..K1 {
let prod = inp[a_idx]
.mul(&inp[b_idx])
.mul(&inp[c_idx])
.mul(&inp[d_idx]);
out = out + prod.scale(f.t4[a_idx][b_idx][c_idx][d_idx] / 24.0);
}
}
}
}
out
}
pub fn moving_limit_boundary_tower<const K: usize>(
z_edge: &Tower4<K>,
b_stack: [f64; 4],
) -> Tower4<K> {
z_edge.compose_unary([0.0, b_stack[0], b_stack[1], b_stack[2], b_stack[3]])
}
pub fn cell_moving_boundary_flux_tower<const K: usize>(
z_right: &Tower4<K>,
b_stack_right: [f64; 4],
z_left: &Tower4<K>,
b_stack_left: [f64; 4],
) -> Tower4<K> {
moving_limit_boundary_tower(z_right, b_stack_right)
- moving_limit_boundary_tower(z_left, b_stack_left)
}
pub fn moving_limit_boundary_tower_theta_integrand<const K1: usize, const K: usize>(
phi_jet: &Tower4<K1>,
z_edge: &Tower4<K>,
) -> Tower4<K> {
assert_eq!(
K1,
K + 1,
"moving_limit_boundary_tower_theta_integrand: Φ jet must carry z + K θ-vars"
);
let frozen_edge = Tower4::<K>::constant(z_edge.v);
let full = substitute_intercept(phi_jet, z_edge);
let interior = substitute_intercept(phi_jet, &frozen_edge);
full - interior
}
pub fn cell_moving_boundary_flux_tower_theta_integrand<const K1: usize, const K: usize>(
phi_jet_right: &Tower4<K1>,
z_right: &Tower4<K>,
phi_jet_left: &Tower4<K1>,
z_left: &Tower4<K>,
) -> Tower4<K> {
moving_limit_boundary_tower_theta_integrand(phi_jet_right, z_right)
- moving_limit_boundary_tower_theta_integrand(phi_jet_left, z_left)
}
pub trait RowNllProgram<const K: usize>: Send + Sync {
fn n_rows(&self) -> usize;
fn primaries(&self, row: usize) -> Result<[f64; K], String>;
fn row_nll(&self, row: usize, p: &[Tower4<K>; K]) -> Result<Tower4<K>, String>;
}
pub fn evaluate_program<const K: usize, P: RowNllProgram<K> + ?Sized>(
prog: &P,
row: usize,
) -> Result<Tower4<K>, String> {
let p = prog.primaries(row)?;
let vars: [Tower4<K>; K] = std::array::from_fn(|a| Tower4::variable(p[a], a));
prog.row_nll(row, &vars)
}
pub fn derived_row_kernel<const K: usize, P: RowNllProgram<K> + ?Sized>(
prog: &P,
row: usize,
) -> Result<(f64, [f64; K], [[f64; K]; K]), String> {
let t = evaluate_program(prog, row)?;
Ok((t.v, t.g, t.h))
}
pub fn derived_third_contracted<const K: usize, P: RowNllProgram<K> + ?Sized>(
prog: &P,
row: usize,
dir: &[f64; K],
) -> Result<[[f64; K]; K], String> {
Ok(evaluate_program(prog, row)?.third_contracted(dir))
}
pub fn derived_fourth_contracted<const K: usize, P: RowNllProgram<K> + ?Sized>(
prog: &P,
row: usize,
dir_u: &[f64; K],
dir_v: &[f64; K],
) -> Result<[[f64; K]; K], String> {
Ok(evaluate_program(prog, row)?.fourth_contracted(dir_u, dir_v))
}
pub trait RowNllProgramGeneric<const K: usize>: Send + Sync {
fn n_rows(&self) -> usize;
fn primaries(&self, row: usize) -> Result<[f64; K], String>;
fn row_nll_generic<S: super::jet_scalar::JetScalar<K>>(
&self,
row: usize,
p: &[S; K],
) -> Result<S, String>;
}
pub fn generic_row_kernel<const K: usize, P: RowNllProgramGeneric<K> + ?Sized>(
prog: &P,
row: usize,
) -> Result<(f64, [f64; K], [[f64; K]; K]), String> {
let base = prog.primaries(row)?;
let vars: [super::jet_scalar::Order2<K>; K] = std::array::from_fn(|a| {
<super::jet_scalar::Order2<K> as super::jet_scalar::JetScalar<K>>::variable(base[a], a)
});
let s = prog.row_nll_generic(row, &vars)?;
Ok((
super::jet_scalar::JetScalar::value(&s),
s.g(),
s.h(),
))
}
pub fn generic_third_contracted<const K: usize, P: RowNllProgramGeneric<K> + ?Sized>(
prog: &P,
row: usize,
dir: &[f64; K],
) -> Result<[[f64; K]; K], String> {
let base = prog.primaries(row)?;
let vars: [super::jet_scalar::OneSeed<K>; K] =
std::array::from_fn(|a| super::jet_scalar::OneSeed::seed_direction(base[a], a, dir[a]));
let s = prog.row_nll_generic(row, &vars)?;
Ok(s.contracted_third())
}
pub fn generic_fourth_contracted<const K: usize, P: RowNllProgramGeneric<K> + ?Sized>(
prog: &P,
row: usize,
dir_u: &[f64; K],
dir_v: &[f64; K],
) -> Result<[[f64; K]; K], String> {
let base = prog.primaries(row)?;
let vars: [super::jet_scalar::TwoSeed<K>; K] =
std::array::from_fn(|a| super::jet_scalar::TwoSeed::seed(base[a], a, dir_u[a], dir_v[a]));
let s = prog.row_nll_generic(row, &vars)?;
Ok(s.contracted_fourth())
}
pub fn generic_full_tower<const K: usize, P: RowNllProgramGeneric<K> + ?Sized>(
prog: &P,
row: usize,
) -> Result<Tower4<K>, String> {
let base = prog.primaries(row)?;
let vars: [Tower4<K>; K] = std::array::from_fn(|a| Tower4::variable(base[a], a));
prog.row_nll_generic(row, &vars)
}
pub struct KernelChannels<const K: usize> {
pub value: f64,
pub gradient: [f64; K],
pub hessian: [[f64; K]; K],
pub third: Vec<([f64; K], [[f64; K]; K])>,
pub fourth: Vec<([f64; K], [f64; K], [[f64; K]; K])>,
}
pub fn verify_kernel_channels<const K: usize>(
tower: &Tower4<K>,
claims: &KernelChannels<K>,
rel_tol: f64,
) -> Result<(), String> {
let atol = rel_tol;
let check = |label: &str, claim: f64, truth: f64| -> Result<(), String> {
if !claim.is_finite() || !truth.is_finite() {
let agree = claim.is_infinite()
&& truth.is_infinite()
&& claim.is_sign_positive() == truth.is_sign_positive();
if agree {
return Ok(());
}
return Err(format!(
"row-kernel oracle: {label} non-finite mismatch: claimed {claim:+.12e}, tower {truth:+.12e}"
));
}
let band = atol + rel_tol * claim.abs().max(truth.abs());
if (claim - truth).abs() > band {
return Err(format!(
"row-kernel oracle: {label} disagrees: claimed {claim:+.12e}, tower {truth:+.12e} (rel_tol {rel_tol:.1e}, atol {atol:.1e}, band {band:.3e})"
));
}
Ok(())
};
check("value", claims.value, tower.v)?;
for a in 0..K {
check(&format!("gradient[{a}]"), claims.gradient[a], tower.g[a])?;
}
for a in 0..K {
for b in 0..K {
check(
&format!("hessian[{a}][{b}]"),
claims.hessian[a][b],
tower.h[a][b],
)?;
}
}
for (t_idx, (dir, claim)) in claims.third.iter().enumerate() {
let truth = tower.third_contracted(dir);
for a in 0..K {
for b in 0..K {
check(
&format!("third[{t_idx}][{a}][{b}]"),
claim[a][b],
truth[a][b],
)?;
}
}
}
for (f_idx, (u, w, claim)) in claims.fourth.iter().enumerate() {
let truth = tower.fourth_contracted(u, w);
for a in 0..K {
for b in 0..K {
check(
&format!("fourth[{f_idx}][{a}][{b}]"),
claim[a][b],
truth[a][b],
)?;
}
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
struct LogitProgram {
eta: Vec<f64>,
y: Vec<f64>,
}
impl RowNllProgram<1> for LogitProgram {
fn n_rows(&self) -> usize {
self.eta.len()
}
fn primaries(&self, row: usize) -> Result<[f64; 1], String> {
Ok([self.eta[row]])
}
fn row_nll(&self, row: usize, p: &[Tower4<1>; 1]) -> Result<Tower4<1>, String> {
let eta = p[0];
Ok((eta.exp() + 1.0).ln() - eta * self.y[row])
}
}
#[test]
fn logit_tower_matches_closed_forms() {
let prog = LogitProgram {
eta: vec![-2.3, -0.4, 0.0, 0.9, 3.1],
y: vec![1.0, 0.0, 1.0, 0.0, 1.0],
};
for row in 0..prog.n_rows() {
let t = evaluate_program(&prog, row).expect("logit program");
let eta = prog.eta[row];
let y = prog.y[row];
let mu = 1.0 / (1.0 + (-eta).exp());
let w = mu * (1.0 - mu);
let expect = [
(t.v, (1.0 + eta.exp()).ln() - y * eta, "value"),
(t.g[0], mu - y, "grad"),
(t.h[0][0], w, "hess"),
(t.t3[0][0][0], w * (1.0 - 2.0 * mu), "third"),
(
t.t4[0][0][0][0],
w * (1.0 - 6.0 * mu + 6.0 * mu * mu),
"fourth",
),
];
for (got, want, label) in expect {
assert!(
(got - want).abs() <= 1e-12 * want.abs().max(1.0),
"row {row} {label}: got {got:+.15e} want {want:+.15e}"
);
}
}
}
fn assert_close(label: &str, got: f64, want: f64, rel_tol: f64) {
let diff = (got - want).abs();
assert!(
diff <= rel_tol * want.abs().max(1.0),
"{label}: got {got:+.17e} want {want:+.17e} diff {diff:.3e}"
);
}
#[test]
fn gamma_special_function_stacks_match_reference_values() {
const EULER_GAMMA: f64 = 0.577_215_664_901_532_9;
let pi_sq = std::f64::consts::PI * std::f64::consts::PI;
let cases = [
(
"x=0.1",
0.1,
-10.423_754_940_411_076,
101.433_299_150_792_75,
),
(
"x=0.5",
0.5,
-EULER_GAMMA - 2.0 * std::f64::consts::LN_2,
pi_sq / 2.0,
),
("x=1", 1.0, -EULER_GAMMA, pi_sq / 6.0),
(
"x=2.5",
2.5,
-EULER_GAMMA - 2.0 * std::f64::consts::LN_2 + 2.0 + 2.0 / 3.0,
pi_sq / 2.0 - 4.0 - 4.0 / 9.0,
),
(
"x=50",
50.0,
3.901_989_673_427_892,
0.020_201_333_226_697_128,
),
];
for (label, x, digamma_ref, trigamma_ref) in cases {
let ln_gamma_stack = ln_gamma_derivative_stack(x);
let digamma_stack = digamma_derivative_stack(x);
let trigamma_stack = trigamma_derivative_stack(x);
assert_close(
&format!("{label} ln_gamma_stack digamma"),
ln_gamma_stack[1],
digamma_ref,
1e-13,
);
assert_close(
&format!("{label} digamma value"),
digamma_stack[0],
digamma_ref,
1e-13,
);
assert_close(
&format!("{label} ln_gamma_stack trigamma"),
ln_gamma_stack[2],
trigamma_ref,
1e-13,
);
assert_close(
&format!("{label} digamma_stack trigamma"),
digamma_stack[1],
trigamma_ref,
1e-13,
);
assert_close(
&format!("{label} trigamma value"),
trigamma_stack[0],
trigamma_ref,
1e-13,
);
}
}
#[test]
fn gamma_special_function_stacks_obey_recurrences() {
for x in [0.1, 0.5, 1.0, 2.5, 50.0] {
let digamma_x = digamma_derivative_stack(x)[0];
let digamma_next = digamma_derivative_stack(x + 1.0)[0];
let trigamma_x = trigamma_derivative_stack(x)[0];
let trigamma_next = trigamma_derivative_stack(x + 1.0)[0];
assert_close(
&format!("digamma recurrence x={x}"),
digamma_next,
digamma_x + 1.0 / x,
1e-13,
);
assert_close(
&format!("trigamma recurrence x={x}"),
trigamma_next,
trigamma_x - 1.0 / (x * x),
1e-13,
);
}
}
struct LocScaleProgram {
eta: Vec<f64>,
s: Vec<f64>,
y: Vec<f64>,
}
impl RowNllProgram<2> for LocScaleProgram {
fn n_rows(&self) -> usize {
self.eta.len()
}
fn primaries(&self, row: usize) -> Result<[f64; 2], String> {
Ok([self.eta[row], self.s[row]])
}
fn row_nll(&self, row: usize, p: &[Tower4<2>; 2]) -> Result<Tower4<2>, String> {
let r = -(p[0] - self.y[row]);
Ok(p[1] + (p[1] * (-2.0)).exp() * r * r * 0.5)
}
}
#[test]
fn locscale_tower_matches_closed_forms_including_cross_blocks() {
let prog = LocScaleProgram {
eta: vec![0.3, -1.1, 2.0],
s: vec![-0.5, 0.2, 0.8],
y: vec![1.0, -2.0, 2.5],
};
let tol = 1e-12;
for row in 0..prog.n_rows() {
let t = evaluate_program(&prog, row).expect("locscale program");
let r = prog.y[row] - prog.eta[row];
let w = (-2.0 * prog.s[row]).exp();
let truth_g = [-w * r, 1.0 - w * r * r];
let truth_h = [[w, 2.0 * w * r], [2.0 * w * r, 2.0 * w * r * r]];
let t3_truth = |a: usize, b: usize, c: usize| -> f64 {
match a + b + c {
0 => 0.0,
1 => -2.0 * w,
2 => -4.0 * w * r,
_ => -4.0 * w * r * r,
}
};
let t4_truth = |a: usize, b: usize, c: usize, d: usize| -> f64 {
match a + b + c + d {
0 | 1 => 0.0,
2 => 4.0 * w,
3 => 8.0 * w * r,
_ => 8.0 * w * r * r,
}
};
for a in 0..2 {
assert!(
(t.g[a] - truth_g[a]).abs() <= tol * truth_g[a].abs().max(1.0),
"row {row} grad[{a}]"
);
for b in 0..2 {
assert!(
(t.h[a][b] - truth_h[a][b]).abs() <= tol * w.max(1.0) * (1.0 + r.abs()),
"row {row} hess[{a}][{b}]: got {} want {}",
t.h[a][b],
truth_h[a][b]
);
for c in 0..2 {
assert!(
(t.t3[a][b][c] - t3_truth(a, b, c)).abs()
<= tol * 8.0 * w.max(1.0) * (1.0 + r.abs() + r * r),
"row {row} t3[{a}][{b}][{c}]: got {} want {}",
t.t3[a][b][c],
t3_truth(a, b, c)
);
for d in 0..2 {
assert!(
(t.t4[a][b][c][d] - t4_truth(a, b, c, d)).abs()
<= tol * 16.0 * w.max(1.0) * (1.0 + r.abs() + r * r),
"row {row} t4[{a}][{b}][{c}][{d}]: got {} want {}",
t.t4[a][b][c][d],
t4_truth(a, b, c, d)
);
}
}
}
}
let dir = [0.7, -1.3];
let third = derived_third_contracted(&prog, row, &dir).expect("third");
for a in 0..2 {
for b in 0..2 {
let want = t.t3[a][b][0] * dir[0] + t.t3[a][b][1] * dir[1];
assert!((third[a][b] - want).abs() <= 1e-13 * want.abs().max(1.0));
}
}
}
}
struct GnarlyProgram {
primaries: Vec<[f64; 3]>,
tau: Vec<f64>,
}
impl GnarlyProgram {
fn fixture() -> Self {
Self {
primaries: vec![[0.4, -0.7, 1.2], [-0.9, 0.6, 0.3], [1.1, -0.2, -0.8]],
tau: vec![0.15, -0.35, 0.5],
}
}
}
impl RowNllProgram<3> for GnarlyProgram {
fn n_rows(&self) -> usize {
self.primaries.len()
}
fn primaries(&self, row: usize) -> Result<[f64; 3], String> {
self.primaries
.get(row)
.copied()
.ok_or_else(|| format!("gnarly: row {row} out of range"))
}
fn row_nll(&self, row: usize, p: &[Tower4<3>; 3]) -> Result<Tower4<3>, String> {
let tau = *self
.tau
.get(row)
.ok_or_else(|| format!("gnarly: tau row {row} out of range"))?;
let a = (p[0] * p[1]).exp();
let b = (p[2] * p[2] + 1.0).sqrt();
let c = (a + b + tau).ln();
let d = (p[1] * 0.5 + 2.0).powf(1.7);
Ok(c / d + (p[0] - p[2]) * (p[0] - p[2]) * 0.25)
}
}
fn gnarly_tower_at(prog: &GnarlyProgram, row: usize, p: [f64; 3]) -> Tower4<3> {
struct At<'a> {
base: &'a GnarlyProgram,
row: usize,
p: [f64; 3],
}
impl RowNllProgram<3> for At<'_> {
fn n_rows(&self) -> usize {
1
}
fn primaries(&self, row: usize) -> Result<[f64; 3], String> {
if row != 0 {
return Err(format!("gnarly-at: row {row} out of range"));
}
Ok(self.p)
}
fn row_nll(&self, eval_row: usize, vars: &[Tower4<3>; 3]) -> Result<Tower4<3>, String> {
if eval_row != 0 {
return Err(format!("gnarly-at: eval row {eval_row} out of range"));
}
self.base.row_nll(self.row, vars)
}
}
evaluate_program(&At { base: prog, row, p }, 0).expect("gnarly tower")
}
#[test]
fn gnarly_tower_is_fd_consistent_order_by_order() {
let prog = GnarlyProgram::fixture();
for row in 0..prog.n_rows() {
let base = prog.primaries(row).expect("primaries");
let t = gnarly_tower_at(&prog, row, base);
let h_step = 1e-5;
let tol = 1e-6;
for c in 0..3 {
let mut up = base;
let mut dn = base;
up[c] += h_step;
dn[c] -= h_step;
let t_up = gnarly_tower_at(&prog, row, up);
let t_dn = gnarly_tower_at(&prog, row, dn);
let fd_g = (t_up.v - t_dn.v) / (2.0 * h_step);
assert!(
(t.g[c] - fd_g).abs() <= tol * fd_g.abs().max(1.0),
"grad[{c}]: analytic {} fd {}",
t.g[c],
fd_g
);
for a in 0..3 {
let fd_h = (t_up.g[a] - t_dn.g[a]) / (2.0 * h_step);
assert!(
(t.h[a][c] - fd_h).abs() <= tol * fd_h.abs().max(1.0),
"hess[{a}][{c}]: analytic {} fd {}",
t.h[a][c],
fd_h
);
for b in 0..3 {
let fd_t3 = (t_up.h[a][b] - t_dn.h[a][b]) / (2.0 * h_step);
assert!(
(t.t3[a][b][c] - fd_t3).abs() <= tol * fd_t3.abs().max(1.0),
"t3[{a}][{b}][{c}]: analytic {} fd {}",
t.t3[a][b][c],
fd_t3
);
for d in 0..3 {
let fd_t4 = (t_up.t3[a][b][d] - t_dn.t3[a][b][d]) / (2.0 * h_step);
assert!(
(t.t4[a][b][d][c] - fd_t4).abs() <= tol * fd_t4.abs().max(1.0),
"t4[{a}][{b}][{d}][{c}]: analytic {} fd {}",
t.t4[a][b][d][c],
fd_t4
);
}
}
}
}
}
}
#[test]
fn implicit_solve_matches_scalar_resolve_to_fourth_order() {
const C: f64 = 1.7;
let f_scalar = |a: f64, th: [f64; 2]| a + th[0] * a * a + th[1] * a.exp() - C;
let f_da = |a: f64, th: [f64; 2]| 1.0 + 2.0 * th[0] * a + th[1] * a.exp();
let solve = |th: [f64; 2]| -> f64 {
let mut a = 0.0_f64;
for _ in 0..100 {
let r = f_scalar(a, th);
if r.abs() < 1e-14 {
break;
}
a -= r / f_da(a, th);
}
a
};
let f_tower = |a0: f64, th: [f64; 2]| -> Tower4<3> {
let a = Tower4::<3>::variable(a0, 0);
let t0 = Tower4::<3>::variable(th[0], 1);
let t1 = Tower4::<3>::variable(th[1], 2);
a + t0 * a.mul(&a) + t1 * a.exp() - C
};
let th0 = [0.35, 0.2];
let a0 = solve(th0);
let f = f_tower(a0, th0);
assert!(f.v.abs() < 1e-12, "constraint residual {:+.3e}", f.v);
let a_tower: Tower4<2> = implicit_solve::<3, 2>(&f, a0).expect("implicit solve");
let h = 1e-4;
let tol = 1e-5;
let re = |th: [f64; 2]| solve(th);
for i in 0..2 {
let mut up = th0;
let mut dn = th0;
up[i] += h;
dn[i] -= h;
let fd_g = (re(up) - re(dn)) / (2.0 * h);
assert!(
(a_tower.g[i] - fd_g).abs() <= tol * fd_g.abs().max(1.0),
"a_θ[{i}]: analytic {:+.6e} fd {:+.6e}",
a_tower.g[i],
fd_g
);
let grad_at = |th: [f64; 2], j: usize| -> f64 {
let mut up = th;
let mut dn = th;
up[j] += h;
dn[j] -= h;
(re(up) - re(dn)) / (2.0 * h)
};
for j in 0..2 {
let fd_h = (grad_at(up, j) - grad_at(dn, j)) / (2.0 * h);
assert!(
(a_tower.h[i][j] - fd_h).abs() <= 1e-3 * fd_h.abs().max(1.0),
"a_θθ[{i}][{j}]: analytic {:+.6e} fd {:+.6e}",
a_tower.h[i][j],
fd_h
);
}
}
}
#[test]
fn implicit_solve_matches_textbook_ift_recursion() {
let a0 = 0.4_f64;
let th = [0.25_f64, -0.15_f64];
let f = {
let a = Tower4::<3>::variable(a0, 0);
let t0 = Tower4::<3>::variable(th[0], 1);
let t1 = Tower4::<3>::variable(th[1], 2);
a * (t0 + 1.0) + t1 * a.mul(&a) + t0 * t1 - 0.4385
};
let a_t = implicit_solve::<3, 2>(&f, a0).expect("solve");
let f_a = f.g[0];
for u in 0..2 {
let want = -f.g[u + 1] / f_a;
assert!(
(a_t.g[u] - want).abs() < 1e-12,
"a_u[{u}] {:+.6e} vs −F_u/F_a {:+.6e}",
a_t.g[u],
want
);
}
for u in 0..2 {
for v in 0..2 {
let f_uv = f.h[u + 1][v + 1];
let f_au = f.h[0][u + 1];
let f_av = f.h[0][v + 1];
let f_aa = f.h[0][0];
let want =
-(f_uv + f_au * a_t.g[v] + f_av * a_t.g[u] + f_aa * a_t.g[u] * a_t.g[v]) / f_a;
assert!(
(a_t.h[u][v] - want).abs() < 1e-12,
"a_uv[{u}][{v}] {:+.6e} vs IFT {:+.6e}",
a_t.h[u][v],
want
);
}
}
}
#[test]
fn moving_boundary_flux_carries_b_zuv_term() {
use std::f64::consts::PI;
let b = |z: f64| (-0.5 * z * z).exp(); let integral = |z_r: f64| (PI / 2.0).sqrt() * libm::erf(z_r / 2.0_f64.sqrt());
let z_r = |th: [f64; 2]| th[0] + th[1] * th[1];
let th0 = [0.7_f64, 0.5_f64];
let mut z_edge = Tower4::<2>::constant(z_r(th0));
z_edge.g[0] = 1.0; z_edge.g[1] = 2.0 * th0[1]; z_edge.h[1][1] = 2.0;
let z0 = z_edge.v;
let b0 = b(z0);
let stack = [
b0,
-z0 * b0,
(z0 * z0 - 1.0) * b0,
(3.0 * z0 - z0 * z0 * z0) * b0,
];
let flux = moving_limit_boundary_tower(&z_edge, stack);
let h = 1e-4;
let tol = 1e-6;
for i in 0..2 {
let mut up = th0;
let mut dn = th0;
up[i] += h;
dn[i] -= h;
let fd_g = (integral(z_r(up)) - integral(z_r(dn))) / (2.0 * h);
assert!(
(flux.g[i] - fd_g).abs() <= tol * fd_g.abs().max(1.0),
"flux_g[{i}]: analytic {:+.8e} fd {:+.8e}",
flux.g[i],
fd_g
);
}
let grad1_at = |th: [f64; 2]| -> f64 {
let mut up = th;
let mut dn = th;
up[1] += h;
dn[1] -= h;
(integral(z_r(up)) - integral(z_r(dn))) / (2.0 * h)
};
let mut up = th0;
let mut dn = th0;
up[1] += h;
dn[1] -= h;
let fd_h11 = (grad1_at(up) - grad1_at(dn)) / (2.0 * h);
assert!(
(flux.h[1][1] - fd_h11).abs() <= 1e-3 * fd_h11.abs().max(1.0),
"flux_h[1][1] (carries B·z_uv): analytic {:+.8e} fd {:+.8e}",
flux.h[1][1],
fd_h11
);
let pure_zu2 = stack[1] * z_edge.g[1] * z_edge.g[1];
let b_zuv = flux.h[1][1] - pure_zu2;
assert!(
(b_zuv - b0 * 2.0).abs() < 1e-10,
"B·z_uv term {:+.8e} != B₀·z_uv {:+.8e}",
b_zuv,
b0 * 2.0
);
}
#[test]
fn moving_boundary_theta_integrand_matches_handpath_and_closed_form() {
let g = |z: f64, t0: f64| (z * t0).exp();
let phi = |z: f64, t0: f64| ((z * t0).exp() - 1.0) / t0;
let z_r = |th: [f64; 2]| 0.6 + th[0] + th[1] * th[1];
let th0 = [0.4_f64, 0.5_f64];
let z0 = z_r(th0);
let mut z_edge = Tower4::<2>::constant(z0);
z_edge.g[0] = 1.0; z_edge.g[1] = 2.0 * th0[1]; z_edge.h[1][1] = 2.0;
let z_var = Tower4::<3>::variable(z0, 0);
let t0_var = Tower4::<3>::variable(th0[0], 1);
let _t1_var = Tower4::<3>::variable(th0[1], 2);
let phi_jet = ((z_var * t0_var).exp() - 1.0) / t0_var;
assert!(
(phi_jet.g[0] - g(z0, th0[0])).abs() < 1e-12,
"Φ_z {:+.8e} != G {:+.8e}",
phi_jet.g[0],
g(z0, th0[0])
);
let flux = moving_limit_boundary_tower_theta_integrand::<3, 2>(&phi_jet, &z_edge);
assert!(
flux.v.abs() < 1e-12,
"boundary value channel {:+.3e}",
flux.v
);
let bnd = |th: [f64; 2]| phi(z_r(th), th[0]) - phi(z0, th[0]);
let h = 1e-4;
let tol = 1e-6;
for i in 0..2 {
let mut up = th0;
let mut dn = th0;
up[i] += h;
dn[i] -= h;
let fd_g = (bnd(up) - bnd(dn)) / (2.0 * h);
assert!(
(flux.g[i] - fd_g).abs() <= tol * fd_g.abs().max(1.0),
"boundary_g[{i}] analytic {:+.8e} fd {:+.8e}",
flux.g[i],
fd_g
);
}
let grad_at = |th: [f64; 2], j: usize| -> f64 {
let mut up = th;
let mut dn = th;
up[j] += h;
dn[j] -= h;
(bnd(up) - bnd(dn)) / (2.0 * h)
};
for i in 0..2 {
for j in 0..2 {
let mut up = th0;
let mut dn = th0;
up[i] += h;
dn[i] -= h;
let fd_h = (grad_at(up, j) - grad_at(dn, j)) / (2.0 * h);
assert!(
(flux.h[i][j] - fd_h).abs() <= 1e-3 * fd_h.abs().max(1.0),
"boundary_h[{i}][{j}] analytic {:+.8e} fd {:+.8e}",
flux.h[i][j],
fd_h
);
}
}
let gg = g(z0, th0[0]);
let g_z = th0[0] * gg;
let g_theta = [z0 * gg, 0.0]; for i in 0..2 {
for j in 0..2 {
let z_u = z_edge.g[i];
let z_v = z_edge.g[j];
let z_uv = z_edge.h[i][j];
let hand = gg * z_uv + g_z * z_u * z_v + g_theta[i] * z_v + g_theta[j] * z_u;
assert!(
(flux.h[i][j] - hand).abs() < 1e-9,
"boundary_h[{i}][{j}] {:+.8e} != hand closure {:+.8e}",
flux.h[i][j],
hand
);
}
}
let pure_no_zuv = g_z * z_edge.g[1] * z_edge.g[1] + 2.0 * g_theta[1] * z_edge.g[1];
let g_zuv = flux.h[1][1] - pure_no_zuv;
assert!(
(g_zuv - gg * 2.0).abs() < 1e-9,
"G·z_uv term {:+.8e} != G₀·z_uv {:+.8e}",
g_zuv,
gg * 2.0
);
}
#[test]
fn crossing_edge_tower_matches_handpath_velocity_formulas() {
const TAU: f64 = 1.3; let g_idx = 1usize;
let g0 = 0.85_f64; let mut a = Tower4::<3>::constant(0.45);
a.g[0] = 0.7;
a.g[1] = -0.3;
a.h[0][0] = 0.25;
a.h[0][1] = 0.11;
a.h[1][0] = 0.11;
a.h[1][1] = -0.08;
let b = Tower4::<3>::variable(g0, g_idx);
let z_edge = (Tower4::<3>::constant(TAU) - a) / b;
let bv = g0;
let z0 = z_edge.v;
assert!((z0 - (TAU - 0.45) / bv).abs() < 1e-12);
for u in 0..2 {
let direct = if u == g_idx { z0 } else { 0.0 };
let want = -(a.g[u] + direct) / bv;
assert!(
(z_edge.g[u] - want).abs() < 1e-10,
"z_u[{u}] {:+.8e} vs hand formula {:+.8e}",
z_edge.g[u],
want
);
}
for u in 0..2 {
for v in 0..2 {
let cross = if u == g_idx { z_edge.g[v] } else { 0.0 }
+ if v == g_idx { z_edge.g[u] } else { 0.0 };
let want = -(a.h[u][v] + cross) / bv;
assert!(
(z_edge.h[u][v] - want).abs() < 1e-10,
"z_uv[{u}][{v}] {:+.8e} vs hand formula {:+.8e}",
z_edge.h[u][v],
want
);
}
}
}
#[test]
fn crossing_edge_constraint_frame_matches_bare_velocity_constants() {
const TAU: f64 = 1.3;
let a0 = 0.45_f64;
let b0 = 0.85_f64;
let a = Tower4::<2>::variable(a0, 0);
let b = Tower4::<2>::variable(b0, 1);
let z = (Tower4::<2>::constant(TAU) - a) / b;
assert!((z.v - (TAU - a0) / b0).abs() < 1e-12);
assert!((z.g[0] - (-1.0 / b0)).abs() < 1e-12, "z_a {:+.10e}", z.g[0]);
assert!(
(z.h[0][1] - 1.0 / (b0 * b0)).abs() < 1e-12,
"z_ab {:+.10e} vs +1/b² {:+.10e}",
z.h[0][1],
1.0 / (b0 * b0)
);
assert!(
z.h[0][0].abs() < 1e-12,
"z_aa must vanish, got {:+.10e}",
z.h[0][0]
);
let want_zbb = 2.0 * (TAU - a0) / (b0 * b0 * b0);
assert!(
(z.h[1][1] - want_zbb).abs() < 1e-12,
"z_bb {:+.10e} vs 2(τ−a)/b³ {:+.10e}",
z.h[1][1],
want_zbb
);
}
#[test]
fn oracle_catches_planted_cross_block_sign_flip() {
let prog = LocScaleProgram {
eta: vec![0.3],
s: vec![-0.5],
y: vec![1.0],
};
let t = evaluate_program(&prog, 0).expect("tower");
let dir = [0.6, -0.2];
let mut third = t.third_contracted(&dir);
let honest = KernelChannels {
value: t.v,
gradient: t.g,
hessian: t.h,
third: vec![(dir, third)],
fourth: vec![(dir, [1.0, 0.5], t.fourth_contracted(&dir, &[1.0, 0.5]))],
};
verify_kernel_channels(&t, &honest, 1e-10).expect("honest kernel must pass");
third[0][1] = -third[0][1];
let flipped = KernelChannels {
value: t.v,
gradient: t.g,
hessian: t.h,
third: vec![(dir, third)],
fourth: vec![],
};
let err = verify_kernel_channels(&t, &flipped, 1e-10)
.expect_err("planted sign flip must be caught");
assert!(
err.contains("third[0][0][1]"),
"oracle must name the flipped channel, got: {err}"
);
}
#[test]
fn t3_t4_are_fully_index_symmetric() {
let prog = GnarlyProgram::fixture();
let perms3: [[usize; 3]; 6] = [
[0, 1, 2],
[0, 2, 1],
[1, 0, 2],
[1, 2, 0],
[2, 0, 1],
[2, 1, 0],
];
let perms4: [[usize; 4]; 24] = [
[0, 1, 2, 3],
[0, 1, 3, 2],
[0, 2, 1, 3],
[0, 2, 3, 1],
[0, 3, 1, 2],
[0, 3, 2, 1],
[1, 0, 2, 3],
[1, 0, 3, 2],
[1, 2, 0, 3],
[1, 2, 3, 0],
[1, 3, 0, 2],
[1, 3, 2, 0],
[2, 0, 1, 3],
[2, 0, 3, 1],
[2, 1, 0, 3],
[2, 1, 3, 0],
[2, 3, 0, 1],
[2, 3, 1, 0],
[3, 0, 1, 2],
[3, 0, 2, 1],
[3, 1, 0, 2],
[3, 1, 2, 0],
[3, 2, 0, 1],
[3, 2, 1, 0],
];
for row in 0..prog.n_rows() {
let t = evaluate_program(&prog, row).expect("gnarly tower");
let scale_t3 =
t.t3.iter()
.flatten()
.flatten()
.fold(0.0_f64, |m, x| m.max(x.abs()))
.max(1.0);
let scale_t4 =
t.t4.iter()
.flatten()
.flatten()
.flatten()
.fold(0.0_f64, |m, x| m.max(x.abs()))
.max(1.0);
for i in 0..3 {
for j in 0..3 {
for k in 0..3 {
let base = t.t3[i][j][k];
let idx = [i, j, k];
for p in &perms3 {
let permed = t.t3[idx[p[0]]][idx[p[1]]][idx[p[2]]];
assert!(
(base - permed).abs() <= 1e-12 * scale_t3,
"row {row}: t3[{i}][{j}][{k}]={base:+.15e} != \
permuted {permed:+.15e} under {p:?}"
);
}
for l in 0..3 {
let base4 = t.t4[i][j][k][l];
let idx4 = [i, j, k, l];
for p in &perms4 {
let permed = t.t4[idx4[p[0]]][idx4[p[1]]][idx4[p[2]]][idx4[p[3]]];
assert!(
(base4 - permed).abs() <= 1e-12 * scale_t4,
"row {row}: t4[{i}][{j}][{k}][{l}]={base4:+.15e} != \
permuted {permed:+.15e} under {p:?}"
);
}
}
}
}
}
}
}
}
#[inline]
pub(crate) fn unary_derivatives_normal_logcdf(x: f64) -> [f64; 5] {
let (log_cdf, lambda) = crate::probability::signed_probit_logcdf_and_mills_ratio(x);
let lambda2 = lambda * lambda;
let lambda3 = lambda2 * lambda;
let x2 = x * x;
[
log_cdf,
lambda,
-lambda * (x + lambda),
lambda * (x2 - 1.0 + 3.0 * x * lambda + 2.0 * lambda2),
-lambda
* ((x * x2 - 3.0 * x) + (7.0 * x2 - 4.0) * lambda + 12.0 * x * lambda2 + 6.0 * lambda3),
]
}
#[inline]
pub(crate) fn unary_derivatives_log1mexp_positive(x: f64) -> [f64; 5] {
let r = 1.0 / x.exp_m1();
[
crate::probability::log1mexp_positive(x),
r,
-r * (1.0 + r),
r * (1.0 + r) * (1.0 + 2.0 * r),
-r * (1.0 + r) * (1.0 + 6.0 * r + 6.0 * r * r),
]
}