omamori 0.3.2

AI Agent's Omamori — protect your system from dangerous commands executed via AI CLI tools
Documentation
use serde::{Deserialize, Serialize};

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CommandInvocation {
    pub program: String,
    pub args: Vec<String>,
}

impl CommandInvocation {
    pub fn new(program: String, args: Vec<String>) -> Self {
        Self { program, args }
    }

    /// Extract non-flag arguments (targets) from the command args.
    /// Respects the POSIX `--` separator: everything after `--` is a target,
    /// regardless of whether it starts with `-`.
    pub fn target_args(&self) -> Vec<&str> {
        if let Some(sep) = self.args.iter().position(|a| a == "--") {
            self.args[(sep + 1)..].iter().map(String::as_str).collect()
        } else {
            self.args
                .iter()
                .filter(|a| !a.starts_with('-'))
                .map(String::as_str)
                .collect()
        }
    }
}

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "kebab-case")]
pub enum ActionKind {
    Trash,
    StashThenExec,
    Block,
    LogOnly,
    MoveTo,
}

impl ActionKind {
    pub fn as_str(&self) -> &'static str {
        match self {
            Self::Trash => "trash",
            Self::StashThenExec => "stash-then-exec",
            Self::Block => "block",
            Self::LogOnly => "log-only",
            Self::MoveTo => "move-to",
        }
    }
}

fn default_true() -> bool {
    true
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RuleConfig {
    pub name: String,
    pub command: String,
    pub action: ActionKind,
    #[serde(default)]
    pub match_all: Vec<String>,
    #[serde(default)]
    pub match_any: Vec<String>,
    pub message: Option<String>,
    #[serde(default = "default_true")]
    pub enabled: bool,
    #[serde(default)]
    pub destination: Option<String>,
}

impl RuleConfig {
    pub fn new(
        name: &str,
        command: &str,
        action: ActionKind,
        match_all: Vec<String>,
        match_any: Vec<String>,
        message: Option<String>,
    ) -> Self {
        Self {
            name: name.to_string(),
            command: command.to_string(),
            action,
            match_all,
            match_any,
            message,
            enabled: true,
            destination: None,
        }
    }

    pub fn with_enabled(mut self, enabled: bool) -> Self {
        self.enabled = enabled;
        self
    }

    pub fn with_destination(mut self, destination: String) -> Self {
        self.destination = Some(destination);
        self
    }
}

pub fn match_rule<'a>(
    rules: &'a [RuleConfig],
    invocation: &CommandInvocation,
) -> Option<&'a RuleConfig> {
    rules
        .iter()
        .filter(|rule| rule.enabled)
        .find(|rule| rule_matches(rule, invocation))
}

/// Expand combined short flags like `-rfv` into individual flags
/// `["-rfv", "-r", "-f", "-v"]`, preserving the original.
/// Only expands when the flag chars are all ASCII alphabetic and there
/// are at least 2 chars after the leading `-`.
fn expand_short_flags(args: &[String]) -> Vec<String> {
    let mut expanded = Vec::with_capacity(args.len());
    for arg in args {
        expanded.push(arg.clone());
        let bytes = arg.as_bytes();
        if bytes.len() >= 3
            && bytes[0] == b'-'
            && bytes[1] != b'-'
            && bytes[1..].iter().all(|b| b.is_ascii_alphabetic())
        {
            for &ch in &bytes[1..] {
                let single = format!("-{}", ch as char);
                if !expanded.contains(&single) {
                    expanded.push(single);
                }
            }
        }
    }
    expanded
}

