1use async_openai::types::chat::{ChatCompletionTool, ChatCompletionTools, FunctionObject};
2use serde_json::{Value, json};
3
4use super::skill::Skill;
5
6fn expand_tilde(path: &str) -> String {
8 if path == "~" {
9 std::env::var("HOME").unwrap_or_else(|_| "~".to_string())
10 } else if let Some(rest) = path.strip_prefix("~/") {
11 match std::env::var("HOME") {
12 Ok(home) => format!("{}/{}", home, rest),
13 Err(_) => path.to_string(),
14 }
15 } else {
16 path.to_string()
17 }
18}
19
20pub struct ToolResult {
22 pub output: String,
24 pub is_error: bool,
26}
27
28pub trait Tool: Send + Sync {
30 fn name(&self) -> &str;
31 fn description(&self) -> &str;
32 fn parameters_schema(&self) -> Value;
33 fn execute(&self, arguments: &str) -> ToolResult;
35 fn requires_confirmation(&self) -> bool {
37 false
38 }
39 fn confirmation_message(&self, arguments: &str) -> String {
41 format!("调用工具 {} 参数: {}", self.name(), arguments)
42 }
43}
44
45pub struct ShellTool;
49
50fn is_dangerous_command(cmd: &str) -> bool {
52 let dangerous_patterns = [
53 "rm -rf /",
54 "rm -rf /*",
55 "mkfs",
56 "dd if=",
57 ":(){:|:&};:",
58 "chmod -R 777 /",
59 "chown -R",
60 "> /dev/sda",
61 "wget -O- | sh",
62 "curl | sh",
63 "alias",
64 "curl | bash",
65 ];
66 let cmd_lower = cmd.to_lowercase();
67 for pat in &dangerous_patterns {
68 if cmd_lower.contains(pat) {
69 return true;
70 }
71 }
72 false
73}
74
75impl Tool for ShellTool {
76 fn name(&self) -> &str {
77 "run_shell"
78 }
79
80 fn description(&self) -> &str {
81 "在当前系统上执行 shell 命令,返回命令的 stdout 和 stderr 输出;注意每次调用 run_shell 都会创建一个新的进程,状态是不延续的"
82 }
83
84 fn parameters_schema(&self) -> Value {
85 json!({
86 "type": "object",
87 "properties": {
88 "command": {
89 "type": "string",
90 "description": "要执行的 shell 命令(在 bash 中执行)"
91 }
92 },
93 "required": ["command"]
94 })
95 }
96
97 fn execute(&self, arguments: &str) -> ToolResult {
98 let command = match serde_json::from_str::<Value>(arguments) {
99 Ok(v) => match v.get("command").and_then(|c| c.as_str()) {
100 Some(cmd) => cmd.to_string(),
101 None => {
102 return ToolResult {
103 output: "参数缺少 command 字段".to_string(),
104 is_error: true,
105 };
106 }
107 },
108 Err(e) => {
109 return ToolResult {
110 output: format!("参数解析失败: {}", e),
111 is_error: true,
112 };
113 }
114 };
115
116 if is_dangerous_command(&command) {
118 return ToolResult {
119 output: "该命令被安全策略拒绝执行".to_string(),
120 is_error: true,
121 };
122 }
123
124 match std::process::Command::new("bash")
125 .arg("-c")
126 .arg(&command)
127 .output()
128 {
129 Ok(output) => {
130 let mut result = String::new();
131 let stdout = String::from_utf8_lossy(&output.stdout);
132 let stderr = String::from_utf8_lossy(&output.stderr);
133
134 if !stdout.is_empty() {
135 result.push_str(&stdout);
136 }
137 if !stderr.is_empty() {
138 if !result.is_empty() {
139 result.push_str("\n[stderr]\n");
140 } else {
141 result.push_str("[stderr]\n");
142 }
143 result.push_str(&stderr);
144 }
145
146 if result.is_empty() {
147 result = "(无输出)".to_string();
148 }
149
150 const MAX_BYTES: usize = 4000;
152 let truncated = if result.len() > MAX_BYTES {
153 let mut end = MAX_BYTES;
154 while !result.is_char_boundary(end) {
155 end -= 1;
156 }
157 format!("{}\n...(输出已截断)", &result[..end])
158 } else {
159 result
160 };
161
162 let is_error = !output.status.success();
163 ToolResult {
164 output: truncated,
165 is_error,
166 }
167 }
168 Err(e) => ToolResult {
169 output: format!("执行失败: {}", e),
170 is_error: true,
171 },
172 }
173 }
174
175 fn requires_confirmation(&self) -> bool {
176 true
177 }
178
179 fn confirmation_message(&self, arguments: &str) -> String {
180 let cmd = serde_json::from_str::<Value>(arguments)
182 .ok()
183 .and_then(|v| {
184 v.get("command")
185 .and_then(|c| c.as_str())
186 .map(|s| s.to_string())
187 })
188 .unwrap_or_else(|| arguments.to_string());
189 format!("即将执行: {}", cmd)
190 }
191}
192
193pub struct ReadFileTool;
197
198impl Tool for ReadFileTool {
199 fn name(&self) -> &str {
200 "read_file"
201 }
202
203 fn description(&self) -> &str {
204 "读取本地文件内容并返回(带行号)。支持通过 offset 和 limit 参数按行范围读取。"
205 }
206
207 fn parameters_schema(&self) -> Value {
208 json!({
209 "type": "object",
210 "properties": {
211 "path": {
212 "type": "string",
213 "description": "要读取的文件路径(绝对路径或相对于当前工作目录)"
214 },
215 "offset": {
216 "type": "integer",
217 "description": "从第几行开始读取(0-based,即 0 表示第 1 行),不传则从头开始"
218 },
219 "limit": {
220 "type": "integer",
221 "description": "读取多少行,不传则读到文件末尾"
222 }
223 },
224 "required": ["path"]
225 })
226 }
227
228 fn execute(&self, arguments: &str) -> ToolResult {
229 let v = match serde_json::from_str::<Value>(arguments) {
230 Ok(v) => v,
231 Err(e) => {
232 return ToolResult {
233 output: format!("参数解析失败: {}", e),
234 is_error: true,
235 };
236 }
237 };
238
239 let path = match v.get("path").and_then(|c| c.as_str()) {
240 Some(p) => expand_tilde(p),
241 None => {
242 return ToolResult {
243 output: "参数缺少 path 字段".to_string(),
244 is_error: true,
245 };
246 }
247 };
248
249 let offset = v.get("offset").and_then(|o| o.as_u64()).map(|o| o as usize);
250 let limit = v.get("limit").and_then(|l| l.as_u64()).map(|l| l as usize);
251
252 match std::fs::read_to_string(&path) {
253 Ok(content) => {
254 let lines: Vec<&str> = content.lines().collect();
255 let total = lines.len();
256 let start = offset.unwrap_or(0).min(total);
257 let count = limit.unwrap_or(total - start).min(total - start);
258 let selected: Vec<String> = lines[start..start + count]
259 .iter()
260 .enumerate()
261 .map(|(i, line)| format!("{:>4}│ {}", start + i + 1, line))
262 .collect();
263 let mut result = selected.join("\n");
264
265 if start + count < total {
266 result.push_str(&format!("\n...(还有 {} 行未显示)", total - start - count));
267 }
268
269 const MAX_BYTES: usize = 8000;
271 let truncated = if result.len() > MAX_BYTES {
272 let mut end = MAX_BYTES;
273 while !result.is_char_boundary(end) {
274 end -= 1;
275 }
276 format!("{}\n...(文件内容已截断)", &result[..end])
277 } else {
278 result
279 };
280 ToolResult {
281 output: truncated,
282 is_error: false,
283 }
284 }
285 Err(e) => ToolResult {
286 output: format!("读取文件失败: {}", e),
287 is_error: true,
288 },
289 }
290 }
291
292 fn requires_confirmation(&self) -> bool {
293 false
294 }
295}
296
297pub struct WriteFileTool;
301
302impl Tool for WriteFileTool {
303 fn name(&self) -> &str {
304 "write_file"
305 }
306
307 fn description(&self) -> &str {
308 "将内容写入指定文件。如果文件已存在则覆盖,如果目录不存在会自动创建。"
309 }
310
311 fn parameters_schema(&self) -> Value {
312 json!({
313 "type": "object",
314 "properties": {
315 "path": {
316 "type": "string",
317 "description": "要写入的文件路径(绝对路径或相对于当前工作目录)"
318 },
319 "content": {
320 "type": "string",
321 "description": "要写入的文件内容"
322 }
323 },
324 "required": ["path", "content"]
325 })
326 }
327
328 fn execute(&self, arguments: &str) -> ToolResult {
329 let v = match serde_json::from_str::<Value>(arguments) {
330 Ok(v) => v,
331 Err(e) => {
332 return ToolResult {
333 output: format!("参数解析失败: {}", e),
334 is_error: true,
335 };
336 }
337 };
338
339 let path = match v.get("path").and_then(|c| c.as_str()) {
340 Some(p) => expand_tilde(p),
341 None => {
342 return ToolResult {
343 output: "参数缺少 path 字段".to_string(),
344 is_error: true,
345 };
346 }
347 };
348
349 let content = match v.get("content").and_then(|c| c.as_str()) {
350 Some(c) => c.to_string(),
351 None => {
352 return ToolResult {
353 output: "参数缺少 content 字段".to_string(),
354 is_error: true,
355 };
356 }
357 };
358
359 let file_path = std::path::Path::new(&path);
361 if let Some(parent) = file_path.parent() {
362 if !parent.exists() {
363 if let Err(e) = std::fs::create_dir_all(parent) {
364 return ToolResult {
365 output: format!("创建目录失败: {}", e),
366 is_error: true,
367 };
368 }
369 }
370 }
371
372 match std::fs::write(&path, &content) {
373 Ok(_) => ToolResult {
374 output: format!("已写入文件: {} ({} 字节)", path, content.len()),
375 is_error: false,
376 },
377 Err(e) => ToolResult {
378 output: format!("写入文件失败: {}", e),
379 is_error: true,
380 },
381 }
382 }
383
384 fn requires_confirmation(&self) -> bool {
385 true
386 }
387
388 fn confirmation_message(&self, arguments: &str) -> String {
389 let path = serde_json::from_str::<Value>(arguments)
390 .ok()
391 .and_then(|v| {
392 v.get("path")
393 .and_then(|c| c.as_str())
394 .map(|s| expand_tilde(s))
395 })
396 .unwrap_or_else(|| "未知路径".to_string());
397 format!("即将写入文件: {}", path)
398 }
399}
400
401pub struct EditFileTool;
405
406impl Tool for EditFileTool {
407 fn name(&self) -> &str {
408 "edit_file"
409 }
410
411 fn description(&self) -> &str {
412 "通过精确字符串匹配替换来编辑文件。old_string 必须在文件中唯一匹配,替换为 new_string。如果 new_string 为空字符串则表示删除匹配内容。"
413 }
414
415 fn parameters_schema(&self) -> Value {
416 json!({
417 "type": "object",
418 "properties": {
419 "path": {
420 "type": "string",
421 "description": "要编辑的文件路径"
422 },
423 "old_string": {
424 "type": "string",
425 "description": "要被替换的原始字符串(必须在文件中唯一存在)"
426 },
427 "new_string": {
428 "type": "string",
429 "description": "替换后的新字符串,为空则表示删除"
430 }
431 },
432 "required": ["path", "old_string", "new_string"]
433 })
434 }
435
436 fn execute(&self, arguments: &str) -> ToolResult {
437 let v = match serde_json::from_str::<Value>(arguments) {
438 Ok(v) => v,
439 Err(e) => {
440 return ToolResult {
441 output: format!("参数解析失败: {}", e),
442 is_error: true,
443 };
444 }
445 };
446
447 let path = match v.get("path").and_then(|c| c.as_str()) {
448 Some(p) => expand_tilde(p),
449 None => {
450 return ToolResult {
451 output: "参数缺少 path 字段".to_string(),
452 is_error: true,
453 };
454 }
455 };
456
457 let old_string = match v.get("old_string").and_then(|c| c.as_str()) {
458 Some(s) => s.to_string(),
459 None => {
460 return ToolResult {
461 output: "参数缺少 old_string 字段".to_string(),
462 is_error: true,
463 };
464 }
465 };
466
467 let new_string = v
468 .get("new_string")
469 .and_then(|c| c.as_str())
470 .unwrap_or("")
471 .to_string();
472
473 let content = match std::fs::read_to_string(&path) {
475 Ok(c) => c,
476 Err(e) => {
477 return ToolResult {
478 output: format!("读取文件失败: {}", e),
479 is_error: true,
480 };
481 }
482 };
483
484 let count = content.matches(&old_string).count();
486 if count == 0 {
487 return ToolResult {
488 output: "未找到匹配的字符串".to_string(),
489 is_error: true,
490 };
491 }
492 if count > 1 {
493 return ToolResult {
494 output: format!(
495 "old_string 在文件中匹配了 {} 次,必须唯一匹配。请提供更多上下文使其唯一",
496 count
497 ),
498 is_error: true,
499 };
500 }
501
502 let new_content = content.replacen(&old_string, &new_string, 1);
504 match std::fs::write(&path, &new_content) {
505 Ok(_) => ToolResult {
506 output: format!("已编辑文件: {}", path),
507 is_error: false,
508 },
509 Err(e) => ToolResult {
510 output: format!("写入文件失败: {}", e),
511 is_error: true,
512 },
513 }
514 }
515
516 fn requires_confirmation(&self) -> bool {
517 true
518 }
519
520 fn confirmation_message(&self, arguments: &str) -> String {
521 let v = serde_json::from_str::<Value>(arguments).ok();
522 let path = v
523 .as_ref()
524 .and_then(|v| {
525 v.get("path")
526 .and_then(|c| c.as_str())
527 .map(|s| expand_tilde(s))
528 })
529 .unwrap_or_else(|| "未知路径".to_string());
530 let old = v
531 .as_ref()
532 .and_then(|v| v.get("old_string").and_then(|c| c.as_str()))
533 .unwrap_or("");
534 let first_line = old.lines().next().unwrap_or("");
535 let has_more = old.lines().count() > 1;
536 let preview = if has_more {
537 format!("{}...", first_line)
538 } else {
539 first_line.to_string()
540 };
541 format!("即将编辑文件 {} (替换: \"{}\")", path, preview)
542 }
543}
544
545pub struct ToolRegistry {
549 tools: Vec<Box<dyn Tool>>,
550}
551
552impl ToolRegistry {
553 pub fn new(skills: Vec<Skill>) -> Self {
555 let mut registry = Self {
556 tools: vec![
557 Box::new(ShellTool),
558 Box::new(ReadFileTool),
559 Box::new(WriteFileTool),
560 Box::new(EditFileTool),
561 ],
562 };
563
564 if !skills.is_empty() {
566 registry.register(Box::new(super::skill::LoadSkillTool { skills }));
567 }
568
569 registry
570 }
571
572 pub fn register(&mut self, tool: Box<dyn Tool>) {
574 self.tools.push(tool);
575 }
576
577 pub fn get(&self, name: &str) -> Option<&dyn Tool> {
579 self.tools
580 .iter()
581 .find(|t| t.name() == name)
582 .map(|t| t.as_ref())
583 }
584
585 pub fn build_tools_summary(&self) -> String {
587 self.tools
588 .iter()
589 .map(|t| format!("- **{}**: {}", t.name(), t.description()))
590 .collect::<Vec<_>>()
591 .join("\n")
592 }
593
594 pub fn to_openai_tools(&self) -> Vec<ChatCompletionTools> {
596 self.tools
597 .iter()
598 .map(|t| {
599 ChatCompletionTools::Function(ChatCompletionTool {
600 function: FunctionObject {
601 name: t.name().to_string(),
602 description: Some(t.description().to_string()),
603 parameters: Some(t.parameters_schema()),
604 strict: None,
605 },
606 })
607 })
608 .collect()
609 }
610}