ltl_args/
opt.rs

1use crate::error::ParseError;
2use crate::matches::Matches;
3use crate::matches::Value;
4use std::collections::HashMap;
5use std::collections::HashSet;
6
7#[derive(Clone, Debug, PartialEq)]
8pub enum Action {
9    Set,
10    Append,
11    SetTrue,
12    SetFalse,
13}
14
15#[derive(Clone, Debug, PartialEq)]
16pub struct Opt {
17    pub name: String,
18    pub short: Option<char>,
19    pub long: Option<String>,
20    pub help: Option<String>,
21    pub default: Option<String>,
22    pub action: Action,
23    pub required: bool,
24}
25
26impl Opt {
27    pub fn name(name: &str) -> Opt {
28        Opt {
29            name: name.into(),
30            short: None,
31            long: None,
32            help: None,
33            default: None,
34            action: Action::Set,
35            required: false,
36        }
37    }
38
39    pub fn short(mut self, short: char) -> Opt {
40        self.short = Some(short);
41        self
42    }
43
44    pub fn long(mut self, long: &str) -> Opt {
45        self.long = Some(long.to_string());
46        self
47    }
48
49    pub fn help(mut self, help: &str) -> Opt {
50        self.help = Some(help.to_string());
51        self
52    }
53
54    pub fn default(mut self, default: &str) -> Opt {
55        self.default = Some(default.to_string());
56        self
57    }
58
59    pub fn action(mut self, action: Action) -> Opt {
60        self.action = action;
61        self
62    }
63
64    pub fn required(mut self) -> Opt {
65        self.required = true;
66        self
67    }
68}
69
70#[derive(Debug, PartialEq)]
71pub struct Opts {
72    opts: Vec<Opt>,
73}
74
75impl Opts {
76    pub fn new(opts: Vec<Opt>) -> Result<Opts, String> {
77        let args = Opts { opts };
78        args.validate()?;
79        Ok(args)
80    }
81
82    pub fn add(&mut self, arg: Opt) -> Result<(), String> {
83        self.opts.push(arg);
84        self.validate()
85    }
86
87    pub fn parse(&self, args: Vec<String>) -> Result<Matches, ParseError> {
88        let mut args_iter = args.into_iter();
89        let exec_name = match args_iter.next() {
90            Some(s) => s,
91            None => return Err(ParseError::MissingProgramName),
92        };
93
94        let mut positional = vec![];
95        let mut named = HashMap::new();
96
97        self.populate_defaults(&mut named);
98
99        while let Some(arg) = args_iter.next() {
100            if arg.starts_with("-") {
101                let opt = self.find_opt(&arg)?;
102
103                match opt.action {
104                    Action::Set => {
105                        if let Some(value) = args_iter.next() {
106                            named.insert(opt.name.clone(), Value::Single(value));
107                        } else {
108                            return Err(ParseError::MissingValue(opt.name.clone()));
109                        }
110                    }
111                    Action::Append => {
112                        match (args_iter.next(), named.get_mut(&opt.name)) {
113                            (None, _) => return Err(ParseError::MissingValue(opt.name.clone())),
114                            (Some(val), Some(Value::Multi(vals))) => {
115                                vals.push(val);
116                            }
117                            (Some(val), None) => {
118                                named.insert(opt.name.clone(), Value::Multi(vec![val]));
119                            }
120                            _ => return Err(ParseError::BadInternalState), // unexpected case
121                        };
122                    }
123                    Action::SetTrue => {
124                        named.insert(opt.name.clone(), Value::Flag(true));
125                    }
126                    Action::SetFalse => {
127                        named.insert(opt.name.clone(), Value::Flag(false));
128                    }
129                };
130            } else {
131                positional.push(arg);
132            }
133        }
134
135        Ok(Matches::new(exec_name, positional, named))
136    }
137
138    fn populate_defaults(&self, named: &mut HashMap<String, Value>) {
139        for opt in self.opts.iter() {
140            if let Some(default) = &opt.default {
141                named.insert(opt.name.clone(), Value::Single(default.to_owned()));
142            } else {
143                match opt.action {
144                    Action::Append => {
145                        named.insert(opt.name.clone(), Value::Multi(vec![]));
146                    }
147                    Action::SetTrue => {
148                        named.insert(opt.name.clone(), Value::Flag(false));
149                    }
150                    Action::SetFalse => {
151                        named.insert(opt.name.clone(), Value::Flag(false));
152                    }
153                    _ => {}
154                }
155            }
156        }
157    }
158
159    fn find_opt(&self, arg: &str) -> Result<&Opt, ParseError> {
160        let opt = if arg.starts_with("--") {
161            let long = arg.strip_prefix("--").unwrap();
162            self.opts.iter().find(|o| o.long.as_deref() == Some(long))
163        } else if arg.starts_with("-") {
164            if arg.chars().count() != 2 {
165                return Err(ParseError::MalformedOption(arg.to_string()));
166            }
167            let short = arg.chars().nth(1);
168            self.opts.iter().find(|o| o.short == short)
169        } else {
170            return Err(ParseError::UnexpectedOption(arg.to_string()));
171        };
172
173        if let Some(opt) = opt {
174            Ok(opt)
175        } else {
176            Err(ParseError::UnexpectedOption(arg.to_string()))
177        }
178    }
179
180    fn validate(&self) -> Result<(), String> {
181        let mut names: HashSet<String> = HashSet::new();
182        let mut short: HashSet<char> = HashSet::new();
183        let mut long: HashSet<String> = HashSet::new();
184
185        for arg in &self.opts {
186            if names.contains(&arg.name) {
187                return Err(format!(
188                    "Optument names must be unique; found two with name {}",
189                    arg.name
190                ));
191            } else if arg.short.is_some() && short.contains(&arg.short.unwrap()) {
192                return Err(format!(
193                    "Short flags must be unique; found two with short flag -{}",
194                    arg.short.unwrap()
195                ));
196            } else if arg.long.is_some() && long.contains(arg.long.as_ref().unwrap()) {
197                return Err(format!(
198                    "Long flags must be unique; found two with long flag --{}",
199                    arg.long.as_ref().unwrap()
200                ));
201            }
202
203            names.insert(arg.name.to_string());
204            if let Some(c) = arg.short {
205                short.insert(c);
206            }
207            if let Some(s) = &arg.long {
208                long.insert(s.to_string());
209            }
210        }
211
212        Ok(())
213    }
214}
215
216#[cfg(test)]
217mod tests {
218    use super::*;
219
220    #[test]
221    fn test_validates_empty_args() {
222        let _ = Opts::new(vec![]).expect("should validate");
223    }
224
225    #[test]
226    fn detects_duplicate_names() {
227        let opts = Opts::new(vec![
228            Opt::name("host"),
229            Opt::name("port"),
230            Opt::name("port"),
231        ]);
232        assert_eq!(
233            opts,
234            Err(format!(
235                "Optument names must be unique; found two with name port"
236            ))
237        );
238    }
239
240    #[test]
241    fn detects_duplicate_short() {
242        let opts = Opts::new(vec![
243            Opt::name("host").short('p'),
244            Opt::name("port").short('p'),
245            Opt::name("threads").short('t'),
246        ]);
247        assert_eq!(
248            opts,
249            Err(format!(
250                "Short flags must be unique; found two with short flag -p"
251            ))
252        );
253    }
254
255    #[test]
256    fn detects_duplicate_long() {
257        let opts = Opts::new(vec![
258            Opt::name("host").long("host"),
259            Opt::name("port").long("host"),
260            Opt::name("threads").long("threads"),
261        ]);
262        assert_eq!(
263            opts,
264            Err(format!(
265                "Long flags must be unique; found two with long flag --host"
266            ))
267        );
268    }
269
270    #[test]
271    fn parses_positional_args() {
272        let opts = Opts::new(vec![Opt::name("host").long("host")]).unwrap();
273        let args: Vec<_> = ["myprogram", "1", "2", "blue"]
274            .iter()
275            .map(|s| s.to_string())
276            .collect();
277        let expected_positional: Vec<_> = args.iter().skip(1).cloned().collect();
278
279        let matches = opts.parse(args);
280        assert!(matches.is_ok());
281        let matches = matches.unwrap();
282
283        assert_eq!(matches.positional(), expected_positional);
284    }
285
286    #[test]
287    fn parses_named_args() {
288        let opts = Opts::new(vec![
289            Opt::name("host").long("host"),
290            Opt::name("verbose").long("verbose").action(Action::SetTrue),
291            Opt::name("queue").short('q').action(Action::Append),
292            Opt::name("nocolor")
293                .short('n')
294                .long("nocolor")
295                .action(Action::SetFalse),
296            Opt::name("missing").default("something"),
297        ])
298        .unwrap();
299        let args: Vec<String> = vec![
300            "myprogram",
301            "1",
302            "2",
303            "--verbose",
304            "-q",
305            "items",
306            "--host",
307            "localhost",
308            "-q",
309            "-queue-name-with-dash",
310            "-n",
311            "blue",
312        ]
313        .iter()
314        .map(|s| s.to_string())
315        .collect();
316
317        let expected_positional: Vec<_> = vec!["1", "2", "blue"];
318
319        let matches = opts.parse(args);
320        dbg!(&matches);
321        assert!(matches.is_ok());
322        let matches = matches.unwrap();
323
324        assert_eq!(matches.positional(), expected_positional);
325        assert_eq!(matches.flag("verbose").unwrap(), Some(true));
326        assert_eq!(matches.one("host").unwrap(), Some("localhost".to_string()));
327        let queues: Vec<String> = matches.all("queue").unwrap();
328        assert_eq!(
329            queues,
330            vec!["items".to_string(), "-queue-name-with-dash".to_string()]
331        );
332
333        assert_eq!(
334            matches.one("missing").unwrap(),
335            Some("something".to_string())
336        );
337    }
338}