twitcher 0.1.8

Find template switch mutations in genomic data
use std::sync::Arc;

use anyhow::Context;
use bstr::ByteSlice;
use rust_htslib::bam::{HeaderView, Record, record::Cigar};
use tokio::sync::mpsc::Sender;
use tokio_stream::{Stream, wrappers::ReceiverStream};
use tracing::error;

use crate::common::{
    ImmutableSequence,
    cluster_settings::{ClusteringSettings, variant_complexity},
    contig::ContigName,
    coords::{GenomePosition, GenomeRegion},
    reference::ReferenceReader,
};

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Cluster {
    pub region: GenomeRegion,
    pub read_range: std::ops::Range<usize>,
}

#[derive(Debug)]
struct WorkingCluster {
    ref_start: usize,
    ref_end: usize,
    read_start: usize,
    read_end: usize,
    complexity: f64,
    event_count: usize,
    matches_since_last_mut: usize,
}

impl WorkingCluster {
    fn new(ref_pos: usize, read_pos: usize) -> Self {
        Self {
            ref_start: ref_pos,
            ref_end: ref_pos,
            read_start: read_pos,
            read_end: read_pos,
            complexity: 0.0,
            event_count: 0,
            matches_since_last_mut: 0,
        }
    }

    fn is_valid(&self, settings: &ClusteringSettings) -> bool {
        let span = self.ref_end.abs_diff(self.ref_start) as f64 + 1.0;
        settings.is_valid_cluster(self.complexity, span, self.event_count)
    }
}

#[derive(Debug)]
struct ClusterFinder {
    record: Arc<Record>,
    ref_seq: ImmutableSequence,
    chr: ContigName,
    settings: ClusteringSettings,
    ref_start: usize,
    ref_pos: usize,
    read_pos: usize,
    expected_read_len: usize,
    current: Option<WorkingCluster>,
    out: Sender<Cluster>,
}

impl ClusterFinder {
    fn new(
        record: Arc<Record>,
        ref_seq: ImmutableSequence,
        chr: ContigName,
        settings: ClusteringSettings,
        out: Sender<Cluster>,
    ) -> anyhow::Result<Self> {
        let ref_pos = usize::try_from(record.pos())?;
        Ok(Self {
            record,
            ref_seq,
            chr,
            settings,
            ref_start: ref_pos,
            ref_pos,
            read_pos: 0,
            expected_read_len: 0,
            current: None,
            out,
        })
    }

    async fn find(mut self) -> anyhow::Result<()> {
        for cig in &self.record.cigar() {
            match *cig {
                Cigar::Match(n) => {
                    let mut n = n as usize;
                    while n > 0 {
                        let qry_seq = self.record.seq();
                        let ref_char = self.ref_seq[self.ref_pos - self.ref_start];
                        let qry_char = qry_seq[self.read_pos];
                        let is_equal = qry_char == ref_char;

                        let mut run_len = 1;
                        while run_len < n {
                            let r = self.ref_seq[self.ref_pos + run_len - self.ref_start];
                            let q = qry_seq[self.read_pos + run_len];
                            if (q == r) != is_equal {
                                break;
                            }
                            run_len += 1;
                        }

                        if is_equal {
                            self.consume_equal(run_len).await?;
                        } else {
                            self.consume_diff(run_len);
                        }
                        n -= run_len;
                    }
                }
                Cigar::Equal(n) => self.consume_equal(n as usize).await?,
                Cigar::Diff(n) => self.consume_diff(n as usize),
                Cigar::Ins(n) => self.consume_ins(n as usize),
                Cigar::Del(n) => self.consume_del(n as usize),
                Cigar::RefSkip(n) => self.consume_refskip(n as usize).await?,
                Cigar::SoftClip(n) => self.consume_softclip(n as usize).await?,
                Cigar::HardClip(_) | Cigar::Pad(_) => self.flush().await?,
            }
        }
        self.flush().await?;
        Ok(())
    }

    async fn consume_equal(&mut self, n: usize) -> anyhow::Result<()> {
        if let Some(c) = self.current.as_mut() {
            if c.matches_since_last_mut + n > self.settings.max_gap {
                self.flush().await?;
            } else {
                c.matches_since_last_mut += n;
                c.ref_end += n;
                c.read_end += n;
            }
        }
        self.expected_read_len += n;
        self.ref_pos += n;
        self.read_pos += n;
        Ok(())
    }

