twitcher 0.4.0

Find template switch mutations in genomic data
use std::cell::OnceCell;

use anyhow::bail;
use rust_htslib::bcf::{
    Record,
    header::{TagLength, TagType},
    record::GenotypeAllele,
};
use tracing::error;

use crate::vcf::{pipeline::clusterizer::phasing::OutputPhasing, strings::VCF_LOCAL_PHASE_KEY};

pub fn compute_statistics(
    record: &mut Record,
    old_records: Option<&[Record]>,
    phasing: Option<OutputPhasing>,
) -> anyhow::Result<()> {
    let mut ctx = StatContext::new(old_records, phasing);

    let stats: [Box<dyn Statistic>; 6] = [
        Box::new(Genotype), // must be present to be valid.
        Box::new(PhaseSetting),
        Box::new(LocalPhase),
        Box::new(AlleleCount),
        Box::new(AlleleFrequency),
        Box::new(AlleleNumber),
    ];

    for stat in stats {
        if stat.should_add(record, &ctx) {
            stat.add_to_record(record, &mut ctx)?;
        }
    }

    Ok(())
}

trait Statistic {
    fn should_add(&self, record: &Record, ctx: &StatContext) -> bool;
    fn add_to_record(&self, record: &mut Record, ctx: &mut StatContext) -> anyhow::Result<()>;
}

// Holds information that is used for more than one statistic
struct StatContext<'r> {
    // (alt_counts, total_called_alleles)
    counts: OnceCell<anyhow::Result<(Vec<i32>, i32)>>,
    old_records: Option<&'r [Record]>,
    phasing: Option<OutputPhasing>,
}

impl<'r> StatContext<'r> {
    const fn new(old_records: Option<&'r [Record]>, phasing: Option<OutputPhasing>) -> Self {
        Self {
            counts: OnceCell::new(),
            old_records,
            phasing,
        }
    }

    fn get_allele_counts(&self, record: &Record) -> anyhow::Result<&Vec<i32>> {
        self.counts
            .get_or_init(|| count_alleles(record))
            .as_ref()
            .map(|(ac, _)| ac)
            .map_err(|e| anyhow::anyhow!("{e}"))
    }

    fn get_total_called(&self, record: &Record) -> anyhow::Result<i32> {
        self.counts
            .get_or_init(|| count_alleles(record))
            .as_ref()
            .map(|(_, total)| *total)
            .map_err(|e| anyhow::anyhow!("{e}"))
    }
}

fn has_header_info_field(record: &Record, tag: &[u8], expected: (TagType, TagLength)) -> bool {
    matches!(record.header().info_type(tag), Ok(ty) if ty == expected)
}

fn has_header_format_field(record: &Record, tag: &[u8], expected: (TagType, TagLength)) -> bool {
    matches!(record.header().format_type(tag), Ok(ty) if ty == expected)
}

/// Returns `(alt_counts, total_called_alleles)`.
/// `alt_counts[i]` is the count of ALT allele `i+1` across all samples.
/// `total_called_alleles` is the count of all non-missing alleles (REF + ALT).
fn count_alleles(record: &Record) -> anyhow::Result<(Vec<i32>, i32)> {
    let alt_alleles = record.allele_count() - 1;
    let genotypes = record.genotypes()?;
    let mut acs = vec![0i32; alt_alleles as usize];
    let mut total = 0i32;
    for i in 0..record.sample_count() {
        let gt = genotypes.get(i as usize);
        for a in &*gt {
            if let Some(index) = a.index() {
                total += 1;
                if index > 0
                    && let Some(x) = acs.get_mut(index as usize - 1)
                {
                    *x += 1;
                }
            }
        }
    }
    Ok((acs, total))
}

struct Genotype;

impl Statistic for Genotype {
    fn should_add(&self, record: &Record, ctx: &StatContext) -> bool {
        // TODO: For some reason, TagLength is not "Genotype" but "Fixed(1)". Why?
        let has_gt_header =
            has_header_format_field(record, b"GT", (TagType::String, TagLength::Fixed(1)));
        let to_add = has_gt_header
            && (ctx.phasing.is_some() || ctx.old_records.is_some_and(|s| !s.is_empty()));
        if !to_add {
            error!("Not adding genotype field GT, this will cause errors!");
        }
        to_add
    }

