ragit_cli/
lib.rs

1use std::collections::HashMap;
2
3mod dist;
4mod error;
5mod file_size;
6mod span;
7
8pub use dist::{get_closest_string, substr_edit_distance};
9pub use error::{Error, ErrorKind, RawError};
10use file_size::parse_file_size;
11pub use span::{RenderedSpan, Span, underline_span};
12
13pub struct ArgParser {
14    arg_count: ArgCount,
15    arg_type: ArgType,
16    flags: Vec<Flag>,
17    aliases: HashMap<String, String>,
18
19    // `--N=20`, `--prefix=rust`
20    arg_flags: HashMap<String, ArgFlag>,
21
22    // '-f' -> '--force'
23    short_flags: HashMap<String, String>,
24}
25
26impl ArgParser {
27    pub fn new() -> Self {
28        ArgParser {
29            arg_count: ArgCount::None,
30            arg_type: ArgType::String,
31            flags: vec![],
32            aliases: HashMap::new(),
33            arg_flags: HashMap::new(),
34            short_flags: HashMap::new(),
35        }
36    }
37
38    pub fn args(&mut self, arg_type: ArgType, arg_count: ArgCount) -> &mut Self {
39        self.arg_type = arg_type;
40        self.arg_count = arg_count;
41        self
42    }
43
44    pub fn flag(&mut self, flags: &[&str]) -> &mut Self {
45        self.flags.push(Flag {
46            values: flags.iter().map(|flag| flag.to_string()).collect(),
47            optional: false,
48            default: None,
49        });
50        self
51    }
52
53    pub fn optional_flag(&mut self, flags: &[&str]) -> &mut Self {
54        self.flags.push(Flag {
55            values: flags.iter().map(|flag| flag.to_string()).collect(),
56            optional: true,
57            default: None,
58        });
59        self
60    }
61
62    pub fn arg_flag(&mut self, flag: &str, arg_type: ArgType) -> &mut Self {
63        self.arg_flags.insert(flag.to_string(), ArgFlag { flag: flag.to_string(), optional: false, default: None, arg_type });
64        self
65    }
66
67    pub fn optional_arg_flag(&mut self, flag: &str, arg_type: ArgType) -> &mut Self {
68        self.arg_flags.insert(flag.to_string(), ArgFlag { flag: flag.to_string(), optional: true, default: None, arg_type });
69        self
70    }
71
72    pub fn arg_flag_with_default(&mut self, flag: &str, default: &str, arg_type: ArgType) -> &mut Self {
73        self.arg_flags.insert(flag.to_string(), ArgFlag { flag: flag.to_string(), optional: true, default: Some(default.to_string()), arg_type });
74        self
75    }
76
77    // the first flag is the default value
78    pub fn flag_with_default(&mut self, flags: &[&str]) -> &mut Self {
79        self.flags.push(Flag {
80            values: flags.iter().map(|flag| flag.to_string()).collect(),
81            optional: true,
82            default: Some(0),
83        });
84        self
85    }
86
87    fn map_short_flag(&self, flag: &str) -> String {
88        match self.short_flags.get(flag) {
89            Some(f) => f.to_string(),
90            None => flag.to_string(),
91        }
92    }
93
94    pub fn short_flag(&mut self, flags: &[&str]) -> &mut Self {
95        for flag in flags.iter() {
96            let short_flag = flag.get(1..3).unwrap().to_string();
97
98            if let Some(old) = self.short_flags.get(&short_flag) {
99                panic!("{flag} and {old} have the same short name!")
100            }
101
102            self.short_flags.insert(short_flag, flag.to_string());
103        }
104
105        self
106    }
107
108    pub fn alias(&mut self, from: &str, to: &str) -> &mut Self {
109        self.aliases.insert(from.to_string(), to.to_string());
110        self
111    }
112
113    /// Let's say `raw_args` is `["rag", "ls-files", "--json", "--staged", "--name-only"]` and
114    /// you don't care about the first 2 args (path and command name). You only want to parse
115    /// the flags (the last 3 args). In this case, you set `skip_first_n` to 2.
116    pub fn parse(&self, raw_args: &[String], skip_first_n: usize) -> Result<ParsedArgs, Error> {
117        self.parse_worker(raw_args, skip_first_n).map_err(
118            |e| Error {
119                span: e.span.render(raw_args, skip_first_n),
120                kind: e.kind,
121            }
122        )
123    }
124
125    fn parse_worker(&self, raw_args: &[String], skip_first_n: usize) -> Result<ParsedArgs, RawError> {
126        let mut args = vec![];
127        let mut flags = vec![None; self.flags.len()];
128        let mut arg_flags = HashMap::new();
129        let mut expecting_flag_arg: Option<ArgFlag> = None;
130        let mut no_more_flags = false;
131
132        if raw_args.get(skip_first_n).map(|arg| arg.as_str()) == Some("--help") {
133            return Ok(ParsedArgs {
134                skip_first_n,
135                raw_args: raw_args.to_vec(),
136                args,
137                flags: vec![],
138                arg_flags,
139                show_help: true,
140            });
141        }
142
143        'raw_arg_loop: for (arg_index, raw_arg) in raw_args[skip_first_n..].iter().enumerate() {
144            let raw_arg = match self.aliases.get(raw_arg) {
145                Some(alias) => alias.to_string(),
146                None => raw_arg.to_string(),
147            };
148
149            if raw_arg == "--" {
150                if let Some(arg_flag) = expecting_flag_arg {
151                    return Err(RawError {
152                        span: Span::End,
153                        kind: ErrorKind::MissingArgument(arg_flag.flag.to_string(), arg_flag.arg_type),
154                    });
155                }
156
157                no_more_flags = true;
158                continue;
159            }
160
161            if let Some(arg_flag) = expecting_flag_arg {
162                expecting_flag_arg = None;
163                let flag_arg = arg_flag.arg_type.parse(&raw_arg, Span::Exact(arg_index + skip_first_n))?;
164
165                if let Some(_) = arg_flags.insert(arg_flag.flag.clone(), flag_arg) {
166                    return Err(RawError {
167                        span: Span::Exact(arg_index + skip_first_n),
168                        kind: ErrorKind::SameFlagMultipleTimes(
169                            arg_flag.flag.clone(),
170                            arg_flag.flag.clone(),
171                        ),
172                    });
173                }
174
175                continue;
176            }
177
178            if raw_arg.starts_with("-") && !no_more_flags {
179                let mapped_flag = self.map_short_flag(&raw_arg);
180
181                for (flag_index, flag) in self.flags.iter().enumerate() {
182                    if flag.values.contains(&mapped_flag) {
183                        if flags[flag_index].is_none() {
184                            flags[flag_index] = Some(mapped_flag.to_string());
185                            continue 'raw_arg_loop;
186                        }
187
188                        else {
189                            return Err(RawError {
190                                span: Span::Exact(arg_index + skip_first_n),
191                                kind: ErrorKind::SameFlagMultipleTimes(
192                                    flags[flag_index].as_ref().unwrap().to_string(),
193                                    raw_arg.to_string(),
194                                ),
195                            });
196                        }
197                    }
198                }
199
200                if let Some(arg_flag) = self.arg_flags.get(&mapped_flag) {
201                    expecting_flag_arg = Some(arg_flag.clone());
202                    continue;
203                }
204
205                if raw_arg.contains("=") {
206                    let splitted = raw_arg.splitn(2, '=').collect::<Vec<_>>();
207                    let flag = self.map_short_flag(splitted[0]);
208                    let flag_arg = splitted[1];
209
210                    if let Some(arg_flag) = self.arg_flags.get(&flag) {
211                        let flag_arg = arg_flag.arg_type.parse(flag_arg, Span::Exact(arg_index + skip_first_n))?;
212
213                        if let Some(_) = arg_flags.insert(flag.to_string(), flag_arg) {
214                            return Err(RawError {
215                                span: Span::Exact(arg_index + skip_first_n),
216                                kind: ErrorKind::SameFlagMultipleTimes(
217                                    flag.to_string(),
218                                    flag.to_string(),
219                                ),
220                            });
221                        }
222
223                        continue;
224                    }
225
226                    else {
227                        return Err(RawError {
228                            span: Span::Exact(arg_index + skip_first_n),
229                            kind: ErrorKind::UnknownFlag {
230                                flag: flag.to_string(),
231                                similar_flag: self.get_similar_flag(&flag),
232                            },
233                        });
234                    }
235                }
236
237                return Err(RawError {
238                    span: Span::Exact(arg_index + skip_first_n),
239                    kind: ErrorKind::UnknownFlag {
240                        flag: raw_arg.to_string(),
241                        similar_flag: self.get_similar_flag(&raw_arg),
242                    },
243                });
244            }
245
246            else {
247                args.push(self.arg_type.parse(&raw_arg, Span::Exact(arg_index + skip_first_n))?);
248            }
249        }
250
251        if let Some(arg_flag) = expecting_flag_arg {
252            return Err(RawError {
253                span: Span::End,
254                kind: ErrorKind::MissingArgument(arg_flag.flag.to_string(), arg_flag.arg_type),
255            });
256        }
257
258        for i in 0..flags.len() {
259            if flags[i].is_none() {
260                if let Some(j) = self.flags[i].default {
261                    flags[i] = Some(self.flags[i].values[j].clone());
262                }
263
264                else if !self.flags[i].optional {
265                    return Err(RawError {
266                        span: Span::End,
267                        kind: ErrorKind::MissingFlag(self.flags[i].values.join(" | ")),
268                    });
269                }
270            }
271        }
272
273        loop {
274            let span = match self.arg_count {
275                ArgCount::Geq(n) if args.len() < n => { Span::End },
276                ArgCount::Leq(n) if args.len() > n => { Span::NthArg(n + 1) },
277                ArgCount::Exact(n) if args.len() > n => { Span::NthArg(n + 1) },
278                ArgCount::Exact(n) if args.len() < n => { Span::NthArg(args.len().max(1) - 1) },
279                ArgCount::None if args.len() > 0 => { Span::FirstArg },
280                _ => { break; },
281            };
282
283            return Err(RawError {
284                span,
285                kind: ErrorKind::WrongArgCount {
286                    expected: self.arg_count,
287                    got: args.len(),
288                },
289            });
290        }
291
292        for (flag, arg_flag) in self.arg_flags.iter() {
293            if arg_flags.contains_key(flag) {
294                continue;
295            }
296
297            else if let Some(default) = &arg_flag.default {
298                arg_flags.insert(flag.to_string(), arg_flag.arg_type.parse(default, Span::None)?);
299            }
300
301            else if !arg_flag.optional {
302                return Err(RawError {
303                    span: Span::End,
304                    kind: ErrorKind::MissingFlag(flag.to_string()),
305                });
306            }
307        }
308
309        Ok(ParsedArgs {
310            skip_first_n,
311            raw_args: raw_args.to_vec(),
312            args,
313            flags,
314            arg_flags,
315            show_help: false,
316        })
317    }
318
319    fn get_similar_flag(&self, flag: &str) -> Option<String> {
320        let mut candidates = vec![];
321
322        for flag in self.flags.iter() {
323            for flag in flag.values.iter() {
324                candidates.push(flag.to_string());
325            }
326        }
327
328        for flag in self.arg_flags.keys() {
329            candidates.push(flag.to_string());
330        }
331
332        get_closest_string(&candidates, flag)
333    }
334}
335
336#[derive(Clone, Copy, Debug)]
337pub enum ArgCount {
338    Geq(usize),
339    Leq(usize),
340    Exact(usize),
341    Any,
342    None,
343}
344
345#[derive(Clone, Debug)]
346pub enum ArgType {
347    /// Any string
348    String,
349
350    /// The argument must be one of the variants.
351    Enum(Vec<String>),
352
353    /// I recommend you use `Self::integer()`, `Self::uinteger()`
354    /// or `Self::integer_between()`.
355    Integer {
356        min: Option<i128>,
357        max: Option<i128>,
358    },
359
360    /// I recommend you use `Self::float()` or `Self::float_between()`.
361    Float {
362        min: Option<f64>,
363        max: Option<f64>,
364    },
365
366    /// I recommend you use `Self::file_size()` or `Self::file_size_between()`.
367    /// It's in bytes.
368    FileSize {
369        min: Option<u64>,
370        max: Option<u64>,
371    },
372}
373
374impl ArgType {
375    pub fn integer() -> Self {
376        ArgType::Integer {
377            min: None,
378            max: None,
379        }
380    }
381
382    pub fn uinteger() -> Self {
383        ArgType::Integer {
384            min: Some(0),
385            max: None,
386        }
387    }
388
389    /// Both inclusive
390    pub fn integer_between(min: Option<i128>, max: Option<i128>) -> Self {
391        ArgType::Integer { min, max }
392    }
393
394    pub fn float() -> Self {
395        ArgType::Float {
396            min: None,
397            max: None,
398        }
399    }
400
401    /// Both inclusive
402    pub fn float_between(min: Option<f64>, max: Option<f64>) -> Self {
403        ArgType::Float { min, max }
404    }
405
406    pub fn enum_(variants: &[&str]) -> Self {
407        ArgType::Enum(variants.iter().map(|v| v.to_string()).collect())
408    }
409
410    pub fn file_size() -> Self {
411        ArgType::FileSize {
412            min: None,
413            max: None,
414        }
415    }
416
417    pub fn file_size_between(min: Option<u64>, max: Option<u64>) -> Self {
418        ArgType::FileSize { min, max }
419    }
420
421    pub fn parse(&self, arg: &str, span: Span) -> Result<String, RawError> {
422        match self {
423            ArgType::Integer { min, max } => match arg.parse::<i128>() {
424                Ok(n) => {
425                    if let Some(min) = *min {
426                        if n < min {
427                            return Err(RawError{
428                                span,
429                                kind: ErrorKind::NumberNotInRange {
430                                    min: Some(min.to_string()),
431                                    max: max.map(|n| n.to_string()),
432                                    n: n.to_string(),
433                                },
434                            });
435                        }
436                    }
437
438                    if let Some(max) = *max {
439                        if n > max {
440                            return Err(RawError{
441                                span,
442                                kind: ErrorKind::NumberNotInRange {
443                                    min: min.map(|n| n.to_string()),
444                                    max: Some(max.to_string()),
445                                    n: n.to_string(),
446                                },
447                            });
448                        }
449                    }
450
451                    Ok(arg.to_string())
452                },
453                Err(e) => Err(RawError {
454                    span,
455                    kind: ErrorKind::ParseIntError(e),
456                }),
457            },
458            ArgType::Float { min, max } => match arg.parse::<f64>() {
459                Ok(n) => {
460                    if let Some(min) = *min {
461                        if n < min {
462                            return Err(RawError{
463                                span,
464                                kind: ErrorKind::NumberNotInRange {
465                                    min: Some(min.to_string()),
466                                    max: max.map(|n| n.to_string()),
467                                    n: n.to_string(),
468                                },
469                            });
470                        }
471                    }
472
473                    if let Some(max) = *max {
474                        if n > max {
475                            return Err(RawError{
476                                span,
477                                kind: ErrorKind::NumberNotInRange {
478                                    min: min.map(|n| n.to_string()),
479                                    max: Some(max.to_string()),
480                                    n: n.to_string(),
481                                },
482                            });
483                        }
484                    }
485
486                    Ok(arg.to_string())
487                },
488                Err(e) => Err(RawError {
489                    span,
490                    kind: ErrorKind::ParseFloatError(e),
491                }),
492            },
493            ArgType::Enum(variants) => {
494                let mut matched = false;
495
496                for variant in variants.iter() {
497                    if variant == arg {
498                        matched = true;
499                        break;
500                    }
501                }
502
503                if matched {
504                    Ok(arg.to_string())
505                }
506
507                else {
508                    Err(RawError {
509                        span,
510                        kind: ErrorKind::UnknownVariant {
511                            variant: arg.to_string(),
512                            similar_variant: get_closest_string(variants, arg),
513                        },
514                    })
515                }
516            },
517            ArgType::FileSize { min, max } => {
518                let file_size = parse_file_size(arg, span)?;
519
520                if let Some(min) = *min {
521                    if file_size < min {
522                        return Err(RawError {
523                            span,
524                            kind: ErrorKind::NumberNotInRange {
525                                min: Some(min.to_string()),
526                                max: max.map(|n| n.to_string()),
527                                n: file_size.to_string(),
528                            },
529                        });
530                    }
531                }
532
533                if let Some(max) = *max {
534                    if file_size > max {
535                        return Err(RawError {
536                            span,
537                            kind: ErrorKind::NumberNotInRange {
538                                min: min.map(|n| n.to_string()),
539                                max: Some(max.to_string()),
540                                n: file_size.to_string(),
541                            },
542                        });
543                    }
544                }
545
546                Ok(file_size.to_string())
547            },
548            ArgType::String => Ok(arg.to_string()),
549        }
550    }
551}
552
553#[derive(Clone, Debug)]
554pub struct Flag {
555    values: Vec<String>,
556    optional: bool,
557    default: Option<usize>,
558}
559
560#[derive(Clone, Debug)]
561pub struct ArgFlag {
562    flag: String,
563    optional: bool,
564    default: Option<String>,
565    arg_type: ArgType,
566}
567
568pub struct ParsedArgs {
569    skip_first_n: usize,
570    raw_args: Vec<String>,
571    args: Vec<String>,
572    flags: Vec<Option<String>>,
573    pub arg_flags: HashMap<String, String>,
574    show_help: bool,  // TODO: options for help messages
575}
576
577impl ParsedArgs {
578    pub fn new() -> Self {
579        ParsedArgs {
580            skip_first_n: 0,
581            raw_args: vec![],
582            args: vec![],
583            flags: vec![],
584            arg_flags: HashMap::new(),
585            show_help: false,
586        }
587    }
588
589    pub fn get_args(&self) -> Vec<String> {
590        self.args.clone()
591    }
592
593    pub fn get_args_exact(&self, count: usize) -> Result<Vec<String>, Error> {
594        if self.args.len() == count {
595            Ok(self.args.clone())
596        }
597
598        else {
599            Err(Error {
600                span: Span::FirstArg.render(&self.raw_args, self.skip_first_n),
601                kind: ErrorKind::WrongArgCount {
602                    expected: ArgCount::Exact(count),
603                    got: self.args.len(),
604                },
605            })
606        }
607    }
608
609    // if there's an index error, it panics instead of returning None
610    // if it returns None, that means Nth flag is optional and its value is None
611    pub fn get_flag(&self, index: usize) -> Option<String> {
612        self.flags[index].clone()
613    }
614
615    pub fn show_help(&self) -> bool {
616        self.show_help
617    }
618}
619
620/// It parses `rag [-C <path>] <command> <args>` and returns
621/// `Ok((args, pre_args))` where `args` is `rag <command> <args>` and
622/// `pre_args` is `-C <path>`.
623///
624/// NOTE: Do not use this function. I have implemented this because I'm not sure
625/// how to implement `-C` option. I'll remove this function as soon as I come up
626/// with a nice way to implement `-C`.
627///
628/// It only supports `-C <path>` and not `-C=<path>` and that's intentional. Git
629/// neither supports `-C=<path>` (I don't know why), and I decided to blindly follow that.
630pub fn parse_pre_args(args: &[String]) -> Result<(Vec<String>, ParsedArgs), Error> {
631    match args.get(1).map(|s| s.as_str()) {
632        Some("-C") => match args.get(2).map(|s| s.as_str()) {
633            Some(path) => {
634                let mut result = ParsedArgs::new();
635                result.arg_flags.insert(String::from("-C"), path.to_string());
636                Ok((
637                    vec![
638                        vec![args[0].clone()],
639                        if args.len() < 4 { vec![] } else { args[3..].to_vec() },
640                    ].concat(),
641                    result,
642                ))
643            },
644            None => Err(Error {
645                span: Span::Exact(2).render(args, 0),
646                kind: ErrorKind::MissingArgument(String::from("-C"), ArgType::String),
647            }),
648        },
649        _ => Ok((args.to_vec(), ParsedArgs::new())),
650    }
651}