Skip to main content

cyanea_stats/
correction.rs

1//! Multiple testing correction.
2//!
3//! When running many hypothesis tests simultaneously, p-values must be
4//! adjusted to control the family-wise error rate or false discovery rate.
5
6use cyanea_core::{CyaneaError, Result};
7
8/// Multiple testing correction method.
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum CorrectionMethod {
11    /// Bonferroni correction — controls family-wise error rate (FWER).
12    Bonferroni,
13    /// Benjamini-Hochberg procedure — controls false discovery rate (FDR).
14    BenjaminiHochberg,
15}
16
17/// Apply a multiple testing correction to `p_values`.
18///
19/// Returns a new `Vec<f64>` of adjusted p-values in the same order as the
20/// input.
21pub fn correct(p_values: &[f64], method: CorrectionMethod) -> Result<Vec<f64>> {
22    match method {
23        CorrectionMethod::Bonferroni => bonferroni(p_values),
24        CorrectionMethod::BenjaminiHochberg => benjamini_hochberg(p_values),
25    }
26}
27
28/// Bonferroni correction: `p_adj = min(p * n, 1.0)`.
29pub fn bonferroni(p_values: &[f64]) -> Result<Vec<f64>> {
30    validate_p_values(p_values)?;
31    let n = p_values.len() as f64;
32    Ok(p_values.iter().map(|&p| (p * n).min(1.0)).collect())
33}
34
35/// Benjamini-Hochberg procedure for controlling the false discovery rate.
36///
37/// Sorts p-values, adjusts as `p * n / rank`, enforces monotonicity
38/// from right to left, and clamps to [0, 1].
39pub fn benjamini_hochberg(p_values: &[f64]) -> Result<Vec<f64>> {
40    validate_p_values(p_values)?;
41    let n = p_values.len();
42    if n == 0 {
43        return Ok(Vec::new());
44    }
45
46    // Sort indices by p-value.
47    let mut indices: Vec<usize> = (0..n).collect();
48    indices.sort_by(|&a, &b| p_values[a].total_cmp(&p_values[b]));
49
50    let n_f = n as f64;
51    let mut adjusted = vec![0.0; n];
52
53    // Compute adjusted p-values and enforce monotonicity (right to left).
54    let mut prev = f64::INFINITY;
55    for i in (0..n).rev() {
56        let rank = (i + 1) as f64;
57        let adj = (p_values[indices[i]] * n_f / rank).min(1.0);
58        let adj = adj.min(prev);
59        adjusted[indices[i]] = adj;
60        prev = adj;
61    }
62
63    Ok(adjusted)
64}
65
66fn validate_p_values(p_values: &[f64]) -> Result<()> {
67    for (i, &p) in p_values.iter().enumerate() {
68        if !(0.0..=1.0).contains(&p) {
69            return Err(CyaneaError::InvalidInput(format!(
70                "p-value at index {} is out of range [0, 1]: {}",
71                i, p,
72            )));
73        }
74    }
75    Ok(())
76}
77
78// ── Tests ──────────────────────────────────────────────────────────────────
79
80#[cfg(test)]
81mod tests {
82    use super::*;
83
84    const TOL: f64 = 1e-10;
85
86    #[test]
87    fn bonferroni_basic() {
88        let p = [0.01, 0.04, 0.03, 0.005];
89        let adj = bonferroni(&p).unwrap();
90        assert!((adj[0] - 0.04).abs() < TOL);
91        assert!((adj[1] - 0.16).abs() < TOL);
92        assert!((adj[2] - 0.12).abs() < TOL);
93        assert!((adj[3] - 0.02).abs() < TOL);
94    }
95
96    #[test]
97    fn bonferroni_clamp() {
98        let p = [0.5, 0.8];
99        let adj = bonferroni(&p).unwrap();
100        assert!((adj[0] - 1.0).abs() < TOL);
101        assert!((adj[1] - 1.0).abs() < TOL);
102    }
103
104    #[test]
105    fn bh_known() {
106        // Classic BH example
107        let p = [0.01, 0.04, 0.03, 0.005];
108        let adj = benjamini_hochberg(&p).unwrap();
109        // Sorted: 0.005(idx3), 0.01(idx0), 0.03(idx2), 0.04(idx1)
110        // Ranks:    1            2            3            4
111        // Raw adj: 0.005*4/1=0.02, 0.01*4/2=0.02, 0.03*4/3=0.04, 0.04*4/4=0.04
112        // Monotonicity (R-to-L): 0.04, 0.04, 0.02, 0.02
113        assert!((adj[3] - 0.02).abs() < TOL);
114        assert!((adj[0] - 0.02).abs() < TOL);
115        assert!((adj[2] - 0.04).abs() < TOL);
116        assert!((adj[1] - 0.04).abs() < TOL);
117    }
118
119    #[test]
120    fn bh_monotonicity() {
121        // Ensure sorted adjusted p-values are non-decreasing.
122        let p = [0.1, 0.001, 0.05, 0.01, 0.5];
123        let adj = benjamini_hochberg(&p).unwrap();
124        let mut sorted_adj: Vec<(f64, f64)> = p.iter().copied().zip(adj.iter().copied()).collect();
125        sorted_adj.sort_by(|a, b| a.0.total_cmp(&b.0));
126        for w in sorted_adj.windows(2) {
127            assert!(
128                w[1].1 >= w[0].1 - TOL,
129                "monotonicity violated: {} > {}",
130                w[0].1,
131                w[1].1
132            );
133        }
134    }
135
136    #[test]
137    fn bh_clamp() {
138        let p = [0.9, 0.95];
139        let adj = benjamini_hochberg(&p).unwrap();
140        assert!(adj[0] <= 1.0);
141        assert!(adj[1] <= 1.0);
142    }
143
144    #[test]
145    fn correction_empty() {
146        assert_eq!(bonferroni(&[]).unwrap(), Vec::<f64>::new());
147        assert_eq!(benjamini_hochberg(&[]).unwrap(), Vec::<f64>::new());
148    }
149
150    #[test]
151    fn correction_single() {
152        assert!((bonferroni(&[0.05]).unwrap()[0] - 0.05).abs() < TOL);
153        assert!((benjamini_hochberg(&[0.05]).unwrap()[0] - 0.05).abs() < TOL);
154    }
155
156    #[test]
157    fn correction_invalid_p() {
158        assert!(bonferroni(&[0.5, 1.5]).is_err());
159        assert!(benjamini_hochberg(&[-0.1, 0.5]).is_err());
160    }
161}