1use super::types::{
4 HookCallback, HookPoint, HookResult, OnCompactionContext, OnInteractionContext,
5 OnSessionEndContext, OnSessionStartContext, OnToolErrorContext, PostToolCallContext,
6 PostTurnContext, PreToolCallDecideContext, PreTurnContext,
7};
8
9pub struct Hooks {
55 callbacks: Vec<(HookPoint, String, HookCallback)>,
56}
57
58impl Hooks {
59 #[must_use]
61 pub const fn new() -> Self {
62 Self {
63 callbacks: Vec::new(),
64 }
65 }
66
67 pub fn register(&mut self, name: impl Into<String>, callback: HookCallback) -> &mut Self {
74 let point = callback.hook_point();
75 let name = name.into();
76 if let Some(pos) = self
77 .callbacks
78 .iter()
79 .position(|(p, n, _)| *p == point && n == &name)
80 {
81 tracing::warn!(
82 hook = %name,
83 point = %point.label(),
84 "duplicate hook name+point in Hooks — replacing previous callback"
85 );
86 self.callbacks[pos] = (point, name, callback);
87 } else {
88 tracing::debug!(hook = %name, point = %point.label(), "registered hook callback");
89 self.callbacks.push((point, name, callback));
90 }
91 self
92 }
93
94 pub fn run_pre_turn(&self, ctx: &PreTurnContext) {
96 for (_, name, cb) in self.iter_at(HookPoint::PreTurn) {
97 tracing::trace!(hook = %name, turn = ctx.turn_number, "firing pre_turn hook");
98 if let HookCallback::PreTurn(f) = cb {
99 let name = name.clone();
100 if let Err(panic) =
101 std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| f(ctx)))
102 {
103 tracing::error!(hook = %name, panic = ?panic, "pre_turn hook panicked — continuing");
104 }
105 }
106 }
107 }
108
109 pub fn run_post_turn(&self, ctx: &PostTurnContext) {
111 for (_, name, cb) in self.iter_at(HookPoint::PostTurn) {
112 tracing::trace!(hook = %name, turn = ctx.turn_number, "firing post_turn hook");
113 if let HookCallback::PostTurn(f) = cb {
114 let name = name.clone();
115 if let Err(panic) =
116 std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| f(ctx)))
117 {
118 tracing::error!(hook = %name, panic = ?panic, "post_turn hook panicked — continuing");
119 }
120 }
121 }
122 }
123
124 pub fn run_pre_tool_call_decide(&self, ctx: &PreToolCallDecideContext) -> HookResult {
132 for (_, name, cb) in self.iter_at(HookPoint::PreToolCallDecide) {
133 tracing::trace!(hook = %name, tool = %ctx.tool_name, "firing pre_tool_call_decide hook");
134 if let HookCallback::PreToolCallDecide(f) = cb {
135 let result = match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| f(ctx)))
136 {
137 Ok(r) => r,
138 Err(panic) => {
139 tracing::error!(
140 hook = %name,
141 tool = %ctx.tool_name,
142 panic = ?panic,
143 "pre_tool_call_decide hook panicked — denying tool call as safe default"
144 );
145 return HookResult::deny(format!(
146 "hook '{name}' panicked — tool call denied as safe default"
147 ));
148 }
149 };
150 if !result.allow {
151 tracing::info!(
152 hook = %name,
153 tool = %ctx.tool_name,
154 reason = %result.message,
155 "tool call denied by hook"
156 );
157 return result;
158 }
159 }
160 }
161 HookResult::allow()
162 }
163
164 pub fn run_post_tool_call(&self, ctx: &PostToolCallContext) {
166 for (_, name, cb) in self.iter_at(HookPoint::PostToolCall) {
167 tracing::trace!(hook = %name, tool = %ctx.tool_name, "firing post_tool_call hook");
168 if let HookCallback::PostToolCall(f) = cb {
169 let name = name.clone();
170 if let Err(panic) =
171 std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| f(ctx)))
172 {
173 tracing::error!(hook = %name, panic = ?panic, "post_tool_call hook panicked — continuing");
174 }
175 }
176 }
177 }
178
179 pub fn run_on_tool_error(&self, ctx: &OnToolErrorContext) {
181 for (_, name, cb) in self.iter_at(HookPoint::OnToolError) {
182 tracing::trace!(hook = %name, tool = %ctx.tool_name, error = %ctx.error, "firing on_tool_error hook");
183 if let HookCallback::OnToolError(f) = cb {
184 let name = name.clone();
185 if let Err(panic) =
186 std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| f(ctx)))
187 {
188 tracing::error!(hook = %name, panic = ?panic, "on_tool_error hook panicked — continuing");
189 }
190 }
191 }
192 }
193
194 pub fn run_on_session_start(&self, ctx: &OnSessionStartContext) {
196 for (_, name, cb) in self.iter_at(HookPoint::OnSessionStart) {
197 tracing::trace!(hook = %name, "firing on_session_start hook");
198 if let HookCallback::OnSessionStart(f) = cb {
199 let name = name.clone();
200 if let Err(panic) =
201 std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| f(ctx)))
202 {
203 tracing::error!(hook = %name, panic = ?panic, "on_session_start hook panicked — continuing");
204 }
205 }
206 }
207 }
208
209 pub fn run_on_session_end(&self, ctx: &OnSessionEndContext) {
211 for (_, name, cb) in self.iter_at(HookPoint::OnSessionEnd) {
212 tracing::trace!(hook = %name, "firing on_session_end hook");
213 if let HookCallback::OnSessionEnd(f) = cb {
214 let name = name.clone();
215 if let Err(panic) =
216 std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| f(ctx)))
217 {
218 tracing::error!(hook = %name, panic = ?panic, "on_session_end hook panicked — continuing");
219 }
220 }
221 }
222 }
223
224 pub fn run_on_compaction(&self, ctx: &OnCompactionContext) {
226 for (_, name, cb) in self.iter_at(HookPoint::OnCompaction) {
227 tracing::trace!(hook = %name, "firing on_compaction hook");
228 if let HookCallback::OnCompaction(f) = cb {
229 let name = name.clone();
230 if let Err(panic) =
231 std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| f(ctx)))
232 {
233 tracing::error!(hook = %name, panic = ?panic, "on_compaction hook panicked — continuing");
234 }
235 }
236 }
237 }
238
239 pub fn run_on_interaction(&self, ctx: &OnInteractionContext) -> HookResult {
244 for (_, name, cb) in self.iter_at(HookPoint::OnInteraction) {
245 tracing::trace!(hook = %name, "firing on_interaction hook");
246 if let HookCallback::OnInteraction(f) = cb {
247 let result = match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| f(ctx)))
248 {
249 Ok(r) => r,
250 Err(panic) => {
251 tracing::error!(
252 hook = %name,
253 panic = ?panic,
254 "on_interaction hook panicked — continuing"
255 );
256 continue;
257 }
258 };
259 if !result.allow {
260 return result;
261 }
262 }
263 }
264 HookResult::allow()
265 }
266
267 pub fn run_transform_tool_input(&self, ctx: &PreToolCallDecideContext) -> serde_json::Value {
277 let mut args = ctx.tool_args.clone();
278 for (_, name, cb) in self.iter_at(HookPoint::PreToolCallDecide) {
279 if let HookCallback::TransformToolInput(f) = cb {
280 let current_ctx = PreToolCallDecideContext {
281 tool_name: ctx.tool_name.clone(),
282 tool_args: args.clone(),
283 };
284 match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| f(¤t_ctx))) {
285 Ok(Some(new_args)) => {
286 tracing::debug!(
287 hook = %name,
288 tool = %ctx.tool_name,
289 "transform_tool_input hook modified tool arguments"
290 );
291 args = new_args;
292 }
293 Ok(None) => { }
294 Err(panic) => {
295 tracing::error!(
296 hook = %name,
297 tool = %ctx.tool_name,
298 panic = ?panic,
299 "transform_tool_input hook panicked — keeping current args"
300 );
301 }
302 }
303 }
304 }
305 args
306 }
307
308 pub fn on_pre_turn(
314 &mut self,
315 name: impl Into<String>,
316 f: impl Fn(&PreTurnContext) + Send + Sync + 'static,
317 ) -> &mut Self {
318 self.register(name, HookCallback::PreTurn(Box::new(f)))
319 }
320
321 pub fn on_post_turn(
325 &mut self,
326 name: impl Into<String>,
327 f: impl Fn(&PostTurnContext) + Send + Sync + 'static,
328 ) -> &mut Self {
329 self.register(name, HookCallback::PostTurn(Box::new(f)))
330 }
331
332 pub fn on_pre_tool_call_decide(
337 &mut self,
338 name: impl Into<String>,
339 f: impl Fn(&PreToolCallDecideContext) -> HookResult + Send + Sync + 'static,
340 ) -> &mut Self {
341 self.register(name, HookCallback::PreToolCallDecide(Box::new(f)))
342 }
343
344 pub fn on_post_tool_call(
348 &mut self,
349 name: impl Into<String>,
350 f: impl Fn(&PostToolCallContext) + Send + Sync + 'static,
351 ) -> &mut Self {
352 self.register(name, HookCallback::PostToolCall(Box::new(f)))
353 }
354
355 pub fn on_tool_error(
359 &mut self,
360 name: impl Into<String>,
361 f: impl Fn(&OnToolErrorContext) + Send + Sync + 'static,
362 ) -> &mut Self {
363 self.register(name, HookCallback::OnToolError(Box::new(f)))
364 }
365
366 pub fn on_compaction(
370 &mut self,
371 name: impl Into<String>,
372 f: impl Fn(&OnCompactionContext) + Send + Sync + 'static,
373 ) -> &mut Self {
374 self.register(name, HookCallback::OnCompaction(Box::new(f)))
375 }
376
377 pub fn on_interaction(
381 &mut self,
382 name: impl Into<String>,
383 f: impl Fn(&OnInteractionContext) -> HookResult + Send + Sync + 'static,
384 ) -> &mut Self {
385 self.register(name, HookCallback::OnInteraction(Box::new(f)))
386 }
387
388 pub fn on_session_start(
392 &mut self,
393 name: impl Into<String>,
394 f: impl Fn(&OnSessionStartContext) + Send + Sync + 'static,
395 ) -> &mut Self {
396 self.register(name, HookCallback::OnSessionStart(Box::new(f)))
397 }
398
399 pub fn on_session_end(
403 &mut self,
404 name: impl Into<String>,
405 f: impl Fn(&OnSessionEndContext) + Send + Sync + 'static,
406 ) -> &mut Self {
407 self.register(name, HookCallback::OnSessionEnd(Box::new(f)))
408 }
409
410 pub fn on_transform_tool_input(
416 &mut self,
417 name: impl Into<String>,
418 f: impl Fn(&PreToolCallDecideContext) -> Option<serde_json::Value> + Send + Sync + 'static,
419 ) -> &mut Self {
420 self.register(name, HookCallback::TransformToolInput(Box::new(f)))
421 }
422
423 #[must_use]
429 pub fn with_pre_turn(
430 mut self,
431 name: impl Into<String>,
432 f: impl Fn(&PreTurnContext) + Send + Sync + 'static,
433 ) -> Self {
434 self.on_pre_turn(name, f);
435 self
436 }
437
438 #[must_use]
442 pub fn with_post_turn(
443 mut self,
444 name: impl Into<String>,
445 f: impl Fn(&PostTurnContext) + Send + Sync + 'static,
446 ) -> Self {
447 self.on_post_turn(name, f);
448 self
449 }
450
451 #[must_use]
457 pub fn with_pre_tool_call_decide(
458 mut self,
459 name: impl Into<String>,
460 f: impl Fn(&PreToolCallDecideContext) -> HookResult + Send + Sync + 'static,
461 ) -> Self {
462 self.on_pre_tool_call_decide(name, f);
463 self
464 }
465
466 #[must_use]
472 pub fn with_post_tool_call(
473 mut self,
474 name: impl Into<String>,
475 f: impl Fn(&PostToolCallContext) + Send + Sync + 'static,
476 ) -> Self {
477 self.on_post_tool_call(name, f);
478 self
479 }
480
481 #[must_use]
487 pub fn with_tool_error(
488 mut self,
489 name: impl Into<String>,
490 f: impl Fn(&OnToolErrorContext) + Send + Sync + 'static,
491 ) -> Self {
492 self.on_tool_error(name, f);
493 self
494 }
495
496 #[must_use]
502 pub fn with_compaction(
503 mut self,
504 name: impl Into<String>,
505 f: impl Fn(&OnCompactionContext) + Send + Sync + 'static,
506 ) -> Self {
507 self.on_compaction(name, f);
508 self
509 }
510
511 #[must_use]
517 pub fn with_interaction(
518 mut self,
519 name: impl Into<String>,
520 f: impl Fn(&OnInteractionContext) -> HookResult + Send + Sync + 'static,
521 ) -> Self {
522 self.on_interaction(name, f);
523 self
524 }
525
526 #[must_use]
532 pub fn with_session_start(
533 mut self,
534 name: impl Into<String>,
535 f: impl Fn(&OnSessionStartContext) + Send + Sync + 'static,
536 ) -> Self {
537 self.on_session_start(name, f);
538 self
539 }
540
541 #[must_use]
547 pub fn with_session_end(
548 mut self,
549 name: impl Into<String>,
550 f: impl Fn(&OnSessionEndContext) + Send + Sync + 'static,
551 ) -> Self {
552 self.on_session_end(name, f);
553 self
554 }
555
556 #[must_use]
562 pub fn with_transform_tool_input(
563 mut self,
564 name: impl Into<String>,
565 f: impl Fn(&PreToolCallDecideContext) -> Option<serde_json::Value> + Send + Sync + 'static,
566 ) -> Self {
567 self.on_transform_tool_input(name, f);
568 self
569 }
570
571 fn iter_at(
573 &self,
574 point: HookPoint,
575 ) -> impl Iterator<Item = &(HookPoint, String, HookCallback)> {
576 self.callbacks.iter().filter(move |(p, _, _)| *p == point)
577 }
578
579 #[must_use]
586 pub fn entries(&self) -> Vec<super::types::HookEntry> {
587 self.callbacks
588 .iter()
589 .map(|(point, name, _)| super::types::HookEntry {
590 name: name.clone(),
591 point: *point,
592 callback_id: name.clone(),
593 })
594 .collect()
595 }
596}
597
598impl Default for Hooks {
599 fn default() -> Self {
600 Self::new()
601 }
602}
603
604#[cfg(test)]
605#[path = "runner_tests.rs"]
606mod tests;