Skip to main content

matrixcode_core/tools/
tool_hooks.rs

1//! Tool Execution Hooks System
2//!
3//! This module provides a hook system for tool execution, allowing
4//! pre and post execution checks and modifications.
5//!
6//! # Hook Types
7//!
8//! - `PreExecute`: Before tool execution, can block or modify params
9//! - `PostExecute`: After tool execution, can modify result
10//!
11//! # Usage
12//!
13//! ```rust
14//! // Create a hook registry
15//! let registry = HookRegistry::new();
16//!
17//! // Register hooks
18//! registry.register(Box::new(CodeQualityHook::new("pre")));
19//!
20//! // Execute with hooks
21//! let result = registry.execute_with_hooks(tool, params);
22//! ```
23
24use anyhow::Result;
25use async_trait::async_trait;
26use serde_json::Value;
27use std::sync::Arc;
28
29/// Hook execution result
30#[derive(Debug, Clone)]
31pub enum HookResult {
32    /// Continue with original params
33    Continue,
34    /// Block execution with reason (returned to AI for correction)
35    Block {
36        reason: String,
37        /// Detailed error information for AI to correct
38        details: Option<String>,
39    },
40    /// Modify params before execution
41    Modify(Value),
42}
43
44/// Tool execution hook trait
45#[async_trait]
46pub trait ToolHook: Send + Sync {
47    /// Hook name for identification
48    fn name(&self) -> &str;
49
50    /// Whether this hook is enabled
51    fn is_enabled(&self) -> bool;
52
53    /// Tools this hook applies to (empty = all tools)
54    fn applies_to(&self) -> Vec<&str> {
55        Vec::new()
56    }
57
58    /// Check if hook applies to a specific tool
59    fn applies_to_tool(&self, tool_name: &str) -> bool {
60        let applies_to = self.applies_to();
61        applies_to.is_empty() || applies_to.iter().any(|t| *t == tool_name)
62    }
63
64    /// Pre-execute hook (before tool execution)
65    /// Returns HookResult to control execution flow
66    async fn pre_execute(&self, tool_name: &str, params: &Value) -> Result<HookResult>;
67
68    /// Post-execute hook (after tool execution)
69    /// Can modify the result before returning to AI
70    async fn post_execute(&self, tool_name: &str, params: &Value, result: &str) -> Result<String>;
71}
72
73/// Hook registry for managing multiple hooks
74pub struct HookRegistry {
75    hooks: Vec<Box<dyn ToolHook>>,
76}
77
78impl Default for HookRegistry {
79    fn default() -> Self {
80        Self::new()
81    }
82}
83
84impl HookRegistry {
85    /// Create empty registry
86    pub fn new() -> Self {
87        Self { hooks: Vec::new() }
88    }
89
90    /// Create registry with default hooks
91    pub fn with_defaults() -> Self {
92        Self::new()
93    }
94
95    /// Register a hook
96    pub fn register(&mut self, hook: Box<dyn ToolHook>) {
97        self.hooks.push(hook);
98    }
99
100    /// Get all registered hooks
101    pub fn hooks(&self) -> &[Box<dyn ToolHook>] {
102        &self.hooks
103    }
104
105    /// Run pre-execute hooks for a tool
106    /// Returns first Block result, or Continue/Modify from last hook
107    pub async fn pre_execute(&self, tool_name: &str, params: &Value) -> Result<HookResult> {
108        let mut current_params = params.clone();
109
110        for hook in &self.hooks {
111            if !hook.is_enabled() || !hook.applies_to_tool(tool_name) {
112                continue;
113            }
114
115            let result = hook.pre_execute(tool_name, &current_params).await?;
116
117            match result {
118                HookResult::Block { .. } => {
119                    // Block immediately, return to AI
120                    return Ok(result);
121                }
122                HookResult::Modify(new_params) => {
123                    // Update params for next hook
124                    current_params = new_params;
125                }
126                HookResult::Continue => {
127                    // Continue to next hook
128                }
129            }
130        }
131
132        // If params were modified, return Modify; otherwise Continue
133        if current_params != *params {
134            Ok(HookResult::Modify(current_params))
135        } else {
136            Ok(HookResult::Continue)
137        }
138    }
139
140    /// Run post-execute hooks for a tool
141    /// Returns modified result after all hooks
142    pub async fn post_execute(&self, tool_name: &str, params: &Value, result: &str) -> Result<String> {
143        let mut current_result = result.to_string();
144
145        for hook in &self.hooks {
146            if !hook.is_enabled() || !hook.applies_to_tool(tool_name) {
147                continue;
148            }
149
150            current_result = hook.post_execute(tool_name, params, &current_result).await?;
151        }
152
153        Ok(current_result)
154    }
155}
156
157/// Global hook registry instance
158static GLOBAL_HOOK_REGISTRY: std::sync::OnceLock<Arc<HookRegistry>> = std::sync::OnceLock::new();
159
160/// Get the global hook registry
161pub fn global_hook_registry() -> Arc<HookRegistry> {
162    GLOBAL_HOOK_REGISTRY
163        .get_or_init(|| Arc::new(HookRegistry::with_defaults()))
164        .clone()
165}
166
167/// Set the global hook registry
168pub fn set_global_hook_registry(registry: HookRegistry) {
169    let _ = GLOBAL_HOOK_REGISTRY.set(Arc::new(registry));
170}
171
172#[cfg(test)]
173mod tests {
174    use super::*;
175
176    struct TestHook {
177        enabled: bool,
178        block: bool,
179    }
180
181    #[async_trait]
182    impl ToolHook for TestHook {
183        fn name(&self) -> &str {
184            "test_hook"
185        }
186
187        fn is_enabled(&self) -> bool {
188            self.enabled
189        }
190
191        fn applies_to(&self) -> Vec<&str> {
192            vec!["write"]
193        }
194
195        async fn pre_execute(&self, _tool_name: &str, _params: &Value) -> Result<HookResult> {
196            if self.block {
197                Ok(HookResult::Block {
198                    reason: "Test block".to_string(),
199                    details: Some("Test details".to_string()),
200                })
201            } else {
202                Ok(HookResult::Continue)
203            }
204        }
205
206        async fn post_execute(&self, _tool_name: &str, _params: &Value, result: &str) -> Result<String> {
207            Ok(format!("{} [hooked]", result))
208        }
209    }
210
211    #[tokio::test]
212    async fn test_hook_registry_pre_execute_continue() {
213        let mut registry = HookRegistry::new();
214        registry.register(Box::new(TestHook { enabled: true, block: false }));
215
216        let result = registry.pre_execute("write", &serde_json::json!({})).await;
217        assert!(matches!(result.unwrap(), HookResult::Continue));
218    }
219
220    #[tokio::test]
221    async fn test_hook_registry_pre_execute_block() {
222        let mut registry = HookRegistry::new();
223        registry.register(Box::new(TestHook { enabled: true, block: true }));
224
225        let result = registry.pre_execute("write", &serde_json::json!({})).await;
226        assert!(matches!(result.unwrap(), HookResult::Block { .. }));
227    }
228
229    #[tokio::test]
230    async fn test_hook_registry_disabled_hook() {
231        let mut registry = HookRegistry::new();
232        registry.register(Box::new(TestHook { enabled: false, block: true }));
233
234        let result = registry.pre_execute("write", &serde_json::json!({})).await;
235        assert!(matches!(result.unwrap(), HookResult::Continue));
236    }
237
238    #[tokio::test]
239    async fn test_hook_registry_tool_filter() {
240        let mut registry = HookRegistry::new();
241        registry.register(Box::new(TestHook { enabled: true, block: true }));
242
243        // Should not block for tools not in applies_to
244        let result = registry.pre_execute("read", &serde_json::json!({})).await;
245        assert!(matches!(result.unwrap(), HookResult::Continue));
246
247        // Should block for tools in applies_to
248        let result = registry.pre_execute("write", &serde_json::json!({})).await;
249        assert!(matches!(result.unwrap(), HookResult::Block { .. }));
250    }
251
252    #[tokio::test]
253    async fn test_hook_registry_post_execute() {
254        let mut registry = HookRegistry::new();
255        registry.register(Box::new(TestHook { enabled: true, block: false }));
256
257        let result = registry.post_execute("write", &serde_json::json!({}), "original").await;
258        assert_eq!(result.unwrap(), "original [hooked]");
259    }
260}