matrixcode_core/tools/
write.rs1use anyhow::Result;
2use async_trait::async_trait;
3use serde_json::{Value, json};
4
5use super::{Tool, ToolDefinition};
6use super::tool_hooks::{HookRegistry, HookResult};
7use super::code_quality_hook::{CodeQualityHook, VerificationStrategy};
8use crate::approval::RiskLevel;
9use crate::path_validator::{validate_content_size, validate_path};
10use super::verify::{VerifyTool, ProjectType};
11
12pub struct WriteTool {
13 hook_registry: HookRegistry,
15}
16
17impl Default for WriteTool {
18 fn default() -> Self {
19 Self::new()
20 }
21}
22
23impl WriteTool {
24 pub fn new() -> Self {
26 Self::with_verification_strategy(VerificationStrategy::default())
27 }
28
29 pub fn with_verification_strategy(strategy: VerificationStrategy) -> Self {
31 let mut registry = HookRegistry::new();
32 if strategy != VerificationStrategy::None {
34 registry.register(Box::new(CodeQualityHook::new(strategy)));
35 }
36 Self { hook_registry: registry }
37 }
38
39 pub fn with_hooks(registry: HookRegistry) -> Self {
41 Self { hook_registry: registry }
42 }
43
44 pub fn hook_registry(&self) -> &HookRegistry {
46 &self.hook_registry
47 }
48}
49
50#[async_trait]
51impl Tool for WriteTool {
52 fn definition(&self) -> ToolDefinition {
53 ToolDefinition {
54 name: "write".to_string(),
55 description: "向文件写入内容,若文件不存在则创建。
56
57【重要】写入现有文件前必须先读取:
58- 如果文件已存在,必须先用 read 工具读取当前内容
59- 如果没先读文件,此工具会失败
60- 了解现有内容可防止意外覆盖重要信息
61
62【代码质量验证】写入代码文件时会自动验证:
63- 根据 verify_strategy 配置决定验证时机
64- 'pre' 策略:写入前验证,失败则阻止写入并返回错误给 AI 纠正
65- 'post' 策略:写入后验证,结果附加在输出中
66- 支持的验证:cargo check / tsc / python -m py_compile / go vet
67
68优先用 edit 工具修改现有文件(只发送 diff)
69只在以下情况使用此工具:
70- 创建新文件
71- 完整重写文件(用户明确要求)
72
73路径安全:自动验证路径安全性,阻止路径穿越和系统文件写入"
74 .to_string(),
75 parameters: json!({
76 "type": "object",
77 "properties": {
78 "path": {
79 "type": "string",
80 "description": "要写入的文件路径(会自动验证安全性,阻止路径穿越和系统文件写入)"
81 },
82 "content": {
83 "type": "string",
84 "description": "要写入的内容(单次写入最大10MB,超大内容请分批写入)"
85 }
86 },
87 "required": ["path", "content"]
88 }),
89 ..Default::default()
90 }
91 }
92
93 async fn execute(&self, params: Value) -> Result<String> {
94 let path_str = params["path"]
96 .as_str()
97 .ok_or_else(|| anyhow::anyhow!("missing 'path'"))?
98 .to_string(); let hook_result = self.hook_registry.pre_execute("write", ¶ms).await?;
102
103 let final_params = match hook_result {
105 HookResult::Block { reason, details } => {
106 let error_msg = if let Some(d) = details {
108 format!("{}\n\n详细信息:\n{}", reason, d)
109 } else {
110 reason
111 };
112 return Err(anyhow::anyhow!(error_msg));
113 }
114 HookResult::Modify(new_params) => new_params,
115 HookResult::Continue => params,
116 };
117
118 let final_content = final_params["content"]
120 .as_str()
121 .ok_or_else(|| anyhow::anyhow!("missing 'content' after hook modification"))?;
122 validate_content_size(final_content)?;
123
124 let validated_path = validate_path(&path_str, None, true)?;
127
128 if let Some(parent) = validated_path.parent() {
130 tokio::fs::create_dir_all(parent).await?;
131 }
132
133 let total_bytes = final_content.len();
135 let size_mb = total_bytes as f64 / 1_000_000.0;
136
137 tokio::fs::write(&validated_path, final_content).await?;
139
140 let verify_result = self.run_code_verification(&validated_path, final_content).await;
142
143 let size_feedback = if size_mb > 1.0 {
145 format!(
146 " ({:.2} MB - large file written successfully. \
147 Consider splitting if this causes performance issues)",
148 size_mb
149 )
150 } else if size_mb > 0.1 {
151 format!(" ({:.2} MB)", size_mb)
152 } else {
153 format!(" ({:.2} KB)", total_bytes as f64 / 1_000.0)
154 };
155
156 let verify_feedback = match verify_result {
157 Ok(msg) => msg,
158 Err(e) => format!(" ⚠️ Code verification failed: {}", e),
159 };
160
161 let base_result = format!(
163 "Successfully wrote {} bytes{} to {}\nPath validated: {}\n{}",
164 total_bytes,
165 size_feedback,
166 path_str,
167 validated_path.display(),
168 verify_feedback
169 );
170
171 let final_result = self.hook_registry.post_execute("write", &final_params, &base_result).await?;
172
173 Ok(final_result)
174 }
175
176 fn risk_level(&self) -> RiskLevel {
177 RiskLevel::Mutating
178 }
179}
180
181impl WriteTool {
182 async fn run_code_verification(&self, path: &std::path::Path, _content: &str) -> Result<String> {
184 let extension = path.extension().and_then(|e| e.to_str());
186 let is_code_file = matches!(extension, Some("rs" | "ts" | "tsx" | "js" | "jsx" | "py" | "go"));
187
188 if !is_code_file {
189 return Ok("(非代码文件,跳过代码检测)".to_string());
190 }
191
192 let project_root = std::env::current_dir()?;
194 let verify_tool = VerifyTool::new(project_root);
195 let project_type = verify_tool.project_type();
196
197 match project_type {
199 ProjectType::Rust if extension == Some("rs") => {
200 self.verify_rust_code(path).await
201 }
202 ProjectType::NodeJs if matches!(extension, Some("ts" | "tsx")) => {
203 self.verify_typescript_code(path).await
204 }
205 ProjectType::Python if extension == Some("py") => {
206 self.verify_python_code(path).await
207 }
208 ProjectType::Go if extension == Some("go") => {
209 self.verify_go_code(path).await
210 }
211 _ => Ok(format!("({} 文件,当前项目类型 {} 无自动检测)",
212 extension.unwrap_or("unknown"),
213 Self::project_type_str(project_type)))
214 }
215 }
216
217 fn project_type_str(pt: ProjectType) -> &'static str {
219 match pt {
220 ProjectType::Rust => "Rust",
221 ProjectType::NodeJs => "Node.js",
222 ProjectType::Python => "Python",
223 ProjectType::Go => "Go",
224 ProjectType::Java => "Java",
225 ProjectType::Unknown => "未知",
226 }
227 }
228
229 async fn verify_rust_code(&self, _path: &std::path::Path) -> Result<String> {
231 let output = tokio::process::Command::new("cargo")
233 .args(["check", "--quiet"])
234 .output()
235 .await;
236
237 match output {
238 Ok(o) if o.status.success() => {
239 Ok("✅ cargo check 通过".to_string())
240 }
241 Ok(o) => {
242 let stderr = String::from_utf8_lossy(&o.stderr);
243 let errors = stderr.lines()
244 .filter(|l| l.contains("error") || l.contains("Error"))
245 .take(5)
246 .collect::<Vec<_>>()
247 .join("\n");
248 if errors.is_empty() {
249 Ok("⚠️ cargo check 有警告,请检查".to_string())
250 } else {
251 Ok(format!("❌ cargo check 失败:\n{}", errors))
252 }
253 }
254 Err(e) => {
255 Ok(format!("⚠️ 无法运行 cargo check: {}", e))
257 }
258 }
259 }
260
261 async fn verify_typescript_code(&self, _path: &std::path::Path) -> Result<String> {
263 let output = tokio::process::Command::new("npx")
265 .args(["tsc", "--noEmit"])
266 .output()
267 .await;
268
269 match output {
270 Ok(o) if o.status.success() => {
271 Ok("✅ tsc --noEmit 通过".to_string())
272 }
273 Ok(o) => {
274 let stderr = String::from_utf8_lossy(&o.stderr);
275 let errors = stderr.lines()
276 .filter(|l| l.contains("error"))
277 .take(5)
278 .collect::<Vec<_>>()
279 .join("\n");
280 if errors.is_empty() {
281 Ok("⚠️ TypeScript 类型检查有警告".to_string())
282 } else {
283 Ok(format!("❌ TypeScript 类型检查失败:\n{}", errors))
284 }
285 }
286 Err(e) => {
287 Ok(format!("⚠️ 无法运行 tsc: {}", e))
289 }
290 }
291 }
292
293 async fn verify_python_code(&self, path: &std::path::Path) -> Result<String> {
295 let output = tokio::process::Command::new("python")
297 .args(["-m", "py_compile"])
298 .arg(path)
299 .output()
300 .await;
301
302 match output {
303 Ok(o) if o.status.success() => {
304 Ok("✅ Python 语法检查通过".to_string())
305 }
306 Ok(o) => {
307 let stderr = String::from_utf8_lossy(&o.stderr);
308 let errors = stderr.lines()
309 .filter(|l| l.contains("Error") || l.contains("SyntaxError"))
310 .take(3)
311 .collect::<Vec<_>>()
312 .join("\n");
313 if errors.is_empty() {
314 Ok("⚠️ Python 检查有警告".to_string())
315 } else {
316 Ok(format!("❌ Python 语法检查失败:\n{}", errors))
317 }
318 }
319 Err(e) => {
320 Ok(format!("⚠️ 无法运行 Python 检查: {}", e))
321 }
322 }
323 }
324
325 async fn verify_go_code(&self, _path: &std::path::Path) -> Result<String> {
327 let output = tokio::process::Command::new("go")
329 .args(["vet"])
330 .output()
331 .await;
332
333 match output {
334 Ok(o) if o.status.success() => {
335 Ok("✅ go vet 通过".to_string())
336 }
337 Ok(o) => {
338 let stderr = String::from_utf8_lossy(&o.stderr);
339 let errors = stderr.lines()
340 .filter(|l| l.contains("error") || l.contains("undefined"))
341 .take(3)
342 .collect::<Vec<_>>()
343 .join("\n");
344 if errors.is_empty() {
345 Ok("⚠️ go vet 有警告".to_string())
346 } else {
347 Ok(format!("❌ go vet 失败:\n{}", errors))
348 }
349 }
350 Err(e) => {
351 Ok(format!("⚠️ 无法运行 go vet: {}", e))
352 }
353 }
354 }
355}