use crate::jet_scalar::JetScalar;
use crate::jet_tower::{
Tower4, digamma_derivative_stack, generic_fourth_contracted, generic_full_tower,
generic_row_kernel, generic_third_contracted, ln_gamma_derivative_stack,
};
struct Lcg(u64);
impl Lcg {
fn next_f64(&mut self) -> f64 {
self.0 = self
.0
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
((self.0 >> 11) as f64) / ((1u64 << 53) as f64)
}
fn uniform(&mut self, lo: f64, hi: f64) -> f64 {
lo + (hi - lo) * self.next_f64()
}
}
#[derive(Clone, Copy, Debug)]
struct GammaRow {
y: f64,
p0: f64,
p1: f64,
}
struct GammaLocScaleRow {
rows: Vec<GammaRow>,
}
impl crate::jet_tower::RowNllProgramGeneric<2> for GammaLocScaleRow {
fn n_rows(&self) -> usize {
self.rows.len()
}
fn primaries(&self, row: usize) -> Result<[f64; 2], String> {
let r = self
.rows
.get(row)
.ok_or_else(|| format!("GammaLocScaleRow: row {row} out of range"))?;
Ok([r.p0, r.p1])
}
fn row_nll_generic<S: JetScalar<2>>(&self, row: usize, p: &[S; 2]) -> Result<S, String> {
let data = self
.rows
.get(row)
.ok_or_else(|| format!("GammaLocScaleRow: row {row} out of range"))?;
let p0 = &p[0];
let p1 = &p[1];
let a = p1.scale(-2.0).exp();
let q = p0.add(&p1.scale(2.0));
let data_over_scale = S::constant(data.y).mul(&q.neg().exp());
let a_times_q = a.mul(&q);
let ln_gamma_a = a.ln_gamma();
let log_y = data.y.ln();
let neg_a_log_y = a.scale(-log_y);
Ok(neg_a_log_y
.add(&S::constant(log_y))
.add(&data_over_scale)
.add(&a_times_q)
.add(&ln_gamma_a))
}
}
fn gamma_nll_f64(row: &GammaRow, p0: f64, p1: f64) -> f64 {
let a = (-2.0 * p1).exp();
let q = p0 + 2.0 * p1;
let log_y = row.y.ln();
-a * log_y + log_y + row.y * (-q).exp() + a * q + statrs::function::gamma::ln_gamma(a)
}
const STENCIL_COEFF: [[f64; 5]; 5] = [
[0.0, 0.0, 1.0, 0.0, 0.0], [1.0, -8.0, 0.0, 8.0, -1.0], [-1.0, 16.0, -30.0, 16.0, -1.0], [-1.0, 2.0, 0.0, -2.0, 1.0], [1.0, -4.0, 6.0, -4.0, 1.0], ];
const STENCIL_NORM: [f64; 5] = [1.0, 12.0, 12.0, 2.0, 1.0];
const STENCIL_TRUNC: [i32; 5] = [64, 4, 4, 2, 2];
fn fd_partial_at(row: &GammaRow, a: usize, b: usize, h: f64) -> f64 {
let ca = STENCIL_COEFF[a];
let cb = STENCIL_COEFF[b];
let mut acc = 0.0;
for (i, &cai) in ca.iter().enumerate() {
if cai == 0.0 {
continue;
}
for (j, &cbj) in cb.iter().enumerate() {
if cbj == 0.0 {
continue;
}
let x0 = row.p0 + (i as f64 - 2.0) * h;
let x1 = row.p1 + (j as f64 - 2.0) * h;
acc += cai * cbj * gamma_nll_f64(row, x0, x1);
}
}
let denom = STENCIL_NORM[a] * STENCIL_NORM[b] * h.powi((a + b) as i32);
acc / denom
}
fn fd_partial(row: &GammaRow, a: usize, b: usize) -> f64 {
if a == 0 && b == 0 {
return gamma_nll_f64(row, row.p0, row.p1);
}
let mut p = i32::MAX;
if a >= 1 {
p = p.min(STENCIL_TRUNC[a]);
}
if b >= 1 {
p = p.min(STENCIL_TRUNC[b]);
}
let h = 1e-2;
let coarse = fd_partial_at(row, a, b, h);
let fine = fd_partial_at(row, a, b, h * 0.5);
let two_p = 2f64.powi(p);
(two_p * fine - coarse) / (two_p - 1.0)
}
fn gamma_fixtures(seed: u64, n: usize) -> Vec<GammaRow> {
let mut rng = Lcg(seed);
(0..n)
.map(|_| GammaRow {
y: rng.uniform(0.5, 3.0),
p0: rng.uniform(-0.5, 0.5),
p1: rng.uniform(-0.3, 0.3),
})
.collect()
}
#[test]
fn gamma_dispersion_jet_tower_matches_independent_fd_oracle() {
let rows = gamma_fixtures(0x9322_0203_6a6d_6d61, 20);
let program = GammaLocScaleRow { rows: rows.clone() };
const REL_TOL: f64 = 5e-6;
const ATOL: f64 = 1e-7;
let close = |jet: f64, fd: f64, label: &str| {
let band = ATOL + REL_TOL * jet.abs().max(fd.abs());
assert!(
(jet - fd).abs() <= band,
"{label}: jet {jet:+.12e} vs FD {fd:+.12e} (|Δ| {:.3e} > band {band:.3e})",
(jet - fd).abs()
);
};
for (row, fixture) in rows.iter().enumerate() {
let tower: Tower4<2> = generic_full_tower(&program, row).expect("gamma jet tower");
close(tower.v, fd_partial(fixture, 0, 0), &format!("row {row} value"));
for i in 0..2 {
let (a, b) = order_counts(&[i]);
close(
tower.g[i],
fd_partial(fixture, a, b),
&format!("row {row} grad[{i}]"),
);
}
for i in 0..2 {
for j in 0..2 {
let (a, b) = order_counts(&[i, j]);
close(
tower.h[i][j],
fd_partial(fixture, a, b),
&format!("row {row} hess[{i}][{j}]"),
);
}
}
for i in 0..2 {
for j in 0..2 {
for k in 0..2 {
let (a, b) = order_counts(&[i, j, k]);
close(
tower.t3[i][j][k],
fd_partial(fixture, a, b),
&format!("row {row} t3[{i}][{j}][{k}]"),
);
}
}
}
for i in 0..2 {
for j in 0..2 {
for k in 0..2 {
for l in 0..2 {
let (a, b) = order_counts(&[i, j, k, l]);
close(
tower.t4[i][j][k][l],
fd_partial(fixture, a, b),
&format!("row {row} t4[{i}][{j}][{k}][{l}]"),
);
}
}
}
}
}
}
fn order_counts(axes: &[usize]) -> (usize, usize) {
let a = axes.iter().filter(|&&x| x == 0).count();
let b = axes.iter().filter(|&&x| x == 1).count();
(a, b)
}
#[test]
fn gamma_dispersion_packed_scalars_match_dense_tower_contractions() {
let rows = gamma_fixtures(0x0bad_c0de_6a6d_6d61, 16);
let program = GammaLocScaleRow { rows: rows.clone() };
let third_dirs: [[f64; 2]; 3] = [[0.9, -0.5], [-1.1, 0.3], [0.2, 1.4]];
const REL_TOL: f64 = 1e-11;
let close = |a: f64, b: f64, label: &str| {
let band = REL_TOL + REL_TOL * a.abs().max(b.abs());
assert!(
(a - b).abs() <= band,
"{label}: packed {a:+.15e} vs dense {b:+.15e} (band {band:.3e})"
);
};
for row in 0..rows.len() {
let tower: Tower4<2> = generic_full_tower(&program, row).expect("tower");
let (v, g, h) = generic_row_kernel(&program, row).expect("Order2 channel");
close(v, tower.v, &format!("row {row} Order2 value"));
for i in 0..2 {
close(g[i], tower.g[i], &format!("row {row} Order2 grad[{i}]"));
for j in 0..2 {
close(
h[i][j],
tower.h[i][j],
&format!("row {row} Order2 hess[{i}][{j}]"),
);
}
}
for (di, dir) in third_dirs.iter().enumerate() {
let third = generic_third_contracted(&program, row, dir).expect("OneSeed third");
let truth = tower.third_contracted(dir);
for i in 0..2 {
for j in 0..2 {
close(
third[i][j],
truth[i][j],
&format!("row {row} dir {di} OneSeed third[{i}][{j}]"),
);
}
}
}
for (ui, u) in third_dirs.iter().enumerate() {
let v = third_dirs[(ui + 1) % third_dirs.len()];
let fourth =
generic_fourth_contracted(&program, row, u, &v).expect("TwoSeed fourth");
let truth = tower.fourth_contracted(u, &v);
for i in 0..2 {
for j in 0..2 {
close(
fourth[i][j],
truth[i][j],
&format!("row {row} pair {ui} TwoSeed fourth[{i}][{j}]"),
);
}
}
}
}
}
struct AffineComposeRow {
c0: f64,
c1: f64,
c2: f64,
p0: f64,
p1: f64,
digamma: bool,
}
impl crate::jet_tower::RowNllProgramGeneric<2> for AffineComposeRow {
fn n_rows(&self) -> usize {
1
}
fn primaries(&self, row: usize) -> Result<[f64; 2], String> {
if row >= self.n_rows() {
return Err(format!("AffineComposeRow: row {row} out of range"));
}
Ok([self.p0, self.p1])
}
fn row_nll_generic<S: JetScalar<2>>(&self, row: usize, p: &[S; 2]) -> Result<S, String> {
if row >= self.n_rows() {
return Err(format!("AffineComposeRow: row {row} out of range"));
}
let x = p[0]
.scale(self.c0)
.add(&p[1].scale(self.c1))
.add(&S::constant(self.c2));
Ok(if self.digamma { x.digamma() } else { x.ln_gamma() })
}
}
#[test]
fn affine_special_function_composition_places_certified_stack_by_order() {
let cases: [(bool, f64, f64, f64, f64, f64); 4] = [
(false, 0.7, -1.3, 4.0, 0.5, -0.25),
(false, 1.1, 0.4, 3.0, -0.3, 0.6),
(true, 0.6, -0.9, 5.0, 0.2, 0.4),
(true, -0.8, 1.2, 6.0, -0.4, 0.1),
];
const REL_TOL: f64 = 1e-11;
let close = |jet: f64, closed: f64, label: &str| {
let band = REL_TOL + REL_TOL * jet.abs().max(closed.abs());
assert!(
(jet - closed).abs() <= band,
"{label}: jet {jet:+.15e} vs closed {closed:+.15e} (band {band:.3e})"
);
};
for (ci, &(digamma, c0, c1, c2, p0, p1)) in cases.iter().enumerate() {
let program = AffineComposeRow {
c0,
c1,
c2,
p0,
p1,
digamma,
};
let x0 = c0 * p0 + c1 * p1 + c2;
let stack = if digamma {
digamma_derivative_stack(x0)
} else {
ln_gamma_derivative_stack(x0)
};
let c = [c0, c1];
let tower: Tower4<2> = generic_full_tower(&program, 0).expect("affine tower");
let tag = if digamma { "digamma" } else { "ln_gamma" };
close(tower.v, stack[0], &format!("case {ci} {tag} value"));
for i in 0..2 {
close(
tower.g[i],
c[i] * stack[1],
&format!("case {ci} {tag} grad[{i}]"),
);
}
for i in 0..2 {
for j in 0..2 {
close(
tower.h[i][j],
c[i] * c[j] * stack[2],
&format!("case {ci} {tag} hess[{i}][{j}]"),
);
}
}
for i in 0..2 {
for j in 0..2 {
for k in 0..2 {
close(
tower.t3[i][j][k],
c[i] * c[j] * c[k] * stack[3],
&format!("case {ci} {tag} t3[{i}][{j}][{k}]"),
);
}
}
}
for i in 0..2 {
for j in 0..2 {
for k in 0..2 {
for l in 0..2 {
close(
tower.t4[i][j][k][l],
c[i] * c[j] * c[k] * c[l] * stack[4],
&format!("case {ci} {tag} t4[{i}][{j}][{k}][{l}]"),
);
}
}
}
}
}
}