use ndarray::{Array1, Array2, ArrayView2};
use crate::linalg::faer_ndarray::{FaerArrayView, factorize_symmetricwith_fallback};
use crate::linalg::matrix::FactorizedSystem;
use faer::Side;
#[inline]
fn jet_mul(a: [f64; 3], b: [f64; 3]) -> [f64; 3] {
[
a[0] * b[0],
a[0] * b[1] + a[1] * b[0],
a[0] * b[2] + 2.0 * a[1] * b[1] + a[2] * b[0],
]
}
#[inline]
fn jet_div(a: [f64; 3], b: [f64; 3]) -> [f64; 3] {
let q0 = a[0] / b[0];
let q1 = (a[1] - q0 * b[1]) / b[0];
let q2 = (a[2] - q0 * b[2] - 2.0 * q1 * b[1]) / b[0];
[q0, q1, q2]
}
#[derive(Debug, Clone, Copy)]
pub struct RowExpectedJets {
pub mu1: f64,
pub mu2: f64,
pub mu3: f64,
pub var: f64,
pub dvar_dmu: f64,
pub d2var_dmu2: f64,
pub dispersion: f64,
}
#[derive(Debug, Clone, Copy)]
pub struct RowKappas {
pub k2: f64,
pub k3: f64,
pub k4: f64,
pub k2_1: f64,
pub k2_11: f64,
pub k3_1: f64,
}
impl RowKappas {
pub fn weighted(self, w: f64) -> Self {
Self {
k2: self.k2 * w,
k3: self.k3 * w,
k4: self.k4 * w,
k2_1: self.k2_1 * w,
k2_11: self.k2_11 * w,
k3_1: self.k3_1 * w,
}
}
}
impl RowExpectedJets {
pub fn kappas(&self) -> Result<RowKappas, String> {
let phi = self.dispersion;
if !(phi.is_finite() && phi > 0.0) {
return Err(format!(
"RowExpectedJets::kappas: dispersion must be finite and positive; got {phi}"
));
}
if !(self.var.is_finite() && self.var > 0.0) {
return Err(format!(
"RowExpectedJets::kappas: variance function must be finite and positive; got {}",
self.var
));
}
let mu1_jet = [self.mu1, self.mu2, self.mu3];
let v_jet = [
self.var,
self.dvar_dmu * self.mu1,
self.d2var_dmu2 * self.mu1 * self.mu1 + self.dvar_dmu * self.mu2,
];
let c = jet_div(mu1_jet, v_jet);
let u0 = jet_mul(mu1_jet, c);
let inv_phi = 1.0 / phi;
Ok(RowKappas {
k2: -u0[0] * inv_phi,
k2_1: -u0[1] * inv_phi,
k2_11: -u0[2] * inv_phi,
k3: -(u0[1] + self.mu1 * c[1]) * inv_phi,
k3_1: -(u0[2] + self.mu2 * c[1] + self.mu1 * c[2]) * inv_phi,
k4: -(u0[2] + self.mu2 * c[1] + 2.0 * self.mu1 * c[2]) * inv_phi,
})
}
pub fn gaussian_identity(dispersion: f64) -> Self {
Self {
mu1: 1.0,
mu2: 0.0,
mu3: 0.0,
var: 1.0,
dvar_dmu: 0.0,
d2var_dmu2: 0.0,
dispersion,
}
}
pub fn poisson_log(eta: f64) -> Self {
let mu = eta.exp();
Self {
mu1: mu,
mu2: mu,
mu3: mu,
var: mu,
dvar_dmu: 1.0,
d2var_dmu2: 0.0,
dispersion: 1.0,
}
}
pub fn binomial_logit(eta: f64) -> Self {
let mu = 1.0 / (1.0 + (-eta).exp());
let mu1 = mu * (1.0 - mu);
let mu2 = mu1 * (1.0 - 2.0 * mu);
let mu3 = mu2 * (1.0 - 2.0 * mu) - 2.0 * mu1 * mu1;
Self {
mu1,
mu2,
mu3,
var: mu1,
dvar_dmu: 1.0 - 2.0 * mu,
d2var_dmu2: -2.0,
dispersion: 1.0,
}
}
pub fn gamma_log(eta: f64, dispersion: f64) -> Self {
let mu = eta.exp();
Self {
mu1: mu,
mu2: mu,
mu3: mu,
var: mu * mu,
dvar_dmu: 2.0 * mu,
d2var_dmu2: 2.0,
dispersion,
}
}
}
pub fn lawley_epsilon(
x: ArrayView2<'_, f64>,
kappas: &[RowKappas],
penalty: Option<ArrayView2<'_, f64>>,
) -> Result<f64, String> {
let n = x.nrows();
let k = x.ncols();
if n == 0 || k == 0 {
return Err(format!(
"lawley_epsilon: empty design ({n} rows, {k} columns)"
));
}
if kappas.len() != n {
return Err(format!(
"lawley_epsilon: {} cumulant rows for {n} design rows",
kappas.len()
));
}
let mut j_mat = Array2::<f64>::zeros((k, k));
for (i, row_kappas) in kappas.iter().enumerate() {
let weight = -row_kappas.k2;
if !weight.is_finite() {
return Err(format!(
"lawley_epsilon: non-finite Fisher weight at row {i}"
));
}
for r in 0..k {
let xr = x[[i, r]] * weight;
for s in 0..k {
j_mat[[r, s]] += xr * x[[i, s]];
}
}
}
if let Some(s_pen) = penalty {
if s_pen.nrows() != k || s_pen.ncols() != k {
return Err(format!(
"lawley_epsilon: penalty is {}×{}, expected {k}×{k}",
s_pen.nrows(),
s_pen.ncols()
));
}
j_mat += &s_pen;
}
let j_view = FaerArrayView::new(&j_mat);
let factor = factorize_symmetricwith_fallback(j_view.as_ref(), Side::Lower)
.map_err(|e| format!("lawley_epsilon: information factorization failed: {e:?}"))?;
let j_inv = FactorizedSystem::solvemulti(&factor, &Array2::<f64>::eye(k))?;
let e_pairs = x.dot(&j_inv).dot(&x.t());
let h = e_pairs.diag().to_owned();
let mut lambda4 = 0.0;
for (i, row_kappas) in kappas.iter().enumerate() {
let a_i = row_kappas.k4 / 4.0 - row_kappas.k3_1 + row_kappas.k2_11;
lambda4 += a_i * h[i] * h[i];
}
let k3: Array1<f64> = kappas.iter().map(|r| r.k3).collect();
let k21: Array1<f64> = kappas.iter().map(|r| r.k2_1).collect();
let mut lambda6 = 0.0;
for i in 0..n {
for j in 0..n {
let e_ij = e_pairs[[i, j]];
let cross = k3[i] * k3[j];
let mixed = -k3[i] * k21[j] + k21[i] * k21[j];
lambda6 -= e_ij * e_ij * e_ij * (cross / 6.0 + mixed)
+ h[i] * h[j] * e_ij * (cross / 4.0 + mixed);
}
}
let epsilon = lambda4 - lambda6;
if !epsilon.is_finite() {
return Err(format!(
"lawley_epsilon: non-finite ε (λ₄={lambda4}, λ₆={lambda6})"
));
}
Ok(epsilon)
}
pub const LAWLEY_PAIR_MATRIX_MAX_ROWS: usize = 2048;
pub fn lawley_lr_mean_shift(
x: ArrayView2<'_, f64>,
kappas: &[RowKappas],
penalty: Option<ArrayView2<'_, f64>>,
tested: std::ops::Range<usize>,
) -> Result<f64, String> {
let n = x.nrows();
let k = x.ncols();
if tested.start >= tested.end || tested.end > k {
return Err(format!(
"lawley_lr_mean_shift: tested block {}..{} out of range for {k} columns",
tested.start, tested.end
));
}
let eps_full = lawley_epsilon(x, kappas, penalty)?;
let nuisance: Vec<usize> = (0..k).filter(|c| !tested.contains(c)).collect();
if nuisance.is_empty() {
return Ok(eps_full);
}
let m = nuisance.len();
let mut x_null = Array2::<f64>::zeros((n, m));
for (col_null, &col_full) in nuisance.iter().enumerate() {
for i in 0..n {
x_null[[i, col_null]] = x[[i, col_full]];
}
}
let penalty_null = penalty.map(|s_pen| {
let mut out = Array2::<f64>::zeros((m, m));
for (r_null, &r_full) in nuisance.iter().enumerate() {
for (c_null, &c_full) in nuisance.iter().enumerate() {
out[[r_null, c_null]] = s_pen[[r_full, c_full]];
}
}
out
});
let eps_null = lawley_epsilon(
x_null.view(),
kappas,
penalty_null.as_ref().map(|s_pen| s_pen.view()),
)?;
Ok(eps_full - eps_null)
}
pub fn lawley_lr_bartlett_factor(
x: ArrayView2<'_, f64>,
kappas: &[RowKappas],
penalty: Option<ArrayView2<'_, f64>>,
tested: std::ops::Range<usize>,
ref_df: f64,
) -> Result<f64, String> {
if !(ref_df.is_finite() && ref_df > 0.0) {
return Err(format!(
"lawley_lr_bartlett_factor: reference df must be finite and positive; got {ref_df}"
));
}
let shift = lawley_lr_mean_shift(x, kappas, penalty, tested)?;
let mean_w = ref_df + shift;
let factor = crate::inference::higher_order::bartlett_factor_from_mean(mean_w, ref_df)
.ok_or_else(|| {
format!(
"lawley_lr_bartlett_factor: degenerate mean {mean_w} (Δε = {shift}, d = {ref_df})"
)
})?;
if !(factor.is_finite() && factor > 0.0) {
return Err(format!(
"lawley_lr_bartlett_factor: degenerate factor {factor} (Δε = {shift}, d = {ref_df})"
));
}
Ok(factor)
}
pub fn known_scale_expected_jets(
family: &crate::types::LikelihoodSpec,
eta: f64,
) -> Option<RowExpectedJets> {
known_scale_expected_jets_with_dispersion(family, eta, 1.0)
}
pub fn known_scale_expected_jets_with_dispersion(
family: &crate::types::LikelihoodSpec,
eta: f64,
dispersion: f64,
) -> Option<RowExpectedJets> {
use crate::types::{InverseLink, ResponseFamily, StandardLink};
match (&family.response, &family.link) {
(ResponseFamily::Poisson, InverseLink::Standard(StandardLink::Log)) => {
Some(RowExpectedJets::poisson_log(eta))
}
(ResponseFamily::Binomial, InverseLink::Standard(StandardLink::Logit)) => {
Some(RowExpectedJets::binomial_logit(eta))
}
(ResponseFamily::Gaussian, InverseLink::Standard(StandardLink::Identity)) => {
(dispersion.is_finite() && dispersion > 0.0)
.then(|| RowExpectedJets::gaussian_identity(dispersion))
}
(ResponseFamily::Gamma, InverseLink::Standard(StandardLink::Log)) => {
(dispersion.is_finite() && dispersion > 0.0)
.then(|| RowExpectedJets::gamma_log(eta, dispersion))
}
_ => None,
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::Array2;
fn lawley_epsilon_index_oracle(
x: &Array2<f64>,
kappas: &[RowKappas],
penalty: Option<&Array2<f64>>,
) -> f64 {
let n = x.nrows();
let k = x.ncols();
let mut kappa2 = Array2::<f64>::zeros((k, k));
for i in 0..n {
for r in 0..k {
for s in 0..k {
kappa2[[r, s]] += kappas[i].k2 * x[[i, r]] * x[[i, s]];
}
}
}
if let Some(s_pen) = penalty {
kappa2 -= s_pen;
}
let j_view = FaerArrayView::new(&kappa2);
let factor = factorize_symmetricwith_fallback(j_view.as_ref(), faer::Side::Lower)
.expect("oracle κ_rs factorization");
let kappa_up = FactorizedSystem::solvemulti(&factor, &Array2::<f64>::eye(k))
.expect("oracle κ_rs inverse");
let arr3 = |weights: &dyn Fn(usize) -> f64, r: usize, s: usize, t: usize| -> f64 {
(0..n)
.map(|i| weights(i) * x[[i, r]] * x[[i, s]] * x[[i, t]])
.sum()
};
let arr4 =
|weights: &dyn Fn(usize) -> f64, r: usize, s: usize, t: usize, u: usize| -> f64 {
(0..n)
.map(|i| weights(i) * x[[i, r]] * x[[i, s]] * x[[i, t]] * x[[i, u]])
.sum()
};
let w_k3 = |i: usize| kappas[i].k3;
let w_k21 = |i: usize| kappas[i].k2_1;
let w_k4 = |i: usize| kappas[i].k4;
let w_k31 = |i: usize| kappas[i].k3_1;
let w_k211 = |i: usize| kappas[i].k2_11;
let mut lambda4 = 0.0;
for r in 0..k {
for s in 0..k {
for t in 0..k {
for u in 0..k {
let braces = arr4(&w_k4, r, s, t, u) / 4.0 - arr4(&w_k31, r, s, t, u)
+ arr4(&w_k211, r, t, s, u);
lambda4 += kappa_up[[r, s]] * kappa_up[[t, u]] * braces;
}
}
}
}
let mut lambda6 = 0.0;
for r in 0..k {
for s in 0..k {
for t in 0..k {
for u in 0..k {
for v in 0..k {
for w in 0..k {
let braces = arr3(&w_k3, r, t, v)
* (arr3(&w_k3, s, u, w) / 6.0 - arr3(&w_k21, s, w, u))
+ arr3(&w_k3, r, t, u)
* (arr3(&w_k3, s, v, w) / 4.0 - arr3(&w_k21, s, w, v))
+ arr3(&w_k21, r, t, v) * arr3(&w_k21, s, w, u)
+ arr3(&w_k21, r, t, u) * arr3(&w_k21, s, w, v);
lambda6 +=
kappa_up[[r, s]] * kappa_up[[t, u]] * kappa_up[[v, w]] * braces;
}
}
}
}
}
}
lambda4 - lambda6
}
fn intercept_design(n: usize) -> Array2<f64> {
Array2::<f64>::ones((n, 1))
}
fn digamma_integer(n: usize) -> f64 {
const EULER_GAMMA: f64 = 0.577_215_664_901_532_9;
-EULER_GAMMA + (1..n).map(|j| 1.0 / j as f64).sum::<f64>()
}
#[test]
fn exponential_intercept_matches_exact_digamma_expansion() {
let eta = 0.4;
let mut residual_prev = f64::INFINITY;
for &n in &[8usize, 16, 32] {
let jets = RowExpectedJets::gamma_log(eta, 1.0);
let kappas = vec![jets.kappas().expect("exponential kappas"); n];
let x = intercept_design(n);
let eps = lawley_epsilon(x.view(), &kappas, None).expect("ε");
let analytic = 1.0 / (6.0 * n as f64);
assert!(
(eps - analytic).abs() < 1e-12,
"n={n}: ε={eps} vs analytic 1/(6n)={analytic}"
);
let exact_mean = 2.0 * n as f64 * ((n as f64).ln() - digamma_integer(n));
let residual = (exact_mean - 1.0 - eps).abs();
assert!(
residual < 0.6 / (n * n) as f64,
"n={n}: |E[W] − 1 − ε| = {residual} is not O(n⁻²)"
);
assert!(
residual < residual_prev,
"n={n}: residual {residual} did not shrink from {residual_prev}"
);
residual_prev = residual;
}
}
#[test]
fn poisson_intercept_matches_exact_pmf_mean() {
let lambda: f64 = 1.7;
for &n in &[20usize, 40] {
let jets = RowExpectedJets::poisson_log(lambda.ln());
let kappas = vec![jets.kappas().expect("poisson kappas"); n];
let x = intercept_design(n);
let eps = lawley_epsilon(x.view(), &kappas, None).expect("ε");
let analytic = 1.0 / (6.0 * n as f64 * lambda);
assert!(
(eps - analytic).abs() < 1e-12,
"n={n}: ε={eps} vs analytic 1/(6nλ)={analytic}"
);
let total_rate = n as f64 * lambda;
let mut pmf = (-total_rate).exp();
let mut exact_mean = 0.0;
let s_max = (total_rate + 60.0 * total_rate.sqrt()).ceil() as usize;
for s in 0..=s_max {
if s > 0 {
pmf *= total_rate / s as f64;
}
let s_f = s as f64;
let w = if s == 0 {
2.0 * total_rate
} else {
2.0 * (total_rate - s_f + s_f * (s_f / total_rate).ln())
};
exact_mean += pmf * w;
}
let residual = (exact_mean - 1.0 - eps).abs();
assert!(
residual < 0.7 / (n * n) as f64,
"n={n}: |E[W] − 1 − ε| = {residual} is not O(n⁻²)"
);
}
}
#[test]
fn gaussian_known_variance_lr_factor_is_exactly_one() {
let n = 20;
let k = 3;
let mut x = Array2::<f64>::zeros((n, k));
for i in 0..n {
let z = i as f64 / n as f64;
x[[i, 0]] = 1.0;
x[[i, 1]] = (5.0 * z).sin();
x[[i, 2]] = z - 0.5;
}
let kappas = vec![
RowExpectedJets::gaussian_identity(1.7)
.kappas()
.expect("gaussian kappas");
n
];
let s_pen = Array2::<f64>::eye(k) * 0.4;
for q in [1usize, 2] {
let shift = lawley_lr_mean_shift(x.view(), &kappas, Some(s_pen.view()), k - q..k)
.expect("shift");
assert!(
shift.abs() < 1e-13,
"Gaussian known-variance Δε must be 0; got {shift}"
);
let c = lawley_lr_bartlett_factor(
x.view(),
&kappas,
Some(s_pen.view()),
k - q..k,
q as f64,
)
.expect("factor");
assert!(
(c - 1.0).abs() < 1e-13,
"Gaussian known-variance Bartlett factor must be exactly 1; got {c}"
);
}
}
#[test]
fn exponential_rate_lr_factor_is_one_plus_one_sixth_n() {
let eta = -0.7; for &n in &[8usize, 16, 32] {
let jets = RowExpectedJets::gamma_log(eta, 1.0);
let kappas = vec![jets.kappas().expect("exponential kappas"); n];
let x = intercept_design(n);
let c = lawley_lr_bartlett_factor(x.view(), &kappas, None, 0..1, 1.0).expect("factor");
let analytic = 1.0 + 1.0 / (6.0 * n as f64);
assert!(
(c - analytic).abs() < 1e-12,
"n={n}: factor {c} vs analytic 1 + 1/(6n) = {analytic}"
);
let exact_mean = 2.0 * n as f64 * ((n as f64).ln() - digamma_integer(n));
assert!(
(exact_mean - c).abs() < 0.6 / (n * n) as f64,
"n={n}: |E[W] − c| = {} is not O(n⁻²)",
(exact_mean - c).abs()
);
}
}
#[test]
fn bernoulli_logit_intercept_factor_matches_exact_pmf_mean() {
let mu: f64 = 0.3;
let u = mu * (1.0 - mu);
let eta = (mu / (1.0 - mu)).ln();
let mut residual_prev = f64::INFINITY;
for &n in &[24usize, 48, 96] {
let jets = RowExpectedJets::binomial_logit(eta);
let kappas = vec![jets.kappas().expect("bernoulli kappas"); n];
let x = intercept_design(n);
let shift = lawley_lr_mean_shift(x.view(), &kappas, None, 0..1).expect("Δε");
let analytic = (1.0 - u) / (6.0 * n as f64 * u);
assert!(
(shift - analytic).abs() < 1e-12,
"n={n}: Δε = {shift} vs analytic (1−u)/(6nu) = {analytic}"
);
let c = lawley_lr_bartlett_factor(x.view(), &kappas, None, 0..1, 1.0).expect("factor");
assert!(
(c - (1.0 + analytic)).abs() < 1e-12,
"n={n}: factor {c} vs 1 + ε = {}",
1.0 + analytic
);
let nf = n as f64;
let mut pmf = (1.0 - mu).powi(n as i32); let mut exact_mean = 0.0;
for s in 0..=n {
if s > 0 {
pmf *= mu / (1.0 - mu) * (n - s + 1) as f64 / s as f64;
}
let s_f = s as f64;
let t1 = if s == 0 {
0.0
} else {
s_f * (s_f / (nf * mu)).ln()
};
let t2 = if s == n {
0.0
} else {
(nf - s_f) * ((nf - s_f) / (nf * (1.0 - mu))).ln()
};
exact_mean += pmf * 2.0 * (t1 + t2);
}
let residual = (exact_mean - 1.0 - shift).abs();
assert!(
residual < 2.5 / (n * n) as f64,
"n={n}: |E[W] − 1 − ε| = {residual} is not O(n⁻²)"
);
assert!(
residual < residual_prev,
"n={n}: residual {residual} did not shrink from {residual_prev}"
);
residual_prev = residual;
}
}
#[test]
fn mean_shift_is_full_minus_nuisance_epsilon() {
let n = 19;
let mut x = Array2::<f64>::zeros((n, 2));
let mut kappas = Vec::with_capacity(n);
for i in 0..n {
let z = i as f64 / n as f64;
x[[i, 0]] = 1.0;
x[[i, 1]] = z - 0.5;
let eta = 0.3 - 0.8 * (z - 0.5);
kappas.push(
RowExpectedJets::binomial_logit(eta)
.kappas()
.expect("binomial kappas"),
);
}
let mut s_pen = Array2::<f64>::zeros((2, 2));
s_pen[[1, 1]] = 0.6;
let shift =
lawley_lr_mean_shift(x.view(), &kappas, Some(s_pen.view()), 1..2).expect("shift");
let eps_full = lawley_epsilon(x.view(), &kappas, Some(s_pen.view())).expect("ε_full");
let x_null = x.slice(ndarray::s![.., 0..1]).to_owned();
let s_null = s_pen.slice(ndarray::s![0..1, 0..1]).to_owned();
let eps_null = lawley_epsilon(x_null.view(), &kappas, Some(s_null.view())).expect("ε_null");
assert!(
(shift - (eps_full - eps_null)).abs() < 1e-14,
"Δε = {shift} must equal ε_full − ε_null = {}",
eps_full - eps_null
);
let kappas_w: Vec<RowKappas> = kappas.iter().map(|r| r.weighted(2.0)).collect();
let mut x2 = Array2::<f64>::zeros((2 * n, 2));
let mut kappas2 = Vec::with_capacity(2 * n);
for i in 0..n {
for rep in 0..2 {
let row = 2 * i + rep;
x2[[row, 0]] = x[[i, 0]];
x2[[row, 1]] = x[[i, 1]];
kappas2.push(kappas[i]);
}
}
let shift_w = lawley_lr_mean_shift(x.view(), &kappas_w, Some(s_pen.view()), 1..2)
.expect("weighted shift");
let shift_dup = lawley_lr_mean_shift(x2.view(), &kappas2, Some(s_pen.view()), 1..2)
.expect("duplicated shift");
assert!(
(shift_w - shift_dup).abs() < 1e-12 * (1.0 + shift_dup.abs()),
"weight-2 rows ({shift_w}) must equal duplicated rows ({shift_dup})"
);
}
#[test]
fn row_pair_reduction_matches_index_oracle() {
let n = 17;
let k = 3;
let mut x = Array2::<f64>::zeros((n, k));
let mut kappas = Vec::with_capacity(n);
for i in 0..n {
let z = i as f64 / n as f64;
x[[i, 0]] = 1.0;
x[[i, 1]] = (7.3 * z).sin();
x[[i, 2]] = z * z - 0.4;
let eta = 0.2 + 0.5 * x[[i, 1]] - 0.3 * x[[i, 2]];
kappas.push(
RowExpectedJets::gamma_log(eta, 1.3)
.kappas()
.expect("gamma kappas"),
);
}
let fast = lawley_epsilon(x.view(), &kappas, None).expect("hat form");
let oracle = lawley_epsilon_index_oracle(&x, &kappas, None);
assert!(
(fast - oracle).abs() < 1e-10 * (1.0 + oracle.abs()),
"row-pair ε={fast} vs index-form ε={oracle}"
);
let mut s_pen = Array2::<f64>::eye(k);
s_pen[[0, 0]] = 0.0; s_pen *= 0.8;
let fast_pen = lawley_epsilon(x.view(), &kappas, Some(s_pen.view())).expect("hat form");
let oracle_pen = lawley_epsilon_index_oracle(&x, &kappas, Some(&s_pen));
assert!(
(fast_pen - oracle_pen).abs() < 1e-10 * (1.0 + oracle_pen.abs()),
"penalized row-pair ε={fast_pen} vs index-form ε={oracle_pen}"
);
assert!(
(fast_pen - fast).abs() > 1e-6,
"penalty must move ε (got {fast} → {fast_pen})"
);
}
#[test]
fn canonical_links_collapse_the_mixed_arrays() {
for eta in [-1.3, 0.0, 0.7] {
for jets in [
RowExpectedJets::poisson_log(eta),
RowExpectedJets::binomial_logit(eta),
] {
let kappas = jets.kappas().expect("canonical kappas");
assert!(
(kappas.k3 - kappas.k2_1).abs() < 1e-13 * (1.0 + kappas.k3.abs()),
"canonical link must satisfy κ₃ = κ₂' (η={eta}): {kappas:?}"
);
assert!(
(kappas.k4 - kappas.k3_1).abs() < 1e-13 * (1.0 + kappas.k4.abs()),
"canonical link must satisfy κ₄ = κ₃' (η={eta}): {kappas:?}"
);
}
}
}
#[test]
fn gaussian_identity_needs_no_correction_even_penalized() {
let n = 12;
let jets = RowExpectedJets::gaussian_identity(2.3);
let kappas = vec![jets.kappas().expect("gaussian kappas"); n];
let mut x = Array2::<f64>::ones((n, 2));
for i in 0..n {
x[[i, 1]] = i as f64 - 5.0;
}
let s_pen = Array2::<f64>::eye(2) * 0.5;
let eps = lawley_epsilon(x.view(), &kappas, Some(s_pen.view())).expect("ε");
assert!(
eps.abs() < 1e-14,
"Gaussian-identity ε must be 0; got {eps}"
);
}
#[test]
fn epsilon_is_invariant_under_linear_reparametrization() {
let n = 15;
let k = 3;
let mut x = Array2::<f64>::zeros((n, k));
let mut kappas = Vec::with_capacity(n);
for i in 0..n {
let z = i as f64 / n as f64;
x[[i, 0]] = 1.0;
x[[i, 1]] = (3.1 * z).cos();
x[[i, 2]] = z - 0.5;
let eta = -0.1 + 0.6 * x[[i, 1]] + 0.4 * x[[i, 2]];
kappas.push(
RowExpectedJets::binomial_logit(eta)
.kappas()
.expect("binomial kappas"),
);
}
let t_mat = ndarray::arr2(&[[1.0, 0.3, -0.2], [0.0, 1.4, 0.5], [0.0, 0.0, 0.8]]);
let xt = x.dot(&t_mat);
let eps = lawley_epsilon(x.view(), &kappas, None).expect("ε");
let eps_t = lawley_epsilon(xt.view(), &kappas, None).expect("ε reparam");
assert!(
(eps - eps_t).abs() < 1e-9 * (1.0 + eps.abs()),
"ε not reparametrization-invariant: {eps} vs {eps_t}"
);
}
}