ragit_cli/
lib.rs

1use std::collections::HashMap;
2
3mod error;
4mod span;
5
6pub use error::{Error, ErrorKind};
7pub use span::Span;
8
9pub struct ArgParser {
10    arg_count: ArgCount,
11    arg_type: ArgType,
12    flags: Vec<Flag>,
13
14    // `--N=20`, `--prefix=rust`
15    arg_flags: HashMap<String, ArgFlag>,
16
17    // '-f' -> '--force'
18    short_flags: HashMap<String, String>,
19}
20
21impl ArgParser {
22    pub fn new() -> Self {
23        ArgParser {
24            arg_count: ArgCount::None,
25            arg_type: ArgType::String,
26            flags: vec![],
27            arg_flags: HashMap::new(),
28            short_flags: HashMap::new(),
29        }
30    }
31
32    pub fn args(&mut self, arg_type: ArgType, arg_count: ArgCount) -> &mut Self {
33        self.arg_type = arg_type;
34        self.arg_count = arg_count;
35        self
36    }
37
38    pub fn flag(&mut self, flags: &[&str]) -> &mut Self {
39        self.flags.push(Flag {
40            values: flags.iter().map(|flag| flag.to_string()).collect(),
41            optional: false,
42            default: None,
43        });
44        self
45    }
46
47    pub fn optional_flag(&mut self, flags: &[&str]) -> &mut Self {
48        self.flags.push(Flag {
49            values: flags.iter().map(|flag| flag.to_string()).collect(),
50            optional: true,
51            default: None,
52        });
53        self
54    }
55
56    pub fn arg_flag(&mut self, flag: &str, arg_type: ArgType) -> &mut Self {
57        self.arg_flags.insert(flag.to_string(), ArgFlag { flag: flag.to_string(), optional: false, default: None, arg_type });
58        self
59    }
60
61    pub fn optional_arg_flag(&mut self, flag: &str, arg_type: ArgType) -> &mut Self {
62        self.arg_flags.insert(flag.to_string(), ArgFlag { flag: flag.to_string(), optional: true, default: None, arg_type });
63        self
64    }
65
66    pub fn arg_flag_with_default(&mut self, flag: &str, default: &str, arg_type: ArgType) -> &mut Self {
67        self.arg_flags.insert(flag.to_string(), ArgFlag { flag: flag.to_string(), optional: true, default: Some(default.to_string()), arg_type });
68        self
69    }
70
71    pub fn short_flag(&mut self, flags: &[&str]) -> &mut Self {
72        for flag in flags.iter() {
73            let short_flag = flag.get(1..3).unwrap().to_string();
74
75            if let Some(old) = self.short_flags.get(&short_flag) {
76                panic!("{flag} and {old} have the same short name!")
77            }
78
79            self.short_flags.insert(short_flag, flag.to_string());
80        }
81
82        self
83    }
84
85    // the first flag is the default value
86    pub fn flag_with_default(&mut self, flags: &[&str]) -> &mut Self {
87        self.flags.push(Flag {
88            values: flags.iter().map(|flag| flag.to_string()).collect(),
89            optional: true,
90            default: Some(0),
91        });
92        self
93    }
94
95    pub fn map_short_flag(&self, flag: &str) -> String {
96        match self.short_flags.get(flag) {
97            Some(f) => f.to_string(),
98            None => flag.to_string(),
99        }
100    }
101
102    pub fn parse(&self, raw_args: &[String]) -> Result<ParsedArgs, Error> {
103        self.parse_worker(raw_args).map_err(
104            |mut e| {
105                e.span = e.span.render(raw_args);
106                e
107            }
108        )
109    }
110
111    pub fn parse_worker(&self, raw_args: &[String]) -> Result<ParsedArgs, Error> {
112        let mut args = vec![];
113        let mut flags = vec![None; self.flags.len()];
114        let mut arg_flags = HashMap::new();
115        let mut expecting_flag_arg: Option<ArgFlag> = None;
116        let mut no_more_flags = false;
117
118        if raw_args.get(0).map(|arg| arg.as_str()) == Some("--help") {
119            return Ok(ParsedArgs {
120                raw_args: raw_args.to_vec(),
121                args,
122                flags: vec![],
123                arg_flags,
124                show_help: true,
125            });
126        }
127
128        'raw_arg_loop: for (arg_index, raw_arg) in raw_args.iter().enumerate() {
129            if raw_arg == "--" {
130                if let Some(arg_flag) = expecting_flag_arg {
131                    return Err(Error {
132                        span: Span::End,
133                        kind: ErrorKind::MissingArgument(arg_flag.flag.to_string(), arg_flag.arg_type),
134                    });
135                }
136
137                no_more_flags = true;
138                continue;
139            }
140
141            if let Some(arg_flag) = expecting_flag_arg {
142                expecting_flag_arg = None;
143                arg_flag.arg_type.parse(raw_arg, Span::Exact(arg_index))?;
144
145                if let Some(_) = arg_flags.insert(arg_flag.flag.clone(), raw_arg.to_string()) {
146                    return Err(Error {
147                        span: Span::Exact(arg_index),
148                        kind: ErrorKind::SameFlagMultipleTimes(
149                            arg_flag.flag.clone(),
150                            arg_flag.flag.clone(),
151                        ),
152                    });
153                }
154
155                continue;
156            }
157
158            if raw_arg.starts_with("-") && !no_more_flags {
159                let mapped_flag = self.map_short_flag(raw_arg);
160
161                for (flag_index, flag) in self.flags.iter().enumerate() {
162                    if flag.values.contains(&mapped_flag) {
163                        if flags[flag_index].is_none() {
164                            flags[flag_index] = Some(mapped_flag.to_string());
165                            continue 'raw_arg_loop;
166                        }
167
168                        else {
169                            return Err(Error {
170                                span: Span::Exact(arg_index),
171                                kind: ErrorKind::SameFlagMultipleTimes(
172                                    flags[flag_index].as_ref().unwrap().to_string(),
173                                    raw_arg.to_string(),
174                                ),
175                            });
176                        }
177                    }
178                }
179
180                if let Some(arg_flag) = self.arg_flags.get(&mapped_flag) {
181                    expecting_flag_arg = Some(arg_flag.clone());
182                    continue;
183                }
184
185                if raw_arg.contains("=") {
186                    let splitted = raw_arg.splitn(2, '=').collect::<Vec<_>>();
187                    let flag = self.map_short_flag(splitted[0]);
188                    let flag_arg = splitted[1];
189
190                    if let Some(arg_flag) = self.arg_flags.get(&flag) {
191                        arg_flag.arg_type.parse(flag_arg, Span::Exact(arg_index))?;
192
193                        if let Some(_) = arg_flags.insert(flag.to_string(), flag_arg.to_string()) {
194                            return Err(Error {
195                                span: Span::Exact(arg_index),
196                                kind: ErrorKind::SameFlagMultipleTimes(
197                                    flag.to_string(),
198                                    flag.to_string(),
199                                ),
200                            });
201                        }
202
203                        continue;
204                    }
205
206                    else {
207                        return Err(Error {
208                            span: Span::Exact(arg_index),
209                            kind: ErrorKind::UnknownFlag(flag.to_string()),
210                        });
211                    }
212                }
213
214                return Err(Error {
215                    span: Span::Exact(arg_index),
216                    kind: ErrorKind::UnknownFlag(raw_arg.to_string()),
217                });
218            }
219
220            else {
221                args.push(self.arg_type.parse(raw_arg, Span::Exact(arg_index))?);
222            }
223        }
224
225        if let Some(arg_flag) = expecting_flag_arg {
226            return Err(Error {
227                span: Span::End,
228                kind: ErrorKind::MissingArgument(arg_flag.flag.to_string(), arg_flag.arg_type),
229            });
230        }
231
232        for i in 0..flags.len() {
233            if flags[i].is_none() {
234                if let Some(j) = self.flags[i].default {
235                    flags[i] = Some(self.flags[i].values[j].clone());
236                }
237
238                else if !self.flags[i].optional {
239                    return Err(Error {
240                        span: Span::End,
241                        kind: ErrorKind::MissingFlag(self.flags[i].values.join(" | ")),
242                    });
243                }
244            }
245        }
246
247        loop {
248            let span = match self.arg_count {
249                ArgCount::Geq(n) if args.len() < n => { Span::End },
250                ArgCount::Leq(n) if args.len() > n => { Span::NthArg(n + 1) },
251                ArgCount::Exact(n) if args.len() != n => { Span::NthArg(n + 1) },
252                ArgCount::None if args.len() > 0 => { Span::FirstArg },
253                _ => { break; },
254            };
255
256            return Err(Error {
257                span,
258                kind: ErrorKind::WrongArgCount {
259                    expected: self.arg_count,
260                    got: args.len(),
261                },
262            });
263        }
264
265        for (flag, arg_flag) in self.arg_flags.iter() {
266            if arg_flags.contains_key(flag) {
267                continue;
268            }
269
270            else if let Some(default) = &arg_flag.default {
271                arg_flags.insert(flag.to_string(), default.to_string());
272            }
273
274            else if !arg_flag.optional {
275                return Err(Error {
276                    span: Span::End,
277                    kind: ErrorKind::MissingFlag(flag.to_string()),
278                });
279            }
280        }
281
282        Ok(ParsedArgs {
283            raw_args: raw_args.to_vec(),
284            args,
285            flags,
286            arg_flags,
287            show_help: false,
288        })
289    }
290}
291
292#[derive(Clone, Copy, Debug)]
293pub enum ArgCount {
294    Geq(usize),
295    Leq(usize),
296    Exact(usize),
297    Any,
298    None,
299}
300
301#[derive(Clone, Copy, Debug)]
302pub enum ArgType {
303    String,
304    Path,
305    Command,
306    Query,  // uid or path
307    Integer,
308    UnsignedInteger,
309}
310
311impl ArgType {
312    pub fn parse(&self, arg: &str, span: Span) -> Result<String, Error> {
313        match self {
314            ArgType::Integer => match arg.parse::<i128>() {
315                Ok(_) => Ok(arg.to_string()),
316                Err(e) => Err(Error {
317                    span,
318                    kind: ErrorKind::ParseIntError(e),
319                }),
320            },
321            ArgType::UnsignedInteger => match arg.parse::<u128>() {
322                Ok(_) => Ok(arg.to_string()),
323                Err(e) => Err(Error {
324                    span,
325                    kind: ErrorKind::ParseIntError(e),
326                }),
327            },
328            ArgType::String
329            | ArgType::Path
330            | ArgType::Command  // TODO: validator for ArgType::Command
331            | ArgType::Query => Ok(arg.to_string()),
332        }
333    }
334}
335
336#[derive(Clone, Debug)]
337pub struct Flag {
338    values: Vec<String>,
339    optional: bool,
340    default: Option<usize>,
341}
342
343#[derive(Clone, Debug)]
344pub struct ArgFlag {
345    flag: String,
346    optional: bool,
347    default: Option<String>,
348    arg_type: ArgType,
349}
350
351pub struct ParsedArgs {
352    raw_args: Vec<String>,
353    args: Vec<String>,
354    flags: Vec<Option<String>>,
355    pub arg_flags: HashMap<String, String>,
356    show_help: bool,  // TODO: options for help messages
357}
358
359impl ParsedArgs {
360    pub fn get_args(&self) -> Vec<String> {
361        self.args.clone()
362    }
363
364    pub fn get_args_exact(&self, count: usize) -> Result<Vec<String>, Error> {
365        if self.args.len() == count {
366            Ok(self.args.clone())
367        }
368
369        else {
370            Err(Error {
371                span: Span::FirstArg.render(&self.raw_args),
372                kind: ErrorKind::WrongArgCount {
373                    expected: ArgCount::Exact(count),
374                    got: self.args.len(),
375                },
376            })
377        }
378    }
379
380    // if there's an index error, it panics instead of returning None
381    // if it returns None, that means Nth flag is optional and its value is None
382    pub fn get_flag(&self, index: usize) -> Option<String> {
383        self.flags[index].clone()
384    }
385
386    pub fn show_help(&self) -> bool {
387        self.show_help
388    }
389}
390
391pub fn underline_span(prefix: &str, args: &str, start: usize, end: usize) -> String {
392    format!(
393        "{prefix}{args}\n{}{}{}{}",
394        " ".repeat(prefix.len()),
395        " ".repeat(start),
396        "^".repeat(end - start),
397        " ".repeat(args.len() - end),
398    )
399}