twitcher 0.1.10

Find template switch mutations in genomic data
use std::{fmt::Debug, io::Write, sync::Arc};

use bstr::ByteSlice;
use compact_genome::{
    implementation::{alphabets::dna_alphabet_or_n::DnaAlphabetOrN, vec_sequence::VectorGenome},
    interface::sequence::{GenomeSequence, OwnedGenomeSequence},
};
use csv::Writer;
use generic_a_star::cost::AStarCost;
use lib_tsalign::{
    a_star_aligner::{
        alignment_result::alignment::Alignment, template_switch_distance::AlignmentType,
    },
    config::TemplateSwitchConfig,
    costs::U64Cost,
};
use serde::{Deserialize, Serialize};
use tracing::{debug, instrument, trace};

use crate::common::{
    ImmutableSequence, SequencePair,
    aligner::result::{TSData, TwitcherAlignmentWithStatistics},
    alignment::{ForwardAlignment, consumed_query},
    coords::GenomeRegion,
};

#[derive(Serialize, Deserialize)]
pub struct CSVRecord {
    pub cluster_id: String, // Some sort of running id (maybe even unique across files?), to make it easy to connect vcf and bam output.
    pub ref_ctx_region: String, // from aux data: entire region
    pub alt_ctx_region: String, // from aux data: entire region
    pub cluster_region: String, // from aux data: the region of the cluster, i.e. the "focus" region, will be a subregion of `region`
    pub ref_cluster_offset: usize,
    // pub ref_cluster_limit: usize,
    pub alt_cluster_offset: usize,
    // pub alt_cluster_limit: usize,
    pub ts_1_4_region: String, // TODO add 1-4 region for the template switches, might be useful for dedup
    pub read_id: Option<String>, // null if vcf
    pub fw_cigar: String,      // for cluster region
    pub fw_cigar_ctx: String,  // for context region
    pub fw_mi_ctx: String,     // for context region
    pub fw_cost: u64,          // for cluster region
    pub fw_cost_ctx: u64,      // for context region
    pub ts_cigar: String,      // for cluster region
    pub ts_cigar_ctx: String,  // for context region
    pub ts_cost: u64,          // for cluster region
    pub ts_cost_ctx: u64,      // for context region
    pub ts_num: usize,
    pub ts_1_2: String,
    pub ts_2_3: String,
    pub ts_1_4: String,
    pub ts_start_left_shift: String,
    pub ts_start_right_shift: String,
    pub ts_end_left_shift: String,
    pub ts_end_right_shift: String,
    pub ts_inner_alignment_cigar: String,
    // TODO some measure of repetition in the inner TS part? could be useful
}

static SEP: &str = "|";

pub struct TwitcherCSVWriter(Writer<Box<dyn Write + Send>>);

pub struct CSVAuxData {
    pub cluster_id: String,
    pub sequences: SequencePair,
    pub ref_context_region: GenomeRegion,
    pub alt_context_region: GenomeRegion,
    pub cluster_region: GenomeRegion,
    // pub ranges: AlignmentRange,
    pub alt_id: Option<String>,
    pub forward_alignment: ForwardAlignment,
    pub cost: Arc<TemplateSwitchConfig<DnaAlphabetOrN, U64Cost>>,
}

impl Debug for CSVAuxData {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("CSVAuxData")
            .field("cluster_id", &self.cluster_id)
            .field("sequences", &self.sequences)
            .field("ref_context_region", &self.ref_context_region)
            .field("alt_context_region", &self.alt_context_region)
            .field("cluster_region", &self.cluster_region)
            // .field("ref range", &self.re)
            // .field("alt range", &self.ranges.query_range())
            .field("alt_id", &self.alt_id)
            .field("forward_alignment", &self.forward_alignment.cigar())
            .field("cost", &"<cost def>")
            .finish()
    }
}

impl TwitcherCSVWriter {
    pub fn new(write: Box<dyn Write + Send>) -> Self {
        Self(Writer::from_writer(write))
    }

