1use crate::linalg::eigh;
14use crate::special::{chisq_qf, f_qf, ln_norm_cdf, t_sf};
15use ndarray::Array2;
16
17fn sgn(v: f64) -> f64 {
20 if v > 0.0 {
21 1.0
22 } else if v < 0.0 {
23 -1.0
24 } else {
25 0.0
26 }
27}
28
29fn df_at(df: &[f64], i: usize) -> f64 {
31 if df.len() == 1 {
32 df[0]
33 } else {
34 df[i]
35 }
36}
37
38fn whitening(cor_matrix: Option<&Array2<f64>>, ntests: usize) -> (Array2<f64>, usize) {
43 match cor_matrix {
44 None => {
45 let r = ntests;
46 let s = 1.0 / (r as f64).sqrt();
47 let mut q = Array2::<f64>::zeros((ntests, r));
48 for k in 0..ntests {
49 q[[k, k]] = s;
50 }
51 (q, r)
52 }
53 Some(cor) => {
54 let (evals, evecs) = eigh(cor);
56 let n = evals.len();
57 let lmax = evals[n - 1];
58 let r = evals.iter().filter(|&&e| e / lmax > 1e-8).count();
59 let rs = (r as f64).sqrt();
60 let mut q = Array2::<f64>::zeros((n, r));
61 for jj in 0..r {
62 let col = n - 1 - jj; let scale = 1.0 / evals[col].sqrt() / rs;
64 for k in 0..n {
65 q[[k, jj]] = evecs[[k, col]] * scale;
66 }
67 }
68 (q, r)
69 }
70 }
71}
72
73fn fstat_row(q: &Array2<f64>, x: &[f64]) -> f64 {
75 let (n, r) = (q.nrows(), q.ncols());
76 let mut s = 0.0;
77 for j in 0..r {
78 let mut dot = 0.0;
79 for k in 0..n {
80 dot += q[[k, j]] * x[k];
81 }
82 s += dot * dot;
83 }
84 s
85}
86
87pub fn classify_tests_fstat(
91 tstat: &Array2<f64>,
92 cor_matrix: Option<&Array2<f64>>,
93) -> (Vec<f64>, usize) {
94 let ngenes = tstat.nrows();
95 let ntests = tstat.ncols();
96 if ntests == 1 {
97 let fstat = (0..ngenes).map(|i| tstat[[i, 0]].powi(2)).collect();
98 return (fstat, 1);
99 }
100 let (q, r) = whitening(cor_matrix, ntests);
101 let fstat = (0..ngenes)
102 .map(|i| fstat_row(&q, &tstat.row(i).to_vec()))
103 .collect();
104 (fstat, r)
105}
106
107pub fn classify_tests_f(
114 tstat: &Array2<f64>,
115 cor_matrix: Option<&Array2<f64>>,
116 df: &[f64],
117 p_value: f64,
118) -> Array2<i8> {
119 let ngenes = tstat.nrows();
120 let ntests = tstat.ncols();
121 let mut result = Array2::<i8>::zeros((ngenes, ntests));
122
123 if ntests == 1 {
124 for i in 0..ngenes {
125 let t = tstat[[i, 0]];
126 if t.is_nan() {
127 continue;
128 }
129 let dfi = df_at(df, i);
130 let p = if dfi.is_infinite() {
131 2.0 * ln_norm_cdf(-t.abs()).exp()
132 } else {
133 2.0 * t_sf(t.abs(), dfi)
134 };
135 if p < p_value {
136 result[[i, 0]] = sgn(t) as i8;
137 }
138 }
139 return result;
140 }
141
142 let (q, r) = whitening(cor_matrix, ntests);
143
144 for i in 0..ngenes {
145 let x = tstat.row(i).to_vec();
146 if x.iter().any(|v| v.is_nan()) {
147 continue;
148 }
149 let dfi = df_at(df, i);
150 let qf = if dfi.is_infinite() {
151 chisq_qf(1.0 - p_value, r as f64) / r as f64
152 } else {
153 f_qf(1.0 - p_value, r as f64, dfi)
154 };
155 if fstat_row(&q, &x) <= qf {
156 continue;
157 }
158 let mut ord: Vec<usize> = (0..ntests).collect();
160 ord.sort_by(|&a, &b| x[b].abs().partial_cmp(&x[a].abs()).unwrap().then(a.cmp(&b)));
161 result[[i, ord[0]]] = sgn(x[ord[0]]) as i8;
162 let mut xx = x.clone();
165 for c in 1..ntests {
166 let cap = x[ord[c]].abs();
167 for &k in &ord[0..c] {
168 xx[k] = sgn(x[k]) * cap;
169 }
170 if fstat_row(&q, &xx) > qf {
171 result[[i, ord[c]]] = sgn(x[ord[c]]) as i8;
172 } else {
173 break;
174 }
175 }
176 }
177 result
178}
179
180#[cfg(test)]
181mod tests {
182 use super::*;
183 use ndarray::array;
184
185 fn close(a: &[f64], b: &[f64], tol: f64) -> bool {
186 a.len() == b.len()
187 && a.iter()
188 .zip(b)
189 .all(|(&x, &y)| (x - y).abs() <= tol + tol * y.abs())
190 }
191
192 fn tstat() -> Array2<f64> {
193 array![
194 [3.0, 0.5, -0.2],
195 [2.5, 2.4, 2.6],
196 [0.1, -0.1, 0.2],
197 [-4.0, 1.0, 0.5],
198 [2.0, -2.1, 0.0],
199 ]
200 }
201
202 fn cormat() -> Array2<f64> {
203 array![[1.0, 0.6, 0.2], [0.6, 1.0, 0.3], [0.2, 0.3, 1.0]]
204 }
205
206 #[test]
207 fn classify_null_inf_matches_r() {
208 let got = classify_tests_f(&tstat(), None, &[f64::INFINITY], 0.01);
209 let want = array![[0i8, 0, 0], [1, 1, 1], [0, 0, 0], [-1, 0, 0], [0, 0, 0],];
210 assert_eq!(got, want);
211 }
212
213 #[test]
214 fn classify_null_df10_matches_r() {
215 let got = classify_tests_f(&tstat(), None, &[10.0], 0.05);
216 let want = array![[0i8, 0, 0], [1, 1, 1], [0, 0, 0], [-1, 0, 0], [0, 0, 0],];
217 assert_eq!(got, want);
218 }
219
220 #[test]
221 fn classify_cor_df10_matches_r() {
222 let cor = cormat();
223 let got = classify_tests_f(&tstat(), Some(&cor), &[10.0], 0.05);
224 let want = array![[1i8, 0, 0], [0, 0, 1], [0, 0, 0], [-1, 0, 0], [1, -1, 0],];
225 assert_eq!(got, want);
226 }
227
228 #[test]
229 fn fstat_cor_matches_r() {
230 let cor = cormat();
231 let (fstat, df1) = classify_tests_fstat(&tstat(), Some(&cor));
232 assert_eq!(df1, 3);
233 assert!(close(
234 &fstat,
235 &[
236 3.94936998854525,
237 3.75549828178694,
238 0.0352233676975945,
239 11.397479954181,
240 7.10744558991983,
241 ],
242 1e-9
243 ));
244 }
245
246 #[test]
247 fn fstat_null_matches_r() {
248 let (fstat, df1) = classify_tests_fstat(&tstat(), None);
249 assert_eq!(df1, 3);
250 assert!(close(
251 &fstat,
252 &[
253 3.09666666666667,
254 6.25666666666667,
255 0.02,
256 5.75,
257 2.80333333333333,
258 ],
259 1e-9
260 ));
261 }
262
263 #[test]
264 fn classify_single_test_matches_r() {
265 let t1 = array![[3.0], [0.5], [-2.8], [1.0]];
266 let got = classify_tests_f(&t1, None, &[10.0], 0.05);
267 let want = array![[1i8], [0], [-1], [0]];
268 assert_eq!(got, want);
269 }
270}