fn rule_matches(rule: &RuleConfig, invocation: &CommandInvocation) -> bool {
    if rule.command != invocation.program {
        return false;
    }

    let expanded = expand_short_flags(&invocation.args);

    if !rule
        .match_all
        .iter()
        .all(|needle| expanded.iter().any(|arg| arg == needle))
    {
        return false;
    }

    if rule.match_any.is_empty() {
        return true;
    }

    rule.match_any
        .iter()
        .any(|needle| expanded.iter().any(|arg| arg == needle))
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn matches_when_all_tokens_present() {
        let rule = RuleConfig::new(
            "git-reset",
            "git",
            ActionKind::StashThenExec,
            vec!["reset".to_string(), "--hard".to_string()],
            Vec::new(),
            None,
        );
        let invocation = CommandInvocation::new(
            "git".to_string(),
            vec!["reset".to_string(), "--hard".to_string()],
        );
        assert!(match_rule(&[rule], &invocation).is_some());
    }

    #[test]
    fn does_not_match_without_required_any_token() {
        let rule = RuleConfig::new(
            "git-push-force",
            "git",
            ActionKind::Block,
            vec!["push".to_string()],
            vec!["-f".to_string(), "--force".to_string()],
            None,
        );
        let invocation = CommandInvocation::new("git".to_string(), vec!["push".to_string()]);
        assert!(match_rule(&[rule], &invocation).is_none());
    }

    // --- target_args tests (Fix 1) ---

    #[test]
    fn target_args_respects_double_dash_separator() {
        let inv = CommandInvocation::new(
            "rm".to_string(),
            vec![
                "-rf".to_string(),
                "--".to_string(),
                "-dangerous.txt".to_string(),
            ],
        );
        assert_eq!(inv.target_args(), vec!["-dangerous.txt"]);
    }

    #[test]
    fn target_args_empty_after_double_dash() {
        let inv =
            CommandInvocation::new("rm".to_string(), vec!["-rf".to_string(), "--".to_string()]);
        assert!(inv.target_args().is_empty());
    }

    #[test]
    fn target_args_all_after_double_dash() {
        let inv = CommandInvocation::new(
            "rm".to_string(),
            vec![
                "--".to_string(),
                "-a".to_string(),
                "-b".to_string(),
                "-c".to_string(),
            ],
        );
        assert_eq!(inv.target_args(), vec!["-a", "-b", "-c"]);
    }

    #[test]
    fn target_args_no_separator_filters_flags() {
        let inv = CommandInvocation::new(
            "rm".to_string(),
            vec!["-rf".to_string(), "target/".to_string()],
        );
        assert_eq!(inv.target_args(), vec!["target/"]);
    }

    // --- expand_short_flags tests (Fix 2) ---

    #[test]
    fn expand_short_flags_splits_combined() {
        let args = vec!["-rfv".to_string()];
        let expanded = expand_short_flags(&args);
        assert!(expanded.contains(&"-rfv".to_string()));
        assert!(expanded.contains(&"-r".to_string()));
        assert!(expanded.contains(&"-f".to_string()));
        assert!(expanded.contains(&"-v".to_string()));
    }

    #[test]
    fn expand_short_flags_ignores_long_flags() {
        let args = vec!["--recursive".to_string()];
        let expanded = expand_short_flags(&args);
        assert_eq!(expanded, vec!["--recursive".to_string()]);
    }

    #[test]
    fn expand_short_flags_ignores_non_alpha() {
        let args = vec!["-C2".to_string(), "-1".to_string()];
        let expanded = expand_short_flags(&args);
        assert_eq!(expanded, vec!["-C2".to_string(), "-1".to_string()]);
    }

    #[test]
    fn expand_short_flags_single_char_not_expanded() {
        let args = vec!["-f".to_string()];
        let expanded = expand_short_flags(&args);
        assert_eq!(expanded, vec!["-f".to_string()]);
    }

    #[test]
    fn combined_flag_matches_rm_trash_rule() {
        let rule = RuleConfig::new(
            "rm-recursive",
            "rm",
            ActionKind::Trash,
            Vec::new(),
            vec![
                "-r".to_string(),
                "-rf".to_string(),
                "-fr".to_string(),
                "--recursive".to_string(),
            ],
            None,
        );
        // -rfv should match because it expands to include -r
        let inv = CommandInvocation::new(
            "rm".to_string(),
            vec!["-rfv".to_string(), "target/".to_string()],
        );
        assert!(match_rule(&[rule], &inv).is_some());
    }

    #[test]
    fn disabled_rule_is_skipped() {
        let rule = RuleConfig::new(
            "git-push-force",
            "git",
            ActionKind::Block,
            vec!["push".to_string()],
            vec!["-f".to_string(), "--force".to_string()],
            None,
        )
        .with_enabled(false);
        let inv = CommandInvocation::new(
            "git".to_string(),
            vec!["push".to_string(), "--force".to_string()],
        );
        assert!(match_rule(&[rule], &inv).is_none());
    }

    #[test]
    fn enabled_rule_still_matches() {
        let rule = RuleConfig::new(
            "git-push-force",
            "git",
            ActionKind::Block,
            vec!["push".to_string()],
            vec!["-f".to_string(), "--force".to_string()],
            None,
        )
        .with_enabled(true);
        let inv = CommandInvocation::new(
            "git".to_string(),
            vec!["push".to_string(), "--force".to_string()],
        );
        assert!(match_rule(&[rule], &inv).is_some());
    }

    #[test]
    fn move_to_action_serializes_correctly() {
        let rule = RuleConfig::new(
            "rm-to-backup",
            "rm",
            ActionKind::MoveTo,
            Vec::new(),
            vec!["-rf".to_string()],
            None,
        )
        .with_destination("/tmp/backup".to_string());
        assert_eq!(rule.action.as_str(), "move-to");
        assert_eq!(rule.destination.as_deref(), Some("/tmp/backup"));
    }

    #[test]
    fn git_push_dash_f_matches_block_rule() {
        let rule = RuleConfig::new(
            "git-push-force",
            "git",
            ActionKind::Block,
            vec!["push".to_string()],
            vec!["-f".to_string(), "--force".to_string()],
            None,
        );
        let inv = CommandInvocation::new(
            "git".to_string(),
            vec![
                "push".to_string(),
                "-f".to_string(),
                "origin".to_string(),
                "main".to_string(),
            ],
        );
        assert!(match_rule(&[rule], &inv).is_some());
    }
}