twitcher 0.1.10

Find template switch mutations in genomic data
use std::{error::Error, fmt::Display, time::Duration};

use generic_a_star::AStarResult;

use itertools::Itertools as _;
use lib_tsalign::{
    a_star_aligner::{
        alignment_result::{AlignmentResult, AlignmentStatistics, alignment::Alignment},
        template_switch_distance::{AlignmentType, EqualCostRange},
    },
    costs::U64Cost,
};
use serde::{Deserialize, Serialize};

use crate::common::{
    aligner::fpa::FpaAlignmentStatistics,
    alignment::consumed_reference,
    coords::{GenomePosition, GenomeRegion},
};

/// The result of an alignment. This is the output of the whole `aligner` module.
pub type TwitcherAlignmentResult = Result<TwitcherAlignmentWithStatistics, AlignmentFailure>;

/// The result of a call to the aligner which completed successfully.
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct TwitcherAlignmentWithStatistics {
    pub alignment: AlignmentWithCost,
    pub stats: TwitcherAlignmentStatistics,
}

#[derive(Debug, Clone, Deserialize, Serialize)]
/// Statistics that are returned by the aligner
pub enum TwitcherAlignmentStatistics {
    FPAStats(Box<FpaAlignmentStatistics>),
    TSAlign(Box<AlignmentStatistics<U64Cost>>),
}

#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct AlignmentWithCost {
    pub alignment: Alignment<AlignmentType>,
    pub cost: U64Cost,
}

impl AlignmentWithCost {
    pub fn new(alignment: Alignment<AlignmentType>, cost: U64Cost) -> Self {
        Self { alignment, cost }
    }
    pub fn has_ts(&self) -> bool {
        self.alignment
            .iter_compact()
            .any(|(_, ty)| matches!(ty, AlignmentType::TemplateSwitchEntrance { .. }))
    }
}

/// The two ways an aligner can fail: expectedly or unexpectedly.
#[derive(Debug, Clone, Deserialize, Serialize)]
pub enum AlignmentFailure {
    SoftFailure { reason: SoftFailureReason },
    Error { error: String },
}

#[derive(Debug, Clone, Deserialize, Serialize)]
pub enum SoftFailureReason {
    OutOfMemory,
    Timeout(Duration),
    Other(String),
}

impl TwitcherAlignmentWithStatistics {
    pub fn new(
        alignment: Alignment<AlignmentType>,
        cost: U64Cost,
        stats: TwitcherAlignmentStatistics,
    ) -> Self {
        Self {
            alignment: AlignmentWithCost { alignment, cost },
            stats,
        }
    }

    pub fn has_ts(&self) -> bool {
        self.alignment.has_ts()
    }
}

impl TwitcherAlignmentStatistics {
    pub fn reference_offset(&self) -> usize {
        match self {
            TwitcherAlignmentStatistics::FPAStats(fpa_alignment_statistics) => {
                fpa_alignment_statistics.ranges.reference_offset()
            }
            TwitcherAlignmentStatistics::TSAlign(alignment_statistics) => {
                alignment_statistics.reference_offset
            }
        }
    }

    pub fn query_offset(&self) -> usize {
        match self {
            TwitcherAlignmentStatistics::FPAStats(fpa_alignment_statistics) => {
                fpa_alignment_statistics.ranges.query_offset()
            }
            TwitcherAlignmentStatistics::TSAlign(alignment_statistics) => {
                alignment_statistics.query_offset
            }
        }
    }
}

impl AlignmentFailure {
    pub fn oom() -> Self {
        AlignmentFailure::SoftFailure {
            reason: SoftFailureReason::OutOfMemory,
        }
    }

    pub fn timeout(duration: Duration) -> Self {
        AlignmentFailure::SoftFailure {
            reason: SoftFailureReason::Timeout(duration),
        }
    }

    pub fn soft_fail<S: ToString>(error: &S) -> Self {
        AlignmentFailure::SoftFailure {
            reason: SoftFailureReason::Other(error.to_string()),
        }
    }

    pub fn error<S: ToString + ?Sized>(error: &S) -> Self {
        AlignmentFailure::Error {
            error: error.to_string(),
        }
    }
}

