use hashbrown::HashMap;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum SafetyDecision {
Allow,
Deny(String),
Unknown,
}
#[derive(Clone)]
pub struct SafeCommandRegistry {
rules: HashMap<String, CommandRule>,
}
#[derive(Clone)]
pub struct CommandRule {
safe_subcommands: Option<rustc_hash::FxHashSet<String>>,
forbidden_options: Vec<String>,
custom_check: Option<fn(&[String]) -> SafetyDecision>,
}
impl CommandRule {
pub fn safe_readonly() -> Self {
Self {
safe_subcommands: None,
forbidden_options: vec![],
custom_check: None,
}
}
pub fn with_allowed_subcommands(subcommands: Vec<&str>) -> Self {
Self {
safe_subcommands: Some(
subcommands
.into_iter()
.map(|s| s.to_string())
.collect::<rustc_hash::FxHashSet<_>>(),
),
forbidden_options: vec![],
custom_check: None,
}
}
pub fn with_forbidden_options(options: Vec<&str>) -> Self {
Self {
safe_subcommands: None,
forbidden_options: options.into_iter().map(|s| s.to_string()).collect(),
custom_check: None,
}
}
}
impl SafeCommandRegistry {
pub fn new() -> Self {
Self {
rules: Self::default_rules(),
}
}
fn default_rules() -> HashMap<String, CommandRule> {
let mut rules = HashMap::new();
rules.insert(
"git".to_string(),
CommandRule {
safe_subcommands: Some(
vec!["status", "log", "diff", "show"]
.into_iter()
.map(|s| s.to_string())
.collect(),
),
forbidden_options: vec![],
custom_check: Some(Self::check_git),
},
);
rules.insert(
"cargo".to_string(),
CommandRule {
safe_subcommands: Some(
vec!["check", "build", "clippy"]
.into_iter()
.map(|s| s.to_string())
.collect(),
),
forbidden_options: vec![],
custom_check: Some(Self::check_cargo),
},
);
rules.insert(
"find".to_string(),
CommandRule {
safe_subcommands: None,
forbidden_options: vec![
"-exec".to_string(),
"-execdir".to_string(),
"-ok".to_string(),
"-okdir".to_string(),
"-delete".to_string(),
"-fls".to_string(),
"-fprint".to_string(),
"-fprint0".to_string(),
"-fprintf".to_string(),
],
custom_check: None,
},
);
rules.insert(
"base64".to_string(),
CommandRule {
safe_subcommands: None,
forbidden_options: vec!["-o".to_string(), "--output".to_string()],
custom_check: Some(Self::check_base64),
},
);
rules.insert(
"sed".to_string(),
CommandRule {
safe_subcommands: None,
forbidden_options: vec![],
custom_check: Some(Self::check_sed),
},
);
rules.insert(
"rg".to_string(),
CommandRule {
safe_subcommands: None,
forbidden_options: vec![
"--pre".to_string(),
"--hostname-bin".to_string(),
"--search-zip".to_string(),
"-z".to_string(),
],
custom_check: None,
},
);
for cmd in &[
"cat", "ls", "pwd", "echo", "grep", "head", "tail", "wc", "tr", "cut", "paste", "sort",
"uniq", "rev", "seq", "expr", "uname", "whoami", "id", "stat", "which",
] {
rules.insert(
cmd.to_string(),
CommandRule {
safe_subcommands: None,
forbidden_options: vec![],
custom_check: None,
},
);
}
rules
}
pub fn is_safe(&self, command: &[String]) -> SafetyDecision {
if command.is_empty() {
return SafetyDecision::Unknown;
}
let cmd_name = Self::extract_command_name(&command[0]);
let Some(rule) = self.rules.get(cmd_name) else {
return SafetyDecision::Unknown;
};
if let Some(check_fn) = rule.custom_check {
let result = check_fn(command);
if result != SafetyDecision::Unknown {
return result;
}
}
if let Some(ref safe_subs) = rule.safe_subcommands {
if command.len() < 2 {
return SafetyDecision::Deny(format!("Command {} requires a subcommand", cmd_name));
}
let subcommand = &command[1];
if !safe_subs.contains(subcommand) {
return SafetyDecision::Deny(format!(
"Subcommand {} not in safe list for {}",
subcommand, cmd_name
));
}
}
if !rule.forbidden_options.is_empty() {
let forbidden_with_eq: Vec<String> = rule
.forbidden_options
.iter()
.map(|opt| format!("{}=", opt))
.collect();
for arg in command {
for (i, forbidden) in rule.forbidden_options.iter().enumerate() {
if arg == forbidden || arg.starts_with(&forbidden_with_eq[i]) {
return SafetyDecision::Deny(format!(
"Option {} is not allowed for {}",
forbidden, cmd_name
));
}
}
}
}
SafetyDecision::Allow
}
fn extract_command_name(cmd: &str) -> &str {
std::path::Path::new(cmd)
.file_name()
.and_then(|osstr| osstr.to_str())
.unwrap_or(cmd)
}
fn check_git(command: &[String]) -> SafetyDecision {
if command.len() < 2 {
return SafetyDecision::Unknown;
}
if command
.iter()
.skip(1)
.map(String::as_str)
.any(crate::command_safety::dangerous_commands::git_global_option_requires_prompt)
{
return SafetyDecision::Deny(
"git global options that redirect config, repository, or helper lookup are not allowed"
.to_string(),
);
}
let subcommands = &["status", "log", "diff", "show", "branch"];
let Some((idx, subcommand)) =
crate::command_safety::dangerous_commands::find_git_subcommand(command, subcommands)
else {
return SafetyDecision::Unknown;
};
match subcommand {
"status" | "log" | "diff" | "show" => SafetyDecision::Allow,
"branch" => {
let branch_args = &command[idx + 1..];
let is_read_only = branch_args.iter().all(|arg| {
let arg = arg.as_str();
matches!(
arg,
"--show-current"
| "--list"
| "-l"
| "-v"
| "-vv"
| "-a"
| "-r"
| "--all"
| "--remote"
| "--verbose"
| "--format"
) || arg.starts_with("--format=")
|| arg.starts_with("--sort=")
|| arg.starts_with("--contains=")
|| arg.starts_with("--no-contains=")
|| arg.starts_with("--merged=")
|| arg.starts_with("--no-merged=")
|| arg.starts_with("--points-at=")
});
let has_dangerous_flag = branch_args.iter().any(|arg| {
let arg = arg.as_str();
matches!(
arg,
"-d" | "-D"
| "--delete"
| "-m"
| "-M"
| "--move"
| "-c"
| "-C"
| "--create"
| "--set-upstream"
| "--set-upstream-to"
| "--unset-upstream"
) || arg.starts_with("--delete=")
|| arg.starts_with("--move=")
|| arg.starts_with("--create=")
|| arg.starts_with("--set-upstream-to=")
});
if has_dangerous_flag {
SafetyDecision::Deny(
"git branch with modification flags is not allowed".to_string(),
)
} else if is_read_only || branch_args.is_empty() {
SafetyDecision::Allow
} else {
SafetyDecision::Deny(
"git branch with unknown flags requires approval".to_string(),
)
}
}
_ => SafetyDecision::Unknown,
}
}
fn check_cargo(command: &[String]) -> SafetyDecision {
if command.len() < 2 {
return SafetyDecision::Unknown;
}
match command[1].as_str() {
"check" | "build" | "clippy" => SafetyDecision::Allow,
"fmt" => {
if command.contains(&"--check".to_string()) {
SafetyDecision::Allow
} else {
SafetyDecision::Deny("cargo fmt without --check is not allowed".to_string())
}
}
_ => SafetyDecision::Deny(format!(
"cargo {} is not in safe subcommand list",
command[1]
)),
}
}
fn check_base64(command: &[String]) -> SafetyDecision {
const UNSAFE_OPTIONS: &[&str] = &["-o", "--output"];
for arg in command.iter().skip(1) {
if UNSAFE_OPTIONS.contains(&arg.as_str()) {
return SafetyDecision::Deny(format!(
"base64 {} is not allowed (output redirection)",
arg
));
}
if arg.starts_with("--output=") || (arg.starts_with("-o") && arg != "-o") {
return SafetyDecision::Deny(
"base64 output redirection is not allowed".to_string(),
);
}
}
SafetyDecision::Unknown
}
fn check_sed(command: &[String]) -> SafetyDecision {
if command.len() <= 2 {
return SafetyDecision::Unknown;
}
if command.len() <= 4
&& command.get(1).map(|s| s.as_str()) == Some("-n")
&& let Some(pattern) = command.get(2)
&& Self::is_valid_sed_n_arg(pattern)
{
return SafetyDecision::Allow;
}
SafetyDecision::Deny("sed only allows safe pattern: sed -n {N|M,N}p".to_string())
}
fn is_valid_sed_n_arg(arg: &str) -> bool {
let Some(core) = arg.strip_suffix('p') else {
return false;
};
let parts: Vec<&str> = core.split(',').collect();
match parts.as_slice() {
[num] => !num.is_empty() && num.chars().all(|c| c.is_ascii_digit()),
[a, b] => {
!a.is_empty()
&& !b.is_empty()
&& a.chars().all(|c| c.is_ascii_digit())
&& b.chars().all(|c| c.is_ascii_digit())
}
_ => false,
}
}
}
impl Default for SafeCommandRegistry {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn git_status_is_safe() {
let registry = SafeCommandRegistry::new();
let cmd = vec!["git".to_string(), "status".to_string()];
assert_eq!(registry.is_safe(&cmd), SafetyDecision::Allow);
}
#[test]
fn git_global_options_require_approval() {
let registry = SafeCommandRegistry::new();
for cmd in [
vec![
"git".to_string(),
"-c".to_string(),
"core.pager=cat".to_string(),
"show".to_string(),
"HEAD:foo.rs".to_string(),
],
vec![
"git".to_string(),
"--config-env".to_string(),
"core.pager=PAGER".to_string(),
"show".to_string(),
"HEAD".to_string(),
],
vec![
"git".to_string(),
"--git-dir=.evil-git".to_string(),
"diff".to_string(),
"HEAD~1..HEAD".to_string(),
],
vec![
"git".to_string(),
"--work-tree".to_string(),
".".to_string(),
"status".to_string(),
],
vec![
"git".to_string(),
"--exec-path=.git/helpers".to_string(),
"show".to_string(),
"HEAD".to_string(),
],
vec![
"git".to_string(),
"--namespace=attacker".to_string(),
"show".to_string(),
"HEAD".to_string(),
],
vec![
"git".to_string(),
"--super-prefix=attacker/".to_string(),
"show".to_string(),
"HEAD".to_string(),
],
] {
assert!(
matches!(registry.is_safe(&cmd), SafetyDecision::Deny(_)),
"expected {cmd:?} to require approval due to unsafe git global option",
);
}
}
#[test]
fn git_reset_is_dangerous() {
let registry = SafeCommandRegistry::new();
let cmd = vec!["git".to_string(), "reset".to_string()];
assert!(matches!(registry.is_safe(&cmd), SafetyDecision::Deny(_)));
}
#[test]
fn cargo_check_is_safe() {
let registry = SafeCommandRegistry::new();
let cmd = vec!["cargo".to_string(), "check".to_string()];
assert_eq!(registry.is_safe(&cmd), SafetyDecision::Allow);
}
#[test]
fn cargo_clean_is_dangerous() {
let registry = SafeCommandRegistry::new();
let cmd = vec!["cargo".to_string(), "clean".to_string()];
assert!(matches!(registry.is_safe(&cmd), SafetyDecision::Deny(_)));
}
#[test]
fn cargo_fmt_without_check_is_dangerous() {
let registry = SafeCommandRegistry::new();
let cmd = vec!["cargo".to_string(), "fmt".to_string()];
assert!(matches!(registry.is_safe(&cmd), SafetyDecision::Deny(_)));
}
#[test]
fn cargo_fmt_with_check_is_safe() {
let registry = SafeCommandRegistry::new();
let cmd = vec![
"cargo".to_string(),
"fmt".to_string(),
"--check".to_string(),
];
assert_eq!(registry.is_safe(&cmd), SafetyDecision::Allow);
}
#[test]
fn find_without_dangerous_options_is_allowed() {
let registry = SafeCommandRegistry::new();
let cmd = vec!["find".to_string(), ".".to_string()];
assert_eq!(registry.is_safe(&cmd), SafetyDecision::Allow);
}
#[test]
fn find_with_delete_is_dangerous() {
let registry = SafeCommandRegistry::new();
let cmd = vec!["find".to_string(), ".".to_string(), "-delete".to_string()];
assert!(matches!(registry.is_safe(&cmd), SafetyDecision::Deny(_)));
}
#[test]
fn find_with_exec_is_dangerous() {
let registry = SafeCommandRegistry::new();
let cmd = vec![
"find".to_string(),
".".to_string(),
"-exec".to_string(),
"rm".to_string(),
];
assert!(matches!(registry.is_safe(&cmd), SafetyDecision::Deny(_)));
}
#[test]
fn base64_without_output_is_allowed() {
let registry = SafeCommandRegistry::new();
let cmd = vec!["base64".to_string(), "file.txt".to_string()];
assert_eq!(registry.is_safe(&cmd), SafetyDecision::Allow);
}
#[test]
fn base64_with_output_is_dangerous() {
let registry = SafeCommandRegistry::new();
let cmd = vec![
"base64".to_string(),
"file.txt".to_string(),
"-o".to_string(),
"output.txt".to_string(),
];
assert!(matches!(registry.is_safe(&cmd), SafetyDecision::Deny(_)));
}
#[test]
fn sed_n_single_line_is_safe() {
let registry = SafeCommandRegistry::new();
let cmd = vec!["sed".to_string(), "-n".to_string(), "10p".to_string()];
assert_eq!(registry.is_safe(&cmd), SafetyDecision::Allow);
}
#[test]
fn sed_n_range_is_safe() {
let registry = SafeCommandRegistry::new();
let cmd = vec!["sed".to_string(), "-n".to_string(), "1,5p".to_string()];
assert_eq!(registry.is_safe(&cmd), SafetyDecision::Allow);
}
#[test]
fn sed_without_n_is_allowed() {
let registry = SafeCommandRegistry::new();
let cmd = vec!["sed".to_string(), "s/foo/bar/g".to_string()];
assert_eq!(registry.is_safe(&cmd), SafetyDecision::Allow);
}
#[test]
fn rg_with_pre_is_dangerous() {
let registry = SafeCommandRegistry::new();
let cmd = vec![
"rg".to_string(),
"--pre".to_string(),
"some_command".to_string(),
"pattern".to_string(),
];
assert!(matches!(registry.is_safe(&cmd), SafetyDecision::Deny(_)));
}
#[test]
fn cat_is_always_safe() {
let registry = SafeCommandRegistry::new();
let cmd = vec!["cat".to_string(), "file.txt".to_string()];
assert_eq!(registry.is_safe(&cmd), SafetyDecision::Allow);
}
#[test]
fn extract_command_name_from_path() {
assert_eq!(
SafeCommandRegistry::extract_command_name("/usr/bin/git"),
"git"
);
assert_eq!(
SafeCommandRegistry::extract_command_name("/usr/local/bin/cargo"),
"cargo"
);
assert_eq!(SafeCommandRegistry::extract_command_name("git"), "git");
}
}