use crate::linalg::eigh;
use crate::special::{chisq_qf, f_qf, ln_norm_cdf, t_sf};
use ndarray::Array2;
fn sgn(v: f64) -> f64 {
if v > 0.0 {
1.0
} else if v < 0.0 {
-1.0
} else {
0.0
}
}
fn df_at(df: &[f64], i: usize) -> f64 {
if df.len() == 1 {
df[0]
} else {
df[i]
}
}
fn whitening(cor_matrix: Option<&Array2<f64>>, ntests: usize) -> (Array2<f64>, usize) {
match cor_matrix {
None => {
let r = ntests;
let s = 1.0 / (r as f64).sqrt();
let mut q = Array2::<f64>::zeros((ntests, r));
for k in 0..ntests {
q[[k, k]] = s;
}
(q, r)
}
Some(cor) => {
let (evals, evecs) = eigh(cor);
let n = evals.len();
let lmax = evals[n - 1];
let r = evals.iter().filter(|&&e| e / lmax > 1e-8).count();
let rs = (r as f64).sqrt();
let mut q = Array2::<f64>::zeros((n, r));
for jj in 0..r {
let col = n - 1 - jj; let scale = 1.0 / evals[col].sqrt() / rs;
for k in 0..n {
q[[k, jj]] = evecs[[k, col]] * scale;
}
}
(q, r)
}
}
}
fn fstat_row(q: &Array2<f64>, x: &[f64]) -> f64 {
let (n, r) = (q.nrows(), q.ncols());
let mut s = 0.0;
for j in 0..r {
let mut dot = 0.0;
for k in 0..n {
dot += q[[k, j]] * x[k];
}
s += dot * dot;
}
s
}
pub fn classify_tests_fstat(
tstat: &Array2<f64>,
cor_matrix: Option<&Array2<f64>>,
) -> (Vec<f64>, usize) {
let ngenes = tstat.nrows();
let ntests = tstat.ncols();
if ntests == 1 {
let fstat = (0..ngenes).map(|i| tstat[[i, 0]].powi(2)).collect();
return (fstat, 1);
}
let (q, r) = whitening(cor_matrix, ntests);
let fstat = (0..ngenes)
.map(|i| fstat_row(&q, &tstat.row(i).to_vec()))
.collect();
(fstat, r)
}
pub fn classify_tests_f(
tstat: &Array2<f64>,
cor_matrix: Option<&Array2<f64>>,
df: &[f64],
p_value: f64,
) -> Array2<i8> {
let ngenes = tstat.nrows();
let ntests = tstat.ncols();
let mut result = Array2::<i8>::zeros((ngenes, ntests));
if ntests == 1 {
for i in 0..ngenes {
let t = tstat[[i, 0]];
if t.is_nan() {
continue;
}
let dfi = df_at(df, i);
let p = if dfi.is_infinite() {
2.0 * ln_norm_cdf(-t.abs()).exp()
} else {
2.0 * t_sf(t.abs(), dfi)
};
if p < p_value {
result[[i, 0]] = sgn(t) as i8;
}
}
return result;
}
let (q, r) = whitening(cor_matrix, ntests);
for i in 0..ngenes {
let x = tstat.row(i).to_vec();
if x.iter().any(|v| v.is_nan()) {
continue;
}
let dfi = df_at(df, i);
let qf = if dfi.is_infinite() {
chisq_qf(1.0 - p_value, r as f64) / r as f64
} else {
f_qf(1.0 - p_value, r as f64, dfi)
};
if fstat_row(&q, &x) <= qf {
continue;
}
let mut ord: Vec<usize> = (0..ntests).collect();
ord.sort_by(|&a, &b| x[b].abs().partial_cmp(&x[a].abs()).unwrap().then(a.cmp(&b)));
result[[i, ord[0]]] = sgn(x[ord[0]]) as i8;
let mut xx = x.clone();
for c in 1..ntests {
let cap = x[ord[c]].abs();
for &k in &ord[0..c] {
xx[k] = sgn(x[k]) * cap;
}
if fstat_row(&q, &xx) > qf {
result[[i, ord[c]]] = sgn(x[ord[c]]) as i8;
} else {
break;
}
}
}
result
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::array;
fn close(a: &[f64], b: &[f64], tol: f64) -> bool {
a.len() == b.len()
&& a.iter()
.zip(b)
.all(|(&x, &y)| (x - y).abs() <= tol + tol * y.abs())
}
fn tstat() -> Array2<f64> {
array![
[3.0, 0.5, -0.2],
[2.5, 2.4, 2.6],
[0.1, -0.1, 0.2],
[-4.0, 1.0, 0.5],
[2.0, -2.1, 0.0],
]
}
fn cormat() -> Array2<f64> {
array![[1.0, 0.6, 0.2], [0.6, 1.0, 0.3], [0.2, 0.3, 1.0]]
}
#[test]
fn classify_null_inf_matches_r() {
let got = classify_tests_f(&tstat(), None, &[f64::INFINITY], 0.01);
let want = array![[0i8, 0, 0], [1, 1, 1], [0, 0, 0], [-1, 0, 0], [0, 0, 0],];
assert_eq!(got, want);
}
#[test]
fn classify_null_df10_matches_r() {
let got = classify_tests_f(&tstat(), None, &[10.0], 0.05);
let want = array![[0i8, 0, 0], [1, 1, 1], [0, 0, 0], [-1, 0, 0], [0, 0, 0],];
assert_eq!(got, want);
}
#[test]
fn classify_cor_df10_matches_r() {
let cor = cormat();
let got = classify_tests_f(&tstat(), Some(&cor), &[10.0], 0.05);
let want = array![[1i8, 0, 0], [0, 0, 1], [0, 0, 0], [-1, 0, 0], [1, -1, 0],];
assert_eq!(got, want);
}
#[test]
fn fstat_cor_matches_r() {
let cor = cormat();
let (fstat, df1) = classify_tests_fstat(&tstat(), Some(&cor));
assert_eq!(df1, 3);
assert!(close(
&fstat,
&[
3.94936998854525,
3.75549828178694,
0.0352233676975945,
11.397479954181,
7.10744558991983,
],
1e-9
));
}
#[test]
fn fstat_null_matches_r() {
let (fstat, df1) = classify_tests_fstat(&tstat(), None);
assert_eq!(df1, 3);
assert!(close(
&fstat,
&[
3.09666666666667,
6.25666666666667,
0.02,
5.75,
2.80333333333333,
],
1e-9
));
}
#[test]
fn classify_single_test_matches_r() {
let t1 = array![[3.0], [0.5], [-2.8], [1.0]];
let got = classify_tests_f(&t1, None, &[10.0], 0.05);
let want = array![[1i8], [0], [-1], [0]];
assert_eq!(got, want);
}
}