use super::jet_algebra;
#[derive(Clone, Copy, Debug)]
pub struct Tower4<const K: usize> {
pub v: f64,
pub g: [f64; K],
pub h: [[f64; K]; K],
pub t3: [[[f64; K]; K]; K],
pub t4: [[[[f64; K]; K]; K]; K],
}
impl<const K: usize> Tower4<K> {
pub fn zero() -> Self {
Self {
v: 0.0,
g: [0.0; K],
h: [[0.0; K]; K],
t3: [[[0.0; K]; K]; K],
t4: [[[[0.0; K]; K]; K]; K],
}
}
pub fn constant(c: f64) -> Self {
let mut out = Self::zero();
out.v = c;
out
}
pub fn variable(value: f64, idx: usize) -> Self {
let mut out = Self::constant(value);
out.g[idx] = 1.0;
out
}
#[inline]
fn deriv(&self, labels: &[usize]) -> f64 {
assert!(
labels.len() <= 4,
"Tower4 carries at most fourth-order derivatives"
);
match labels.len() {
0 => self.v,
1 => self.g[labels[0]],
2 => self.h[labels[0]][labels[1]],
3 => self.t3[labels[0]][labels[1]][labels[2]],
_ => self.t4[labels[0]][labels[1]][labels[2]][labels[3]],
}
}
pub fn mul(&self, o: &Self) -> Self {
let a = self;
let b = o;
let mut out = Self::zero();
out.v = a.v * b.v;
for i in 0..K {
let labels = [i];
out.g[i] = jet_algebra::leibniz_product(&labels, |t| a.deriv(t), |c| b.deriv(c));
}
for i in 0..K {
for j in 0..K {
let labels = [i, j];
out.h[i][j] = jet_algebra::leibniz_product(&labels, |t| a.deriv(t), |c| b.deriv(c));
}
}
for i in 0..K {
for j in 0..K {
for k in 0..K {
let labels = [i, j, k];
out.t3[i][j][k] =
jet_algebra::leibniz_product(&labels, |t| a.deriv(t), |c| b.deriv(c));
}
}
}
for i in 0..K {
for j in 0..K {
for k in 0..K {
for l in 0..K {
let labels = [i, j, k, l];
out.t4[i][j][k][l] =
jet_algebra::leibniz_product(&labels, |t| a.deriv(t), |c| b.deriv(c));
}
}
}
}
out
}
pub fn compose_unary(&self, d: [f64; 5]) -> Self {
<Self as jet_algebra::JetAlgebra<5>>::compose_unary(self, d)
}
pub fn scale(&self, s: f64) -> Self {
let mut out = *self;
out.v *= s;
for i in 0..K {
out.g[i] *= s;
for j in 0..K {
out.h[i][j] *= s;
for k in 0..K {
out.t3[i][j][k] *= s;
for l in 0..K {
out.t4[i][j][k][l] *= s;
}
}
}
}
out
}
pub fn exp(&self) -> Self {
let e = self.v.exp();
self.compose_unary([e, e, e, e, e])
}
pub fn ln(&self) -> Self {
let u = self.v;
let r = 1.0 / u;
self.compose_unary([u.ln(), r, -r * r, 2.0 * r * r * r, -6.0 * r * r * r * r])
}
pub fn recip(&self) -> Self {
let r = 1.0 / self.v;
let r2 = r * r;
self.compose_unary([r, -r2, 2.0 * r2 * r, -6.0 * r2 * r2, 24.0 * r2 * r2 * r])
}
pub fn sqrt(&self) -> Self {
let u = self.v;
let s = u.sqrt();
self.compose_unary([
s,
0.5 / s,
-0.25 / (u * s),
0.375 / (u * u * s),
-0.9375 / (u * u * u * s),
])
}
pub fn powf(&self, a: f64) -> Self {
let u = self.v;
let f0 = u.powf(a);
let f1 = a * u.powf(a - 1.0);
let f2 = a * (a - 1.0) * u.powf(a - 2.0);
let f3 = a * (a - 1.0) * (a - 2.0) * u.powf(a - 3.0);
let f4 = a * (a - 1.0) * (a - 2.0) * (a - 3.0) * u.powf(a - 4.0);
self.compose_unary([f0, f1, f2, f3, f4])
}
pub fn ln_gamma(&self) -> Self {
self.compose_unary(ln_gamma_derivative_stack(self.v))
}
pub fn digamma(&self) -> Self {
self.compose_unary(digamma_derivative_stack(self.v))
}
pub fn trigamma(&self) -> Self {
self.compose_unary(trigamma_derivative_stack(self.v))
}
pub fn third_contracted(&self, dir: &[f64; K]) -> [[f64; K]; K] {
let mut out = [[0.0; K]; K];
for a in 0..K {
for b in 0..K {
let mut acc = 0.0;
for c in 0..K {
acc += self.t3[a][b][c] * dir[c];
}
out[a][b] = acc;
}
}
out
}
pub fn fourth_contracted(&self, u: &[f64; K], w: &[f64; K]) -> [[f64; K]; K] {
let mut out = [[0.0; K]; K];
for i in 0..K {
for j in 0..K {
let mut acc = 0.0;
for k in 0..K {
for l in 0..K {
acc += self.t4[i][j][k][l] * u[k] * w[l];
}
}
out[i][j] = acc;
}
}
out
}
}
impl<const K: usize> jet_algebra::JetAlgebra<5> for Tower4<K> {
#[inline]
fn derivative(&self, labels: &[usize]) -> f64 {
self.deriv(labels)
}
fn map_derivatives<F>(&self, mut f: F) -> Self
where
F: FnMut(&[usize]) -> f64,
{
let mut out = Self::zero();
out.v = f(&[]);
for i in 0..K {
let labels = [i];
out.g[i] = f(&labels);
}
for i in 0..K {
for j in 0..K {
let labels = [i, j];
out.h[i][j] = f(&labels);
}
}
for i in 0..K {
for j in 0..K {
for k in 0..K {
let labels = [i, j, k];
out.t3[i][j][k] = f(&labels);
}
}
}
for i in 0..K {
for j in 0..K {
for k in 0..K {
for l in 0..K {
let labels = [i, j, k, l];
out.t4[i][j][k][l] = f(&labels);
}
}
}
}
out
}
}
#[derive(Clone, Copy, Debug)]
pub struct Tower2<const K: usize> {
pub v: f64,
pub g: [f64; K],
pub h: [[f64; K]; K],
}
impl<const K: usize> Tower2<K> {
pub fn zero() -> Self {
Self {
v: 0.0,
g: [0.0; K],
h: [[0.0; K]; K],
}
}
pub fn constant(c: f64) -> Self {
let mut out = Self::zero();
out.v = c;
out
}
pub fn variable(value: f64, idx: usize) -> Self {
let mut out = Self::constant(value);
out.g[idx] = 1.0;
out
}
#[inline]
fn deriv(&self, labels: &[usize]) -> f64 {
assert!(
labels.len() <= 2,
"Tower2 carries at most second-order derivatives"
);
match labels.len() {
0 => self.v,
1 => self.g[labels[0]],
_ => self.h[labels[0]][labels[1]],
}
}
pub fn mul(&self, o: &Self) -> Self {
let a = self;
let b = o;
let mut out = Self::zero();
out.v = a.v * b.v;
for i in 0..K {
out.g[i] = a.v * b.g[i] + a.g[i] * b.v;
}
for i in 0..K {
for j in 0..K {
out.h[i][j] = a.v * b.h[i][j] + a.g[i] * b.g[j] + a.g[j] * b.g[i] + a.h[i][j] * b.v;
}
}
out
}
pub fn compose_unary(&self, d: [f64; 3]) -> Self {
<Self as jet_algebra::JetAlgebra<3>>::compose_unary(self, d)
}
pub fn scale(&self, s: f64) -> Self {
let mut out = *self;
out.v *= s;
for i in 0..K {
out.g[i] *= s;
for j in 0..K {
out.h[i][j] *= s;
}
}
out
}
pub fn sqrt(&self) -> Self {
let u = self.v;
let s = u.sqrt();
self.compose_unary([s, 0.5 / s, -0.25 / (u * s)])
}
}
impl<const K: usize> jet_algebra::JetAlgebra<3> for Tower2<K> {
#[inline]
fn derivative(&self, labels: &[usize]) -> f64 {
self.deriv(labels)
}
fn map_derivatives<F>(&self, mut f: F) -> Self
where
F: FnMut(&[usize]) -> f64,
{
let mut out = Self::zero();
out.v = f(&[]);
for i in 0..K {
let labels = [i];
out.g[i] = f(&labels);
}
for i in 0..K {
for j in 0..K {
let labels = [i, j];
out.h[i][j] = f(&labels);
}
}
out
}
}
impl<const K: usize> std::ops::Add for Tower2<K> {
type Output = Self;
fn add(self, o: Self) -> Self {
let mut out = self;
out.v += o.v;
for i in 0..K {
out.g[i] += o.g[i];
for j in 0..K {
out.h[i][j] += o.h[i][j];
}
}
out
}
}
impl<const K: usize> std::ops::Mul for Tower2<K> {
type Output = Self;
fn mul(self, o: Self) -> Self {
Tower2::mul(&self, &o)
}
}
impl<const K: usize> std::ops::Add<f64> for Tower2<K> {
type Output = Self;
fn add(self, c: f64) -> Self {
let mut out = self;
out.v += c;
out
}
}
impl<const K: usize> std::ops::Mul<f64> for Tower2<K> {
type Output = Self;
fn mul(self, c: f64) -> Self {
self.scale(c)
}
}
pub fn ln_gamma_derivative_stack(x: f64) -> [f64; 5] {
[
statrs::function::gamma::ln_gamma(x),
digamma_positive(x),
polygamma_positive(1, x),
polygamma_positive(2, x),
polygamma_positive(3, x),
]
}
pub fn digamma_derivative_stack(x: f64) -> [f64; 5] {
[
digamma_positive(x),
polygamma_positive(1, x),
polygamma_positive(2, x),
polygamma_positive(3, x),
polygamma_positive(4, x),
]
}
pub fn trigamma_derivative_stack(x: f64) -> [f64; 5] {
[
polygamma_positive(1, x),
polygamma_positive(2, x),
polygamma_positive(3, x),
polygamma_positive(4, x),
polygamma_positive(5, x),
]
}
fn digamma_positive(mut x: f64) -> f64 {
if !(x.is_finite() && x > 0.0) {
return f64::NAN;
}
let mut acc = 0.0;
while x < POLYGAMMA_ASYMPTOTIC_MIN_X {
acc -= 1.0 / x;
x += 1.0;
}
acc + digamma_asymptotic(x)
}
fn polygamma_positive(order: usize, mut x: f64) -> f64 {
if !(x.is_finite() && x > 0.0) {
return f64::NAN;
}
let mut acc = 0.0;
while x < POLYGAMMA_ASYMPTOTIC_MIN_X {
acc += polygamma_recurrence_term(order, x);
x += 1.0;
}
acc + polygamma_asymptotic(order, x)
}
const POLYGAMMA_ASYMPTOTIC_MIN_X: f64 = 20.0;
const BERNOULLI_EVEN: [(usize, f64); 10] = [
(2, 1.0 / 6.0),
(4, -1.0 / 30.0),
(6, 1.0 / 42.0),
(8, -1.0 / 30.0),
(10, 5.0 / 66.0),
(12, -691.0 / 2730.0),
(14, 7.0 / 6.0),
(16, -3617.0 / 510.0),
(18, 43867.0 / 798.0),
(20, -174611.0 / 330.0),
];
fn polygamma_recurrence_term(order: usize, x: f64) -> f64 {
let sign = if order % 2 == 1 { 1.0 } else { -1.0 };
sign * factorial(order) / x.powi((order + 1) as i32)
}
fn digamma_asymptotic(x: f64) -> f64 {
let mut out = x.ln() - 0.5 / x;
for (bernoulli_order, bernoulli) in BERNOULLI_EVEN {
out -= bernoulli / (bernoulli_order as f64 * x.powi(bernoulli_order as i32));
}
out
}
fn polygamma_asymptotic(order: usize, x: f64) -> f64 {
if !(1..=5).contains(&order) {
return f64::NAN;
}
let order_factorial = factorial(order);
let leading_sign = if order % 2 == 1 { 1.0 } else { -1.0 };
let mut out = leading_sign * factorial(order - 1) / x.powi(order as i32)
+ leading_sign * order_factorial / (2.0 * x.powi((order + 1) as i32));
let bernoulli_sign = if order % 2 == 1 { 1.0 } else { -1.0 };
for (bernoulli_order, bernoulli) in BERNOULLI_EVEN {
let rising = rising_factorial(bernoulli_order, order);
out += bernoulli_sign * bernoulli * rising
/ bernoulli_order as f64
/ x.powi((bernoulli_order + order) as i32);
}
out
}
fn factorial(n: usize) -> f64 {
(1..=n).fold(1.0, |acc, k| acc * k as f64)
}
fn rising_factorial(start: usize, len: usize) -> f64 {
(start..start + len).fold(1.0, |acc, k| acc * k as f64)
}
impl<const K: usize> std::ops::Add for Tower4<K> {
type Output = Self;
fn add(self, o: Self) -> Self {
let mut out = self;
out.v += o.v;
for i in 0..K {
out.g[i] += o.g[i];
for j in 0..K {
out.h[i][j] += o.h[i][j];
for k in 0..K {
out.t3[i][j][k] += o.t3[i][j][k];
for l in 0..K {
out.t4[i][j][k][l] += o.t4[i][j][k][l];
}
}
}
}
out
}
}
impl<const K: usize> std::ops::Sub for Tower4<K> {
type Output = Self;
fn sub(self, o: Self) -> Self {
self + o.scale(-1.0)
}
}
impl<const K: usize> std::ops::Neg for Tower4<K> {
type Output = Self;
fn neg(self) -> Self {
self.scale(-1.0)
}
}
impl<const K: usize> std::ops::Mul for Tower4<K> {
type Output = Self;
fn mul(self, o: Self) -> Self {
Tower4::mul(&self, &o)
}
}
impl<const K: usize> std::ops::Div for Tower4<K> {
type Output = Self;
fn div(self, o: Self) -> Self {
Tower4::mul(&self, &o.recip())
}
}
impl<const K: usize> std::ops::Add<f64> for Tower4<K> {
type Output = Self;
fn add(self, c: f64) -> Self {
let mut out = self;
out.v += c;
out
}
}
impl<const K: usize> std::ops::Sub<f64> for Tower4<K> {
type Output = Self;
fn sub(self, c: f64) -> Self {
self + (-c)
}
}
impl<const K: usize> std::ops::Mul<f64> for Tower4<K> {
type Output = Self;
fn mul(self, c: f64) -> Self {
self.scale(c)
}
}
pub trait RowNllProgram<const K: usize>: Send + Sync {
fn n_rows(&self) -> usize;
fn primaries(&self, row: usize) -> Result<[f64; K], String>;
fn row_nll(&self, row: usize, p: &[Tower4<K>; K]) -> Result<Tower4<K>, String>;
}
pub fn evaluate_program<const K: usize, P: RowNllProgram<K> + ?Sized>(
prog: &P,
row: usize,
) -> Result<Tower4<K>, String> {
let p = prog.primaries(row)?;
let vars: [Tower4<K>; K] = std::array::from_fn(|a| Tower4::variable(p[a], a));
prog.row_nll(row, &vars)
}
pub fn derived_row_kernel<const K: usize, P: RowNllProgram<K> + ?Sized>(
prog: &P,
row: usize,
) -> Result<(f64, [f64; K], [[f64; K]; K]), String> {
let t = evaluate_program(prog, row)?;
Ok((t.v, t.g, t.h))
}
pub fn derived_third_contracted<const K: usize, P: RowNllProgram<K> + ?Sized>(
prog: &P,
row: usize,
dir: &[f64; K],
) -> Result<[[f64; K]; K], String> {
Ok(evaluate_program(prog, row)?.third_contracted(dir))
}
pub fn derived_fourth_contracted<const K: usize, P: RowNllProgram<K> + ?Sized>(
prog: &P,
row: usize,
dir_u: &[f64; K],
dir_v: &[f64; K],
) -> Result<[[f64; K]; K], String> {
Ok(evaluate_program(prog, row)?.fourth_contracted(dir_u, dir_v))
}
pub struct KernelChannels<const K: usize> {
pub value: f64,
pub gradient: [f64; K],
pub hessian: [[f64; K]; K],
pub third: Vec<([f64; K], [[f64; K]; K])>,
pub fourth: Vec<([f64; K], [f64; K], [[f64; K]; K])>,
}
pub fn verify_kernel_channels<const K: usize>(
tower: &Tower4<K>,
claims: &KernelChannels<K>,
rel_tol: f64,
) -> Result<(), String> {
let check = |label: &str, claim: f64, truth: f64, floor: f64| -> Result<(), String> {
let scale = truth.abs().max(floor).max(1e-300);
if (claim - truth).abs() > rel_tol * scale {
return Err(format!(
"row-kernel oracle: {label} disagrees: claimed {claim:+.12e}, tower {truth:+.12e} (rel_tol {rel_tol:.1e}, scale {scale:.3e})"
));
}
Ok(())
};
check("value", claims.value, tower.v, 1.0)?;
let g_floor = tower.g.iter().fold(0.0_f64, |m, x| m.max(x.abs()));
for a in 0..K {
check(
&format!("gradient[{a}]"),
claims.gradient[a],
tower.g[a],
g_floor,
)?;
}
let h_floor = tower
.h
.iter()
.flatten()
.fold(0.0_f64, |m, x| m.max(x.abs()));
for a in 0..K {
for b in 0..K {
check(
&format!("hessian[{a}][{b}]"),
claims.hessian[a][b],
tower.h[a][b],
h_floor,
)?;
}
}
for (t_idx, (dir, claim)) in claims.third.iter().enumerate() {
let truth = tower.third_contracted(dir);
let floor = truth.iter().flatten().fold(0.0_f64, |m, x| m.max(x.abs()));
for a in 0..K {
for b in 0..K {
check(
&format!("third[{t_idx}][{a}][{b}]"),
claim[a][b],
truth[a][b],
floor,
)?;
}
}
}
for (f_idx, (u, w, claim)) in claims.fourth.iter().enumerate() {
let truth = tower.fourth_contracted(u, w);
let floor = truth.iter().flatten().fold(0.0_f64, |m, x| m.max(x.abs()));
for a in 0..K {
for b in 0..K {
check(
&format!("fourth[{f_idx}][{a}][{b}]"),
claim[a][b],
truth[a][b],
floor,
)?;
}
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
struct LogitProgram {
eta: Vec<f64>,
y: Vec<f64>,
}
impl RowNllProgram<1> for LogitProgram {
fn n_rows(&self) -> usize {
self.eta.len()
}
fn primaries(&self, row: usize) -> Result<[f64; 1], String> {
Ok([self.eta[row]])
}
fn row_nll(&self, row: usize, p: &[Tower4<1>; 1]) -> Result<Tower4<1>, String> {
let eta = p[0];
Ok((eta.exp() + 1.0).ln() - eta * self.y[row])
}
}
#[test]
fn logit_tower_matches_closed_forms() {
let prog = LogitProgram {
eta: vec![-2.3, -0.4, 0.0, 0.9, 3.1],
y: vec![1.0, 0.0, 1.0, 0.0, 1.0],
};
for row in 0..prog.n_rows() {
let t = evaluate_program(&prog, row).expect("logit program");
let eta = prog.eta[row];
let y = prog.y[row];
let mu = 1.0 / (1.0 + (-eta).exp());
let w = mu * (1.0 - mu);
let expect = [
(t.v, (1.0 + eta.exp()).ln() - y * eta, "value"),
(t.g[0], mu - y, "grad"),
(t.h[0][0], w, "hess"),
(t.t3[0][0][0], w * (1.0 - 2.0 * mu), "third"),
(
t.t4[0][0][0][0],
w * (1.0 - 6.0 * mu + 6.0 * mu * mu),
"fourth",
),
];
for (got, want, label) in expect {
assert!(
(got - want).abs() <= 1e-12 * want.abs().max(1.0),
"row {row} {label}: got {got:+.15e} want {want:+.15e}"
);
}
}
}
fn assert_close(label: &str, got: f64, want: f64, rel_tol: f64) {
let diff = (got - want).abs();
assert!(
diff <= rel_tol * want.abs().max(1.0),
"{label}: got {got:+.17e} want {want:+.17e} diff {diff:.3e}"
);
}
#[test]
fn gamma_special_function_stacks_match_reference_values() {
const EULER_GAMMA: f64 = 0.577_215_664_901_532_9;
let pi_sq = std::f64::consts::PI * std::f64::consts::PI;
let cases = [
(
"x=0.1",
0.1,
-10.423_754_940_411_076,
101.433_299_150_792_75,
),
(
"x=0.5",
0.5,
-EULER_GAMMA - 2.0 * std::f64::consts::LN_2,
pi_sq / 2.0,
),
("x=1", 1.0, -EULER_GAMMA, pi_sq / 6.0),
(
"x=2.5",
2.5,
-EULER_GAMMA - 2.0 * std::f64::consts::LN_2 + 2.0 + 2.0 / 3.0,
pi_sq / 2.0 - 4.0 - 4.0 / 9.0,
),
(
"x=50",
50.0,
3.901_989_673_427_892,
0.020_201_333_226_697_128,
),
];
for (label, x, digamma_ref, trigamma_ref) in cases {
let ln_gamma_stack = ln_gamma_derivative_stack(x);
let digamma_stack = digamma_derivative_stack(x);
let trigamma_stack = trigamma_derivative_stack(x);
assert_close(
&format!("{label} ln_gamma_stack digamma"),
ln_gamma_stack[1],
digamma_ref,
1e-13,
);
assert_close(
&format!("{label} digamma value"),
digamma_stack[0],
digamma_ref,
1e-13,
);
assert_close(
&format!("{label} ln_gamma_stack trigamma"),
ln_gamma_stack[2],
trigamma_ref,
1e-13,
);
assert_close(
&format!("{label} digamma_stack trigamma"),
digamma_stack[1],
trigamma_ref,
1e-13,
);
assert_close(
&format!("{label} trigamma value"),
trigamma_stack[0],
trigamma_ref,
1e-13,
);
}
}
#[test]
fn gamma_special_function_stacks_obey_recurrences() {
for x in [0.1, 0.5, 1.0, 2.5, 50.0] {
let digamma_x = digamma_derivative_stack(x)[0];
let digamma_next = digamma_derivative_stack(x + 1.0)[0];
let trigamma_x = trigamma_derivative_stack(x)[0];
let trigamma_next = trigamma_derivative_stack(x + 1.0)[0];
assert_close(
&format!("digamma recurrence x={x}"),
digamma_next,
digamma_x + 1.0 / x,
1e-13,
);
assert_close(
&format!("trigamma recurrence x={x}"),
trigamma_next,
trigamma_x - 1.0 / (x * x),
1e-13,
);
}
}
struct LocScaleProgram {
eta: Vec<f64>,
s: Vec<f64>,
y: Vec<f64>,
}
impl RowNllProgram<2> for LocScaleProgram {
fn n_rows(&self) -> usize {
self.eta.len()
}
fn primaries(&self, row: usize) -> Result<[f64; 2], String> {
Ok([self.eta[row], self.s[row]])
}
fn row_nll(&self, row: usize, p: &[Tower4<2>; 2]) -> Result<Tower4<2>, String> {
let r = -(p[0] - self.y[row]);
Ok(p[1] + (p[1] * (-2.0)).exp() * r * r * 0.5)
}
}
#[test]
fn locscale_tower_matches_closed_forms_including_cross_blocks() {
let prog = LocScaleProgram {
eta: vec![0.3, -1.1, 2.0],
s: vec![-0.5, 0.2, 0.8],
y: vec![1.0, -2.0, 2.5],
};
let tol = 1e-12;
for row in 0..prog.n_rows() {
let t = evaluate_program(&prog, row).expect("locscale program");
let r = prog.y[row] - prog.eta[row];
let w = (-2.0 * prog.s[row]).exp();
let truth_g = [-w * r, 1.0 - w * r * r];
let truth_h = [[w, 2.0 * w * r], [2.0 * w * r, 2.0 * w * r * r]];
let t3_truth = |a: usize, b: usize, c: usize| -> f64 {
match a + b + c {
0 => 0.0,
1 => -2.0 * w,
2 => -4.0 * w * r,
_ => -4.0 * w * r * r,
}
};
let t4_truth = |a: usize, b: usize, c: usize, d: usize| -> f64 {
match a + b + c + d {
0 | 1 => 0.0,
2 => 4.0 * w,
3 => 8.0 * w * r,
_ => 8.0 * w * r * r,
}
};
for a in 0..2 {
assert!(
(t.g[a] - truth_g[a]).abs() <= tol * truth_g[a].abs().max(1.0),
"row {row} grad[{a}]"
);
for b in 0..2 {
assert!(
(t.h[a][b] - truth_h[a][b]).abs() <= tol * w.max(1.0) * (1.0 + r.abs()),
"row {row} hess[{a}][{b}]: got {} want {}",
t.h[a][b],
truth_h[a][b]
);
for c in 0..2 {
assert!(
(t.t3[a][b][c] - t3_truth(a, b, c)).abs()
<= tol * 8.0 * w.max(1.0) * (1.0 + r.abs() + r * r),
"row {row} t3[{a}][{b}][{c}]: got {} want {}",
t.t3[a][b][c],
t3_truth(a, b, c)
);
for d in 0..2 {
assert!(
(t.t4[a][b][c][d] - t4_truth(a, b, c, d)).abs()
<= tol * 16.0 * w.max(1.0) * (1.0 + r.abs() + r * r),
"row {row} t4[{a}][{b}][{c}][{d}]: got {} want {}",
t.t4[a][b][c][d],
t4_truth(a, b, c, d)
);
}
}
}
}
let dir = [0.7, -1.3];
let third = derived_third_contracted(&prog, row, &dir).expect("third");
for a in 0..2 {
for b in 0..2 {
let want = t.t3[a][b][0] * dir[0] + t.t3[a][b][1] * dir[1];
assert!((third[a][b] - want).abs() <= 1e-13 * want.abs().max(1.0));
}
}
}
}
struct GnarlyProgram {
primaries: Vec<[f64; 3]>,
tau: Vec<f64>,
}
impl GnarlyProgram {
fn fixture() -> Self {
Self {
primaries: vec![[0.4, -0.7, 1.2], [-0.9, 0.6, 0.3], [1.1, -0.2, -0.8]],
tau: vec![0.15, -0.35, 0.5],
}
}
}
impl RowNllProgram<3> for GnarlyProgram {
fn n_rows(&self) -> usize {
self.primaries.len()
}
fn primaries(&self, row: usize) -> Result<[f64; 3], String> {
self.primaries
.get(row)
.copied()
.ok_or_else(|| format!("gnarly: row {row} out of range"))
}
fn row_nll(&self, row: usize, p: &[Tower4<3>; 3]) -> Result<Tower4<3>, String> {
let tau = *self
.tau
.get(row)
.ok_or_else(|| format!("gnarly: tau row {row} out of range"))?;
let a = (p[0] * p[1]).exp();
let b = (p[2] * p[2] + 1.0).sqrt();
let c = (a + b + tau).ln();
let d = (p[1] * 0.5 + 2.0).powf(1.7);
Ok(c / d + (p[0] - p[2]) * (p[0] - p[2]) * 0.25)
}
}
fn gnarly_tower_at(prog: &GnarlyProgram, row: usize, p: [f64; 3]) -> Tower4<3> {
struct At<'a> {
base: &'a GnarlyProgram,
row: usize,
p: [f64; 3],
}
impl RowNllProgram<3> for At<'_> {
fn n_rows(&self) -> usize {
1
}
fn primaries(&self, row: usize) -> Result<[f64; 3], String> {
if row != 0 {
return Err(format!("gnarly-at: row {row} out of range"));
}
Ok(self.p)
}
fn row_nll(&self, eval_row: usize, vars: &[Tower4<3>; 3]) -> Result<Tower4<3>, String> {
if eval_row != 0 {
return Err(format!("gnarly-at: eval row {eval_row} out of range"));
}
self.base.row_nll(self.row, vars)
}
}
evaluate_program(&At { base: prog, row, p }, 0).expect("gnarly tower")
}
#[test]
fn gnarly_tower_is_fd_consistent_order_by_order() {
let prog = GnarlyProgram::fixture();
for row in 0..prog.n_rows() {
let base = prog.primaries(row).expect("primaries");
let t = gnarly_tower_at(&prog, row, base);
let h_step = 1e-5;
let tol = 1e-6;
for c in 0..3 {
let mut up = base;
let mut dn = base;
up[c] += h_step;
dn[c] -= h_step;
let t_up = gnarly_tower_at(&prog, row, up);
let t_dn = gnarly_tower_at(&prog, row, dn);
let fd_g = (t_up.v - t_dn.v) / (2.0 * h_step);
assert!(
(t.g[c] - fd_g).abs() <= tol * fd_g.abs().max(1.0),
"grad[{c}]: analytic {} fd {}",
t.g[c],
fd_g
);
for a in 0..3 {
let fd_h = (t_up.g[a] - t_dn.g[a]) / (2.0 * h_step);
assert!(
(t.h[a][c] - fd_h).abs() <= tol * fd_h.abs().max(1.0),
"hess[{a}][{c}]: analytic {} fd {}",
t.h[a][c],
fd_h
);
for b in 0..3 {
let fd_t3 = (t_up.h[a][b] - t_dn.h[a][b]) / (2.0 * h_step);
assert!(
(t.t3[a][b][c] - fd_t3).abs() <= tol * fd_t3.abs().max(1.0),
"t3[{a}][{b}][{c}]: analytic {} fd {}",
t.t3[a][b][c],
fd_t3
);
for d in 0..3 {
let fd_t4 = (t_up.t3[a][b][d] - t_dn.t3[a][b][d]) / (2.0 * h_step);
assert!(
(t.t4[a][b][d][c] - fd_t4).abs() <= tol * fd_t4.abs().max(1.0),
"t4[{a}][{b}][{d}][{c}]: analytic {} fd {}",
t.t4[a][b][d][c],
fd_t4
);
}
}
}
}
}
}
#[test]
fn oracle_catches_planted_cross_block_sign_flip() {
let prog = LocScaleProgram {
eta: vec![0.3],
s: vec![-0.5],
y: vec![1.0],
};
let t = evaluate_program(&prog, 0).expect("tower");
let dir = [0.6, -0.2];
let mut third = t.third_contracted(&dir);
let honest = KernelChannels {
value: t.v,
gradient: t.g,
hessian: t.h,
third: vec![(dir, third)],
fourth: vec![(dir, [1.0, 0.5], t.fourth_contracted(&dir, &[1.0, 0.5]))],
};
verify_kernel_channels(&t, &honest, 1e-10).expect("honest kernel must pass");
third[0][1] = -third[0][1];
let flipped = KernelChannels {
value: t.v,
gradient: t.g,
hessian: t.h,
third: vec![(dir, third)],
fourth: vec![],
};
let err = verify_kernel_channels(&t, &flipped, 1e-10)
.expect_err("planted sign flip must be caught");
assert!(
err.contains("third[0][0][1]"),
"oracle must name the flipped channel, got: {err}"
);
}
#[test]
fn t3_t4_are_fully_index_symmetric() {
let prog = GnarlyProgram::fixture();
let perms3: [[usize; 3]; 6] = [
[0, 1, 2],
[0, 2, 1],
[1, 0, 2],
[1, 2, 0],
[2, 0, 1],
[2, 1, 0],
];
let perms4: [[usize; 4]; 24] = [
[0, 1, 2, 3],
[0, 1, 3, 2],
[0, 2, 1, 3],
[0, 2, 3, 1],
[0, 3, 1, 2],
[0, 3, 2, 1],
[1, 0, 2, 3],
[1, 0, 3, 2],
[1, 2, 0, 3],
[1, 2, 3, 0],
[1, 3, 0, 2],
[1, 3, 2, 0],
[2, 0, 1, 3],
[2, 0, 3, 1],
[2, 1, 0, 3],
[2, 1, 3, 0],
[2, 3, 0, 1],
[2, 3, 1, 0],
[3, 0, 1, 2],
[3, 0, 2, 1],
[3, 1, 0, 2],
[3, 1, 2, 0],
[3, 2, 0, 1],
[3, 2, 1, 0],
];
for row in 0..prog.n_rows() {
let t = evaluate_program(&prog, row).expect("gnarly tower");
let scale_t3 =
t.t3.iter()
.flatten()
.flatten()
.fold(0.0_f64, |m, x| m.max(x.abs()))
.max(1.0);
let scale_t4 =
t.t4.iter()
.flatten()
.flatten()
.flatten()
.fold(0.0_f64, |m, x| m.max(x.abs()))
.max(1.0);
for i in 0..3 {
for j in 0..3 {
for k in 0..3 {
let base = t.t3[i][j][k];
let idx = [i, j, k];
for p in &perms3 {
let permed = t.t3[idx[p[0]]][idx[p[1]]][idx[p[2]]];
assert!(
(base - permed).abs() <= 1e-12 * scale_t3,
"row {row}: t3[{i}][{j}][{k}]={base:+.15e} != \
permuted {permed:+.15e} under {p:?}"
);
}
for l in 0..3 {
let base4 = t.t4[i][j][k][l];
let idx4 = [i, j, k, l];
for p in &perms4 {
let permed = t.t4[idx4[p[0]]][idx4[p[1]]][idx4[p[2]]][idx4[p[3]]];
assert!(
(base4 - permed).abs() <= 1e-12 * scale_t4,
"row {row}: t4[{i}][{j}][{k}][{l}]={base4:+.15e} != \
permuted {permed:+.15e} under {p:?}"
);
}
}
}
}
}
}
}
}
#[inline]
pub(crate) fn unary_derivatives_normal_logcdf(x: f64) -> [f64; 5] {
let (log_cdf, lambda) = crate::probability::signed_probit_logcdf_and_mills_ratio(x);
let lambda2 = lambda * lambda;
let lambda3 = lambda2 * lambda;
let x2 = x * x;
[
log_cdf,
lambda,
-lambda * (x + lambda),
lambda * (x2 - 1.0 + 3.0 * x * lambda + 2.0 * lambda2),
-lambda
* ((x * x2 - 3.0 * x) + (7.0 * x2 - 4.0) * lambda + 12.0 * x * lambda2 + 6.0 * lambda3),
]
}
#[inline]
pub(crate) fn unary_derivatives_log1mexp_positive(x: f64) -> [f64; 5] {
let r = 1.0 / x.exp_m1();
[
crate::probability::log1mexp_positive(x),
r,
-r * (1.0 + r),
r * (1.0 + r) * (1.0 + 2.0 * r),
-r * (1.0 + r) * (1.0 + 6.0 * r + 6.0 * r * r),
]
}