Skip to main content

rmcp_server_kit/
tool_hooks.rs

1//! Opt-in tool-call instrumentation for `ServerHandler` implementations.
2//!
3//! [`crate::tool_hooks::HookedHandler`] wraps any [`rmcp::ServerHandler`] with:
4//!
5//! - **Before hooks** (async) that observe `(tool_name, arguments, identity,
6//!   role, sub, request_id)` and may [`HookOutcome::Continue`](crate::tool_hooks::HookOutcome::Continue),
7//!   [`HookOutcome::Deny`](crate::tool_hooks::HookOutcome::Deny), or
8//!   [`HookOutcome::Replace`](crate::tool_hooks::HookOutcome::Replace) the call.
9//! - **After hooks** (async) that observe the same context plus a
10//!   [`HookDisposition`](crate::tool_hooks::HookDisposition) describing how the call resolved and the
11//!   approximate result size in bytes.  After-hooks are spawned via
12//!   `tokio::spawn` and never block the response path.
13//! - **Result-size capping**: serialized tool results larger than
14//!   `max_result_bytes` are replaced with a structured error, preventing
15//!   token-expensive or memory-expensive payloads from reaching clients.
16//!   The cap applies both to inner-handler results and to
17//!   [`HookOutcome::Replace`](crate::tool_hooks::HookOutcome::Replace) payloads.
18//!
19//! This is entirely **opt-in** at the application layer - `rmcp_server_kit::serve()`
20//! does not wrap handlers automatically.  Applications that want hooks do:
21//!
22//! ```no_run
23//! use std::sync::Arc;
24//! use rmcp_server_kit::tool_hooks::{HookedHandler, HookOutcome, ToolHooks, with_hooks};
25//!
26//! # #[derive(Clone, Default)]
27//! # struct MyHandler;
28//! # impl rmcp::ServerHandler for MyHandler {}
29//! let handler = MyHandler::default();
30//! let hooks = Arc::new(
31//!     ToolHooks::new()
32//!         .with_max_result_bytes(256 * 1024)
33//!         .with_before(Arc::new(|_ctx| Box::pin(async { HookOutcome::Continue })))
34//!         .with_after(Arc::new(|_ctx, _disp, _bytes| Box::pin(async {}))),
35//! );
36//! let _wrapped = with_hooks(handler, hooks);
37//! ```
38
39use std::{fmt, future::Future, pin::Pin, sync::Arc};
40
41use rmcp::{
42    ErrorData, RoleServer, ServerHandler,
43    model::{
44        CallToolRequestParams, CallToolResult, Content, GetPromptRequestParams, GetPromptResult,
45        InitializeRequestParams, InitializeResult, ListPromptsResult, ListResourceTemplatesResult,
46        ListResourcesResult, ListToolsResult, PaginatedRequestParams, ReadResourceRequestParams,
47        ReadResourceResult, ServerInfo, Tool,
48    },
49    service::RequestContext,
50};
51
52/// Context passed to before/after hooks for a single tool call.
53#[derive(Debug, Clone)]
54#[non_exhaustive]
55pub struct ToolCallContext {
56    /// Tool name being invoked.
57    pub tool_name: String,
58    /// JSON arguments as sent by the client (may be `None`).
59    pub arguments: Option<serde_json::Value>,
60    /// Identity name from the authenticated request, if any.
61    pub identity: Option<String>,
62    /// RBAC role associated with the request, if any.
63    pub role: Option<String>,
64    /// OAuth `sub` claim, if present.
65    pub sub: Option<String>,
66    /// Raw JSON-RPC request id rendered as a string, if available.
67    pub request_id: Option<String>,
68}
69
70impl ToolCallContext {
71    /// Construct a [`ToolCallContext`] with the given tool name and all
72    /// optional fields cleared.  Primarily for use in unit tests and
73    /// benchmarks of user-supplied hooks; the runtime path populates
74    /// these fields from the request and task-local RBAC state.
75    #[must_use]
76    pub fn for_tool(tool_name: impl Into<String>) -> Self {
77        Self {
78            tool_name: tool_name.into(),
79            arguments: None,
80            identity: None,
81            role: None,
82            sub: None,
83            request_id: None,
84        }
85    }
86}
87
88/// Outcome returned by a [`BeforeHook`] to control invocation flow.
89///
90/// - [`HookOutcome::Continue`] - proceed with the wrapped handler.
91/// - [`HookOutcome::Deny`] - reject the call with the supplied
92///   [`ErrorData`]; the inner handler is **not** called.
93/// - [`HookOutcome::Replace`] - return the supplied result instead of
94///   invoking the inner handler.  The result is still subject to
95///   `max_result_bytes` capping.
96#[derive(Debug)]
97#[non_exhaustive]
98pub enum HookOutcome {
99    /// Proceed with the wrapped handler.
100    Continue,
101    /// Reject the call.  The error is propagated to the client as-is.
102    Deny(ErrorData),
103    /// Skip the inner handler and return the supplied result instead.
104    Replace(Box<CallToolResult>),
105}
106
107/// How a tool call resolved, passed to the [`AfterHook`].
108#[derive(Debug, Clone, Copy)]
109#[non_exhaustive]
110pub enum HookDisposition {
111    /// The inner handler ran and returned `Ok`.
112    InnerExecuted,
113    /// The inner handler ran and returned `Err`.
114    InnerErrored,
115    /// The before-hook returned [`HookOutcome::Deny`].
116    DeniedBefore,
117    /// The before-hook returned [`HookOutcome::Replace`].
118    ReplacedBefore,
119    /// The result (from inner or replace) exceeded `max_result_bytes`
120    /// and was substituted with a structured error.
121    ResultTooLarge,
122}
123
124/// Async before-hook callback type.
125///
126/// Returns a [`HookOutcome`] controlling whether the inner handler runs.
127/// The borrow of `ToolCallContext` is held for the duration of the
128/// returned future, which avoids forcing implementations to clone the
129/// context for every invocation.
130pub type BeforeHook = Arc<
131    dyn for<'a> Fn(&'a ToolCallContext) -> Pin<Box<dyn Future<Output = HookOutcome> + Send + 'a>>
132        + Send
133        + Sync
134        + 'static,
135>;
136
137/// Async after-hook callback type.
138///
139/// Receives the call context, a [`HookDisposition`] describing how the
140/// call resolved, and the approximate serialized result size in bytes
141/// (`0` for `DeniedBefore` and `InnerErrored`).  Spawned via
142/// `tokio::spawn`, so it must not assume it runs before the response is
143/// flushed.
144pub type AfterHook = Arc<
145    dyn for<'a> Fn(
146            &'a ToolCallContext,
147            HookDisposition,
148            usize,
149        ) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>>
150        + Send
151        + Sync
152        + 'static,
153>;
154
155/// Opt-in hooks applied by [`crate::tool_hooks::HookedHandler`].
156#[allow(clippy::struct_field_names, reason = "before/after read naturally")]
157#[derive(Clone, Default)]
158#[non_exhaustive]
159pub struct ToolHooks {
160    /// Hard cap on serialized `CallToolResult` size in bytes.  When
161    /// exceeded, the result is replaced with an `is_error=true` result
162    /// carrying a `result_too_large` structured error.  `None` disables
163    /// the cap.
164    pub max_result_bytes: Option<usize>,
165    /// Optional before-hook invoked after arg deserialization, before
166    /// the wrapped handler is called.
167    pub before: Option<BeforeHook>,
168    /// Optional after-hook invoked once per call, regardless of how the
169    /// call resolved.  Spawned via `tokio::spawn` and never blocks the
170    /// response path.
171    pub after: Option<AfterHook>,
172}
173
174impl ToolHooks {
175    /// Construct an empty [`ToolHooks`] with no cap and no hooks.
176    ///
177    /// Use the `with_*` builder methods to populate fields; this avoids
178    /// the `#[non_exhaustive]` restriction that prevents struct-literal
179    /// construction from outside the crate.
180    #[must_use]
181    pub fn new() -> Self {
182        Self::default()
183    }
184
185    /// Set the serialized result size cap in bytes.
186    #[must_use]
187    pub fn with_max_result_bytes(mut self, max: usize) -> Self {
188        self.max_result_bytes = Some(max);
189        self
190    }
191
192    /// Set the before-hook.
193    #[must_use]
194    pub fn with_before(mut self, before: BeforeHook) -> Self {
195        self.before = Some(before);
196        self
197    }
198
199    /// Set the after-hook.
200    #[must_use]
201    pub fn with_after(mut self, after: AfterHook) -> Self {
202        self.after = Some(after);
203        self
204    }
205}
206
207impl fmt::Debug for ToolHooks {
208    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
209        f.debug_struct("ToolHooks")
210            .field("max_result_bytes", &self.max_result_bytes)
211            .field("before", &self.before.as_ref().map(|_| "<fn>"))
212            .field("after", &self.after.as_ref().map(|_| "<fn>"))
213            .finish()
214    }
215}
216
217/// `ServerHandler` wrapper that applies [`ToolHooks`].
218#[derive(Clone)]
219pub struct HookedHandler<H: ServerHandler> {
220    inner: Arc<H>,
221    hooks: Arc<ToolHooks>,
222}
223
224impl<H: ServerHandler> fmt::Debug for HookedHandler<H> {
225    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
226        f.debug_struct("HookedHandler")
227            .field("hooks", &self.hooks)
228            .finish_non_exhaustive()
229    }
230}
231
232/// Construct a [`crate::tool_hooks::HookedHandler`] from an inner handler and hooks.
233///
234/// Returning the wrapped handler is the entire point of this function;
235/// dropping it on the floor would silently disable the supplied hooks. The
236/// `#[must_use]` attribute would be the natural enforcement here, but adding
237/// it to a public function is a SemVer-minor change per cargo-semver-checks;
238/// it is deferred to the next minor-version bump.
239pub fn with_hooks<H: ServerHandler>(inner: H, hooks: Arc<ToolHooks>) -> HookedHandler<H> {
240    HookedHandler {
241        inner: Arc::new(inner),
242        hooks,
243    }
244}
245
246impl<H: ServerHandler> HookedHandler<H> {
247    /// Access the wrapped handler.
248    #[must_use]
249    pub fn inner(&self) -> &H {
250        &self.inner
251    }
252
253    fn build_context(request: &CallToolRequestParams, req_id: Option<String>) -> ToolCallContext {
254        ToolCallContext {
255            tool_name: request.name.to_string(),
256            arguments: request.arguments.clone().map(serde_json::Value::Object),
257            identity: crate::rbac::current_identity(),
258            role: crate::rbac::current_role(),
259            sub: crate::rbac::current_sub(),
260            request_id: req_id,
261        }
262    }
263
264    /// Spawn the after-hook on the current Tokio runtime.  The future
265    /// captures clones of `ctx` and the `Arc<AfterHook>` so it can run
266    /// independently of the request task; panics inside the after-hook
267    /// are caught by Tokio and never poison the response path.
268    ///
269    /// The spawned task is **instrumented** with the request span via
270    /// [`tracing::Instrument`] and re-establishes the per-request RBAC
271    /// task-locals (role, identity, token, sub) via
272    /// [`crate::rbac::with_rbac_scope`]. Without this, after-hooks lose
273    /// their parent span (breaking trace correlation) and observe
274    /// `current_role()` / `current_identity()` as `None`.
275    fn spawn_after(
276        after: Option<&Arc<AfterHookHolder>>,
277        ctx: ToolCallContext,
278        disposition: HookDisposition,
279        size: usize,
280    ) {
281        if let Some(after) = after {
282            use tracing::Instrument;
283
284            let after = Arc::clone(after);
285            // Capture the request span before leaving the request task so
286            // after-hook log lines are correlated with the originating call.
287            let span = tracing::Span::current();
288            // Snapshot RBAC task-locals; defaults are empty strings so the
289            // re-established scope is a no-op when the request had no
290            // authenticated identity (e.g. health checks, anonymous tools).
291            let role = crate::rbac::current_role().unwrap_or_default();
292            let identity = crate::rbac::current_identity().unwrap_or_default();
293            let token = crate::rbac::current_token()
294                .unwrap_or_else(|| secrecy::SecretString::from(String::new()));
295            let sub = crate::rbac::current_sub().unwrap_or_default();
296            tokio::spawn(
297                async move {
298                    crate::rbac::with_rbac_scope(role, identity, token, sub, async move {
299                        let fut = (after.f)(&ctx, disposition, size);
300                        fut.await;
301                    })
302                    .await;
303                }
304                .instrument(span),
305            );
306        }
307    }
308}
309
310/// Internal newtype that owns the [`AfterHook`] so we can `Arc::clone`
311/// the *holder* and let the spawned task borrow `ctx` for the lifetime
312/// of the future without lifetime acrobatics in `tokio::spawn`.
313struct AfterHookHolder {
314    f: AfterHook,
315}
316
317/// Structured error body returned when a result exceeds `max_result_bytes`.
318fn too_large_result(limit: usize, actual: usize, tool: &str) -> CallToolResult {
319    let body = serde_json::json!({
320        "error": "result_too_large",
321        "message": format!(
322            "tool '{tool}' result of {actual} bytes exceeds the configured \
323             max_result_bytes={limit}; ask for a narrower query"
324        ),
325        "limit_bytes": limit,
326        "actual_bytes": actual,
327    });
328    let mut r = CallToolResult::error(vec![Content::text(body.to_string())]);
329    r.structured_content = None;
330    r
331}
332
333fn serialized_size(result: &CallToolResult) -> usize {
334    serde_json::to_vec(result).map_or(0, |v| v.len())
335}
336
337/// Apply the `max_result_bytes` cap to a result.  Returns the (possibly
338/// replaced) result, the size used for accounting, and whether the cap
339/// fired.
340fn apply_size_cap(
341    result: CallToolResult,
342    max: Option<usize>,
343    tool: &str,
344) -> (CallToolResult, usize, bool) {
345    let size = serialized_size(&result);
346    if let Some(limit) = max
347        && size > limit
348    {
349        tracing::warn!(
350            tool = %tool,
351            size_bytes = size,
352            limit_bytes = limit,
353            "tool result exceeds max_result_bytes; replacing with structured error"
354        );
355        let replaced = too_large_result(limit, size, tool);
356        return (replaced, size, true);
357    }
358    (result, size, false)
359}
360
361impl<H: ServerHandler> ServerHandler for HookedHandler<H> {
362    fn get_info(&self) -> ServerInfo {
363        self.inner.get_info()
364    }
365
366    async fn initialize(
367        &self,
368        request: InitializeRequestParams,
369        context: RequestContext<RoleServer>,
370    ) -> Result<InitializeResult, ErrorData> {
371        self.inner.initialize(request, context).await
372    }
373
374    async fn list_tools(
375        &self,
376        request: Option<PaginatedRequestParams>,
377        context: RequestContext<RoleServer>,
378    ) -> Result<ListToolsResult, ErrorData> {
379        self.inner.list_tools(request, context).await
380    }
381
382    fn get_tool(&self, name: &str) -> Option<Tool> {
383        self.inner.get_tool(name)
384    }
385
386    async fn list_prompts(
387        &self,
388        request: Option<PaginatedRequestParams>,
389        context: RequestContext<RoleServer>,
390    ) -> Result<ListPromptsResult, ErrorData> {
391        self.inner.list_prompts(request, context).await
392    }
393
394    async fn get_prompt(
395        &self,
396        request: GetPromptRequestParams,
397        context: RequestContext<RoleServer>,
398    ) -> Result<GetPromptResult, ErrorData> {
399        self.inner.get_prompt(request, context).await
400    }
401
402    async fn list_resources(
403        &self,
404        request: Option<PaginatedRequestParams>,
405        context: RequestContext<RoleServer>,
406    ) -> Result<ListResourcesResult, ErrorData> {
407        self.inner.list_resources(request, context).await
408    }
409
410    async fn list_resource_templates(
411        &self,
412        request: Option<PaginatedRequestParams>,
413        context: RequestContext<RoleServer>,
414    ) -> Result<ListResourceTemplatesResult, ErrorData> {
415        self.inner.list_resource_templates(request, context).await
416    }
417
418    async fn read_resource(
419        &self,
420        request: ReadResourceRequestParams,
421        context: RequestContext<RoleServer>,
422    ) -> Result<ReadResourceResult, ErrorData> {
423        self.inner.read_resource(request, context).await
424    }
425
426    async fn call_tool(
427        &self,
428        request: CallToolRequestParams,
429        context: RequestContext<RoleServer>,
430    ) -> Result<CallToolResult, ErrorData> {
431        let req_id = Some(format!("{:?}", context.id));
432        let ctx = Self::build_context(&request, req_id);
433        let max = self.hooks.max_result_bytes;
434        let after_holder = self
435            .hooks
436            .after
437            .as_ref()
438            .map(|f| Arc::new(AfterHookHolder { f: Arc::clone(f) }));
439
440        // Before hook: may Continue, Deny, or Replace.
441        if let Some(before) = self.hooks.before.as_ref() {
442            let outcome = before(&ctx).await;
443            match outcome {
444                HookOutcome::Continue => {}
445                HookOutcome::Deny(err) => {
446                    Self::spawn_after(after_holder.as_ref(), ctx, HookDisposition::DeniedBefore, 0);
447                    return Err(err);
448                }
449                HookOutcome::Replace(boxed) => {
450                    let (final_result, size, capped) = apply_size_cap(*boxed, max, &ctx.tool_name);
451                    let disposition = if capped {
452                        HookDisposition::ResultTooLarge
453                    } else {
454                        HookDisposition::ReplacedBefore
455                    };
456                    Self::spawn_after(after_holder.as_ref(), ctx, disposition, size);
457                    return Ok(final_result);
458                }
459            }
460        }
461
462        // Inner handler.
463        let result = self.inner.call_tool(request, context).await;
464
465        match result {
466            Ok(ok) => {
467                let (final_result, size, capped) = apply_size_cap(ok, max, &ctx.tool_name);
468                let disposition = if capped {
469                    HookDisposition::ResultTooLarge
470                } else {
471                    HookDisposition::InnerExecuted
472                };
473                Self::spawn_after(after_holder.as_ref(), ctx, disposition, size);
474                Ok(final_result)
475            }
476            Err(e) => {
477                Self::spawn_after(after_holder.as_ref(), ctx, HookDisposition::InnerErrored, 0);
478                Err(e)
479            }
480        }
481    }
482}
483
484#[cfg(test)]
485mod tests {
486    use std::sync::{
487        Arc,
488        atomic::{AtomicUsize, Ordering},
489    };
490
491    use rmcp::{
492        ErrorData, RoleServer, ServerHandler,
493        model::{CallToolRequestParams, CallToolResult, Content, ServerInfo},
494        service::RequestContext,
495    };
496
497    use super::*;
498
499    /// Minimal in-process `ServerHandler` for tests.
500    #[derive(Clone, Default)]
501    struct TestHandler {
502        /// When Some, `call_tool` returns a body of this many 'x' bytes.
503        body_bytes: Option<usize>,
504    }
505
506    impl ServerHandler for TestHandler {
507        fn get_info(&self) -> ServerInfo {
508            ServerInfo::default()
509        }
510
511        async fn call_tool(
512            &self,
513            _request: CallToolRequestParams,
514            _context: RequestContext<RoleServer>,
515        ) -> Result<CallToolResult, ErrorData> {
516            let body = "x".repeat(self.body_bytes.unwrap_or(4));
517            Ok(CallToolResult::success(vec![Content::text(body)]))
518        }
519    }
520
521    fn ctx(name: &str) -> ToolCallContext {
522        ToolCallContext {
523            tool_name: name.to_owned(),
524            arguments: None,
525            identity: None,
526            role: None,
527            sub: None,
528            request_id: None,
529        }
530    }
531
532    #[tokio::test]
533    async fn size_cap_replaces_oversized_result() {
534        let inner = TestHandler {
535            body_bytes: Some(8_192),
536        };
537        let hooks = Arc::new(ToolHooks {
538            max_result_bytes: Some(256),
539            before: None,
540            after: None,
541        });
542        let hooked = with_hooks(inner, hooks);
543
544        let small = CallToolResult::success(vec![Content::text("ok".to_owned())]);
545        assert!(serialized_size(&small) < 256);
546
547        let big = CallToolResult::success(vec![Content::text("x".repeat(8_192))]);
548        let size = serialized_size(&big);
549        assert!(size > 256);
550
551        let (replaced, accounted, capped) = apply_size_cap(big, Some(256), "whatever");
552        assert!(capped);
553        assert_eq!(accounted, size);
554        assert_eq!(replaced.is_error, Some(true));
555        assert!(matches!(
556            &replaced.content[0].raw,
557            rmcp::model::RawContent::Text(t) if t.text.contains("result_too_large")
558        ));
559
560        // Compile-check that HookedHandler instantiates with the test inner.
561        let _ = hooked;
562    }
563
564    #[tokio::test]
565    async fn before_hook_deny_builds_error() {
566        let counter = Arc::new(AtomicUsize::new(0));
567        let c = Arc::clone(&counter);
568        let before: BeforeHook = Arc::new(move |ctx_ref| {
569            let c = Arc::clone(&c);
570            let name = ctx_ref.tool_name.clone();
571            Box::pin(async move {
572                c.fetch_add(1, Ordering::Relaxed);
573                if name == "forbidden" {
574                    HookOutcome::Deny(ErrorData::invalid_request("nope", None))
575                } else {
576                    HookOutcome::Continue
577                }
578            })
579        });
580
581        let hooks = Arc::new(ToolHooks {
582            max_result_bytes: None,
583            before: Some(before),
584            after: None,
585        });
586        let hooked = with_hooks(TestHandler::default(), hooks);
587
588        let bad_ctx = ctx("forbidden");
589        let before_fn = hooked.hooks.before.as_ref().unwrap();
590        let outcome = before_fn(&bad_ctx).await;
591        assert!(matches!(outcome, HookOutcome::Deny(_)));
592        assert_eq!(counter.load(Ordering::Relaxed), 1);
593
594        let ok_ctx = ctx("allowed");
595        let outcome2 = before_fn(&ok_ctx).await;
596        assert!(matches!(outcome2, HookOutcome::Continue));
597        assert_eq!(counter.load(Ordering::Relaxed), 2);
598    }
599
600    #[test]
601    fn too_large_result_mentions_limit_and_actual() {
602        let r = too_large_result(100, 500, "my_tool");
603        let body = serde_json::to_string(&r).unwrap();
604        assert!(body.contains("result_too_large"));
605        assert!(body.contains("my_tool"));
606        assert!(body.contains("100"));
607        assert!(body.contains("500"));
608    }
609
610    #[tokio::test]
611    async fn replace_outcome_skips_inner_and_returns_payload() {
612        // Returning Replace from before-hook must yield the supplied
613        // CallToolResult directly, with no need for the inner handler.
614        let before: BeforeHook = Arc::new(|_ctx| {
615            Box::pin(async {
616                HookOutcome::Replace(Box::new(CallToolResult::success(vec![Content::text(
617                    "from-replace".to_owned(),
618                )])))
619            })
620        });
621        let hooks = Arc::new(ToolHooks {
622            max_result_bytes: None,
623            before: Some(before),
624            after: None,
625        });
626        let _hooked = with_hooks(TestHandler::default(), Arc::clone(&hooks));
627
628        // Exercise the before-hook closure + apply_size_cap helper directly,
629        // matching the established test pattern in this module.
630        let outcome = (hooks.before.as_ref().unwrap())(&ctx("any")).await;
631        let HookOutcome::Replace(boxed) = outcome else {
632            panic!("expected HookOutcome::Replace");
633        };
634        let (result, size, capped) = apply_size_cap(*boxed, None, "any");
635        assert!(!capped);
636        assert!(size > 0);
637        assert!(!result.is_error.unwrap_or(false));
638        assert!(matches!(
639            &result.content[0].raw,
640            rmcp::model::RawContent::Text(t) if t.text == "from-replace"
641        ));
642    }
643
644    #[tokio::test]
645    async fn replace_outcome_subject_to_size_cap() {
646        // A Replace payload that exceeds max_result_bytes must be rewritten
647        // to result_too_large just like an inner-handler result would be,
648        // and the disposition must reflect ResultTooLarge.
649        let huge = CallToolResult::success(vec![Content::text("y".repeat(8_192))]);
650        let huge_size = serialized_size(&huge);
651        assert!(huge_size > 256);
652
653        let (final_result, accounted, capped) = apply_size_cap(huge, Some(256), "replaced_tool");
654        assert!(capped);
655        assert_eq!(accounted, huge_size);
656        assert_eq!(final_result.is_error, Some(true));
657        assert!(matches!(
658            &final_result.content[0].raw,
659            rmcp::model::RawContent::Text(t) if t.text.contains("result_too_large")
660        ));
661    }
662
663    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
664    async fn after_hook_fires_exactly_once_via_spawn() {
665        // spawn_after must enqueue the after-hook exactly one time per
666        // invocation and never block the caller; we wait for the spawned
667        // task to run by polling the counter with a short timeout.
668        let counter = Arc::new(AtomicUsize::new(0));
669        let c = Arc::clone(&counter);
670        let after: AfterHook = Arc::new(move |_ctx, _disp, _size| {
671            let c = Arc::clone(&c);
672            Box::pin(async move {
673                c.fetch_add(1, Ordering::Relaxed);
674            })
675        });
676        let holder = Arc::new(AfterHookHolder { f: after });
677
678        HookedHandler::<TestHandler>::spawn_after(
679            Some(&holder),
680            ctx("t"),
681            HookDisposition::InnerExecuted,
682            42,
683        );
684
685        // Wait up to 1s for the spawned task to run.
686        let deadline = std::time::Instant::now() + std::time::Duration::from_secs(1);
687        while counter.load(Ordering::Relaxed) == 0 && std::time::Instant::now() < deadline {
688            tokio::task::yield_now().await;
689            tokio::time::sleep(std::time::Duration::from_millis(5)).await;
690        }
691        assert_eq!(counter.load(Ordering::Relaxed), 1);
692    }
693
694    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
695    async fn after_hook_panic_is_isolated_from_response_path() {
696        // A panicking after-hook must not affect the request task.  We
697        // spawn a panicking after-hook and then verify the current task
698        // can still complete an unrelated future to completion.
699        let after: AfterHook = Arc::new(|_ctx, _disp, _size| {
700            Box::pin(async {
701                panic!("intentional panic in after-hook");
702            })
703        });
704        let holder = Arc::new(AfterHookHolder { f: after });
705
706        HookedHandler::<TestHandler>::spawn_after(
707            Some(&holder),
708            ctx("boom"),
709            HookDisposition::InnerExecuted,
710            0,
711        );
712
713        // Give Tokio a chance to run + abort the panicking task, then
714        // confirm we're still alive and the runtime is healthy.
715        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
716        let still_alive = tokio::spawn(async { 1_u32 + 2 }).await.unwrap();
717        assert_eq!(still_alive, 3);
718    }
719}