1use std::{fmt, future::Future, pin::Pin, sync::Arc};
40
41use rmcp::{
42 ErrorData, RoleServer, ServerHandler,
43 model::{
44 CallToolRequestParams, CallToolResult, Content, GetPromptRequestParams, GetPromptResult,
45 InitializeRequestParams, InitializeResult, ListPromptsResult, ListResourceTemplatesResult,
46 ListResourcesResult, ListToolsResult, PaginatedRequestParams, ReadResourceRequestParams,
47 ReadResourceResult, ServerInfo, Tool,
48 },
49 service::RequestContext,
50};
51
52#[derive(Debug, Clone)]
54#[non_exhaustive]
55pub struct ToolCallContext {
56 pub tool_name: String,
58 pub arguments: Option<serde_json::Value>,
60 pub identity: Option<String>,
62 pub role: Option<String>,
64 pub sub: Option<String>,
66 pub request_id: Option<String>,
68}
69
70impl ToolCallContext {
71 #[must_use]
76 pub fn for_tool(tool_name: impl Into<String>) -> Self {
77 Self {
78 tool_name: tool_name.into(),
79 arguments: None,
80 identity: None,
81 role: None,
82 sub: None,
83 request_id: None,
84 }
85 }
86}
87
88#[derive(Debug)]
97#[non_exhaustive]
98pub enum HookOutcome {
99 Continue,
101 Deny(ErrorData),
103 Replace(Box<CallToolResult>),
105}
106
107#[derive(Debug, Clone, Copy)]
109#[non_exhaustive]
110pub enum HookDisposition {
111 InnerExecuted,
113 InnerErrored,
115 DeniedBefore,
117 ReplacedBefore,
119 ResultTooLarge,
122}
123
124pub type BeforeHook = Arc<
131 dyn for<'a> Fn(&'a ToolCallContext) -> Pin<Box<dyn Future<Output = HookOutcome> + Send + 'a>>
132 + Send
133 + Sync
134 + 'static,
135>;
136
137pub type AfterHook = Arc<
145 dyn for<'a> Fn(
146 &'a ToolCallContext,
147 HookDisposition,
148 usize,
149 ) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>>
150 + Send
151 + Sync
152 + 'static,
153>;
154
155#[allow(clippy::struct_field_names, reason = "before/after read naturally")]
157#[derive(Clone, Default)]
158#[non_exhaustive]
159pub struct ToolHooks {
160 pub max_result_bytes: Option<usize>,
165 pub before: Option<BeforeHook>,
168 pub after: Option<AfterHook>,
172}
173
174impl ToolHooks {
175 #[must_use]
181 pub fn new() -> Self {
182 Self::default()
183 }
184
185 #[must_use]
187 pub fn with_max_result_bytes(mut self, max: usize) -> Self {
188 self.max_result_bytes = Some(max);
189 self
190 }
191
192 #[must_use]
194 pub fn with_before(mut self, before: BeforeHook) -> Self {
195 self.before = Some(before);
196 self
197 }
198
199 #[must_use]
201 pub fn with_after(mut self, after: AfterHook) -> Self {
202 self.after = Some(after);
203 self
204 }
205}
206
207impl fmt::Debug for ToolHooks {
208 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
209 f.debug_struct("ToolHooks")
210 .field("max_result_bytes", &self.max_result_bytes)
211 .field("before", &self.before.as_ref().map(|_| "<fn>"))
212 .field("after", &self.after.as_ref().map(|_| "<fn>"))
213 .finish()
214 }
215}
216
217#[derive(Clone)]
219pub struct HookedHandler<H: ServerHandler> {
220 inner: Arc<H>,
221 hooks: Arc<ToolHooks>,
222}
223
224impl<H: ServerHandler> fmt::Debug for HookedHandler<H> {
225 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
226 f.debug_struct("HookedHandler")
227 .field("hooks", &self.hooks)
228 .finish_non_exhaustive()
229 }
230}
231
232pub fn with_hooks<H: ServerHandler>(inner: H, hooks: Arc<ToolHooks>) -> HookedHandler<H> {
234 HookedHandler {
235 inner: Arc::new(inner),
236 hooks,
237 }
238}
239
240impl<H: ServerHandler> HookedHandler<H> {
241 #[must_use]
243 pub fn inner(&self) -> &H {
244 &self.inner
245 }
246
247 fn build_context(request: &CallToolRequestParams, req_id: Option<String>) -> ToolCallContext {
248 ToolCallContext {
249 tool_name: request.name.to_string(),
250 arguments: request.arguments.clone().map(serde_json::Value::Object),
251 identity: crate::rbac::current_identity(),
252 role: crate::rbac::current_role(),
253 sub: crate::rbac::current_sub(),
254 request_id: req_id,
255 }
256 }
257
258 fn spawn_after(
270 after: Option<&Arc<AfterHookHolder>>,
271 ctx: ToolCallContext,
272 disposition: HookDisposition,
273 size: usize,
274 ) {
275 if let Some(after) = after {
276 use tracing::Instrument;
277
278 let after = Arc::clone(after);
279 let span = tracing::Span::current();
282 let role = crate::rbac::current_role().unwrap_or_default();
286 let identity = crate::rbac::current_identity().unwrap_or_default();
287 let token = crate::rbac::current_token()
288 .unwrap_or_else(|| secrecy::SecretString::from(String::new()));
289 let sub = crate::rbac::current_sub().unwrap_or_default();
290 tokio::spawn(
291 async move {
292 crate::rbac::with_rbac_scope(role, identity, token, sub, async move {
293 let fut = (after.f)(&ctx, disposition, size);
294 fut.await;
295 })
296 .await;
297 }
298 .instrument(span),
299 );
300 }
301 }
302}
303
304struct AfterHookHolder {
308 f: AfterHook,
309}
310
311fn too_large_result(limit: usize, actual: usize, tool: &str) -> CallToolResult {
313 let body = serde_json::json!({
314 "error": "result_too_large",
315 "message": format!(
316 "tool '{tool}' result of {actual} bytes exceeds the configured \
317 max_result_bytes={limit}; ask for a narrower query"
318 ),
319 "limit_bytes": limit,
320 "actual_bytes": actual,
321 });
322 let mut r = CallToolResult::error(vec![Content::text(body.to_string())]);
323 r.structured_content = None;
324 r
325}
326
327fn serialized_size(result: &CallToolResult) -> usize {
328 serde_json::to_vec(result).map_or(0, |v| v.len())
329}
330
331fn apply_size_cap(
335 result: CallToolResult,
336 max: Option<usize>,
337 tool: &str,
338) -> (CallToolResult, usize, bool) {
339 let size = serialized_size(&result);
340 if let Some(limit) = max
341 && size > limit
342 {
343 tracing::warn!(
344 tool = %tool,
345 size_bytes = size,
346 limit_bytes = limit,
347 "tool result exceeds max_result_bytes; replacing with structured error"
348 );
349 let replaced = too_large_result(limit, size, tool);
350 return (replaced, size, true);
351 }
352 (result, size, false)
353}
354
355impl<H: ServerHandler> ServerHandler for HookedHandler<H> {
356 fn get_info(&self) -> ServerInfo {
357 self.inner.get_info()
358 }
359
360 async fn initialize(
361 &self,
362 request: InitializeRequestParams,
363 context: RequestContext<RoleServer>,
364 ) -> Result<InitializeResult, ErrorData> {
365 self.inner.initialize(request, context).await
366 }
367
368 async fn list_tools(
369 &self,
370 request: Option<PaginatedRequestParams>,
371 context: RequestContext<RoleServer>,
372 ) -> Result<ListToolsResult, ErrorData> {
373 self.inner.list_tools(request, context).await
374 }
375
376 fn get_tool(&self, name: &str) -> Option<Tool> {
377 self.inner.get_tool(name)
378 }
379
380 async fn list_prompts(
381 &self,
382 request: Option<PaginatedRequestParams>,
383 context: RequestContext<RoleServer>,
384 ) -> Result<ListPromptsResult, ErrorData> {
385 self.inner.list_prompts(request, context).await
386 }
387
388 async fn get_prompt(
389 &self,
390 request: GetPromptRequestParams,
391 context: RequestContext<RoleServer>,
392 ) -> Result<GetPromptResult, ErrorData> {
393 self.inner.get_prompt(request, context).await
394 }
395
396 async fn list_resources(
397 &self,
398 request: Option<PaginatedRequestParams>,
399 context: RequestContext<RoleServer>,
400 ) -> Result<ListResourcesResult, ErrorData> {
401 self.inner.list_resources(request, context).await
402 }
403
404 async fn list_resource_templates(
405 &self,
406 request: Option<PaginatedRequestParams>,
407 context: RequestContext<RoleServer>,
408 ) -> Result<ListResourceTemplatesResult, ErrorData> {
409 self.inner.list_resource_templates(request, context).await
410 }
411
412 async fn read_resource(
413 &self,
414 request: ReadResourceRequestParams,
415 context: RequestContext<RoleServer>,
416 ) -> Result<ReadResourceResult, ErrorData> {
417 self.inner.read_resource(request, context).await
418 }
419
420 async fn call_tool(
421 &self,
422 request: CallToolRequestParams,
423 context: RequestContext<RoleServer>,
424 ) -> Result<CallToolResult, ErrorData> {
425 let req_id = Some(format!("{:?}", context.id));
426 let ctx = Self::build_context(&request, req_id);
427 let max = self.hooks.max_result_bytes;
428 let after_holder = self
429 .hooks
430 .after
431 .as_ref()
432 .map(|f| Arc::new(AfterHookHolder { f: Arc::clone(f) }));
433
434 if let Some(before) = self.hooks.before.as_ref() {
436 let outcome = before(&ctx).await;
437 match outcome {
438 HookOutcome::Continue => {}
439 HookOutcome::Deny(err) => {
440 Self::spawn_after(after_holder.as_ref(), ctx, HookDisposition::DeniedBefore, 0);
441 return Err(err);
442 }
443 HookOutcome::Replace(boxed) => {
444 let (final_result, size, capped) = apply_size_cap(*boxed, max, &ctx.tool_name);
445 let disposition = if capped {
446 HookDisposition::ResultTooLarge
447 } else {
448 HookDisposition::ReplacedBefore
449 };
450 Self::spawn_after(after_holder.as_ref(), ctx, disposition, size);
451 return Ok(final_result);
452 }
453 }
454 }
455
456 let result = self.inner.call_tool(request, context).await;
458
459 match result {
460 Ok(ok) => {
461 let (final_result, size, capped) = apply_size_cap(ok, max, &ctx.tool_name);
462 let disposition = if capped {
463 HookDisposition::ResultTooLarge
464 } else {
465 HookDisposition::InnerExecuted
466 };
467 Self::spawn_after(after_holder.as_ref(), ctx, disposition, size);
468 Ok(final_result)
469 }
470 Err(e) => {
471 Self::spawn_after(after_holder.as_ref(), ctx, HookDisposition::InnerErrored, 0);
472 Err(e)
473 }
474 }
475 }
476}
477
478#[cfg(test)]
479mod tests {
480 use std::sync::{
481 Arc,
482 atomic::{AtomicUsize, Ordering},
483 };
484
485 use rmcp::{
486 ErrorData, RoleServer, ServerHandler,
487 model::{CallToolRequestParams, CallToolResult, Content, ServerInfo},
488 service::RequestContext,
489 };
490
491 use super::*;
492
493 #[derive(Clone, Default)]
495 struct TestHandler {
496 body_bytes: Option<usize>,
498 }
499
500 impl ServerHandler for TestHandler {
501 fn get_info(&self) -> ServerInfo {
502 ServerInfo::default()
503 }
504
505 async fn call_tool(
506 &self,
507 _request: CallToolRequestParams,
508 _context: RequestContext<RoleServer>,
509 ) -> Result<CallToolResult, ErrorData> {
510 let body = "x".repeat(self.body_bytes.unwrap_or(4));
511 Ok(CallToolResult::success(vec![Content::text(body)]))
512 }
513 }
514
515 fn ctx(name: &str) -> ToolCallContext {
516 ToolCallContext {
517 tool_name: name.to_owned(),
518 arguments: None,
519 identity: None,
520 role: None,
521 sub: None,
522 request_id: None,
523 }
524 }
525
526 #[tokio::test]
527 async fn size_cap_replaces_oversized_result() {
528 let inner = TestHandler {
529 body_bytes: Some(8_192),
530 };
531 let hooks = Arc::new(ToolHooks {
532 max_result_bytes: Some(256),
533 before: None,
534 after: None,
535 });
536 let hooked = with_hooks(inner, hooks);
537
538 let small = CallToolResult::success(vec![Content::text("ok".to_owned())]);
539 assert!(serialized_size(&small) < 256);
540
541 let big = CallToolResult::success(vec![Content::text("x".repeat(8_192))]);
542 let size = serialized_size(&big);
543 assert!(size > 256);
544
545 let (replaced, accounted, capped) = apply_size_cap(big, Some(256), "whatever");
546 assert!(capped);
547 assert_eq!(accounted, size);
548 assert_eq!(replaced.is_error, Some(true));
549 assert!(matches!(
550 &replaced.content[0].raw,
551 rmcp::model::RawContent::Text(t) if t.text.contains("result_too_large")
552 ));
553
554 let _ = hooked;
556 }
557
558 #[tokio::test]
559 async fn before_hook_deny_builds_error() {
560 let counter = Arc::new(AtomicUsize::new(0));
561 let c = Arc::clone(&counter);
562 let before: BeforeHook = Arc::new(move |ctx_ref| {
563 let c = Arc::clone(&c);
564 let name = ctx_ref.tool_name.clone();
565 Box::pin(async move {
566 c.fetch_add(1, Ordering::Relaxed);
567 if name == "forbidden" {
568 HookOutcome::Deny(ErrorData::invalid_request("nope", None))
569 } else {
570 HookOutcome::Continue
571 }
572 })
573 });
574
575 let hooks = Arc::new(ToolHooks {
576 max_result_bytes: None,
577 before: Some(before),
578 after: None,
579 });
580 let hooked = with_hooks(TestHandler::default(), hooks);
581
582 let bad_ctx = ctx("forbidden");
583 let before_fn = hooked.hooks.before.as_ref().unwrap();
584 let outcome = before_fn(&bad_ctx).await;
585 assert!(matches!(outcome, HookOutcome::Deny(_)));
586 assert_eq!(counter.load(Ordering::Relaxed), 1);
587
588 let ok_ctx = ctx("allowed");
589 let outcome2 = before_fn(&ok_ctx).await;
590 assert!(matches!(outcome2, HookOutcome::Continue));
591 assert_eq!(counter.load(Ordering::Relaxed), 2);
592 }
593
594 #[test]
595 fn too_large_result_mentions_limit_and_actual() {
596 let r = too_large_result(100, 500, "my_tool");
597 let body = serde_json::to_string(&r).unwrap();
598 assert!(body.contains("result_too_large"));
599 assert!(body.contains("my_tool"));
600 assert!(body.contains("100"));
601 assert!(body.contains("500"));
602 }
603
604 #[tokio::test]
605 async fn replace_outcome_skips_inner_and_returns_payload() {
606 let before: BeforeHook = Arc::new(|_ctx| {
609 Box::pin(async {
610 HookOutcome::Replace(Box::new(CallToolResult::success(vec![Content::text(
611 "from-replace".to_owned(),
612 )])))
613 })
614 });
615 let hooks = Arc::new(ToolHooks {
616 max_result_bytes: None,
617 before: Some(before),
618 after: None,
619 });
620 let _hooked = with_hooks(TestHandler::default(), Arc::clone(&hooks));
621
622 let outcome = (hooks.before.as_ref().unwrap())(&ctx("any")).await;
625 let HookOutcome::Replace(boxed) = outcome else {
626 panic!("expected HookOutcome::Replace");
627 };
628 let (result, size, capped) = apply_size_cap(*boxed, None, "any");
629 assert!(!capped);
630 assert!(size > 0);
631 assert!(!result.is_error.unwrap_or(false));
632 assert!(matches!(
633 &result.content[0].raw,
634 rmcp::model::RawContent::Text(t) if t.text == "from-replace"
635 ));
636 }
637
638 #[tokio::test]
639 async fn replace_outcome_subject_to_size_cap() {
640 let huge = CallToolResult::success(vec![Content::text("y".repeat(8_192))]);
644 let huge_size = serialized_size(&huge);
645 assert!(huge_size > 256);
646
647 let (final_result, accounted, capped) = apply_size_cap(huge, Some(256), "replaced_tool");
648 assert!(capped);
649 assert_eq!(accounted, huge_size);
650 assert_eq!(final_result.is_error, Some(true));
651 assert!(matches!(
652 &final_result.content[0].raw,
653 rmcp::model::RawContent::Text(t) if t.text.contains("result_too_large")
654 ));
655 }
656
657 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
658 async fn after_hook_fires_exactly_once_via_spawn() {
659 let counter = Arc::new(AtomicUsize::new(0));
663 let c = Arc::clone(&counter);
664 let after: AfterHook = Arc::new(move |_ctx, _disp, _size| {
665 let c = Arc::clone(&c);
666 Box::pin(async move {
667 c.fetch_add(1, Ordering::Relaxed);
668 })
669 });
670 let holder = Arc::new(AfterHookHolder { f: after });
671
672 HookedHandler::<TestHandler>::spawn_after(
673 Some(&holder),
674 ctx("t"),
675 HookDisposition::InnerExecuted,
676 42,
677 );
678
679 let deadline = std::time::Instant::now() + std::time::Duration::from_secs(1);
681 while counter.load(Ordering::Relaxed) == 0 && std::time::Instant::now() < deadline {
682 tokio::task::yield_now().await;
683 tokio::time::sleep(std::time::Duration::from_millis(5)).await;
684 }
685 assert_eq!(counter.load(Ordering::Relaxed), 1);
686 }
687
688 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
689 async fn after_hook_panic_is_isolated_from_response_path() {
690 let after: AfterHook = Arc::new(|_ctx, _disp, _size| {
694 Box::pin(async {
695 panic!("intentional panic in after-hook");
696 })
697 });
698 let holder = Arc::new(AfterHookHolder { f: after });
699
700 HookedHandler::<TestHandler>::spawn_after(
701 Some(&holder),
702 ctx("boom"),
703 HookDisposition::InnerExecuted,
704 0,
705 );
706
707 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
710 let still_alive = tokio::spawn(async { 1_u32 + 2 }).await.unwrap();
711 assert_eq!(still_alive, 3);
712 }
713}