matrixcode_core/tools/
tool_hooks.rs1use anyhow::Result;
25use async_trait::async_trait;
26use serde_json::Value;
27use std::sync::Arc;
28
29#[derive(Debug, Clone)]
31pub enum HookResult {
32 Continue,
34 Block {
36 reason: String,
37 details: Option<String>,
39 },
40 Modify(Value),
42}
43
44#[async_trait]
46pub trait ToolHook: Send + Sync {
47 fn name(&self) -> &str;
49
50 fn is_enabled(&self) -> bool;
52
53 fn applies_to(&self) -> Vec<&str> {
55 Vec::new()
56 }
57
58 fn applies_to_tool(&self, tool_name: &str) -> bool {
60 let applies_to = self.applies_to();
61 applies_to.is_empty() || applies_to.iter().any(|t| *t == tool_name)
62 }
63
64 async fn pre_execute(&self, tool_name: &str, params: &Value) -> Result<HookResult>;
67
68 async fn post_execute(&self, tool_name: &str, params: &Value, result: &str) -> Result<String>;
71}
72
73pub struct HookRegistry {
75 hooks: Vec<Box<dyn ToolHook>>,
76}
77
78impl Default for HookRegistry {
79 fn default() -> Self {
80 Self::new()
81 }
82}
83
84impl HookRegistry {
85 pub fn new() -> Self {
87 Self { hooks: Vec::new() }
88 }
89
90 pub fn with_defaults() -> Self {
92 Self::new()
93 }
94
95 pub fn register(&mut self, hook: Box<dyn ToolHook>) {
97 self.hooks.push(hook);
98 }
99
100 pub fn hooks(&self) -> &[Box<dyn ToolHook>] {
102 &self.hooks
103 }
104
105 pub async fn pre_execute(&self, tool_name: &str, params: &Value) -> Result<HookResult> {
108 let mut current_params = params.clone();
109
110 for hook in &self.hooks {
111 if !hook.is_enabled() || !hook.applies_to_tool(tool_name) {
112 continue;
113 }
114
115 let result = hook.pre_execute(tool_name, ¤t_params).await?;
116
117 match result {
118 HookResult::Block { .. } => {
119 return Ok(result);
121 }
122 HookResult::Modify(new_params) => {
123 current_params = new_params;
125 }
126 HookResult::Continue => {
127 }
129 }
130 }
131
132 if current_params != *params {
134 Ok(HookResult::Modify(current_params))
135 } else {
136 Ok(HookResult::Continue)
137 }
138 }
139
140 pub async fn post_execute(&self, tool_name: &str, params: &Value, result: &str) -> Result<String> {
143 let mut current_result = result.to_string();
144
145 for hook in &self.hooks {
146 if !hook.is_enabled() || !hook.applies_to_tool(tool_name) {
147 continue;
148 }
149
150 current_result = hook.post_execute(tool_name, params, ¤t_result).await?;
151 }
152
153 Ok(current_result)
154 }
155}
156
157static GLOBAL_HOOK_REGISTRY: std::sync::OnceLock<Arc<HookRegistry>> = std::sync::OnceLock::new();
159
160pub fn global_hook_registry() -> Arc<HookRegistry> {
162 GLOBAL_HOOK_REGISTRY
163 .get_or_init(|| Arc::new(HookRegistry::with_defaults()))
164 .clone()
165}
166
167pub fn set_global_hook_registry(registry: HookRegistry) {
169 let _ = GLOBAL_HOOK_REGISTRY.set(Arc::new(registry));
170}
171
172#[cfg(test)]
173mod tests {
174 use super::*;
175
176 struct TestHook {
177 enabled: bool,
178 block: bool,
179 }
180
181 #[async_trait]
182 impl ToolHook for TestHook {
183 fn name(&self) -> &str {
184 "test_hook"
185 }
186
187 fn is_enabled(&self) -> bool {
188 self.enabled
189 }
190
191 fn applies_to(&self) -> Vec<&str> {
192 vec!["write"]
193 }
194
195 async fn pre_execute(&self, _tool_name: &str, _params: &Value) -> Result<HookResult> {
196 if self.block {
197 Ok(HookResult::Block {
198 reason: "Test block".to_string(),
199 details: Some("Test details".to_string()),
200 })
201 } else {
202 Ok(HookResult::Continue)
203 }
204 }
205
206 async fn post_execute(&self, _tool_name: &str, _params: &Value, result: &str) -> Result<String> {
207 Ok(format!("{} [hooked]", result))
208 }
209 }
210
211 #[tokio::test]
212 async fn test_hook_registry_pre_execute_continue() {
213 let mut registry = HookRegistry::new();
214 registry.register(Box::new(TestHook { enabled: true, block: false }));
215
216 let result = registry.pre_execute("write", &serde_json::json!({})).await;
217 assert!(matches!(result.unwrap(), HookResult::Continue));
218 }
219
220 #[tokio::test]
221 async fn test_hook_registry_pre_execute_block() {
222 let mut registry = HookRegistry::new();
223 registry.register(Box::new(TestHook { enabled: true, block: true }));
224
225 let result = registry.pre_execute("write", &serde_json::json!({})).await;
226 assert!(matches!(result.unwrap(), HookResult::Block { .. }));
227 }
228
229 #[tokio::test]
230 async fn test_hook_registry_disabled_hook() {
231 let mut registry = HookRegistry::new();
232 registry.register(Box::new(TestHook { enabled: false, block: true }));
233
234 let result = registry.pre_execute("write", &serde_json::json!({})).await;
235 assert!(matches!(result.unwrap(), HookResult::Continue));
236 }
237
238 #[tokio::test]
239 async fn test_hook_registry_tool_filter() {
240 let mut registry = HookRegistry::new();
241 registry.register(Box::new(TestHook { enabled: true, block: true }));
242
243 let result = registry.pre_execute("read", &serde_json::json!({})).await;
245 assert!(matches!(result.unwrap(), HookResult::Continue));
246
247 let result = registry.pre_execute("write", &serde_json::json!({})).await;
249 assert!(matches!(result.unwrap(), HookResult::Block { .. }));
250 }
251
252 #[tokio::test]
253 async fn test_hook_registry_post_execute() {
254 let mut registry = HookRegistry::new();
255 registry.register(Box::new(TestHook { enabled: true, block: false }));
256
257 let result = registry.post_execute("write", &serde_json::json!({}), "original").await;
258 assert_eq!(result.unwrap(), "original [hooked]");
259 }
260}