1use anyhow::Result;
6use async_trait::async_trait;
7use std::collections::HashMap;
8use std::sync::Arc;
9use tokio::sync::RwLock;
10
11use super::context::{ToolContext, ToolResult};
12
13#[derive(Debug, Clone, PartialEq, Eq, Hash)]
15pub enum HookTrigger {
16 PreExecution,
18 PostExecution,
20 OnError,
22}
23
24#[derive(Debug, Clone)]
26pub struct HookContext {
27 pub tool_name: String,
29 pub tool_params: serde_json::Value,
31 pub tool_result: Option<ToolResult>,
33 pub error_message: Option<String>,
35 pub tool_context: ToolContext,
37 pub metadata: HashMap<String, String>,
39}
40
41impl HookContext {
42 pub fn new(
43 tool_name: String,
44 tool_params: serde_json::Value,
45 tool_context: ToolContext,
46 ) -> Self {
47 Self {
48 tool_name,
49 tool_params,
50 tool_result: None,
51 error_message: None,
52 tool_context,
53 metadata: HashMap::new(),
54 }
55 }
56
57 pub fn with_result(mut self, result: ToolResult) -> Self {
58 self.tool_result = Some(result);
59 self
60 }
61
62 pub fn with_error(mut self, error: String) -> Self {
63 self.error_message = Some(error);
64 self
65 }
66
67 pub fn with_metadata(mut self, key: String, value: String) -> Self {
68 self.metadata.insert(key, value);
69 self
70 }
71}
72
73#[async_trait]
75pub trait ToolHook: Send + Sync {
76 fn name(&self) -> &str;
78
79 fn description(&self) -> &str;
81
82 async fn execute(&self, context: &HookContext) -> Result<()>;
84
85 fn should_execute(&self, _context: &HookContext) -> bool {
87 true }
89
90 fn priority(&self) -> u32 {
92 100
93 }
94}
95
96pub struct LoggingHook {
98 name: String,
99 log_level: tracing::Level,
100}
101
102impl LoggingHook {
103 pub fn new(name: String, log_level: tracing::Level) -> Self {
104 Self { name, log_level }
105 }
106}
107
108#[async_trait]
109impl ToolHook for LoggingHook {
110 fn name(&self) -> &str {
111 &self.name
112 }
113
114 fn description(&self) -> &str {
115 "记录工具执行日志"
116 }
117
118 async fn execute(&self, context: &HookContext) -> Result<()> {
119 match self.log_level {
120 tracing::Level::ERROR => {
121 tracing::error!(
122 tool = %context.tool_name,
123 params = %context.tool_params,
124 "工具执行"
125 );
126 }
127 tracing::Level::WARN => {
128 tracing::warn!(
129 tool = %context.tool_name,
130 params = %context.tool_params,
131 "工具执行"
132 );
133 }
134 tracing::Level::INFO => {
135 tracing::info!(
136 tool = %context.tool_name,
137 params = %context.tool_params,
138 "工具执行"
139 );
140 }
141 tracing::Level::DEBUG => {
142 tracing::debug!(
143 tool = %context.tool_name,
144 params = %context.tool_params,
145 "工具执行"
146 );
147 }
148 tracing::Level::TRACE => {
149 tracing::trace!(
150 tool = %context.tool_name,
151 params = %context.tool_params,
152 "工具执行"
153 );
154 }
155 }
156 Ok(())
157 }
158
159 fn priority(&self) -> u32 {
160 10 }
162}
163
164pub struct FileOperationHook {
166 name: String,
167 target_tools: Vec<String>,
168}
169
170impl FileOperationHook {
171 pub fn new(name: String, target_tools: Vec<String>) -> Self {
172 Self { name, target_tools }
173 }
174}
175
176#[async_trait]
177impl ToolHook for FileOperationHook {
178 fn name(&self) -> &str {
179 &self.name
180 }
181
182 fn description(&self) -> &str {
183 "文件操作钩子"
184 }
185
186 async fn execute(&self, context: &HookContext) -> Result<()> {
187 if self
189 .target_tools
190 .iter()
191 .any(|tool| context.tool_name.contains(tool))
192 {
193 tracing::info!("文件操作检测: 工具 {} 正在操作文件", context.tool_name);
194
195 if let Some(path) = context.tool_params.get("path").and_then(|p| p.as_str()) {
197 tracing::debug!("操作文件路径: {}", path);
198 }
199 }
200 Ok(())
201 }
202
203 fn should_execute(&self, context: &HookContext) -> bool {
204 self.target_tools
205 .iter()
206 .any(|tool| context.tool_name.contains(tool))
207 }
208}
209
210pub struct ErrorTrackingHook {
212 name: String,
213 error_history: Arc<RwLock<HashMap<String, Vec<String>>>>,
214}
215
216impl ErrorTrackingHook {
217 pub fn new(name: String) -> Self {
218 Self {
219 name,
220 error_history: Arc::new(RwLock::new(HashMap::new())),
221 }
222 }
223
224 pub async fn get_error_history(&self, tool_name: &str) -> Vec<String> {
226 let history = self.error_history.read().await;
227 history.get(tool_name).cloned().unwrap_or_default()
228 }
229
230 pub async fn is_repeated_error(&self, tool_name: &str, error: &str) -> bool {
232 let history = self.error_history.read().await;
233 if let Some(errors) = history.get(tool_name) {
234 errors
235 .iter()
236 .any(|e| e.contains(error) || error.contains(e))
237 } else {
238 false
239 }
240 }
241}
242
243#[async_trait]
244impl ToolHook for ErrorTrackingHook {
245 fn name(&self) -> &str {
246 &self.name
247 }
248
249 fn description(&self) -> &str {
250 "跟踪工具执行错误"
251 }
252
253 async fn execute(&self, context: &HookContext) -> Result<()> {
254 if let Some(error_msg) = &context.error_message {
255 let mut history = self.error_history.write().await;
256 let tool_errors = history
257 .entry(context.tool_name.clone())
258 .or_insert_with(Vec::new);
259
260 if !tool_errors.iter().any(|e| e == error_msg) {
262 tool_errors.push(error_msg.clone());
263
264 if tool_errors.len() > 10 {
266 tool_errors.remove(0);
267 }
268
269 tracing::warn!(
270 tool = %context.tool_name,
271 error = %error_msg,
272 "记录工具错误"
273 );
274 }
275 }
276 Ok(())
277 }
278
279 fn should_execute(&self, context: &HookContext) -> bool {
280 context.error_message.is_some()
281 }
282}
283
284type HookCollection = HashMap<HookTrigger, Vec<Box<dyn ToolHook>>>;
286
287pub struct ToolHookManager {
289 hooks: Arc<RwLock<HookCollection>>,
290 enabled: bool,
291}
292
293impl ToolHookManager {
294 pub fn new(enabled: bool) -> Self {
296 Self {
297 hooks: Arc::new(RwLock::new(HashMap::new())),
298 enabled,
299 }
300 }
301
302 pub async fn register_hook(&self, trigger: HookTrigger, hook: Box<dyn ToolHook>) {
304 if !self.enabled {
305 return;
306 }
307
308 let mut hooks = self.hooks.write().await;
309 let hook_list = hooks.entry(trigger).or_insert_with(Vec::new);
310 hook_list.push(hook);
311
312 hook_list.sort_by_key(|h| h.priority());
314 }
315
316 pub async fn trigger_hooks(&self, trigger: HookTrigger, context: &HookContext) -> Result<()> {
318 if !self.enabled {
319 return Ok(());
320 }
321
322 let hooks = self.hooks.read().await;
323 if let Some(hook_list) = hooks.get(&trigger) {
324 for hook in hook_list {
325 if hook.should_execute(context) {
326 if let Err(e) = hook.execute(context).await {
327 tracing::warn!("钩子 '{}' 执行失败: {}", hook.name(), e);
328 }
329 }
330 }
331 }
332
333 Ok(())
334 }
335
336 pub async fn hook_count(&self, trigger: HookTrigger) -> usize {
338 let hooks = self.hooks.read().await;
339 hooks.get(&trigger).map(|list| list.len()).unwrap_or(0)
340 }
341
342 pub fn is_enabled(&self) -> bool {
344 self.enabled
345 }
346
347 pub fn set_enabled(&mut self, enabled: bool) {
349 self.enabled = enabled;
350 }
351
352 pub async fn register_default_hooks(&self) {
354 self.register_hook(
356 HookTrigger::PreExecution,
357 Box::new(LoggingHook::new(
358 "pre_execution_log".to_string(),
359 tracing::Level::DEBUG,
360 )),
361 )
362 .await;
363
364 self.register_hook(
365 HookTrigger::PostExecution,
366 Box::new(LoggingHook::new(
367 "post_execution_log".to_string(),
368 tracing::Level::DEBUG,
369 )),
370 )
371 .await;
372
373 self.register_hook(
375 HookTrigger::PreExecution,
376 Box::new(FileOperationHook::new(
377 "file_operation_check".to_string(),
378 vec!["Write".to_string(), "Edit".to_string(), "Read".to_string()],
379 )),
380 )
381 .await;
382
383 self.register_hook(
385 HookTrigger::OnError,
386 Box::new(ErrorTrackingHook::new("error_tracker".to_string())),
387 )
388 .await;
389 }
390}
391
392impl Default for ToolHookManager {
393 fn default() -> Self {
394 Self::new(true)
395 }
396}
397
398#[cfg(test)]
399mod tests {
400 use super::*;
401 use std::path::PathBuf;
402
403 struct TestHook {
404 name: String,
405 executed: Arc<RwLock<bool>>,
406 }
407
408 impl TestHook {
409 fn new(name: String) -> (Self, Arc<RwLock<bool>>) {
410 let executed = Arc::new(RwLock::new(false));
411 let hook = Self {
412 name,
413 executed: executed.clone(),
414 };
415 (hook, executed)
416 }
417 }
418
419 #[async_trait]
420 impl ToolHook for TestHook {
421 fn name(&self) -> &str {
422 &self.name
423 }
424
425 fn description(&self) -> &str {
426 "测试钩子"
427 }
428
429 async fn execute(&self, _context: &HookContext) -> Result<()> {
430 let mut executed = self.executed.write().await;
431 *executed = true;
432 Ok(())
433 }
434 }
435
436 fn create_test_context() -> HookContext {
437 let tool_context = ToolContext::new(PathBuf::from("/tmp"))
438 .with_session_id("test-session")
439 .with_user("test-user");
440
441 HookContext::new(
442 "TestTool".to_string(),
443 serde_json::json!({"test": "value"}),
444 tool_context,
445 )
446 }
447
448 #[tokio::test]
449 async fn test_hook_manager_creation() {
450 let manager = ToolHookManager::new(true);
451 assert!(manager.is_enabled());
452
453 let manager_disabled = ToolHookManager::new(false);
454 assert!(!manager_disabled.is_enabled());
455 }
456
457 #[tokio::test]
458 async fn test_hook_registration_and_execution() {
459 let manager = ToolHookManager::new(true);
460 let (hook, executed) = TestHook::new("test_hook".to_string());
461
462 manager
463 .register_hook(HookTrigger::PreExecution, Box::new(hook))
464 .await;
465
466 assert_eq!(manager.hook_count(HookTrigger::PreExecution).await, 1);
467
468 let context = create_test_context();
469 manager
470 .trigger_hooks(HookTrigger::PreExecution, &context)
471 .await
472 .unwrap();
473
474 let was_executed = *executed.read().await;
475 assert!(was_executed);
476 }
477
478 #[tokio::test]
479 async fn test_error_tracking_hook() {
480 let hook = ErrorTrackingHook::new("error_tracker".to_string());
481
482 let context = create_test_context().with_error("Test error message".to_string());
483
484 hook.execute(&context).await.unwrap();
485
486 let history = hook.get_error_history("TestTool").await;
487 assert_eq!(history.len(), 1);
488 assert_eq!(history[0], "Test error message");
489
490 let is_repeated = hook.is_repeated_error("TestTool", "Test error").await;
491 assert!(is_repeated);
492 }
493
494 #[tokio::test]
495 async fn test_file_operation_hook() {
496 let hook = FileOperationHook::new("file_hook".to_string(), vec!["Write".to_string()]);
497
498 let context = create_test_context();
499 assert!(!hook.should_execute(&context)); let write_context = HookContext::new(
502 "WriteTool".to_string(),
503 serde_json::json!({"path": "/test/file.txt"}),
504 ToolContext::new(PathBuf::from("/tmp")),
505 );
506 assert!(hook.should_execute(&write_context)); }
508}