twitcher 0.1.8

Find template switch mutations in genomic data
use anyhow::{Context, Error};
use iterator::{ClusterOrRecords, find_clusters};
use lib_tsalign::a_star_aligner::alignment_geometry::{AlignmentCoordinates, AlignmentRange};
use rust_htslib::bcf;
use tokio::{pin, sync::mpsc};
use tokio_stream::StreamExt;
use tracing::{error, info_span, instrument, warn};

use crate::{
    common::{
        ImmutableSequence, SequencePair,
        aligner::{AlignmentOrchestrator, AlignmentQuery, cli::CliAlignmentArgs},
        contig::ContigName,
        coords::{GenomePosition, GenomeRegion},
        list_of_regions::Regions,
        reference::{ReferenceQueryResult, ReferenceReader},
    },
    vcf::pipeline::{
        message::{Cluster, MaskedRecords},
        reader::VCFReader,
    },
};

use super::Message;

pub mod cluster;
mod iterator;

pub struct Clusterizer {
    input: VCFReader,
    targets: Option<Regions>,
    reference: ReferenceReader,
    bam_f: Option<rust_htslib::bam::IndexedReader>,
    output: mpsc::Sender<Message>,
    settings: ClusterSettings,
}

#[derive(clap::Args, Debug, Default)]
pub struct ClusterSettings {
    /// Select how clusters are pre-selected. Currently, the default option is the only sensible one, and the other one is for debug purposes.
    #[command(flatten)]
    pub cluster_strategy: cluster::ClusteringSettings,

    #[command(flatten)]
    pub aligner: CliAlignmentArgs,
}

impl Clusterizer {
    pub(super) fn new(
        input: VCFReader,
        reference: ReferenceReader,
        targets: Option<Regions>,
        bam_f: Option<rust_htslib::bam::IndexedReader>,
        output: mpsc::Sender<Message>,
        settings: ClusterSettings,
    ) -> Self {
        Self {
            input,
            targets,
            reference,
            bam_f,
            output,
            settings,
        }
    }

    pub async fn run(self) {
        let aligner = match AlignmentOrchestrator::try_from(&self.settings.aligner) {
            Ok(aligner) => aligner,
            Err(e) => {
                error!("Cannot initialize aligners: {e}");
                return;
            }
        };

        let strategy = self.settings.cluster_strategy;

        let clusters = find_clusters(strategy, self.input, self.targets, self.bam_f);

        pin!(clusters);

        while let Some(e) = clusters.next().await {
            let m = match e {
                ClusterOrRecords::Cluster(records) => {
                    let cluster = Cluster::new(bitvec::bitvec![1; records.len()]);
                    let masked = cluster.apply_to_records(&records);

                    match prepare_cluster(&self.reference, masked, self.settings.aligner.padding)
                        .await
                    {
                        Ok(Some(prepared)) => {
                            match aligner.get_or_compute_alignment(
                                self.reference.get_name(),
                                &prepared.reference_region.clone(),
                                prepared.cluster_region,
                                prepared.query.clone(),
                            ) {
                                Ok(pending) => Message::cluster(
                                    records,
                                    vec![(
                                        cluster,
                                        prepared.reference_region.start().clone(),
                                        prepared.query.sequences,
                                        pending,
                                    )],
                                ),
                                Err(e) => {
                                    error!(
                                        "Could not get or start the computation of an alignment: {e}"
                                    );
                                    Message::passthrough(records)
                                }
                            }
                        }
                        Ok(None) => Message::passthrough(records),
                        Err(err) => {
                            error!(
                                "Could not prepare cluster for realignment: {err}. Passing records through.",
                            );
                            Message::passthrough(records)
                        }
                    }
                }
                ClusterOrRecords::Records(records) => Message::passthrough(records),
                ClusterOrRecords::SequenceDone => {
                    aligner.clear_cache();
                    continue;
                }
            };
            self.output.send(m).await.expect("Channel closed !?");
        }
    }
}

#[derive(Clone, Debug)]
struct PreparedCluster {
    query: AlignmentQuery,
    reference_region: GenomeRegion,
    cluster_region: GenomeRegion,
}

async fn prepare_cluster(
    reference: &ReferenceReader,
    cluster: MaskedRecords<'_, '_>,
    padding: usize,
) -> Result<Option<PreparedCluster>, Error> {
    let query_region = extract_region(cluster)?;
    let Some(ReferenceQueryResult {
        region: actual_region,
        sequence: reference_sequence,
        ..
    }) = reference
        .get_seq(query_region.clone(), padding, padding)
        .await?
    else {
        return Ok(None);
    };

    let _span = info_span!("prepare_cluster", pos = %query_region).entered();

    let actual_padding_left =
        query_region.start().position_0() - actual_region.start().position_0();
    let actual_padding_right =
        reference_sequence.len() - actual_padding_left - query_region.len().unwrap();

    let query_sequence = apply_mutations(
        &reference_sequence,
        actual_region.start().position_0(),
        cluster.masked_iter(),
    )?;

    let ranges = AlignmentRange::new_offset_limit(
        AlignmentCoordinates::new(actual_padding_left, actual_padding_left),
        AlignmentCoordinates::new(
            reference_sequence.len() - actual_padding_right,
            query_sequence.len() - actual_padding_right,
        ),
    );
    Ok(Some(PreparedCluster {
        query: AlignmentQuery {
            sequences: SequencePair {
                reference: reference_sequence,
                query: query_sequence,
            },
            ranges,
        },
        reference_region: actual_region,
        cluster_region: query_region,
    }))
}

fn extract_region(cluster: MaskedRecords<'_, '_>) -> anyhow::Result<GenomeRegion> {
    let (start, end, contig) = {
        let (mut start, mut end, mut contig) = (i64::MAX, 0, None);
        for r in cluster.masked_iter() {
            start = start.min(r.pos());
            end = end.max(r.end());
            if contig.is_none() {
                if let Some(present) = r.rid() {
                    let name = r.header().rid2name(present)?;
                    contig = Some(ContigName::new(name));
                }
            }
        }
        (
            usize::try_from(start)?,
            usize::try_from(end)?,
            contig.context("No rid present in records??")?,
        )
    };
    let query_region = GenomeRegion::new_bounded(GenomePosition::new_0(contig, start), end - start);
    Ok(query_region)
}

/// Build the alt sequence.
#[instrument(name = "build_query_sequence", skip_all)]
fn apply_mutations<'a>(
    reference: &[u8],
    reference_start: usize,
    mutations: impl Iterator<Item = &'a bcf::Record>,
) -> anyhow::Result<ImmutableSequence> {
    let mut result = Vec::new();
    let mut last_pos = reference_start;
    let reference_end = reference_start + reference.len();

    for m in mutations {
        let pos = usize::try_from(m.pos())?;
        let end = usize::try_from(m.end())?;

        if pos < last_pos || end > reference_end || pos > end {
            warn!(
                "Overlapping Mutations -- last mutation ends at {last_pos}, but this starts at {}",
                pos
            );
            continue;
        }

        result.extend_from_slice(&reference[last_pos - reference_start..pos - reference_start]);

        let alt = m.alleles()[1]; // TODO don't use hardcoded alleles!
        result.extend_from_slice(alt);
        last_pos = end;
    }

    result.extend_from_slice(&reference[last_pos - reference_start..]);

    Ok(result.into())
}