twitcher 0.2.1

Find template switch mutations in genomic data
use anyhow::Context as _;
use clap::ValueEnum;
use compact_genome::implementation::alphabets::dna_alphabet_or_n::DnaAlphabetOrN;
use lib_tsalign::a_star_aligner::configurable_a_star_align::{
    DescendantStrategySelector, MinLengthStrategySelector,
};
use lib_tsalign::costs::U64Cost;
use lib_tsalign::{
    a_star_aligner::configurable_a_star_align::Aligner, config::TemplateSwitchConfig,
};
use serde::{Deserialize, Serialize};
use std::str::FromStr;
use std::{fs::File, io::BufReader, time::Duration};
use tokio::sync::Semaphore;

use bytesize::ByteSize;

use crate::common::aligner::db::CliDatabaseArgs;
use crate::common::aligner::{AlignerSelector, fpa::FourPointAligner};

#[derive(clap::Args, Debug)]
pub struct CliAlignmentArgs {
    /// Number of threads to use for the alignments
    #[arg(long = "threads", value_name = "NUM")]
    pub num_threads: Option<usize>,

    /// Set an overall memory limit, for example 4g or 500m.
    #[arg(short, long, global = true, value_name = "EXPR", value_parser = ByteSize::from_str)]
    pub memory_limit: Option<ByteSize>,

    /// Set an overall memory limit, for example 4g or 500m.
    #[arg(long = "aligner-timeout", value_name = "EXPR", value_parser = parse_duration)]
    pub aligner_timeout: Option<Duration>,

    /// How much padding should be considered in the alignment around a mutation cluster. This is the window to where a template switch may "jump" (e.g. 100 for +- 100 bps before and after a mutation cluster)
    #[arg(short, long, default_value_t = 200, value_name = "BASES")]
    pub padding: usize,

    /// A config file for the costs (in custom .tsa format), this is used to initialize the aligner.
    #[arg(short, long, value_name = "FILE")]
    pub costs: Option<String>,

    /// Force the aligner to not align with template switches. Useful for performance benchmarks, useless otherwise.
    #[arg(long)]
    pub no_ts: bool,

    /// Use the Four-Point-Aligner, based on dynamic programming. While it is limited capabilities, it is generally faster.
    /// It cannot, however, find alignments that have indels in the inner part of the template switch or that have more than one template switch.
    /// A best effort is made to keep the costs consistent with the `--cost` argument, with one exception:
    /// Since this aligner is not gap-affine, there is no gap open cost. Only the cost for extending a gap is taken into account.
    #[arg(long = "fpa")]
    pub use_fpa: bool,

    /// tsalign only: Select a strategy to determine the minimum inner length of a template switch during the alignment. See the aligner repo for details.
    #[arg(long, value_name = "ML_STRATEGY", conflicts_with = "use_fpa")]
    pub min_length_strategy: Option<MLSSelector>,

    /// tsalign only: Allow multiple template switches in the same solution to have different descendants. By default, all template switches in the solution must have the same descendant, but it can be either of the sequences (ref / alt).
    #[arg(long, conflicts_with = "use_fpa")]
    pub allow_mixed_descendants: bool,

    #[clap(flatten)]
    pub database: CliDatabaseArgs,
}

fn parse_duration(s: &str) -> anyhow::Result<Duration> {
    if s.is_empty() {
        anyhow::bail!("empty duration string");
    }

    let ix = s
        .find(|c: char| !c.is_ascii_digit() && c != '.')
        .ok_or_else(|| anyhow::anyhow!("no unit detected"))?;
    let (number, unit) = s.split_at_checked(ix).context("non-utf8 duration string")?;
    let value: u64 = number.trim().parse()?;

    let dur = match unit.trim() {
        "ms" => Duration::from_millis(value),
        "s" => Duration::from_secs(value),
        "m" => Duration::from_secs(value * 60),
        "h" => Duration::from_secs(value * 60 * 60),
        // This only works on Rust 1.91.0 and up
        // "m" => Duration::from_mins(value),
        // "h" => Duration::from_hours(value),
        _ => anyhow::bail!("unknown duration unit: {unit}"),
    };

    Ok(dur)
}

