cyanea_stats/
correction.rs1use cyanea_core::{CyaneaError, Result};
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum CorrectionMethod {
11 Bonferroni,
13 BenjaminiHochberg,
15}
16
17pub fn correct(p_values: &[f64], method: CorrectionMethod) -> Result<Vec<f64>> {
22 match method {
23 CorrectionMethod::Bonferroni => bonferroni(p_values),
24 CorrectionMethod::BenjaminiHochberg => benjamini_hochberg(p_values),
25 }
26}
27
28pub fn bonferroni(p_values: &[f64]) -> Result<Vec<f64>> {
30 validate_p_values(p_values)?;
31 let n = p_values.len() as f64;
32 Ok(p_values.iter().map(|&p| (p * n).min(1.0)).collect())
33}
34
35pub fn benjamini_hochberg(p_values: &[f64]) -> Result<Vec<f64>> {
40 validate_p_values(p_values)?;
41 let n = p_values.len();
42 if n == 0 {
43 return Ok(Vec::new());
44 }
45
46 let mut indices: Vec<usize> = (0..n).collect();
48 indices.sort_by(|&a, &b| p_values[a].total_cmp(&p_values[b]));
49
50 let n_f = n as f64;
51 let mut adjusted = vec![0.0; n];
52
53 let mut prev = f64::INFINITY;
55 for i in (0..n).rev() {
56 let rank = (i + 1) as f64;
57 let adj = (p_values[indices[i]] * n_f / rank).min(1.0);
58 let adj = adj.min(prev);
59 adjusted[indices[i]] = adj;
60 prev = adj;
61 }
62
63 Ok(adjusted)
64}
65
66fn validate_p_values(p_values: &[f64]) -> Result<()> {
67 for (i, &p) in p_values.iter().enumerate() {
68 if !(0.0..=1.0).contains(&p) {
69 return Err(CyaneaError::InvalidInput(format!(
70 "p-value at index {} is out of range [0, 1]: {}",
71 i, p,
72 )));
73 }
74 }
75 Ok(())
76}
77
78#[cfg(test)]
81mod tests {
82 use super::*;
83
84 const TOL: f64 = 1e-10;
85
86 #[test]
87 fn bonferroni_basic() {
88 let p = [0.01, 0.04, 0.03, 0.005];
89 let adj = bonferroni(&p).unwrap();
90 assert!((adj[0] - 0.04).abs() < TOL);
91 assert!((adj[1] - 0.16).abs() < TOL);
92 assert!((adj[2] - 0.12).abs() < TOL);
93 assert!((adj[3] - 0.02).abs() < TOL);
94 }
95
96 #[test]
97 fn bonferroni_clamp() {
98 let p = [0.5, 0.8];
99 let adj = bonferroni(&p).unwrap();
100 assert!((adj[0] - 1.0).abs() < TOL);
101 assert!((adj[1] - 1.0).abs() < TOL);
102 }
103
104 #[test]
105 fn bh_known() {
106 let p = [0.01, 0.04, 0.03, 0.005];
108 let adj = benjamini_hochberg(&p).unwrap();
109 assert!((adj[3] - 0.02).abs() < TOL);
114 assert!((adj[0] - 0.02).abs() < TOL);
115 assert!((adj[2] - 0.04).abs() < TOL);
116 assert!((adj[1] - 0.04).abs() < TOL);
117 }
118
119 #[test]
120 fn bh_monotonicity() {
121 let p = [0.1, 0.001, 0.05, 0.01, 0.5];
123 let adj = benjamini_hochberg(&p).unwrap();
124 let mut sorted_adj: Vec<(f64, f64)> = p.iter().copied().zip(adj.iter().copied()).collect();
125 sorted_adj.sort_by(|a, b| a.0.total_cmp(&b.0));
126 for w in sorted_adj.windows(2) {
127 assert!(
128 w[1].1 >= w[0].1 - TOL,
129 "monotonicity violated: {} > {}",
130 w[0].1,
131 w[1].1
132 );
133 }
134 }
135
136 #[test]
137 fn bh_clamp() {
138 let p = [0.9, 0.95];
139 let adj = benjamini_hochberg(&p).unwrap();
140 assert!(adj[0] <= 1.0);
141 assert!(adj[1] <= 1.0);
142 }
143
144 #[test]
145 fn correction_empty() {
146 assert_eq!(bonferroni(&[]).unwrap(), Vec::<f64>::new());
147 assert_eq!(benjamini_hochberg(&[]).unwrap(), Vec::<f64>::new());
148 }
149
150 #[test]
151 fn correction_single() {
152 assert!((bonferroni(&[0.05]).unwrap()[0] - 0.05).abs() < TOL);
153 assert!((benjamini_hochberg(&[0.05]).unwrap()[0] - 0.05).abs() < TOL);
154 }
155
156 #[test]
157 fn correction_invalid_p() {
158 assert!(bonferroni(&[0.5, 1.5]).is_err());
159 assert!(benjamini_hochberg(&[-0.1, 0.5]).is_err());
160 }
161}