twitcher 0.1.8

Find template switch mutations in genomic data
use anyhow::Context as _;
use clap::ValueEnum;
use lib_tsalign::a_star_aligner::configurable_a_star_align::MinLengthStrategySelector;
use lib_tsalign::{
    a_star_aligner::configurable_a_star_align::Aligner, config::TemplateSwitchConfig,
};
use serde::{Deserialize, Serialize};
use std::str::FromStr;
use std::{fs::File, io::BufReader, time::Duration};
use tokio::sync::Semaphore;

use bytesize::ByteSize;

use crate::common::aligner::db::CliDatabaseArgs;
use crate::common::aligner::{AStarAlignerPair, AlignerSelector, fpa::FourPointAligner};

#[derive(clap::Args, Debug)]
pub struct CliAlignmentArgs {
    /// Number of threads to use for the alignments
    #[arg(long = "threads", value_name = "NUM")]
    pub num_threads: Option<usize>,

    /// Set an overall memory limit, for example 4g or 500m.
    #[arg(short, long, global = true, value_name = "EXPR", value_parser = ByteSize::from_str)]
    pub memory_limit: Option<ByteSize>,

    /// Set an overall memory limit, for example 4g or 500m.
    #[arg(long = "aligner-timeout", value_name = "EXPR", value_parser = parse_duration)]
    pub aligner_timeout: Option<Duration>,

    /// How much padding should be considered in the alignment around a mutation cluster. This is the window to where a template switch may "jump" (e.g. 100 for +- 100 bps before and after a mutation cluster)
    #[arg(short, long, default_value_t = 200, value_name = "BASES")]
    pub padding: usize,

    /// A config file for the costs (in custom .tsa format), this is used to initialize the aligner.
    #[arg(short, long, value_name = "FILE")]
    pub costs: Option<String>,

    /// Force the aligner to not align with template switches. Useful for performance benchmarks, useless otherwise.
    #[arg(long)]
    pub no_ts: bool,

    /// Use the Four-Point-Aligner, based on dynamic programming. While it is limited capabilities, it is generally faster.
    /// It cannot, however, find alignments that have indels in the inner part of the template switch or that have more than one template switch.
    /// A best effort is made to keep the costs consistent with the `--cost` argument, with one exception:
    /// Since this aligner is not gap-affine, there is no gap open cost. Only the cost for extending a gap is taken into account.
    #[arg(long = "fpa")]
    pub use_fpa: bool,

    /// Select a strategy to determine the minimum inner length of a template switch during the alignment. See the aligner repo for details.
    #[arg(long, value_name = "ML_STRATEGY")]
    pub min_length_strategy: Option<MLSSelector>,

    #[clap(flatten)]
    pub database: CliDatabaseArgs,
}

fn parse_duration(s: &str) -> anyhow::Result<Duration> {
    if s.is_empty() {
        anyhow::bail!("empty duration string");
    }

    let ix = s
        .find(|c: char| !c.is_ascii_digit() && c != '.')
        .ok_or_else(|| anyhow::anyhow!("no unit detected"))?;
    let (number, unit) = s.split_at_checked(ix).context("non-utf8 duration string")?;
    let value: u64 = number.trim().parse()?;

    let dur = match unit.trim() {
        "ms" => Duration::from_millis(value),
        "s" => Duration::from_secs(value),
        "m" => Duration::from_secs(value * 60),
        "h" => Duration::from_secs(value * 60 * 60),
        // This only works on Rust 1.91.0 and up
        // "m" => Duration::from_mins(value),
        // "h" => Duration::from_hours(value),
        _ => anyhow::bail!("unknown duration unit: {unit}"),
    };

    Ok(dur)
}

impl Default for CliAlignmentArgs {
    fn default() -> Self {
        Self {
            num_threads: None,
            memory_limit: None,
            aligner_timeout: None,
            padding: 200,
            costs: None,
            no_ts: false,
            use_fpa: false,
            min_length_strategy: None,
            database: CliDatabaseArgs::default(),
        }
    }
}

impl CliAlignmentArgs {
    pub(super) fn init_aligner(&self) -> anyhow::Result<AlignerSelector> {
        let costs = if let Some(cost_path) = &self.costs {
            tracing::info!("Loading costs from {cost_path}");
            let file = BufReader::new(File::open(cost_path)?);
            Some(TemplateSwitchConfig::read_plain(file)?)
        } else {
            tracing::info!("Using default costs");
            None
        };

        if self.use_fpa {
            let aligner = FourPointAligner::new(costs.unwrap_or_default(), self.no_ts);
            Ok(super::AlignerSelector::Fpa(aligner))
        } else {
            let init = || {
                let mut a = Aligner::new();
                a.set_costs(costs.clone().unwrap_or_default());
                a.set_min_length_strategy(self.min_length_strategy.unwrap_or_default().into());
                a.set_no_ts(self.no_ts);
                a
            };
            let ts = init();
            let mut no_ts = init();
            no_ts.set_no_ts(true);
            Ok(AlignerSelector::AStar(AStarAlignerPair { ts, no_ts }))
        }
    }

    pub fn init_semaphore(&self) -> anyhow::Result<(tokio::sync::Semaphore, usize)> {
        let n = if let Some(num) = self.num_threads {
            num.try_into()?
        } else {
            std::thread::available_parallelism()?
        };
        let sem = Semaphore::new(n.into());
        let total_memory = usize::try_from(
            self.memory_limit
                .unwrap_or_else(default_memory_limit)
                .as_u64(),
        )
        .context("The memory limit is too high")?;
        let memory_per_thread = total_memory / n;
        Ok((sem, memory_per_thread))
    }
}

#[derive(Default, ValueEnum, Debug, Clone, Copy, Serialize, Deserialize)]
pub enum MLSSelector {
    None,
    Lookahead,
    #[default]
    PreprocessFilter,
    PreprocessPrice,
    PreprocessLookahead,
}

impl From<MLSSelector> for MinLengthStrategySelector {
    fn from(value: MLSSelector) -> Self {
        match value {
            MLSSelector::None => MinLengthStrategySelector::None,
            MLSSelector::Lookahead => MinLengthStrategySelector::Lookahead,
            MLSSelector::PreprocessFilter => MinLengthStrategySelector::PreprocessFilter,
            MLSSelector::PreprocessPrice => MinLengthStrategySelector::PreprocessPrice,
            MLSSelector::PreprocessLookahead => MinLengthStrategySelector::PreprocessLookahead,
        }
    }
}

fn default_memory_limit() -> ByteSize {
    let mut system = sysinfo::System::new_all();
    system.refresh_memory();
    ByteSize::b(system.available_memory())
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn parse_duration_milliseconds() {
        assert_eq!(parse_duration("100ms").unwrap(), Duration::from_millis(100));
    }

    #[test]
    fn parse_duration_seconds() {
        assert_eq!(parse_duration("5s").unwrap(), Duration::from_secs(5));
    }

    #[test]
    fn parse_duration_minutes() {
        assert_eq!(parse_duration("2m").unwrap(), Duration::from_secs(120));
    }

    #[test]
    fn parse_duration_hours() {
        assert_eq!(parse_duration("1h").unwrap(), Duration::from_secs(3600));
    }

    #[test]
    fn parse_duration_empty_is_error() {
        assert!(parse_duration("").is_err());
    }

    #[test]
    fn parse_duration_no_unit_is_error() {
        assert!(parse_duration("100").is_err());
    }

    #[test]
    fn parse_duration_unknown_unit_is_error() {
        assert!(parse_duration("10d").is_err());
    }
}