impl Default for CliAlignmentArgs {
    fn default() -> Self {
        Self {
            num_threads: None,
            memory_limit: None,
            aligner_timeout: None,
            padding: 200,
            costs: None,
            no_ts: false,
            use_fpa: false,
            min_length_strategy: None,
            allow_mixed_descendants: false,
            database: CliDatabaseArgs::default(),
        }
    }
}

fn default_costs() -> TemplateSwitchConfig<DnaAlphabetOrN, U64Cost> {
    let costs = include_bytes!("../../../default_costs.tsa");
    TemplateSwitchConfig::read_plain(&costs[..]).expect("unable to parse the default cost file")
}

impl CliAlignmentArgs {
    pub(super) fn init_aligner(
        &self,
    ) -> anyhow::Result<(
        AlignerSelector,
        TemplateSwitchConfig<DnaAlphabetOrN, U64Cost>,
    )> {
        let costs = if let Some(cost_path) = &self.costs {
            tracing::info!("Loading costs from {cost_path}");
            let file = BufReader::new(File::open(cost_path)?);
            TemplateSwitchConfig::read_plain(file)?
        } else {
            tracing::info!("Using default costs");
            default_costs()
        };

        if self.use_fpa {
            let aligner = FourPointAligner::new(costs.clone(), self.no_ts);
            Ok((super::AlignerSelector::Fpa(aligner), costs))
        } else {
            let init = || {
                let mut a = Aligner::new();
                a.set_costs(costs.clone());
                a.set_min_length_strategy(self.min_length_strategy.unwrap_or_default().into());
                a.set_descendant_strategy(if self.allow_mixed_descendants {
                    DescendantStrategySelector::AllowAny
                } else {
                    DescendantStrategySelector::AllowOnlyAllEqual
                });
                a.set_no_ts(self.no_ts);
                a
            };
            let ts = init();
            Ok((AlignerSelector::AStar(ts), costs))
        }
    }

    pub fn init_semaphore(&self) -> anyhow::Result<(tokio::sync::Semaphore, usize)> {
        let n = if let Some(num) = self.num_threads {
            num.try_into()?
        } else {
            std::thread::available_parallelism()?
        };
        let sem = Semaphore::new(n.into());
        let total_memory = usize::try_from(
            self.memory_limit
                .unwrap_or_else(default_memory_limit)
                .as_u64(),
        )
        .context("The memory limit is too high")?;
        let memory_per_thread = total_memory / n;
        Ok((sem, memory_per_thread))
    }
}

#[derive(Default, ValueEnum, Debug, Clone, Copy, Serialize, Deserialize)]
pub enum MLSSelector {
    None,
    Lookahead,
    #[default]
    PreprocessFilter,
    PreprocessPrice,
    PreprocessLookahead,
}

impl From<MLSSelector> for MinLengthStrategySelector {
    fn from(value: MLSSelector) -> Self {
        match value {
            MLSSelector::None => Self::None,
            MLSSelector::Lookahead => Self::Lookahead,
            MLSSelector::PreprocessFilter => Self::PreprocessFilter,
            MLSSelector::PreprocessPrice => Self::PreprocessPrice,
            MLSSelector::PreprocessLookahead => Self::PreprocessLookahead,
        }
    }
}

fn default_memory_limit() -> ByteSize {
    let mut system = sysinfo::System::new_all();
    system.refresh_memory();
    ByteSize::b(system.available_memory())
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn parse_duration_milliseconds() {
        assert_eq!(parse_duration("100ms").unwrap(), Duration::from_millis(100));
    }

    #[test]
    fn parse_duration_seconds() {
        assert_eq!(parse_duration("5s").unwrap(), Duration::from_secs(5));
    }

    #[test]
    fn parse_duration_minutes() {
        assert_eq!(parse_duration("2m").unwrap(), Duration::from_secs(120));
    }

    #[test]
    fn parse_duration_hours() {
        assert_eq!(parse_duration("1h").unwrap(), Duration::from_secs(3600));
    }

    #[test]
    fn parse_duration_empty_is_error() {
        assert!(parse_duration("").is_err());
    }

    #[test]
    fn parse_duration_no_unit_is_error() {
        assert!(parse_duration("100").is_err());
    }

    #[test]
    fn parse_duration_unknown_unit_is_error() {
        assert!(parse_duration("10d").is_err());
    }

    #[test]
    fn parse_default_costs() {
        let _ = default_costs();
    }
}