use anyhow::Context as _;
use clap::ValueEnum;
use lib_tsalign::a_star_aligner::configurable_a_star_align::MinLengthStrategySelector;
use lib_tsalign::{
a_star_aligner::configurable_a_star_align::Aligner, config::TemplateSwitchConfig,
};
use serde::{Deserialize, Serialize};
use std::str::FromStr;
use std::{fs::File, io::BufReader, time::Duration};
use tokio::sync::Semaphore;
use bytesize::ByteSize;
use crate::common::aligner::db::CliDatabaseArgs;
use crate::common::aligner::{AStarAlignerPair, AlignerSelector, fpa::FourPointAligner};
#[derive(clap::Args, Debug)]
pub struct CliAlignmentArgs {
#[arg(long = "threads", value_name = "NUM")]
pub num_threads: Option<usize>,
#[arg(short, long, global = true, value_name = "EXPR", value_parser = ByteSize::from_str)]
pub memory_limit: Option<ByteSize>,
#[arg(long = "aligner-timeout", value_name = "EXPR", value_parser = parse_duration)]
pub aligner_timeout: Option<Duration>,
#[arg(short, long, default_value_t = 200, value_name = "BASES")]
pub padding: usize,
#[arg(short, long, value_name = "FILE")]
pub costs: Option<String>,
#[arg(long)]
pub no_ts: bool,
#[arg(long = "fpa")]
pub use_fpa: bool,
#[arg(long, value_name = "ML_STRATEGY")]
pub min_length_strategy: Option<MLSSelector>,
#[clap(flatten)]
pub database: CliDatabaseArgs,
}
fn parse_duration(s: &str) -> anyhow::Result<Duration> {
if s.is_empty() {
anyhow::bail!("empty duration string");
}
let ix = s
.find(|c: char| !c.is_ascii_digit() && c != '.')
.ok_or_else(|| anyhow::anyhow!("no unit detected"))?;
let (number, unit) = s.split_at_checked(ix).context("non-utf8 duration string")?;
let value: u64 = number.trim().parse()?;
let dur = match unit.trim() {
"ms" => Duration::from_millis(value),
"s" => Duration::from_secs(value),
"m" => Duration::from_secs(value * 60),
"h" => Duration::from_secs(value * 60 * 60),
_ => anyhow::bail!("unknown duration unit: {unit}"),
};
Ok(dur)
}
impl Default for CliAlignmentArgs {
fn default() -> Self {
Self {
num_threads: None,
memory_limit: None,
aligner_timeout: None,
padding: 200,
costs: None,
no_ts: false,
use_fpa: false,
min_length_strategy: None,
database: CliDatabaseArgs::default(),
}
}
}
impl CliAlignmentArgs {
pub(super) fn init_aligner(&self) -> anyhow::Result<AlignerSelector> {
let costs = if let Some(cost_path) = &self.costs {
tracing::info!("Loading costs from {cost_path}");
let file = BufReader::new(File::open(cost_path)?);
Some(TemplateSwitchConfig::read_plain(file)?)
} else {
tracing::info!("Using default costs");
None
};
if self.use_fpa {
let aligner = FourPointAligner::new(costs.unwrap_or_default(), self.no_ts);
Ok(super::AlignerSelector::Fpa(aligner))
} else {
let init = || {
let mut a = Aligner::new();
a.set_costs(costs.clone().unwrap_or_default());
a.set_min_length_strategy(self.min_length_strategy.unwrap_or_default().into());
a.set_no_ts(self.no_ts);
a
};
let ts = init();
let mut no_ts = init();
no_ts.set_no_ts(true);
Ok(AlignerSelector::AStar(AStarAlignerPair { ts, no_ts }))
}
}
pub fn init_semaphore(&self) -> anyhow::Result<(tokio::sync::Semaphore, usize)> {
let n = if let Some(num) = self.num_threads {
num.try_into()?
} else {
std::thread::available_parallelism()?
};
let sem = Semaphore::new(n.into());
let total_memory = usize::try_from(
self.memory_limit
.unwrap_or_else(default_memory_limit)
.as_u64(),
)
.context("The memory limit is too high")?;
let memory_per_thread = total_memory / n;
Ok((sem, memory_per_thread))
}
}
#[derive(Default, ValueEnum, Debug, Clone, Copy, Serialize, Deserialize)]
pub enum MLSSelector {
None,
Lookahead,
#[default]
PreprocessFilter,
PreprocessPrice,
PreprocessLookahead,
}
impl From<MLSSelector> for MinLengthStrategySelector {
fn from(value: MLSSelector) -> Self {
match value {
MLSSelector::None => MinLengthStrategySelector::None,
MLSSelector::Lookahead => MinLengthStrategySelector::Lookahead,
MLSSelector::PreprocessFilter => MinLengthStrategySelector::PreprocessFilter,
MLSSelector::PreprocessPrice => MinLengthStrategySelector::PreprocessPrice,
MLSSelector::PreprocessLookahead => MinLengthStrategySelector::PreprocessLookahead,
}
}
}
fn default_memory_limit() -> ByteSize {
let mut system = sysinfo::System::new_all();
system.refresh_memory();
ByteSize::b(system.available_memory())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_duration_milliseconds() {
assert_eq!(parse_duration("100ms").unwrap(), Duration::from_millis(100));
}
#[test]
fn parse_duration_seconds() {
assert_eq!(parse_duration("5s").unwrap(), Duration::from_secs(5));
}
#[test]
fn parse_duration_minutes() {
assert_eq!(parse_duration("2m").unwrap(), Duration::from_secs(120));
}
#[test]
fn parse_duration_hours() {
assert_eq!(parse_duration("1h").unwrap(), Duration::from_secs(3600));
}
#[test]
fn parse_duration_empty_is_error() {
assert!(parse_duration("").is_err());
}
#[test]
fn parse_duration_no_unit_is_error() {
assert!(parse_duration("100").is_err());
}
#[test]
fn parse_duration_unknown_unit_is_error() {
assert!(parse_duration("10d").is_err());
}
}