use crate::call::types::{Filter, Genotype, GermlineCall, GermlineParams, ACGT};
use crate::core::allele_index;
use crate::pileup::PileupColumn;
const LN10: f64 = std::f64::consts::LN_10;
struct SiteLikelihoods {
log_l: [f64; 3],
alt_idx: usize,
ad: [u32; 2],
dp: u32,
}
fn site_likelihoods(column: &PileupColumn) -> Option<SiteLikelihoods> {
let ref_idx = allele_index(column.ref_base)?;
let counts = column.allele_counts();
let mut alt_idx: Option<usize> = None;
let mut best = 0u32;
for (i, &cnt) in counts.iter().enumerate() {
if i == ref_idx {
continue;
}
if cnt > best {
best = cnt;
alt_idx = Some(i);
}
}
let alt_idx = alt_idx?;
if best == 0 {
return None;
}
let mut log_l = [0.0f64; 3];
for o in &column.obs {
let eps = (10f64.powf(-(o.base_qual as f64) / 10.0)).min(0.75);
let a = o.allele as usize;
let p_ref = if a == ref_idx { 1.0 - eps } else { eps / 3.0 };
let p_alt = if a == alt_idx { 1.0 - eps } else { eps / 3.0 };
log_l[0] += p_ref.ln();
log_l[1] += (0.5 * p_ref + 0.5 * p_alt).ln();
log_l[2] += p_alt.ln();
}
Some(SiteLikelihoods {
log_l,
alt_idx,
ad: [counts[ref_idx], counts[alt_idx]],
dp: column.depth(),
})
}
pub fn call_germline(column: &PileupColumn, params: &GermlineParams) -> Option<GermlineCall> {
let sl = site_likelihoods(column)?;
let theta = params.heterozygosity;
let log_prior = [(1.0 - 1.5 * theta).ln(), theta.ln(), (theta / 2.0).ln()];
let log_post = [
sl.log_l[0] + log_prior[0],
sl.log_l[1] + log_prior[1],
sl.log_l[2] + log_prior[2],
];
let gt_idx = argmax3(&log_post);
if gt_idx == 0 {
return None;
}
let max_log_l = sl.log_l.iter().copied().fold(f64::NEG_INFINITY, f64::max);
let mut pl = [0u32; 3];
for (g, &ll) in sl.log_l.iter().enumerate() {
let phred = -10.0 * (ll - max_log_l) / LN10;
pl[g] = phred.round().min(255.0) as u32;
}
let mut sorted = pl;
sorted.sort_unstable();
let gq = sorted[1].min(99) as u8;
let max_lp = log_post.iter().copied().fold(f64::NEG_INFINITY, f64::max);
let log_z = max_lp
+ log_post
.iter()
.map(|lp| (lp - max_lp).exp())
.sum::<f64>()
.ln();
let log_p_homref = log_post[0] - log_z; let qual = (-10.0 * log_p_homref / LN10).max(0.0);
let genotype = if gt_idx == 1 {
Genotype::Het
} else {
Genotype::HomAlt
};
let filter = if sl.dp < params.min_depth {
Filter::LowDepth
} else if qual < params.min_qual {
Filter::LowQual
} else {
Filter::Pass
};
Some(GermlineCall {
genotype,
alt_base: ACGT[sl.alt_idx],
qual,
gq,
pl,
ad: sl.ad,
dp: sl.dp,
filter,
})
}
#[derive(Debug, Clone, PartialEq)]
pub enum GvcfGenotype {
HomRef {
gq: u8,
dp: u32,
},
Variant(GermlineCall),
NoCall {
dp: u32,
},
}
pub fn genotype_column_gvcf(column: &PileupColumn, params: &GermlineParams) -> GvcfGenotype {
let Some(ref_idx) = allele_index(column.ref_base) else {
return GvcfGenotype::NoCall { dp: column.depth() };
};
let dp = column.depth();
if dp == 0 {
return GvcfGenotype::NoCall { dp: 0 };
}
let counts = column.allele_counts();
let mut alt_idx = (ref_idx + 1) % 4;
let mut best = 0u32;
for (i, &cnt) in counts.iter().enumerate() {
if i != ref_idx && cnt > best {
best = cnt;
alt_idx = i;
}
}
let mut log_l = [0.0f64; 3];
for o in &column.obs {
let eps = (10f64.powf(-(o.base_qual as f64) / 10.0)).min(0.75);
let a = o.allele as usize;
let p_ref = if a == ref_idx { 1.0 - eps } else { eps / 3.0 };
let p_alt = if a == alt_idx { 1.0 - eps } else { eps / 3.0 };
log_l[0] += p_ref.ln();
log_l[1] += (0.5 * p_ref + 0.5 * p_alt).ln();
log_l[2] += p_alt.ln();
}
let theta = params.heterozygosity;
let log_prior = [(1.0 - 1.5 * theta).ln(), theta.ln(), (theta / 2.0).ln()];
let log_post = [
log_l[0] + log_prior[0],
log_l[1] + log_prior[1],
log_l[2] + log_prior[2],
];
let max_log_l = log_l.iter().copied().fold(f64::NEG_INFINITY, f64::max);
let mut pl = [0u32; 3];
for (g, &ll) in log_l.iter().enumerate() {
pl[g] = (-10.0 * (ll - max_log_l) / LN10).round().min(255.0) as u32;
}
let mut sorted = pl;
sorted.sort_unstable();
let gq = sorted[1].min(99) as u8;
let gt_idx = argmax3(&log_post);
if gt_idx == 0 {
return GvcfGenotype::HomRef { gq, dp };
}
let max_lp = log_post.iter().copied().fold(f64::NEG_INFINITY, f64::max);
let log_z = max_lp
+ log_post
.iter()
.map(|lp| (lp - max_lp).exp())
.sum::<f64>()
.ln();
let qual = (-10.0 * (log_post[0] - log_z) / LN10).max(0.0);
let filter = if dp < params.min_depth {
Filter::LowDepth
} else if qual < params.min_qual {
Filter::LowQual
} else {
Filter::Pass
};
GvcfGenotype::Variant(GermlineCall {
genotype: if gt_idx == 1 {
Genotype::Het
} else {
Genotype::HomAlt
},
alt_base: ACGT[alt_idx],
qual,
gq,
pl,
ad: [counts[ref_idx], counts[alt_idx]],
dp,
filter,
})
}
fn argmax3(v: &[f64; 3]) -> usize {
let mut best = 0;
for i in 1..3 {
if v[i] > v[best] {
best = i;
}
}
best
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::{Locus, Position};
use crate::pileup::Obs;
fn col(ref_base: u8, obs_spec: &[(u8, u8)]) -> PileupColumn {
let obs: Vec<Obs> = obs_spec
.iter()
.map(|&(allele, base_qual)| Obs {
allele,
base_qual,
mapq: 60,
reverse: false,
})
.collect();
PileupColumn {
locus: Locus {
contig: 0,
pos: Position(100),
},
ref_base,
raw_depth: obs.len() as u32,
obs,
}
}
#[test]
fn likelihoods_rank_genotypes_correctly() {
let het: Vec<(u8, u8)> = (0..15)
.map(|_| (0u8, 30u8))
.chain((0..15).map(|_| (1u8, 30u8)))
.collect();
let sl = site_likelihoods(&col(b'A', &het)).unwrap();
assert_eq!(sl.alt_idx, 1); assert_eq!(sl.ad, [15, 15]);
assert_eq!(sl.dp, 30);
assert!(sl.log_l[1] > sl.log_l[0]);
assert!(sl.log_l[1] > sl.log_l[2]);
let homalt: Vec<(u8, u8)> = (0..20).map(|_| (1u8, 30u8)).collect();
let sl = site_likelihoods(&col(b'A', &homalt)).unwrap();
assert!(sl.log_l[2] > sl.log_l[1]);
assert!(sl.log_l[2] > sl.log_l[0]);
}
#[test]
fn no_alt_or_non_callable_ref_yields_none() {
let allref: Vec<(u8, u8)> = (0..20).map(|_| (0u8, 30u8)).collect();
assert!(site_likelihoods(&col(b'A', &allref)).is_none());
assert!(site_likelihoods(&col(b'N', &[(1, 30), (1, 30)])).is_none());
}
#[test]
fn gvcf_genotype_classifies_every_column() {
let params = GermlineParams::default();
let allref: Vec<(u8, u8)> = (0..20).map(|_| (0u8, 30u8)).collect();
match genotype_column_gvcf(&col(b'A', &allref), ¶ms) {
GvcfGenotype::HomRef { gq, dp } => {
assert_eq!(dp, 20);
assert!(gq > 0, "deep all-ref should be confident hom-ref, gq={gq}");
}
other => panic!("expected HomRef, got {other:?}"),
}
let het: Vec<(u8, u8)> = (0..15)
.map(|_| (0u8, 30u8))
.chain((0..15).map(|_| (1u8, 30u8)))
.collect();
assert!(matches!(
genotype_column_gvcf(&col(b'A', &het), ¶ms),
GvcfGenotype::Variant(_)
));
assert!(matches!(
genotype_column_gvcf(&col(b'N', &[(1, 30)]), ¶ms),
GvcfGenotype::NoCall { .. }
));
assert!(matches!(
genotype_column_gvcf(&col(b'A', &[]), ¶ms),
GvcfGenotype::NoCall { dp: 0 }
));
}
#[test]
fn hom_ref_site_abstains() {
let allref: Vec<(u8, u8)> = (0..20).map(|_| (0u8, 30u8)).collect();
assert!(call_germline(&col(b'A', &allref), &GermlineParams::default()).is_none());
}
#[test]
fn thin_alt_evidence_abstains() {
let thin = [(0u8, 30u8), (0, 30), (0, 30), (1, 30)];
assert!(call_germline(&col(b'A', &thin), &GermlineParams::default()).is_none());
}
#[test]
fn clear_het_passes() {
let het: Vec<(u8, u8)> = (0..15)
.map(|_| (0u8, 30u8))
.chain((0..15).map(|_| (1u8, 30u8)))
.collect();
let call = call_germline(&col(b'A', &het), &GermlineParams::default()).unwrap();
assert_eq!(call.genotype, Genotype::Het);
assert_eq!(call.alt_base, b'C');
assert_eq!(call.ad, [15, 15]);
assert_eq!(call.dp, 30);
assert_eq!(call.pl[1], 0); assert!(call.pl[0] > 0 && call.pl[2] > 0);
assert_eq!(call.filter, Filter::Pass);
assert!(call.gq > 0);
}
#[test]
fn clear_hom_alt_passes() {
let homalt: Vec<(u8, u8)> = (0..20).map(|_| (1u8, 30u8)).collect();
let call = call_germline(&col(b'A', &homalt), &GermlineParams::default()).unwrap();
assert_eq!(call.genotype, Genotype::HomAlt);
assert_eq!(call.pl[2], 0);
assert_eq!(call.filter, Filter::Pass);
}
#[test]
fn shallow_variant_is_flagged_low_depth_not_dropped() {
let shallow: Vec<(u8, u8)> = (0..3).map(|_| (1u8, 30u8)).collect();
let call = call_germline(&col(b'A', &shallow), &GermlineParams::default()).unwrap();
assert_eq!(call.genotype, Genotype::HomAlt);
assert_eq!(call.dp, 3);
assert_eq!(call.filter, Filter::LowDepth);
}
#[test]
fn qual_increases_with_evidence() {
let mk = |n: usize| -> f64 {
let obs: Vec<(u8, u8)> = (0..n)
.map(|_| (0u8, 30u8))
.chain((0..n).map(|_| (1u8, 30u8)))
.collect();
call_germline(&col(b'A', &obs), &GermlineParams::default())
.unwrap()
.qual
};
assert!(mk(20) > mk(8));
}
}