    fn consume_diff(&mut self, n: usize) {
        // Each contiguous mismatch run is one event; complexity accumulates per-base (SNP = 1.0).
        let c = self
            .current
            .get_or_insert_with(|| WorkingCluster::new(self.ref_pos, self.read_pos));
        c.complexity += variant_complexity(1, 1) * n as f64;
        c.event_count += 1;
        c.matches_since_last_mut = 0;
        c.ref_end += n;
        c.read_end += n;

        self.expected_read_len += n;
        self.ref_pos += n;
        self.read_pos += n;
    }

    fn consume_ins(&mut self, n: usize) {
        let c = self
            .current
            .get_or_insert_with(|| WorkingCluster::new(self.ref_pos, self.read_pos));
        c.complexity += variant_complexity(1, n + 1);
        c.event_count += 1;
        c.matches_since_last_mut = 0;
        c.read_end += n;

        self.expected_read_len += n;
        self.read_pos += n;
    }

    fn consume_del(&mut self, n: usize) {
        let c = self
            .current
            .get_or_insert_with(|| WorkingCluster::new(self.ref_pos, self.read_pos));
        c.complexity += variant_complexity(n + 1, 1);
        c.event_count += 1;
        c.matches_since_last_mut = 0;
        c.ref_end += n;

        self.ref_pos += n;
    }

    async fn consume_softclip(&mut self, n: usize) -> anyhow::Result<()> {
        self.flush().await?;
        self.expected_read_len += n;
        self.read_pos += n;
        Ok(())
    }

    async fn consume_refskip(&mut self, n: usize) -> anyhow::Result<()> {
        self.flush().await?;
        self.ref_pos += n;
        Ok(())
    }

    async fn flush(&mut self) -> anyhow::Result<()> {
        if let Some(c) = self.current.take() {
            if c.is_valid(&self.settings) {
                let start = GenomePosition::new_0(self.chr.clone(), c.ref_start);
                let end = GenomePosition::new_0(self.chr.clone(), c.ref_end);

                let cluster = Cluster {
                    region: GenomeRegion::from_incl_excl(start, Some(end)).unwrap(),
                    read_range: c.read_start..c.read_end,
                };
                self.out.send(cluster).await?;
            }
        }
        Ok(())
    }
}

pub async fn find_clusters(
    header: &HeaderView,
    record: Arc<Record>,
    reference_reader: &ReferenceReader,
    settings: ClusteringSettings,
) -> anyhow::Result<impl Stream<Item = Cluster>> {
    let chr = u32::try_from(record.tid())
        .ok()
        .map(|tid| ContigName::new(header.tid2name(tid)))
        .context("Cannot find clusters without target id (can't extract reference)")?;
    let rec_reg = GenomeRegion::try_from_bam_record(&record, chr.clone())?;
    let ref_seq = reference_reader.get_seq_exact_unmasked(rec_reg).await?;
    let (tx, rx) = tokio::sync::mpsc::channel(32);
    tokio::spawn(async move {
        let Ok(finder) = ClusterFinder::new(record.clone(), ref_seq, chr, settings, tx)
            .inspect_err(|e| error!("{e}"))
        else {
            return;
        };
        let ret = finder.find().await;
        if let Err(e) = ret {
            error!(
                "While finding clusters on {}: {e}",
                record.qname().as_bstr()
            );
        }
    });
    Ok(ReceiverStream::new(rx))
}

#[cfg(test)]
mod tests {
    use std::sync::Arc;

    use rust_htslib::bam::{
        Record,
        record::{Cigar, CigarString},
    };
    use tokio_stream::StreamExt as _;

    use crate::common::cluster_settings::ClusteringSettings;

    use super::{Cluster, ClusterFinder};

