1use crate::HookError;
8use serde::{Deserialize, Serialize};
9use std::fmt;
10use std::str::FromStr;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
14pub enum HookPoint {
15 ComponentPreInit,
18 ComponentPostInit,
20 ComponentPreShutdown,
22 ComponentPostShutdown,
24
25 RequestPreDispatch,
28 RequestPostDispatch,
30
31 SignalPreDispatch,
34 SignalPostDispatch,
36
37 ChildPreSpawn,
40 ChildPostSpawn,
42 ChildPreRun,
44 ChildPostRun,
46
47 ChannelPreCreate,
50 ChannelPostCreate,
52 ChannelPreDestroy,
54 ChannelPostDestroy,
56
57 ToolPreExecute,
60 ToolPostExecute,
62
63 AuthPreCheck,
66 AuthPostCheck,
68 AuthOnGrant,
70
71 BusPreBroadcast,
74 BusPostBroadcast,
76 BusOnRegister,
78 BusOnUnregister,
80}
81
82impl HookPoint {
83 #[must_use]
85 pub fn is_pre(&self) -> bool {
86 matches!(
87 self,
88 Self::ComponentPreInit
89 | Self::ComponentPreShutdown
90 | Self::RequestPreDispatch
91 | Self::SignalPreDispatch
92 | Self::ChildPreSpawn
93 | Self::ChildPreRun
94 | Self::ChannelPreCreate
95 | Self::ChannelPreDestroy
96 | Self::ToolPreExecute
97 | Self::AuthPreCheck
98 | Self::BusPreBroadcast
99 )
100 }
101
102 #[must_use]
104 pub fn is_post(&self) -> bool {
105 matches!(
106 self,
107 Self::ComponentPostInit
108 | Self::ComponentPostShutdown
109 | Self::RequestPostDispatch
110 | Self::SignalPostDispatch
111 | Self::ChildPostSpawn
112 | Self::ChildPostRun
113 | Self::ChannelPostCreate
114 | Self::ChannelPostDestroy
115 | Self::ToolPostExecute
116 | Self::AuthPostCheck
117 | Self::BusPostBroadcast
118 )
119 }
120
121 #[must_use]
123 pub fn is_event(&self) -> bool {
124 !self.is_pre() && !self.is_post()
125 }
126
127 #[must_use]
129 pub fn as_str(&self) -> &'static str {
130 match self {
131 Self::ComponentPreInit => "component.pre_init",
132 Self::ComponentPostInit => "component.post_init",
133 Self::ComponentPreShutdown => "component.pre_shutdown",
134 Self::ComponentPostShutdown => "component.post_shutdown",
135 Self::RequestPreDispatch => "request.pre_dispatch",
136 Self::RequestPostDispatch => "request.post_dispatch",
137 Self::SignalPreDispatch => "signal.pre_dispatch",
138 Self::SignalPostDispatch => "signal.post_dispatch",
139 Self::ChildPreSpawn => "child.pre_spawn",
140 Self::ChildPostSpawn => "child.post_spawn",
141 Self::ChildPreRun => "child.pre_run",
142 Self::ChildPostRun => "child.post_run",
143 Self::ChannelPreCreate => "channel.pre_create",
144 Self::ChannelPostCreate => "channel.post_create",
145 Self::ChannelPreDestroy => "channel.pre_destroy",
146 Self::ChannelPostDestroy => "channel.post_destroy",
147 Self::ToolPreExecute => "tool.pre_execute",
148 Self::ToolPostExecute => "tool.post_execute",
149 Self::AuthPreCheck => "auth.pre_check",
150 Self::AuthPostCheck => "auth.post_check",
151 Self::AuthOnGrant => "auth.on_grant",
152 Self::BusPreBroadcast => "bus.pre_broadcast",
153 Self::BusPostBroadcast => "bus.post_broadcast",
154 Self::BusOnRegister => "bus.on_register",
155 Self::BusOnUnregister => "bus.on_unregister",
156 }
157 }
158
159 pub const KNOWN_PREFIXES: &'static [&'static str] = &[
161 "component.",
162 "request.",
163 "signal.",
164 "child.",
165 "channel.",
166 "tool.",
167 "auth.",
168 "bus.",
169 ];
170}
171
172impl FromStr for HookPoint {
173 type Err = HookError;
174
175 fn from_str(s: &str) -> Result<Self, Self::Err> {
176 match s {
177 "component.pre_init" => Ok(Self::ComponentPreInit),
178 "component.post_init" => Ok(Self::ComponentPostInit),
179 "component.pre_shutdown" => Ok(Self::ComponentPreShutdown),
180 "component.post_shutdown" => Ok(Self::ComponentPostShutdown),
181 "request.pre_dispatch" => Ok(Self::RequestPreDispatch),
182 "request.post_dispatch" => Ok(Self::RequestPostDispatch),
183 "signal.pre_dispatch" => Ok(Self::SignalPreDispatch),
184 "signal.post_dispatch" => Ok(Self::SignalPostDispatch),
185 "child.pre_spawn" => Ok(Self::ChildPreSpawn),
186 "child.post_spawn" => Ok(Self::ChildPostSpawn),
187 "child.pre_run" => Ok(Self::ChildPreRun),
188 "child.post_run" => Ok(Self::ChildPostRun),
189 "channel.pre_create" => Ok(Self::ChannelPreCreate),
190 "channel.post_create" => Ok(Self::ChannelPostCreate),
191 "channel.pre_destroy" => Ok(Self::ChannelPreDestroy),
192 "channel.post_destroy" => Ok(Self::ChannelPostDestroy),
193 "tool.pre_execute" => Ok(Self::ToolPreExecute),
194 "tool.post_execute" => Ok(Self::ToolPostExecute),
195 "auth.pre_check" => Ok(Self::AuthPreCheck),
196 "auth.post_check" => Ok(Self::AuthPostCheck),
197 "auth.on_grant" => Ok(Self::AuthOnGrant),
198 "bus.pre_broadcast" => Ok(Self::BusPreBroadcast),
199 "bus.post_broadcast" => Ok(Self::BusPostBroadcast),
200 "bus.on_register" => Ok(Self::BusOnRegister),
201 "bus.on_unregister" => Ok(Self::BusOnUnregister),
202 _ => Err(HookError::UnknownHookPoint(s.to_string())),
203 }
204 }
205}
206
207impl fmt::Display for HookPoint {
208 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
209 f.write_str(self.as_str())
210 }
211}
212
213#[cfg(test)]
214mod tests {
215 use super::*;
216
217 const ALL_POINTS: &[HookPoint] = &[
219 HookPoint::ComponentPreInit,
220 HookPoint::ComponentPostInit,
221 HookPoint::ComponentPreShutdown,
222 HookPoint::ComponentPostShutdown,
223 HookPoint::RequestPreDispatch,
224 HookPoint::RequestPostDispatch,
225 HookPoint::SignalPreDispatch,
226 HookPoint::SignalPostDispatch,
227 HookPoint::ChildPreSpawn,
228 HookPoint::ChildPostSpawn,
229 HookPoint::ChildPreRun,
230 HookPoint::ChildPostRun,
231 HookPoint::ChannelPreCreate,
232 HookPoint::ChannelPostCreate,
233 HookPoint::ChannelPreDestroy,
234 HookPoint::ChannelPostDestroy,
235 HookPoint::ToolPreExecute,
236 HookPoint::ToolPostExecute,
237 HookPoint::AuthPreCheck,
238 HookPoint::AuthPostCheck,
239 HookPoint::AuthOnGrant,
240 HookPoint::BusPreBroadcast,
241 HookPoint::BusPostBroadcast,
242 HookPoint::BusOnRegister,
243 HookPoint::BusOnUnregister,
244 ];
245
246 #[test]
247 fn all_variants_count() {
248 assert_eq!(ALL_POINTS.len(), 25);
249 }
250
251 #[test]
252 fn from_str_roundtrip_all() {
253 for &point in ALL_POINTS {
254 let s = point.to_string();
255 let parsed: HookPoint = s.parse().unwrap_or_else(|e| {
256 panic!("Failed to parse '{s}': {e}");
257 });
258 assert_eq!(parsed, point, "roundtrip failed for {s}");
259 }
260 }
261
262 #[test]
263 fn from_str_unknown() {
264 let result = "foo.bar".parse::<HookPoint>();
265 assert!(result.is_err());
266 assert!(matches!(
267 result.expect_err("unknown hook point 'foo.bar' should return error"),
268 HookError::UnknownHookPoint(_)
269 ));
270 }
271
272 #[test]
273 fn from_str_empty() {
274 let result = "".parse::<HookPoint>();
275 assert!(result.is_err());
276 }
277
278 #[test]
279 fn is_pre_correct() {
280 let pre_points = [
281 HookPoint::ComponentPreInit,
282 HookPoint::ComponentPreShutdown,
283 HookPoint::RequestPreDispatch,
284 HookPoint::SignalPreDispatch,
285 HookPoint::ChildPreSpawn,
286 HookPoint::ChildPreRun,
287 HookPoint::ChannelPreCreate,
288 HookPoint::ChannelPreDestroy,
289 HookPoint::ToolPreExecute,
290 HookPoint::AuthPreCheck,
291 HookPoint::BusPreBroadcast,
292 ];
293 for &point in &pre_points {
294 assert!(point.is_pre(), "{point} should be pre");
295 assert!(!point.is_post(), "{point} should not be post");
296 }
297 }
298
299 #[test]
300 fn is_post_correct() {
301 let post_points = [
302 HookPoint::ComponentPostInit,
303 HookPoint::ComponentPostShutdown,
304 HookPoint::RequestPostDispatch,
305 HookPoint::SignalPostDispatch,
306 HookPoint::ChildPostSpawn,
307 HookPoint::ChildPostRun,
308 HookPoint::ChannelPostCreate,
309 HookPoint::ChannelPostDestroy,
310 HookPoint::ToolPostExecute,
311 HookPoint::AuthPostCheck,
312 HookPoint::BusPostBroadcast,
313 ];
314 for &point in &post_points {
315 assert!(point.is_post(), "{point} should be post");
316 assert!(!point.is_pre(), "{point} should not be pre");
317 }
318 }
319
320 #[test]
321 fn event_hooks_are_neither_pre_nor_post() {
322 let event_points = [
323 HookPoint::AuthOnGrant,
324 HookPoint::BusOnRegister,
325 HookPoint::BusOnUnregister,
326 ];
327 for &point in &event_points {
328 assert!(!point.is_pre(), "{point} should not be pre");
329 assert!(!point.is_post(), "{point} should not be post");
330 assert!(point.is_event(), "{point} should be event");
331 }
332 }
333
334 #[test]
335 fn every_variant_is_exactly_one_category() {
336 for &point in ALL_POINTS {
337 let cats = [point.is_pre(), point.is_post(), point.is_event()];
338 let count = cats.iter().filter(|&&v| v).count();
339 assert_eq!(count, 1, "{point} should be in exactly 1 category");
340 }
341 }
342
343 #[test]
344 fn serde_roundtrip() {
345 for &point in ALL_POINTS {
346 let json = serde_json::to_string(&point).expect("HookPoint should serialize to JSON");
347 let restored: HookPoint =
348 serde_json::from_str(&json).expect("HookPoint should deserialize from JSON");
349 assert_eq!(restored, point);
350 }
351 }
352}