Skip to main content

matrixcode_core/tools/
write.rs

1use anyhow::Result;
2use async_trait::async_trait;
3use serde_json::{Value, json};
4
5use super::{Tool, ToolDefinition};
6use super::tool_hooks::{HookRegistry, HookResult};
7use super::code_quality_hook::{CodeQualityHook, VerificationStrategy};
8use crate::approval::RiskLevel;
9use crate::path_validator::{validate_content_size, validate_path};
10use super::verify::{VerifyTool, ProjectType};
11
12pub struct WriteTool {
13    /// Hook registry for pre/post execution hooks
14    hook_registry: HookRegistry,
15}
16
17impl Default for WriteTool {
18    fn default() -> Self {
19        Self::new()
20    }
21}
22
23impl WriteTool {
24    /// Create a new WriteTool with default hooks
25    pub fn new() -> Self {
26        Self::with_verification_strategy(VerificationStrategy::default())
27    }
28
29    /// Create WriteTool with specific verification strategy
30    pub fn with_verification_strategy(strategy: VerificationStrategy) -> Self {
31        let mut registry = HookRegistry::new();
32        // Note: CodeQualityHook is added when strategy is not None
33        if strategy != VerificationStrategy::None {
34            registry.register(Box::new(CodeQualityHook::new(strategy)));
35        }
36        Self { hook_registry: registry }
37    }
38
39    /// Create WriteTool with custom hook registry
40    pub fn with_hooks(registry: HookRegistry) -> Self {
41        Self { hook_registry: registry }
42    }
43
44    /// Get the hook registry
45    pub fn hook_registry(&self) -> &HookRegistry {
46        &self.hook_registry
47    }
48}
49
50#[async_trait]
51impl Tool for WriteTool {
52    fn definition(&self) -> ToolDefinition {
53        ToolDefinition {
54            name: "write".to_string(),
55            description: "向文件写入内容,若文件不存在则创建。
56
57【重要】写入现有文件前必须先读取:
58- 如果文件已存在,必须先用 read 工具读取当前内容
59- 如果没先读文件,此工具会失败
60- 了解现有内容可防止意外覆盖重要信息
61
62【代码质量验证】写入代码文件时会自动验证:
63- 根据 verify_strategy 配置决定验证时机
64- 'pre' 策略:写入前验证,失败则阻止写入并返回错误给 AI 纠正
65- 'post' 策略:写入后验证,结果附加在输出中
66- 支持的验证:cargo check / tsc / python -m py_compile / go vet
67
68优先用 edit 工具修改现有文件(只发送 diff)
69只在以下情况使用此工具:
70- 创建新文件
71- 完整重写文件(用户明确要求)
72
73路径安全:自动验证路径安全性,阻止路径穿越和系统文件写入"
74                .to_string(),
75            parameters: json!({
76                "type": "object",
77                "properties": {
78                    "path": {
79                        "type": "string",
80                        "description": "要写入的文件路径(会自动验证安全性,阻止路径穿越和系统文件写入)"
81                    },
82                    "content": {
83                        "type": "string",
84                        "description": "要写入的内容(单次写入最大10MB,超大内容请分批写入)"
85                    }
86                },
87                "required": ["path", "content"]
88            }),
89            ..Default::default()
90        }
91    }
92
93    async fn execute(&self, params: Value) -> Result<String> {
94        // Extract path early (before potential modification by hooks)
95        let path_str = params["path"]
96            .as_str()
97            .ok_or_else(|| anyhow::anyhow!("missing 'path'"))?
98            .to_string();  // Clone to avoid borrow issues
99
100        // 1. Run pre-execute hooks (including code quality verification)
101        let hook_result = self.hook_registry.pre_execute("write", &params).await?;
102
103        // Check if hooks blocked execution
104        let final_params = match hook_result {
105            HookResult::Block { reason, details } => {
106                // Return error to AI for correction
107                let error_msg = if let Some(d) = details {
108                    format!("{}\n\n详细信息:\n{}", reason, d)
109                } else {
110                    reason
111                };
112                return Err(anyhow::anyhow!(error_msg));
113            }
114            HookResult::Modify(new_params) => new_params,
115            HookResult::Continue => params,
116        };
117
118        // 2. Validate content size (prevent accidental huge writes)
119        let final_content = final_params["content"]
120            .as_str()
121            .ok_or_else(|| anyhow::anyhow!("missing 'content' after hook modification"))?;
122        validate_content_size(final_content)?;
123
124        // 3. Validate path security (prevent path traversal and system file writes)
125        // For writes, we use strict validation (is_write=true)
126        let validated_path = validate_path(&path_str, None, true)?;
127
128        // 4. Create parent directories if needed
129        if let Some(parent) = validated_path.parent() {
130            tokio::fs::create_dir_all(parent).await?;
131        }
132
133        // 5. Write the file with validated path
134        let total_bytes = final_content.len();
135        let size_mb = total_bytes as f64 / 1_000_000.0;
136
137        // Write the file
138        tokio::fs::write(&validated_path, final_content).await?;
139
140        // 6. Run code verification for code files (post-write)
141        let verify_result = self.run_code_verification(&validated_path, final_content).await;
142
143        // 7. Provide helpful feedback based on file size
144        let size_feedback = if size_mb > 1.0 {
145            format!(
146                " ({:.2} MB - large file written successfully. \
147                Consider splitting if this causes performance issues)",
148                size_mb
149            )
150        } else if size_mb > 0.1 {
151            format!(" ({:.2} MB)", size_mb)
152        } else {
153            format!(" ({:.2} KB)", total_bytes as f64 / 1_000.0)
154        };
155
156        let verify_feedback = match verify_result {
157            Ok(msg) => msg,
158            Err(e) => format!(" ⚠️ Code verification failed: {}", e),
159        };
160
161        // 8. Run post-execute hooks
162        let base_result = format!(
163            "Successfully wrote {} bytes{} to {}\nPath validated: {}\n{}",
164            total_bytes,
165            size_feedback,
166            path_str,
167            validated_path.display(),
168            verify_feedback
169        );
170
171        let final_result = self.hook_registry.post_execute("write", &final_params, &base_result).await?;
172
173        Ok(final_result)
174    }
175
176    fn risk_level(&self) -> RiskLevel {
177        RiskLevel::Mutating
178    }
179}
180
181impl WriteTool {
182    /// Run code verification after writing code files
183    async fn run_code_verification(&self, path: &std::path::Path, _content: &str) -> Result<String> {
184        // Only verify code files
185        let extension = path.extension().and_then(|e| e.to_str());
186        let is_code_file = matches!(extension, Some("rs" | "ts" | "tsx" | "js" | "jsx" | "py" | "go"));
187
188        if !is_code_file {
189            return Ok("(非代码文件,跳过代码检测)".to_string());
190        }
191
192        // Detect project type
193        let project_root = std::env::current_dir()?;
194        let verify_tool = VerifyTool::new(project_root);
195        let project_type = verify_tool.project_type();
196
197        // Run appropriate verification based on project type and file type
198        match project_type {
199            ProjectType::Rust if extension == Some("rs") => {
200                self.verify_rust_code(path).await
201            }
202            ProjectType::NodeJs if matches!(extension, Some("ts" | "tsx")) => {
203                self.verify_typescript_code(path).await
204            }
205            ProjectType::Python if extension == Some("py") => {
206                self.verify_python_code(path).await
207            }
208            ProjectType::Go if extension == Some("go") => {
209                self.verify_go_code(path).await
210            }
211            _ => Ok(format!("({} 文件,当前项目类型 {} 无自动检测)",
212                extension.unwrap_or("unknown"),
213                Self::project_type_str(project_type)))
214        }
215    }
216
217    /// Get project type display name
218    fn project_type_str(pt: ProjectType) -> &'static str {
219        match pt {
220            ProjectType::Rust => "Rust",
221            ProjectType::NodeJs => "Node.js",
222            ProjectType::Python => "Python",
223            ProjectType::Go => "Go",
224            ProjectType::Java => "Java",
225            ProjectType::Unknown => "未知",
226        }
227    }
228
229    /// Verify Rust code with cargo check
230    async fn verify_rust_code(&self, _path: &std::path::Path) -> Result<String> {
231        // Run cargo check in background
232        let output = tokio::process::Command::new("cargo")
233            .args(["check", "--quiet"])
234            .output()
235            .await;
236
237        match output {
238            Ok(o) if o.status.success() => {
239                Ok("✅ cargo check 通过".to_string())
240            }
241            Ok(o) => {
242                let stderr = String::from_utf8_lossy(&o.stderr);
243                let errors = stderr.lines()
244                    .filter(|l| l.contains("error") || l.contains("Error"))
245                    .take(5)
246                    .collect::<Vec<_>>()
247                    .join("\n");
248                if errors.is_empty() {
249                    Ok("⚠️ cargo check 有警告,请检查".to_string())
250                } else {
251                    Ok(format!("❌ cargo check 失败:\n{}", errors))
252                }
253            }
254            Err(e) => {
255                // cargo not found or other error - don't fail the write
256                Ok(format!("⚠️ 无法运行 cargo check: {}", e))
257            }
258        }
259    }
260
261    /// Verify TypeScript code with tsc --noEmit
262    async fn verify_typescript_code(&self, _path: &std::path::Path) -> Result<String> {
263        // Run tsc --noEmit in background
264        let output = tokio::process::Command::new("npx")
265            .args(["tsc", "--noEmit"])
266            .output()
267            .await;
268
269        match output {
270            Ok(o) if o.status.success() => {
271                Ok("✅ tsc --noEmit 通过".to_string())
272            }
273            Ok(o) => {
274                let stderr = String::from_utf8_lossy(&o.stderr);
275                let errors = stderr.lines()
276                    .filter(|l| l.contains("error"))
277                    .take(5)
278                    .collect::<Vec<_>>()
279                    .join("\n");
280                if errors.is_empty() {
281                    Ok("⚠️ TypeScript 类型检查有警告".to_string())
282                } else {
283                    Ok(format!("❌ TypeScript 类型检查失败:\n{}", errors))
284                }
285            }
286            Err(e) => {
287                // tsc not found - don't fail the write
288                Ok(format!("⚠️ 无法运行 tsc: {}", e))
289            }
290        }
291    }
292
293    /// Verify Python code with python -m py_compile
294    async fn verify_python_code(&self, path: &std::path::Path) -> Result<String> {
295        // Quick syntax check with python -m py_compile
296        let output = tokio::process::Command::new("python")
297            .args(["-m", "py_compile"])
298            .arg(path)
299            .output()
300            .await;
301
302        match output {
303            Ok(o) if o.status.success() => {
304                Ok("✅ Python 语法检查通过".to_string())
305            }
306            Ok(o) => {
307                let stderr = String::from_utf8_lossy(&o.stderr);
308                let errors = stderr.lines()
309                    .filter(|l| l.contains("Error") || l.contains("SyntaxError"))
310                    .take(3)
311                    .collect::<Vec<_>>()
312                    .join("\n");
313                if errors.is_empty() {
314                    Ok("⚠️ Python 检查有警告".to_string())
315                } else {
316                    Ok(format!("❌ Python 语法检查失败:\n{}", errors))
317                }
318            }
319            Err(e) => {
320                Ok(format!("⚠️ 无法运行 Python 检查: {}", e))
321            }
322        }
323    }
324
325    /// Verify Go code with go vet
326    async fn verify_go_code(&self, _path: &std::path::Path) -> Result<String> {
327        // Run go vet for quick syntax check
328        let output = tokio::process::Command::new("go")
329            .args(["vet"])
330            .output()
331            .await;
332
333        match output {
334            Ok(o) if o.status.success() => {
335                Ok("✅ go vet 通过".to_string())
336            }
337            Ok(o) => {
338                let stderr = String::from_utf8_lossy(&o.stderr);
339                let errors = stderr.lines()
340                    .filter(|l| l.contains("error") || l.contains("undefined"))
341                    .take(3)
342                    .collect::<Vec<_>>()
343                    .join("\n");
344                if errors.is_empty() {
345                    Ok("⚠️ go vet 有警告".to_string())
346                } else {
347                    Ok(format!("❌ go vet 失败:\n{}", errors))
348                }
349            }
350            Err(e) => {
351                Ok(format!("⚠️ 无法运行 go vet: {}", e))
352            }
353        }
354    }
355}