safe-chains 0.113.1

Auto-allow safe, read-only bash commands in agentic coding tools
Documentation
use crate::parse::{Token, WordSet};

#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum FlagStyle {
    Strict,
    Positional,
}

pub trait FlagSet {
    fn contains_flag(&self, token: &str) -> bool;
    fn contains_short(&self, byte: u8) -> bool;
}

impl FlagSet for WordSet {
    fn contains_flag(&self, token: &str) -> bool {
        self.contains(token)
    }
    fn contains_short(&self, byte: u8) -> bool {
        self.contains_short(byte)
    }
}

impl FlagSet for [String] {
    fn contains_flag(&self, token: &str) -> bool {
        self.iter().any(|f| f.as_str() == token)
    }
    fn contains_short(&self, byte: u8) -> bool {
        self.iter().any(|f| f.len() == 2 && f.as_bytes()[1] == byte)
    }
}

impl FlagSet for Vec<String> {
    fn contains_flag(&self, token: &str) -> bool {
        self.as_slice().contains_flag(token)
    }
    fn contains_short(&self, byte: u8) -> bool {
        self.as_slice().contains_short(byte)
    }
}

pub struct FlagPolicy {
    pub standalone: WordSet,
    pub valued: WordSet,
    pub bare: bool,
    pub max_positional: Option<usize>,
    pub flag_style: FlagStyle,
}

impl FlagPolicy {
    pub fn describe(&self) -> String {
        use crate::docs::wordset_items;
        let mut lines = Vec::new();
        let standalone = wordset_items(&self.standalone);
        if !standalone.is_empty() {
            lines.push(format!("- Allowed standalone flags: {standalone}"));
        }
        let valued = wordset_items(&self.valued);
        if !valued.is_empty() {
            lines.push(format!("- Allowed valued flags: {valued}"));
        }
        if self.bare {
            lines.push("- Bare invocation allowed".to_string());
        }
        if self.flag_style == FlagStyle::Positional {
            lines.push("- Hyphen-prefixed positional arguments accepted".to_string());
        }
        if lines.is_empty() && !self.bare {
            return "- Positional arguments only".to_string();
        }
        lines.join("\n")
    }

}

pub fn check(tokens: &[Token], policy: &FlagPolicy) -> bool {
    check_flags(
        tokens,
        &policy.standalone,
        &policy.valued,
        policy.bare,
        policy.max_positional,
        policy.flag_style,
    )
}

pub fn check_flags<S: FlagSet + ?Sized, V: FlagSet + ?Sized>(
    tokens: &[Token],
    standalone: &S,
    valued: &V,
    bare: bool,
    max_positional: Option<usize>,
    flag_style: FlagStyle,
) -> bool {
    if tokens.len() == 1 {
        return bare;
    }

    let mut i = 1;
    let mut positionals: usize = 0;
    while i < tokens.len() {
        let t = &tokens[i];

        if *t == "--" {
            positionals += tokens.len() - i - 1;
            break;
        }

        if !t.starts_with('-') {
            positionals += 1;
            i += 1;
            continue;
        }

        if standalone.contains_flag(t) {
            i += 1;
            continue;
        }

        if valued.contains_flag(t) {
            i += 2;
            continue;
        }

        if let Some(flag) = t.as_str().split_once('=').map(|(f, _)| f) {
            if valued.contains_flag(flag) {
                i += 1;
                continue;
            }
            if flag_style == FlagStyle::Positional {
                positionals += 1;
                i += 1;
                continue;
            }
            return false;
        }

        if t.starts_with("--") {
            if flag_style == FlagStyle::Positional {
                positionals += 1;
                i += 1;
                continue;
            }
            return false;
        }

        let bytes = t.as_bytes();
        let mut j = 1;
        while j < bytes.len() {
            let b = bytes[j];
            let is_last = j == bytes.len() - 1;
            if standalone.contains_short(b) {
                j += 1;
                continue;
            }
            if valued.contains_short(b) {
                if is_last {
                    i += 1;
                }
                break;
            }
            if flag_style == FlagStyle::Positional {
                positionals += 1;
                break;
            }
            return false;
        }
        i += 1;
    }
    max_positional.is_none_or(|max| positionals <= max)
}

