twitcher 0.1.8

Find template switch mutations in genomic data
use std::{error::Error, fmt::Display, time::Duration};

use generic_a_star::{AStarResult, cost::AStarCost};

use itertools::Itertools as _;
use lib_tsalign::{
    a_star_aligner::{
        alignment_result::{AlignmentResult, AlignmentStatistics, alignment::Alignment},
        template_switch_distance::{AlignmentType, EqualCostRange},
    },
    costs::U64Cost,
};
use serde::{Deserialize, Serialize};

use crate::common::aligner::fpa::FpaAlignmentStatistics;

/// The result of an alignment. This is the output of the whole `aligner` module.
pub type TwitcherAlignmentResult = Result<TwitcherAlignment, AlignmentFailure>;

/// The result of a call to the aligner which completed successfully.
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct TwitcherAlignment {
    pub result: TwitcherAlignmentCase,
    pub stats: TwitcherAlignmentStatistics,
}

#[derive(Debug, Clone, Deserialize, Serialize)]
/// Statistics that are returned by the aligner
pub enum TwitcherAlignmentStatistics {
    FPAStats(Box<FpaAlignmentStatistics>),
    TSAlign(Box<AlignmentStatistics<U64Cost>>),
}

#[derive(Debug, Clone, Deserialize, Serialize)]
/// The two outcomes of a realignment: Either we found a template switch, or we didn't.
pub enum TwitcherAlignmentCase {
    FoundTS {
        alignment_with_ts: Alignment<AlignmentType>,
        cost_with_ts: U64Cost,
        cost_without_ts: Option<U64Cost>,
    },
    NoTS {
        #[allow(dead_code)]
        alignment_without_ts: Alignment<AlignmentType>,
        cost_without_ts: U64Cost,
    },
}
impl TwitcherAlignmentCase {}

/// The two ways an aligner can fail: expectedly or unexpectedly.
#[derive(Debug, Clone, Deserialize, Serialize)]
pub enum AlignmentFailure {
    SoftFailure { reason: SoftFailureReason },
    Error { error: String },
}

#[derive(Debug, Clone, Deserialize, Serialize)]
pub enum SoftFailureReason {
    OutOfMemory,
    Timeout(Duration),
    Other(String),
}

impl TwitcherAlignment {
    pub fn found_ts(
        alignment: Alignment<AlignmentType>,
        cost_with_ts: U64Cost,
        cost_without_ts: Option<U64Cost>,
        stats: TwitcherAlignmentStatistics,
    ) -> Self {
        TwitcherAlignment {
            result: TwitcherAlignmentCase::FoundTS {
                alignment_with_ts: alignment,
                cost_with_ts,
                cost_without_ts,
            },
            stats,
        }
    }

    pub fn no_ts(
        alignment: Alignment<AlignmentType>,
        cost_without_ts: U64Cost,
        stats: TwitcherAlignmentStatistics,
    ) -> Self {
        TwitcherAlignment {
            result: TwitcherAlignmentCase::NoTS {
                alignment_without_ts: alignment,
                cost_without_ts,
            },
            stats,
        }
    }

    pub fn has_ts(&self) -> bool {
        matches!(self.result, TwitcherAlignmentCase::FoundTS { .. })
    }
}

impl TwitcherAlignmentStatistics {
    pub fn reference_offset(&self) -> usize {
        match self {
            TwitcherAlignmentStatistics::FPAStats(fpa_alignment_statistics) => {
                fpa_alignment_statistics.ranges.reference_offset()
            }
            TwitcherAlignmentStatistics::TSAlign(alignment_statistics) => {
                alignment_statistics.reference_offset
            }
        }
    }

    pub fn query_offset(&self) -> usize {
        match self {
            TwitcherAlignmentStatistics::FPAStats(fpa_alignment_statistics) => {
                fpa_alignment_statistics.ranges.query_offset()
            }
            TwitcherAlignmentStatistics::TSAlign(alignment_statistics) => {
                alignment_statistics.query_offset
            }
        }
    }
}

impl AlignmentFailure {
    pub fn oom() -> Self {
        AlignmentFailure::SoftFailure {
            reason: SoftFailureReason::OutOfMemory,
        }
    }

    pub fn timeout(duration: Duration) -> Self {
        AlignmentFailure::SoftFailure {
            reason: SoftFailureReason::Timeout(duration),
        }
    }

    pub fn soft_fail<S: ToString>(error: &S) -> Self {
        AlignmentFailure::SoftFailure {
            reason: SoftFailureReason::Other(error.to_string()),
        }
    }

    pub fn error<S: ToString + ?Sized>(error: &S) -> Self {
        AlignmentFailure::Error {
            error: error.to_string(),
        }
    }
}

