use std::cell::OnceCell;
use anyhow::bail;
use rust_htslib::bcf::{
Record,
header::{TagLength, TagType},
record::GenotypeAllele,
};
use tracing::error;
use crate::vcf::{pipeline::clusterizer::phasing::OutputPhasing, strings::VCF_LOCAL_PHASE_KEY};
pub fn compute_statistics(
record: &mut Record,
old_records: Option<&[Record]>,
phasing: Option<OutputPhasing>,
) -> anyhow::Result<()> {
let mut ctx = StatContext::new(old_records, phasing);
let stats: [Box<dyn Statistic>; 6] = [
Box::new(Genotype), Box::new(PhaseSetting),
Box::new(LocalPhase),
Box::new(AlleleCount),
Box::new(AlleleFrequency),
Box::new(AlleleNumber),
];
for stat in stats {
if stat.should_add(record, &ctx) {
stat.add_to_record(record, &mut ctx)?;
}
}
Ok(())
}
trait Statistic {
fn should_add(&self, record: &Record, ctx: &StatContext) -> bool;
fn add_to_record(&self, record: &mut Record, ctx: &mut StatContext) -> anyhow::Result<()>;
}
struct StatContext<'r> {
counts: OnceCell<anyhow::Result<(Vec<i32>, i32)>>,
old_records: Option<&'r [Record]>,
phasing: Option<OutputPhasing>,
}
impl<'r> StatContext<'r> {
const fn new(old_records: Option<&'r [Record]>, phasing: Option<OutputPhasing>) -> Self {
Self {
counts: OnceCell::new(),
old_records,
phasing,
}
}
fn get_allele_counts(&self, record: &Record) -> anyhow::Result<&Vec<i32>> {
self.counts
.get_or_init(|| count_alleles(record))
.as_ref()
.map(|(ac, _)| ac)
.map_err(|e| anyhow::anyhow!("{e}"))
}
fn get_total_called(&self, record: &Record) -> anyhow::Result<i32> {
self.counts
.get_or_init(|| count_alleles(record))
.as_ref()
.map(|(_, total)| *total)
.map_err(|e| anyhow::anyhow!("{e}"))
}
}
fn has_header_info_field(record: &Record, tag: &[u8], expected: (TagType, TagLength)) -> bool {
matches!(record.header().info_type(tag), Ok(ty) if ty == expected)
}
fn has_header_format_field(record: &Record, tag: &[u8], expected: (TagType, TagLength)) -> bool {
matches!(record.header().format_type(tag), Ok(ty) if ty == expected)
}
fn count_alleles(record: &Record) -> anyhow::Result<(Vec<i32>, i32)> {
let alt_alleles = record.allele_count() - 1;
let genotypes = record.genotypes()?;
let mut acs = vec![0i32; alt_alleles as usize];
let mut total = 0i32;
for i in 0..record.sample_count() {
let gt = genotypes.get(i as usize);
for a in &*gt {
if let Some(index) = a.index() {
total += 1;
if index > 0
&& let Some(x) = acs.get_mut(index as usize - 1)
{
*x += 1;
}
}
}
}
Ok((acs, total))
}
struct Genotype;
impl Statistic for Genotype {
fn should_add(&self, record: &Record, ctx: &StatContext) -> bool {
let has_gt_header =
has_header_format_field(record, b"GT", (TagType::String, TagLength::Fixed(1)));
let to_add = has_gt_header
&& (ctx.phasing.is_some() || ctx.old_records.is_some_and(|s| !s.is_empty()));
if !to_add {
error!("Not adding genotype field GT, this will cause errors!");
}
to_add
}
fn add_to_record(&self, record: &mut Record, ctx: &mut StatContext) -> anyhow::Result<()> {
match ctx.phasing {
Some(op) => {
#[allow(clippy::cast_possible_wrap)]
let (a0, a1) = (op.alleles[0] as i32, op.alleles[1] as i32);
let second = if op.is_phased() {
GenotypeAllele::Phased(a1)
} else {
GenotypeAllele::Unphased(a1)
};
record.push_genotypes(&[GenotypeAllele::Unphased(a0), second])?;
}
None => {
copy_gt_from_old(record, ctx)?;
}
}
Ok(())
}
}
fn copy_gt_from_old(record: &mut Record, ctx: &StatContext<'_>) -> anyhow::Result<()> {
let Some(tpl) = ctx.old_records.and_then(|r| r.first()) else {
bail!("Implementation error: there should be old records available.");
};
let gt = tpl.genotypes()?;
record.push_genotypes(>.get(0))?;
Ok(())
}
struct PhaseSetting;
impl Statistic for PhaseSetting {
fn should_add(&self, record: &Record, ctx: &StatContext) -> bool {
ctx.phasing.is_some_and(|op| op.is_phased() && !op.is_hom())
&& has_header_format_field(record, b"PS", (TagType::Integer, TagLength::Fixed(1)))
}
fn add_to_record(&self, record: &mut Record, ctx: &mut StatContext) -> anyhow::Result<()> {
let Some(op) = ctx.phasing else {
bail!("Implementation error");
};
let phaseset = op.phaseset.unwrap_or(0);
record.push_format_integer(b"PS", &[phaseset])?;
Ok(())
}
}
struct LocalPhase;
impl Statistic for LocalPhase {
fn should_add(&self, _record: &Record, _ctx: &StatContext) -> bool {
false
}
fn add_to_record(&self, record: &mut Record, _ctx: &mut StatContext) -> anyhow::Result<()> {
record.push_info_flag(VCF_LOCAL_PHASE_KEY.as_bytes())?;
Ok(())
}
}
struct AlleleCount;
impl Statistic for AlleleCount {
fn should_add(&self, record: &Record, _ctx: &StatContext) -> bool {
has_header_info_field(record, b"AC", (TagType::Integer, TagLength::AltAlleles))
}
fn add_to_record(&self, record: &mut Record, ctx: &mut StatContext) -> anyhow::Result<()> {
let counts = ctx.get_allele_counts(record)?;
record.push_info_integer(b"AC", counts)?;
Ok(())
}
}
struct AlleleFrequency;
impl Statistic for AlleleFrequency {
fn should_add(&self, record: &Record, _ctx: &StatContext) -> bool {
has_header_info_field(record, b"AF", (TagType::Float, TagLength::AltAlleles))
}
#[allow(clippy::cast_precision_loss)]
fn add_to_record(&self, record: &mut Record, ctx: &mut StatContext) -> anyhow::Result<()> {
let counts = ctx.get_allele_counts(record)?;
let total = ctx.get_total_called(record)? as f32;
let freqs: Vec<_> = counts.iter().map(|c| *c as f32 / total).collect();
record.push_info_float(b"AF", &freqs)?;
Ok(())
}
}
struct AlleleNumber;
impl Statistic for AlleleNumber {
fn should_add(&self, record: &Record, _ctx: &StatContext) -> bool {
has_header_info_field(record, b"AN", (TagType::Integer, TagLength::Fixed(1)))
}
fn add_to_record(&self, record: &mut Record, ctx: &mut StatContext) -> anyhow::Result<()> {
let total = ctx.get_total_called(record)?;
record.push_info_integer(b"AN", &[total])?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use rust_htslib::bcf::{self, Header, Writer, record::GenotypeAllele};
use super::*;
use crate::vcf::pipeline::clusterizer::phasing::Haplotype;
fn make_record() -> Record {
let mut h = Header::new();
h.push_record(b"##contig=<ID=chr1,length=1000000>");
h.push_record(b"##FORMAT=<ID=GT,Number=1,Type=String,Description=\"Genotype\">");
h.push_record(b"##FORMAT=<ID=PS,Number=1,Type=Integer,Description=\"Phase set\">");
h.push_sample(b"S1");
let w = Writer::from_stdout(&h, true, bcf::Format::Vcf).unwrap();
let mut rec = w.empty_record();
let rid = rec.header().name2rid(b"chr1").unwrap();
rec.set_rid(Some(rid));
rec.set_pos(100);
rec.set_alleles(&[b"A", b"T"]).unwrap();
rec
}
fn gt(record: &Record) -> Vec<GenotypeAllele> {
record.genotypes().unwrap().get(0).iter().copied().collect()
}
fn ps(record: &Record) -> Option<i32> {
record
.format(b"PS")
.integer()
.ok()
.and_then(|d| d.first().and_then(|s| s.first().copied()))
}
#[test]
fn synthesize_phased_h0_is_1_0() {
let mut rec = make_record();
let op = OutputPhasing::from_subcluster(Haplotype::H0, Some(42));
compute_statistics(&mut rec, None, Some(op)).unwrap();
assert_eq!(
gt(&rec),
vec![GenotypeAllele::Unphased(1), GenotypeAllele::Phased(0)]
);
assert_eq!(ps(&rec), Some(42));
}
#[test]
fn synthesize_phased_h1_is_0_1() {
let mut rec = make_record();
let op = OutputPhasing::from_subcluster(Haplotype::H1, Some(42));
compute_statistics(&mut rec, None, Some(op)).unwrap();
assert_eq!(
gt(&rec),
vec![GenotypeAllele::Unphased(0), GenotypeAllele::Phased(1)]
);
assert_eq!(ps(&rec), Some(42));
}
#[test]
fn synthesize_hom_alt_is_1_1_no_ps() {
let mut rec = make_record();
let op = OutputPhasing::from_subcluster(Haplotype::Both, None);
compute_statistics(&mut rec, None, Some(op)).unwrap();
assert_eq!(
gt(&rec),
vec![GenotypeAllele::Unphased(1), GenotypeAllele::Unphased(1)]
);
assert_eq!(ps(&rec), None);
}
#[test]
fn synthesize_unphased_single_het_is_0_1_no_ps() {
let mut rec = make_record();
let op = OutputPhasing::from_subcluster(Haplotype::H1, None);
compute_statistics(&mut rec, None, Some(op)).unwrap();
assert_eq!(
gt(&rec),
vec![GenotypeAllele::Unphased(0), GenotypeAllele::Unphased(1)]
);
assert_eq!(ps(&rec), None);
}
}