1use std::collections::{HashMap, HashSet};
7use std::future::Future;
8use std::pin::Pin;
9use std::sync::{Arc, RwLock};
10use std::task::{Context, Poll};
11
12use tower_service::Service;
13
14use crate::async_task::TaskStore;
15use crate::context::{
16 CancellationToken, ClientRequesterHandle, NotificationSender, RequestContext,
17 ServerNotification,
18};
19use crate::error::{Error, JsonRpcError, Result};
20use crate::filter::{PromptFilter, ResourceFilter, ToolFilter};
21use crate::prompt::Prompt;
22use crate::protocol::*;
23use crate::resource::{Resource, ResourceTemplate};
24use crate::session::SessionState;
25use crate::tool::Tool;
26
27pub type CompletionHandler = Arc<
29 dyn Fn(CompleteParams) -> Pin<Box<dyn Future<Output = Result<CompleteResult>> + Send>>
30 + Send
31 + Sync,
32>;
33
34#[derive(Clone)]
59pub struct McpRouter {
60 inner: Arc<McpRouterInner>,
61 session: SessionState,
62}
63
64impl std::fmt::Debug for McpRouter {
65 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66 f.debug_struct("McpRouter")
67 .field("server_name", &self.inner.server_name)
68 .field("server_version", &self.inner.server_version)
69 .field("tools_count", &self.inner.tools.len())
70 .field("resources_count", &self.inner.resources.len())
71 .field("prompts_count", &self.inner.prompts.len())
72 .field("session_phase", &self.session.phase())
73 .finish()
74 }
75}
76
77#[derive(Clone)]
79struct McpRouterInner {
80 server_name: String,
81 server_version: String,
82 server_title: Option<String>,
84 server_description: Option<String>,
86 server_icons: Option<Vec<ToolIcon>>,
88 server_website_url: Option<String>,
90 instructions: Option<String>,
91 tools: HashMap<String, Arc<Tool>>,
92 resources: HashMap<String, Arc<Resource>>,
93 resource_templates: Vec<Arc<ResourceTemplate>>,
95 prompts: HashMap<String, Arc<Prompt>>,
96 in_flight: Arc<RwLock<HashMap<RequestId, CancellationToken>>>,
98 notification_tx: Option<NotificationSender>,
100 client_requester: Option<ClientRequesterHandle>,
102 task_store: TaskStore,
104 subscriptions: Arc<RwLock<HashSet<String>>>,
106 completion_handler: Option<CompletionHandler>,
108 tool_filter: Option<ToolFilter>,
110 resource_filter: Option<ResourceFilter>,
112 prompt_filter: Option<PromptFilter>,
114 extensions: Arc<crate::context::Extensions>,
116}
117
118impl McpRouter {
119 pub fn new() -> Self {
121 Self {
122 inner: Arc::new(McpRouterInner {
123 server_name: "tower-mcp".to_string(),
124 server_version: env!("CARGO_PKG_VERSION").to_string(),
125 server_title: None,
126 server_description: None,
127 server_icons: None,
128 server_website_url: None,
129 instructions: None,
130 tools: HashMap::new(),
131 resources: HashMap::new(),
132 resource_templates: Vec::new(),
133 prompts: HashMap::new(),
134 in_flight: Arc::new(RwLock::new(HashMap::new())),
135 notification_tx: None,
136 client_requester: None,
137 task_store: TaskStore::new(),
138 subscriptions: Arc::new(RwLock::new(HashSet::new())),
139 extensions: Arc::new(crate::context::Extensions::new()),
140 completion_handler: None,
141 tool_filter: None,
142 resource_filter: None,
143 prompt_filter: None,
144 }),
145 session: SessionState::new(),
146 }
147 }
148
149 pub fn with_fresh_session(&self) -> Self {
157 Self {
158 inner: self.inner.clone(),
159 session: SessionState::new(),
160 }
161 }
162
163 pub fn task_store(&self) -> &TaskStore {
165 &self.inner.task_store
166 }
167
168 pub fn with_notification_sender(mut self, tx: NotificationSender) -> Self {
172 Arc::make_mut(&mut self.inner).notification_tx = Some(tx);
173 self
174 }
175
176 pub fn notification_sender(&self) -> Option<&NotificationSender> {
178 self.inner.notification_tx.as_ref()
179 }
180
181 pub fn with_client_requester(mut self, requester: ClientRequesterHandle) -> Self {
186 Arc::make_mut(&mut self.inner).client_requester = Some(requester);
187 self
188 }
189
190 pub fn client_requester(&self) -> Option<&ClientRequesterHandle> {
192 self.inner.client_requester.as_ref()
193 }
194
195 pub fn with_state<T: Clone + Send + Sync + 'static>(mut self, state: T) -> Self {
239 let inner = Arc::make_mut(&mut self.inner);
240 Arc::make_mut(&mut inner.extensions).insert(state);
241 self
242 }
243
244 pub fn with_extension<T: Clone + Send + Sync + 'static>(self, value: T) -> Self {
249 self.with_state(value)
250 }
251
252 pub fn extensions(&self) -> &crate::context::Extensions {
254 &self.inner.extensions
255 }
256
257 pub fn create_context(
262 &self,
263 request_id: RequestId,
264 progress_token: Option<ProgressToken>,
265 ) -> RequestContext {
266 let ctx = RequestContext::new(request_id.clone());
267
268 let ctx = if let Some(token) = progress_token {
270 ctx.with_progress_token(token)
271 } else {
272 ctx
273 };
274
275 let ctx = if let Some(tx) = &self.inner.notification_tx {
277 ctx.with_notification_sender(tx.clone())
278 } else {
279 ctx
280 };
281
282 let ctx = if let Some(requester) = &self.inner.client_requester {
284 ctx.with_client_requester(requester.clone())
285 } else {
286 ctx
287 };
288
289 let ctx = ctx.with_extensions(self.inner.extensions.clone());
291
292 let token = ctx.cancellation_token();
294 if let Ok(mut in_flight) = self.inner.in_flight.write() {
295 in_flight.insert(request_id, token);
296 }
297
298 ctx
299 }
300
301 pub fn complete_request(&self, request_id: &RequestId) {
303 if let Ok(mut in_flight) = self.inner.in_flight.write() {
304 in_flight.remove(request_id);
305 }
306 }
307
308 fn cancel_request(&self, request_id: &RequestId) -> bool {
310 let Ok(in_flight) = self.inner.in_flight.read() else {
311 return false;
312 };
313 let Some(token) = in_flight.get(request_id) else {
314 return false;
315 };
316 token.cancel();
317 true
318 }
319
320 pub fn server_info(mut self, name: impl Into<String>, version: impl Into<String>) -> Self {
322 let inner = Arc::make_mut(&mut self.inner);
323 inner.server_name = name.into();
324 inner.server_version = version.into();
325 self
326 }
327
328 pub fn instructions(mut self, instructions: impl Into<String>) -> Self {
330 Arc::make_mut(&mut self.inner).instructions = Some(instructions.into());
331 self
332 }
333
334 pub fn server_title(mut self, title: impl Into<String>) -> Self {
336 Arc::make_mut(&mut self.inner).server_title = Some(title.into());
337 self
338 }
339
340 pub fn server_description(mut self, description: impl Into<String>) -> Self {
342 Arc::make_mut(&mut self.inner).server_description = Some(description.into());
343 self
344 }
345
346 pub fn server_icons(mut self, icons: Vec<ToolIcon>) -> Self {
348 Arc::make_mut(&mut self.inner).server_icons = Some(icons);
349 self
350 }
351
352 pub fn server_website_url(mut self, url: impl Into<String>) -> Self {
354 Arc::make_mut(&mut self.inner).server_website_url = Some(url.into());
355 self
356 }
357
358 pub fn tool(mut self, tool: Tool) -> Self {
360 Arc::make_mut(&mut self.inner)
361 .tools
362 .insert(tool.name.clone(), Arc::new(tool));
363 self
364 }
365
366 pub fn resource(mut self, resource: Resource) -> Self {
368 Arc::make_mut(&mut self.inner)
369 .resources
370 .insert(resource.uri.clone(), Arc::new(resource));
371 self
372 }
373
374 pub fn resource_template(mut self, template: ResourceTemplate) -> Self {
405 Arc::make_mut(&mut self.inner)
406 .resource_templates
407 .push(Arc::new(template));
408 self
409 }
410
411 pub fn prompt(mut self, prompt: Prompt) -> Self {
413 Arc::make_mut(&mut self.inner)
414 .prompts
415 .insert(prompt.name.clone(), Arc::new(prompt));
416 self
417 }
418
419 pub fn tools(self, tools: impl IntoIterator<Item = Tool>) -> Self {
445 tools
446 .into_iter()
447 .fold(self, |router, tool| router.tool(tool))
448 }
449
450 pub fn resources(self, resources: impl IntoIterator<Item = Resource>) -> Self {
469 resources
470 .into_iter()
471 .fold(self, |router, resource| router.resource(resource))
472 }
473
474 pub fn prompts(self, prompts: impl IntoIterator<Item = Prompt>) -> Self {
493 prompts
494 .into_iter()
495 .fold(self, |router, prompt| router.prompt(prompt))
496 }
497
498 pub fn merge(mut self, other: McpRouter) -> Self {
545 let inner = Arc::make_mut(&mut self.inner);
546 let other_inner = other.inner;
547
548 for (name, tool) in &other_inner.tools {
550 inner.tools.insert(name.clone(), tool.clone());
551 }
552
553 for (uri, resource) in &other_inner.resources {
555 inner.resources.insert(uri.clone(), resource.clone());
556 }
557
558 for template in &other_inner.resource_templates {
561 inner.resource_templates.push(template.clone());
562 }
563
564 for (name, prompt) in &other_inner.prompts {
566 inner.prompts.insert(name.clone(), prompt.clone());
567 }
568
569 self
570 }
571
572 pub fn nest(mut self, prefix: impl Into<String>, other: McpRouter) -> Self {
614 let prefix = prefix.into();
615 let inner = Arc::make_mut(&mut self.inner);
616 let other_inner = other.inner;
617
618 for tool in other_inner.tools.values() {
620 let prefixed_tool = tool.with_name_prefix(&prefix);
621 inner
622 .tools
623 .insert(prefixed_tool.name.clone(), Arc::new(prefixed_tool));
624 }
625
626 for (uri, resource) in &other_inner.resources {
628 inner.resources.insert(uri.clone(), resource.clone());
629 }
630
631 for template in &other_inner.resource_templates {
633 inner.resource_templates.push(template.clone());
634 }
635
636 for (name, prompt) in &other_inner.prompts {
638 inner.prompts.insert(name.clone(), prompt.clone());
639 }
640
641 self
642 }
643
644 pub fn completion_handler<F, Fut>(mut self, handler: F) -> Self
671 where
672 F: Fn(CompleteParams) -> Fut + Send + Sync + 'static,
673 Fut: Future<Output = Result<CompleteResult>> + Send + 'static,
674 {
675 Arc::make_mut(&mut self.inner).completion_handler =
676 Some(Arc::new(move |params| Box::pin(handler(params))));
677 self
678 }
679
680 pub fn tool_filter(mut self, filter: ToolFilter) -> Self {
717 Arc::make_mut(&mut self.inner).tool_filter = Some(filter);
718 self
719 }
720
721 pub fn resource_filter(mut self, filter: ResourceFilter) -> Self {
752 Arc::make_mut(&mut self.inner).resource_filter = Some(filter);
753 self
754 }
755
756 pub fn prompt_filter(mut self, filter: PromptFilter) -> Self {
785 Arc::make_mut(&mut self.inner).prompt_filter = Some(filter);
786 self
787 }
788
789 pub fn session(&self) -> &SessionState {
791 &self.session
792 }
793
794 pub fn log(&self, params: LoggingMessageParams) -> bool {
816 let Some(tx) = &self.inner.notification_tx else {
817 return false;
818 };
819 tx.try_send(ServerNotification::LogMessage(params)).is_ok()
820 }
821
822 pub fn log_info(&self, message: &str) -> bool {
826 self.log(
827 LoggingMessageParams::new(LogLevel::Info)
828 .with_data(serde_json::json!({ "message": message })),
829 )
830 }
831
832 pub fn log_warning(&self, message: &str) -> bool {
834 self.log(
835 LoggingMessageParams::new(LogLevel::Warning)
836 .with_data(serde_json::json!({ "message": message })),
837 )
838 }
839
840 pub fn log_error(&self, message: &str) -> bool {
842 self.log(
843 LoggingMessageParams::new(LogLevel::Error)
844 .with_data(serde_json::json!({ "message": message })),
845 )
846 }
847
848 pub fn log_debug(&self, message: &str) -> bool {
850 self.log(
851 LoggingMessageParams::new(LogLevel::Debug)
852 .with_data(serde_json::json!({ "message": message })),
853 )
854 }
855
856 pub fn is_subscribed(&self, uri: &str) -> bool {
858 if let Ok(subs) = self.inner.subscriptions.read() {
859 return subs.contains(uri);
860 }
861 false
862 }
863
864 pub fn subscribed_uris(&self) -> Vec<String> {
866 if let Ok(subs) = self.inner.subscriptions.read() {
867 return subs.iter().cloned().collect();
868 }
869 Vec::new()
870 }
871
872 fn subscribe(&self, uri: &str) -> bool {
874 if let Ok(mut subs) = self.inner.subscriptions.write() {
875 return subs.insert(uri.to_string());
876 }
877 false
878 }
879
880 fn unsubscribe(&self, uri: &str) -> bool {
882 if let Ok(mut subs) = self.inner.subscriptions.write() {
883 return subs.remove(uri);
884 }
885 false
886 }
887
888 pub fn notify_resource_updated(&self, uri: &str) -> bool {
893 if !self.is_subscribed(uri) {
895 return false;
896 }
897
898 let Some(tx) = &self.inner.notification_tx else {
899 return false;
900 };
901 tx.try_send(ServerNotification::ResourceUpdated {
902 uri: uri.to_string(),
903 })
904 .is_ok()
905 }
906
907 pub fn notify_resources_list_changed(&self) -> bool {
911 let Some(tx) = &self.inner.notification_tx else {
912 return false;
913 };
914 tx.try_send(ServerNotification::ResourcesListChanged)
915 .is_ok()
916 }
917
918 fn capabilities(&self) -> ServerCapabilities {
920 let has_resources =
921 !self.inner.resources.is_empty() || !self.inner.resource_templates.is_empty();
922
923 ServerCapabilities {
924 tools: if self.inner.tools.is_empty() {
925 None
926 } else {
927 Some(ToolsCapability::default())
928 },
929 resources: if has_resources {
930 Some(ResourcesCapability {
931 subscribe: true,
932 ..Default::default()
933 })
934 } else {
935 None
936 },
937 prompts: if self.inner.prompts.is_empty() {
938 None
939 } else {
940 Some(PromptsCapability::default())
941 },
942 logging: if self.inner.notification_tx.is_some() {
944 Some(LoggingCapability::default())
945 } else {
946 None
947 },
948 tasks: Some(TasksCapability::default()),
950 completions: if self.inner.completion_handler.is_some() {
952 Some(CompletionsCapability::default())
953 } else {
954 None
955 },
956 }
957 }
958
959 async fn handle(&self, request_id: RequestId, request: McpRequest) -> Result<McpResponse> {
961 let method = request.method_name();
963 if !self.session.is_request_allowed(method) {
964 tracing::warn!(
965 method = %method,
966 phase = ?self.session.phase(),
967 "Request rejected: session not initialized"
968 );
969 return Err(Error::JsonRpc(JsonRpcError::invalid_request(format!(
970 "Session not initialized. Only 'initialize' and 'ping' are allowed before initialization. Got: {}",
971 method
972 ))));
973 }
974
975 match request {
976 McpRequest::Initialize(params) => {
977 tracing::info!(
978 client = %params.client_info.name,
979 version = %params.client_info.version,
980 "Client initializing"
981 );
982
983 let protocol_version = if crate::protocol::SUPPORTED_PROTOCOL_VERSIONS
986 .contains(¶ms.protocol_version.as_str())
987 {
988 params.protocol_version
989 } else {
990 crate::protocol::LATEST_PROTOCOL_VERSION.to_string()
991 };
992
993 self.session.mark_initializing();
995
996 Ok(McpResponse::Initialize(InitializeResult {
997 protocol_version,
998 capabilities: self.capabilities(),
999 server_info: Implementation {
1000 name: self.inner.server_name.clone(),
1001 version: self.inner.server_version.clone(),
1002 title: self.inner.server_title.clone(),
1003 description: self.inner.server_description.clone(),
1004 icons: self.inner.server_icons.clone(),
1005 website_url: self.inner.server_website_url.clone(),
1006 },
1007 instructions: self.inner.instructions.clone(),
1008 }))
1009 }
1010
1011 McpRequest::ListTools(_params) => {
1012 let tools: Vec<ToolDefinition> = self
1013 .inner
1014 .tools
1015 .values()
1016 .filter(|t| {
1017 self.inner
1019 .tool_filter
1020 .as_ref()
1021 .map(|f| f.is_visible(&self.session, t))
1022 .unwrap_or(true)
1023 })
1024 .map(|t| t.definition())
1025 .collect();
1026
1027 Ok(McpResponse::ListTools(ListToolsResult {
1028 tools,
1029 next_cursor: None,
1030 }))
1031 }
1032
1033 McpRequest::CallTool(params) => {
1034 let tool =
1035 self.inner.tools.get(¶ms.name).ok_or_else(|| {
1036 Error::JsonRpc(JsonRpcError::method_not_found(¶ms.name))
1037 })?;
1038
1039 if let Some(filter) = &self.inner.tool_filter {
1041 if !filter.is_visible(&self.session, tool) {
1042 return Err(filter.denial_error(¶ms.name));
1043 }
1044 }
1045
1046 let progress_token = params.meta.and_then(|m| m.progress_token);
1048 let ctx = self.create_context(request_id, progress_token);
1049
1050 tracing::debug!(tool = %params.name, "Calling tool");
1051 let result = tool.call_with_context(ctx, params.arguments).await;
1052
1053 Ok(McpResponse::CallTool(result))
1054 }
1055
1056 McpRequest::ListResources(_params) => {
1057 let resources: Vec<ResourceDefinition> = self
1058 .inner
1059 .resources
1060 .values()
1061 .filter(|r| {
1062 self.inner
1064 .resource_filter
1065 .as_ref()
1066 .map(|f| f.is_visible(&self.session, r))
1067 .unwrap_or(true)
1068 })
1069 .map(|r| r.definition())
1070 .collect();
1071
1072 Ok(McpResponse::ListResources(ListResourcesResult {
1073 resources,
1074 next_cursor: None,
1075 }))
1076 }
1077
1078 McpRequest::ListResourceTemplates(_params) => {
1079 let resource_templates: Vec<ResourceTemplateDefinition> = self
1080 .inner
1081 .resource_templates
1082 .iter()
1083 .map(|t| t.definition())
1084 .collect();
1085
1086 Ok(McpResponse::ListResourceTemplates(
1087 ListResourceTemplatesResult {
1088 resource_templates,
1089 next_cursor: None,
1090 },
1091 ))
1092 }
1093
1094 McpRequest::ReadResource(params) => {
1095 if let Some(resource) = self.inner.resources.get(¶ms.uri) {
1097 if let Some(filter) = &self.inner.resource_filter {
1099 if !filter.is_visible(&self.session, resource) {
1100 return Err(filter.denial_error(¶ms.uri));
1101 }
1102 }
1103
1104 tracing::debug!(uri = %params.uri, "Reading static resource");
1105 let result = resource.read().await;
1106 return Ok(McpResponse::ReadResource(result));
1107 }
1108
1109 for template in &self.inner.resource_templates {
1111 if let Some(variables) = template.match_uri(¶ms.uri) {
1112 tracing::debug!(
1113 uri = %params.uri,
1114 template = %template.uri_template,
1115 "Reading resource via template"
1116 );
1117 let result = template.read(¶ms.uri, variables).await?;
1118 return Ok(McpResponse::ReadResource(result));
1119 }
1120 }
1121
1122 Err(Error::JsonRpc(JsonRpcError::resource_not_found(
1124 ¶ms.uri,
1125 )))
1126 }
1127
1128 McpRequest::SubscribeResource(params) => {
1129 if !self.inner.resources.contains_key(¶ms.uri) {
1131 return Err(Error::JsonRpc(JsonRpcError::resource_not_found(
1132 ¶ms.uri,
1133 )));
1134 }
1135
1136 tracing::debug!(uri = %params.uri, "Subscribing to resource");
1137 self.subscribe(¶ms.uri);
1138
1139 Ok(McpResponse::SubscribeResource(EmptyResult {}))
1140 }
1141
1142 McpRequest::UnsubscribeResource(params) => {
1143 if !self.inner.resources.contains_key(¶ms.uri) {
1145 return Err(Error::JsonRpc(JsonRpcError::resource_not_found(
1146 ¶ms.uri,
1147 )));
1148 }
1149
1150 tracing::debug!(uri = %params.uri, "Unsubscribing from resource");
1151 self.unsubscribe(¶ms.uri);
1152
1153 Ok(McpResponse::UnsubscribeResource(EmptyResult {}))
1154 }
1155
1156 McpRequest::ListPrompts(_params) => {
1157 let prompts: Vec<PromptDefinition> = self
1158 .inner
1159 .prompts
1160 .values()
1161 .filter(|p| {
1162 self.inner
1164 .prompt_filter
1165 .as_ref()
1166 .map(|f| f.is_visible(&self.session, p))
1167 .unwrap_or(true)
1168 })
1169 .map(|p| p.definition())
1170 .collect();
1171
1172 Ok(McpResponse::ListPrompts(ListPromptsResult {
1173 prompts,
1174 next_cursor: None,
1175 }))
1176 }
1177
1178 McpRequest::GetPrompt(params) => {
1179 let prompt = self.inner.prompts.get(¶ms.name).ok_or_else(|| {
1180 Error::JsonRpc(JsonRpcError::method_not_found(&format!(
1181 "Prompt not found: {}",
1182 params.name
1183 )))
1184 })?;
1185
1186 if let Some(filter) = &self.inner.prompt_filter {
1188 if !filter.is_visible(&self.session, prompt) {
1189 return Err(filter.denial_error(¶ms.name));
1190 }
1191 }
1192
1193 tracing::debug!(name = %params.name, "Getting prompt");
1194 let result = prompt.get(params.arguments).await?;
1195
1196 Ok(McpResponse::GetPrompt(result))
1197 }
1198
1199 McpRequest::Ping => Ok(McpResponse::Pong(EmptyResult {})),
1200
1201 McpRequest::EnqueueTask(params) => {
1202 let tool = self.inner.tools.get(¶ms.tool_name).ok_or_else(|| {
1204 Error::JsonRpc(JsonRpcError::method_not_found(&format!(
1205 "Tool not found: {}",
1206 params.tool_name
1207 )))
1208 })?;
1209
1210 let (task_id, cancellation_token) = self.inner.task_store.create_task(
1212 ¶ms.tool_name,
1213 params.arguments.clone(),
1214 params.ttl,
1215 );
1216
1217 tracing::info!(task_id = %task_id, tool = %params.tool_name, "Enqueued async task");
1218
1219 let ctx = self.create_context(request_id, None);
1221
1222 let task_store = self.inner.task_store.clone();
1224 let tool = tool.clone();
1225 let arguments = params.arguments;
1226 let task_id_clone = task_id.clone();
1227
1228 tokio::spawn(async move {
1229 if cancellation_token.is_cancelled() {
1231 tracing::debug!(task_id = %task_id_clone, "Task cancelled before execution");
1232 return;
1233 }
1234
1235 let result = tool.call_with_context(ctx, arguments).await;
1237
1238 if cancellation_token.is_cancelled() {
1239 tracing::debug!(task_id = %task_id_clone, "Task cancelled during execution");
1240 } else if result.is_error {
1241 let error_msg = result.first_text().unwrap_or("Tool execution failed");
1243 task_store.fail_task(&task_id_clone, error_msg);
1244 tracing::warn!(task_id = %task_id_clone, error = %error_msg, "Task failed");
1245 } else {
1246 task_store.complete_task(&task_id_clone, result);
1247 tracing::debug!(task_id = %task_id_clone, "Task completed successfully");
1248 }
1249 });
1250
1251 Ok(McpResponse::EnqueueTask(EnqueueTaskResult {
1252 task_id,
1253 status: TaskStatus::Working,
1254 poll_interval: Some(2),
1255 }))
1256 }
1257
1258 McpRequest::ListTasks(params) => {
1259 let tasks = self.inner.task_store.list_tasks(params.status);
1260
1261 Ok(McpResponse::ListTasks(ListTasksResult {
1262 tasks,
1263 next_cursor: None,
1264 }))
1265 }
1266
1267 McpRequest::GetTaskInfo(params) => {
1268 let task = self
1269 .inner
1270 .task_store
1271 .get_task(¶ms.task_id)
1272 .ok_or_else(|| {
1273 Error::JsonRpc(JsonRpcError::invalid_params(format!(
1274 "Task not found: {}",
1275 params.task_id
1276 )))
1277 })?;
1278
1279 Ok(McpResponse::GetTaskInfo(task))
1280 }
1281
1282 McpRequest::GetTaskResult(params) => {
1283 let (status, result, error) = self
1284 .inner
1285 .task_store
1286 .get_task_full(¶ms.task_id)
1287 .ok_or_else(|| {
1288 Error::JsonRpc(JsonRpcError::invalid_params(format!(
1289 "Task not found: {}",
1290 params.task_id
1291 )))
1292 })?;
1293
1294 Ok(McpResponse::GetTaskResult(GetTaskResultResult {
1295 task_id: params.task_id,
1296 status,
1297 result,
1298 error,
1299 }))
1300 }
1301
1302 McpRequest::CancelTask(params) => {
1303 let status = self
1304 .inner
1305 .task_store
1306 .cancel_task(¶ms.task_id, params.reason.as_deref())
1307 .ok_or_else(|| {
1308 Error::JsonRpc(JsonRpcError::invalid_params(format!(
1309 "Task not found: {}",
1310 params.task_id
1311 )))
1312 })?;
1313
1314 let cancelled = status == TaskStatus::Cancelled;
1315
1316 Ok(McpResponse::CancelTask(CancelTaskResult {
1317 cancelled,
1318 status,
1319 }))
1320 }
1321
1322 McpRequest::SetLoggingLevel(params) => {
1323 tracing::debug!(level = ?params.level, "Client set logging level");
1327 Ok(McpResponse::SetLoggingLevel(EmptyResult {}))
1328 }
1329
1330 McpRequest::Complete(params) => {
1331 tracing::debug!(
1332 reference = ?params.reference,
1333 argument = %params.argument.name,
1334 "Completion request"
1335 );
1336
1337 if let Some(ref handler) = self.inner.completion_handler {
1339 let result = handler(params).await?;
1340 Ok(McpResponse::Complete(result))
1341 } else {
1342 Ok(McpResponse::Complete(CompleteResult::new(vec![])))
1344 }
1345 }
1346
1347 McpRequest::Unknown { method, .. } => {
1348 Err(Error::JsonRpc(JsonRpcError::method_not_found(&method)))
1349 }
1350 }
1351 }
1352
1353 pub fn handle_notification(&self, notification: McpNotification) {
1355 match notification {
1356 McpNotification::Initialized => {
1357 if self.session.mark_initialized() {
1358 tracing::info!("Session initialized, entering operation phase");
1359 } else {
1360 tracing::warn!(
1361 "Received initialized notification in unexpected state: {:?}",
1362 self.session.phase()
1363 );
1364 }
1365 }
1366 McpNotification::Cancelled(params) => {
1367 if self.cancel_request(¶ms.request_id) {
1368 tracing::info!(
1369 request_id = ?params.request_id,
1370 reason = ?params.reason,
1371 "Request cancelled"
1372 );
1373 } else {
1374 tracing::debug!(
1375 request_id = ?params.request_id,
1376 reason = ?params.reason,
1377 "Cancellation requested for unknown request"
1378 );
1379 }
1380 }
1381 McpNotification::Progress(params) => {
1382 tracing::trace!(
1383 token = ?params.progress_token,
1384 progress = params.progress,
1385 total = ?params.total,
1386 "Progress notification"
1387 );
1388 }
1390 McpNotification::RootsListChanged => {
1391 tracing::info!("Client roots list changed");
1392 }
1395 McpNotification::Unknown { method, .. } => {
1396 tracing::debug!(method = %method, "Unknown notification received");
1397 }
1398 }
1399 }
1400}
1401
1402impl Default for McpRouter {
1403 fn default() -> Self {
1404 Self::new()
1405 }
1406}
1407
1408pub use crate::context::Extensions;
1414
1415#[derive(Debug, Clone)]
1417pub struct RouterRequest {
1418 pub id: RequestId,
1419 pub inner: McpRequest,
1420 pub extensions: Extensions,
1422}
1423
1424#[derive(Debug, Clone)]
1426pub struct RouterResponse {
1427 pub id: RequestId,
1428 pub inner: std::result::Result<McpResponse, JsonRpcError>,
1429}
1430
1431impl RouterResponse {
1432 pub fn into_jsonrpc(self) -> JsonRpcResponse {
1434 match self.inner {
1435 Ok(response) => match serde_json::to_value(response) {
1436 Ok(result) => JsonRpcResponse::result(self.id, result),
1437 Err(e) => {
1438 tracing::error!(error = %e, "Failed to serialize response");
1439 JsonRpcResponse::error(
1440 Some(self.id),
1441 JsonRpcError::internal_error(format!("Serialization error: {}", e)),
1442 )
1443 }
1444 },
1445 Err(error) => JsonRpcResponse::error(Some(self.id), error),
1446 }
1447 }
1448}
1449
1450impl Service<RouterRequest> for McpRouter {
1451 type Response = RouterResponse;
1452 type Error = std::convert::Infallible; type Future =
1454 Pin<Box<dyn Future<Output = std::result::Result<Self::Response, Self::Error>> + Send>>;
1455
1456 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
1457 Poll::Ready(Ok(()))
1458 }
1459
1460 fn call(&mut self, req: RouterRequest) -> Self::Future {
1461 let router = self.clone();
1462 let request_id = req.id.clone();
1463 Box::pin(async move {
1464 let result = router.handle(req.id, req.inner).await;
1465 router.complete_request(&request_id);
1467 Ok(RouterResponse {
1468 id: request_id,
1469 inner: result.map_err(|e| match e {
1474 Error::JsonRpc(err) => err,
1475 Error::Tool(err) => JsonRpcError::internal_error(err.to_string()),
1476 e => JsonRpcError::internal_error(e.to_string()),
1477 }),
1478 })
1479 })
1480 }
1481}
1482
1483#[cfg(test)]
1484mod tests {
1485 use super::*;
1486 use crate::extract::{Context, Json};
1487 use crate::jsonrpc::JsonRpcService;
1488 use crate::tool::ToolBuilder;
1489 use schemars::JsonSchema;
1490 use serde::Deserialize;
1491 use tower::ServiceExt;
1492
1493 #[derive(Debug, Deserialize, JsonSchema)]
1494 struct AddInput {
1495 a: i64,
1496 b: i64,
1497 }
1498
1499 async fn init_router(router: &mut McpRouter) {
1501 let init_req = RouterRequest {
1503 id: RequestId::Number(0),
1504 inner: McpRequest::Initialize(InitializeParams {
1505 protocol_version: "2025-11-25".to_string(),
1506 capabilities: ClientCapabilities {
1507 roots: None,
1508 sampling: None,
1509 elicitation: None,
1510 },
1511 client_info: Implementation {
1512 name: "test".to_string(),
1513 version: "1.0".to_string(),
1514 ..Default::default()
1515 },
1516 }),
1517 extensions: Extensions::new(),
1518 };
1519 let _ = router.ready().await.unwrap().call(init_req).await.unwrap();
1520 router.handle_notification(McpNotification::Initialized);
1522 }
1523
1524 #[tokio::test]
1525 async fn test_router_list_tools() {
1526 let add_tool = ToolBuilder::new("add")
1527 .description("Add two numbers")
1528 .handler(|input: AddInput| async move {
1529 Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1530 })
1531 .build()
1532 .expect("valid tool name");
1533
1534 let mut router = McpRouter::new().tool(add_tool);
1535
1536 init_router(&mut router).await;
1538
1539 let req = RouterRequest {
1540 id: RequestId::Number(1),
1541 inner: McpRequest::ListTools(ListToolsParams::default()),
1542 extensions: Extensions::new(),
1543 };
1544
1545 let resp = router.ready().await.unwrap().call(req).await.unwrap();
1546
1547 match resp.inner {
1548 Ok(McpResponse::ListTools(result)) => {
1549 assert_eq!(result.tools.len(), 1);
1550 assert_eq!(result.tools[0].name, "add");
1551 }
1552 _ => panic!("Expected ListTools response"),
1553 }
1554 }
1555
1556 #[tokio::test]
1557 async fn test_router_call_tool() {
1558 let add_tool = ToolBuilder::new("add")
1559 .description("Add two numbers")
1560 .handler(|input: AddInput| async move {
1561 Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1562 })
1563 .build()
1564 .expect("valid tool name");
1565
1566 let mut router = McpRouter::new().tool(add_tool);
1567
1568 init_router(&mut router).await;
1570
1571 let req = RouterRequest {
1572 id: RequestId::Number(1),
1573 inner: McpRequest::CallTool(CallToolParams {
1574 name: "add".to_string(),
1575 arguments: serde_json::json!({"a": 2, "b": 3}),
1576 meta: None,
1577 }),
1578 extensions: Extensions::new(),
1579 };
1580
1581 let resp = router.ready().await.unwrap().call(req).await.unwrap();
1582
1583 match resp.inner {
1584 Ok(McpResponse::CallTool(result)) => {
1585 assert!(!result.is_error);
1586 match &result.content[0] {
1588 Content::Text { text, .. } => assert_eq!(text, "5"),
1589 _ => panic!("Expected text content"),
1590 }
1591 }
1592 _ => panic!("Expected CallTool response"),
1593 }
1594 }
1595
1596 async fn init_jsonrpc_service(service: &mut JsonRpcService<McpRouter>, router: &McpRouter) {
1598 let init_req = JsonRpcRequest::new(0, "initialize").with_params(serde_json::json!({
1599 "protocolVersion": "2025-11-25",
1600 "capabilities": {},
1601 "clientInfo": { "name": "test", "version": "1.0" }
1602 }));
1603 let _ = service.call_single(init_req).await.unwrap();
1604 router.handle_notification(McpNotification::Initialized);
1605 }
1606
1607 #[tokio::test]
1608 async fn test_jsonrpc_service() {
1609 let add_tool = ToolBuilder::new("add")
1610 .description("Add two numbers")
1611 .handler(|input: AddInput| async move {
1612 Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1613 })
1614 .build()
1615 .expect("valid tool name");
1616
1617 let router = McpRouter::new().tool(add_tool);
1618 let mut service = JsonRpcService::new(router.clone());
1619
1620 init_jsonrpc_service(&mut service, &router).await;
1622
1623 let req = JsonRpcRequest::new(1, "tools/list");
1624
1625 let resp = service.call_single(req).await.unwrap();
1626
1627 match resp {
1628 JsonRpcResponse::Result(r) => {
1629 assert_eq!(r.id, RequestId::Number(1));
1630 let tools = r.result.get("tools").unwrap().as_array().unwrap();
1631 assert_eq!(tools.len(), 1);
1632 }
1633 JsonRpcResponse::Error(_) => panic!("Expected success response"),
1634 }
1635 }
1636
1637 #[tokio::test]
1638 async fn test_batch_request() {
1639 let add_tool = ToolBuilder::new("add")
1640 .description("Add two numbers")
1641 .handler(|input: AddInput| async move {
1642 Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1643 })
1644 .build()
1645 .expect("valid tool name");
1646
1647 let router = McpRouter::new().tool(add_tool);
1648 let mut service = JsonRpcService::new(router.clone());
1649
1650 init_jsonrpc_service(&mut service, &router).await;
1652
1653 let requests = vec![
1655 JsonRpcRequest::new(1, "tools/list"),
1656 JsonRpcRequest::new(2, "tools/call").with_params(serde_json::json!({
1657 "name": "add",
1658 "arguments": {"a": 10, "b": 20}
1659 })),
1660 JsonRpcRequest::new(3, "ping"),
1661 ];
1662
1663 let responses = service.call_batch(requests).await.unwrap();
1664
1665 assert_eq!(responses.len(), 3);
1666
1667 match &responses[0] {
1669 JsonRpcResponse::Result(r) => {
1670 assert_eq!(r.id, RequestId::Number(1));
1671 let tools = r.result.get("tools").unwrap().as_array().unwrap();
1672 assert_eq!(tools.len(), 1);
1673 }
1674 JsonRpcResponse::Error(_) => panic!("Expected success for tools/list"),
1675 }
1676
1677 match &responses[1] {
1679 JsonRpcResponse::Result(r) => {
1680 assert_eq!(r.id, RequestId::Number(2));
1681 let content = r.result.get("content").unwrap().as_array().unwrap();
1682 let text = content[0].get("text").unwrap().as_str().unwrap();
1683 assert_eq!(text, "30");
1684 }
1685 JsonRpcResponse::Error(_) => panic!("Expected success for tools/call"),
1686 }
1687
1688 match &responses[2] {
1690 JsonRpcResponse::Result(r) => {
1691 assert_eq!(r.id, RequestId::Number(3));
1692 }
1693 JsonRpcResponse::Error(_) => panic!("Expected success for ping"),
1694 }
1695 }
1696
1697 #[tokio::test]
1698 async fn test_empty_batch_error() {
1699 let router = McpRouter::new();
1700 let mut service = JsonRpcService::new(router);
1701
1702 let result = service.call_batch(vec![]).await;
1703 assert!(result.is_err());
1704 }
1705
1706 #[tokio::test]
1711 async fn test_progress_token_extraction() {
1712 use crate::context::{ServerNotification, notification_channel};
1713 use crate::protocol::ProgressToken;
1714 use std::sync::Arc;
1715 use std::sync::atomic::{AtomicBool, Ordering};
1716
1717 let progress_reported = Arc::new(AtomicBool::new(false));
1719 let progress_ref = progress_reported.clone();
1720
1721 let tool = ToolBuilder::new("progress_tool")
1723 .description("Tool that reports progress")
1724 .extractor_handler_typed::<_, _, _, AddInput>(
1725 (),
1726 move |ctx: Context, Json(_input): Json<AddInput>| {
1727 let reported = progress_ref.clone();
1728 async move {
1729 ctx.report_progress(50.0, Some(100.0), Some("Halfway"))
1731 .await;
1732 reported.store(true, Ordering::SeqCst);
1733 Ok(CallToolResult::text("done"))
1734 }
1735 },
1736 )
1737 .build()
1738 .expect("valid tool name");
1739
1740 let (tx, mut rx) = notification_channel(10);
1742 let router = McpRouter::new().with_notification_sender(tx).tool(tool);
1743 let mut service = JsonRpcService::new(router.clone());
1744
1745 init_jsonrpc_service(&mut service, &router).await;
1747
1748 let req = JsonRpcRequest::new(1, "tools/call").with_params(serde_json::json!({
1750 "name": "progress_tool",
1751 "arguments": {"a": 1, "b": 2},
1752 "_meta": {
1753 "progressToken": "test-token-123"
1754 }
1755 }));
1756
1757 let resp = service.call_single(req).await.unwrap();
1758
1759 match resp {
1761 JsonRpcResponse::Result(_) => {}
1762 JsonRpcResponse::Error(e) => panic!("Expected success, got error: {:?}", e),
1763 }
1764
1765 assert!(progress_reported.load(Ordering::SeqCst));
1767
1768 let notification = rx.try_recv().expect("Expected progress notification");
1770 match notification {
1771 ServerNotification::Progress(params) => {
1772 assert_eq!(
1773 params.progress_token,
1774 ProgressToken::String("test-token-123".to_string())
1775 );
1776 assert_eq!(params.progress, 50.0);
1777 assert_eq!(params.total, Some(100.0));
1778 assert_eq!(params.message.as_deref(), Some("Halfway"));
1779 }
1780 _ => panic!("Expected Progress notification"),
1781 }
1782 }
1783
1784 #[tokio::test]
1785 async fn test_tool_call_without_progress_token() {
1786 use crate::context::notification_channel;
1787 use std::sync::Arc;
1788 use std::sync::atomic::{AtomicBool, Ordering};
1789
1790 let progress_attempted = Arc::new(AtomicBool::new(false));
1791 let progress_ref = progress_attempted.clone();
1792
1793 let tool = ToolBuilder::new("no_token_tool")
1794 .description("Tool that tries to report progress without token")
1795 .extractor_handler_typed::<_, _, _, AddInput>(
1796 (),
1797 move |ctx: Context, Json(_input): Json<AddInput>| {
1798 let attempted = progress_ref.clone();
1799 async move {
1800 ctx.report_progress(50.0, Some(100.0), None).await;
1802 attempted.store(true, Ordering::SeqCst);
1803 Ok(CallToolResult::text("done"))
1804 }
1805 },
1806 )
1807 .build()
1808 .expect("valid tool name");
1809
1810 let (tx, mut rx) = notification_channel(10);
1811 let router = McpRouter::new().with_notification_sender(tx).tool(tool);
1812 let mut service = JsonRpcService::new(router.clone());
1813
1814 init_jsonrpc_service(&mut service, &router).await;
1815
1816 let req = JsonRpcRequest::new(1, "tools/call").with_params(serde_json::json!({
1818 "name": "no_token_tool",
1819 "arguments": {"a": 1, "b": 2}
1820 }));
1821
1822 let resp = service.call_single(req).await.unwrap();
1823 assert!(matches!(resp, JsonRpcResponse::Result(_)));
1824
1825 assert!(progress_attempted.load(Ordering::SeqCst));
1827
1828 assert!(rx.try_recv().is_err());
1830 }
1831
1832 #[tokio::test]
1833 async fn test_batch_errors_returned_not_dropped() {
1834 let add_tool = ToolBuilder::new("add")
1835 .description("Add two numbers")
1836 .handler(|input: AddInput| async move {
1837 Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1838 })
1839 .build()
1840 .expect("valid tool name");
1841
1842 let router = McpRouter::new().tool(add_tool);
1843 let mut service = JsonRpcService::new(router.clone());
1844
1845 init_jsonrpc_service(&mut service, &router).await;
1846
1847 let requests = vec![
1849 JsonRpcRequest::new(1, "tools/call").with_params(serde_json::json!({
1851 "name": "add",
1852 "arguments": {"a": 10, "b": 20}
1853 })),
1854 JsonRpcRequest::new(2, "tools/call").with_params(serde_json::json!({
1856 "name": "nonexistent_tool",
1857 "arguments": {}
1858 })),
1859 JsonRpcRequest::new(3, "ping"),
1861 ];
1862
1863 let responses = service.call_batch(requests).await.unwrap();
1864
1865 assert_eq!(responses.len(), 3);
1867
1868 match &responses[0] {
1870 JsonRpcResponse::Result(r) => {
1871 assert_eq!(r.id, RequestId::Number(1));
1872 }
1873 JsonRpcResponse::Error(_) => panic!("Expected success for first request"),
1874 }
1875
1876 match &responses[1] {
1878 JsonRpcResponse::Error(e) => {
1879 assert_eq!(e.id, Some(RequestId::Number(2)));
1880 assert!(e.error.message.contains("not found") || e.error.code == -32601);
1882 }
1883 JsonRpcResponse::Result(_) => panic!("Expected error for second request"),
1884 }
1885
1886 match &responses[2] {
1888 JsonRpcResponse::Result(r) => {
1889 assert_eq!(r.id, RequestId::Number(3));
1890 }
1891 JsonRpcResponse::Error(_) => panic!("Expected success for third request"),
1892 }
1893 }
1894
1895 #[tokio::test]
1900 async fn test_list_resource_templates() {
1901 use crate::resource::ResourceTemplateBuilder;
1902 use std::collections::HashMap;
1903
1904 let template = ResourceTemplateBuilder::new("file:///{path}")
1905 .name("Project Files")
1906 .description("Access project files")
1907 .handler(|uri: String, _vars: HashMap<String, String>| async move {
1908 Ok(ReadResourceResult {
1909 contents: vec![ResourceContent {
1910 uri,
1911 mime_type: None,
1912 text: None,
1913 blob: None,
1914 }],
1915 })
1916 });
1917
1918 let mut router = McpRouter::new().resource_template(template);
1919
1920 init_router(&mut router).await;
1922
1923 let req = RouterRequest {
1924 id: RequestId::Number(1),
1925 inner: McpRequest::ListResourceTemplates(ListResourceTemplatesParams::default()),
1926 extensions: Extensions::new(),
1927 };
1928
1929 let resp = router.ready().await.unwrap().call(req).await.unwrap();
1930
1931 match resp.inner {
1932 Ok(McpResponse::ListResourceTemplates(result)) => {
1933 assert_eq!(result.resource_templates.len(), 1);
1934 assert_eq!(result.resource_templates[0].uri_template, "file:///{path}");
1935 assert_eq!(result.resource_templates[0].name, "Project Files");
1936 }
1937 _ => panic!("Expected ListResourceTemplates response"),
1938 }
1939 }
1940
1941 #[tokio::test]
1942 async fn test_read_resource_via_template() {
1943 use crate::resource::ResourceTemplateBuilder;
1944 use std::collections::HashMap;
1945
1946 let template = ResourceTemplateBuilder::new("db://users/{id}")
1947 .name("User Records")
1948 .handler(|uri: String, vars: HashMap<String, String>| async move {
1949 let id = vars.get("id").unwrap().clone();
1950 Ok(ReadResourceResult {
1951 contents: vec![ResourceContent {
1952 uri,
1953 mime_type: Some("application/json".to_string()),
1954 text: Some(format!(r#"{{"id": "{}"}}"#, id)),
1955 blob: None,
1956 }],
1957 })
1958 });
1959
1960 let mut router = McpRouter::new().resource_template(template);
1961
1962 init_router(&mut router).await;
1964
1965 let req = RouterRequest {
1967 id: RequestId::Number(1),
1968 inner: McpRequest::ReadResource(ReadResourceParams {
1969 uri: "db://users/123".to_string(),
1970 }),
1971 extensions: Extensions::new(),
1972 };
1973
1974 let resp = router.ready().await.unwrap().call(req).await.unwrap();
1975
1976 match resp.inner {
1977 Ok(McpResponse::ReadResource(result)) => {
1978 assert_eq!(result.contents.len(), 1);
1979 assert_eq!(result.contents[0].uri, "db://users/123");
1980 assert!(result.contents[0].text.as_ref().unwrap().contains("123"));
1981 }
1982 _ => panic!("Expected ReadResource response"),
1983 }
1984 }
1985
1986 #[tokio::test]
1987 async fn test_static_resource_takes_precedence_over_template() {
1988 use crate::resource::{ResourceBuilder, ResourceTemplateBuilder};
1989 use std::collections::HashMap;
1990
1991 let template = ResourceTemplateBuilder::new("file:///{path}")
1993 .name("Files Template")
1994 .handler(|uri: String, _vars: HashMap<String, String>| async move {
1995 Ok(ReadResourceResult {
1996 contents: vec![ResourceContent {
1997 uri,
1998 mime_type: None,
1999 text: Some("from template".to_string()),
2000 blob: None,
2001 }],
2002 })
2003 });
2004
2005 let static_resource = ResourceBuilder::new("file:///README.md")
2007 .name("README")
2008 .text("from static resource");
2009
2010 let mut router = McpRouter::new()
2011 .resource_template(template)
2012 .resource(static_resource);
2013
2014 init_router(&mut router).await;
2016
2017 let req = RouterRequest {
2019 id: RequestId::Number(1),
2020 inner: McpRequest::ReadResource(ReadResourceParams {
2021 uri: "file:///README.md".to_string(),
2022 }),
2023 extensions: Extensions::new(),
2024 };
2025
2026 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2027
2028 match resp.inner {
2029 Ok(McpResponse::ReadResource(result)) => {
2030 assert_eq!(
2032 result.contents[0].text.as_deref(),
2033 Some("from static resource")
2034 );
2035 }
2036 _ => panic!("Expected ReadResource response"),
2037 }
2038 }
2039
2040 #[tokio::test]
2041 async fn test_resource_not_found_when_no_match() {
2042 use crate::resource::ResourceTemplateBuilder;
2043 use std::collections::HashMap;
2044
2045 let template = ResourceTemplateBuilder::new("db://users/{id}")
2046 .name("Users")
2047 .handler(|uri: String, _vars: HashMap<String, String>| async move {
2048 Ok(ReadResourceResult {
2049 contents: vec![ResourceContent {
2050 uri,
2051 mime_type: None,
2052 text: None,
2053 blob: None,
2054 }],
2055 })
2056 });
2057
2058 let mut router = McpRouter::new().resource_template(template);
2059
2060 init_router(&mut router).await;
2062
2063 let req = RouterRequest {
2065 id: RequestId::Number(1),
2066 inner: McpRequest::ReadResource(ReadResourceParams {
2067 uri: "db://posts/123".to_string(),
2068 }),
2069 extensions: Extensions::new(),
2070 };
2071
2072 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2073
2074 match resp.inner {
2075 Err(err) => {
2076 assert!(err.message.contains("not found"));
2077 }
2078 Ok(_) => panic!("Expected error for non-matching URI"),
2079 }
2080 }
2081
2082 #[tokio::test]
2083 async fn test_capabilities_include_resources_with_only_templates() {
2084 use crate::resource::ResourceTemplateBuilder;
2085 use std::collections::HashMap;
2086
2087 let template = ResourceTemplateBuilder::new("file:///{path}")
2088 .name("Files")
2089 .handler(|uri: String, _vars: HashMap<String, String>| async move {
2090 Ok(ReadResourceResult {
2091 contents: vec![ResourceContent {
2092 uri,
2093 mime_type: None,
2094 text: None,
2095 blob: None,
2096 }],
2097 })
2098 });
2099
2100 let mut router = McpRouter::new().resource_template(template);
2101
2102 let init_req = RouterRequest {
2104 id: RequestId::Number(0),
2105 inner: McpRequest::Initialize(InitializeParams {
2106 protocol_version: "2025-11-25".to_string(),
2107 capabilities: ClientCapabilities {
2108 roots: None,
2109 sampling: None,
2110 elicitation: None,
2111 },
2112 client_info: Implementation {
2113 name: "test".to_string(),
2114 version: "1.0".to_string(),
2115 ..Default::default()
2116 },
2117 }),
2118 extensions: Extensions::new(),
2119 };
2120 let resp = router.ready().await.unwrap().call(init_req).await.unwrap();
2121
2122 match resp.inner {
2123 Ok(McpResponse::Initialize(result)) => {
2124 assert!(result.capabilities.resources.is_some());
2126 }
2127 _ => panic!("Expected Initialize response"),
2128 }
2129 }
2130
2131 #[tokio::test]
2136 async fn test_log_sends_notification() {
2137 use crate::context::notification_channel;
2138
2139 let (tx, mut rx) = notification_channel(10);
2140 let router = McpRouter::new().with_notification_sender(tx);
2141
2142 let sent = router.log_info("Test message");
2144 assert!(sent);
2145
2146 let notification = rx.try_recv().unwrap();
2148 match notification {
2149 ServerNotification::LogMessage(params) => {
2150 assert_eq!(params.level, LogLevel::Info);
2151 let data = params.data.unwrap();
2152 assert_eq!(
2153 data.get("message").unwrap().as_str().unwrap(),
2154 "Test message"
2155 );
2156 }
2157 _ => panic!("Expected LogMessage notification"),
2158 }
2159 }
2160
2161 #[tokio::test]
2162 async fn test_log_with_custom_params() {
2163 use crate::context::notification_channel;
2164
2165 let (tx, mut rx) = notification_channel(10);
2166 let router = McpRouter::new().with_notification_sender(tx);
2167
2168 let params = LoggingMessageParams::new(LogLevel::Error)
2170 .with_logger("database")
2171 .with_data(serde_json::json!({
2172 "error": "Connection failed",
2173 "host": "localhost"
2174 }));
2175
2176 let sent = router.log(params);
2177 assert!(sent);
2178
2179 let notification = rx.try_recv().unwrap();
2180 match notification {
2181 ServerNotification::LogMessage(params) => {
2182 assert_eq!(params.level, LogLevel::Error);
2183 assert_eq!(params.logger.as_deref(), Some("database"));
2184 let data = params.data.unwrap();
2185 assert_eq!(
2186 data.get("error").unwrap().as_str().unwrap(),
2187 "Connection failed"
2188 );
2189 }
2190 _ => panic!("Expected LogMessage notification"),
2191 }
2192 }
2193
2194 #[tokio::test]
2195 async fn test_log_without_channel_returns_false() {
2196 let router = McpRouter::new();
2198
2199 assert!(!router.log_info("Test"));
2201 assert!(!router.log_warning("Test"));
2202 assert!(!router.log_error("Test"));
2203 assert!(!router.log_debug("Test"));
2204 }
2205
2206 #[tokio::test]
2207 async fn test_logging_capability_with_channel() {
2208 use crate::context::notification_channel;
2209
2210 let (tx, _rx) = notification_channel(10);
2211 let mut router = McpRouter::new().with_notification_sender(tx);
2212
2213 let init_req = RouterRequest {
2215 id: RequestId::Number(0),
2216 inner: McpRequest::Initialize(InitializeParams {
2217 protocol_version: "2025-11-25".to_string(),
2218 capabilities: ClientCapabilities {
2219 roots: None,
2220 sampling: None,
2221 elicitation: None,
2222 },
2223 client_info: Implementation {
2224 name: "test".to_string(),
2225 version: "1.0".to_string(),
2226 ..Default::default()
2227 },
2228 }),
2229 extensions: Extensions::new(),
2230 };
2231 let resp = router.ready().await.unwrap().call(init_req).await.unwrap();
2232
2233 match resp.inner {
2234 Ok(McpResponse::Initialize(result)) => {
2235 assert!(result.capabilities.logging.is_some());
2237 }
2238 _ => panic!("Expected Initialize response"),
2239 }
2240 }
2241
2242 #[tokio::test]
2243 async fn test_no_logging_capability_without_channel() {
2244 let mut router = McpRouter::new();
2245
2246 let init_req = RouterRequest {
2248 id: RequestId::Number(0),
2249 inner: McpRequest::Initialize(InitializeParams {
2250 protocol_version: "2025-11-25".to_string(),
2251 capabilities: ClientCapabilities {
2252 roots: None,
2253 sampling: None,
2254 elicitation: None,
2255 },
2256 client_info: Implementation {
2257 name: "test".to_string(),
2258 version: "1.0".to_string(),
2259 ..Default::default()
2260 },
2261 }),
2262 extensions: Extensions::new(),
2263 };
2264 let resp = router.ready().await.unwrap().call(init_req).await.unwrap();
2265
2266 match resp.inner {
2267 Ok(McpResponse::Initialize(result)) => {
2268 assert!(result.capabilities.logging.is_none());
2270 }
2271 _ => panic!("Expected Initialize response"),
2272 }
2273 }
2274
2275 #[tokio::test]
2280 async fn test_enqueue_task() {
2281 let add_tool = ToolBuilder::new("add")
2282 .description("Add two numbers")
2283 .handler(|input: AddInput| async move {
2284 Ok(CallToolResult::text(format!("{}", input.a + input.b)))
2285 })
2286 .build()
2287 .expect("valid tool name");
2288
2289 let mut router = McpRouter::new().tool(add_tool);
2290 init_router(&mut router).await;
2291
2292 let req = RouterRequest {
2293 id: RequestId::Number(1),
2294 inner: McpRequest::EnqueueTask(EnqueueTaskParams {
2295 tool_name: "add".to_string(),
2296 arguments: serde_json::json!({"a": 5, "b": 10}),
2297 ttl: None,
2298 }),
2299 extensions: Extensions::new(),
2300 };
2301
2302 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2303
2304 match resp.inner {
2305 Ok(McpResponse::EnqueueTask(result)) => {
2306 assert!(result.task_id.starts_with("task-"));
2307 assert_eq!(result.status, TaskStatus::Working);
2308 }
2309 _ => panic!("Expected EnqueueTask response"),
2310 }
2311 }
2312
2313 #[tokio::test]
2314 async fn test_list_tasks_empty() {
2315 let mut router = McpRouter::new();
2316 init_router(&mut router).await;
2317
2318 let req = RouterRequest {
2319 id: RequestId::Number(1),
2320 inner: McpRequest::ListTasks(ListTasksParams::default()),
2321 extensions: Extensions::new(),
2322 };
2323
2324 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2325
2326 match resp.inner {
2327 Ok(McpResponse::ListTasks(result)) => {
2328 assert!(result.tasks.is_empty());
2329 }
2330 _ => panic!("Expected ListTasks response"),
2331 }
2332 }
2333
2334 #[tokio::test]
2335 async fn test_task_lifecycle_complete() {
2336 let add_tool = ToolBuilder::new("add")
2337 .description("Add two numbers")
2338 .handler(|input: AddInput| async move {
2339 Ok(CallToolResult::text(format!("{}", input.a + input.b)))
2340 })
2341 .build()
2342 .expect("valid tool name");
2343
2344 let mut router = McpRouter::new().tool(add_tool);
2345 init_router(&mut router).await;
2346
2347 let req = RouterRequest {
2349 id: RequestId::Number(1),
2350 inner: McpRequest::EnqueueTask(EnqueueTaskParams {
2351 tool_name: "add".to_string(),
2352 arguments: serde_json::json!({"a": 7, "b": 8}),
2353 ttl: None,
2354 }),
2355 extensions: Extensions::new(),
2356 };
2357
2358 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2359 let task_id = match resp.inner {
2360 Ok(McpResponse::EnqueueTask(result)) => result.task_id,
2361 _ => panic!("Expected EnqueueTask response"),
2362 };
2363
2364 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
2366
2367 let req = RouterRequest {
2369 id: RequestId::Number(2),
2370 inner: McpRequest::GetTaskResult(GetTaskResultParams {
2371 task_id: task_id.clone(),
2372 }),
2373 extensions: Extensions::new(),
2374 };
2375
2376 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2377
2378 match resp.inner {
2379 Ok(McpResponse::GetTaskResult(result)) => {
2380 assert_eq!(result.task_id, task_id);
2381 assert_eq!(result.status, TaskStatus::Completed);
2382 assert!(result.result.is_some());
2383 assert!(result.error.is_none());
2384
2385 let tool_result = result.result.unwrap();
2387 match &tool_result.content[0] {
2388 Content::Text { text, .. } => assert_eq!(text, "15"),
2389 _ => panic!("Expected text content"),
2390 }
2391 }
2392 _ => panic!("Expected GetTaskResult response"),
2393 }
2394 }
2395
2396 #[tokio::test]
2397 async fn test_task_cancellation() {
2398 let slow_tool = ToolBuilder::new("slow")
2400 .description("Slow tool")
2401 .handler(|_input: serde_json::Value| async move {
2402 tokio::time::sleep(tokio::time::Duration::from_secs(60)).await;
2403 Ok(CallToolResult::text("done"))
2404 })
2405 .build()
2406 .expect("valid tool name");
2407
2408 let mut router = McpRouter::new().tool(slow_tool);
2409 init_router(&mut router).await;
2410
2411 let req = RouterRequest {
2413 id: RequestId::Number(1),
2414 inner: McpRequest::EnqueueTask(EnqueueTaskParams {
2415 tool_name: "slow".to_string(),
2416 arguments: serde_json::json!({}),
2417 ttl: None,
2418 }),
2419 extensions: Extensions::new(),
2420 };
2421
2422 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2423 let task_id = match resp.inner {
2424 Ok(McpResponse::EnqueueTask(result)) => result.task_id,
2425 _ => panic!("Expected EnqueueTask response"),
2426 };
2427
2428 let req = RouterRequest {
2430 id: RequestId::Number(2),
2431 inner: McpRequest::CancelTask(CancelTaskParams {
2432 task_id: task_id.clone(),
2433 reason: Some("Test cancellation".to_string()),
2434 }),
2435 extensions: Extensions::new(),
2436 };
2437
2438 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2439
2440 match resp.inner {
2441 Ok(McpResponse::CancelTask(result)) => {
2442 assert!(result.cancelled);
2443 assert_eq!(result.status, TaskStatus::Cancelled);
2444 }
2445 _ => panic!("Expected CancelTask response"),
2446 }
2447 }
2448
2449 #[tokio::test]
2450 async fn test_get_task_info() {
2451 let add_tool = ToolBuilder::new("add")
2452 .description("Add two numbers")
2453 .handler(|input: AddInput| async move {
2454 Ok(CallToolResult::text(format!("{}", input.a + input.b)))
2455 })
2456 .build()
2457 .expect("valid tool name");
2458
2459 let mut router = McpRouter::new().tool(add_tool);
2460 init_router(&mut router).await;
2461
2462 let req = RouterRequest {
2464 id: RequestId::Number(1),
2465 inner: McpRequest::EnqueueTask(EnqueueTaskParams {
2466 tool_name: "add".to_string(),
2467 arguments: serde_json::json!({"a": 1, "b": 2}),
2468 ttl: Some(600),
2469 }),
2470 extensions: Extensions::new(),
2471 };
2472
2473 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2474 let task_id = match resp.inner {
2475 Ok(McpResponse::EnqueueTask(result)) => result.task_id,
2476 _ => panic!("Expected EnqueueTask response"),
2477 };
2478
2479 let req = RouterRequest {
2481 id: RequestId::Number(2),
2482 inner: McpRequest::GetTaskInfo(GetTaskInfoParams {
2483 task_id: task_id.clone(),
2484 }),
2485 extensions: Extensions::new(),
2486 };
2487
2488 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2489
2490 match resp.inner {
2491 Ok(McpResponse::GetTaskInfo(info)) => {
2492 assert_eq!(info.task_id, task_id);
2493 assert!(info.created_at.contains('T')); assert_eq!(info.ttl, Some(600));
2495 }
2496 _ => panic!("Expected GetTaskInfo response"),
2497 }
2498 }
2499
2500 #[tokio::test]
2501 async fn test_enqueue_nonexistent_tool() {
2502 let mut router = McpRouter::new();
2503 init_router(&mut router).await;
2504
2505 let req = RouterRequest {
2506 id: RequestId::Number(1),
2507 inner: McpRequest::EnqueueTask(EnqueueTaskParams {
2508 tool_name: "nonexistent".to_string(),
2509 arguments: serde_json::json!({}),
2510 ttl: None,
2511 }),
2512 extensions: Extensions::new(),
2513 };
2514
2515 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2516
2517 match resp.inner {
2518 Err(e) => {
2519 assert!(e.message.contains("not found"));
2520 }
2521 _ => panic!("Expected error response"),
2522 }
2523 }
2524
2525 #[tokio::test]
2526 async fn test_get_nonexistent_task() {
2527 let mut router = McpRouter::new();
2528 init_router(&mut router).await;
2529
2530 let req = RouterRequest {
2531 id: RequestId::Number(1),
2532 inner: McpRequest::GetTaskInfo(GetTaskInfoParams {
2533 task_id: "task-999".to_string(),
2534 }),
2535 extensions: Extensions::new(),
2536 };
2537
2538 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2539
2540 match resp.inner {
2541 Err(e) => {
2542 assert!(e.message.contains("not found"));
2543 }
2544 _ => panic!("Expected error response"),
2545 }
2546 }
2547
2548 #[tokio::test]
2553 async fn test_subscribe_to_resource() {
2554 use crate::resource::ResourceBuilder;
2555
2556 let resource = ResourceBuilder::new("file:///test.txt")
2557 .name("Test File")
2558 .text("Hello");
2559
2560 let mut router = McpRouter::new().resource(resource);
2561 init_router(&mut router).await;
2562
2563 let req = RouterRequest {
2565 id: RequestId::Number(1),
2566 inner: McpRequest::SubscribeResource(SubscribeResourceParams {
2567 uri: "file:///test.txt".to_string(),
2568 }),
2569 extensions: Extensions::new(),
2570 };
2571
2572 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2573
2574 match resp.inner {
2575 Ok(McpResponse::SubscribeResource(_)) => {
2576 assert!(router.is_subscribed("file:///test.txt"));
2578 }
2579 _ => panic!("Expected SubscribeResource response"),
2580 }
2581 }
2582
2583 #[tokio::test]
2584 async fn test_unsubscribe_from_resource() {
2585 use crate::resource::ResourceBuilder;
2586
2587 let resource = ResourceBuilder::new("file:///test.txt")
2588 .name("Test File")
2589 .text("Hello");
2590
2591 let mut router = McpRouter::new().resource(resource);
2592 init_router(&mut router).await;
2593
2594 let req = RouterRequest {
2596 id: RequestId::Number(1),
2597 inner: McpRequest::SubscribeResource(SubscribeResourceParams {
2598 uri: "file:///test.txt".to_string(),
2599 }),
2600 extensions: Extensions::new(),
2601 };
2602 let _ = router.ready().await.unwrap().call(req).await.unwrap();
2603 assert!(router.is_subscribed("file:///test.txt"));
2604
2605 let req = RouterRequest {
2607 id: RequestId::Number(2),
2608 inner: McpRequest::UnsubscribeResource(UnsubscribeResourceParams {
2609 uri: "file:///test.txt".to_string(),
2610 }),
2611 extensions: Extensions::new(),
2612 };
2613
2614 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2615
2616 match resp.inner {
2617 Ok(McpResponse::UnsubscribeResource(_)) => {
2618 assert!(!router.is_subscribed("file:///test.txt"));
2620 }
2621 _ => panic!("Expected UnsubscribeResource response"),
2622 }
2623 }
2624
2625 #[tokio::test]
2626 async fn test_subscribe_nonexistent_resource() {
2627 let mut router = McpRouter::new();
2628 init_router(&mut router).await;
2629
2630 let req = RouterRequest {
2631 id: RequestId::Number(1),
2632 inner: McpRequest::SubscribeResource(SubscribeResourceParams {
2633 uri: "file:///nonexistent.txt".to_string(),
2634 }),
2635 extensions: Extensions::new(),
2636 };
2637
2638 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2639
2640 match resp.inner {
2641 Err(e) => {
2642 assert!(e.message.contains("not found"));
2643 }
2644 _ => panic!("Expected error response"),
2645 }
2646 }
2647
2648 #[tokio::test]
2649 async fn test_notify_resource_updated() {
2650 use crate::context::notification_channel;
2651 use crate::resource::ResourceBuilder;
2652
2653 let (tx, mut rx) = notification_channel(10);
2654
2655 let resource = ResourceBuilder::new("file:///test.txt")
2656 .name("Test File")
2657 .text("Hello");
2658
2659 let router = McpRouter::new()
2660 .resource(resource)
2661 .with_notification_sender(tx);
2662
2663 router.subscribe("file:///test.txt");
2665
2666 let sent = router.notify_resource_updated("file:///test.txt");
2668 assert!(sent);
2669
2670 let notification = rx.try_recv().unwrap();
2672 match notification {
2673 ServerNotification::ResourceUpdated { uri } => {
2674 assert_eq!(uri, "file:///test.txt");
2675 }
2676 _ => panic!("Expected ResourceUpdated notification"),
2677 }
2678 }
2679
2680 #[tokio::test]
2681 async fn test_notify_resource_updated_not_subscribed() {
2682 use crate::context::notification_channel;
2683 use crate::resource::ResourceBuilder;
2684
2685 let (tx, mut rx) = notification_channel(10);
2686
2687 let resource = ResourceBuilder::new("file:///test.txt")
2688 .name("Test File")
2689 .text("Hello");
2690
2691 let router = McpRouter::new()
2692 .resource(resource)
2693 .with_notification_sender(tx);
2694
2695 let sent = router.notify_resource_updated("file:///test.txt");
2697 assert!(!sent); assert!(rx.try_recv().is_err());
2701 }
2702
2703 #[tokio::test]
2704 async fn test_notify_resources_list_changed() {
2705 use crate::context::notification_channel;
2706
2707 let (tx, mut rx) = notification_channel(10);
2708 let router = McpRouter::new().with_notification_sender(tx);
2709
2710 let sent = router.notify_resources_list_changed();
2711 assert!(sent);
2712
2713 let notification = rx.try_recv().unwrap();
2714 match notification {
2715 ServerNotification::ResourcesListChanged => {}
2716 _ => panic!("Expected ResourcesListChanged notification"),
2717 }
2718 }
2719
2720 #[tokio::test]
2721 async fn test_subscribed_uris() {
2722 use crate::resource::ResourceBuilder;
2723
2724 let resource1 = ResourceBuilder::new("file:///a.txt").name("A").text("A");
2725
2726 let resource2 = ResourceBuilder::new("file:///b.txt").name("B").text("B");
2727
2728 let router = McpRouter::new().resource(resource1).resource(resource2);
2729
2730 router.subscribe("file:///a.txt");
2732 router.subscribe("file:///b.txt");
2733
2734 let uris = router.subscribed_uris();
2735 assert_eq!(uris.len(), 2);
2736 assert!(uris.contains(&"file:///a.txt".to_string()));
2737 assert!(uris.contains(&"file:///b.txt".to_string()));
2738 }
2739
2740 #[tokio::test]
2741 async fn test_subscription_capability_advertised() {
2742 use crate::resource::ResourceBuilder;
2743
2744 let resource = ResourceBuilder::new("file:///test.txt")
2745 .name("Test")
2746 .text("Hello");
2747
2748 let mut router = McpRouter::new().resource(resource);
2749
2750 let init_req = RouterRequest {
2752 id: RequestId::Number(0),
2753 inner: McpRequest::Initialize(InitializeParams {
2754 protocol_version: "2025-11-25".to_string(),
2755 capabilities: ClientCapabilities {
2756 roots: None,
2757 sampling: None,
2758 elicitation: None,
2759 },
2760 client_info: Implementation {
2761 name: "test".to_string(),
2762 version: "1.0".to_string(),
2763 ..Default::default()
2764 },
2765 }),
2766 extensions: Extensions::new(),
2767 };
2768 let resp = router.ready().await.unwrap().call(init_req).await.unwrap();
2769
2770 match resp.inner {
2771 Ok(McpResponse::Initialize(result)) => {
2772 let resources_cap = result.capabilities.resources.unwrap();
2774 assert!(resources_cap.subscribe);
2775 }
2776 _ => panic!("Expected Initialize response"),
2777 }
2778 }
2779
2780 #[tokio::test]
2781 async fn test_completion_handler() {
2782 let router = McpRouter::new()
2783 .server_info("test", "1.0")
2784 .completion_handler(|params: CompleteParams| async move {
2785 let prefix = ¶ms.argument.value;
2787 let suggestions: Vec<String> = vec!["alpha", "beta", "gamma"]
2788 .into_iter()
2789 .filter(|s| s.starts_with(prefix))
2790 .map(String::from)
2791 .collect();
2792 Ok(CompleteResult::new(suggestions))
2793 });
2794
2795 let init_req = RouterRequest {
2797 id: RequestId::Number(0),
2798 inner: McpRequest::Initialize(InitializeParams {
2799 protocol_version: "2025-11-25".to_string(),
2800 capabilities: ClientCapabilities::default(),
2801 client_info: Implementation {
2802 name: "test".to_string(),
2803 version: "1.0".to_string(),
2804 ..Default::default()
2805 },
2806 }),
2807 extensions: Extensions::new(),
2808 };
2809 let resp = router
2810 .clone()
2811 .ready()
2812 .await
2813 .unwrap()
2814 .call(init_req)
2815 .await
2816 .unwrap();
2817
2818 match resp.inner {
2820 Ok(McpResponse::Initialize(result)) => {
2821 assert!(result.capabilities.completions.is_some());
2822 }
2823 _ => panic!("Expected Initialize response"),
2824 }
2825
2826 router.handle_notification(McpNotification::Initialized);
2828
2829 let complete_req = RouterRequest {
2831 id: RequestId::Number(1),
2832 inner: McpRequest::Complete(CompleteParams {
2833 reference: CompletionReference::prompt("test-prompt"),
2834 argument: CompletionArgument::new("query", "al"),
2835 }),
2836 extensions: Extensions::new(),
2837 };
2838 let resp = router
2839 .clone()
2840 .ready()
2841 .await
2842 .unwrap()
2843 .call(complete_req)
2844 .await
2845 .unwrap();
2846
2847 match resp.inner {
2848 Ok(McpResponse::Complete(result)) => {
2849 assert_eq!(result.completion.values, vec!["alpha"]);
2850 }
2851 _ => panic!("Expected Complete response"),
2852 }
2853 }
2854
2855 #[tokio::test]
2856 async fn test_completion_without_handler_returns_empty() {
2857 let router = McpRouter::new().server_info("test", "1.0");
2858
2859 let init_req = RouterRequest {
2861 id: RequestId::Number(0),
2862 inner: McpRequest::Initialize(InitializeParams {
2863 protocol_version: "2025-11-25".to_string(),
2864 capabilities: ClientCapabilities::default(),
2865 client_info: Implementation {
2866 name: "test".to_string(),
2867 version: "1.0".to_string(),
2868 ..Default::default()
2869 },
2870 }),
2871 extensions: Extensions::new(),
2872 };
2873 let resp = router
2874 .clone()
2875 .ready()
2876 .await
2877 .unwrap()
2878 .call(init_req)
2879 .await
2880 .unwrap();
2881
2882 match resp.inner {
2884 Ok(McpResponse::Initialize(result)) => {
2885 assert!(result.capabilities.completions.is_none());
2886 }
2887 _ => panic!("Expected Initialize response"),
2888 }
2889
2890 router.handle_notification(McpNotification::Initialized);
2892
2893 let complete_req = RouterRequest {
2895 id: RequestId::Number(1),
2896 inner: McpRequest::Complete(CompleteParams {
2897 reference: CompletionReference::prompt("test-prompt"),
2898 argument: CompletionArgument::new("query", "al"),
2899 }),
2900 extensions: Extensions::new(),
2901 };
2902 let resp = router
2903 .clone()
2904 .ready()
2905 .await
2906 .unwrap()
2907 .call(complete_req)
2908 .await
2909 .unwrap();
2910
2911 match resp.inner {
2912 Ok(McpResponse::Complete(result)) => {
2913 assert!(result.completion.values.is_empty());
2914 }
2915 _ => panic!("Expected Complete response"),
2916 }
2917 }
2918
2919 #[tokio::test]
2920 async fn test_tool_filter_list() {
2921 use crate::filter::CapabilityFilter;
2922 use crate::tool::Tool;
2923
2924 let public_tool = ToolBuilder::new("public")
2925 .description("Public tool")
2926 .handler(|_: AddInput| async move { Ok(CallToolResult::text("public")) })
2927 .build()
2928 .expect("valid tool name");
2929
2930 let admin_tool = ToolBuilder::new("admin")
2931 .description("Admin tool")
2932 .handler(|_: AddInput| async move { Ok(CallToolResult::text("admin")) })
2933 .build()
2934 .expect("valid tool name");
2935
2936 let mut router = McpRouter::new()
2937 .tool(public_tool)
2938 .tool(admin_tool)
2939 .tool_filter(CapabilityFilter::new(|_, tool: &Tool| tool.name != "admin"));
2940
2941 init_router(&mut router).await;
2943
2944 let req = RouterRequest {
2945 id: RequestId::Number(1),
2946 inner: McpRequest::ListTools(ListToolsParams::default()),
2947 extensions: Extensions::new(),
2948 };
2949
2950 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2951
2952 match resp.inner {
2953 Ok(McpResponse::ListTools(result)) => {
2954 assert_eq!(result.tools.len(), 1);
2956 assert_eq!(result.tools[0].name, "public");
2957 }
2958 _ => panic!("Expected ListTools response"),
2959 }
2960 }
2961
2962 #[tokio::test]
2963 async fn test_tool_filter_call_denied() {
2964 use crate::filter::CapabilityFilter;
2965 use crate::tool::Tool;
2966
2967 let admin_tool = ToolBuilder::new("admin")
2968 .description("Admin tool")
2969 .handler(|_: AddInput| async move { Ok(CallToolResult::text("admin")) })
2970 .build()
2971 .expect("valid tool name");
2972
2973 let mut router = McpRouter::new()
2974 .tool(admin_tool)
2975 .tool_filter(CapabilityFilter::new(|_, _: &Tool| false)); init_router(&mut router).await;
2979
2980 let req = RouterRequest {
2981 id: RequestId::Number(1),
2982 inner: McpRequest::CallTool(CallToolParams {
2983 name: "admin".to_string(),
2984 arguments: serde_json::json!({"a": 1, "b": 2}),
2985 meta: None,
2986 }),
2987 extensions: Extensions::new(),
2988 };
2989
2990 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2991
2992 match resp.inner {
2994 Err(e) => {
2995 assert_eq!(e.code, -32601); }
2997 _ => panic!("Expected JsonRpc error"),
2998 }
2999 }
3000
3001 #[tokio::test]
3002 async fn test_tool_filter_call_allowed() {
3003 use crate::filter::CapabilityFilter;
3004 use crate::tool::Tool;
3005
3006 let public_tool = ToolBuilder::new("public")
3007 .description("Public tool")
3008 .handler(|input: AddInput| async move {
3009 Ok(CallToolResult::text(format!("{}", input.a + input.b)))
3010 })
3011 .build()
3012 .expect("valid tool name");
3013
3014 let mut router = McpRouter::new()
3015 .tool(public_tool)
3016 .tool_filter(CapabilityFilter::new(|_, _: &Tool| true)); init_router(&mut router).await;
3020
3021 let req = RouterRequest {
3022 id: RequestId::Number(1),
3023 inner: McpRequest::CallTool(CallToolParams {
3024 name: "public".to_string(),
3025 arguments: serde_json::json!({"a": 1, "b": 2}),
3026 meta: None,
3027 }),
3028 extensions: Extensions::new(),
3029 };
3030
3031 let resp = router.ready().await.unwrap().call(req).await.unwrap();
3032
3033 match resp.inner {
3034 Ok(McpResponse::CallTool(result)) => {
3035 assert!(!result.is_error);
3036 }
3037 _ => panic!("Expected CallTool response"),
3038 }
3039 }
3040
3041 #[tokio::test]
3042 async fn test_tool_filter_custom_denial() {
3043 use crate::filter::{CapabilityFilter, DenialBehavior};
3044 use crate::tool::Tool;
3045
3046 let admin_tool = ToolBuilder::new("admin")
3047 .description("Admin tool")
3048 .handler(|_: AddInput| async move { Ok(CallToolResult::text("admin")) })
3049 .build()
3050 .expect("valid tool name");
3051
3052 let mut router = McpRouter::new().tool(admin_tool).tool_filter(
3053 CapabilityFilter::new(|_, _: &Tool| false)
3054 .denial_behavior(DenialBehavior::Unauthorized),
3055 );
3056
3057 init_router(&mut router).await;
3059
3060 let req = RouterRequest {
3061 id: RequestId::Number(1),
3062 inner: McpRequest::CallTool(CallToolParams {
3063 name: "admin".to_string(),
3064 arguments: serde_json::json!({"a": 1, "b": 2}),
3065 meta: None,
3066 }),
3067 extensions: Extensions::new(),
3068 };
3069
3070 let resp = router.ready().await.unwrap().call(req).await.unwrap();
3071
3072 match resp.inner {
3074 Err(e) => {
3075 assert_eq!(e.code, -32007); assert!(e.message.contains("Unauthorized"));
3077 }
3078 _ => panic!("Expected JsonRpc error"),
3079 }
3080 }
3081
3082 #[tokio::test]
3083 async fn test_resource_filter_list() {
3084 use crate::filter::CapabilityFilter;
3085 use crate::resource::{Resource, ResourceBuilder};
3086
3087 let public_resource = ResourceBuilder::new("file:///public.txt")
3088 .name("Public File")
3089 .text("public content");
3090
3091 let secret_resource = ResourceBuilder::new("file:///secret.txt")
3092 .name("Secret File")
3093 .text("secret content");
3094
3095 let mut router = McpRouter::new()
3096 .resource(public_resource)
3097 .resource(secret_resource)
3098 .resource_filter(CapabilityFilter::new(|_, r: &Resource| {
3099 !r.name.contains("Secret")
3100 }));
3101
3102 init_router(&mut router).await;
3104
3105 let req = RouterRequest {
3106 id: RequestId::Number(1),
3107 inner: McpRequest::ListResources(ListResourcesParams::default()),
3108 extensions: Extensions::new(),
3109 };
3110
3111 let resp = router.ready().await.unwrap().call(req).await.unwrap();
3112
3113 match resp.inner {
3114 Ok(McpResponse::ListResources(result)) => {
3115 assert_eq!(result.resources.len(), 1);
3117 assert_eq!(result.resources[0].name, "Public File");
3118 }
3119 _ => panic!("Expected ListResources response"),
3120 }
3121 }
3122
3123 #[tokio::test]
3124 async fn test_resource_filter_read_denied() {
3125 use crate::filter::CapabilityFilter;
3126 use crate::resource::{Resource, ResourceBuilder};
3127
3128 let secret_resource = ResourceBuilder::new("file:///secret.txt")
3129 .name("Secret File")
3130 .text("secret content");
3131
3132 let mut router = McpRouter::new()
3133 .resource(secret_resource)
3134 .resource_filter(CapabilityFilter::new(|_, _: &Resource| false)); init_router(&mut router).await;
3138
3139 let req = RouterRequest {
3140 id: RequestId::Number(1),
3141 inner: McpRequest::ReadResource(ReadResourceParams {
3142 uri: "file:///secret.txt".to_string(),
3143 }),
3144 extensions: Extensions::new(),
3145 };
3146
3147 let resp = router.ready().await.unwrap().call(req).await.unwrap();
3148
3149 match resp.inner {
3151 Err(e) => {
3152 assert_eq!(e.code, -32601); }
3154 _ => panic!("Expected JsonRpc error"),
3155 }
3156 }
3157
3158 #[tokio::test]
3159 async fn test_resource_filter_read_allowed() {
3160 use crate::filter::CapabilityFilter;
3161 use crate::resource::{Resource, ResourceBuilder};
3162
3163 let public_resource = ResourceBuilder::new("file:///public.txt")
3164 .name("Public File")
3165 .text("public content");
3166
3167 let mut router = McpRouter::new()
3168 .resource(public_resource)
3169 .resource_filter(CapabilityFilter::new(|_, _: &Resource| true)); init_router(&mut router).await;
3173
3174 let req = RouterRequest {
3175 id: RequestId::Number(1),
3176 inner: McpRequest::ReadResource(ReadResourceParams {
3177 uri: "file:///public.txt".to_string(),
3178 }),
3179 extensions: Extensions::new(),
3180 };
3181
3182 let resp = router.ready().await.unwrap().call(req).await.unwrap();
3183
3184 match resp.inner {
3185 Ok(McpResponse::ReadResource(result)) => {
3186 assert_eq!(result.contents.len(), 1);
3187 assert_eq!(result.contents[0].text.as_deref(), Some("public content"));
3188 }
3189 _ => panic!("Expected ReadResource response"),
3190 }
3191 }
3192
3193 #[tokio::test]
3194 async fn test_resource_filter_custom_denial() {
3195 use crate::filter::{CapabilityFilter, DenialBehavior};
3196 use crate::resource::{Resource, ResourceBuilder};
3197
3198 let secret_resource = ResourceBuilder::new("file:///secret.txt")
3199 .name("Secret File")
3200 .text("secret content");
3201
3202 let mut router = McpRouter::new().resource(secret_resource).resource_filter(
3203 CapabilityFilter::new(|_, _: &Resource| false)
3204 .denial_behavior(DenialBehavior::Unauthorized),
3205 );
3206
3207 init_router(&mut router).await;
3209
3210 let req = RouterRequest {
3211 id: RequestId::Number(1),
3212 inner: McpRequest::ReadResource(ReadResourceParams {
3213 uri: "file:///secret.txt".to_string(),
3214 }),
3215 extensions: Extensions::new(),
3216 };
3217
3218 let resp = router.ready().await.unwrap().call(req).await.unwrap();
3219
3220 match resp.inner {
3222 Err(e) => {
3223 assert_eq!(e.code, -32007); assert!(e.message.contains("Unauthorized"));
3225 }
3226 _ => panic!("Expected JsonRpc error"),
3227 }
3228 }
3229
3230 #[tokio::test]
3231 async fn test_prompt_filter_list() {
3232 use crate::filter::CapabilityFilter;
3233 use crate::prompt::{Prompt, PromptBuilder};
3234
3235 let public_prompt = PromptBuilder::new("greeting")
3236 .description("A greeting")
3237 .user_message("Hello!");
3238
3239 let admin_prompt = PromptBuilder::new("system_debug")
3240 .description("Admin prompt")
3241 .user_message("Debug");
3242
3243 let mut router = McpRouter::new()
3244 .prompt(public_prompt)
3245 .prompt(admin_prompt)
3246 .prompt_filter(CapabilityFilter::new(|_, p: &Prompt| {
3247 !p.name.contains("system")
3248 }));
3249
3250 init_router(&mut router).await;
3252
3253 let req = RouterRequest {
3254 id: RequestId::Number(1),
3255 inner: McpRequest::ListPrompts(ListPromptsParams::default()),
3256 extensions: Extensions::new(),
3257 };
3258
3259 let resp = router.ready().await.unwrap().call(req).await.unwrap();
3260
3261 match resp.inner {
3262 Ok(McpResponse::ListPrompts(result)) => {
3263 assert_eq!(result.prompts.len(), 1);
3265 assert_eq!(result.prompts[0].name, "greeting");
3266 }
3267 _ => panic!("Expected ListPrompts response"),
3268 }
3269 }
3270
3271 #[tokio::test]
3272 async fn test_prompt_filter_get_denied() {
3273 use crate::filter::CapabilityFilter;
3274 use crate::prompt::{Prompt, PromptBuilder};
3275 use std::collections::HashMap;
3276
3277 let admin_prompt = PromptBuilder::new("system_debug")
3278 .description("Admin prompt")
3279 .user_message("Debug");
3280
3281 let mut router = McpRouter::new()
3282 .prompt(admin_prompt)
3283 .prompt_filter(CapabilityFilter::new(|_, _: &Prompt| false)); init_router(&mut router).await;
3287
3288 let req = RouterRequest {
3289 id: RequestId::Number(1),
3290 inner: McpRequest::GetPrompt(GetPromptParams {
3291 name: "system_debug".to_string(),
3292 arguments: HashMap::new(),
3293 }),
3294 extensions: Extensions::new(),
3295 };
3296
3297 let resp = router.ready().await.unwrap().call(req).await.unwrap();
3298
3299 match resp.inner {
3301 Err(e) => {
3302 assert_eq!(e.code, -32601); }
3304 _ => panic!("Expected JsonRpc error"),
3305 }
3306 }
3307
3308 #[tokio::test]
3309 async fn test_prompt_filter_get_allowed() {
3310 use crate::filter::CapabilityFilter;
3311 use crate::prompt::{Prompt, PromptBuilder};
3312 use std::collections::HashMap;
3313
3314 let public_prompt = PromptBuilder::new("greeting")
3315 .description("A greeting")
3316 .user_message("Hello!");
3317
3318 let mut router = McpRouter::new()
3319 .prompt(public_prompt)
3320 .prompt_filter(CapabilityFilter::new(|_, _: &Prompt| true)); init_router(&mut router).await;
3324
3325 let req = RouterRequest {
3326 id: RequestId::Number(1),
3327 inner: McpRequest::GetPrompt(GetPromptParams {
3328 name: "greeting".to_string(),
3329 arguments: HashMap::new(),
3330 }),
3331 extensions: Extensions::new(),
3332 };
3333
3334 let resp = router.ready().await.unwrap().call(req).await.unwrap();
3335
3336 match resp.inner {
3337 Ok(McpResponse::GetPrompt(result)) => {
3338 assert_eq!(result.messages.len(), 1);
3339 }
3340 _ => panic!("Expected GetPrompt response"),
3341 }
3342 }
3343
3344 #[tokio::test]
3345 async fn test_prompt_filter_custom_denial() {
3346 use crate::filter::{CapabilityFilter, DenialBehavior};
3347 use crate::prompt::{Prompt, PromptBuilder};
3348 use std::collections::HashMap;
3349
3350 let admin_prompt = PromptBuilder::new("system_debug")
3351 .description("Admin prompt")
3352 .user_message("Debug");
3353
3354 let mut router = McpRouter::new().prompt(admin_prompt).prompt_filter(
3355 CapabilityFilter::new(|_, _: &Prompt| false)
3356 .denial_behavior(DenialBehavior::Unauthorized),
3357 );
3358
3359 init_router(&mut router).await;
3361
3362 let req = RouterRequest {
3363 id: RequestId::Number(1),
3364 inner: McpRequest::GetPrompt(GetPromptParams {
3365 name: "system_debug".to_string(),
3366 arguments: HashMap::new(),
3367 }),
3368 extensions: Extensions::new(),
3369 };
3370
3371 let resp = router.ready().await.unwrap().call(req).await.unwrap();
3372
3373 match resp.inner {
3375 Err(e) => {
3376 assert_eq!(e.code, -32007); assert!(e.message.contains("Unauthorized"));
3378 }
3379 _ => panic!("Expected JsonRpc error"),
3380 }
3381 }
3382
3383 #[derive(Debug, Deserialize, JsonSchema)]
3388 struct StringInput {
3389 value: String,
3390 }
3391
3392 #[tokio::test]
3393 async fn test_router_merge_tools() {
3394 let tool_a = ToolBuilder::new("tool_a")
3396 .description("Tool A")
3397 .handler(|_: StringInput| async move { Ok(CallToolResult::text("A")) })
3398 .build()
3399 .unwrap();
3400
3401 let router_a = McpRouter::new().tool(tool_a);
3402
3403 let tool_b = ToolBuilder::new("tool_b")
3405 .description("Tool B")
3406 .handler(|_: StringInput| async move { Ok(CallToolResult::text("B")) })
3407 .build()
3408 .unwrap();
3409 let tool_c = ToolBuilder::new("tool_c")
3410 .description("Tool C")
3411 .handler(|_: StringInput| async move { Ok(CallToolResult::text("C")) })
3412 .build()
3413 .unwrap();
3414
3415 let router_b = McpRouter::new().tool(tool_b).tool(tool_c);
3416
3417 let mut merged = McpRouter::new()
3419 .server_info("merged", "1.0")
3420 .merge(router_a)
3421 .merge(router_b);
3422
3423 init_router(&mut merged).await;
3424
3425 let req = RouterRequest {
3427 id: RequestId::Number(1),
3428 inner: McpRequest::ListTools(ListToolsParams::default()),
3429 extensions: Extensions::new(),
3430 };
3431
3432 let resp = merged.ready().await.unwrap().call(req).await.unwrap();
3433
3434 match resp.inner {
3435 Ok(McpResponse::ListTools(result)) => {
3436 assert_eq!(result.tools.len(), 3);
3437 let names: Vec<&str> = result.tools.iter().map(|t| t.name.as_str()).collect();
3438 assert!(names.contains(&"tool_a"));
3439 assert!(names.contains(&"tool_b"));
3440 assert!(names.contains(&"tool_c"));
3441 }
3442 _ => panic!("Expected ListTools response"),
3443 }
3444 }
3445
3446 #[tokio::test]
3447 async fn test_router_merge_overwrites_duplicates() {
3448 let tool_v1 = ToolBuilder::new("shared")
3450 .description("Version 1")
3451 .handler(|_: StringInput| async move { Ok(CallToolResult::text("v1")) })
3452 .build()
3453 .unwrap();
3454
3455 let router_a = McpRouter::new().tool(tool_v1);
3456
3457 let tool_v2 = ToolBuilder::new("shared")
3459 .description("Version 2")
3460 .handler(|_: StringInput| async move { Ok(CallToolResult::text("v2")) })
3461 .build()
3462 .unwrap();
3463
3464 let router_b = McpRouter::new().tool(tool_v2);
3465
3466 let mut merged = McpRouter::new().merge(router_a).merge(router_b);
3468
3469 init_router(&mut merged).await;
3470
3471 let req = RouterRequest {
3472 id: RequestId::Number(1),
3473 inner: McpRequest::ListTools(ListToolsParams::default()),
3474 extensions: Extensions::new(),
3475 };
3476
3477 let resp = merged.ready().await.unwrap().call(req).await.unwrap();
3478
3479 match resp.inner {
3480 Ok(McpResponse::ListTools(result)) => {
3481 assert_eq!(result.tools.len(), 1);
3482 assert_eq!(result.tools[0].name, "shared");
3483 assert_eq!(result.tools[0].description.as_deref(), Some("Version 2"));
3484 }
3485 _ => panic!("Expected ListTools response"),
3486 }
3487 }
3488
3489 #[tokio::test]
3490 async fn test_router_merge_resources() {
3491 use crate::resource::ResourceBuilder;
3492
3493 let router_a = McpRouter::new().resource(
3495 ResourceBuilder::new("file:///a.txt")
3496 .name("File A")
3497 .text("content a"),
3498 );
3499
3500 let router_b = McpRouter::new().resource(
3501 ResourceBuilder::new("file:///b.txt")
3502 .name("File B")
3503 .text("content b"),
3504 );
3505
3506 let mut merged = McpRouter::new().merge(router_a).merge(router_b);
3507
3508 init_router(&mut merged).await;
3509
3510 let req = RouterRequest {
3511 id: RequestId::Number(1),
3512 inner: McpRequest::ListResources(ListResourcesParams::default()),
3513 extensions: Extensions::new(),
3514 };
3515
3516 let resp = merged.ready().await.unwrap().call(req).await.unwrap();
3517
3518 match resp.inner {
3519 Ok(McpResponse::ListResources(result)) => {
3520 assert_eq!(result.resources.len(), 2);
3521 let uris: Vec<&str> = result.resources.iter().map(|r| r.uri.as_str()).collect();
3522 assert!(uris.contains(&"file:///a.txt"));
3523 assert!(uris.contains(&"file:///b.txt"));
3524 }
3525 _ => panic!("Expected ListResources response"),
3526 }
3527 }
3528
3529 #[tokio::test]
3530 async fn test_router_merge_prompts() {
3531 use crate::prompt::PromptBuilder;
3532
3533 let router_a =
3534 McpRouter::new().prompt(PromptBuilder::new("prompt_a").user_message("Hello A"));
3535
3536 let router_b =
3537 McpRouter::new().prompt(PromptBuilder::new("prompt_b").user_message("Hello B"));
3538
3539 let mut merged = McpRouter::new().merge(router_a).merge(router_b);
3540
3541 init_router(&mut merged).await;
3542
3543 let req = RouterRequest {
3544 id: RequestId::Number(1),
3545 inner: McpRequest::ListPrompts(ListPromptsParams::default()),
3546 extensions: Extensions::new(),
3547 };
3548
3549 let resp = merged.ready().await.unwrap().call(req).await.unwrap();
3550
3551 match resp.inner {
3552 Ok(McpResponse::ListPrompts(result)) => {
3553 assert_eq!(result.prompts.len(), 2);
3554 let names: Vec<&str> = result.prompts.iter().map(|p| p.name.as_str()).collect();
3555 assert!(names.contains(&"prompt_a"));
3556 assert!(names.contains(&"prompt_b"));
3557 }
3558 _ => panic!("Expected ListPrompts response"),
3559 }
3560 }
3561
3562 #[tokio::test]
3563 async fn test_router_nest_prefixes_tools() {
3564 let tool_query = ToolBuilder::new("query")
3566 .description("Query the database")
3567 .handler(|_: StringInput| async move { Ok(CallToolResult::text("query result")) })
3568 .build()
3569 .unwrap();
3570 let tool_insert = ToolBuilder::new("insert")
3571 .description("Insert into database")
3572 .handler(|_: StringInput| async move { Ok(CallToolResult::text("insert result")) })
3573 .build()
3574 .unwrap();
3575
3576 let db_router = McpRouter::new().tool(tool_query).tool(tool_insert);
3577
3578 let mut router = McpRouter::new()
3580 .server_info("nested", "1.0")
3581 .nest("db", db_router);
3582
3583 init_router(&mut router).await;
3584
3585 let req = RouterRequest {
3586 id: RequestId::Number(1),
3587 inner: McpRequest::ListTools(ListToolsParams::default()),
3588 extensions: Extensions::new(),
3589 };
3590
3591 let resp = router.ready().await.unwrap().call(req).await.unwrap();
3592
3593 match resp.inner {
3594 Ok(McpResponse::ListTools(result)) => {
3595 assert_eq!(result.tools.len(), 2);
3596 let names: Vec<&str> = result.tools.iter().map(|t| t.name.as_str()).collect();
3597 assert!(names.contains(&"db.query"));
3598 assert!(names.contains(&"db.insert"));
3599 }
3600 _ => panic!("Expected ListTools response"),
3601 }
3602 }
3603
3604 #[tokio::test]
3605 async fn test_router_nest_call_prefixed_tool() {
3606 let tool = ToolBuilder::new("echo")
3607 .description("Echo input")
3608 .handler(|input: StringInput| async move { Ok(CallToolResult::text(&input.value)) })
3609 .build()
3610 .unwrap();
3611
3612 let nested_router = McpRouter::new().tool(tool);
3613
3614 let mut router = McpRouter::new().nest("api", nested_router);
3615
3616 init_router(&mut router).await;
3617
3618 let req = RouterRequest {
3620 id: RequestId::Number(1),
3621 inner: McpRequest::CallTool(CallToolParams {
3622 name: "api.echo".to_string(),
3623 arguments: serde_json::json!({"value": "hello world"}),
3624 meta: None,
3625 }),
3626 extensions: Extensions::new(),
3627 };
3628
3629 let resp = router.ready().await.unwrap().call(req).await.unwrap();
3630
3631 match resp.inner {
3632 Ok(McpResponse::CallTool(result)) => {
3633 assert!(!result.is_error);
3634 match &result.content[0] {
3635 Content::Text { text, .. } => assert_eq!(text, "hello world"),
3636 _ => panic!("Expected text content"),
3637 }
3638 }
3639 _ => panic!("Expected CallTool response"),
3640 }
3641 }
3642
3643 #[tokio::test]
3644 async fn test_router_multiple_nests() {
3645 let db_tool = ToolBuilder::new("query")
3646 .description("Database query")
3647 .handler(|_: StringInput| async move { Ok(CallToolResult::text("db")) })
3648 .build()
3649 .unwrap();
3650
3651 let api_tool = ToolBuilder::new("fetch")
3652 .description("API fetch")
3653 .handler(|_: StringInput| async move { Ok(CallToolResult::text("api")) })
3654 .build()
3655 .unwrap();
3656
3657 let db_router = McpRouter::new().tool(db_tool);
3658 let api_router = McpRouter::new().tool(api_tool);
3659
3660 let mut router = McpRouter::new()
3661 .nest("db", db_router)
3662 .nest("api", api_router);
3663
3664 init_router(&mut router).await;
3665
3666 let req = RouterRequest {
3667 id: RequestId::Number(1),
3668 inner: McpRequest::ListTools(ListToolsParams::default()),
3669 extensions: Extensions::new(),
3670 };
3671
3672 let resp = router.ready().await.unwrap().call(req).await.unwrap();
3673
3674 match resp.inner {
3675 Ok(McpResponse::ListTools(result)) => {
3676 assert_eq!(result.tools.len(), 2);
3677 let names: Vec<&str> = result.tools.iter().map(|t| t.name.as_str()).collect();
3678 assert!(names.contains(&"db.query"));
3679 assert!(names.contains(&"api.fetch"));
3680 }
3681 _ => panic!("Expected ListTools response"),
3682 }
3683 }
3684
3685 #[tokio::test]
3686 async fn test_router_merge_and_nest_combined() {
3687 let tool_a = ToolBuilder::new("local")
3689 .description("Local tool")
3690 .handler(|_: StringInput| async move { Ok(CallToolResult::text("local")) })
3691 .build()
3692 .unwrap();
3693
3694 let nested_tool = ToolBuilder::new("remote")
3695 .description("Remote tool")
3696 .handler(|_: StringInput| async move { Ok(CallToolResult::text("remote")) })
3697 .build()
3698 .unwrap();
3699
3700 let nested_router = McpRouter::new().tool(nested_tool);
3701
3702 let mut router = McpRouter::new()
3703 .tool(tool_a)
3704 .nest("external", nested_router);
3705
3706 init_router(&mut router).await;
3707
3708 let req = RouterRequest {
3709 id: RequestId::Number(1),
3710 inner: McpRequest::ListTools(ListToolsParams::default()),
3711 extensions: Extensions::new(),
3712 };
3713
3714 let resp = router.ready().await.unwrap().call(req).await.unwrap();
3715
3716 match resp.inner {
3717 Ok(McpResponse::ListTools(result)) => {
3718 assert_eq!(result.tools.len(), 2);
3719 let names: Vec<&str> = result.tools.iter().map(|t| t.name.as_str()).collect();
3720 assert!(names.contains(&"local"));
3721 assert!(names.contains(&"external.remote"));
3722 }
3723 _ => panic!("Expected ListTools response"),
3724 }
3725 }
3726
3727 #[tokio::test]
3728 async fn test_router_merge_preserves_server_info() {
3729 let child_router = McpRouter::new()
3730 .server_info("child", "2.0")
3731 .instructions("Child instructions");
3732
3733 let mut router = McpRouter::new()
3734 .server_info("parent", "1.0")
3735 .instructions("Parent instructions")
3736 .merge(child_router);
3737
3738 init_router(&mut router).await;
3739
3740 let init_req = RouterRequest {
3742 id: RequestId::Number(99),
3743 inner: McpRequest::Initialize(InitializeParams {
3744 protocol_version: "2025-11-25".to_string(),
3745 capabilities: ClientCapabilities::default(),
3746 client_info: Implementation {
3747 name: "test".to_string(),
3748 version: "1.0".to_string(),
3749 ..Default::default()
3750 },
3751 }),
3752 extensions: Extensions::new(),
3753 };
3754
3755 let child_router2 = McpRouter::new().server_info("child", "2.0");
3757 let mut fresh_router = McpRouter::new()
3758 .server_info("parent", "1.0")
3759 .merge(child_router2);
3760
3761 let resp = fresh_router
3762 .ready()
3763 .await
3764 .unwrap()
3765 .call(init_req)
3766 .await
3767 .unwrap();
3768
3769 match resp.inner {
3770 Ok(McpResponse::Initialize(result)) => {
3771 assert_eq!(result.server_info.name, "parent");
3772 assert_eq!(result.server_info.version, "1.0");
3773 }
3774 _ => panic!("Expected Initialize response"),
3775 }
3776 }
3777}