Skip to main content

agy_bridge/hooks/
runner.rs

1//! Hook runner: registration and execution of lifecycle callbacks.
2
3use super::types::{
4    HookCallback, HookPoint, HookResult, OnCompactionContext, OnInteractionContext,
5    OnSessionEndContext, OnSessionStartContext, OnToolErrorContext, PostToolCallContext,
6    PostTurnContext, PreToolCallDecideContext, PreTurnContext,
7};
8
9// ── Hook runner ─────────────────────────────────────────────────────────────
10
11/// Stores and executes registered hook callbacks.
12///
13/// Callbacks at the same [`HookPoint`] fire in the order they were registered.
14///
15/// # Example
16///
17/// Fluent builder pattern (recommended):
18///
19/// ```
20/// use agy_bridge::hooks::{HookResult, Hooks, PreToolCallDecideContext, PreTurnContext};
21///
22/// let hooks = Hooks::new()
23///     .with_pre_turn("logger", |ctx: &PreTurnContext| {
24///         println!("Turn {} prompt: {}", ctx.turn_number, ctx.prompt);
25///     })
26///     .with_pre_tool_call_decide("gate", |ctx: &PreToolCallDecideContext| {
27///         if ctx.tool_name == "dangerous_tool" {
28///             HookResult::deny("blocked by policy")
29///         } else {
30///             HookResult::allow()
31///         }
32///     });
33///
34/// hooks.run_pre_turn(&PreTurnContext::new("hi", 1));
35/// let result = hooks.run_pre_tool_call_decide(&PreToolCallDecideContext::new(
36///     "safe_tool",
37///     serde_json::Value::Null,
38/// ));
39/// assert!(result.allow);
40/// ```
41///
42/// For conditional or loop-based registration, use the `on_*(&mut self)` methods:
43///
44/// ```
45/// # use agy_bridge::hooks::{HookResult, Hooks};
46/// let mut hooks = Hooks::new();
47/// hooks.on_pre_turn("logger", |ctx| {
48///     println!("Turn {}", ctx.turn_number);
49/// });
50/// ```
51pub struct Hooks {
52    callbacks: Vec<(HookPoint, String, HookCallback)>,
53}
54
55impl Hooks {
56    /// Create an empty hook runner.
57    #[must_use]
58    pub const fn new() -> Self {
59        Self {
60            callbacks: Vec::new(),
61        }
62    }
63
64    /// Register a named callback.
65    ///
66    /// The [`HookPoint`] is derived automatically from the callback variant.
67    /// If a callback with the same name AND hook point already exists, it is
68    /// replaced and a warning is logged.
69    /// Returns `&mut Self` for chaining.
70    pub fn register(&mut self, name: impl Into<String>, callback: HookCallback) -> &mut Self {
71        let point = callback.hook_point();
72        let name = name.into();
73        if let Some(pos) = self
74            .callbacks
75            .iter()
76            .position(|(p, n, _)| *p == point && n == &name)
77        {
78            tracing::warn!(
79                hook = %name,
80                point = %point.label(),
81                "duplicate hook name+point in Hooks — replacing previous callback"
82            );
83            self.callbacks[pos] = (point, name, callback);
84        } else {
85            tracing::debug!(hook = %name, point = %point.label(), "registered hook callback");
86            self.callbacks.push((point, name, callback));
87        }
88        self
89    }
90
91    /// Run all observer callbacks at the given [`HookPoint`], calling `invoke`
92    /// for each matching callback.
93    ///
94    /// Panics in individual callbacks are caught and logged; execution
95    /// continues with the remaining callbacks.
96    fn run_observer<F>(&self, point: HookPoint, mut invoke: F)
97    where
98        F: FnMut(&str, &HookCallback),
99    {
100        for (_, name, cb) in self.iter_at(point) {
101            let name_owned = name.clone();
102            if let Err(panic) =
103                std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| invoke(&name_owned, cb)))
104            {
105                tracing::error!(
106                    hook = %name,
107                    panic = ?panic,
108                    "{} hook panicked — continuing", point.label(),
109                );
110            }
111        }
112    }
113
114    /// Run all [`HookPoint::PreTurn`] callbacks in registration order.
115    pub fn run_pre_turn(&self, ctx: &PreTurnContext) {
116        self.run_observer(HookPoint::PreTurn, |name, cb| {
117            tracing::trace!(hook = %name, turn = ctx.turn_number, "firing pre_turn hook");
118            if let HookCallback::PreTurn(f) = cb {
119                f(ctx);
120            }
121        });
122    }
123
124    /// Run all [`HookPoint::PostTurn`] callbacks in registration order.
125    pub fn run_post_turn(&self, ctx: &PostTurnContext) {
126        self.run_observer(HookPoint::PostTurn, |name, cb| {
127            tracing::trace!(hook = %name, turn = ctx.turn_number, "firing post_turn hook");
128            if let HookCallback::PostTurn(f) = cb {
129                f(ctx);
130            }
131        });
132    }
133
134    /// Run all [`HookPoint::PreToolCallDecide`] callbacks in registration order.
135    ///
136    /// If any callback returns [`HookResult`] with `allow: false`, execution
137    /// short-circuits and that deny result is returned immediately.  Otherwise
138    /// returns [`HookResult::allow()`].
139    ///
140    /// If a callback panics, the tool call is denied as a safe default.
141    pub fn run_pre_tool_call_decide(&self, ctx: &PreToolCallDecideContext) -> HookResult {
142        for (_, name, cb) in self.iter_at(HookPoint::PreToolCallDecide) {
143            tracing::trace!(hook = %name, tool = %ctx.tool_name, "firing pre_tool_call_decide hook");
144            if let HookCallback::PreToolCallDecide(f) = cb {
145                let result = match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| f(ctx)))
146                {
147                    Ok(r) => r,
148                    Err(panic) => {
149                        tracing::error!(
150                            hook = %name,
151                            tool = %ctx.tool_name,
152                            panic = ?panic,
153                            "pre_tool_call_decide hook panicked — denying tool call as safe default"
154                        );
155                        return HookResult::deny(format!(
156                            "hook '{name}' panicked — tool call denied as safe default"
157                        ));
158                    }
159                };
160                if !result.allow {
161                    tracing::info!(
162                        hook = %name,
163                        tool = %ctx.tool_name,
164                        reason = %result.message,
165                        "tool call denied by hook"
166                    );
167                    return result;
168                }
169            }
170        }
171        HookResult::allow()
172    }
173
174    /// Run all [`HookPoint::PostToolCall`] callbacks in registration order.
175    pub fn run_post_tool_call(&self, ctx: &PostToolCallContext) {
176        self.run_observer(HookPoint::PostToolCall, |name, cb| {
177            tracing::trace!(hook = %name, tool = %ctx.tool_name, "firing post_tool_call hook");
178            if let HookCallback::PostToolCall(f) = cb {
179                f(ctx);
180            }
181        });
182    }
183
184    /// Run all [`HookPoint::OnToolError`] callbacks in registration order.
185    pub fn run_on_tool_error(&self, ctx: &OnToolErrorContext) {
186        self.run_observer(HookPoint::OnToolError, |name, cb| {
187            tracing::trace!(hook = %name, tool = %ctx.tool_name, error = %ctx.error, "firing on_tool_error hook");
188            if let HookCallback::OnToolError(f) = cb {
189                f(ctx);
190            }
191        });
192    }
193
194    /// Run all [`HookPoint::OnSessionStart`] callbacks in registration order.
195    pub fn run_on_session_start(&self, ctx: &OnSessionStartContext) {
196        self.run_observer(HookPoint::OnSessionStart, |name, cb| {
197            tracing::trace!(hook = %name, "firing on_session_start hook");
198            if let HookCallback::OnSessionStart(f) = cb {
199                f(ctx);
200            }
201        });
202    }
203
204    /// Run all [`HookPoint::OnSessionEnd`] callbacks in registration order.
205    pub fn run_on_session_end(&self, ctx: &OnSessionEndContext) {
206        self.run_observer(HookPoint::OnSessionEnd, |name, cb| {
207            tracing::trace!(hook = %name, "firing on_session_end hook");
208            if let HookCallback::OnSessionEnd(f) = cb {
209                f(ctx);
210            }
211        });
212    }
213
214    /// Run all [`HookPoint::OnCompaction`] callbacks in registration order.
215    pub fn run_on_compaction(&self, ctx: &OnCompactionContext) {
216        self.run_observer(HookPoint::OnCompaction, |name, cb| {
217            tracing::trace!(hook = %name, "firing on_compaction hook");
218            if let HookCallback::OnCompaction(f) = cb {
219                f(ctx);
220            }
221        });
222    }
223
224    /// Run all [`HookPoint::OnInteraction`] callbacks in registration order.
225    ///
226    /// If a callback panics, the panic is logged and execution continues
227    /// (the interaction is not blocked).
228    pub fn run_on_interaction(&self, ctx: &OnInteractionContext) -> HookResult {
229        for (_, name, cb) in self.iter_at(HookPoint::OnInteraction) {
230            tracing::trace!(hook = %name, "firing on_interaction hook");
231            if let HookCallback::OnInteraction(f) = cb {
232                let result = match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| f(ctx)))
233                {
234                    Ok(r) => r,
235                    Err(panic) => {
236                        tracing::error!(
237                            hook = %name,
238                            panic = ?panic,
239                            "on_interaction hook panicked — continuing"
240                        );
241                        continue;
242                    }
243                };
244                if !result.allow {
245                    return result;
246                }
247            }
248        }
249        HookResult::allow()
250    }
251
252    /// Run all [`TransformToolInput`](HookCallback::TransformToolInput)
253    /// callbacks in registration order, threading the (possibly modified)
254    /// tool arguments through each transform.
255    ///
256    /// Returns the final tool arguments after all transforms have been
257    /// applied.  If no transform returns `Some`, the original arguments
258    /// are returned unchanged.
259    ///
260    /// Panicking transforms are logged and skipped (original args kept).
261    pub fn run_transform_tool_input(&self, ctx: &PreToolCallDecideContext) -> serde_json::Value {
262        let mut args = ctx.tool_args.clone();
263        for (_, name, cb) in self.iter_at(HookPoint::PreToolCallDecide) {
264            if let HookCallback::TransformToolInput(f) = cb {
265                let current_ctx = PreToolCallDecideContext {
266                    tool_name: ctx.tool_name.clone(),
267                    tool_args: args.clone(),
268                };
269                match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| f(&current_ctx))) {
270                    Ok(Some(new_args)) => {
271                        tracing::debug!(
272                            hook = %name,
273                            tool = %ctx.tool_name,
274                            "transform_tool_input hook modified tool arguments"
275                        );
276                        args = new_args;
277                    }
278                    Ok(None) => { /* no modification */ }
279                    Err(panic) => {
280                        tracing::error!(
281                            hook = %name,
282                            tool = %ctx.tool_name,
283                            panic = ?panic,
284                            "transform_tool_input hook panicked — keeping current args"
285                        );
286                    }
287                }
288            }
289        }
290        args
291    }
292
293    // ── Convenience builder methods (Python decorator parity) ────────
294
295    /// Register a [`HookPoint::PreTurn`] callback.
296    ///
297    /// Convenience wrapper matching the Python SDK's `@on_pre_turn` decorator.
298    pub fn on_pre_turn(
299        &mut self,
300        name: impl Into<String>,
301        f: impl Fn(&PreTurnContext) + Send + Sync + 'static,
302    ) -> &mut Self {
303        self.register(name, HookCallback::PreTurn(Box::new(f)))
304    }
305
306    /// Register a [`HookPoint::PostTurn`] callback.
307    ///
308    /// Convenience wrapper matching the Python SDK's `@on_post_turn` decorator.
309    pub fn on_post_turn(
310        &mut self,
311        name: impl Into<String>,
312        f: impl Fn(&PostTurnContext) + Send + Sync + 'static,
313    ) -> &mut Self {
314        self.register(name, HookCallback::PostTurn(Box::new(f)))
315    }
316
317    /// Register a [`HookPoint::PreToolCallDecide`] callback.
318    ///
319    /// Convenience wrapper matching the Python SDK's `@on_pre_tool_call_decide`
320    /// decorator.
321    pub fn on_pre_tool_call_decide(
322        &mut self,
323        name: impl Into<String>,
324        f: impl Fn(&PreToolCallDecideContext) -> HookResult + Send + Sync + 'static,
325    ) -> &mut Self {
326        self.register(name, HookCallback::PreToolCallDecide(Box::new(f)))
327    }
328
329    /// Register a [`HookPoint::PostToolCall`] callback.
330    ///
331    /// Convenience wrapper matching the Python SDK's `@on_post_tool_call` decorator.
332    pub fn on_post_tool_call(
333        &mut self,
334        name: impl Into<String>,
335        f: impl Fn(&PostToolCallContext) + Send + Sync + 'static,
336    ) -> &mut Self {
337        self.register(name, HookCallback::PostToolCall(Box::new(f)))
338    }
339
340    /// Register a [`HookPoint::OnToolError`] callback.
341    ///
342    /// Convenience wrapper matching the Python SDK's `@on_tool_error` decorator.
343    pub fn on_tool_error(
344        &mut self,
345        name: impl Into<String>,
346        f: impl Fn(&OnToolErrorContext) + Send + Sync + 'static,
347    ) -> &mut Self {
348        self.register(name, HookCallback::OnToolError(Box::new(f)))
349    }
350
351    /// Register a [`HookPoint::OnCompaction`] callback.
352    ///
353    /// Convenience wrapper matching the Python SDK's `@on_compaction` decorator.
354    pub fn on_compaction(
355        &mut self,
356        name: impl Into<String>,
357        f: impl Fn(&OnCompactionContext) + Send + Sync + 'static,
358    ) -> &mut Self {
359        self.register(name, HookCallback::OnCompaction(Box::new(f)))
360    }
361
362    /// Register a [`HookPoint::OnInteraction`] callback.
363    ///
364    /// Convenience wrapper matching the Python SDK's `@on_interaction` decorator.
365    pub fn on_interaction(
366        &mut self,
367        name: impl Into<String>,
368        f: impl Fn(&OnInteractionContext) -> HookResult + Send + Sync + 'static,
369    ) -> &mut Self {
370        self.register(name, HookCallback::OnInteraction(Box::new(f)))
371    }
372
373    /// Register a [`HookPoint::OnSessionStart`] callback.
374    ///
375    /// Convenience wrapper matching the Python SDK's `@on_session_start` decorator.
376    pub fn on_session_start(
377        &mut self,
378        name: impl Into<String>,
379        f: impl Fn(&OnSessionStartContext) + Send + Sync + 'static,
380    ) -> &mut Self {
381        self.register(name, HookCallback::OnSessionStart(Box::new(f)))
382    }
383
384    /// Register a [`HookPoint::OnSessionEnd`] callback.
385    ///
386    /// Convenience wrapper matching the Python SDK's `@on_session_end` decorator.
387    pub fn on_session_end(
388        &mut self,
389        name: impl Into<String>,
390        f: impl Fn(&OnSessionEndContext) + Send + Sync + 'static,
391    ) -> &mut Self {
392        self.register(name, HookCallback::OnSessionEnd(Box::new(f)))
393    }
394
395    /// Register a [`TransformToolInput`](HookCallback::TransformToolInput) callback.
396    ///
397    /// The closure receives the pre-tool-call context and may return
398    /// `Some(new_args)` to replace tool arguments, or `None` to leave them
399    /// unchanged.
400    pub fn on_transform_tool_input(
401        &mut self,
402        name: impl Into<String>,
403        f: impl Fn(&PreToolCallDecideContext) -> Option<serde_json::Value> + Send + Sync + 'static,
404    ) -> &mut Self {
405        self.register(name, HookCallback::TransformToolInput(Box::new(f)))
406    }
407
408    // ── Owned-self builder methods (for fluent chaining) ────────────
409
410    /// Register a [`HookPoint::PreTurn`] callback, returning `self` for chaining.
411    ///
412    /// This is the owned-self variant of [`on_pre_turn`](Self::on_pre_turn).
413    #[must_use]
414    pub fn with_pre_turn(
415        mut self,
416        name: impl Into<String>,
417        f: impl Fn(&PreTurnContext) + Send + Sync + 'static,
418    ) -> Self {
419        self.on_pre_turn(name, f);
420        self
421    }
422
423    /// Register a [`HookPoint::PostTurn`] callback, returning `self` for chaining.
424    ///
425    /// This is the owned-self variant of [`on_post_turn`](Self::on_post_turn).
426    #[must_use]
427    pub fn with_post_turn(
428        mut self,
429        name: impl Into<String>,
430        f: impl Fn(&PostTurnContext) + Send + Sync + 'static,
431    ) -> Self {
432        self.on_post_turn(name, f);
433        self
434    }
435
436    /// Register a [`HookPoint::PreToolCallDecide`] callback, returning `self`
437    /// for chaining.
438    ///
439    /// This is the owned-self variant of
440    /// [`on_pre_tool_call_decide`](Self::on_pre_tool_call_decide).
441    #[must_use]
442    pub fn with_pre_tool_call_decide(
443        mut self,
444        name: impl Into<String>,
445        f: impl Fn(&PreToolCallDecideContext) -> HookResult + Send + Sync + 'static,
446    ) -> Self {
447        self.on_pre_tool_call_decide(name, f);
448        self
449    }
450
451    /// Register a [`HookPoint::PostToolCall`] callback, returning `self` for
452    /// chaining.
453    ///
454    /// This is the owned-self variant of
455    /// [`on_post_tool_call`](Self::on_post_tool_call).
456    #[must_use]
457    pub fn with_post_tool_call(
458        mut self,
459        name: impl Into<String>,
460        f: impl Fn(&PostToolCallContext) + Send + Sync + 'static,
461    ) -> Self {
462        self.on_post_tool_call(name, f);
463        self
464    }
465
466    /// Register a [`HookPoint::OnToolError`] callback, returning `self` for
467    /// chaining.
468    ///
469    /// This is the owned-self variant of
470    /// [`on_tool_error`](Self::on_tool_error).
471    #[must_use]
472    pub fn with_tool_error(
473        mut self,
474        name: impl Into<String>,
475        f: impl Fn(&OnToolErrorContext) + Send + Sync + 'static,
476    ) -> Self {
477        self.on_tool_error(name, f);
478        self
479    }
480
481    /// Register a [`HookPoint::OnCompaction`] callback, returning `self` for
482    /// chaining.
483    ///
484    /// This is the owned-self variant of
485    /// [`on_compaction`](Self::on_compaction).
486    #[must_use]
487    pub fn with_compaction(
488        mut self,
489        name: impl Into<String>,
490        f: impl Fn(&OnCompactionContext) + Send + Sync + 'static,
491    ) -> Self {
492        self.on_compaction(name, f);
493        self
494    }
495
496    /// Register a [`HookPoint::OnInteraction`] callback, returning `self` for
497    /// chaining.
498    ///
499    /// This is the owned-self variant of
500    /// [`on_interaction`](Self::on_interaction).
501    #[must_use]
502    pub fn with_interaction(
503        mut self,
504        name: impl Into<String>,
505        f: impl Fn(&OnInteractionContext) -> HookResult + Send + Sync + 'static,
506    ) -> Self {
507        self.on_interaction(name, f);
508        self
509    }
510
511    /// Register a [`HookPoint::OnSessionStart`] callback, returning `self`
512    /// for chaining.
513    ///
514    /// This is the owned-self variant of
515    /// [`on_session_start`](Self::on_session_start).
516    #[must_use]
517    pub fn with_session_start(
518        mut self,
519        name: impl Into<String>,
520        f: impl Fn(&OnSessionStartContext) + Send + Sync + 'static,
521    ) -> Self {
522        self.on_session_start(name, f);
523        self
524    }
525
526    /// Register a [`HookPoint::OnSessionEnd`] callback, returning `self` for
527    /// chaining.
528    ///
529    /// This is the owned-self variant of
530    /// [`on_session_end`](Self::on_session_end).
531    #[must_use]
532    pub fn with_session_end(
533        mut self,
534        name: impl Into<String>,
535        f: impl Fn(&OnSessionEndContext) + Send + Sync + 'static,
536    ) -> Self {
537        self.on_session_end(name, f);
538        self
539    }
540
541    /// Register a [`TransformToolInput`](HookCallback::TransformToolInput)
542    /// callback, returning `self` for chaining.
543    ///
544    /// This is the owned-self variant of
545    /// [`on_transform_tool_input`](Self::on_transform_tool_input).
546    #[must_use]
547    pub fn with_transform_tool_input(
548        mut self,
549        name: impl Into<String>,
550        f: impl Fn(&PreToolCallDecideContext) -> Option<serde_json::Value> + Send + Sync + 'static,
551    ) -> Self {
552        self.on_transform_tool_input(name, f);
553        self
554    }
555
556    /// Iterate callbacks at a given hook point in registration order.
557    fn iter_at(
558        &self,
559        point: HookPoint,
560    ) -> impl Iterator<Item = &(HookPoint, String, HookCallback)> {
561        self.callbacks.iter().filter(move |(p, _, _)| *p == point)
562    }
563
564    /// Extract a list of [`HookEntry`](super::types::HookEntry) objects
565    /// corresponding to the registered callbacks.
566    ///
567    /// This allows the `AgentBuilder` to automatically populate the agent's
568    /// configuration with the necessary entries to connect the Python SDK's
569    /// hook dispatcher back to the Rust runner.
570    #[must_use]
571    pub fn entries(&self) -> Vec<super::types::HookEntry> {
572        self.callbacks
573            .iter()
574            .map(|(point, name, _)| super::types::HookEntry {
575                name: name.clone(),
576                point: *point,
577                callback_id: name.clone(),
578            })
579            .collect()
580    }
581}
582
583impl Default for Hooks {
584    fn default() -> Self {
585        Self::new()
586    }
587}
588
589#[cfg(test)]
590#[path = "runner_tests.rs"]
591mod tests;