1use crate::parse::{Token, WordSet};
2
3#[derive(Clone, Copy, Debug, PartialEq, Eq)]
4pub enum FlagStyle {
5 Strict,
6 Positional,
7}
8
9pub trait FlagSet {
10 fn contains_flag(&self, token: &str) -> bool;
11 fn contains_short(&self, byte: u8) -> bool;
12}
13
14impl FlagSet for WordSet {
15 fn contains_flag(&self, token: &str) -> bool {
16 self.contains(token)
17 }
18 fn contains_short(&self, byte: u8) -> bool {
19 self.contains_short(byte)
20 }
21}
22
23impl FlagSet for [String] {
24 fn contains_flag(&self, token: &str) -> bool {
25 self.iter().any(|f| f.as_str() == token)
26 }
27 fn contains_short(&self, byte: u8) -> bool {
28 self.iter().any(|f| f.len() == 2 && f.as_bytes()[1] == byte)
29 }
30}
31
32impl FlagSet for Vec<String> {
33 fn contains_flag(&self, token: &str) -> bool {
34 self.as_slice().contains_flag(token)
35 }
36 fn contains_short(&self, byte: u8) -> bool {
37 self.as_slice().contains_short(byte)
38 }
39}
40
41pub struct FlagPolicy {
42 pub standalone: WordSet,
43 pub valued: WordSet,
44 pub bare: bool,
45 pub max_positional: Option<usize>,
46 pub flag_style: FlagStyle,
47}
48
49impl FlagPolicy {
50 pub fn describe(&self) -> String {
51 use crate::docs::wordset_items;
52 let mut lines = Vec::new();
53 let standalone = wordset_items(&self.standalone);
54 if !standalone.is_empty() {
55 lines.push(format!("- Allowed standalone flags: {standalone}"));
56 }
57 let valued = wordset_items(&self.valued);
58 if !valued.is_empty() {
59 lines.push(format!("- Allowed valued flags: {valued}"));
60 }
61 if self.bare {
62 lines.push("- Bare invocation allowed".to_string());
63 }
64 if self.flag_style == FlagStyle::Positional {
65 lines.push("- Hyphen-prefixed positional arguments accepted".to_string());
66 }
67 if lines.is_empty() && !self.bare {
68 return "- Positional arguments only".to_string();
69 }
70 lines.join("\n")
71 }
72
73}
74
75pub fn check(tokens: &[Token], policy: &FlagPolicy) -> bool {
76 check_flags(
77 tokens,
78 &policy.standalone,
79 &policy.valued,
80 policy.bare,
81 policy.max_positional,
82 policy.flag_style,
83 )
84}
85
86pub fn check_flags<S: FlagSet + ?Sized, V: FlagSet + ?Sized>(
87 tokens: &[Token],
88 standalone: &S,
89 valued: &V,
90 bare: bool,
91 max_positional: Option<usize>,
92 flag_style: FlagStyle,
93) -> bool {
94 if tokens.len() == 1 {
95 return bare;
96 }
97
98 let mut i = 1;
99 let mut positionals: usize = 0;
100 while i < tokens.len() {
101 let t = &tokens[i];
102
103 if *t == "--" {
104 positionals += tokens.len() - i - 1;
105 break;
106 }
107
108 if !t.starts_with('-') {
109 positionals += 1;
110 i += 1;
111 continue;
112 }
113
114 if standalone.contains_flag(t) {
115 i += 1;
116 continue;
117 }
118
119 if valued.contains_flag(t) {
120 i += 2;
121 continue;
122 }
123
124 if let Some(flag) = t.as_str().split_once('=').map(|(f, _)| f) {
125 if valued.contains_flag(flag) {
126 i += 1;
127 continue;
128 }
129 if flag_style == FlagStyle::Positional {
130 positionals += 1;
131 i += 1;
132 continue;
133 }
134 return false;
135 }
136
137 if t.starts_with("--") {
138 if flag_style == FlagStyle::Positional {
139 positionals += 1;
140 i += 1;
141 continue;
142 }
143 return false;
144 }
145
146 let bytes = t.as_bytes();
147 let mut j = 1;
148 while j < bytes.len() {
149 let b = bytes[j];
150 let is_last = j == bytes.len() - 1;
151 if standalone.contains_short(b) {
152 j += 1;
153 continue;
154 }
155 if valued.contains_short(b) {
156 if is_last {
157 i += 1;
158 }
159 break;
160 }
161 if flag_style == FlagStyle::Positional {
162 positionals += 1;
163 break;
164 }
165 return false;
166 }
167 i += 1;
168 }
169 max_positional.is_none_or(|max| positionals <= max)
170}
171
172#[cfg(test)]
173mod tests {
174 use super::*;
175
176 static TEST_POLICY: FlagPolicy = FlagPolicy {
177 standalone: WordSet::flags(&[
178 "--color", "--count", "--help", "--recursive", "--version",
179 "-H", "-c", "-i", "-l", "-n", "-o", "-r", "-s", "-v", "-w",
180 ]),
181 valued: WordSet::flags(&[
182 "--after-context", "--before-context", "--max-count",
183 "-A", "-B", "-m",
184 ]),
185 bare: false,
186 max_positional: None,
187 flag_style: FlagStyle::Strict,
188 };
189
190 fn toks(words: &[&str]) -> Vec<Token> {
191 words.iter().map(|s| Token::from_test(s)).collect()
192 }
193
194 #[test]
195 fn bare_denied_when_bare_false() {
196 assert!(!check(&toks(&["grep"]), &TEST_POLICY));
197 }
198
199 #[test]
200 fn bare_allowed_when_bare_true() {
201 let policy = FlagPolicy {
202 standalone: WordSet::flags(&[]),
203 valued: WordSet::flags(&[]),
204 bare: true,
205 max_positional: None,
206 flag_style: FlagStyle::Strict,
207 };
208 assert!(check(&toks(&["uname"]), &policy));
209 }
210
211 #[test]
212 fn standalone_long_flag() {
213 assert!(check(&toks(&["grep", "--recursive", "pattern", "."]), &TEST_POLICY));
214 }
215
216 #[test]
217 fn standalone_short_flag() {
218 assert!(check(&toks(&["grep", "-r", "pattern", "."]), &TEST_POLICY));
219 }
220
221 #[test]
222 fn valued_long_flag_space() {
223 assert!(check(&toks(&["grep", "--max-count", "5", "pattern"]), &TEST_POLICY));
224 }
225
226 #[test]
227 fn valued_long_flag_eq() {
228 assert!(check(&toks(&["grep", "--max-count=5", "pattern"]), &TEST_POLICY));
229 }
230
231 #[test]
232 fn valued_short_flag_space() {
233 assert!(check(&toks(&["grep", "-m", "5", "pattern"]), &TEST_POLICY));
234 }
235
236 #[test]
237 fn combined_standalone_short() {
238 assert!(check(&toks(&["grep", "-rn", "pattern", "."]), &TEST_POLICY));
239 }
240
241 #[test]
242 fn combined_short_with_valued_last() {
243 assert!(check(&toks(&["grep", "-rnm", "5", "pattern"]), &TEST_POLICY));
244 }
245
246 #[test]
247 fn combined_short_valued_mid_consumes_rest() {
248 assert!(check(&toks(&["grep", "-rmn", "pattern"]), &TEST_POLICY));
249 }
250
251 #[test]
252 fn unknown_long_flag_denied() {
253 assert!(!check(&toks(&["grep", "--exec", "cmd"]), &TEST_POLICY));
254 }
255
256 #[test]
257 fn unknown_short_flag_denied() {
258 assert!(!check(&toks(&["grep", "-z", "pattern"]), &TEST_POLICY));
259 }
260
261 #[test]
262 fn unknown_combined_short_denied() {
263 assert!(!check(&toks(&["grep", "-rz", "pattern"]), &TEST_POLICY));
264 }
265
266 #[test]
267 fn unknown_long_eq_denied() {
268 assert!(!check(&toks(&["grep", "--output=file.txt", "pattern"]), &TEST_POLICY));
269 }
270
271 #[test]
272 fn double_dash_stops_checking() {
273 assert!(check(&toks(&["grep", "--", "--not-a-flag", "file"]), &TEST_POLICY));
274 }
275
276 #[test]
277 fn positional_args_allowed() {
278 assert!(check(&toks(&["grep", "pattern", "file.txt", "other.txt"]), &TEST_POLICY));
279 }
280
281 #[test]
282 fn mixed_flags_and_positional() {
283 assert!(check(
284 &toks(&["grep", "-rn", "--color", "--max-count", "10", "pattern", "."]),
285 &TEST_POLICY,
286 ));
287 }
288
289 #[test]
290 fn valued_short_in_explicit_form() {
291 assert!(check(&toks(&["grep", "-A", "3", "-B", "3", "pattern"]), &TEST_POLICY));
292 }
293
294 #[test]
295 fn bare_dash_allowed_as_stdin() {
296 assert!(check(&toks(&["grep", "pattern", "-"]), &TEST_POLICY));
297 }
298
299 #[test]
300 fn valued_flag_at_end_without_value() {
301 assert!(check(&toks(&["grep", "--max-count"]), &TEST_POLICY));
302 }
303
304 #[test]
305 fn single_short_in_wordset_and_byte_array() {
306 assert!(check(&toks(&["grep", "-c", "pattern"]), &TEST_POLICY));
307 }
308
309 static LIMITED_POLICY: FlagPolicy = FlagPolicy {
310 standalone: WordSet::flags(&["--count", "-c", "-d", "-i", "-u"]),
311 valued: WordSet::flags(&["--skip-fields", "-f", "-s"]),
312 bare: true,
313 max_positional: Some(1),
314 flag_style: FlagStyle::Strict,
315 };
316
317 #[test]
318 fn max_positional_within_limit() {
319 assert!(check(&toks(&["uniq", "input.txt"]), &LIMITED_POLICY));
320 }
321
322 #[test]
323 fn max_positional_exceeded() {
324 assert!(!check(&toks(&["uniq", "input.txt", "output.txt"]), &LIMITED_POLICY));
325 }
326
327 #[test]
328 fn max_positional_with_flags_within_limit() {
329 assert!(check(&toks(&["uniq", "-c", "-f", "3", "input.txt"]), &LIMITED_POLICY));
330 }
331
332 #[test]
333 fn max_positional_with_flags_exceeded() {
334 assert!(!check(&toks(&["uniq", "-c", "input.txt", "output.txt"]), &LIMITED_POLICY));
335 }
336
337 #[test]
338 fn max_positional_after_double_dash() {
339 assert!(!check(&toks(&["uniq", "--", "input.txt", "output.txt"]), &LIMITED_POLICY));
340 }
341
342 #[test]
343 fn max_positional_bare_allowed() {
344 assert!(check(&toks(&["uniq"]), &LIMITED_POLICY));
345 }
346
347 static POSITIONAL_POLICY: FlagPolicy = FlagPolicy {
348 standalone: WordSet::flags(&["-E", "-e", "-n"]),
349 valued: WordSet::flags(&[]),
350 bare: true,
351 max_positional: None,
352 flag_style: FlagStyle::Positional,
353 };
354
355 #[test]
356 fn positional_style_unknown_long() {
357 assert!(check(&toks(&["echo", "--unknown", "hello"]), &POSITIONAL_POLICY));
358 }
359
360 #[test]
361 fn positional_style_unknown_short() {
362 assert!(check(&toks(&["echo", "-x", "hello"]), &POSITIONAL_POLICY));
363 }
364
365 #[test]
366 fn positional_style_dashes() {
367 assert!(check(&toks(&["echo", "---"]), &POSITIONAL_POLICY));
368 }
369
370 #[test]
371 fn positional_style_known_flags_still_work() {
372 assert!(check(&toks(&["echo", "-n", "hello"]), &POSITIONAL_POLICY));
373 }
374
375 #[test]
376 fn positional_style_combo_known() {
377 assert!(check(&toks(&["echo", "-ne", "hello"]), &POSITIONAL_POLICY));
378 }
379
380 #[test]
381 fn positional_style_combo_unknown_byte() {
382 assert!(check(&toks(&["echo", "-nx", "hello"]), &POSITIONAL_POLICY));
383 }
384
385 #[test]
386 fn positional_style_unknown_eq() {
387 assert!(check(&toks(&["echo", "--foo=bar"]), &POSITIONAL_POLICY));
388 }
389
390 #[test]
391 fn positional_style_with_max_positional() {
392 let policy = FlagPolicy {
393 standalone: WordSet::flags(&["-n"]),
394 valued: WordSet::flags(&[]),
395 bare: true,
396 max_positional: Some(2),
397 flag_style: FlagStyle::Positional,
398 };
399 assert!(check(&toks(&["echo", "--unknown", "hello"]), &policy));
400 assert!(!check(&toks(&["echo", "--a", "--b", "--c"]), &policy));
401 }
402}