Skip to main content

a3s_code_core/permissions/
rule.rs

1use serde::{Deserialize, Serialize};
2
3/// A permission rule with pattern matching support
4///
5/// Format: `ToolName(pattern)` or `ToolName` (matches all)
6///
7/// Examples:
8/// - `Bash(cargo:*)` - matches all cargo commands
9/// - `Bash(npm run test:*)` - matches npm run test with any args
10/// - `Read(src/**/*.rs)` - matches Rust files in src/
11/// - `Grep(*)` - matches all grep invocations
12/// - `mcp__pencil` - matches all pencil MCP tools
13///
14/// Deserialization supports both plain strings and `{rule: "..."}` objects:
15/// ```yaml
16/// allow:
17///   - read                   # plain string
18///   - rule: "Bash(cargo:*)"  # struct form
19/// ```
20#[derive(Debug, Clone, Serialize, PartialEq, Eq)]
21pub struct PermissionRule {
22    /// The original rule string
23    pub rule: String,
24    /// Parsed tool name
25    #[serde(skip)]
26    pub(crate) tool_name: Option<String>,
27    /// Parsed argument pattern (None means match all)
28    #[serde(skip)]
29    pub(crate) arg_pattern: Option<String>,
30}
31
32impl<'de> Deserialize<'de> for PermissionRule {
33    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
34    where
35        D: serde::Deserializer<'de>,
36    {
37        /// Helper enum to accept both `"read"` and `{rule: "read"}` in YAML/JSON.
38        #[derive(Deserialize)]
39        #[serde(untagged)]
40        enum RuleRepr {
41            Plain(String),
42            Struct { rule: String },
43        }
44
45        let rule_str = match RuleRepr::deserialize(deserializer)? {
46            RuleRepr::Plain(s) => s,
47            RuleRepr::Struct { rule } => rule,
48        };
49        // `new()` calls `parse_rule()` to populate tool_name and arg_pattern.
50        Ok(PermissionRule::new(&rule_str))
51    }
52}
53
54impl PermissionRule {
55    /// Create a new permission rule from a pattern string
56    pub fn new(rule: &str) -> Self {
57        let (tool_name, arg_pattern) = Self::parse_rule(rule);
58        Self {
59            rule: rule.to_string(),
60            tool_name,
61            arg_pattern,
62        }
63    }
64
65    /// Parse rule string into tool name and argument pattern
66    fn parse_rule(rule: &str) -> (Option<String>, Option<String>) {
67        // Handle format: ToolName(pattern) or ToolName
68        if let Some(paren_start) = rule.find('(') {
69            if rule.ends_with(')') {
70                let tool_name = rule[..paren_start].to_string();
71                let pattern = rule[paren_start + 1..rule.len() - 1].to_string();
72                return (Some(tool_name), Some(pattern));
73            }
74        }
75        // No parentheses - tool name only, matches all args
76        (Some(rule.to_string()), None)
77    }
78
79    /// Check if this rule matches a tool invocation
80    pub fn matches(&self, tool_name: &str, args: &serde_json::Value) -> bool {
81        // Check tool name
82        let rule_tool = match &self.tool_name {
83            Some(t) => t,
84            None => return false,
85        };
86
87        if !self.matches_tool_name(rule_tool, tool_name) {
88            return false;
89        }
90
91        // If no argument pattern, match all
92        let pattern = match &self.arg_pattern {
93            Some(p) => p,
94            None => return true,
95        };
96
97        // Match against argument pattern
98        self.matches_args(pattern, tool_name, args)
99    }
100
101    /// Check if tool names match (case-insensitive, wildcard-aware)
102    fn matches_tool_name(&self, rule_tool: &str, actual_tool: &str) -> bool {
103        // If the rule contains wildcards, use glob matching on the tool name directly.
104        // e.g. "mcp__longvt__*" must use glob, not starts_with, because starts_with
105        // treats '*' as a literal character and will never match.
106        if rule_tool.contains('*') || rule_tool.contains('?') {
107            return self.glob_match(rule_tool, actual_tool);
108        }
109
110        // Handle MCP tools: mcp__server matches mcp__server__tool
111        if rule_tool.starts_with("mcp__") && actual_tool.starts_with("mcp__") {
112            // mcp__pencil matches mcp__pencil__batch_design
113            if actual_tool.starts_with(rule_tool) {
114                return true;
115            }
116        }
117        rule_tool.eq_ignore_ascii_case(actual_tool)
118    }
119
120    /// Match argument pattern against tool arguments
121    fn matches_args(&self, pattern: &str, tool_name: &str, args: &serde_json::Value) -> bool {
122        // Handle wildcard pattern "*" - matches everything
123        if pattern == "*" {
124            return true;
125        }
126
127        // Build argument string based on tool type
128        let arg_string = self.build_arg_string(tool_name, args);
129
130        // Perform glob-style matching
131        self.glob_match(pattern, &arg_string)
132    }
133
134    /// Build a string representation of arguments for matching
135    fn build_arg_string(&self, tool_name: &str, args: &serde_json::Value) -> String {
136        match tool_name.to_lowercase().as_str() {
137            "bash" => {
138                // For Bash, use the command field
139                args.get("command")
140                    .and_then(|v| v.as_str())
141                    .unwrap_or("")
142                    .to_string()
143            }
144            "read" | "write" | "edit" => {
145                // For file operations, use the file_path field
146                args.get("file_path")
147                    .and_then(|v| v.as_str())
148                    .unwrap_or("")
149                    .to_string()
150            }
151            "glob" => {
152                // For glob, use the pattern field
153                args.get("pattern")
154                    .and_then(|v| v.as_str())
155                    .unwrap_or("")
156                    .to_string()
157            }
158            "grep" => {
159                // For grep, combine pattern and path
160                let pattern = args.get("pattern").and_then(|v| v.as_str()).unwrap_or("");
161                let path = args.get("path").and_then(|v| v.as_str()).unwrap_or("");
162                format!("{} {}", pattern, path)
163            }
164            "ls" => {
165                // For ls, use the path field
166                args.get("path")
167                    .and_then(|v| v.as_str())
168                    .unwrap_or("")
169                    .to_string()
170            }
171            _ => {
172                // For other tools, serialize the entire args
173                serde_json::to_string(args).unwrap_or_default()
174            }
175        }
176    }
177
178    /// Perform glob-style pattern matching
179    ///
180    /// Supports:
181    /// - `*` matches any sequence of characters (except /)
182    /// - `**` matches any sequence including /
183    /// - `:*` at the end matches any suffix (including empty)
184    fn glob_match(&self, pattern: &str, text: &str) -> bool {
185        // Handle special `:*` suffix (matches any args after the prefix)
186        if let Some(prefix) = pattern.strip_suffix(":*") {
187            return text.starts_with(prefix);
188        }
189
190        // Normalize Windows backslashes to forward slashes for consistent matching
191        let text = text.replace('\\', "/");
192
193        // Convert glob pattern to regex pattern
194        let regex_pattern = Self::glob_to_regex(pattern);
195        if let Ok(re) = regex::Regex::new(&regex_pattern) {
196            re.is_match(&text)
197        } else {
198            // Fallback to simple prefix match if regex fails
199            text.starts_with(pattern)
200        }
201    }
202
203    /// Convert glob pattern to regex pattern
204    fn glob_to_regex(pattern: &str) -> String {
205        let mut regex = String::from("^");
206        let chars: Vec<char> = pattern.chars().collect();
207        let mut i = 0;
208
209        while i < chars.len() {
210            let c = chars[i];
211            match c {
212                '*' => {
213                    // Check for ** (matches anything including /)
214                    if i + 1 < chars.len() && chars[i + 1] == '*' {
215                        // ** matches any path including /
216                        // Skip optional following /
217                        if i + 2 < chars.len() && chars[i + 2] == '/' {
218                            regex.push_str(".*");
219                            i += 3;
220                        } else {
221                            regex.push_str(".*");
222                            i += 2;
223                        }
224                    } else {
225                        // * matches anything except path separators
226                        regex.push_str("[^/\\\\]*");
227                        i += 1;
228                    }
229                }
230                '?' => {
231                    // ? matches any single character except path separators
232                    regex.push_str("[^/\\\\]");
233                    i += 1;
234                }
235                '.' | '+' | '^' | '$' | '(' | ')' | '[' | ']' | '{' | '}' | '|' | '\\' => {
236                    // Escape regex special characters
237                    regex.push('\\');
238                    regex.push(c);
239                    i += 1;
240                }
241                _ => {
242                    regex.push(c);
243                    i += 1;
244                }
245            }
246        }
247
248        regex.push('$');
249        regex
250    }
251}