use std::collections::HashMap;
use std::error::Error as StdError;
use std::fmt::{self, Debug};
use std::io::{self, Write};
use std::str::FromStr;
use std::result;
use lazy_static::lazy_static;
use regex::{Captures, Regex};
use serde::de;
use serde::de::IntoDeserializer;
use crate::parse::Parser;
use crate::synonym::SynonymMap;
use self::Value::{Switch, Counted, Plain, List};
use self::Error::{Usage, Argv, NoMatch, Deserialize, WithProgramUsage, Help, Version};
use crate::cap_or_empty;
#[derive(Debug)]
pub enum Error {
Usage(String),
Argv(String),
NoMatch,
Deserialize(String),
WithProgramUsage(Box<Error>, String),
Help,
Version(String),
}
impl Error {
pub fn fatal(&self) -> bool {
match *self {
Help | Version(..) => false,
Usage(..) | Argv(..) | NoMatch | Deserialize(..) => true,
WithProgramUsage(ref b, _) => b.fatal(),
}
}
pub fn exit(&self) -> ! {
if self.fatal() {
werr!("{}\n", self);
::std::process::exit(1)
} else {
let _ = writeln!(&mut io::stdout(), "{}", self);
::std::process::exit(0)
}
}
}
type Result<T> = result::Result<T, Error>;
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match *self {
WithProgramUsage(ref other, ref usage) => {
let other = other.to_string();
if other.is_empty() {
write!(f, "{}", usage)
} else {
write!(f, "{}\n\n{}", other, usage)
}
}
Help => write!(f, ""),
NoMatch => write!(f, "Invalid arguments."),
Usage(ref s) |
Argv(ref s) |
Deserialize(ref s) |
Version(ref s) => write!(f, "{}", s),
}
}
}
impl StdError for Error {
fn source(&self) -> Option<&(dyn StdError + 'static)> {
match *self {
WithProgramUsage(ref cause, _) => Some(&**cause),
_ => None,
}
}
}
impl de::Error for Error {
fn custom<T: fmt::Display>(msg: T) -> Self {
Error::Deserialize(msg.to_string())
}
}
#[derive(Clone, Debug)]
pub struct Docopt {
p: Parser,
argv: Option<Vec<String>>,
options_first: bool,
help: bool,
version: Option<String>,
}
impl Docopt {
pub fn new<S>(usage: S) -> Result<Docopt>
where S: ::std::ops::Deref<Target=str> {
Parser::new(usage.deref())
.map_err(Usage)
.map(|p| Docopt {
p: p,
argv: None,
options_first: false,
help: true,
version: None,
})
}
pub fn deserialize<'a, 'de: 'a, D>(&'a self) -> Result<D>
where D: de::Deserialize<'de>
{
self.parse().and_then(|vals| vals.deserialize())
}
pub fn parse(&self) -> Result<ArgvMap> {
let argv = self.argv.clone().unwrap_or_else(Docopt::get_argv);
let vals =
self.p.parse_argv(argv, self.options_first)
.map_err(|s| self.err_with_usage(Argv(s)))
.and_then(|argv|
match self.p.matches(&argv) {
Some(m) => Ok(ArgvMap { map: m }),
None => Err(self.err_with_usage(NoMatch)),
})?;
if self.help && vals.get_bool("--help") {
return Err(self.err_with_full_doc(Help));
}
match self.version {
Some(ref v) if vals.get_bool("--version") => {
return Err(Version(v.clone()))
}
_ => {},
}
Ok(vals)
}
pub fn argv<I, S>(mut self, argv: I) -> Docopt
where I: IntoIterator<Item=S>, S: AsRef<str> {
self.argv = Some(
argv.into_iter().skip(1).map(|s| s.as_ref().to_owned()).collect()
);
self
}
pub fn options_first(mut self, yes: bool) -> Docopt {
self.options_first = yes;
self
}
pub fn help(mut self, yes: bool) -> Docopt {
self.help = yes;
self
}
pub fn version(mut self, version: Option<String>) -> Docopt {
self.version = version;
self
}
#[doc(hidden)]
pub fn parser(&self) -> &Parser {
&self.p
}
fn err_with_usage(&self, e: Error) -> Error {
WithProgramUsage(
Box::new(e), self.p.usage.trim().into())
}
fn err_with_full_doc(&self, e: Error) -> Error {
WithProgramUsage(
Box::new(e), self.p.full_doc.trim().into())
}
fn get_argv() -> Vec<String> {
::std::env::args().skip(1).collect()
}
}
#[derive(Clone)]
pub struct ArgvMap {
#[doc(hidden)]
pub map: SynonymMap<String, Value>,
}
impl ArgvMap {
pub fn deserialize<'de, T: de::Deserialize<'de>>(self) -> Result<T> {
de::Deserialize::deserialize(&mut Deserializer {
vals: self,
stack: vec![],
})
}
pub fn get_bool(&self, key: &str) -> bool {
self.find(key).map_or(false, |v| v.as_bool())
}
pub fn get_count(&self, key: &str) -> u64 {
self.find(key).map_or(0, |v| v.as_count())
}
pub fn get_str(&self, key: &str) -> &str {
self.find(key).map_or("", |v| v.as_str())
}
pub fn get_vec(&self, key: &str) -> Vec<&str> {
self.find(key).map(|v| v.as_vec()).unwrap_or(vec!())
}
pub fn find(&self, key: &str) -> Option<&Value> {
self.map.find(&key.into())
}
pub fn len(&self) -> usize {
self.map.len()
}
#[doc(hidden)]
pub fn key_to_struct_field(name: &str) -> String {
lazy_static! {
static ref RE: Regex = regex!(
r"^(?:--?(?P<flag>\S+)|(?:(?P<argu>\p{Lu}+)|<(?P<argb>[^>]+)>)|(?P<cmd>\S+))$"
);
}
fn sanitize(name: &str) -> String {
name.replace("-", "_")
}
RE.replace(name, |cap: &Captures<'_>| {
let (flag, cmd) = (
cap_or_empty(cap, "flag"),
cap_or_empty(cap, "cmd"),
);
let (argu, argb) = (
cap_or_empty(cap, "argu"),
cap_or_empty(cap, "argb"),
);
let (prefix, name) =
if !flag.is_empty() {
("flag_", flag)
} else if !argu.is_empty() {
("arg_", argu)
} else if !argb.is_empty() {
("arg_", argb)
} else if !cmd.is_empty() {
("cmd_", cmd)
} else {
panic!("Unknown ArgvMap key: '{}'", name)
};
let mut prefix = prefix.to_owned();
prefix.push_str(&sanitize(name));
prefix
}).into_owned()
}
#[doc(hidden)]
pub fn struct_field_to_key(field: &str) -> String {
lazy_static! {
static ref FLAG: Regex = regex!(r"^flag_");
static ref ARG: Regex = regex!(r"^arg_");
static ref LETTERS: Regex = regex!(r"^\p{Lu}+$");
static ref CMD: Regex = regex!(r"^cmd_");
}
fn desanitize(name: &str) -> String {
name.replace("_", "-")
}
let name =
if field.starts_with("flag_") {
let name = FLAG.replace(field, "");
let mut pre_name = (if name.len() == 1 { "-" } else { "--" })
.to_owned();
pre_name.push_str(&*name);
pre_name
} else if field.starts_with("arg_") {
let name = ARG.replace(field, "").into_owned();
if LETTERS.is_match(&name) {
name
} else {
let mut pre_name = "<".to_owned();
pre_name.push_str(&*name);
pre_name.push('>');
pre_name
}
} else if field.starts_with("cmd_") {
CMD.replace(field, "").into_owned()
} else {
panic!("Unrecognized struct field: '{}'", field)
};
desanitize(&*name)
}
}
impl fmt::Debug for ArgvMap {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.len() == 0 {
return write!(f, "{{EMPTY}}");
}
let reverse: HashMap<&String, &String> =
self.map.synonyms().map(|(from, to)| (to, from)).collect();
let mut keys: Vec<&String> = self.map.keys().collect();
keys.sort();
let mut first = true;
for &k in &keys {
if !first { write!(f, "\n")?; } else { first = false; }
match reverse.get(&k) {
None => {
write!(f, "{} => {:?}", k, self.map.get(k))?
}
Some(s) => {
write!(f, "{}, {} => {:?}", s, k, self.map.get(k))?
}
}
}
Ok(())
}
}
#[derive(Clone, Debug, PartialEq)]
pub enum Value {
Switch(bool),
Counted(u64),
Plain(Option<String>),
List(Vec<String>),
}
impl Value {
pub fn as_bool(&self) -> bool {
match *self {
Switch(b) => b,
Counted(n) => n > 0,
Plain(None) => false,
Plain(Some(_)) => true,
List(ref vs) => !vs.is_empty(),
}
}
pub fn as_count(&self) -> u64 {
match *self {
Switch(b) => if b { 1 } else { 0 },
Counted(n) => n,
Plain(None) => 0,
Plain(Some(_)) => 1,
List(ref vs) => vs.len() as u64,
}
}
pub fn as_str(&self) -> &str {
match *self {
Switch(_) | Counted(_) | Plain(None) | List(_) => "",
Plain(Some(ref s)) => &**s,
}
}
pub fn as_vec(&self) -> Vec<&str> {
match *self {
Switch(_) | Counted(_) | Plain(None) => vec![],
Plain(Some(ref s)) => vec![&**s],
List(ref vs) => vs.iter().map(|s| &**s).collect(),
}
}
}
pub struct Deserializer<'de> {
vals: ArgvMap,
stack: Vec<DeserializerItem<'de>>,
}
#[derive(Debug)]
struct DeserializerItem<'de> {
key: String,
struct_field: &'de str,
val: Option<Value>,
}
macro_rules! derr(
($($arg:tt)*) => (return Err(Deserialize(format!($($arg)*))))
);
impl<'de> Deserializer<'de> {
fn push(&mut self, struct_field: &'de str) {
let key = ArgvMap::struct_field_to_key(struct_field);
self.stack
.push(DeserializerItem {
key: key.clone(),
struct_field: struct_field,
val: self.vals.find(&*key).cloned(),
});
}
fn pop(&mut self) -> Result<DeserializerItem<'_>> {
match self.stack.pop() {
None => derr!("Could not deserialize value into unknown key."),
Some(it) => Ok(it),
}
}
fn pop_key_val(&mut self) -> Result<(String, Value)> {
let it = self.pop()?;
match it.val {
None => {
derr!("Could not find argument '{}' (from struct field '{}').
Note that each struct field must have the right key prefix, which must
be one of `cmd_`, `flag_` or `arg_`.",
it.key,
it.struct_field)
}
Some(v) => Ok((it.key, v)),
}
}
fn pop_val(&mut self) -> Result<Value> {
let (_, v) = self.pop_key_val()?;
Ok(v)
}
fn to_number<T>(&mut self, expect: &str) -> Result<T>
where T: FromStr + ToString,
<T as FromStr>::Err: Debug
{
let (k, v) = self.pop_key_val()?;
match v {
Counted(n) => Ok(n.to_string().parse().unwrap()),
_ => {
if v.as_str().trim().is_empty() {
Ok("0".parse().unwrap())
} else {
match v.as_str().parse() {
Err(_) => {
derr!("Could not deserialize '{}' to {} for '{}'.",
v.as_str(),
expect,
k)
}
Ok(v) => Ok(v),
}
}
}
}
}
fn to_float(&mut self, expect: &str) -> Result<f64> {
let (k, v) = self.pop_key_val()?;
match v {
Counted(n) => Ok(n as f64),
_ => {
match v.as_str().parse() {
Err(_) => {
derr!("Could not deserialize '{}' to {} for '{}'.",
v.as_str(),
expect,
k)
}
Ok(v) => Ok(v),
}
}
}
}
}
macro_rules! deserialize_num {
($name:ident, $method:ident, $ty:ty) => (
fn $name<V>(self, visitor: V) -> Result<V::Value>
where V: de::Visitor<'de>
{
visitor.$method(self.to_number::<$ty>(stringify!($ty)).map(|n| n as $ty)?)
}
);
}
impl<'a, 'de> ::serde::Deserializer<'de> for &'a mut Deserializer<'de> {
type Error = Error;
fn deserialize_any<V>(self, _visitor: V) -> Result<V::Value>
where V: de::Visitor<'de>
{
unimplemented!()
}
fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value>
where V: de::Visitor<'de>
{
visitor.visit_bool(self.pop_val().map(|v| v.as_bool())?)
}
deserialize_num!(deserialize_i8, visit_i8, i8);
deserialize_num!(deserialize_i16, visit_i16, i16);
deserialize_num!(deserialize_i32, visit_i32, i32);
deserialize_num!(deserialize_i64, visit_i64, i64);
deserialize_num!(deserialize_u8, visit_u8, u8);
deserialize_num!(deserialize_u16, visit_u16, u16);
deserialize_num!(deserialize_u32, visit_u32, u32);
deserialize_num!(deserialize_u64, visit_u64, u64);
fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value>
where V: de::Visitor<'de>
{
visitor.visit_f32(self.to_float("f32").map(|n| n as f32)?)
}
fn deserialize_f64<V>(self, visitor: V) -> Result<V::Value>
where V: de::Visitor<'de>
{
visitor.visit_f64(self.to_float("f64")?)
}
fn deserialize_char<V>(self, visitor: V) -> Result<V::Value>
where V: de::Visitor<'de>
{
let (k, v) = self.pop_key_val()?;
let vstr = v.as_str();
match vstr.chars().count() {
1 => visitor.visit_char(vstr.chars().next().unwrap()),
_ => derr!("Could not deserialize '{}' into char for '{}'.", vstr, k),
}
}
fn deserialize_str<V>(self, visitor: V) -> Result<V::Value>
where V: de::Visitor<'de>
{
let s = self.pop_val()?;
visitor.visit_str(s.as_str())
}
fn deserialize_string<V>(self, visitor:V) -> Result<V::Value>
where V: de::Visitor<'de>
{
self.deserialize_str(visitor)
}
fn deserialize_bytes<V>(self, _visitor: V) -> Result<V::Value>
where V: de::Visitor<'de>
{
unimplemented!()
}
fn deserialize_byte_buf<V>(self, _visitor: V) -> Result<V::Value>
where V: de::Visitor<'de>
{
unimplemented!()
}
fn deserialize_option<V>(self, visitor: V) -> Result<V::Value>
where V: de::Visitor<'de>
{
let is_some = match self.stack.last() {
None => derr!("Could not deserialize value into unknown key."),
Some(it) => it.val.as_ref().map_or(false, |v| v.as_bool()),
};
if is_some {
visitor.visit_some(self)
} else {
visitor.visit_none()
}
}
fn deserialize_unit<V>(self, _visitor: V) -> Result<V::Value>
where V: de::Visitor<'de>
{
panic!("I don't know how to read into a nil value.")
}
fn deserialize_unit_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value>
where V: de::Visitor<'de>
{
visitor.visit_unit()
}
fn deserialize_newtype_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value>
where V: de::Visitor<'de>
{
visitor.visit_newtype_struct(self)
}
fn deserialize_tuple<V>(self, _len: usize, _visitor: V) -> Result<V::Value>
where V: de::Visitor<'de>
{
unimplemented!()
}
fn deserialize_tuple_struct<V>(self,
_name: &'static str,
_len: usize,
_visitor: V)
-> Result<V::Value>
where V: de::Visitor<'de>
{
unimplemented!()
}
fn deserialize_map<V>(self, _visitor: V) -> Result<V::Value>
where V: de::Visitor<'de>
{
unimplemented!()
}
fn deserialize_seq<V>(mut self, visitor: V) -> Result<V::Value>
where V: de::Visitor<'de>
{
let (key, struct_field, val) = match self.stack.pop() {
None => derr!("Could not deserialize value into unknown key."),
Some(DeserializerItem {key, struct_field, val}) => (key, struct_field, val),
};
let list = val.unwrap_or(List(vec![]));
let vals = list.as_vec();
for val in vals.iter().rev() {
self.stack
.push(DeserializerItem {
key: key.clone(),
struct_field: struct_field,
val: Some(Plain(Some((*val).into()))),
});
}
visitor.visit_seq(SeqDeserializer::new(&mut self, vals.len()))
}
fn deserialize_struct<V>(mut self,
_: &str,
fields: &'static [&'static str],
visitor: V)
-> Result<V::Value>
where V: de::Visitor<'de>
{
visitor.visit_seq(StructDeserializer::new(&mut self, fields))
}
fn deserialize_enum<V>(self, _name: &str, variants: &[&str], visitor: V) -> Result<V::Value>
where V: de::Visitor<'de>
{
let v = self.pop_val()?.as_str().to_lowercase();
let s = match variants.iter().find(|&n| n.to_lowercase() == v) {
Some(s) => s,
None => {
derr!("Could not match '{}' with any of \
the allowed variants: {:?}",
v,
variants)
}
};
visitor.visit_enum(s.into_deserializer())
}
fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value>
where V: de::Visitor<'de>
{
self.deserialize_str(visitor)
}
fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value>
where V: de::Visitor<'de>
{
self.deserialize_any(visitor)
}
}
struct SeqDeserializer<'a, 'de: 'a> {
de: &'a mut Deserializer<'de>,
len: usize,
}
impl<'a, 'de> SeqDeserializer<'a, 'de> {
fn new(de: &'a mut Deserializer<'de>, len: usize) -> Self {
SeqDeserializer { de: de, len: len }
}
}
impl<'a, 'de> de::SeqAccess<'de> for SeqDeserializer<'a, 'de> {
type Error = Error;
fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
where T: de::DeserializeSeed<'de>
{
if self.len == 0 {
return Ok(None);
}
self.len -= 1;
seed.deserialize(&mut *self.de).map(Some)
}
fn size_hint(&self) -> Option<usize> {
return Some(self.len);
}
}
struct StructDeserializer<'a, 'de: 'a> {
de: &'a mut Deserializer<'de>,
fields: &'static [&'static str],
}
impl<'a, 'de> StructDeserializer<'a, 'de> {
fn new(de: &'a mut Deserializer<'de>, fields: &'static [&'static str]) -> Self {
StructDeserializer {
de: de,
fields: fields,
}
}
}
impl<'a, 'de> de::SeqAccess<'de> for StructDeserializer<'a, 'de> {
type Error = Error;
fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
where T: de::DeserializeSeed<'de>
{
if self.fields.len() == 0 {
return Ok(None);
}
self.de.push(self.fields[0]);
self.fields = &self.fields[1..];
seed.deserialize(&mut *self.de).map(Some)
}
fn size_hint(&self) -> Option<usize> {
return Some(self.fields.len());
}
}