twitcher 0.1.8

Find template switch mutations in genomic data
use std::cmp::Ordering;

use async_stream::stream;
use bstr::ByteSlice;
use rust_htslib::{
    bam::Read as BamRead,
    bcf::{self, Read, record::GenotypeAllele},
};
use tokio_stream::Stream;
use tracing::{debug, error, trace};

use crate::{
    common::{coords::GenomeRegion, list_of_regions::Regions},
    counter,
    vcf::pipeline::{clusterizer::cluster::ClusteringSettings, reader::VCFReader},
};

struct State {
    strategy: ClusteringSettings,
    record_buf: bcf::Record,
    buffer: Vec<bcf::Record>,
    last_rid: Option<u32>,
}

impl State {
    fn new(strategy: ClusteringSettings, record_buf: bcf::Record) -> Self {
        Self {
            strategy,
            record_buf,
            buffer: vec![],
            last_rid: None,
        }
    }

    fn consume_record(&mut self, record: bcf::Record) -> Option<ClusterOrRecords> {
        let rid = record.rid();
        let record_belongs_to_current_buffer = self
            .strategy
            .belongs(&self.buffer, &record)
            .unwrap_or(false);
        if rid != self.last_rid {
            self.last_rid = rid;
            let ret = self.flush();
            if record_belongs_to_current_buffer {
                self.buffer.push(record);
            }
            ret
        } else if record_belongs_to_current_buffer {
            self.buffer.push(record);
            None
        } else {
            let ret = self.flush();
            self.buffer.push(record);
            ret
        }
    }

    fn flush(&mut self) -> Option<ClusterOrRecords> {
        if self.strategy.is_cluster(&self.buffer) {
            let buf = std::mem::take(&mut self.buffer);
            Some(ClusterOrRecords::Cluster(buf))
        } else if !self.buffer.is_empty() {
            let buf = std::mem::take(&mut self.buffer);
            Some(ClusterOrRecords::Records(buf))
        } else {
            None
        }
    }
}

pub enum ClusterOrRecords {
    Cluster(Vec<bcf::Record>),
    Records(Vec<bcf::Record>),
    SequenceDone,
}

impl std::fmt::Debug for ClusterOrRecords {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            ClusterOrRecords::Records(records) => write!(f, "{} records", records.len()),
            ClusterOrRecords::Cluster(records) => write!(f, "cluster of {}", records.len()),
            ClusterOrRecords::SequenceDone => write!(f, "sequence done"),
        }
    }
}

fn phase_cluster(
    bam_reader: &mut rust_htslib::bam::IndexedReader,
    cluster: Vec<bcf::Record>,
) -> ClusterOrRecords {
    for rec in &cluster {
        if rec.allele_count() > 2 {
            error!("Record with too many alleles... discarding cluster");
            return ClusterOrRecords::Records(cluster);
        }
    }
    // Location for pileup. The left-most record in the cluster.
    let a = cluster
        .iter()
        .map(rust_htslib::bcf::Record::pos)
        .min()
        .unwrap();

    // All elements in cluster should have the same rid, so It's sufficient to get the first one.
    let chrom = cluster[0].rid().unwrap();

    bam_reader.fetch((chrom, a, a + 1)).unwrap();

    // Accumulate the number of reads that support the cluster in hits.
    // And compare to magic number to either accept or reject cluster.
    let mut hits = 0;
    let magic_number = 4;

    debug!("Phasing cluster {}, {}", a, chrom);
    for rec in &cluster {
        let alleles = rec.alleles();
        trace!(
            "{}, {}, {}",
            alleles[0].as_bstr(),
            alleles[1].as_bstr(),
            rec.pos() - a
        );
    }
    let u_a: u32 = a.try_into().unwrap();
    // Access the first column of the pileup (in a for loop to ensure iterator stays alive)
    // Note the break at the end of the loop
    for p in bam_reader.pileup() {
        let Ok(p) = p else {
            continue;
        };
        if p.pos() != u_a {
            continue;
        }
        // Access all the reads in the pileup column.
        for aln in p.alignments() {
            if hits >= magic_number {
                break;
            }
            let q_pos = aln.qpos();
            if q_pos.is_none() {
                continue;
            }
            let off: u32 = q_pos.unwrap().try_into().unwrap();
            let read = aln.record();
            let cig = read.cigar();
            let r_seq = read.seq();
            let seq = r_seq.as_bytes();
            let mut off: usize = off.try_into().unwrap();
            // Match cigar string to catch insertions and deletions.
            let indel_match = cluster.iter().all(|rec| {
                let ref_s = rec.alleles()[0];
                let var_s = rec.alleles()[1];
                let ref_l = ref_s.len() as u32;
                let var_l = var_s.len() as u32;
                let pos: u32 = rec.pos().try_into().unwrap();
                match ref_l.cmp(&var_l) {
                    Ordering::Equal => true,
                    Ordering::Less | Ordering::Greater =>
                    // Insertion or deletion, anchor positions before and after vcf record need to be used.
                    {
                        match cig.read_pos(pos - 1, true, true) {
                            Ok(Some(a_pos)) => match cig.read_pos(pos + ref_l, true, true) {
                                Ok(Some(b_pos)) => a_pos + 1 + var_l == b_pos,
                                _ => false,
                            },
                            _ => false,
                        }
                    }
                }
            });
            if !indel_match {
                continue;
            }
            // Match actual read sequence to check modifications and actual values.
            // Just check bytes in variants against what should be corresponding bytes in the read.
            let c_match = cluster.iter().all(|rec| {
                let s_off: usize = (rec.pos() - a).try_into().unwrap();
                let ref_s = rec.alleles()[0];
                let var_s = rec.alleles()[1];
                let ref_l = ref_s.len();
                let var_l = var_s.len();
                match ref_l.cmp(&var_l) {
                    Ordering::Less => {
                        // Insertion
                        for i in 0..var_l {
                            if i + off + s_off >= seq.len() || var_s[i] != seq[i + off + s_off] {
                                return false;
                            }
                        }
                        off += var_l - ref_l;
                    }
                    Ordering::Equal => {
                        // Just changed
                        for i in 0..var_l {
                            if i + off + s_off >= seq.len() {
                                return false;
                            }
                            if var_s[i] != seq[i + off + s_off] {
                                return false;
                            }
                        }
                    }
                    Ordering::Greater => {
                        // Deletion
                        for i in 0..var_l {
                            if i + off + s_off >= seq.len() {
                                return false;
                            }
                            if seq[i + off + s_off] != var_s[i] {
                                return false;
                            }
                        }
                        off += var_l - ref_l;
                    }
                }

                true
            });
            hits += i32::from(c_match);
        }
        break;
    }

    if hits >= 4 {
        debug!("Found consensus");
        ClusterOrRecords::Cluster(cluster)
    } else {
        debug!("No consensus... discarding cluster");
        ClusterOrRecords::Records(cluster)
    }
}

