twitcher 0.1.10

Find template switch mutations in genomic data
use anyhow::{Context, bail};
use iterator::{ClusterOrRecords, find_clusters};
use lib_tsalign::a_star_aligner::{
    alignment_geometry::{AlignmentCoordinates, AlignmentRange},
    alignment_result::alignment::Alignment,
    template_switch_distance::AlignmentType,
};
use rust_htslib::bcf;
use tokio::{pin, sync::mpsc};
use tokio_stream::StreamExt;
use tracing::{error, instrument, warn, warn_span};

use crate::{
    common::{
        ImmutableSequence, SequencePair,
        aligner::{AlignmentOrchestrator, AlignmentQuery, cli::CliAlignmentArgs},
        alignment::ForwardAlignment,
        contig::ContigName,
        coords::{GenomePosition, GenomeRegion},
        csv::CSVAuxData,
        list_of_regions::Regions,
        reference::{ReferenceQueryResult, ReferenceReader},
    },
    vcf::pipeline::{
        message::{Cluster, MaskedRecords},
        reader::VCFReader,
    },
};

use super::Message;

pub mod cluster;
mod iterator;

pub struct Clusterizer {
    input: VCFReader,
    targets: Option<Regions>,
    reference: ReferenceReader,
    bam_f: Option<rust_htslib::bam::IndexedReader>,
    output: mpsc::Sender<Message>,
    settings: ClusterSettings,
}

#[derive(clap::Args, Debug, Default)]
pub struct ClusterSettings {
    /// Select how clusters are pre-selected. Currently, the default option is the only sensible one, and the other one is for debug purposes.
    #[command(flatten)]
    pub cluster_strategy: cluster::ClusteringSettings,

    #[command(flatten)]
    pub aligner: CliAlignmentArgs,
}

impl Clusterizer {
    pub(super) fn new(
        input: VCFReader,
        reference: ReferenceReader,
        targets: Option<Regions>,
        bam_f: Option<rust_htslib::bam::IndexedReader>,
        output: mpsc::Sender<Message>,
        settings: ClusterSettings,
    ) -> Self {
        Self {
            input,
            targets,
            reference,
            bam_f,
            output,
            settings,
        }
    }

    pub async fn run(self) {
        let aligner = match AlignmentOrchestrator::try_from(&self.settings.aligner) {
            Ok(aligner) => aligner,
            Err(e) => {
                error!("Cannot initialize aligners: {e}");
                return;
            }
        };

        let strategy = self.settings.cluster_strategy;

        let vcf_name = self.input.name().to_string();
        let clusters = find_clusters(strategy, self.input, self.targets, self.bam_f);

        pin!(clusters);
        let mut cluster_id = 0;

        while let Some(e) = clusters.next().await {
            let m = match e {
                ClusterOrRecords::Cluster(records) => {
                    cluster_id += 1;
                    let cluster = Cluster::new(bitvec::bitvec![1; records.len()]);
                    let masked = cluster.apply_to_records(&records);

                    match prepare_cluster(&self.reference, masked, self.settings.aligner.padding)
                        .await
                    {
                        Ok(Some(prepared)) => {
                            match aligner.get_or_compute_alignment(
                                self.reference.get_name(),
                                &prepared.reference_region.clone(),
                                prepared.cluster_region.clone(),
                                prepared.query.clone(),
                            ) {
                                Ok(pending) => Message::cluster(
                                    records,
                                    vec![(
                                        cluster,
                                        pending,
                                        CSVAuxData {
                                            cluster_id: cluster_id.to_string(),
                                            sequences: prepared.query.sequences,
                                            ref_context_region: prepared.reference_region.clone(),
                                            alt_context_region: prepared.reference_region,
                                            cluster_region: prepared.cluster_region,
                                            alt_id: Some(vcf_name.clone()),
                                            forward_alignment: ForwardAlignment(
                                                prepared.fw_alignment,
                                            ),
                                            cost: aligner.costs.clone(),
                                        },
                                    )],
                                ),
                                Err(e) => {
                                    error!(
                                        "Could not get or start the computation of an alignment: {e}"
                                    );
                                    Message::passthrough(records)
                                }
                            }
                        }
                        Ok(None) => Message::passthrough(records),
                        Err((reg, err)) => {
                            let _s =
                                reg.map(|r| warn_span!("Failed to prepare", pos = %r).entered());
                            warn!("{err}. Passing records through.",);
                            Message::passthrough(records)
                        }
                    }
                }
                ClusterOrRecords::Records(records) => Message::passthrough(records),
                ClusterOrRecords::SequenceDone => {
                    aligner.clear_cache();
                    continue;
                }
            };
            self.output.send(m).await.expect("Channel closed !?");
        }
    }
}

#[derive(Clone, Debug)]
struct PreparedCluster {
    query: AlignmentQuery,
    reference_region: GenomeRegion,
    cluster_region: GenomeRegion,
    fw_alignment: Alignment<AlignmentType>,
}

