Skip to main content

safe_chains/
policy.rs

1use crate::parse::{Token, WordSet};
2
3#[derive(Clone, Copy, PartialEq, Eq)]
4pub enum FlagStyle {
5    Strict,
6    Positional,
7}
8
9pub struct FlagPolicy {
10    pub standalone: WordSet,
11    pub valued: WordSet,
12    pub bare: bool,
13    pub max_positional: Option<usize>,
14    pub flag_style: FlagStyle,
15}
16
17impl FlagPolicy {
18    pub fn describe(&self) -> String {
19        use crate::docs::wordset_items;
20        let mut lines = Vec::new();
21        let standalone = wordset_items(&self.standalone);
22        if !standalone.is_empty() {
23            lines.push(format!("- Allowed standalone flags: {standalone}"));
24        }
25        let valued = wordset_items(&self.valued);
26        if !valued.is_empty() {
27            lines.push(format!("- Allowed valued flags: {valued}"));
28        }
29        if self.bare {
30            lines.push("- Bare invocation allowed".to_string());
31        }
32        if self.flag_style == FlagStyle::Positional {
33            lines.push("- Hyphen-prefixed positional arguments accepted".to_string());
34        }
35        if lines.is_empty() && !self.bare {
36            return "- Positional arguments only".to_string();
37        }
38        lines.join("\n")
39    }
40
41    pub fn flag_summary(&self) -> String {
42        use crate::docs::wordset_items;
43        let mut parts = Vec::new();
44        let standalone = wordset_items(&self.standalone);
45        if !standalone.is_empty() {
46            parts.push(format!("Flags: {standalone}"));
47        }
48        let valued = wordset_items(&self.valued);
49        if !valued.is_empty() {
50            parts.push(format!("Valued: {valued}"));
51        }
52        if self.flag_style == FlagStyle::Positional {
53            parts.push("Positional args accepted".to_string());
54        }
55        parts.join(". ")
56    }
57}
58
59pub fn check(tokens: &[Token], policy: &FlagPolicy) -> bool {
60    if tokens.len() == 1 {
61        return policy.bare;
62    }
63
64    let mut i = 1;
65    let mut positionals: usize = 0;
66    while i < tokens.len() {
67        let t = &tokens[i];
68
69        if *t == "--" {
70            positionals += tokens.len() - i - 1;
71            break;
72        }
73
74        if !t.starts_with('-') {
75            positionals += 1;
76            i += 1;
77            continue;
78        }
79
80        if policy.standalone.contains(t) {
81            i += 1;
82            continue;
83        }
84
85        if policy.valued.contains(t) {
86            i += 2;
87            continue;
88        }
89
90        if let Some(flag) = t.as_str().split_once('=').map(|(f, _)| f) {
91            if policy.valued.contains(flag) {
92                i += 1;
93                continue;
94            }
95            if policy.flag_style == FlagStyle::Positional {
96                positionals += 1;
97                i += 1;
98                continue;
99            }
100            return false;
101        }
102
103        if t.starts_with("--") {
104            if policy.flag_style == FlagStyle::Positional {
105                positionals += 1;
106                i += 1;
107                continue;
108            }
109            return false;
110        }
111
112        let bytes = t.as_bytes();
113        let mut j = 1;
114        while j < bytes.len() {
115            let b = bytes[j];
116            let is_last = j == bytes.len() - 1;
117            if policy.standalone.contains_short(b) {
118                j += 1;
119                continue;
120            }
121            if policy.valued.contains_short(b) {
122                if is_last {
123                    i += 1;
124                }
125                break;
126            }
127            if policy.flag_style == FlagStyle::Positional {
128                positionals += 1;
129                break;
130            }
131            return false;
132        }
133        i += 1;
134    }
135    policy.max_positional.is_none_or(|max| positionals <= max)
136}
137
138#[cfg(test)]
139mod tests {
140    use super::*;
141
142    static TEST_POLICY: FlagPolicy = FlagPolicy {
143        standalone: WordSet::flags(&[
144            "--color", "--count", "--help", "--recursive", "--version",
145            "-H", "-c", "-i", "-l", "-n", "-o", "-r", "-s", "-v", "-w",
146        ]),
147        valued: WordSet::flags(&[
148            "--after-context", "--before-context", "--max-count",
149            "-A", "-B", "-m",
150        ]),
151        bare: false,
152        max_positional: None,
153        flag_style: FlagStyle::Strict,
154    };
155
156    fn toks(words: &[&str]) -> Vec<Token> {
157        words.iter().map(|s| Token::from_test(s)).collect()
158    }
159
160    #[test]
161    fn bare_denied_when_bare_false() {
162        assert!(!check(&toks(&["grep"]), &TEST_POLICY));
163    }
164
165    #[test]
166    fn bare_allowed_when_bare_true() {
167        let policy = FlagPolicy {
168            standalone: WordSet::flags(&[]),
169            valued: WordSet::flags(&[]),
170            bare: true,
171            max_positional: None,
172            flag_style: FlagStyle::Strict,
173        };
174        assert!(check(&toks(&["uname"]), &policy));
175    }
176
177    #[test]
178    fn standalone_long_flag() {
179        assert!(check(&toks(&["grep", "--recursive", "pattern", "."]), &TEST_POLICY));
180    }
181
182    #[test]
183    fn standalone_short_flag() {
184        assert!(check(&toks(&["grep", "-r", "pattern", "."]), &TEST_POLICY));
185    }
186
187    #[test]
188    fn valued_long_flag_space() {
189        assert!(check(&toks(&["grep", "--max-count", "5", "pattern"]), &TEST_POLICY));
190    }
191
192    #[test]
193    fn valued_long_flag_eq() {
194        assert!(check(&toks(&["grep", "--max-count=5", "pattern"]), &TEST_POLICY));
195    }
196
197    #[test]
198    fn valued_short_flag_space() {
199        assert!(check(&toks(&["grep", "-m", "5", "pattern"]), &TEST_POLICY));
200    }
201
202    #[test]
203    fn combined_standalone_short() {
204        assert!(check(&toks(&["grep", "-rn", "pattern", "."]), &TEST_POLICY));
205    }
206
207    #[test]
208    fn combined_short_with_valued_last() {
209        assert!(check(&toks(&["grep", "-rnm", "5", "pattern"]), &TEST_POLICY));
210    }
211
212    #[test]
213    fn combined_short_valued_mid_consumes_rest() {
214        assert!(check(&toks(&["grep", "-rmn", "pattern"]), &TEST_POLICY));
215    }
216
217    #[test]
218    fn unknown_long_flag_denied() {
219        assert!(!check(&toks(&["grep", "--exec", "cmd"]), &TEST_POLICY));
220    }
221
222    #[test]
223    fn unknown_short_flag_denied() {
224        assert!(!check(&toks(&["grep", "-z", "pattern"]), &TEST_POLICY));
225    }
226
227    #[test]
228    fn unknown_combined_short_denied() {
229        assert!(!check(&toks(&["grep", "-rz", "pattern"]), &TEST_POLICY));
230    }
231
232    #[test]
233    fn unknown_long_eq_denied() {
234        assert!(!check(&toks(&["grep", "--output=file.txt", "pattern"]), &TEST_POLICY));
235    }
236
237    #[test]
238    fn double_dash_stops_checking() {
239        assert!(check(&toks(&["grep", "--", "--not-a-flag", "file"]), &TEST_POLICY));
240    }
241
242    #[test]
243    fn positional_args_allowed() {
244        assert!(check(&toks(&["grep", "pattern", "file.txt", "other.txt"]), &TEST_POLICY));
245    }
246
247    #[test]
248    fn mixed_flags_and_positional() {
249        assert!(check(
250            &toks(&["grep", "-rn", "--color", "--max-count", "10", "pattern", "."]),
251            &TEST_POLICY,
252        ));
253    }
254
255    #[test]
256    fn valued_short_in_explicit_form() {
257        assert!(check(&toks(&["grep", "-A", "3", "-B", "3", "pattern"]), &TEST_POLICY));
258    }
259
260    #[test]
261    fn bare_dash_allowed_as_stdin() {
262        assert!(check(&toks(&["grep", "pattern", "-"]), &TEST_POLICY));
263    }
264
265    #[test]
266    fn valued_flag_at_end_without_value() {
267        assert!(check(&toks(&["grep", "--max-count"]), &TEST_POLICY));
268    }
269
270    #[test]
271    fn single_short_in_wordset_and_byte_array() {
272        assert!(check(&toks(&["grep", "-c", "pattern"]), &TEST_POLICY));
273    }
274
275    static LIMITED_POLICY: FlagPolicy = FlagPolicy {
276        standalone: WordSet::flags(&["--count", "-c", "-d", "-i", "-u"]),
277        valued: WordSet::flags(&["--skip-fields", "-f", "-s"]),
278        bare: true,
279        max_positional: Some(1),
280        flag_style: FlagStyle::Strict,
281    };
282
283    #[test]
284    fn max_positional_within_limit() {
285        assert!(check(&toks(&["uniq", "input.txt"]), &LIMITED_POLICY));
286    }
287
288    #[test]
289    fn max_positional_exceeded() {
290        assert!(!check(&toks(&["uniq", "input.txt", "output.txt"]), &LIMITED_POLICY));
291    }
292
293    #[test]
294    fn max_positional_with_flags_within_limit() {
295        assert!(check(&toks(&["uniq", "-c", "-f", "3", "input.txt"]), &LIMITED_POLICY));
296    }
297
298    #[test]
299    fn max_positional_with_flags_exceeded() {
300        assert!(!check(&toks(&["uniq", "-c", "input.txt", "output.txt"]), &LIMITED_POLICY));
301    }
302
303    #[test]
304    fn max_positional_after_double_dash() {
305        assert!(!check(&toks(&["uniq", "--", "input.txt", "output.txt"]), &LIMITED_POLICY));
306    }
307
308    #[test]
309    fn max_positional_bare_allowed() {
310        assert!(check(&toks(&["uniq"]), &LIMITED_POLICY));
311    }
312
313    static POSITIONAL_POLICY: FlagPolicy = FlagPolicy {
314        standalone: WordSet::flags(&["-E", "-e", "-n"]),
315        valued: WordSet::flags(&[]),
316        bare: true,
317        max_positional: None,
318        flag_style: FlagStyle::Positional,
319    };
320
321    #[test]
322    fn positional_style_unknown_long() {
323        assert!(check(&toks(&["echo", "--unknown", "hello"]), &POSITIONAL_POLICY));
324    }
325
326    #[test]
327    fn positional_style_unknown_short() {
328        assert!(check(&toks(&["echo", "-x", "hello"]), &POSITIONAL_POLICY));
329    }
330
331    #[test]
332    fn positional_style_dashes() {
333        assert!(check(&toks(&["echo", "---"]), &POSITIONAL_POLICY));
334    }
335
336    #[test]
337    fn positional_style_known_flags_still_work() {
338        assert!(check(&toks(&["echo", "-n", "hello"]), &POSITIONAL_POLICY));
339    }
340
341    #[test]
342    fn positional_style_combo_known() {
343        assert!(check(&toks(&["echo", "-ne", "hello"]), &POSITIONAL_POLICY));
344    }
345
346    #[test]
347    fn positional_style_combo_unknown_byte() {
348        assert!(check(&toks(&["echo", "-nx", "hello"]), &POSITIONAL_POLICY));
349    }
350
351    #[test]
352    fn positional_style_unknown_eq() {
353        assert!(check(&toks(&["echo", "--foo=bar"]), &POSITIONAL_POLICY));
354    }
355
356    #[test]
357    fn positional_style_with_max_positional() {
358        let policy = FlagPolicy {
359            standalone: WordSet::flags(&["-n"]),
360            valued: WordSet::flags(&[]),
361            bare: true,
362            max_positional: Some(2),
363            flag_style: FlagStyle::Positional,
364        };
365        assert!(check(&toks(&["echo", "--unknown", "hello"]), &policy));
366        assert!(!check(&toks(&["echo", "--a", "--b", "--c"]), &policy));
367    }
368}