use anyhow::Result;
use async_trait::async_trait;
use serde_json::Value;
use std::path::Path;
use std::sync::Arc;
use tempfile::TempDir;
use super::tool_hooks::{HookResult, ToolHook};
use crate::tools::verify::{ProjectType, VerifyTool};
pub struct CodeQualityHook {
strategy: VerificationStrategy,
enabled: bool,
project_root: Option<Arc<std::path::PathBuf>>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum VerificationStrategy {
None,
#[default]
Post,
Pre,
PreQuick,
}
impl VerificationStrategy {
pub fn from_str(s: &str) -> Self {
match s.to_lowercase().as_str() {
"none" => Self::None,
"post" => Self::Post,
"pre" => Self::Pre,
"pre-quick" | "prequick" => Self::PreQuick,
_ => Self::Post,
}
}
pub fn to_str(&self) -> &'static str {
match self {
Self::None => "none",
Self::Post => "post",
Self::Pre => "pre",
Self::PreQuick => "pre-quick",
}
}
}
impl Default for CodeQualityHook {
fn default() -> Self {
Self::new(VerificationStrategy::default())
}
}
impl CodeQualityHook {
pub fn new(strategy: VerificationStrategy) -> Self {
Self {
strategy,
enabled: strategy != VerificationStrategy::None,
project_root: None,
}
}
pub fn from_strategy_str(strategy: &str) -> Self {
Self::new(VerificationStrategy::from_str(strategy))
}
pub fn with_project_root(mut self, root: Arc<std::path::PathBuf>) -> Self {
self.project_root = Some(root);
self
}
pub fn set_enabled(mut self, enabled: bool) -> Self {
self.enabled = enabled;
self
}
pub fn strategy(&self) -> VerificationStrategy {
self.strategy
}
fn is_code_file(path: &str) -> bool {
let ext = Path::new(path)
.extension()
.and_then(|e| e.to_str());
matches!(ext, Some("rs" | "ts" | "tsx" | "js" | "jsx" | "py" | "go"))
}
fn get_extension(path: &str) -> Option<&str> {
Path::new(path)
.extension()
.and_then(|e| e.to_str())
}
fn detect_project_type(&self) -> ProjectType {
if let Some(root) = &self.project_root {
VerifyTool::detect_project_type(root.as_ref())
} else {
let current_dir = std::env::current_dir().ok();
current_dir
.as_ref()
.map(|d| VerifyTool::detect_project_type(d))
.unwrap_or(ProjectType::Unknown)
}
}
async fn verify_before_write(&self, path: &str, content: &str) -> Result<HookResult> {
if !Self::is_code_file(path) {
return Ok(HookResult::Continue);
}
let temp_dir = TempDir::new()?;
let temp_path = temp_dir.path().join(Path::new(path).file_name().unwrap_or_default());
tokio::fs::write(&temp_path, content).await?;
let project_type = self.detect_project_type();
let extension = Self::get_extension(path);
let verify_result = match project_type {
ProjectType::Rust if extension == Some("rs") => {
self.verify_rust(&temp_path).await
}
ProjectType::NodeJs if matches!(extension, Some("ts" | "tsx")) => {
self.verify_typescript(&temp_path).await
}
ProjectType::Python if extension == Some("py") => {
self.verify_python(&temp_path).await
}
ProjectType::Go if extension == Some("go") => {
self.verify_go(&temp_path).await
}
_ => {
return Ok(HookResult::Continue);
}
};
match verify_result {
Ok(VerifyOutcome::Pass) => {
Ok(HookResult::Continue)
}
Ok(VerifyOutcome::Fail { errors, warnings }) => {
let reason = if errors.is_empty() {
format!("⚠️ 代码验证发现警告,建议检查:\n{}", warnings.join("\n"))
} else {
format!("❌ 代码验证失败,请修正以下错误后再写入:\n{}", errors.join("\n"))
};
let details = if !warnings.is_empty() && !errors.is_empty() {
Some(format!("警告:\n{}\n\n错误:\n{}",
warnings.join("\n"),
errors.join("\n")))
} else if !warnings.is_empty() {
Some(format!("警告:\n{}", warnings.join("\n")))
} else {
None
};
Ok(HookResult::Block { reason, details })
}
Err(e) => {
log::warn!("Code verification failed: {}", e);
Ok(HookResult::Continue)
}
}
}
async fn verify_rust(&self, path: &Path) -> Result<VerifyOutcome> {
let fmt_output = tokio::process::Command::new("rustfmt")
.arg("--check")
.arg(path)
.output()
.await;
let mut errors = Vec::new();
let mut warnings = Vec::new();
match fmt_output {
Ok(o) if !o.status.success() => {
let stderr = String::from_utf8_lossy(&o.stderr);
if !stderr.is_empty() {
warnings.push(format!("格式问题: 建议运行 rustfmt"));
}
}
Err(_) => {
}
_ => {}
}
let syntax_output = tokio::process::Command::new("rustc")
.arg("--edition=2021")
.arg("--emit=metadata")
.arg("-o")
.arg("/dev/null") .arg(path)
.output()
.await;
match syntax_output {
Ok(o) if !o.status.success() => {
let stderr = String::from_utf8_lossy(&o.stderr);
for line in stderr.lines() {
if line.contains("error") {
errors.push(line.to_string());
} else if line.contains("warning") {
warnings.push(line.to_string());
}
}
}
Err(_) => {
}
_ => {}
}
if errors.is_empty() {
if let Some(root) = &self.project_root {
let cargo_output = tokio::process::Command::new("cargo")
.args(["check", "--quiet"])
.current_dir(root.as_ref())
.output()
.await;
match cargo_output {
Ok(o) if !o.status.success() => {
let stderr = String::from_utf8_lossy(&o.stderr);
for line in stderr.lines().filter(|l| l.contains("error")) {
errors.push(line.to_string());
}
}
Err(_) => {}
_ => {}
}
}
}
if errors.is_empty() && warnings.is_empty() {
Ok(VerifyOutcome::Pass)
} else {
Ok(VerifyOutcome::Fail { errors, warnings })
}
}
async fn verify_typescript(&self, path: &Path) -> Result<VerifyOutcome> {
let mut errors = Vec::new();
let mut warnings = Vec::new();
let tsc_output = tokio::process::Command::new("npx")
.args(["tsc", "--noEmit", "--skipLibCheck"])
.arg(path)
.output()
.await;
match tsc_output {
Ok(o) if !o.status.success() => {
let stderr = String::from_utf8_lossy(&o.stderr);
let stdout = String::from_utf8_lossy(&o.stdout);
for line in stderr.lines().chain(stdout.lines()) {
if line.contains("error TS") {
errors.push(line.to_string());
}
}
}
Err(_) => {
warnings.push("tsc 不可用,跳过 TypeScript 验证".to_string());
}
_ => {}
}
if errors.is_empty() {
if let Some(root) = &self.project_root {
let project_output = tokio::process::Command::new("npx")
.args(["tsc", "--noEmit"])
.current_dir(root.as_ref())
.output()
.await;
match project_output {
Ok(o) if !o.status.success() => {
let stderr = String::from_utf8_lossy(&o.stderr);
for line in stderr.lines().filter(|l| l.contains("error TS")) {
errors.push(line.to_string());
}
}
Err(_) => {}
_ => {}
}
}
}
if errors.is_empty() && warnings.is_empty() {
Ok(VerifyOutcome::Pass)
} else {
Ok(VerifyOutcome::Fail { errors, warnings })
}
}
async fn verify_python(&self, path: &Path) -> Result<VerifyOutcome> {
let mut errors = Vec::new();
let mut warnings = Vec::new();
let output = tokio::process::Command::new("python")
.args(["-m", "py_compile"])
.arg(path)
.output()
.await;
match output {
Ok(o) if !o.status.success() => {
let stderr = String::from_utf8_lossy(&o.stderr);
for line in stderr.lines() {
if line.contains("SyntaxError") || line.contains("Error") {
errors.push(line.to_string());
}
}
}
Err(_) => {
warnings.push("python 不可用,跳过语法验证".to_string());
}
_ => {}
}
if errors.is_empty() && warnings.is_empty() {
Ok(VerifyOutcome::Pass)
} else {
Ok(VerifyOutcome::Fail { errors, warnings })
}
}
async fn verify_go(&self, path: &Path) -> Result<VerifyOutcome> {
let mut errors = Vec::new();
let mut warnings = Vec::new();
let output = tokio::process::Command::new("go")
.args(["vet"])
.arg(path)
.output()
.await;
match output {
Ok(o) if !o.status.success() => {
let stderr = String::from_utf8_lossy(&o.stderr);
for line in stderr.lines() {
if line.contains("error") || line.contains("undefined") {
errors.push(line.to_string());
}
}
}
Err(_) => {
warnings.push("go vet 不可用,跳过验证".to_string());
}
_ => {}
}
let fmt_output = tokio::process::Command::new("gofmt")
.args(["-l"])
.arg(path)
.output()
.await;
match fmt_output {
Ok(o) if !o.stdout.is_empty() => {
warnings.push("格式问题: 建议运行 gofmt".to_string());
}
Err(_) => {}
_ => {}
}
if errors.is_empty() && warnings.is_empty() {
Ok(VerifyOutcome::Pass)
} else {
Ok(VerifyOutcome::Fail { errors, warnings })
}
}
}
#[derive(Debug, Clone)]
enum VerifyOutcome {
Pass,
Fail {
errors: Vec<String>,
warnings: Vec<String>,
},
}
#[async_trait]
impl ToolHook for CodeQualityHook {
fn name(&self) -> &str {
"code_quality"
}
fn is_enabled(&self) -> bool {
self.enabled && self.strategy != VerificationStrategy::None
}
fn applies_to(&self) -> Vec<&str> {
vec!["write", "edit", "multi_edit"]
}
async fn pre_execute(&self, tool_name: &str, params: &Value) -> Result<HookResult> {
if self.strategy != VerificationStrategy::Pre &&
self.strategy != VerificationStrategy::PreQuick {
return Ok(HookResult::Continue);
}
let path = params["path"].as_str().ok_or_else(||
anyhow::anyhow!("missing 'path' in params"))?;
let content = params["content"].as_str().ok_or_else(||
anyhow::anyhow!("missing 'content' in params"))?;
if tool_name != "write" {
return Ok(HookResult::Continue);
}
self.verify_before_write(path, content).await
}
async fn post_execute(&self, _tool_name: &str, _params: &Value, result: &str) -> Result<String> {
if self.strategy == VerificationStrategy::None {
return Ok(result.to_string());
}
Ok(format!("{}\n[code_quality_hook: strategy={}]", result, self.strategy.to_str()))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_verification_strategy_parse() {
assert_eq!(VerificationStrategy::from_str("none"), VerificationStrategy::None);
assert_eq!(VerificationStrategy::from_str("post"), VerificationStrategy::Post);
assert_eq!(VerificationStrategy::from_str("pre"), VerificationStrategy::Pre);
assert_eq!(VerificationStrategy::from_str("pre-quick"), VerificationStrategy::PreQuick);
assert_eq!(VerificationStrategy::from_str("invalid"), VerificationStrategy::Post);
}
#[test]
fn test_is_code_file() {
assert!(CodeQualityHook::is_code_file("test.rs"));
assert!(CodeQualityHook::is_code_file("test.ts"));
assert!(CodeQualityHook::is_code_file("test.py"));
assert!(CodeQualityHook::is_code_file("test.go"));
assert!(!CodeQualityHook::is_code_file("test.txt"));
assert!(!CodeQualityHook::is_code_file("test.md"));
}
#[test]
fn test_hook_applies_to() {
let hook = CodeQualityHook::default();
let applies_to = hook.applies_to();
assert!(applies_to.contains(&"write"));
assert!(applies_to.contains(&"edit"));
assert!(applies_to.contains(&"multi_edit"));
assert!(!applies_to.contains(&"read"));
}
#[tokio::test]
async fn test_hook_disabled() {
let hook = CodeQualityHook::new(VerificationStrategy::None);
assert!(!hook.is_enabled());
let result = hook.pre_execute("write", &serde_json::json!({
"path": "test.rs",
"content": "fn main() {}"
})).await;
assert!(matches!(result.unwrap(), HookResult::Continue));
}
}