async fn prepare_cluster(
    reference: &ReferenceReader,
    cluster: MaskedRecords<'_, '_>,
    padding: usize,
) -> Result<Option<PreparedCluster>, (Option<GenomeRegion>, anyhow::Error)> {
    let query_region = extract_region(cluster).map_err(|e| (None, e))?;
    let Some(ReferenceQueryResult {
        region: actual_region,
        sequence: reference_sequence,
        ..
    }) = reference
        .get_seq(query_region.clone(), padding, padding)
        .await
        .map_err(|e| (Some(query_region.clone()), e))?
    else {
        return Ok(None);
    };

    let _span = warn_span!("prepare_cluster", pos = %query_region).entered();

    let actual_padding_left =
        query_region.start().position_0() - actual_region.start().position_0();
    let actual_padding_right =
        reference_sequence.len() - actual_padding_left - query_region.len().unwrap();

    let (query_sequence, fw_alignment) = apply_mutations(
        &reference_sequence,
        actual_region.start().position_0(),
        cluster.masked_iter(),
    )
    .map_err(|e| (Some(query_region.clone()), e))?;

    let ranges = AlignmentRange::new_offset_limit(
        AlignmentCoordinates::new(actual_padding_left, actual_padding_left),
        AlignmentCoordinates::new(
            reference_sequence.len() - actual_padding_right,
            query_sequence.len() - actual_padding_right,
        ),
    );

    Ok(Some(PreparedCluster {
        query: AlignmentQuery {
            sequences: SequencePair {
                reference: reference_sequence,
                query: query_sequence,
            },
            ranges,
        },
        reference_region: actual_region,
        cluster_region: query_region,
        fw_alignment,
    }))
}

fn extract_region(cluster: MaskedRecords<'_, '_>) -> anyhow::Result<GenomeRegion> {
    let (start, end, contig) = {
        let (mut start, mut end, mut contig) = (i64::MAX, 0, None);
        for r in cluster.masked_iter() {
            start = start.min(r.pos());
            end = end.max(r.end());
            if contig.is_none()
                && let Some(present) = r.rid() {
                    let name = r.header().rid2name(present)?;
                    contig = Some(ContigName::new(name));
                }
        }
        (
            usize::try_from(start)?,
            usize::try_from(end)?,
            contig.context("No rid present in records??")?,
        )
    };
    let query_region = GenomeRegion::new_bounded(GenomePosition::new_0(contig, start), end - start);
    Ok(query_region)
}

/// Build the alt sequence.
#[instrument(name = "build_query_sequence", skip_all)]
fn apply_mutations<'a>(
    reference: &[u8],
    reference_start: usize,
    mutations: impl Iterator<Item = &'a bcf::Record>,
) -> anyhow::Result<(ImmutableSequence, Alignment<AlignmentType>)> {
    let mut alt_sequence = Vec::new();
    let mut cigar = Alignment::new();
    let mut last_end = reference_start;
    let reference_end = reference_start + reference.len();

    let pos2off = |pos: usize| pos - reference_start;

    for m in mutations {
        let mut pos = usize::try_from(m.pos())?;
        let end = usize::try_from(m.end())?;
        let alleles = m.alleles();
        if alleles.len() <= 1 {
            warn!("Record at {pos} has no alt allele");
            continue;
        }
        let mut ref_allele = alleles[0];
        let mut alt_allele = alleles[1];

        if !ref_allele.is_empty() && !alt_allele.is_empty() && ref_allele[0] == alt_allele[0] {
            // vcf insertion / deletion:
            // A -> AC or AC -> A
            //
            // We remove the "anchor" to get the raw insertion.
            ref_allele = &ref_allele[1..];
            alt_allele = &alt_allele[1..];
            pos += 1;
        }

        if ref_allele.is_empty() && alt_allele.is_empty() {
            warn!("Empty record; Skipping.");
            continue;
        }

        if pos < last_end {
            bail!(
                "Overlapping mutation: last mutation ended at {last_end}, \
                 but this one starts at {pos}"
            );
        }

        if end > reference_end {
            warn!(
                "Skipping out-of-bounds mutation at {pos}: \
                 record end {end} exceeds reference end {reference_end}"
            );
            continue;
        }

        if pos > end {
            warn!("Skip negative-sized mutation: the position is {pos} but the end is {end}",);
            continue;
        }

        if pos - last_end > 0 {
            alt_sequence.extend_from_slice(&reference[pos2off(last_end)..pos2off(pos)]);
            cigar.push_n(pos - last_end, AlignmentType::PrimaryMatch);
        }

        alt_sequence.extend_from_slice(alt_allele); // TODO should we take anything else than the first alt allele?

        match (ref_allele.len(), alt_allele.len()) {
            (0, 0) => {
                bail!("Empty record; should have been caught earlier!");
            }
            // TODO all of the following cases can actually be seen as only the n,m case but for clarity we leave these here.
            (1, 1) => {
                // SNV
                cigar.push(AlignmentType::PrimarySubstitution);
            }
            (0, m) => {
                // insertion
                cigar.push_n(m, AlignmentType::PrimaryInsertion);
            }
            (n, 0) => {
                // deletion
                cigar.push_n(n, AlignmentType::PrimaryDeletion);
            }
            (n, m) => {
                // TODO perhaps run some simple local aligner here? For now, we're gonna put match/mismatch and ins/del
                let match_len = n.min(m);
                for i in 0..match_len {
                    if ref_allele[i] == alt_allele[i] {
                        cigar.push(AlignmentType::PrimaryMatch);
                    } else {
                        cigar.push(AlignmentType::PrimarySubstitution);
                    }
                }

                let extra = n.max(m) - match_len;
                if n > m {
                    cigar.push_n(extra, AlignmentType::PrimaryDeletion);
                } else {
                    cigar.push_n(extra, AlignmentType::PrimaryInsertion);
                }
            }
        }

        last_end = end;
    }

    alt_sequence.extend_from_slice(&reference[pos2off(last_end)..]);
    cigar.push_n(
        reference.len() - pos2off(last_end),
        AlignmentType::PrimaryMatch,
    );

    Ok((alt_sequence.into(), cigar))
}