use ndarray::{Array1, Array2, ArrayView2};
use crate::linalg::faer_ndarray::{FaerArrayView, factorize_symmetricwith_fallback};
use crate::linalg::matrix::FactorizedSystem;
use faer::Side;
#[inline]
fn jet_mul(a: [f64; 3], b: [f64; 3]) -> [f64; 3] {
[
a[0] * b[0],
a[0] * b[1] + a[1] * b[0],
a[0] * b[2] + 2.0 * a[1] * b[1] + a[2] * b[0],
]
}
#[inline]
fn jet_div(a: [f64; 3], b: [f64; 3]) -> [f64; 3] {
let q0 = a[0] / b[0];
let q1 = (a[1] - q0 * b[1]) / b[0];
let q2 = (a[2] - q0 * b[2] - 2.0 * q1 * b[1]) / b[0];
[q0, q1, q2]
}
#[derive(Debug, Clone, Copy)]
pub struct RowExpectedJets {
pub mu1: f64,
pub mu2: f64,
pub mu3: f64,
pub var: f64,
pub dvar_dmu: f64,
pub d2var_dmu2: f64,
pub dispersion: f64,
}
#[derive(Debug, Clone, Copy)]
pub struct RowKappas {
pub k2: f64,
pub k3: f64,
pub k4: f64,
pub k2_1: f64,
pub k2_11: f64,
pub k3_1: f64,
}
impl RowKappas {
pub fn weighted(self, w: f64) -> Self {
Self {
k2: self.k2 * w,
k3: self.k3 * w,
k4: self.k4 * w,
k2_1: self.k2_1 * w,
k2_11: self.k2_11 * w,
k3_1: self.k3_1 * w,
}
}
}
impl RowExpectedJets {
pub fn kappas(&self) -> Result<RowKappas, String> {
let phi = self.dispersion;
if !(phi.is_finite() && phi > 0.0) {
return Err(format!(
"RowExpectedJets::kappas: dispersion must be finite and positive; got {phi}"
));
}
if !(self.var.is_finite() && self.var > 0.0) {
return Err(format!(
"RowExpectedJets::kappas: variance function must be finite and positive; got {}",
self.var
));
}
let mu1_jet = [self.mu1, self.mu2, self.mu3];
let v_jet = [
self.var,
self.dvar_dmu * self.mu1,
self.d2var_dmu2 * self.mu1 * self.mu1 + self.dvar_dmu * self.mu2,
];
let c = jet_div(mu1_jet, v_jet);
let u0 = jet_mul(mu1_jet, c);
let inv_phi = 1.0 / phi;
Ok(RowKappas {
k2: -u0[0] * inv_phi,
k2_1: -u0[1] * inv_phi,
k2_11: -u0[2] * inv_phi,
k3: -(u0[1] + self.mu1 * c[1]) * inv_phi,
k3_1: -(u0[2] + self.mu2 * c[1] + self.mu1 * c[2]) * inv_phi,
k4: -(u0[2] + self.mu2 * c[1] + 2.0 * self.mu1 * c[2]) * inv_phi,
})
}
pub fn gaussian_identity(dispersion: f64) -> Self {
Self {
mu1: 1.0,
mu2: 0.0,
mu3: 0.0,
var: 1.0,
dvar_dmu: 0.0,
d2var_dmu2: 0.0,
dispersion,
}
}
pub fn poisson_log(eta: f64) -> Self {
let mu = eta.exp();
Self {
mu1: mu,
mu2: mu,
mu3: mu,
var: mu,
dvar_dmu: 1.0,
d2var_dmu2: 0.0,
dispersion: 1.0,
}
}
pub fn binomial_logit(eta: f64) -> Self {
let mu = 1.0 / (1.0 + (-eta).exp());
let mu1 = mu * (1.0 - mu);
let mu2 = mu1 * (1.0 - 2.0 * mu);
let mu3 = mu2 * (1.0 - 2.0 * mu) - 2.0 * mu1 * mu1;
Self {
mu1,
mu2,
mu3,
var: mu1,
dvar_dmu: 1.0 - 2.0 * mu,
d2var_dmu2: -2.0,
dispersion: 1.0,
}
}
pub fn gamma_log(eta: f64, dispersion: f64) -> Self {
let mu = eta.exp();
Self {
mu1: mu,
mu2: mu,
mu3: mu,
var: mu * mu,
dvar_dmu: 2.0 * mu,
d2var_dmu2: 2.0,
dispersion,
}
}
}
pub fn lawley_epsilon(
x: ArrayView2<'_, f64>,
kappas: &[RowKappas],
penalty: Option<ArrayView2<'_, f64>>,
) -> Result<f64, String> {
let n = x.nrows();
let k = x.ncols();
if n == 0 || k == 0 {
return Err(format!(
"lawley_epsilon: empty design ({n} rows, {k} columns)"
));
}
if kappas.len() != n {
return Err(format!(
"lawley_epsilon: {} cumulant rows for {n} design rows",
kappas.len()
));
}
let mut j_mat = Array2::<f64>::zeros((k, k));
for (i, row_kappas) in kappas.iter().enumerate() {
let weight = -row_kappas.k2;
if !weight.is_finite() {
return Err(format!(
"lawley_epsilon: non-finite Fisher weight at row {i}"
));
}
for r in 0..k {
let xr = x[[i, r]] * weight;
for s in 0..k {
j_mat[[r, s]] += xr * x[[i, s]];
}
}
}
if let Some(s_pen) = penalty {
if s_pen.nrows() != k || s_pen.ncols() != k {
return Err(format!(
"lawley_epsilon: penalty is {}×{}, expected {k}×{k}",
s_pen.nrows(),
s_pen.ncols()
));
}
j_mat += &s_pen;
}
let j_view = FaerArrayView::new(&j_mat);
let factor = factorize_symmetricwith_fallback(j_view.as_ref(), Side::Lower)
.map_err(|e| format!("lawley_epsilon: information factorization failed: {e:?}"))?;
let j_inv = FactorizedSystem::solvemulti(&factor, &Array2::<f64>::eye(k))?;
let e_pairs = x.dot(&j_inv).dot(&x.t());
let h = e_pairs.diag().to_owned();
let mut lambda4 = 0.0;
for (i, row_kappas) in kappas.iter().enumerate() {
let a_i = row_kappas.k4 / 4.0 - row_kappas.k3_1 + row_kappas.k2_11;
lambda4 += a_i * h[i] * h[i];
}
let k3: Array1<f64> = kappas.iter().map(|r| r.k3).collect();
let k21: Array1<f64> = kappas.iter().map(|r| r.k2_1).collect();
let mut lambda6 = 0.0;
for i in 0..n {
for j in 0..n {
let e_ij = e_pairs[[i, j]];
let cross = k3[i] * k3[j];
let mixed = -k3[i] * k21[j] + k21[i] * k21[j];
lambda6 -= e_ij * e_ij * e_ij * (cross / 6.0 + mixed)
+ h[i] * h[j] * e_ij * (cross / 4.0 + mixed);
}
}
let epsilon = lambda4 - lambda6;
if !epsilon.is_finite() {
return Err(format!(
"lawley_epsilon: non-finite ε (λ₄={lambda4}, λ₆={lambda6})"
));
}
Ok(epsilon)
}
pub const LAWLEY_PAIR_MATRIX_MAX_ROWS: usize = 2048;
pub fn lawley_lr_mean_shift(
x: ArrayView2<'_, f64>,
kappas: &[RowKappas],
penalty: Option<ArrayView2<'_, f64>>,
tested: std::ops::Range<usize>,
) -> Result<f64, String> {
let n = x.nrows();
let k = x.ncols();
if tested.start >= tested.end || tested.end > k {
return Err(format!(
"lawley_lr_mean_shift: tested block {}..{} out of range for {k} columns",
tested.start, tested.end
));
}
let eps_full = lawley_epsilon(x, kappas, penalty)?;
let nuisance: Vec<usize> = (0..k).filter(|c| !tested.contains(c)).collect();
if nuisance.is_empty() {
return Ok(eps_full);
}
let m = nuisance.len();
let mut x_null = Array2::<f64>::zeros((n, m));
for (col_null, &col_full) in nuisance.iter().enumerate() {
for i in 0..n {
x_null[[i, col_null]] = x[[i, col_full]];
}
}
let penalty_null = penalty.map(|s_pen| {
let mut out = Array2::<f64>::zeros((m, m));
for (r_null, &r_full) in nuisance.iter().enumerate() {
for (c_null, &c_full) in nuisance.iter().enumerate() {
out[[r_null, c_null]] = s_pen[[r_full, c_full]];
}
}
out
});
let eps_null = lawley_epsilon(
x_null.view(),
kappas,
penalty_null.as_ref().map(|s_pen| s_pen.view()),
)?;
Ok(eps_full - eps_null)
}
pub fn lawley_lr_bartlett_factor(
x: ArrayView2<'_, f64>,
kappas: &[RowKappas],
penalty: Option<ArrayView2<'_, f64>>,
tested: std::ops::Range<usize>,
ref_df: f64,
) -> Result<f64, String> {
if !(ref_df.is_finite() && ref_df > 0.0) {
return Err(format!(
"lawley_lr_bartlett_factor: reference df must be finite and positive; got {ref_df}"
));
}
let shift = lawley_lr_mean_shift(x, kappas, penalty, tested)?;
let mean_w = ref_df + shift;
let factor = crate::inference::higher_order::bartlett_factor_from_mean(mean_w, ref_df)
.ok_or_else(|| {
format!(
"lawley_lr_bartlett_factor: degenerate mean {mean_w} (Δε = {shift}, d = {ref_df})"
)
})?;
if !(factor.is_finite() && factor > 0.0) {
return Err(format!(
"lawley_lr_bartlett_factor: degenerate factor {factor} (Δε = {shift}, d = {ref_df})"
));
}
Ok(factor)
}
#[derive(Debug, Clone)]
pub struct RhoPenaltyComponent {
pub s_component: Array2<f64>,
}
const RHO_VARIATION_STEP: f64 = 0.05;
pub fn lawley_lr_mean_shift_with_rho_variation(
x: ArrayView2<'_, f64>,
kappas: &[RowKappas],
penalty: ArrayView2<'_, f64>,
tested: std::ops::Range<usize>,
components: &[RhoPenaltyComponent],
rho_cov: ArrayView2<'_, f64>,
) -> Result<f64, String> {
let k = x.ncols();
let m = components.len();
if m == 0 {
return Err(
"lawley_lr_mean_shift_with_rho_variation: no smoothing-parameter components"
.to_string(),
);
}
if rho_cov.nrows() != m || rho_cov.ncols() != m {
return Err(format!(
"lawley_lr_mean_shift_with_rho_variation: rho_cov is {}×{}, expected {m}×{m}",
rho_cov.nrows(),
rho_cov.ncols()
));
}
for b in 0..m {
for c in 0..m {
let v_bc = rho_cov[[b, c]];
if !v_bc.is_finite() {
return Err(format!(
"lawley_lr_mean_shift_with_rho_variation: rho_cov[{b},{c}] is not finite"
));
}
let v_cb = rho_cov[[c, b]];
let tol = 1e-10 * (1.0 + v_bc.abs().max(v_cb.abs()));
if (v_bc - v_cb).abs() > tol {
return Err(format!(
"lawley_lr_mean_shift_with_rho_variation: rho_cov must be symmetric; \
entries [{b},{c}]={v_bc} and [{c},{b}]={v_cb} differ"
));
}
}
}
if penalty.nrows() != k || penalty.ncols() != k {
return Err(format!(
"lawley_lr_mean_shift_with_rho_variation: penalty is {}×{}, expected {k}×{k}",
penalty.nrows(),
penalty.ncols()
));
}
for (b, comp) in components.iter().enumerate() {
if comp.s_component.nrows() != k || comp.s_component.ncols() != k {
return Err(format!(
"lawley_lr_mean_shift_with_rho_variation: component {b} is {}×{}, expected {k}×{k}",
comp.s_component.nrows(),
comp.s_component.ncols()
));
}
}
let conditional = lawley_lr_mean_shift(x, kappas, Some(penalty), tested.clone())?;
let shift_at = |steps: &[(usize, f64)]| -> Result<f64, String> {
let mut s = penalty.to_owned();
for &(b, t) in steps {
let scale = t.exp() - 1.0;
s.scaled_add(scale, &components[b].s_component);
}
lawley_lr_mean_shift(x, kappas, Some(s.view()), tested.clone())
};
let h = RHO_VARIATION_STEP;
let mut quad = 0.0; let base = conditional;
for b in 0..m {
let fp = shift_at(&[(b, h)])?;
let fm = shift_at(&[(b, -h)])?;
let hbb = (fp - 2.0 * base + fm) / (h * h);
if !hbb.is_finite() {
return Err(format!(
"lawley_lr_mean_shift_with_rho_variation: non-finite curvature H[{b},{b}]"
));
}
quad += 0.5 * hbb * rho_cov[[b, b]];
for c in (b + 1)..m {
let fpp = shift_at(&[(b, h), (c, h)])?;
let fpm = shift_at(&[(b, h), (c, -h)])?;
let fmp = shift_at(&[(b, -h), (c, h)])?;
let fmm = shift_at(&[(b, -h), (c, -h)])?;
let hbc = (fpp - fpm - fmp + fmm) / (4.0 * h * h);
if !hbc.is_finite() {
return Err(format!(
"lawley_lr_mean_shift_with_rho_variation: non-finite curvature H[{b},{c}]"
));
}
quad += hbc * rho_cov[[b, c]];
}
}
let total = conditional + quad;
if !total.is_finite() {
return Err(format!(
"lawley_lr_mean_shift_with_rho_variation: non-finite total shift \
(conditional={conditional}, rho-variation={quad})"
));
}
Ok(total)
}
pub fn lawley_lr_bartlett_factor_with_rho_variation(
x: ArrayView2<'_, f64>,
kappas: &[RowKappas],
penalty: ArrayView2<'_, f64>,
tested: std::ops::Range<usize>,
components: &[RhoPenaltyComponent],
rho_cov: ArrayView2<'_, f64>,
ref_df: f64,
) -> Result<f64, String> {
if !(ref_df.is_finite() && ref_df > 0.0) {
return Err(format!(
"lawley_lr_bartlett_factor_with_rho_variation: reference df must be finite and positive; got {ref_df}"
));
}
let shift =
lawley_lr_mean_shift_with_rho_variation(x, kappas, penalty, tested, components, rho_cov)?;
let mean_w = ref_df + shift;
let factor = crate::inference::higher_order::bartlett_factor_from_mean(mean_w, ref_df)
.ok_or_else(|| {
format!(
"lawley_lr_bartlett_factor_with_rho_variation: degenerate mean {mean_w} \
(Δε(ρ̂) = {shift}, d = {ref_df})"
)
})?;
if !(factor.is_finite() && factor > 0.0) {
return Err(format!(
"lawley_lr_bartlett_factor_with_rho_variation: degenerate factor {factor} \
(Δε(ρ̂) = {shift}, d = {ref_df})"
));
}
Ok(factor)
}
pub fn known_scale_expected_jets(
family: &crate::types::LikelihoodSpec,
eta: f64,
) -> Option<RowExpectedJets> {
known_scale_expected_jets_with_dispersion(family, eta, 1.0)
}
pub fn known_scale_expected_jets_with_dispersion(
family: &crate::types::LikelihoodSpec,
eta: f64,
dispersion: f64,
) -> Option<RowExpectedJets> {
use crate::types::{InverseLink, ResponseFamily, StandardLink};
match (&family.response, &family.link) {
(ResponseFamily::Poisson, InverseLink::Standard(StandardLink::Log)) => {
Some(RowExpectedJets::poisson_log(eta))
}
(ResponseFamily::Binomial, InverseLink::Standard(StandardLink::Logit)) => {
Some(RowExpectedJets::binomial_logit(eta))
}
(ResponseFamily::Gaussian, InverseLink::Standard(StandardLink::Identity)) => {
(dispersion.is_finite() && dispersion > 0.0)
.then(|| RowExpectedJets::gaussian_identity(dispersion))
}
(ResponseFamily::Gamma, InverseLink::Standard(StandardLink::Log)) => {
(dispersion.is_finite() && dispersion > 0.0)
.then(|| RowExpectedJets::gamma_log(eta, dispersion))
}
_ => None,
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::Array2;
fn lawley_epsilon_index_oracle(
x: &Array2<f64>,
kappas: &[RowKappas],
penalty: Option<&Array2<f64>>,
) -> f64 {
let n = x.nrows();
let k = x.ncols();
let mut kappa2 = Array2::<f64>::zeros((k, k));
for i in 0..n {
for r in 0..k {
for s in 0..k {
kappa2[[r, s]] += kappas[i].k2 * x[[i, r]] * x[[i, s]];
}
}
}
if let Some(s_pen) = penalty {
kappa2 -= s_pen;
}
let j_view = FaerArrayView::new(&kappa2);
let factor = factorize_symmetricwith_fallback(j_view.as_ref(), faer::Side::Lower)
.expect("oracle κ_rs factorization");
let kappa_up = FactorizedSystem::solvemulti(&factor, &Array2::<f64>::eye(k))
.expect("oracle κ_rs inverse");
let arr3 = |weights: &dyn Fn(usize) -> f64, r: usize, s: usize, t: usize| -> f64 {
(0..n)
.map(|i| weights(i) * x[[i, r]] * x[[i, s]] * x[[i, t]])
.sum()
};
let arr4 =
|weights: &dyn Fn(usize) -> f64, r: usize, s: usize, t: usize, u: usize| -> f64 {
(0..n)
.map(|i| weights(i) * x[[i, r]] * x[[i, s]] * x[[i, t]] * x[[i, u]])
.sum()
};
let w_k3 = |i: usize| kappas[i].k3;
let w_k21 = |i: usize| kappas[i].k2_1;
let w_k4 = |i: usize| kappas[i].k4;
let w_k31 = |i: usize| kappas[i].k3_1;
let w_k211 = |i: usize| kappas[i].k2_11;
let mut lambda4 = 0.0;
for r in 0..k {
for s in 0..k {
for t in 0..k {
for u in 0..k {
let braces = arr4(&w_k4, r, s, t, u) / 4.0 - arr4(&w_k31, r, s, t, u)
+ arr4(&w_k211, r, t, s, u);
lambda4 += kappa_up[[r, s]] * kappa_up[[t, u]] * braces;
}
}
}
}
let mut lambda6 = 0.0;
for r in 0..k {
for s in 0..k {
for t in 0..k {
for u in 0..k {
for v in 0..k {
for w in 0..k {
let braces = arr3(&w_k3, r, t, v)
* (arr3(&w_k3, s, u, w) / 6.0 - arr3(&w_k21, s, w, u))
+ arr3(&w_k3, r, t, u)
* (arr3(&w_k3, s, v, w) / 4.0 - arr3(&w_k21, s, w, v))
+ arr3(&w_k21, r, t, v) * arr3(&w_k21, s, w, u)
+ arr3(&w_k21, r, t, u) * arr3(&w_k21, s, w, v);
lambda6 +=
kappa_up[[r, s]] * kappa_up[[t, u]] * kappa_up[[v, w]] * braces;
}
}
}
}
}
}
lambda4 - lambda6
}
fn intercept_design(n: usize) -> Array2<f64> {
Array2::<f64>::ones((n, 1))
}
fn digamma_integer(n: usize) -> f64 {
const EULER_GAMMA: f64 = 0.577_215_664_901_532_9;
-EULER_GAMMA + (1..n).map(|j| 1.0 / j as f64).sum::<f64>()
}
#[test]
fn exponential_intercept_matches_exact_digamma_expansion() {
let eta = 0.4;
let mut residual_prev = f64::INFINITY;
for &n in &[8usize, 16, 32] {
let jets = RowExpectedJets::gamma_log(eta, 1.0);
let kappas = vec![jets.kappas().expect("exponential kappas"); n];
let x = intercept_design(n);
let eps = lawley_epsilon(x.view(), &kappas, None).expect("ε");
let analytic = 1.0 / (6.0 * n as f64);
assert!(
(eps - analytic).abs() < 1e-12,
"n={n}: ε={eps} vs analytic 1/(6n)={analytic}"
);
let exact_mean = 2.0 * n as f64 * ((n as f64).ln() - digamma_integer(n));
let residual = (exact_mean - 1.0 - eps).abs();
assert!(
residual < 0.6 / (n * n) as f64,
"n={n}: |E[W] − 1 − ε| = {residual} is not O(n⁻²)"
);
assert!(
residual < residual_prev,
"n={n}: residual {residual} did not shrink from {residual_prev}"
);
residual_prev = residual;
}
}
#[test]
fn penalty_shift_term_is_consumed() {
let n = 40usize;
let eta = 0.2_f64;
let jets = RowExpectedJets::poisson_log(eta);
let kappas = vec![jets.kappas().expect("poisson kappas"); n];
let mut x = Array2::<f64>::ones((n, 2));
for i in 0..n {
x[[i, 1]] = (i as f64) / (n as f64) - 0.5;
}
let eps_unpen = lawley_epsilon(x.view(), &kappas, None).expect("ε unpenalized");
let mut distinct = std::collections::BTreeSet::new();
for &lambda in &[0.5_f64, 2.0, 8.0, 32.0] {
let mut s = Array2::<f64>::zeros((2, 2));
s[[1, 1]] = lambda;
let eps_pen = lawley_epsilon(x.view(), &kappas, Some(s.view())).expect("ε penalized");
assert!(
(eps_pen - eps_unpen).abs() > 1e-9,
"λ={lambda}: penalty did not move ε ({eps_pen} vs {eps_unpen}) — S is being dropped"
);
assert!(
eps_pen.is_finite(),
"λ={lambda}: ε must be finite, got {eps_pen}"
);
distinct.insert((eps_pen * 1e9) as i64);
}
assert!(
distinct.len() >= 3,
"ε must vary with λ; got {} distinct values",
distinct.len()
);
}
#[test]
fn poisson_intercept_matches_exact_pmf_mean() {
let lambda: f64 = 1.7;
for &n in &[20usize, 40] {
let jets = RowExpectedJets::poisson_log(lambda.ln());
let kappas = vec![jets.kappas().expect("poisson kappas"); n];
let x = intercept_design(n);
let eps = lawley_epsilon(x.view(), &kappas, None).expect("ε");
let analytic = 1.0 / (6.0 * n as f64 * lambda);
assert!(
(eps - analytic).abs() < 1e-12,
"n={n}: ε={eps} vs analytic 1/(6nλ)={analytic}"
);
let total_rate = n as f64 * lambda;
let mut pmf = (-total_rate).exp();
let mut exact_mean = 0.0;
let s_max = (total_rate + 60.0 * total_rate.sqrt()).ceil() as usize;
for s in 0..=s_max {
if s > 0 {
pmf *= total_rate / s as f64;
}
let s_f = s as f64;
let w = if s == 0 {
2.0 * total_rate
} else {
2.0 * (total_rate - s_f + s_f * (s_f / total_rate).ln())
};
exact_mean += pmf * w;
}
let residual = (exact_mean - 1.0 - eps).abs();
assert!(
residual < 0.7 / (n * n) as f64,
"n={n}: |E[W] − 1 − ε| = {residual} is not O(n⁻²)"
);
}
}
#[test]
fn gaussian_known_variance_lr_factor_is_exactly_one() {
let n = 20;
let k = 3;
let mut x = Array2::<f64>::zeros((n, k));
for i in 0..n {
let z = i as f64 / n as f64;
x[[i, 0]] = 1.0;
x[[i, 1]] = (5.0 * z).sin();
x[[i, 2]] = z - 0.5;
}
let kappas = vec![
RowExpectedJets::gaussian_identity(1.7)
.kappas()
.expect("gaussian kappas");
n
];
let s_pen = Array2::<f64>::eye(k) * 0.4;
for q in [1usize, 2] {
let shift = lawley_lr_mean_shift(x.view(), &kappas, Some(s_pen.view()), k - q..k)
.expect("shift");
assert!(
shift.abs() < 1e-13,
"Gaussian known-variance Δε must be 0; got {shift}"
);
let c = lawley_lr_bartlett_factor(
x.view(),
&kappas,
Some(s_pen.view()),
k - q..k,
q as f64,
)
.expect("factor");
assert!(
(c - 1.0).abs() < 1e-13,
"Gaussian known-variance Bartlett factor must be exactly 1; got {c}"
);
}
}
#[test]
fn exponential_rate_lr_factor_is_one_plus_one_sixth_n() {
let eta = -0.7; for &n in &[8usize, 16, 32] {
let jets = RowExpectedJets::gamma_log(eta, 1.0);
let kappas = vec![jets.kappas().expect("exponential kappas"); n];
let x = intercept_design(n);
let c = lawley_lr_bartlett_factor(x.view(), &kappas, None, 0..1, 1.0).expect("factor");
let analytic = 1.0 + 1.0 / (6.0 * n as f64);
assert!(
(c - analytic).abs() < 1e-12,
"n={n}: factor {c} vs analytic 1 + 1/(6n) = {analytic}"
);
let exact_mean = 2.0 * n as f64 * ((n as f64).ln() - digamma_integer(n));
assert!(
(exact_mean - c).abs() < 0.6 / (n * n) as f64,
"n={n}: |E[W] − c| = {} is not O(n⁻²)",
(exact_mean - c).abs()
);
}
}
#[test]
fn bernoulli_logit_intercept_factor_matches_exact_pmf_mean() {
let mu: f64 = 0.3;
let u = mu * (1.0 - mu);
let eta = (mu / (1.0 - mu)).ln();
let mut residual_prev = f64::INFINITY;
for &n in &[24usize, 48, 96] {
let jets = RowExpectedJets::binomial_logit(eta);
let kappas = vec![jets.kappas().expect("bernoulli kappas"); n];
let x = intercept_design(n);
let shift = lawley_lr_mean_shift(x.view(), &kappas, None, 0..1).expect("Δε");
let analytic = (1.0 - u) / (6.0 * n as f64 * u);
assert!(
(shift - analytic).abs() < 1e-12,
"n={n}: Δε = {shift} vs analytic (1−u)/(6nu) = {analytic}"
);
let c = lawley_lr_bartlett_factor(x.view(), &kappas, None, 0..1, 1.0).expect("factor");
assert!(
(c - (1.0 + analytic)).abs() < 1e-12,
"n={n}: factor {c} vs 1 + ε = {}",
1.0 + analytic
);
let nf = n as f64;
let mut pmf = (1.0 - mu).powi(n as i32); let mut exact_mean = 0.0;
for s in 0..=n {
if s > 0 {
pmf *= mu / (1.0 - mu) * (n - s + 1) as f64 / s as f64;
}
let s_f = s as f64;
let t1 = if s == 0 {
0.0
} else {
s_f * (s_f / (nf * mu)).ln()
};
let t2 = if s == n {
0.0
} else {
(nf - s_f) * ((nf - s_f) / (nf * (1.0 - mu))).ln()
};
exact_mean += pmf * 2.0 * (t1 + t2);
}
let residual = (exact_mean - 1.0 - shift).abs();
assert!(
residual < 2.5 / (n * n) as f64,
"n={n}: |E[W] − 1 − ε| = {residual} is not O(n⁻²)"
);
assert!(
residual < residual_prev,
"n={n}: residual {residual} did not shrink from {residual_prev}"
);
residual_prev = residual;
}
}
#[test]
fn mean_shift_is_full_minus_nuisance_epsilon() {
let n = 19;
let mut x = Array2::<f64>::zeros((n, 2));
let mut kappas = Vec::with_capacity(n);
for i in 0..n {
let z = i as f64 / n as f64;
x[[i, 0]] = 1.0;
x[[i, 1]] = z - 0.5;
let eta = 0.3 - 0.8 * (z - 0.5);
kappas.push(
RowExpectedJets::binomial_logit(eta)
.kappas()
.expect("binomial kappas"),
);
}
let mut s_pen = Array2::<f64>::zeros((2, 2));
s_pen[[1, 1]] = 0.6;
let shift =
lawley_lr_mean_shift(x.view(), &kappas, Some(s_pen.view()), 1..2).expect("shift");
let eps_full = lawley_epsilon(x.view(), &kappas, Some(s_pen.view())).expect("ε_full");
let x_null = x.slice(ndarray::s![.., 0..1]).to_owned();
let s_null = s_pen.slice(ndarray::s![0..1, 0..1]).to_owned();
let eps_null = lawley_epsilon(x_null.view(), &kappas, Some(s_null.view())).expect("ε_null");
assert!(
(shift - (eps_full - eps_null)).abs() < 1e-14,
"Δε = {shift} must equal ε_full − ε_null = {}",
eps_full - eps_null
);
let kappas_w: Vec<RowKappas> = kappas.iter().map(|r| r.weighted(2.0)).collect();
let mut x2 = Array2::<f64>::zeros((2 * n, 2));
let mut kappas2 = Vec::with_capacity(2 * n);
for i in 0..n {
for rep in 0..2 {
let row = 2 * i + rep;
x2[[row, 0]] = x[[i, 0]];
x2[[row, 1]] = x[[i, 1]];
kappas2.push(kappas[i]);
}
}
let shift_w = lawley_lr_mean_shift(x.view(), &kappas_w, Some(s_pen.view()), 1..2)
.expect("weighted shift");
let shift_dup = lawley_lr_mean_shift(x2.view(), &kappas2, Some(s_pen.view()), 1..2)
.expect("duplicated shift");
assert!(
(shift_w - shift_dup).abs() < 1e-12 * (1.0 + shift_dup.abs()),
"weight-2 rows ({shift_w}) must equal duplicated rows ({shift_dup})"
);
}
#[test]
fn row_pair_reduction_matches_index_oracle() {
let n = 17;
let k = 3;
let mut x = Array2::<f64>::zeros((n, k));
let mut kappas = Vec::with_capacity(n);
for i in 0..n {
let z = i as f64 / n as f64;
x[[i, 0]] = 1.0;
x[[i, 1]] = (7.3 * z).sin();
x[[i, 2]] = z * z - 0.4;
let eta = 0.2 + 0.5 * x[[i, 1]] - 0.3 * x[[i, 2]];
kappas.push(
RowExpectedJets::gamma_log(eta, 1.3)
.kappas()
.expect("gamma kappas"),
);
}
let fast = lawley_epsilon(x.view(), &kappas, None).expect("hat form");
let oracle = lawley_epsilon_index_oracle(&x, &kappas, None);
assert!(
(fast - oracle).abs() < 1e-10 * (1.0 + oracle.abs()),
"row-pair ε={fast} vs index-form ε={oracle}"
);
let mut s_pen = Array2::<f64>::eye(k);
s_pen[[0, 0]] = 0.0; s_pen *= 0.8;
let fast_pen = lawley_epsilon(x.view(), &kappas, Some(s_pen.view())).expect("hat form");
let oracle_pen = lawley_epsilon_index_oracle(&x, &kappas, Some(&s_pen));
assert!(
(fast_pen - oracle_pen).abs() < 1e-10 * (1.0 + oracle_pen.abs()),
"penalized row-pair ε={fast_pen} vs index-form ε={oracle_pen}"
);
assert!(
(fast_pen - fast).abs() > 1e-6,
"penalty must move ε (got {fast} → {fast_pen})"
);
}
#[test]
fn canonical_links_collapse_the_mixed_arrays() {
for eta in [-1.3, 0.0, 0.7] {
for jets in [
RowExpectedJets::poisson_log(eta),
RowExpectedJets::binomial_logit(eta),
] {
let kappas = jets.kappas().expect("canonical kappas");
assert!(
(kappas.k3 - kappas.k2_1).abs() < 1e-13 * (1.0 + kappas.k3.abs()),
"canonical link must satisfy κ₃ = κ₂' (η={eta}): {kappas:?}"
);
assert!(
(kappas.k4 - kappas.k3_1).abs() < 1e-13 * (1.0 + kappas.k4.abs()),
"canonical link must satisfy κ₄ = κ₃' (η={eta}): {kappas:?}"
);
}
}
}
#[test]
fn gaussian_identity_needs_no_correction_even_penalized() {
let n = 12;
let jets = RowExpectedJets::gaussian_identity(2.3);
let kappas = vec![jets.kappas().expect("gaussian kappas"); n];
let mut x = Array2::<f64>::ones((n, 2));
for i in 0..n {
x[[i, 1]] = i as f64 - 5.0;
}
let s_pen = Array2::<f64>::eye(2) * 0.5;
let eps = lawley_epsilon(x.view(), &kappas, Some(s_pen.view())).expect("ε");
assert!(
eps.abs() < 1e-14,
"Gaussian-identity ε must be 0; got {eps}"
);
}
#[test]
fn rho_variation_correction_is_zero_for_gaussian() {
let n = 16usize;
let jets = RowExpectedJets::gaussian_identity(1.3);
let kappas = vec![jets.kappas().expect("gaussian kappas"); n];
let mut x = Array2::<f64>::ones((n, 2));
for i in 0..n {
x[[i, 1]] = i as f64 / n as f64 - 0.5;
}
let mut s_comp = Array2::<f64>::zeros((2, 2));
s_comp[[1, 1]] = 2.0;
let penalty = s_comp.clone();
let components = vec![RhoPenaltyComponent {
s_component: s_comp,
}];
let rho_cov = Array2::from_shape_vec((1, 1), vec![5.0]).unwrap();
let total = lawley_lr_mean_shift_with_rho_variation(
x.view(),
&kappas,
penalty.view(),
1..2,
&components,
rho_cov.view(),
)
.expect("rho-variation shift");
assert!(
total.abs() < 1e-12,
"Gaussian ρ̂-variation total shift must be 0; got {total}"
);
}
#[test]
fn rho_variation_correction_matches_curvature_times_variance() {
let n = 50usize;
let mut x = Array2::<f64>::ones((n, 2));
let mut kappas = Vec::with_capacity(n);
for i in 0..n {
let z = i as f64 / n as f64 - 0.5;
x[[i, 1]] = z;
let eta = 0.3 + 0.6 * z;
kappas.push(
RowExpectedJets::poisson_log(eta)
.kappas()
.expect("poisson kappas"),
);
}
let lambda = 3.0_f64;
let mut s_comp = Array2::<f64>::zeros((2, 2));
s_comp[[1, 1]] = lambda;
let penalty = s_comp.clone();
let components = vec![RhoPenaltyComponent {
s_component: s_comp.clone(),
}];
let tested = 1..2;
let conditional =
lawley_lr_mean_shift(x.view(), &kappas, Some(penalty.view()), tested.clone())
.expect("conditional shift");
let de_at = |t: f64| {
let mut s = Array2::<f64>::zeros((2, 2));
s[[1, 1]] = lambda * t.exp();
lawley_lr_mean_shift(x.view(), &kappas, Some(s.view()), tested.clone())
.expect("perturbed shift")
};
let h = 0.05_f64;
let d2_h = (de_at(h) - 2.0 * conditional + de_at(-h)) / (h * h);
let d2_2h = (de_at(2.0 * h) - 2.0 * conditional + de_at(-2.0 * h)) / (4.0 * h * h);
let curvature = (4.0 * d2_h - d2_2h) / 3.0;
assert!(
curvature.abs() > 1e-9,
"fixture must have non-zero ρ-curvature; got {curvature}"
);
let var_rho = 0.8_f64; let rho_cov = Array2::from_shape_vec((1, 1), vec![var_rho]).unwrap();
let total = lawley_lr_mean_shift_with_rho_variation(
x.view(),
&kappas,
penalty.view(),
tested.clone(),
&components,
rho_cov.view(),
)
.expect("rho-variation shift");
let expected = conditional + 0.5 * curvature * var_rho;
assert!(
(total - expected).abs() < 1e-6 * (1.0 + expected.abs()),
"ρ̂-variation total {total} must equal conditional + ½ H Var = {expected} \
(conditional={conditional}, H={curvature}, Var={var_rho})"
);
assert!(
(total - conditional).abs() > 1e-9,
"ρ̂-variation correction must be non-zero (H={curvature}, Var={var_rho}); \
total={total} conditional={conditional}"
);
let zero_cov = Array2::from_shape_vec((1, 1), vec![0.0]).unwrap();
let total_zero = lawley_lr_mean_shift_with_rho_variation(
x.view(),
&kappas,
penalty.view(),
tested.clone(),
&components,
zero_cov.view(),
)
.expect("zero-variance shift");
assert!(
(total_zero - conditional).abs() < 1e-12,
"zero ρ-variance must recover the conditional shift: {total_zero} vs {conditional}"
);
}
#[test]
fn rho_variation_factor_folds_estimated_lambda_into_c() {
let n = 50usize;
let mut x = Array2::<f64>::ones((n, 2));
let mut kappas = Vec::with_capacity(n);
for i in 0..n {
let z = i as f64 / n as f64 - 0.5;
x[[i, 1]] = z;
let eta = 0.3 + 0.6 * z;
kappas.push(
RowExpectedJets::poisson_log(eta)
.kappas()
.expect("poisson kappas"),
);
}
let lambda = 3.0_f64;
let mut s_comp = Array2::<f64>::zeros((2, 2));
s_comp[[1, 1]] = lambda;
let penalty = s_comp.clone();
let components = vec![RhoPenaltyComponent {
s_component: s_comp,
}];
let tested = 1..2;
let ref_df = 1.0_f64;
let var_rho = 0.8_f64;
let rho_cov = Array2::from_shape_vec((1, 1), vec![var_rho]).unwrap();
let total = lawley_lr_mean_shift_with_rho_variation(
x.view(),
&kappas,
penalty.view(),
tested.clone(),
&components,
rho_cov.view(),
)
.expect("total shift");
let factor = lawley_lr_bartlett_factor_with_rho_variation(
x.view(),
&kappas,
penalty.view(),
tested.clone(),
&components,
rho_cov.view(),
ref_df,
)
.expect("estimated-λ factor");
assert!(
(factor - (1.0 + total / ref_df)).abs() < 1e-12,
"estimated-λ factor {factor} must equal 1 + Δε(ρ̂)/d = {}",
1.0 + total / ref_df
);
let conditional_factor = lawley_lr_bartlett_factor(
x.view(),
&kappas,
Some(penalty.view()),
tested.clone(),
ref_df,
)
.expect("conditional factor");
assert!(
(factor - conditional_factor).abs() > 1e-9,
"estimated-λ factor {factor} must differ from the fixed-λ factor \
{conditional_factor} (ρ̂-variation is load-bearing)"
);
let g_kappas = vec![
RowExpectedJets::gaussian_identity(1.3)
.kappas()
.expect("gaussian kappas");
n
];
let big_cov = Array2::from_shape_vec((1, 1), vec![5.0]).unwrap();
let g_factor = lawley_lr_bartlett_factor_with_rho_variation(
x.view(),
&g_kappas,
penalty.view(),
tested.clone(),
&components,
big_cov.view(),
ref_df,
)
.expect("gaussian factor");
assert!(
(g_factor - 1.0).abs() < 1e-12,
"Gaussian known-variance estimated-λ factor must be exactly 1; got {g_factor}"
);
assert!(
lawley_lr_bartlett_factor_with_rho_variation(
x.view(),
&kappas,
penalty.view(),
tested.clone(),
&components,
rho_cov.view(),
0.0,
)
.is_err()
);
}
#[test]
fn rho_variation_includes_symmetric_cross_terms() {
let n = 40usize;
let mut x = Array2::<f64>::ones((n, 3));
let mut kappas = Vec::with_capacity(n);
for i in 0..n {
let z = i as f64 / n as f64 - 0.5;
x[[i, 1]] = z;
x[[i, 2]] = z * z - 0.1;
let eta = 0.2 + 0.5 * z - 0.3 * x[[i, 2]];
kappas.push(
RowExpectedJets::binomial_logit(eta)
.kappas()
.expect("binomial kappas"),
);
}
let (l1, l2) = (2.0_f64, 4.0_f64);
let mut s1 = Array2::<f64>::zeros((3, 3));
s1[[1, 1]] = l1;
let mut s2 = Array2::<f64>::zeros((3, 3));
s2[[2, 2]] = l2;
let penalty = &s1 + &s2;
let components = vec![
RhoPenaltyComponent {
s_component: s1.clone(),
},
RhoPenaltyComponent {
s_component: s2.clone(),
},
];
let tested = 1..3;
let conditional =
lawley_lr_mean_shift(x.view(), &kappas, Some(penalty.view()), tested.clone())
.expect("conditional");
let de = |t0: f64, t1: f64| {
let mut s = Array2::<f64>::zeros((3, 3));
s[[1, 1]] = l1 * t0.exp();
s[[2, 2]] = l2 * t1.exp();
lawley_lr_mean_shift(x.view(), &kappas, Some(s.view()), tested.clone())
.expect("perturbed")
};
let h = 0.05_f64;
let h00 = (de(h, 0.0) - 2.0 * conditional + de(-h, 0.0)) / (h * h);
let h11 = (de(0.0, h) - 2.0 * conditional + de(0.0, -h)) / (h * h);
let h01 = (de(h, h) - de(h, -h) - de(-h, h) + de(-h, -h)) / (4.0 * h * h);
let rho_cov = Array2::from_shape_vec((2, 2), vec![0.7, 0.2, 0.2, 0.5]).unwrap();
let total = lawley_lr_mean_shift_with_rho_variation(
x.view(),
&kappas,
penalty.view(),
tested.clone(),
&components,
rho_cov.view(),
)
.expect("rho-variation shift");
let expected = conditional
+ 0.5 * (h00 * rho_cov[[0, 0]] + h11 * rho_cov[[1, 1]])
+ h01 * rho_cov[[0, 1]];
assert!(
(total - expected).abs() < 1e-6 * (1.0 + expected.abs()),
"two-parameter ρ̂-variation {total} must equal {expected} \
(H00={h00}, H11={h11}, H01={h01})"
);
let diag_only = conditional + 0.5 * (h00 * rho_cov[[0, 0]] + h11 * rho_cov[[1, 1]]);
assert!(
(total - diag_only).abs() > 1e-9,
"cross term H01·Cov01 must be included (off-diagonal non-zero): \
total={total} diag_only={diag_only}"
);
}
#[test]
fn rho_variation_rejects_shape_mismatch() {
let n = 8usize;
let jets = RowExpectedJets::poisson_log(0.1);
let kappas = vec![jets.kappas().expect("kappas"); n];
let mut x = Array2::<f64>::ones((n, 2));
for i in 0..n {
x[[i, 1]] = i as f64 - 4.0;
}
let mut s = Array2::<f64>::zeros((2, 2));
s[[1, 1]] = 1.0;
let components = vec![RhoPenaltyComponent {
s_component: s.clone(),
}];
let bad_cov = Array2::<f64>::eye(2);
assert!(
lawley_lr_mean_shift_with_rho_variation(
x.view(),
&kappas,
s.view(),
1..2,
&components,
bad_cov.view(),
)
.is_err()
);
let cov1 = Array2::from_shape_vec((1, 1), vec![1.0]).unwrap();
assert!(
lawley_lr_mean_shift_with_rho_variation(
x.view(),
&kappas,
s.view(),
1..2,
&[],
cov1.view(),
)
.is_err()
);
let wrong = vec![RhoPenaltyComponent {
s_component: Array2::<f64>::eye(3),
}];
assert!(
lawley_lr_mean_shift_with_rho_variation(
x.view(),
&kappas,
s.view(),
1..2,
&wrong,
cov1.view(),
)
.is_err()
);
let nonsymmetric_cov = Array2::from_shape_vec((1, 1), vec![1.0]).unwrap();
assert!(
lawley_lr_mean_shift_with_rho_variation(
x.view(),
&kappas,
s.view(),
1..2,
&components,
nonsymmetric_cov.view(),
)
.is_ok()
);
let components2 = vec![
RhoPenaltyComponent {
s_component: s.clone(),
},
RhoPenaltyComponent {
s_component: s.clone(),
},
];
let bad_sym = Array2::from_shape_vec((2, 2), vec![1.0, 0.25, 0.20, 1.0]).unwrap();
assert!(
lawley_lr_mean_shift_with_rho_variation(
x.view(),
&kappas,
s.view(),
1..2,
&components2,
bad_sym.view(),
)
.is_err()
);
}
#[test]
fn epsilon_is_invariant_under_linear_reparametrization() {
let n = 15;
let k = 3;
let mut x = Array2::<f64>::zeros((n, k));
let mut kappas = Vec::with_capacity(n);
for i in 0..n {
let z = i as f64 / n as f64;
x[[i, 0]] = 1.0;
x[[i, 1]] = (3.1 * z).cos();
x[[i, 2]] = z - 0.5;
let eta = -0.1 + 0.6 * x[[i, 1]] + 0.4 * x[[i, 2]];
kappas.push(
RowExpectedJets::binomial_logit(eta)
.kappas()
.expect("binomial kappas"),
);
}
let t_mat = ndarray::arr2(&[[1.0, 0.3, -0.2], [0.0, 1.4, 0.5], [0.0, 0.0, 0.8]]);
let xt = x.dot(&t_mat);
let eps = lawley_epsilon(x.view(), &kappas, None).expect("ε");
let eps_t = lawley_epsilon(xt.view(), &kappas, None).expect("ε reparam");
assert!(
(eps - eps_t).abs() < 1e-9 * (1.0 + eps.abs()),
"ε not reparametrization-invariant: {eps} vs {eps_t}"
);
}
}