Skip to main content

safe_chains/
policy.rs

1use crate::parse::{Token, WordSet};
2
3#[derive(Clone, Copy, PartialEq, Eq)]
4pub enum FlagStyle {
5    Strict,
6    Positional,
7}
8
9pub struct FlagPolicy {
10    pub standalone: WordSet,
11    pub standalone_short: &'static [u8],
12    pub valued: WordSet,
13    pub valued_short: &'static [u8],
14    pub bare: bool,
15    pub max_positional: Option<usize>,
16    pub flag_style: FlagStyle,
17}
18
19impl FlagPolicy {
20    pub fn describe(&self) -> String {
21        use crate::docs::{wordset_items, DocBuilder};
22        let mut builder = DocBuilder::new();
23        let standalone = wordset_items(&self.standalone);
24        if !standalone.is_empty() {
25            builder = builder.section(format!("Allowed standalone flags: {standalone}."));
26        }
27        let valued = wordset_items(&self.valued);
28        if !valued.is_empty() {
29            builder = builder.section(format!("Allowed valued flags: {valued}."));
30        }
31        if self.bare {
32            builder = builder.section("Bare invocation allowed.".to_string());
33        }
34        if self.flag_style == FlagStyle::Positional {
35            builder = builder.section("Hyphen-prefixed positional arguments accepted.".to_string());
36        }
37        builder.build()
38    }
39}
40
41pub fn check(tokens: &[Token], policy: &FlagPolicy) -> bool {
42    if tokens.len() == 1 {
43        return policy.bare;
44    }
45
46    let mut i = 1;
47    let mut positionals: usize = 0;
48    while i < tokens.len() {
49        let t = &tokens[i];
50
51        if *t == "--" {
52            positionals += tokens.len() - i - 1;
53            break;
54        }
55
56        if !t.starts_with('-') {
57            positionals += 1;
58            i += 1;
59            continue;
60        }
61
62        if policy.standalone.contains(t) {
63            i += 1;
64            continue;
65        }
66
67        if policy.valued.contains(t) {
68            i += 2;
69            continue;
70        }
71
72        if let Some(flag) = t.as_str().split_once('=').map(|(f, _)| f) {
73            if policy.valued.contains(flag) {
74                i += 1;
75                continue;
76            }
77            if policy.flag_style == FlagStyle::Positional {
78                positionals += 1;
79                i += 1;
80                continue;
81            }
82            return false;
83        }
84
85        if t.starts_with("--") {
86            if policy.flag_style == FlagStyle::Positional {
87                positionals += 1;
88                i += 1;
89                continue;
90            }
91            return false;
92        }
93
94        let bytes = t.as_bytes();
95        let mut j = 1;
96        while j < bytes.len() {
97            let b = bytes[j];
98            let is_last = j == bytes.len() - 1;
99            if policy.standalone_short.contains(&b) {
100                j += 1;
101                continue;
102            }
103            if policy.valued_short.contains(&b) {
104                if is_last {
105                    i += 1;
106                }
107                break;
108            }
109            if policy.flag_style == FlagStyle::Positional {
110                positionals += 1;
111                break;
112            }
113            return false;
114        }
115        i += 1;
116    }
117    policy.max_positional.is_none_or(|max| positionals <= max)
118}
119
120#[cfg(test)]
121mod tests {
122    use super::*;
123
124    static TEST_POLICY: FlagPolicy = FlagPolicy {
125        standalone: WordSet::new(&[
126            "--color", "--count", "--help", "--recursive", "--version",
127            "-c", "-r",
128        ]),
129        standalone_short: b"cHilnorsvw",
130        valued: WordSet::new(&[
131            "--after-context", "--before-context", "--max-count",
132            "-A", "-B", "-m",
133        ]),
134        valued_short: b"ABm",
135        bare: false,
136        max_positional: None,
137        flag_style: FlagStyle::Strict,
138    };
139
140    fn toks(words: &[&str]) -> Vec<Token> {
141        words.iter().map(|s| Token::from_test(s)).collect()
142    }
143
144    #[test]
145    fn bare_denied_when_bare_false() {
146        assert!(!check(&toks(&["grep"]), &TEST_POLICY));
147    }
148
149    #[test]
150    fn bare_allowed_when_bare_true() {
151        let policy = FlagPolicy {
152            standalone: WordSet::new(&[]),
153            standalone_short: b"",
154            valued: WordSet::new(&[]),
155            valued_short: b"",
156            bare: true,
157            max_positional: None,
158            flag_style: FlagStyle::Strict,
159        };
160        assert!(check(&toks(&["uname"]), &policy));
161    }
162
163    #[test]
164    fn standalone_long_flag() {
165        assert!(check(&toks(&["grep", "--recursive", "pattern", "."]), &TEST_POLICY));
166    }
167
168    #[test]
169    fn standalone_short_flag() {
170        assert!(check(&toks(&["grep", "-r", "pattern", "."]), &TEST_POLICY));
171    }
172
173    #[test]
174    fn valued_long_flag_space() {
175        assert!(check(&toks(&["grep", "--max-count", "5", "pattern"]), &TEST_POLICY));
176    }
177
178    #[test]
179    fn valued_long_flag_eq() {
180        assert!(check(&toks(&["grep", "--max-count=5", "pattern"]), &TEST_POLICY));
181    }
182
183    #[test]
184    fn valued_short_flag_space() {
185        assert!(check(&toks(&["grep", "-m", "5", "pattern"]), &TEST_POLICY));
186    }
187
188    #[test]
189    fn combined_standalone_short() {
190        assert!(check(&toks(&["grep", "-rn", "pattern", "."]), &TEST_POLICY));
191    }
192
193    #[test]
194    fn combined_short_with_valued_last() {
195        assert!(check(&toks(&["grep", "-rnm", "5", "pattern"]), &TEST_POLICY));
196    }
197
198    #[test]
199    fn combined_short_valued_mid_consumes_rest() {
200        assert!(check(&toks(&["grep", "-rmn", "pattern"]), &TEST_POLICY));
201    }
202
203    #[test]
204    fn unknown_long_flag_denied() {
205        assert!(!check(&toks(&["grep", "--exec", "cmd"]), &TEST_POLICY));
206    }
207
208    #[test]
209    fn unknown_short_flag_denied() {
210        assert!(!check(&toks(&["grep", "-z", "pattern"]), &TEST_POLICY));
211    }
212
213    #[test]
214    fn unknown_combined_short_denied() {
215        assert!(!check(&toks(&["grep", "-rz", "pattern"]), &TEST_POLICY));
216    }
217
218    #[test]
219    fn unknown_long_eq_denied() {
220        assert!(!check(&toks(&["grep", "--output=file.txt", "pattern"]), &TEST_POLICY));
221    }
222
223    #[test]
224    fn double_dash_stops_checking() {
225        assert!(check(&toks(&["grep", "--", "--not-a-flag", "file"]), &TEST_POLICY));
226    }
227
228    #[test]
229    fn positional_args_allowed() {
230        assert!(check(&toks(&["grep", "pattern", "file.txt", "other.txt"]), &TEST_POLICY));
231    }
232
233    #[test]
234    fn mixed_flags_and_positional() {
235        assert!(check(
236            &toks(&["grep", "-rn", "--color", "--max-count", "10", "pattern", "."]),
237            &TEST_POLICY,
238        ));
239    }
240
241    #[test]
242    fn valued_short_in_explicit_form() {
243        assert!(check(&toks(&["grep", "-A", "3", "-B", "3", "pattern"]), &TEST_POLICY));
244    }
245
246    #[test]
247    fn bare_dash_allowed_as_stdin() {
248        assert!(check(&toks(&["grep", "pattern", "-"]), &TEST_POLICY));
249    }
250
251    #[test]
252    fn valued_flag_at_end_without_value() {
253        assert!(check(&toks(&["grep", "--max-count"]), &TEST_POLICY));
254    }
255
256    #[test]
257    fn single_short_in_wordset_and_byte_array() {
258        assert!(check(&toks(&["grep", "-c", "pattern"]), &TEST_POLICY));
259    }
260
261    static LIMITED_POLICY: FlagPolicy = FlagPolicy {
262        standalone: WordSet::new(&["--count", "-c", "-d", "-i", "-u"]),
263        standalone_short: b"cdiu",
264        valued: WordSet::new(&["--skip-fields", "-f", "-s"]),
265        valued_short: b"fs",
266        bare: true,
267        max_positional: Some(1),
268        flag_style: FlagStyle::Strict,
269    };
270
271    #[test]
272    fn max_positional_within_limit() {
273        assert!(check(&toks(&["uniq", "input.txt"]), &LIMITED_POLICY));
274    }
275
276    #[test]
277    fn max_positional_exceeded() {
278        assert!(!check(&toks(&["uniq", "input.txt", "output.txt"]), &LIMITED_POLICY));
279    }
280
281    #[test]
282    fn max_positional_with_flags_within_limit() {
283        assert!(check(&toks(&["uniq", "-c", "-f", "3", "input.txt"]), &LIMITED_POLICY));
284    }
285
286    #[test]
287    fn max_positional_with_flags_exceeded() {
288        assert!(!check(&toks(&["uniq", "-c", "input.txt", "output.txt"]), &LIMITED_POLICY));
289    }
290
291    #[test]
292    fn max_positional_after_double_dash() {
293        assert!(!check(&toks(&["uniq", "--", "input.txt", "output.txt"]), &LIMITED_POLICY));
294    }
295
296    #[test]
297    fn max_positional_bare_allowed() {
298        assert!(check(&toks(&["uniq"]), &LIMITED_POLICY));
299    }
300
301    static POSITIONAL_POLICY: FlagPolicy = FlagPolicy {
302        standalone: WordSet::new(&["-E", "-e", "-n"]),
303        standalone_short: b"Een",
304        valued: WordSet::new(&[]),
305        valued_short: b"",
306        bare: true,
307        max_positional: None,
308        flag_style: FlagStyle::Positional,
309    };
310
311    #[test]
312    fn positional_style_unknown_long() {
313        assert!(check(&toks(&["echo", "--unknown", "hello"]), &POSITIONAL_POLICY));
314    }
315
316    #[test]
317    fn positional_style_unknown_short() {
318        assert!(check(&toks(&["echo", "-x", "hello"]), &POSITIONAL_POLICY));
319    }
320
321    #[test]
322    fn positional_style_dashes() {
323        assert!(check(&toks(&["echo", "---"]), &POSITIONAL_POLICY));
324    }
325
326    #[test]
327    fn positional_style_known_flags_still_work() {
328        assert!(check(&toks(&["echo", "-n", "hello"]), &POSITIONAL_POLICY));
329    }
330
331    #[test]
332    fn positional_style_combo_known() {
333        assert!(check(&toks(&["echo", "-ne", "hello"]), &POSITIONAL_POLICY));
334    }
335
336    #[test]
337    fn positional_style_combo_unknown_byte() {
338        assert!(check(&toks(&["echo", "-nx", "hello"]), &POSITIONAL_POLICY));
339    }
340
341    #[test]
342    fn positional_style_unknown_eq() {
343        assert!(check(&toks(&["echo", "--foo=bar"]), &POSITIONAL_POLICY));
344    }
345
346    #[test]
347    fn positional_style_with_max_positional() {
348        let policy = FlagPolicy {
349            standalone: WordSet::new(&["-n"]),
350            standalone_short: b"n",
351            valued: WordSet::new(&[]),
352            valued_short: b"",
353            bare: true,
354            max_positional: Some(2),
355            flag_style: FlagStyle::Positional,
356        };
357        assert!(check(&toks(&["echo", "--unknown", "hello"]), &policy));
358        assert!(!check(&toks(&["echo", "--a", "--b", "--c"]), &policy));
359    }
360}