    fn add_to_record(&self, record: &mut Record, ctx: &mut StatContext) -> anyhow::Result<()> {
        match ctx.phasing {
            // The output record is always biallelic (slot value 0 = ref, 1 = the single
            // alt), so emit the orientation/phase carried by the sub-cluster directly.
            Some(op) => {
                #[allow(clippy::cast_possible_wrap)]
                let (a0, a1) = (op.alleles[0] as i32, op.alleles[1] as i32);
                let second = if op.is_phased() {
                    GenotypeAllele::Phased(a1)
                } else {
                    GenotypeAllele::Unphased(a1)
                };
                record.push_genotypes(&[GenotypeAllele::Unphased(a0), second])?;
            }
            None => {
                // Passthrough safety only; cluster outputs always carry phasing.
                copy_gt_from_old(record, ctx)?;
            }
        }
        Ok(())
    }
}

fn copy_gt_from_old(record: &mut Record, ctx: &StatContext<'_>) -> anyhow::Result<()> {
    let Some(tpl) = ctx.old_records.and_then(|r| r.first()) else {
        bail!("Implementation error: there should be old records available.");
    };
    let gt = tpl.genotypes()?;
    record.push_genotypes(&gt.get(0))?;
    Ok(())
}

/// Write the PS FORMAT field when the sub-cluster has a known phaseset.
struct PhaseSetting;

impl Statistic for PhaseSetting {
    fn should_add(&self, record: &Record, ctx: &StatContext) -> bool {
        ctx.phasing.is_some_and(|op| op.is_phased() && !op.is_hom())
            && has_header_format_field(record, b"PS", (TagType::Integer, TagLength::Fixed(1)))
    }

    fn add_to_record(&self, record: &mut Record, ctx: &mut StatContext) -> anyhow::Result<()> {
        let Some(op) = ctx.phasing else {
            bail!("Implementation error");
        };
        let phaseset = op.phaseset.unwrap_or(0);
        record.push_format_integer(b"PS", &[phaseset])?;
        Ok(())
    }
}

/// Emit the LOCALPHASE INFO flag when the phaseset was invented by local read-based phasing.
struct LocalPhase;

impl Statistic for LocalPhase {
    fn should_add(&self, _record: &Record, _ctx: &StatContext) -> bool {
        // local_phase tracking was removed; flag is never emitted.
        false
    }

    fn add_to_record(&self, record: &mut Record, _ctx: &mut StatContext) -> anyhow::Result<()> {
        record.push_info_flag(VCF_LOCAL_PHASE_KEY.as_bytes())?;
        Ok(())
    }
}

struct AlleleCount;

impl Statistic for AlleleCount {
    fn should_add(&self, record: &Record, _ctx: &StatContext) -> bool {
        has_header_info_field(record, b"AC", (TagType::Integer, TagLength::AltAlleles))
    }

    fn add_to_record(&self, record: &mut Record, ctx: &mut StatContext) -> anyhow::Result<()> {
        let counts = ctx.get_allele_counts(record)?;
        record.push_info_integer(b"AC", counts)?;
        Ok(())
    }
}

struct AlleleFrequency;

impl Statistic for AlleleFrequency {
    fn should_add(&self, record: &Record, _ctx: &StatContext) -> bool {
        has_header_info_field(record, b"AF", (TagType::Float, TagLength::AltAlleles))
    }

    #[allow(clippy::cast_precision_loss)]
    fn add_to_record(&self, record: &mut Record, ctx: &mut StatContext) -> anyhow::Result<()> {
        let counts = ctx.get_allele_counts(record)?;
        let total = ctx.get_total_called(record)? as f32;
        let freqs: Vec<_> = counts.iter().map(|c| *c as f32 / total).collect();
        record.push_info_float(b"AF", &freqs)?;
        Ok(())
    }
}

