use std::collections::BTreeMap;
#[derive(Debug, Clone, Default)]
pub struct ArgsSpec {
pub options: Vec<OptionDecl>,
pub positionals: Vec<PositionalDecl>,
pub lenient_unknowns: bool,
}
#[derive(Debug, Clone)]
pub struct OptionDecl {
pub long: String,
pub short: Option<char>,
pub takes_value: bool,
pub allows_bare: bool,
pub repeatable: bool,
pub choices: Option<Vec<String>>,
}
#[derive(Debug, Clone)]
pub struct PositionalDecl {
pub name: String,
pub required: bool,
pub variadic: bool,
pub choices: Option<Vec<String>>,
}
#[derive(Debug, Default)]
pub struct ParsedArgs {
pub options: BTreeMap<String, OptionState>,
pub positionals: Vec<String>,
}
#[derive(Debug, Clone)]
pub enum OptionState {
BarePresent,
WithValues(Vec<String>),
}
pub fn parse(spec: &ArgsSpec, args: &[String]) -> Result<ParsedArgs, String> {
let mut out = ParsedArgs::default();
let mut i = 1usize;
let mut stop_flags = false;
while i < args.len() {
let tok = &args[i];
if !stop_flags && tok == "--" {
stop_flags = true;
i += 1;
continue;
}
if !stop_flags && tok.starts_with("--") {
let rest = &tok[2..];
let (name, inline_value) = match rest.split_once('=') {
Some((n, v)) => (n, Some(v.to_string())),
None => (rest, None),
};
match find_long(spec, name) {
Some(decl) => {
i = consume_flag(decl, inline_value, args, i, &mut out)?;
}
None if spec.lenient_unknowns => {
i += 1;
}
None => return Err(unknown_long_error(spec, name)),
}
continue;
}
if !stop_flags && tok.starts_with('-') && tok.len() >= 2 {
let rest = &tok[1..];
if let Some((head, inline_value)) = rest.split_once('=') {
if head.chars().count() != 1 {
return Err(format!(
"invalid short-flag syntax `{tok}`: `-x=value` requires a single-letter short"
));
}
let c = head.chars().next().unwrap();
match find_short(spec, c) {
Some(decl) => {
i = consume_flag(decl, Some(inline_value.to_string()), args, i, &mut out)?;
}
None if spec.lenient_unknowns => {
i += 1;
}
None => return Err(format!("unknown short flag `-{c}`")),
}
continue;
}
let chars: Vec<char> = rest.chars().collect();
if spec.lenient_unknowns && chars.iter().any(|c| find_short(spec, *c).is_none()) {
i += 1;
continue;
}
for (pos, c) in chars.iter().enumerate() {
let decl = find_short(spec, *c)
.ok_or_else(|| format!("unknown short flag `-{c}`"))?;
let is_last = pos == chars.len() - 1;
if !is_last && decl.takes_value {
return Err(format!(
"short `-{c}` takes a value and cannot be clustered mid-token `{tok}`"
));
}
if is_last {
i = consume_flag(decl, None, args, i, &mut out)?;
} else {
record_option(&mut out, decl, None, spec)?;
}
}
if chars.is_empty() {
out.positionals.push(tok.clone());
i += 1;
}
continue;
}
out.positionals.push(tok.clone());
i += 1;
}
if !spec.lenient_unknowns {
let required_count = spec.positionals.iter().filter(|p| p.required).count();
if out.positionals.len() < required_count {
let missing = &spec.positionals[out.positionals.len()].name;
return Err(format!("missing required argument <{missing}>"));
}
}
for (idx, value) in out.positionals.iter().enumerate() {
let decl = positional_decl_for(spec, idx);
if let Some(d) = decl {
if let Some(choices) = &d.choices {
if !choices.iter().any(|c| c == value) {
return Err(format!(
"invalid value `{value}` for <{}> -- allowed: {}",
d.name,
choices.join(", ")
));
}
}
}
}
Ok(out)
}
fn consume_flag(
decl: &OptionDecl,
inline_value: Option<String>,
args: &[String],
i: usize,
out: &mut ParsedArgs,
) -> Result<usize, String> {
if !decl.takes_value {
if inline_value.is_some() {
return Err(format!("flag `--{}` takes no value", decl.long));
}
record_option(out, decl, None, &ArgsSpec::default())?;
return Ok(i + 1);
}
if let Some(v) = inline_value {
record_option(out, decl, Some(v), &ArgsSpec::default())?;
return Ok(i + 1);
}
let next_idx = i + 1;
let next_is_flag = args
.get(next_idx)
.map(|s| s.starts_with('-') && s != "-" && s != "--")
.unwrap_or(true);
if !next_is_flag {
let v = args[next_idx].clone();
record_option(out, decl, Some(v), &ArgsSpec::default())?;
return Ok(i + 2);
}
if !decl.allows_bare {
return Err(format!("`--{}` requires a value", decl.long));
}
record_option(out, decl, None, &ArgsSpec::default())?;
Ok(i + 1)
}
fn record_option(
out: &mut ParsedArgs,
decl: &OptionDecl,
value: Option<String>,
_spec: &ArgsSpec,
) -> Result<(), String> {
if let (Some(v), Some(choices)) = (&value, &decl.choices) {
for part in split_list_value(v) {
if !choices.iter().any(|c| c == part) {
return Err(format!(
"invalid value `{part}` for `--{}` -- allowed: {}",
decl.long,
choices.join(", ")
));
}
}
}
let key = decl.long.clone();
match (value, decl.repeatable) {
(None, _) => {
out.options.entry(key).or_insert(OptionState::BarePresent);
}
(Some(v), false) => {
out.options.insert(key, OptionState::WithValues(vec![v]));
}
(Some(v), true) => {
let parts: Vec<String> = split_list_value(&v).into_iter().map(String::from).collect();
let entry = out
.options
.entry(key)
.or_insert(OptionState::WithValues(Vec::new()));
if let OptionState::WithValues(list) = entry {
list.extend(parts);
}
}
}
Ok(())
}
fn split_list_value(v: &str) -> Vec<&str> {
v.split(',').map(str::trim).filter(|s| !s.is_empty()).collect()
}
fn find_long<'a>(spec: &'a ArgsSpec, name: &str) -> Option<&'a OptionDecl> {
spec.options.iter().find(|o| o.long == name)
}
fn find_short(spec: &ArgsSpec, c: char) -> Option<&OptionDecl> {
spec.options.iter().find(|o| o.short == Some(c))
}
fn positional_decl_for(spec: &ArgsSpec, idx: usize) -> Option<&PositionalDecl> {
if let Some(decl) = spec.positionals.get(idx) {
return Some(decl);
}
spec.positionals.iter().rev().find(|d| d.variadic)
}
fn unknown_long_error(spec: &ArgsSpec, name: &str) -> String {
let suggestion = spec
.options
.iter()
.filter(|o| similar(&o.long, name))
.map(|o| format!("--{}", o.long))
.next();
match suggestion {
Some(s) => format!("unknown flag `--{name}`, did you mean `{s}`?"),
None => format!("unknown flag `--{name}`"),
}
}
fn similar(candidate: &str, target: &str) -> bool {
if candidate == target {
return false;
}
levenshtein(candidate, target) <= 2
}
fn levenshtein(a: &str, b: &str) -> usize {
let a: Vec<char> = a.chars().collect();
let b: Vec<char> = b.chars().collect();
let (m, n) = (a.len(), b.len());
if m == 0 {
return n;
}
if n == 0 {
return m;
}
let mut prev: Vec<usize> = (0..=n).collect();
let mut curr = vec![0usize; n + 1];
for (i, ca) in a.iter().enumerate() {
curr[0] = i + 1;
for (j, cb) in b.iter().enumerate() {
let cost = if ca == cb { 0 } else { 1 };
curr[j + 1] = (prev[j + 1] + 1)
.min(curr[j] + 1)
.min(prev[j] + cost);
}
std::mem::swap(&mut prev, &mut curr);
}
prev[n]
}
#[cfg(test)]
mod tests {
use super::*;
fn flag(long: &str, short: Option<char>) -> OptionDecl {
OptionDecl {
long: long.into(),
short,
takes_value: false,
allows_bare: true,
repeatable: false,
choices: None,
}
}
fn value(long: &str, short: Option<char>, bare_ok: bool) -> OptionDecl {
OptionDecl {
long: long.into(),
short,
takes_value: true,
allows_bare: bare_ok,
repeatable: false,
choices: None,
}
}
fn list(long: &str, short: Option<char>) -> OptionDecl {
OptionDecl {
long: long.into(),
short,
takes_value: true,
allows_bare: false,
repeatable: true,
choices: None,
}
}
fn pos(name: &str, required: bool, variadic: bool) -> PositionalDecl {
PositionalDecl {
name: name.into(),
required,
variadic,
choices: None,
}
}
fn argv(parts: &[&str]) -> Vec<String> {
std::iter::once("prog")
.chain(parts.iter().copied())
.map(String::from)
.collect()
}
#[test]
fn parses_long_flag_with_value() {
let spec = ArgsSpec {
options: vec![value("model", None, false)],
positionals: vec![],
..ArgsSpec::default()
};
let out = parse(&spec, &argv(&["--model", "mlp"])).unwrap();
match out.options.get("model") {
Some(OptionState::WithValues(v)) => assert_eq!(v, &vec!["mlp".to_string()]),
other => panic!("expected WithValues, got {:?}", other),
}
}
#[test]
fn parses_long_flag_with_equals() {
let spec = ArgsSpec {
options: vec![value("model", None, false)],
positionals: vec![],
..ArgsSpec::default()
};
let out = parse(&spec, &argv(&["--model=mlp"])).unwrap();
match out.options.get("model") {
Some(OptionState::WithValues(v)) => assert_eq!(v, &vec!["mlp".to_string()]),
_ => panic!("expected WithValues"),
}
}
#[test]
fn bare_flag_without_default_errors() {
let spec = ArgsSpec {
options: vec![value("report", None, false)],
positionals: vec![],
..ArgsSpec::default()
};
let err = parse(&spec, &argv(&["--report"])).unwrap_err();
assert!(err.contains("requires a value"), "got: {err}");
}
#[test]
fn bare_flag_with_default_is_present() {
let spec = ArgsSpec {
options: vec![value("report", None, true)],
positionals: vec![],
..ArgsSpec::default()
};
let out = parse(&spec, &argv(&["--report"])).unwrap();
assert!(matches!(out.options.get("report"), Some(OptionState::BarePresent)));
}
#[test]
fn bool_flag_presence() {
let spec = ArgsSpec {
options: vec![flag("validate", None)],
positionals: vec![],
..ArgsSpec::default()
};
let out = parse(&spec, &argv(&["--validate"])).unwrap();
assert!(matches!(out.options.get("validate"), Some(OptionState::BarePresent)));
}
#[test]
fn bool_flag_rejects_value() {
let spec = ArgsSpec {
options: vec![flag("validate", None)],
positionals: vec![],
..ArgsSpec::default()
};
let err = parse(&spec, &argv(&["--validate=yes"])).unwrap_err();
assert!(err.contains("takes no value"), "got: {err}");
}
#[test]
fn short_flag() {
let spec = ArgsSpec {
options: vec![flag("verbose", Some('v'))],
positionals: vec![],
..ArgsSpec::default()
};
let out = parse(&spec, &argv(&["-v"])).unwrap();
assert!(matches!(out.options.get("verbose"), Some(OptionState::BarePresent)));
}
#[test]
fn short_clustering_for_bool_flags() {
let spec = ArgsSpec {
options: vec![flag("a", Some('a')), flag("b", Some('b'))],
positionals: vec![],
..ArgsSpec::default()
};
let out = parse(&spec, &argv(&["-ab"])).unwrap();
assert!(out.options.contains_key("a"));
assert!(out.options.contains_key("b"));
}
#[test]
fn short_cluster_last_may_take_value() {
let spec = ArgsSpec {
options: vec![flag("a", Some('a')), value("model", Some('m'), false)],
positionals: vec![],
..ArgsSpec::default()
};
let out = parse(&spec, &argv(&["-am", "mlp"])).unwrap();
assert!(out.options.contains_key("a"));
match out.options.get("model") {
Some(OptionState::WithValues(v)) => assert_eq!(v, &vec!["mlp".to_string()]),
_ => panic!("expected model value"),
}
}
#[test]
fn list_option_accumulates_across_repeats_and_commas() {
let spec = ArgsSpec {
options: vec![list("tags", Some('t'))],
positionals: vec![],
..ArgsSpec::default()
};
let out = parse(&spec, &argv(&["--tags", "a,b", "-t", "c"])).unwrap();
match out.options.get("tags") {
Some(OptionState::WithValues(v)) => {
assert_eq!(v, &vec!["a".to_string(), "b".into(), "c".into()]);
}
_ => panic!("expected list values"),
}
}
#[test]
fn positionals_in_order() {
let spec = ArgsSpec {
options: vec![],
positionals: vec![pos("first", true, false), pos("second", false, false)],
..ArgsSpec::default()
};
let out = parse(&spec, &argv(&["a", "b"])).unwrap();
assert_eq!(out.positionals, vec!["a".to_string(), "b".into()]);
}
#[test]
fn missing_required_positional_errors() {
let spec = ArgsSpec {
options: vec![],
positionals: vec![pos("first", true, false)],
..ArgsSpec::default()
};
let err = parse(&spec, &argv(&[])).unwrap_err();
assert!(err.contains("missing required argument"), "got: {err}");
}
#[test]
fn variadic_positional_absorbs_tail() {
let spec = ArgsSpec {
options: vec![],
positionals: vec![pos("files", false, true)],
..ArgsSpec::default()
};
let out = parse(&spec, &argv(&["a", "b", "c"])).unwrap();
assert_eq!(out.positionals, vec!["a".to_string(), "b".into(), "c".into()]);
}
#[test]
fn double_dash_stops_flag_parsing() {
let spec = ArgsSpec {
options: vec![flag("verbose", None)],
positionals: vec![pos("rest", false, true)],
..ArgsSpec::default()
};
let out = parse(&spec, &argv(&["--", "--verbose", "-x"])).unwrap();
assert!(!out.options.contains_key("verbose"));
assert_eq!(out.positionals, vec!["--verbose".to_string(), "-x".into()]);
}
#[test]
fn unknown_flag_suggests_similar() {
let spec = ArgsSpec {
options: vec![value("model", None, false)],
positionals: vec![],
..ArgsSpec::default()
};
let err = parse(&spec, &argv(&["--modl", "mlp"])).unwrap_err();
assert!(err.contains("did you mean"), "got: {err}");
}
#[test]
fn choices_validated_at_parse_time() {
let mut model = value("model", None, false);
model.choices = Some(vec!["mlp".into(), "lenet".into()]);
let spec = ArgsSpec {
options: vec![model],
positionals: vec![],
..ArgsSpec::default()
};
let err = parse(&spec, &argv(&["--model", "foobar"])).unwrap_err();
assert!(err.contains("allowed"), "got: {err}");
}
#[test]
fn bare_dash_is_positional() {
let spec = ArgsSpec {
options: vec![],
positionals: vec![pos("target", true, false)],
..ArgsSpec::default()
};
let out = parse(&spec, &argv(&["-"])).unwrap();
assert_eq!(out.positionals, vec!["-".to_string()]);
}
#[test]
fn scalar_last_write_wins() {
let spec = ArgsSpec {
options: vec![value("model", None, false)],
positionals: vec![],
..ArgsSpec::default()
};
let out = parse(&spec, &argv(&["--model", "a", "--model", "b"])).unwrap();
match out.options.get("model") {
Some(OptionState::WithValues(v)) => assert_eq!(v, &vec!["b".to_string()]),
_ => panic!("expected last-write-wins"),
}
}
}