Skip to main content

aster/hooks/
registry.rs

1//! Hook 注册表
2//!
3//! 管理已注册的 hooks
4
5use super::types::{HookConfig, HookEvent, LegacyHookConfig};
6use parking_lot::RwLock;
7use regex::Regex;
8use std::collections::HashMap;
9use std::sync::Arc;
10
11/// 已注册的 Hooks 存储
12pub type RegisteredHooks = HashMap<HookEvent, Vec<HookConfig>>;
13
14/// Hook 注册表
15#[derive(Debug, Default)]
16pub struct HookRegistry {
17    hooks: RwLock<RegisteredHooks>,
18}
19
20impl HookRegistry {
21    /// 创建新的注册表
22    pub fn new() -> Self {
23        Self {
24            hooks: RwLock::new(HashMap::new()),
25        }
26    }
27
28    /// 注册 hook
29    pub fn register(&self, event: HookEvent, config: HookConfig) {
30        let mut hooks = self.hooks.write();
31        hooks.entry(event).or_default().push(config);
32    }
33
34    /// 注册旧版 hook(兼容性)
35    pub fn register_legacy(&self, config: LegacyHookConfig) {
36        let (event, hook_config) = config.into();
37        self.register(event, hook_config);
38    }
39
40    /// 获取匹配的 hooks
41    pub fn get_matching(&self, event: HookEvent, tool_name: Option<&str>) -> Vec<HookConfig> {
42        let hooks = self.hooks.read();
43        let event_hooks = match hooks.get(&event) {
44            Some(h) => h,
45            None => return vec![],
46        };
47
48        event_hooks
49            .iter()
50            .filter(|hook| {
51                if let Some(matcher) = hook.matcher() {
52                    if let Some(name) = tool_name {
53                        // 支持正则匹配
54                        if matcher.starts_with('/') && matcher.ends_with('/') {
55                            let pattern = matcher
56                                .get(1..matcher.len().saturating_sub(1))
57                                .unwrap_or("");
58                            if let Ok(regex) = Regex::new(pattern) {
59                                return regex.is_match(name);
60                            }
61                        }
62                        // 精确匹配
63                        return matcher == name;
64                    }
65                    return false;
66                }
67                true
68            })
69            .cloned()
70            .collect()
71    }
72
73    /// 获取指定事件的 hooks
74    pub fn get_for_event(&self, event: HookEvent) -> Vec<HookConfig> {
75        let hooks = self.hooks.read();
76        hooks.get(&event).cloned().unwrap_or_default()
77    }
78
79    /// 获取所有已注册的 hooks
80    pub fn get_all(&self) -> RegisteredHooks {
81        self.hooks.read().clone()
82    }
83
84    /// 获取所有已注册的 hooks(扁平数组)
85    pub fn get_all_flat(&self) -> Vec<(HookEvent, HookConfig)> {
86        let hooks = self.hooks.read();
87        let mut result = Vec::new();
88        for (event, configs) in hooks.iter() {
89            for config in configs {
90                result.push((*event, config.clone()));
91            }
92        }
93        result
94    }
95
96    /// 获取 hook 总数
97    pub fn count(&self) -> usize {
98        let hooks = self.hooks.read();
99        hooks.values().map(|v| v.len()).sum()
100    }
101
102    /// 获取指定事件的 hook 数量
103    pub fn count_for_event(&self, event: HookEvent) -> usize {
104        let hooks = self.hooks.read();
105        hooks.get(&event).map(|v| v.len()).unwrap_or(0)
106    }
107
108    /// 取消注册 hook
109    pub fn unregister(&self, event: HookEvent, config: &HookConfig) -> bool {
110        let mut hooks = self.hooks.write();
111        if let Some(event_hooks) = hooks.get_mut(&event) {
112            let initial_len = event_hooks.len();
113            event_hooks.retain(|h| !Self::configs_match(h, config));
114            let removed = event_hooks.len() < initial_len;
115            if event_hooks.is_empty() {
116                hooks.remove(&event);
117            }
118            return removed;
119        }
120        false
121    }
122
123    /// 清除指定事件的所有 hooks
124    pub fn clear_event(&self, event: HookEvent) {
125        let mut hooks = self.hooks.write();
126        hooks.remove(&event);
127    }
128
129    /// 清除所有 hooks
130    pub fn clear(&self) {
131        let mut hooks = self.hooks.write();
132        hooks.clear();
133    }
134
135    /// 比较两个配置是否匹配
136    fn configs_match(a: &HookConfig, b: &HookConfig) -> bool {
137        match (a, b) {
138            (HookConfig::Command(a), HookConfig::Command(b)) => a.command == b.command,
139            (HookConfig::Url(a), HookConfig::Url(b)) => a.url == b.url,
140            (HookConfig::Mcp(a), HookConfig::Mcp(b)) => a.server == b.server && a.tool == b.tool,
141            (HookConfig::Prompt(a), HookConfig::Prompt(b)) => a.prompt == b.prompt,
142            (HookConfig::Agent(a), HookConfig::Agent(b)) => a.agent_type == b.agent_type,
143            _ => false,
144        }
145    }
146}
147
148/// 共享的 Hook 注册表
149pub type SharedHookRegistry = Arc<HookRegistry>;
150
151/// 全局注册表
152static GLOBAL_REGISTRY: once_cell::sync::Lazy<SharedHookRegistry> =
153    once_cell::sync::Lazy::new(|| Arc::new(HookRegistry::new()));
154
155/// 获取全局注册表
156pub fn global_registry() -> SharedHookRegistry {
157    GLOBAL_REGISTRY.clone()
158}
159
160/// 注册 hook 到全局注册表
161pub fn register_hook(event: HookEvent, config: HookConfig) {
162    global_registry().register(event, config);
163}
164
165/// 注册旧版 hook 到全局注册表
166pub fn register_legacy_hook(config: LegacyHookConfig) {
167    global_registry().register_legacy(config);
168}
169
170/// 清除全局注册表
171pub fn clear_hooks() {
172    global_registry().clear();
173}
174
175/// 获取 hook 总数
176pub fn get_hook_count() -> usize {
177    global_registry().count()
178}
179
180/// 获取指定事件的 hook 数量
181pub fn get_event_hook_count(event: HookEvent) -> usize {
182    global_registry().count_for_event(event)
183}