use crate::faer_ndarray::FaerEigh;
use ndarray::{Array1, Array2, ArrayView1, s};
use statrs::distribution::{ChiSquared, ContinuousCDF, FisherSnedecor};
use std::ops::Range;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SmoothTestScale {
Known,
Estimated,
}
#[derive(Debug, Clone)]
pub struct SmoothTestInput<'a> {
pub beta: ArrayView1<'a, f64>,
pub covariance: &'a Array2<f64>,
pub influence_matrix: Option<&'a Array2<f64>>,
pub coeff_range: Range<usize>,
pub edf: f64,
pub nullspace_dim: usize,
pub residual_df: f64,
pub scale: SmoothTestScale,
}
#[derive(Debug, Clone)]
pub struct SmoothTestResult {
pub statistic: f64,
pub ref_df: f64,
pub p_value: f64,
}
pub fn wood_smooth_test(input: SmoothTestInput<'_>) -> Option<SmoothTestResult> {
let start = input.coeff_range.start;
let end = input.coeff_range.end;
if start >= end
|| end > input.beta.len()
|| end > input.covariance.nrows()
|| end > input.covariance.ncols()
|| !input.edf.is_finite()
|| input.edf <= 0.0
{
return None;
}
let k = end - start;
let beta = input.beta.slice(s![start..end]).to_owned();
let cov = block(input.covariance, start, end)?;
let null_dim = input.nullspace_dim.min(k);
let pen_dim = k.saturating_sub(null_dim);
let mut statistic = 0.0;
let mut rank_used = 0usize;
if null_dim > 0 {
let beta_null = beta.slice(s![0..null_dim]).to_owned();
let cov_null = cov.slice(s![0..null_dim, 0..null_dim]).to_owned();
let (q, used) = full_rank_quadratic(&beta_null, &cov_null)?;
statistic += q;
rank_used += used;
}
if pen_dim > 0 {
let beta_pen = beta.slice(s![null_dim..k]).to_owned();
let cov_pen = cov.slice(s![null_dim..k, null_dim..k]).to_owned();
let rank = truncated_rank(input.edf - null_dim as f64, pen_dim);
if rank > 0 {
let (q, used) = truncated_quadratic(&beta_pen, &cov_pen, rank)?;
statistic += q;
rank_used += used;
}
}
if rank_used == 0 {
return None;
}
let ref_df = match reference_df(input.influence_matrix, start, end) {
Some(rd) if rd.is_finite() && rd > 0.0 => rd.max(rank_used as f64),
_ => rank_used as f64,
};
if !statistic.is_finite() || statistic < 0.0 || !ref_df.is_finite() || ref_df <= 0.0 {
return None;
}
let p_value = match input.scale {
SmoothTestScale::Known => {
let dist = ChiSquared::new(ref_df).ok()?;
1.0 - dist.cdf(statistic)
}
SmoothTestScale::Estimated => {
if !input.residual_df.is_finite() || input.residual_df <= 0.0 {
return None;
}
let f_stat = statistic / ref_df;
let dist = FisherSnedecor::new(ref_df, input.residual_df).ok()?;
1.0 - dist.cdf(f_stat)
}
};
if !p_value.is_finite() {
return None;
}
Some(SmoothTestResult {
statistic,
ref_df,
p_value: p_value.clamp(0.0, 1.0),
})
}
fn truncated_rank(edf_pen: f64, pen_dim: usize) -> usize {
if pen_dim == 0 || !edf_pen.is_finite() || edf_pen <= 0.0 {
return 0;
}
(edf_pen.round() as usize).clamp(1, pen_dim)
}
fn block(matrix: &Array2<f64>, start: usize, end: usize) -> Option<Array2<f64>> {
if start >= end || end > matrix.nrows() || end > matrix.ncols() {
return None;
}
Some(matrix.slice(s![start..end, start..end]).to_owned())
}
fn full_rank_quadratic(beta: &Array1<f64>, cov: &Array2<f64>) -> Option<(f64, usize)> {
truncated_quadratic(beta, cov, beta.len())
}
fn truncated_quadratic(beta: &Array1<f64>, cov: &Array2<f64>, rank: usize) -> Option<(f64, usize)> {
if beta.is_empty() || cov.nrows() != beta.len() || cov.ncols() != beta.len() || rank == 0 {
return None;
}
let (evals, evecs) = cov.to_owned().eigh(faer::Side::Lower).ok()?;
let mut order: Vec<usize> = (0..evals.len()).collect();
order.sort_by(|&a, &b| evals[b].total_cmp(&evals[a]));
let tol = evals
.iter()
.copied()
.fold(0.0_f64, |acc, v| acc.max(v.abs()))
* 1e-10;
let mut q = 0.0;
let mut used = 0usize;
for idx in order {
let lambda = evals[idx];
if lambda <= tol {
continue;
}
let v = evecs.column(idx);
let proj = beta.dot(&v);
q += proj * proj / lambda;
used += 1;
if used >= rank {
break;
}
}
(used > 0 && q.is_finite()).then_some((q.max(0.0), used))
}
fn reference_df(influence: Option<&Array2<f64>>, start: usize, end: usize) -> Option<f64> {
let f = influence?;
let f_block = block(f, start, end)?;
let tr = (0..f_block.nrows()).map(|i| f_block[[i, i]]).sum::<f64>();
let tr2 = f_block.dot(&f_block).diag().sum();
if tr.is_finite() && tr2.is_finite() && tr > 0.0 && tr2 > 0.0 {
Some((tr * tr / tr2).max(1e-12))
} else {
None
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::array;
use statrs::distribution::{ChiSquared, ContinuousCDF};
#[test]
fn reference_df_uses_trace_correction() {
let beta = array![1.0, 2.0];
let cov = array![[2.0, 0.0], [0.0, 3.0]];
let f = array![[0.5, 0.0], [0.0, 0.25]];
let out = wood_smooth_test(SmoothTestInput {
beta: beta.view(),
covariance: &cov,
influence_matrix: Some(&f),
coeff_range: 0..2,
edf: 1.0,
nullspace_dim: 0,
residual_df: 20.0,
scale: SmoothTestScale::Known,
})
.expect("smooth test");
assert!((out.ref_df - 1.8).abs() < 1e-12);
assert!(out.statistic > 0.0);
assert!((0.0..=1.0).contains(&out.p_value));
}
#[test]
fn known_scale_branch_reports_plain_wald_chi_square() {
let beta = array![1.0, 2.0];
let cov = array![[2.0, 0.0], [0.0, 3.0]];
let f = array![[0.5, 0.0], [0.0, 0.25]];
let out = wood_smooth_test(SmoothTestInput {
beta: beta.view(),
covariance: &cov,
influence_matrix: Some(&f),
coeff_range: 0..2,
edf: 1.0,
nullspace_dim: 0,
residual_df: 20.0,
scale: SmoothTestScale::Known,
})
.expect("smooth test");
let dist = ChiSquared::new(out.ref_df).expect("chi-square");
let expected = 1.0 - dist.cdf(out.statistic);
assert!((out.p_value - expected).abs() < 1e-15);
}
#[test]
fn estimated_scale_pvalue_is_response_unit_invariant() {
let beta = array![2.5, -3.5, 1.8];
let cov = array![[2.0, 0.3, 0.0], [0.3, 1.5, 0.1], [0.0, 0.1, 0.9]];
let f = array![[0.7, 0.0, 0.0], [0.0, 0.6, 0.0], [0.0, 0.0, 0.4]];
let run = |c: f64| {
let beta_c = &beta * c;
let cov_c = &cov * (c * c);
wood_smooth_test(SmoothTestInput {
beta: beta_c.view(),
covariance: &cov_c,
influence_matrix: Some(&f),
coeff_range: 0..3,
edf: 2.0,
nullspace_dim: 0,
residual_df: 50.0,
scale: SmoothTestScale::Estimated,
})
.expect("smooth test")
};
let base = run(1.0);
assert!(base.statistic > 0.0);
assert!(base.p_value > 0.0 && base.p_value < 0.05);
for c in [1e-3, 0.1, 10.0, 1e3, 1e6] {
let scaled = run(c);
let rel_stat = (scaled.statistic - base.statistic).abs() / base.statistic;
assert!(
rel_stat < 1e-9,
"Wald statistic not scale-invariant at c={c}: {} vs {}",
scaled.statistic,
base.statistic
);
let rel_p = (scaled.p_value - base.p_value).abs() / base.p_value;
assert!(
rel_p < 1e-9,
"estimated-scale p-value not scale-invariant at c={c}: {} vs {}",
scaled.p_value,
base.p_value
);
}
}
#[test]
fn boundary_shrunk_term_is_not_significant() {
let beta = array![1e-9, -2e-9, 5e-10];
let cov = array![[0.04, 0.0, 0.0], [0.0, 0.05, 0.0], [0.0, 0.0, 0.06]];
let f = array![[1e-9, 0.0, 0.0], [0.0, -1e-9, 0.0], [0.0, 0.0, 1e-12]];
for scale in [SmoothTestScale::Known, SmoothTestScale::Estimated] {
let out = wood_smooth_test(SmoothTestInput {
beta: beta.view(),
covariance: &cov,
influence_matrix: Some(&f),
coeff_range: 0..3,
edf: 1e-6,
nullspace_dim: 0,
residual_df: 500.0,
scale,
})
.expect("boundary term still produces a result");
assert!(
out.ref_df >= 1.0,
"reference d.f. must not collapse below the tested rank: {}",
out.ref_df
);
assert!(
out.statistic < 1e-6,
"boundary statistic should be ~0: {}",
out.statistic
);
assert!(
out.p_value > 0.5,
"shrunk boundary term must not be significant (p={}, scale={:?})",
out.p_value,
scale
);
}
}
#[test]
fn floor_does_not_blunt_a_real_signal() {
let beta = array![6.0, -5.0];
let cov = array![[1.0, 0.0], [0.0, 1.0]];
let f = array![[0.9, 0.0], [0.0, 0.9]];
let out = wood_smooth_test(SmoothTestInput {
beta: beta.view(),
covariance: &cov,
influence_matrix: Some(&f),
coeff_range: 0..2,
edf: 2.0,
nullspace_dim: 2,
residual_df: 500.0,
scale: SmoothTestScale::Known,
})
.expect("smooth test");
assert!(out.statistic > 40.0, "statistic={}", out.statistic);
assert!(
out.p_value < 1e-6,
"a strong term must stay significant: p={}",
out.p_value
);
}
}