use crate::jet_scalar::{JetScalar, OneSeed, Order2, TwoSeed};
use crate::jet_tower::Tower4;
#[derive(Clone, Copy, Debug)]
pub struct PoissonRow {
pub y: f64,
pub a: f64,
pub b: f64,
pub d: f64,
}
pub fn poisson_row_nll<S: JetScalar<2>>(row: &PoissonRow, p: &[S; 2]) -> S {
let eta = p[0]
.scale(row.a)
.add(&p[1].scale(row.b))
.add(&p[0].mul(&p[1]).scale(row.d));
let mu = eta.exp();
let ln_norm = ln_gamma_real(row.y + 1.0);
mu.sub(&eta.scale(row.y)).add(&S::constant(ln_norm))
}
pub fn poisson_jet_tower(row: &PoissonRow, p0: [f64; 2]) -> Tower4<2> {
let vars: [Tower4<2>; 2] = std::array::from_fn(|axis| Tower4::variable(p0[axis], axis));
poisson_row_nll(row, &vars)
}
pub fn poisson_closed_form_tower(row: &PoissonRow, p0: [f64; 2]) -> Tower4<2> {
let (a, b, d, y) = (row.a, row.b, row.d, row.y);
let (q0, q1) = (p0[0], p0[1]);
let eta = a * q0 + b * q1 + d * q0 * q1;
let m = eta.exp();
let c = ln_gamma_real(y + 1.0);
let g = [a + d * q1, b + d * q0];
let mut hh = [[0.0_f64; 2]; 2];
hh[0][1] = d;
hh[1][0] = d;
let mut t = Tower4::<2>::zero();
t.v = m - y * eta + c;
for av in 0..2 {
t.g[av] = (m - y) * g[av];
for bv in 0..2 {
t.h[av][bv] = m * g[av] * g[bv] + (m - y) * hh[av][bv];
for cv in 0..2 {
t.t3[av][bv][cv] = m
* (g[av] * g[bv] * g[cv]
+ hh[av][bv] * g[cv]
+ hh[av][cv] * g[bv]
+ hh[bv][cv] * g[av]);
for dv in 0..2 {
t.t4[av][bv][cv][dv] = m
* (g[dv]
* (g[av] * g[bv] * g[cv]
+ hh[av][bv] * g[cv]
+ hh[av][cv] * g[bv]
+ hh[bv][cv] * g[av])
+ hh[av][dv] * g[bv] * g[cv]
+ g[av] * hh[bv][dv] * g[cv]
+ g[av] * g[bv] * hh[cv][dv]
+ hh[av][bv] * hh[cv][dv]
+ hh[av][cv] * hh[bv][dv]
+ hh[bv][cv] * hh[av][dv]);
}
}
}
}
t
}
#[inline]
fn ln_gamma_real(x: f64) -> f64 {
crate::jet_tower::ln_gamma_derivative_stack(x)[0]
}
struct Lcg(u64);
impl Lcg {
fn next_f64(&mut self) -> f64 {
self.0 = self
.0
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
((self.0 >> 11) as f64) / ((1u64 << 53) as f64)
}
fn uniform(&mut self, lo: f64, hi: f64) -> f64 {
lo + (hi - lo) * self.next_f64()
}
}
const REL_TOL: f64 = 1e-9;
fn close(a: f64, b: f64, label: &str) {
let band = REL_TOL + REL_TOL * a.abs().max(b.abs());
assert!(
(a - b).abs() <= band,
"{label}: jet {a:+.15e} vs hand {b:+.15e} (band {band:.3e})"
);
}
#[test]
fn poisson_jet_tower_matches_hand_derived_closed_form() {
let mut rng = Lcg(0x9322_0451_dead_beef);
for trial in 0..16 {
let y = (rng.uniform(0.0, 8.0)).floor();
let row = PoissonRow {
y,
a: rng.uniform(-1.2, 1.2),
b: rng.uniform(-1.2, 1.2),
d: rng.uniform(-0.9, 0.9),
};
let p0 = [rng.uniform(-0.8, 0.8), rng.uniform(-0.8, 0.8)];
let jet = poisson_jet_tower(&row, p0);
let hand = poisson_closed_form_tower(&row, p0);
close(jet.v, hand.v, &format!("trial {trial} value"));
for i in 0..2 {
close(jet.g[i], hand.g[i], &format!("trial {trial} grad[{i}]"));
for j in 0..2 {
close(
jet.h[i][j],
hand.h[i][j],
&format!("trial {trial} hess[{i}][{j}]"),
);
for k in 0..2 {
close(
jet.t3[i][j][k],
hand.t3[i][j][k],
&format!("trial {trial} t3[{i}][{j}][{k}]"),
);
for l in 0..2 {
close(
jet.t4[i][j][k][l],
hand.t4[i][j][k][l],
&format!("trial {trial} t4[{i}][{j}][{k}][{l}]"),
);
}
}
}
}
}
}
#[test]
fn poisson_packed_scalars_match_hand_derived_contractions() {
let mut rng = Lcg(0x0bad_f00d_1234_5678);
let dirs: [[f64; 2]; 3] = [[0.7, -0.4], [-0.9, 1.3], [1.1, 0.6]];
for trial in 0..12 {
let y = (rng.uniform(0.0, 8.0)).floor();
let row = PoissonRow {
y,
a: rng.uniform(-1.2, 1.2),
b: rng.uniform(-1.2, 1.2),
d: rng.uniform(-0.9, 0.9),
};
let p0 = [rng.uniform(-0.8, 0.8), rng.uniform(-0.8, 0.8)];
let hand = poisson_closed_form_tower(&row, p0);
let o2: [Order2<2>; 2] = std::array::from_fn(|axis| Order2::variable(p0[axis], axis));
let s2 = poisson_row_nll(&row, &o2);
close(s2.value(), hand.v, &format!("trial {trial} Order2 value"));
for i in 0..2 {
close(
s2.g()[i],
hand.g[i],
&format!("trial {trial} Order2 grad[{i}]"),
);
for j in 0..2 {
close(
s2.h()[i][j],
hand.h[i][j],
&format!("trial {trial} Order2 hess[{i}][{j}]"),
);
}
}
for (di, dir) in dirs.iter().enumerate() {
let os: [OneSeed<2>; 2] =
std::array::from_fn(|axis| OneSeed::seed_direction(p0[axis], axis, dir[axis]));
let third = poisson_row_nll(&row, &os).contracted_third();
let truth = hand.third_contracted(dir);
for i in 0..2 {
for j in 0..2 {
close(
third[i][j],
truth[i][j],
&format!("trial {trial} dir {di} OneSeed third[{i}][{j}]"),
);
}
}
}
for (ui, u) in dirs.iter().enumerate() {
let v = dirs[(ui + 1) % dirs.len()];
let ts: [TwoSeed<2>; 2] =
std::array::from_fn(|axis| TwoSeed::seed(p0[axis], axis, u[axis], v[axis]));
let fourth = poisson_row_nll(&row, &ts).contracted_fourth();
let truth = hand.fourth_contracted(u, &v);
for i in 0..2 {
for j in 0..2 {
close(
fourth[i][j],
truth[i][j],
&format!("trial {trial} pair {ui} TwoSeed fourth[{i}][{j}]"),
);
}
}
}
}
}