Skip to main content

matrixcode_core/tools/
code_quality_hook.rs

1//! Code Quality Verification Hook
2//!
3//! This hook verifies code quality before writing files, preventing
4//! invalid code from being written and returning errors to AI for correction.
5//!
6//! # Verification Strategy
7//!
8//! - `none`: No verification
9//! - `post`: Verify after write (default, current behavior)
10//! - `pre`: Verify before write, block if errors
11//! - `pre-quick`: Quick syntax check before write, full check after
12//!
13//! # Workflow
14//!
15//! 1. Detect file type (Rust, TypeScript, Python, Go)
16//! 2. Write to temporary file for verification
17//! 3. Run appropriate verification command
18//! 4. If errors found, block write and return errors to AI
19//! 5. AI corrects code and tries again
20
21use anyhow::Result;
22use async_trait::async_trait;
23use serde_json::Value;
24use std::path::Path;
25use std::sync::Arc;
26use tempfile::TempDir;
27
28use super::tool_hooks::{HookResult, ToolHook};
29use crate::tools::verify::{ProjectType, VerifyTool};
30
31/// Code quality verification hook
32pub struct CodeQualityHook {
33    /// Verification strategy
34    strategy: VerificationStrategy,
35    /// Whether hook is enabled
36    enabled: bool,
37    /// Project root for detection
38    project_root: Option<Arc<std::path::PathBuf>>,
39}
40
41/// Verification strategy
42#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
43pub enum VerificationStrategy {
44    /// No verification
45    None,
46    /// Verify after write (current behavior)
47    #[default]
48    Post,
49    /// Verify before write, block if errors
50    Pre,
51    /// Quick syntax check before, full check after
52    PreQuick,
53}
54
55impl VerificationStrategy {
56    /// Parse from string
57    pub fn from_str(s: &str) -> Self {
58        match s.to_lowercase().as_str() {
59            "none" => Self::None,
60            "post" => Self::Post,
61            "pre" => Self::Pre,
62            "pre-quick" | "prequick" => Self::PreQuick,
63            _ => Self::Post,
64        }
65    }
66
67    /// Convert to string
68    pub fn to_str(&self) -> &'static str {
69        match self {
70            Self::None => "none",
71            Self::Post => "post",
72            Self::Pre => "pre",
73            Self::PreQuick => "pre-quick",
74        }
75    }
76}
77
78impl Default for CodeQualityHook {
79    fn default() -> Self {
80        Self::new(VerificationStrategy::default())
81    }
82}
83
84impl CodeQualityHook {
85    /// Create with strategy
86    pub fn new(strategy: VerificationStrategy) -> Self {
87        Self {
88            strategy,
89            enabled: strategy != VerificationStrategy::None,
90            project_root: None,
91        }
92    }
93
94    /// Create with strategy string
95    pub fn from_strategy_str(strategy: &str) -> Self {
96        Self::new(VerificationStrategy::from_str(strategy))
97    }
98
99    /// Set project root
100    pub fn with_project_root(mut self, root: Arc<std::path::PathBuf>) -> Self {
101        self.project_root = Some(root);
102        self
103    }
104
105    /// Set enabled status
106    pub fn set_enabled(mut self, enabled: bool) -> Self {
107        self.enabled = enabled;
108        self
109    }
110
111    /// Get verification strategy
112    pub fn strategy(&self) -> VerificationStrategy {
113        self.strategy
114    }
115
116    /// Check if file is a code file that needs verification
117    fn is_code_file(path: &str) -> bool {
118        let ext = Path::new(path)
119            .extension()
120            .and_then(|e| e.to_str());
121        matches!(ext, Some("rs" | "ts" | "tsx" | "js" | "jsx" | "py" | "go"))
122    }
123
124    /// Get file extension
125    fn get_extension(path: &str) -> Option<&str> {
126        Path::new(path)
127            .extension()
128            .and_then(|e| e.to_str())
129    }
130
131    /// Detect project type
132    fn detect_project_type(&self) -> ProjectType {
133        if let Some(root) = &self.project_root {
134            VerifyTool::detect_project_type(root.as_ref())
135        } else {
136            // Fallback: detect from current directory
137            let current_dir = std::env::current_dir().ok();
138            current_dir
139                .as_ref()
140                .map(|d| VerifyTool::detect_project_type(d))
141                .unwrap_or(ProjectType::Unknown)
142        }
143    }
144
145    /// Run pre-write verification
146    async fn verify_before_write(&self, path: &str, content: &str) -> Result<HookResult> {
147        // Only verify code files
148        if !Self::is_code_file(path) {
149            return Ok(HookResult::Continue);
150        }
151
152        // Create temporary directory for verification
153        let temp_dir = TempDir::new()?;
154        let temp_path = temp_dir.path().join(Path::new(path).file_name().unwrap_or_default());
155
156        // Write content to temp file
157        tokio::fs::write(&temp_path, content).await?;
158
159        // Detect project type and run appropriate verification
160        let project_type = self.detect_project_type();
161        let extension = Self::get_extension(path);
162
163        let verify_result = match project_type {
164            ProjectType::Rust if extension == Some("rs") => {
165                self.verify_rust(&temp_path).await
166            }
167            ProjectType::NodeJs if matches!(extension, Some("ts" | "tsx")) => {
168                self.verify_typescript(&temp_path).await
169            }
170            ProjectType::Python if extension == Some("py") => {
171                self.verify_python(&temp_path).await
172            }
173            ProjectType::Go if extension == Some("go") => {
174                self.verify_go(&temp_path).await
175            }
176            _ => {
177                // No verification for mismatched types
178                return Ok(HookResult::Continue);
179            }
180        };
181
182        match verify_result {
183            Ok(VerifyOutcome::Pass) => {
184                Ok(HookResult::Continue)
185            }
186            Ok(VerifyOutcome::Fail { errors, warnings }) => {
187                // Build detailed error message for AI correction
188                let reason = if errors.is_empty() {
189                    format!("⚠️ 代码验证发现警告,建议检查:\n{}", warnings.join("\n"))
190                } else {
191                    format!("❌ 代码验证失败,请修正以下错误后再写入:\n{}", errors.join("\n"))
192                };
193
194                let details = if !warnings.is_empty() && !errors.is_empty() {
195                    Some(format!("警告:\n{}\n\n错误:\n{}",
196                        warnings.join("\n"),
197                        errors.join("\n")))
198                } else if !warnings.is_empty() {
199                    Some(format!("警告:\n{}", warnings.join("\n")))
200                } else {
201                    None
202                };
203
204                Ok(HookResult::Block { reason, details })
205            }
206            Err(e) => {
207                // Verification tool not available - don't block
208                log::warn!("Code verification failed: {}", e);
209                Ok(HookResult::Continue)
210            }
211        }
212    }
213
214    /// Verify Rust code with rustfmt and rustc
215    async fn verify_rust(&self, path: &Path) -> Result<VerifyOutcome> {
216        // 1. Quick syntax check with rustfmt --check
217        let fmt_output = tokio::process::Command::new("rustfmt")
218            .arg("--check")
219            .arg(path)
220            .output()
221            .await;
222
223        let mut errors = Vec::new();
224        let mut warnings = Vec::new();
225
226        // Check formatting
227        match fmt_output {
228            Ok(o) if !o.status.success() => {
229                // Format issues - treat as warning, not blocking
230                let stderr = String::from_utf8_lossy(&o.stderr);
231                if !stderr.is_empty() {
232                    warnings.push(format!("格式问题: 建议运行 rustfmt"));
233                }
234            }
235            Err(_) => {
236                // rustfmt not available - skip format check
237            }
238            _ => {}
239        }
240
241        // 2. Syntax check with rustc (fast, no full compilation)
242        // For single file, we can't do full cargo check, but we can catch syntax errors
243        let syntax_output = tokio::process::Command::new("rustc")
244            .arg("--edition=2021")
245            .arg("--emit=metadata")
246            .arg("-o")
247            .arg("/dev/null")  // We just want to check syntax
248            .arg(path)
249            .output()
250            .await;
251
252        match syntax_output {
253            Ok(o) if !o.status.success() => {
254                let stderr = String::from_utf8_lossy(&o.stderr);
255                for line in stderr.lines() {
256                    if line.contains("error") {
257                        errors.push(line.to_string());
258                    } else if line.contains("warning") {
259                        warnings.push(line.to_string());
260                    }
261                }
262            }
263            Err(_) => {
264                // rustc not available - try cargo check
265                // This might not work for single file, but let's try
266            }
267            _ => {}
268        }
269
270        // 3. If we have project context, run cargo check
271        if errors.is_empty() {
272            if let Some(root) = &self.project_root {
273                let cargo_output = tokio::process::Command::new("cargo")
274                    .args(["check", "--quiet"])
275                    .current_dir(root.as_ref())
276                    .output()
277                    .await;
278
279                match cargo_output {
280                    Ok(o) if !o.status.success() => {
281                        let stderr = String::from_utf8_lossy(&o.stderr);
282                        for line in stderr.lines().filter(|l| l.contains("error")) {
283                            errors.push(line.to_string());
284                        }
285                    }
286                    Err(_) => {}
287                    _ => {}
288                }
289            }
290        }
291
292        if errors.is_empty() && warnings.is_empty() {
293            Ok(VerifyOutcome::Pass)
294        } else {
295            Ok(VerifyOutcome::Fail { errors, warnings })
296        }
297    }
298
299    /// Verify TypeScript code with tsc
300    async fn verify_typescript(&self, path: &Path) -> Result<VerifyOutcome> {
301        let mut errors = Vec::new();
302        let mut warnings = Vec::new();
303
304        // For single file verification, we use tsc with the file
305        // Note: This may require a tsconfig.json in the temp directory
306        // For quick check, we just verify syntax
307
308        // Try to run tsc
309        let tsc_output = tokio::process::Command::new("npx")
310            .args(["tsc", "--noEmit", "--skipLibCheck"])
311            .arg(path)
312            .output()
313            .await;
314
315        match tsc_output {
316            Ok(o) if !o.status.success() => {
317                let stderr = String::from_utf8_lossy(&o.stderr);
318                let stdout = String::from_utf8_lossy(&o.stdout);
319
320                for line in stderr.lines().chain(stdout.lines()) {
321                    if line.contains("error TS") {
322                        errors.push(line.to_string());
323                    }
324                }
325            }
326            Err(_) => {
327                // tsc not available - skip
328                warnings.push("tsc 不可用,跳过 TypeScript 验证".to_string());
329            }
330            _ => {}
331        }
332
333        // If we have project context, run project-level check
334        if errors.is_empty() {
335            if let Some(root) = &self.project_root {
336                let project_output = tokio::process::Command::new("npx")
337                    .args(["tsc", "--noEmit"])
338                    .current_dir(root.as_ref())
339                    .output()
340                    .await;
341
342                match project_output {
343                    Ok(o) if !o.status.success() => {
344                        let stderr = String::from_utf8_lossy(&o.stderr);
345                        for line in stderr.lines().filter(|l| l.contains("error TS")) {
346                            errors.push(line.to_string());
347                        }
348                    }
349                    Err(_) => {}
350                    _ => {}
351                }
352            }
353        }
354
355        if errors.is_empty() && warnings.is_empty() {
356            Ok(VerifyOutcome::Pass)
357        } else {
358            Ok(VerifyOutcome::Fail { errors, warnings })
359        }
360    }
361
362    /// Verify Python code with python -m py_compile
363    async fn verify_python(&self, path: &Path) -> Result<VerifyOutcome> {
364        let mut errors = Vec::new();
365        let mut warnings = Vec::new();
366
367        // Quick syntax check
368        let output = tokio::process::Command::new("python")
369            .args(["-m", "py_compile"])
370            .arg(path)
371            .output()
372            .await;
373
374        match output {
375            Ok(o) if !o.status.success() => {
376                let stderr = String::from_utf8_lossy(&o.stderr);
377                for line in stderr.lines() {
378                    if line.contains("SyntaxError") || line.contains("Error") {
379                        errors.push(line.to_string());
380                    }
381                }
382            }
383            Err(_) => {
384                warnings.push("python 不可用,跳过语法验证".to_string());
385            }
386            _ => {}
387        }
388
389        if errors.is_empty() && warnings.is_empty() {
390            Ok(VerifyOutcome::Pass)
391        } else {
392            Ok(VerifyOutcome::Fail { errors, warnings })
393        }
394    }
395
396    /// Verify Go code with go vet
397    async fn verify_go(&self, path: &Path) -> Result<VerifyOutcome> {
398        let mut errors = Vec::new();
399        let mut warnings = Vec::new();
400
401        // Go vet for single file
402        let output = tokio::process::Command::new("go")
403            .args(["vet"])
404            .arg(path)
405            .output()
406            .await;
407
408        match output {
409            Ok(o) if !o.status.success() => {
410                let stderr = String::from_utf8_lossy(&o.stderr);
411                for line in stderr.lines() {
412                    if line.contains("error") || line.contains("undefined") {
413                        errors.push(line.to_string());
414                    }
415                }
416            }
417            Err(_) => {
418                warnings.push("go vet 不可用,跳过验证".to_string());
419            }
420            _ => {}
421        }
422
423        // gofmt check (formatting)
424        let fmt_output = tokio::process::Command::new("gofmt")
425            .args(["-l"])
426            .arg(path)
427            .output()
428            .await;
429
430        match fmt_output {
431            Ok(o) if !o.stdout.is_empty() => {
432                warnings.push("格式问题: 建议运行 gofmt".to_string());
433            }
434            Err(_) => {}
435            _ => {}
436        }
437
438        if errors.is_empty() && warnings.is_empty() {
439            Ok(VerifyOutcome::Pass)
440        } else {
441            Ok(VerifyOutcome::Fail { errors, warnings })
442        }
443    }
444}
445
446/// Verification outcome
447#[derive(Debug, Clone)]
448enum VerifyOutcome {
449    /// Verification passed
450    Pass,
451    /// Verification failed with errors/warnings
452    Fail {
453        errors: Vec<String>,
454        warnings: Vec<String>,
455    },
456}
457
458#[async_trait]
459impl ToolHook for CodeQualityHook {
460    fn name(&self) -> &str {
461        "code_quality"
462    }
463
464    fn is_enabled(&self) -> bool {
465        self.enabled && self.strategy != VerificationStrategy::None
466    }
467
468    fn applies_to(&self) -> Vec<&str> {
469        vec!["write", "edit", "multi_edit"]
470    }
471
472    async fn pre_execute(&self, tool_name: &str, params: &Value) -> Result<HookResult> {
473        // Only run pre-verification for write tool with pre strategy
474        if self.strategy != VerificationStrategy::Pre &&
475           self.strategy != VerificationStrategy::PreQuick {
476            return Ok(HookResult::Continue);
477        }
478
479        // Get path and content from params
480        let path = params["path"].as_str().ok_or_else(||
481            anyhow::anyhow!("missing 'path' in params"))?;
482
483        let content = params["content"].as_str().ok_or_else(||
484            anyhow::anyhow!("missing 'content' in params"))?;
485
486        // For edit/multi_edit, we need to apply the edit first to get full content
487        // This is complex, so we skip pre-verification for edits
488        if tool_name != "write" {
489            return Ok(HookResult::Continue);
490        }
491
492        self.verify_before_write(path, content).await
493    }
494
495    async fn post_execute(&self, _tool_name: &str, _params: &Value, result: &str) -> Result<String> {
496        // Post-verification is handled by WriteTool's own run_code_verification
497        // This hook just adds additional context if needed
498
499        if self.strategy == VerificationStrategy::None {
500            return Ok(result.to_string());
501        }
502
503        // Add hook signature to result (for debugging)
504        Ok(format!("{}\n[code_quality_hook: strategy={}]", result, self.strategy.to_str()))
505    }
506}
507
508#[cfg(test)]
509mod tests {
510    use super::*;
511
512    #[test]
513    fn test_verification_strategy_parse() {
514        assert_eq!(VerificationStrategy::from_str("none"), VerificationStrategy::None);
515        assert_eq!(VerificationStrategy::from_str("post"), VerificationStrategy::Post);
516        assert_eq!(VerificationStrategy::from_str("pre"), VerificationStrategy::Pre);
517        assert_eq!(VerificationStrategy::from_str("pre-quick"), VerificationStrategy::PreQuick);
518        assert_eq!(VerificationStrategy::from_str("invalid"), VerificationStrategy::Post);
519    }
520
521    #[test]
522    fn test_is_code_file() {
523        assert!(CodeQualityHook::is_code_file("test.rs"));
524        assert!(CodeQualityHook::is_code_file("test.ts"));
525        assert!(CodeQualityHook::is_code_file("test.py"));
526        assert!(CodeQualityHook::is_code_file("test.go"));
527        assert!(!CodeQualityHook::is_code_file("test.txt"));
528        assert!(!CodeQualityHook::is_code_file("test.md"));
529    }
530
531    #[test]
532    fn test_hook_applies_to() {
533        let hook = CodeQualityHook::default();
534        let applies_to = hook.applies_to();
535        assert!(applies_to.contains(&"write"));
536        assert!(applies_to.contains(&"edit"));
537        assert!(applies_to.contains(&"multi_edit"));
538        assert!(!applies_to.contains(&"read"));
539    }
540
541    #[tokio::test]
542    async fn test_hook_disabled() {
543        let hook = CodeQualityHook::new(VerificationStrategy::None);
544        assert!(!hook.is_enabled());
545
546        let result = hook.pre_execute("write", &serde_json::json!({
547            "path": "test.rs",
548            "content": "fn main() {}"
549        })).await;
550
551        assert!(matches!(result.unwrap(), HookResult::Continue));
552    }
553}