agy_bridge/hooks/runner.rs
1//! Hook runner: registration and execution of lifecycle callbacks.
2
3use super::types::{
4 HookCallback, HookPoint, HookResult, OnCompactionContext, OnInteractionContext,
5 OnSessionEndContext, OnSessionStartContext, OnToolErrorContext, PostToolCallContext,
6 PostTurnContext, PreToolCallDecideContext, PreTurnContext,
7};
8
9// ── Hook runner ─────────────────────────────────────────────────────────────
10
11/// Stores and executes registered hook callbacks.
12///
13/// Callbacks at the same [`HookPoint`] fire in the order they were registered.
14///
15/// # Example
16///
17/// Fluent builder pattern (recommended):
18///
19/// ```
20/// use agy_bridge::hooks::{HookResult, Hooks, PreToolCallDecideContext, PreTurnContext};
21///
22/// let hooks = Hooks::new()
23/// .with_pre_turn("logger", |ctx: &PreTurnContext| {
24/// println!("Turn {} prompt: {}", ctx.turn_number, ctx.prompt);
25/// })
26/// .with_pre_tool_call_decide("gate", |ctx: &PreToolCallDecideContext| {
27/// if ctx.tool_name == "dangerous_tool" {
28/// HookResult::deny("blocked by policy")
29/// } else {
30/// HookResult::allow()
31/// }
32/// });
33///
34/// hooks.run_pre_turn(&PreTurnContext::new("hi", 1));
35/// let result = hooks.run_pre_tool_call_decide(&PreToolCallDecideContext::new(
36/// "safe_tool",
37/// serde_json::Value::Null,
38/// ));
39/// assert!(result.allow);
40/// ```
41///
42/// For conditional or loop-based registration, use the `on_*(&mut self)` methods:
43///
44/// ```
45/// # use agy_bridge::hooks::{HookResult, Hooks};
46/// let mut hooks = Hooks::new();
47/// hooks.on_pre_turn("logger", |ctx| {
48/// println!("Turn {}", ctx.turn_number);
49/// });
50/// ```
51pub struct Hooks {
52 callbacks: Vec<(HookPoint, String, HookCallback)>,
53}
54
55impl Hooks {
56 /// Create an empty hook runner.
57 #[must_use]
58 pub const fn new() -> Self {
59 Self {
60 callbacks: Vec::new(),
61 }
62 }
63
64 /// Register a named callback.
65 ///
66 /// The [`HookPoint`] is derived automatically from the callback variant.
67 /// If a callback with the same name AND hook point already exists, it is
68 /// replaced and a warning is logged.
69 /// Returns `&mut Self` for chaining.
70 pub fn register(&mut self, name: impl Into<String>, callback: HookCallback) -> &mut Self {
71 let point = callback.hook_point();
72 let name = name.into();
73 if let Some(pos) = self
74 .callbacks
75 .iter()
76 .position(|(p, n, _)| *p == point && n == &name)
77 {
78 tracing::warn!(
79 hook = %name,
80 point = %point.label(),
81 "duplicate hook name+point in Hooks — replacing previous callback"
82 );
83 self.callbacks[pos] = (point, name, callback);
84 } else {
85 tracing::debug!(hook = %name, point = %point.label(), "registered hook callback");
86 self.callbacks.push((point, name, callback));
87 }
88 self
89 }
90
91 /// Run all observer callbacks at the given [`HookPoint`], calling `invoke`
92 /// for each matching callback.
93 ///
94 /// Panics in individual callbacks are caught and logged; execution
95 /// continues with the remaining callbacks.
96 fn run_observer<F>(&self, point: HookPoint, mut invoke: F)
97 where
98 F: FnMut(&str, &HookCallback),
99 {
100 for (_, name, cb) in self.iter_at(point) {
101 let name_owned = name.clone();
102 if let Err(panic) =
103 std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| invoke(&name_owned, cb)))
104 {
105 tracing::error!(
106 hook = %name,
107 panic = ?panic,
108 "{} hook panicked — continuing", point.label(),
109 );
110 }
111 }
112 }
113
114 /// Run all [`HookPoint::PreTurn`] callbacks in registration order.
115 pub fn run_pre_turn(&self, ctx: &PreTurnContext) {
116 self.run_observer(HookPoint::PreTurn, |name, cb| {
117 tracing::trace!(hook = %name, turn = ctx.turn_number, "firing pre_turn hook");
118 if let HookCallback::PreTurn(f) = cb {
119 f(ctx);
120 }
121 });
122 }
123
124 /// Run all [`HookPoint::PostTurn`] callbacks in registration order.
125 pub fn run_post_turn(&self, ctx: &PostTurnContext) {
126 self.run_observer(HookPoint::PostTurn, |name, cb| {
127 tracing::trace!(hook = %name, turn = ctx.turn_number, "firing post_turn hook");
128 if let HookCallback::PostTurn(f) = cb {
129 f(ctx);
130 }
131 });
132 }
133
134 /// Run all [`HookPoint::PreToolCallDecide`] callbacks in registration order.
135 ///
136 /// If any callback returns [`HookResult`] with `allow: false`, execution
137 /// short-circuits and that deny result is returned immediately. Otherwise
138 /// returns [`HookResult::allow()`].
139 ///
140 /// If a callback panics, the tool call is denied as a safe default.
141 pub fn run_pre_tool_call_decide(&self, ctx: &PreToolCallDecideContext) -> HookResult {
142 for (_, name, cb) in self.iter_at(HookPoint::PreToolCallDecide) {
143 tracing::trace!(hook = %name, tool = %ctx.tool_name, "firing pre_tool_call_decide hook");
144 if let HookCallback::PreToolCallDecide(f) = cb {
145 let result = match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| f(ctx)))
146 {
147 Ok(r) => r,
148 Err(panic) => {
149 tracing::error!(
150 hook = %name,
151 tool = %ctx.tool_name,
152 panic = ?panic,
153 "pre_tool_call_decide hook panicked — denying tool call as safe default"
154 );
155 return HookResult::deny(format!(
156 "hook '{name}' panicked — tool call denied as safe default"
157 ));
158 }
159 };
160 if !result.allow {
161 tracing::info!(
162 hook = %name,
163 tool = %ctx.tool_name,
164 reason = %result.message,
165 "tool call denied by hook"
166 );
167 return result;
168 }
169 }
170 }
171 HookResult::allow()
172 }
173
174 /// Run all [`HookPoint::PostToolCall`] callbacks in registration order.
175 pub fn run_post_tool_call(&self, ctx: &PostToolCallContext) {
176 self.run_observer(HookPoint::PostToolCall, |name, cb| {
177 tracing::trace!(hook = %name, tool = %ctx.tool_name, "firing post_tool_call hook");
178 if let HookCallback::PostToolCall(f) = cb {
179 f(ctx);
180 }
181 });
182 }
183
184 /// Run all [`HookPoint::OnToolError`] callbacks in registration order.
185 pub fn run_on_tool_error(&self, ctx: &OnToolErrorContext) {
186 self.run_observer(HookPoint::OnToolError, |name, cb| {
187 tracing::trace!(hook = %name, tool = %ctx.tool_name, error = %ctx.error, "firing on_tool_error hook");
188 if let HookCallback::OnToolError(f) = cb {
189 f(ctx);
190 }
191 });
192 }
193
194 /// Run all [`HookPoint::OnSessionStart`] callbacks in registration order.
195 pub fn run_on_session_start(&self, ctx: &OnSessionStartContext) {
196 self.run_observer(HookPoint::OnSessionStart, |name, cb| {
197 tracing::trace!(hook = %name, "firing on_session_start hook");
198 if let HookCallback::OnSessionStart(f) = cb {
199 f(ctx);
200 }
201 });
202 }
203
204 /// Run all [`HookPoint::OnSessionEnd`] callbacks in registration order.
205 pub fn run_on_session_end(&self, ctx: &OnSessionEndContext) {
206 self.run_observer(HookPoint::OnSessionEnd, |name, cb| {
207 tracing::trace!(hook = %name, "firing on_session_end hook");
208 if let HookCallback::OnSessionEnd(f) = cb {
209 f(ctx);
210 }
211 });
212 }
213
214 /// Run all [`HookPoint::OnCompaction`] callbacks in registration order.
215 pub fn run_on_compaction(&self, ctx: &OnCompactionContext) {
216 self.run_observer(HookPoint::OnCompaction, |name, cb| {
217 tracing::trace!(hook = %name, "firing on_compaction hook");
218 if let HookCallback::OnCompaction(f) = cb {
219 f(ctx);
220 }
221 });
222 }
223
224 /// Run all [`HookPoint::OnInteraction`] callbacks in registration order.
225 ///
226 /// If a callback panics, the panic is logged and execution continues
227 /// (the interaction is not blocked).
228 pub fn run_on_interaction(&self, ctx: &OnInteractionContext) -> HookResult {
229 for (_, name, cb) in self.iter_at(HookPoint::OnInteraction) {
230 tracing::trace!(hook = %name, "firing on_interaction hook");
231 if let HookCallback::OnInteraction(f) = cb {
232 let result = match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| f(ctx)))
233 {
234 Ok(r) => r,
235 Err(panic) => {
236 tracing::error!(
237 hook = %name,
238 panic = ?panic,
239 "on_interaction hook panicked — continuing"
240 );
241 continue;
242 }
243 };
244 if !result.allow {
245 return result;
246 }
247 }
248 }
249 HookResult::allow()
250 }
251
252 /// Run all [`TransformToolInput`](HookCallback::TransformToolInput)
253 /// callbacks in registration order, threading the (possibly modified)
254 /// tool arguments through each transform.
255 ///
256 /// Returns the final tool arguments after all transforms have been
257 /// applied. If no transform returns `Some`, the original arguments
258 /// are returned unchanged.
259 ///
260 /// Panicking transforms are logged and skipped (original args kept).
261 pub fn run_transform_tool_input(&self, ctx: &PreToolCallDecideContext) -> serde_json::Value {
262 let mut args = ctx.tool_args.clone();
263 for (_, name, cb) in self.iter_at(HookPoint::PreToolCallDecide) {
264 if let HookCallback::TransformToolInput(f) = cb {
265 let current_ctx = PreToolCallDecideContext {
266 tool_name: ctx.tool_name.clone(),
267 tool_args: args.clone(),
268 };
269 match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| f(¤t_ctx))) {
270 Ok(Some(new_args)) => {
271 tracing::debug!(
272 hook = %name,
273 tool = %ctx.tool_name,
274 "transform_tool_input hook modified tool arguments"
275 );
276 args = new_args;
277 }
278 Ok(None) => { /* no modification */ }
279 Err(panic) => {
280 tracing::error!(
281 hook = %name,
282 tool = %ctx.tool_name,
283 panic = ?panic,
284 "transform_tool_input hook panicked — keeping current args"
285 );
286 }
287 }
288 }
289 }
290 args
291 }
292
293 // ── Convenience builder methods (Python decorator parity) ────────
294
295 /// Register a [`HookPoint::PreTurn`] callback.
296 ///
297 /// Convenience wrapper matching the Python SDK's `@on_pre_turn` decorator.
298 pub fn on_pre_turn(
299 &mut self,
300 name: impl Into<String>,
301 f: impl Fn(&PreTurnContext) + Send + Sync + 'static,
302 ) -> &mut Self {
303 self.register(name, HookCallback::PreTurn(Box::new(f)))
304 }
305
306 /// Register a [`HookPoint::PostTurn`] callback.
307 ///
308 /// Convenience wrapper matching the Python SDK's `@on_post_turn` decorator.
309 pub fn on_post_turn(
310 &mut self,
311 name: impl Into<String>,
312 f: impl Fn(&PostTurnContext) + Send + Sync + 'static,
313 ) -> &mut Self {
314 self.register(name, HookCallback::PostTurn(Box::new(f)))
315 }
316
317 /// Register a [`HookPoint::PreToolCallDecide`] callback.
318 ///
319 /// Convenience wrapper matching the Python SDK's `@on_pre_tool_call_decide`
320 /// decorator.
321 pub fn on_pre_tool_call_decide(
322 &mut self,
323 name: impl Into<String>,
324 f: impl Fn(&PreToolCallDecideContext) -> HookResult + Send + Sync + 'static,
325 ) -> &mut Self {
326 self.register(name, HookCallback::PreToolCallDecide(Box::new(f)))
327 }
328
329 /// Register a [`HookPoint::PostToolCall`] callback.
330 ///
331 /// Convenience wrapper matching the Python SDK's `@on_post_tool_call` decorator.
332 pub fn on_post_tool_call(
333 &mut self,
334 name: impl Into<String>,
335 f: impl Fn(&PostToolCallContext) + Send + Sync + 'static,
336 ) -> &mut Self {
337 self.register(name, HookCallback::PostToolCall(Box::new(f)))
338 }
339
340 /// Register a [`HookPoint::OnToolError`] callback.
341 ///
342 /// Convenience wrapper matching the Python SDK's `@on_tool_error` decorator.
343 pub fn on_tool_error(
344 &mut self,
345 name: impl Into<String>,
346 f: impl Fn(&OnToolErrorContext) + Send + Sync + 'static,
347 ) -> &mut Self {
348 self.register(name, HookCallback::OnToolError(Box::new(f)))
349 }
350
351 /// Register a [`HookPoint::OnCompaction`] callback.
352 ///
353 /// Convenience wrapper matching the Python SDK's `@on_compaction` decorator.
354 pub fn on_compaction(
355 &mut self,
356 name: impl Into<String>,
357 f: impl Fn(&OnCompactionContext) + Send + Sync + 'static,
358 ) -> &mut Self {
359 self.register(name, HookCallback::OnCompaction(Box::new(f)))
360 }
361
362 /// Register a [`HookPoint::OnInteraction`] callback.
363 ///
364 /// Convenience wrapper matching the Python SDK's `@on_interaction` decorator.
365 pub fn on_interaction(
366 &mut self,
367 name: impl Into<String>,
368 f: impl Fn(&OnInteractionContext) -> HookResult + Send + Sync + 'static,
369 ) -> &mut Self {
370 self.register(name, HookCallback::OnInteraction(Box::new(f)))
371 }
372
373 /// Register a [`HookPoint::OnSessionStart`] callback.
374 ///
375 /// Convenience wrapper matching the Python SDK's `@on_session_start` decorator.
376 pub fn on_session_start(
377 &mut self,
378 name: impl Into<String>,
379 f: impl Fn(&OnSessionStartContext) + Send + Sync + 'static,
380 ) -> &mut Self {
381 self.register(name, HookCallback::OnSessionStart(Box::new(f)))
382 }
383
384 /// Register a [`HookPoint::OnSessionEnd`] callback.
385 ///
386 /// Convenience wrapper matching the Python SDK's `@on_session_end` decorator.
387 pub fn on_session_end(
388 &mut self,
389 name: impl Into<String>,
390 f: impl Fn(&OnSessionEndContext) + Send + Sync + 'static,
391 ) -> &mut Self {
392 self.register(name, HookCallback::OnSessionEnd(Box::new(f)))
393 }
394
395 /// Register a [`TransformToolInput`](HookCallback::TransformToolInput) callback.
396 ///
397 /// The closure receives the pre-tool-call context and may return
398 /// `Some(new_args)` to replace tool arguments, or `None` to leave them
399 /// unchanged.
400 pub fn on_transform_tool_input(
401 &mut self,
402 name: impl Into<String>,
403 f: impl Fn(&PreToolCallDecideContext) -> Option<serde_json::Value> + Send + Sync + 'static,
404 ) -> &mut Self {
405 self.register(name, HookCallback::TransformToolInput(Box::new(f)))
406 }
407
408 // ── Owned-self builder methods (for fluent chaining) ────────────
409
410 /// Register a [`HookPoint::PreTurn`] callback, returning `self` for chaining.
411 ///
412 /// This is the owned-self variant of [`on_pre_turn`](Self::on_pre_turn).
413 #[must_use]
414 pub fn with_pre_turn(
415 mut self,
416 name: impl Into<String>,
417 f: impl Fn(&PreTurnContext) + Send + Sync + 'static,
418 ) -> Self {
419 self.on_pre_turn(name, f);
420 self
421 }
422
423 /// Register a [`HookPoint::PostTurn`] callback, returning `self` for chaining.
424 ///
425 /// This is the owned-self variant of [`on_post_turn`](Self::on_post_turn).
426 #[must_use]
427 pub fn with_post_turn(
428 mut self,
429 name: impl Into<String>,
430 f: impl Fn(&PostTurnContext) + Send + Sync + 'static,
431 ) -> Self {
432 self.on_post_turn(name, f);
433 self
434 }
435
436 /// Register a [`HookPoint::PreToolCallDecide`] callback, returning `self`
437 /// for chaining.
438 ///
439 /// This is the owned-self variant of
440 /// [`on_pre_tool_call_decide`](Self::on_pre_tool_call_decide).
441 #[must_use]
442 pub fn with_pre_tool_call_decide(
443 mut self,
444 name: impl Into<String>,
445 f: impl Fn(&PreToolCallDecideContext) -> HookResult + Send + Sync + 'static,
446 ) -> Self {
447 self.on_pre_tool_call_decide(name, f);
448 self
449 }
450
451 /// Register a [`HookPoint::PostToolCall`] callback, returning `self` for
452 /// chaining.
453 ///
454 /// This is the owned-self variant of
455 /// [`on_post_tool_call`](Self::on_post_tool_call).
456 #[must_use]
457 pub fn with_post_tool_call(
458 mut self,
459 name: impl Into<String>,
460 f: impl Fn(&PostToolCallContext) + Send + Sync + 'static,
461 ) -> Self {
462 self.on_post_tool_call(name, f);
463 self
464 }
465
466 /// Register a [`HookPoint::OnToolError`] callback, returning `self` for
467 /// chaining.
468 ///
469 /// This is the owned-self variant of
470 /// [`on_tool_error`](Self::on_tool_error).
471 #[must_use]
472 pub fn with_tool_error(
473 mut self,
474 name: impl Into<String>,
475 f: impl Fn(&OnToolErrorContext) + Send + Sync + 'static,
476 ) -> Self {
477 self.on_tool_error(name, f);
478 self
479 }
480
481 /// Register a [`HookPoint::OnCompaction`] callback, returning `self` for
482 /// chaining.
483 ///
484 /// This is the owned-self variant of
485 /// [`on_compaction`](Self::on_compaction).
486 #[must_use]
487 pub fn with_compaction(
488 mut self,
489 name: impl Into<String>,
490 f: impl Fn(&OnCompactionContext) + Send + Sync + 'static,
491 ) -> Self {
492 self.on_compaction(name, f);
493 self
494 }
495
496 /// Register a [`HookPoint::OnInteraction`] callback, returning `self` for
497 /// chaining.
498 ///
499 /// This is the owned-self variant of
500 /// [`on_interaction`](Self::on_interaction).
501 #[must_use]
502 pub fn with_interaction(
503 mut self,
504 name: impl Into<String>,
505 f: impl Fn(&OnInteractionContext) -> HookResult + Send + Sync + 'static,
506 ) -> Self {
507 self.on_interaction(name, f);
508 self
509 }
510
511 /// Register a [`HookPoint::OnSessionStart`] callback, returning `self`
512 /// for chaining.
513 ///
514 /// This is the owned-self variant of
515 /// [`on_session_start`](Self::on_session_start).
516 #[must_use]
517 pub fn with_session_start(
518 mut self,
519 name: impl Into<String>,
520 f: impl Fn(&OnSessionStartContext) + Send + Sync + 'static,
521 ) -> Self {
522 self.on_session_start(name, f);
523 self
524 }
525
526 /// Register a [`HookPoint::OnSessionEnd`] callback, returning `self` for
527 /// chaining.
528 ///
529 /// This is the owned-self variant of
530 /// [`on_session_end`](Self::on_session_end).
531 #[must_use]
532 pub fn with_session_end(
533 mut self,
534 name: impl Into<String>,
535 f: impl Fn(&OnSessionEndContext) + Send + Sync + 'static,
536 ) -> Self {
537 self.on_session_end(name, f);
538 self
539 }
540
541 /// Register a [`TransformToolInput`](HookCallback::TransformToolInput)
542 /// callback, returning `self` for chaining.
543 ///
544 /// This is the owned-self variant of
545 /// [`on_transform_tool_input`](Self::on_transform_tool_input).
546 #[must_use]
547 pub fn with_transform_tool_input(
548 mut self,
549 name: impl Into<String>,
550 f: impl Fn(&PreToolCallDecideContext) -> Option<serde_json::Value> + Send + Sync + 'static,
551 ) -> Self {
552 self.on_transform_tool_input(name, f);
553 self
554 }
555
556 /// Iterate callbacks at a given hook point in registration order.
557 fn iter_at(
558 &self,
559 point: HookPoint,
560 ) -> impl Iterator<Item = &(HookPoint, String, HookCallback)> {
561 self.callbacks.iter().filter(move |(p, _, _)| *p == point)
562 }
563
564 /// Extract a list of [`HookEntry`](super::types::HookEntry) objects
565 /// corresponding to the registered callbacks.
566 ///
567 /// This allows the `AgentBuilder` to automatically populate the agent's
568 /// configuration with the necessary entries to connect the Python SDK's
569 /// hook dispatcher back to the Rust runner.
570 #[must_use]
571 pub fn entries(&self) -> Vec<super::types::HookEntry> {
572 self.callbacks
573 .iter()
574 .map(|(point, name, _)| super::types::HookEntry {
575 name: name.clone(),
576 point: *point,
577 callback_id: name.clone(),
578 })
579 .collect()
580 }
581}
582
583impl Default for Hooks {
584 fn default() -> Self {
585 Self::new()
586 }
587}
588
589#[cfg(test)]
590#[path = "runner_tests.rs"]
591mod tests;