use faer::Mat;
#[derive(Debug, Clone)]
pub struct EnrichResult {
pub annotations: Vec<String>,
pub enrichment: Vec<f64>,
pub se: Vec<f64>,
pub p: Vec<f64>,
}
pub fn enrichment_test(
s_baseline: &Mat<f64>,
s_annot: &[Mat<f64>],
v_annot: &[Mat<f64>],
annotation_names: &[String],
m_annot: &[f64],
m_total: f64,
) -> EnrichResult {
let n_annot = annotation_names.len();
let k = s_baseline.nrows();
let h2_total: f64 = (0..k).map(|i| s_baseline[(i, i)]).sum::<f64>() / k as f64;
let mut enrichment = Vec::with_capacity(n_annot);
let mut se = Vec::with_capacity(n_annot);
let mut p_vals = Vec::with_capacity(n_annot);
for a in 0..n_annot {
let h2_annot: f64 = (0..k).map(|i| s_annot[a][(i, i)]).sum::<f64>() / k as f64;
let prop_snps = m_annot[a] / m_total;
let prop_h2 = if h2_total.abs() > 1e-30 {
h2_annot / h2_total
} else {
0.0
};
let enrich = if prop_snps > 0.0 {
prop_h2 / prop_snps
} else {
0.0
};
let annot_se = compute_diag_se(&v_annot[a], k);
let se_val = if h2_total.abs() > 1e-30 && prop_snps > 0.0 {
(annot_se / h2_total.abs()) / prop_snps
} else {
0.0
};
let z = if se_val > 1e-30 {
(enrich - 1.0) / se_val
} else {
0.0
};
let p = {
use statrs::distribution::{ContinuousCDF, Normal};
let n = Normal::standard();
n.cdf(-z)
};
enrichment.push(enrich);
se.push(se_val);
p_vals.push(p);
}
EnrichResult {
annotations: annotation_names.to_vec(),
enrichment,
se,
p: p_vals,
}
}
fn compute_diag_se(v: &Mat<f64>, k: usize) -> f64 {
let var_sum: f64 = (0..k)
.map(|i| vech_diag_index(i, k))
.filter(|&idx| idx < v.nrows())
.map(|idx| v[(idx, idx)])
.sum();
(var_sum / (k as f64).powi(2)).sqrt()
}
fn vech_diag_index(i: usize, k: usize) -> usize {
i * k - i * i.saturating_sub(1) / 2
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_enrichment_basic() {
let s_baseline = faer::mat![[0.3, 0.1], [0.1, 0.4]];
let s_annot = vec![faer::mat![[0.15, 0.05], [0.05, 0.2]]];
let v_annot = vec![Mat::from_fn(3, 3, |i, j| if i == j { 0.02 } else { 0.0 })];
let names = vec!["Annot1".to_string()];
let m_annot = vec![100000.0];
let m_total = 1000000.0;
let result = enrichment_test(&s_baseline, &s_annot, &v_annot, &names, &m_annot, m_total);
assert_eq!(result.annotations.len(), 1);
assert!(result.enrichment[0] > 0.0);
assert!(result.se[0] > 0.0, "SE should be positive, not placeholder");
assert!(result.p[0] >= 0.0 && result.p[0] <= 1.0);
}
#[test]
fn test_vech_diag_index() {
assert_eq!(vech_diag_index(0, 3), 0);
assert_eq!(vech_diag_index(1, 3), 3);
assert_eq!(vech_diag_index(2, 3), 5);
}
}