use std::{fmt, future::Future, pin::Pin, sync::Arc};
use rmcp::{
ErrorData, RoleServer, ServerHandler,
model::{
CallToolRequestParams, CallToolResult, Content, GetPromptRequestParams, GetPromptResult,
InitializeRequestParams, InitializeResult, ListPromptsResult, ListResourceTemplatesResult,
ListResourcesResult, ListToolsResult, PaginatedRequestParams, ReadResourceRequestParams,
ReadResourceResult, ServerInfo, Tool,
},
service::RequestContext,
};
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct ToolCallContext {
pub tool_name: String,
pub arguments: Option<serde_json::Value>,
pub identity: Option<String>,
pub role: Option<String>,
pub sub: Option<String>,
pub request_id: Option<String>,
}
impl ToolCallContext {
#[must_use]
pub fn for_tool(tool_name: impl Into<String>) -> Self {
Self {
tool_name: tool_name.into(),
arguments: None,
identity: None,
role: None,
sub: None,
request_id: None,
}
}
}
#[derive(Debug)]
#[non_exhaustive]
pub enum HookOutcome {
Continue,
Deny(ErrorData),
Replace(Box<CallToolResult>),
}
#[derive(Debug, Clone, Copy)]
#[non_exhaustive]
pub enum HookDisposition {
InnerExecuted,
InnerErrored,
DeniedBefore,
ReplacedBefore,
ResultTooLarge,
}
pub type BeforeHook = Arc<
dyn for<'a> Fn(&'a ToolCallContext) -> Pin<Box<dyn Future<Output = HookOutcome> + Send + 'a>>
+ Send
+ Sync
+ 'static,
>;
pub type AfterHook = Arc<
dyn for<'a> Fn(
&'a ToolCallContext,
HookDisposition,
usize,
) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>>
+ Send
+ Sync
+ 'static,
>;
#[allow(clippy::struct_field_names, reason = "before/after read naturally")]
#[derive(Clone, Default)]
#[non_exhaustive]
pub struct ToolHooks {
pub max_result_bytes: Option<usize>,
pub before: Option<BeforeHook>,
pub after: Option<AfterHook>,
}
impl ToolHooks {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_max_result_bytes(mut self, max: usize) -> Self {
self.max_result_bytes = Some(max);
self
}
#[must_use]
pub fn with_before(mut self, before: BeforeHook) -> Self {
self.before = Some(before);
self
}
#[must_use]
pub fn with_after(mut self, after: AfterHook) -> Self {
self.after = Some(after);
self
}
}
impl fmt::Debug for ToolHooks {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ToolHooks")
.field("max_result_bytes", &self.max_result_bytes)
.field("before", &self.before.as_ref().map(|_| "<fn>"))
.field("after", &self.after.as_ref().map(|_| "<fn>"))
.finish()
}
}
#[derive(Clone)]
pub struct HookedHandler<H: ServerHandler> {
inner: Arc<H>,
hooks: Arc<ToolHooks>,
}
impl<H: ServerHandler> fmt::Debug for HookedHandler<H> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("HookedHandler")
.field("hooks", &self.hooks)
.finish_non_exhaustive()
}
}
pub fn with_hooks<H: ServerHandler>(inner: H, hooks: Arc<ToolHooks>) -> HookedHandler<H> {
HookedHandler {
inner: Arc::new(inner),
hooks,
}
}
impl<H: ServerHandler> HookedHandler<H> {
#[must_use]
pub fn inner(&self) -> &H {
&self.inner
}
fn build_context(request: &CallToolRequestParams, req_id: Option<String>) -> ToolCallContext {
ToolCallContext {
tool_name: request.name.to_string(),
arguments: request.arguments.clone().map(serde_json::Value::Object),
identity: crate::rbac::current_identity(),
role: crate::rbac::current_role(),
sub: crate::rbac::current_sub(),
request_id: req_id,
}
}
fn spawn_after(
after: Option<&Arc<AfterHookHolder>>,
ctx: ToolCallContext,
disposition: HookDisposition,
size: usize,
) {
if let Some(after) = after {
use tracing::Instrument;
let after = Arc::clone(after);
let span = tracing::Span::current();
let role = crate::rbac::current_role().unwrap_or_default();
let identity = crate::rbac::current_identity().unwrap_or_default();
let token = crate::rbac::current_token()
.unwrap_or_else(|| secrecy::SecretString::from(String::new()));
let sub = crate::rbac::current_sub().unwrap_or_default();
tokio::spawn(
async move {
crate::rbac::with_rbac_scope(role, identity, token, sub, async move {
let fut = (after.f)(&ctx, disposition, size);
fut.await;
})
.await;
}
.instrument(span),
);
}
}
}
struct AfterHookHolder {
f: AfterHook,
}
fn too_large_result(limit: usize, actual: usize, tool: &str) -> CallToolResult {
let body = serde_json::json!({
"error": "result_too_large",
"message": format!(
"tool '{tool}' result of {actual} bytes exceeds the configured \
max_result_bytes={limit}; ask for a narrower query"
),
"limit_bytes": limit,
"actual_bytes": actual,
});
let mut r = CallToolResult::error(vec![Content::text(body.to_string())]);
r.structured_content = None;
r
}
fn serialized_size(result: &CallToolResult) -> usize {
serde_json::to_vec(result).map_or(0, |v| v.len())
}
fn apply_size_cap(
result: CallToolResult,
max: Option<usize>,
tool: &str,
) -> (CallToolResult, usize, bool) {
let size = serialized_size(&result);
if let Some(limit) = max
&& size > limit
{
tracing::warn!(
tool = %tool,
size_bytes = size,
limit_bytes = limit,
"tool result exceeds max_result_bytes; replacing with structured error"
);
let replaced = too_large_result(limit, size, tool);
return (replaced, size, true);
}
(result, size, false)
}
impl<H: ServerHandler> ServerHandler for HookedHandler<H> {
fn get_info(&self) -> ServerInfo {
self.inner.get_info()
}
async fn initialize(
&self,
request: InitializeRequestParams,
context: RequestContext<RoleServer>,
) -> Result<InitializeResult, ErrorData> {
self.inner.initialize(request, context).await
}
async fn list_tools(
&self,
request: Option<PaginatedRequestParams>,
context: RequestContext<RoleServer>,
) -> Result<ListToolsResult, ErrorData> {
self.inner.list_tools(request, context).await
}
fn get_tool(&self, name: &str) -> Option<Tool> {
self.inner.get_tool(name)
}
async fn list_prompts(
&self,
request: Option<PaginatedRequestParams>,
context: RequestContext<RoleServer>,
) -> Result<ListPromptsResult, ErrorData> {
self.inner.list_prompts(request, context).await
}
async fn get_prompt(
&self,
request: GetPromptRequestParams,
context: RequestContext<RoleServer>,
) -> Result<GetPromptResult, ErrorData> {
self.inner.get_prompt(request, context).await
}
async fn list_resources(
&self,
request: Option<PaginatedRequestParams>,
context: RequestContext<RoleServer>,
) -> Result<ListResourcesResult, ErrorData> {
self.inner.list_resources(request, context).await
}
async fn list_resource_templates(
&self,
request: Option<PaginatedRequestParams>,
context: RequestContext<RoleServer>,
) -> Result<ListResourceTemplatesResult, ErrorData> {
self.inner.list_resource_templates(request, context).await
}
async fn read_resource(
&self,
request: ReadResourceRequestParams,
context: RequestContext<RoleServer>,
) -> Result<ReadResourceResult, ErrorData> {
self.inner.read_resource(request, context).await
}
async fn call_tool(
&self,
request: CallToolRequestParams,
context: RequestContext<RoleServer>,
) -> Result<CallToolResult, ErrorData> {
let req_id = Some(format!("{:?}", context.id));
let ctx = Self::build_context(&request, req_id);
let max = self.hooks.max_result_bytes;
let after_holder = self
.hooks
.after
.as_ref()
.map(|f| Arc::new(AfterHookHolder { f: Arc::clone(f) }));
if let Some(before) = self.hooks.before.as_ref() {
let outcome = before(&ctx).await;
match outcome {
HookOutcome::Continue => {}
HookOutcome::Deny(err) => {
Self::spawn_after(after_holder.as_ref(), ctx, HookDisposition::DeniedBefore, 0);
return Err(err);
}
HookOutcome::Replace(boxed) => {
let (final_result, size, capped) = apply_size_cap(*boxed, max, &ctx.tool_name);
let disposition = if capped {
HookDisposition::ResultTooLarge
} else {
HookDisposition::ReplacedBefore
};
Self::spawn_after(after_holder.as_ref(), ctx, disposition, size);
return Ok(final_result);
}
}
}
let result = self.inner.call_tool(request, context).await;
match result {
Ok(ok) => {
let (final_result, size, capped) = apply_size_cap(ok, max, &ctx.tool_name);
let disposition = if capped {
HookDisposition::ResultTooLarge
} else {
HookDisposition::InnerExecuted
};
Self::spawn_after(after_holder.as_ref(), ctx, disposition, size);
Ok(final_result)
}
Err(e) => {
Self::spawn_after(after_holder.as_ref(), ctx, HookDisposition::InnerErrored, 0);
Err(e)
}
}
}
}
#[cfg(test)]
mod tests {
use std::sync::{
Arc,
atomic::{AtomicUsize, Ordering},
};
use rmcp::{
ErrorData, RoleServer, ServerHandler,
model::{CallToolRequestParams, CallToolResult, Content, ServerInfo},
service::RequestContext,
};
use super::*;
#[derive(Clone, Default)]
struct TestHandler {
body_bytes: Option<usize>,
}
impl ServerHandler for TestHandler {
fn get_info(&self) -> ServerInfo {
ServerInfo::default()
}
async fn call_tool(
&self,
_request: CallToolRequestParams,
_context: RequestContext<RoleServer>,
) -> Result<CallToolResult, ErrorData> {
let body = "x".repeat(self.body_bytes.unwrap_or(4));
Ok(CallToolResult::success(vec![Content::text(body)]))
}
}
fn ctx(name: &str) -> ToolCallContext {
ToolCallContext {
tool_name: name.to_owned(),
arguments: None,
identity: None,
role: None,
sub: None,
request_id: None,
}
}
#[tokio::test]
async fn size_cap_replaces_oversized_result() {
let inner = TestHandler {
body_bytes: Some(8_192),
};
let hooks = Arc::new(ToolHooks {
max_result_bytes: Some(256),
before: None,
after: None,
});
let hooked = with_hooks(inner, hooks);
let small = CallToolResult::success(vec![Content::text("ok".to_owned())]);
assert!(serialized_size(&small) < 256);
let big = CallToolResult::success(vec![Content::text("x".repeat(8_192))]);
let size = serialized_size(&big);
assert!(size > 256);
let (replaced, accounted, capped) = apply_size_cap(big, Some(256), "whatever");
assert!(capped);
assert_eq!(accounted, size);
assert_eq!(replaced.is_error, Some(true));
assert!(matches!(
&replaced.content[0].raw,
rmcp::model::RawContent::Text(t) if t.text.contains("result_too_large")
));
let _ = hooked;
}
#[tokio::test]
async fn before_hook_deny_builds_error() {
let counter = Arc::new(AtomicUsize::new(0));
let c = Arc::clone(&counter);
let before: BeforeHook = Arc::new(move |ctx_ref| {
let c = Arc::clone(&c);
let name = ctx_ref.tool_name.clone();
Box::pin(async move {
c.fetch_add(1, Ordering::Relaxed);
if name == "forbidden" {
HookOutcome::Deny(ErrorData::invalid_request("nope", None))
} else {
HookOutcome::Continue
}
})
});
let hooks = Arc::new(ToolHooks {
max_result_bytes: None,
before: Some(before),
after: None,
});
let hooked = with_hooks(TestHandler::default(), hooks);
let bad_ctx = ctx("forbidden");
let before_fn = hooked.hooks.before.as_ref().unwrap();
let outcome = before_fn(&bad_ctx).await;
assert!(matches!(outcome, HookOutcome::Deny(_)));
assert_eq!(counter.load(Ordering::Relaxed), 1);
let ok_ctx = ctx("allowed");
let outcome2 = before_fn(&ok_ctx).await;
assert!(matches!(outcome2, HookOutcome::Continue));
assert_eq!(counter.load(Ordering::Relaxed), 2);
}
#[test]
fn too_large_result_mentions_limit_and_actual() {
let r = too_large_result(100, 500, "my_tool");
let body = serde_json::to_string(&r).unwrap();
assert!(body.contains("result_too_large"));
assert!(body.contains("my_tool"));
assert!(body.contains("100"));
assert!(body.contains("500"));
}
#[tokio::test]
async fn replace_outcome_skips_inner_and_returns_payload() {
let before: BeforeHook = Arc::new(|_ctx| {
Box::pin(async {
HookOutcome::Replace(Box::new(CallToolResult::success(vec![Content::text(
"from-replace".to_owned(),
)])))
})
});
let hooks = Arc::new(ToolHooks {
max_result_bytes: None,
before: Some(before),
after: None,
});
let _hooked = with_hooks(TestHandler::default(), Arc::clone(&hooks));
let outcome = (hooks.before.as_ref().unwrap())(&ctx("any")).await;
let HookOutcome::Replace(boxed) = outcome else {
panic!("expected HookOutcome::Replace");
};
let (result, size, capped) = apply_size_cap(*boxed, None, "any");
assert!(!capped);
assert!(size > 0);
assert!(!result.is_error.unwrap_or(false));
assert!(matches!(
&result.content[0].raw,
rmcp::model::RawContent::Text(t) if t.text == "from-replace"
));
}
#[tokio::test]
async fn replace_outcome_subject_to_size_cap() {
let huge = CallToolResult::success(vec![Content::text("y".repeat(8_192))]);
let huge_size = serialized_size(&huge);
assert!(huge_size > 256);
let (final_result, accounted, capped) = apply_size_cap(huge, Some(256), "replaced_tool");
assert!(capped);
assert_eq!(accounted, huge_size);
assert_eq!(final_result.is_error, Some(true));
assert!(matches!(
&final_result.content[0].raw,
rmcp::model::RawContent::Text(t) if t.text.contains("result_too_large")
));
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn after_hook_fires_exactly_once_via_spawn() {
let counter = Arc::new(AtomicUsize::new(0));
let c = Arc::clone(&counter);
let after: AfterHook = Arc::new(move |_ctx, _disp, _size| {
let c = Arc::clone(&c);
Box::pin(async move {
c.fetch_add(1, Ordering::Relaxed);
})
});
let holder = Arc::new(AfterHookHolder { f: after });
HookedHandler::<TestHandler>::spawn_after(
Some(&holder),
ctx("t"),
HookDisposition::InnerExecuted,
42,
);
let deadline = std::time::Instant::now() + std::time::Duration::from_secs(1);
while counter.load(Ordering::Relaxed) == 0 && std::time::Instant::now() < deadline {
tokio::task::yield_now().await;
tokio::time::sleep(std::time::Duration::from_millis(5)).await;
}
assert_eq!(counter.load(Ordering::Relaxed), 1);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn after_hook_panic_is_isolated_from_response_path() {
let after: AfterHook = Arc::new(|_ctx, _disp, _size| {
Box::pin(async {
panic!("intentional panic in after-hook");
})
});
let holder = Arc::new(AfterHookHolder { f: after });
HookedHandler::<TestHandler>::spawn_after(
Some(&holder),
ctx("boom"),
HookDisposition::InnerExecuted,
0,
);
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
let still_alive = tokio::spawn(async { 1_u32 + 2 }).await.unwrap();
assert_eq!(still_alive, 3);
}
}