#[cfg(test)]
mod tests {
    use super::*;

    static TEST_POLICY: FlagPolicy = FlagPolicy {
        standalone: WordSet::flags(&[
            "--color", "--count", "--help", "--recursive", "--version",
            "-H", "-c", "-i", "-l", "-n", "-o", "-r", "-s", "-v", "-w",
        ]),
        valued: WordSet::flags(&[
            "--after-context", "--before-context", "--max-count",
            "-A", "-B", "-m",
        ]),
        bare: false,
        max_positional: None,
        flag_style: FlagStyle::Strict,
    };

    fn toks(words: &[&str]) -> Vec<Token> {
        words.iter().map(|s| Token::from_test(s)).collect()
    }

    #[test]
    fn bare_denied_when_bare_false() {
        assert!(!check(&toks(&["grep"]), &TEST_POLICY));
    }

    #[test]
    fn bare_allowed_when_bare_true() {
        let policy = FlagPolicy {
            standalone: WordSet::flags(&[]),
            valued: WordSet::flags(&[]),
            bare: true,
            max_positional: None,
            flag_style: FlagStyle::Strict,
        };
        assert!(check(&toks(&["uname"]), &policy));
    }

    #[test]
    fn standalone_long_flag() {
        assert!(check(&toks(&["grep", "--recursive", "pattern", "."]), &TEST_POLICY));
    }

    #[test]
    fn standalone_short_flag() {
        assert!(check(&toks(&["grep", "-r", "pattern", "."]), &TEST_POLICY));
    }

    #[test]
    fn valued_long_flag_space() {
        assert!(check(&toks(&["grep", "--max-count", "5", "pattern"]), &TEST_POLICY));
    }

    #[test]
    fn valued_long_flag_eq() {
        assert!(check(&toks(&["grep", "--max-count=5", "pattern"]), &TEST_POLICY));
    }

    #[test]
    fn valued_short_flag_space() {
        assert!(check(&toks(&["grep", "-m", "5", "pattern"]), &TEST_POLICY));
    }

    #[test]
    fn combined_standalone_short() {
        assert!(check(&toks(&["grep", "-rn", "pattern", "."]), &TEST_POLICY));
    }

    #[test]
    fn combined_short_with_valued_last() {
        assert!(check(&toks(&["grep", "-rnm", "5", "pattern"]), &TEST_POLICY));
    }

    #[test]
    fn combined_short_valued_mid_consumes_rest() {
        assert!(check(&toks(&["grep", "-rmn", "pattern"]), &TEST_POLICY));
    }

    #[test]
    fn unknown_long_flag_denied() {
        assert!(!check(&toks(&["grep", "--exec", "cmd"]), &TEST_POLICY));
    }

    #[test]
    fn unknown_short_flag_denied() {
        assert!(!check(&toks(&["grep", "-z", "pattern"]), &TEST_POLICY));
    }

    #[test]
    fn unknown_combined_short_denied() {
        assert!(!check(&toks(&["grep", "-rz", "pattern"]), &TEST_POLICY));
    }

    #[test]
    fn unknown_long_eq_denied() {
        assert!(!check(&toks(&["grep", "--output=file.txt", "pattern"]), &TEST_POLICY));
    }

    #[test]
    fn double_dash_stops_checking() {
        assert!(check(&toks(&["grep", "--", "--not-a-flag", "file"]), &TEST_POLICY));
    }

    #[test]
    fn positional_args_allowed() {
        assert!(check(&toks(&["grep", "pattern", "file.txt", "other.txt"]), &TEST_POLICY));
    }

    #[test]
    fn mixed_flags_and_positional() {
        assert!(check(
            &toks(&["grep", "-rn", "--color", "--max-count", "10", "pattern", "."]),
            &TEST_POLICY,
        ));
    }

