Skip to main content

aster/tools/
hooks.rs

1//! 工具钩子系统
2//!
3//! 为工具执行提供钩子支持,允许在工具执行前后触发自定义逻辑
4
5use anyhow::Result;
6use async_trait::async_trait;
7use std::collections::HashMap;
8use std::sync::Arc;
9use tokio::sync::RwLock;
10
11use super::context::{ToolContext, ToolResult};
12
13/// 钩子触发时机
14#[derive(Debug, Clone, PartialEq, Eq, Hash)]
15pub enum HookTrigger {
16    /// 工具执行前
17    PreExecution,
18    /// 工具执行后
19    PostExecution,
20    /// 工具执行失败时
21    OnError,
22}
23
24/// 钩子执行上下文
25#[derive(Debug, Clone)]
26pub struct HookContext {
27    /// 工具名称
28    pub tool_name: String,
29    /// 工具参数
30    pub tool_params: serde_json::Value,
31    /// 工具执行结果(仅在 PostExecution 时有效)
32    pub tool_result: Option<ToolResult>,
33    /// 错误信息(仅在 OnError 时有效)
34    pub error_message: Option<String>,
35    /// 工具执行上下文
36    pub tool_context: ToolContext,
37    /// 额外元数据
38    pub metadata: HashMap<String, String>,
39}
40
41impl HookContext {
42    pub fn new(
43        tool_name: String,
44        tool_params: serde_json::Value,
45        tool_context: ToolContext,
46    ) -> Self {
47        Self {
48            tool_name,
49            tool_params,
50            tool_result: None,
51            error_message: None,
52            tool_context,
53            metadata: HashMap::new(),
54        }
55    }
56
57    pub fn with_result(mut self, result: ToolResult) -> Self {
58        self.tool_result = Some(result);
59        self
60    }
61
62    pub fn with_error(mut self, error: String) -> Self {
63        self.error_message = Some(error);
64        self
65    }
66
67    pub fn with_metadata(mut self, key: String, value: String) -> Self {
68        self.metadata.insert(key, value);
69        self
70    }
71}
72
73/// 工具钩子特征
74#[async_trait]
75pub trait ToolHook: Send + Sync {
76    /// 钩子名称
77    fn name(&self) -> &str;
78
79    /// 钩子描述
80    fn description(&self) -> &str;
81
82    /// 执行钩子
83    async fn execute(&self, context: &HookContext) -> Result<()>;
84
85    /// 检查是否应该执行此钩子
86    fn should_execute(&self, _context: &HookContext) -> bool {
87        true // 默认总是执行
88    }
89
90    /// 钩子优先级(数字越小优先级越高)
91    fn priority(&self) -> u32 {
92        100
93    }
94}
95
96/// 日志钩子 - 记录工具执行日志
97pub struct LoggingHook {
98    name: String,
99    log_level: tracing::Level,
100}
101
102impl LoggingHook {
103    pub fn new(name: String, log_level: tracing::Level) -> Self {
104        Self { name, log_level }
105    }
106}
107
108#[async_trait]
109impl ToolHook for LoggingHook {
110    fn name(&self) -> &str {
111        &self.name
112    }
113
114    fn description(&self) -> &str {
115        "记录工具执行日志"
116    }
117
118    async fn execute(&self, context: &HookContext) -> Result<()> {
119        match self.log_level {
120            tracing::Level::ERROR => {
121                tracing::error!(
122                    tool = %context.tool_name,
123                    params = %context.tool_params,
124                    "工具执行"
125                );
126            }
127            tracing::Level::WARN => {
128                tracing::warn!(
129                    tool = %context.tool_name,
130                    params = %context.tool_params,
131                    "工具执行"
132                );
133            }
134            tracing::Level::INFO => {
135                tracing::info!(
136                    tool = %context.tool_name,
137                    params = %context.tool_params,
138                    "工具执行"
139                );
140            }
141            tracing::Level::DEBUG => {
142                tracing::debug!(
143                    tool = %context.tool_name,
144                    params = %context.tool_params,
145                    "工具执行"
146                );
147            }
148            tracing::Level::TRACE => {
149                tracing::trace!(
150                    tool = %context.tool_name,
151                    params = %context.tool_params,
152                    "工具执行"
153                );
154            }
155        }
156        Ok(())
157    }
158
159    fn priority(&self) -> u32 {
160        10 // 高优先级,确保日志记录
161    }
162}
163
164/// 文件操作钩子 - 在文件操作前后执行特定逻辑
165pub struct FileOperationHook {
166    name: String,
167    target_tools: Vec<String>,
168}
169
170impl FileOperationHook {
171    pub fn new(name: String, target_tools: Vec<String>) -> Self {
172        Self { name, target_tools }
173    }
174}
175
176#[async_trait]
177impl ToolHook for FileOperationHook {
178    fn name(&self) -> &str {
179        &self.name
180    }
181
182    fn description(&self) -> &str {
183        "文件操作钩子"
184    }
185
186    async fn execute(&self, context: &HookContext) -> Result<()> {
187        // 检查是否是文件操作工具
188        if self
189            .target_tools
190            .iter()
191            .any(|tool| context.tool_name.contains(tool))
192        {
193            tracing::info!("文件操作检测: 工具 {} 正在操作文件", context.tool_name);
194
195            // 可以在这里添加文件备份、权限检查等逻辑
196            if let Some(path) = context.tool_params.get("path").and_then(|p| p.as_str()) {
197                tracing::debug!("操作文件路径: {}", path);
198            }
199        }
200        Ok(())
201    }
202
203    fn should_execute(&self, context: &HookContext) -> bool {
204        self.target_tools
205            .iter()
206            .any(|tool| context.tool_name.contains(tool))
207    }
208}
209
210/// 错误跟踪钩子 - 跟踪和学习错误模式
211pub struct ErrorTrackingHook {
212    name: String,
213    error_history: Arc<RwLock<HashMap<String, Vec<String>>>>,
214}
215
216impl ErrorTrackingHook {
217    pub fn new(name: String) -> Self {
218        Self {
219            name,
220            error_history: Arc::new(RwLock::new(HashMap::new())),
221        }
222    }
223
224    /// 获取工具的错误历史
225    pub async fn get_error_history(&self, tool_name: &str) -> Vec<String> {
226        let history = self.error_history.read().await;
227        history.get(tool_name).cloned().unwrap_or_default()
228    }
229
230    /// 检查是否是重复错误
231    pub async fn is_repeated_error(&self, tool_name: &str, error: &str) -> bool {
232        let history = self.error_history.read().await;
233        if let Some(errors) = history.get(tool_name) {
234            errors
235                .iter()
236                .any(|e| e.contains(error) || error.contains(e))
237        } else {
238            false
239        }
240    }
241}
242
243#[async_trait]
244impl ToolHook for ErrorTrackingHook {
245    fn name(&self) -> &str {
246        &self.name
247    }
248
249    fn description(&self) -> &str {
250        "跟踪工具执行错误"
251    }
252
253    async fn execute(&self, context: &HookContext) -> Result<()> {
254        if let Some(error_msg) = &context.error_message {
255            let mut history = self.error_history.write().await;
256            let tool_errors = history
257                .entry(context.tool_name.clone())
258                .or_insert_with(Vec::new);
259
260            // 避免重复记录相同错误
261            if !tool_errors.iter().any(|e| e == error_msg) {
262                tool_errors.push(error_msg.clone());
263
264                // 限制历史记录数量
265                if tool_errors.len() > 10 {
266                    tool_errors.remove(0);
267                }
268
269                tracing::warn!(
270                    tool = %context.tool_name,
271                    error = %error_msg,
272                    "记录工具错误"
273                );
274            }
275        }
276        Ok(())
277    }
278
279    fn should_execute(&self, context: &HookContext) -> bool {
280        context.error_message.is_some()
281    }
282}
283
284/// 钩子集合类型别名
285type HookCollection = HashMap<HookTrigger, Vec<Box<dyn ToolHook>>>;
286
287/// 工具钩子管理器
288pub struct ToolHookManager {
289    hooks: Arc<RwLock<HookCollection>>,
290    enabled: bool,
291}
292
293impl ToolHookManager {
294    /// 创建新的钩子管理器
295    pub fn new(enabled: bool) -> Self {
296        Self {
297            hooks: Arc::new(RwLock::new(HashMap::new())),
298            enabled,
299        }
300    }
301
302    /// 注册钩子
303    pub async fn register_hook(&self, trigger: HookTrigger, hook: Box<dyn ToolHook>) {
304        if !self.enabled {
305            return;
306        }
307
308        let mut hooks = self.hooks.write().await;
309        let hook_list = hooks.entry(trigger).or_insert_with(Vec::new);
310        hook_list.push(hook);
311
312        // 按优先级排序
313        hook_list.sort_by_key(|h| h.priority());
314    }
315
316    /// 触发钩子
317    pub async fn trigger_hooks(&self, trigger: HookTrigger, context: &HookContext) -> Result<()> {
318        if !self.enabled {
319            return Ok(());
320        }
321
322        let hooks = self.hooks.read().await;
323        if let Some(hook_list) = hooks.get(&trigger) {
324            for hook in hook_list {
325                if hook.should_execute(context) {
326                    if let Err(e) = hook.execute(context).await {
327                        tracing::warn!("钩子 '{}' 执行失败: {}", hook.name(), e);
328                    }
329                }
330            }
331        }
332
333        Ok(())
334    }
335
336    /// 获取已注册的钩子数量
337    pub async fn hook_count(&self, trigger: HookTrigger) -> usize {
338        let hooks = self.hooks.read().await;
339        hooks.get(&trigger).map(|list| list.len()).unwrap_or(0)
340    }
341
342    /// 检查是否启用
343    pub fn is_enabled(&self) -> bool {
344        self.enabled
345    }
346
347    /// 启用/禁用钩子系统
348    pub fn set_enabled(&mut self, enabled: bool) {
349        self.enabled = enabled;
350    }
351
352    /// 注册默认钩子
353    pub async fn register_default_hooks(&self) {
354        // 注册日志钩子
355        self.register_hook(
356            HookTrigger::PreExecution,
357            Box::new(LoggingHook::new(
358                "pre_execution_log".to_string(),
359                tracing::Level::DEBUG,
360            )),
361        )
362        .await;
363
364        self.register_hook(
365            HookTrigger::PostExecution,
366            Box::new(LoggingHook::new(
367                "post_execution_log".to_string(),
368                tracing::Level::DEBUG,
369            )),
370        )
371        .await;
372
373        // 注册文件操作钩子
374        self.register_hook(
375            HookTrigger::PreExecution,
376            Box::new(FileOperationHook::new(
377                "file_operation_check".to_string(),
378                vec!["Write".to_string(), "Edit".to_string(), "Read".to_string()],
379            )),
380        )
381        .await;
382
383        // 注册错误跟踪钩子
384        self.register_hook(
385            HookTrigger::OnError,
386            Box::new(ErrorTrackingHook::new("error_tracker".to_string())),
387        )
388        .await;
389    }
390}
391
392impl Default for ToolHookManager {
393    fn default() -> Self {
394        Self::new(true)
395    }
396}
397
398#[cfg(test)]
399mod tests {
400    use super::*;
401    use std::path::PathBuf;
402
403    struct TestHook {
404        name: String,
405        executed: Arc<RwLock<bool>>,
406    }
407
408    impl TestHook {
409        fn new(name: String) -> (Self, Arc<RwLock<bool>>) {
410            let executed = Arc::new(RwLock::new(false));
411            let hook = Self {
412                name,
413                executed: executed.clone(),
414            };
415            (hook, executed)
416        }
417    }
418
419    #[async_trait]
420    impl ToolHook for TestHook {
421        fn name(&self) -> &str {
422            &self.name
423        }
424
425        fn description(&self) -> &str {
426            "测试钩子"
427        }
428
429        async fn execute(&self, _context: &HookContext) -> Result<()> {
430            let mut executed = self.executed.write().await;
431            *executed = true;
432            Ok(())
433        }
434    }
435
436    fn create_test_context() -> HookContext {
437        let tool_context = ToolContext::new(PathBuf::from("/tmp"))
438            .with_session_id("test-session")
439            .with_user("test-user");
440
441        HookContext::new(
442            "TestTool".to_string(),
443            serde_json::json!({"test": "value"}),
444            tool_context,
445        )
446    }
447
448    #[tokio::test]
449    async fn test_hook_manager_creation() {
450        let manager = ToolHookManager::new(true);
451        assert!(manager.is_enabled());
452
453        let manager_disabled = ToolHookManager::new(false);
454        assert!(!manager_disabled.is_enabled());
455    }
456
457    #[tokio::test]
458    async fn test_hook_registration_and_execution() {
459        let manager = ToolHookManager::new(true);
460        let (hook, executed) = TestHook::new("test_hook".to_string());
461
462        manager
463            .register_hook(HookTrigger::PreExecution, Box::new(hook))
464            .await;
465
466        assert_eq!(manager.hook_count(HookTrigger::PreExecution).await, 1);
467
468        let context = create_test_context();
469        manager
470            .trigger_hooks(HookTrigger::PreExecution, &context)
471            .await
472            .unwrap();
473
474        let was_executed = *executed.read().await;
475        assert!(was_executed);
476    }
477
478    #[tokio::test]
479    async fn test_error_tracking_hook() {
480        let hook = ErrorTrackingHook::new("error_tracker".to_string());
481
482        let context = create_test_context().with_error("Test error message".to_string());
483
484        hook.execute(&context).await.unwrap();
485
486        let history = hook.get_error_history("TestTool").await;
487        assert_eq!(history.len(), 1);
488        assert_eq!(history[0], "Test error message");
489
490        let is_repeated = hook.is_repeated_error("TestTool", "Test error").await;
491        assert!(is_repeated);
492    }
493
494    #[tokio::test]
495    async fn test_file_operation_hook() {
496        let hook = FileOperationHook::new("file_hook".to_string(), vec!["Write".to_string()]);
497
498        let context = create_test_context();
499        assert!(!hook.should_execute(&context)); // TestTool 不包含 "Write",应该不匹配
500
501        let write_context = HookContext::new(
502            "WriteTool".to_string(),
503            serde_json::json!({"path": "/test/file.txt"}),
504            ToolContext::new(PathBuf::from("/tmp")),
505        );
506        assert!(hook.should_execute(&write_context)); // WriteTool 包含 "Write",应该匹配
507    }
508}