1use crate::{Hook, HookAction, HookContext, HookPoint};
7use orcs_types::ComponentId;
8use std::collections::HashMap;
9
10struct RegisteredHook {
12 hook: Box<dyn Hook>,
13 enabled: bool,
14 owner: Option<ComponentId>,
17}
18
19pub struct HookRegistry {
30 hooks: HashMap<HookPoint, Vec<RegisteredHook>>,
31}
32
33impl HookRegistry {
34 #[must_use]
36 pub fn new() -> Self {
37 Self {
38 hooks: HashMap::new(),
39 }
40 }
41
42 pub fn register(&mut self, hook: Box<dyn Hook>) -> String {
46 self.register_inner(hook, None)
47 }
48
49 pub fn register_owned(&mut self, hook: Box<dyn Hook>, owner: ComponentId) -> String {
54 self.register_inner(hook, Some(owner))
55 }
56
57 fn register_inner(&mut self, hook: Box<dyn Hook>, owner: Option<ComponentId>) -> String {
58 let id = hook.id().to_string();
59 let point = hook.hook_point();
60 let priority = hook.priority();
61
62 let entry = self.hooks.entry(point).or_default();
63
64 let rh = RegisteredHook {
65 hook,
66 enabled: true,
67 owner,
68 };
69
70 let pos = entry
72 .iter()
73 .position(|h| h.hook.priority() > priority)
74 .unwrap_or(entry.len());
75 entry.insert(pos, rh);
76
77 id
78 }
79
80 pub fn unregister(&mut self, id: &str) -> bool {
82 let mut found = false;
83 for hooks in self.hooks.values_mut() {
84 let before = hooks.len();
85 hooks.retain(|rh| rh.hook.id() != id);
86 if hooks.len() < before {
87 found = true;
88 }
89 }
90 found
91 }
92
93 pub fn unregister_by_owner(&mut self, owner: &ComponentId) -> usize {
97 let mut count = 0;
98 for hooks in self.hooks.values_mut() {
99 let before = hooks.len();
100 hooks.retain(|rh| rh.owner.as_ref() != Some(owner));
101 count += before - hooks.len();
102 }
103 count
104 }
105
106 pub fn set_enabled(&mut self, id: &str, enabled: bool) {
108 for hooks in self.hooks.values_mut() {
109 for rh in hooks.iter_mut() {
110 if rh.hook.id() == id {
111 rh.enabled = enabled;
112 return;
113 }
114 }
115 }
116 }
117
118 #[must_use]
120 pub fn len(&self) -> usize {
121 self.hooks.values().map(|v| v.len()).sum()
122 }
123
124 #[must_use]
126 pub fn is_empty(&self) -> bool {
127 self.len() == 0
128 }
129
130 pub fn dispatch(
140 &self,
141 point: HookPoint,
142 component_id: &ComponentId,
143 child_id: Option<&str>,
144 ctx: HookContext,
145 ) -> HookAction {
146 let Some(hooks) = self.hooks.get(&point) else {
147 return HookAction::Continue(Box::new(ctx));
148 };
149
150 let mut current_ctx = ctx;
151
152 for rh in hooks.iter().filter(|rh| rh.enabled) {
153 if !rh.hook.fql_pattern().matches(component_id, child_id) {
154 continue;
155 }
156
157 if current_ctx.is_depth_exceeded() {
159 tracing::warn!(
160 hook_id = rh.hook.id(),
161 depth = current_ctx.depth,
162 max_depth = current_ctx.max_depth,
163 "hook chain depth exceeded, stopping chain"
164 );
165 break;
166 }
167
168 match rh.hook.execute(current_ctx.clone()) {
169 HookAction::Continue(new_ctx) => {
170 current_ctx = *new_ctx;
171 }
172 action @ (HookAction::Skip(_) | HookAction::Abort { .. }) => {
173 return action;
175 }
176 HookAction::Replace(value) => {
177 if point.is_post() {
178 current_ctx.payload = value;
180 } else {
181 tracing::warn!(
183 hook_id = rh.hook.id(),
184 point = %point,
185 "Replace returned from non-post hook, ignoring"
186 );
187 }
188 }
189 }
190 }
191
192 HookAction::Continue(Box::new(current_ctx))
193 }
194}
195
196impl Default for HookRegistry {
197 fn default() -> Self {
198 Self::new()
199 }
200}
201
202#[cfg(test)]
203mod tests {
204 use super::*;
205 use crate::hook::testing::MockHook;
206 use orcs_types::{ChannelId, Principal};
207 use serde_json::json;
208
209 fn test_ctx(point: HookPoint) -> HookContext {
210 HookContext::new(
211 point,
212 ComponentId::builtin("llm"),
213 ChannelId::new(),
214 Principal::System,
215 0,
216 json!({"op": "test"}),
217 )
218 }
219
220 #[test]
223 fn dispatch_no_hooks_returns_continue() {
224 let reg = HookRegistry::new();
225 let ctx = test_ctx(HookPoint::RequestPreDispatch);
226 let action = reg.dispatch(
227 HookPoint::RequestPreDispatch,
228 &ComponentId::builtin("llm"),
229 None,
230 ctx.clone(),
231 );
232 assert!(action.is_continue());
233 if let HookAction::Continue(result) = action {
234 assert_eq!(result.payload, ctx.payload);
235 }
236 }
237
238 #[test]
239 fn dispatch_pass_through_hook() {
240 let mut reg = HookRegistry::new();
241 let hook = MockHook::pass_through("h1", "*::*", HookPoint::RequestPreDispatch);
242 let counter = hook.call_count.clone();
243 reg.register(Box::new(hook));
244
245 let ctx = test_ctx(HookPoint::RequestPreDispatch);
246 let action = reg.dispatch(
247 HookPoint::RequestPreDispatch,
248 &ComponentId::builtin("llm"),
249 None,
250 ctx.clone(),
251 );
252
253 assert!(action.is_continue());
254 assert_eq!(counter.load(std::sync::atomic::Ordering::SeqCst), 1);
255 }
256
257 #[test]
258 fn dispatch_modifying_hook() {
259 let mut reg = HookRegistry::new();
260 let hook = MockHook::modifier("mod", "*::*", HookPoint::RequestPreDispatch, |ctx| {
261 ctx.payload = json!({"modified": true});
262 });
263 reg.register(Box::new(hook));
264
265 let ctx = test_ctx(HookPoint::RequestPreDispatch);
266 let action = reg.dispatch(
267 HookPoint::RequestPreDispatch,
268 &ComponentId::builtin("llm"),
269 None,
270 ctx,
271 );
272
273 if let HookAction::Continue(result) = action {
274 assert_eq!(result.payload, json!({"modified": true}));
275 } else {
276 panic!("expected Continue");
277 }
278 }
279
280 #[test]
283 fn dispatch_skip_stops_chain() {
284 let mut reg = HookRegistry::new();
285 let skip = MockHook::skipper(
286 "skip",
287 "*::*",
288 HookPoint::RequestPreDispatch,
289 json!({"skipped": true}),
290 )
291 .with_priority(10);
292 let after = MockHook::pass_through("after", "*::*", HookPoint::RequestPreDispatch)
293 .with_priority(20);
294 let after_counter = after.call_count.clone();
295
296 reg.register(Box::new(skip));
297 reg.register(Box::new(after));
298
299 let ctx = test_ctx(HookPoint::RequestPreDispatch);
300 let action = reg.dispatch(
301 HookPoint::RequestPreDispatch,
302 &ComponentId::builtin("llm"),
303 None,
304 ctx,
305 );
306
307 assert!(action.is_skip());
308 assert_eq!(after_counter.load(std::sync::atomic::Ordering::SeqCst), 0);
310 }
311
312 #[test]
313 fn dispatch_abort_stops_chain() {
314 let mut reg = HookRegistry::new();
315 let abort = MockHook::aborter("abort", "*::*", HookPoint::RequestPreDispatch, "policy");
316 reg.register(Box::new(abort));
317
318 let ctx = test_ctx(HookPoint::RequestPreDispatch);
319 let action = reg.dispatch(
320 HookPoint::RequestPreDispatch,
321 &ComponentId::builtin("llm"),
322 None,
323 ctx,
324 );
325
326 assert!(action.is_abort());
327 if let HookAction::Abort { reason } = action {
328 assert_eq!(reason, "policy");
329 }
330 }
331
332 #[test]
335 fn priority_ordering() {
336 let mut reg = HookRegistry::new();
337
338 let h100 = MockHook::modifier("h100", "*::*", HookPoint::RequestPreDispatch, |ctx| {
340 let arr = ctx
341 .payload
342 .as_array_mut()
343 .expect("payload should be a JSON array for priority ordering test");
344 arr.push(json!("h100"));
345 })
346 .with_priority(100);
347
348 let h10 = MockHook::modifier("h10", "*::*", HookPoint::RequestPreDispatch, |ctx| {
349 let arr = ctx
350 .payload
351 .as_array_mut()
352 .expect("payload should be a JSON array for h10 priority test");
353 arr.push(json!("h10"));
354 })
355 .with_priority(10);
356
357 let h50 = MockHook::modifier("h50", "*::*", HookPoint::RequestPreDispatch, |ctx| {
358 let arr = ctx
359 .payload
360 .as_array_mut()
361 .expect("payload should be a JSON array for h50 priority test");
362 arr.push(json!("h50"));
363 })
364 .with_priority(50);
365
366 reg.register(Box::new(h100));
367 reg.register(Box::new(h10));
368 reg.register(Box::new(h50));
369
370 let mut ctx = test_ctx(HookPoint::RequestPreDispatch);
371 ctx.payload = json!([]);
372
373 let action = reg.dispatch(
374 HookPoint::RequestPreDispatch,
375 &ComponentId::builtin("llm"),
376 None,
377 ctx,
378 );
379
380 if let HookAction::Continue(result) = action {
381 assert_eq!(result.payload, json!(["h10", "h50", "h100"]));
383 } else {
384 panic!("expected Continue");
385 }
386 }
387
388 #[test]
391 fn fql_filtering() {
392 let mut reg = HookRegistry::new();
393
394 let llm_only =
395 MockHook::pass_through("llm-hook", "builtin::llm", HookPoint::RequestPreDispatch);
396 let llm_counter = llm_only.call_count.clone();
397 reg.register(Box::new(llm_only));
398
399 let ctx = test_ctx(HookPoint::RequestPreDispatch);
400
401 reg.dispatch(
403 HookPoint::RequestPreDispatch,
404 &ComponentId::builtin("llm"),
405 None,
406 ctx.clone(),
407 );
408 assert_eq!(llm_counter.load(std::sync::atomic::Ordering::SeqCst), 1);
409
410 reg.dispatch(
412 HookPoint::RequestPreDispatch,
413 &ComponentId::builtin("hil"),
414 None,
415 ctx,
416 );
417 assert_eq!(llm_counter.load(std::sync::atomic::Ordering::SeqCst), 1);
418 }
419
420 #[test]
423 fn disabled_hook_skipped() {
424 let mut reg = HookRegistry::new();
425 let hook = MockHook::pass_through("h1", "*::*", HookPoint::RequestPreDispatch);
426 let counter = hook.call_count.clone();
427 reg.register(Box::new(hook));
428
429 reg.set_enabled("h1", false);
430
431 let ctx = test_ctx(HookPoint::RequestPreDispatch);
432 reg.dispatch(
433 HookPoint::RequestPreDispatch,
434 &ComponentId::builtin("llm"),
435 None,
436 ctx,
437 );
438
439 assert_eq!(counter.load(std::sync::atomic::Ordering::SeqCst), 0);
440 }
441
442 #[test]
443 fn re_enable_hook() {
444 let mut reg = HookRegistry::new();
445 let hook = MockHook::pass_through("h1", "*::*", HookPoint::RequestPreDispatch);
446 let counter = hook.call_count.clone();
447 reg.register(Box::new(hook));
448
449 reg.set_enabled("h1", false);
450 reg.set_enabled("h1", true);
451
452 let ctx = test_ctx(HookPoint::RequestPreDispatch);
453 reg.dispatch(
454 HookPoint::RequestPreDispatch,
455 &ComponentId::builtin("llm"),
456 None,
457 ctx,
458 );
459
460 assert_eq!(counter.load(std::sync::atomic::Ordering::SeqCst), 1);
461 }
462
463 #[test]
466 fn depth_exceeded_breaks_chain() {
467 let mut reg = HookRegistry::new();
468 let hook = MockHook::pass_through("h1", "*::*", HookPoint::RequestPreDispatch);
469 let counter = hook.call_count.clone();
470 reg.register(Box::new(hook));
471
472 let mut ctx = test_ctx(HookPoint::RequestPreDispatch);
473 ctx.depth = 4;
474 ctx.max_depth = 4;
475
476 reg.dispatch(
477 HookPoint::RequestPreDispatch,
478 &ComponentId::builtin("llm"),
479 None,
480 ctx,
481 );
482
483 assert_eq!(counter.load(std::sync::atomic::Ordering::SeqCst), 0);
485 }
486
487 #[test]
490 fn unregister_by_id() {
491 let mut reg = HookRegistry::new();
492 reg.register(Box::new(MockHook::pass_through(
493 "h1",
494 "*::*",
495 HookPoint::RequestPreDispatch,
496 )));
497 assert_eq!(reg.len(), 1);
498
499 assert!(reg.unregister("h1"));
500 assert_eq!(reg.len(), 0);
501
502 assert!(!reg.unregister("h1")); }
504
505 #[test]
506 fn unregister_by_owner() {
507 let mut reg = HookRegistry::new();
508 let owner = ComponentId::builtin("llm");
509
510 reg.register_owned(
511 Box::new(MockHook::pass_through(
512 "h1",
513 "*::*",
514 HookPoint::RequestPreDispatch,
515 )),
516 owner.clone(),
517 );
518 reg.register_owned(
519 Box::new(MockHook::pass_through(
520 "h2",
521 "*::*",
522 HookPoint::SignalPreDispatch,
523 )),
524 owner.clone(),
525 );
526 reg.register(Box::new(MockHook::pass_through(
527 "h3",
528 "*::*",
529 HookPoint::RequestPreDispatch,
530 )));
531
532 assert_eq!(reg.len(), 3);
533
534 let removed = reg.unregister_by_owner(&owner);
535 assert_eq!(removed, 2);
536 assert_eq!(reg.len(), 1); }
538
539 #[test]
542 fn post_hook_replace_updates_payload_and_continues_chain() {
543 let mut reg = HookRegistry::new();
544
545 let replacer = MockHook::replacer(
546 "replacer",
547 "*::*",
548 HookPoint::RequestPostDispatch,
549 json!({"replaced": true}),
550 )
551 .with_priority(10);
552
553 let observer = MockHook::pass_through("observer", "*::*", HookPoint::RequestPostDispatch)
554 .with_priority(20);
555 let observer_counter = observer.call_count.clone();
556
557 reg.register(Box::new(replacer));
558 reg.register(Box::new(observer));
559
560 let ctx = test_ctx(HookPoint::RequestPostDispatch);
561 let action = reg.dispatch(
562 HookPoint::RequestPostDispatch,
563 &ComponentId::builtin("llm"),
564 None,
565 ctx,
566 );
567
568 assert_eq!(
570 observer_counter.load(std::sync::atomic::Ordering::SeqCst),
571 1
572 );
573
574 if let HookAction::Continue(result) = action {
576 assert_eq!(result.payload, json!({"replaced": true}));
577 } else {
578 panic!("expected Continue");
579 }
580 }
581
582 #[test]
583 fn pre_hook_replace_is_ignored() {
584 let mut reg = HookRegistry::new();
585
586 let replacer = MockHook::replacer(
588 "bad-replacer",
589 "*::*",
590 HookPoint::RequestPreDispatch,
591 json!({"should_not_replace": true}),
592 );
593 reg.register(Box::new(replacer));
594
595 let ctx = test_ctx(HookPoint::RequestPreDispatch);
596 let original_payload = ctx.payload.clone();
597 let action = reg.dispatch(
598 HookPoint::RequestPreDispatch,
599 &ComponentId::builtin("llm"),
600 None,
601 ctx,
602 );
603
604 if let HookAction::Continue(result) = action {
606 assert_eq!(result.payload, original_payload);
607 } else {
608 panic!("expected Continue");
609 }
610 }
611
612 #[test]
615 fn chain_hooks_modify_sequentially() {
616 let mut reg = HookRegistry::new();
617
618 let h1 = MockHook::modifier("h1", "*::*", HookPoint::RequestPreDispatch, |ctx| {
619 if let Some(obj) = ctx.payload.as_object_mut() {
620 obj.insert("h1".into(), json!(true));
621 }
622 })
623 .with_priority(10);
624
625 let h2 = MockHook::modifier("h2", "*::*", HookPoint::RequestPreDispatch, |ctx| {
626 if let Some(obj) = ctx.payload.as_object_mut() {
627 obj.insert("h2".into(), json!(true));
628 }
629 })
630 .with_priority(20);
631
632 reg.register(Box::new(h1));
633 reg.register(Box::new(h2));
634
635 let ctx = test_ctx(HookPoint::RequestPreDispatch);
636 let action = reg.dispatch(
637 HookPoint::RequestPreDispatch,
638 &ComponentId::builtin("llm"),
639 None,
640 ctx,
641 );
642
643 if let HookAction::Continue(result) = action {
644 assert_eq!(result.payload["h1"], json!(true));
646 assert_eq!(result.payload["h2"], json!(true));
647 assert_eq!(result.payload["op"], json!("test"));
649 } else {
650 panic!("expected Continue");
651 }
652 }
653
654 #[test]
657 fn empty_registry() {
658 let reg = HookRegistry::new();
659 assert!(reg.is_empty());
660 assert_eq!(reg.len(), 0);
661 }
662
663 #[test]
664 fn len_counts_across_points() {
665 let mut reg = HookRegistry::new();
666 reg.register(Box::new(MockHook::pass_through(
667 "h1",
668 "*::*",
669 HookPoint::RequestPreDispatch,
670 )));
671 reg.register(Box::new(MockHook::pass_through(
672 "h2",
673 "*::*",
674 HookPoint::SignalPreDispatch,
675 )));
676 assert_eq!(reg.len(), 2);
677 assert!(!reg.is_empty());
678 }
679}