Skip to main content

matrixcode_core/tools/
multi_edit.rs

1use anyhow::Result;
2use async_trait::async_trait;
3use serde_json::{Value, json};
4
5use super::{Tool, ToolDefinition};
6use crate::approval::RiskLevel;
7
8pub struct MultiEditTool;
9
10#[async_trait]
11impl Tool for MultiEditTool {
12    fn definition(&self) -> ToolDefinition {
13        ToolDefinition {
14            name: "multi_edit".to_string(),
15            description: "对单个文件应用多处精确字符串替换,一次性原子写入。\
16                 每个编辑必须在前序编辑后的文件状态中精确匹配一次。\
17                 若任一编辑失败,文件不会被修改。"
18                .to_string(),
19            parameters: json!({
20                "type": "object",
21                "properties": {
22                    "path": {
23                        "type": "string",
24                        "description": "要编辑的文件路径"
25                    },
26                    "edits": {
27                        "type": "array",
28                        "description": "有序的 {old_string, new_string} 替换列表",
29                        "items": {
30                            "type": "object",
31                            "properties": {
32                                "old_string": {"type": "string"},
33                                "new_string": {"type": "string"}
34                            },
35                            "required": ["old_string", "new_string"]
36                        }
37                    }
38                },
39                "required": ["path", "edits"]
40            }),
41        }
42    }
43
44    async fn execute(&self, params: Value) -> Result<String> {
45        let path = params["path"]
46            .as_str()
47            .ok_or_else(|| anyhow::anyhow!("missing 'path'"))?;
48        let edits = params["edits"]
49            .as_array()
50            .ok_or_else(|| anyhow::anyhow!("missing 'edits' array"))?;
51        if edits.is_empty() {
52            anyhow::bail!("'edits' must contain at least one entry");
53        }
54
55        // Show spinner while editing - RAII guard ensures cleanup on error
56        // let mut spinner = ToolSpinner::new(&format!("multi-editing {} ({} edits)", path, edits.len()));
57
58        let mut content = tokio::fs::read_to_string(path).await?;
59
60        for (idx, edit) in edits.iter().enumerate() {
61            let old_string = edit["old_string"]
62                .as_str()
63                .ok_or_else(|| anyhow::anyhow!("edit {}: missing 'old_string'", idx))?;
64            let new_string = edit["new_string"]
65                .as_str()
66                .ok_or_else(|| anyhow::anyhow!("edit {}: missing 'new_string'", idx))?;
67
68            if old_string.is_empty() {
69                // spinner.finish_error("empty old_string");
70                anyhow::bail!("edit {}: 'old_string' must not be empty", idx);
71            }
72
73            let count = content.matches(old_string).count();
74            if count == 0 {
75                // spinner.finish_error(&format!("edit {} not found", idx));
76                anyhow::bail!("edit {}: old_string not found", idx);
77            }
78            if count > 1 {
79                // spinner.finish_error(&format!("edit {} multiple matches", idx));
80                anyhow::bail!(
81                    "edit {}: old_string found {} times — must be unique",
82                    idx,
83                    count
84                );
85            }
86
87            content = content.replacen(old_string, new_string, 1);
88        }
89
90        tokio::fs::write(path, &content).await?;
91
92        // Return diff-style output
93        let mut diff = format!("Applied {} edit(s) to {}\n", edits.len(), path);
94        for (idx, edit) in edits.iter().enumerate() {
95            let old_string = edit["old_string"].as_str().unwrap_or("");
96            let new_string = edit["new_string"].as_str().unwrap_or("");
97            if edits.len() > 1 {
98                diff.push_str(&format!("edit {}:\n", idx + 1));
99            }
100            for line in old_string.lines().take(3) {
101                diff.push_str(&format!("- {}\n", line));
102            }
103            if old_string.lines().count() > 3 {
104                diff.push_str(&format!(
105                    "  ... ({} more lines removed)\n",
106                    old_string.lines().count() - 3
107                ));
108            }
109            for line in new_string.lines().take(3) {
110                diff.push_str(&format!("+ {}\n", line));
111            }
112            if new_string.lines().count() > 3 {
113                diff.push_str(&format!(
114                    "  ... ({} more lines added)\n",
115                    new_string.lines().count() - 3
116                ));
117            }
118        }
119        Ok(diff)
120    }
121
122    fn risk_level(&self) -> RiskLevel {
123        RiskLevel::Mutating
124    }
125}