use statrs::distribution::{ChiSquared, ContinuousCDF};
#[derive(Debug, Clone, Copy)]
pub struct BartlettCorrection {
pub factor: f64,
pub corrected_statistic: f64,
pub corrected_p_value: f64,
pub relative_adjustment: f64,
}
pub fn bartlett_correct(w: f64, ref_df: f64, factor: f64) -> Option<BartlettCorrection> {
if !(w.is_finite() && ref_df.is_finite() && factor.is_finite())
|| w < 0.0
|| ref_df <= 0.0
|| factor <= 0.0
{
return None;
}
let corrected = w / factor;
let dist = ChiSquared::new(ref_df).ok()?;
let p = (1.0 - dist.cdf(corrected)).clamp(0.0, 1.0);
Some(BartlettCorrection {
factor,
corrected_statistic: corrected,
corrected_p_value: p,
relative_adjustment: (factor - 1.0).abs(),
})
}
pub fn bartlett_factor_from_mean(mean_w: f64, ref_df: f64) -> Option<f64> {
if !(mean_w.is_finite() && ref_df.is_finite()) || mean_w <= 0.0 || ref_df <= 0.0 {
return None;
}
Some(mean_w / ref_df)
}
pub fn gaussian_linear_bartlett_factor(q: f64, residual_df: f64) -> Option<f64> {
if !(q.is_finite() && residual_df.is_finite()) || q <= 0.0 || residual_df <= 0.0 {
return None;
}
Some(1.0 + (q + 1.0) / (2.0 * residual_df))
}
#[derive(Debug, Clone, Copy)]
pub struct RowLogLikDerivs {
pub d1: f64,
pub d2: f64,
pub d3: f64,
pub d4: f64,
}
pub fn row_derivs_from_nll_tower(
value_grad: f64,
hess: f64,
third: f64,
fourth: f64,
) -> RowLogLikDerivs {
RowLogLikDerivs {
d1: -value_grad,
d2: -hess,
d3: -third,
d4: -fourth,
}
}
#[derive(Debug, Clone)]
pub struct CumulantArrays {
pub q: usize,
pub info: Vec<f64>,
pub nu3: Vec<f64>,
pub nu4: Vec<f64>,
}
impl CumulantArrays {
#[inline]
pub fn info(&self, a: usize, b: usize) -> f64 {
self.info[a * self.q + b]
}
#[inline]
pub fn nu3(&self, a: usize, b: usize, c: usize) -> f64 {
self.nu3[(a * self.q + b) * self.q + c]
}
#[inline]
pub fn nu4(&self, a: usize, b: usize, c: usize, d: usize) -> f64 {
self.nu4[((a * self.q + b) * self.q + c) * self.q + d]
}
}
pub fn assemble_cumulants(block: &[&[f64]], rows: &[RowLogLikDerivs]) -> Option<CumulantArrays> {
let n = rows.len();
if n == 0 || block.len() != n {
return None;
}
let q = block[0].len();
if q == 0 || block.iter().any(|r| r.len() != q) {
return None;
}
let mut info = vec![0.0_f64; q * q];
let mut nu3 = vec![0.0_f64; q * q * q];
let mut nu4 = vec![0.0_f64; q * q * q * q];
for (z, d) in block.iter().zip(rows.iter()) {
if !(d.d1.is_finite() && d.d2.is_finite() && d.d3.is_finite() && d.d4.is_finite()) {
return None;
}
if z.iter().any(|v| !v.is_finite()) {
return None;
}
for a in 0..q {
let za = z[a];
for b in 0..q {
let zab = za * z[b];
info[a * q + b] -= d.d2 * zab;
for c in 0..q {
let zabc = zab * z[c];
nu3[(a * q + b) * q + c] += d.d3 * zabc;
for e in 0..q {
nu4[((a * q + b) * q + c) * q + e] += d.d4 * zabc * z[e];
}
}
}
}
}
if info
.iter()
.chain(nu3.iter())
.chain(nu4.iter())
.any(|v| !v.is_finite())
{
return None;
}
Some(CumulantArrays { q, info, nu3, nu4 })
}
pub fn scalar_standardized_cumulants(cumulants: &CumulantArrays) -> Option<(f64, f64)> {
if cumulants.q != 1 {
return None;
}
let i = cumulants.info(0, 0);
if !(i.is_finite() && i > 0.0) {
return None;
}
let rho3 = cumulants.nu3(0, 0, 0) / i.powf(1.5);
let rho4 = cumulants.nu4(0, 0, 0, 0) / (i * i);
if rho3.is_finite() && rho4.is_finite() {
Some((rho3, rho4))
} else {
None
}
}
#[cfg(test)]
mod tests {
use super::*;
use statrs::distribution::{ContinuousCDF, FisherSnedecor};
#[test]
fn bartlett_factor_recovers_mean_over_df() {
let c = bartlett_factor_from_mean(6.0, 4.0).expect("factor");
assert!((c - 1.5).abs() < 1e-12);
assert!(bartlett_factor_from_mean(-1.0, 4.0).is_none());
assert!(bartlett_factor_from_mean(6.0, 0.0).is_none());
}
#[test]
fn correction_rescales_statistic_and_enlarges_p_for_inflated_stat() {
let raw_w = 12.0;
let d = 4.0;
let factor = 1.5;
let corr = bartlett_correct(raw_w, d, factor).expect("correction");
assert!((corr.corrected_statistic - 8.0).abs() < 1e-12);
let dist = ChiSquared::new(d).unwrap();
let raw_p = 1.0 - dist.cdf(raw_w);
assert!(
corr.corrected_p_value > raw_p,
"corrected p {} must exceed raw p {}",
corr.corrected_p_value,
raw_p
);
assert!((corr.relative_adjustment - 0.5).abs() < 1e-12);
}
#[test]
fn gaussian_linear_bartlett_moves_mean_toward_truth() {
let q = 3.0_f64;
let nu = 20.0_f64; let n = (q + 1.0 + nu) as f64;
let c = gaussian_linear_bartlett_factor(q, nu).expect("factor");
assert!((c - 1.1).abs() < 1e-12);
let fdist = FisherSnedecor::new(q, nu).expect("F dist");
let pdf = |f: f64| {
let h = 1e-5 * (1.0 + f);
(fdist.cdf(f + h) - fdist.cdf(f - h)) / (2.0 * h)
};
let w_of = |f: f64| n * (1.0 + (q / nu) * f).ln();
let f_hi = 60.0_f64;
let steps = 600_000usize;
let dx = f_hi / steps as f64;
let mut e_w = 0.0;
for i in 0..=steps {
let f = (i as f64) * dx + 1e-9;
let weight = if i == 0 || i == steps { 0.5 } else { 1.0 };
e_w += weight * w_of(f) * pdf(f);
}
e_w *= dx;
let raw_bias = (e_w - q).abs();
assert!(
raw_bias > 0.1,
"first-order test should be materially biased at ν={nu}: E[W]={e_w}, q={q}"
);
let corrected_mean = e_w / c;
let corrected_bias = (corrected_mean - q).abs();
assert!(
corrected_bias < 0.5 * raw_bias,
"Bartlett correction must move the mean toward truth: \
raw_bias={raw_bias:.5} corrected_bias={corrected_bias:.5} \
(E[W]={e_w:.5}, c={c:.5})"
);
}
#[test]
fn factor_vanishes_in_the_large_sample_limit() {
let c_small = gaussian_linear_bartlett_factor(3.0, 10.0).unwrap();
let c_large = gaussian_linear_bartlett_factor(3.0, 100_000.0).unwrap();
assert!(c_small > 1.0);
assert!((c_large - 1.0).abs() < 1e-3);
assert!(c_small > c_large);
}
#[test]
fn nll_tower_sign_flip_gives_loglik_derivatives() {
let d = row_derivs_from_nll_tower(0.5, -2.0, 0.3, -0.1);
assert_eq!(d.d1, -0.5);
assert_eq!(d.d2, 2.0);
assert_eq!(d.d3, -0.3);
assert_eq!(d.d4, 0.1);
}
#[test]
fn cumulant_arrays_are_exact_row_sums_and_fully_symmetric() {
let z0 = [1.0_f64, 2.0];
let z1 = [-1.0_f64, 0.5];
let block: Vec<&[f64]> = vec![&z0, &z1];
let rows = vec![
RowLogLikDerivs {
d1: 0.0,
d2: -1.5,
d3: 0.7,
d4: -0.2,
},
RowLogLikDerivs {
d1: 0.0,
d2: -0.5,
d3: 1.1,
d4: 0.4,
},
];
let c = assemble_cumulants(&block, &rows).expect("cumulants");
assert_eq!(c.q, 2);
let info00 = 1.5 * (1.0 * 1.0) + 0.5 * (-1.0 * -1.0);
let info01 = 1.5 * (1.0 * 2.0) + 0.5 * (-1.0 * 0.5);
assert!((c.info(0, 0) - info00).abs() < 1e-12);
assert!((c.info(0, 1) - info01).abs() < 1e-12);
assert!((c.info(0, 1) - c.info(1, 0)).abs() < 1e-14);
let nu3_010 = 0.7 * (1.0 * 2.0 * 1.0) + 1.1 * (-1.0 * 0.5 * -1.0);
assert!((c.nu3(0, 1, 0) - nu3_010).abs() < 1e-12);
assert!((c.nu3(0, 1, 0) - c.nu3(1, 0, 0)).abs() < 1e-14);
assert!((c.nu3(0, 1, 0) - c.nu3(0, 0, 1)).abs() < 1e-14);
let nu4_0011 = -0.2 * (1.0 * 1.0 * 2.0 * 2.0) + 0.4 * (-1.0 * -1.0 * 0.5 * 0.5);
assert!((c.nu4(0, 0, 1, 1) - nu4_0011).abs() < 1e-12);
assert!((c.nu4(0, 0, 1, 1) - c.nu4(1, 1, 0, 0)).abs() < 1e-14);
}
#[test]
fn gaussian_known_variance_has_zero_standardized_cumulants() {
let phi = 2.0;
let n = 50usize;
let zcol = [1.0_f64];
let block: Vec<&[f64]> = (0..n).map(|_| &zcol[..]).collect();
let rows: Vec<RowLogLikDerivs> = (0..n)
.map(|_| RowLogLikDerivs {
d1: 0.0,
d2: -1.0 / phi,
d3: 0.0,
d4: 0.0,
})
.collect();
let c = assemble_cumulants(&block, &rows).expect("cumulants");
let (rho3, rho4) = scalar_standardized_cumulants(&c).expect("standardized");
assert!(rho3.abs() < 1e-12, "Gaussian ρ₃ must be 0, got {rho3}");
assert!(rho4.abs() < 1e-12, "Gaussian ρ₄ must be 0, got {rho4}");
assert!((c.info(0, 0) - (n as f64) / phi).abs() < 1e-10);
}
#[test]
fn exponential_rate_standardized_cumulants_match_closed_form() {
let theta = 1.0_f64;
let n = 64usize;
let zcol = [1.0_f64];
let block: Vec<&[f64]> = (0..n).map(|_| &zcol[..]).collect();
let rows: Vec<RowLogLikDerivs> = (0..n)
.map(|_| RowLogLikDerivs {
d1: 0.0, d2: -1.0 / (theta * theta),
d3: 2.0 / theta.powi(3),
d4: -6.0 / theta.powi(4),
})
.collect();
let c = assemble_cumulants(&block, &rows).expect("cumulants");
assert!((c.info(0, 0) - n as f64).abs() < 1e-10);
let (rho3, rho4) = scalar_standardized_cumulants(&c).expect("standardized");
let nf = n as f64;
assert!(
(rho3 - 2.0 / nf.sqrt()).abs() < 1e-10,
"Exponential ρ₃ must be 2/√n = {}, got {rho3}",
2.0 / nf.sqrt()
);
assert!(
(rho4 - (-6.0 / nf)).abs() < 1e-10,
"Exponential ρ₄ must be −6/n = {}, got {rho4}",
-6.0 / nf
);
}
#[test]
fn assemble_cumulants_rejects_degenerate_input() {
let z = [1.0_f64];
let block: Vec<&[f64]> = vec![&z];
assert!(assemble_cumulants(&block, &[]).is_none());
let bad = vec![RowLogLikDerivs {
d1: 0.0,
d2: f64::NAN,
d3: 0.0,
d4: 0.0,
}];
assert!(assemble_cumulants(&block, &bad).is_none());
}
}