1use gam_linalg::faer_ndarray::FaerEigh;
16use ndarray::{Array1, Array2, ArrayView1, s};
17use statrs::distribution::{ChiSquared, ContinuousCDF, FisherSnedecor};
18use std::ops::Range;
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26pub enum SmoothTestScale {
27 Known,
28 Estimated,
29}
30
31#[derive(Debug, Clone)]
46pub struct SmoothTestInput<'a> {
47 pub beta: ArrayView1<'a, f64>,
48 pub covariance: &'a Array2<f64>,
49 pub influence_matrix: Option<&'a Array2<f64>>,
50 pub coeff_range: Range<usize>,
51 pub edf: f64,
52 pub nullspace_dim: usize,
53 pub residual_df: f64,
54 pub scale: SmoothTestScale,
55}
56
57#[derive(Debug, Clone)]
62pub struct SmoothTestResult {
63 pub statistic: f64,
64 pub ref_df: f64,
65 pub p_value: f64,
66}
67
68pub fn wood_smooth_test(input: SmoothTestInput<'_>) -> Option<SmoothTestResult> {
86 let start = input.coeff_range.start;
87 let end = input.coeff_range.end;
88 if start >= end
89 || end > input.beta.len()
90 || end > input.covariance.nrows()
91 || end > input.covariance.ncols()
92 || !input.edf.is_finite()
93 || input.edf <= 0.0
94 {
95 return None;
96 }
97 let k = end - start;
98 let beta = input.beta.slice(s![start..end]).to_owned();
99 let cov = block(input.covariance, start, end)?;
100 let null_dim = input.nullspace_dim.min(k);
101 let pen_dim = k.saturating_sub(null_dim);
102
103 let mut statistic = 0.0;
104 let mut rank_used = 0usize;
111 if null_dim > 0 {
112 let beta_null = beta.slice(s![0..null_dim]).to_owned();
113 let cov_null = cov.slice(s![0..null_dim, 0..null_dim]).to_owned();
114 let (q, used) = full_rank_quadratic(&beta_null, &cov_null)?;
115 statistic += q;
116 rank_used += used;
117 }
118 if pen_dim > 0 {
119 let beta_pen = beta.slice(s![null_dim..k]).to_owned();
120 let cov_pen = cov.slice(s![null_dim..k, null_dim..k]).to_owned();
121 let rank = truncated_rank(input.edf - null_dim as f64, pen_dim);
122 if rank > 0 {
123 let (q, used) = truncated_quadratic(&beta_pen, &cov_pen, rank)?;
124 statistic += q;
125 rank_used += used;
126 }
127 }
128
129 if rank_used == 0 {
130 return None;
133 }
134 let ref_df = match reference_df(input.influence_matrix, start, end) {
138 Some(rd) if rd.is_finite() && rd > 0.0 => rd.max(rank_used as f64),
139 _ => rank_used as f64,
140 };
141 if !statistic.is_finite() || statistic < 0.0 || !ref_df.is_finite() || ref_df <= 0.0 {
142 return None;
143 }
144 let p_value = match input.scale {
145 SmoothTestScale::Known => {
146 let dist = ChiSquared::new(ref_df).ok()?;
147 1.0 - dist.cdf(statistic)
148 }
149 SmoothTestScale::Estimated => {
150 if !input.residual_df.is_finite() || input.residual_df <= 0.0 {
151 return None;
152 }
153 let f_stat = statistic / ref_df;
158 let dist = FisherSnedecor::new(ref_df, input.residual_df).ok()?;
159 1.0 - dist.cdf(f_stat)
160 }
161 };
162 if !p_value.is_finite() {
163 return None;
164 }
165 Some(SmoothTestResult {
166 statistic,
167 ref_df,
168 p_value: p_value.clamp(0.0, 1.0),
169 })
170}
171
172fn truncated_rank(edf_pen: f64, pen_dim: usize) -> usize {
173 if pen_dim == 0 || !edf_pen.is_finite() || edf_pen <= 0.0 {
174 return 0;
175 }
176 (edf_pen.round() as usize).clamp(1, pen_dim)
177}
178
179fn block(matrix: &Array2<f64>, start: usize, end: usize) -> Option<Array2<f64>> {
180 if start >= end || end > matrix.nrows() || end > matrix.ncols() {
181 return None;
182 }
183 Some(matrix.slice(s![start..end, start..end]).to_owned())
184}
185
186fn full_rank_quadratic(beta: &Array1<f64>, cov: &Array2<f64>) -> Option<(f64, usize)> {
187 truncated_quadratic(beta, cov, beta.len())
188}
189
190fn truncated_quadratic(beta: &Array1<f64>, cov: &Array2<f64>, rank: usize) -> Option<(f64, usize)> {
197 if beta.is_empty() || cov.nrows() != beta.len() || cov.ncols() != beta.len() || rank == 0 {
198 return None;
199 }
200 let (evals, evecs) = cov.to_owned().eigh(faer::Side::Lower).ok()?;
201 let mut order: Vec<usize> = (0..evals.len()).collect();
202 order.sort_by(|&a, &b| evals[b].total_cmp(&evals[a]));
203 let tol = evals
204 .iter()
205 .copied()
206 .fold(0.0_f64, |acc, v| acc.max(v.abs()))
207 * 1e-10;
208 let mut q = 0.0;
209 let mut used = 0usize;
210 for idx in order {
211 let lambda = evals[idx];
212 if lambda <= tol {
213 continue;
214 }
215 let v = evecs.column(idx);
216 let proj = beta.dot(&v);
217 q += proj * proj / lambda;
218 used += 1;
219 if used >= rank {
220 break;
221 }
222 }
223 (used > 0 && q.is_finite()).then_some((q.max(0.0), used))
224}
225
226fn reference_df(influence: Option<&Array2<f64>>, start: usize, end: usize) -> Option<f64> {
227 let f = influence?;
228 let f_block = block(f, start, end)?;
229 let tr = (0..f_block.nrows()).map(|i| f_block[[i, i]]).sum::<f64>();
230 let tr2 = f_block.dot(&f_block).diag().sum();
231 if tr.is_finite() && tr2.is_finite() && tr > 0.0 && tr2 > 0.0 {
232 Some((tr * tr / tr2).max(1e-12))
233 } else {
234 None
235 }
236}
237
238#[cfg(test)]
239mod tests {
240 use super::*;
241 use ndarray::array;
242 use statrs::distribution::{ChiSquared, ContinuousCDF};
243
244 #[test]
245 fn reference_df_uses_trace_correction() {
246 let beta = array![1.0, 2.0];
247 let cov = array![[2.0, 0.0], [0.0, 3.0]];
248 let f = array![[0.5, 0.0], [0.0, 0.25]];
249 let out = wood_smooth_test(SmoothTestInput {
250 beta: beta.view(),
251 covariance: &cov,
252 influence_matrix: Some(&f),
253 coeff_range: 0..2,
254 edf: 1.0,
255 nullspace_dim: 0,
256 residual_df: 20.0,
257 scale: SmoothTestScale::Known,
258 })
259 .expect("smooth test");
260 assert!((out.ref_df - 1.8).abs() < 1e-12);
261 assert!(out.statistic > 0.0);
262 assert!((0.0..=1.0).contains(&out.p_value));
263 }
264
265 #[test]
266 fn known_scale_branch_reports_plain_wald_chi_square() {
267 let beta = array![1.0, 2.0];
268 let cov = array![[2.0, 0.0], [0.0, 3.0]];
269 let f = array![[0.5, 0.0], [0.0, 0.25]];
270 let out = wood_smooth_test(SmoothTestInput {
271 beta: beta.view(),
272 covariance: &cov,
273 influence_matrix: Some(&f),
274 coeff_range: 0..2,
275 edf: 1.0,
276 nullspace_dim: 0,
277 residual_df: 20.0,
278 scale: SmoothTestScale::Known,
279 })
280 .expect("smooth test");
281
282 let dist = ChiSquared::new(out.ref_df).expect("chi-square");
283 let expected = 1.0 - dist.cdf(out.statistic);
284 assert!((out.p_value - expected).abs() < 1e-15);
285 }
286
287 #[test]
294 fn estimated_scale_pvalue_is_response_unit_invariant() {
295 let beta = array![2.5, -3.5, 1.8];
296 let cov = array![[2.0, 0.3, 0.0], [0.3, 1.5, 0.1], [0.0, 0.1, 0.9]];
297 let f = array![[0.7, 0.0, 0.0], [0.0, 0.6, 0.0], [0.0, 0.0, 0.4]];
298
299 let run = |c: f64| {
300 let beta_c = &beta * c;
301 let cov_c = &cov * (c * c);
302 wood_smooth_test(SmoothTestInput {
303 beta: beta_c.view(),
304 covariance: &cov_c,
305 influence_matrix: Some(&f),
306 coeff_range: 0..3,
307 edf: 2.0,
308 nullspace_dim: 0,
309 residual_df: 50.0,
310 scale: SmoothTestScale::Estimated,
311 })
312 .expect("smooth test")
313 };
314
315 let base = run(1.0);
316 assert!(base.statistic > 0.0);
317 assert!(base.p_value > 0.0 && base.p_value < 0.05);
320 for c in [1e-3, 0.1, 10.0, 1e3, 1e6] {
321 let scaled = run(c);
322 let rel_stat = (scaled.statistic - base.statistic).abs() / base.statistic;
323 assert!(
324 rel_stat < 1e-9,
325 "Wald statistic not scale-invariant at c={c}: {} vs {}",
326 scaled.statistic,
327 base.statistic
328 );
329 let rel_p = (scaled.p_value - base.p_value).abs() / base.p_value;
330 assert!(
331 rel_p < 1e-9,
332 "estimated-scale p-value not scale-invariant at c={c}: {} vs {}",
333 scaled.p_value,
334 base.p_value
335 );
336 }
337 }
338
339 #[test]
347 fn boundary_shrunk_term_is_not_significant() {
348 let beta = array![1e-9, -2e-9, 5e-10];
352 let cov = array![[0.04, 0.0, 0.0], [0.0, 0.05, 0.0], [0.0, 0.0, 0.06]];
353 let f = array![[1e-9, 0.0, 0.0], [0.0, -1e-9, 0.0], [0.0, 0.0, 1e-12]];
356 for scale in [SmoothTestScale::Known, SmoothTestScale::Estimated] {
357 let out = wood_smooth_test(SmoothTestInput {
358 beta: beta.view(),
359 covariance: &cov,
360 influence_matrix: Some(&f),
361 coeff_range: 0..3,
362 edf: 1e-6,
363 nullspace_dim: 0,
364 residual_df: 500.0,
365 scale,
366 })
367 .expect("boundary term still produces a result");
368 assert!(
369 out.ref_df >= 1.0,
370 "reference d.f. must not collapse below the tested rank: {}",
371 out.ref_df
372 );
373 assert!(
374 out.statistic < 1e-6,
375 "boundary statistic should be ~0: {}",
376 out.statistic
377 );
378 assert!(
379 out.p_value > 0.5,
380 "shrunk boundary term must not be significant (p={}, scale={:?})",
381 out.p_value,
382 scale
383 );
384 }
385 }
386
387 #[test]
391 fn floor_does_not_blunt_a_real_signal() {
392 let beta = array![6.0, -5.0];
393 let cov = array![[1.0, 0.0], [0.0, 1.0]];
394 let f = array![[0.9, 0.0], [0.0, 0.9]];
395 let out = wood_smooth_test(SmoothTestInput {
396 beta: beta.view(),
397 covariance: &cov,
398 influence_matrix: Some(&f),
399 coeff_range: 0..2,
400 edf: 2.0,
401 nullspace_dim: 2,
402 residual_df: 500.0,
403 scale: SmoothTestScale::Known,
404 })
405 .expect("smooth test");
406 assert!(out.statistic > 40.0, "statistic={}", out.statistic);
407 assert!(
408 out.p_value < 1e-6,
409 "a strong term must stay significant: p={}",
410 out.p_value
411 );
412 }
413}