use std::ffi::{OsStr, OsString};
use std::fmt;
use crate::options::error::{Choices, OptionsError};
pub type ShortArg = u8;
pub type LongArg = &'static str;
pub type Values = &'static [&'static str];
#[derive(PartialEq, Eq, Debug, Copy, Clone)]
pub enum Flag {
Short(ShortArg),
Long(LongArg),
}
impl Flag {
pub fn matches(&self, arg: &Arg) -> bool {
match self {
Self::Short(short) => arg.short == Some(*short),
Self::Long(long) => arg.long == *long,
}
}
}
impl fmt::Display for Flag {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
match self {
Self::Short(short) => write!(f, "-{}", *short as char),
Self::Long(long) => write!(f, "--{long}"),
}
}
}
#[derive(PartialEq, Eq, Debug, Copy, Clone)]
pub enum Strictness {
ComplainAboutRedundantArguments,
UseLastArguments,
}
#[derive(PartialEq, Eq, Debug, Copy, Clone)]
pub enum TakesValue {
Necessary(Option<Values>),
Forbidden,
Optional(Option<Values>, &'static str),
}
#[derive(PartialEq, Eq, Debug, Copy, Clone)]
pub struct Arg {
pub short: Option<ShortArg>,
pub long: LongArg,
pub takes_value: TakesValue,
}
impl fmt::Display for Arg {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
write!(f, "--{}", self.long)?;
if let Some(short) = self.short {
write!(f, " (-{})", short as char)?;
}
Ok(())
}
}
#[derive(PartialEq, Eq, Debug)]
pub struct Args(pub &'static [&'static Arg]);
impl Args {
pub fn parse<'args, I>(
&self,
inputs: I,
strictness: Strictness,
) -> Result<Matches<'args>, ParseError>
where
I: IntoIterator<Item = &'args OsStr>,
{
let mut parsing = true;
let mut result_flags = Vec::new();
let mut frees: Vec<&OsStr> = Vec::new();
let mut inputs = inputs.into_iter().peekable();
while let Some(arg) = inputs.next() {
let bytes = os_str_to_bytes(arg);
if !parsing {
frees.push(arg);
} else if arg == "--" {
parsing = false;
}
else if bytes.starts_with(b"--") {
let long_arg_name = bytes_to_os_str(&bytes[2..]);
if let Some((before, after)) = split_on_equals(long_arg_name) {
let arg = self.lookup_long(before)?;
let flag = Flag::Long(arg.long);
match arg.takes_value {
TakesValue::Necessary(_) | TakesValue::Optional(_, _) => {
result_flags.push((flag, Some(after)));
}
TakesValue::Forbidden => return Err(ParseError::ForbiddenValue { flag }),
}
}
else {
let arg = self.lookup_long(long_arg_name)?;
let flag = Flag::Long(arg.long);
match arg.takes_value {
TakesValue::Forbidden => {
result_flags.push((flag, None));
}
TakesValue::Necessary(values) => {
if let Some(next_arg) = inputs.next() {
result_flags.push((flag, Some(next_arg)));
} else {
return Err(ParseError::NeedsValue { flag, values });
}
}
TakesValue::Optional(values, default) => match inputs.peek() {
Some(next_arg) if is_optional_arg(next_arg, values) => {
result_flags.push((flag, Some(inputs.next().unwrap())));
}
_ => {
result_flags
.push((flag, Some(bytes_to_os_str(default.as_bytes()))));
}
},
}
}
}
else if bytes.starts_with(b"-") && arg != "-" {
let short_arg = bytes_to_os_str(&bytes[1..]);
if let Some((before, after)) = split_on_equals(short_arg) {
let (arg_with_value, other_args) =
os_str_to_bytes(before).split_last().unwrap();
for byte in other_args {
let arg = self.lookup_short(*byte)?;
let flag = Flag::Short(*byte);
match arg.takes_value {
TakesValue::Forbidden => {
result_flags.push((flag, None));
}
TakesValue::Optional(_, default) => {
result_flags
.push((flag, Some(bytes_to_os_str(default.as_bytes()))));
}
TakesValue::Necessary(values) => {
return Err(ParseError::NeedsValue { flag, values });
}
}
}
let arg = self.lookup_short(*arg_with_value)?;
let flag = Flag::Short(arg.short.unwrap());
match arg.takes_value {
TakesValue::Necessary(_) | TakesValue::Optional(_, _) => {
result_flags.push((flag, Some(after)));
}
TakesValue::Forbidden => {
return Err(ParseError::ForbiddenValue { flag });
}
}
}
else {
for (index, byte) in bytes.iter().enumerate().skip(1) {
let arg = self.lookup_short(*byte)?;
let flag = Flag::Short(*byte);
match arg.takes_value {
TakesValue::Forbidden => {
result_flags.push((flag, None));
}
TakesValue::Necessary(values) => {
if index < bytes.len() - 1 {
let remnants = &bytes[index + 1..];
result_flags.push((flag, Some(bytes_to_os_str(remnants))));
break;
} else if let Some(next_arg) = inputs.next() {
result_flags.push((flag, Some(next_arg)));
} else {
match arg.takes_value {
TakesValue::Forbidden | TakesValue::Optional(_, _) => {
unreachable!()
}
TakesValue::Necessary(_) => {
return Err(ParseError::NeedsValue { flag, values });
}
}
}
}
TakesValue::Optional(values, default) => {
if index < bytes.len() - 1 {
let remnants = bytes_to_os_str(&bytes[index + 1..]);
if is_optional_arg(remnants, values) {
result_flags.push((flag, Some(remnants)));
} else {
return Err(ParseError::ForbiddenValue { flag });
}
break;
} else if let Some(next_arg) = inputs.peek() {
if is_optional_arg(next_arg, values) {
result_flags.push((flag, Some(inputs.next().unwrap())));
} else {
result_flags.push((flag, Some(OsStr::new(default))));
}
} else {
match arg.takes_value {
TakesValue::Forbidden | TakesValue::Necessary(_) => {
unreachable!()
}
TakesValue::Optional(_, default) => {
result_flags.push((flag, Some(OsStr::new(default))));
}
}
}
}
}
}
}
}
else {
frees.push(arg);
}
}
Ok(Matches {
frees,
flags: MatchedFlags {
flags: result_flags,
strictness,
},
})
}
fn lookup_short(&self, short: ShortArg) -> Result<&Arg, ParseError> {
match self.0.iter().find(|arg| arg.short == Some(short)) {
Some(arg) => Ok(arg),
None => Err(ParseError::UnknownShortArgument { attempt: short }),
}
}
fn lookup_long(&self, long: &OsStr) -> Result<&Arg, ParseError> {
match self.0.iter().find(|arg| arg.long == long) {
Some(arg) => Ok(arg),
None => Err(ParseError::UnknownArgument {
attempt: long.to_os_string(),
}),
}
}
}
fn is_optional_arg(value: &OsStr, values: Option<&[&str]>) -> bool {
match (values, value.to_str()) {
(Some(values), Some(value)) => values.contains(&value),
_ => false,
}
}
#[derive(PartialEq, Eq, Debug)]
pub struct Matches<'args> {
pub flags: MatchedFlags<'args>,
pub frees: Vec<&'args OsStr>,
}
#[derive(PartialEq, Eq, Debug)]
pub struct MatchedFlags<'args> {
flags: Vec<(Flag, Option<&'args OsStr>)>,
strictness: Strictness,
}
impl MatchedFlags<'_> {
pub fn has(&self, arg: &'static Arg) -> Result<bool, OptionsError> {
self.has_where(|flag| flag.matches(arg))
.map(|flag| flag.is_some())
}
pub fn has_where<P>(&self, predicate: P) -> Result<Option<&Flag>, OptionsError>
where
P: Fn(&Flag) -> bool,
{
if self.is_strict() {
let all = self
.flags
.iter()
.filter(|tuple| tuple.1.is_none() && predicate(&tuple.0))
.collect::<Vec<_>>();
if all.len() < 2 {
Ok(all.first().map(|t| &t.0))
} else {
Err(OptionsError::Duplicate(all[0].0, all[1].0))
}
} else {
Ok(self.has_where_any(predicate))
}
}
pub fn has_where_any<P>(&self, predicate: P) -> Option<&Flag>
where
P: Fn(&Flag) -> bool,
{
self.flags
.iter()
.rev()
.find(|tuple| tuple.1.is_none() && predicate(&tuple.0))
.map(|tuple| &tuple.0)
}
pub fn get(&self, arg: &'static Arg) -> Result<Option<&OsStr>, OptionsError> {
self.get_where(|flag| flag.matches(arg))
}
pub fn get_where<P>(&self, predicate: P) -> Result<Option<&OsStr>, OptionsError>
where
P: Fn(&Flag) -> bool,
{
if self.is_strict() {
let those = self
.flags
.iter()
.filter(|tuple| tuple.1.is_some() && predicate(&tuple.0))
.collect::<Vec<_>>();
if those.len() < 2 {
Ok(those.first().copied().map(|t| t.1.unwrap()))
} else {
Err(OptionsError::Duplicate(those[0].0, those[1].0))
}
} else {
let found = self
.flags
.iter()
.rev()
.find(|tuple| tuple.1.is_some() && predicate(&tuple.0))
.map(|tuple| tuple.1.unwrap());
Ok(found)
}
}
pub fn count(&self, arg: &Arg) -> usize {
self.flags
.iter()
.filter(|tuple| tuple.0.matches(arg))
.count()
}
pub fn is_strict(&self) -> bool {
self.strictness == Strictness::ComplainAboutRedundantArguments
}
}
#[derive(PartialEq, Eq, Debug)]
pub enum ParseError {
NeedsValue { flag: Flag, values: Option<Values> },
ForbiddenValue { flag: Flag },
UnknownShortArgument { attempt: ShortArg },
UnknownArgument { attempt: OsString },
}
impl fmt::Display for ParseError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::NeedsValue { flag, values: None } => write!(f, "Flag {flag} needs a value"),
Self::NeedsValue {
flag,
values: Some(cs),
} => write!(f, "Flag {flag} needs a value ({})", Choices(cs)),
Self::ForbiddenValue { flag } => write!(f, "Flag {flag} cannot take a value"),
Self::UnknownShortArgument { attempt } => {
write!(f, "Unknown argument -{}", *attempt as char)
}
Self::UnknownArgument { attempt } => {
write!(f, "Unknown argument --{}", attempt.to_string_lossy())
}
}
}
}
#[cfg(unix)]
fn os_str_to_bytes(s: &OsStr) -> &[u8] {
use std::os::unix::ffi::OsStrExt;
s.as_bytes()
}
#[cfg(unix)]
fn bytes_to_os_str(b: &[u8]) -> &OsStr {
use std::os::unix::ffi::OsStrExt;
OsStr::from_bytes(b)
}
#[cfg(windows)]
fn os_str_to_bytes(s: &OsStr) -> &[u8] {
return s.to_str().unwrap().as_bytes();
}
#[cfg(windows)]
fn bytes_to_os_str(b: &[u8]) -> &OsStr {
use std::str;
return OsStr::new(str::from_utf8(b).unwrap());
}
fn split_on_equals(input: &OsStr) -> Option<(&OsStr, &OsStr)> {
if let Some(index) = os_str_to_bytes(input).iter().position(|elem| *elem == b'=') {
let (before, after) = os_str_to_bytes(input).split_at(index);
if !before.is_empty() && after.len() >= 2 {
return Some((bytes_to_os_str(before), bytes_to_os_str(&after[1..])));
}
}
None
}
#[cfg(test)]
mod split_test {
use super::split_on_equals;
use std::ffi::{OsStr, OsString};
macro_rules! test_split {
($name:ident: $input:expr => None) => {
#[test]
fn $name() {
assert_eq!(split_on_equals(&OsString::from($input)), None);
}
};
($name:ident: $input:expr => $before:expr, $after:expr) => {
#[test]
fn $name() {
assert_eq!(
split_on_equals(&OsString::from($input)),
Some((OsStr::new($before), OsStr::new($after)))
);
}
};
}
test_split!(empty: "" => None);
test_split!(letter: "a" => None);
test_split!(just: "=" => None);
test_split!(intro: "=bbb" => None);
test_split!(denou: "aaa=" => None);
test_split!(equals: "aaa=bbb" => "aaa", "bbb");
test_split!(sort: "--sort=size" => "--sort", "size");
test_split!(more: "this=that=other" => "this", "that=other");
}
#[cfg(test)]
mod parse_test {
use super::*;
macro_rules! test {
($name:ident: $inputs:expr => frees: $frees:expr, flags: $flags:expr) => {
#[test]
fn $name() {
let inputs: &[&'static str] = $inputs.as_ref();
let inputs = inputs.iter().map(OsStr::new);
let frees: &[&'static str] = $frees.as_ref();
let frees = frees.iter().map(OsStr::new).collect();
let flags = <[_]>::into_vec(Box::new($flags));
let strictness = Strictness::UseLastArguments; let got = Args(TEST_ARGS).parse(inputs, strictness);
let flags = MatchedFlags { flags, strictness };
let expected = Ok(Matches { frees, flags });
assert_eq!(got, expected);
}
};
($name:ident: $inputs:expr => error $error:expr) => {
#[test]
fn $name() {
use self::ParseError::*;
let inputs = $inputs.iter().map(OsStr::new);
let strictness = Strictness::UseLastArguments; let got = Args(TEST_ARGS).parse(inputs, strictness);
assert_eq!(got, Err($error));
}
};
}
const SUGGESTIONS: Values = &["example"];
#[rustfmt::skip]
static TEST_ARGS: &[&Arg] = &[
&Arg { short: Some(b'l'), long: "long", takes_value: TakesValue::Forbidden },
&Arg { short: Some(b'v'), long: "verbose", takes_value: TakesValue::Forbidden },
&Arg { short: Some(b'c'), long: "count", takes_value: TakesValue::Necessary(None) },
&Arg { short: Some(b't'), long: "type", takes_value: TakesValue::Necessary(Some(SUGGESTIONS))},
&Arg { short: Some(b'o'), long: "optional", takes_value: TakesValue::Optional(Some(&["all", "some", "none"]), "all")}
];
test!(empty: [] => frees: [], flags: []);
test!(one_arg: ["exa"] => frees: [ "exa" ], flags: []);
test!(one_dash: ["-"] => frees: [ "-" ], flags: []);
test!(two_dashes: ["--"] => frees: [], flags: []);
test!(two_file: ["--", "file"] => frees: [ "file" ], flags: []);
test!(two_arg_l: ["--", "--long"] => frees: [ "--long" ], flags: []);
test!(two_arg_s: ["--", "-l"] => frees: [ "-l" ], flags: []);
test!(long: ["--long"] => frees: [], flags: [ (Flag::Long("long"), None) ]);
test!(long_then: ["--long", "4"] => frees: [ "4" ], flags: [ (Flag::Long("long"), None) ]);
test!(long_two: ["--long", "--verbose"] => frees: [], flags: [ (Flag::Long("long"), None), (Flag::Long("verbose"), None) ]);
test!(bad_equals: ["--long=equals"] => error ForbiddenValue { flag: Flag::Long("long") });
test!(no_arg: ["--count"] => error NeedsValue { flag: Flag::Long("count"), values: None });
test!(arg_equals: ["--count=4"] => frees: [], flags: [ (Flag::Long("count"), Some(OsStr::new("4"))) ]);
test!(arg_then: ["--count", "4"] => frees: [], flags: [ (Flag::Long("count"), Some(OsStr::new("4"))) ]);
test!(no_arg_s: ["--type"] => error NeedsValue { flag: Flag::Long("type"), values: Some(SUGGESTIONS) });
test!(arg_equals_s: ["--type=exa"] => frees: [], flags: [ (Flag::Long("type"), Some(OsStr::new("exa"))) ]);
test!(arg_then_s: ["--type", "exa"] => frees: [], flags: [ (Flag::Long("type"), Some(OsStr::new("exa"))) ]);
test!(short: ["-l"] => frees: [], flags: [ (Flag::Short(b'l'), None) ]);
test!(short_then: ["-l", "4"] => frees: [ "4" ], flags: [ (Flag::Short(b'l'), None) ]);
test!(short_two: ["-lv"] => frees: [], flags: [ (Flag::Short(b'l'), None), (Flag::Short(b'v'), None) ]);
test!(mixed: ["-v", "--long"] => frees: [], flags: [ (Flag::Short(b'v'), None), (Flag::Long("long"), None) ]);
test!(bad_short: ["-l=equals"] => error ForbiddenValue { flag: Flag::Short(b'l') });
test!(short_none: ["-c"] => error NeedsValue { flag: Flag::Short(b'c'), values: None });
test!(short_arg_eq: ["-c=4"] => frees: [], flags: [(Flag::Short(b'c'), Some(OsStr::new("4"))) ]);
test!(short_arg_then: ["-c", "4"] => frees: [], flags: [(Flag::Short(b'c'), Some(OsStr::new("4"))) ]);
test!(short_two_together: ["-lctwo"] => frees: [], flags: [(Flag::Short(b'l'), None), (Flag::Short(b'c'), Some(OsStr::new("two"))) ]);
test!(short_two_equals: ["-lc=two"] => frees: [], flags: [(Flag::Short(b'l'), None), (Flag::Short(b'c'), Some(OsStr::new("two"))) ]);
test!(short_two_next: ["-lc", "two"] => frees: [], flags: [(Flag::Short(b'l'), None), (Flag::Short(b'c'), Some(OsStr::new("two"))) ]);
test!(short_none_s: ["-t"] => error NeedsValue { flag: Flag::Short(b't'), values: Some(SUGGESTIONS) });
test!(short_two_together_s: ["-texa"] => frees: [], flags: [(Flag::Short(b't'), Some(OsStr::new("exa"))) ]);
test!(short_two_equals_s: ["-t=exa"] => frees: [], flags: [(Flag::Short(b't'), Some(OsStr::new("exa"))) ]);
test!(short_two_next_s: ["-t", "exa"] => frees: [], flags: [(Flag::Short(b't'), Some(OsStr::new("exa"))) ]);
test!(unknown_long: ["--quiet"] => error UnknownArgument { attempt: OsString::from("quiet") });
test!(unknown_long_eq: ["--quiet=shhh"] => error UnknownArgument { attempt: OsString::from("quiet") });
test!(unknown_short: ["-q"] => error UnknownShortArgument { attempt: b'q' });
test!(unknown_short_2nd: ["-lq"] => error UnknownShortArgument { attempt: b'q' });
test!(unknown_short_eq: ["-q=shhh"] => error UnknownShortArgument { attempt: b'q' });
test!(unknown_short_2nd_eq: ["-lq=shhh"] => error UnknownShortArgument { attempt: b'q' });
test!(optional: ["--optional"] => frees: [], flags: [(Flag::Long("optional"), Some(OsStr::new("all")))]);
test!(optional_2: ["--optional", "-l"] => frees: [], flags: [ (Flag::Long("optional"), Some(OsStr::new("all"))), (Flag::Short(b'l'), None)]);
test!(optional_3: ["--optional", "path"] => frees: ["path"], flags: [(Flag::Long("optional"), Some(OsStr::new("all")))]);
test!(optional_with_eq: ["--optional=none"] => frees: [], flags: [(Flag::Long("optional"), Some(OsStr::new("none")))]);
test!(optional_wo_eq: ["--optional", "none"] => frees: [], flags: [(Flag::Long("optional"), Some(OsStr::new("none")))]);
test!(short_opt: ["-o"] => frees: [], flags: [(Flag::Short(b'o'), Some(OsStr::new("all")))]);
test!(short_opt_value: ["-onone"] => frees: [], flags: [(Flag::Short(b'o'), Some(OsStr::new("none")))]);
test!(short_forbidden: ["-opath"] => error ForbiddenValue { flag: Flag::Short(b'o') });
test!(short_allowed: ["-o","path"] => frees: ["path"], flags: [(Flag::Short(b'o'), Some(OsStr::new("all")))]);
}
#[cfg(test)]
mod matches_test {
use super::*;
macro_rules! test {
($name:ident: $input:expr, has $param:expr => $result:expr) => {
#[test]
fn $name() {
let flags = MatchedFlags {
flags: $input.to_vec(),
strictness: Strictness::UseLastArguments,
};
assert_eq!(flags.has(&$param), Ok($result));
}
};
}
static VERBOSE: Arg = Arg {
short: Some(b'v'),
long: "verbose",
takes_value: TakesValue::Forbidden,
};
static COUNT: Arg = Arg {
short: Some(b'c'),
long: "count",
takes_value: TakesValue::Necessary(None),
};
test!(short_never: [], has VERBOSE => false);
test!(short_once: [(Flag::Short(b'v'), None)], has VERBOSE => true);
test!(short_twice: [(Flag::Short(b'v'), None), (Flag::Short(b'v'), None)], has VERBOSE => true);
test!(long_once: [(Flag::Long("verbose"), None)], has VERBOSE => true);
test!(long_twice: [(Flag::Long("verbose"), None), (Flag::Long("verbose"), None)], has VERBOSE => true);
test!(long_mixed: [(Flag::Long("verbose"), None), (Flag::Short(b'v'), None)], has VERBOSE => true);
#[test]
fn only_count() {
let everything = OsString::from("everything");
let flags = MatchedFlags {
flags: vec![(Flag::Short(b'c'), Some(&*everything))],
strictness: Strictness::UseLastArguments,
};
assert_eq!(flags.get(&COUNT), Ok(Some(&*everything)));
}
#[test]
fn rightmost_count() {
let everything = OsString::from("everything");
let nothing = OsString::from("nothing");
let flags = MatchedFlags {
flags: vec![
(Flag::Short(b'c'), Some(&*everything)),
(Flag::Short(b'c'), Some(&*nothing)),
],
strictness: Strictness::UseLastArguments,
};
assert_eq!(flags.get(&COUNT), Ok(Some(&*nothing)));
}
#[test]
fn no_count() {
let flags = MatchedFlags {
flags: Vec::new(),
strictness: Strictness::UseLastArguments,
};
assert!(!flags.has(&COUNT).unwrap());
}
}