    /// Runs `ClusterFinder` directly on a record built from the given CIGAR and a dummy reference
    /// sequence. Only `Equal`, `Diff`, `Ins`, and `Del` ops are used — these never index into the
    /// reference buffer, so the content of `ref_seq` is irrelevant.
    async fn clusters_from_cigar(
        cigar: Vec<Cigar>,
        pos: i64,
        settings: ClusteringSettings,
    ) -> Vec<Cluster> {
        let cigar_string = CigarString(cigar);
        let read_len: usize = cigar_string
            .iter()
            .map(|op| match op {
                Cigar::Equal(n) | Cigar::Diff(n) | Cigar::Ins(n) | Cigar::SoftClip(n) => {
                    *n as usize
                }
                _ => 0,
            })
            .sum();
        let ref_len: usize = cigar_string
            .iter()
            .map(|op| match op {
                Cigar::Equal(n)
                | Cigar::Diff(n)
                | Cigar::Del(n)
                | Cigar::RefSkip(n)
                | Cigar::Match(n) => *n as usize,
                _ => 0,
            })
            .sum();

        let mut record = Record::new();
        let seq = vec![b'A'; read_len];
        let qual = vec![40u8; read_len];
        record.set(b"test", Some(&cigar_string), &seq, &qual);
        record.set_pos(pos);
        let record = Arc::new(record);

        let chr = crate::common::contig::ContigName::new(b"chr1");
        let ref_seq: crate::common::ImmutableSequence = vec![b'A'; ref_len].into();

        let (tx, rx) = tokio::sync::mpsc::channel(32);
        let finder = ClusterFinder::new(record, ref_seq, chr, settings, tx).unwrap();
        finder.find().await.unwrap();

        let mut stream = tokio_stream::wrappers::ReceiverStream::new(rx);
        let mut results = Vec::new();
        while let Some(c) = stream.next().await {
            results.push(c);
        }
        results
    }

    #[tokio::test]
    async fn two_diff_runs_close_together_form_cluster() {
        // 5= 2X 3= 2X 5=: two mismatch events separated by 3 matches (within max_gap)
        let cigar = vec![
            Cigar::Equal(5),
            Cigar::Diff(2),
            Cigar::Equal(3),
            Cigar::Diff(2),
            Cigar::Equal(5),
        ];
        let clusters = clusters_from_cigar(cigar, 0, ClusteringSettings::default()).await;
        assert_eq!(clusters.len(), 1);
        let c = &clusters[0];
        // Cluster begins at the first mismatch (ref offset 5 from pos=0)
        assert_eq!(c.region.start().position_0(), 5);
        // Trailing equal ops extend both boundaries
        assert_eq!(c.region.end_excl().unwrap().position_0(), 17);
        assert_eq!(c.read_range, 5..17);
    }

    #[tokio::test]
    async fn diff_runs_separated_by_overlong_gap_are_rejected() {
        // 3X 25= 3X: the 25-match gap exceeds max_gap=20, breaking the cluster each time.
        // Each resulting fragment has only 1 event — below min_records=2 → both discarded.
        let cigar = vec![Cigar::Diff(3), Cigar::Equal(25), Cigar::Diff(3)];
        let clusters = clusters_from_cigar(cigar, 0, ClusteringSettings::default()).await;
        assert_eq!(clusters.len(), 0);
    }

    #[tokio::test]
    async fn insertion_between_diffs_contributes_complexity() {
        // 2= 1X 1I 1X 2= at pos=10: SNP, 1-base insertion, SNP — all within gap
        let cigar = vec![
            Cigar::Equal(2),
            Cigar::Diff(1),
            Cigar::Ins(1),
            Cigar::Diff(1),
            Cigar::Equal(2),
        ];
        let clusters = clusters_from_cigar(cigar, 10, ClusteringSettings::default()).await;
        assert_eq!(clusters.len(), 1);
        let c = &clusters[0];
        // Cluster starts two bases into the read (after the leading 2=), at absolute ref pos 12
        assert_eq!(c.region.start().position_0(), 12);
        assert_eq!(c.read_range, 2..7);
    }

    #[tokio::test]
    async fn single_diff_event_rejected_below_min_records() {
        // 3= 1X 3=: only one event, below the default min_records=2
        let cigar = vec![Cigar::Equal(3), Cigar::Diff(1), Cigar::Equal(3)];
        let clusters = clusters_from_cigar(cigar, 0, ClusteringSettings::default()).await;
        assert_eq!(clusters.len(), 0);
    }

    #[tokio::test]
    async fn density_filter_rejects_sparse_cluster() {
        // 1X 20= 1X: exactly 20 matches between events — not over max_gap, so gap is NOT broken.
        // But complexity=2.0 over span=23 → density≈0.087, below min_density=0.2.
        let cigar = vec![Cigar::Diff(1), Cigar::Equal(20), Cigar::Diff(1)];
        let clusters = clusters_from_cigar(cigar, 0, ClusteringSettings::default()).await;
        assert_eq!(clusters.len(), 0);
    }
}