use crate::parse::{Token, WordSet};
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum FlagStyle {
Strict,
Positional,
}
pub trait FlagSet {
fn contains_flag(&self, token: &str) -> bool;
fn contains_short(&self, byte: u8) -> bool;
}
impl FlagSet for WordSet {
fn contains_flag(&self, token: &str) -> bool {
self.contains(token)
}
fn contains_short(&self, byte: u8) -> bool {
self.contains_short(byte)
}
}
impl FlagSet for [String] {
fn contains_flag(&self, token: &str) -> bool {
self.iter().any(|f| f.as_str() == token)
}
fn contains_short(&self, byte: u8) -> bool {
self.iter().any(|f| f.len() == 2 && f.as_bytes()[1] == byte)
}
}
impl FlagSet for Vec<String> {
fn contains_flag(&self, token: &str) -> bool {
self.as_slice().contains_flag(token)
}
fn contains_short(&self, byte: u8) -> bool {
self.as_slice().contains_short(byte)
}
}
pub struct FlagPolicy {
pub standalone: WordSet,
pub valued: WordSet,
pub bare: bool,
pub max_positional: Option<usize>,
pub flag_style: FlagStyle,
pub numeric_dash: bool,
}
impl FlagPolicy {
pub fn describe(&self) -> String {
use crate::docs::wordset_items;
let mut lines = Vec::new();
let standalone = wordset_items(&self.standalone);
if !standalone.is_empty() {
lines.push(format!("- Allowed standalone flags: {standalone}"));
}
let valued = wordset_items(&self.valued);
if !valued.is_empty() {
lines.push(format!("- Allowed valued flags: {valued}"));
}
if self.bare {
lines.push("- Bare invocation allowed".to_string());
}
if self.flag_style == FlagStyle::Positional {
lines.push("- Hyphen-prefixed positional arguments accepted".to_string());
}
if self.numeric_dash {
lines.push("- Numeric shorthand accepted (e.g. -20 for -n 20)".to_string());
}
if lines.is_empty() && !self.bare {
return "- Positional arguments only".to_string();
}
lines.join("\n")
}
}
pub fn check(tokens: &[Token], policy: &FlagPolicy) -> bool {
check_flags(
tokens,
&policy.standalone,
&policy.valued,
policy.bare,
policy.max_positional,
policy.flag_style,
policy.numeric_dash,
)
}
pub fn check_flags<S: FlagSet + ?Sized, V: FlagSet + ?Sized>(
tokens: &[Token],
standalone: &S,
valued: &V,
bare: bool,
max_positional: Option<usize>,
flag_style: FlagStyle,
numeric_dash: bool,
) -> bool {
if tokens.len() == 1 {
return bare;
}
let mut i = 1;
let mut positionals: usize = 0;
while i < tokens.len() {
let t = &tokens[i];
if *t == "--" {
positionals += tokens.len() - i - 1;
break;
}
if !t.starts_with('-') {
positionals += 1;
i += 1;
continue;
}
if numeric_dash && t.len() > 1 && t[1..].bytes().all(|b| b.is_ascii_digit()) {
i += 1;
continue;
}
if standalone.contains_flag(t) {
i += 1;
continue;
}
if valued.contains_flag(t) {
i += 2;
continue;
}
if let Some(flag) = t.as_str().split_once('=').map(|(f, _)| f) {
if valued.contains_flag(flag) {
i += 1;
continue;
}
if flag_style == FlagStyle::Positional {
positionals += 1;
i += 1;
continue;
}
return false;
}
if t.starts_with("--") {
if flag_style == FlagStyle::Positional {
positionals += 1;
i += 1;
continue;
}
return false;
}
let bytes = t.as_bytes();
let mut j = 1;
while j < bytes.len() {
let b = bytes[j];
let is_last = j == bytes.len() - 1;
if standalone.contains_short(b) {
j += 1;
continue;
}
if valued.contains_short(b) {
if is_last {
i += 1;
}
break;
}
if flag_style == FlagStyle::Positional {
positionals += 1;
break;
}
return false;
}
i += 1;
}
max_positional.is_none_or(|max| positionals <= max)
}
#[cfg(test)]
mod tests {
use super::*;
static TEST_POLICY: FlagPolicy = FlagPolicy {
standalone: WordSet::flags(&[
"--color", "--count", "--help", "--recursive", "--version",
"-H", "-c", "-i", "-l", "-n", "-o", "-r", "-s", "-v", "-w",
]),
valued: WordSet::flags(&[
"--after-context", "--before-context", "--max-count",
"-A", "-B", "-m",
]),
bare: false,
max_positional: None,
flag_style: FlagStyle::Strict,
numeric_dash: false,
};
fn toks(words: &[&str]) -> Vec<Token> {
words.iter().map(|s| Token::from_test(s)).collect()
}
#[test]
fn bare_denied_when_bare_false() {
assert!(!check(&toks(&["grep"]), &TEST_POLICY));
}
#[test]
fn bare_allowed_when_bare_true() {
let policy = FlagPolicy {
standalone: WordSet::flags(&[]),
valued: WordSet::flags(&[]),
bare: true,
max_positional: None,
flag_style: FlagStyle::Strict,
numeric_dash: false,
};
assert!(check(&toks(&["uname"]), &policy));
}
#[test]
fn standalone_long_flag() {
assert!(check(&toks(&["grep", "--recursive", "pattern", "."]), &TEST_POLICY));
}
#[test]
fn standalone_short_flag() {
assert!(check(&toks(&["grep", "-r", "pattern", "."]), &TEST_POLICY));
}
#[test]
fn valued_long_flag_space() {
assert!(check(&toks(&["grep", "--max-count", "5", "pattern"]), &TEST_POLICY));
}
#[test]
fn valued_long_flag_eq() {
assert!(check(&toks(&["grep", "--max-count=5", "pattern"]), &TEST_POLICY));
}
#[test]
fn valued_short_flag_space() {
assert!(check(&toks(&["grep", "-m", "5", "pattern"]), &TEST_POLICY));
}
#[test]
fn combined_standalone_short() {
assert!(check(&toks(&["grep", "-rn", "pattern", "."]), &TEST_POLICY));
}
#[test]
fn combined_short_with_valued_last() {
assert!(check(&toks(&["grep", "-rnm", "5", "pattern"]), &TEST_POLICY));
}
#[test]
fn combined_short_valued_mid_consumes_rest() {
assert!(check(&toks(&["grep", "-rmn", "pattern"]), &TEST_POLICY));
}
#[test]
fn unknown_long_flag_denied() {
assert!(!check(&toks(&["grep", "--exec", "cmd"]), &TEST_POLICY));
}
#[test]
fn unknown_short_flag_denied() {
assert!(!check(&toks(&["grep", "-z", "pattern"]), &TEST_POLICY));
}
#[test]
fn unknown_combined_short_denied() {
assert!(!check(&toks(&["grep", "-rz", "pattern"]), &TEST_POLICY));
}
#[test]
fn unknown_long_eq_denied() {
assert!(!check(&toks(&["grep", "--output=file.txt", "pattern"]), &TEST_POLICY));
}
#[test]
fn double_dash_stops_checking() {
assert!(check(&toks(&["grep", "--", "--not-a-flag", "file"]), &TEST_POLICY));
}
#[test]
fn positional_args_allowed() {
assert!(check(&toks(&["grep", "pattern", "file.txt", "other.txt"]), &TEST_POLICY));
}
#[test]
fn mixed_flags_and_positional() {
assert!(check(
&toks(&["grep", "-rn", "--color", "--max-count", "10", "pattern", "."]),
&TEST_POLICY,
));
}
#[test]
fn valued_short_in_explicit_form() {
assert!(check(&toks(&["grep", "-A", "3", "-B", "3", "pattern"]), &TEST_POLICY));
}
#[test]
fn bare_dash_allowed_as_stdin() {
assert!(check(&toks(&["grep", "pattern", "-"]), &TEST_POLICY));
}
#[test]
fn valued_flag_at_end_without_value() {
assert!(check(&toks(&["grep", "--max-count"]), &TEST_POLICY));
}
#[test]
fn single_short_in_wordset_and_byte_array() {
assert!(check(&toks(&["grep", "-c", "pattern"]), &TEST_POLICY));
}
static LIMITED_POLICY: FlagPolicy = FlagPolicy {
standalone: WordSet::flags(&["--count", "-c", "-d", "-i", "-u"]),
valued: WordSet::flags(&["--skip-fields", "-f", "-s"]),
bare: true,
max_positional: Some(1),
flag_style: FlagStyle::Strict,
numeric_dash: false,
};
#[test]
fn max_positional_within_limit() {
assert!(check(&toks(&["uniq", "input.txt"]), &LIMITED_POLICY));
}
#[test]
fn max_positional_exceeded() {
assert!(!check(&toks(&["uniq", "input.txt", "output.txt"]), &LIMITED_POLICY));
}
#[test]
fn max_positional_with_flags_within_limit() {
assert!(check(&toks(&["uniq", "-c", "-f", "3", "input.txt"]), &LIMITED_POLICY));
}
#[test]
fn max_positional_with_flags_exceeded() {
assert!(!check(&toks(&["uniq", "-c", "input.txt", "output.txt"]), &LIMITED_POLICY));
}
#[test]
fn max_positional_after_double_dash() {
assert!(!check(&toks(&["uniq", "--", "input.txt", "output.txt"]), &LIMITED_POLICY));
}
#[test]
fn max_positional_bare_allowed() {
assert!(check(&toks(&["uniq"]), &LIMITED_POLICY));
}
static POSITIONAL_POLICY: FlagPolicy = FlagPolicy {
standalone: WordSet::flags(&["-E", "-e", "-n"]),
valued: WordSet::flags(&[]),
bare: true,
max_positional: None,
flag_style: FlagStyle::Positional,
numeric_dash: false,
};
#[test]
fn positional_style_unknown_long() {
assert!(check(&toks(&["echo", "--unknown", "hello"]), &POSITIONAL_POLICY));
}
#[test]
fn positional_style_unknown_short() {
assert!(check(&toks(&["echo", "-x", "hello"]), &POSITIONAL_POLICY));
}
#[test]
fn positional_style_dashes() {
assert!(check(&toks(&["echo", "---"]), &POSITIONAL_POLICY));
}
#[test]
fn positional_style_known_flags_still_work() {
assert!(check(&toks(&["echo", "-n", "hello"]), &POSITIONAL_POLICY));
}
#[test]
fn positional_style_combo_known() {
assert!(check(&toks(&["echo", "-ne", "hello"]), &POSITIONAL_POLICY));
}
#[test]
fn positional_style_combo_unknown_byte() {
assert!(check(&toks(&["echo", "-nx", "hello"]), &POSITIONAL_POLICY));
}
#[test]
fn positional_style_unknown_eq() {
assert!(check(&toks(&["echo", "--foo=bar"]), &POSITIONAL_POLICY));
}
#[test]
fn positional_style_with_max_positional() {
let policy = FlagPolicy {
standalone: WordSet::flags(&["-n"]),
valued: WordSet::flags(&[]),
bare: true,
max_positional: Some(2),
flag_style: FlagStyle::Positional,
numeric_dash: false,
};
assert!(check(&toks(&["echo", "--unknown", "hello"]), &policy));
assert!(!check(&toks(&["echo", "--a", "--b", "--c"]), &policy));
}
static NUMERIC_DASH_POLICY: FlagPolicy = FlagPolicy {
standalone: WordSet::flags(&[
"--help", "--quiet", "--verbose", "--version",
"-V", "-h", "-q", "-v", "-z",
]),
valued: WordSet::flags(&["--bytes", "--lines", "-c", "-n"]),
bare: true,
max_positional: None,
flag_style: FlagStyle::Strict,
numeric_dash: true,
};
#[test]
fn numeric_dash_single_digit() {
assert!(check(&toks(&["head", "-5"]), &NUMERIC_DASH_POLICY));
}
#[test]
fn numeric_dash_multi_digit() {
assert!(check(&toks(&["head", "-20"]), &NUMERIC_DASH_POLICY));
}
#[test]
fn numeric_dash_large_number() {
assert!(check(&toks(&["head", "-1000"]), &NUMERIC_DASH_POLICY));
}
#[test]
fn numeric_dash_with_file_arg() {
assert!(check(&toks(&["head", "-20", "file.txt"]), &NUMERIC_DASH_POLICY));
}
#[test]
fn numeric_dash_with_other_flags() {
assert!(check(&toks(&["head", "-q", "-20", "file.txt"]), &NUMERIC_DASH_POLICY));
}
#[test]
fn numeric_dash_zero() {
assert!(check(&toks(&["head", "-0"]), &NUMERIC_DASH_POLICY));
}
#[test]
fn numeric_dash_still_rejects_unknown_flags() {
assert!(!check(&toks(&["head", "-x"]), &NUMERIC_DASH_POLICY));
}
#[test]
fn numeric_dash_rejects_mixed_alpha_num() {
assert!(!check(&toks(&["head", "-20x"]), &NUMERIC_DASH_POLICY));
}
#[test]
fn numeric_dash_disabled_rejects_multi_digit() {
assert!(!check(&toks(&["grep", "-20", "pattern"]), &TEST_POLICY));
}
}