1use std::{fmt, future::Future, pin::Pin, sync::Arc};
40
41use rmcp::{
42 ErrorData, RoleServer, ServerHandler,
43 model::{
44 CallToolRequestParams, CallToolResult, Content, GetPromptRequestParams, GetPromptResult,
45 InitializeRequestParams, InitializeResult, ListPromptsResult, ListResourceTemplatesResult,
46 ListResourcesResult, ListToolsResult, PaginatedRequestParams, ReadResourceRequestParams,
47 ReadResourceResult, ServerInfo, Tool,
48 },
49 service::RequestContext,
50};
51
52#[derive(Debug, Clone)]
54#[non_exhaustive]
55pub struct ToolCallContext {
56 pub tool_name: String,
58 pub arguments: Option<serde_json::Value>,
60 pub identity: Option<String>,
62 pub role: Option<String>,
64 pub sub: Option<String>,
66 pub request_id: Option<String>,
68}
69
70impl ToolCallContext {
71 #[must_use]
76 pub fn for_tool(tool_name: impl Into<String>) -> Self {
77 Self {
78 tool_name: tool_name.into(),
79 arguments: None,
80 identity: None,
81 role: None,
82 sub: None,
83 request_id: None,
84 }
85 }
86}
87
88#[derive(Debug)]
97#[non_exhaustive]
98pub enum HookOutcome {
99 Continue,
101 Deny(ErrorData),
103 Replace(Box<CallToolResult>),
105}
106
107#[derive(Debug, Clone, Copy)]
109#[non_exhaustive]
110pub enum HookDisposition {
111 InnerExecuted,
113 InnerErrored,
115 DeniedBefore,
117 ReplacedBefore,
119 ResultTooLarge,
122}
123
124pub type BeforeHook = Arc<
131 dyn for<'a> Fn(&'a ToolCallContext) -> Pin<Box<dyn Future<Output = HookOutcome> + Send + 'a>>
132 + Send
133 + Sync
134 + 'static,
135>;
136
137pub type AfterHook = Arc<
145 dyn for<'a> Fn(
146 &'a ToolCallContext,
147 HookDisposition,
148 usize,
149 ) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>>
150 + Send
151 + Sync
152 + 'static,
153>;
154
155#[allow(clippy::struct_field_names, reason = "before/after read naturally")]
157#[derive(Clone, Default)]
158#[non_exhaustive]
159pub struct ToolHooks {
160 pub max_result_bytes: Option<usize>,
165 pub before: Option<BeforeHook>,
168 pub after: Option<AfterHook>,
172}
173
174impl ToolHooks {
175 #[must_use]
181 pub fn new() -> Self {
182 Self::default()
183 }
184
185 #[must_use]
187 pub fn with_max_result_bytes(mut self, max: usize) -> Self {
188 self.max_result_bytes = Some(max);
189 self
190 }
191
192 #[must_use]
194 pub fn with_before(mut self, before: BeforeHook) -> Self {
195 self.before = Some(before);
196 self
197 }
198
199 #[must_use]
201 pub fn with_after(mut self, after: AfterHook) -> Self {
202 self.after = Some(after);
203 self
204 }
205}
206
207impl fmt::Debug for ToolHooks {
208 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
209 f.debug_struct("ToolHooks")
210 .field("max_result_bytes", &self.max_result_bytes)
211 .field("before", &self.before.as_ref().map(|_| "<fn>"))
212 .field("after", &self.after.as_ref().map(|_| "<fn>"))
213 .finish()
214 }
215}
216
217#[derive(Clone)]
219pub struct HookedHandler<H: ServerHandler> {
220 inner: Arc<H>,
221 hooks: Arc<ToolHooks>,
222}
223
224impl<H: ServerHandler> fmt::Debug for HookedHandler<H> {
225 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
226 f.debug_struct("HookedHandler")
227 .field("hooks", &self.hooks)
228 .finish_non_exhaustive()
229 }
230}
231
232pub fn with_hooks<H: ServerHandler>(inner: H, hooks: Arc<ToolHooks>) -> HookedHandler<H> {
246 HookedHandler {
247 inner: Arc::new(inner),
248 hooks,
249 }
250}
251
252impl<H: ServerHandler> HookedHandler<H> {
253 #[must_use]
255 pub fn inner(&self) -> &H {
256 &self.inner
257 }
258
259 fn build_context(request: &CallToolRequestParams, req_id: Option<String>) -> ToolCallContext {
260 ToolCallContext {
261 tool_name: request.name.to_string(),
262 arguments: request.arguments.clone().map(serde_json::Value::Object),
263 identity: crate::rbac::current_identity(),
264 role: crate::rbac::current_role(),
265 sub: crate::rbac::current_sub(),
266 request_id: req_id,
267 }
268 }
269
270 fn spawn_after(
282 after: Option<&Arc<AfterHookHolder>>,
283 ctx: ToolCallContext,
284 disposition: HookDisposition,
285 size: usize,
286 ) {
287 if let Some(after) = after {
288 use tracing::Instrument;
289
290 let after = Arc::clone(after);
291 let span = tracing::Span::current();
294 let role = crate::rbac::current_role().unwrap_or_default();
298 let identity = crate::rbac::current_identity().unwrap_or_default();
299 let token = crate::rbac::current_token()
300 .unwrap_or_else(|| secrecy::SecretString::from(String::new()));
301 let sub = crate::rbac::current_sub().unwrap_or_default();
302 tokio::spawn(
303 async move {
304 crate::rbac::with_rbac_scope(role, identity, token, sub, async move {
305 let fut = (after.f)(&ctx, disposition, size);
306 fut.await;
307 })
308 .await;
309 }
310 .instrument(span),
311 );
312 }
313 }
314}
315
316struct AfterHookHolder {
320 f: AfterHook,
321}
322
323fn too_large_result(limit: usize, actual: usize, tool: &str) -> CallToolResult {
325 let body = serde_json::json!({
326 "error": "result_too_large",
327 "message": format!(
328 "tool '{tool}' result of {actual} bytes exceeds the configured \
329 max_result_bytes={limit}; ask for a narrower query"
330 ),
331 "limit_bytes": limit,
332 "actual_bytes": actual,
333 });
334 let mut r = CallToolResult::error(vec![Content::text(body.to_string())]);
335 r.structured_content = None;
336 r
337}
338
339fn serialized_size(result: &CallToolResult) -> usize {
340 serde_json::to_vec(result).map_or(0, |v| v.len())
341}
342
343fn apply_size_cap(
347 result: CallToolResult,
348 max: Option<usize>,
349 tool: &str,
350) -> (CallToolResult, usize, bool) {
351 let size = serialized_size(&result);
352 if let Some(limit) = max
353 && size > limit
354 {
355 tracing::warn!(
356 tool = %tool,
357 size_bytes = size,
358 limit_bytes = limit,
359 "tool result exceeds max_result_bytes; replacing with structured error"
360 );
361 let replaced = too_large_result(limit, size, tool);
362 return (replaced, size, true);
363 }
364 (result, size, false)
365}
366
367impl<H: ServerHandler> ServerHandler for HookedHandler<H> {
368 fn get_info(&self) -> ServerInfo {
369 self.inner.get_info()
370 }
371
372 async fn initialize(
373 &self,
374 request: InitializeRequestParams,
375 context: RequestContext<RoleServer>,
376 ) -> Result<InitializeResult, ErrorData> {
377 self.inner.initialize(request, context).await
378 }
379
380 async fn list_tools(
381 &self,
382 request: Option<PaginatedRequestParams>,
383 context: RequestContext<RoleServer>,
384 ) -> Result<ListToolsResult, ErrorData> {
385 self.inner.list_tools(request, context).await
386 }
387
388 fn get_tool(&self, name: &str) -> Option<Tool> {
389 self.inner.get_tool(name)
390 }
391
392 async fn list_prompts(
393 &self,
394 request: Option<PaginatedRequestParams>,
395 context: RequestContext<RoleServer>,
396 ) -> Result<ListPromptsResult, ErrorData> {
397 self.inner.list_prompts(request, context).await
398 }
399
400 async fn get_prompt(
401 &self,
402 request: GetPromptRequestParams,
403 context: RequestContext<RoleServer>,
404 ) -> Result<GetPromptResult, ErrorData> {
405 self.inner.get_prompt(request, context).await
406 }
407
408 async fn list_resources(
409 &self,
410 request: Option<PaginatedRequestParams>,
411 context: RequestContext<RoleServer>,
412 ) -> Result<ListResourcesResult, ErrorData> {
413 self.inner.list_resources(request, context).await
414 }
415
416 async fn list_resource_templates(
417 &self,
418 request: Option<PaginatedRequestParams>,
419 context: RequestContext<RoleServer>,
420 ) -> Result<ListResourceTemplatesResult, ErrorData> {
421 self.inner.list_resource_templates(request, context).await
422 }
423
424 async fn read_resource(
425 &self,
426 request: ReadResourceRequestParams,
427 context: RequestContext<RoleServer>,
428 ) -> Result<ReadResourceResult, ErrorData> {
429 self.inner.read_resource(request, context).await
430 }
431
432 async fn call_tool(
433 &self,
434 request: CallToolRequestParams,
435 context: RequestContext<RoleServer>,
436 ) -> Result<CallToolResult, ErrorData> {
437 let req_id = Some(format!("{:?}", context.id));
438 let ctx = Self::build_context(&request, req_id);
439 let max = self.hooks.max_result_bytes;
440 let after_holder = self
441 .hooks
442 .after
443 .as_ref()
444 .map(|f| Arc::new(AfterHookHolder { f: Arc::clone(f) }));
445
446 if let Some(before) = self.hooks.before.as_ref() {
448 let outcome = before(&ctx).await;
449 match outcome {
450 HookOutcome::Continue => {}
451 HookOutcome::Deny(err) => {
452 Self::spawn_after(after_holder.as_ref(), ctx, HookDisposition::DeniedBefore, 0);
453 return Err(err);
454 }
455 HookOutcome::Replace(boxed) => {
456 let (final_result, size, capped) = apply_size_cap(*boxed, max, &ctx.tool_name);
457 let disposition = if capped {
458 HookDisposition::ResultTooLarge
459 } else {
460 HookDisposition::ReplacedBefore
461 };
462 Self::spawn_after(after_holder.as_ref(), ctx, disposition, size);
463 return Ok(final_result);
464 }
465 }
466 }
467
468 let result = self.inner.call_tool(request, context).await;
470
471 match result {
472 Ok(ok) => {
473 let (final_result, size, capped) = apply_size_cap(ok, max, &ctx.tool_name);
474 let disposition = if capped {
475 HookDisposition::ResultTooLarge
476 } else {
477 HookDisposition::InnerExecuted
478 };
479 Self::spawn_after(after_holder.as_ref(), ctx, disposition, size);
480 Ok(final_result)
481 }
482 Err(e) => {
483 Self::spawn_after(after_holder.as_ref(), ctx, HookDisposition::InnerErrored, 0);
484 Err(e)
485 }
486 }
487 }
488}
489
490#[cfg(test)]
491mod tests {
492 use std::sync::{
493 Arc,
494 atomic::{AtomicUsize, Ordering},
495 };
496
497 use rmcp::{
498 ErrorData, RoleServer, ServerHandler,
499 model::{CallToolRequestParams, CallToolResult, Content, ServerInfo},
500 service::RequestContext,
501 };
502
503 use super::*;
504
505 #[derive(Clone, Default)]
507 struct TestHandler {
508 body_bytes: Option<usize>,
510 }
511
512 impl ServerHandler for TestHandler {
513 fn get_info(&self) -> ServerInfo {
514 ServerInfo::default()
515 }
516
517 async fn call_tool(
518 &self,
519 _request: CallToolRequestParams,
520 _context: RequestContext<RoleServer>,
521 ) -> Result<CallToolResult, ErrorData> {
522 let body = "x".repeat(self.body_bytes.unwrap_or(4));
523 Ok(CallToolResult::success(vec![Content::text(body)]))
524 }
525 }
526
527 fn ctx(name: &str) -> ToolCallContext {
528 ToolCallContext {
529 tool_name: name.to_owned(),
530 arguments: None,
531 identity: None,
532 role: None,
533 sub: None,
534 request_id: None,
535 }
536 }
537
538 #[tokio::test]
539 async fn size_cap_replaces_oversized_result() {
540 let inner = TestHandler {
541 body_bytes: Some(8_192),
542 };
543 let hooks = Arc::new(ToolHooks {
544 max_result_bytes: Some(256),
545 before: None,
546 after: None,
547 });
548 let hooked = with_hooks(inner, hooks);
549
550 let small = CallToolResult::success(vec![Content::text("ok".to_owned())]);
551 assert!(serialized_size(&small) < 256);
552
553 let big = CallToolResult::success(vec![Content::text("x".repeat(8_192))]);
554 let size = serialized_size(&big);
555 assert!(size > 256);
556
557 let (replaced, accounted, capped) = apply_size_cap(big, Some(256), "whatever");
558 assert!(capped);
559 assert_eq!(accounted, size);
560 assert_eq!(replaced.is_error, Some(true));
561 assert!(matches!(
562 &replaced.content[0].raw,
563 rmcp::model::RawContent::Text(t) if t.text.contains("result_too_large")
564 ));
565
566 let _ = hooked;
568 }
569
570 #[tokio::test]
571 async fn before_hook_deny_builds_error() {
572 let counter = Arc::new(AtomicUsize::new(0));
573 let c = Arc::clone(&counter);
574 let before: BeforeHook = Arc::new(move |ctx_ref| {
575 let c = Arc::clone(&c);
576 let name = ctx_ref.tool_name.clone();
577 Box::pin(async move {
578 c.fetch_add(1, Ordering::Relaxed);
579 if name == "forbidden" {
580 HookOutcome::Deny(ErrorData::invalid_request("nope", None))
581 } else {
582 HookOutcome::Continue
583 }
584 })
585 });
586
587 let hooks = Arc::new(ToolHooks {
588 max_result_bytes: None,
589 before: Some(before),
590 after: None,
591 });
592 let hooked = with_hooks(TestHandler::default(), hooks);
593
594 let bad_ctx = ctx("forbidden");
595 let before_fn = hooked.hooks.before.as_ref().unwrap();
596 let outcome = before_fn(&bad_ctx).await;
597 assert!(matches!(outcome, HookOutcome::Deny(_)));
598 assert_eq!(counter.load(Ordering::Relaxed), 1);
599
600 let ok_ctx = ctx("allowed");
601 let outcome2 = before_fn(&ok_ctx).await;
602 assert!(matches!(outcome2, HookOutcome::Continue));
603 assert_eq!(counter.load(Ordering::Relaxed), 2);
604 }
605
606 #[test]
607 fn too_large_result_mentions_limit_and_actual() {
608 let r = too_large_result(100, 500, "my_tool");
609 let body = serde_json::to_string(&r).unwrap();
610 assert!(body.contains("result_too_large"));
611 assert!(body.contains("my_tool"));
612 assert!(body.contains("100"));
613 assert!(body.contains("500"));
614 }
615
616 #[tokio::test]
617 async fn replace_outcome_skips_inner_and_returns_payload() {
618 let before: BeforeHook = Arc::new(|_ctx| {
621 Box::pin(async {
622 HookOutcome::Replace(Box::new(CallToolResult::success(vec![Content::text(
623 "from-replace".to_owned(),
624 )])))
625 })
626 });
627 let hooks = Arc::new(ToolHooks {
628 max_result_bytes: None,
629 before: Some(before),
630 after: None,
631 });
632 let _hooked = with_hooks(TestHandler::default(), Arc::clone(&hooks));
633
634 let outcome = (hooks.before.as_ref().unwrap())(&ctx("any")).await;
637 let HookOutcome::Replace(boxed) = outcome else {
638 panic!("expected HookOutcome::Replace");
639 };
640 let (result, size, capped) = apply_size_cap(*boxed, None, "any");
641 assert!(!capped);
642 assert!(size > 0);
643 assert!(!result.is_error.unwrap_or(false));
644 assert!(matches!(
645 &result.content[0].raw,
646 rmcp::model::RawContent::Text(t) if t.text == "from-replace"
647 ));
648 }
649
650 #[tokio::test]
651 async fn replace_outcome_subject_to_size_cap() {
652 let huge = CallToolResult::success(vec![Content::text("y".repeat(8_192))]);
656 let huge_size = serialized_size(&huge);
657 assert!(huge_size > 256);
658
659 let (final_result, accounted, capped) = apply_size_cap(huge, Some(256), "replaced_tool");
660 assert!(capped);
661 assert_eq!(accounted, huge_size);
662 assert_eq!(final_result.is_error, Some(true));
663 assert!(matches!(
664 &final_result.content[0].raw,
665 rmcp::model::RawContent::Text(t) if t.text.contains("result_too_large")
666 ));
667 }
668
669 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
670 async fn after_hook_fires_exactly_once_via_spawn() {
671 let counter = Arc::new(AtomicUsize::new(0));
675 let c = Arc::clone(&counter);
676 let after: AfterHook = Arc::new(move |_ctx, _disp, _size| {
677 let c = Arc::clone(&c);
678 Box::pin(async move {
679 c.fetch_add(1, Ordering::Relaxed);
680 })
681 });
682 let holder = Arc::new(AfterHookHolder { f: after });
683
684 HookedHandler::<TestHandler>::spawn_after(
685 Some(&holder),
686 ctx("t"),
687 HookDisposition::InnerExecuted,
688 42,
689 );
690
691 let deadline = std::time::Instant::now() + std::time::Duration::from_secs(1);
693 while counter.load(Ordering::Relaxed) == 0 && std::time::Instant::now() < deadline {
694 tokio::task::yield_now().await;
695 tokio::time::sleep(std::time::Duration::from_millis(5)).await;
696 }
697 assert_eq!(counter.load(Ordering::Relaxed), 1);
698 }
699
700 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
701 async fn after_hook_panic_is_isolated_from_response_path() {
702 let after: AfterHook = Arc::new(|_ctx, _disp, _size| {
706 Box::pin(async {
707 panic!("intentional panic in after-hook");
708 })
709 });
710 let holder = Arc::new(AfterHookHolder { f: after });
711
712 HookedHandler::<TestHandler>::spawn_after(
713 Some(&holder),
714 ctx("boom"),
715 HookDisposition::InnerExecuted,
716 0,
717 );
718
719 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
722 let still_alive = tokio::spawn(async { 1_u32 + 2 }).await.unwrap();
723 assert_eq!(still_alive, 3);
724 }
725}