Skip to main content

safe_chains/
policy.rs

1use crate::parse::{Token, WordSet};
2
3/// Whether unrecognized flag-shaped tokens are denied or silently accepted
4/// as positional arguments. The default (Strict) makes the allowlist
5/// authoritative — any unrecognized `-X` or `--foo` is denied.
6#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
7pub enum UnknownTolerance {
8    /// Deny every unrecognized flag-shaped token. The safe default.
9    #[default]
10    Strict,
11    /// Accept unknown single-dash tokens (`-X`, `-help`, `-mayDie`) as
12    /// positional. Reject unknown double-dash. Use for tools like
13    /// `pdftotext` that have single-dash long flags.
14    Short,
15    /// Accept unknown double-dash tokens (`--foo`, `--foo=value`) as
16    /// positional. Reject unknown single-dash. Dangerous: most modern
17    /// destructive flags are double-dash, so enabling this can silently
18    /// accept mutating options. Reserved for tools with genuinely
19    /// unbounded long-flag surfaces (AWS CLI service flags).
20    Long,
21    /// Accept both single-dash and double-dash unknowns as positional.
22    /// Most permissive; combines the cost of `Short` and `Long`.
23    Both,
24}
25
26impl UnknownTolerance {
27    pub const fn allows_short(self) -> bool {
28        matches!(self, Self::Short | Self::Both)
29    }
30    pub const fn allows_long(self) -> bool {
31        matches!(self, Self::Long | Self::Both)
32    }
33}
34
35/// How the dispatcher treats tokens that look like flags but aren't in the
36/// allowlist. `unknown` controls flag-shaped unknowns; `numeric_dash` opts
37/// into `-NUMBER` shorthand (e.g. `head -20`).
38#[derive(Clone, Copy, Debug, Default)]
39pub struct FlagTolerance {
40    pub unknown: UnknownTolerance,
41    pub numeric_dash: bool,
42}
43
44impl FlagTolerance {
45    /// Strict allowlist: deny every unrecognized flag-shaped token.
46    /// `const`-callable for use in static `FlagPolicy` literals.
47    pub const fn strict() -> Self {
48        Self { unknown: UnknownTolerance::Strict, numeric_dash: false }
49    }
50}
51
52/// Predicate over the first positional token of a fallback grammar.
53/// Lets a TOML-declared fallback say "the first positional must look
54/// like a path" without the handler hardcoding the test.
55#[derive(Clone, Copy, Debug, PartialEq, Eq)]
56pub enum PositionalShape {
57    /// Looks like a file path: contains `/`, contains `.`, or is `-`
58    /// (the conventional stdin marker). Rejects flag-shaped tokens.
59    Path,
60}
61
62impl PositionalShape {
63    pub fn matches(self, token: &str) -> bool {
64        match self {
65            Self::Path => looks_like_path(token),
66        }
67    }
68
69    pub fn from_name(name: &str) -> Option<Self> {
70        match name {
71            "path" => Some(Self::Path),
72            _ => None,
73        }
74    }
75}
76
77/// Heuristic for "this token looks like a file path." Used by the
78/// `path` `PositionalShape`. Conservative on purpose — a bare word
79/// like `Tiltfile` is a valid filename in cwd but the heuristic
80/// rejects it to avoid swallowing flag-less subcommands. Callers
81/// that want bare-name acceptance should match a sub block instead.
82pub fn looks_like_path(token: &str) -> bool {
83    if token.is_empty() {
84        return false;
85    }
86    if token.starts_with('-') {
87        return token == "-";
88    }
89    token.contains('/') || token.contains('.')
90}
91
92pub trait FlagSet {
93    fn contains_flag(&self, token: &str) -> bool;
94    fn contains_short(&self, byte: u8) -> bool;
95}
96
97impl FlagSet for WordSet {
98    fn contains_flag(&self, token: &str) -> bool {
99        self.contains(token)
100    }
101    fn contains_short(&self, byte: u8) -> bool {
102        self.contains_short(byte)
103    }
104}
105
106impl FlagSet for [String] {
107    fn contains_flag(&self, token: &str) -> bool {
108        self.iter().any(|f| f.as_str() == token)
109    }
110    fn contains_short(&self, byte: u8) -> bool {
111        self.iter().any(|f| f.len() == 2 && f.as_bytes()[1] == byte)
112    }
113}
114
115impl FlagSet for Vec<String> {
116    fn contains_flag(&self, token: &str) -> bool {
117        self.as_slice().contains_flag(token)
118    }
119    fn contains_short(&self, byte: u8) -> bool {
120        self.as_slice().contains_short(byte)
121    }
122}
123
124pub struct FlagPolicy {
125    pub standalone: WordSet,
126    pub valued: WordSet,
127    pub bare: bool,
128    pub max_positional: Option<usize>,
129    pub tolerance: FlagTolerance,
130}
131
132impl FlagPolicy {
133    pub fn describe(&self) -> String {
134        use crate::docs::wordset_items;
135        let mut lines = Vec::new();
136        let standalone = wordset_items(&self.standalone);
137        if !standalone.is_empty() {
138            lines.push(format!("- Allowed standalone flags: {standalone}"));
139        }
140        let valued = wordset_items(&self.valued);
141        if !valued.is_empty() {
142            lines.push(format!("- Allowed valued flags: {valued}"));
143        }
144        if self.bare {
145            lines.push("- Bare invocation allowed".to_string());
146        }
147        if self.tolerance.unknown != UnknownTolerance::Strict {
148            lines.push("- Hyphen-prefixed positional arguments accepted".to_string());
149        }
150        if self.tolerance.numeric_dash {
151            lines.push("- Numeric shorthand accepted (e.g. -20 for -n 20)".to_string());
152        }
153        if lines.is_empty() && !self.bare {
154            return "- Positional arguments only".to_string();
155        }
156        lines.join("\n")
157    }
158
159}
160
161pub fn check(tokens: &[Token], policy: &FlagPolicy) -> bool {
162    check_flags(
163        tokens,
164        &policy.standalone,
165        &policy.valued,
166        policy.bare,
167        policy.max_positional,
168        policy.tolerance,
169    )
170}
171
172pub fn check_flags<S: FlagSet + ?Sized, V: FlagSet + ?Sized>(
173    tokens: &[Token],
174    standalone: &S,
175    valued: &V,
176    bare: bool,
177    max_positional: Option<usize>,
178    tolerance: FlagTolerance,
179) -> bool {
180    if tokens.len() == 1 {
181        return bare;
182    }
183
184    let mut i = 1;
185    let mut positionals: usize = 0;
186    while i < tokens.len() {
187        let t = &tokens[i];
188
189        if *t == "--" {
190            positionals += tokens.len() - i - 1;
191            break;
192        }
193
194        if !t.starts_with('-') {
195            positionals += 1;
196            i += 1;
197            continue;
198        }
199
200        if tolerance.numeric_dash && t.len() > 1 && t[1..].bytes().all(|b| b.is_ascii_digit()) {
201            i += 1;
202            continue;
203        }
204
205        if standalone.contains_flag(t) {
206            i += 1;
207            continue;
208        }
209
210        if valued.contains_flag(t) {
211            i += 2;
212            continue;
213        }
214
215        if let Some(flag) = t.as_str().split_once('=').map(|(f, _)| f) {
216            if valued.contains_flag(flag) {
217                i += 1;
218                continue;
219            }
220            // `--foo=value` forms are governed by the long-flag tolerance.
221            if tolerance.unknown.allows_long() {
222                positionals += 1;
223                i += 1;
224                continue;
225            }
226            return false;
227        }
228
229        if t.starts_with("--") {
230            if tolerance.unknown.allows_long() {
231                positionals += 1;
232                i += 1;
233                continue;
234            }
235            return false;
236        }
237
238        let bytes = t.as_bytes();
239        let mut j = 1;
240        while j < bytes.len() {
241            let b = bytes[j];
242            let is_last = j == bytes.len() - 1;
243            if standalone.contains_short(b) {
244                j += 1;
245                continue;
246            }
247            if valued.contains_short(b) {
248                if is_last {
249                    i += 1;
250                }
251                break;
252            }
253            if tolerance.unknown.allows_short() {
254                positionals += 1;
255                break;
256            }
257            return false;
258        }
259        i += 1;
260    }
261    max_positional.is_none_or(|max| positionals <= max)
262}
263
264#[cfg(test)]
265mod tests {
266    use super::*;
267
268    static TEST_POLICY: FlagPolicy = FlagPolicy {
269        standalone: WordSet::flags(&[
270            "--color", "--count", "--help", "--recursive", "--version",
271            "-H", "-c", "-i", "-l", "-n", "-o", "-r", "-s", "-v", "-w",
272        ]),
273        valued: WordSet::flags(&[
274            "--after-context", "--before-context", "--max-count",
275            "-A", "-B", "-m",
276        ]),
277        bare: false,
278        max_positional: None,
279        tolerance: FlagTolerance::strict(),
280    };
281
282    fn toks(words: &[&str]) -> Vec<Token> {
283        words.iter().map(|s| Token::from_test(s)).collect()
284    }
285
286    #[test]
287    fn bare_denied_when_bare_false() {
288        assert!(!check(&toks(&["grep"]), &TEST_POLICY));
289    }
290
291    #[test]
292    fn bare_allowed_when_bare_true() {
293        let policy = FlagPolicy {
294            standalone: WordSet::flags(&[]),
295            valued: WordSet::flags(&[]),
296            bare: true,
297            max_positional: None,
298            tolerance: FlagTolerance::strict(),
299        };
300        assert!(check(&toks(&["uname"]), &policy));
301    }
302
303    #[test]
304    fn standalone_long_flag() {
305        assert!(check(&toks(&["grep", "--recursive", "pattern", "."]), &TEST_POLICY));
306    }
307
308    #[test]
309    fn standalone_short_flag() {
310        assert!(check(&toks(&["grep", "-r", "pattern", "."]), &TEST_POLICY));
311    }
312
313    #[test]
314    fn valued_long_flag_space() {
315        assert!(check(&toks(&["grep", "--max-count", "5", "pattern"]), &TEST_POLICY));
316    }
317
318    #[test]
319    fn valued_long_flag_eq() {
320        assert!(check(&toks(&["grep", "--max-count=5", "pattern"]), &TEST_POLICY));
321    }
322
323    #[test]
324    fn valued_short_flag_space() {
325        assert!(check(&toks(&["grep", "-m", "5", "pattern"]), &TEST_POLICY));
326    }
327
328    #[test]
329    fn combined_standalone_short() {
330        assert!(check(&toks(&["grep", "-rn", "pattern", "."]), &TEST_POLICY));
331    }
332
333    #[test]
334    fn combined_short_with_valued_last() {
335        assert!(check(&toks(&["grep", "-rnm", "5", "pattern"]), &TEST_POLICY));
336    }
337
338    #[test]
339    fn combined_short_valued_mid_consumes_rest() {
340        assert!(check(&toks(&["grep", "-rmn", "pattern"]), &TEST_POLICY));
341    }
342
343    #[test]
344    fn unknown_long_flag_denied() {
345        assert!(!check(&toks(&["grep", "--exec", "cmd"]), &TEST_POLICY));
346    }
347
348    #[test]
349    fn unknown_short_flag_denied() {
350        assert!(!check(&toks(&["grep", "-z", "pattern"]), &TEST_POLICY));
351    }
352
353    #[test]
354    fn unknown_combined_short_denied() {
355        assert!(!check(&toks(&["grep", "-rz", "pattern"]), &TEST_POLICY));
356    }
357
358    #[test]
359    fn unknown_long_eq_denied() {
360        assert!(!check(&toks(&["grep", "--output=file.txt", "pattern"]), &TEST_POLICY));
361    }
362
363    #[test]
364    fn double_dash_stops_checking() {
365        assert!(check(&toks(&["grep", "--", "--not-a-flag", "file"]), &TEST_POLICY));
366    }
367
368    #[test]
369    fn positional_args_allowed() {
370        assert!(check(&toks(&["grep", "pattern", "file.txt", "other.txt"]), &TEST_POLICY));
371    }
372
373    #[test]
374    fn mixed_flags_and_positional() {
375        assert!(check(
376            &toks(&["grep", "-rn", "--color", "--max-count", "10", "pattern", "."]),
377            &TEST_POLICY,
378        ));
379    }
380
381    #[test]
382    fn valued_short_in_explicit_form() {
383        assert!(check(&toks(&["grep", "-A", "3", "-B", "3", "pattern"]), &TEST_POLICY));
384    }
385
386    #[test]
387    fn bare_dash_allowed_as_stdin() {
388        assert!(check(&toks(&["grep", "pattern", "-"]), &TEST_POLICY));
389    }
390
391    #[test]
392    fn valued_flag_at_end_without_value() {
393        assert!(check(&toks(&["grep", "--max-count"]), &TEST_POLICY));
394    }
395
396    #[test]
397    fn single_short_in_wordset_and_byte_array() {
398        assert!(check(&toks(&["grep", "-c", "pattern"]), &TEST_POLICY));
399    }
400
401    static LIMITED_POLICY: FlagPolicy = FlagPolicy {
402        standalone: WordSet::flags(&["--count", "-c", "-d", "-i", "-u"]),
403        valued: WordSet::flags(&["--skip-fields", "-f", "-s"]),
404        bare: true,
405        max_positional: Some(1),
406        tolerance: FlagTolerance::strict(),
407    };
408
409    #[test]
410    fn max_positional_within_limit() {
411        assert!(check(&toks(&["uniq", "input.txt"]), &LIMITED_POLICY));
412    }
413
414    #[test]
415    fn max_positional_exceeded() {
416        assert!(!check(&toks(&["uniq", "input.txt", "output.txt"]), &LIMITED_POLICY));
417    }
418
419    #[test]
420    fn max_positional_with_flags_within_limit() {
421        assert!(check(&toks(&["uniq", "-c", "-f", "3", "input.txt"]), &LIMITED_POLICY));
422    }
423
424    #[test]
425    fn max_positional_with_flags_exceeded() {
426        assert!(!check(&toks(&["uniq", "-c", "input.txt", "output.txt"]), &LIMITED_POLICY));
427    }
428
429    #[test]
430    fn max_positional_after_double_dash() {
431        assert!(!check(&toks(&["uniq", "--", "input.txt", "output.txt"]), &LIMITED_POLICY));
432    }
433
434    #[test]
435    fn max_positional_bare_allowed() {
436        assert!(check(&toks(&["uniq"]), &LIMITED_POLICY));
437    }
438
439    static BOTH_TOLERANCES_POLICY: FlagPolicy = FlagPolicy {
440        standalone: WordSet::flags(&["-E", "-e", "-n"]),
441        valued: WordSet::flags(&[]),
442        bare: true,
443        max_positional: None,
444        tolerance: FlagTolerance { unknown: UnknownTolerance::Both, numeric_dash: false },
445    };
446
447    #[test]
448    fn both_tolerances_accept_unknown_long() {
449        assert!(check(&toks(&["echo", "--unknown", "hello"]), &BOTH_TOLERANCES_POLICY));
450    }
451
452    #[test]
453    fn both_tolerances_accept_unknown_short() {
454        assert!(check(&toks(&["echo", "-x", "hello"]), &BOTH_TOLERANCES_POLICY));
455    }
456
457    #[test]
458    fn both_tolerances_accept_triple_dash() {
459        assert!(check(&toks(&["echo", "---"]), &BOTH_TOLERANCES_POLICY));
460    }
461
462    #[test]
463    fn both_tolerances_known_flags_still_work() {
464        assert!(check(&toks(&["echo", "-n", "hello"]), &BOTH_TOLERANCES_POLICY));
465    }
466
467    #[test]
468    fn both_tolerances_combo_known_short() {
469        assert!(check(&toks(&["echo", "-ne", "hello"]), &BOTH_TOLERANCES_POLICY));
470    }
471
472    #[test]
473    fn both_tolerances_combo_unknown_short_byte() {
474        assert!(check(&toks(&["echo", "-nx", "hello"]), &BOTH_TOLERANCES_POLICY));
475    }
476
477    #[test]
478    fn both_tolerances_unknown_eq_form() {
479        assert!(check(&toks(&["echo", "--foo=bar"]), &BOTH_TOLERANCES_POLICY));
480    }
481
482    // ============ Narrow tolerance: short-only ============
483    // tolerate_unknown_short = true accepts unknown single-dash tokens
484    // (-X, -mayDie, -help) as positional, while leaving double-dash unknowns
485    // strict. This is the safer setting because most modern destructive
486    // flags are double-dash.
487
488    static SHORT_ONLY_POLICY: FlagPolicy = FlagPolicy {
489        standalone: WordSet::flags(&["--help"]),
490        valued: WordSet::flags(&[]),
491        bare: false,
492        max_positional: None,
493        tolerance: FlagTolerance { unknown: UnknownTolerance::Short, numeric_dash: false },
494    };
495
496    #[test]
497    fn short_only_accepts_unknown_dash_letter() {
498        assert!(check(&toks(&["sample", "-mayDie"]), &SHORT_ONLY_POLICY));
499    }
500
501    #[test]
502    fn short_only_accepts_single_dash_long_word() {
503        // pdftotext-style: `-help`, `-layout`, `-version` (single dash + word)
504        assert!(check(&toks(&["pdftotext", "-layout"]), &SHORT_ONLY_POLICY));
505    }
506
507    #[test]
508    fn short_only_denies_unknown_double_dash() {
509        // The whole point of the narrow split: --evil-flag must not slip
510        // through when only short-tolerance is on.
511        assert!(!check(&toks(&["sample", "--evil-flag"]), &SHORT_ONLY_POLICY));
512    }
513
514    #[test]
515    fn short_only_denies_unknown_eq_form() {
516        assert!(!check(&toks(&["sample", "--evil=value"]), &SHORT_ONLY_POLICY));
517    }
518
519    #[test]
520    fn short_only_known_long_flag_still_works() {
521        assert!(check(&toks(&["sample", "--help"]), &SHORT_ONLY_POLICY));
522    }
523
524    // ============ Narrow tolerance: long-only ============
525    // tolerate_unknown_long = true accepts unknown double-dash tokens as
526    // positional. This is the dangerous form; reserved for tools like AWS
527    // CLI whose long-flag surface is genuinely unbounded.
528
529    static LONG_ONLY_POLICY: FlagPolicy = FlagPolicy {
530        standalone: WordSet::flags(&["--help"]),
531        valued: WordSet::flags(&[]),
532        bare: false,
533        max_positional: None,
534        tolerance: FlagTolerance { unknown: UnknownTolerance::Long, numeric_dash: false },
535    };
536
537    #[test]
538    fn long_only_accepts_unknown_double_dash() {
539        assert!(check(&toks(&["aws", "--some-aws-flag"]), &LONG_ONLY_POLICY));
540    }
541
542    #[test]
543    fn long_only_accepts_unknown_eq_form() {
544        assert!(check(
545            &toks(&["aws", "--filter=Name=tag,Values=foo"]),
546            &LONG_ONLY_POLICY,
547        ));
548    }
549
550    #[test]
551    fn long_only_denies_unknown_short_dash() {
552        assert!(!check(&toks(&["aws", "-x"]), &LONG_ONLY_POLICY));
553    }
554
555    // ============ Both tolerances false: strict ============
556
557    static STRICT_POLICY: FlagPolicy = FlagPolicy {
558        standalone: WordSet::flags(&["--help"]),
559        valued: WordSet::flags(&[]),
560        bare: false,
561        max_positional: None,
562        tolerance: FlagTolerance::strict(),
563    };
564
565    #[test]
566    fn strict_denies_unknown_short() {
567        assert!(!check(&toks(&["foo", "-evil"]), &STRICT_POLICY));
568    }
569
570    #[test]
571    fn strict_denies_unknown_long() {
572        assert!(!check(&toks(&["foo", "--evil"]), &STRICT_POLICY));
573    }
574
575    #[test]
576    fn strict_known_flag_passes() {
577        assert!(check(&toks(&["foo", "--help"]), &STRICT_POLICY));
578    }
579
580    #[test]
581    fn both_tolerances_with_max_positional() {
582        let policy = FlagPolicy {
583            standalone: WordSet::flags(&["-n"]),
584            valued: WordSet::flags(&[]),
585            bare: true,
586            max_positional: Some(2),
587            tolerance: FlagTolerance { unknown: UnknownTolerance::Both, numeric_dash: false },
588        };
589        assert!(check(&toks(&["echo", "--unknown", "hello"]), &policy));
590        assert!(!check(&toks(&["echo", "--a", "--b", "--c"]), &policy));
591    }
592
593    static NUMERIC_DASH_POLICY: FlagPolicy = FlagPolicy {
594        standalone: WordSet::flags(&[
595            "--help", "--quiet", "--verbose", "--version",
596            "-V", "-h", "-q", "-v", "-z",
597        ]),
598        valued: WordSet::flags(&["--bytes", "--lines", "-c", "-n"]),
599        bare: true,
600        max_positional: None,
601        tolerance: FlagTolerance { numeric_dash: true, ..FlagTolerance::strict() },
602    };
603
604    #[test]
605    fn numeric_dash_single_digit() {
606        assert!(check(&toks(&["head", "-5"]), &NUMERIC_DASH_POLICY));
607    }
608
609    #[test]
610    fn numeric_dash_multi_digit() {
611        assert!(check(&toks(&["head", "-20"]), &NUMERIC_DASH_POLICY));
612    }
613
614    #[test]
615    fn numeric_dash_large_number() {
616        assert!(check(&toks(&["head", "-1000"]), &NUMERIC_DASH_POLICY));
617    }
618
619    #[test]
620    fn numeric_dash_with_file_arg() {
621        assert!(check(&toks(&["head", "-20", "file.txt"]), &NUMERIC_DASH_POLICY));
622    }
623
624    #[test]
625    fn numeric_dash_with_other_flags() {
626        assert!(check(&toks(&["head", "-q", "-20", "file.txt"]), &NUMERIC_DASH_POLICY));
627    }
628
629    #[test]
630    fn numeric_dash_zero() {
631        assert!(check(&toks(&["head", "-0"]), &NUMERIC_DASH_POLICY));
632    }
633
634    #[test]
635    fn numeric_dash_still_rejects_unknown_flags() {
636        assert!(!check(&toks(&["head", "-x"]), &NUMERIC_DASH_POLICY));
637    }
638
639    #[test]
640    fn numeric_dash_rejects_mixed_alpha_num() {
641        assert!(!check(&toks(&["head", "-20x"]), &NUMERIC_DASH_POLICY));
642    }
643
644    #[test]
645    fn numeric_dash_disabled_rejects_multi_digit() {
646        assert!(!check(&toks(&["grep", "-20", "pattern"]), &TEST_POLICY));
647    }
648
649    #[test]
650    fn looks_like_path_accepts_relative() {
651        assert!(looks_like_path("./Tiltfile"));
652        assert!(looks_like_path("path/to/file"));
653    }
654
655    #[test]
656    fn looks_like_path_accepts_dotted() {
657        assert!(looks_like_path("Tiltfile.dev"));
658        assert!(looks_like_path("file.rb"));
659    }
660
661    #[test]
662    fn looks_like_path_accepts_stdin_dash() {
663        assert!(looks_like_path("-"));
664    }
665
666    #[test]
667    fn looks_like_path_rejects_flag() {
668        assert!(!looks_like_path("--help"));
669        assert!(!looks_like_path("-x"));
670    }
671
672    #[test]
673    fn looks_like_path_rejects_bare_word() {
674        assert!(!looks_like_path("Tiltfile"));
675        assert!(!looks_like_path("up"));
676    }
677
678    #[test]
679    fn looks_like_path_rejects_empty() {
680        assert!(!looks_like_path(""));
681    }
682
683    #[test]
684    fn positional_shape_path_matches() {
685        assert!(PositionalShape::Path.matches("./file.rb"));
686        assert!(!PositionalShape::Path.matches("--flag"));
687    }
688
689    #[test]
690    fn positional_shape_from_name() {
691        assert_eq!(PositionalShape::from_name("path"), Some(PositionalShape::Path));
692        assert_eq!(PositionalShape::from_name("nope"), None);
693    }
694}