use std::{
error::Error as StdError,
fmt::{self, Debug},
io::{self, Write},
result,
str::FromStr,
};
use ahash::AHashMap;
use regex::Captures;
use serde::{de, de::IntoDeserializer};
use self::{
Error::{Argv, Deserialize, Help, NoMatch, Usage, Version, WithProgramUsage},
Value::{Counted, List, Plain, Switch},
};
use crate::{cap_or_empty, parse::Parser, synonym::SynonymMap};
#[derive(Debug)]
pub enum Error {
Usage(String),
Argv(String),
NoMatch,
Deserialize(String),
WithProgramUsage(Box<Error>, String),
Help,
Version(String),
}
impl Error {
#[must_use]
pub fn fatal(&self) -> bool {
match *self {
Help | Version(..) => false,
Usage(..) | Argv(..) | NoMatch | Deserialize(..) => true,
WithProgramUsage(ref b, _) => b.fatal(),
}
}
pub fn exit(&self) -> ! {
if self.fatal() {
werr!("{}\n", self);
::std::process::exit(1)
} else {
let _ = writeln!(&mut io::stdout(), "{self}");
::std::process::exit(0)
}
}
}
type Result<T> = result::Result<T, Error>;
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match *self {
WithProgramUsage(ref other, ref usage) => {
let other = other.to_string();
if other.is_empty() {
write!(f, "{usage}")
} else {
write!(f, "{other}\n\n{usage}")
}
}
Help => write!(f, ""),
NoMatch => write!(f, "Invalid arguments."),
Usage(ref s) | Argv(ref s) | Deserialize(ref s) | Version(ref s) => write!(f, "{s}"),
}
}
}
impl StdError for Error {
fn source(&self) -> Option<&(dyn StdError + 'static)> {
match *self {
WithProgramUsage(ref cause, _) => Some(&**cause),
_ => None,
}
}
}
impl de::Error for Error {
fn custom<T: fmt::Display>(msg: T) -> Self {
Error::Deserialize(msg.to_string())
}
}
#[derive(Clone, Debug)]
pub struct Docopt {
p: Parser,
argv: Option<Vec<String>>,
options_first: bool,
help: bool,
version: Option<String>,
}
impl Docopt {
pub fn new<S>(usage: S) -> Result<Docopt>
where
S: ::std::ops::Deref<Target = str>,
{
Parser::new(&usage).map_err(Usage).map(|p| Docopt {
p,
argv: None,
options_first: false,
help: true,
version: None,
})
}
pub fn deserialize<'a, 'de: 'a, D>(&'a self) -> Result<D>
where
D: de::Deserialize<'de>,
{
self.parse().and_then(ArgvMap::deserialize)
}
pub fn parse(&self) -> Result<ArgvMap> {
let argv = self.argv.clone().unwrap_or_else(Docopt::get_argv);
let vals = self
.p
.parse_argv(&argv, self.options_first)
.map_err(|s| self.err_with_usage(Argv(s)))
.and_then(|argv| match self.p.matches(&argv) {
Some(m) => Ok(ArgvMap { map: m }),
None => Err(self.err_with_usage(NoMatch)),
})?;
if self.help && vals.get_bool("--help") {
return Err(self.err_with_full_doc(Help));
}
match self.version {
Some(ref v) if vals.get_bool("--version") => return Err(Version(v.clone())),
_ => {}
}
Ok(vals)
}
pub fn argv<I, S>(mut self, argv: I) -> Docopt
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
self.argv = Some(
argv.into_iter()
.skip(1)
.map(|s| s.as_ref().to_owned())
.collect(),
);
self
}
#[must_use]
pub const fn options_first(mut self, yes: bool) -> Docopt {
self.options_first = yes;
self
}
#[must_use]
pub const fn help(mut self, yes: bool) -> Docopt {
self.help = yes;
self
}
#[must_use]
pub fn version(mut self, version: Option<String>) -> Docopt {
self.version = version;
self
}
#[doc(hidden)]
#[must_use]
pub const fn parser(&self) -> &Parser {
&self.p
}
fn err_with_usage(&self, e: Error) -> Error {
WithProgramUsage(Box::new(e), self.p.usage.trim().into())
}
fn err_with_full_doc(&self, e: Error) -> Error {
WithProgramUsage(Box::new(e), self.p.full_doc.trim().into())
}
fn get_argv() -> Vec<String> {
::std::env::args_os()
.skip(1)
.map(|s| s.to_string_lossy().into_owned())
.collect()
}
}
#[derive(Clone)]
pub struct ArgvMap {
#[doc(hidden)]
pub map: SynonymMap<String, Value>,
}
impl ArgvMap {
pub fn deserialize<'de, T: de::Deserialize<'de>>(self) -> Result<T> {
de::Deserialize::deserialize(&mut Deserializer {
vals: self,
stack: vec![],
})
}
pub fn get_bool(&self, key: &str) -> bool {
self.find(key).is_some_and(Value::as_bool)
}
pub fn get_count(&self, key: &str) -> u64 {
self.find(key).map_or(0, Value::as_count)
}
pub fn get_str(&self, key: &str) -> &str {
self.find(key).map_or("", Value::as_str)
}
pub fn get_vec(&self, key: &str) -> Vec<&str> {
self.find(key).map(Value::as_vec).unwrap_or_default()
}
#[must_use]
pub fn find(&self, key: &str) -> Option<&Value> {
self.map.find(&key.into())
}
#[must_use]
pub fn len(&self) -> usize {
self.map.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.map.len() == 0
}
#[doc(hidden)]
pub fn key_to_struct_field(name: &str) -> String {
decl_regex! {
RE :
r"^(?:--?(?P<flag>\S+)|(?:(?P<argu>\p{Lu}+)|<(?P<argb>[^>]+)>)|(?P<cmd>\S+))$"
;
}
fn sanitize(name: &str) -> String {
name.replace('-', "_")
}
RE.replace(name, |cap: &Captures<'_>| {
let (flag, cmd) = (cap_or_empty(cap, "flag"), cap_or_empty(cap, "cmd"));
let (argu, argb) = (cap_or_empty(cap, "argu"), cap_or_empty(cap, "argb"));
let (prefix, name) = if !flag.is_empty() {
("flag_", flag)
} else if !argu.is_empty() {
("arg_", argu)
} else if !argb.is_empty() {
("arg_", argb)
} else if !cmd.is_empty() {
("cmd_", cmd)
} else {
panic!("Unknown ArgvMap key: '{name}'")
};
let mut prefix = prefix.to_owned();
prefix.push_str(&sanitize(name));
prefix
})
.into_owned()
}
#[doc(hidden)]
pub fn struct_field_to_key(field: &str) -> String {
decl_regex! {
FLAG: r"^flag_";
ARG: r"^arg_";
LETTERS: r"^\p{Lu}+$";
CMD: r"^cmd_";
}
fn desanitize(name: &str) -> String {
name.replace('_', "-")
}
let name = if field.starts_with("flag_") {
let name = FLAG.replace(field, "");
let mut pre_name = (if name.len() == 1 { "-" } else { "--" }).to_owned();
pre_name.push_str(&name);
pre_name
} else if field.starts_with("arg_") {
let name = ARG.replace(field, "").into_owned();
if LETTERS.is_match(&name) {
name
} else {
let mut pre_name = "<".to_owned();
pre_name.push_str(&name);
pre_name.push('>');
pre_name
}
} else if field.starts_with("cmd_") {
CMD.replace(field, "").into_owned()
} else {
panic!("Unrecognized struct field: '{field}'")
};
desanitize(&name)
}
}
impl fmt::Debug for ArgvMap {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.is_empty() {
return write!(f, "{{EMPTY}}");
}
let reverse: AHashMap<&String, &String> =
self.map.synonyms().map(|(from, to)| (to, from)).collect();
let mut keys: Vec<&String> = self.map.keys().collect();
keys.sort();
let mut first = true;
for &k in &keys {
if first {
first = false;
} else {
writeln!(f)?;
}
match reverse.get(&k) {
None => write!(f, "{k} => {:?}", self.map.get(k))?,
Some(s) => write!(f, "{s}, {k} => {:?}", self.map.get(k))?,
}
}
Ok(())
}
}
#[derive(Clone, Debug, PartialEq)]
pub enum Value {
Switch(bool),
Counted(u64),
Plain(Option<String>),
List(Vec<String>),
}
impl Value {
#[must_use]
pub fn as_bool(&self) -> bool {
match *self {
Switch(b) => b,
Counted(n) => n > 0,
Plain(None) => false,
Plain(Some(_)) => true,
List(ref vs) => !vs.is_empty(),
}
}
#[must_use]
pub fn as_count(&self) -> u64 {
match *self {
Switch(b) => u64::from(b), Counted(n) => n,
Plain(None) => 0,
Plain(Some(_)) => 1,
List(ref vs) => vs.len() as u64,
}
}
#[must_use]
pub fn as_str(&self) -> &str {
match *self {
Switch(_) | Counted(_) | Plain(None) | List(_) => "",
Plain(Some(ref s)) => s,
}
}
#[must_use]
pub fn as_vec(&self) -> Vec<&str> {
match *self {
Switch(_) | Counted(_) | Plain(None) => vec![],
Plain(Some(ref s)) => vec![&**s],
List(ref vs) => vs.iter().map(|s| &**s).collect(),
}
}
}
pub struct Deserializer<'de> {
vals: ArgvMap,
stack: Vec<DeserializerItem<'de>>,
}
#[derive(Debug)]
struct DeserializerItem<'de> {
key: String,
struct_field: &'de str,
val: Option<Value>,
}
macro_rules! derr(
($($arg:tt)*) => (return Err(Deserialize(format!($($arg)*))))
);
impl<'de> Deserializer<'de> {
#[inline]
fn push(&mut self, struct_field: &'de str) {
let key = ArgvMap::struct_field_to_key(struct_field);
self.stack.push(DeserializerItem {
key: key.clone(),
struct_field,
val: self.vals.find(&key).cloned(),
});
}
#[inline]
fn pop(&mut self) -> Result<DeserializerItem<'_>> {
match self.stack.pop() {
None => derr!("Could not deserialize value into unknown key."),
Some(it) => Ok(it),
}
}
#[inline]
fn pop_key_val(&mut self) -> Result<(String, Value)> {
let it = self.pop()?;
match it.val {
None => {
derr!(
"Could not find argument '{}' (from struct field '{}').
Note that each struct field must have the right key prefix, which must
be one of `cmd_`, `flag_` or `arg_`.",
it.key,
it.struct_field
)
}
Some(v) => Ok((it.key, v)),
}
}
#[inline]
fn pop_val(&mut self) -> Result<Value> {
let (_, v) = self.pop_key_val()?;
Ok(v)
}
fn pop_number<T>(&mut self, expect: &str) -> Result<T>
where
T: FromStr + ToString,
<T as FromStr>::Err: Debug,
{
let (k, v) = self.pop_key_val()?;
if let Counted(n) = v {
Ok(n.to_string().parse().unwrap())
} else {
let vstr = v.as_str();
if vstr.trim().is_empty() {
Ok("0".parse().unwrap())
} else {
match vstr.parse() {
Err(_) => {
derr!("Could not deserialize '{vstr}' to {expect} for '{k}'.")
}
Ok(v) => Ok(v),
}
}
}
}
fn pop_float(&mut self, expect: &str) -> Result<f64> {
let (k, v) = self.pop_key_val()?;
if let Counted(n) = v {
Ok(n as f64)
} else {
let vstr = v.as_str();
match vstr.parse() {
Err(_) => {
derr!("Could not deserialize '{vstr}' to {expect} for '{k}'.")
}
Ok(v) => Ok(v),
}
}
}
}
macro_rules! deserialize_num {
($name:ident, $method:ident, $ty:ty) => {
fn $name<V>(self, visitor: V) -> Result<V::Value>
where
V: de::Visitor<'de>,
{
visitor.$method(self.pop_number::<$ty>(stringify!($ty)).map(|n| n as $ty)?)
}
};
}
impl<'de> ::serde::Deserializer<'de> for &mut Deserializer<'de> {
type Error = Error;
fn deserialize_any<V>(self, _visitor: V) -> Result<V::Value>
where
V: de::Visitor<'de>,
{
unimplemented!()
}
fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value>
where
V: de::Visitor<'de>,
{
visitor.visit_bool(self.pop_val().map(|v| v.as_bool())?)
}
deserialize_num!(deserialize_i8, visit_i8, i8);
deserialize_num!(deserialize_i16, visit_i16, i16);
deserialize_num!(deserialize_i32, visit_i32, i32);
deserialize_num!(deserialize_i64, visit_i64, i64);
deserialize_num!(deserialize_u8, visit_u8, u8);
deserialize_num!(deserialize_u16, visit_u16, u16);
deserialize_num!(deserialize_u32, visit_u32, u32);
deserialize_num!(deserialize_u64, visit_u64, u64);
fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value>
where
V: de::Visitor<'de>,
{
visitor.visit_f32(self.pop_float("f32").map(|n| n as f32)?)
}
fn deserialize_f64<V>(self, visitor: V) -> Result<V::Value>
where
V: de::Visitor<'de>,
{
visitor.visit_f64(self.pop_float("f64")?)
}
fn deserialize_char<V>(self, visitor: V) -> Result<V::Value>
where
V: de::Visitor<'de>,
{
let (k, v) = self.pop_key_val()?;
let vstr = v.as_str();
match vstr.chars().count() {
1 => visitor.visit_char(vstr.chars().next().unwrap()),
_ => derr!("Could not deserialize '{vstr}' into char for '{k}'."),
}
}
fn deserialize_str<V>(self, visitor: V) -> Result<V::Value>
where
V: de::Visitor<'de>,
{
let s = self.pop_val()?;
visitor.visit_str(s.as_str())
}
fn deserialize_string<V>(self, visitor: V) -> Result<V::Value>
where
V: de::Visitor<'de>,
{
self.deserialize_str(visitor)
}
fn deserialize_bytes<V>(self, _visitor: V) -> Result<V::Value>
where
V: de::Visitor<'de>,
{
unimplemented!()
}
fn deserialize_byte_buf<V>(self, _visitor: V) -> Result<V::Value>
where
V: de::Visitor<'de>,
{
unimplemented!()
}
fn deserialize_option<V>(self, visitor: V) -> Result<V::Value>
where
V: de::Visitor<'de>,
{
let is_some = match self.stack.last() {
None => derr!("Could not deserialize value into unknown key."),
Some(it) => it.val.as_ref().is_some_and(Value::as_bool),
};
if is_some {
visitor.visit_some(self)
} else {
visitor.visit_none()
}
}
fn deserialize_unit<V>(self, _visitor: V) -> Result<V::Value>
where
V: de::Visitor<'de>,
{
derr!("Cannot deserialize a Docopt value into `()` (unit type).")
}
fn deserialize_unit_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value>
where
V: de::Visitor<'de>,
{
visitor.visit_unit()
}
fn deserialize_newtype_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value>
where
V: de::Visitor<'de>,
{
visitor.visit_newtype_struct(self)
}
fn deserialize_tuple<V>(self, _len: usize, _visitor: V) -> Result<V::Value>
where
V: de::Visitor<'de>,
{
unimplemented!()
}
fn deserialize_tuple_struct<V>(
self,
_name: &'static str,
_len: usize,
_visitor: V,
) -> Result<V::Value>
where
V: de::Visitor<'de>,
{
unimplemented!()
}
fn deserialize_map<V>(self, _visitor: V) -> Result<V::Value>
where
V: de::Visitor<'de>,
{
unimplemented!()
}
fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value>
where
V: de::Visitor<'de>,
{
let (key, struct_field, val) = match self.stack.pop() {
None => derr!("Could not deserialize value into unknown key."),
Some(DeserializerItem {
key,
struct_field,
val,
}) => (key, struct_field, val),
};
let list = val.unwrap_or(List(vec![]));
let vals = list.as_vec();
for val in vals.iter().rev() {
self.stack.push(DeserializerItem {
key: key.clone(),
struct_field,
val: Some(Plain(Some((*val).into()))),
});
}
visitor.visit_seq(SeqDeserializer::new(self, vals.len()))
}
fn deserialize_struct<V>(
self,
_: &str,
fields: &'static [&'static str],
visitor: V,
) -> Result<V::Value>
where
V: de::Visitor<'de>,
{
visitor.visit_seq(StructDeserializer::new(self, fields))
}
fn deserialize_enum<V>(self, _name: &str, variants: &[&str], visitor: V) -> Result<V::Value>
where
V: de::Visitor<'de>,
{
let v = self.pop_val()?.as_str().to_lowercase();
let Some(s) = variants.iter().find(|&n| n.to_lowercase() == v) else {
derr!("Could not match '{v}' with any of the allowed variants: {variants:?}")
};
visitor.visit_enum(s.into_deserializer())
}
fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value>
where
V: de::Visitor<'de>,
{
self.deserialize_str(visitor)
}
fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value>
where
V: de::Visitor<'de>,
{
self.deserialize_any(visitor)
}
}
struct SeqDeserializer<'a, 'de: 'a> {
de: &'a mut Deserializer<'de>,
len: usize,
}
impl<'a, 'de> SeqDeserializer<'a, 'de> {
fn new(de: &'a mut Deserializer<'de>, len: usize) -> Self {
SeqDeserializer { de, len }
}
}
impl<'de> de::SeqAccess<'de> for SeqDeserializer<'_, 'de> {
type Error = Error;
fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
where
T: de::DeserializeSeed<'de>,
{
if self.len == 0 {
return Ok(None);
}
self.len -= 1;
seed.deserialize(&mut *self.de).map(Some)
}
fn size_hint(&self) -> Option<usize> {
Some(self.len)
}
}
struct StructDeserializer<'a, 'de: 'a> {
de: &'a mut Deserializer<'de>,
fields: &'static [&'static str],
}
impl<'a, 'de> StructDeserializer<'a, 'de> {
fn new(de: &'a mut Deserializer<'de>, fields: &'static [&'static str]) -> Self {
StructDeserializer { de, fields }
}
}
impl<'de> de::SeqAccess<'de> for StructDeserializer<'_, 'de> {
type Error = Error;
fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
where
T: de::DeserializeSeed<'de>,
{
if self.fields.is_empty() {
return Ok(None);
}
self.de.push(self.fields[0]);
self.fields = &self.fields[1..];
seed.deserialize(&mut *self.de).map(Some)
}
fn size_hint(&self) -> Option<usize> {
Some(self.fields.len())
}
}