1use crate::error::ParseError;
2use crate::matches::Matches;
3use crate::matches::Value;
4use std::collections::HashMap;
5use std::collections::HashSet;
6
7#[derive(Clone, Debug, PartialEq)]
8pub enum Action {
9 Set,
10 Append,
11 SetTrue,
12 SetFalse,
13}
14
15#[derive(Clone, Debug, PartialEq)]
16pub struct Opt {
17 pub name: String,
18 pub short: Option<char>,
19 pub long: Option<String>,
20 pub help: Option<String>,
21 pub default: Option<String>,
22 pub action: Action,
23 pub required: bool,
24}
25
26impl Opt {
27 pub fn name(name: &str) -> Opt {
28 Opt {
29 name: name.into(),
30 short: None,
31 long: None,
32 help: None,
33 default: None,
34 action: Action::Set,
35 required: false,
36 }
37 }
38
39 pub fn short(mut self, short: char) -> Opt {
40 self.short = Some(short);
41 self
42 }
43
44 pub fn long(mut self, long: &str) -> Opt {
45 self.long = Some(long.to_string());
46 self
47 }
48
49 pub fn help(mut self, help: &str) -> Opt {
50 self.help = Some(help.to_string());
51 self
52 }
53
54 pub fn default(mut self, default: &str) -> Opt {
55 self.default = Some(default.to_string());
56 self
57 }
58
59 pub fn action(mut self, action: Action) -> Opt {
60 self.action = action;
61 self
62 }
63
64 pub fn required(mut self) -> Opt {
65 self.required = true;
66 self
67 }
68}
69
70#[derive(Debug, PartialEq)]
71pub struct Opts {
72 opts: Vec<Opt>,
73}
74
75impl Opts {
76 pub fn new(opts: Vec<Opt>) -> Result<Opts, String> {
77 let args = Opts { opts };
78 args.validate()?;
79 Ok(args)
80 }
81
82 pub fn add(&mut self, arg: Opt) -> Result<(), String> {
83 self.opts.push(arg);
84 self.validate()
85 }
86
87 pub fn parse(&self, args: Vec<String>) -> Result<Matches, ParseError> {
88 let mut args_iter = args.into_iter();
89 let exec_name = match args_iter.next() {
90 Some(s) => s,
91 None => return Err(ParseError::MissingProgramName),
92 };
93
94 let mut positional = vec![];
95 let mut named = HashMap::new();
96
97 self.populate_defaults(&mut named);
98
99 while let Some(arg) = args_iter.next() {
100 if arg.starts_with("-") {
101 let opt = self.find_opt(&arg)?;
102
103 match opt.action {
104 Action::Set => {
105 if let Some(value) = args_iter.next() {
106 named.insert(opt.name.clone(), Value::Single(value));
107 } else {
108 return Err(ParseError::MissingValue(opt.name.clone()));
109 }
110 }
111 Action::Append => {
112 match (args_iter.next(), named.get_mut(&opt.name)) {
113 (None, _) => return Err(ParseError::MissingValue(opt.name.clone())),
114 (Some(val), Some(Value::Multi(vals))) => {
115 vals.push(val);
116 }
117 (Some(val), None) => {
118 named.insert(opt.name.clone(), Value::Multi(vec![val]));
119 }
120 _ => return Err(ParseError::BadInternalState), };
122 }
123 Action::SetTrue => {
124 named.insert(opt.name.clone(), Value::Flag(true));
125 }
126 Action::SetFalse => {
127 named.insert(opt.name.clone(), Value::Flag(false));
128 }
129 };
130 } else {
131 positional.push(arg);
132 }
133 }
134
135 Ok(Matches::new(exec_name, positional, named))
136 }
137
138 fn populate_defaults(&self, named: &mut HashMap<String, Value>) {
139 for opt in self.opts.iter() {
140 if let Some(default) = &opt.default {
141 named.insert(opt.name.clone(), Value::Single(default.to_owned()));
142 } else {
143 match opt.action {
144 Action::Append => {
145 named.insert(opt.name.clone(), Value::Multi(vec![]));
146 }
147 Action::SetTrue => {
148 named.insert(opt.name.clone(), Value::Flag(false));
149 }
150 Action::SetFalse => {
151 named.insert(opt.name.clone(), Value::Flag(false));
152 }
153 _ => {}
154 }
155 }
156 }
157 }
158
159 fn find_opt(&self, arg: &str) -> Result<&Opt, ParseError> {
160 let opt = if arg.starts_with("--") {
161 let long = arg.strip_prefix("--").unwrap();
162 self.opts.iter().find(|o| o.long.as_deref() == Some(long))
163 } else if arg.starts_with("-") {
164 if arg.chars().count() != 2 {
165 return Err(ParseError::MalformedOption(arg.to_string()));
166 }
167 let short = arg.chars().nth(1);
168 self.opts.iter().find(|o| o.short == short)
169 } else {
170 return Err(ParseError::UnexpectedOption(arg.to_string()));
171 };
172
173 if let Some(opt) = opt {
174 Ok(opt)
175 } else {
176 Err(ParseError::UnexpectedOption(arg.to_string()))
177 }
178 }
179
180 fn validate(&self) -> Result<(), String> {
181 let mut names: HashSet<String> = HashSet::new();
182 let mut short: HashSet<char> = HashSet::new();
183 let mut long: HashSet<String> = HashSet::new();
184
185 for arg in &self.opts {
186 if names.contains(&arg.name) {
187 return Err(format!(
188 "Optument names must be unique; found two with name {}",
189 arg.name
190 ));
191 } else if arg.short.is_some() && short.contains(&arg.short.unwrap()) {
192 return Err(format!(
193 "Short flags must be unique; found two with short flag -{}",
194 arg.short.unwrap()
195 ));
196 } else if arg.long.is_some() && long.contains(arg.long.as_ref().unwrap()) {
197 return Err(format!(
198 "Long flags must be unique; found two with long flag --{}",
199 arg.long.as_ref().unwrap()
200 ));
201 }
202
203 names.insert(arg.name.to_string());
204 if let Some(c) = arg.short {
205 short.insert(c);
206 }
207 if let Some(s) = &arg.long {
208 long.insert(s.to_string());
209 }
210 }
211
212 Ok(())
213 }
214}
215
216#[cfg(test)]
217mod tests {
218 use super::*;
219
220 #[test]
221 fn test_validates_empty_args() {
222 let _ = Opts::new(vec![]).expect("should validate");
223 }
224
225 #[test]
226 fn detects_duplicate_names() {
227 let opts = Opts::new(vec![
228 Opt::name("host"),
229 Opt::name("port"),
230 Opt::name("port"),
231 ]);
232 assert_eq!(
233 opts,
234 Err(format!(
235 "Optument names must be unique; found two with name port"
236 ))
237 );
238 }
239
240 #[test]
241 fn detects_duplicate_short() {
242 let opts = Opts::new(vec![
243 Opt::name("host").short('p'),
244 Opt::name("port").short('p'),
245 Opt::name("threads").short('t'),
246 ]);
247 assert_eq!(
248 opts,
249 Err(format!(
250 "Short flags must be unique; found two with short flag -p"
251 ))
252 );
253 }
254
255 #[test]
256 fn detects_duplicate_long() {
257 let opts = Opts::new(vec![
258 Opt::name("host").long("host"),
259 Opt::name("port").long("host"),
260 Opt::name("threads").long("threads"),
261 ]);
262 assert_eq!(
263 opts,
264 Err(format!(
265 "Long flags must be unique; found two with long flag --host"
266 ))
267 );
268 }
269
270 #[test]
271 fn parses_positional_args() {
272 let opts = Opts::new(vec![Opt::name("host").long("host")]).unwrap();
273 let args: Vec<_> = ["myprogram", "1", "2", "blue"]
274 .iter()
275 .map(|s| s.to_string())
276 .collect();
277 let expected_positional: Vec<_> = args.iter().skip(1).cloned().collect();
278
279 let matches = opts.parse(args);
280 assert!(matches.is_ok());
281 let matches = matches.unwrap();
282
283 assert_eq!(matches.positional(), expected_positional);
284 }
285
286 #[test]
287 fn parses_named_args() {
288 let opts = Opts::new(vec![
289 Opt::name("host").long("host"),
290 Opt::name("verbose").long("verbose").action(Action::SetTrue),
291 Opt::name("queue").short('q').action(Action::Append),
292 Opt::name("nocolor")
293 .short('n')
294 .long("nocolor")
295 .action(Action::SetFalse),
296 Opt::name("missing").default("something"),
297 ])
298 .unwrap();
299 let args: Vec<String> = vec![
300 "myprogram",
301 "1",
302 "2",
303 "--verbose",
304 "-q",
305 "items",
306 "--host",
307 "localhost",
308 "-q",
309 "-queue-name-with-dash",
310 "-n",
311 "blue",
312 ]
313 .iter()
314 .map(|s| s.to_string())
315 .collect();
316
317 let expected_positional: Vec<_> = vec!["1", "2", "blue"];
318
319 let matches = opts.parse(args);
320 dbg!(&matches);
321 assert!(matches.is_ok());
322 let matches = matches.unwrap();
323
324 assert_eq!(matches.positional(), expected_positional);
325 assert_eq!(matches.flag("verbose").unwrap(), Some(true));
326 assert_eq!(matches.one("host").unwrap(), Some("localhost".to_string()));
327 let queues: Vec<String> = matches.all("queue").unwrap();
328 assert_eq!(
329 queues,
330 vec!["items".to_string(), "-queue-name-with-dash".to_string()]
331 );
332
333 assert_eq!(
334 matches.one("missing").unwrap(),
335 Some("something".to_string())
336 );
337 }
338}