Skip to main content

agy_bridge/hooks/
runner.rs

1//! Hook runner: registration and execution of lifecycle callbacks.
2
3use super::types::{
4    HookCallback, HookPoint, HookResult, OnCompactionContext, OnInteractionContext,
5    OnSessionEndContext, OnSessionStartContext, OnToolErrorContext, PostToolCallContext,
6    PostTurnContext, PreToolCallDecideContext, PreTurnContext,
7};
8
9// ── Hook runner ─────────────────────────────────────────────────────────────
10
11/// Stores and executes registered hook callbacks.
12///
13/// Callbacks at the same [`HookPoint`] fire in the order they were registered.
14///
15/// # Example
16///
17/// Fluent builder pattern (recommended):
18///
19/// ```
20/// use agy_bridge::hooks::{HookResult, Hooks, PreToolCallDecideContext, PreTurnContext};
21///
22/// let hooks = Hooks::new()
23///     .with_pre_turn("logger", |ctx: &PreTurnContext| {
24///         println!("Turn {} prompt: {}", ctx.turn_number, ctx.prompt);
25///     })
26///     .with_pre_tool_call_decide("gate", |ctx: &PreToolCallDecideContext| {
27///         if ctx.tool_name == "dangerous_tool" {
28///             HookResult::deny("blocked by policy")
29///         } else {
30///             HookResult::allow()
31///         }
32///     });
33///
34/// hooks.run_pre_turn(&PreTurnContext {
35///     prompt: "hi".into(),
36///     turn_number: 1,
37/// });
38/// let result = hooks.run_pre_tool_call_decide(&PreToolCallDecideContext {
39///     tool_name: "safe_tool".into(),
40///     tool_args: serde_json::Value::Null,
41/// });
42/// assert!(result.allow);
43/// ```
44///
45/// For conditional or loop-based registration, use the `on_*(&mut self)` methods:
46///
47/// ```
48/// # use agy_bridge::hooks::{HookResult, Hooks};
49/// let mut hooks = Hooks::new();
50/// hooks.on_pre_turn("logger", |ctx| {
51///     println!("Turn {}", ctx.turn_number);
52/// });
53/// ```
54pub struct Hooks {
55    callbacks: Vec<(HookPoint, String, HookCallback)>,
56}
57
58impl Hooks {
59    /// Create an empty hook runner.
60    #[must_use]
61    pub const fn new() -> Self {
62        Self {
63            callbacks: Vec::new(),
64        }
65    }
66
67    /// Register a named callback.
68    ///
69    /// The [`HookPoint`] is derived automatically from the callback variant.
70    /// If a callback with the same name AND hook point already exists, it is
71    /// replaced and a warning is logged.
72    /// Returns `&mut Self` for chaining.
73    pub fn register(&mut self, name: impl Into<String>, callback: HookCallback) -> &mut Self {
74        let point = callback.hook_point();
75        let name = name.into();
76        if let Some(pos) = self
77            .callbacks
78            .iter()
79            .position(|(p, n, _)| *p == point && n == &name)
80        {
81            tracing::warn!(
82                hook = %name,
83                point = %point.label(),
84                "duplicate hook name+point in Hooks — replacing previous callback"
85            );
86            self.callbacks[pos] = (point, name, callback);
87        } else {
88            tracing::debug!(hook = %name, point = %point.label(), "registered hook callback");
89            self.callbacks.push((point, name, callback));
90        }
91        self
92    }
93
94    /// Run all [`HookPoint::PreTurn`] callbacks in registration order.
95    pub fn run_pre_turn(&self, ctx: &PreTurnContext) {
96        for (_, name, cb) in self.iter_at(HookPoint::PreTurn) {
97            tracing::trace!(hook = %name, turn = ctx.turn_number, "firing pre_turn hook");
98            if let HookCallback::PreTurn(f) = cb {
99                let name = name.clone();
100                if let Err(panic) =
101                    std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| f(ctx)))
102                {
103                    tracing::error!(hook = %name, panic = ?panic, "pre_turn hook panicked — continuing");
104                }
105            }
106        }
107    }
108
109    /// Run all [`HookPoint::PostTurn`] callbacks in registration order.
110    pub fn run_post_turn(&self, ctx: &PostTurnContext) {
111        for (_, name, cb) in self.iter_at(HookPoint::PostTurn) {
112            tracing::trace!(hook = %name, turn = ctx.turn_number, "firing post_turn hook");
113            if let HookCallback::PostTurn(f) = cb {
114                let name = name.clone();
115                if let Err(panic) =
116                    std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| f(ctx)))
117                {
118                    tracing::error!(hook = %name, panic = ?panic, "post_turn hook panicked — continuing");
119                }
120            }
121        }
122    }
123
124    /// Run all [`HookPoint::PreToolCallDecide`] callbacks in registration order.
125    ///
126    /// If any callback returns [`HookResult`] with `allow: false`, execution
127    /// short-circuits and that deny result is returned immediately.  Otherwise
128    /// returns [`HookResult::allow()`].
129    ///
130    /// If a callback panics, the tool call is denied as a safe default.
131    pub fn run_pre_tool_call_decide(&self, ctx: &PreToolCallDecideContext) -> HookResult {
132        for (_, name, cb) in self.iter_at(HookPoint::PreToolCallDecide) {
133            tracing::trace!(hook = %name, tool = %ctx.tool_name, "firing pre_tool_call_decide hook");
134            if let HookCallback::PreToolCallDecide(f) = cb {
135                let result = match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| f(ctx)))
136                {
137                    Ok(r) => r,
138                    Err(panic) => {
139                        tracing::error!(
140                            hook = %name,
141                            tool = %ctx.tool_name,
142                            panic = ?panic,
143                            "pre_tool_call_decide hook panicked — denying tool call as safe default"
144                        );
145                        return HookResult::deny(format!(
146                            "hook '{name}' panicked — tool call denied as safe default"
147                        ));
148                    }
149                };
150                if !result.allow {
151                    tracing::info!(
152                        hook = %name,
153                        tool = %ctx.tool_name,
154                        reason = %result.message,
155                        "tool call denied by hook"
156                    );
157                    return result;
158                }
159            }
160        }
161        HookResult::allow()
162    }
163
164    /// Run all [`HookPoint::PostToolCall`] callbacks in registration order.
165    pub fn run_post_tool_call(&self, ctx: &PostToolCallContext) {
166        for (_, name, cb) in self.iter_at(HookPoint::PostToolCall) {
167            tracing::trace!(hook = %name, tool = %ctx.tool_name, "firing post_tool_call hook");
168            if let HookCallback::PostToolCall(f) = cb {
169                let name = name.clone();
170                if let Err(panic) =
171                    std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| f(ctx)))
172                {
173                    tracing::error!(hook = %name, panic = ?panic, "post_tool_call hook panicked — continuing");
174                }
175            }
176        }
177    }
178
179    /// Run all [`HookPoint::OnToolError`] callbacks in registration order.
180    pub fn run_on_tool_error(&self, ctx: &OnToolErrorContext) {
181        for (_, name, cb) in self.iter_at(HookPoint::OnToolError) {
182            tracing::trace!(hook = %name, tool = %ctx.tool_name, error = %ctx.error, "firing on_tool_error hook");
183            if let HookCallback::OnToolError(f) = cb {
184                let name = name.clone();
185                if let Err(panic) =
186                    std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| f(ctx)))
187                {
188                    tracing::error!(hook = %name, panic = ?panic, "on_tool_error hook panicked — continuing");
189                }
190            }
191        }
192    }
193
194    /// Run all [`HookPoint::OnSessionStart`] callbacks in registration order.
195    pub fn run_on_session_start(&self, ctx: &OnSessionStartContext) {
196        for (_, name, cb) in self.iter_at(HookPoint::OnSessionStart) {
197            tracing::trace!(hook = %name, "firing on_session_start hook");
198            if let HookCallback::OnSessionStart(f) = cb {
199                let name = name.clone();
200                if let Err(panic) =
201                    std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| f(ctx)))
202                {
203                    tracing::error!(hook = %name, panic = ?panic, "on_session_start hook panicked — continuing");
204                }
205            }
206        }
207    }
208
209    /// Run all [`HookPoint::OnSessionEnd`] callbacks in registration order.
210    pub fn run_on_session_end(&self, ctx: &OnSessionEndContext) {
211        for (_, name, cb) in self.iter_at(HookPoint::OnSessionEnd) {
212            tracing::trace!(hook = %name, "firing on_session_end hook");
213            if let HookCallback::OnSessionEnd(f) = cb {
214                let name = name.clone();
215                if let Err(panic) =
216                    std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| f(ctx)))
217                {
218                    tracing::error!(hook = %name, panic = ?panic, "on_session_end hook panicked — continuing");
219                }
220            }
221        }
222    }
223
224    /// Run all [`HookPoint::OnCompaction`] callbacks in registration order.
225    pub fn run_on_compaction(&self, ctx: &OnCompactionContext) {
226        for (_, name, cb) in self.iter_at(HookPoint::OnCompaction) {
227            tracing::trace!(hook = %name, "firing on_compaction hook");
228            if let HookCallback::OnCompaction(f) = cb {
229                let name = name.clone();
230                if let Err(panic) =
231                    std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| f(ctx)))
232                {
233                    tracing::error!(hook = %name, panic = ?panic, "on_compaction hook panicked — continuing");
234                }
235            }
236        }
237    }
238
239    /// Run all [`HookPoint::OnInteraction`] callbacks in registration order.
240    ///
241    /// If a callback panics, the panic is logged and execution continues
242    /// (the interaction is not blocked).
243    pub fn run_on_interaction(&self, ctx: &OnInteractionContext) -> HookResult {
244        for (_, name, cb) in self.iter_at(HookPoint::OnInteraction) {
245            tracing::trace!(hook = %name, "firing on_interaction hook");
246            if let HookCallback::OnInteraction(f) = cb {
247                let result = match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| f(ctx)))
248                {
249                    Ok(r) => r,
250                    Err(panic) => {
251                        tracing::error!(
252                            hook = %name,
253                            panic = ?panic,
254                            "on_interaction hook panicked — continuing"
255                        );
256                        continue;
257                    }
258                };
259                if !result.allow {
260                    return result;
261                }
262            }
263        }
264        HookResult::allow()
265    }
266
267    /// Run all [`TransformToolInput`](HookCallback::TransformToolInput)
268    /// callbacks in registration order, threading the (possibly modified)
269    /// tool arguments through each transform.
270    ///
271    /// Returns the final tool arguments after all transforms have been
272    /// applied.  If no transform returns `Some`, the original arguments
273    /// are returned unchanged.
274    ///
275    /// Panicking transforms are logged and skipped (original args kept).
276    pub fn run_transform_tool_input(&self, ctx: &PreToolCallDecideContext) -> serde_json::Value {
277        let mut args = ctx.tool_args.clone();
278        for (_, name, cb) in self.iter_at(HookPoint::PreToolCallDecide) {
279            if let HookCallback::TransformToolInput(f) = cb {
280                let current_ctx = PreToolCallDecideContext {
281                    tool_name: ctx.tool_name.clone(),
282                    tool_args: args.clone(),
283                };
284                match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| f(&current_ctx))) {
285                    Ok(Some(new_args)) => {
286                        tracing::debug!(
287                            hook = %name,
288                            tool = %ctx.tool_name,
289                            "transform_tool_input hook modified tool arguments"
290                        );
291                        args = new_args;
292                    }
293                    Ok(None) => { /* no modification */ }
294                    Err(panic) => {
295                        tracing::error!(
296                            hook = %name,
297                            tool = %ctx.tool_name,
298                            panic = ?panic,
299                            "transform_tool_input hook panicked — keeping current args"
300                        );
301                    }
302                }
303            }
304        }
305        args
306    }
307
308    // ── Convenience builder methods (Python decorator parity) ────────
309
310    /// Register a [`HookPoint::PreTurn`] callback.
311    ///
312    /// Convenience wrapper matching the Python SDK's `@on_pre_turn` decorator.
313    pub fn on_pre_turn(
314        &mut self,
315        name: impl Into<String>,
316        f: impl Fn(&PreTurnContext) + Send + Sync + 'static,
317    ) -> &mut Self {
318        self.register(name, HookCallback::PreTurn(Box::new(f)))
319    }
320
321    /// Register a [`HookPoint::PostTurn`] callback.
322    ///
323    /// Convenience wrapper matching the Python SDK's `@on_post_turn` decorator.
324    pub fn on_post_turn(
325        &mut self,
326        name: impl Into<String>,
327        f: impl Fn(&PostTurnContext) + Send + Sync + 'static,
328    ) -> &mut Self {
329        self.register(name, HookCallback::PostTurn(Box::new(f)))
330    }
331
332    /// Register a [`HookPoint::PreToolCallDecide`] callback.
333    ///
334    /// Convenience wrapper matching the Python SDK's `@on_pre_tool_call_decide`
335    /// decorator.
336    pub fn on_pre_tool_call_decide(
337        &mut self,
338        name: impl Into<String>,
339        f: impl Fn(&PreToolCallDecideContext) -> HookResult + Send + Sync + 'static,
340    ) -> &mut Self {
341        self.register(name, HookCallback::PreToolCallDecide(Box::new(f)))
342    }
343
344    /// Register a [`HookPoint::PostToolCall`] callback.
345    ///
346    /// Convenience wrapper matching the Python SDK's `@on_post_tool_call` decorator.
347    pub fn on_post_tool_call(
348        &mut self,
349        name: impl Into<String>,
350        f: impl Fn(&PostToolCallContext) + Send + Sync + 'static,
351    ) -> &mut Self {
352        self.register(name, HookCallback::PostToolCall(Box::new(f)))
353    }
354
355    /// Register a [`HookPoint::OnToolError`] callback.
356    ///
357    /// Convenience wrapper matching the Python SDK's `@on_tool_error` decorator.
358    pub fn on_tool_error(
359        &mut self,
360        name: impl Into<String>,
361        f: impl Fn(&OnToolErrorContext) + Send + Sync + 'static,
362    ) -> &mut Self {
363        self.register(name, HookCallback::OnToolError(Box::new(f)))
364    }
365
366    /// Register a [`HookPoint::OnCompaction`] callback.
367    ///
368    /// Convenience wrapper matching the Python SDK's `@on_compaction` decorator.
369    pub fn on_compaction(
370        &mut self,
371        name: impl Into<String>,
372        f: impl Fn(&OnCompactionContext) + Send + Sync + 'static,
373    ) -> &mut Self {
374        self.register(name, HookCallback::OnCompaction(Box::new(f)))
375    }
376
377    /// Register a [`HookPoint::OnInteraction`] callback.
378    ///
379    /// Convenience wrapper matching the Python SDK's `@on_interaction` decorator.
380    pub fn on_interaction(
381        &mut self,
382        name: impl Into<String>,
383        f: impl Fn(&OnInteractionContext) -> HookResult + Send + Sync + 'static,
384    ) -> &mut Self {
385        self.register(name, HookCallback::OnInteraction(Box::new(f)))
386    }
387
388    /// Register a [`HookPoint::OnSessionStart`] callback.
389    ///
390    /// Convenience wrapper matching the Python SDK's `@on_session_start` decorator.
391    pub fn on_session_start(
392        &mut self,
393        name: impl Into<String>,
394        f: impl Fn(&OnSessionStartContext) + Send + Sync + 'static,
395    ) -> &mut Self {
396        self.register(name, HookCallback::OnSessionStart(Box::new(f)))
397    }
398
399    /// Register a [`HookPoint::OnSessionEnd`] callback.
400    ///
401    /// Convenience wrapper matching the Python SDK's `@on_session_end` decorator.
402    pub fn on_session_end(
403        &mut self,
404        name: impl Into<String>,
405        f: impl Fn(&OnSessionEndContext) + Send + Sync + 'static,
406    ) -> &mut Self {
407        self.register(name, HookCallback::OnSessionEnd(Box::new(f)))
408    }
409
410    /// Register a [`TransformToolInput`](HookCallback::TransformToolInput) callback.
411    ///
412    /// The closure receives the pre-tool-call context and may return
413    /// `Some(new_args)` to replace tool arguments, or `None` to leave them
414    /// unchanged.
415    pub fn on_transform_tool_input(
416        &mut self,
417        name: impl Into<String>,
418        f: impl Fn(&PreToolCallDecideContext) -> Option<serde_json::Value> + Send + Sync + 'static,
419    ) -> &mut Self {
420        self.register(name, HookCallback::TransformToolInput(Box::new(f)))
421    }
422
423    // ── Owned-self builder methods (for fluent chaining) ────────────
424
425    /// Register a [`HookPoint::PreTurn`] callback, returning `self` for chaining.
426    ///
427    /// This is the owned-self variant of [`on_pre_turn`](Self::on_pre_turn).
428    #[must_use]
429    pub fn with_pre_turn(
430        mut self,
431        name: impl Into<String>,
432        f: impl Fn(&PreTurnContext) + Send + Sync + 'static,
433    ) -> Self {
434        self.on_pre_turn(name, f);
435        self
436    }
437
438    /// Register a [`HookPoint::PostTurn`] callback, returning `self` for chaining.
439    ///
440    /// This is the owned-self variant of [`on_post_turn`](Self::on_post_turn).
441    #[must_use]
442    pub fn with_post_turn(
443        mut self,
444        name: impl Into<String>,
445        f: impl Fn(&PostTurnContext) + Send + Sync + 'static,
446    ) -> Self {
447        self.on_post_turn(name, f);
448        self
449    }
450
451    /// Register a [`HookPoint::PreToolCallDecide`] callback, returning `self`
452    /// for chaining.
453    ///
454    /// This is the owned-self variant of
455    /// [`on_pre_tool_call_decide`](Self::on_pre_tool_call_decide).
456    #[must_use]
457    pub fn with_pre_tool_call_decide(
458        mut self,
459        name: impl Into<String>,
460        f: impl Fn(&PreToolCallDecideContext) -> HookResult + Send + Sync + 'static,
461    ) -> Self {
462        self.on_pre_tool_call_decide(name, f);
463        self
464    }
465
466    /// Register a [`HookPoint::PostToolCall`] callback, returning `self` for
467    /// chaining.
468    ///
469    /// This is the owned-self variant of
470    /// [`on_post_tool_call`](Self::on_post_tool_call).
471    #[must_use]
472    pub fn with_post_tool_call(
473        mut self,
474        name: impl Into<String>,
475        f: impl Fn(&PostToolCallContext) + Send + Sync + 'static,
476    ) -> Self {
477        self.on_post_tool_call(name, f);
478        self
479    }
480
481    /// Register a [`HookPoint::OnToolError`] callback, returning `self` for
482    /// chaining.
483    ///
484    /// This is the owned-self variant of
485    /// [`on_tool_error`](Self::on_tool_error).
486    #[must_use]
487    pub fn with_tool_error(
488        mut self,
489        name: impl Into<String>,
490        f: impl Fn(&OnToolErrorContext) + Send + Sync + 'static,
491    ) -> Self {
492        self.on_tool_error(name, f);
493        self
494    }
495
496    /// Register a [`HookPoint::OnCompaction`] callback, returning `self` for
497    /// chaining.
498    ///
499    /// This is the owned-self variant of
500    /// [`on_compaction`](Self::on_compaction).
501    #[must_use]
502    pub fn with_compaction(
503        mut self,
504        name: impl Into<String>,
505        f: impl Fn(&OnCompactionContext) + Send + Sync + 'static,
506    ) -> Self {
507        self.on_compaction(name, f);
508        self
509    }
510
511    /// Register a [`HookPoint::OnInteraction`] callback, returning `self` for
512    /// chaining.
513    ///
514    /// This is the owned-self variant of
515    /// [`on_interaction`](Self::on_interaction).
516    #[must_use]
517    pub fn with_interaction(
518        mut self,
519        name: impl Into<String>,
520        f: impl Fn(&OnInteractionContext) -> HookResult + Send + Sync + 'static,
521    ) -> Self {
522        self.on_interaction(name, f);
523        self
524    }
525
526    /// Register a [`HookPoint::OnSessionStart`] callback, returning `self`
527    /// for chaining.
528    ///
529    /// This is the owned-self variant of
530    /// [`on_session_start`](Self::on_session_start).
531    #[must_use]
532    pub fn with_session_start(
533        mut self,
534        name: impl Into<String>,
535        f: impl Fn(&OnSessionStartContext) + Send + Sync + 'static,
536    ) -> Self {
537        self.on_session_start(name, f);
538        self
539    }
540
541    /// Register a [`HookPoint::OnSessionEnd`] callback, returning `self` for
542    /// chaining.
543    ///
544    /// This is the owned-self variant of
545    /// [`on_session_end`](Self::on_session_end).
546    #[must_use]
547    pub fn with_session_end(
548        mut self,
549        name: impl Into<String>,
550        f: impl Fn(&OnSessionEndContext) + Send + Sync + 'static,
551    ) -> Self {
552        self.on_session_end(name, f);
553        self
554    }
555
556    /// Register a [`TransformToolInput`](HookCallback::TransformToolInput)
557    /// callback, returning `self` for chaining.
558    ///
559    /// This is the owned-self variant of
560    /// [`on_transform_tool_input`](Self::on_transform_tool_input).
561    #[must_use]
562    pub fn with_transform_tool_input(
563        mut self,
564        name: impl Into<String>,
565        f: impl Fn(&PreToolCallDecideContext) -> Option<serde_json::Value> + Send + Sync + 'static,
566    ) -> Self {
567        self.on_transform_tool_input(name, f);
568        self
569    }
570
571    /// Iterate callbacks at a given hook point in registration order.
572    fn iter_at(
573        &self,
574        point: HookPoint,
575    ) -> impl Iterator<Item = &(HookPoint, String, HookCallback)> {
576        self.callbacks.iter().filter(move |(p, _, _)| *p == point)
577    }
578
579    /// Extract a list of [`HookEntry`](super::types::HookEntry) objects
580    /// corresponding to the registered callbacks.
581    ///
582    /// This allows the `AgentBuilder` to automatically populate the agent's
583    /// configuration with the necessary entries to connect the Python SDK's
584    /// hook dispatcher back to the Rust runner.
585    #[must_use]
586    pub fn entries(&self) -> Vec<super::types::HookEntry> {
587        self.callbacks
588            .iter()
589            .map(|(point, name, _)| super::types::HookEntry {
590                name: name.clone(),
591                point: *point,
592                callback_id: name.clone(),
593            })
594            .collect()
595    }
596}
597
598impl Default for Hooks {
599    fn default() -> Self {
600        Self::new()
601    }
602}
603
604#[cfg(test)]
605#[path = "runner_tests.rs"]
606mod tests;