use std::collections::HashMap;
use anyhow::Result;
use statrs::distribution::{ContinuousCDF, Normal};
use crate::io::gwas_reader::{GwasRecord, MungedRecord};
use super::allele::{AlleleMatch, alleles_match};
use super::{MungeConfig, RefSnp};
pub fn run_qc_pipeline(
records: Vec<GwasRecord>,
reference: &HashMap<String, RefSnp>,
config: &MungeConfig,
) -> Result<Vec<MungedRecord>> {
let normal = Normal::standard();
let mut output = Vec::new();
let mut n_no_ref = 0usize;
let mut n_missing = 0usize;
let mut n_allele_mismatch = 0usize;
let mut n_info_filtered = 0usize;
let mut n_maf_filtered = 0usize;
let mut n_bad_alleles = 0usize;
let is_or = detect_or(&records);
if is_or {
log::info!("Detected OR format (median effect ≈ 1), converting to log(OR)");
}
for mut rec in records {
if let Some(n_override) = config.n_override {
rec.n = Some(n_override);
}
let (Some(mut effect), Some(p)) = (rec.effect, rec.p) else {
n_missing += 1;
continue;
};
if !(0.0..=1.0).contains(&p) || p.is_nan() {
n_missing += 1;
continue;
}
let n = rec.n.unwrap_or(0.0);
if n <= 0.0 {
n_missing += 1;
continue;
}
let (Some(a1_str), Some(a2_str)) = (rec.a1.as_ref(), rec.a2.as_ref()) else {
n_bad_alleles += 1;
continue;
};
let a1 = a1_str.to_uppercase();
let a2 = a2_str.to_uppercase();
if !is_valid_allele(&a1) || !is_valid_allele(&a2) {
n_bad_alleles += 1;
continue;
}
let Some(ref_snp) = reference.get(&rec.snp) else {
n_no_ref += 1;
continue;
};
#[allow(clippy::collapsible_if)]
if let Some(info) = rec.info {
if info < config.info_filter {
n_info_filtered += 1;
continue;
}
}
if let Some(mut maf) = rec.maf {
if maf > 0.5 {
maf = 1.0 - maf;
}
if maf < config.maf_filter {
n_maf_filtered += 1;
continue;
}
}
if is_or {
if effect <= 0.0 {
n_missing += 1;
continue;
}
effect = effect.ln();
}
let amatch = alleles_match(&a1, &a2, &ref_snp.a1, &ref_snp.a2);
match amatch {
AlleleMatch::Match => {}
AlleleMatch::Flipped => {
effect = -effect;
}
AlleleMatch::NoMatch => {
n_allele_mismatch += 1;
continue;
}
}
let z = compute_z(effect, p, &normal);
output.push(MungedRecord {
snp: rec.snp,
n,
z,
a1: ref_snp.a1.clone(),
a2: ref_snp.a2.clone(),
});
}
log::info!(
"QC summary: {} no ref, {} missing P/effect/N, {} bad alleles, {} allele mismatch, {} INFO filtered, {} MAF filtered",
n_no_ref,
n_missing,
n_bad_alleles,
n_allele_mismatch,
n_info_filtered,
n_maf_filtered
);
Ok(output)
}
fn detect_or(records: &[GwasRecord]) -> bool {
let effects: Vec<f64> = records
.iter()
.filter_map(|r| r.effect)
.filter(|e| e.is_finite())
.collect();
if effects.is_empty() {
return false;
}
let mut sorted = effects;
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let median = sorted[sorted.len() / 2];
median.round() == 1.0
}
fn compute_z(effect: f64, p: f64, normal: &Normal) -> f64 {
if effect == 0.0 {
return 0.0;
}
if p == 0.0 {
return if effect > 0.0 { 37.0 } else { -37.0 };
}
let z_abs = normal.inverse_cdf(1.0 - p / 2.0);
if !z_abs.is_finite() {
return if effect > 0.0 { 37.0 } else { -37.0 };
}
if effect > 0.0 { z_abs } else { -z_abs }
}
fn is_valid_allele(a: &str) -> bool {
matches!(a, "A" | "C" | "G" | "T")
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_compute_z_positive() {
let normal = Normal::standard();
let z = compute_z(0.5, 0.05, &normal);
assert!(z > 0.0);
assert!((z - 1.96).abs() < 0.01);
}
#[test]
fn test_compute_z_negative() {
let normal = Normal::standard();
let z = compute_z(-0.5, 0.05, &normal);
assert!(z < 0.0);
assert!((z + 1.96).abs() < 0.01);
}
#[test]
fn test_compute_z_large_p() {
let normal = Normal::standard();
for &p in &[0.9722_f64, 0.98, 0.99, 0.999] {
let z = compute_z(-0.001, p, &normal);
assert!(z.is_finite(), "z non-finite at p={p}");
let expected = -match p {
p if (p - 0.9722).abs() < 1e-9 => 0.034849,
p if (p - 0.98).abs() < 1e-9 => 0.025069,
p if (p - 0.99).abs() < 1e-9 => 0.012533,
p if (p - 0.999).abs() < 1e-9 => 0.001253,
_ => unreachable!(),
};
assert!(
(z - expected).abs() < 1e-4,
"p={p} z={z} expected≈{expected}"
);
}
}
#[test]
fn test_compute_z_tiny_p() {
let normal = Normal::standard();
let z = compute_z(1.0, 1e-300, &normal);
assert!(z > 30.0);
}
#[test]
fn test_compute_z_zero_p() {
let normal = Normal::standard();
let z = compute_z(1.0, 0.0, &normal);
assert_eq!(z, 37.0);
}
#[test]
fn test_detect_or_beta() {
let records: Vec<GwasRecord> = (0..100)
.map(|_| GwasRecord {
snp: "rs1".to_string(),
effect: Some(0.05),
p: Some(0.5),
a1: None,
a2: None,
se: None,
n: None,
z: None,
info: None,
maf: None,
})
.collect();
assert!(!detect_or(&records));
}
#[test]
fn test_detect_or_or() {
let records: Vec<GwasRecord> = (0..100)
.map(|_| GwasRecord {
snp: "rs1".to_string(),
effect: Some(1.05),
p: Some(0.5),
a1: None,
a2: None,
se: None,
n: None,
z: None,
info: None,
maf: None,
})
.collect();
assert!(detect_or(&records));
}
#[test]
fn test_is_valid_allele() {
assert!(is_valid_allele("A"));
assert!(is_valid_allele("C"));
assert!(is_valid_allele("G"));
assert!(is_valid_allele("T"));
assert!(!is_valid_allele("N"));
assert!(!is_valid_allele("AC"));
assert!(!is_valid_allele(""));
}
}