use std::collections::{HashMap, HashSet};
use std::future::Future;
use std::pin::Pin;
use std::sync::{Arc, RwLock};
use std::task::{Context, Poll};
use tower_service::Service;
use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64};
use crate::async_task::TaskStore;
use crate::context::{
CancellationToken, ClientRequesterHandle, NotificationSender, RequestContext,
ServerNotification,
};
use crate::error::{Error, JsonRpcError, Result};
use crate::filter::{PromptFilter, ResourceFilter, ToolFilter};
use crate::prompt::Prompt;
use crate::protocol::*;
#[cfg(feature = "dynamic-tools")]
use crate::registry::{
DynamicPromptRegistry, DynamicPromptsInner, DynamicResourceRegistry,
DynamicResourceTemplateRegistry, DynamicResourceTemplatesInner, DynamicResourcesInner,
DynamicToolRegistry, DynamicToolsInner,
};
use crate::resource::{Resource, ResourceTemplate};
use crate::session::SessionState;
use crate::tool::Tool;
pub(crate) type CompletionHandler = Arc<
dyn Fn(CompleteParams) -> Pin<Box<dyn Future<Output = Result<CompleteResult>> + Send>>
+ Send
+ Sync,
>;
fn decode_cursor(cursor: &str) -> Result<usize> {
let bytes = BASE64
.decode(cursor)
.map_err(|_| Error::JsonRpc(JsonRpcError::invalid_params("Invalid pagination cursor")))?;
let s = String::from_utf8(bytes)
.map_err(|_| Error::JsonRpc(JsonRpcError::invalid_params("Invalid pagination cursor")))?;
s.parse::<usize>()
.map_err(|_| Error::JsonRpc(JsonRpcError::invalid_params("Invalid pagination cursor")))
}
fn encode_cursor(offset: usize) -> String {
BASE64.encode(offset.to_string())
}
fn paginate<T>(
items: Vec<T>,
cursor: Option<&str>,
page_size: Option<usize>,
) -> Result<(Vec<T>, Option<String>)> {
let Some(page_size) = page_size else {
return Ok((items, None));
};
let offset = match cursor {
Some(c) => decode_cursor(c)?,
None => 0,
};
if offset >= items.len() {
return Ok((Vec::new(), None));
}
let end = (offset + page_size).min(items.len());
let next_cursor = if end < items.len() {
Some(encode_cursor(end))
} else {
None
};
let mut items = items;
let page = items.drain(offset..end).collect();
Ok((page, next_cursor))
}
#[derive(Clone)]
pub struct McpRouter {
inner: Arc<McpRouterInner>,
session: SessionState,
}
impl std::fmt::Debug for McpRouter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("McpRouter")
.field("server_name", &self.inner.server_name)
.field("server_version", &self.inner.server_version)
.field("tools_count", &self.inner.tools.len())
.field("resources_count", &self.inner.resources.len())
.field("prompts_count", &self.inner.prompts.len())
.field("session_phase", &self.session.phase())
.finish()
}
}
#[derive(Clone, Debug)]
struct AutoInstructionsConfig {
prefix: Option<String>,
suffix: Option<String>,
}
#[derive(Clone)]
struct McpRouterInner {
server_name: String,
server_version: String,
server_title: Option<String>,
server_description: Option<String>,
server_icons: Option<Vec<ToolIcon>>,
server_website_url: Option<String>,
instructions: Option<String>,
auto_instructions: Option<AutoInstructionsConfig>,
tools: HashMap<String, Arc<Tool>>,
resources: HashMap<String, Arc<Resource>>,
resource_templates: Vec<Arc<ResourceTemplate>>,
prompts: HashMap<String, Arc<Prompt>>,
in_flight: Arc<RwLock<HashMap<RequestId, CancellationToken>>>,
notification_tx: Option<NotificationSender>,
client_requester: Option<ClientRequesterHandle>,
task_store: TaskStore,
subscriptions: Arc<RwLock<HashSet<String>>>,
completion_handler: Option<CompletionHandler>,
tool_filter: Option<ToolFilter>,
resource_filter: Option<ResourceFilter>,
prompt_filter: Option<PromptFilter>,
extensions: Arc<crate::context::Extensions>,
min_log_level: Arc<RwLock<LogLevel>>,
page_size: Option<usize>,
disabled_tools: Arc<RwLock<HashSet<String>>>,
disabled_resources: Arc<RwLock<HashSet<String>>>,
disabled_prompts: Arc<RwLock<HashSet<String>>>,
#[cfg(feature = "dynamic-tools")]
dynamic_tools: Option<Arc<DynamicToolsInner>>,
#[cfg(feature = "dynamic-tools")]
dynamic_prompts: Option<Arc<DynamicPromptsInner>>,
#[cfg(feature = "dynamic-tools")]
dynamic_resources: Option<Arc<DynamicResourcesInner>>,
#[cfg(feature = "dynamic-tools")]
dynamic_resource_templates: Option<Arc<DynamicResourceTemplatesInner>>,
}
impl McpRouterInner {
fn generate_instructions(&self, config: &AutoInstructionsConfig) -> String {
let mut parts = Vec::new();
if let Some(prefix) = &config.prefix {
parts.push(prefix.clone());
}
if !self.tools.is_empty() {
let mut lines = vec!["## Tools".to_string(), String::new()];
let mut tools: Vec<_> = self.tools.values().collect();
tools.sort_by(|a, b| a.name.cmp(&b.name));
for tool in tools {
let desc = tool.description.as_deref().unwrap_or("No description");
let tags = annotation_tags(tool.annotations.as_ref());
if tags.is_empty() {
lines.push(format!("- **{}**: {}", tool.name, desc));
} else {
lines.push(format!("- **{}**: {} [{}]", tool.name, desc, tags));
}
}
parts.push(lines.join("\n"));
}
if !self.resources.is_empty() || !self.resource_templates.is_empty() {
let mut lines = vec!["## Resources".to_string(), String::new()];
let mut resources: Vec<_> = self.resources.values().collect();
resources.sort_by(|a, b| a.uri.cmp(&b.uri));
for resource in resources {
let desc = resource.description.as_deref().unwrap_or("No description");
lines.push(format!("- **{}**: {}", resource.uri, desc));
}
let mut templates: Vec<_> = self.resource_templates.iter().collect();
templates.sort_by(|a, b| a.uri_template.cmp(&b.uri_template));
for template in templates {
let desc = template.description.as_deref().unwrap_or("No description");
lines.push(format!("- **{}**: {}", template.uri_template, desc));
}
parts.push(lines.join("\n"));
}
if !self.prompts.is_empty() {
let mut lines = vec!["## Prompts".to_string(), String::new()];
let mut prompts: Vec<_> = self.prompts.values().collect();
prompts.sort_by(|a, b| a.name.cmp(&b.name));
for prompt in prompts {
let desc = prompt.description.as_deref().unwrap_or("No description");
lines.push(format!("- **{}**: {}", prompt.name, desc));
}
parts.push(lines.join("\n"));
}
if let Some(suffix) = &config.suffix {
parts.push(suffix.clone());
}
parts.join("\n\n")
}
}
fn annotation_tags(annotations: Option<&crate::protocol::ToolAnnotations>) -> String {
let Some(ann) = annotations else {
return String::new();
};
let mut tags = Vec::new();
if ann.is_read_only() {
tags.push("read-only");
}
if ann.is_idempotent() {
tags.push("idempotent");
}
tags.join(", ")
}
impl McpRouter {
pub fn new() -> Self {
Self {
inner: Arc::new(McpRouterInner {
server_name: "tower-mcp".to_string(),
server_version: env!("CARGO_PKG_VERSION").to_string(),
server_title: None,
server_description: None,
server_icons: None,
server_website_url: None,
instructions: None,
auto_instructions: None,
tools: HashMap::new(),
resources: HashMap::new(),
resource_templates: Vec::new(),
prompts: HashMap::new(),
in_flight: Arc::new(RwLock::new(HashMap::new())),
notification_tx: None,
client_requester: None,
task_store: TaskStore::new(),
subscriptions: Arc::new(RwLock::new(HashSet::new())),
extensions: Arc::new(crate::context::Extensions::new()),
completion_handler: None,
tool_filter: None,
resource_filter: None,
prompt_filter: None,
min_log_level: Arc::new(RwLock::new(LogLevel::Debug)),
page_size: None,
disabled_tools: Arc::new(RwLock::new(HashSet::new())),
disabled_resources: Arc::new(RwLock::new(HashSet::new())),
disabled_prompts: Arc::new(RwLock::new(HashSet::new())),
#[cfg(feature = "dynamic-tools")]
dynamic_tools: None,
#[cfg(feature = "dynamic-tools")]
dynamic_prompts: None,
#[cfg(feature = "dynamic-tools")]
dynamic_resources: None,
#[cfg(feature = "dynamic-tools")]
dynamic_resource_templates: None,
}),
session: SessionState::new(),
}
}
pub fn with_fresh_session(&self) -> Self {
Self {
inner: self.inner.clone(),
session: SessionState::new(),
}
}
pub fn tool_annotations_map(&self) -> ToolAnnotationsMap {
let disabled = self.inner.disabled_tools.read().unwrap();
let mut map = HashMap::new();
for (name, tool) in &self.inner.tools {
if disabled.contains(name) {
continue;
}
if let Some(annotations) = &tool.annotations {
map.insert(name.clone(), annotations.clone());
}
}
#[cfg(feature = "dynamic-tools")]
if let Some(dynamic) = &self.inner.dynamic_tools {
for tool in dynamic.list() {
if disabled.contains(&tool.name) {
continue;
}
if !map.contains_key(&tool.name)
&& let Some(ref annotations) = tool.annotations
{
map.insert(tool.name.clone(), annotations.clone());
}
}
}
ToolAnnotationsMap { map: Arc::new(map) }
}
pub fn task_store(&self) -> &TaskStore {
&self.inner.task_store
}
#[cfg(feature = "dynamic-tools")]
pub fn with_dynamic_tools(mut self) -> (Self, DynamicToolRegistry) {
let inner_dyn = Arc::new(DynamicToolsInner::new());
Arc::make_mut(&mut self.inner).dynamic_tools = Some(inner_dyn.clone());
(self, DynamicToolRegistry::new(inner_dyn))
}
#[cfg(feature = "dynamic-tools")]
pub fn with_dynamic_prompts(mut self) -> (Self, DynamicPromptRegistry) {
let inner_dyn = Arc::new(DynamicPromptsInner::new());
Arc::make_mut(&mut self.inner).dynamic_prompts = Some(inner_dyn.clone());
(self, DynamicPromptRegistry::new(inner_dyn))
}
#[cfg(feature = "dynamic-tools")]
pub fn with_dynamic_resources(mut self) -> (Self, DynamicResourceRegistry) {
let inner_dyn = Arc::new(DynamicResourcesInner::new());
Arc::make_mut(&mut self.inner).dynamic_resources = Some(inner_dyn.clone());
(self, DynamicResourceRegistry::new(inner_dyn))
}
#[cfg(feature = "dynamic-tools")]
pub fn with_dynamic_resource_templates(mut self) -> (Self, DynamicResourceTemplateRegistry) {
let inner_dyn = Arc::new(DynamicResourceTemplatesInner::new());
Arc::make_mut(&mut self.inner).dynamic_resource_templates = Some(inner_dyn.clone());
(self, DynamicResourceTemplateRegistry::new(inner_dyn))
}
pub fn with_notification_sender(mut self, tx: NotificationSender) -> Self {
let inner = Arc::make_mut(&mut self.inner);
#[cfg(feature = "dynamic-tools")]
if let Some(ref dynamic_tools) = inner.dynamic_tools {
dynamic_tools.add_notification_sender(tx.clone());
}
#[cfg(feature = "dynamic-tools")]
if let Some(ref dynamic_prompts) = inner.dynamic_prompts {
dynamic_prompts.add_notification_sender(tx.clone());
}
#[cfg(feature = "dynamic-tools")]
if let Some(ref dynamic_resources) = inner.dynamic_resources {
dynamic_resources.add_notification_sender(tx.clone());
}
#[cfg(feature = "dynamic-tools")]
if let Some(ref dynamic_resource_templates) = inner.dynamic_resource_templates {
dynamic_resource_templates.add_notification_sender(tx.clone());
}
inner.notification_tx = Some(tx);
self
}
pub fn notification_sender(&self) -> Option<&NotificationSender> {
self.inner.notification_tx.as_ref()
}
pub fn with_client_requester(mut self, requester: ClientRequesterHandle) -> Self {
Arc::make_mut(&mut self.inner).client_requester = Some(requester);
self
}
pub fn client_requester(&self) -> Option<&ClientRequesterHandle> {
self.inner.client_requester.as_ref()
}
pub fn with_state<T: Clone + Send + Sync + 'static>(mut self, state: T) -> Self {
let inner = Arc::make_mut(&mut self.inner);
Arc::make_mut(&mut inner.extensions).insert(state);
self
}
pub fn with_extension<T: Clone + Send + Sync + 'static>(self, value: T) -> Self {
self.with_state(value)
}
pub fn extensions(&self) -> &crate::context::Extensions {
&self.inner.extensions
}
pub fn create_context(
&self,
request_id: RequestId,
progress_token: Option<ProgressToken>,
) -> RequestContext {
let ctx = RequestContext::new(request_id.clone());
let ctx = if let Some(token) = progress_token {
ctx.with_progress_token(token)
} else {
ctx
};
let ctx = if let Some(tx) = &self.inner.notification_tx {
ctx.with_notification_sender(tx.clone())
} else {
ctx
};
let ctx = if let Some(requester) = &self.inner.client_requester {
ctx.with_client_requester(requester.clone())
} else {
ctx
};
let ctx = ctx.with_extensions(self.inner.extensions.clone());
let ctx = ctx.with_min_log_level(self.inner.min_log_level.clone());
let token = ctx.cancellation_token();
if let Ok(mut in_flight) = self.inner.in_flight.write() {
in_flight.insert(request_id, token);
}
ctx
}
pub fn complete_request(&self, request_id: &RequestId) {
if let Ok(mut in_flight) = self.inner.in_flight.write() {
in_flight.remove(request_id);
}
}
fn cancel_request(&self, request_id: &RequestId) -> bool {
let Ok(in_flight) = self.inner.in_flight.read() else {
return false;
};
let Some(token) = in_flight.get(request_id) else {
return false;
};
token.cancel();
true
}
pub fn server_info(mut self, name: impl Into<String>, version: impl Into<String>) -> Self {
let inner = Arc::make_mut(&mut self.inner);
inner.server_name = name.into();
inner.server_version = version.into();
self
}
pub fn page_size(mut self, size: usize) -> Self {
Arc::make_mut(&mut self.inner).page_size = Some(size);
self
}
pub fn instructions(mut self, instructions: impl Into<String>) -> Self {
Arc::make_mut(&mut self.inner).instructions = Some(instructions.into());
self
}
pub fn auto_instructions(mut self) -> Self {
Arc::make_mut(&mut self.inner).auto_instructions = Some(AutoInstructionsConfig {
prefix: None,
suffix: None,
});
self
}
pub fn auto_instructions_with(
mut self,
prefix: Option<impl Into<String>>,
suffix: Option<impl Into<String>>,
) -> Self {
Arc::make_mut(&mut self.inner).auto_instructions = Some(AutoInstructionsConfig {
prefix: prefix.map(Into::into),
suffix: suffix.map(Into::into),
});
self
}
pub fn server_title(mut self, title: impl Into<String>) -> Self {
Arc::make_mut(&mut self.inner).server_title = Some(title.into());
self
}
pub fn server_description(mut self, description: impl Into<String>) -> Self {
Arc::make_mut(&mut self.inner).server_description = Some(description.into());
self
}
pub fn server_icons(mut self, icons: Vec<ToolIcon>) -> Self {
Arc::make_mut(&mut self.inner).server_icons = Some(icons);
self
}
pub fn server_website_url(mut self, url: impl Into<String>) -> Self {
Arc::make_mut(&mut self.inner).server_website_url = Some(url.into());
self
}
pub fn tool(mut self, tool: Tool) -> Self {
Arc::make_mut(&mut self.inner)
.tools
.insert(tool.name.clone(), Arc::new(tool));
self
}
pub fn tool_if(self, condition: bool, tool: Tool) -> Self {
if condition { self.tool(tool) } else { self }
}
pub fn resource(mut self, resource: Resource) -> Self {
Arc::make_mut(&mut self.inner)
.resources
.insert(resource.uri.clone(), Arc::new(resource));
self
}
pub fn resource_if(self, condition: bool, resource: Resource) -> Self {
if condition {
self.resource(resource)
} else {
self
}
}
pub fn resource_template(mut self, template: ResourceTemplate) -> Self {
Arc::make_mut(&mut self.inner)
.resource_templates
.push(Arc::new(template));
self
}
pub fn prompt(mut self, prompt: Prompt) -> Self {
Arc::make_mut(&mut self.inner)
.prompts
.insert(prompt.name.clone(), Arc::new(prompt));
self
}
pub fn prompt_if(self, condition: bool, prompt: Prompt) -> Self {
if condition { self.prompt(prompt) } else { self }
}
pub fn tools(self, tools: impl IntoIterator<Item = Tool>) -> Self {
tools
.into_iter()
.fold(self, |router, tool| router.tool(tool))
}
pub fn tools_if(self, condition: bool, tools: impl IntoIterator<Item = Tool>) -> Self {
if condition { self.tools(tools) } else { self }
}
pub fn resources(self, resources: impl IntoIterator<Item = Resource>) -> Self {
resources
.into_iter()
.fold(self, |router, resource| router.resource(resource))
}
pub fn resources_if(
self,
condition: bool,
resources: impl IntoIterator<Item = Resource>,
) -> Self {
if condition {
self.resources(resources)
} else {
self
}
}
pub fn prompts(self, prompts: impl IntoIterator<Item = Prompt>) -> Self {
prompts
.into_iter()
.fold(self, |router, prompt| router.prompt(prompt))
}
pub fn prompts_if(self, condition: bool, prompts: impl IntoIterator<Item = Prompt>) -> Self {
if condition {
self.prompts(prompts)
} else {
self
}
}
pub fn merge(mut self, other: McpRouter) -> Self {
let inner = Arc::make_mut(&mut self.inner);
let other_inner = other.inner;
for (name, tool) in &other_inner.tools {
inner.tools.insert(name.clone(), tool.clone());
}
for (uri, resource) in &other_inner.resources {
inner.resources.insert(uri.clone(), resource.clone());
}
for template in &other_inner.resource_templates {
inner.resource_templates.push(template.clone());
}
for (name, prompt) in &other_inner.prompts {
inner.prompts.insert(name.clone(), prompt.clone());
}
self
}
pub fn nest(mut self, prefix: impl Into<String>, other: McpRouter) -> Self {
let prefix = prefix.into();
let inner = Arc::make_mut(&mut self.inner);
let other_inner = other.inner;
for tool in other_inner.tools.values() {
let prefixed_tool = tool.with_name_prefix(&prefix);
inner
.tools
.insert(prefixed_tool.name.clone(), Arc::new(prefixed_tool));
}
for (uri, resource) in &other_inner.resources {
inner.resources.insert(uri.clone(), resource.clone());
}
for template in &other_inner.resource_templates {
inner.resource_templates.push(template.clone());
}
for (name, prompt) in &other_inner.prompts {
inner.prompts.insert(name.clone(), prompt.clone());
}
self
}
pub fn completion_handler<F, Fut>(mut self, handler: F) -> Self
where
F: Fn(CompleteParams) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<CompleteResult>> + Send + 'static,
{
Arc::make_mut(&mut self.inner).completion_handler =
Some(Arc::new(move |params| Box::pin(handler(params))));
self
}
pub fn tool_filter(mut self, filter: ToolFilter) -> Self {
Arc::make_mut(&mut self.inner).tool_filter = Some(filter);
self
}
pub fn resource_filter(mut self, filter: ResourceFilter) -> Self {
Arc::make_mut(&mut self.inner).resource_filter = Some(filter);
self
}
pub fn prompt_filter(mut self, filter: PromptFilter) -> Self {
Arc::make_mut(&mut self.inner).prompt_filter = Some(filter);
self
}
pub fn session(&self) -> &SessionState {
&self.session
}
pub fn log(&self, params: LoggingMessageParams) -> bool {
let Some(tx) = &self.inner.notification_tx else {
return false;
};
tx.try_send(ServerNotification::LogMessage(params)).is_ok()
}
pub fn log_info(&self, message: &str) -> bool {
self.log(LoggingMessageParams::new(
LogLevel::Info,
serde_json::json!({ "message": message }),
))
}
pub fn log_warning(&self, message: &str) -> bool {
self.log(LoggingMessageParams::new(
LogLevel::Warning,
serde_json::json!({ "message": message }),
))
}
pub fn log_error(&self, message: &str) -> bool {
self.log(LoggingMessageParams::new(
LogLevel::Error,
serde_json::json!({ "message": message }),
))
}
pub fn log_debug(&self, message: &str) -> bool {
self.log(LoggingMessageParams::new(
LogLevel::Debug,
serde_json::json!({ "message": message }),
))
}
pub fn is_subscribed(&self, uri: &str) -> bool {
if let Ok(subs) = self.inner.subscriptions.read() {
return subs.contains(uri);
}
false
}
pub fn subscribed_uris(&self) -> Vec<String> {
if let Ok(subs) = self.inner.subscriptions.read() {
return subs.iter().cloned().collect();
}
Vec::new()
}
fn subscribe(&self, uri: &str) -> bool {
if let Ok(mut subs) = self.inner.subscriptions.write() {
return subs.insert(uri.to_string());
}
false
}
fn unsubscribe(&self, uri: &str) -> bool {
if let Ok(mut subs) = self.inner.subscriptions.write() {
return subs.remove(uri);
}
false
}
pub fn notify_resource_updated(&self, uri: &str) -> bool {
if !self.is_subscribed(uri) {
return false;
}
let Some(tx) = &self.inner.notification_tx else {
return false;
};
tx.try_send(ServerNotification::ResourceUpdated {
uri: uri.to_string(),
})
.is_ok()
}
pub fn notify_resources_list_changed(&self) -> bool {
let Some(tx) = &self.inner.notification_tx else {
return false;
};
tx.try_send(ServerNotification::ResourcesListChanged)
.is_ok()
}
pub fn notify_tools_list_changed(&self) -> bool {
let Some(tx) = &self.inner.notification_tx else {
return false;
};
tx.try_send(ServerNotification::ToolsListChanged).is_ok()
}
pub fn notify_prompts_list_changed(&self) -> bool {
let Some(tx) = &self.inner.notification_tx else {
return false;
};
tx.try_send(ServerNotification::PromptsListChanged).is_ok()
}
pub fn disable_tool(&self, name: impl Into<String>) {
let mut set = self.inner.disabled_tools.write().unwrap();
set.insert(name.into());
}
pub fn enable_tool(&self, name: &str) {
let mut set = self.inner.disabled_tools.write().unwrap();
set.remove(name);
}
pub fn is_tool_enabled(&self, name: &str) -> bool {
!self.inner.disabled_tools.read().unwrap().contains(name)
}
pub fn disable_resource(&self, uri: impl Into<String>) {
let mut set = self.inner.disabled_resources.write().unwrap();
set.insert(uri.into());
}
pub fn enable_resource(&self, uri: &str) {
let mut set = self.inner.disabled_resources.write().unwrap();
set.remove(uri);
}
pub fn is_resource_enabled(&self, uri: &str) -> bool {
!self.inner.disabled_resources.read().unwrap().contains(uri)
}
pub fn disable_prompt(&self, name: impl Into<String>) {
let mut set = self.inner.disabled_prompts.write().unwrap();
set.insert(name.into());
}
pub fn enable_prompt(&self, name: &str) {
let mut set = self.inner.disabled_prompts.write().unwrap();
set.remove(name);
}
pub fn is_prompt_enabled(&self, name: &str) -> bool {
!self.inner.disabled_prompts.read().unwrap().contains(name)
}
fn capabilities(&self) -> ServerCapabilities {
let has_resources =
!self.inner.resources.is_empty() || !self.inner.resource_templates.is_empty();
let has_notifications = self.inner.notification_tx.is_some();
#[cfg(feature = "dynamic-tools")]
let has_dynamic_tools = self.inner.dynamic_tools.is_some();
#[cfg(not(feature = "dynamic-tools"))]
let has_dynamic_tools = false;
#[cfg(feature = "dynamic-tools")]
let has_dynamic_prompts = self.inner.dynamic_prompts.is_some();
#[cfg(not(feature = "dynamic-tools"))]
let has_dynamic_prompts = false;
#[cfg(feature = "dynamic-tools")]
let has_dynamic_resources = self.inner.dynamic_resources.is_some()
|| self.inner.dynamic_resource_templates.is_some();
#[cfg(not(feature = "dynamic-tools"))]
let has_dynamic_resources = false;
ServerCapabilities {
tools: if self.inner.tools.is_empty() && !has_dynamic_tools {
None
} else {
Some(ToolsCapability {
list_changed: has_notifications,
})
},
resources: if has_resources || has_dynamic_resources {
Some(ResourcesCapability {
subscribe: true,
list_changed: has_notifications,
})
} else {
None
},
prompts: if self.inner.prompts.is_empty() && !has_dynamic_prompts {
None
} else {
Some(PromptsCapability {
list_changed: has_notifications,
})
},
logging: if self.inner.notification_tx.is_some() {
Some(LoggingCapability::default())
} else {
None
},
tasks: {
let has_task_support = self
.inner
.tools
.values()
.any(|t| !matches!(t.task_support, TaskSupportMode::Forbidden));
if has_task_support {
Some(TasksCapability {
list: Some(TasksListCapability {}),
cancel: Some(TasksCancelCapability {}),
requests: Some(TasksRequestsCapability {
tools: Some(TasksToolsRequestsCapability {
call: Some(TasksToolsCallCapability {}),
}),
}),
})
} else {
None
}
},
completions: if self.inner.completion_handler.is_some() {
Some(CompletionsCapability::default())
} else {
None
},
experimental: None,
extensions: None,
}
}
async fn handle(&self, request_id: RequestId, request: McpRequest) -> Result<McpResponse> {
let method = request.method_name();
if !self.session.is_request_allowed(method) {
tracing::warn!(
method = %method,
phase = ?self.session.phase(),
"Request rejected: session not initialized"
);
return Err(Error::JsonRpc(JsonRpcError::invalid_request(format!(
"Session not initialized. Only 'initialize' and 'ping' are allowed before initialization. Got: {}",
method
))));
}
match request {
McpRequest::Initialize(params) => {
tracing::info!(
client = %params.client_info.name,
version = %params.client_info.version,
"Client initializing"
);
let protocol_version = if crate::protocol::SUPPORTED_PROTOCOL_VERSIONS
.contains(¶ms.protocol_version.as_str())
{
params.protocol_version
} else {
crate::protocol::LATEST_PROTOCOL_VERSION.to_string()
};
self.session.mark_initializing();
Ok(McpResponse::Initialize(InitializeResult {
protocol_version,
capabilities: self.capabilities(),
server_info: Implementation {
name: self.inner.server_name.clone(),
version: self.inner.server_version.clone(),
title: self.inner.server_title.clone(),
description: self.inner.server_description.clone(),
icons: self.inner.server_icons.clone(),
website_url: self.inner.server_website_url.clone(),
meta: None,
},
instructions: if let Some(config) = &self.inner.auto_instructions {
Some(self.inner.generate_instructions(config))
} else {
self.inner.instructions.clone()
},
meta: None,
}))
}
McpRequest::ListTools(params) => {
let filter = self.inner.tool_filter.as_ref();
let disabled = self.inner.disabled_tools.read().unwrap().clone();
let is_visible = |t: &Tool| {
!disabled.contains(&t.name)
&& filter
.map(|f| f.is_visible(&self.session, t))
.unwrap_or(true)
};
let mut tools: Vec<ToolDefinition> = self
.inner
.tools
.values()
.filter(|t| is_visible(t))
.map(|t| t.definition())
.collect();
#[cfg(feature = "dynamic-tools")]
if let Some(ref dynamic) = self.inner.dynamic_tools {
let static_names: HashSet<String> =
tools.iter().map(|t| t.name.clone()).collect();
for t in dynamic.list() {
if !static_names.contains(&t.name) && is_visible(&t) {
tools.push(t.definition());
}
}
}
tools.sort_by(|a, b| a.name.cmp(&b.name));
let (tools, next_cursor) =
paginate(tools, params.cursor.as_deref(), self.inner.page_size)?;
Ok(McpResponse::ListTools(ListToolsResult {
tools,
next_cursor,
meta: None,
}))
}
McpRequest::CallTool(params) => {
if self
.inner
.disabled_tools
.read()
.unwrap()
.contains(¶ms.name)
{
tracing::info!(
target: "mcp::tools",
tool = %params.name,
status = "disabled",
"tool call completed"
);
return Err(Error::JsonRpc(JsonRpcError::method_not_found(¶ms.name)));
}
let tool = self.inner.tools.get(¶ms.name).cloned();
#[cfg(feature = "dynamic-tools")]
let tool = tool.or_else(|| {
self.inner
.dynamic_tools
.as_ref()
.and_then(|d| d.get(¶ms.name))
});
let tool = match tool {
Some(t) => t,
None => {
tracing::info!(
target: "mcp::tools",
tool = %params.name,
status = "not_found",
"tool call completed"
);
return Err(Error::JsonRpc(JsonRpcError::method_not_found(¶ms.name)));
}
};
if let Some(filter) = &self.inner.tool_filter
&& !filter.is_visible(&self.session, &tool)
{
tracing::info!(
target: "mcp::tools",
tool = %params.name,
status = "denied",
"tool call completed"
);
return Err(filter.denial_error(¶ms.name));
}
if let Some(task_params) = params.task {
if matches!(tool.task_support, TaskSupportMode::Forbidden) {
return Err(Error::JsonRpc(JsonRpcError::invalid_params(format!(
"Tool '{}' does not support async tasks",
params.name
))));
}
let (task_id, cancellation_token) = self.inner.task_store.create_task(
¶ms.name,
params.arguments.clone(),
task_params.ttl,
);
tracing::info!(task_id = %task_id, tool = %params.name, "Created async task");
let progress_token = params.meta.and_then(|m| m.progress_token);
let ctx = self.create_context(request_id, progress_token);
let task_store = self.inner.task_store.clone();
let tool = tool.clone();
let arguments = params.arguments;
let task_id_clone = task_id.clone();
let tool_name = params.name.clone();
tokio::spawn(async move {
if cancellation_token.is_cancelled() {
tracing::debug!(task_id = %task_id_clone, "Task cancelled before execution");
return;
}
let start = std::time::Instant::now();
let result = tool.call_with_context(ctx, arguments).await;
let duration_ms = start.elapsed().as_secs_f64() * 1000.0;
if cancellation_token.is_cancelled() {
tracing::debug!(task_id = %task_id_clone, "Task cancelled during execution");
} else if result.is_error {
let error_msg = result.first_text().unwrap_or("Tool execution failed");
task_store.fail_task(&task_id_clone, error_msg);
tracing::info!(
target: "mcp::tools",
tool = %tool_name,
task_id = %task_id_clone,
duration_ms,
status = "error",
error = %error_msg,
"tool call completed"
);
} else {
task_store.complete_task(&task_id_clone, result);
tracing::info!(
target: "mcp::tools",
tool = %tool_name,
task_id = %task_id_clone,
duration_ms,
status = "success",
"tool call completed"
);
}
});
let task = self.inner.task_store.get_task(&task_id).ok_or_else(|| {
Error::JsonRpc(JsonRpcError::internal_error(
"Failed to retrieve created task",
))
})?;
Ok(McpResponse::CreateTask(CreateTaskResult {
task,
meta: None,
}))
} else {
if matches!(tool.task_support, TaskSupportMode::Required) {
return Err(Error::JsonRpc(JsonRpcError::invalid_params(format!(
"Tool '{}' requires async task execution (include 'task' in params)",
params.name
))));
}
let progress_token = params.meta.and_then(|m| m.progress_token);
let ctx = self.create_context(request_id, progress_token);
let start = std::time::Instant::now();
let result = tool.call_with_context(ctx, params.arguments).await;
let duration_ms = start.elapsed().as_secs_f64() * 1000.0;
if result.is_error {
tracing::info!(
target: "mcp::tools",
tool = %params.name,
duration_ms,
status = "error",
"tool call completed"
);
} else {
tracing::info!(
target: "mcp::tools",
tool = %params.name,
duration_ms,
status = "success",
"tool call completed"
);
}
Ok(McpResponse::CallTool(result))
}
}
McpRequest::ListResources(params) => {
let disabled = self.inner.disabled_resources.read().unwrap().clone();
let is_visible = |r: &Resource| -> bool {
!disabled.contains(&r.uri)
&& self
.inner
.resource_filter
.as_ref()
.map(|f| f.is_visible(&self.session, r))
.unwrap_or(true)
};
let mut resources: Vec<ResourceDefinition> = self
.inner
.resources
.values()
.filter(|r| is_visible(r))
.map(|r| r.definition())
.collect();
#[cfg(feature = "dynamic-tools")]
if let Some(ref dynamic) = self.inner.dynamic_resources {
let static_uris: HashSet<String> =
resources.iter().map(|r| r.uri.clone()).collect();
for r in dynamic.list() {
if !static_uris.contains(&r.uri) && is_visible(&r) {
resources.push(r.definition());
}
}
}
resources.sort_by(|a, b| a.uri.cmp(&b.uri));
let (resources, next_cursor) =
paginate(resources, params.cursor.as_deref(), self.inner.page_size)?;
Ok(McpResponse::ListResources(ListResourcesResult {
resources,
next_cursor,
meta: None,
}))
}
McpRequest::ListResourceTemplates(params) => {
let mut resource_templates: Vec<ResourceTemplateDefinition> = self
.inner
.resource_templates
.iter()
.map(|t| t.definition())
.collect();
#[cfg(feature = "dynamic-tools")]
if let Some(ref dynamic) = self.inner.dynamic_resource_templates {
let static_patterns: HashSet<String> = resource_templates
.iter()
.map(|t| t.uri_template.clone())
.collect();
for t in dynamic.list() {
if !static_patterns.contains(&t.uri_template) {
resource_templates.push(t.definition());
}
}
}
resource_templates.sort_by(|a, b| a.uri_template.cmp(&b.uri_template));
let (resource_templates, next_cursor) = paginate(
resource_templates,
params.cursor.as_deref(),
self.inner.page_size,
)?;
Ok(McpResponse::ListResourceTemplates(
ListResourceTemplatesResult {
resource_templates,
next_cursor,
meta: None,
},
))
}
McpRequest::ReadResource(params) => {
if self
.inner
.disabled_resources
.read()
.unwrap()
.contains(¶ms.uri)
{
return Err(Error::JsonRpc(JsonRpcError::resource_not_found(
¶ms.uri,
)));
}
if let Some(resource) = self.inner.resources.get(¶ms.uri) {
if let Some(filter) = &self.inner.resource_filter
&& !filter.is_visible(&self.session, resource)
{
return Err(filter.denial_error(¶ms.uri));
}
tracing::debug!(uri = %params.uri, "Reading static resource");
let ctx = self.create_context(request_id, None);
let result = resource.read_with_context(ctx).await;
return Ok(McpResponse::ReadResource(result));
}
#[cfg(feature = "dynamic-tools")]
#[allow(clippy::collapsible_if)]
if let Some(ref dynamic) = self.inner.dynamic_resources {
if let Some(resource) = dynamic.get(¶ms.uri) {
if let Some(filter) = &self.inner.resource_filter
&& !filter.is_visible(&self.session, &resource)
{
return Err(filter.denial_error(¶ms.uri));
}
tracing::debug!(uri = %params.uri, "Reading dynamic resource");
let ctx = self.create_context(request_id, None);
let result = resource.read_with_context(ctx).await;
return Ok(McpResponse::ReadResource(result));
}
}
for template in &self.inner.resource_templates {
if let Some(variables) = template.match_uri(¶ms.uri) {
tracing::debug!(
uri = %params.uri,
template = %template.uri_template,
"Reading resource via template"
);
let result = template.read(¶ms.uri, variables).await?;
return Ok(McpResponse::ReadResource(result));
}
}
#[cfg(feature = "dynamic-tools")]
#[allow(clippy::collapsible_if)]
if let Some(ref dynamic) = self.inner.dynamic_resource_templates {
if let Some((template, variables)) = dynamic.match_uri(¶ms.uri) {
tracing::debug!(
uri = %params.uri,
template = %template.uri_template,
"Reading resource via dynamic template"
);
let result = template.read(¶ms.uri, variables).await?;
return Ok(McpResponse::ReadResource(result));
}
}
Err(Error::JsonRpc(JsonRpcError::resource_not_found(
¶ms.uri,
)))
}
McpRequest::SubscribeResource(params) => {
if !self.inner.resources.contains_key(¶ms.uri) {
return Err(Error::JsonRpc(JsonRpcError::resource_not_found(
¶ms.uri,
)));
}
tracing::debug!(uri = %params.uri, "Subscribing to resource");
self.subscribe(¶ms.uri);
Ok(McpResponse::SubscribeResource(EmptyResult {}))
}
McpRequest::UnsubscribeResource(params) => {
if !self.inner.resources.contains_key(¶ms.uri) {
return Err(Error::JsonRpc(JsonRpcError::resource_not_found(
¶ms.uri,
)));
}
tracing::debug!(uri = %params.uri, "Unsubscribing from resource");
self.unsubscribe(¶ms.uri);
Ok(McpResponse::UnsubscribeResource(EmptyResult {}))
}
McpRequest::ListPrompts(params) => {
let disabled = self.inner.disabled_prompts.read().unwrap().clone();
let is_visible = |p: &Prompt| -> bool {
!disabled.contains(&p.name)
&& self
.inner
.prompt_filter
.as_ref()
.map(|f| f.is_visible(&self.session, p))
.unwrap_or(true)
};
let mut prompts: Vec<PromptDefinition> = self
.inner
.prompts
.values()
.filter(|p| is_visible(p))
.map(|p| p.definition())
.collect();
#[cfg(feature = "dynamic-tools")]
if let Some(ref dynamic) = self.inner.dynamic_prompts {
let static_names: HashSet<String> =
prompts.iter().map(|p| p.name.clone()).collect();
for p in dynamic.list() {
if !static_names.contains(&p.name) && is_visible(&p) {
prompts.push(p.definition());
}
}
}
prompts.sort_by(|a, b| a.name.cmp(&b.name));
let (prompts, next_cursor) =
paginate(prompts, params.cursor.as_deref(), self.inner.page_size)?;
Ok(McpResponse::ListPrompts(ListPromptsResult {
prompts,
next_cursor,
meta: None,
}))
}
McpRequest::GetPrompt(params) => {
if self
.inner
.disabled_prompts
.read()
.unwrap()
.contains(¶ms.name)
{
return Err(Error::JsonRpc(JsonRpcError::method_not_found(&format!(
"Prompt not found: {}",
params.name
))));
}
let prompt = self.inner.prompts.get(¶ms.name).cloned();
#[cfg(feature = "dynamic-tools")]
let prompt = prompt.or_else(|| {
self.inner
.dynamic_prompts
.as_ref()
.and_then(|d| d.get(¶ms.name))
});
let prompt = prompt.ok_or_else(|| {
Error::JsonRpc(JsonRpcError::method_not_found(&format!(
"Prompt not found: {}",
params.name
)))
})?;
if let Some(filter) = &self.inner.prompt_filter
&& !filter.is_visible(&self.session, &prompt)
{
return Err(filter.denial_error(¶ms.name));
}
tracing::debug!(name = %params.name, "Getting prompt");
let ctx = self.create_context(request_id, None);
let result = prompt.get_with_context(ctx, params.arguments).await?;
Ok(McpResponse::GetPrompt(result))
}
McpRequest::Ping => Ok(McpResponse::Pong(EmptyResult {})),
McpRequest::ListTasks(params) => {
let tasks = self.inner.task_store.list_tasks(params.status);
let (tasks, next_cursor) =
paginate(tasks, params.cursor.as_deref(), self.inner.page_size)?;
Ok(McpResponse::ListTasks(ListTasksResult {
tasks,
next_cursor,
}))
}
McpRequest::GetTaskInfo(params) => {
let task = self
.inner
.task_store
.get_task(¶ms.task_id)
.ok_or_else(|| {
Error::JsonRpc(JsonRpcError::invalid_params(format!(
"Task not found: {}",
params.task_id
)))
})?;
Ok(McpResponse::GetTaskInfo(task))
}
McpRequest::GetTaskResult(params) => {
let (task_obj, result, error) = self
.inner
.task_store
.wait_for_completion(¶ms.task_id)
.await
.ok_or_else(|| {
Error::JsonRpc(JsonRpcError::invalid_params(format!(
"Task not found: {}",
params.task_id
)))
})?;
let meta = serde_json::json!({
"io.modelcontextprotocol/related-task": task_obj
});
match task_obj.status {
TaskStatus::Cancelled => Err(Error::JsonRpc(JsonRpcError::invalid_params(
format!("Task {} was cancelled", params.task_id),
))),
TaskStatus::Failed => {
let mut call_result = CallToolResult::error(
error.unwrap_or_else(|| "Task failed".to_string()),
);
call_result.meta = Some(meta);
Ok(McpResponse::GetTaskResult(call_result))
}
_ => {
let mut call_result = result.unwrap_or_else(|| CallToolResult::text(""));
call_result.meta = Some(meta);
Ok(McpResponse::GetTaskResult(call_result))
}
}
}
McpRequest::CancelTask(params) => {
let current = self
.inner
.task_store
.get_task(¶ms.task_id)
.ok_or_else(|| {
Error::JsonRpc(JsonRpcError::invalid_params(format!(
"Task not found: {}",
params.task_id
)))
})?;
if current.status.is_terminal() {
return Err(Error::JsonRpc(JsonRpcError::invalid_params(format!(
"Task {} is already in terminal state: {}",
params.task_id, current.status
))));
}
let task_obj = self
.inner
.task_store
.cancel_task(¶ms.task_id, params.reason.as_deref())
.ok_or_else(|| {
Error::JsonRpc(JsonRpcError::invalid_params(format!(
"Task not found: {}",
params.task_id
)))
})?;
Ok(McpResponse::CancelTask(task_obj))
}
McpRequest::SetLoggingLevel(params) => {
tracing::debug!(level = ?params.level, "Client set logging level");
if let Ok(mut level) = self.inner.min_log_level.write() {
*level = params.level;
}
Ok(McpResponse::SetLoggingLevel(EmptyResult {}))
}
McpRequest::Complete(params) => {
tracing::debug!(
reference = ?params.reference,
argument = %params.argument.name,
"Completion request"
);
if let Some(ref handler) = self.inner.completion_handler {
let result = handler(params).await?;
Ok(McpResponse::Complete(result))
} else {
Ok(McpResponse::Complete(CompleteResult::new(vec![])))
}
}
McpRequest::Unknown { ref method, .. } if method == "server/discover" => {
#[cfg(feature = "stateless")]
{
use crate::protocol::SUPPORTED_PROTOCOL_VERSIONS;
let result = crate::stateless::DiscoverResult {
supported_versions: SUPPORTED_PROTOCOL_VERSIONS
.iter()
.map(|v| v.to_string())
.collect(),
capabilities: self.capabilities(),
server_info: Implementation {
name: self.inner.server_name.clone(),
version: self.inner.server_version.clone(),
title: self.inner.server_title.clone(),
description: self.inner.server_description.clone(),
icons: self.inner.server_icons.clone(),
website_url: None,
meta: None,
},
instructions: self.inner.instructions.clone(),
};
Ok(McpResponse::Raw(serde_json::to_value(result).unwrap()))
}
#[cfg(not(feature = "stateless"))]
{
Err(Error::JsonRpc(JsonRpcError::method_not_found(method)))
}
}
McpRequest::Unknown { method, .. } => {
Err(Error::JsonRpc(JsonRpcError::method_not_found(&method)))
}
_ => Err(Error::JsonRpc(JsonRpcError::method_not_found(
"unknown method",
))),
}
}
pub fn handle_notification(&self, notification: McpNotification) {
match notification {
McpNotification::Initialized => {
let phase_before = self.session.phase();
if self.session.mark_initialized() {
if phase_before == crate::session::SessionPhase::Uninitialized {
tracing::info!(
"Session initialized from uninitialized state (race resolved)"
);
} else {
tracing::info!("Session initialized, entering operation phase");
}
} else {
tracing::warn!(
phase = ?self.session.phase(),
"Received initialized notification in unexpected state"
);
}
}
McpNotification::Cancelled(params) => {
if let Some(ref request_id) = params.request_id {
if self.cancel_request(request_id) {
tracing::info!(
request_id = ?request_id,
reason = ?params.reason,
"Request cancelled"
);
} else {
tracing::debug!(
request_id = ?request_id,
reason = ?params.reason,
"Cancellation requested for unknown request"
);
}
} else {
tracing::debug!(
reason = ?params.reason,
"Cancellation notification received without request_id"
);
}
}
McpNotification::Progress(params) => {
tracing::trace!(
token = ?params.progress_token,
progress = params.progress,
total = ?params.total,
"Progress notification"
);
}
McpNotification::RootsListChanged => {
tracing::info!("Client roots list changed");
}
McpNotification::Unknown { method, .. } => {
tracing::debug!(method = %method, "Unknown notification received");
}
_ => {
tracing::debug!("Unrecognized notification variant received");
}
}
}
}
impl Default for McpRouter {
fn default() -> Self {
Self::new()
}
}
pub use crate::context::Extensions;
#[derive(Debug, Clone)]
pub struct ToolAnnotationsMap {
map: Arc<HashMap<String, ToolAnnotations>>,
}
impl ToolAnnotationsMap {
pub fn get(&self, tool_name: &str) -> Option<&ToolAnnotations> {
self.map.get(tool_name)
}
pub fn is_read_only(&self, tool_name: &str) -> bool {
self.map.get(tool_name).is_some_and(|a| a.read_only_hint)
}
pub fn is_destructive(&self, tool_name: &str) -> bool {
self.map.get(tool_name).is_none_or(|a| a.destructive_hint)
}
pub fn is_idempotent(&self, tool_name: &str) -> bool {
self.map.get(tool_name).is_some_and(|a| a.idempotent_hint)
}
}
#[derive(Debug, Clone)]
pub struct RouterRequest {
pub id: RequestId,
pub inner: McpRequest,
pub extensions: Extensions,
}
impl RouterRequest {
pub fn new(id: RequestId, inner: McpRequest) -> Self {
Self {
id,
inner,
extensions: Extensions::new(),
}
}
pub fn with_inner(self, inner: McpRequest) -> Self {
Self {
id: self.id,
inner,
extensions: self.extensions,
}
}
pub fn with_id_and_inner(self, id: RequestId, inner: McpRequest) -> Self {
Self {
id,
inner,
extensions: self.extensions,
}
}
pub fn clone_with_inner(&self, inner: McpRequest) -> Self {
Self {
id: self.id.clone(),
inner,
extensions: self.extensions.clone(),
}
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct RouterResponse {
pub id: RequestId,
pub inner: std::result::Result<McpResponse, JsonRpcError>,
}
impl RouterResponse {
pub fn is_error(&self) -> bool {
self.inner.is_err()
}
pub fn into_jsonrpc(self) -> JsonRpcResponse {
match self.inner {
Ok(response) => match serde_json::to_value(response) {
Ok(result) => JsonRpcResponse::result(self.id, result),
Err(e) => {
tracing::error!(error = %e, "Failed to serialize response");
JsonRpcResponse::error(
Some(self.id),
JsonRpcError::internal_error(format!("Serialization error: {}", e)),
)
}
},
Err(error) => JsonRpcResponse::error(Some(self.id), error),
}
}
}
impl Service<RouterRequest> for McpRouter {
type Response = RouterResponse;
type Error = std::convert::Infallible; type Future =
Pin<Box<dyn Future<Output = std::result::Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: RouterRequest) -> Self::Future {
let router = self.clone();
let request_id = req.id.clone();
Box::pin(async move {
let result = router.handle(req.id, req.inner).await;
router.complete_request(&request_id);
Ok(RouterResponse {
id: request_id,
inner: result.map_err(|e| match e {
Error::JsonRpc(err) => err,
Error::Tool(err) => JsonRpcError::internal_error(err.to_string()),
e => JsonRpcError::internal_error(e.to_string()),
}),
})
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::extract::{Context, Json};
use crate::jsonrpc::JsonRpcService;
use crate::tool::ToolBuilder;
use schemars::JsonSchema;
use serde::Deserialize;
use tower::ServiceExt;
#[derive(Debug, Deserialize, JsonSchema)]
struct AddInput {
a: i64,
b: i64,
}
async fn init_router(router: &mut McpRouter) {
let init_req = RouterRequest {
id: RequestId::Number(0),
inner: McpRequest::Initialize(InitializeParams {
protocol_version: "2025-11-25".to_string(),
capabilities: ClientCapabilities {
roots: None,
sampling: None,
elicitation: None,
tasks: None,
experimental: None,
extensions: None,
},
client_info: Implementation {
name: "test".to_string(),
version: "1.0".to_string(),
..Default::default()
},
meta: None,
}),
extensions: Extensions::new(),
};
let _ = router.ready().await.unwrap().call(init_req).await.unwrap();
router.handle_notification(McpNotification::Initialized);
}
#[tokio::test]
async fn test_router_list_tools() {
let add_tool = ToolBuilder::new("add")
.description("Add two numbers")
.handler(|input: AddInput| async move {
Ok(CallToolResult::text(format!("{}", input.a + input.b)))
})
.build();
let mut router = McpRouter::new().tool(add_tool);
init_router(&mut router).await;
let req = RouterRequest {
id: RequestId::Number(1),
inner: McpRequest::ListTools(ListToolsParams::default()),
extensions: Extensions::new(),
};
let resp = router.ready().await.unwrap().call(req).await.unwrap();
match resp.inner {
Ok(McpResponse::ListTools(result)) => {
assert_eq!(result.tools.len(), 1);
assert_eq!(result.tools[0].name, "add");
}
_ => panic!("Expected ListTools response"),
}
}
#[tokio::test]
async fn test_router_call_tool() {
let add_tool = ToolBuilder::new("add")
.description("Add two numbers")
.handler(|input: AddInput| async move {
Ok(CallToolResult::text(format!("{}", input.a + input.b)))
})
.build();
let mut router = McpRouter::new().tool(add_tool);
init_router(&mut router).await;
let req = RouterRequest {
id: RequestId::Number(1),
inner: McpRequest::CallTool(CallToolParams {
name: "add".to_string(),
arguments: serde_json::json!({"a": 2, "b": 3}),
meta: None,
task: None,
}),
extensions: Extensions::new(),
};
let resp = router.ready().await.unwrap().call(req).await.unwrap();
match resp.inner {
Ok(McpResponse::CallTool(result)) => {
assert!(!result.is_error);
match &result.content[0] {
Content::Text { text, .. } => assert_eq!(text, "5"),
_ => panic!("Expected text content"),
}
}
_ => panic!("Expected CallTool response"),
}
}
async fn init_jsonrpc_service(service: &mut JsonRpcService<McpRouter>, router: &McpRouter) {
let init_req = JsonRpcRequest::new(0, "initialize").with_params(serde_json::json!({
"protocolVersion": "2025-11-25",
"capabilities": {},
"clientInfo": { "name": "test", "version": "1.0" }
}));
let _ = service.call_single(init_req).await.unwrap();
router.handle_notification(McpNotification::Initialized);
}
#[tokio::test]
async fn test_jsonrpc_service() {
let add_tool = ToolBuilder::new("add")
.description("Add two numbers")
.handler(|input: AddInput| async move {
Ok(CallToolResult::text(format!("{}", input.a + input.b)))
})
.build();
let router = McpRouter::new().tool(add_tool);
let mut service = JsonRpcService::new(router.clone());
init_jsonrpc_service(&mut service, &router).await;
let req = JsonRpcRequest::new(1, "tools/list");
let resp = service.call_single(req).await.unwrap();
match resp {
JsonRpcResponse::Result(r) => {
assert_eq!(r.id, RequestId::Number(1));
let tools = r.result.get("tools").unwrap().as_array().unwrap();
assert_eq!(tools.len(), 1);
}
JsonRpcResponse::Error(_) => panic!("Expected success response"),
_ => panic!("unexpected response variant"),
}
}
#[tokio::test]
async fn test_batch_request() {
let add_tool = ToolBuilder::new("add")
.description("Add two numbers")
.handler(|input: AddInput| async move {
Ok(CallToolResult::text(format!("{}", input.a + input.b)))
})
.build();
let router = McpRouter::new().tool(add_tool);
let mut service = JsonRpcService::new(router.clone());
init_jsonrpc_service(&mut service, &router).await;
let requests = vec![
JsonRpcRequest::new(1, "tools/list"),
JsonRpcRequest::new(2, "tools/call").with_params(serde_json::json!({
"name": "add",
"arguments": {"a": 10, "b": 20}
})),
JsonRpcRequest::new(3, "ping"),
];
let responses = service.call_batch(requests).await.unwrap();
assert_eq!(responses.len(), 3);
match &responses[0] {
JsonRpcResponse::Result(r) => {
assert_eq!(r.id, RequestId::Number(1));
let tools = r.result.get("tools").unwrap().as_array().unwrap();
assert_eq!(tools.len(), 1);
}
JsonRpcResponse::Error(_) => panic!("Expected success for tools/list"),
_ => panic!("unexpected response variant"),
}
match &responses[1] {
JsonRpcResponse::Result(r) => {
assert_eq!(r.id, RequestId::Number(2));
let content = r.result.get("content").unwrap().as_array().unwrap();
let text = content[0].get("text").unwrap().as_str().unwrap();
assert_eq!(text, "30");
}
JsonRpcResponse::Error(_) => panic!("Expected success for tools/call"),
_ => panic!("unexpected response variant"),
}
match &responses[2] {
JsonRpcResponse::Result(r) => {
assert_eq!(r.id, RequestId::Number(3));
}
JsonRpcResponse::Error(_) => panic!("Expected success for ping"),
_ => panic!("unexpected response variant"),
}
}
#[tokio::test]
async fn test_empty_batch_error() {
let router = McpRouter::new();
let mut service = JsonRpcService::new(router);
let result = service.call_batch(vec![]).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_progress_token_extraction() {
use crate::context::{ServerNotification, notification_channel};
use crate::protocol::ProgressToken;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
let progress_reported = Arc::new(AtomicBool::new(false));
let progress_ref = progress_reported.clone();
let tool = ToolBuilder::new("progress_tool")
.description("Tool that reports progress")
.extractor_handler((), move |ctx: Context, Json(_input): Json<AddInput>| {
let reported = progress_ref.clone();
async move {
ctx.report_progress(50.0, Some(100.0), Some("Halfway"))
.await;
reported.store(true, Ordering::SeqCst);
Ok(CallToolResult::text("done"))
}
})
.build();
let (tx, mut rx) = notification_channel(10);
let router = McpRouter::new().with_notification_sender(tx).tool(tool);
let mut service = JsonRpcService::new(router.clone());
init_jsonrpc_service(&mut service, &router).await;
let req = JsonRpcRequest::new(1, "tools/call").with_params(serde_json::json!({
"name": "progress_tool",
"arguments": {"a": 1, "b": 2},
"_meta": {
"progressToken": "test-token-123"
}
}));
let resp = service.call_single(req).await.unwrap();
match resp {
JsonRpcResponse::Result(_) => {}
JsonRpcResponse::Error(e) => panic!("Expected success, got error: {:?}", e),
_ => panic!("unexpected response variant"),
}
assert!(progress_reported.load(Ordering::SeqCst));
let notification = rx.try_recv().expect("Expected progress notification");
match notification {
ServerNotification::Progress(params) => {
assert_eq!(
params.progress_token,
ProgressToken::String("test-token-123".to_string())
);
assert_eq!(params.progress, 50.0);
assert_eq!(params.total, Some(100.0));
assert_eq!(params.message.as_deref(), Some("Halfway"));
}
_ => panic!("Expected Progress notification"),
}
}
#[tokio::test]
async fn test_tool_call_without_progress_token() {
use crate::context::notification_channel;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
let progress_attempted = Arc::new(AtomicBool::new(false));
let progress_ref = progress_attempted.clone();
let tool = ToolBuilder::new("no_token_tool")
.description("Tool that tries to report progress without token")
.extractor_handler((), move |ctx: Context, Json(_input): Json<AddInput>| {
let attempted = progress_ref.clone();
async move {
ctx.report_progress(50.0, Some(100.0), None).await;
attempted.store(true, Ordering::SeqCst);
Ok(CallToolResult::text("done"))
}
})
.build();
let (tx, mut rx) = notification_channel(10);
let router = McpRouter::new().with_notification_sender(tx).tool(tool);
let mut service = JsonRpcService::new(router.clone());
init_jsonrpc_service(&mut service, &router).await;
let req = JsonRpcRequest::new(1, "tools/call").with_params(serde_json::json!({
"name": "no_token_tool",
"arguments": {"a": 1, "b": 2}
}));
let resp = service.call_single(req).await.unwrap();
assert!(matches!(resp, JsonRpcResponse::Result(_)));
assert!(progress_attempted.load(Ordering::SeqCst));
assert!(rx.try_recv().is_err());
}
#[tokio::test]
async fn test_batch_errors_returned_not_dropped() {
let add_tool = ToolBuilder::new("add")
.description("Add two numbers")
.handler(|input: AddInput| async move {
Ok(CallToolResult::text(format!("{}", input.a + input.b)))
})
.build();
let router = McpRouter::new().tool(add_tool);
let mut service = JsonRpcService::new(router.clone());
init_jsonrpc_service(&mut service, &router).await;
let requests = vec![
JsonRpcRequest::new(1, "tools/call").with_params(serde_json::json!({
"name": "add",
"arguments": {"a": 10, "b": 20}
})),
JsonRpcRequest::new(2, "tools/call").with_params(serde_json::json!({
"name": "nonexistent_tool",
"arguments": {}
})),
JsonRpcRequest::new(3, "ping"),
];
let responses = service.call_batch(requests).await.unwrap();
assert_eq!(responses.len(), 3);
match &responses[0] {
JsonRpcResponse::Result(r) => {
assert_eq!(r.id, RequestId::Number(1));
}
JsonRpcResponse::Error(_) => panic!("Expected success for first request"),
_ => panic!("unexpected response variant"),
}
match &responses[1] {
JsonRpcResponse::Error(e) => {
assert_eq!(e.id, Some(RequestId::Number(2)));
assert!(e.error.message.contains("not found") || e.error.code == -32601);
}
JsonRpcResponse::Result(_) => panic!("Expected error for second request"),
_ => panic!("unexpected response variant"),
}
match &responses[2] {
JsonRpcResponse::Result(r) => {
assert_eq!(r.id, RequestId::Number(3));
}
JsonRpcResponse::Error(_) => panic!("Expected success for third request"),
_ => panic!("unexpected response variant"),
}
}
#[tokio::test]
async fn test_list_resource_templates() {
use crate::resource::ResourceTemplateBuilder;
use std::collections::HashMap;
let template = ResourceTemplateBuilder::new("file:///{path}")
.name("Project Files")
.description("Access project files")
.handler(|uri: String, _vars: HashMap<String, String>| async move {
Ok(ReadResourceResult {
contents: vec![ResourceContent {
uri,
mime_type: None,
text: None,
blob: None,
meta: None,
}],
meta: None,
})
});
let mut router = McpRouter::new().resource_template(template);
init_router(&mut router).await;
let req = RouterRequest {
id: RequestId::Number(1),
inner: McpRequest::ListResourceTemplates(ListResourceTemplatesParams::default()),
extensions: Extensions::new(),
};
let resp = router.ready().await.unwrap().call(req).await.unwrap();
match resp.inner {
Ok(McpResponse::ListResourceTemplates(result)) => {
assert_eq!(result.resource_templates.len(), 1);
assert_eq!(result.resource_templates[0].uri_template, "file:///{path}");
assert_eq!(result.resource_templates[0].name, "Project Files");
}
_ => panic!("Expected ListResourceTemplates response"),
}
}
#[tokio::test]
async fn test_read_resource_via_template() {
use crate::resource::ResourceTemplateBuilder;
use std::collections::HashMap;
let template = ResourceTemplateBuilder::new("db://users/{id}")
.name("User Records")
.handler(|uri: String, vars: HashMap<String, String>| async move {
let id = vars.get("id").unwrap().clone();
Ok(ReadResourceResult {
contents: vec![ResourceContent {
uri,
mime_type: Some("application/json".to_string()),
text: Some(format!(r#"{{"id": "{}"}}"#, id)),
blob: None,
meta: None,
}],
meta: None,
})
});
let mut router = McpRouter::new().resource_template(template);
init_router(&mut router).await;
let req = RouterRequest {
id: RequestId::Number(1),
inner: McpRequest::ReadResource(ReadResourceParams {
uri: "db://users/123".to_string(),
meta: None,
}),
extensions: Extensions::new(),
};
let resp = router.ready().await.unwrap().call(req).await.unwrap();
match resp.inner {
Ok(McpResponse::ReadResource(result)) => {
assert_eq!(result.contents.len(), 1);
assert_eq!(result.contents[0].uri, "db://users/123");
assert!(result.contents[0].text.as_ref().unwrap().contains("123"));
}
_ => panic!("Expected ReadResource response"),
}
}
#[tokio::test]
async fn test_static_resource_takes_precedence_over_template() {
use crate::resource::{ResourceBuilder, ResourceTemplateBuilder};
use std::collections::HashMap;
let template = ResourceTemplateBuilder::new("file:///{path}")
.name("Files Template")
.handler(|uri: String, _vars: HashMap<String, String>| async move {
Ok(ReadResourceResult {
contents: vec![ResourceContent {
uri,
mime_type: None,
text: Some("from template".to_string()),
blob: None,
meta: None,
}],
meta: None,
})
});
let static_resource = ResourceBuilder::new("file:///README.md")
.name("README")
.text("from static resource");
let mut router = McpRouter::new()
.resource_template(template)
.resource(static_resource);
init_router(&mut router).await;
let req = RouterRequest {
id: RequestId::Number(1),
inner: McpRequest::ReadResource(ReadResourceParams {
uri: "file:///README.md".to_string(),
meta: None,
}),
extensions: Extensions::new(),
};
let resp = router.ready().await.unwrap().call(req).await.unwrap();
match resp.inner {
Ok(McpResponse::ReadResource(result)) => {
assert_eq!(
result.contents[0].text.as_deref(),
Some("from static resource")
);
}
_ => panic!("Expected ReadResource response"),
}
}
#[tokio::test]
async fn test_resource_not_found_when_no_match() {
use crate::resource::ResourceTemplateBuilder;
use std::collections::HashMap;
let template = ResourceTemplateBuilder::new("db://users/{id}")
.name("Users")
.handler(|uri: String, _vars: HashMap<String, String>| async move {
Ok(ReadResourceResult {
contents: vec![ResourceContent {
uri,
mime_type: None,
text: None,
blob: None,
meta: None,
}],
meta: None,
})
});
let mut router = McpRouter::new().resource_template(template);
init_router(&mut router).await;
let req = RouterRequest {
id: RequestId::Number(1),
inner: McpRequest::ReadResource(ReadResourceParams {
uri: "db://posts/123".to_string(),
meta: None,
}),
extensions: Extensions::new(),
};
let resp = router.ready().await.unwrap().call(req).await.unwrap();
match resp.inner {
Err(err) => {
assert!(err.message.contains("not found"));
}
Ok(_) => panic!("Expected error for non-matching URI"),
}
}
#[tokio::test]
async fn test_capabilities_include_resources_with_only_templates() {
use crate::resource::ResourceTemplateBuilder;
use std::collections::HashMap;
let template = ResourceTemplateBuilder::new("file:///{path}")
.name("Files")
.handler(|uri: String, _vars: HashMap<String, String>| async move {
Ok(ReadResourceResult {
contents: vec![ResourceContent {
uri,
mime_type: None,
text: None,
blob: None,
meta: None,
}],
meta: None,
})
});
let mut router = McpRouter::new().resource_template(template);
let init_req = RouterRequest {
id: RequestId::Number(0),
inner: McpRequest::Initialize(InitializeParams {
protocol_version: "2025-11-25".to_string(),
capabilities: ClientCapabilities {
roots: None,
sampling: None,
elicitation: None,
tasks: None,
experimental: None,
extensions: None,
},
client_info: Implementation {
name: "test".to_string(),
version: "1.0".to_string(),
..Default::default()
},
meta: None,
}),
extensions: Extensions::new(),
};
let resp = router.ready().await.unwrap().call(init_req).await.unwrap();
match resp.inner {
Ok(McpResponse::Initialize(result)) => {
assert!(result.capabilities.resources.is_some());
}
_ => panic!("Expected Initialize response"),
}
}
#[tokio::test]
async fn test_log_sends_notification() {
use crate::context::notification_channel;
let (tx, mut rx) = notification_channel(10);
let router = McpRouter::new().with_notification_sender(tx);
let sent = router.log_info("Test message");
assert!(sent);
let notification = rx.try_recv().unwrap();
match notification {
ServerNotification::LogMessage(params) => {
assert_eq!(params.level, LogLevel::Info);
let data = params.data;
assert_eq!(
data.get("message").unwrap().as_str().unwrap(),
"Test message"
);
}
_ => panic!("Expected LogMessage notification"),
}
}
#[tokio::test]
async fn test_log_with_custom_params() {
use crate::context::notification_channel;
let (tx, mut rx) = notification_channel(10);
let router = McpRouter::new().with_notification_sender(tx);
let params = LoggingMessageParams::new(
LogLevel::Error,
serde_json::json!({
"error": "Connection failed",
"host": "localhost"
}),
)
.with_logger("database");
let sent = router.log(params);
assert!(sent);
let notification = rx.try_recv().unwrap();
match notification {
ServerNotification::LogMessage(params) => {
assert_eq!(params.level, LogLevel::Error);
assert_eq!(params.logger.as_deref(), Some("database"));
let data = params.data;
assert_eq!(
data.get("error").unwrap().as_str().unwrap(),
"Connection failed"
);
}
_ => panic!("Expected LogMessage notification"),
}
}
#[tokio::test]
async fn test_log_without_channel_returns_false() {
let router = McpRouter::new();
assert!(!router.log_info("Test"));
assert!(!router.log_warning("Test"));
assert!(!router.log_error("Test"));
assert!(!router.log_debug("Test"));
}
#[tokio::test]
async fn test_logging_capability_with_channel() {
use crate::context::notification_channel;
let (tx, _rx) = notification_channel(10);
let mut router = McpRouter::new().with_notification_sender(tx);
let init_req = RouterRequest {
id: RequestId::Number(0),
inner: McpRequest::Initialize(InitializeParams {
protocol_version: "2025-11-25".to_string(),
capabilities: ClientCapabilities {
roots: None,
sampling: None,
elicitation: None,
tasks: None,
experimental: None,
extensions: None,
},
client_info: Implementation {
name: "test".to_string(),
version: "1.0".to_string(),
..Default::default()
},
meta: None,
}),
extensions: Extensions::new(),
};
let resp = router.ready().await.unwrap().call(init_req).await.unwrap();
match resp.inner {
Ok(McpResponse::Initialize(result)) => {
assert!(result.capabilities.logging.is_some());
}
_ => panic!("Expected Initialize response"),
}
}
#[tokio::test]
async fn test_no_logging_capability_without_channel() {
let mut router = McpRouter::new();
let init_req = RouterRequest {
id: RequestId::Number(0),
inner: McpRequest::Initialize(InitializeParams {
protocol_version: "2025-11-25".to_string(),
capabilities: ClientCapabilities {
roots: None,
sampling: None,
elicitation: None,
tasks: None,
experimental: None,
extensions: None,
},
client_info: Implementation {
name: "test".to_string(),
version: "1.0".to_string(),
..Default::default()
},
meta: None,
}),
extensions: Extensions::new(),
};
let resp = router.ready().await.unwrap().call(init_req).await.unwrap();
match resp.inner {
Ok(McpResponse::Initialize(result)) => {
assert!(result.capabilities.logging.is_none());
}
_ => panic!("Expected Initialize response"),
}
}
#[tokio::test]
async fn test_create_task_via_call_tool() {
let add_tool = ToolBuilder::new("add")
.description("Add two numbers")
.task_support(TaskSupportMode::Optional)
.handler(|input: AddInput| async move {
Ok(CallToolResult::text(format!("{}", input.a + input.b)))
})
.build();
let mut router = McpRouter::new().tool(add_tool);
init_router(&mut router).await;
let req = RouterRequest {
id: RequestId::Number(1),
inner: McpRequest::CallTool(CallToolParams {
name: "add".to_string(),
arguments: serde_json::json!({"a": 5, "b": 10}),
meta: None,
task: Some(TaskRequestParams { ttl: None }),
}),
extensions: Extensions::new(),
};
let resp = router.ready().await.unwrap().call(req).await.unwrap();
match resp.inner {
Ok(McpResponse::CreateTask(result)) => {
assert!(result.task.task_id.starts_with("task-"));
assert_eq!(result.task.status, TaskStatus::Working);
}
_ => panic!("Expected CreateTask response"),
}
}
#[tokio::test]
async fn test_list_tasks_empty() {
let mut router = McpRouter::new();
init_router(&mut router).await;
let req = RouterRequest {
id: RequestId::Number(1),
inner: McpRequest::ListTasks(ListTasksParams::default()),
extensions: Extensions::new(),
};
let resp = router.ready().await.unwrap().call(req).await.unwrap();
match resp.inner {
Ok(McpResponse::ListTasks(result)) => {
assert!(result.tasks.is_empty());
}
_ => panic!("Expected ListTasks response"),
}
}
#[tokio::test]
async fn test_task_lifecycle_complete() {
let add_tool = ToolBuilder::new("add")
.description("Add two numbers")
.task_support(TaskSupportMode::Optional)
.handler(|input: AddInput| async move {
Ok(CallToolResult::text(format!("{}", input.a + input.b)))
})
.build();
let mut router = McpRouter::new().tool(add_tool);
init_router(&mut router).await;
let req = RouterRequest {
id: RequestId::Number(1),
inner: McpRequest::CallTool(CallToolParams {
name: "add".to_string(),
arguments: serde_json::json!({"a": 7, "b": 8}),
meta: None,
task: Some(TaskRequestParams { ttl: None }),
}),
extensions: Extensions::new(),
};
let resp = router.ready().await.unwrap().call(req).await.unwrap();
let task_id = match resp.inner {
Ok(McpResponse::CreateTask(result)) => result.task.task_id,
_ => panic!("Expected CreateTask response"),
};
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
let req = RouterRequest {
id: RequestId::Number(2),
inner: McpRequest::GetTaskResult(GetTaskResultParams {
task_id: task_id.clone(),
meta: None,
}),
extensions: Extensions::new(),
};
let resp = router.ready().await.unwrap().call(req).await.unwrap();
match resp.inner {
Ok(McpResponse::GetTaskResult(result)) => {
assert!(result.meta.is_some());
match &result.content[0] {
Content::Text { text, .. } => assert_eq!(text, "15"),
_ => panic!("Expected text content"),
}
}
_ => panic!("Expected GetTaskResult response"),
}
}
#[tokio::test]
async fn test_task_cancellation() {
let slow_tool = ToolBuilder::new("slow")
.description("Slow tool")
.task_support(TaskSupportMode::Optional)
.handler(|_input: serde_json::Value| async move {
tokio::time::sleep(tokio::time::Duration::from_secs(60)).await;
Ok(CallToolResult::text("done"))
})
.build();
let mut router = McpRouter::new().tool(slow_tool);
init_router(&mut router).await;
let req = RouterRequest {
id: RequestId::Number(1),
inner: McpRequest::CallTool(CallToolParams {
name: "slow".to_string(),
arguments: serde_json::json!({}),
meta: None,
task: Some(TaskRequestParams { ttl: None }),
}),
extensions: Extensions::new(),
};
let resp = router.ready().await.unwrap().call(req).await.unwrap();
let task_id = match resp.inner {
Ok(McpResponse::CreateTask(result)) => result.task.task_id,
_ => panic!("Expected CreateTask response"),
};
let req = RouterRequest {
id: RequestId::Number(2),
inner: McpRequest::CancelTask(CancelTaskParams {
task_id: task_id.clone(),
reason: Some("Test cancellation".to_string()),
meta: None,
}),
extensions: Extensions::new(),
};
let resp = router.ready().await.unwrap().call(req).await.unwrap();
match resp.inner {
Ok(McpResponse::CancelTask(task_obj)) => {
assert_eq!(task_obj.status, TaskStatus::Cancelled);
}
_ => panic!("Expected CancelTask response"),
}
}
#[tokio::test]
async fn test_get_task_info() {
let add_tool = ToolBuilder::new("add")
.description("Add two numbers")
.task_support(TaskSupportMode::Optional)
.handler(|input: AddInput| async move {
Ok(CallToolResult::text(format!("{}", input.a + input.b)))
})
.build();
let mut router = McpRouter::new().tool(add_tool);
init_router(&mut router).await;
let req = RouterRequest {
id: RequestId::Number(1),
inner: McpRequest::CallTool(CallToolParams {
name: "add".to_string(),
arguments: serde_json::json!({"a": 1, "b": 2}),
meta: None,
task: Some(TaskRequestParams { ttl: Some(600_000) }),
}),
extensions: Extensions::new(),
};
let resp = router.ready().await.unwrap().call(req).await.unwrap();
let task_id = match resp.inner {
Ok(McpResponse::CreateTask(result)) => result.task.task_id,
_ => panic!("Expected CreateTask response"),
};
let req = RouterRequest {
id: RequestId::Number(2),
inner: McpRequest::GetTaskInfo(GetTaskInfoParams {
task_id: task_id.clone(),
meta: None,
}),
extensions: Extensions::new(),
};
let resp = router.ready().await.unwrap().call(req).await.unwrap();
match resp.inner {
Ok(McpResponse::GetTaskInfo(info)) => {
assert_eq!(info.task_id, task_id);
assert!(info.created_at.contains('T')); assert_eq!(info.ttl, Some(600_000));
}
_ => panic!("Expected GetTaskInfo response"),
}
}
#[tokio::test]
async fn test_task_forbidden_tool_rejects_task_params() {
let tool = ToolBuilder::new("sync_only")
.description("Sync only tool")
.handler(|_input: serde_json::Value| async move { Ok(CallToolResult::text("ok")) })
.build();
let mut router = McpRouter::new().tool(tool);
init_router(&mut router).await;
let req = RouterRequest {
id: RequestId::Number(1),
inner: McpRequest::CallTool(CallToolParams {
name: "sync_only".to_string(),
arguments: serde_json::json!({}),
meta: None,
task: Some(TaskRequestParams { ttl: None }),
}),
extensions: Extensions::new(),
};
let resp = router.ready().await.unwrap().call(req).await.unwrap();
match resp.inner {
Err(e) => {
assert!(e.message.contains("does not support async tasks"));
}
_ => panic!("Expected error response"),
}
}
#[tokio::test]
async fn test_get_nonexistent_task() {
let mut router = McpRouter::new();
init_router(&mut router).await;
let req = RouterRequest {
id: RequestId::Number(1),
inner: McpRequest::GetTaskInfo(GetTaskInfoParams {
task_id: "task-999".to_string(),
meta: None,
}),
extensions: Extensions::new(),
};
let resp = router.ready().await.unwrap().call(req).await.unwrap();
match resp.inner {
Err(e) => {
assert!(e.message.contains("not found"));
}
_ => panic!("Expected error response"),
}
}
#[tokio::test]
async fn test_subscribe_to_resource() {
use crate::resource::ResourceBuilder;
let resource = ResourceBuilder::new("file:///test.txt")
.name("Test File")
.text("Hello");
let mut router = McpRouter::new().resource(resource);
init_router(&mut router).await;
let req = RouterRequest {
id: RequestId::Number(1),
inner: McpRequest::SubscribeResource(SubscribeResourceParams {
uri: "file:///test.txt".to_string(),
meta: None,
}),
extensions: Extensions::new(),
};
let resp = router.ready().await.unwrap().call(req).await.unwrap();
match resp.inner {
Ok(McpResponse::SubscribeResource(_)) => {
assert!(router.is_subscribed("file:///test.txt"));
}
_ => panic!("Expected SubscribeResource response"),
}
}
#[tokio::test]
async fn test_unsubscribe_from_resource() {
use crate::resource::ResourceBuilder;
let resource = ResourceBuilder::new("file:///test.txt")
.name("Test File")
.text("Hello");
let mut router = McpRouter::new().resource(resource);
init_router(&mut router).await;
let req = RouterRequest {
id: RequestId::Number(1),
inner: McpRequest::SubscribeResource(SubscribeResourceParams {
uri: "file:///test.txt".to_string(),
meta: None,
}),
extensions: Extensions::new(),
};
let _ = router.ready().await.unwrap().call(req).await.unwrap();
assert!(router.is_subscribed("file:///test.txt"));
let req = RouterRequest {
id: RequestId::Number(2),
inner: McpRequest::UnsubscribeResource(UnsubscribeResourceParams {
uri: "file:///test.txt".to_string(),
meta: None,
}),
extensions: Extensions::new(),
};
let resp = router.ready().await.unwrap().call(req).await.unwrap();
match resp.inner {
Ok(McpResponse::UnsubscribeResource(_)) => {
assert!(!router.is_subscribed("file:///test.txt"));
}
_ => panic!("Expected UnsubscribeResource response"),
}
}
#[tokio::test]
async fn test_subscribe_nonexistent_resource() {
let mut router = McpRouter::new();
init_router(&mut router).await;
let req = RouterRequest {
id: RequestId::Number(1),
inner: McpRequest::SubscribeResource(SubscribeResourceParams {
uri: "file:///nonexistent.txt".to_string(),
meta: None,
}),
extensions: Extensions::new(),
};
let resp = router.ready().await.unwrap().call(req).await.unwrap();
match resp.inner {
Err(e) => {
assert!(e.message.contains("not found"));
}
_ => panic!("Expected error response"),
}
}
#[tokio::test]
async fn test_notify_resource_updated() {
use crate::context::notification_channel;
use crate::resource::ResourceBuilder;
let (tx, mut rx) = notification_channel(10);
let resource = ResourceBuilder::new("file:///test.txt")
.name("Test File")
.text("Hello");
let router = McpRouter::new()
.resource(resource)
.with_notification_sender(tx);
router.subscribe("file:///test.txt");
let sent = router.notify_resource_updated("file:///test.txt");
assert!(sent);
let notification = rx.try_recv().unwrap();
match notification {
ServerNotification::ResourceUpdated { uri } => {
assert_eq!(uri, "file:///test.txt");
}
_ => panic!("Expected ResourceUpdated notification"),
}
}
#[tokio::test]
async fn test_notify_resource_updated_not_subscribed() {
use crate::context::notification_channel;
use crate::resource::ResourceBuilder;
let (tx, mut rx) = notification_channel(10);
let resource = ResourceBuilder::new("file:///test.txt")
.name("Test File")
.text("Hello");
let router = McpRouter::new()
.resource(resource)
.with_notification_sender(tx);
let sent = router.notify_resource_updated("file:///test.txt");
assert!(!sent);
assert!(rx.try_recv().is_err());
}
#[tokio::test]
async fn test_notify_resources_list_changed() {
use crate::context::notification_channel;
let (tx, mut rx) = notification_channel(10);
let router = McpRouter::new().with_notification_sender(tx);
let sent = router.notify_resources_list_changed();
assert!(sent);
let notification = rx.try_recv().unwrap();
match notification {
ServerNotification::ResourcesListChanged => {}
_ => panic!("Expected ResourcesListChanged notification"),
}
}
#[tokio::test]
async fn test_subscribed_uris() {
use crate::resource::ResourceBuilder;
let resource1 = ResourceBuilder::new("file:///a.txt").name("A").text("A");
let resource2 = ResourceBuilder::new("file:///b.txt").name("B").text("B");
let router = McpRouter::new().resource(resource1).resource(resource2);
router.subscribe("file:///a.txt");
router.subscribe("file:///b.txt");
let uris = router.subscribed_uris();
assert_eq!(uris.len(), 2);
assert!(uris.contains(&"file:///a.txt".to_string()));
assert!(uris.contains(&"file:///b.txt".to_string()));
}
#[tokio::test]
async fn test_subscription_capability_advertised() {
use crate::resource::ResourceBuilder;
let resource = ResourceBuilder::new("file:///test.txt")
.name("Test")
.text("Hello");
let mut router = McpRouter::new().resource(resource);
let init_req = RouterRequest {
id: RequestId::Number(0),
inner: McpRequest::Initialize(InitializeParams {
protocol_version: "2025-11-25".to_string(),
capabilities: ClientCapabilities {
roots: None,
sampling: None,
elicitation: None,
tasks: None,
experimental: None,
extensions: None,
},
client_info: Implementation {
name: "test".to_string(),
version: "1.0".to_string(),
..Default::default()
},
meta: None,
}),
extensions: Extensions::new(),
};
let resp = router.ready().await.unwrap().call(init_req).await.unwrap();
match resp.inner {
Ok(McpResponse::Initialize(result)) => {
let resources_cap = result.capabilities.resources.unwrap();
assert!(resources_cap.subscribe);
}
_ => panic!("Expected Initialize response"),
}
}
#[tokio::test]
async fn test_completion_handler() {
let router = McpRouter::new()
.server_info("test", "1.0")
.completion_handler(|params: CompleteParams| async move {
let prefix = ¶ms.argument.value;
let suggestions: Vec<String> = vec!["alpha", "beta", "gamma"]
.into_iter()
.filter(|s| s.starts_with(prefix))
.map(String::from)
.collect();
Ok(CompleteResult::new(suggestions))
});
let init_req = RouterRequest {
id: RequestId::Number(0),
inner: McpRequest::Initialize(InitializeParams {
protocol_version: "2025-11-25".to_string(),
capabilities: ClientCapabilities::default(),
client_info: Implementation {
name: "test".to_string(),
version: "1.0".to_string(),
..Default::default()
},
meta: None,
}),
extensions: Extensions::new(),
};
let resp = router
.clone()
.ready()
.await
.unwrap()
.call(init_req)
.await
.unwrap();
match resp.inner {
Ok(McpResponse::Initialize(result)) => {
assert!(result.capabilities.completions.is_some());
}
_ => panic!("Expected Initialize response"),
}
router.handle_notification(McpNotification::Initialized);
let complete_req = RouterRequest {
id: RequestId::Number(1),
inner: McpRequest::Complete(CompleteParams {
reference: CompletionReference::prompt("test-prompt"),
argument: CompletionArgument::new("query", "al"),
context: None,
meta: None,
}),
extensions: Extensions::new(),
};
let resp = router
.clone()
.ready()
.await
.unwrap()
.call(complete_req)
.await
.unwrap();
match resp.inner {
Ok(McpResponse::Complete(result)) => {
assert_eq!(result.completion.values, vec!["alpha"]);
}
_ => panic!("Expected Complete response"),
}
}
#[tokio::test]
async fn test_completion_without_handler_returns_empty() {
let router = McpRouter::new().server_info("test", "1.0");
let init_req = RouterRequest {
id: RequestId::Number(0),
inner: McpRequest::Initialize(InitializeParams {
protocol_version: "2025-11-25".to_string(),
capabilities: ClientCapabilities::default(),
client_info: Implementation {
name: "test".to_string(),
version: "1.0".to_string(),
..Default::default()
},
meta: None,
}),
extensions: Extensions::new(),
};
let resp = router
.clone()
.ready()
.await
.unwrap()
.call(init_req)
.await
.unwrap();
match resp.inner {
Ok(McpResponse::Initialize(result)) => {
assert!(result.capabilities.completions.is_none());
}
_ => panic!("Expected Initialize response"),
}
router.handle_notification(McpNotification::Initialized);
let complete_req = RouterRequest {
id: RequestId::Number(1),
inner: McpRequest::Complete(CompleteParams {
reference: CompletionReference::prompt("test-prompt"),
argument: CompletionArgument::new("query", "al"),
context: None,
meta: None,
}),
extensions: Extensions::new(),
};
let resp = router
.clone()
.ready()
.await
.unwrap()
.call(complete_req)
.await
.unwrap();
match resp.inner {
Ok(McpResponse::Complete(result)) => {
assert!(result.completion.values.is_empty());
}
_ => panic!("Expected Complete response"),
}
}
#[tokio::test]
async fn test_tool_filter_list() {
use crate::filter::CapabilityFilter;
use crate::tool::Tool;
let public_tool = ToolBuilder::new("public")
.description("Public tool")
.handler(|_: AddInput| async move { Ok(CallToolResult::text("public")) })
.build();
let admin_tool = ToolBuilder::new("admin")
.description("Admin tool")
.handler(|_: AddInput| async move { Ok(CallToolResult::text("admin")) })
.build();
let mut router = McpRouter::new()
.tool(public_tool)
.tool(admin_tool)
.tool_filter(CapabilityFilter::new(|_, tool: &Tool| tool.name != "admin"));
init_router(&mut router).await;
let req = RouterRequest {
id: RequestId::Number(1),
inner: McpRequest::ListTools(ListToolsParams::default()),
extensions: Extensions::new(),
};
let resp = router.ready().await.unwrap().call(req).await.unwrap();
match resp.inner {
Ok(McpResponse::ListTools(result)) => {
assert_eq!(result.tools.len(), 1);
assert_eq!(result.tools[0].name, "public");
}
_ => panic!("Expected ListTools response"),
}
}
#[tokio::test]
async fn test_tool_filter_call_denied() {
use crate::filter::CapabilityFilter;
use crate::tool::Tool;
let admin_tool = ToolBuilder::new("admin")
.description("Admin tool")
.handler(|_: AddInput| async move { Ok(CallToolResult::text("admin")) })
.build();
let mut router = McpRouter::new()
.tool(admin_tool)
.tool_filter(CapabilityFilter::new(|_, _: &Tool| false));
init_router(&mut router).await;
let req = RouterRequest {
id: RequestId::Number(1),
inner: McpRequest::CallTool(CallToolParams {
name: "admin".to_string(),
arguments: serde_json::json!({"a": 1, "b": 2}),
meta: None,
task: None,
}),
extensions: Extensions::new(),
};
let resp = router.ready().await.unwrap().call(req).await.unwrap();
match resp.inner {
Err(e) => {
assert_eq!(e.code, -32601); }
_ => panic!("Expected JsonRpc error"),
}
}
#[tokio::test]
async fn test_tool_filter_call_allowed() {
use crate::filter::CapabilityFilter;
use crate::tool::Tool;
let public_tool = ToolBuilder::new("public")
.description("Public tool")
.handler(|input: AddInput| async move {
Ok(CallToolResult::text(format!("{}", input.a + input.b)))
})
.build();
let mut router = McpRouter::new()
.tool(public_tool)
.tool_filter(CapabilityFilter::new(|_, _: &Tool| true));
init_router(&mut router).await;
let req = RouterRequest {
id: RequestId::Number(1),
inner: McpRequest::CallTool(CallToolParams {
name: "public".to_string(),
arguments: serde_json::json!({"a": 1, "b": 2}),
meta: None,
task: None,
}),
extensions: Extensions::new(),
};
let resp = router.ready().await.unwrap().call(req).await.unwrap();
match resp.inner {
Ok(McpResponse::CallTool(result)) => {
assert!(!result.is_error);
}
_ => panic!("Expected CallTool response"),
}
}
#[tokio::test]
async fn test_tool_filter_custom_denial() {
use crate::filter::{CapabilityFilter, DenialBehavior};
use crate::tool::Tool;
let admin_tool = ToolBuilder::new("admin")
.description("Admin tool")
.handler(|_: AddInput| async move { Ok(CallToolResult::text("admin")) })
.build();
let mut router = McpRouter::new().tool(admin_tool).tool_filter(
CapabilityFilter::new(|_, _: &Tool| false)
.denial_behavior(DenialBehavior::Unauthorized),
);
init_router(&mut router).await;
let req = RouterRequest {
id: RequestId::Number(1),
inner: McpRequest::CallTool(CallToolParams {
name: "admin".to_string(),
arguments: serde_json::json!({"a": 1, "b": 2}),
meta: None,
task: None,
}),
extensions: Extensions::new(),
};
let resp = router.ready().await.unwrap().call(req).await.unwrap();
match resp.inner {
Err(e) => {
assert_eq!(e.code, -32007); assert!(e.message.contains("Unauthorized"));
}
_ => panic!("Expected JsonRpc error"),
}
}
#[tokio::test]
async fn test_resource_filter_list() {
use crate::filter::CapabilityFilter;
use crate::resource::{Resource, ResourceBuilder};
let public_resource = ResourceBuilder::new("file:///public.txt")
.name("Public File")
.text("public content");
let secret_resource = ResourceBuilder::new("file:///secret.txt")
.name("Secret File")
.text("secret content");
let mut router = McpRouter::new()
.resource(public_resource)
.resource(secret_resource)
.resource_filter(CapabilityFilter::new(|_, r: &Resource| {
!r.name.contains("Secret")
}));
init_router(&mut router).await;
let req = RouterRequest {
id: RequestId::Number(1),
inner: McpRequest::ListResources(ListResourcesParams::default()),
extensions: Extensions::new(),
};
let resp = router.ready().await.unwrap().call(req).await.unwrap();
match resp.inner {
Ok(McpResponse::ListResources(result)) => {
assert_eq!(result.resources.len(), 1);
assert_eq!(result.resources[0].name, "Public File");
}
_ => panic!("Expected ListResources response"),
}
}
#[tokio::test]
async fn test_resource_filter_read_denied() {
use crate::filter::CapabilityFilter;
use crate::resource::{Resource, ResourceBuilder};
let secret_resource = ResourceBuilder::new("file:///secret.txt")
.name("Secret File")
.text("secret content");
let mut router = McpRouter::new()
.resource(secret_resource)
.resource_filter(CapabilityFilter::new(|_, _: &Resource| false));
init_router(&mut router).await;
let req = RouterRequest {
id: RequestId::Number(1),
inner: McpRequest::ReadResource(ReadResourceParams {
uri: "file:///secret.txt".to_string(),
meta: None,
}),
extensions: Extensions::new(),
};
let resp = router.ready().await.unwrap().call(req).await.unwrap();
match resp.inner {
Err(e) => {
assert_eq!(e.code, -32601); }
_ => panic!("Expected JsonRpc error"),
}
}
#[tokio::test]
async fn test_resource_filter_read_allowed() {
use crate::filter::CapabilityFilter;
use crate::resource::{Resource, ResourceBuilder};
let public_resource = ResourceBuilder::new("file:///public.txt")
.name("Public File")
.text("public content");
let mut router = McpRouter::new()
.resource(public_resource)
.resource_filter(CapabilityFilter::new(|_, _: &Resource| true));
init_router(&mut router).await;
let req = RouterRequest {
id: RequestId::Number(1),
inner: McpRequest::ReadResource(ReadResourceParams {
uri: "file:///public.txt".to_string(),
meta: None,
}),
extensions: Extensions::new(),
};
let resp = router.ready().await.unwrap().call(req).await.unwrap();
match resp.inner {
Ok(McpResponse::ReadResource(result)) => {
assert_eq!(result.contents.len(), 1);
assert_eq!(result.contents[0].text.as_deref(), Some("public content"));
}
_ => panic!("Expected ReadResource response"),
}
}
#[tokio::test]
async fn test_resource_filter_custom_denial() {
use crate::filter::{CapabilityFilter, DenialBehavior};
use crate::resource::{Resource, ResourceBuilder};
let secret_resource = ResourceBuilder::new("file:///secret.txt")
.name("Secret File")
.text("secret content");
let mut router = McpRouter::new().resource(secret_resource).resource_filter(
CapabilityFilter::new(|_, _: &Resource| false)
.denial_behavior(DenialBehavior::Unauthorized),
);
init_router(&mut router).await;
let req = RouterRequest {
id: RequestId::Number(1),
inner: McpRequest::ReadResource(ReadResourceParams {
uri: "file:///secret.txt".to_string(),
meta: None,
}),
extensions: Extensions::new(),
};
let resp = router.ready().await.unwrap().call(req).await.unwrap();
match resp.inner {
Err(e) => {
assert_eq!(e.code, -32007); assert!(e.message.contains("Unauthorized"));
}
_ => panic!("Expected JsonRpc error"),
}
}
#[tokio::test]
async fn test_prompt_filter_list() {
use crate::filter::CapabilityFilter;
use crate::prompt::{Prompt, PromptBuilder};
let public_prompt = PromptBuilder::new("greeting")
.description("A greeting")
.user_message("Hello!");
let admin_prompt = PromptBuilder::new("system_debug")
.description("Admin prompt")
.user_message("Debug");
let mut router = McpRouter::new()
.prompt(public_prompt)
.prompt(admin_prompt)
.prompt_filter(CapabilityFilter::new(|_, p: &Prompt| {
!p.name.contains("system")
}));
init_router(&mut router).await;
let req = RouterRequest {
id: RequestId::Number(1),
inner: McpRequest::ListPrompts(ListPromptsParams::default()),
extensions: Extensions::new(),
};
let resp = router.ready().await.unwrap().call(req).await.unwrap();
match resp.inner {
Ok(McpResponse::ListPrompts(result)) => {
assert_eq!(result.prompts.len(), 1);
assert_eq!(result.prompts[0].name, "greeting");
}
_ => panic!("Expected ListPrompts response"),
}
}
#[tokio::test]
async fn test_prompt_filter_get_denied() {
use crate::filter::CapabilityFilter;
use crate::prompt::{Prompt, PromptBuilder};
use std::collections::HashMap;
let admin_prompt = PromptBuilder::new("system_debug")
.description("Admin prompt")
.user_message("Debug");
let mut router = McpRouter::new()
.prompt(admin_prompt)
.prompt_filter(CapabilityFilter::new(|_, _: &Prompt| false));
init_router(&mut router).await;
let req = RouterRequest {
id: RequestId::Number(1),
inner: McpRequest::GetPrompt(GetPromptParams {
name: "system_debug".to_string(),
arguments: HashMap::new(),
meta: None,
}),
extensions: Extensions::new(),
};
let resp = router.ready().await.unwrap().call(req).await.unwrap();
match resp.inner {
Err(e) => {
assert_eq!(e.code, -32601); }
_ => panic!("Expected JsonRpc error"),
}
}
#[tokio::test]
async fn test_prompt_filter_get_allowed() {
use crate::filter::CapabilityFilter;
use crate::prompt::{Prompt, PromptBuilder};
use std::collections::HashMap;
let public_prompt = PromptBuilder::new("greeting")
.description("A greeting")
.user_message("Hello!");
let mut router = McpRouter::new()
.prompt(public_prompt)
.prompt_filter(CapabilityFilter::new(|_, _: &Prompt| true));
init_router(&mut router).await;
let req = RouterRequest {
id: RequestId::Number(1),
inner: McpRequest::GetPrompt(GetPromptParams {
name: "greeting".to_string(),
arguments: HashMap::new(),
meta: None,
}),
extensions: Extensions::new(),
};
let resp = router.ready().await.unwrap().call(req).await.unwrap();
match resp.inner {
Ok(McpResponse::GetPrompt(result)) => {
assert_eq!(result.messages.len(), 1);
}
_ => panic!("Expected GetPrompt response"),
}
}
#[tokio::test]
async fn test_prompt_filter_custom_denial() {
use crate::filter::{CapabilityFilter, DenialBehavior};
use crate::prompt::{Prompt, PromptBuilder};
use std::collections::HashMap;
let admin_prompt = PromptBuilder::new("system_debug")
.description("Admin prompt")
.user_message("Debug");
let mut router = McpRouter::new().prompt(admin_prompt).prompt_filter(
CapabilityFilter::new(|_, _: &Prompt| false)
.denial_behavior(DenialBehavior::Unauthorized),
);
init_router(&mut router).await;
let req = RouterRequest {
id: RequestId::Number(1),
inner: McpRequest::GetPrompt(GetPromptParams {
name: "system_debug".to_string(),
arguments: HashMap::new(),
meta: None,
}),
extensions: Extensions::new(),
};
let resp = router.ready().await.unwrap().call(req).await.unwrap();
match resp.inner {
Err(e) => {
assert_eq!(e.code, -32007); assert!(e.message.contains("Unauthorized"));
}
_ => panic!("Expected JsonRpc error"),
}
}
#[derive(Debug, Deserialize, JsonSchema)]
struct StringInput {
value: String,
}
#[tokio::test]
async fn test_router_merge_tools() {
let tool_a = ToolBuilder::new("tool_a")
.description("Tool A")
.handler(|_: StringInput| async move { Ok(CallToolResult::text("A")) })
.build();
let router_a = McpRouter::new().tool(tool_a);
let tool_b = ToolBuilder::new("tool_b")
.description("Tool B")
.handler(|_: StringInput| async move { Ok(CallToolResult::text("B")) })
.build();
let tool_c = ToolBuilder::new("tool_c")
.description("Tool C")
.handler(|_: StringInput| async move { Ok(CallToolResult::text("C")) })
.build();
let router_b = McpRouter::new().tool(tool_b).tool(tool_c);
let mut merged = McpRouter::new()
.server_info("merged", "1.0")
.merge(router_a)
.merge(router_b);
init_router(&mut merged).await;
let req = RouterRequest {
id: RequestId::Number(1),
inner: McpRequest::ListTools(ListToolsParams::default()),
extensions: Extensions::new(),
};
let resp = merged.ready().await.unwrap().call(req).await.unwrap();
match resp.inner {
Ok(McpResponse::ListTools(result)) => {
assert_eq!(result.tools.len(), 3);
let names: Vec<&str> = result.tools.iter().map(|t| t.name.as_str()).collect();
assert!(names.contains(&"tool_a"));
assert!(names.contains(&"tool_b"));
assert!(names.contains(&"tool_c"));
}
_ => panic!("Expected ListTools response"),
}
}
#[tokio::test]
async fn test_router_merge_overwrites_duplicates() {
let tool_v1 = ToolBuilder::new("shared")
.description("Version 1")
.handler(|_: StringInput| async move { Ok(CallToolResult::text("v1")) })
.build();
let router_a = McpRouter::new().tool(tool_v1);
let tool_v2 = ToolBuilder::new("shared")
.description("Version 2")
.handler(|_: StringInput| async move { Ok(CallToolResult::text("v2")) })
.build();
let router_b = McpRouter::new().tool(tool_v2);
let mut merged = McpRouter::new().merge(router_a).merge(router_b);
init_router(&mut merged).await;
let req = RouterRequest {
id: RequestId::Number(1),
inner: McpRequest::ListTools(ListToolsParams::default()),
extensions: Extensions::new(),
};
let resp = merged.ready().await.unwrap().call(req).await.unwrap();
match resp.inner {
Ok(McpResponse::ListTools(result)) => {
assert_eq!(result.tools.len(), 1);
assert_eq!(result.tools[0].name, "shared");
assert_eq!(result.tools[0].description.as_deref(), Some("Version 2"));
}
_ => panic!("Expected ListTools response"),
}
}
#[tokio::test]
async fn test_router_merge_resources() {
use crate::resource::ResourceBuilder;
let router_a = McpRouter::new().resource(
ResourceBuilder::new("file:///a.txt")
.name("File A")
.text("content a"),
);
let router_b = McpRouter::new().resource(
ResourceBuilder::new("file:///b.txt")
.name("File B")
.text("content b"),
);
let mut merged = McpRouter::new().merge(router_a).merge(router_b);
init_router(&mut merged).await;
let req = RouterRequest {
id: RequestId::Number(1),
inner: McpRequest::ListResources(ListResourcesParams::default()),
extensions: Extensions::new(),
};
let resp = merged.ready().await.unwrap().call(req).await.unwrap();
match resp.inner {
Ok(McpResponse::ListResources(result)) => {
assert_eq!(result.resources.len(), 2);
let uris: Vec<&str> = result.resources.iter().map(|r| r.uri.as_str()).collect();
assert!(uris.contains(&"file:///a.txt"));
assert!(uris.contains(&"file:///b.txt"));
}
_ => panic!("Expected ListResources response"),
}
}
#[tokio::test]
async fn test_router_merge_prompts() {
use crate::prompt::PromptBuilder;
let router_a =
McpRouter::new().prompt(PromptBuilder::new("prompt_a").user_message("Hello A"));
let router_b =
McpRouter::new().prompt(PromptBuilder::new("prompt_b").user_message("Hello B"));
let mut merged = McpRouter::new().merge(router_a).merge(router_b);
init_router(&mut merged).await;
let req = RouterRequest {
id: RequestId::Number(1),
inner: McpRequest::ListPrompts(ListPromptsParams::default()),
extensions: Extensions::new(),
};
let resp = merged.ready().await.unwrap().call(req).await.unwrap();
match resp.inner {
Ok(McpResponse::ListPrompts(result)) => {
assert_eq!(result.prompts.len(), 2);
let names: Vec<&str> = result.prompts.iter().map(|p| p.name.as_str()).collect();
assert!(names.contains(&"prompt_a"));
assert!(names.contains(&"prompt_b"));
}
_ => panic!("Expected ListPrompts response"),
}
}
#[tokio::test]
async fn test_router_nest_prefixes_tools() {
let tool_query = ToolBuilder::new("query")
.description("Query the database")
.handler(|_: StringInput| async move { Ok(CallToolResult::text("query result")) })
.build();
let tool_insert = ToolBuilder::new("insert")
.description("Insert into database")
.handler(|_: StringInput| async move { Ok(CallToolResult::text("insert result")) })
.build();
let db_router = McpRouter::new().tool(tool_query).tool(tool_insert);
let mut router = McpRouter::new()
.server_info("nested", "1.0")
.nest("db", db_router);
init_router(&mut router).await;
let req = RouterRequest {
id: RequestId::Number(1),
inner: McpRequest::ListTools(ListToolsParams::default()),
extensions: Extensions::new(),
};
let resp = router.ready().await.unwrap().call(req).await.unwrap();
match resp.inner {
Ok(McpResponse::ListTools(result)) => {
assert_eq!(result.tools.len(), 2);
let names: Vec<&str> = result.tools.iter().map(|t| t.name.as_str()).collect();
assert!(names.contains(&"db.query"));
assert!(names.contains(&"db.insert"));
}
_ => panic!("Expected ListTools response"),
}
}
#[tokio::test]
async fn test_router_nest_call_prefixed_tool() {
let tool = ToolBuilder::new("echo")
.description("Echo input")
.handler(|input: StringInput| async move { Ok(CallToolResult::text(&input.value)) })
.build();
let nested_router = McpRouter::new().tool(tool);
let mut router = McpRouter::new().nest("api", nested_router);
init_router(&mut router).await;
let req = RouterRequest {
id: RequestId::Number(1),
inner: McpRequest::CallTool(CallToolParams {
name: "api.echo".to_string(),
arguments: serde_json::json!({"value": "hello world"}),
meta: None,
task: None,
}),
extensions: Extensions::new(),
};
let resp = router.ready().await.unwrap().call(req).await.unwrap();
match resp.inner {
Ok(McpResponse::CallTool(result)) => {
assert!(!result.is_error);
match &result.content[0] {
Content::Text { text, .. } => assert_eq!(text, "hello world"),
_ => panic!("Expected text content"),
}
}
_ => panic!("Expected CallTool response"),
}
}
#[tokio::test]
async fn test_router_multiple_nests() {
let db_tool = ToolBuilder::new("query")
.description("Database query")
.handler(|_: StringInput| async move { Ok(CallToolResult::text("db")) })
.build();
let api_tool = ToolBuilder::new("fetch")
.description("API fetch")
.handler(|_: StringInput| async move { Ok(CallToolResult::text("api")) })
.build();
let db_router = McpRouter::new().tool(db_tool);
let api_router = McpRouter::new().tool(api_tool);
let mut router = McpRouter::new()
.nest("db", db_router)
.nest("api", api_router);
init_router(&mut router).await;
let req = RouterRequest {
id: RequestId::Number(1),
inner: McpRequest::ListTools(ListToolsParams::default()),
extensions: Extensions::new(),
};
let resp = router.ready().await.unwrap().call(req).await.unwrap();
match resp.inner {
Ok(McpResponse::ListTools(result)) => {
assert_eq!(result.tools.len(), 2);
let names: Vec<&str> = result.tools.iter().map(|t| t.name.as_str()).collect();
assert!(names.contains(&"db.query"));
assert!(names.contains(&"api.fetch"));
}
_ => panic!("Expected ListTools response"),
}
}
#[tokio::test]
async fn test_router_merge_and_nest_combined() {
let tool_a = ToolBuilder::new("local")
.description("Local tool")
.handler(|_: StringInput| async move { Ok(CallToolResult::text("local")) })
.build();
let nested_tool = ToolBuilder::new("remote")
.description("Remote tool")
.handler(|_: StringInput| async move { Ok(CallToolResult::text("remote")) })
.build();
let nested_router = McpRouter::new().tool(nested_tool);
let mut router = McpRouter::new()
.tool(tool_a)
.nest("external", nested_router);
init_router(&mut router).await;
let req = RouterRequest {
id: RequestId::Number(1),
inner: McpRequest::ListTools(ListToolsParams::default()),
extensions: Extensions::new(),
};
let resp = router.ready().await.unwrap().call(req).await.unwrap();
match resp.inner {
Ok(McpResponse::ListTools(result)) => {
assert_eq!(result.tools.len(), 2);
let names: Vec<&str> = result.tools.iter().map(|t| t.name.as_str()).collect();
assert!(names.contains(&"local"));
assert!(names.contains(&"external.remote"));
}
_ => panic!("Expected ListTools response"),
}
}
#[tokio::test]
async fn test_router_merge_preserves_server_info() {
let child_router = McpRouter::new()
.server_info("child", "2.0")
.instructions("Child instructions");
let mut router = McpRouter::new()
.server_info("parent", "1.0")
.instructions("Parent instructions")
.merge(child_router);
init_router(&mut router).await;
let init_req = RouterRequest {
id: RequestId::Number(99),
inner: McpRequest::Initialize(InitializeParams {
protocol_version: "2025-11-25".to_string(),
capabilities: ClientCapabilities::default(),
client_info: Implementation {
name: "test".to_string(),
version: "1.0".to_string(),
..Default::default()
},
meta: None,
}),
extensions: Extensions::new(),
};
let child_router2 = McpRouter::new().server_info("child", "2.0");
let mut fresh_router = McpRouter::new()
.server_info("parent", "1.0")
.merge(child_router2);
let resp = fresh_router
.ready()
.await
.unwrap()
.call(init_req)
.await
.unwrap();
match resp.inner {
Ok(McpResponse::Initialize(result)) => {
assert_eq!(result.server_info.name, "parent");
assert_eq!(result.server_info.version, "1.0");
}
_ => panic!("Expected Initialize response"),
}
}
#[tokio::test]
async fn test_auto_instructions_tools_only() {
let tool_a = ToolBuilder::new("alpha")
.description("Alpha tool")
.handler(|_: AddInput| async move { Ok(CallToolResult::text("ok")) })
.build();
let tool_b = ToolBuilder::new("beta")
.description("Beta tool")
.handler(|_: AddInput| async move { Ok(CallToolResult::text("ok")) })
.build();
let mut router = McpRouter::new()
.auto_instructions()
.tool(tool_a)
.tool(tool_b);
let resp = send_initialize(&mut router).await;
let instructions = resp.instructions.expect("should have instructions");
assert!(instructions.contains("## Tools"));
assert!(instructions.contains("- **alpha**: Alpha tool"));
assert!(instructions.contains("- **beta**: Beta tool"));
assert!(!instructions.contains("## Resources"));
assert!(!instructions.contains("## Prompts"));
}
#[tokio::test]
async fn test_auto_instructions_with_annotations() {
let read_only_tool = ToolBuilder::new("query")
.description("Run a query")
.read_only()
.handler(|_: AddInput| async move { Ok(CallToolResult::text("ok")) })
.build();
let destructive_tool = ToolBuilder::new("delete")
.description("Delete a record")
.handler(|_: AddInput| async move { Ok(CallToolResult::text("ok")) })
.build();
let idempotent_tool = ToolBuilder::new("upsert")
.description("Upsert a record")
.non_destructive()
.idempotent()
.handler(|_: AddInput| async move { Ok(CallToolResult::text("ok")) })
.build();
let mut router = McpRouter::new()
.auto_instructions()
.tool(read_only_tool)
.tool(destructive_tool)
.tool(idempotent_tool);
let resp = send_initialize(&mut router).await;
let instructions = resp.instructions.unwrap();
assert!(instructions.contains("- **query**: Run a query [read-only]"));
assert!(instructions.contains("- **delete**: Delete a record\n"));
assert!(instructions.contains("- **upsert**: Upsert a record [idempotent]"));
}
#[tokio::test]
async fn test_auto_instructions_with_resources() {
use crate::resource::ResourceBuilder;
let resource = ResourceBuilder::new("file:///schema.sql")
.name("Schema")
.description("Database schema")
.text("CREATE TABLE ...");
let mut router = McpRouter::new().auto_instructions().resource(resource);
let resp = send_initialize(&mut router).await;
let instructions = resp.instructions.unwrap();
assert!(instructions.contains("## Resources"));
assert!(instructions.contains("- **file:///schema.sql**: Database schema"));
assert!(!instructions.contains("## Tools"));
}
#[tokio::test]
async fn test_auto_instructions_with_resource_templates() {
use crate::resource::ResourceTemplateBuilder;
let template = ResourceTemplateBuilder::new("file:///{path}")
.name("File")
.description("Read a file by path")
.handler(
|_uri: String, _vars: std::collections::HashMap<String, String>| async move {
Ok(crate::ReadResourceResult::text("content", "text/plain"))
},
);
let mut router = McpRouter::new()
.auto_instructions()
.resource_template(template);
let resp = send_initialize(&mut router).await;
let instructions = resp.instructions.unwrap();
assert!(instructions.contains("## Resources"));
assert!(instructions.contains("- **file:///{path}**: Read a file by path"));
}
#[tokio::test]
async fn test_auto_instructions_with_prompts() {
use crate::prompt::PromptBuilder;
let prompt = PromptBuilder::new("write_query")
.description("Help write a SQL query")
.user_message("Write a query for: {task}");
let mut router = McpRouter::new().auto_instructions().prompt(prompt);
let resp = send_initialize(&mut router).await;
let instructions = resp.instructions.unwrap();
assert!(instructions.contains("## Prompts"));
assert!(instructions.contains("- **write_query**: Help write a SQL query"));
assert!(!instructions.contains("## Tools"));
}
#[tokio::test]
async fn test_auto_instructions_all_sections() {
use crate::prompt::PromptBuilder;
use crate::resource::ResourceBuilder;
let tool = ToolBuilder::new("query")
.description("Execute SQL")
.read_only()
.handler(|_: AddInput| async move { Ok(CallToolResult::text("ok")) })
.build();
let resource = ResourceBuilder::new("db://schema")
.name("Schema")
.description("Full database schema")
.text("schema");
let prompt = PromptBuilder::new("write_query")
.description("Help write a SQL query")
.user_message("Write a query");
let mut router = McpRouter::new()
.auto_instructions()
.tool(tool)
.resource(resource)
.prompt(prompt);
let resp = send_initialize(&mut router).await;
let instructions = resp.instructions.unwrap();
assert!(instructions.contains("## Tools"));
assert!(instructions.contains("## Resources"));
assert!(instructions.contains("## Prompts"));
let tools_pos = instructions.find("## Tools").unwrap();
let resources_pos = instructions.find("## Resources").unwrap();
let prompts_pos = instructions.find("## Prompts").unwrap();
assert!(tools_pos < resources_pos);
assert!(resources_pos < prompts_pos);
}
#[tokio::test]
async fn test_auto_instructions_with_prefix_and_suffix() {
let tool = ToolBuilder::new("echo")
.description("Echo input")
.handler(|_: AddInput| async move { Ok(CallToolResult::text("ok")) })
.build();
let mut router = McpRouter::new()
.auto_instructions_with(
Some("This server provides echo capabilities."),
Some("Contact admin@example.com for support."),
)
.tool(tool);
let resp = send_initialize(&mut router).await;
let instructions = resp.instructions.unwrap();
assert!(instructions.starts_with("This server provides echo capabilities."));
assert!(instructions.ends_with("Contact admin@example.com for support."));
assert!(instructions.contains("## Tools"));
assert!(instructions.contains("- **echo**: Echo input"));
}
#[tokio::test]
async fn test_auto_instructions_prefix_only() {
let tool = ToolBuilder::new("echo")
.description("Echo input")
.handler(|_: AddInput| async move { Ok(CallToolResult::text("ok")) })
.build();
let mut router = McpRouter::new()
.auto_instructions_with(Some("My server intro."), None::<String>)
.tool(tool);
let resp = send_initialize(&mut router).await;
let instructions = resp.instructions.unwrap();
assert!(instructions.starts_with("My server intro."));
assert!(instructions.contains("- **echo**: Echo input"));
}
#[tokio::test]
async fn test_auto_instructions_empty_router() {
let mut router = McpRouter::new().auto_instructions();
let resp = send_initialize(&mut router).await;
let instructions = resp.instructions.expect("should have instructions");
assert!(!instructions.contains("## Tools"));
assert!(!instructions.contains("## Resources"));
assert!(!instructions.contains("## Prompts"));
assert!(instructions.is_empty());
}
#[tokio::test]
async fn test_auto_instructions_overrides_manual() {
let tool = ToolBuilder::new("echo")
.description("Echo input")
.handler(|_: AddInput| async move { Ok(CallToolResult::text("ok")) })
.build();
let mut router = McpRouter::new()
.instructions("This will be overridden")
.auto_instructions()
.tool(tool);
let resp = send_initialize(&mut router).await;
let instructions = resp.instructions.unwrap();
assert!(!instructions.contains("This will be overridden"));
assert!(instructions.contains("- **echo**: Echo input"));
}
#[tokio::test]
async fn test_no_auto_instructions_returns_manual() {
let tool = ToolBuilder::new("echo")
.description("Echo input")
.handler(|_: AddInput| async move { Ok(CallToolResult::text("ok")) })
.build();
let mut router = McpRouter::new()
.instructions("Manual instructions here")
.tool(tool);
let resp = send_initialize(&mut router).await;
let instructions = resp.instructions.unwrap();
assert_eq!(instructions, "Manual instructions here");
}
#[tokio::test]
async fn test_auto_instructions_no_description_fallback() {
let tool = ToolBuilder::new("mystery")
.handler(|_: AddInput| async move { Ok(CallToolResult::text("ok")) })
.build();
let mut router = McpRouter::new().auto_instructions().tool(tool);
let resp = send_initialize(&mut router).await;
let instructions = resp.instructions.unwrap();
assert!(instructions.contains("- **mystery**: No description"));
}
#[tokio::test]
async fn test_auto_instructions_sorted_alphabetically() {
let tool_z = ToolBuilder::new("zebra")
.description("Z tool")
.handler(|_: AddInput| async move { Ok(CallToolResult::text("ok")) })
.build();
let tool_a = ToolBuilder::new("alpha")
.description("A tool")
.handler(|_: AddInput| async move { Ok(CallToolResult::text("ok")) })
.build();
let tool_m = ToolBuilder::new("middle")
.description("M tool")
.handler(|_: AddInput| async move { Ok(CallToolResult::text("ok")) })
.build();
let mut router = McpRouter::new()
.auto_instructions()
.tool(tool_z)
.tool(tool_a)
.tool(tool_m);
let resp = send_initialize(&mut router).await;
let instructions = resp.instructions.unwrap();
let alpha_pos = instructions.find("**alpha**").unwrap();
let middle_pos = instructions.find("**middle**").unwrap();
let zebra_pos = instructions.find("**zebra**").unwrap();
assert!(alpha_pos < middle_pos);
assert!(middle_pos < zebra_pos);
}
#[tokio::test]
async fn test_auto_instructions_read_only_and_idempotent_tags() {
let tool = ToolBuilder::new("safe_update")
.description("Safe update operation")
.idempotent()
.handler(|_: AddInput| async move { Ok(CallToolResult::text("ok")) })
.build();
let mut router = McpRouter::new().auto_instructions().tool(tool);
let resp = send_initialize(&mut router).await;
let instructions = resp.instructions.unwrap();
assert!(
instructions.contains("[idempotent]"),
"got: {}",
instructions
);
}
#[tokio::test]
async fn test_auto_instructions_lazy_generation() {
let mut router = McpRouter::new().auto_instructions();
let tool = ToolBuilder::new("late_tool")
.description("Added after auto_instructions")
.handler(|_: AddInput| async move { Ok(CallToolResult::text("ok")) })
.build();
router = router.tool(tool);
let resp = send_initialize(&mut router).await;
let instructions = resp.instructions.unwrap();
assert!(instructions.contains("- **late_tool**: Added after auto_instructions"));
}
#[tokio::test]
async fn test_auto_instructions_multiple_annotation_tags() {
let tool = ToolBuilder::new("update")
.description("Update a record")
.annotations(ToolAnnotations {
read_only_hint: true,
idempotent_hint: true,
..Default::default()
})
.handler(|_: AddInput| async move { Ok(CallToolResult::text("ok")) })
.build();
let mut router = McpRouter::new().auto_instructions().tool(tool);
let resp = send_initialize(&mut router).await;
let instructions = resp.instructions.unwrap();
assert!(
instructions.contains("[read-only, idempotent]"),
"got: {}",
instructions
);
}
#[tokio::test]
async fn test_auto_instructions_no_annotations_no_tags() {
let tool = ToolBuilder::new("fetch")
.description("Fetch data")
.handler(|_: AddInput| async move { Ok(CallToolResult::text("ok")) })
.build();
let mut router = McpRouter::new().auto_instructions().tool(tool);
let resp = send_initialize(&mut router).await;
let instructions = resp.instructions.unwrap();
assert!(
!instructions.contains('['),
"should have no tags, got: {}",
instructions
);
assert!(instructions.contains("- **fetch**: Fetch data"));
}
async fn send_initialize(router: &mut McpRouter) -> InitializeResult {
let init_req = RouterRequest {
id: RequestId::Number(0),
inner: McpRequest::Initialize(InitializeParams {
protocol_version: "2025-11-25".to_string(),
capabilities: ClientCapabilities {
roots: None,
sampling: None,
elicitation: None,
tasks: None,
experimental: None,
extensions: None,
},
client_info: Implementation {
name: "test".to_string(),
version: "1.0".to_string(),
..Default::default()
},
meta: None,
}),
extensions: Extensions::new(),
};
let resp = router.ready().await.unwrap().call(init_req).await.unwrap();
match resp.inner {
Ok(McpResponse::Initialize(result)) => result,
other => panic!("Expected Initialize response, got {:?}", other),
}
}
#[tokio::test]
async fn test_notify_tools_list_changed() {
let (tx, mut rx) = crate::context::notification_channel(16);
let router = McpRouter::new()
.server_info("test", "1.0")
.with_notification_sender(tx);
assert!(router.notify_tools_list_changed());
let notification = rx.recv().await.unwrap();
assert!(matches!(notification, ServerNotification::ToolsListChanged));
}
#[tokio::test]
async fn test_notify_prompts_list_changed() {
let (tx, mut rx) = crate::context::notification_channel(16);
let router = McpRouter::new()
.server_info("test", "1.0")
.with_notification_sender(tx);
assert!(router.notify_prompts_list_changed());
let notification = rx.recv().await.unwrap();
assert!(matches!(
notification,
ServerNotification::PromptsListChanged
));
}
#[tokio::test]
async fn test_notify_without_sender_returns_false() {
let router = McpRouter::new().server_info("test", "1.0");
assert!(!router.notify_tools_list_changed());
assert!(!router.notify_prompts_list_changed());
assert!(!router.notify_resources_list_changed());
}
#[tokio::test]
async fn test_list_changed_capabilities_with_notification_sender() {
let (tx, _rx) = crate::context::notification_channel(16);
let tool = ToolBuilder::new("test")
.description("test")
.handler(|_input: AddInput| async { Ok(CallToolResult::text("ok")) })
.build();
let mut router = McpRouter::new()
.server_info("test", "1.0")
.tool(tool)
.with_notification_sender(tx);
init_router(&mut router).await;
let caps = router.capabilities();
let tools_cap = caps.tools.expect("tools capability should be present");
assert!(
tools_cap.list_changed,
"tools.listChanged should be true when notification sender is configured"
);
}
#[tokio::test]
async fn test_list_changed_capabilities_without_notification_sender() {
let tool = ToolBuilder::new("test")
.description("test")
.handler(|_input: AddInput| async { Ok(CallToolResult::text("ok")) })
.build();
let mut router = McpRouter::new().server_info("test", "1.0").tool(tool);
init_router(&mut router).await;
let caps = router.capabilities();
let tools_cap = caps.tools.expect("tools capability should be present");
assert!(
!tools_cap.list_changed,
"tools.listChanged should be false without notification sender"
);
}
#[tokio::test]
async fn test_set_logging_level_filters_messages() {
let (tx, mut rx) = crate::context::notification_channel(16);
let mut router = McpRouter::new()
.server_info("test", "1.0")
.with_notification_sender(tx);
init_router(&mut router).await;
let set_level_req = RouterRequest {
id: RequestId::Number(99),
inner: McpRequest::SetLoggingLevel(SetLogLevelParams {
level: LogLevel::Warning,
meta: None,
}),
extensions: crate::context::Extensions::new(),
};
let resp = router
.ready()
.await
.unwrap()
.call(set_level_req)
.await
.unwrap();
assert!(matches!(resp.inner, Ok(McpResponse::SetLoggingLevel(_))));
let ctx = router.create_context(RequestId::Number(100), None);
ctx.send_log(LoggingMessageParams::new(
LogLevel::Error,
serde_json::Value::Null,
));
assert!(
rx.try_recv().is_ok(),
"Error should pass through Warning filter"
);
ctx.send_log(LoggingMessageParams::new(
LogLevel::Info,
serde_json::Value::Null,
));
assert!(
rx.try_recv().is_err(),
"Info should be filtered at Warning level"
);
}
#[test]
fn test_paginate_no_page_size() {
let items = vec![1, 2, 3, 4, 5];
let (page, cursor) = paginate(items.clone(), None, None).unwrap();
assert_eq!(page, items);
assert!(cursor.is_none());
}
#[test]
fn test_paginate_first_page() {
let items = vec![1, 2, 3, 4, 5];
let (page, cursor) = paginate(items, None, Some(2)).unwrap();
assert_eq!(page, vec![1, 2]);
assert!(cursor.is_some());
}
#[test]
fn test_paginate_middle_page() {
let items = vec![1, 2, 3, 4, 5];
let (page1, cursor1) = paginate(items.clone(), None, Some(2)).unwrap();
assert_eq!(page1, vec![1, 2]);
let (page2, cursor2) = paginate(items, cursor1.as_deref(), Some(2)).unwrap();
assert_eq!(page2, vec![3, 4]);
assert!(cursor2.is_some());
}
#[test]
fn test_paginate_last_page() {
let items = vec![1, 2, 3, 4, 5];
let cursor = encode_cursor(4);
let (page, next) = paginate(items, Some(&cursor), Some(2)).unwrap();
assert_eq!(page, vec![5]);
assert!(next.is_none());
}
#[test]
fn test_paginate_exact_boundary() {
let items = vec![1, 2, 3, 4];
let (page, cursor) = paginate(items, None, Some(4)).unwrap();
assert_eq!(page, vec![1, 2, 3, 4]);
assert!(cursor.is_none());
}
#[test]
fn test_paginate_invalid_cursor() {
let items = vec![1, 2, 3];
let result = paginate(items, Some("not-valid-base64!@#$"), Some(2));
assert!(result.is_err());
}
#[test]
fn test_cursor_round_trip() {
let offset = 42;
let encoded = encode_cursor(offset);
let decoded = decode_cursor(&encoded).unwrap();
assert_eq!(decoded, offset);
}
#[tokio::test]
async fn test_list_tools_pagination() {
let tool_a = ToolBuilder::new("alpha")
.description("a")
.handler(|_input: AddInput| async { Ok(CallToolResult::text("ok")) })
.build();
let tool_b = ToolBuilder::new("beta")
.description("b")
.handler(|_input: AddInput| async { Ok(CallToolResult::text("ok")) })
.build();
let tool_c = ToolBuilder::new("gamma")
.description("c")
.handler(|_input: AddInput| async { Ok(CallToolResult::text("ok")) })
.build();
let mut router = McpRouter::new()
.server_info("test", "1.0")
.page_size(2)
.tool(tool_a)
.tool(tool_b)
.tool(tool_c);
init_router(&mut router).await;
let req = RouterRequest {
id: RequestId::Number(1),
inner: McpRequest::ListTools(ListToolsParams {
cursor: None,
meta: None,
}),
extensions: Extensions::new(),
};
let resp = router.ready().await.unwrap().call(req).await.unwrap();
let (tools, next_cursor) = match resp.inner {
Ok(McpResponse::ListTools(result)) => (result.tools, result.next_cursor),
other => panic!("Expected ListTools, got {:?}", other),
};
assert_eq!(tools.len(), 2);
assert_eq!(tools[0].name, "alpha");
assert_eq!(tools[1].name, "beta");
assert!(next_cursor.is_some());
let req = RouterRequest {
id: RequestId::Number(2),
inner: McpRequest::ListTools(ListToolsParams {
cursor: next_cursor,
meta: None,
}),
extensions: Extensions::new(),
};
let resp = router.ready().await.unwrap().call(req).await.unwrap();
let (tools, next_cursor) = match resp.inner {
Ok(McpResponse::ListTools(result)) => (result.tools, result.next_cursor),
other => panic!("Expected ListTools, got {:?}", other),
};
assert_eq!(tools.len(), 1);
assert_eq!(tools[0].name, "gamma");
assert!(next_cursor.is_none());
}
#[tokio::test]
async fn test_list_tools_no_pagination_by_default() {
let tool_a = ToolBuilder::new("alpha")
.description("a")
.handler(|_input: AddInput| async { Ok(CallToolResult::text("ok")) })
.build();
let tool_b = ToolBuilder::new("beta")
.description("b")
.handler(|_input: AddInput| async { Ok(CallToolResult::text("ok")) })
.build();
let mut router = McpRouter::new()
.server_info("test", "1.0")
.tool(tool_a)
.tool(tool_b);
init_router(&mut router).await;
let req = RouterRequest {
id: RequestId::Number(1),
inner: McpRequest::ListTools(ListToolsParams {
cursor: None,
meta: None,
}),
extensions: Extensions::new(),
};
let resp = router.ready().await.unwrap().call(req).await.unwrap();
match resp.inner {
Ok(McpResponse::ListTools(result)) => {
assert_eq!(result.tools.len(), 2);
assert!(result.next_cursor.is_none());
}
other => panic!("Expected ListTools, got {:?}", other),
}
}
#[cfg(feature = "dynamic-tools")]
mod dynamic_tools_tests {
use super::*;
#[tokio::test]
async fn test_dynamic_tools_register_and_list() {
let (router, registry) = McpRouter::new()
.server_info("test", "1.0")
.with_dynamic_tools();
let tool = ToolBuilder::new("dynamic_echo")
.description("Dynamic echo")
.handler(|input: AddInput| async move {
Ok(CallToolResult::text(format!("{}", input.a)))
})
.build();
registry.register(tool);
let mut router = router;
init_router(&mut router).await;
let req = RouterRequest {
id: RequestId::Number(1),
inner: McpRequest::ListTools(ListToolsParams::default()),
extensions: Extensions::new(),
};
let resp = router.ready().await.unwrap().call(req).await.unwrap();
match resp.inner {
Ok(McpResponse::ListTools(result)) => {
assert_eq!(result.tools.len(), 1);
assert_eq!(result.tools[0].name, "dynamic_echo");
}
_ => panic!("Expected ListTools response"),
}
}
#[tokio::test]
async fn test_dynamic_tools_unregister() {
let (router, registry) = McpRouter::new()
.server_info("test", "1.0")
.with_dynamic_tools();
let tool = ToolBuilder::new("temp")
.description("Temporary")
.handler(|_: AddInput| async { Ok(CallToolResult::text("ok")) })
.build();
registry.register(tool);
assert!(registry.contains("temp"));
let removed = registry.unregister("temp");
assert!(removed);
assert!(!registry.contains("temp"));
assert!(!registry.unregister("temp"));
let mut router = router;
init_router(&mut router).await;
let req = RouterRequest {
id: RequestId::Number(1),
inner: McpRequest::ListTools(ListToolsParams::default()),
extensions: Extensions::new(),
};
let resp = router.ready().await.unwrap().call(req).await.unwrap();
match resp.inner {
Ok(McpResponse::ListTools(result)) => {
assert_eq!(result.tools.len(), 0);
}
_ => panic!("Expected ListTools response"),
}
}
#[tokio::test]
async fn test_dynamic_tools_merged_with_static() {
let static_tool = ToolBuilder::new("static_tool")
.description("Static")
.handler(|_: AddInput| async { Ok(CallToolResult::text("static")) })
.build();
let (router, registry) = McpRouter::new()
.server_info("test", "1.0")
.tool(static_tool)
.with_dynamic_tools();
let dynamic_tool = ToolBuilder::new("dynamic_tool")
.description("Dynamic")
.handler(|_: AddInput| async { Ok(CallToolResult::text("dynamic")) })
.build();
registry.register(dynamic_tool);
let mut router = router;
init_router(&mut router).await;
let req = RouterRequest {
id: RequestId::Number(1),
inner: McpRequest::ListTools(ListToolsParams::default()),
extensions: Extensions::new(),
};
let resp = router.ready().await.unwrap().call(req).await.unwrap();
match resp.inner {
Ok(McpResponse::ListTools(result)) => {
assert_eq!(result.tools.len(), 2);
let names: Vec<&str> = result.tools.iter().map(|t| t.name.as_str()).collect();
assert!(names.contains(&"static_tool"));
assert!(names.contains(&"dynamic_tool"));
}
_ => panic!("Expected ListTools response"),
}
}
#[tokio::test]
async fn test_static_tools_shadow_dynamic() {
let static_tool = ToolBuilder::new("shared")
.description("Static version")
.handler(|_: AddInput| async { Ok(CallToolResult::text("static")) })
.build();
let (router, registry) = McpRouter::new()
.server_info("test", "1.0")
.tool(static_tool)
.with_dynamic_tools();
let dynamic_tool = ToolBuilder::new("shared")
.description("Dynamic version")
.handler(|_: AddInput| async { Ok(CallToolResult::text("dynamic")) })
.build();
registry.register(dynamic_tool);
let mut router = router;
init_router(&mut router).await;
let req = RouterRequest {
id: RequestId::Number(1),
inner: McpRequest::ListTools(ListToolsParams::default()),
extensions: Extensions::new(),
};
let resp = router.ready().await.unwrap().call(req).await.unwrap();
match resp.inner {
Ok(McpResponse::ListTools(result)) => {
assert_eq!(result.tools.len(), 1);
assert_eq!(result.tools[0].name, "shared");
assert_eq!(
result.tools[0].description.as_deref(),
Some("Static version")
);
}
_ => panic!("Expected ListTools response"),
}
let req = RouterRequest {
id: RequestId::Number(2),
inner: McpRequest::CallTool(CallToolParams {
name: "shared".to_string(),
arguments: serde_json::json!({"a": 1, "b": 2}),
meta: None,
task: None,
}),
extensions: Extensions::new(),
};
let resp = router.ready().await.unwrap().call(req).await.unwrap();
match resp.inner {
Ok(McpResponse::CallTool(result)) => {
assert!(!result.is_error);
match &result.content[0] {
Content::Text { text, .. } => assert_eq!(text, "static"),
_ => panic!("Expected text content"),
}
}
_ => panic!("Expected CallTool response"),
}
}
#[tokio::test]
async fn test_dynamic_tools_call() {
let (router, registry) = McpRouter::new()
.server_info("test", "1.0")
.with_dynamic_tools();
let tool = ToolBuilder::new("add")
.description("Add two numbers")
.handler(|input: AddInput| async move {
Ok(CallToolResult::text(format!("{}", input.a + input.b)))
})
.build();
registry.register(tool);
let mut router = router;
init_router(&mut router).await;
let req = RouterRequest {
id: RequestId::Number(1),
inner: McpRequest::CallTool(CallToolParams {
name: "add".to_string(),
arguments: serde_json::json!({"a": 3, "b": 4}),
meta: None,
task: None,
}),
extensions: Extensions::new(),
};
let resp = router.ready().await.unwrap().call(req).await.unwrap();
match resp.inner {
Ok(McpResponse::CallTool(result)) => {
assert!(!result.is_error);
match &result.content[0] {
Content::Text { text, .. } => assert_eq!(text, "7"),
_ => panic!("Expected text content"),
}
}
_ => panic!("Expected CallTool response"),
}
}
#[tokio::test]
async fn test_dynamic_tools_notification_on_register() {
let (tx, mut rx) = crate::context::notification_channel(16);
let (router, registry) = McpRouter::new()
.server_info("test", "1.0")
.with_dynamic_tools();
let _router = router.with_notification_sender(tx);
let tool = ToolBuilder::new("notified")
.description("Test")
.handler(|_: AddInput| async { Ok(CallToolResult::text("ok")) })
.build();
registry.register(tool);
let notification = rx.recv().await.unwrap();
assert!(matches!(notification, ServerNotification::ToolsListChanged));
}
#[tokio::test]
async fn test_dynamic_tools_notification_on_unregister() {
let (tx, mut rx) = crate::context::notification_channel(16);
let (router, registry) = McpRouter::new()
.server_info("test", "1.0")
.with_dynamic_tools();
let _router = router.with_notification_sender(tx);
let tool = ToolBuilder::new("notified")
.description("Test")
.handler(|_: AddInput| async { Ok(CallToolResult::text("ok")) })
.build();
registry.register(tool);
let _ = rx.recv().await.unwrap();
registry.unregister("notified");
let notification = rx.recv().await.unwrap();
assert!(matches!(notification, ServerNotification::ToolsListChanged));
}
#[tokio::test]
async fn test_dynamic_tools_no_notification_on_empty_unregister() {
let (tx, mut rx) = crate::context::notification_channel(16);
let (router, registry) = McpRouter::new()
.server_info("test", "1.0")
.with_dynamic_tools();
let _router = router.with_notification_sender(tx);
assert!(!registry.unregister("nonexistent"));
assert!(rx.try_recv().is_err());
}
#[tokio::test]
async fn test_dynamic_tools_filter_applies() {
use crate::filter::CapabilityFilter;
let (router, registry) = McpRouter::new()
.server_info("test", "1.0")
.tool_filter(CapabilityFilter::new(|_, tool: &Tool| {
tool.name != "hidden"
}))
.with_dynamic_tools();
let visible = ToolBuilder::new("visible")
.description("Visible")
.handler(|_: AddInput| async { Ok(CallToolResult::text("ok")) })
.build();
let hidden = ToolBuilder::new("hidden")
.description("Hidden")
.handler(|_: AddInput| async { Ok(CallToolResult::text("ok")) })
.build();
registry.register(visible);
registry.register(hidden);
let mut router = router;
init_router(&mut router).await;
let req = RouterRequest {
id: RequestId::Number(1),
inner: McpRequest::ListTools(ListToolsParams::default()),
extensions: Extensions::new(),
};
let resp = router.ready().await.unwrap().call(req).await.unwrap();
match resp.inner {
Ok(McpResponse::ListTools(result)) => {
assert_eq!(result.tools.len(), 1);
assert_eq!(result.tools[0].name, "visible");
}
_ => panic!("Expected ListTools response"),
}
let req = RouterRequest {
id: RequestId::Number(2),
inner: McpRequest::CallTool(CallToolParams {
name: "hidden".to_string(),
arguments: serde_json::json!({"a": 1, "b": 2}),
meta: None,
task: None,
}),
extensions: Extensions::new(),
};
let resp = router.ready().await.unwrap().call(req).await.unwrap();
match resp.inner {
Err(e) => {
assert_eq!(e.code, -32601); }
_ => panic!("Expected JsonRpc error"),
}
}
#[tokio::test]
async fn test_dynamic_tools_capabilities_advertised() {
let (mut router, _registry) = McpRouter::new()
.server_info("test", "1.0")
.with_dynamic_tools();
let init_req = RouterRequest {
id: RequestId::Number(1),
inner: McpRequest::Initialize(InitializeParams {
protocol_version: "2025-11-25".to_string(),
capabilities: ClientCapabilities::default(),
client_info: Implementation {
name: "test".to_string(),
version: "1.0".to_string(),
..Default::default()
},
meta: None,
}),
extensions: Extensions::new(),
};
let resp = router.ready().await.unwrap().call(init_req).await.unwrap();
match resp.inner {
Ok(McpResponse::Initialize(result)) => {
assert!(result.capabilities.tools.is_some());
}
_ => panic!("Expected Initialize response"),
}
}
#[tokio::test]
async fn test_dynamic_tools_multi_session_notification() {
let (tx1, mut rx1) = crate::context::notification_channel(16);
let (tx2, mut rx2) = crate::context::notification_channel(16);
let (router, registry) = McpRouter::new()
.server_info("test", "1.0")
.with_dynamic_tools();
let _session1 = router.clone().with_notification_sender(tx1);
let _session2 = router.clone().with_notification_sender(tx2);
let tool = ToolBuilder::new("broadcast")
.description("Test")
.handler(|_: AddInput| async { Ok(CallToolResult::text("ok")) })
.build();
registry.register(tool);
let n1 = rx1.recv().await.unwrap();
let n2 = rx2.recv().await.unwrap();
assert!(matches!(n1, ServerNotification::ToolsListChanged));
assert!(matches!(n2, ServerNotification::ToolsListChanged));
}
#[tokio::test]
async fn test_dynamic_tools_call_not_found() {
let (router, _registry) = McpRouter::new()
.server_info("test", "1.0")
.with_dynamic_tools();
let mut router = router;
init_router(&mut router).await;
let req = RouterRequest {
id: RequestId::Number(1),
inner: McpRequest::CallTool(CallToolParams {
name: "nonexistent".to_string(),
arguments: serde_json::json!({}),
meta: None,
task: None,
}),
extensions: Extensions::new(),
};
let resp = router.ready().await.unwrap().call(req).await.unwrap();
match resp.inner {
Err(e) => {
assert_eq!(e.code, -32601);
}
_ => panic!("Expected method not found error"),
}
}
#[tokio::test]
async fn test_dynamic_tools_registry_list() {
let (_, registry) = McpRouter::new()
.server_info("test", "1.0")
.with_dynamic_tools();
assert!(registry.list().is_empty());
let tool = ToolBuilder::new("tool_a")
.description("A")
.handler(|_: AddInput| async { Ok(CallToolResult::text("ok")) })
.build();
registry.register(tool);
let tool = ToolBuilder::new("tool_b")
.description("B")
.handler(|_: AddInput| async { Ok(CallToolResult::text("ok")) })
.build();
registry.register(tool);
let tools = registry.list();
assert_eq!(tools.len(), 2);
let names: Vec<&str> = tools.iter().map(|t| t.name.as_str()).collect();
assert!(names.contains(&"tool_a"));
assert!(names.contains(&"tool_b"));
}
}
#[tokio::test]
async fn test_tool_if_true_registers() {
let tool = ToolBuilder::new("conditional")
.description("Conditional tool")
.handler(|_: AddInput| async { Ok(CallToolResult::text("ok")) })
.build();
let mut router = McpRouter::new().tool_if(true, tool);
init_router(&mut router).await;
let req = RouterRequest {
id: RequestId::Number(1),
inner: McpRequest::ListTools(ListToolsParams::default()),
extensions: Extensions::new(),
};
let resp = router.ready().await.unwrap().call(req).await.unwrap();
match resp.inner {
Ok(McpResponse::ListTools(result)) => {
assert_eq!(result.tools.len(), 1);
assert_eq!(result.tools[0].name, "conditional");
}
_ => panic!("Expected ListTools response"),
}
}
#[tokio::test]
async fn test_tool_if_false_skips() {
let tool = ToolBuilder::new("conditional")
.description("Conditional tool")
.handler(|_: AddInput| async { Ok(CallToolResult::text("ok")) })
.build();
let mut router = McpRouter::new().tool_if(false, tool);
init_router(&mut router).await;
let req = RouterRequest {
id: RequestId::Number(1),
inner: McpRequest::ListTools(ListToolsParams::default()),
extensions: Extensions::new(),
};
let resp = router.ready().await.unwrap().call(req).await.unwrap();
match resp.inner {
Ok(McpResponse::ListTools(result)) => {
assert_eq!(result.tools.len(), 0);
}
_ => panic!("Expected ListTools response"),
}
}
#[tokio::test]
async fn test_tools_if_batch_conditional() {
let tools = vec![
ToolBuilder::new("a")
.description("Tool A")
.handler(|_: AddInput| async { Ok(CallToolResult::text("ok")) })
.build(),
ToolBuilder::new("b")
.description("Tool B")
.handler(|_: AddInput| async { Ok(CallToolResult::text("ok")) })
.build(),
];
let mut router = McpRouter::new().tools_if(false, tools);
init_router(&mut router).await;
let req = RouterRequest {
id: RequestId::Number(1),
inner: McpRequest::ListTools(ListToolsParams::default()),
extensions: Extensions::new(),
};
let resp = router.ready().await.unwrap().call(req).await.unwrap();
match resp.inner {
Ok(McpResponse::ListTools(result)) => {
assert_eq!(result.tools.len(), 0);
}
_ => panic!("Expected ListTools response"),
}
}
#[test]
fn test_resource_if_true_registers() {
let resource = crate::resource::ResourceBuilder::new("file:///test.txt")
.name("test")
.text("hello");
let router = McpRouter::new().resource_if(true, resource);
assert_eq!(router.inner.resources.len(), 1);
}
#[test]
fn test_resource_if_false_skips() {
let resource = crate::resource::ResourceBuilder::new("file:///test.txt")
.name("test")
.text("hello");
let router = McpRouter::new().resource_if(false, resource);
assert_eq!(router.inner.resources.len(), 0);
}
#[test]
fn test_prompt_if_true_registers() {
let prompt = crate::prompt::PromptBuilder::new("greet")
.description("Greeting")
.user_message("Hello!");
let router = McpRouter::new().prompt_if(true, prompt);
assert_eq!(router.inner.prompts.len(), 1);
}
#[test]
fn test_prompt_if_false_skips() {
let prompt = crate::prompt::PromptBuilder::new("greet")
.description("Greeting")
.user_message("Hello!");
let router = McpRouter::new().prompt_if(false, prompt);
assert_eq!(router.inner.prompts.len(), 0);
}
#[tokio::test]
async fn test_disable_tool_hides_from_list() {
let safe = ToolBuilder::new("safe")
.description("Safe tool")
.handler(|_: AddInput| async { Ok(CallToolResult::text("ok")) })
.build();
let dangerous = ToolBuilder::new("dangerous")
.description("Dangerous tool")
.handler(|_: AddInput| async { Ok(CallToolResult::text("ok")) })
.build();
let mut router = McpRouter::new().tool(safe).tool(dangerous);
init_router(&mut router).await;
router.disable_tool("dangerous");
assert!(router.is_tool_enabled("safe"));
assert!(!router.is_tool_enabled("dangerous"));
let req = RouterRequest {
id: RequestId::Number(1),
inner: McpRequest::ListTools(ListToolsParams::default()),
extensions: Extensions::new(),
};
let resp = router.ready().await.unwrap().call(req).await.unwrap();
match resp.inner {
Ok(McpResponse::ListTools(result)) => {
let names: Vec<&str> = result.tools.iter().map(|t| t.name.as_str()).collect();
assert_eq!(names, vec!["safe"]);
}
_ => panic!("Expected ListTools response"),
}
}
#[tokio::test]
async fn test_disable_tool_blocks_call() {
let dangerous = ToolBuilder::new("dangerous")
.description("Dangerous tool")
.handler(|_: AddInput| async { Ok(CallToolResult::text("ran")) })
.build();
let mut router = McpRouter::new().tool(dangerous);
init_router(&mut router).await;
router.disable_tool("dangerous");
let req = RouterRequest {
id: RequestId::Number(2),
inner: McpRequest::CallTool(CallToolParams {
name: "dangerous".to_string(),
arguments: serde_json::json!({"a": 1, "b": 2}),
meta: None,
task: None,
}),
extensions: Extensions::new(),
};
let resp = router.ready().await.unwrap().call(req).await.unwrap();
let err = resp.inner.expect_err("disabled tool should error");
assert_eq!(err.code, crate::error::ErrorCode::MethodNotFound as i32);
}
#[tokio::test]
async fn test_enable_tool_restores_visibility() {
let tool = ToolBuilder::new("flippy")
.description("Toggleable tool")
.handler(|_: AddInput| async { Ok(CallToolResult::text("ran")) })
.build();
let mut router = McpRouter::new().tool(tool);
init_router(&mut router).await;
router.disable_tool("flippy");
router.enable_tool("flippy");
assert!(router.is_tool_enabled("flippy"));
let req = RouterRequest {
id: RequestId::Number(3),
inner: McpRequest::CallTool(CallToolParams {
name: "flippy".to_string(),
arguments: serde_json::json!({"a": 1, "b": 2}),
meta: None,
task: None,
}),
extensions: Extensions::new(),
};
let resp = router.ready().await.unwrap().call(req).await.unwrap();
match resp.inner {
Ok(McpResponse::CallTool(result)) => {
assert_eq!(result.first_text(), Some("ran"));
}
_ => panic!("Expected CallTool response"),
}
}
#[tokio::test]
async fn test_disable_propagates_through_fresh_session() {
let tool = ToolBuilder::new("shared")
.description("Shared across sessions")
.handler(|_: AddInput| async { Ok(CallToolResult::text("ok")) })
.build();
let router = McpRouter::new().tool(tool);
router.disable_tool("shared");
let mut child = router.with_fresh_session();
init_router(&mut child).await;
assert!(!child.is_tool_enabled("shared"));
let req = RouterRequest {
id: RequestId::Number(4),
inner: McpRequest::ListTools(ListToolsParams::default()),
extensions: Extensions::new(),
};
let resp = child.ready().await.unwrap().call(req).await.unwrap();
match resp.inner {
Ok(McpResponse::ListTools(result)) => {
assert!(result.tools.is_empty());
}
_ => panic!("Expected ListTools response"),
}
}
#[tokio::test]
async fn test_disable_resource_and_prompt() {
let resource = crate::resource::ResourceBuilder::new("file:///hidden.txt")
.name("hidden")
.text("secret");
let prompt = crate::prompt::PromptBuilder::new("hidden_prompt")
.description("hidden")
.user_message("hello");
let mut router = McpRouter::new().resource(resource).prompt(prompt);
init_router(&mut router).await;
router.disable_resource("file:///hidden.txt");
router.disable_prompt("hidden_prompt");
assert!(!router.is_resource_enabled("file:///hidden.txt"));
assert!(!router.is_prompt_enabled("hidden_prompt"));
let req = RouterRequest {
id: RequestId::Number(5),
inner: McpRequest::ListResources(ListResourcesParams::default()),
extensions: Extensions::new(),
};
let resp = router.ready().await.unwrap().call(req).await.unwrap();
match resp.inner {
Ok(McpResponse::ListResources(result)) => {
assert!(result.resources.is_empty());
}
_ => panic!("Expected ListResources response"),
}
let req = RouterRequest {
id: RequestId::Number(6),
inner: McpRequest::ReadResource(ReadResourceParams {
uri: "file:///hidden.txt".to_string(),
meta: None,
}),
extensions: Extensions::new(),
};
let resp = router.ready().await.unwrap().call(req).await.unwrap();
let err = resp.inner.expect_err("disabled resource should error");
assert_eq!(err.code, -32002);
let req = RouterRequest {
id: RequestId::Number(7),
inner: McpRequest::ListPrompts(ListPromptsParams::default()),
extensions: Extensions::new(),
};
let resp = router.ready().await.unwrap().call(req).await.unwrap();
match resp.inner {
Ok(McpResponse::ListPrompts(result)) => {
assert!(result.prompts.is_empty());
}
_ => panic!("Expected ListPrompts response"),
}
let req = RouterRequest {
id: RequestId::Number(8),
inner: McpRequest::GetPrompt(GetPromptParams {
name: "hidden_prompt".to_string(),
arguments: Default::default(),
meta: None,
}),
extensions: Extensions::new(),
};
let resp = router.ready().await.unwrap().call(req).await.unwrap();
let err = resp.inner.expect_err("disabled prompt should error");
assert_eq!(err.code, crate::error::ErrorCode::MethodNotFound as i32);
}
#[test]
fn test_router_request_new() {
let req = RouterRequest::new(RequestId::Number(1), McpRequest::Ping);
assert_eq!(req.id, RequestId::Number(1));
assert!(req.extensions.is_empty());
}
#[test]
fn test_with_inner_preserves_extensions() {
let mut req = RouterRequest::new(RequestId::Number(1), McpRequest::Ping);
req.extensions.insert(42u32);
let rewritten = req.with_inner(McpRequest::ListTools(Default::default()));
assert!(matches!(rewritten.inner, McpRequest::ListTools(_)));
assert_eq!(rewritten.id, RequestId::Number(1));
assert_eq!(rewritten.extensions.get::<u32>(), Some(&42));
}
#[test]
fn test_with_id_and_inner_preserves_extensions() {
let mut req = RouterRequest::new(RequestId::Number(1), McpRequest::Ping);
req.extensions.insert(String::from("token-abc"));
let rewritten = req.with_id_and_inner(
RequestId::Number(99),
McpRequest::ListResources(Default::default()),
);
assert_eq!(rewritten.id, RequestId::Number(99));
assert!(matches!(rewritten.inner, McpRequest::ListResources(_)));
assert_eq!(
rewritten.extensions.get::<String>(),
Some(&String::from("token-abc"))
);
}
#[test]
fn test_clone_with_inner_preserves_extensions() {
let mut req = RouterRequest::new(RequestId::Number(1), McpRequest::Ping);
req.extensions.insert(true);
let cloned = req.clone_with_inner(McpRequest::ListTools(Default::default()));
assert!(matches!(req.inner, McpRequest::Ping));
assert_eq!(req.extensions.get::<bool>(), Some(&true));
assert!(matches!(cloned.inner, McpRequest::ListTools(_)));
assert_eq!(cloned.extensions.get::<bool>(), Some(&true));
}
#[test]
fn test_router_response_is_error() {
let ok_resp = RouterResponse {
id: RequestId::Number(1),
inner: Ok(McpResponse::Pong(Default::default())),
};
assert!(!ok_resp.is_error());
let err_resp = RouterResponse {
id: RequestId::Number(2),
inner: Err(JsonRpcError::internal_error("boom")),
};
assert!(err_resp.is_error());
}
#[test]
fn test_extensions_len_and_is_empty() {
let mut ext = Extensions::new();
assert!(ext.is_empty());
assert_eq!(ext.len(), 0);
ext.insert(42u32);
assert!(!ext.is_empty());
assert_eq!(ext.len(), 1);
ext.insert(String::from("hello"));
assert_eq!(ext.len(), 2);
}
#[test]
fn test_router_response_serde_roundtrip() {
let response = RouterResponse {
id: RequestId::Number(1),
inner: Ok(McpResponse::Empty(EmptyResult {})),
};
let json = serde_json::to_string(&response).unwrap();
let deserialized: RouterResponse = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.id, RequestId::Number(1));
assert!(!deserialized.is_error());
let response = RouterResponse {
id: RequestId::String("req-2".into()),
inner: Err(JsonRpcError::method_not_found("unknown")),
};
let json = serde_json::to_string(&response).unwrap();
let deserialized: RouterResponse = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.id, RequestId::String("req-2".into()));
assert!(deserialized.is_error());
}
}