struct AlleleNumber;

impl Statistic for AlleleNumber {
    fn should_add(&self, record: &Record, _ctx: &StatContext) -> bool {
        has_header_info_field(record, b"AN", (TagType::Integer, TagLength::Fixed(1)))
    }

    fn add_to_record(&self, record: &mut Record, ctx: &mut StatContext) -> anyhow::Result<()> {
        let total = ctx.get_total_called(record)?;
        record.push_info_integer(b"AN", &[total])?;
        Ok(())
    }
}

#[cfg(test)]
mod tests {
    use rust_htslib::bcf::{self, Header, Writer, record::GenotypeAllele};

    use super::*;
    use crate::vcf::pipeline::clusterizer::phasing::Haplotype;

    fn make_record() -> Record {
        let mut h = Header::new();
        h.push_record(b"##contig=<ID=chr1,length=1000000>");
        h.push_record(b"##FORMAT=<ID=GT,Number=1,Type=String,Description=\"Genotype\">");
        h.push_record(b"##FORMAT=<ID=PS,Number=1,Type=Integer,Description=\"Phase set\">");
        h.push_sample(b"S1");
        let w = Writer::from_stdout(&h, true, bcf::Format::Vcf).unwrap();
        let mut rec = w.empty_record();
        let rid = rec.header().name2rid(b"chr1").unwrap();
        rec.set_rid(Some(rid));
        rec.set_pos(100);
        // Output records are always biallelic: a single synthetic alt.
        rec.set_alleles(&[b"A", b"T"]).unwrap();
        rec
    }

    fn gt(record: &Record) -> Vec<GenotypeAllele> {
        record.genotypes().unwrap().get(0).iter().copied().collect()
    }

    fn ps(record: &Record) -> Option<i32> {
        record
            .format(b"PS")
            .integer()
            .ok()
            .and_then(|d| d.first().and_then(|s| s.first().copied()))
    }

    #[test]
    fn synthesize_phased_h0_is_1_0() {
        let mut rec = make_record();
        let op = OutputPhasing::from_subcluster(Haplotype::H0, Some(42));
        compute_statistics(&mut rec, None, Some(op)).unwrap();
        assert_eq!(
            gt(&rec),
            vec![GenotypeAllele::Unphased(1), GenotypeAllele::Phased(0)]
        );
        assert_eq!(ps(&rec), Some(42));
    }

    #[test]
    fn synthesize_phased_h1_is_0_1() {
        let mut rec = make_record();
        let op = OutputPhasing::from_subcluster(Haplotype::H1, Some(42));
        compute_statistics(&mut rec, None, Some(op)).unwrap();
        assert_eq!(
            gt(&rec),
            vec![GenotypeAllele::Unphased(0), GenotypeAllele::Phased(1)]
        );
        assert_eq!(ps(&rec), Some(42));
    }

    #[test]
    fn synthesize_hom_alt_is_1_1_no_ps() {
        let mut rec = make_record();
        let op = OutputPhasing::from_subcluster(Haplotype::Both, None);
        compute_statistics(&mut rec, None, Some(op)).unwrap();
        assert_eq!(
            gt(&rec),
            vec![GenotypeAllele::Unphased(1), GenotypeAllele::Unphased(1)]
        );
        // Homozygous => no phase set emitted.
        assert_eq!(ps(&rec), None);
    }

    #[test]
    fn synthesize_unphased_single_het_is_0_1_no_ps() {
        // A single 1/2 (or 0/1) unphased het is split per-allele; each output is biallelic
        // and unphased, so it must reference only allele 1 (never the input's allele 2).
        let mut rec = make_record();
        let op = OutputPhasing::from_subcluster(Haplotype::H1, None);
        compute_statistics(&mut rec, None, Some(op)).unwrap();
        assert_eq!(
            gt(&rec),
            vec![GenotypeAllele::Unphased(0), GenotypeAllele::Unphased(1)]
        );
        assert_eq!(ps(&rec), None);
    }
}