Skip to main content

safe_chains/
policy.rs

1use crate::parse::{Token, WordSet};
2
3#[derive(Clone, Copy, Debug, PartialEq, Eq)]
4pub enum FlagStyle {
5    Strict,
6    Positional,
7}
8
9pub trait FlagSet {
10    fn contains_flag(&self, token: &str) -> bool;
11    fn contains_short(&self, byte: u8) -> bool;
12}
13
14impl FlagSet for WordSet {
15    fn contains_flag(&self, token: &str) -> bool {
16        self.contains(token)
17    }
18    fn contains_short(&self, byte: u8) -> bool {
19        self.contains_short(byte)
20    }
21}
22
23impl FlagSet for [String] {
24    fn contains_flag(&self, token: &str) -> bool {
25        self.iter().any(|f| f.as_str() == token)
26    }
27    fn contains_short(&self, byte: u8) -> bool {
28        self.iter().any(|f| f.len() == 2 && f.as_bytes()[1] == byte)
29    }
30}
31
32impl FlagSet for Vec<String> {
33    fn contains_flag(&self, token: &str) -> bool {
34        self.as_slice().contains_flag(token)
35    }
36    fn contains_short(&self, byte: u8) -> bool {
37        self.as_slice().contains_short(byte)
38    }
39}
40
41pub struct FlagPolicy {
42    pub standalone: WordSet,
43    pub valued: WordSet,
44    pub bare: bool,
45    pub max_positional: Option<usize>,
46    pub flag_style: FlagStyle,
47}
48
49impl FlagPolicy {
50    pub fn describe(&self) -> String {
51        use crate::docs::wordset_items;
52        let mut lines = Vec::new();
53        let standalone = wordset_items(&self.standalone);
54        if !standalone.is_empty() {
55            lines.push(format!("- Allowed standalone flags: {standalone}"));
56        }
57        let valued = wordset_items(&self.valued);
58        if !valued.is_empty() {
59            lines.push(format!("- Allowed valued flags: {valued}"));
60        }
61        if self.bare {
62            lines.push("- Bare invocation allowed".to_string());
63        }
64        if self.flag_style == FlagStyle::Positional {
65            lines.push("- Hyphen-prefixed positional arguments accepted".to_string());
66        }
67        if lines.is_empty() && !self.bare {
68            return "- Positional arguments only".to_string();
69        }
70        lines.join("\n")
71    }
72
73}
74
75pub fn check(tokens: &[Token], policy: &FlagPolicy) -> bool {
76    check_flags(
77        tokens,
78        &policy.standalone,
79        &policy.valued,
80        policy.bare,
81        policy.max_positional,
82        policy.flag_style,
83    )
84}
85
86pub fn check_flags<S: FlagSet + ?Sized, V: FlagSet + ?Sized>(
87    tokens: &[Token],
88    standalone: &S,
89    valued: &V,
90    bare: bool,
91    max_positional: Option<usize>,
92    flag_style: FlagStyle,
93) -> bool {
94    if tokens.len() == 1 {
95        return bare;
96    }
97
98    let mut i = 1;
99    let mut positionals: usize = 0;
100    while i < tokens.len() {
101        let t = &tokens[i];
102
103        if *t == "--" {
104            positionals += tokens.len() - i - 1;
105            break;
106        }
107
108        if !t.starts_with('-') {
109            positionals += 1;
110            i += 1;
111            continue;
112        }
113
114        if standalone.contains_flag(t) {
115            i += 1;
116            continue;
117        }
118
119        if valued.contains_flag(t) {
120            i += 2;
121            continue;
122        }
123
124        if let Some(flag) = t.as_str().split_once('=').map(|(f, _)| f) {
125            if valued.contains_flag(flag) {
126                i += 1;
127                continue;
128            }
129            if flag_style == FlagStyle::Positional {
130                positionals += 1;
131                i += 1;
132                continue;
133            }
134            return false;
135        }
136
137        if t.starts_with("--") {
138            if flag_style == FlagStyle::Positional {
139                positionals += 1;
140                i += 1;
141                continue;
142            }
143            return false;
144        }
145
146        let bytes = t.as_bytes();
147        let mut j = 1;
148        while j < bytes.len() {
149            let b = bytes[j];
150            let is_last = j == bytes.len() - 1;
151            if standalone.contains_short(b) {
152                j += 1;
153                continue;
154            }
155            if valued.contains_short(b) {
156                if is_last {
157                    i += 1;
158                }
159                break;
160            }
161            if flag_style == FlagStyle::Positional {
162                positionals += 1;
163                break;
164            }
165            return false;
166        }
167        i += 1;
168    }
169    max_positional.is_none_or(|max| positionals <= max)
170}
171
172#[cfg(test)]
173mod tests {
174    use super::*;
175
176    static TEST_POLICY: FlagPolicy = FlagPolicy {
177        standalone: WordSet::flags(&[
178            "--color", "--count", "--help", "--recursive", "--version",
179            "-H", "-c", "-i", "-l", "-n", "-o", "-r", "-s", "-v", "-w",
180        ]),
181        valued: WordSet::flags(&[
182            "--after-context", "--before-context", "--max-count",
183            "-A", "-B", "-m",
184        ]),
185        bare: false,
186        max_positional: None,
187        flag_style: FlagStyle::Strict,
188    };
189
190    fn toks(words: &[&str]) -> Vec<Token> {
191        words.iter().map(|s| Token::from_test(s)).collect()
192    }
193
194    #[test]
195    fn bare_denied_when_bare_false() {
196        assert!(!check(&toks(&["grep"]), &TEST_POLICY));
197    }
198
199    #[test]
200    fn bare_allowed_when_bare_true() {
201        let policy = FlagPolicy {
202            standalone: WordSet::flags(&[]),
203            valued: WordSet::flags(&[]),
204            bare: true,
205            max_positional: None,
206            flag_style: FlagStyle::Strict,
207        };
208        assert!(check(&toks(&["uname"]), &policy));
209    }
210
211    #[test]
212    fn standalone_long_flag() {
213        assert!(check(&toks(&["grep", "--recursive", "pattern", "."]), &TEST_POLICY));
214    }
215
216    #[test]
217    fn standalone_short_flag() {
218        assert!(check(&toks(&["grep", "-r", "pattern", "."]), &TEST_POLICY));
219    }
220
221    #[test]
222    fn valued_long_flag_space() {
223        assert!(check(&toks(&["grep", "--max-count", "5", "pattern"]), &TEST_POLICY));
224    }
225
226    #[test]
227    fn valued_long_flag_eq() {
228        assert!(check(&toks(&["grep", "--max-count=5", "pattern"]), &TEST_POLICY));
229    }
230
231    #[test]
232    fn valued_short_flag_space() {
233        assert!(check(&toks(&["grep", "-m", "5", "pattern"]), &TEST_POLICY));
234    }
235
236    #[test]
237    fn combined_standalone_short() {
238        assert!(check(&toks(&["grep", "-rn", "pattern", "."]), &TEST_POLICY));
239    }
240
241    #[test]
242    fn combined_short_with_valued_last() {
243        assert!(check(&toks(&["grep", "-rnm", "5", "pattern"]), &TEST_POLICY));
244    }
245
246    #[test]
247    fn combined_short_valued_mid_consumes_rest() {
248        assert!(check(&toks(&["grep", "-rmn", "pattern"]), &TEST_POLICY));
249    }
250
251    #[test]
252    fn unknown_long_flag_denied() {
253        assert!(!check(&toks(&["grep", "--exec", "cmd"]), &TEST_POLICY));
254    }
255
256    #[test]
257    fn unknown_short_flag_denied() {
258        assert!(!check(&toks(&["grep", "-z", "pattern"]), &TEST_POLICY));
259    }
260
261    #[test]
262    fn unknown_combined_short_denied() {
263        assert!(!check(&toks(&["grep", "-rz", "pattern"]), &TEST_POLICY));
264    }
265
266    #[test]
267    fn unknown_long_eq_denied() {
268        assert!(!check(&toks(&["grep", "--output=file.txt", "pattern"]), &TEST_POLICY));
269    }
270
271    #[test]
272    fn double_dash_stops_checking() {
273        assert!(check(&toks(&["grep", "--", "--not-a-flag", "file"]), &TEST_POLICY));
274    }
275
276    #[test]
277    fn positional_args_allowed() {
278        assert!(check(&toks(&["grep", "pattern", "file.txt", "other.txt"]), &TEST_POLICY));
279    }
280
281    #[test]
282    fn mixed_flags_and_positional() {
283        assert!(check(
284            &toks(&["grep", "-rn", "--color", "--max-count", "10", "pattern", "."]),
285            &TEST_POLICY,
286        ));
287    }
288
289    #[test]
290    fn valued_short_in_explicit_form() {
291        assert!(check(&toks(&["grep", "-A", "3", "-B", "3", "pattern"]), &TEST_POLICY));
292    }
293
294    #[test]
295    fn bare_dash_allowed_as_stdin() {
296        assert!(check(&toks(&["grep", "pattern", "-"]), &TEST_POLICY));
297    }
298
299    #[test]
300    fn valued_flag_at_end_without_value() {
301        assert!(check(&toks(&["grep", "--max-count"]), &TEST_POLICY));
302    }
303
304    #[test]
305    fn single_short_in_wordset_and_byte_array() {
306        assert!(check(&toks(&["grep", "-c", "pattern"]), &TEST_POLICY));
307    }
308
309    static LIMITED_POLICY: FlagPolicy = FlagPolicy {
310        standalone: WordSet::flags(&["--count", "-c", "-d", "-i", "-u"]),
311        valued: WordSet::flags(&["--skip-fields", "-f", "-s"]),
312        bare: true,
313        max_positional: Some(1),
314        flag_style: FlagStyle::Strict,
315    };
316
317    #[test]
318    fn max_positional_within_limit() {
319        assert!(check(&toks(&["uniq", "input.txt"]), &LIMITED_POLICY));
320    }
321
322    #[test]
323    fn max_positional_exceeded() {
324        assert!(!check(&toks(&["uniq", "input.txt", "output.txt"]), &LIMITED_POLICY));
325    }
326
327    #[test]
328    fn max_positional_with_flags_within_limit() {
329        assert!(check(&toks(&["uniq", "-c", "-f", "3", "input.txt"]), &LIMITED_POLICY));
330    }
331
332    #[test]
333    fn max_positional_with_flags_exceeded() {
334        assert!(!check(&toks(&["uniq", "-c", "input.txt", "output.txt"]), &LIMITED_POLICY));
335    }
336
337    #[test]
338    fn max_positional_after_double_dash() {
339        assert!(!check(&toks(&["uniq", "--", "input.txt", "output.txt"]), &LIMITED_POLICY));
340    }
341
342    #[test]
343    fn max_positional_bare_allowed() {
344        assert!(check(&toks(&["uniq"]), &LIMITED_POLICY));
345    }
346
347    static POSITIONAL_POLICY: FlagPolicy = FlagPolicy {
348        standalone: WordSet::flags(&["-E", "-e", "-n"]),
349        valued: WordSet::flags(&[]),
350        bare: true,
351        max_positional: None,
352        flag_style: FlagStyle::Positional,
353    };
354
355    #[test]
356    fn positional_style_unknown_long() {
357        assert!(check(&toks(&["echo", "--unknown", "hello"]), &POSITIONAL_POLICY));
358    }
359
360    #[test]
361    fn positional_style_unknown_short() {
362        assert!(check(&toks(&["echo", "-x", "hello"]), &POSITIONAL_POLICY));
363    }
364
365    #[test]
366    fn positional_style_dashes() {
367        assert!(check(&toks(&["echo", "---"]), &POSITIONAL_POLICY));
368    }
369
370    #[test]
371    fn positional_style_known_flags_still_work() {
372        assert!(check(&toks(&["echo", "-n", "hello"]), &POSITIONAL_POLICY));
373    }
374
375    #[test]
376    fn positional_style_combo_known() {
377        assert!(check(&toks(&["echo", "-ne", "hello"]), &POSITIONAL_POLICY));
378    }
379
380    #[test]
381    fn positional_style_combo_unknown_byte() {
382        assert!(check(&toks(&["echo", "-nx", "hello"]), &POSITIONAL_POLICY));
383    }
384
385    #[test]
386    fn positional_style_unknown_eq() {
387        assert!(check(&toks(&["echo", "--foo=bar"]), &POSITIONAL_POLICY));
388    }
389
390    #[test]
391    fn positional_style_with_max_positional() {
392        let policy = FlagPolicy {
393            standalone: WordSet::flags(&["-n"]),
394            valued: WordSet::flags(&[]),
395            bare: true,
396            max_positional: Some(2),
397            flag_style: FlagStyle::Positional,
398        };
399        assert!(check(&toks(&["echo", "--unknown", "hello"]), &policy));
400        assert!(!check(&toks(&["echo", "--a", "--b", "--c"]), &policy));
401    }
402}