pub(super) fn find_clusters(
    strategy: ClusteringSettings,
    mut reader: VCFReader,
    targets: Option<Regions>,
    mut bam_f: Option<rust_htslib::bam::IndexedReader>,
) -> impl Stream<Item = ClusterOrRecords>
where
{
    let record_buf = reader.empty_record();

    let mut state = State::new(strategy, record_buf);

    stream! {
        loop {
            let read =
                tokio::task::block_in_place(|| bcf::Read::read(&mut reader, &mut state.record_buf));

            let event = match read {
                Some(Ok(())) => {
                    let reg = GenomeRegion::try_from(&state.record_buf);
                    if targets
                        .as_ref()
                        .is_none_or(|t| reg.is_ok_and(|r| t.contains(&r)))
                    {
                        state.consume_record(state.record_buf.clone())
                    } else {
                        None // Hole because excluded by targets
                    }
                }
                None if state.buffer.is_empty() => {
                    if reader.done() {
                        // the iterator is finished
                        break;
                    }
                    // there might be more to get, e.g. the next region.
                    Some(ClusterOrRecords::SequenceDone)
                }
                None => state.flush(),
                Some(Err(e)) => {
                    error!("Error reading record: {e}");
                    None // we have no entry here, but continue reading
                }
            };

            // Ignore holes
            let Some(event) = event else {
                continue;
            };

            // Check phasing
            let event = match event {
                ClusterOrRecords::Cluster(cluster) => {
                    counter!("clusters").inc(1);
                    let already_phase_resolved = cluster.windows(2).all(|elems| {
                        let gt1 = elems[0].genotypes().unwrap().get(0);
                        let gt2 = elems[1].genotypes().unwrap().get(0);
                        gt1[0] != gt1[1] && gt1 == gt2 && matches!(gt1[1], GenotypeAllele::Phased(_))
                    }) || cluster.windows(2).all(|elems| {
                        let gt1 = elems[0].genotypes().unwrap().get(0);
                        let gt2 = elems[1].genotypes().unwrap().get(0);
                        gt1[0] == gt1[1] && gt1 == gt2 && gt1[0].index() != Some(0)
                    }) || cluster.len() == 1;

                    if already_phase_resolved {
                        counter!("clusters.trivial").inc(1);
                        ClusterOrRecords::Cluster(cluster)
                    } else if let Some(bam_f_r) = &mut bam_f {
                        match phase_cluster(bam_f_r, cluster) {
                            c @ ClusterOrRecords::Cluster(_) => {
                                counter!("clusters.phased.successfully").inc(1);
                                c
                            }
                            r @ ClusterOrRecords::Records(_) => {
                                counter!("clusters.phased.unsuccessfully").inc(1);
                                r
                            }
                            d => d,
                        }
                    } else {
                        ClusterOrRecords::Records(cluster)
                    }
                }
                records_or_done => records_or_done,
            };
            yield event;
        }
    }
}