    #[instrument(name = "write_csv_record", skip_all, fields(pos = %aux.cluster_region, src = aux.alt_id))]
    pub fn write(
        &mut self,
        result: &TwitcherAlignmentWithStatistics,
        aux: CSVAuxData,
    ) -> anyhow::Result<()> {
        debug!("Writing CSV Record");

        trace!(
            ref_len = aux.sequences.reference.len(),
            query_len = aux.sequences.query.len(),
            ref_offset = result.stats.reference_offset(),
            query_offset = result.stats.query_offset(),
            fw_cigar = aux.forward_alignment.cigar(),
            "Input sequences and forward alignment"
        );

        trace!(
            ref = %aux.sequences.reference.as_bstr(),
            alt = %aux.sequences.query.as_bstr(),
        );

        let entire_forward_alignment = &aux.forward_alignment;
        let entire_forward_alignment_cost = compute_cost(
            &aux.sequences,
            None,
            entire_forward_alignment.0.clone(),
            &aux.cost,
        )?;
        trace!(
            cost = entire_forward_alignment_cost,
            "Computed entire forward alignment cost"
        );

        let entire_ts_alignment = aux.forward_alignment.insert_ts_alignment(result)?;
        trace!(
            cigar = entire_ts_alignment.cigar(),
            "Entire TS alignment after insert"
        );
        let entire_ts_alignment_cost =
            compute_cost(&aux.sequences, None, entire_ts_alignment.clone(), &aux.cost)?;
        trace!(
            cost = entire_ts_alignment_cost,
            "Computed entire TS alignment cost"
        );

        let cluster_forward_alignment = aux.forward_alignment.crop_to_ts_region(result)?;
        trace!(
            cigar = cluster_forward_alignment.cigar(),
            ref_offset = result.stats.reference_offset(),
            query_offset = result.stats.query_offset(),
            ref_remaining = aux
                .sequences
                .reference
                .len()
                .saturating_sub(result.stats.reference_offset()),
            query_remaining = aux
                .sequences
                .query
                .len()
                .saturating_sub(result.stats.query_offset()),
            "Cluster forward alignment after crop"
        );
        let cluster_forward_alignment_cost = compute_cost(
            &aux.sequences,
            Some((result.stats.reference_offset(), result.stats.query_offset())),
            cluster_forward_alignment.clone(),
            &aux.cost,
        )?;
        trace!(
            cost = cluster_forward_alignment_cost,
            "Computed cluster forward alignment cost"
        );

        let cluster_ts_alignment = &result.alignment.alignment;
        let cluster_ts_alignment_cost = result.alignment.cost.as_primitive();
        trace!(
            cigar = cluster_ts_alignment.cigar(),
            cost = cluster_ts_alignment_cost,
            "Cluster TS alignment"
        );

        let ts_datas = TSData::compute(&aux.cluster_region, result)?;
        let ts_1_4_region = TSData::to_field(&ts_datas, SEP, |d| {
            if d.pos_1 <= d.pos_4 {
                GenomeRegion::from_incl_incl(d.pos_1.clone(), Some(d.pos_4.clone())).unwrap()
            } else {
                GenomeRegion::from_incl_incl(d.pos_4.clone(), Some(d.pos_1.clone())).unwrap()
            }
        });
        let record = CSVRecord {
            cluster_id: aux.cluster_id,
            ref_ctx_region: aux.ref_context_region.to_string(),
            alt_ctx_region: aux.alt_context_region.to_string(),
            cluster_region: aux.cluster_region.to_string(),
            ts_1_4_region,
            read_id: aux.alt_id,
            fw_cigar: cluster_forward_alignment.cigar(),
            fw_cigar_ctx: entire_forward_alignment.cigar(),
            fw_mi_ctx: create_mi_string(entire_forward_alignment, &aux.sequences.query),
            fw_cost: cluster_forward_alignment_cost,
            fw_cost_ctx: entire_forward_alignment_cost,
            ts_cigar: cluster_ts_alignment.cigar(),
            ts_cigar_ctx: entire_ts_alignment.cigar(),
            ts_cost: cluster_ts_alignment_cost,
            ts_cost_ctx: entire_ts_alignment_cost,
            ts_num: ts_datas.len(),
            ts_1_2: TSData::to_field(&ts_datas, SEP, |d| d.jump_1_2),
            ts_2_3: TSData::to_field(&ts_datas, SEP, |d| d.inner_len),
            ts_1_4: TSData::to_field(&ts_datas, SEP, |d| d.apg),
            ts_start_left_shift: TSData::to_field(&ts_datas, SEP, |d| -d.er.min_start),
            ts_start_right_shift: TSData::to_field(&ts_datas, SEP, |d| d.er.max_start),
            ts_end_left_shift: TSData::to_field(&ts_datas, SEP, |d| -d.er.min_end),
            ts_end_right_shift: TSData::to_field(&ts_datas, SEP, |d| d.er.max_end),
            ts_inner_alignment_cigar: TSData::to_field(&ts_datas, SEP, |d| d.inner_aln.cigar()),
            ref_cluster_offset: result.stats.reference_offset(),
            // ref_cluster_limit: aux.ranges.reference_limit(),
            alt_cluster_offset: result.stats.query_offset(),
            // alt_cluster_limit: aux.ranges.query_limit(),
        };

        // Write to csv
        tokio::task::block_in_place(|| self.0.serialize(record))?;

        Ok(())
    }

    pub fn flush(&mut self) -> std::io::Result<()> {
        self.0.flush()
    }
}

fn compute_cost(
    sequences: &SequencePair,
    sequence_offsets: Option<(usize, usize)>,
    mut alignment: Alignment<AlignmentType>,
    costs: &TemplateSwitchConfig<DnaAlphabetOrN, U64Cost>,
) -> anyhow::Result<u64> {
    let (r_off, q_off) = sequence_offsets.unwrap_or_default();
    // TODO this is copying data, which is a bit stupid but at the moment unavoidable due to the interfaces in lib_tsalign
    let r = VectorGenome::<DnaAlphabetOrN>::from_slice_u8(&sequences.reference[r_off..])?;
    let q = VectorGenome::from_slice_u8(&sequences.query[q_off..])?;
    // dbg!((r_off, q_off));
    // dbg!(sequences.reference[r_off..].len());
    // dbg!(sequences.reference[..].len());
    // dbg!(sequences.reference.as_bstr());
    // dbg!(sequences.query[q_off..].len());
    // dbg!(sequences.query[..].len());
    // dbg!(sequences.query.as_bstr());
    // dbg!(alignment.cigar());
    let cost = alignment.compute_cost(
        r.as_genome_subsequence(),
        q.as_genome_subsequence(),
        0,
        0,
        costs,
    );
    Ok(cost.as_primitive())
}

fn create_mi_string(forward_alignment: &ForwardAlignment, query: &ImmutableSequence) -> String {
    let mut qi = 0;
    let mut result = String::new();
    let mut counter = 0;
    for ty in forward_alignment.iter_flat() {
        match ty {
            AlignmentType::PrimaryInsertion | AlignmentType::PrimarySubstitution => {
                if counter > 0 {
                    result.push_str(&counter.to_string());
                    counter = 0;
                }
                result.push(char::from_u32(query[qi].into()).unwrap());
            }
            AlignmentType::PrimaryMatch => {
                counter += 1;
            }
            _ => {}
        }
        qi = (qi as isize + consumed_query(1, ty, None).unwrap()) as usize;
    }

    result
}