1#![deny(missing_docs)]
2use layer0::hook::{Hook, HookAction, HookContext};
11use std::sync::Arc;
12
13pub struct HookRegistry {
19 hooks: Vec<Arc<dyn Hook>>,
20}
21
22impl HookRegistry {
23 pub fn new() -> Self {
25 Self { hooks: Vec::new() }
26 }
27
28 pub fn add(&mut self, hook: Arc<dyn Hook>) {
30 self.hooks.push(hook);
31 }
32
33 pub async fn dispatch(&self, ctx: &HookContext) -> HookAction {
40 for hook in &self.hooks {
41 if !hook.points().contains(&ctx.point) {
43 continue;
44 }
45
46 match hook.on_event(ctx).await {
47 Ok(HookAction::Continue) => continue,
48 Ok(action) => return action,
49 Err(_e) => {
50 continue;
53 }
54 }
55 }
56
57 HookAction::Continue
58 }
59}
60
61impl Default for HookRegistry {
62 fn default() -> Self {
63 Self::new()
64 }
65}
66
67#[cfg(test)]
68mod tests {
69 use super::*;
70 use async_trait::async_trait;
71 use layer0::error::HookError;
72 use layer0::hook::HookPoint;
73
74 struct ContinueHook {
75 points: Vec<HookPoint>,
76 }
77
78 #[async_trait]
79 impl Hook for ContinueHook {
80 fn points(&self) -> &[HookPoint] {
81 &self.points
82 }
83 async fn on_event(&self, _ctx: &HookContext) -> Result<HookAction, HookError> {
84 Ok(HookAction::Continue)
85 }
86 }
87
88 struct HaltHook {
89 points: Vec<HookPoint>,
90 reason: String,
91 }
92
93 #[async_trait]
94 impl Hook for HaltHook {
95 fn points(&self) -> &[HookPoint] {
96 &self.points
97 }
98 async fn on_event(&self, _ctx: &HookContext) -> Result<HookAction, HookError> {
99 Ok(HookAction::Halt {
100 reason: self.reason.clone(),
101 })
102 }
103 }
104
105 struct ErrorHook {
106 points: Vec<HookPoint>,
107 }
108
109 #[async_trait]
110 impl Hook for ErrorHook {
111 fn points(&self) -> &[HookPoint] {
112 &self.points
113 }
114 async fn on_event(&self, _ctx: &HookContext) -> Result<HookAction, HookError> {
115 Err(HookError::Failed("hook error".into()))
116 }
117 }
118
119 #[tokio::test]
120 async fn empty_registry_returns_continue() {
121 let registry = HookRegistry::new();
122 let ctx = HookContext::new(HookPoint::PreInference);
123 let action = registry.dispatch(&ctx).await;
124 assert!(matches!(action, HookAction::Continue));
125 }
126
127 #[tokio::test]
128 async fn continue_hook_returns_continue() {
129 let mut registry = HookRegistry::new();
130 registry.add(Arc::new(ContinueHook {
131 points: vec![HookPoint::PreInference],
132 }));
133
134 let ctx = HookContext::new(HookPoint::PreInference);
135 let action = registry.dispatch(&ctx).await;
136 assert!(matches!(action, HookAction::Continue));
137 }
138
139 #[tokio::test]
140 async fn halt_hook_short_circuits() {
141 let mut registry = HookRegistry::new();
142 registry.add(Arc::new(HaltHook {
143 points: vec![HookPoint::PreInference],
144 reason: "budget exceeded".into(),
145 }));
146 registry.add(Arc::new(ContinueHook {
147 points: vec![HookPoint::PreInference],
148 }));
149
150 let ctx = HookContext::new(HookPoint::PreInference);
151 let action = registry.dispatch(&ctx).await;
152 match action {
153 HookAction::Halt { reason } => assert_eq!(reason, "budget exceeded"),
154 _ => panic!("expected Halt"),
155 }
156 }
157
158 #[tokio::test]
159 async fn hook_not_matching_point_is_skipped() {
160 let mut registry = HookRegistry::new();
161 registry.add(Arc::new(HaltHook {
162 points: vec![HookPoint::PostInference],
163 reason: "should not trigger".into(),
164 }));
165
166 let ctx = HookContext::new(HookPoint::PreInference);
167 let action = registry.dispatch(&ctx).await;
168 assert!(matches!(action, HookAction::Continue));
169 }
170
171 #[tokio::test]
172 async fn error_hook_treated_as_continue() {
173 let mut registry = HookRegistry::new();
174 registry.add(Arc::new(ErrorHook {
175 points: vec![HookPoint::PreInference],
176 }));
177
178 let ctx = HookContext::new(HookPoint::PreInference);
179 let action = registry.dispatch(&ctx).await;
180 assert!(matches!(action, HookAction::Continue));
181 }
182
183 #[tokio::test]
184 async fn multiple_continue_hooks_all_pass() {
185 let mut registry = HookRegistry::new();
186 registry.add(Arc::new(ContinueHook {
187 points: vec![HookPoint::PreInference],
188 }));
189 registry.add(Arc::new(ContinueHook {
190 points: vec![HookPoint::PreInference],
191 }));
192
193 let ctx = HookContext::new(HookPoint::PreInference);
194 let action = registry.dispatch(&ctx).await;
195 assert!(matches!(action, HookAction::Continue));
196 }
197
198 #[test]
199 fn default_registry_is_empty() {
200 let registry = HookRegistry::default();
201 let ctx = HookContext::new(HookPoint::PreInference);
202 let _ = registry;
204 let _ = ctx;
205 }
206}