1use std::{fmt, future::Future, pin::Pin, sync::Arc};
40
41use rmcp::{
42 ErrorData, RoleServer, ServerHandler,
43 model::{
44 CallToolRequestParams, CallToolResult, Content, GetPromptRequestParams, GetPromptResult,
45 InitializeRequestParams, InitializeResult, ListPromptsResult, ListResourceTemplatesResult,
46 ListResourcesResult, ListToolsResult, PaginatedRequestParams, ReadResourceRequestParams,
47 ReadResourceResult, ServerInfo, Tool,
48 },
49 service::RequestContext,
50};
51
52#[derive(Debug, Clone)]
54#[non_exhaustive]
55pub struct ToolCallContext {
56 pub tool_name: String,
58 pub arguments: Option<serde_json::Value>,
60 pub identity: Option<String>,
62 pub role: Option<String>,
64 pub sub: Option<String>,
66 pub request_id: Option<String>,
68}
69
70impl ToolCallContext {
71 #[must_use]
76 pub fn for_tool(tool_name: impl Into<String>) -> Self {
77 Self {
78 tool_name: tool_name.into(),
79 arguments: None,
80 identity: None,
81 role: None,
82 sub: None,
83 request_id: None,
84 }
85 }
86}
87
88#[derive(Debug)]
97#[non_exhaustive]
98pub enum HookOutcome {
99 Continue,
101 Deny(ErrorData),
103 Replace(Box<CallToolResult>),
105}
106
107#[derive(Debug, Clone, Copy)]
109#[non_exhaustive]
110pub enum HookDisposition {
111 InnerExecuted,
113 InnerErrored,
115 DeniedBefore,
117 ReplacedBefore,
119 ResultTooLarge,
122}
123
124pub type BeforeHook = Arc<
131 dyn for<'a> Fn(&'a ToolCallContext) -> Pin<Box<dyn Future<Output = HookOutcome> + Send + 'a>>
132 + Send
133 + Sync
134 + 'static,
135>;
136
137pub type AfterHook = Arc<
145 dyn for<'a> Fn(
146 &'a ToolCallContext,
147 HookDisposition,
148 usize,
149 ) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>>
150 + Send
151 + Sync
152 + 'static,
153>;
154
155#[allow(clippy::struct_field_names, reason = "before/after read naturally")]
157#[derive(Clone, Default)]
158#[non_exhaustive]
159pub struct ToolHooks {
160 pub max_result_bytes: Option<usize>,
165 pub before: Option<BeforeHook>,
168 pub after: Option<AfterHook>,
172}
173
174impl ToolHooks {
175 #[must_use]
181 pub fn new() -> Self {
182 Self::default()
183 }
184
185 #[must_use]
187 pub fn with_max_result_bytes(mut self, max: usize) -> Self {
188 self.max_result_bytes = Some(max);
189 self
190 }
191
192 #[must_use]
194 pub fn with_before(mut self, before: BeforeHook) -> Self {
195 self.before = Some(before);
196 self
197 }
198
199 #[must_use]
201 pub fn with_after(mut self, after: AfterHook) -> Self {
202 self.after = Some(after);
203 self
204 }
205}
206
207impl fmt::Debug for ToolHooks {
208 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
209 f.debug_struct("ToolHooks")
210 .field("max_result_bytes", &self.max_result_bytes)
211 .field("before", &self.before.as_ref().map(|_| "<fn>"))
212 .field("after", &self.after.as_ref().map(|_| "<fn>"))
213 .finish()
214 }
215}
216
217#[derive(Clone)]
219pub struct HookedHandler<H: ServerHandler> {
220 inner: Arc<H>,
221 hooks: Arc<ToolHooks>,
222}
223
224impl<H: ServerHandler> fmt::Debug for HookedHandler<H> {
225 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
226 f.debug_struct("HookedHandler")
227 .field("hooks", &self.hooks)
228 .finish_non_exhaustive()
229 }
230}
231
232pub fn with_hooks<H: ServerHandler>(inner: H, hooks: Arc<ToolHooks>) -> HookedHandler<H> {
240 HookedHandler {
241 inner: Arc::new(inner),
242 hooks,
243 }
244}
245
246impl<H: ServerHandler> HookedHandler<H> {
247 #[must_use]
249 pub fn inner(&self) -> &H {
250 &self.inner
251 }
252
253 fn build_context(request: &CallToolRequestParams, req_id: Option<String>) -> ToolCallContext {
254 ToolCallContext {
255 tool_name: request.name.to_string(),
256 arguments: request.arguments.clone().map(serde_json::Value::Object),
257 identity: crate::rbac::current_identity(),
258 role: crate::rbac::current_role(),
259 sub: crate::rbac::current_sub(),
260 request_id: req_id,
261 }
262 }
263
264 fn spawn_after(
276 after: Option<&Arc<AfterHookHolder>>,
277 ctx: ToolCallContext,
278 disposition: HookDisposition,
279 size: usize,
280 ) {
281 if let Some(after) = after {
282 use tracing::Instrument;
283
284 let after = Arc::clone(after);
285 let span = tracing::Span::current();
288 let role = crate::rbac::current_role().unwrap_or_default();
292 let identity = crate::rbac::current_identity().unwrap_or_default();
293 let token = crate::rbac::current_token()
294 .unwrap_or_else(|| secrecy::SecretString::from(String::new()));
295 let sub = crate::rbac::current_sub().unwrap_or_default();
296 tokio::spawn(
297 async move {
298 crate::rbac::with_rbac_scope(role, identity, token, sub, async move {
299 let fut = (after.f)(&ctx, disposition, size);
300 fut.await;
301 })
302 .await;
303 }
304 .instrument(span),
305 );
306 }
307 }
308}
309
310struct AfterHookHolder {
314 f: AfterHook,
315}
316
317fn too_large_result(limit: usize, actual: usize, tool: &str) -> CallToolResult {
319 let body = serde_json::json!({
320 "error": "result_too_large",
321 "message": format!(
322 "tool '{tool}' result of {actual} bytes exceeds the configured \
323 max_result_bytes={limit}; ask for a narrower query"
324 ),
325 "limit_bytes": limit,
326 "actual_bytes": actual,
327 });
328 let mut r = CallToolResult::error(vec![Content::text(body.to_string())]);
329 r.structured_content = None;
330 r
331}
332
333fn serialized_size(result: &CallToolResult) -> usize {
334 serde_json::to_vec(result).map_or(0, |v| v.len())
335}
336
337fn apply_size_cap(
341 result: CallToolResult,
342 max: Option<usize>,
343 tool: &str,
344) -> (CallToolResult, usize, bool) {
345 let size = serialized_size(&result);
346 if let Some(limit) = max
347 && size > limit
348 {
349 tracing::warn!(
350 tool = %tool,
351 size_bytes = size,
352 limit_bytes = limit,
353 "tool result exceeds max_result_bytes; replacing with structured error"
354 );
355 let replaced = too_large_result(limit, size, tool);
356 return (replaced, size, true);
357 }
358 (result, size, false)
359}
360
361impl<H: ServerHandler> ServerHandler for HookedHandler<H> {
362 fn get_info(&self) -> ServerInfo {
363 self.inner.get_info()
364 }
365
366 async fn initialize(
367 &self,
368 request: InitializeRequestParams,
369 context: RequestContext<RoleServer>,
370 ) -> Result<InitializeResult, ErrorData> {
371 self.inner.initialize(request, context).await
372 }
373
374 async fn list_tools(
375 &self,
376 request: Option<PaginatedRequestParams>,
377 context: RequestContext<RoleServer>,
378 ) -> Result<ListToolsResult, ErrorData> {
379 self.inner.list_tools(request, context).await
380 }
381
382 fn get_tool(&self, name: &str) -> Option<Tool> {
383 self.inner.get_tool(name)
384 }
385
386 async fn list_prompts(
387 &self,
388 request: Option<PaginatedRequestParams>,
389 context: RequestContext<RoleServer>,
390 ) -> Result<ListPromptsResult, ErrorData> {
391 self.inner.list_prompts(request, context).await
392 }
393
394 async fn get_prompt(
395 &self,
396 request: GetPromptRequestParams,
397 context: RequestContext<RoleServer>,
398 ) -> Result<GetPromptResult, ErrorData> {
399 self.inner.get_prompt(request, context).await
400 }
401
402 async fn list_resources(
403 &self,
404 request: Option<PaginatedRequestParams>,
405 context: RequestContext<RoleServer>,
406 ) -> Result<ListResourcesResult, ErrorData> {
407 self.inner.list_resources(request, context).await
408 }
409
410 async fn list_resource_templates(
411 &self,
412 request: Option<PaginatedRequestParams>,
413 context: RequestContext<RoleServer>,
414 ) -> Result<ListResourceTemplatesResult, ErrorData> {
415 self.inner.list_resource_templates(request, context).await
416 }
417
418 async fn read_resource(
419 &self,
420 request: ReadResourceRequestParams,
421 context: RequestContext<RoleServer>,
422 ) -> Result<ReadResourceResult, ErrorData> {
423 self.inner.read_resource(request, context).await
424 }
425
426 async fn call_tool(
427 &self,
428 request: CallToolRequestParams,
429 context: RequestContext<RoleServer>,
430 ) -> Result<CallToolResult, ErrorData> {
431 let req_id = Some(format!("{:?}", context.id));
432 let ctx = Self::build_context(&request, req_id);
433 let max = self.hooks.max_result_bytes;
434 let after_holder = self
435 .hooks
436 .after
437 .as_ref()
438 .map(|f| Arc::new(AfterHookHolder { f: Arc::clone(f) }));
439
440 if let Some(before) = self.hooks.before.as_ref() {
442 let outcome = before(&ctx).await;
443 match outcome {
444 HookOutcome::Continue => {}
445 HookOutcome::Deny(err) => {
446 Self::spawn_after(after_holder.as_ref(), ctx, HookDisposition::DeniedBefore, 0);
447 return Err(err);
448 }
449 HookOutcome::Replace(boxed) => {
450 let (final_result, size, capped) = apply_size_cap(*boxed, max, &ctx.tool_name);
451 let disposition = if capped {
452 HookDisposition::ResultTooLarge
453 } else {
454 HookDisposition::ReplacedBefore
455 };
456 Self::spawn_after(after_holder.as_ref(), ctx, disposition, size);
457 return Ok(final_result);
458 }
459 }
460 }
461
462 let result = self.inner.call_tool(request, context).await;
464
465 match result {
466 Ok(ok) => {
467 let (final_result, size, capped) = apply_size_cap(ok, max, &ctx.tool_name);
468 let disposition = if capped {
469 HookDisposition::ResultTooLarge
470 } else {
471 HookDisposition::InnerExecuted
472 };
473 Self::spawn_after(after_holder.as_ref(), ctx, disposition, size);
474 Ok(final_result)
475 }
476 Err(e) => {
477 Self::spawn_after(after_holder.as_ref(), ctx, HookDisposition::InnerErrored, 0);
478 Err(e)
479 }
480 }
481 }
482}
483
484#[cfg(test)]
485mod tests {
486 use std::sync::{
487 Arc,
488 atomic::{AtomicUsize, Ordering},
489 };
490
491 use rmcp::{
492 ErrorData, RoleServer, ServerHandler,
493 model::{CallToolRequestParams, CallToolResult, Content, ServerInfo},
494 service::RequestContext,
495 };
496
497 use super::*;
498
499 #[derive(Clone, Default)]
501 struct TestHandler {
502 body_bytes: Option<usize>,
504 }
505
506 impl ServerHandler for TestHandler {
507 fn get_info(&self) -> ServerInfo {
508 ServerInfo::default()
509 }
510
511 async fn call_tool(
512 &self,
513 _request: CallToolRequestParams,
514 _context: RequestContext<RoleServer>,
515 ) -> Result<CallToolResult, ErrorData> {
516 let body = "x".repeat(self.body_bytes.unwrap_or(4));
517 Ok(CallToolResult::success(vec![Content::text(body)]))
518 }
519 }
520
521 fn ctx(name: &str) -> ToolCallContext {
522 ToolCallContext {
523 tool_name: name.to_owned(),
524 arguments: None,
525 identity: None,
526 role: None,
527 sub: None,
528 request_id: None,
529 }
530 }
531
532 #[tokio::test]
533 async fn size_cap_replaces_oversized_result() {
534 let inner = TestHandler {
535 body_bytes: Some(8_192),
536 };
537 let hooks = Arc::new(ToolHooks {
538 max_result_bytes: Some(256),
539 before: None,
540 after: None,
541 });
542 let hooked = with_hooks(inner, hooks);
543
544 let small = CallToolResult::success(vec![Content::text("ok".to_owned())]);
545 assert!(serialized_size(&small) < 256);
546
547 let big = CallToolResult::success(vec![Content::text("x".repeat(8_192))]);
548 let size = serialized_size(&big);
549 assert!(size > 256);
550
551 let (replaced, accounted, capped) = apply_size_cap(big, Some(256), "whatever");
552 assert!(capped);
553 assert_eq!(accounted, size);
554 assert_eq!(replaced.is_error, Some(true));
555 assert!(matches!(
556 &replaced.content[0].raw,
557 rmcp::model::RawContent::Text(t) if t.text.contains("result_too_large")
558 ));
559
560 let _ = hooked;
562 }
563
564 #[tokio::test]
565 async fn before_hook_deny_builds_error() {
566 let counter = Arc::new(AtomicUsize::new(0));
567 let c = Arc::clone(&counter);
568 let before: BeforeHook = Arc::new(move |ctx_ref| {
569 let c = Arc::clone(&c);
570 let name = ctx_ref.tool_name.clone();
571 Box::pin(async move {
572 c.fetch_add(1, Ordering::Relaxed);
573 if name == "forbidden" {
574 HookOutcome::Deny(ErrorData::invalid_request("nope", None))
575 } else {
576 HookOutcome::Continue
577 }
578 })
579 });
580
581 let hooks = Arc::new(ToolHooks {
582 max_result_bytes: None,
583 before: Some(before),
584 after: None,
585 });
586 let hooked = with_hooks(TestHandler::default(), hooks);
587
588 let bad_ctx = ctx("forbidden");
589 let before_fn = hooked.hooks.before.as_ref().unwrap();
590 let outcome = before_fn(&bad_ctx).await;
591 assert!(matches!(outcome, HookOutcome::Deny(_)));
592 assert_eq!(counter.load(Ordering::Relaxed), 1);
593
594 let ok_ctx = ctx("allowed");
595 let outcome2 = before_fn(&ok_ctx).await;
596 assert!(matches!(outcome2, HookOutcome::Continue));
597 assert_eq!(counter.load(Ordering::Relaxed), 2);
598 }
599
600 #[test]
601 fn too_large_result_mentions_limit_and_actual() {
602 let r = too_large_result(100, 500, "my_tool");
603 let body = serde_json::to_string(&r).unwrap();
604 assert!(body.contains("result_too_large"));
605 assert!(body.contains("my_tool"));
606 assert!(body.contains("100"));
607 assert!(body.contains("500"));
608 }
609
610 #[tokio::test]
611 async fn replace_outcome_skips_inner_and_returns_payload() {
612 let before: BeforeHook = Arc::new(|_ctx| {
615 Box::pin(async {
616 HookOutcome::Replace(Box::new(CallToolResult::success(vec![Content::text(
617 "from-replace".to_owned(),
618 )])))
619 })
620 });
621 let hooks = Arc::new(ToolHooks {
622 max_result_bytes: None,
623 before: Some(before),
624 after: None,
625 });
626 let _hooked = with_hooks(TestHandler::default(), Arc::clone(&hooks));
627
628 let outcome = (hooks.before.as_ref().unwrap())(&ctx("any")).await;
631 let HookOutcome::Replace(boxed) = outcome else {
632 panic!("expected HookOutcome::Replace");
633 };
634 let (result, size, capped) = apply_size_cap(*boxed, None, "any");
635 assert!(!capped);
636 assert!(size > 0);
637 assert!(!result.is_error.unwrap_or(false));
638 assert!(matches!(
639 &result.content[0].raw,
640 rmcp::model::RawContent::Text(t) if t.text == "from-replace"
641 ));
642 }
643
644 #[tokio::test]
645 async fn replace_outcome_subject_to_size_cap() {
646 let huge = CallToolResult::success(vec![Content::text("y".repeat(8_192))]);
650 let huge_size = serialized_size(&huge);
651 assert!(huge_size > 256);
652
653 let (final_result, accounted, capped) = apply_size_cap(huge, Some(256), "replaced_tool");
654 assert!(capped);
655 assert_eq!(accounted, huge_size);
656 assert_eq!(final_result.is_error, Some(true));
657 assert!(matches!(
658 &final_result.content[0].raw,
659 rmcp::model::RawContent::Text(t) if t.text.contains("result_too_large")
660 ));
661 }
662
663 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
664 async fn after_hook_fires_exactly_once_via_spawn() {
665 let counter = Arc::new(AtomicUsize::new(0));
669 let c = Arc::clone(&counter);
670 let after: AfterHook = Arc::new(move |_ctx, _disp, _size| {
671 let c = Arc::clone(&c);
672 Box::pin(async move {
673 c.fetch_add(1, Ordering::Relaxed);
674 })
675 });
676 let holder = Arc::new(AfterHookHolder { f: after });
677
678 HookedHandler::<TestHandler>::spawn_after(
679 Some(&holder),
680 ctx("t"),
681 HookDisposition::InnerExecuted,
682 42,
683 );
684
685 let deadline = std::time::Instant::now() + std::time::Duration::from_secs(1);
687 while counter.load(Ordering::Relaxed) == 0 && std::time::Instant::now() < deadline {
688 tokio::task::yield_now().await;
689 tokio::time::sleep(std::time::Duration::from_millis(5)).await;
690 }
691 assert_eq!(counter.load(Ordering::Relaxed), 1);
692 }
693
694 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
695 async fn after_hook_panic_is_isolated_from_response_path() {
696 let after: AfterHook = Arc::new(|_ctx, _disp, _size| {
700 Box::pin(async {
701 panic!("intentional panic in after-hook");
702 })
703 });
704 let holder = Arc::new(AfterHookHolder { f: after });
705
706 HookedHandler::<TestHandler>::spawn_after(
707 Some(&holder),
708 ctx("boom"),
709 HookDisposition::InnerExecuted,
710 0,
711 );
712
713 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
716 let still_alive = tokio::spawn(async { 1_u32 + 2 }).await.unwrap();
717 assert_eq!(still_alive, 3);
718 }
719}