1use super::types::{HookConfig, HookEvent, LegacyHookConfig};
6use parking_lot::RwLock;
7use regex::Regex;
8use std::collections::HashMap;
9use std::sync::Arc;
10
11pub type RegisteredHooks = HashMap<HookEvent, Vec<HookConfig>>;
13
14#[derive(Debug, Default)]
16pub struct HookRegistry {
17 hooks: RwLock<RegisteredHooks>,
18}
19
20impl HookRegistry {
21 pub fn new() -> Self {
23 Self {
24 hooks: RwLock::new(HashMap::new()),
25 }
26 }
27
28 pub fn register(&self, event: HookEvent, config: HookConfig) {
30 let mut hooks = self.hooks.write();
31 hooks.entry(event).or_default().push(config);
32 }
33
34 pub fn register_legacy(&self, config: LegacyHookConfig) {
36 let (event, hook_config) = config.into();
37 self.register(event, hook_config);
38 }
39
40 pub fn get_matching(&self, event: HookEvent, tool_name: Option<&str>) -> Vec<HookConfig> {
42 let hooks = self.hooks.read();
43 let event_hooks = match hooks.get(&event) {
44 Some(h) => h,
45 None => return vec![],
46 };
47
48 event_hooks
49 .iter()
50 .filter(|hook| {
51 if let Some(matcher) = hook.matcher() {
52 if let Some(name) = tool_name {
53 if matcher.starts_with('/') && matcher.ends_with('/') {
55 let pattern = matcher
56 .get(1..matcher.len().saturating_sub(1))
57 .unwrap_or("");
58 if let Ok(regex) = Regex::new(pattern) {
59 return regex.is_match(name);
60 }
61 }
62 return matcher == name;
64 }
65 return false;
66 }
67 true
68 })
69 .cloned()
70 .collect()
71 }
72
73 pub fn get_for_event(&self, event: HookEvent) -> Vec<HookConfig> {
75 let hooks = self.hooks.read();
76 hooks.get(&event).cloned().unwrap_or_default()
77 }
78
79 pub fn get_all(&self) -> RegisteredHooks {
81 self.hooks.read().clone()
82 }
83
84 pub fn get_all_flat(&self) -> Vec<(HookEvent, HookConfig)> {
86 let hooks = self.hooks.read();
87 let mut result = Vec::new();
88 for (event, configs) in hooks.iter() {
89 for config in configs {
90 result.push((*event, config.clone()));
91 }
92 }
93 result
94 }
95
96 pub fn count(&self) -> usize {
98 let hooks = self.hooks.read();
99 hooks.values().map(|v| v.len()).sum()
100 }
101
102 pub fn count_for_event(&self, event: HookEvent) -> usize {
104 let hooks = self.hooks.read();
105 hooks.get(&event).map(|v| v.len()).unwrap_or(0)
106 }
107
108 pub fn unregister(&self, event: HookEvent, config: &HookConfig) -> bool {
110 let mut hooks = self.hooks.write();
111 if let Some(event_hooks) = hooks.get_mut(&event) {
112 let initial_len = event_hooks.len();
113 event_hooks.retain(|h| !Self::configs_match(h, config));
114 let removed = event_hooks.len() < initial_len;
115 if event_hooks.is_empty() {
116 hooks.remove(&event);
117 }
118 return removed;
119 }
120 false
121 }
122
123 pub fn clear_event(&self, event: HookEvent) {
125 let mut hooks = self.hooks.write();
126 hooks.remove(&event);
127 }
128
129 pub fn clear(&self) {
131 let mut hooks = self.hooks.write();
132 hooks.clear();
133 }
134
135 fn configs_match(a: &HookConfig, b: &HookConfig) -> bool {
137 match (a, b) {
138 (HookConfig::Command(a), HookConfig::Command(b)) => a.command == b.command,
139 (HookConfig::Url(a), HookConfig::Url(b)) => a.url == b.url,
140 (HookConfig::Mcp(a), HookConfig::Mcp(b)) => a.server == b.server && a.tool == b.tool,
141 (HookConfig::Prompt(a), HookConfig::Prompt(b)) => a.prompt == b.prompt,
142 (HookConfig::Agent(a), HookConfig::Agent(b)) => a.agent_type == b.agent_type,
143 _ => false,
144 }
145 }
146}
147
148pub type SharedHookRegistry = Arc<HookRegistry>;
150
151static GLOBAL_REGISTRY: once_cell::sync::Lazy<SharedHookRegistry> =
153 once_cell::sync::Lazy::new(|| Arc::new(HookRegistry::new()));
154
155pub fn global_registry() -> SharedHookRegistry {
157 GLOBAL_REGISTRY.clone()
158}
159
160pub fn register_hook(event: HookEvent, config: HookConfig) {
162 global_registry().register(event, config);
163}
164
165pub fn register_legacy_hook(config: LegacyHookConfig) {
167 global_registry().register_legacy(config);
168}
169
170pub fn clear_hooks() {
172 global_registry().clear();
173}
174
175pub fn get_hook_count() -> usize {
177 global_registry().count()
178}
179
180pub fn get_event_hook_count(event: HookEvent) -> usize {
182 global_registry().count_for_event(event)
183}