pub fn from_tsalign(
    tsalign_result: AlignmentResult<AlignmentType, U64Cost>,
) -> TwitcherAlignmentResult {
    let (alignment_cigar, mut statistics) = match tsalign_result {
        AlignmentResult::WithTarget {
            alignment,
            statistics,
        } => (Some(alignment), statistics),
        AlignmentResult::WithoutTarget { statistics } => (None, statistics),
    };

    // Remove the sequences from the statistics, since we handle sequences throughout twitcher as shared references (Arc<[u8]>) elsewhere.
    statistics.sequences.reference = String::new();
    statistics.sequences.reference_rc = String::new();
    statistics.sequences.query = String::new();
    statistics.sequences.query_rc = String::new();

    if let Some(alignment_cigar) = alignment_cigar {
        let cost = statistics.result.cost();
        Ok(TwitcherAlignmentWithStatistics::new(
            alignment_cigar,
            cost,
            statistics.into(),
        ))
    } else {
        let reason = match statistics.result {
            AStarResult::FoundTarget { .. } => {
                anyhow::anyhow!("There should be an alignment available, but it is not reported")
            }
            AStarResult::ExceededCostLimit { cost_limit } => {
                anyhow::anyhow!("Exceeded cost limit. Lower bound for the cost: {cost_limit}")
            }
            AStarResult::ExceededMemoryLimit { .. } => {
                return Err(AlignmentFailure::oom());
            }
            AStarResult::NoTarget => {
                anyhow::anyhow!("There was no target (implementation error?)")
            }
        };
        Err(AlignmentFailure::soft_fail(&reason))
    }
}

impl From<FpaAlignmentStatistics> for TwitcherAlignmentStatistics {
    fn from(value: FpaAlignmentStatistics) -> Self {
        Self::FPAStats(value.into())
    }
}

impl From<AlignmentStatistics<U64Cost>> for TwitcherAlignmentStatistics {
    fn from(value: AlignmentStatistics<U64Cost>) -> Self {
        Self::TSAlign(value.into())
    }
}

impl From<anyhow::Error> for AlignmentFailure {
    fn from(value: anyhow::Error) -> Self {
        Self::Error {
            error: value.to_string(),
        }
    }
}

impl From<std::io::Error> for AlignmentFailure {
    fn from(value: std::io::Error) -> Self {
        Self::Error {
            error: value.to_string(),
        }
    }
}

impl Display for AlignmentFailure {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            AlignmentFailure::SoftFailure {
                reason: SoftFailureReason::OutOfMemory,
            } => write!(f, "Expected failure: Out of memory."),
            AlignmentFailure::SoftFailure {
                reason: SoftFailureReason::Timeout(dur),
            } => write!(f, "Expected failure: Timeout ({dur:?})."),
            AlignmentFailure::SoftFailure {
                reason: SoftFailureReason::Other(e),
            } => write!(f, "Expected failure: {e}."),
            AlignmentFailure::Error { error } => write!(f, "Unexpected failure: {error}"),
        }
    }
}

impl Error for AlignmentFailure {}

#[derive(Debug)]
pub struct TSData {
    pub inner_len: usize,
    pub jump_1_2: isize,
    pub er: EqualCostRange,
    pub inner_aln: Alignment<AlignmentType>,
    pub apg: isize,
    pub pos_1: GenomePosition,
    pub pos_4: GenomePosition,
}

impl TSData {
    pub fn compute(
        region: &GenomeRegion,
        result: &TwitcherAlignmentWithStatistics,
    ) -> anyhow::Result<Vec<TSData>> {
        let alignment = &result.alignment.alignment;

        let mut pos = region.start().clone();

        let mut results = Vec::new();
        let mut curr_data = None;
        let mut curr_ts_primary = None;
        for (n, ty) in alignment.iter_compact() {
            pos += consumed_reference(n, ty, curr_ts_primary)?;
            match ty {
                AlignmentType::TemplateSwitchEntrance {
                    first_offset,
                    equal_cost_range,
                    primary,
                    ..
                } => {
                    curr_data = Some(TSData {
                        inner_len: 0, // will change
                        jump_1_2: *first_offset,
                        er: *equal_cost_range,
                        inner_aln: Alignment::new(), // will change
                        apg: 0,                      // will change
                        pos_1: pos.clone(),
                        pos_4: pos.clone(), // Will change
                    });
                    curr_ts_primary = Some(*primary);
                }
                AlignmentType::TemplateSwitchExit { anti_primary_gap } => {
                    let Some(mut current) = curr_data.take() else {
                        panic!("Invalid alignment");
                    };
                    current.apg = *anti_primary_gap;
                    current.pos_4 = pos.clone();
                    results.push(current);
                    curr_ts_primary = None;
                }
                AlignmentType::SecondaryInsertion
                | AlignmentType::SecondarySubstitution
                | AlignmentType::SecondaryMatch => {
                    let Some(current) = &mut curr_data else {
                        panic!("Invalid alignment");
                    };
                    current.inner_len += n;
                    current.inner_aln.push_n(n, *ty);
                }
                AlignmentType::SecondaryDeletion => {
                    let Some(current) = &mut curr_data else {
                        panic!("Invalid alignment");
                    };
                    current.inner_aln.push_n(n, *ty);
                }
                _ => {}
            }
        }
        Ok(results)
    }

    pub fn to_field<S: ToString>(data: &[Self], sep: &str, f: impl Fn(&Self) -> S) -> String {
        data.iter().map(|d| f(d).to_string()).join(sep)
    }
}