ragit_cli/
lib.rs

1use std::collections::HashMap;
2
3mod dist;
4mod error;
5mod span;
6
7pub use dist::{get_closest_string, substr_edit_distance};
8pub use error::{Error, ErrorKind};
9pub use span::{Span, underline_span};
10
11pub struct ArgParser {
12    arg_count: ArgCount,
13    arg_type: ArgType,
14    flags: Vec<Flag>,
15    aliases: HashMap<String, String>,
16
17    // `--N=20`, `--prefix=rust`
18    arg_flags: HashMap<String, ArgFlag>,
19
20    // '-f' -> '--force'
21    short_flags: HashMap<String, String>,
22}
23
24impl ArgParser {
25    pub fn new() -> Self {
26        ArgParser {
27            arg_count: ArgCount::None,
28            arg_type: ArgType::String,
29            flags: vec![],
30            aliases: HashMap::new(),
31            arg_flags: HashMap::new(),
32            short_flags: HashMap::new(),
33        }
34    }
35
36    pub fn args(&mut self, arg_type: ArgType, arg_count: ArgCount) -> &mut Self {
37        self.arg_type = arg_type;
38        self.arg_count = arg_count;
39        self
40    }
41
42    pub fn flag(&mut self, flags: &[&str]) -> &mut Self {
43        self.flags.push(Flag {
44            values: flags.iter().map(|flag| flag.to_string()).collect(),
45            optional: false,
46            default: None,
47        });
48        self
49    }
50
51    pub fn optional_flag(&mut self, flags: &[&str]) -> &mut Self {
52        self.flags.push(Flag {
53            values: flags.iter().map(|flag| flag.to_string()).collect(),
54            optional: true,
55            default: None,
56        });
57        self
58    }
59
60    pub fn arg_flag(&mut self, flag: &str, arg_type: ArgType) -> &mut Self {
61        self.arg_flags.insert(flag.to_string(), ArgFlag { flag: flag.to_string(), optional: false, default: None, arg_type });
62        self
63    }
64
65    pub fn optional_arg_flag(&mut self, flag: &str, arg_type: ArgType) -> &mut Self {
66        self.arg_flags.insert(flag.to_string(), ArgFlag { flag: flag.to_string(), optional: true, default: None, arg_type });
67        self
68    }
69
70    pub fn arg_flag_with_default(&mut self, flag: &str, default: &str, arg_type: ArgType) -> &mut Self {
71        self.arg_flags.insert(flag.to_string(), ArgFlag { flag: flag.to_string(), optional: true, default: Some(default.to_string()), arg_type });
72        self
73    }
74
75    // the first flag is the default value
76    pub fn flag_with_default(&mut self, flags: &[&str]) -> &mut Self {
77        self.flags.push(Flag {
78            values: flags.iter().map(|flag| flag.to_string()).collect(),
79            optional: true,
80            default: Some(0),
81        });
82        self
83    }
84
85    fn map_short_flag(&self, flag: &str) -> String {
86        match self.short_flags.get(flag) {
87            Some(f) => f.to_string(),
88            None => flag.to_string(),
89        }
90    }
91
92    pub fn short_flag(&mut self, flags: &[&str]) -> &mut Self {
93        for flag in flags.iter() {
94            let short_flag = flag.get(1..3).unwrap().to_string();
95
96            if let Some(old) = self.short_flags.get(&short_flag) {
97                panic!("{flag} and {old} have the same short name!")
98            }
99
100            self.short_flags.insert(short_flag, flag.to_string());
101        }
102
103        self
104    }
105
106    pub fn alias(&mut self, from: &str, to: &str) -> &mut Self {
107        self.aliases.insert(from.to_string(), to.to_string());
108        self
109    }
110
111    /// Let's say `raw_args` is `["rag", "ls-files", "--json", "--staged", "--name-only"]` and
112    /// you don't care about the first 2 args (path and command name). You only want to parse
113    /// the flags (the last 3 args). In this case, you set `skip_first_n` to 2.
114    pub fn parse(&self, raw_args: &[String], skip_first_n: usize) -> Result<ParsedArgs, Error> {
115        self.parse_worker(raw_args, skip_first_n).map_err(
116            |mut e| {
117                e.span = e.span.render(raw_args, skip_first_n);
118                e
119            }
120        )
121    }
122
123    fn parse_worker(&self, raw_args: &[String], skip_first_n: usize) -> Result<ParsedArgs, Error> {
124        let mut args = vec![];
125        let mut flags = vec![None; self.flags.len()];
126        let mut arg_flags = HashMap::new();
127        let mut expecting_flag_arg: Option<ArgFlag> = None;
128        let mut no_more_flags = false;
129
130        if raw_args.get(skip_first_n).map(|arg| arg.as_str()) == Some("--help") {
131            return Ok(ParsedArgs {
132                skip_first_n,
133                raw_args: raw_args.to_vec(),
134                args,
135                flags: vec![],
136                arg_flags,
137                show_help: true,
138            });
139        }
140
141        'raw_arg_loop: for (arg_index, raw_arg) in raw_args[skip_first_n..].iter().enumerate() {
142            let raw_arg = match self.aliases.get(raw_arg) {
143                Some(alias) => alias.to_string(),
144                None => raw_arg.to_string(),
145            };
146
147            if raw_arg == "--" {
148                if let Some(arg_flag) = expecting_flag_arg {
149                    return Err(Error {
150                        span: Span::End,
151                        kind: ErrorKind::MissingArgument(arg_flag.flag.to_string(), arg_flag.arg_type),
152                    });
153                }
154
155                no_more_flags = true;
156                continue;
157            }
158
159            if let Some(arg_flag) = expecting_flag_arg {
160                expecting_flag_arg = None;
161                arg_flag.arg_type.parse(&raw_arg, Span::Exact(arg_index + skip_first_n))?;
162
163                if let Some(_) = arg_flags.insert(arg_flag.flag.clone(), raw_arg.to_string()) {
164                    return Err(Error {
165                        span: Span::Exact(arg_index + skip_first_n),
166                        kind: ErrorKind::SameFlagMultipleTimes(
167                            arg_flag.flag.clone(),
168                            arg_flag.flag.clone(),
169                        ),
170                    });
171                }
172
173                continue;
174            }
175
176            if raw_arg.starts_with("-") && !no_more_flags {
177                let mapped_flag = self.map_short_flag(&raw_arg);
178
179                for (flag_index, flag) in self.flags.iter().enumerate() {
180                    if flag.values.contains(&mapped_flag) {
181                        if flags[flag_index].is_none() {
182                            flags[flag_index] = Some(mapped_flag.to_string());
183                            continue 'raw_arg_loop;
184                        }
185
186                        else {
187                            return Err(Error {
188                                span: Span::Exact(arg_index + skip_first_n),
189                                kind: ErrorKind::SameFlagMultipleTimes(
190                                    flags[flag_index].as_ref().unwrap().to_string(),
191                                    raw_arg.to_string(),
192                                ),
193                            });
194                        }
195                    }
196                }
197
198                if let Some(arg_flag) = self.arg_flags.get(&mapped_flag) {
199                    expecting_flag_arg = Some(arg_flag.clone());
200                    continue;
201                }
202
203                if raw_arg.contains("=") {
204                    let splitted = raw_arg.splitn(2, '=').collect::<Vec<_>>();
205                    let flag = self.map_short_flag(splitted[0]);
206                    let flag_arg = splitted[1];
207
208                    if let Some(arg_flag) = self.arg_flags.get(&flag) {
209                        arg_flag.arg_type.parse(flag_arg, Span::Exact(arg_index + skip_first_n))?;
210
211                        if let Some(_) = arg_flags.insert(flag.to_string(), flag_arg.to_string()) {
212                            return Err(Error {
213                                span: Span::Exact(arg_index + skip_first_n),
214                                kind: ErrorKind::SameFlagMultipleTimes(
215                                    flag.to_string(),
216                                    flag.to_string(),
217                                ),
218                            });
219                        }
220
221                        continue;
222                    }
223
224                    else {
225                        return Err(Error {
226                            span: Span::Exact(arg_index + skip_first_n),
227                            kind: ErrorKind::UnknownFlag {
228                                flag: flag.to_string(),
229                                similar_flag: self.get_similar_flag(&flag),
230                            },
231                        });
232                    }
233                }
234
235                return Err(Error {
236                    span: Span::Exact(arg_index + skip_first_n),
237                    kind: ErrorKind::UnknownFlag {
238                        flag: raw_arg.to_string(),
239                        similar_flag: self.get_similar_flag(&raw_arg),
240                    },
241                });
242            }
243
244            else {
245                args.push(self.arg_type.parse(&raw_arg, Span::Exact(arg_index + skip_first_n))?);
246            }
247        }
248
249        if let Some(arg_flag) = expecting_flag_arg {
250            return Err(Error {
251                span: Span::End,
252                kind: ErrorKind::MissingArgument(arg_flag.flag.to_string(), arg_flag.arg_type),
253            });
254        }
255
256        for i in 0..flags.len() {
257            if flags[i].is_none() {
258                if let Some(j) = self.flags[i].default {
259                    flags[i] = Some(self.flags[i].values[j].clone());
260                }
261
262                else if !self.flags[i].optional {
263                    return Err(Error {
264                        span: Span::End,
265                        kind: ErrorKind::MissingFlag(self.flags[i].values.join(" | ")),
266                    });
267                }
268            }
269        }
270
271        loop {
272            let span = match self.arg_count {
273                ArgCount::Geq(n) if args.len() < n => { Span::End },
274                ArgCount::Leq(n) if args.len() > n => { Span::NthArg(n + 1) },
275                ArgCount::Exact(n) if args.len() > n => { Span::NthArg(n + 1) },
276                ArgCount::Exact(n) if args.len() < n => { Span::NthArg(args.len().max(1) - 1) },
277                ArgCount::None if args.len() > 0 => { Span::FirstArg },
278                _ => { break; },
279            };
280
281            return Err(Error {
282                span,
283                kind: ErrorKind::WrongArgCount {
284                    expected: self.arg_count,
285                    got: args.len(),
286                },
287            });
288        }
289
290        for (flag, arg_flag) in self.arg_flags.iter() {
291            if arg_flags.contains_key(flag) {
292                continue;
293            }
294
295            else if let Some(default) = &arg_flag.default {
296                arg_flags.insert(flag.to_string(), default.to_string());
297            }
298
299            else if !arg_flag.optional {
300                return Err(Error {
301                    span: Span::End,
302                    kind: ErrorKind::MissingFlag(flag.to_string()),
303                });
304            }
305        }
306
307        Ok(ParsedArgs {
308            skip_first_n,
309            raw_args: raw_args.to_vec(),
310            args,
311            flags,
312            arg_flags,
313            show_help: false,
314        })
315    }
316
317    fn get_similar_flag(&self, flag: &str) -> Option<String> {
318        let mut candidates = vec![];
319
320        for flag in self.flags.iter() {
321            for flag in flag.values.iter() {
322                candidates.push(flag.to_string());
323            }
324        }
325
326        for flag in self.arg_flags.keys() {
327            candidates.push(flag.to_string());
328        }
329
330        get_closest_string(&candidates, flag)
331    }
332}
333
334#[derive(Clone, Copy, Debug)]
335pub enum ArgCount {
336    Geq(usize),
337    Leq(usize),
338    Exact(usize),
339    Any,
340    None,
341}
342
343/// `ArgType` doesn't do much. Only `Integer` and
344/// `IntegerBetween` variants do extra checks. The
345/// other variants are more like type signatures
346/// which clarifies the intent of the code.
347#[derive(Clone, Copy, Debug)]
348pub enum ArgType {
349    String,
350    Path,
351    Command,
352    Query,
353    UidOrPath,
354    Integer,
355    Url,
356
357    /// Both inclusive
358    IntegerBetween {
359        min: Option<i128>,
360        max: Option<i128>,
361    },
362}
363
364impl ArgType {
365    pub fn parse(&self, arg: &str, span: Span) -> Result<String, Error> {
366        match self {
367            ArgType::Integer => match arg.parse::<i128>() {
368                Ok(_) => Ok(arg.to_string()),
369                Err(e) => Err(Error {
370                    span,
371                    kind: ErrorKind::ParseIntError(e),
372                }),
373            },
374            ArgType::IntegerBetween { min, max } => match arg.parse::<i128>() {
375                Ok(n) => {
376                    if let Some(min) = *min {
377                        if n < min {
378                            return Err(Error{
379                                span,
380                                kind: ErrorKind::IntegerNotInRange { min: Some(min), max: *max, n },
381                            });
382                        }
383                    }
384
385                    if let Some(max) = *max {
386                        if n > max {
387                            return Err(Error{
388                                span,
389                                kind: ErrorKind::IntegerNotInRange { min: *min, max: Some(max), n },
390                            });
391                        }
392                    }
393
394                    Ok(arg.to_string())
395                },
396                Err(e) => Err(Error {
397                    span,
398                    kind: ErrorKind::ParseIntError(e),
399                }),
400            },
401            ArgType::String
402            | ArgType::Path
403            | ArgType::Url
404            | ArgType::UidOrPath
405            | ArgType::Command
406            | ArgType::Query => Ok(arg.to_string()),
407        }
408    }
409}
410
411#[derive(Clone, Debug)]
412pub struct Flag {
413    values: Vec<String>,
414    optional: bool,
415    default: Option<usize>,
416}
417
418#[derive(Clone, Debug)]
419pub struct ArgFlag {
420    flag: String,
421    optional: bool,
422    default: Option<String>,
423    arg_type: ArgType,
424}
425
426pub struct ParsedArgs {
427    skip_first_n: usize,
428    raw_args: Vec<String>,
429    args: Vec<String>,
430    flags: Vec<Option<String>>,
431    pub arg_flags: HashMap<String, String>,
432    show_help: bool,  // TODO: options for help messages
433}
434
435impl ParsedArgs {
436    pub fn new() -> Self {
437        ParsedArgs {
438            skip_first_n: 0,
439            raw_args: vec![],
440            args: vec![],
441            flags: vec![],
442            arg_flags: HashMap::new(),
443            show_help: false,
444        }
445    }
446
447    pub fn get_args(&self) -> Vec<String> {
448        self.args.clone()
449    }
450
451    pub fn get_args_exact(&self, count: usize) -> Result<Vec<String>, Error> {
452        if self.args.len() == count {
453            Ok(self.args.clone())
454        }
455
456        else {
457            Err(Error {
458                span: Span::FirstArg.render(&self.raw_args, self.skip_first_n),
459                kind: ErrorKind::WrongArgCount {
460                    expected: ArgCount::Exact(count),
461                    got: self.args.len(),
462                },
463            })
464        }
465    }
466
467    // if there's an index error, it panics instead of returning None
468    // if it returns None, that means Nth flag is optional and its value is None
469    pub fn get_flag(&self, index: usize) -> Option<String> {
470        self.flags[index].clone()
471    }
472
473    pub fn show_help(&self) -> bool {
474        self.show_help
475    }
476}
477
478/// It parses `rag [-C <path>] <command> <args>` and returns
479/// `Ok((args, pre_args))` where `args` is `rag <command> <args>` and
480/// `pre_args` is `-C <path>`.
481///
482/// NOTE: Do not use this function. I have implemented this because I'm not sure
483/// how to implement `-C` option. I'll remove this function as soon as I come up
484/// with a nice way to implement `-C`.
485///
486/// It only supports `-C <path>` and not `-C=<path>` and that's intentional. Git
487/// neither supports `-C=<path>` (I don't know why), and I decided to blindly follow that.
488pub fn parse_pre_args(args: &[String]) -> Result<(Vec<String>, ParsedArgs), Error> {
489    match args.get(1).map(|s| s.as_str()) {
490        Some("-C") => match args.get(2).map(|s| s.as_str()) {
491            Some(path) => {
492                let mut result = ParsedArgs::new();
493                result.arg_flags.insert(String::from("-C"), path.to_string());
494                Ok((
495                    vec![
496                        vec![args[0].clone()],
497                        if args.len() < 4 { vec![] } else { args[3..].to_vec() },
498                    ].concat(),
499                    result,
500                ))
501            },
502            None => Err(Error {
503                span: Span::Exact(2),
504                kind: ErrorKind::MissingArgument(String::from("-C"), ArgType::Path),
505            }),
506        },
507        _ => Ok((args.to_vec(), ParsedArgs::new())),
508    }
509}