twitcher 0.1.8

Find template switch mutations in genomic data
use std::cell::OnceCell;

use anyhow::bail;
use rust_htslib::bcf::{
    Record,
    header::{TagLength, TagType},
};
use tracing::error;

use crate::vcf::pipeline::message::MaskedRecords;

pub fn compute_statistics(
    record: &mut Record,
    old_records: Option<MaskedRecords>,
) -> anyhow::Result<()> {
    let mut ctx = StatContext::new(old_records);

    let stats: [Box<dyn Statistic>; 4] = [
        Box::new(Genotype), // must be present to be valid.
        Box::new(AlleleCount),
        Box::new(AlleleFrequency),
        Box::new(AlleleNumber),
    ];

    for stat in stats {
        if stat.should_add(record, &ctx) {
            stat.add_to_record(record, &mut ctx)?;
        }
    }

    Ok(())
}

trait Statistic {
    fn should_add(&self, record: &Record, ctx: &StatContext) -> bool;
    fn add_to_record(&self, record: &mut Record, ctx: &mut StatContext) -> anyhow::Result<()>;
}

// Holds information that is used for more than one statistic
struct StatContext<'r, 'c> {
    num_alleles: OnceCell<u32>,
    counts: OnceCell<anyhow::Result<Vec<i32>>>,
    old_records: Option<MaskedRecords<'r, 'c>>,
}

impl<'r, 'c> StatContext<'r, 'c> {
    fn new(old_records: Option<MaskedRecords<'r, 'c>>) -> Self {
        Self {
            counts: OnceCell::new(),
            num_alleles: OnceCell::new(),
            old_records,
        }
    }

    fn get_allele_counts(&self, record: &Record) -> anyhow::Result<&Vec<i32>> {
        self.counts
            .get_or_init(|| count_alleles(record))
            .as_ref()
            .map_err(|e| anyhow::anyhow!("{e}"))
    }

    fn get_num_alleles(&self, record: &Record) -> u32 {
        *self.num_alleles.get_or_init(|| record.allele_count() - 1)
    }
}

fn has_header_info_field(record: &Record, tag: &[u8], expected: (TagType, TagLength)) -> bool {
    matches!(record.header().info_type(tag), Ok(ty) if ty == expected)
}

fn has_header_format_field(record: &Record, tag: &[u8], expected: (TagType, TagLength)) -> bool {
    matches!(record.header().format_type(tag), Ok(ty) if ty == expected)
}

fn count_alleles(record: &Record) -> anyhow::Result<Vec<i32>> {
    let alt_alleles = record.allele_count() - 1;
    let genotypes = record.genotypes()?;
    let mut acs = vec![0; alt_alleles as usize];
    for i in 0..record.sample_count() {
        let gt = genotypes.get(i as usize);
        for a in &*gt {
            if let Some(index) = a.index() {
                if index > 0 {
                    acs[index as usize - 1] += 1;
                }
            }
        }
    }
    Ok(acs)
}

struct Genotype;

impl Statistic for Genotype {
    fn should_add(&self, record: &Record, ctx: &StatContext) -> bool {
        // TODO: For some reason, TagLength is not "Genotype" but "Fixed(1)". Why?
        let to_add = has_header_format_field(record, b"GT", (TagType::String, TagLength::Fixed(1)))
            && ctx.old_records.is_some_and(|s| s.masked_len() > 0);
        if !to_add {
            error!("Not adding genotype field GT, this will cause errors!");
        }
        to_add
    }

    fn add_to_record(&self, record: &mut Record, ctx: &mut StatContext) -> anyhow::Result<()> {
        let Some(tpl) = ctx.old_records.and_then(|r| r.masked_iter().next()) else {
            bail!("Implementation error: there should be old records available.");
        };
        let gt = tpl.genotypes()?;
        record.push_genotypes(&gt.get(0))?;
        Ok(())
    }
}

struct AlleleCount;

impl Statistic for AlleleCount {
    fn should_add(&self, record: &Record, _ctx: &StatContext) -> bool {
        has_header_info_field(record, b"AC", (TagType::Integer, TagLength::AltAlleles))
    }

    fn add_to_record(&self, record: &mut Record, ctx: &mut StatContext) -> anyhow::Result<()> {
        let counts = ctx.get_allele_counts(record)?;
        record.push_info_integer(b"AC", counts)?;
        Ok(())
    }
}

struct AlleleFrequency;

impl Statistic for AlleleFrequency {
    fn should_add(&self, record: &Record, _ctx: &StatContext) -> bool {
        has_header_info_field(record, b"AF", (TagType::Float, TagLength::AltAlleles))
    }

    #[allow(clippy::cast_precision_loss)]
    fn add_to_record(&self, record: &mut Record, ctx: &mut StatContext) -> anyhow::Result<()> {
        let counts = ctx.get_allele_counts(record)?;
        let total = ctx.get_num_alleles(record) as f32;
        let freqs: Vec<_> = counts.iter().map(|c| *c as f32 / total).collect();
        record.push_info_float(b"AF", &freqs)?;
        Ok(())
    }
}

struct AlleleNumber;

impl Statistic for AlleleNumber {
    fn should_add(&self, record: &Record, _ctx: &StatContext) -> bool {
        has_header_info_field(record, b"AN", (TagType::Integer, TagLength::Fixed(1)))
    }

    fn add_to_record(&self, record: &mut Record, ctx: &mut StatContext) -> anyhow::Result<()> {
        let total = i32::try_from(ctx.get_num_alleles(record))?;
        record.push_info_integer(b"AN", &[total])?;
        Ok(())
    }
}