pub fn from_tsalign(
    with_ts: AlignmentResult<AlignmentType, U64Cost>,
    without_ts: Option<AlignmentResult<AlignmentType, U64Cost>>,
) -> TwitcherAlignmentResult {
    let (ts_alignment_cigar, mut statistics) = match with_ts {
        AlignmentResult::WithTarget {
            alignment,
            statistics,
        } => (Some(alignment), statistics),
        AlignmentResult::WithoutTarget { statistics } => (None, statistics),
    };

    // Remove the sequences from the statistics, since we handle sequences throughout twitcher as shared references (Arc<[u8]>) elsewhere.
    statistics.sequences.reference = String::new();
    statistics.sequences.reference_rc = String::new();
    statistics.sequences.query = String::new();
    statistics.sequences.query_rc = String::new();

    if let Some(ts_alignment_cigar) = ts_alignment_cigar {
        let cost = statistics.result.cost();
        if statistics.template_switch_amount.const_raw() > 0.0 {
            // We actually have a TS, report that.
            let cost_without_ts = without_ts.map(|wt| match wt {
                wt @ AlignmentResult::WithTarget { .. } => wt.statistics().result.cost(),
                AlignmentResult::WithoutTarget { .. } => U64Cost::from_primitive(u64::MAX),
            });
            Ok(TwitcherAlignment::found_ts(
                ts_alignment_cigar,
                cost,
                cost_without_ts,
                statistics.into(),
            ))
        } else {
            Ok(TwitcherAlignment::no_ts(
                ts_alignment_cigar,
                cost,
                statistics.into(),
            ))
        }
    } else {
        let reason = match statistics.result {
            AStarResult::FoundTarget { .. } => {
                anyhow::anyhow!("There should be an alignment available, but it is not reported")
            }
            AStarResult::ExceededCostLimit { cost_limit } => {
                anyhow::anyhow!("Exceeded cost limit. Lower bound for the cost: {cost_limit}")
            }
            AStarResult::ExceededMemoryLimit { .. } => {
                return Err(AlignmentFailure::oom());
            }
            AStarResult::NoTarget => {
                anyhow::anyhow!("There was no target (implementation error?)")
            }
        };
        Err(AlignmentFailure::soft_fail(&reason))
    }
}

impl From<FpaAlignmentStatistics> for TwitcherAlignmentStatistics {
    fn from(value: FpaAlignmentStatistics) -> Self {
        Self::FPAStats(value.into())
    }
}

impl From<AlignmentStatistics<U64Cost>> for TwitcherAlignmentStatistics {
    fn from(value: AlignmentStatistics<U64Cost>) -> Self {
        Self::TSAlign(value.into())
    }
}

impl From<anyhow::Error> for AlignmentFailure {
    fn from(value: anyhow::Error) -> Self {
        Self::Error {
            error: value.to_string(),
        }
    }
}

impl From<std::io::Error> for AlignmentFailure {
    fn from(value: std::io::Error) -> Self {
        Self::Error {
            error: value.to_string(),
        }
    }
}

impl Display for AlignmentFailure {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            AlignmentFailure::SoftFailure {
                reason: SoftFailureReason::OutOfMemory,
            } => write!(f, "Expected failure: Out of memory."),
            AlignmentFailure::SoftFailure {
                reason: SoftFailureReason::Timeout(dur),
            } => write!(f, "Expected failure: Timeout ({dur:?})."),
            AlignmentFailure::SoftFailure {
                reason: SoftFailureReason::Other(e),
            } => write!(f, "Expected failure: {e}."),
            AlignmentFailure::Error { error } => write!(f, "Unexpected failure: {error}"),
        }
    }
}

impl Error for AlignmentFailure {}

#[derive(Debug)]
pub struct TSData {
    pub inner_len: usize,
    pub jump_1_2: isize,
    pub er: EqualCostRange,
    pub inner_aln: Alignment<AlignmentType>,
    pub apg: isize,
}

impl TSData {
    pub fn compute(result: &TwitcherAlignment) -> Vec<TSData> {
        let TwitcherAlignmentCase::FoundTS {
            alignment_with_ts, ..
        } = &result.result
        else {
            return Vec::new();
        };

        let mut results = Vec::new();
        let mut curr_data = None;
        for (n, ty) in alignment_with_ts.iter_compact() {
            match ty {
                AlignmentType::TemplateSwitchEntrance {
                    first_offset,
                    equal_cost_range,
                    ..
                } => {
                    curr_data = Some(TSData {
                        inner_len: 0,
                        jump_1_2: *first_offset,
                        er: *equal_cost_range,
                        inner_aln: Alignment::new(),
                        apg: 0,
                    });
                }
                AlignmentType::TemplateSwitchExit { anti_primary_gap } => {
                    let Some(mut current) = curr_data.take() else {
                        panic!("Invalid alignment");
                    };
                    current.apg = *anti_primary_gap;
                    results.push(current);
                }
                AlignmentType::SecondaryInsertion
                | AlignmentType::SecondarySubstitution
                | AlignmentType::SecondaryMatch => {
                    let Some(current) = &mut curr_data else {
                        panic!("Invalid alignment");
                    };
                    current.inner_len += n;
                    current.inner_aln.push_n(n, *ty);
                }
                AlignmentType::SecondaryDeletion => {
                    let Some(current) = &mut curr_data else {
                        panic!("Invalid alignment");
                    };
                    current.inner_aln.push_n(n, *ty);
                }
                _ => {}
            }
        }
        results
    }

    pub fn to_field<S: ToString>(data: &[Self], sep: &str, f: impl Fn(&Self) -> S) -> String {
        data.iter().map(|d| f(d).to_string()).join(sep)
    }
}