Skip to main content

orcs_hook/
registry.rs

1//! Hook registry — central dispatch for all hooks.
2//!
3//! Thread-safe: wrapped in `Arc<std::sync::RwLock<>>` at the engine level,
4//! following the same pattern as `SharedChannelHandles`.
5
6use crate::{Hook, HookAction, HookContext, HookPoint};
7use orcs_types::ComponentId;
8use std::collections::HashMap;
9
10/// A registered hook with metadata.
11struct RegisteredHook {
12    hook: Box<dyn Hook>,
13    enabled: bool,
14    /// The Component that owns this hook (for auto-unregister on shutdown).
15    /// Config-derived hooks have `owner: None`.
16    owner: Option<ComponentId>,
17}
18
19/// Central registry for all hooks.
20///
21/// Hooks are indexed by [`HookPoint`] for O(1) lookup.
22/// Within each point, hooks are sorted by priority (ascending).
23///
24/// # Concurrency
25///
26/// Use `Arc<std::sync::RwLock<HookRegistry>>` for concurrent access:
27/// - `dispatch()` takes `&self` (read lock)
28/// - `register()` / `unregister()` take `&mut self` (write lock)
29pub struct HookRegistry {
30    hooks: HashMap<HookPoint, Vec<RegisteredHook>>,
31}
32
33impl HookRegistry {
34    /// Creates an empty registry.
35    #[must_use]
36    pub fn new() -> Self {
37        Self {
38            hooks: HashMap::new(),
39        }
40    }
41
42    /// Registers a hook. Returns the hook's ID.
43    ///
44    /// The hook is inserted in priority order (ascending).
45    pub fn register(&mut self, hook: Box<dyn Hook>) -> String {
46        self.register_inner(hook, None)
47    }
48
49    /// Registers a hook owned by a component.
50    ///
51    /// Owned hooks are automatically unregistered when
52    /// `unregister_by_owner()` is called (e.g., on component shutdown).
53    pub fn register_owned(&mut self, hook: Box<dyn Hook>, owner: ComponentId) -> String {
54        self.register_inner(hook, Some(owner))
55    }
56
57    fn register_inner(&mut self, hook: Box<dyn Hook>, owner: Option<ComponentId>) -> String {
58        let id = hook.id().to_string();
59        let point = hook.hook_point();
60        let priority = hook.priority();
61
62        let entry = self.hooks.entry(point).or_default();
63
64        let rh = RegisteredHook {
65            hook,
66            enabled: true,
67            owner,
68        };
69
70        // Insert in priority order (stable: FIFO for same priority)
71        let pos = entry
72            .iter()
73            .position(|h| h.hook.priority() > priority)
74            .unwrap_or(entry.len());
75        entry.insert(pos, rh);
76
77        id
78    }
79
80    /// Unregisters a hook by ID. Returns `true` if found and removed.
81    pub fn unregister(&mut self, id: &str) -> bool {
82        let mut found = false;
83        for hooks in self.hooks.values_mut() {
84            let before = hooks.len();
85            hooks.retain(|rh| rh.hook.id() != id);
86            if hooks.len() < before {
87                found = true;
88            }
89        }
90        found
91    }
92
93    /// Unregisters all hooks owned by the given component.
94    ///
95    /// Returns the number of hooks removed.
96    pub fn unregister_by_owner(&mut self, owner: &ComponentId) -> usize {
97        let mut count = 0;
98        for hooks in self.hooks.values_mut() {
99            let before = hooks.len();
100            hooks.retain(|rh| rh.owner.as_ref() != Some(owner));
101            count += before - hooks.len();
102        }
103        count
104    }
105
106    /// Enables or disables a hook by ID.
107    pub fn set_enabled(&mut self, id: &str, enabled: bool) {
108        for hooks in self.hooks.values_mut() {
109            for rh in hooks.iter_mut() {
110                if rh.hook.id() == id {
111                    rh.enabled = enabled;
112                    return;
113                }
114            }
115        }
116    }
117
118    /// Returns the number of registered hooks.
119    #[must_use]
120    pub fn len(&self) -> usize {
121        self.hooks.values().map(|v| v.len()).sum()
122    }
123
124    /// Returns `true` if no hooks are registered.
125    #[must_use]
126    pub fn is_empty(&self) -> bool {
127        self.len() == 0
128    }
129
130    /// Dispatches hooks for the given point and target.
131    ///
132    /// Hooks are executed in priority order (ascending).
133    /// Chain semantics:
134    ///
135    /// - **Pre-hook**: `Skip` or `Abort` → stop chain immediately
136    /// - **Post-hook**: `Replace` → update payload, continue chain
137    /// - Disabled hooks and FQL non-matches are skipped
138    /// - Depth exceeded → break with warning
139    pub fn dispatch(
140        &self,
141        point: HookPoint,
142        component_id: &ComponentId,
143        child_id: Option<&str>,
144        ctx: HookContext,
145    ) -> HookAction {
146        let Some(hooks) = self.hooks.get(&point) else {
147            return HookAction::Continue(Box::new(ctx));
148        };
149
150        let mut current_ctx = ctx;
151
152        for rh in hooks.iter().filter(|rh| rh.enabled) {
153            if !rh.hook.fql_pattern().matches(component_id, child_id) {
154                continue;
155            }
156
157            // Depth check (recursion prevention)
158            if current_ctx.is_depth_exceeded() {
159                tracing::warn!(
160                    hook_id = rh.hook.id(),
161                    depth = current_ctx.depth,
162                    max_depth = current_ctx.max_depth,
163                    "hook chain depth exceeded, stopping chain"
164                );
165                break;
166            }
167
168            match rh.hook.execute(current_ctx.clone()) {
169                HookAction::Continue(new_ctx) => {
170                    current_ctx = *new_ctx;
171                }
172                action @ (HookAction::Skip(_) | HookAction::Abort { .. }) => {
173                    // Pre-hook aborted the chain
174                    return action;
175                }
176                HookAction::Replace(value) => {
177                    if point.is_post() {
178                        // Post-hook: replace payload, continue chain
179                        current_ctx.payload = value;
180                    } else {
181                        // Pre-hook: Replace is invalid → ignore with warning
182                        tracing::warn!(
183                            hook_id = rh.hook.id(),
184                            point = %point,
185                            "Replace returned from non-post hook, ignoring"
186                        );
187                    }
188                }
189            }
190        }
191
192        HookAction::Continue(Box::new(current_ctx))
193    }
194}
195
196impl Default for HookRegistry {
197    fn default() -> Self {
198        Self::new()
199    }
200}
201
202#[cfg(test)]
203mod tests {
204    use super::*;
205    use crate::hook::testing::MockHook;
206    use orcs_types::{ChannelId, Principal};
207    use serde_json::json;
208
209    fn test_ctx(point: HookPoint) -> HookContext {
210        HookContext::new(
211            point,
212            ComponentId::builtin("llm"),
213            ChannelId::new(),
214            Principal::System,
215            0,
216            json!({"op": "test"}),
217        )
218    }
219
220    // ── Basic dispatch ───────────────────────────────────────
221
222    #[test]
223    fn dispatch_no_hooks_returns_continue() {
224        let reg = HookRegistry::new();
225        let ctx = test_ctx(HookPoint::RequestPreDispatch);
226        let action = reg.dispatch(
227            HookPoint::RequestPreDispatch,
228            &ComponentId::builtin("llm"),
229            None,
230            ctx.clone(),
231        );
232        assert!(action.is_continue());
233        if let HookAction::Continue(result) = action {
234            assert_eq!(result.payload, ctx.payload);
235        }
236    }
237
238    #[test]
239    fn dispatch_pass_through_hook() {
240        let mut reg = HookRegistry::new();
241        let hook = MockHook::pass_through("h1", "*::*", HookPoint::RequestPreDispatch);
242        let counter = hook.call_count.clone();
243        reg.register(Box::new(hook));
244
245        let ctx = test_ctx(HookPoint::RequestPreDispatch);
246        let action = reg.dispatch(
247            HookPoint::RequestPreDispatch,
248            &ComponentId::builtin("llm"),
249            None,
250            ctx.clone(),
251        );
252
253        assert!(action.is_continue());
254        assert_eq!(counter.load(std::sync::atomic::Ordering::SeqCst), 1);
255    }
256
257    #[test]
258    fn dispatch_modifying_hook() {
259        let mut reg = HookRegistry::new();
260        let hook = MockHook::modifier("mod", "*::*", HookPoint::RequestPreDispatch, |ctx| {
261            ctx.payload = json!({"modified": true});
262        });
263        reg.register(Box::new(hook));
264
265        let ctx = test_ctx(HookPoint::RequestPreDispatch);
266        let action = reg.dispatch(
267            HookPoint::RequestPreDispatch,
268            &ComponentId::builtin("llm"),
269            None,
270            ctx,
271        );
272
273        if let HookAction::Continue(result) = action {
274            assert_eq!(result.payload, json!({"modified": true}));
275        } else {
276            panic!("expected Continue");
277        }
278    }
279
280    // ── Skip & Abort ─────────────────────────────────────────
281
282    #[test]
283    fn dispatch_skip_stops_chain() {
284        let mut reg = HookRegistry::new();
285        let skip = MockHook::skipper(
286            "skip",
287            "*::*",
288            HookPoint::RequestPreDispatch,
289            json!({"skipped": true}),
290        )
291        .with_priority(10);
292        let after = MockHook::pass_through("after", "*::*", HookPoint::RequestPreDispatch)
293            .with_priority(20);
294        let after_counter = after.call_count.clone();
295
296        reg.register(Box::new(skip));
297        reg.register(Box::new(after));
298
299        let ctx = test_ctx(HookPoint::RequestPreDispatch);
300        let action = reg.dispatch(
301            HookPoint::RequestPreDispatch,
302            &ComponentId::builtin("llm"),
303            None,
304            ctx,
305        );
306
307        assert!(action.is_skip());
308        // "after" hook should not have been called
309        assert_eq!(after_counter.load(std::sync::atomic::Ordering::SeqCst), 0);
310    }
311
312    #[test]
313    fn dispatch_abort_stops_chain() {
314        let mut reg = HookRegistry::new();
315        let abort = MockHook::aborter("abort", "*::*", HookPoint::RequestPreDispatch, "policy");
316        reg.register(Box::new(abort));
317
318        let ctx = test_ctx(HookPoint::RequestPreDispatch);
319        let action = reg.dispatch(
320            HookPoint::RequestPreDispatch,
321            &ComponentId::builtin("llm"),
322            None,
323            ctx,
324        );
325
326        assert!(action.is_abort());
327        if let HookAction::Abort { reason } = action {
328            assert_eq!(reason, "policy");
329        }
330    }
331
332    // ── Priority ordering ────────────────────────────────────
333
334    #[test]
335    fn priority_ordering() {
336        let mut reg = HookRegistry::new();
337
338        // Register in reverse priority order
339        let h100 = MockHook::modifier("h100", "*::*", HookPoint::RequestPreDispatch, |ctx| {
340            let arr = ctx
341                .payload
342                .as_array_mut()
343                .expect("payload should be a JSON array for priority ordering test");
344            arr.push(json!("h100"));
345        })
346        .with_priority(100);
347
348        let h10 = MockHook::modifier("h10", "*::*", HookPoint::RequestPreDispatch, |ctx| {
349            let arr = ctx
350                .payload
351                .as_array_mut()
352                .expect("payload should be a JSON array for h10 priority test");
353            arr.push(json!("h10"));
354        })
355        .with_priority(10);
356
357        let h50 = MockHook::modifier("h50", "*::*", HookPoint::RequestPreDispatch, |ctx| {
358            let arr = ctx
359                .payload
360                .as_array_mut()
361                .expect("payload should be a JSON array for h50 priority test");
362            arr.push(json!("h50"));
363        })
364        .with_priority(50);
365
366        reg.register(Box::new(h100));
367        reg.register(Box::new(h10));
368        reg.register(Box::new(h50));
369
370        let mut ctx = test_ctx(HookPoint::RequestPreDispatch);
371        ctx.payload = json!([]);
372
373        let action = reg.dispatch(
374            HookPoint::RequestPreDispatch,
375            &ComponentId::builtin("llm"),
376            None,
377            ctx,
378        );
379
380        if let HookAction::Continue(result) = action {
381            // Should be ordered by priority: 10, 50, 100
382            assert_eq!(result.payload, json!(["h10", "h50", "h100"]));
383        } else {
384            panic!("expected Continue");
385        }
386    }
387
388    // ── FQL filtering ────────────────────────────────────────
389
390    #[test]
391    fn fql_filtering() {
392        let mut reg = HookRegistry::new();
393
394        let llm_only =
395            MockHook::pass_through("llm-hook", "builtin::llm", HookPoint::RequestPreDispatch);
396        let llm_counter = llm_only.call_count.clone();
397        reg.register(Box::new(llm_only));
398
399        let ctx = test_ctx(HookPoint::RequestPreDispatch);
400
401        // Dispatch for LLM → should match
402        reg.dispatch(
403            HookPoint::RequestPreDispatch,
404            &ComponentId::builtin("llm"),
405            None,
406            ctx.clone(),
407        );
408        assert_eq!(llm_counter.load(std::sync::atomic::Ordering::SeqCst), 1);
409
410        // Dispatch for HIL → should NOT match
411        reg.dispatch(
412            HookPoint::RequestPreDispatch,
413            &ComponentId::builtin("hil"),
414            None,
415            ctx,
416        );
417        assert_eq!(llm_counter.load(std::sync::atomic::Ordering::SeqCst), 1);
418    }
419
420    // ── Enabled/disabled ─────────────────────────────────────
421
422    #[test]
423    fn disabled_hook_skipped() {
424        let mut reg = HookRegistry::new();
425        let hook = MockHook::pass_through("h1", "*::*", HookPoint::RequestPreDispatch);
426        let counter = hook.call_count.clone();
427        reg.register(Box::new(hook));
428
429        reg.set_enabled("h1", false);
430
431        let ctx = test_ctx(HookPoint::RequestPreDispatch);
432        reg.dispatch(
433            HookPoint::RequestPreDispatch,
434            &ComponentId::builtin("llm"),
435            None,
436            ctx,
437        );
438
439        assert_eq!(counter.load(std::sync::atomic::Ordering::SeqCst), 0);
440    }
441
442    #[test]
443    fn re_enable_hook() {
444        let mut reg = HookRegistry::new();
445        let hook = MockHook::pass_through("h1", "*::*", HookPoint::RequestPreDispatch);
446        let counter = hook.call_count.clone();
447        reg.register(Box::new(hook));
448
449        reg.set_enabled("h1", false);
450        reg.set_enabled("h1", true);
451
452        let ctx = test_ctx(HookPoint::RequestPreDispatch);
453        reg.dispatch(
454            HookPoint::RequestPreDispatch,
455            &ComponentId::builtin("llm"),
456            None,
457            ctx,
458        );
459
460        assert_eq!(counter.load(std::sync::atomic::Ordering::SeqCst), 1);
461    }
462
463    // ── Depth exceeded ───────────────────────────────────────
464
465    #[test]
466    fn depth_exceeded_breaks_chain() {
467        let mut reg = HookRegistry::new();
468        let hook = MockHook::pass_through("h1", "*::*", HookPoint::RequestPreDispatch);
469        let counter = hook.call_count.clone();
470        reg.register(Box::new(hook));
471
472        let mut ctx = test_ctx(HookPoint::RequestPreDispatch);
473        ctx.depth = 4;
474        ctx.max_depth = 4;
475
476        reg.dispatch(
477            HookPoint::RequestPreDispatch,
478            &ComponentId::builtin("llm"),
479            None,
480            ctx,
481        );
482
483        // Hook should NOT have been called
484        assert_eq!(counter.load(std::sync::atomic::Ordering::SeqCst), 0);
485    }
486
487    // ── Unregister ───────────────────────────────────────────
488
489    #[test]
490    fn unregister_by_id() {
491        let mut reg = HookRegistry::new();
492        reg.register(Box::new(MockHook::pass_through(
493            "h1",
494            "*::*",
495            HookPoint::RequestPreDispatch,
496        )));
497        assert_eq!(reg.len(), 1);
498
499        assert!(reg.unregister("h1"));
500        assert_eq!(reg.len(), 0);
501
502        assert!(!reg.unregister("h1")); // Already gone
503    }
504
505    #[test]
506    fn unregister_by_owner() {
507        let mut reg = HookRegistry::new();
508        let owner = ComponentId::builtin("llm");
509
510        reg.register_owned(
511            Box::new(MockHook::pass_through(
512                "h1",
513                "*::*",
514                HookPoint::RequestPreDispatch,
515            )),
516            owner.clone(),
517        );
518        reg.register_owned(
519            Box::new(MockHook::pass_through(
520                "h2",
521                "*::*",
522                HookPoint::SignalPreDispatch,
523            )),
524            owner.clone(),
525        );
526        reg.register(Box::new(MockHook::pass_through(
527            "h3",
528            "*::*",
529            HookPoint::RequestPreDispatch,
530        )));
531
532        assert_eq!(reg.len(), 3);
533
534        let removed = reg.unregister_by_owner(&owner);
535        assert_eq!(removed, 2);
536        assert_eq!(reg.len(), 1); // h3 remains (no owner)
537    }
538
539    // ── Post-hook Replace ────────────────────────────────────
540
541    #[test]
542    fn post_hook_replace_updates_payload_and_continues_chain() {
543        let mut reg = HookRegistry::new();
544
545        let replacer = MockHook::replacer(
546            "replacer",
547            "*::*",
548            HookPoint::RequestPostDispatch,
549            json!({"replaced": true}),
550        )
551        .with_priority(10);
552
553        let observer = MockHook::pass_through("observer", "*::*", HookPoint::RequestPostDispatch)
554            .with_priority(20);
555        let observer_counter = observer.call_count.clone();
556
557        reg.register(Box::new(replacer));
558        reg.register(Box::new(observer));
559
560        let ctx = test_ctx(HookPoint::RequestPostDispatch);
561        let action = reg.dispatch(
562            HookPoint::RequestPostDispatch,
563            &ComponentId::builtin("llm"),
564            None,
565            ctx,
566        );
567
568        // Chain should continue (not stop at Replace)
569        assert_eq!(
570            observer_counter.load(std::sync::atomic::Ordering::SeqCst),
571            1
572        );
573
574        // Final payload should be the replaced value
575        if let HookAction::Continue(result) = action {
576            assert_eq!(result.payload, json!({"replaced": true}));
577        } else {
578            panic!("expected Continue");
579        }
580    }
581
582    #[test]
583    fn pre_hook_replace_is_ignored() {
584        let mut reg = HookRegistry::new();
585
586        // Replace in a pre-hook should be ignored (treated as Continue)
587        let replacer = MockHook::replacer(
588            "bad-replacer",
589            "*::*",
590            HookPoint::RequestPreDispatch,
591            json!({"should_not_replace": true}),
592        );
593        reg.register(Box::new(replacer));
594
595        let ctx = test_ctx(HookPoint::RequestPreDispatch);
596        let original_payload = ctx.payload.clone();
597        let action = reg.dispatch(
598            HookPoint::RequestPreDispatch,
599            &ComponentId::builtin("llm"),
600            None,
601            ctx,
602        );
603
604        // Should continue with original payload (Replace ignored)
605        if let HookAction::Continue(result) = action {
606            assert_eq!(result.payload, original_payload);
607        } else {
608            panic!("expected Continue");
609        }
610    }
611
612    // ── Chain: multiple hooks modify sequentially ─────────────
613
614    #[test]
615    fn chain_hooks_modify_sequentially() {
616        let mut reg = HookRegistry::new();
617
618        let h1 = MockHook::modifier("h1", "*::*", HookPoint::RequestPreDispatch, |ctx| {
619            if let Some(obj) = ctx.payload.as_object_mut() {
620                obj.insert("h1".into(), json!(true));
621            }
622        })
623        .with_priority(10);
624
625        let h2 = MockHook::modifier("h2", "*::*", HookPoint::RequestPreDispatch, |ctx| {
626            if let Some(obj) = ctx.payload.as_object_mut() {
627                obj.insert("h2".into(), json!(true));
628            }
629        })
630        .with_priority(20);
631
632        reg.register(Box::new(h1));
633        reg.register(Box::new(h2));
634
635        let ctx = test_ctx(HookPoint::RequestPreDispatch);
636        let action = reg.dispatch(
637            HookPoint::RequestPreDispatch,
638            &ComponentId::builtin("llm"),
639            None,
640            ctx,
641        );
642
643        if let HookAction::Continue(result) = action {
644            // Both hooks should have added their keys
645            assert_eq!(result.payload["h1"], json!(true));
646            assert_eq!(result.payload["h2"], json!(true));
647            // Original key should still be there
648            assert_eq!(result.payload["op"], json!("test"));
649        } else {
650            panic!("expected Continue");
651        }
652    }
653
654    // ── Misc ─────────────────────────────────────────────────
655
656    #[test]
657    fn empty_registry() {
658        let reg = HookRegistry::new();
659        assert!(reg.is_empty());
660        assert_eq!(reg.len(), 0);
661    }
662
663    #[test]
664    fn len_counts_across_points() {
665        let mut reg = HookRegistry::new();
666        reg.register(Box::new(MockHook::pass_through(
667            "h1",
668            "*::*",
669            HookPoint::RequestPreDispatch,
670        )));
671        reg.register(Box::new(MockHook::pass_through(
672            "h2",
673            "*::*",
674            HookPoint::SignalPreDispatch,
675        )));
676        assert_eq!(reg.len(), 2);
677        assert!(!reg.is_empty());
678    }
679}