Skip to main content

j_agent/permission/
rules.rs

1use regex::Regex;
2use serde::{Deserialize, Serialize};
3use serde_json::Value;
4use std::collections::HashMap;
5use std::path::PathBuf;
6use std::sync::{LazyLock, Mutex};
7
8/// .jcli/ 目录权限配置
9#[derive(Debug, Deserialize, Serialize, Default, Clone)]
10pub struct JcliConfig {
11    #[serde(default)]
12    pub permissions: PermissionConfig,
13}
14
15/// 权限配置
16#[derive(Debug, Deserialize, Serialize, Default, Clone)]
17pub struct PermissionConfig {
18    /// 完全放开(跳过所有工具确认)
19    #[serde(default)]
20    pub allow_all: bool,
21    /// 允许列表(匹配则跳过确认)
22    #[serde(default)]
23    pub allow: Vec<String>,
24    /// 拒绝列表(优先于 allow,匹配则直接拒绝执行)
25    #[serde(default)]
26    pub deny: Vec<String>,
27}
28
29impl JcliConfig {
30    /// 从 cwd 向上查找 .jcli/ 目录并加载 permissions.yaml
31    pub fn load() -> Self {
32        if let Some(dir) = Self::find_config_dir() {
33            let perm_path = dir.join("permissions.yaml");
34            match std::fs::read_to_string(&perm_path) {
35                Ok(content) => {
36                    let permissions =
37                        serde_yaml::from_str::<PermissionConfig>(&content).unwrap_or_default();
38                    JcliConfig { permissions }
39                }
40                Err(_) => Self::default(),
41            }
42        } else {
43            Self::default()
44        }
45    }
46
47    /// 从当前目录向上查找 .jcli/ 目录
48    pub fn find_config_dir() -> Option<PathBuf> {
49        let mut dir = std::env::current_dir().ok()?;
50        loop {
51            let candidate = dir.join(".jcli");
52            if candidate.is_dir() {
53                return Some(candidate);
54            }
55            if !dir.pop() {
56                return None;
57            }
58        }
59    }
60
61    /// 确保 cwd 下存在 .jcli/ 目录,返回该目录路径
62    pub fn ensure_config_dir() -> Option<PathBuf> {
63        let dir = std::env::current_dir().ok()?.join(".jcli");
64        let _ = std::fs::create_dir_all(&dir);
65
66        // 创建 hooks/example 目录和 HOOK.yaml.example 模板(仅在首次创建时)
67        let example_dir = dir.join("hooks").join("example");
68        if !example_dir.exists() {
69            let _ = std::fs::create_dir_all(&example_dir);
70            let example_yaml = example_dir.join("HOOK.yaml.example");
71            if !example_yaml.exists() {
72                const HOOK_YAML_EXAMPLE: &str = include_str!("../../assets/hook_yaml_example.yaml");
73                let _ = std::fs::write(&example_yaml, HOOK_YAML_EXAMPLE);
74            }
75        }
76
77        Some(dir)
78    }
79
80    /// 检查某个工具调用是否被自动允许(跳过确认)
81    ///
82    /// - tool_name: "Bash", "Read", "Write" 等
83    /// - arguments: JSON 字符串(用于提取 command/path 等)
84    ///
85    /// 返回 true 表示该调用无需用户确认
86    pub fn is_allowed(&self, tool_name: &str, arguments: &str) -> bool {
87        // 先检查 deny(deny 优先于 allow)
88        if self.is_denied(tool_name, arguments) {
89            return false;
90        }
91
92        // allow_all 模式
93        if self.permissions.allow_all {
94            return true;
95        }
96
97        // 逐条匹配 allow 列表
98        for rule in &self.permissions.allow {
99            if matches_rule(rule, tool_name, arguments) {
100                return true;
101            }
102        }
103
104        false
105    }
106
107    /// 检查是否被 deny 列表拦截(deny 中匹配则直接拒绝执行)
108    pub fn is_denied(&self, tool_name: &str, arguments: &str) -> bool {
109        for rule in &self.permissions.deny {
110            if matches_rule(rule, tool_name, arguments) {
111                return true;
112            }
113        }
114        false
115    }
116
117    /// 将一条 allow 规则追加到 .jcli/permissions.yaml(若目录/文件不存在则创建)
118    /// 去重:如果 allow 列表已包含该规则则不重复添加
119    pub fn add_allow_rule(&mut self, rule: &str) {
120        // 去重
121        if self.permissions.allow.contains(&rule.to_string()) {
122            return;
123        }
124
125        // 更新内存
126        self.permissions.allow.push(rule.to_string());
127
128        // 确保 .jcli/ 目录存在
129        let config_dir = match Self::ensure_config_dir() {
130            Some(dir) => dir,
131            None => return,
132        };
133        let perm_path = config_dir.join("permissions.yaml");
134
135        // 如果文件已存在,尝试加载已有内容再追加
136        let mut permissions = if perm_path.is_file() {
137            match std::fs::read_to_string(&perm_path) {
138                Ok(content) => {
139                    serde_yaml::from_str::<PermissionConfig>(&content).unwrap_or_default()
140                }
141                Err(_) => PermissionConfig::default(),
142            }
143        } else {
144            PermissionConfig::default()
145        };
146
147        if !permissions.allow.contains(&rule.to_string()) {
148            permissions.allow.push(rule.to_string());
149        }
150
151        if let Ok(yaml) = serde_yaml::to_string(&permissions) {
152            let _ = std::fs::write(&perm_path, yaml);
153        }
154    }
155}
156
157/// 匹配单条规则
158///
159/// 支持的格式:
160/// - `"*"` → 匹配所有工具所有调用
161/// - `"Read"` → 匹配该工具所有调用(工具名不带括号)
162/// - `"Bash(cargo build:*)"` → Bash 命令前缀匹配
163/// - `"Write(path:/foo/bar/*)"` → 文件路径前缀匹配
164/// - `"WebFetch(domain:docs.rs)"` → URL 域名匹配
165fn matches_rule(rule: &str, tool_name: &str, arguments: &str) -> bool {
166    let rule = rule.trim();
167
168    // 通配符:匹配所有
169    if rule == "*" {
170        return true;
171    }
172
173    // 带括号的规则:ToolName(condition)
174    if let Some(paren_start) = rule.find('(') {
175        if !rule.ends_with(')') {
176            return false;
177        }
178        let rule_tool = &rule[..paren_start];
179        if rule_tool != tool_name {
180            return false;
181        }
182        let condition = &rule[paren_start + 1..rule.len() - 1];
183        return match_condition(tool_name, condition, arguments);
184    }
185
186    // 不带括号:纯工具名,匹配该工具所有调用
187    rule == tool_name
188}
189
190/// 匹配条件部分
191///
192/// - `"cargo build:*"` → Bash 命令前缀(取 arguments.command)
193/// - `"path:/foo/*"` → 文件路径前缀(取 arguments.file_path)
194/// - `"domain:docs.rs"` → URL 域名(取 arguments.url)
195/// - 支持 regex: `/pattern/` 语法,如 `"path:/\.rs$/"`, `"/^cargo (build|test)/"`
196/// - `"domain:/.*\.google\.com$/"` → regex 域名匹配
197fn match_condition(tool_name: &str, condition: &str, arguments: &str) -> bool {
198    let parsed: Value = match serde_json::from_str(arguments) {
199        Ok(v) => v,
200        Err(_) => return false,
201    };
202
203    // path: 前缀 → 文件路径匹配(Write, Edit, Read, Glob, Grep)
204    if let Some(path_pattern) = condition.strip_prefix("path:") {
205        let file_path = parsed
206            .get("file_path")
207            .or_else(|| parsed.get("path"))
208            .and_then(|v| v.as_str())
209            .unwrap_or("");
210        if is_regex_pattern(path_pattern) {
211            return match_regex(path_pattern, file_path);
212        }
213        return match_glob_prefix(path_pattern, file_path);
214    }
215
216    // domain: 前缀 → URL 域名匹配(WebFetch, WebSearch)
217    if let Some(domain) = condition.strip_prefix("domain:") {
218        let url = parsed.get("url").and_then(|v| v.as_str()).unwrap_or("");
219        if is_regex_pattern(domain) {
220            // 提取 host 后对 host 做 regex 匹配
221            let host = extract_host(url);
222            return match_regex(domain, &host);
223        }
224        return url_matches_domain(url, domain);
225    }
226
227    // ComputerUse: action 前缀匹配(格式 "action:screenshot:*" 或 "screenshot:*")
228    if tool_name == "ComputerUse" {
229        let action = parsed.get("action").and_then(|v| v.as_str()).unwrap_or("");
230        // 支持 "action:screenshot:*" 和 "screenshot:*" 两种格式
231        let action_pattern = if let Some(rest) = condition.strip_prefix("action:") {
232            rest
233        } else {
234            condition
235        };
236        if is_regex_pattern(action_pattern) {
237            return match_regex(action_pattern, action);
238        }
239        return match_command_prefix(action_pattern, action);
240    }
241
242    // 默认:Bash 命令前缀匹配(格式 "command_prefix:*")
243    if tool_name == "Bash" || tool_name == "Shell" {
244        let command = parsed.get("command").and_then(|v| v.as_str()).unwrap_or("");
245        if is_regex_pattern(condition) {
246            return match_regex(condition, command);
247        }
248        return match_command_prefix(condition, command);
249    }
250
251    false
252}
253
254// ========== Regex 辅助函数 ==========
255
256/// 全局 regex 编译缓存(避免重复编译同一正则表达式)
257static REGEX_CACHE: LazyLock<Mutex<HashMap<String, Regex>>> =
258    LazyLock::new(|| Mutex::new(HashMap::new()));
259
260/// 判断是否为 `/pattern/` 格式的 regex 模式
261fn is_regex_pattern(pattern: &str) -> bool {
262    pattern.starts_with('/') && pattern.ends_with('/') && pattern.len() >= 2
263}
264
265/// 用 `/pattern/` 匹配 input,带编译缓存
266/// 返回 regex 是否匹配(`is_match` 语义)
267fn match_regex(pattern: &str, input: &str) -> bool {
268    // 去掉首尾的 /
269    let regex_str = &pattern[1..pattern.len() - 1];
270    if regex_str.is_empty() {
271        return false;
272    }
273
274    let mut cache = match REGEX_CACHE.lock() {
275        Ok(c) => c,
276        Err(poisoned) => poisoned.into_inner(),
277    };
278
279    let re = cache
280        .entry(regex_str.to_string())
281        .or_insert_with(|| match Regex::new(regex_str) {
282            Ok(r) => r,
283            // SAFETY: "^$" 是合法正则模式,此处 unwrap 永不触发 panic
284            Err(_) => Regex::new("^$").unwrap_or_else(|_| unreachable!("^$ 是合法正则")),
285        });
286
287    re.is_match(input)
288}
289
290/// 从 URL 中提取 host 部分
291fn extract_host(url: &str) -> String {
292    let url_lower = url.to_lowercase();
293    let after_scheme = if let Some(pos) = url_lower.find("://") {
294        &url_lower[pos + 3..]
295    } else {
296        &url_lower
297    };
298    after_scheme
299        .split('/')
300        .next()
301        .unwrap_or("")
302        .split(':')
303        .next()
304        .unwrap_or("")
305        .to_string()
306}
307
308/// Bash 命令前缀匹配
309///
310/// 规则格式:`"cargo build:*"` 表示命令以 "cargo build" 开头
311/// 也支持 `"ls"` 不带 `:*` 后缀的精确前缀匹配
312fn match_command_prefix(pattern: &str, command: &str) -> bool {
313    // 去掉尾部的 `:*` 通配符
314    let prefix = pattern.strip_suffix(":*").unwrap_or(pattern).trim();
315
316    let command = command.trim();
317    // 前缀匹配:命令以 prefix 开头,后续要么是空、要么是空格/参数
318    if command == prefix {
319        return true;
320    }
321    if let Some(rest) = command.strip_prefix(prefix) {
322        return rest.starts_with(' ') || rest.starts_with('\t');
323    }
324    false
325}
326
327/// 简单的 glob 前缀匹配(支持尾部 `*` 通配符)
328///
329/// - `/foo/bar/*` → 匹配 /foo/bar/ 下的所有文件
330/// - `/foo/bar/baz.rs` → 精确匹配
331fn match_glob_prefix(pattern: &str, path: &str) -> bool {
332    if pattern == "*" {
333        return true;
334    }
335    if let Some(prefix) = pattern.strip_suffix('*') {
336        return path.starts_with(prefix);
337    }
338    path == pattern
339}
340
341/// URL 域名匹配
342fn url_matches_domain(url: &str, domain: &str) -> bool {
343    let host = extract_host(url);
344    let domain_lower = domain.to_lowercase();
345
346    host == domain_lower || host.ends_with(&format!(".{}", domain_lower))
347}
348
349/// 根据工具名和参数生成对应的 allow 规则
350///
351/// - Bash: 提取 command 字段的前两个词(如 `cargo build --release` → `Bash(cargo build:*)`)
352/// - Write/Edit: 提取 file_path 所在目录 → `Write(path:/dir/*)`
353/// - WebFetch: 提取 url 域名 → `WebFetch(domain:xxx)`
354/// - 其他工具: 直接用工具名 → `"Read"`
355pub fn generate_allow_rule(tool_name: &str, arguments: &str) -> String {
356    let parsed: Value = serde_json::from_str(arguments).unwrap_or(Value::Null);
357
358    match tool_name {
359        "ComputerUse" => {
360            let action = parsed.get("action").and_then(|v| v.as_str()).unwrap_or("");
361            if !action.is_empty() {
362                format!("ComputerUse({}:*)", action)
363            } else {
364                "ComputerUse".to_string()
365            }
366        }
367        "Bash" | "Shell" => {
368            let command = parsed.get("command").and_then(|v| v.as_str()).unwrap_or("");
369            let words: Vec<&str> = command.split_whitespace().collect();
370            let prefix = if words.len() >= 2 {
371                format!("{} {}", words[0], words[1])
372            } else if words.len() == 1 {
373                words[0].to_string()
374            } else {
375                return tool_name.to_string();
376            };
377            format!("{}({}:*)", tool_name, prefix)
378        }
379        "Write" | "Edit" => {
380            let file_path = parsed
381                .get("file_path")
382                .and_then(|v| v.as_str())
383                .unwrap_or("");
384            if let Some(dir) = std::path::Path::new(file_path).parent() {
385                format!("{}(path:{}/*)", tool_name, dir.display())
386            } else {
387                tool_name.to_string()
388            }
389        }
390        "WebFetch" => {
391            let url = parsed.get("url").and_then(|v| v.as_str()).unwrap_or("");
392            let host = extract_host(url);
393            if !host.is_empty() {
394                format!("WebFetch(domain:{})", host)
395            } else {
396                "WebFetch".to_string()
397            }
398        }
399        _ => tool_name.to_string(),
400    }
401}
402
403#[cfg(test)]
404mod tests;