use clap::{ArgAction, Parser};
use std::path::PathBuf;
use crate::profile;
use crate::symbol;
use crate::udf;
#[derive(Parser, Debug)]
#[command(name = "ries-rs")]
#[command(author = "Maxwell Santoro")]
#[command(version)]
#[command(about = "Find algebraic equations given their solution", long_about = None)]
pub struct Args {
pub target: Option<f64>,
#[arg(short = 'l', long, default_value = "2")]
pub level: String,
#[arg(short = 'n', long = "max-matches", default_value = "16")]
pub max_matches: usize,
#[arg(short = 'x', long, alias = "absolute-roots")]
pub absolute: bool,
#[arg(short = 's', long, alias = "try-solve-for-x")]
pub solve: bool,
#[arg(long = "no-solve-for-x")]
pub no_solve: bool,
#[arg(short = 'N', long)]
pub exclude: Option<String>,
#[arg(short = 'E', long = "enable", num_args = 0..=1, default_missing_value = "all")]
pub enable: Option<String>,
#[arg(short = 'S', long, num_args = 0..=1)]
pub only_symbols: Option<String>,
#[arg(short = 'O', long)]
pub op_limits: Option<String>,
#[arg(long = "S-RHS")]
pub only_symbols_rhs: Option<String>,
#[arg(long = "N-RHS")]
pub exclude_rhs: Option<String>,
#[arg(long = "E-RHS")]
pub enable_rhs: Option<String>,
#[arg(long = "O-RHS")]
pub op_limits_rhs: Option<String>,
#[arg(long)]
pub symbol_weights: Option<String>,
#[arg(long)]
pub symbol_names: Option<String>,
#[arg(short = 'a', long, alias = "algebraic-subexpressions")]
pub algebraic: bool,
#[arg(short = 'c', long, alias = "constructible-subexpressions")]
pub constructible: bool,
#[arg(short = 'r', long, alias = "rational-subexpressions")]
pub rational: bool,
#[arg(short = 'i', long, alias = "integer-subexpressions")]
pub integer: bool,
#[arg(long = "ie")]
pub integer_exact: bool,
#[arg(long = "re")]
pub rational_exact: bool,
#[arg(long = "liouvillian-subexpressions")]
pub liouvillian: bool,
#[arg(long, default_value = "true")]
pub parallel: bool,
#[arg(long)]
pub deterministic: bool,
#[arg(long)]
pub streaming: bool,
#[arg(long)]
pub adaptive: bool,
#[arg(long, default_value_t = true, action = ArgAction::Set)]
pub report: bool,
#[arg(long)]
pub classic: bool,
#[arg(long, conflicts_with = "complexity_ranking")]
pub parity_ranking: bool,
#[arg(long, conflicts_with = "parity_ranking")]
pub complexity_ranking: bool,
#[arg(short = 'F', long, num_args = 0..=1, default_missing_value = "3", default_value = "2")]
pub format: String,
#[arg(short = 'D', num_args = 0..=1, default_missing_value = "")]
pub diagnostics: Option<String>,
#[arg(short = 'k', long = "top-k", default_value = "8")]
pub top_k: usize,
#[arg(long)]
pub no_stable: bool,
#[arg(long)]
pub stats: bool,
#[arg(long)]
pub json: bool,
#[arg(long = "list-options")]
pub list_options: bool,
#[arg(long)]
pub preset: Option<String>,
#[arg(long)]
pub list_presets: bool,
#[arg(long)]
pub stability_check: bool,
#[arg(long)]
pub stability_thorough: bool,
#[arg(long)]
pub stop_at_exact: bool,
#[arg(long)]
pub stop_below: Option<f64>,
#[arg(short = 'p', long)]
pub profile: Option<PathBuf>,
#[arg(long)]
pub include: Vec<PathBuf>,
#[arg(short = 'X', long = "user-constant", alias = "constant")]
pub user_constant: Vec<String>,
#[arg(long)]
pub define: Vec<String>,
#[arg(long, alias = "mad")]
pub max_match_distance: Option<f64>,
#[arg(long)]
pub min_match_distance: Option<f64>,
#[arg(long)]
pub one_sided: bool,
#[arg(long)]
pub no_refinement: bool,
#[arg(long)]
pub eval_expression: Option<String>,
#[arg(long)]
pub find_expression: Option<String>,
#[arg(long)]
pub at: Option<f64>,
#[arg(long, default_value = "15")]
pub newton_iterations: usize,
#[arg(long)]
pub precision: Option<u32>,
#[arg(long)]
pub zero_threshold: Option<f64>,
#[arg(long)]
pub wide: bool,
#[arg(long = "wide-output")]
pub wide_output: bool,
#[arg(long = "relative-roots")]
pub relative_roots: bool,
#[arg(long = "any-exponents")]
pub any_exponents: bool,
#[arg(long = "any-subexpressions")]
pub any_subexpressions: bool,
#[arg(long = "any-trig-args")]
pub any_trig_args: bool,
#[arg(long = "canon-reduction")]
pub canon_reduction: Option<String>,
#[arg(long = "canon-simplify")]
pub canon_simplify: bool,
#[arg(long = "derivative-margin")]
pub derivative_margin: Option<f64>,
#[arg(long = "explicit-multiply")]
pub explicit_multiply: bool,
#[arg(long = "match-all-digits")]
pub match_all_digits: bool,
#[arg(long = "max-equate-value")]
pub max_equate_value: Option<f64>,
#[arg(long = "max-memory")]
pub max_memory: Option<String>,
#[arg(long = "memory-abort-threshold")]
pub memory_abort_threshold: Option<f64>,
#[arg(long = "max-trig-cycles")]
pub max_trig_cycles: Option<u32>,
#[arg(long = "min-equate-value")]
pub min_equate_value: Option<f64>,
#[arg(long = "min-memory")]
pub min_memory: Option<String>,
#[arg(long = "no-canon-simplify")]
pub no_canon_simplify: bool,
#[arg(long = "no-slow-messages")]
pub no_slow_messages: bool,
#[arg(long = "numeric-anagram")]
pub numeric_anagram: bool,
#[arg(long = "rational-exponents")]
pub rational_exponents: bool,
#[arg(long = "rational-trig-args")]
pub rational_trig_args: bool,
#[arg(long = "show-work")]
pub show_work: bool,
#[arg(long = "significance-loss-margin")]
pub significance_loss_margin: Option<f64>,
#[arg(long = "trig-argument-scale")]
pub trig_argument_scale: Option<f64>,
#[arg(long)]
pub verbose: bool,
#[arg(long)]
pub pslq: bool,
#[arg(long)]
pub pslq_extended: bool,
#[arg(long, default_value = "1000")]
pub pslq_max_coeff: i64,
#[arg(long, value_name = "FILE")]
pub emit_manifest: Option<PathBuf>,
}
pub fn parse_user_constant_from_cli(
profile: &mut profile::Profile,
spec: &str,
) -> Result<(), String> {
let parts: Vec<&str> = spec.split(':').collect();
if parts.len() != 4 {
return Err(format!(
"Expected 4 colon-separated parts, got {}",
parts.len()
));
}
let weight: u32 = parts[0]
.parse()
.map_err(|_| format!("Invalid weight: {}", parts[0]))?;
let name = parts[1].to_string();
let description = parts[2].to_string();
let value: f64 = parts[3]
.parse()
.map_err(|_| format!("Invalid value: {}", parts[3]))?;
profile
.add_constant(weight, name, description, value)
.map_err(|e| e.to_string())
}
pub fn parse_user_function_from_cli(
profile: &mut profile::Profile,
spec: &str,
) -> Result<(), String> {
let udf = udf::UserFunction::parse(spec)?;
profile.functions.push(udf);
Ok(())
}
pub fn parse_symbol_names_from_cli(
profile: &mut profile::Profile,
spec: &str,
) -> Result<(), String> {
for part in spec.split_whitespace() {
if !part.starts_with(':') {
continue;
}
let inner = &part[1..];
let Some(colon_pos) = inner.find(':') else {
return Err(format!("Invalid symbol name format: {}", part));
};
let symbol_char = inner[..colon_pos]
.chars()
.next()
.ok_or_else(|| "Empty symbol in --symbol-names".to_string())?;
let display_name = inner[colon_pos + 1..].to_string();
if display_name.is_empty() {
return Err(format!(
"Empty replacement name in --symbol-names: {}",
part
));
}
let Some(symbol) = symbol::Symbol::from_byte(symbol_char as u8) else {
return Err(format!("Unknown symbol in --symbol-names: {}", symbol_char));
};
profile.symbol_names.insert(symbol, display_name);
}
Ok(())
}
pub fn parse_symbol_weights_from_cli(
profile: &mut profile::Profile,
spec: &str,
) -> Result<(), String> {
for part in spec.split_whitespace() {
if !part.starts_with(':') {
continue;
}
let inner = &part[1..];
let Some(colon_pos) = inner.find(':') else {
return Err(format!("Invalid symbol weight format: {}", part));
};
let symbol_char = inner[..colon_pos]
.chars()
.next()
.ok_or_else(|| "Empty symbol in --symbol-weights".to_string())?;
let weight: u32 = inner[colon_pos + 1..]
.parse()
.map_err(|_| format!("Invalid weight in --symbol-weights: {}", part))?;
let Some(symbol) = symbol::Symbol::from_byte(symbol_char as u8) else {
return Err(format!(
"Unknown symbol in --symbol-weights: {}",
symbol_char
));
};
profile.symbol_weights.insert(symbol, weight);
}
Ok(())
}
pub fn parse_symbol_count_limits(
spec: &str,
) -> Result<std::collections::HashMap<symbol::Symbol, u32>, String> {
let mut limits = std::collections::HashMap::new();
let chars: Vec<char> = spec.chars().collect();
let mut i = 0;
while i < chars.len() {
if chars[i].is_whitespace() {
i += 1;
continue;
}
let mut saw_digits = false;
let mut count: u32 = 0;
while i < chars.len() && chars[i].is_ascii_digit() {
saw_digits = true;
count = count
.saturating_mul(10)
.saturating_add((chars[i] as u8 - b'0') as u32);
i += 1;
}
if i >= chars.len() {
return Err("Trailing count with no symbol in -O/--op-limits".to_string());
}
let symbol_char = chars[i];
i += 1;
if symbol_char.is_whitespace() {
continue;
}
let max_count = if saw_digits { count } else { 1 };
if max_count == 0 {
return Err(format!("Invalid zero count for symbol '{}'", symbol_char));
}
let Some(sym) = symbol::Symbol::from_byte(symbol_char as u8) else {
return Err(format!(
"Unknown symbol '{}' in -O/--op-limits",
symbol_char
));
};
limits.insert(sym, max_count);
}
Ok(limits)
}
pub fn parse_symbol_sets(
only_symbols: Option<&str>,
exclude_symbols: Option<&str>,
enable_symbols: Option<&str>,
) -> (
Option<std::collections::HashSet<u8>>,
Option<std::collections::HashSet<u8>>,
) {
let mut allowed: Option<std::collections::HashSet<u8>> =
only_symbols.map(|s| s.bytes().collect());
let mut excluded: Option<std::collections::HashSet<u8>> =
exclude_symbols.map(|s| s.bytes().collect());
if let Some(enabled) = enable_symbols {
if enabled == "all" {
excluded = None;
} else {
for b in enabled.bytes() {
if let Some(excl) = excluded.as_mut() {
excl.remove(&b);
}
if let Some(allow) = allowed.as_mut() {
allow.insert(b);
}
}
}
}
(allowed, excluded)
}
pub fn parse_memory_size_bytes(spec: &str) -> Option<u64> {
let trimmed = spec.trim();
if trimmed.is_empty() {
return None;
}
let (num_part, suffix) = match trimmed.chars().last().filter(|c| c.is_ascii_alphabetic()) {
Some(last) => (&trimmed[..trimmed.len() - last.len_utf8()], Some(last)),
None => (trimmed, None),
};
let number: f64 = num_part.trim().parse().ok()?;
if !number.is_finite() || number < 0.0 {
return None;
}
let mult = match suffix.map(|c| c.to_ascii_uppercase()) {
None => 1_f64,
Some('K') => 1024_f64,
Some('M') => 1024_f64 * 1024_f64,
Some('G') => 1024_f64 * 1024_f64 * 1024_f64,
Some('T') => 1024_f64 * 1024_f64 * 1024_f64 * 1024_f64,
_ => return None,
};
Some((number * mult) as u64)
}
pub fn canon_reduction_enabled(spec: Option<&str>) -> bool {
let Some(value) = spec else {
return false;
};
let lowered = value.trim().to_ascii_lowercase();
!matches!(lowered.as_str(), "" | "off" | "none" | "0" | "false")
}
pub fn print_option_list() {
let opts = [
"--list-options",
"-p",
"--include",
"--any-exponents",
"--any-subexpressions",
"--any-trig-args",
"--canon-reduction",
"--canon-simplify",
"--derivative-margin",
"--eval-expression",
"--explicit-multiply",
"--find-expression",
"--match-all-digits",
"--mad",
"--max-equate-value",
"--max-match-distance",
"--min-match-distance",
"--max-matches",
"--max-memory",
"--memory-abort-threshold",
"-X",
"--constant",
"--define",
"--min-equate-value",
"--max-trig-cycles",
"--min-memory",
"--no-canon-simplify",
"--no-refinement",
"--no-slow-messages",
"--no-solve-for-x",
"--numeric-anagram",
"--one-sided",
"--complexity-ranking",
"--parity-ranking",
"--rational-exponents",
"--rational-trig-args",
"--show-work",
"--significance-loss-margin",
"--symbol-weights",
"--symbol-names",
"--trig-argument-scale",
"-s",
"--try-solve-for-x",
"--version",
"--wide",
"--wide-output",
"-a",
"--algebraic-subexpressions",
"-c",
"--constructible-subexpressions",
"-D",
"-E",
"-F",
"-i",
"--integer-subexpressions",
"-l",
"--liouvillian-subexpressions",
"-N",
"-O",
"-r",
"--rational-subexpressions",
"-S",
"-x",
"--absolute-roots",
"--relative-roots",
"--N-RHS",
"--O-RHS",
"--S-RHS",
"--E-RHS",
];
for opt in opts {
println!("{}", opt);
}
}
pub fn sym_description(sym: symbol::Symbol) -> &'static str {
use symbol::Symbol;
match sym {
Symbol::One => "integer",
Symbol::Two => "integer",
Symbol::Three => "integer",
Symbol::Four => "integer",
Symbol::Five => "integer",
Symbol::Six => "integer",
Symbol::Seven => "integer",
Symbol::Eight => "integer",
Symbol::Nine => "integer",
Symbol::Pi => "pi = 3.14159...",
Symbol::E => "e = base of natural logarithms, 2.71828...",
Symbol::Phi => "phi = the golden ratio, (1+sqrt(5))/2",
Symbol::Gamma => "Euler-Mascheroni constant gamma",
Symbol::Plastic => "plastic constant",
Symbol::Apery => "Apery's constant zeta(3)",
Symbol::Catalan => "Catalan's constant",
Symbol::X => "the variable of the equation",
Symbol::Neg => "negate",
Symbol::Recip => "reciprocal",
Symbol::Sqrt => "sqrt(x) = square root",
Symbol::Square => "^2 = square",
Symbol::Ln => "ln(x) = natural logarithm or log base e",
Symbol::Exp => "natural exponent function",
Symbol::SinPi => "sinpi(X) = sin(pi * x)",
Symbol::CosPi => "cospi(X) = cos(pi * x)",
Symbol::TanPi => "tanpi(X) = tan(pi * x)",
Symbol::LambertW => "Lambert W function",
Symbol::Add => "add",
Symbol::Sub => "subtract",
Symbol::Mul => "multiply",
Symbol::Div => "divide",
Symbol::Pow => "power",
Symbol::Root => "a-th root of b",
Symbol::Log => "log base a of b",
Symbol::Atan2 => "2-argument arctangent",
_ => "",
}
}
pub fn print_symbol_table() {
println!("Explicit values:");
println!(" sym seft wght name description");
for sym in symbol::Symbol::constants() {
let byte = *sym as u8;
if byte < 128 {
println!(
" {:<2} {:<1} {:<3} {:<6} {}",
byte as char,
match sym.seft() {
symbol::Seft::A => "a",
symbol::Seft::B => "b",
symbol::Seft::C => "c",
},
sym.weight(),
sym.name(),
sym_description(*sym)
);
}
}
println!("\nFunctions of one argument:");
println!(" sym seft wght name description");
for sym in symbol::Symbol::unary_ops() {
let byte = *sym as u8;
println!(
" {:<2} {:<1} {:<3} {:<6} {}",
byte as char,
match sym.seft() {
symbol::Seft::A => "a",
symbol::Seft::B => "b",
symbol::Seft::C => "c",
},
sym.weight(),
sym.name(),
sym_description(*sym)
);
}
println!("\nFunctions of two arguments:");
println!(" sym seft wght name description");
for sym in symbol::Symbol::binary_ops() {
let byte = *sym as u8;
println!(
" {:<2} {:<1} {:<3} {:<6} {}",
byte as char,
match sym.seft() {
symbol::Seft::A => "a",
symbol::Seft::B => "b",
symbol::Seft::C => "c",
},
sym.weight(),
sym.name(),
sym_description(*sym)
);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_memory_size() {
assert_eq!(parse_memory_size_bytes("512M"), Some(512 * 1024 * 1024));
assert_eq!(parse_memory_size_bytes("2G"), Some(2 * 1024 * 1024 * 1024));
assert_eq!(parse_memory_size_bytes("1024"), Some(1024));
assert_eq!(parse_memory_size_bytes("1k"), Some(1024));
}
#[test]
fn test_parse_symbol_sets() {
let (allowed, excluded) = parse_symbol_sets(Some("abc"), Some("d"), None);
assert_eq!(allowed, Some(vec![b'a', b'b', b'c'].into_iter().collect()));
assert_eq!(excluded, Some(vec![b'd'].into_iter().collect()));
let (_, excluded) = parse_symbol_sets(None, Some("abc"), Some("all"));
assert!(excluded.is_none());
}
}