    #[test]
    fn valued_short_in_explicit_form() {
        assert!(check(&toks(&["grep", "-A", "3", "-B", "3", "pattern"]), &TEST_POLICY));
    }

    #[test]
    fn bare_dash_allowed_as_stdin() {
        assert!(check(&toks(&["grep", "pattern", "-"]), &TEST_POLICY));
    }

    #[test]
    fn valued_flag_at_end_without_value() {
        assert!(check(&toks(&["grep", "--max-count"]), &TEST_POLICY));
    }

    #[test]
    fn single_short_in_wordset_and_byte_array() {
        assert!(check(&toks(&["grep", "-c", "pattern"]), &TEST_POLICY));
    }

    static LIMITED_POLICY: FlagPolicy = FlagPolicy {
        standalone: WordSet::flags(&["--count", "-c", "-d", "-i", "-u"]),
        valued: WordSet::flags(&["--skip-fields", "-f", "-s"]),
        bare: true,
        max_positional: Some(1),
        flag_style: FlagStyle::Strict,
    };

    #[test]
    fn max_positional_within_limit() {
        assert!(check(&toks(&["uniq", "input.txt"]), &LIMITED_POLICY));
    }

    #[test]
    fn max_positional_exceeded() {
        assert!(!check(&toks(&["uniq", "input.txt", "output.txt"]), &LIMITED_POLICY));
    }

    #[test]
    fn max_positional_with_flags_within_limit() {
        assert!(check(&toks(&["uniq", "-c", "-f", "3", "input.txt"]), &LIMITED_POLICY));
    }

    #[test]
    fn max_positional_with_flags_exceeded() {
        assert!(!check(&toks(&["uniq", "-c", "input.txt", "output.txt"]), &LIMITED_POLICY));
    }

    #[test]
    fn max_positional_after_double_dash() {
        assert!(!check(&toks(&["uniq", "--", "input.txt", "output.txt"]), &LIMITED_POLICY));
    }

    #[test]
    fn max_positional_bare_allowed() {
        assert!(check(&toks(&["uniq"]), &LIMITED_POLICY));
    }

    static POSITIONAL_POLICY: FlagPolicy = FlagPolicy {
        standalone: WordSet::flags(&["-E", "-e", "-n"]),
        valued: WordSet::flags(&[]),
        bare: true,
        max_positional: None,
        flag_style: FlagStyle::Positional,
    };

    #[test]
    fn positional_style_unknown_long() {
        assert!(check(&toks(&["echo", "--unknown", "hello"]), &POSITIONAL_POLICY));
    }

    #[test]
    fn positional_style_unknown_short() {
        assert!(check(&toks(&["echo", "-x", "hello"]), &POSITIONAL_POLICY));
    }

    #[test]
    fn positional_style_dashes() {
        assert!(check(&toks(&["echo", "---"]), &POSITIONAL_POLICY));
    }

    #[test]
    fn positional_style_known_flags_still_work() {
        assert!(check(&toks(&["echo", "-n", "hello"]), &POSITIONAL_POLICY));
    }

    #[test]
    fn positional_style_combo_known() {
        assert!(check(&toks(&["echo", "-ne", "hello"]), &POSITIONAL_POLICY));
    }

    #[test]
    fn positional_style_combo_unknown_byte() {
        assert!(check(&toks(&["echo", "-nx", "hello"]), &POSITIONAL_POLICY));
    }

    #[test]
    fn positional_style_unknown_eq() {
        assert!(check(&toks(&["echo", "--foo=bar"]), &POSITIONAL_POLICY));
    }

    #[test]
    fn positional_style_with_max_positional() {
        let policy = FlagPolicy {
            standalone: WordSet::flags(&["-n"]),
            valued: WordSet::flags(&[]),
            bare: true,
            max_positional: Some(2),
            flag_style: FlagStyle::Positional,
        };
        assert!(check(&toks(&["echo", "--unknown", "hello"]), &policy));
        assert!(!check(&toks(&["echo", "--a", "--b", "--c"]), &policy));
    }
}