use std::cell::OnceCell;
use anyhow::bail;
use rust_htslib::bcf::{
Record,
header::{TagLength, TagType},
};
use tracing::error;
use crate::vcf::pipeline::message::MaskedRecords;
pub fn compute_statistics(
record: &mut Record,
old_records: Option<MaskedRecords>,
) -> anyhow::Result<()> {
let mut ctx = StatContext::new(old_records);
let stats: [Box<dyn Statistic>; 4] = [
Box::new(Genotype), Box::new(AlleleCount),
Box::new(AlleleFrequency),
Box::new(AlleleNumber),
];
for stat in stats {
if stat.should_add(record, &ctx) {
stat.add_to_record(record, &mut ctx)?;
}
}
Ok(())
}
trait Statistic {
fn should_add(&self, record: &Record, ctx: &StatContext) -> bool;
fn add_to_record(&self, record: &mut Record, ctx: &mut StatContext) -> anyhow::Result<()>;
}
struct StatContext<'r, 'c> {
num_alleles: OnceCell<u32>,
counts: OnceCell<anyhow::Result<Vec<i32>>>,
old_records: Option<MaskedRecords<'r, 'c>>,
}
impl<'r, 'c> StatContext<'r, 'c> {
fn new(old_records: Option<MaskedRecords<'r, 'c>>) -> Self {
Self {
counts: OnceCell::new(),
num_alleles: OnceCell::new(),
old_records,
}
}
fn get_allele_counts(&self, record: &Record) -> anyhow::Result<&Vec<i32>> {
self.counts
.get_or_init(|| count_alleles(record))
.as_ref()
.map_err(|e| anyhow::anyhow!("{e}"))
}
fn get_num_alleles(&self, record: &Record) -> u32 {
*self.num_alleles.get_or_init(|| record.allele_count() - 1)
}
}
fn has_header_info_field(record: &Record, tag: &[u8], expected: (TagType, TagLength)) -> bool {
matches!(record.header().info_type(tag), Ok(ty) if ty == expected)
}
fn has_header_format_field(record: &Record, tag: &[u8], expected: (TagType, TagLength)) -> bool {
matches!(record.header().format_type(tag), Ok(ty) if ty == expected)
}
fn count_alleles(record: &Record) -> anyhow::Result<Vec<i32>> {
let alt_alleles = record.allele_count() - 1;
let genotypes = record.genotypes()?;
let mut acs = vec![0; alt_alleles as usize];
for i in 0..record.sample_count() {
let gt = genotypes.get(i as usize);
for a in &*gt {
if let Some(index) = a.index() {
if index > 0 {
acs[index as usize - 1] += 1;
}
}
}
}
Ok(acs)
}
struct Genotype;
impl Statistic for Genotype {
fn should_add(&self, record: &Record, ctx: &StatContext) -> bool {
let to_add = has_header_format_field(record, b"GT", (TagType::String, TagLength::Fixed(1)))
&& ctx.old_records.is_some_and(|s| s.masked_len() > 0);
if !to_add {
error!("Not adding genotype field GT, this will cause errors!");
}
to_add
}
fn add_to_record(&self, record: &mut Record, ctx: &mut StatContext) -> anyhow::Result<()> {
let Some(tpl) = ctx.old_records.and_then(|r| r.masked_iter().next()) else {
bail!("Implementation error: there should be old records available.");
};
let gt = tpl.genotypes()?;
record.push_genotypes(>.get(0))?;
Ok(())
}
}
struct AlleleCount;
impl Statistic for AlleleCount {
fn should_add(&self, record: &Record, _ctx: &StatContext) -> bool {
has_header_info_field(record, b"AC", (TagType::Integer, TagLength::AltAlleles))
}
fn add_to_record(&self, record: &mut Record, ctx: &mut StatContext) -> anyhow::Result<()> {
let counts = ctx.get_allele_counts(record)?;
record.push_info_integer(b"AC", counts)?;
Ok(())
}
}
struct AlleleFrequency;
impl Statistic for AlleleFrequency {
fn should_add(&self, record: &Record, _ctx: &StatContext) -> bool {
has_header_info_field(record, b"AF", (TagType::Float, TagLength::AltAlleles))
}
#[allow(clippy::cast_precision_loss)]
fn add_to_record(&self, record: &mut Record, ctx: &mut StatContext) -> anyhow::Result<()> {
let counts = ctx.get_allele_counts(record)?;
let total = ctx.get_num_alleles(record) as f32;
let freqs: Vec<_> = counts.iter().map(|c| *c as f32 / total).collect();
record.push_info_float(b"AF", &freqs)?;
Ok(())
}
}
struct AlleleNumber;
impl Statistic for AlleleNumber {
fn should_add(&self, record: &Record, _ctx: &StatContext) -> bool {
has_header_info_field(record, b"AN", (TagType::Integer, TagLength::Fixed(1)))
}
fn add_to_record(&self, record: &mut Record, ctx: &mut StatContext) -> anyhow::Result<()> {
let total = i32::try_from(ctx.get_num_alleles(record))?;
record.push_info_integer(b"AN", &[total])?;
Ok(())
}
}