Skip to main content

rmcp_server_kit/
tool_hooks.rs

1//! Opt-in tool-call instrumentation for `ServerHandler` implementations.
2//!
3//! [`crate::tool_hooks::HookedHandler`] wraps any [`rmcp::ServerHandler`] with:
4//!
5//! - **Before hooks** (async) that observe `(tool_name, arguments, identity,
6//!   role, sub, request_id)` and may [`HookOutcome::Continue`](crate::tool_hooks::HookOutcome::Continue),
7//!   [`HookOutcome::Deny`](crate::tool_hooks::HookOutcome::Deny), or
8//!   [`HookOutcome::Replace`](crate::tool_hooks::HookOutcome::Replace) the call.
9//! - **After hooks** (async) that observe the same context plus a
10//!   [`HookDisposition`](crate::tool_hooks::HookDisposition) describing how the call resolved and the
11//!   approximate result size in bytes.  After-hooks are spawned via
12//!   `tokio::spawn` and never block the response path.
13//! - **Result-size capping**: serialized tool results larger than
14//!   `max_result_bytes` are replaced with a structured error, preventing
15//!   token-expensive or memory-expensive payloads from reaching clients.
16//!   The cap applies both to inner-handler results and to
17//!   [`HookOutcome::Replace`](crate::tool_hooks::HookOutcome::Replace) payloads.
18//!
19//! This is entirely **opt-in** at the application layer - `rmcp_server_kit::serve()`
20//! does not wrap handlers automatically.  Applications that want hooks do:
21//!
22//! ```no_run
23//! use std::sync::Arc;
24//! use rmcp_server_kit::tool_hooks::{HookedHandler, HookOutcome, ToolHooks, with_hooks};
25//!
26//! # #[derive(Clone, Default)]
27//! # struct MyHandler;
28//! # impl rmcp::ServerHandler for MyHandler {}
29//! let handler = MyHandler::default();
30//! let hooks = Arc::new(
31//!     ToolHooks::new()
32//!         .with_max_result_bytes(256 * 1024)
33//!         .with_before(Arc::new(|_ctx| Box::pin(async { HookOutcome::Continue })))
34//!         .with_after(Arc::new(|_ctx, _disp, _bytes| Box::pin(async {}))),
35//! );
36//! let _wrapped = with_hooks(handler, hooks);
37//! ```
38
39use std::{fmt, future::Future, pin::Pin, sync::Arc};
40
41use rmcp::{
42    ErrorData, RoleServer, ServerHandler,
43    model::{
44        CallToolRequestParams, CallToolResult, Content, GetPromptRequestParams, GetPromptResult,
45        InitializeRequestParams, InitializeResult, ListPromptsResult, ListResourceTemplatesResult,
46        ListResourcesResult, ListToolsResult, PaginatedRequestParams, ReadResourceRequestParams,
47        ReadResourceResult, ServerInfo, Tool,
48    },
49    service::RequestContext,
50};
51
52/// Context passed to before/after hooks for a single tool call.
53#[derive(Debug, Clone)]
54#[non_exhaustive]
55pub struct ToolCallContext {
56    /// Tool name being invoked.
57    pub tool_name: String,
58    /// JSON arguments as sent by the client (may be `None`).
59    pub arguments: Option<serde_json::Value>,
60    /// Identity name from the authenticated request, if any.
61    pub identity: Option<String>,
62    /// RBAC role associated with the request, if any.
63    pub role: Option<String>,
64    /// OAuth `sub` claim, if present.
65    pub sub: Option<String>,
66    /// Raw JSON-RPC request id rendered as a string, if available.
67    pub request_id: Option<String>,
68}
69
70impl ToolCallContext {
71    /// Construct a [`ToolCallContext`] with the given tool name and all
72    /// optional fields cleared.  Primarily for use in unit tests and
73    /// benchmarks of user-supplied hooks; the runtime path populates
74    /// these fields from the request and task-local RBAC state.
75    #[must_use]
76    pub fn for_tool(tool_name: impl Into<String>) -> Self {
77        Self {
78            tool_name: tool_name.into(),
79            arguments: None,
80            identity: None,
81            role: None,
82            sub: None,
83            request_id: None,
84        }
85    }
86}
87
88/// Outcome returned by a [`BeforeHook`] to control invocation flow.
89///
90/// - [`HookOutcome::Continue`] - proceed with the wrapped handler.
91/// - [`HookOutcome::Deny`] - reject the call with the supplied
92///   [`ErrorData`]; the inner handler is **not** called.
93/// - [`HookOutcome::Replace`] - return the supplied result instead of
94///   invoking the inner handler.  The result is still subject to
95///   `max_result_bytes` capping.
96#[derive(Debug)]
97#[non_exhaustive]
98pub enum HookOutcome {
99    /// Proceed with the wrapped handler.
100    Continue,
101    /// Reject the call.  The error is propagated to the client as-is.
102    Deny(ErrorData),
103    /// Skip the inner handler and return the supplied result instead.
104    Replace(Box<CallToolResult>),
105}
106
107/// How a tool call resolved, passed to the [`AfterHook`].
108#[derive(Debug, Clone, Copy)]
109#[non_exhaustive]
110pub enum HookDisposition {
111    /// The inner handler ran and returned `Ok`.
112    InnerExecuted,
113    /// The inner handler ran and returned `Err`.
114    InnerErrored,
115    /// The before-hook returned [`HookOutcome::Deny`].
116    DeniedBefore,
117    /// The before-hook returned [`HookOutcome::Replace`].
118    ReplacedBefore,
119    /// The result (from inner or replace) exceeded `max_result_bytes`
120    /// and was substituted with a structured error.
121    ResultTooLarge,
122}
123
124/// Async before-hook callback type.
125///
126/// Returns a [`HookOutcome`] controlling whether the inner handler runs.
127/// The borrow of `ToolCallContext` is held for the duration of the
128/// returned future, which avoids forcing implementations to clone the
129/// context for every invocation.
130pub type BeforeHook = Arc<
131    dyn for<'a> Fn(&'a ToolCallContext) -> Pin<Box<dyn Future<Output = HookOutcome> + Send + 'a>>
132        + Send
133        + Sync
134        + 'static,
135>;
136
137/// Async after-hook callback type.
138///
139/// Receives the call context, a [`HookDisposition`] describing how the
140/// call resolved, and the approximate serialized result size in bytes
141/// (`0` for `DeniedBefore` and `InnerErrored`).  Spawned via
142/// `tokio::spawn`, so it must not assume it runs before the response is
143/// flushed.
144pub type AfterHook = Arc<
145    dyn for<'a> Fn(
146            &'a ToolCallContext,
147            HookDisposition,
148            usize,
149        ) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>>
150        + Send
151        + Sync
152        + 'static,
153>;
154
155/// Opt-in hooks applied by [`crate::tool_hooks::HookedHandler`].
156#[allow(clippy::struct_field_names, reason = "before/after read naturally")]
157#[derive(Clone, Default)]
158#[non_exhaustive]
159pub struct ToolHooks {
160    /// Hard cap on serialized `CallToolResult` size in bytes.  When
161    /// exceeded, the result is replaced with an `is_error=true` result
162    /// carrying a `result_too_large` structured error.  `None` disables
163    /// the cap.
164    pub max_result_bytes: Option<usize>,
165    /// Optional before-hook invoked after arg deserialization, before
166    /// the wrapped handler is called.
167    pub before: Option<BeforeHook>,
168    /// Optional after-hook invoked once per call, regardless of how the
169    /// call resolved.  Spawned via `tokio::spawn` and never blocks the
170    /// response path.
171    pub after: Option<AfterHook>,
172}
173
174impl ToolHooks {
175    /// Construct an empty [`ToolHooks`] with no cap and no hooks.
176    ///
177    /// Use the `with_*` builder methods to populate fields; this avoids
178    /// the `#[non_exhaustive]` restriction that prevents struct-literal
179    /// construction from outside the crate.
180    #[must_use]
181    pub fn new() -> Self {
182        Self::default()
183    }
184
185    /// Set the serialized result size cap in bytes.
186    #[must_use]
187    pub fn with_max_result_bytes(mut self, max: usize) -> Self {
188        self.max_result_bytes = Some(max);
189        self
190    }
191
192    /// Set the before-hook.
193    #[must_use]
194    pub fn with_before(mut self, before: BeforeHook) -> Self {
195        self.before = Some(before);
196        self
197    }
198
199    /// Set the after-hook.
200    #[must_use]
201    pub fn with_after(mut self, after: AfterHook) -> Self {
202        self.after = Some(after);
203        self
204    }
205}
206
207impl fmt::Debug for ToolHooks {
208    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
209        f.debug_struct("ToolHooks")
210            .field("max_result_bytes", &self.max_result_bytes)
211            .field("before", &self.before.as_ref().map(|_| "<fn>"))
212            .field("after", &self.after.as_ref().map(|_| "<fn>"))
213            .finish()
214    }
215}
216
217/// `ServerHandler` wrapper that applies [`ToolHooks`].
218#[derive(Clone)]
219pub struct HookedHandler<H: ServerHandler> {
220    inner: Arc<H>,
221    hooks: Arc<ToolHooks>,
222}
223
224impl<H: ServerHandler> fmt::Debug for HookedHandler<H> {
225    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
226        f.debug_struct("HookedHandler")
227            .field("hooks", &self.hooks)
228            .finish_non_exhaustive()
229    }
230}
231
232/// Construct a [`crate::tool_hooks::HookedHandler`] from an inner handler and hooks.
233pub fn with_hooks<H: ServerHandler>(inner: H, hooks: Arc<ToolHooks>) -> HookedHandler<H> {
234    HookedHandler {
235        inner: Arc::new(inner),
236        hooks,
237    }
238}
239
240impl<H: ServerHandler> HookedHandler<H> {
241    /// Access the wrapped handler.
242    #[must_use]
243    pub fn inner(&self) -> &H {
244        &self.inner
245    }
246
247    fn build_context(request: &CallToolRequestParams, req_id: Option<String>) -> ToolCallContext {
248        ToolCallContext {
249            tool_name: request.name.to_string(),
250            arguments: request.arguments.clone().map(serde_json::Value::Object),
251            identity: crate::rbac::current_identity(),
252            role: crate::rbac::current_role(),
253            sub: crate::rbac::current_sub(),
254            request_id: req_id,
255        }
256    }
257
258    /// Spawn the after-hook on the current Tokio runtime.  The future
259    /// captures clones of `ctx` and the `Arc<AfterHook>` so it can run
260    /// independently of the request task; panics inside the after-hook
261    /// are caught by Tokio and never poison the response path.
262    ///
263    /// The spawned task is **instrumented** with the request span via
264    /// [`tracing::Instrument`] and re-establishes the per-request RBAC
265    /// task-locals (role, identity, token, sub) via
266    /// [`crate::rbac::with_rbac_scope`]. Without this, after-hooks lose
267    /// their parent span (breaking trace correlation) and observe
268    /// `current_role()` / `current_identity()` as `None`.
269    fn spawn_after(
270        after: Option<&Arc<AfterHookHolder>>,
271        ctx: ToolCallContext,
272        disposition: HookDisposition,
273        size: usize,
274    ) {
275        if let Some(after) = after {
276            use tracing::Instrument;
277
278            let after = Arc::clone(after);
279            // Capture the request span before leaving the request task so
280            // after-hook log lines are correlated with the originating call.
281            let span = tracing::Span::current();
282            // Snapshot RBAC task-locals; defaults are empty strings so the
283            // re-established scope is a no-op when the request had no
284            // authenticated identity (e.g. health checks, anonymous tools).
285            let role = crate::rbac::current_role().unwrap_or_default();
286            let identity = crate::rbac::current_identity().unwrap_or_default();
287            let token = crate::rbac::current_token()
288                .unwrap_or_else(|| secrecy::SecretString::from(String::new()));
289            let sub = crate::rbac::current_sub().unwrap_or_default();
290            tokio::spawn(
291                async move {
292                    crate::rbac::with_rbac_scope(role, identity, token, sub, async move {
293                        let fut = (after.f)(&ctx, disposition, size);
294                        fut.await;
295                    })
296                    .await;
297                }
298                .instrument(span),
299            );
300        }
301    }
302}
303
304/// Internal newtype that owns the [`AfterHook`] so we can `Arc::clone`
305/// the *holder* and let the spawned task borrow `ctx` for the lifetime
306/// of the future without lifetime acrobatics in `tokio::spawn`.
307struct AfterHookHolder {
308    f: AfterHook,
309}
310
311/// Structured error body returned when a result exceeds `max_result_bytes`.
312fn too_large_result(limit: usize, actual: usize, tool: &str) -> CallToolResult {
313    let body = serde_json::json!({
314        "error": "result_too_large",
315        "message": format!(
316            "tool '{tool}' result of {actual} bytes exceeds the configured \
317             max_result_bytes={limit}; ask for a narrower query"
318        ),
319        "limit_bytes": limit,
320        "actual_bytes": actual,
321    });
322    let mut r = CallToolResult::error(vec![Content::text(body.to_string())]);
323    r.structured_content = None;
324    r
325}
326
327fn serialized_size(result: &CallToolResult) -> usize {
328    serde_json::to_vec(result).map_or(0, |v| v.len())
329}
330
331/// Apply the `max_result_bytes` cap to a result.  Returns the (possibly
332/// replaced) result, the size used for accounting, and whether the cap
333/// fired.
334fn apply_size_cap(
335    result: CallToolResult,
336    max: Option<usize>,
337    tool: &str,
338) -> (CallToolResult, usize, bool) {
339    let size = serialized_size(&result);
340    if let Some(limit) = max
341        && size > limit
342    {
343        tracing::warn!(
344            tool = %tool,
345            size_bytes = size,
346            limit_bytes = limit,
347            "tool result exceeds max_result_bytes; replacing with structured error"
348        );
349        let replaced = too_large_result(limit, size, tool);
350        return (replaced, size, true);
351    }
352    (result, size, false)
353}
354
355impl<H: ServerHandler> ServerHandler for HookedHandler<H> {
356    fn get_info(&self) -> ServerInfo {
357        self.inner.get_info()
358    }
359
360    async fn initialize(
361        &self,
362        request: InitializeRequestParams,
363        context: RequestContext<RoleServer>,
364    ) -> Result<InitializeResult, ErrorData> {
365        self.inner.initialize(request, context).await
366    }
367
368    async fn list_tools(
369        &self,
370        request: Option<PaginatedRequestParams>,
371        context: RequestContext<RoleServer>,
372    ) -> Result<ListToolsResult, ErrorData> {
373        self.inner.list_tools(request, context).await
374    }
375
376    fn get_tool(&self, name: &str) -> Option<Tool> {
377        self.inner.get_tool(name)
378    }
379
380    async fn list_prompts(
381        &self,
382        request: Option<PaginatedRequestParams>,
383        context: RequestContext<RoleServer>,
384    ) -> Result<ListPromptsResult, ErrorData> {
385        self.inner.list_prompts(request, context).await
386    }
387
388    async fn get_prompt(
389        &self,
390        request: GetPromptRequestParams,
391        context: RequestContext<RoleServer>,
392    ) -> Result<GetPromptResult, ErrorData> {
393        self.inner.get_prompt(request, context).await
394    }
395
396    async fn list_resources(
397        &self,
398        request: Option<PaginatedRequestParams>,
399        context: RequestContext<RoleServer>,
400    ) -> Result<ListResourcesResult, ErrorData> {
401        self.inner.list_resources(request, context).await
402    }
403
404    async fn list_resource_templates(
405        &self,
406        request: Option<PaginatedRequestParams>,
407        context: RequestContext<RoleServer>,
408    ) -> Result<ListResourceTemplatesResult, ErrorData> {
409        self.inner.list_resource_templates(request, context).await
410    }
411
412    async fn read_resource(
413        &self,
414        request: ReadResourceRequestParams,
415        context: RequestContext<RoleServer>,
416    ) -> Result<ReadResourceResult, ErrorData> {
417        self.inner.read_resource(request, context).await
418    }
419
420    async fn call_tool(
421        &self,
422        request: CallToolRequestParams,
423        context: RequestContext<RoleServer>,
424    ) -> Result<CallToolResult, ErrorData> {
425        let req_id = Some(format!("{:?}", context.id));
426        let ctx = Self::build_context(&request, req_id);
427        let max = self.hooks.max_result_bytes;
428        let after_holder = self
429            .hooks
430            .after
431            .as_ref()
432            .map(|f| Arc::new(AfterHookHolder { f: Arc::clone(f) }));
433
434        // Before hook: may Continue, Deny, or Replace.
435        if let Some(before) = self.hooks.before.as_ref() {
436            let outcome = before(&ctx).await;
437            match outcome {
438                HookOutcome::Continue => {}
439                HookOutcome::Deny(err) => {
440                    Self::spawn_after(after_holder.as_ref(), ctx, HookDisposition::DeniedBefore, 0);
441                    return Err(err);
442                }
443                HookOutcome::Replace(boxed) => {
444                    let (final_result, size, capped) = apply_size_cap(*boxed, max, &ctx.tool_name);
445                    let disposition = if capped {
446                        HookDisposition::ResultTooLarge
447                    } else {
448                        HookDisposition::ReplacedBefore
449                    };
450                    Self::spawn_after(after_holder.as_ref(), ctx, disposition, size);
451                    return Ok(final_result);
452                }
453            }
454        }
455
456        // Inner handler.
457        let result = self.inner.call_tool(request, context).await;
458
459        match result {
460            Ok(ok) => {
461                let (final_result, size, capped) = apply_size_cap(ok, max, &ctx.tool_name);
462                let disposition = if capped {
463                    HookDisposition::ResultTooLarge
464                } else {
465                    HookDisposition::InnerExecuted
466                };
467                Self::spawn_after(after_holder.as_ref(), ctx, disposition, size);
468                Ok(final_result)
469            }
470            Err(e) => {
471                Self::spawn_after(after_holder.as_ref(), ctx, HookDisposition::InnerErrored, 0);
472                Err(e)
473            }
474        }
475    }
476}
477
478#[cfg(test)]
479mod tests {
480    use std::sync::{
481        Arc,
482        atomic::{AtomicUsize, Ordering},
483    };
484
485    use rmcp::{
486        ErrorData, RoleServer, ServerHandler,
487        model::{CallToolRequestParams, CallToolResult, Content, ServerInfo},
488        service::RequestContext,
489    };
490
491    use super::*;
492
493    /// Minimal in-process `ServerHandler` for tests.
494    #[derive(Clone, Default)]
495    struct TestHandler {
496        /// When Some, `call_tool` returns a body of this many 'x' bytes.
497        body_bytes: Option<usize>,
498    }
499
500    impl ServerHandler for TestHandler {
501        fn get_info(&self) -> ServerInfo {
502            ServerInfo::default()
503        }
504
505        async fn call_tool(
506            &self,
507            _request: CallToolRequestParams,
508            _context: RequestContext<RoleServer>,
509        ) -> Result<CallToolResult, ErrorData> {
510            let body = "x".repeat(self.body_bytes.unwrap_or(4));
511            Ok(CallToolResult::success(vec![Content::text(body)]))
512        }
513    }
514
515    fn ctx(name: &str) -> ToolCallContext {
516        ToolCallContext {
517            tool_name: name.to_owned(),
518            arguments: None,
519            identity: None,
520            role: None,
521            sub: None,
522            request_id: None,
523        }
524    }
525
526    #[tokio::test]
527    async fn size_cap_replaces_oversized_result() {
528        let inner = TestHandler {
529            body_bytes: Some(8_192),
530        };
531        let hooks = Arc::new(ToolHooks {
532            max_result_bytes: Some(256),
533            before: None,
534            after: None,
535        });
536        let hooked = with_hooks(inner, hooks);
537
538        let small = CallToolResult::success(vec![Content::text("ok".to_owned())]);
539        assert!(serialized_size(&small) < 256);
540
541        let big = CallToolResult::success(vec![Content::text("x".repeat(8_192))]);
542        let size = serialized_size(&big);
543        assert!(size > 256);
544
545        let (replaced, accounted, capped) = apply_size_cap(big, Some(256), "whatever");
546        assert!(capped);
547        assert_eq!(accounted, size);
548        assert_eq!(replaced.is_error, Some(true));
549        assert!(matches!(
550            &replaced.content[0].raw,
551            rmcp::model::RawContent::Text(t) if t.text.contains("result_too_large")
552        ));
553
554        // Compile-check that HookedHandler instantiates with the test inner.
555        let _ = hooked;
556    }
557
558    #[tokio::test]
559    async fn before_hook_deny_builds_error() {
560        let counter = Arc::new(AtomicUsize::new(0));
561        let c = Arc::clone(&counter);
562        let before: BeforeHook = Arc::new(move |ctx_ref| {
563            let c = Arc::clone(&c);
564            let name = ctx_ref.tool_name.clone();
565            Box::pin(async move {
566                c.fetch_add(1, Ordering::Relaxed);
567                if name == "forbidden" {
568                    HookOutcome::Deny(ErrorData::invalid_request("nope", None))
569                } else {
570                    HookOutcome::Continue
571                }
572            })
573        });
574
575        let hooks = Arc::new(ToolHooks {
576            max_result_bytes: None,
577            before: Some(before),
578            after: None,
579        });
580        let hooked = with_hooks(TestHandler::default(), hooks);
581
582        let bad_ctx = ctx("forbidden");
583        let before_fn = hooked.hooks.before.as_ref().unwrap();
584        let outcome = before_fn(&bad_ctx).await;
585        assert!(matches!(outcome, HookOutcome::Deny(_)));
586        assert_eq!(counter.load(Ordering::Relaxed), 1);
587
588        let ok_ctx = ctx("allowed");
589        let outcome2 = before_fn(&ok_ctx).await;
590        assert!(matches!(outcome2, HookOutcome::Continue));
591        assert_eq!(counter.load(Ordering::Relaxed), 2);
592    }
593
594    #[test]
595    fn too_large_result_mentions_limit_and_actual() {
596        let r = too_large_result(100, 500, "my_tool");
597        let body = serde_json::to_string(&r).unwrap();
598        assert!(body.contains("result_too_large"));
599        assert!(body.contains("my_tool"));
600        assert!(body.contains("100"));
601        assert!(body.contains("500"));
602    }
603
604    #[tokio::test]
605    async fn replace_outcome_skips_inner_and_returns_payload() {
606        // Returning Replace from before-hook must yield the supplied
607        // CallToolResult directly, with no need for the inner handler.
608        let before: BeforeHook = Arc::new(|_ctx| {
609            Box::pin(async {
610                HookOutcome::Replace(Box::new(CallToolResult::success(vec![Content::text(
611                    "from-replace".to_owned(),
612                )])))
613            })
614        });
615        let hooks = Arc::new(ToolHooks {
616            max_result_bytes: None,
617            before: Some(before),
618            after: None,
619        });
620        let _hooked = with_hooks(TestHandler::default(), Arc::clone(&hooks));
621
622        // Exercise the before-hook closure + apply_size_cap helper directly,
623        // matching the established test pattern in this module.
624        let outcome = (hooks.before.as_ref().unwrap())(&ctx("any")).await;
625        let HookOutcome::Replace(boxed) = outcome else {
626            panic!("expected HookOutcome::Replace");
627        };
628        let (result, size, capped) = apply_size_cap(*boxed, None, "any");
629        assert!(!capped);
630        assert!(size > 0);
631        assert!(!result.is_error.unwrap_or(false));
632        assert!(matches!(
633            &result.content[0].raw,
634            rmcp::model::RawContent::Text(t) if t.text == "from-replace"
635        ));
636    }
637
638    #[tokio::test]
639    async fn replace_outcome_subject_to_size_cap() {
640        // A Replace payload that exceeds max_result_bytes must be rewritten
641        // to result_too_large just like an inner-handler result would be,
642        // and the disposition must reflect ResultTooLarge.
643        let huge = CallToolResult::success(vec![Content::text("y".repeat(8_192))]);
644        let huge_size = serialized_size(&huge);
645        assert!(huge_size > 256);
646
647        let (final_result, accounted, capped) = apply_size_cap(huge, Some(256), "replaced_tool");
648        assert!(capped);
649        assert_eq!(accounted, huge_size);
650        assert_eq!(final_result.is_error, Some(true));
651        assert!(matches!(
652            &final_result.content[0].raw,
653            rmcp::model::RawContent::Text(t) if t.text.contains("result_too_large")
654        ));
655    }
656
657    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
658    async fn after_hook_fires_exactly_once_via_spawn() {
659        // spawn_after must enqueue the after-hook exactly one time per
660        // invocation and never block the caller; we wait for the spawned
661        // task to run by polling the counter with a short timeout.
662        let counter = Arc::new(AtomicUsize::new(0));
663        let c = Arc::clone(&counter);
664        let after: AfterHook = Arc::new(move |_ctx, _disp, _size| {
665            let c = Arc::clone(&c);
666            Box::pin(async move {
667                c.fetch_add(1, Ordering::Relaxed);
668            })
669        });
670        let holder = Arc::new(AfterHookHolder { f: after });
671
672        HookedHandler::<TestHandler>::spawn_after(
673            Some(&holder),
674            ctx("t"),
675            HookDisposition::InnerExecuted,
676            42,
677        );
678
679        // Wait up to 1s for the spawned task to run.
680        let deadline = std::time::Instant::now() + std::time::Duration::from_secs(1);
681        while counter.load(Ordering::Relaxed) == 0 && std::time::Instant::now() < deadline {
682            tokio::task::yield_now().await;
683            tokio::time::sleep(std::time::Duration::from_millis(5)).await;
684        }
685        assert_eq!(counter.load(Ordering::Relaxed), 1);
686    }
687
688    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
689    async fn after_hook_panic_is_isolated_from_response_path() {
690        // A panicking after-hook must not affect the request task.  We
691        // spawn a panicking after-hook and then verify the current task
692        // can still complete an unrelated future to completion.
693        let after: AfterHook = Arc::new(|_ctx, _disp, _size| {
694            Box::pin(async {
695                panic!("intentional panic in after-hook");
696            })
697        });
698        let holder = Arc::new(AfterHookHolder { f: after });
699
700        HookedHandler::<TestHandler>::spawn_after(
701            Some(&holder),
702            ctx("boom"),
703            HookDisposition::InnerExecuted,
704            0,
705        );
706
707        // Give Tokio a chance to run + abort the panicking task, then
708        // confirm we're still alive and the runtime is healthy.
709        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
710        let still_alive = tokio::spawn(async { 1_u32 + 2 }).await.unwrap();
711        assert_eq!(still_alive, 3);
712    }
713}