Skip to main content

rmcp_server_kit/
tool_hooks.rs

1//! Opt-in tool-call instrumentation for `ServerHandler` implementations.
2//!
3//! [`crate::tool_hooks::HookedHandler`] wraps any [`rmcp::ServerHandler`] with:
4//!
5//! - **Before hooks** (async) that observe `(tool_name, arguments, identity,
6//!   role, sub, request_id)` and may [`HookOutcome::Continue`](crate::tool_hooks::HookOutcome::Continue),
7//!   [`HookOutcome::Deny`](crate::tool_hooks::HookOutcome::Deny), or
8//!   [`HookOutcome::Replace`](crate::tool_hooks::HookOutcome::Replace) the call.
9//! - **After hooks** (async) that observe the same context plus a
10//!   [`HookDisposition`](crate::tool_hooks::HookDisposition) describing how the call resolved and the
11//!   approximate result size in bytes.  After-hooks are spawned via
12//!   `tokio::spawn` and never block the response path.
13//! - **Result-size capping**: serialized tool results larger than
14//!   `max_result_bytes` are replaced with a structured error, preventing
15//!   token-expensive or memory-expensive payloads from reaching clients.
16//!   The cap applies both to inner-handler results and to
17//!   [`HookOutcome::Replace`](crate::tool_hooks::HookOutcome::Replace) payloads.
18//!
19//! This is entirely **opt-in** at the application layer - `rmcp_server_kit::serve()`
20//! does not wrap handlers automatically.  Applications that want hooks do:
21//!
22//! ```no_run
23//! use std::sync::Arc;
24//! use rmcp_server_kit::tool_hooks::{HookedHandler, HookOutcome, ToolHooks, with_hooks};
25//!
26//! # #[derive(Clone, Default)]
27//! # struct MyHandler;
28//! # impl rmcp::ServerHandler for MyHandler {}
29//! let handler = MyHandler::default();
30//! let hooks = Arc::new(
31//!     ToolHooks::new()
32//!         .with_max_result_bytes(256 * 1024)
33//!         .with_before(Arc::new(|_ctx| Box::pin(async { HookOutcome::Continue })))
34//!         .with_after(Arc::new(|_ctx, _disp, _bytes| Box::pin(async {}))),
35//! );
36//! let _wrapped = with_hooks(handler, hooks);
37//! ```
38
39use std::{fmt, future::Future, pin::Pin, sync::Arc};
40
41use rmcp::{
42    ErrorData, RoleServer, ServerHandler,
43    model::{
44        CallToolRequestParams, CallToolResult, Content, GetPromptRequestParams, GetPromptResult,
45        InitializeRequestParams, InitializeResult, ListPromptsResult, ListResourceTemplatesResult,
46        ListResourcesResult, ListToolsResult, PaginatedRequestParams, ReadResourceRequestParams,
47        ReadResourceResult, ServerInfo, Tool,
48    },
49    service::RequestContext,
50};
51
52/// Context passed to before/after hooks for a single tool call.
53#[derive(Debug, Clone)]
54#[non_exhaustive]
55pub struct ToolCallContext {
56    /// Tool name being invoked.
57    pub tool_name: String,
58    /// JSON arguments as sent by the client (may be `None`).
59    pub arguments: Option<serde_json::Value>,
60    /// Identity name from the authenticated request, if any.
61    pub identity: Option<String>,
62    /// RBAC role associated with the request, if any.
63    pub role: Option<String>,
64    /// OAuth `sub` claim, if present.
65    pub sub: Option<String>,
66    /// Raw JSON-RPC request id rendered as a string, if available.
67    pub request_id: Option<String>,
68}
69
70impl ToolCallContext {
71    /// Construct a [`ToolCallContext`] with the given tool name and all
72    /// optional fields cleared.  Primarily for use in unit tests and
73    /// benchmarks of user-supplied hooks; the runtime path populates
74    /// these fields from the request and task-local RBAC state.
75    #[must_use]
76    pub fn for_tool(tool_name: impl Into<String>) -> Self {
77        Self {
78            tool_name: tool_name.into(),
79            arguments: None,
80            identity: None,
81            role: None,
82            sub: None,
83            request_id: None,
84        }
85    }
86}
87
88/// Outcome returned by a [`BeforeHook`] to control invocation flow.
89///
90/// - [`HookOutcome::Continue`] - proceed with the wrapped handler.
91/// - [`HookOutcome::Deny`] - reject the call with the supplied
92///   [`ErrorData`]; the inner handler is **not** called.
93/// - [`HookOutcome::Replace`] - return the supplied result instead of
94///   invoking the inner handler.  The result is still subject to
95///   `max_result_bytes` capping.
96#[derive(Debug)]
97#[non_exhaustive]
98pub enum HookOutcome {
99    /// Proceed with the wrapped handler.
100    Continue,
101    /// Reject the call.  The error is propagated to the client as-is.
102    Deny(ErrorData),
103    /// Skip the inner handler and return the supplied result instead.
104    Replace(Box<CallToolResult>),
105}
106
107/// How a tool call resolved, passed to the [`AfterHook`].
108#[derive(Debug, Clone, Copy)]
109#[non_exhaustive]
110pub enum HookDisposition {
111    /// The inner handler ran and returned `Ok`.
112    InnerExecuted,
113    /// The inner handler ran and returned `Err`.
114    InnerErrored,
115    /// The before-hook returned [`HookOutcome::Deny`].
116    DeniedBefore,
117    /// The before-hook returned [`HookOutcome::Replace`].
118    ReplacedBefore,
119    /// The result (from inner or replace) exceeded `max_result_bytes`
120    /// and was substituted with a structured error.
121    ResultTooLarge,
122}
123
124/// Async before-hook callback type.
125///
126/// Returns a [`HookOutcome`] controlling whether the inner handler runs.
127/// The borrow of `ToolCallContext` is held for the duration of the
128/// returned future, which avoids forcing implementations to clone the
129/// context for every invocation.
130pub type BeforeHook = Arc<
131    dyn for<'a> Fn(&'a ToolCallContext) -> Pin<Box<dyn Future<Output = HookOutcome> + Send + 'a>>
132        + Send
133        + Sync
134        + 'static,
135>;
136
137/// Async after-hook callback type.
138///
139/// Receives the call context, a [`HookDisposition`] describing how the
140/// call resolved, and the approximate serialized result size in bytes
141/// (`0` for `DeniedBefore` and `InnerErrored`).  Spawned via
142/// `tokio::spawn`, so it must not assume it runs before the response is
143/// flushed.
144pub type AfterHook = Arc<
145    dyn for<'a> Fn(
146            &'a ToolCallContext,
147            HookDisposition,
148            usize,
149        ) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>>
150        + Send
151        + Sync
152        + 'static,
153>;
154
155/// Opt-in hooks applied by [`crate::tool_hooks::HookedHandler`].
156#[allow(clippy::struct_field_names, reason = "before/after read naturally")]
157#[derive(Clone, Default)]
158#[non_exhaustive]
159pub struct ToolHooks {
160    /// Hard cap on serialized `CallToolResult` size in bytes.  When
161    /// exceeded, the result is replaced with an `is_error=true` result
162    /// carrying a `result_too_large` structured error.  `None` disables
163    /// the cap.
164    pub max_result_bytes: Option<usize>,
165    /// Optional before-hook invoked after arg deserialization, before
166    /// the wrapped handler is called.
167    pub before: Option<BeforeHook>,
168    /// Optional after-hook invoked once per call, regardless of how the
169    /// call resolved.  Spawned via `tokio::spawn` and never blocks the
170    /// response path.
171    pub after: Option<AfterHook>,
172}
173
174impl ToolHooks {
175    /// Construct an empty [`ToolHooks`] with no cap and no hooks.
176    ///
177    /// Use the `with_*` builder methods to populate fields; this avoids
178    /// the `#[non_exhaustive]` restriction that prevents struct-literal
179    /// construction from outside the crate.
180    #[must_use]
181    pub fn new() -> Self {
182        Self::default()
183    }
184
185    /// Set the serialized result size cap in bytes.
186    #[must_use]
187    pub fn with_max_result_bytes(mut self, max: usize) -> Self {
188        self.max_result_bytes = Some(max);
189        self
190    }
191
192    /// Set the before-hook.
193    #[must_use]
194    pub fn with_before(mut self, before: BeforeHook) -> Self {
195        self.before = Some(before);
196        self
197    }
198
199    /// Set the after-hook.
200    #[must_use]
201    pub fn with_after(mut self, after: AfterHook) -> Self {
202        self.after = Some(after);
203        self
204    }
205}
206
207impl fmt::Debug for ToolHooks {
208    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
209        f.debug_struct("ToolHooks")
210            .field("max_result_bytes", &self.max_result_bytes)
211            .field("before", &self.before.as_ref().map(|_| "<fn>"))
212            .field("after", &self.after.as_ref().map(|_| "<fn>"))
213            .finish()
214    }
215}
216
217/// `ServerHandler` wrapper that applies [`ToolHooks`].
218#[derive(Clone)]
219pub struct HookedHandler<H: ServerHandler> {
220    inner: Arc<H>,
221    hooks: Arc<ToolHooks>,
222}
223
224impl<H: ServerHandler> fmt::Debug for HookedHandler<H> {
225    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
226        f.debug_struct("HookedHandler")
227            .field("hooks", &self.hooks)
228            .finish_non_exhaustive()
229    }
230}
231
232/// Construct a [`crate::tool_hooks::HookedHandler`] from an inner handler and hooks.
233///
234/// Returning the wrapped handler is the entire point of this function;
235/// dropping it on the floor would silently disable the supplied hooks. The
236/// `#[must_use]` attribute would be the natural enforcement here, but adding
237/// it to a public function is a SemVer-minor change per cargo-semver-checks;
238/// it is deferred to the next minor-version bump.
239///
240// NOTE(next-minor): add `#[must_use = "HookedHandler must be wired into a
241// ServerHandler (e.g. via `serve(..., || hooked)`) to take effect; dropping
242// the returned value silently disables the supplied hooks"]` here. Deferred
243// from patch releases on 1.7.x to avoid downstream `-D warnings` churn.
244// Cross-ref: CHANGELOG.md "[Unreleased]" — Quality / lint hygiene section.
245pub fn with_hooks<H: ServerHandler>(inner: H, hooks: Arc<ToolHooks>) -> HookedHandler<H> {
246    HookedHandler {
247        inner: Arc::new(inner),
248        hooks,
249    }
250}
251
252impl<H: ServerHandler> HookedHandler<H> {
253    /// Access the wrapped handler.
254    #[must_use]
255    pub fn inner(&self) -> &H {
256        &self.inner
257    }
258
259    fn build_context(request: &CallToolRequestParams, req_id: Option<String>) -> ToolCallContext {
260        ToolCallContext {
261            tool_name: request.name.to_string(),
262            arguments: request.arguments.clone().map(serde_json::Value::Object),
263            identity: crate::rbac::current_identity(),
264            role: crate::rbac::current_role(),
265            sub: crate::rbac::current_sub(),
266            request_id: req_id,
267        }
268    }
269
270    /// Spawn the after-hook on the current Tokio runtime.  The future
271    /// captures clones of `ctx` and the `Arc<AfterHook>` so it can run
272    /// independently of the request task; panics inside the after-hook
273    /// are caught by Tokio and never poison the response path.
274    ///
275    /// The spawned task is **instrumented** with the request span via
276    /// [`tracing::Instrument`] and re-establishes the per-request RBAC
277    /// task-locals (role, identity, token, sub) via
278    /// [`crate::rbac::with_rbac_scope`]. Without this, after-hooks lose
279    /// their parent span (breaking trace correlation) and observe
280    /// `current_role()` / `current_identity()` as `None`.
281    fn spawn_after(
282        after: Option<&Arc<AfterHookHolder>>,
283        ctx: ToolCallContext,
284        disposition: HookDisposition,
285        size: usize,
286    ) {
287        if let Some(after) = after {
288            use tracing::Instrument;
289
290            let after = Arc::clone(after);
291            // Capture the request span before leaving the request task so
292            // after-hook log lines are correlated with the originating call.
293            let span = tracing::Span::current();
294            // Snapshot RBAC task-locals; defaults are empty strings so the
295            // re-established scope is a no-op when the request had no
296            // authenticated identity (e.g. health checks, anonymous tools).
297            let role = crate::rbac::current_role().unwrap_or_default();
298            let identity = crate::rbac::current_identity().unwrap_or_default();
299            let token = crate::rbac::current_token()
300                .unwrap_or_else(|| secrecy::SecretString::from(String::new()));
301            let sub = crate::rbac::current_sub().unwrap_or_default();
302            tokio::spawn(
303                async move {
304                    crate::rbac::with_rbac_scope(role, identity, token, sub, async move {
305                        let fut = (after.f)(&ctx, disposition, size);
306                        fut.await;
307                    })
308                    .await;
309                }
310                .instrument(span),
311            );
312        }
313    }
314}
315
316/// Internal newtype that owns the [`AfterHook`] so we can `Arc::clone`
317/// the *holder* and let the spawned task borrow `ctx` for the lifetime
318/// of the future without lifetime acrobatics in `tokio::spawn`.
319struct AfterHookHolder {
320    f: AfterHook,
321}
322
323/// Structured error body returned when a result exceeds `max_result_bytes`.
324fn too_large_result(limit: usize, actual: usize, tool: &str) -> CallToolResult {
325    let body = serde_json::json!({
326        "error": "result_too_large",
327        "message": format!(
328            "tool '{tool}' result of {actual} bytes exceeds the configured \
329             max_result_bytes={limit}; ask for a narrower query"
330        ),
331        "limit_bytes": limit,
332        "actual_bytes": actual,
333    });
334    let mut r = CallToolResult::error(vec![Content::text(body.to_string())]);
335    r.structured_content = None;
336    r
337}
338
339fn serialized_size(result: &CallToolResult) -> usize {
340    serde_json::to_vec(result).map_or(0, |v| v.len())
341}
342
343/// Apply the `max_result_bytes` cap to a result.  Returns the (possibly
344/// replaced) result, the size used for accounting, and whether the cap
345/// fired.
346fn apply_size_cap(
347    result: CallToolResult,
348    max: Option<usize>,
349    tool: &str,
350) -> (CallToolResult, usize, bool) {
351    let size = serialized_size(&result);
352    if let Some(limit) = max
353        && size > limit
354    {
355        tracing::warn!(
356            tool = %tool,
357            size_bytes = size,
358            limit_bytes = limit,
359            "tool result exceeds max_result_bytes; replacing with structured error"
360        );
361        let replaced = too_large_result(limit, size, tool);
362        return (replaced, size, true);
363    }
364    (result, size, false)
365}
366
367impl<H: ServerHandler> ServerHandler for HookedHandler<H> {
368    fn get_info(&self) -> ServerInfo {
369        self.inner.get_info()
370    }
371
372    async fn initialize(
373        &self,
374        request: InitializeRequestParams,
375        context: RequestContext<RoleServer>,
376    ) -> Result<InitializeResult, ErrorData> {
377        self.inner.initialize(request, context).await
378    }
379
380    async fn list_tools(
381        &self,
382        request: Option<PaginatedRequestParams>,
383        context: RequestContext<RoleServer>,
384    ) -> Result<ListToolsResult, ErrorData> {
385        self.inner.list_tools(request, context).await
386    }
387
388    fn get_tool(&self, name: &str) -> Option<Tool> {
389        self.inner.get_tool(name)
390    }
391
392    async fn list_prompts(
393        &self,
394        request: Option<PaginatedRequestParams>,
395        context: RequestContext<RoleServer>,
396    ) -> Result<ListPromptsResult, ErrorData> {
397        self.inner.list_prompts(request, context).await
398    }
399
400    async fn get_prompt(
401        &self,
402        request: GetPromptRequestParams,
403        context: RequestContext<RoleServer>,
404    ) -> Result<GetPromptResult, ErrorData> {
405        self.inner.get_prompt(request, context).await
406    }
407
408    async fn list_resources(
409        &self,
410        request: Option<PaginatedRequestParams>,
411        context: RequestContext<RoleServer>,
412    ) -> Result<ListResourcesResult, ErrorData> {
413        self.inner.list_resources(request, context).await
414    }
415
416    async fn list_resource_templates(
417        &self,
418        request: Option<PaginatedRequestParams>,
419        context: RequestContext<RoleServer>,
420    ) -> Result<ListResourceTemplatesResult, ErrorData> {
421        self.inner.list_resource_templates(request, context).await
422    }
423
424    async fn read_resource(
425        &self,
426        request: ReadResourceRequestParams,
427        context: RequestContext<RoleServer>,
428    ) -> Result<ReadResourceResult, ErrorData> {
429        self.inner.read_resource(request, context).await
430    }
431
432    async fn call_tool(
433        &self,
434        request: CallToolRequestParams,
435        context: RequestContext<RoleServer>,
436    ) -> Result<CallToolResult, ErrorData> {
437        let req_id = Some(format!("{:?}", context.id));
438        let ctx = Self::build_context(&request, req_id);
439        let max = self.hooks.max_result_bytes;
440        let after_holder = self
441            .hooks
442            .after
443            .as_ref()
444            .map(|f| Arc::new(AfterHookHolder { f: Arc::clone(f) }));
445
446        // Before hook: may Continue, Deny, or Replace.
447        if let Some(before) = self.hooks.before.as_ref() {
448            let outcome = before(&ctx).await;
449            match outcome {
450                HookOutcome::Continue => {}
451                HookOutcome::Deny(err) => {
452                    Self::spawn_after(after_holder.as_ref(), ctx, HookDisposition::DeniedBefore, 0);
453                    return Err(err);
454                }
455                HookOutcome::Replace(boxed) => {
456                    let (final_result, size, capped) = apply_size_cap(*boxed, max, &ctx.tool_name);
457                    let disposition = if capped {
458                        HookDisposition::ResultTooLarge
459                    } else {
460                        HookDisposition::ReplacedBefore
461                    };
462                    Self::spawn_after(after_holder.as_ref(), ctx, disposition, size);
463                    return Ok(final_result);
464                }
465            }
466        }
467
468        // Inner handler.
469        let result = self.inner.call_tool(request, context).await;
470
471        match result {
472            Ok(ok) => {
473                let (final_result, size, capped) = apply_size_cap(ok, max, &ctx.tool_name);
474                let disposition = if capped {
475                    HookDisposition::ResultTooLarge
476                } else {
477                    HookDisposition::InnerExecuted
478                };
479                Self::spawn_after(after_holder.as_ref(), ctx, disposition, size);
480                Ok(final_result)
481            }
482            Err(e) => {
483                Self::spawn_after(after_holder.as_ref(), ctx, HookDisposition::InnerErrored, 0);
484                Err(e)
485            }
486        }
487    }
488}
489
490#[cfg(test)]
491mod tests {
492    use std::sync::{
493        Arc,
494        atomic::{AtomicUsize, Ordering},
495    };
496
497    use rmcp::{
498        ErrorData, RoleServer, ServerHandler,
499        model::{CallToolRequestParams, CallToolResult, Content, ServerInfo},
500        service::RequestContext,
501    };
502
503    use super::*;
504
505    /// Minimal in-process `ServerHandler` for tests.
506    #[derive(Clone, Default)]
507    struct TestHandler {
508        /// When Some, `call_tool` returns a body of this many 'x' bytes.
509        body_bytes: Option<usize>,
510    }
511
512    impl ServerHandler for TestHandler {
513        fn get_info(&self) -> ServerInfo {
514            ServerInfo::default()
515        }
516
517        async fn call_tool(
518            &self,
519            _request: CallToolRequestParams,
520            _context: RequestContext<RoleServer>,
521        ) -> Result<CallToolResult, ErrorData> {
522            let body = "x".repeat(self.body_bytes.unwrap_or(4));
523            Ok(CallToolResult::success(vec![Content::text(body)]))
524        }
525    }
526
527    fn ctx(name: &str) -> ToolCallContext {
528        ToolCallContext {
529            tool_name: name.to_owned(),
530            arguments: None,
531            identity: None,
532            role: None,
533            sub: None,
534            request_id: None,
535        }
536    }
537
538    #[tokio::test]
539    async fn size_cap_replaces_oversized_result() {
540        let inner = TestHandler {
541            body_bytes: Some(8_192),
542        };
543        let hooks = Arc::new(ToolHooks {
544            max_result_bytes: Some(256),
545            before: None,
546            after: None,
547        });
548        let hooked = with_hooks(inner, hooks);
549
550        let small = CallToolResult::success(vec![Content::text("ok".to_owned())]);
551        assert!(serialized_size(&small) < 256);
552
553        let big = CallToolResult::success(vec![Content::text("x".repeat(8_192))]);
554        let size = serialized_size(&big);
555        assert!(size > 256);
556
557        let (replaced, accounted, capped) = apply_size_cap(big, Some(256), "whatever");
558        assert!(capped);
559        assert_eq!(accounted, size);
560        assert_eq!(replaced.is_error, Some(true));
561        assert!(matches!(
562            &replaced.content[0].raw,
563            rmcp::model::RawContent::Text(t) if t.text.contains("result_too_large")
564        ));
565
566        // Compile-check that HookedHandler instantiates with the test inner.
567        let _ = hooked;
568    }
569
570    #[tokio::test]
571    async fn before_hook_deny_builds_error() {
572        let counter = Arc::new(AtomicUsize::new(0));
573        let c = Arc::clone(&counter);
574        let before: BeforeHook = Arc::new(move |ctx_ref| {
575            let c = Arc::clone(&c);
576            let name = ctx_ref.tool_name.clone();
577            Box::pin(async move {
578                c.fetch_add(1, Ordering::Relaxed);
579                if name == "forbidden" {
580                    HookOutcome::Deny(ErrorData::invalid_request("nope", None))
581                } else {
582                    HookOutcome::Continue
583                }
584            })
585        });
586
587        let hooks = Arc::new(ToolHooks {
588            max_result_bytes: None,
589            before: Some(before),
590            after: None,
591        });
592        let hooked = with_hooks(TestHandler::default(), hooks);
593
594        let bad_ctx = ctx("forbidden");
595        let before_fn = hooked.hooks.before.as_ref().unwrap();
596        let outcome = before_fn(&bad_ctx).await;
597        assert!(matches!(outcome, HookOutcome::Deny(_)));
598        assert_eq!(counter.load(Ordering::Relaxed), 1);
599
600        let ok_ctx = ctx("allowed");
601        let outcome2 = before_fn(&ok_ctx).await;
602        assert!(matches!(outcome2, HookOutcome::Continue));
603        assert_eq!(counter.load(Ordering::Relaxed), 2);
604    }
605
606    #[test]
607    fn too_large_result_mentions_limit_and_actual() {
608        let r = too_large_result(100, 500, "my_tool");
609        let body = serde_json::to_string(&r).unwrap();
610        assert!(body.contains("result_too_large"));
611        assert!(body.contains("my_tool"));
612        assert!(body.contains("100"));
613        assert!(body.contains("500"));
614    }
615
616    #[tokio::test]
617    async fn replace_outcome_skips_inner_and_returns_payload() {
618        // Returning Replace from before-hook must yield the supplied
619        // CallToolResult directly, with no need for the inner handler.
620        let before: BeforeHook = Arc::new(|_ctx| {
621            Box::pin(async {
622                HookOutcome::Replace(Box::new(CallToolResult::success(vec![Content::text(
623                    "from-replace".to_owned(),
624                )])))
625            })
626        });
627        let hooks = Arc::new(ToolHooks {
628            max_result_bytes: None,
629            before: Some(before),
630            after: None,
631        });
632        let _hooked = with_hooks(TestHandler::default(), Arc::clone(&hooks));
633
634        // Exercise the before-hook closure + apply_size_cap helper directly,
635        // matching the established test pattern in this module.
636        let outcome = (hooks.before.as_ref().unwrap())(&ctx("any")).await;
637        let HookOutcome::Replace(boxed) = outcome else {
638            panic!("expected HookOutcome::Replace");
639        };
640        let (result, size, capped) = apply_size_cap(*boxed, None, "any");
641        assert!(!capped);
642        assert!(size > 0);
643        assert!(!result.is_error.unwrap_or(false));
644        assert!(matches!(
645            &result.content[0].raw,
646            rmcp::model::RawContent::Text(t) if t.text == "from-replace"
647        ));
648    }
649
650    #[tokio::test]
651    async fn replace_outcome_subject_to_size_cap() {
652        // A Replace payload that exceeds max_result_bytes must be rewritten
653        // to result_too_large just like an inner-handler result would be,
654        // and the disposition must reflect ResultTooLarge.
655        let huge = CallToolResult::success(vec![Content::text("y".repeat(8_192))]);
656        let huge_size = serialized_size(&huge);
657        assert!(huge_size > 256);
658
659        let (final_result, accounted, capped) = apply_size_cap(huge, Some(256), "replaced_tool");
660        assert!(capped);
661        assert_eq!(accounted, huge_size);
662        assert_eq!(final_result.is_error, Some(true));
663        assert!(matches!(
664            &final_result.content[0].raw,
665            rmcp::model::RawContent::Text(t) if t.text.contains("result_too_large")
666        ));
667    }
668
669    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
670    async fn after_hook_fires_exactly_once_via_spawn() {
671        // spawn_after must enqueue the after-hook exactly one time per
672        // invocation and never block the caller; we wait for the spawned
673        // task to run by polling the counter with a short timeout.
674        let counter = Arc::new(AtomicUsize::new(0));
675        let c = Arc::clone(&counter);
676        let after: AfterHook = Arc::new(move |_ctx, _disp, _size| {
677            let c = Arc::clone(&c);
678            Box::pin(async move {
679                c.fetch_add(1, Ordering::Relaxed);
680            })
681        });
682        let holder = Arc::new(AfterHookHolder { f: after });
683
684        HookedHandler::<TestHandler>::spawn_after(
685            Some(&holder),
686            ctx("t"),
687            HookDisposition::InnerExecuted,
688            42,
689        );
690
691        // Wait up to 1s for the spawned task to run.
692        let deadline = std::time::Instant::now() + std::time::Duration::from_secs(1);
693        while counter.load(Ordering::Relaxed) == 0 && std::time::Instant::now() < deadline {
694            tokio::task::yield_now().await;
695            tokio::time::sleep(std::time::Duration::from_millis(5)).await;
696        }
697        assert_eq!(counter.load(Ordering::Relaxed), 1);
698    }
699
700    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
701    async fn after_hook_panic_is_isolated_from_response_path() {
702        // A panicking after-hook must not affect the request task.  We
703        // spawn a panicking after-hook and then verify the current task
704        // can still complete an unrelated future to completion.
705        let after: AfterHook = Arc::new(|_ctx, _disp, _size| {
706            Box::pin(async {
707                panic!("intentional panic in after-hook");
708            })
709        });
710        let holder = Arc::new(AfterHookHolder { f: after });
711
712        HookedHandler::<TestHandler>::spawn_after(
713            Some(&holder),
714            ctx("boom"),
715            HookDisposition::InnerExecuted,
716            0,
717        );
718
719        // Give Tokio a chance to run + abort the panicking task, then
720        // confirm we're still alive and the runtime is healthy.
721        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
722        let still_alive = tokio::spawn(async { 1_u32 + 2 }).await.unwrap();
723        assert_eq!(still_alive, 3);
724    }
725}