Skip to main content

j_cli/command/chat/
tools.rs

1use async_openai::types::chat::{ChatCompletionTool, ChatCompletionTools, FunctionObject};
2use serde_json::{Value, json};
3
4use super::skill::Skill;
5
6/// 展开路径中的 ~ 为用户 home 目录
7fn expand_tilde(path: &str) -> String {
8    if path == "~" {
9        std::env::var("HOME").unwrap_or_else(|_| "~".to_string())
10    } else if let Some(rest) = path.strip_prefix("~/") {
11        match std::env::var("HOME") {
12            Ok(home) => format!("{}/{}", home, rest),
13            Err(_) => path.to_string(),
14        }
15    } else {
16        path.to_string()
17    }
18}
19
20/// 工具执行结果
21pub struct ToolResult {
22    /// 返回给 LLM 的内容
23    pub output: String,
24    /// 是否执行出错
25    pub is_error: bool,
26}
27
28/// 工具 trait
29pub trait Tool: Send + Sync {
30    fn name(&self) -> &str;
31    fn description(&self) -> &str;
32    fn parameters_schema(&self) -> Value;
33    /// 执行工具(同步)
34    fn execute(&self, arguments: &str) -> ToolResult;
35    /// 是否需要用户确认(shell 命令需要,文件读取不需要)
36    fn requires_confirmation(&self) -> bool {
37        false
38    }
39    /// 生成确认提示文字(供 TUI 展示)
40    fn confirmation_message(&self, arguments: &str) -> String {
41        format!("调用工具 {} 参数: {}", self.name(), arguments)
42    }
43}
44
45// ========== run_shell ==========
46
47/// 执行 shell 命令的工具
48pub struct ShellTool;
49
50/// 简单的危险命令过滤
51fn is_dangerous_command(cmd: &str) -> bool {
52    let dangerous_patterns = [
53        "rm -rf /",
54        "rm -rf /*",
55        "mkfs",
56        "dd if=",
57        ":(){:|:&};:",
58        "chmod -R 777 /",
59        "chown -R",
60        "> /dev/sda",
61        "wget -O- | sh",
62        "curl | sh",
63        "alias",
64        "curl | bash",
65    ];
66    let cmd_lower = cmd.to_lowercase();
67    for pat in &dangerous_patterns {
68        if cmd_lower.contains(pat) {
69            return true;
70        }
71    }
72    false
73}
74
75impl Tool for ShellTool {
76    fn name(&self) -> &str {
77        "run_shell"
78    }
79
80    fn description(&self) -> &str {
81        "在当前系统上执行 shell 命令,返回命令的 stdout 和 stderr 输出;注意每次调用 run_shell 都会创建一个新的进程,状态是不延续的"
82    }
83
84    fn parameters_schema(&self) -> Value {
85        json!({
86            "type": "object",
87            "properties": {
88                "command": {
89                    "type": "string",
90                    "description": "要执行的 shell 命令(在 bash 中执行)"
91                }
92            },
93            "required": ["command"]
94        })
95    }
96
97    fn execute(&self, arguments: &str) -> ToolResult {
98        let command = match serde_json::from_str::<Value>(arguments) {
99            Ok(v) => match v.get("command").and_then(|c| c.as_str()) {
100                Some(cmd) => cmd.to_string(),
101                None => {
102                    return ToolResult {
103                        output: "参数缺少 command 字段".to_string(),
104                        is_error: true,
105                    };
106                }
107            },
108            Err(e) => {
109                return ToolResult {
110                    output: format!("参数解析失败: {}", e),
111                    is_error: true,
112                };
113            }
114        };
115
116        // 安全过滤
117        if is_dangerous_command(&command) {
118            return ToolResult {
119                output: "该命令被安全策略拒绝执行".to_string(),
120                is_error: true,
121            };
122        }
123
124        match std::process::Command::new("bash")
125            .arg("-c")
126            .arg(&command)
127            .output()
128        {
129            Ok(output) => {
130                let mut result = String::new();
131                let stdout = String::from_utf8_lossy(&output.stdout);
132                let stderr = String::from_utf8_lossy(&output.stderr);
133
134                if !stdout.is_empty() {
135                    result.push_str(&stdout);
136                }
137                if !stderr.is_empty() {
138                    if !result.is_empty() {
139                        result.push_str("\n[stderr]\n");
140                    } else {
141                        result.push_str("[stderr]\n");
142                    }
143                    result.push_str(&stderr);
144                }
145
146                if result.is_empty() {
147                    result = "(无输出)".to_string();
148                }
149
150                // 截断到 4000 字节
151                const MAX_BYTES: usize = 4000;
152                let truncated = if result.len() > MAX_BYTES {
153                    let mut end = MAX_BYTES;
154                    while !result.is_char_boundary(end) {
155                        end -= 1;
156                    }
157                    format!("{}\n...(输出已截断)", &result[..end])
158                } else {
159                    result
160                };
161
162                let is_error = !output.status.success();
163                ToolResult {
164                    output: truncated,
165                    is_error,
166                }
167            }
168            Err(e) => ToolResult {
169                output: format!("执行失败: {}", e),
170                is_error: true,
171            },
172        }
173    }
174
175    fn requires_confirmation(&self) -> bool {
176        true
177    }
178
179    fn confirmation_message(&self, arguments: &str) -> String {
180        // 尝试解析 command 字段
181        let cmd = serde_json::from_str::<Value>(arguments)
182            .ok()
183            .and_then(|v| {
184                v.get("command")
185                    .and_then(|c| c.as_str())
186                    .map(|s| s.to_string())
187            })
188            .unwrap_or_else(|| arguments.to_string());
189        format!("即将执行: {}", cmd)
190    }
191}
192
193// ========== read_file ==========
194
195/// 读取文件的工具
196pub struct ReadFileTool;
197
198impl Tool for ReadFileTool {
199    fn name(&self) -> &str {
200        "read_file"
201    }
202
203    fn description(&self) -> &str {
204        "读取本地文件内容并返回(带行号)。支持通过 offset 和 limit 参数按行范围读取。"
205    }
206
207    fn parameters_schema(&self) -> Value {
208        json!({
209            "type": "object",
210            "properties": {
211                "path": {
212                    "type": "string",
213                    "description": "要读取的文件路径(绝对路径或相对于当前工作目录)"
214                },
215                "offset": {
216                    "type": "integer",
217                    "description": "从第几行开始读取(0-based,即 0 表示第 1 行),不传则从头开始"
218                },
219                "limit": {
220                    "type": "integer",
221                    "description": "读取多少行,不传则读到文件末尾"
222                }
223            },
224            "required": ["path"]
225        })
226    }
227
228    fn execute(&self, arguments: &str) -> ToolResult {
229        let v = match serde_json::from_str::<Value>(arguments) {
230            Ok(v) => v,
231            Err(e) => {
232                return ToolResult {
233                    output: format!("参数解析失败: {}", e),
234                    is_error: true,
235                };
236            }
237        };
238
239        let path = match v.get("path").and_then(|c| c.as_str()) {
240            Some(p) => expand_tilde(p),
241            None => {
242                return ToolResult {
243                    output: "参数缺少 path 字段".to_string(),
244                    is_error: true,
245                };
246            }
247        };
248
249        let offset = v.get("offset").and_then(|o| o.as_u64()).map(|o| o as usize);
250        let limit = v.get("limit").and_then(|l| l.as_u64()).map(|l| l as usize);
251
252        match std::fs::read_to_string(&path) {
253            Ok(content) => {
254                let lines: Vec<&str> = content.lines().collect();
255                let total = lines.len();
256                let start = offset.unwrap_or(0).min(total);
257                let count = limit.unwrap_or(total - start).min(total - start);
258                let selected: Vec<String> = lines[start..start + count]
259                    .iter()
260                    .enumerate()
261                    .map(|(i, line)| format!("{:>4}│ {}", start + i + 1, line))
262                    .collect();
263                let mut result = selected.join("\n");
264
265                if start + count < total {
266                    result.push_str(&format!("\n...(还有 {} 行未显示)", total - start - count));
267                }
268
269                // 截断到 8000 字节
270                const MAX_BYTES: usize = 8000;
271                let truncated = if result.len() > MAX_BYTES {
272                    let mut end = MAX_BYTES;
273                    while !result.is_char_boundary(end) {
274                        end -= 1;
275                    }
276                    format!("{}\n...(文件内容已截断)", &result[..end])
277                } else {
278                    result
279                };
280                ToolResult {
281                    output: truncated,
282                    is_error: false,
283                }
284            }
285            Err(e) => ToolResult {
286                output: format!("读取文件失败: {}", e),
287                is_error: true,
288            },
289        }
290    }
291
292    fn requires_confirmation(&self) -> bool {
293        false
294    }
295}
296
297// ========== write_file ==========
298
299/// 写入文件的工具
300pub struct WriteFileTool;
301
302impl Tool for WriteFileTool {
303    fn name(&self) -> &str {
304        "write_file"
305    }
306
307    fn description(&self) -> &str {
308        "将内容写入指定文件。如果文件已存在则覆盖,如果目录不存在会自动创建。"
309    }
310
311    fn parameters_schema(&self) -> Value {
312        json!({
313            "type": "object",
314            "properties": {
315                "path": {
316                    "type": "string",
317                    "description": "要写入的文件路径(绝对路径或相对于当前工作目录)"
318                },
319                "content": {
320                    "type": "string",
321                    "description": "要写入的文件内容"
322                }
323            },
324            "required": ["path", "content"]
325        })
326    }
327
328    fn execute(&self, arguments: &str) -> ToolResult {
329        let v = match serde_json::from_str::<Value>(arguments) {
330            Ok(v) => v,
331            Err(e) => {
332                return ToolResult {
333                    output: format!("参数解析失败: {}", e),
334                    is_error: true,
335                };
336            }
337        };
338
339        let path = match v.get("path").and_then(|c| c.as_str()) {
340            Some(p) => expand_tilde(p),
341            None => {
342                return ToolResult {
343                    output: "参数缺少 path 字段".to_string(),
344                    is_error: true,
345                };
346            }
347        };
348
349        let content = match v.get("content").and_then(|c| c.as_str()) {
350            Some(c) => c.to_string(),
351            None => {
352                return ToolResult {
353                    output: "参数缺少 content 字段".to_string(),
354                    is_error: true,
355                };
356            }
357        };
358
359        // 自动创建父目录
360        let file_path = std::path::Path::new(&path);
361        if let Some(parent) = file_path.parent() {
362            if !parent.exists() {
363                if let Err(e) = std::fs::create_dir_all(parent) {
364                    return ToolResult {
365                        output: format!("创建目录失败: {}", e),
366                        is_error: true,
367                    };
368                }
369            }
370        }
371
372        match std::fs::write(&path, &content) {
373            Ok(_) => ToolResult {
374                output: format!("已写入文件: {} ({} 字节)", path, content.len()),
375                is_error: false,
376            },
377            Err(e) => ToolResult {
378                output: format!("写入文件失败: {}", e),
379                is_error: true,
380            },
381        }
382    }
383
384    fn requires_confirmation(&self) -> bool {
385        true
386    }
387
388    fn confirmation_message(&self, arguments: &str) -> String {
389        let path = serde_json::from_str::<Value>(arguments)
390            .ok()
391            .and_then(|v| {
392                v.get("path")
393                    .and_then(|c| c.as_str())
394                    .map(|s| expand_tilde(s))
395            })
396            .unwrap_or_else(|| "未知路径".to_string());
397        format!("即将写入文件: {}", path)
398    }
399}
400
401// ========== edit_file ==========
402
403/// 编辑文件的工具(基于字符串替换)
404pub struct EditFileTool;
405
406impl Tool for EditFileTool {
407    fn name(&self) -> &str {
408        "edit_file"
409    }
410
411    fn description(&self) -> &str {
412        "通过精确字符串匹配替换来编辑文件。old_string 必须在文件中唯一匹配,替换为 new_string。如果 new_string 为空字符串则表示删除匹配内容。"
413    }
414
415    fn parameters_schema(&self) -> Value {
416        json!({
417            "type": "object",
418            "properties": {
419                "path": {
420                    "type": "string",
421                    "description": "要编辑的文件路径"
422                },
423                "old_string": {
424                    "type": "string",
425                    "description": "要被替换的原始字符串(必须在文件中唯一存在)"
426                },
427                "new_string": {
428                    "type": "string",
429                    "description": "替换后的新字符串,为空则表示删除"
430                }
431            },
432            "required": ["path", "old_string", "new_string"]
433        })
434    }
435
436    fn execute(&self, arguments: &str) -> ToolResult {
437        let v = match serde_json::from_str::<Value>(arguments) {
438            Ok(v) => v,
439            Err(e) => {
440                return ToolResult {
441                    output: format!("参数解析失败: {}", e),
442                    is_error: true,
443                };
444            }
445        };
446
447        let path = match v.get("path").and_then(|c| c.as_str()) {
448            Some(p) => expand_tilde(p),
449            None => {
450                return ToolResult {
451                    output: "参数缺少 path 字段".to_string(),
452                    is_error: true,
453                };
454            }
455        };
456
457        let old_string = match v.get("old_string").and_then(|c| c.as_str()) {
458            Some(s) => s.to_string(),
459            None => {
460                return ToolResult {
461                    output: "参数缺少 old_string 字段".to_string(),
462                    is_error: true,
463                };
464            }
465        };
466
467        let new_string = v
468            .get("new_string")
469            .and_then(|c| c.as_str())
470            .unwrap_or("")
471            .to_string();
472
473        // 读取文件
474        let content = match std::fs::read_to_string(&path) {
475            Ok(c) => c,
476            Err(e) => {
477                return ToolResult {
478                    output: format!("读取文件失败: {}", e),
479                    is_error: true,
480                };
481            }
482        };
483
484        // 检查匹配次数
485        let count = content.matches(&old_string).count();
486        if count == 0 {
487            return ToolResult {
488                output: "未找到匹配的字符串".to_string(),
489                is_error: true,
490            };
491        }
492        if count > 1 {
493            return ToolResult {
494                output: format!(
495                    "old_string 在文件中匹配了 {} 次,必须唯一匹配。请提供更多上下文使其唯一",
496                    count
497                ),
498                is_error: true,
499            };
500        }
501
502        // 执行替换
503        let new_content = content.replacen(&old_string, &new_string, 1);
504        match std::fs::write(&path, &new_content) {
505            Ok(_) => ToolResult {
506                output: format!("已编辑文件: {}", path),
507                is_error: false,
508            },
509            Err(e) => ToolResult {
510                output: format!("写入文件失败: {}", e),
511                is_error: true,
512            },
513        }
514    }
515
516    fn requires_confirmation(&self) -> bool {
517        true
518    }
519
520    fn confirmation_message(&self, arguments: &str) -> String {
521        let v = serde_json::from_str::<Value>(arguments).ok();
522        let path = v
523            .as_ref()
524            .and_then(|v| {
525                v.get("path")
526                    .and_then(|c| c.as_str())
527                    .map(|s| expand_tilde(s))
528            })
529            .unwrap_or_else(|| "未知路径".to_string());
530        let old = v
531            .as_ref()
532            .and_then(|v| v.get("old_string").and_then(|c| c.as_str()))
533            .unwrap_or("");
534        let first_line = old.lines().next().unwrap_or("");
535        let has_more = old.lines().count() > 1;
536        let preview = if has_more {
537            format!("{}...", first_line)
538        } else {
539            first_line.to_string()
540        };
541        format!("即将编辑文件 {} (替换: \"{}\")", path, preview)
542    }
543}
544
545// ========== ToolRegistry ==========
546
547/// 工具注册表
548pub struct ToolRegistry {
549    tools: Vec<Box<dyn Tool>>,
550}
551
552impl ToolRegistry {
553    /// 创建注册表(包含 run_shell、read_file、write_file、edit_file,以及当 skills 非空时注册 load_skill)
554    pub fn new(skills: Vec<Skill>) -> Self {
555        let mut registry = Self {
556            tools: vec![
557                Box::new(ShellTool),
558                Box::new(ReadFileTool),
559                Box::new(WriteFileTool),
560                Box::new(EditFileTool),
561            ],
562        };
563
564        // 如果有 skills,注册统一的 LoadSkillTool
565        if !skills.is_empty() {
566            registry.register(Box::new(super::skill::LoadSkillTool { skills }));
567        }
568
569        registry
570    }
571
572    /// 注册一个工具
573    pub fn register(&mut self, tool: Box<dyn Tool>) {
574        self.tools.push(tool);
575    }
576
577    /// 按名称获取工具
578    pub fn get(&self, name: &str) -> Option<&dyn Tool> {
579        self.tools
580            .iter()
581            .find(|t| t.name() == name)
582            .map(|t| t.as_ref())
583    }
584
585    /// 构建工具摘要列表,用于系统提示词的 {{.tools}} 占位符
586    pub fn build_tools_summary(&self) -> String {
587        self.tools
588            .iter()
589            .map(|t| format!("- **{}**: {}", t.name(), t.description()))
590            .collect::<Vec<_>>()
591            .join("\n")
592    }
593
594    /// 生成 async-openai 的 ChatCompletionTools 列表
595    pub fn to_openai_tools(&self) -> Vec<ChatCompletionTools> {
596        self.tools
597            .iter()
598            .map(|t| {
599                ChatCompletionTools::Function(ChatCompletionTool {
600                    function: FunctionObject {
601                        name: t.name().to_string(),
602                        description: Some(t.description().to_string()),
603                        parameters: Some(t.parameters_schema()),
604                        strict: None,
605                    },
606                })
607            })
608            .collect()
609    }
610}