1use std::collections::{HashMap, HashSet};
2use std::sync::{Arc, RwLock};
3
4use crate::plugin::{
5 BlockReason, HookAction, HookContext, HookIssue, HookPhase, HookReport, PostHook, PreHook,
6};
7
8#[derive(Clone, Default)]
9pub struct RuntimeHookConfig {
10 pub pre_hooks: Vec<Arc<dyn PreHook>>,
11 pub post_hooks: Vec<Arc<dyn PostHook>>,
12 pub pre_tool_use_hooks: Vec<Arc<dyn PreHook>>,
16}
17
18impl std::fmt::Debug for RuntimeHookConfig {
19 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20 f.debug_struct("RuntimeHookConfig")
21 .field("pre_hooks", &hook_names(&self.pre_hooks))
22 .field("post_hooks", &hook_names(&self.post_hooks))
23 .field("pre_tool_use_hooks", &hook_names(&self.pre_tool_use_hooks))
24 .finish()
25 }
26}
27
28impl PartialEq for RuntimeHookConfig {
29 fn eq(&self, other: &Self) -> bool {
30 hook_names(&self.pre_hooks) == hook_names(&other.pre_hooks)
31 && hook_names(&self.post_hooks) == hook_names(&other.post_hooks)
32 && hook_names(&self.pre_tool_use_hooks) == hook_names(&other.pre_tool_use_hooks)
33 }
34}
35
36impl Eq for RuntimeHookConfig {}
37
38impl RuntimeHookConfig {
39 pub fn new() -> Self {
42 Self::default()
43 }
44
45 pub fn with_pre_hook(mut self, hook: Arc<dyn PreHook>) -> Self {
48 self.pre_hooks.push(hook);
49 self
50 }
51
52 pub fn with_post_hook(mut self, hook: Arc<dyn PostHook>) -> Self {
55 self.post_hooks.push(hook);
56 self
57 }
58
59 pub fn with_pre_tool_use_hook(mut self, hook: Arc<dyn PreHook>) -> Self {
62 self.pre_tool_use_hooks.push(hook);
63 self
64 }
65
66 pub fn has_pre_tool_use_hooks(&self) -> bool {
69 !self.pre_tool_use_hooks.is_empty()
70 }
71
72 pub fn is_empty(&self) -> bool {
75 self.pre_hooks.is_empty()
76 && self.post_hooks.is_empty()
77 && self.pre_tool_use_hooks.is_empty()
78 }
79}
80
81pub(crate) fn merge_hook_configs(
84 defaults: &RuntimeHookConfig,
85 overlay: &RuntimeHookConfig,
86) -> RuntimeHookConfig {
87 if defaults.is_empty() {
88 return overlay.clone();
89 }
90 if overlay.is_empty() {
91 return defaults.clone();
92 }
93 RuntimeHookConfig {
94 pre_hooks: merge_preferred_hooks(&overlay.pre_hooks, &defaults.pre_hooks),
95 post_hooks: merge_preferred_hooks(&overlay.post_hooks, &defaults.post_hooks),
96 pre_tool_use_hooks: merge_preferred_hooks(
97 &overlay.pre_tool_use_hooks,
98 &defaults.pre_tool_use_hooks,
99 ),
100 }
101}
102
103pub(crate) struct HookKernel {
104 pre_hooks: RwLock<Vec<Arc<dyn PreHook>>>,
105 post_hooks: RwLock<Vec<Arc<dyn PostHook>>>,
106 pre_tool_use_hooks: RwLock<Vec<Arc<dyn PreHook>>>,
107 thread_scoped_pre_tool_use_hooks: RwLock<HashMap<String, Vec<Arc<dyn PreHook>>>>,
108 latest_report: RwLock<HookReport>,
109}
110
111#[derive(Clone, Debug)]
112pub(crate) struct PreHookDecision {
113 pub hook_name: String,
114 pub action: HookAction,
115}
116
117impl HookKernel {
118 pub(crate) fn new(config: RuntimeHookConfig) -> Self {
119 Self {
120 pre_hooks: RwLock::new(config.pre_hooks),
121 post_hooks: RwLock::new(config.post_hooks),
122 pre_tool_use_hooks: RwLock::new(config.pre_tool_use_hooks),
123 thread_scoped_pre_tool_use_hooks: RwLock::new(HashMap::new()),
124 latest_report: RwLock::new(HookReport::default()),
125 }
126 }
127
128 pub(crate) fn is_enabled(&self) -> bool {
129 rwlock_len(&self.pre_hooks) > 0
130 || rwlock_len(&self.post_hooks) > 0
131 || rwlock_len(&self.pre_tool_use_hooks) > 0
132 }
133
134 pub(crate) fn has_pre_tool_use_hooks(&self) -> bool {
137 rwlock_len(&self.pre_tool_use_hooks) > 0
138 || match self.thread_scoped_pre_tool_use_hooks.read() {
139 Ok(guard) => guard.values().any(|hooks| !hooks.is_empty()),
140 Err(poisoned) => poisoned
141 .into_inner()
142 .values()
143 .any(|hooks| !hooks.is_empty()),
144 }
145 }
146
147 pub(crate) fn register_thread_scoped_pre_tool_use_hooks(
148 &self,
149 thread_id: &str,
150 hooks: &[Arc<dyn PreHook>],
151 ) {
152 if hooks.is_empty() {
153 return;
154 }
155 let mut guard = match self.thread_scoped_pre_tool_use_hooks.write() {
156 Ok(guard) => guard,
157 Err(poisoned) => poisoned.into_inner(),
158 };
159 let entry = guard.entry(thread_id.to_owned()).or_default();
160 let mut names: HashSet<&'static str> = entry.iter().map(|hook| hook.hook_name()).collect();
161 for hook in hooks {
162 if names.insert(hook.hook_name()) {
163 entry.push(Arc::clone(hook));
164 }
165 }
166 }
167
168 pub(crate) fn clear_thread_scoped_pre_tool_use_hooks(&self, thread_id: &str) {
169 let mut guard = match self.thread_scoped_pre_tool_use_hooks.write() {
170 Ok(guard) => guard,
171 Err(poisoned) => poisoned.into_inner(),
172 };
173 guard.remove(thread_id);
174 }
175
176 pub(crate) fn register(&self, config: RuntimeHookConfig) {
180 if config.is_empty() {
181 return;
182 }
183 register_dedup_hooks(&self.pre_hooks, config.pre_hooks);
184 register_dedup_hooks(&self.post_hooks, config.post_hooks);
185 register_dedup_hooks(&self.pre_tool_use_hooks, config.pre_tool_use_hooks);
186 }
187
188 pub(crate) fn report_snapshot(&self) -> HookReport {
189 match self.latest_report.read() {
190 Ok(guard) => guard.clone(),
191 Err(poisoned) => poisoned.into_inner().clone(),
192 }
193 }
194
195 pub(crate) fn set_latest_report(&self, report: HookReport) {
196 match self.latest_report.write() {
197 Ok(mut guard) => *guard = report,
198 Err(poisoned) => *poisoned.into_inner() = report,
199 }
200 }
201
202 pub(crate) async fn run_pre_with(
207 &self,
208 ctx: &HookContext,
209 report: &mut HookReport,
210 scoped: Option<&RuntimeHookConfig>,
211 ) -> Result<Vec<PreHookDecision>, BlockReason> {
212 let hooks = merge_owned_with_overlay(
213 read_rwlock_vec(&self.pre_hooks),
214 scoped.map(|cfg| cfg.pre_hooks.as_slice()),
215 );
216 let mut decisions = Vec::with_capacity(hooks.len());
217 for hook in hooks {
218 match hook.call(ctx).await {
219 Ok(HookAction::Block(reason)) => return Err(reason),
220 Ok(action) => decisions.push(PreHookDecision {
221 hook_name: hook.name().to_owned(),
222 action,
223 }),
224 Err(issue) => report.push(normalize_issue(issue, hook.name(), ctx.phase)),
225 }
226 }
227 Ok(decisions)
228 }
229
230 pub(crate) async fn run_pre_tool_use_with(
235 &self,
236 ctx: &HookContext,
237 report: &mut HookReport,
238 ) -> Result<(), BlockReason> {
239 let mut hooks = read_rwlock_vec(&self.pre_tool_use_hooks);
240 if let Some(thread_id) = ctx.thread_id.as_deref() {
241 let scoped = self.thread_scoped_pre_tool_use_hooks_for(thread_id);
242 hooks = merge_owned_with_overlay(hooks, scoped.as_deref());
243 }
244 for hook in hooks {
245 match hook.call(ctx).await {
246 Ok(HookAction::Block(reason)) => return Err(reason),
247 Ok(_) => {}
248 Err(issue) => report.push(normalize_issue(issue, hook.name(), ctx.phase)),
249 }
250 }
251 Ok(())
252 }
253
254 fn thread_scoped_pre_tool_use_hooks_for(
255 &self,
256 thread_id: &str,
257 ) -> Option<Vec<Arc<dyn PreHook>>> {
258 let guard = match self.thread_scoped_pre_tool_use_hooks.read() {
259 Ok(guard) => guard,
260 Err(poisoned) => poisoned.into_inner(),
261 };
262 guard.get(thread_id).cloned()
263 }
264
265 pub(crate) async fn run_post_with(
268 &self,
269 ctx: &HookContext,
270 report: &mut HookReport,
271 scoped: Option<&RuntimeHookConfig>,
272 ) {
273 let hooks = merge_owned_with_overlay(
274 read_rwlock_vec(&self.post_hooks),
275 scoped.map(|cfg| cfg.post_hooks.as_slice()),
276 );
277 for hook in hooks {
278 if let Err(issue) = hook.call(ctx).await {
279 report.push(normalize_issue(issue, hook.name(), ctx.phase));
280 }
281 }
282 }
283}
284
285fn normalize_issue(mut issue: HookIssue, fallback_name: &str, phase: HookPhase) -> HookIssue {
286 if issue.hook_name.trim().is_empty() {
287 issue.hook_name = fallback_name.to_owned();
288 }
289 issue.phase = phase;
290 issue
291}
292
293fn hook_names<T>(hooks: &[Arc<T>]) -> Vec<&'static str>
294where
295 T: ?Sized + HookName,
296{
297 hooks.iter().map(|hook| hook.hook_name()).collect()
298}
299
300trait HookName {
301 fn hook_name(&self) -> &'static str;
302}
303
304impl HookName for dyn PreHook {
305 fn hook_name(&self) -> &'static str {
306 self.name()
307 }
308}
309
310impl HookName for dyn PostHook {
311 fn hook_name(&self) -> &'static str {
312 self.name()
313 }
314}
315
316fn rwlock_len<T: ?Sized>(target: &RwLock<Vec<Arc<T>>>) -> usize {
319 match target.read() {
320 Ok(guard) => guard.len(),
321 Err(poisoned) => poisoned.into_inner().len(),
322 }
323}
324
325fn read_rwlock_vec<T: ?Sized>(target: &RwLock<Vec<Arc<T>>>) -> Vec<Arc<T>> {
328 match target.read() {
329 Ok(guard) => guard.clone(),
330 Err(poisoned) => poisoned.into_inner().clone(),
331 }
332}
333
334fn merge_preferred_hooks<T>(preferred: &[Arc<T>], fallback: &[Arc<T>]) -> Vec<Arc<T>>
335where
336 T: ?Sized + HookName,
337{
338 let mut merged = Vec::with_capacity(preferred.len() + fallback.len());
339 let mut names: HashSet<&'static str> = HashSet::with_capacity(preferred.len() + fallback.len());
340 for hook in preferred {
341 if names.insert(hook.hook_name()) {
342 merged.push(Arc::clone(hook));
343 }
344 }
345 for hook in fallback {
346 if names.insert(hook.hook_name()) {
347 merged.push(Arc::clone(hook));
348 }
349 }
350 merged
351}
352
353fn merge_owned_with_overlay<T>(mut base: Vec<Arc<T>>, overlay: Option<&[Arc<T>]>) -> Vec<Arc<T>>
354where
355 T: ?Sized + HookName,
356{
357 let Some(overlay) = overlay else {
358 return base;
359 };
360 if overlay.is_empty() {
361 return base;
362 }
363 let mut names: HashSet<&'static str> = base.iter().map(|hook| hook.hook_name()).collect();
364 for hook in overlay {
365 if names.insert(hook.hook_name()) {
366 base.push(Arc::clone(hook));
367 }
368 }
369 base
370}
371
372fn register_dedup_hooks<T>(target: &RwLock<Vec<Arc<T>>>, incoming: Vec<Arc<T>>)
375where
376 T: ?Sized + HookName,
377{
378 let mut guard = match target.write() {
379 Ok(guard) => guard,
380 Err(poisoned) => poisoned.into_inner(),
381 };
382 let mut names: HashSet<&'static str> = guard.iter().map(|hook| hook.hook_name()).collect();
383 for hook in incoming {
384 if names.insert(hook.hook_name()) {
385 guard.push(hook);
386 }
387 }
388}