1use crate::parse::{Token, WordSet};
2
3#[derive(Clone, Copy, Debug, PartialEq, Eq)]
4pub enum FlagStyle {
5 Strict,
6 Positional,
7}
8
9pub trait FlagSet {
10 fn contains_flag(&self, token: &str) -> bool;
11 fn contains_short(&self, byte: u8) -> bool;
12}
13
14impl FlagSet for WordSet {
15 fn contains_flag(&self, token: &str) -> bool {
16 self.contains(token)
17 }
18 fn contains_short(&self, byte: u8) -> bool {
19 self.contains_short(byte)
20 }
21}
22
23impl FlagSet for [String] {
24 fn contains_flag(&self, token: &str) -> bool {
25 self.iter().any(|f| f.as_str() == token)
26 }
27 fn contains_short(&self, byte: u8) -> bool {
28 self.iter().any(|f| f.len() == 2 && f.as_bytes()[1] == byte)
29 }
30}
31
32impl FlagSet for Vec<String> {
33 fn contains_flag(&self, token: &str) -> bool {
34 self.as_slice().contains_flag(token)
35 }
36 fn contains_short(&self, byte: u8) -> bool {
37 self.as_slice().contains_short(byte)
38 }
39}
40
41pub struct FlagPolicy {
42 pub standalone: WordSet,
43 pub valued: WordSet,
44 pub bare: bool,
45 pub max_positional: Option<usize>,
46 pub flag_style: FlagStyle,
47}
48
49impl FlagPolicy {
50 pub fn describe(&self) -> String {
51 use crate::docs::wordset_items;
52 let mut lines = Vec::new();
53 let standalone = wordset_items(&self.standalone);
54 if !standalone.is_empty() {
55 lines.push(format!("- Allowed standalone flags: {standalone}"));
56 }
57 let valued = wordset_items(&self.valued);
58 if !valued.is_empty() {
59 lines.push(format!("- Allowed valued flags: {valued}"));
60 }
61 if self.bare {
62 lines.push("- Bare invocation allowed".to_string());
63 }
64 if self.flag_style == FlagStyle::Positional {
65 lines.push("- Hyphen-prefixed positional arguments accepted".to_string());
66 }
67 if lines.is_empty() && !self.bare {
68 return "- Positional arguments only".to_string();
69 }
70 lines.join("\n")
71 }
72
73 pub fn flag_summary(&self) -> String {
74 use crate::docs::wordset_items;
75 let mut parts = Vec::new();
76 let standalone = wordset_items(&self.standalone);
77 if !standalone.is_empty() {
78 parts.push(format!("Flags: {standalone}"));
79 }
80 let valued = wordset_items(&self.valued);
81 if !valued.is_empty() {
82 parts.push(format!("Valued: {valued}"));
83 }
84 if self.flag_style == FlagStyle::Positional {
85 parts.push("Positional args accepted".to_string());
86 }
87 parts.join(". ")
88 }
89}
90
91pub fn check(tokens: &[Token], policy: &FlagPolicy) -> bool {
92 check_flags(
93 tokens,
94 &policy.standalone,
95 &policy.valued,
96 policy.bare,
97 policy.max_positional,
98 policy.flag_style,
99 )
100}
101
102pub fn check_flags<S: FlagSet + ?Sized, V: FlagSet + ?Sized>(
103 tokens: &[Token],
104 standalone: &S,
105 valued: &V,
106 bare: bool,
107 max_positional: Option<usize>,
108 flag_style: FlagStyle,
109) -> bool {
110 if tokens.len() == 1 {
111 return bare;
112 }
113
114 let mut i = 1;
115 let mut positionals: usize = 0;
116 while i < tokens.len() {
117 let t = &tokens[i];
118
119 if *t == "--" {
120 positionals += tokens.len() - i - 1;
121 break;
122 }
123
124 if !t.starts_with('-') {
125 positionals += 1;
126 i += 1;
127 continue;
128 }
129
130 if standalone.contains_flag(t) {
131 i += 1;
132 continue;
133 }
134
135 if valued.contains_flag(t) {
136 i += 2;
137 continue;
138 }
139
140 if let Some(flag) = t.as_str().split_once('=').map(|(f, _)| f) {
141 if valued.contains_flag(flag) {
142 i += 1;
143 continue;
144 }
145 if flag_style == FlagStyle::Positional {
146 positionals += 1;
147 i += 1;
148 continue;
149 }
150 return false;
151 }
152
153 if t.starts_with("--") {
154 if flag_style == FlagStyle::Positional {
155 positionals += 1;
156 i += 1;
157 continue;
158 }
159 return false;
160 }
161
162 let bytes = t.as_bytes();
163 let mut j = 1;
164 while j < bytes.len() {
165 let b = bytes[j];
166 let is_last = j == bytes.len() - 1;
167 if standalone.contains_short(b) {
168 j += 1;
169 continue;
170 }
171 if valued.contains_short(b) {
172 if is_last {
173 i += 1;
174 }
175 break;
176 }
177 if flag_style == FlagStyle::Positional {
178 positionals += 1;
179 break;
180 }
181 return false;
182 }
183 i += 1;
184 }
185 max_positional.is_none_or(|max| positionals <= max)
186}
187
188#[cfg(test)]
189mod tests {
190 use super::*;
191
192 static TEST_POLICY: FlagPolicy = FlagPolicy {
193 standalone: WordSet::flags(&[
194 "--color", "--count", "--help", "--recursive", "--version",
195 "-H", "-c", "-i", "-l", "-n", "-o", "-r", "-s", "-v", "-w",
196 ]),
197 valued: WordSet::flags(&[
198 "--after-context", "--before-context", "--max-count",
199 "-A", "-B", "-m",
200 ]),
201 bare: false,
202 max_positional: None,
203 flag_style: FlagStyle::Strict,
204 };
205
206 fn toks(words: &[&str]) -> Vec<Token> {
207 words.iter().map(|s| Token::from_test(s)).collect()
208 }
209
210 #[test]
211 fn bare_denied_when_bare_false() {
212 assert!(!check(&toks(&["grep"]), &TEST_POLICY));
213 }
214
215 #[test]
216 fn bare_allowed_when_bare_true() {
217 let policy = FlagPolicy {
218 standalone: WordSet::flags(&[]),
219 valued: WordSet::flags(&[]),
220 bare: true,
221 max_positional: None,
222 flag_style: FlagStyle::Strict,
223 };
224 assert!(check(&toks(&["uname"]), &policy));
225 }
226
227 #[test]
228 fn standalone_long_flag() {
229 assert!(check(&toks(&["grep", "--recursive", "pattern", "."]), &TEST_POLICY));
230 }
231
232 #[test]
233 fn standalone_short_flag() {
234 assert!(check(&toks(&["grep", "-r", "pattern", "."]), &TEST_POLICY));
235 }
236
237 #[test]
238 fn valued_long_flag_space() {
239 assert!(check(&toks(&["grep", "--max-count", "5", "pattern"]), &TEST_POLICY));
240 }
241
242 #[test]
243 fn valued_long_flag_eq() {
244 assert!(check(&toks(&["grep", "--max-count=5", "pattern"]), &TEST_POLICY));
245 }
246
247 #[test]
248 fn valued_short_flag_space() {
249 assert!(check(&toks(&["grep", "-m", "5", "pattern"]), &TEST_POLICY));
250 }
251
252 #[test]
253 fn combined_standalone_short() {
254 assert!(check(&toks(&["grep", "-rn", "pattern", "."]), &TEST_POLICY));
255 }
256
257 #[test]
258 fn combined_short_with_valued_last() {
259 assert!(check(&toks(&["grep", "-rnm", "5", "pattern"]), &TEST_POLICY));
260 }
261
262 #[test]
263 fn combined_short_valued_mid_consumes_rest() {
264 assert!(check(&toks(&["grep", "-rmn", "pattern"]), &TEST_POLICY));
265 }
266
267 #[test]
268 fn unknown_long_flag_denied() {
269 assert!(!check(&toks(&["grep", "--exec", "cmd"]), &TEST_POLICY));
270 }
271
272 #[test]
273 fn unknown_short_flag_denied() {
274 assert!(!check(&toks(&["grep", "-z", "pattern"]), &TEST_POLICY));
275 }
276
277 #[test]
278 fn unknown_combined_short_denied() {
279 assert!(!check(&toks(&["grep", "-rz", "pattern"]), &TEST_POLICY));
280 }
281
282 #[test]
283 fn unknown_long_eq_denied() {
284 assert!(!check(&toks(&["grep", "--output=file.txt", "pattern"]), &TEST_POLICY));
285 }
286
287 #[test]
288 fn double_dash_stops_checking() {
289 assert!(check(&toks(&["grep", "--", "--not-a-flag", "file"]), &TEST_POLICY));
290 }
291
292 #[test]
293 fn positional_args_allowed() {
294 assert!(check(&toks(&["grep", "pattern", "file.txt", "other.txt"]), &TEST_POLICY));
295 }
296
297 #[test]
298 fn mixed_flags_and_positional() {
299 assert!(check(
300 &toks(&["grep", "-rn", "--color", "--max-count", "10", "pattern", "."]),
301 &TEST_POLICY,
302 ));
303 }
304
305 #[test]
306 fn valued_short_in_explicit_form() {
307 assert!(check(&toks(&["grep", "-A", "3", "-B", "3", "pattern"]), &TEST_POLICY));
308 }
309
310 #[test]
311 fn bare_dash_allowed_as_stdin() {
312 assert!(check(&toks(&["grep", "pattern", "-"]), &TEST_POLICY));
313 }
314
315 #[test]
316 fn valued_flag_at_end_without_value() {
317 assert!(check(&toks(&["grep", "--max-count"]), &TEST_POLICY));
318 }
319
320 #[test]
321 fn single_short_in_wordset_and_byte_array() {
322 assert!(check(&toks(&["grep", "-c", "pattern"]), &TEST_POLICY));
323 }
324
325 static LIMITED_POLICY: FlagPolicy = FlagPolicy {
326 standalone: WordSet::flags(&["--count", "-c", "-d", "-i", "-u"]),
327 valued: WordSet::flags(&["--skip-fields", "-f", "-s"]),
328 bare: true,
329 max_positional: Some(1),
330 flag_style: FlagStyle::Strict,
331 };
332
333 #[test]
334 fn max_positional_within_limit() {
335 assert!(check(&toks(&["uniq", "input.txt"]), &LIMITED_POLICY));
336 }
337
338 #[test]
339 fn max_positional_exceeded() {
340 assert!(!check(&toks(&["uniq", "input.txt", "output.txt"]), &LIMITED_POLICY));
341 }
342
343 #[test]
344 fn max_positional_with_flags_within_limit() {
345 assert!(check(&toks(&["uniq", "-c", "-f", "3", "input.txt"]), &LIMITED_POLICY));
346 }
347
348 #[test]
349 fn max_positional_with_flags_exceeded() {
350 assert!(!check(&toks(&["uniq", "-c", "input.txt", "output.txt"]), &LIMITED_POLICY));
351 }
352
353 #[test]
354 fn max_positional_after_double_dash() {
355 assert!(!check(&toks(&["uniq", "--", "input.txt", "output.txt"]), &LIMITED_POLICY));
356 }
357
358 #[test]
359 fn max_positional_bare_allowed() {
360 assert!(check(&toks(&["uniq"]), &LIMITED_POLICY));
361 }
362
363 static POSITIONAL_POLICY: FlagPolicy = FlagPolicy {
364 standalone: WordSet::flags(&["-E", "-e", "-n"]),
365 valued: WordSet::flags(&[]),
366 bare: true,
367 max_positional: None,
368 flag_style: FlagStyle::Positional,
369 };
370
371 #[test]
372 fn positional_style_unknown_long() {
373 assert!(check(&toks(&["echo", "--unknown", "hello"]), &POSITIONAL_POLICY));
374 }
375
376 #[test]
377 fn positional_style_unknown_short() {
378 assert!(check(&toks(&["echo", "-x", "hello"]), &POSITIONAL_POLICY));
379 }
380
381 #[test]
382 fn positional_style_dashes() {
383 assert!(check(&toks(&["echo", "---"]), &POSITIONAL_POLICY));
384 }
385
386 #[test]
387 fn positional_style_known_flags_still_work() {
388 assert!(check(&toks(&["echo", "-n", "hello"]), &POSITIONAL_POLICY));
389 }
390
391 #[test]
392 fn positional_style_combo_known() {
393 assert!(check(&toks(&["echo", "-ne", "hello"]), &POSITIONAL_POLICY));
394 }
395
396 #[test]
397 fn positional_style_combo_unknown_byte() {
398 assert!(check(&toks(&["echo", "-nx", "hello"]), &POSITIONAL_POLICY));
399 }
400
401 #[test]
402 fn positional_style_unknown_eq() {
403 assert!(check(&toks(&["echo", "--foo=bar"]), &POSITIONAL_POLICY));
404 }
405
406 #[test]
407 fn positional_style_with_max_positional() {
408 let policy = FlagPolicy {
409 standalone: WordSet::flags(&["-n"]),
410 valued: WordSet::flags(&[]),
411 bare: true,
412 max_positional: Some(2),
413 flag_style: FlagStyle::Positional,
414 };
415 assert!(check(&toks(&["echo", "--unknown", "hello"]), &policy));
416 assert!(!check(&toks(&["echo", "--a", "--b", "--c"]), &policy));
417 }
418}