1use std::any::{Any, TypeId};
7use std::collections::{HashMap, HashSet};
8use std::future::Future;
9use std::pin::Pin;
10use std::sync::{Arc, RwLock};
11use std::task::{Context, Poll};
12
13use tower_service::Service;
14
15use crate::async_task::TaskStore;
16use crate::context::{
17 CancellationToken, ClientRequesterHandle, NotificationSender, RequestContext,
18 ServerNotification,
19};
20use crate::error::{Error, JsonRpcError, Result};
21use crate::prompt::Prompt;
22use crate::protocol::*;
23use crate::resource::{Resource, ResourceTemplate};
24use crate::session::SessionState;
25use crate::tool::Tool;
26
27pub type CompletionHandler = Arc<
29 dyn Fn(CompleteParams) -> Pin<Box<dyn Future<Output = Result<CompleteResult>> + Send>>
30 + Send
31 + Sync,
32>;
33
34#[derive(Clone)]
59pub struct McpRouter {
60 inner: Arc<McpRouterInner>,
61 session: SessionState,
62}
63
64impl std::fmt::Debug for McpRouter {
65 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66 f.debug_struct("McpRouter")
67 .field("server_name", &self.inner.server_name)
68 .field("server_version", &self.inner.server_version)
69 .field("tools_count", &self.inner.tools.len())
70 .field("resources_count", &self.inner.resources.len())
71 .field("prompts_count", &self.inner.prompts.len())
72 .field("session_phase", &self.session.phase())
73 .finish()
74 }
75}
76
77#[derive(Clone)]
79struct McpRouterInner {
80 server_name: String,
81 server_version: String,
82 server_title: Option<String>,
84 server_description: Option<String>,
86 server_icons: Option<Vec<ToolIcon>>,
88 server_website_url: Option<String>,
90 instructions: Option<String>,
91 tools: HashMap<String, Arc<Tool>>,
92 resources: HashMap<String, Arc<Resource>>,
93 resource_templates: Vec<Arc<ResourceTemplate>>,
95 prompts: HashMap<String, Arc<Prompt>>,
96 in_flight: Arc<RwLock<HashMap<RequestId, CancellationToken>>>,
98 notification_tx: Option<NotificationSender>,
100 client_requester: Option<ClientRequesterHandle>,
102 task_store: TaskStore,
104 subscriptions: Arc<RwLock<HashSet<String>>>,
106 completion_handler: Option<CompletionHandler>,
108}
109
110impl McpRouter {
111 pub fn new() -> Self {
113 Self {
114 inner: Arc::new(McpRouterInner {
115 server_name: "tower-mcp".to_string(),
116 server_version: env!("CARGO_PKG_VERSION").to_string(),
117 server_title: None,
118 server_description: None,
119 server_icons: None,
120 server_website_url: None,
121 instructions: None,
122 tools: HashMap::new(),
123 resources: HashMap::new(),
124 resource_templates: Vec::new(),
125 prompts: HashMap::new(),
126 in_flight: Arc::new(RwLock::new(HashMap::new())),
127 notification_tx: None,
128 client_requester: None,
129 task_store: TaskStore::new(),
130 subscriptions: Arc::new(RwLock::new(HashSet::new())),
131 completion_handler: None,
132 }),
133 session: SessionState::new(),
134 }
135 }
136
137 pub fn task_store(&self) -> &TaskStore {
139 &self.inner.task_store
140 }
141
142 pub fn with_notification_sender(mut self, tx: NotificationSender) -> Self {
146 Arc::make_mut(&mut self.inner).notification_tx = Some(tx);
147 self
148 }
149
150 pub fn notification_sender(&self) -> Option<&NotificationSender> {
152 self.inner.notification_tx.as_ref()
153 }
154
155 pub fn with_client_requester(mut self, requester: ClientRequesterHandle) -> Self {
160 Arc::make_mut(&mut self.inner).client_requester = Some(requester);
161 self
162 }
163
164 pub fn client_requester(&self) -> Option<&ClientRequesterHandle> {
166 self.inner.client_requester.as_ref()
167 }
168
169 pub fn create_context(
174 &self,
175 request_id: RequestId,
176 progress_token: Option<ProgressToken>,
177 ) -> RequestContext {
178 let ctx = RequestContext::new(request_id.clone());
179
180 let ctx = if let Some(token) = progress_token {
182 ctx.with_progress_token(token)
183 } else {
184 ctx
185 };
186
187 let ctx = if let Some(tx) = &self.inner.notification_tx {
189 ctx.with_notification_sender(tx.clone())
190 } else {
191 ctx
192 };
193
194 let ctx = if let Some(requester) = &self.inner.client_requester {
196 ctx.with_client_requester(requester.clone())
197 } else {
198 ctx
199 };
200
201 let token = ctx.cancellation_token();
203 if let Ok(mut in_flight) = self.inner.in_flight.write() {
204 in_flight.insert(request_id, token);
205 }
206
207 ctx
208 }
209
210 pub fn complete_request(&self, request_id: &RequestId) {
212 if let Ok(mut in_flight) = self.inner.in_flight.write() {
213 in_flight.remove(request_id);
214 }
215 }
216
217 fn cancel_request(&self, request_id: &RequestId) -> bool {
219 if let Ok(in_flight) = self.inner.in_flight.read()
220 && let Some(token) = in_flight.get(request_id)
221 {
222 token.cancel();
223 return true;
224 }
225 false
226 }
227
228 pub fn server_info(mut self, name: impl Into<String>, version: impl Into<String>) -> Self {
230 let inner = Arc::make_mut(&mut self.inner);
231 inner.server_name = name.into();
232 inner.server_version = version.into();
233 self
234 }
235
236 pub fn instructions(mut self, instructions: impl Into<String>) -> Self {
238 Arc::make_mut(&mut self.inner).instructions = Some(instructions.into());
239 self
240 }
241
242 pub fn server_title(mut self, title: impl Into<String>) -> Self {
244 Arc::make_mut(&mut self.inner).server_title = Some(title.into());
245 self
246 }
247
248 pub fn server_description(mut self, description: impl Into<String>) -> Self {
250 Arc::make_mut(&mut self.inner).server_description = Some(description.into());
251 self
252 }
253
254 pub fn server_icons(mut self, icons: Vec<ToolIcon>) -> Self {
256 Arc::make_mut(&mut self.inner).server_icons = Some(icons);
257 self
258 }
259
260 pub fn server_website_url(mut self, url: impl Into<String>) -> Self {
262 Arc::make_mut(&mut self.inner).server_website_url = Some(url.into());
263 self
264 }
265
266 pub fn tool(mut self, tool: Tool) -> Self {
268 Arc::make_mut(&mut self.inner)
269 .tools
270 .insert(tool.name.clone(), Arc::new(tool));
271 self
272 }
273
274 pub fn resource(mut self, resource: Resource) -> Self {
276 Arc::make_mut(&mut self.inner)
277 .resources
278 .insert(resource.uri.clone(), Arc::new(resource));
279 self
280 }
281
282 pub fn resource_template(mut self, template: ResourceTemplate) -> Self {
313 Arc::make_mut(&mut self.inner)
314 .resource_templates
315 .push(Arc::new(template));
316 self
317 }
318
319 pub fn prompt(mut self, prompt: Prompt) -> Self {
321 Arc::make_mut(&mut self.inner)
322 .prompts
323 .insert(prompt.name.clone(), Arc::new(prompt));
324 self
325 }
326
327 pub fn tools(self, tools: impl IntoIterator<Item = Tool>) -> Self {
353 tools
354 .into_iter()
355 .fold(self, |router, tool| router.tool(tool))
356 }
357
358 pub fn resources(self, resources: impl IntoIterator<Item = Resource>) -> Self {
377 resources
378 .into_iter()
379 .fold(self, |router, resource| router.resource(resource))
380 }
381
382 pub fn prompts(self, prompts: impl IntoIterator<Item = Prompt>) -> Self {
401 prompts
402 .into_iter()
403 .fold(self, |router, prompt| router.prompt(prompt))
404 }
405
406 pub fn completion_handler<F, Fut>(mut self, handler: F) -> Self
433 where
434 F: Fn(CompleteParams) -> Fut + Send + Sync + 'static,
435 Fut: Future<Output = Result<CompleteResult>> + Send + 'static,
436 {
437 Arc::make_mut(&mut self.inner).completion_handler =
438 Some(Arc::new(move |params| Box::pin(handler(params))));
439 self
440 }
441
442 pub fn session(&self) -> &SessionState {
444 &self.session
445 }
446
447 pub fn log(&self, params: LoggingMessageParams) -> bool {
469 if let Some(tx) = &self.inner.notification_tx
470 && tx.try_send(ServerNotification::LogMessage(params)).is_ok()
471 {
472 return true;
473 }
474 false
475 }
476
477 pub fn log_info(&self, message: &str) -> bool {
481 self.log(
482 LoggingMessageParams::new(LogLevel::Info)
483 .with_data(serde_json::json!({ "message": message })),
484 )
485 }
486
487 pub fn log_warning(&self, message: &str) -> bool {
489 self.log(
490 LoggingMessageParams::new(LogLevel::Warning)
491 .with_data(serde_json::json!({ "message": message })),
492 )
493 }
494
495 pub fn log_error(&self, message: &str) -> bool {
497 self.log(
498 LoggingMessageParams::new(LogLevel::Error)
499 .with_data(serde_json::json!({ "message": message })),
500 )
501 }
502
503 pub fn log_debug(&self, message: &str) -> bool {
505 self.log(
506 LoggingMessageParams::new(LogLevel::Debug)
507 .with_data(serde_json::json!({ "message": message })),
508 )
509 }
510
511 pub fn is_subscribed(&self, uri: &str) -> bool {
513 if let Ok(subs) = self.inner.subscriptions.read() {
514 return subs.contains(uri);
515 }
516 false
517 }
518
519 pub fn subscribed_uris(&self) -> Vec<String> {
521 if let Ok(subs) = self.inner.subscriptions.read() {
522 return subs.iter().cloned().collect();
523 }
524 Vec::new()
525 }
526
527 fn subscribe(&self, uri: &str) -> bool {
529 if let Ok(mut subs) = self.inner.subscriptions.write() {
530 return subs.insert(uri.to_string());
531 }
532 false
533 }
534
535 fn unsubscribe(&self, uri: &str) -> bool {
537 if let Ok(mut subs) = self.inner.subscriptions.write() {
538 return subs.remove(uri);
539 }
540 false
541 }
542
543 pub fn notify_resource_updated(&self, uri: &str) -> bool {
548 if !self.is_subscribed(uri) {
550 return false;
551 }
552
553 if let Some(tx) = &self.inner.notification_tx
554 && tx
555 .try_send(ServerNotification::ResourceUpdated {
556 uri: uri.to_string(),
557 })
558 .is_ok()
559 {
560 return true;
561 }
562 false
563 }
564
565 pub fn notify_resources_list_changed(&self) -> bool {
569 if let Some(tx) = &self.inner.notification_tx
570 && tx
571 .try_send(ServerNotification::ResourcesListChanged)
572 .is_ok()
573 {
574 return true;
575 }
576 false
577 }
578
579 fn capabilities(&self) -> ServerCapabilities {
581 let has_resources =
582 !self.inner.resources.is_empty() || !self.inner.resource_templates.is_empty();
583
584 ServerCapabilities {
585 tools: if self.inner.tools.is_empty() {
586 None
587 } else {
588 Some(ToolsCapability::default())
589 },
590 resources: if has_resources {
591 Some(ResourcesCapability {
592 subscribe: true,
593 ..Default::default()
594 })
595 } else {
596 None
597 },
598 prompts: if self.inner.prompts.is_empty() {
599 None
600 } else {
601 Some(PromptsCapability::default())
602 },
603 logging: if self.inner.notification_tx.is_some() {
605 Some(LoggingCapability::default())
606 } else {
607 None
608 },
609 tasks: Some(TasksCapability::default()),
611 completions: if self.inner.completion_handler.is_some() {
613 Some(CompletionsCapability::default())
614 } else {
615 None
616 },
617 }
618 }
619
620 async fn handle(&self, request_id: RequestId, request: McpRequest) -> Result<McpResponse> {
622 let method = request.method_name();
624 if !self.session.is_request_allowed(method) {
625 tracing::warn!(
626 method = %method,
627 phase = ?self.session.phase(),
628 "Request rejected: session not initialized"
629 );
630 return Err(Error::JsonRpc(JsonRpcError::invalid_request(format!(
631 "Session not initialized. Only 'initialize' and 'ping' are allowed before initialization. Got: {}",
632 method
633 ))));
634 }
635
636 match request {
637 McpRequest::Initialize(params) => {
638 tracing::info!(
639 client = %params.client_info.name,
640 version = %params.client_info.version,
641 "Client initializing"
642 );
643
644 let protocol_version = if crate::protocol::SUPPORTED_PROTOCOL_VERSIONS
647 .contains(¶ms.protocol_version.as_str())
648 {
649 params.protocol_version
650 } else {
651 crate::protocol::LATEST_PROTOCOL_VERSION.to_string()
652 };
653
654 self.session.mark_initializing();
656
657 Ok(McpResponse::Initialize(InitializeResult {
658 protocol_version,
659 capabilities: self.capabilities(),
660 server_info: Implementation {
661 name: self.inner.server_name.clone(),
662 version: self.inner.server_version.clone(),
663 title: self.inner.server_title.clone(),
664 description: self.inner.server_description.clone(),
665 icons: self.inner.server_icons.clone(),
666 website_url: self.inner.server_website_url.clone(),
667 },
668 instructions: self.inner.instructions.clone(),
669 }))
670 }
671
672 McpRequest::ListTools(_params) => {
673 let tools: Vec<ToolDefinition> =
674 self.inner.tools.values().map(|t| t.definition()).collect();
675
676 Ok(McpResponse::ListTools(ListToolsResult {
677 tools,
678 next_cursor: None,
679 }))
680 }
681
682 McpRequest::CallTool(params) => {
683 let tool =
684 self.inner.tools.get(¶ms.name).ok_or_else(|| {
685 Error::JsonRpc(JsonRpcError::method_not_found(¶ms.name))
686 })?;
687
688 let progress_token = params.meta.and_then(|m| m.progress_token);
690 let ctx = self.create_context(request_id, progress_token);
691
692 tracing::debug!(tool = %params.name, "Calling tool");
693 let result = tool.call_with_context(ctx, params.arguments).await?;
694
695 Ok(McpResponse::CallTool(result))
696 }
697
698 McpRequest::ListResources(_params) => {
699 let resources: Vec<ResourceDefinition> = self
700 .inner
701 .resources
702 .values()
703 .map(|r| r.definition())
704 .collect();
705
706 Ok(McpResponse::ListResources(ListResourcesResult {
707 resources,
708 next_cursor: None,
709 }))
710 }
711
712 McpRequest::ListResourceTemplates(_params) => {
713 let resource_templates: Vec<ResourceTemplateDefinition> = self
714 .inner
715 .resource_templates
716 .iter()
717 .map(|t| t.definition())
718 .collect();
719
720 Ok(McpResponse::ListResourceTemplates(
721 ListResourceTemplatesResult {
722 resource_templates,
723 next_cursor: None,
724 },
725 ))
726 }
727
728 McpRequest::ReadResource(params) => {
729 if let Some(resource) = self.inner.resources.get(¶ms.uri) {
731 tracing::debug!(uri = %params.uri, "Reading static resource");
732 let result = resource.read().await?;
733 return Ok(McpResponse::ReadResource(result));
734 }
735
736 for template in &self.inner.resource_templates {
738 if let Some(variables) = template.match_uri(¶ms.uri) {
739 tracing::debug!(
740 uri = %params.uri,
741 template = %template.uri_template,
742 "Reading resource via template"
743 );
744 let result = template.read(¶ms.uri, variables).await?;
745 return Ok(McpResponse::ReadResource(result));
746 }
747 }
748
749 Err(Error::JsonRpc(JsonRpcError::resource_not_found(
751 ¶ms.uri,
752 )))
753 }
754
755 McpRequest::SubscribeResource(params) => {
756 if !self.inner.resources.contains_key(¶ms.uri) {
758 return Err(Error::JsonRpc(JsonRpcError::resource_not_found(
759 ¶ms.uri,
760 )));
761 }
762
763 tracing::debug!(uri = %params.uri, "Subscribing to resource");
764 self.subscribe(¶ms.uri);
765
766 Ok(McpResponse::SubscribeResource(EmptyResult {}))
767 }
768
769 McpRequest::UnsubscribeResource(params) => {
770 if !self.inner.resources.contains_key(¶ms.uri) {
772 return Err(Error::JsonRpc(JsonRpcError::resource_not_found(
773 ¶ms.uri,
774 )));
775 }
776
777 tracing::debug!(uri = %params.uri, "Unsubscribing from resource");
778 self.unsubscribe(¶ms.uri);
779
780 Ok(McpResponse::UnsubscribeResource(EmptyResult {}))
781 }
782
783 McpRequest::ListPrompts(_params) => {
784 let prompts: Vec<PromptDefinition> = self
785 .inner
786 .prompts
787 .values()
788 .map(|p| p.definition())
789 .collect();
790
791 Ok(McpResponse::ListPrompts(ListPromptsResult {
792 prompts,
793 next_cursor: None,
794 }))
795 }
796
797 McpRequest::GetPrompt(params) => {
798 let prompt = self.inner.prompts.get(¶ms.name).ok_or_else(|| {
799 Error::JsonRpc(JsonRpcError::method_not_found(&format!(
800 "Prompt not found: {}",
801 params.name
802 )))
803 })?;
804
805 tracing::debug!(name = %params.name, "Getting prompt");
806 let result = prompt.get(params.arguments).await?;
807
808 Ok(McpResponse::GetPrompt(result))
809 }
810
811 McpRequest::Ping => Ok(McpResponse::Pong(EmptyResult {})),
812
813 McpRequest::EnqueueTask(params) => {
814 let tool = self.inner.tools.get(¶ms.tool_name).ok_or_else(|| {
816 Error::JsonRpc(JsonRpcError::method_not_found(&format!(
817 "Tool not found: {}",
818 params.tool_name
819 )))
820 })?;
821
822 let (task_id, cancellation_token) = self.inner.task_store.create_task(
824 ¶ms.tool_name,
825 params.arguments.clone(),
826 params.ttl,
827 );
828
829 tracing::info!(task_id = %task_id, tool = %params.tool_name, "Enqueued async task");
830
831 let ctx = self.create_context(request_id, None);
833
834 let task_store = self.inner.task_store.clone();
836 let tool = tool.clone();
837 let arguments = params.arguments;
838 let task_id_clone = task_id.clone();
839
840 tokio::spawn(async move {
841 if cancellation_token.is_cancelled() {
843 tracing::debug!(task_id = %task_id_clone, "Task cancelled before execution");
844 return;
845 }
846
847 match tool.call_with_context(ctx, arguments).await {
849 Ok(result) => {
850 if cancellation_token.is_cancelled() {
851 tracing::debug!(task_id = %task_id_clone, "Task cancelled during execution");
852 } else {
853 task_store.complete_task(&task_id_clone, result);
854 tracing::debug!(task_id = %task_id_clone, "Task completed successfully");
855 }
856 }
857 Err(e) => {
858 task_store.fail_task(&task_id_clone, &e.to_string());
859 tracing::warn!(task_id = %task_id_clone, error = %e, "Task failed");
860 }
861 }
862 });
863
864 Ok(McpResponse::EnqueueTask(EnqueueTaskResult {
865 task_id,
866 status: TaskStatus::Working,
867 poll_interval: Some(2),
868 }))
869 }
870
871 McpRequest::ListTasks(params) => {
872 let tasks = self.inner.task_store.list_tasks(params.status);
873
874 Ok(McpResponse::ListTasks(ListTasksResult {
875 tasks,
876 next_cursor: None,
877 }))
878 }
879
880 McpRequest::GetTaskInfo(params) => {
881 let task = self
882 .inner
883 .task_store
884 .get_task(¶ms.task_id)
885 .ok_or_else(|| {
886 Error::JsonRpc(JsonRpcError::invalid_params(format!(
887 "Task not found: {}",
888 params.task_id
889 )))
890 })?;
891
892 Ok(McpResponse::GetTaskInfo(task))
893 }
894
895 McpRequest::GetTaskResult(params) => {
896 let (status, result, error) = self
897 .inner
898 .task_store
899 .get_task_full(¶ms.task_id)
900 .ok_or_else(|| {
901 Error::JsonRpc(JsonRpcError::invalid_params(format!(
902 "Task not found: {}",
903 params.task_id
904 )))
905 })?;
906
907 Ok(McpResponse::GetTaskResult(GetTaskResultResult {
908 task_id: params.task_id,
909 status,
910 result,
911 error,
912 }))
913 }
914
915 McpRequest::CancelTask(params) => {
916 let status = self
917 .inner
918 .task_store
919 .cancel_task(¶ms.task_id, params.reason.as_deref())
920 .ok_or_else(|| {
921 Error::JsonRpc(JsonRpcError::invalid_params(format!(
922 "Task not found: {}",
923 params.task_id
924 )))
925 })?;
926
927 let cancelled = status == TaskStatus::Cancelled;
928
929 Ok(McpResponse::CancelTask(CancelTaskResult {
930 cancelled,
931 status,
932 }))
933 }
934
935 McpRequest::SetLoggingLevel(params) => {
936 tracing::debug!(level = ?params.level, "Client set logging level");
940 Ok(McpResponse::SetLoggingLevel(EmptyResult {}))
941 }
942
943 McpRequest::Complete(params) => {
944 tracing::debug!(
945 reference = ?params.reference,
946 argument = %params.argument.name,
947 "Completion request"
948 );
949
950 if let Some(ref handler) = self.inner.completion_handler {
952 let result = handler(params).await?;
953 Ok(McpResponse::Complete(result))
954 } else {
955 Ok(McpResponse::Complete(CompleteResult::new(vec![])))
957 }
958 }
959
960 McpRequest::Unknown { method, .. } => {
961 Err(Error::JsonRpc(JsonRpcError::method_not_found(&method)))
962 }
963 }
964 }
965
966 pub fn handle_notification(&self, notification: McpNotification) {
968 match notification {
969 McpNotification::Initialized => {
970 if self.session.mark_initialized() {
971 tracing::info!("Session initialized, entering operation phase");
972 } else {
973 tracing::warn!(
974 "Received initialized notification in unexpected state: {:?}",
975 self.session.phase()
976 );
977 }
978 }
979 McpNotification::Cancelled(params) => {
980 if self.cancel_request(¶ms.request_id) {
981 tracing::info!(
982 request_id = ?params.request_id,
983 reason = ?params.reason,
984 "Request cancelled"
985 );
986 } else {
987 tracing::debug!(
988 request_id = ?params.request_id,
989 reason = ?params.reason,
990 "Cancellation requested for unknown request"
991 );
992 }
993 }
994 McpNotification::Progress(params) => {
995 tracing::trace!(
996 token = ?params.progress_token,
997 progress = params.progress,
998 total = ?params.total,
999 "Progress notification"
1000 );
1001 }
1003 McpNotification::RootsListChanged => {
1004 tracing::info!("Client roots list changed");
1005 }
1008 McpNotification::Unknown { method, .. } => {
1009 tracing::debug!(method = %method, "Unknown notification received");
1010 }
1011 }
1012 }
1013}
1014
1015impl Default for McpRouter {
1016 fn default() -> Self {
1017 Self::new()
1018 }
1019}
1020
1021#[derive(Default, Clone)]
1041pub struct Extensions {
1042 map: HashMap<TypeId, Arc<dyn Any + Send + Sync>>,
1043}
1044
1045impl Extensions {
1046 pub fn new() -> Self {
1048 Self::default()
1049 }
1050
1051 pub fn insert<T: Send + Sync + 'static>(&mut self, val: T) {
1055 self.map.insert(TypeId::of::<T>(), Arc::new(val));
1056 }
1057
1058 pub fn get<T: Send + Sync + 'static>(&self) -> Option<&T> {
1062 self.map
1063 .get(&TypeId::of::<T>())
1064 .and_then(|val| val.downcast_ref::<T>())
1065 }
1066}
1067
1068impl std::fmt::Debug for Extensions {
1069 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1070 f.debug_struct("Extensions")
1071 .field("len", &self.map.len())
1072 .finish()
1073 }
1074}
1075
1076#[derive(Debug)]
1078pub struct RouterRequest {
1079 pub id: RequestId,
1080 pub inner: McpRequest,
1081 pub extensions: Extensions,
1083}
1084
1085#[derive(Debug)]
1087pub struct RouterResponse {
1088 pub id: RequestId,
1089 pub inner: std::result::Result<McpResponse, JsonRpcError>,
1090}
1091
1092impl RouterResponse {
1093 pub fn into_jsonrpc(self) -> JsonRpcResponse {
1095 match self.inner {
1096 Ok(response) => match serde_json::to_value(response) {
1097 Ok(result) => JsonRpcResponse::result(self.id, result),
1098 Err(e) => {
1099 tracing::error!(error = %e, "Failed to serialize response");
1100 JsonRpcResponse::error(
1101 Some(self.id),
1102 JsonRpcError::internal_error(format!("Serialization error: {}", e)),
1103 )
1104 }
1105 },
1106 Err(error) => JsonRpcResponse::error(Some(self.id), error),
1107 }
1108 }
1109}
1110
1111impl Service<RouterRequest> for McpRouter {
1112 type Response = RouterResponse;
1113 type Error = std::convert::Infallible; type Future =
1115 Pin<Box<dyn Future<Output = std::result::Result<Self::Response, Self::Error>> + Send>>;
1116
1117 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
1118 Poll::Ready(Ok(()))
1119 }
1120
1121 fn call(&mut self, req: RouterRequest) -> Self::Future {
1122 let router = self.clone();
1123 let request_id = req.id.clone();
1124 Box::pin(async move {
1125 let result = router.handle(req.id, req.inner).await;
1126 router.complete_request(&request_id);
1128 Ok(RouterResponse {
1129 id: request_id,
1130 inner: result.map_err(|e| match e {
1135 Error::JsonRpc(err) => err,
1136 Error::Tool(err) => JsonRpcError::internal_error(err.to_string()),
1137 e => JsonRpcError::internal_error(e.to_string()),
1138 }),
1139 })
1140 })
1141 }
1142}
1143
1144#[cfg(test)]
1145mod tests {
1146 use super::*;
1147 use crate::jsonrpc::JsonRpcService;
1148 use crate::tool::ToolBuilder;
1149 use schemars::JsonSchema;
1150 use serde::Deserialize;
1151 use tower::ServiceExt;
1152
1153 #[derive(Debug, Deserialize, JsonSchema)]
1154 struct AddInput {
1155 a: i64,
1156 b: i64,
1157 }
1158
1159 async fn init_router(router: &mut McpRouter) {
1161 let init_req = RouterRequest {
1163 id: RequestId::Number(0),
1164 inner: McpRequest::Initialize(InitializeParams {
1165 protocol_version: "2025-11-25".to_string(),
1166 capabilities: ClientCapabilities {
1167 roots: None,
1168 sampling: None,
1169 elicitation: None,
1170 },
1171 client_info: Implementation {
1172 name: "test".to_string(),
1173 version: "1.0".to_string(),
1174 ..Default::default()
1175 },
1176 }),
1177 extensions: Extensions::new(),
1178 };
1179 let _ = router.ready().await.unwrap().call(init_req).await.unwrap();
1180 router.handle_notification(McpNotification::Initialized);
1182 }
1183
1184 #[tokio::test]
1185 async fn test_router_list_tools() {
1186 let add_tool = ToolBuilder::new("add")
1187 .description("Add two numbers")
1188 .handler(|input: AddInput| async move {
1189 Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1190 })
1191 .build()
1192 .expect("valid tool name");
1193
1194 let mut router = McpRouter::new().tool(add_tool);
1195
1196 init_router(&mut router).await;
1198
1199 let req = RouterRequest {
1200 id: RequestId::Number(1),
1201 inner: McpRequest::ListTools(ListToolsParams::default()),
1202 extensions: Extensions::new(),
1203 };
1204
1205 let resp = router.ready().await.unwrap().call(req).await.unwrap();
1206
1207 match resp.inner {
1208 Ok(McpResponse::ListTools(result)) => {
1209 assert_eq!(result.tools.len(), 1);
1210 assert_eq!(result.tools[0].name, "add");
1211 }
1212 _ => panic!("Expected ListTools response"),
1213 }
1214 }
1215
1216 #[tokio::test]
1217 async fn test_router_call_tool() {
1218 let add_tool = ToolBuilder::new("add")
1219 .description("Add two numbers")
1220 .handler(|input: AddInput| async move {
1221 Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1222 })
1223 .build()
1224 .expect("valid tool name");
1225
1226 let mut router = McpRouter::new().tool(add_tool);
1227
1228 init_router(&mut router).await;
1230
1231 let req = RouterRequest {
1232 id: RequestId::Number(1),
1233 inner: McpRequest::CallTool(CallToolParams {
1234 name: "add".to_string(),
1235 arguments: serde_json::json!({"a": 2, "b": 3}),
1236 meta: None,
1237 }),
1238 extensions: Extensions::new(),
1239 };
1240
1241 let resp = router.ready().await.unwrap().call(req).await.unwrap();
1242
1243 match resp.inner {
1244 Ok(McpResponse::CallTool(result)) => {
1245 assert!(!result.is_error);
1246 match &result.content[0] {
1248 Content::Text { text, .. } => assert_eq!(text, "5"),
1249 _ => panic!("Expected text content"),
1250 }
1251 }
1252 _ => panic!("Expected CallTool response"),
1253 }
1254 }
1255
1256 async fn init_jsonrpc_service(service: &mut JsonRpcService<McpRouter>, router: &McpRouter) {
1258 let init_req = JsonRpcRequest::new(0, "initialize").with_params(serde_json::json!({
1259 "protocolVersion": "2025-11-25",
1260 "capabilities": {},
1261 "clientInfo": { "name": "test", "version": "1.0" }
1262 }));
1263 let _ = service.call_single(init_req).await.unwrap();
1264 router.handle_notification(McpNotification::Initialized);
1265 }
1266
1267 #[tokio::test]
1268 async fn test_jsonrpc_service() {
1269 let add_tool = ToolBuilder::new("add")
1270 .description("Add two numbers")
1271 .handler(|input: AddInput| async move {
1272 Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1273 })
1274 .build()
1275 .expect("valid tool name");
1276
1277 let router = McpRouter::new().tool(add_tool);
1278 let mut service = JsonRpcService::new(router.clone());
1279
1280 init_jsonrpc_service(&mut service, &router).await;
1282
1283 let req = JsonRpcRequest::new(1, "tools/list");
1284
1285 let resp = service.call_single(req).await.unwrap();
1286
1287 match resp {
1288 JsonRpcResponse::Result(r) => {
1289 assert_eq!(r.id, RequestId::Number(1));
1290 let tools = r.result.get("tools").unwrap().as_array().unwrap();
1291 assert_eq!(tools.len(), 1);
1292 }
1293 JsonRpcResponse::Error(_) => panic!("Expected success response"),
1294 }
1295 }
1296
1297 #[tokio::test]
1298 async fn test_batch_request() {
1299 let add_tool = ToolBuilder::new("add")
1300 .description("Add two numbers")
1301 .handler(|input: AddInput| async move {
1302 Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1303 })
1304 .build()
1305 .expect("valid tool name");
1306
1307 let router = McpRouter::new().tool(add_tool);
1308 let mut service = JsonRpcService::new(router.clone());
1309
1310 init_jsonrpc_service(&mut service, &router).await;
1312
1313 let requests = vec![
1315 JsonRpcRequest::new(1, "tools/list"),
1316 JsonRpcRequest::new(2, "tools/call").with_params(serde_json::json!({
1317 "name": "add",
1318 "arguments": {"a": 10, "b": 20}
1319 })),
1320 JsonRpcRequest::new(3, "ping"),
1321 ];
1322
1323 let responses = service.call_batch(requests).await.unwrap();
1324
1325 assert_eq!(responses.len(), 3);
1326
1327 match &responses[0] {
1329 JsonRpcResponse::Result(r) => {
1330 assert_eq!(r.id, RequestId::Number(1));
1331 let tools = r.result.get("tools").unwrap().as_array().unwrap();
1332 assert_eq!(tools.len(), 1);
1333 }
1334 JsonRpcResponse::Error(_) => panic!("Expected success for tools/list"),
1335 }
1336
1337 match &responses[1] {
1339 JsonRpcResponse::Result(r) => {
1340 assert_eq!(r.id, RequestId::Number(2));
1341 let content = r.result.get("content").unwrap().as_array().unwrap();
1342 let text = content[0].get("text").unwrap().as_str().unwrap();
1343 assert_eq!(text, "30");
1344 }
1345 JsonRpcResponse::Error(_) => panic!("Expected success for tools/call"),
1346 }
1347
1348 match &responses[2] {
1350 JsonRpcResponse::Result(r) => {
1351 assert_eq!(r.id, RequestId::Number(3));
1352 }
1353 JsonRpcResponse::Error(_) => panic!("Expected success for ping"),
1354 }
1355 }
1356
1357 #[tokio::test]
1358 async fn test_empty_batch_error() {
1359 let router = McpRouter::new();
1360 let mut service = JsonRpcService::new(router);
1361
1362 let result = service.call_batch(vec![]).await;
1363 assert!(result.is_err());
1364 }
1365
1366 #[tokio::test]
1371 async fn test_progress_token_extraction() {
1372 use crate::context::{RequestContext, ServerNotification, notification_channel};
1373 use crate::protocol::ProgressToken;
1374 use std::sync::Arc;
1375 use std::sync::atomic::{AtomicBool, Ordering};
1376
1377 let progress_reported = Arc::new(AtomicBool::new(false));
1379 let progress_ref = progress_reported.clone();
1380
1381 let tool = ToolBuilder::new("progress_tool")
1383 .description("Tool that reports progress")
1384 .handler_with_context(move |ctx: RequestContext, _input: AddInput| {
1385 let reported = progress_ref.clone();
1386 async move {
1387 ctx.report_progress(50.0, Some(100.0), Some("Halfway"))
1389 .await;
1390 reported.store(true, Ordering::SeqCst);
1391 Ok(CallToolResult::text("done"))
1392 }
1393 })
1394 .build()
1395 .expect("valid tool name");
1396
1397 let (tx, mut rx) = notification_channel(10);
1399 let router = McpRouter::new().with_notification_sender(tx).tool(tool);
1400 let mut service = JsonRpcService::new(router.clone());
1401
1402 init_jsonrpc_service(&mut service, &router).await;
1404
1405 let req = JsonRpcRequest::new(1, "tools/call").with_params(serde_json::json!({
1407 "name": "progress_tool",
1408 "arguments": {"a": 1, "b": 2},
1409 "_meta": {
1410 "progressToken": "test-token-123"
1411 }
1412 }));
1413
1414 let resp = service.call_single(req).await.unwrap();
1415
1416 match resp {
1418 JsonRpcResponse::Result(_) => {}
1419 JsonRpcResponse::Error(e) => panic!("Expected success, got error: {:?}", e),
1420 }
1421
1422 assert!(progress_reported.load(Ordering::SeqCst));
1424
1425 let notification = rx.try_recv().expect("Expected progress notification");
1427 match notification {
1428 ServerNotification::Progress(params) => {
1429 assert_eq!(
1430 params.progress_token,
1431 ProgressToken::String("test-token-123".to_string())
1432 );
1433 assert_eq!(params.progress, 50.0);
1434 assert_eq!(params.total, Some(100.0));
1435 assert_eq!(params.message.as_deref(), Some("Halfway"));
1436 }
1437 _ => panic!("Expected Progress notification"),
1438 }
1439 }
1440
1441 #[tokio::test]
1442 async fn test_tool_call_without_progress_token() {
1443 use crate::context::{RequestContext, notification_channel};
1444 use std::sync::Arc;
1445 use std::sync::atomic::{AtomicBool, Ordering};
1446
1447 let progress_attempted = Arc::new(AtomicBool::new(false));
1448 let progress_ref = progress_attempted.clone();
1449
1450 let tool = ToolBuilder::new("no_token_tool")
1451 .description("Tool that tries to report progress without token")
1452 .handler_with_context(move |ctx: RequestContext, _input: AddInput| {
1453 let attempted = progress_ref.clone();
1454 async move {
1455 ctx.report_progress(50.0, Some(100.0), None).await;
1457 attempted.store(true, Ordering::SeqCst);
1458 Ok(CallToolResult::text("done"))
1459 }
1460 })
1461 .build()
1462 .expect("valid tool name");
1463
1464 let (tx, mut rx) = notification_channel(10);
1465 let router = McpRouter::new().with_notification_sender(tx).tool(tool);
1466 let mut service = JsonRpcService::new(router.clone());
1467
1468 init_jsonrpc_service(&mut service, &router).await;
1469
1470 let req = JsonRpcRequest::new(1, "tools/call").with_params(serde_json::json!({
1472 "name": "no_token_tool",
1473 "arguments": {"a": 1, "b": 2}
1474 }));
1475
1476 let resp = service.call_single(req).await.unwrap();
1477 assert!(matches!(resp, JsonRpcResponse::Result(_)));
1478
1479 assert!(progress_attempted.load(Ordering::SeqCst));
1481
1482 assert!(rx.try_recv().is_err());
1484 }
1485
1486 #[tokio::test]
1487 async fn test_batch_errors_returned_not_dropped() {
1488 let add_tool = ToolBuilder::new("add")
1489 .description("Add two numbers")
1490 .handler(|input: AddInput| async move {
1491 Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1492 })
1493 .build()
1494 .expect("valid tool name");
1495
1496 let router = McpRouter::new().tool(add_tool);
1497 let mut service = JsonRpcService::new(router.clone());
1498
1499 init_jsonrpc_service(&mut service, &router).await;
1500
1501 let requests = vec![
1503 JsonRpcRequest::new(1, "tools/call").with_params(serde_json::json!({
1505 "name": "add",
1506 "arguments": {"a": 10, "b": 20}
1507 })),
1508 JsonRpcRequest::new(2, "tools/call").with_params(serde_json::json!({
1510 "name": "nonexistent_tool",
1511 "arguments": {}
1512 })),
1513 JsonRpcRequest::new(3, "ping"),
1515 ];
1516
1517 let responses = service.call_batch(requests).await.unwrap();
1518
1519 assert_eq!(responses.len(), 3);
1521
1522 match &responses[0] {
1524 JsonRpcResponse::Result(r) => {
1525 assert_eq!(r.id, RequestId::Number(1));
1526 }
1527 JsonRpcResponse::Error(_) => panic!("Expected success for first request"),
1528 }
1529
1530 match &responses[1] {
1532 JsonRpcResponse::Error(e) => {
1533 assert_eq!(e.id, Some(RequestId::Number(2)));
1534 assert!(e.error.message.contains("not found") || e.error.code == -32601);
1536 }
1537 JsonRpcResponse::Result(_) => panic!("Expected error for second request"),
1538 }
1539
1540 match &responses[2] {
1542 JsonRpcResponse::Result(r) => {
1543 assert_eq!(r.id, RequestId::Number(3));
1544 }
1545 JsonRpcResponse::Error(_) => panic!("Expected success for third request"),
1546 }
1547 }
1548
1549 #[tokio::test]
1554 async fn test_list_resource_templates() {
1555 use crate::resource::ResourceTemplateBuilder;
1556 use std::collections::HashMap;
1557
1558 let template = ResourceTemplateBuilder::new("file:///{path}")
1559 .name("Project Files")
1560 .description("Access project files")
1561 .handler(|uri: String, _vars: HashMap<String, String>| async move {
1562 Ok(ReadResourceResult {
1563 contents: vec![ResourceContent {
1564 uri,
1565 mime_type: None,
1566 text: None,
1567 blob: None,
1568 }],
1569 })
1570 });
1571
1572 let mut router = McpRouter::new().resource_template(template);
1573
1574 init_router(&mut router).await;
1576
1577 let req = RouterRequest {
1578 id: RequestId::Number(1),
1579 inner: McpRequest::ListResourceTemplates(ListResourceTemplatesParams::default()),
1580 extensions: Extensions::new(),
1581 };
1582
1583 let resp = router.ready().await.unwrap().call(req).await.unwrap();
1584
1585 match resp.inner {
1586 Ok(McpResponse::ListResourceTemplates(result)) => {
1587 assert_eq!(result.resource_templates.len(), 1);
1588 assert_eq!(result.resource_templates[0].uri_template, "file:///{path}");
1589 assert_eq!(result.resource_templates[0].name, "Project Files");
1590 }
1591 _ => panic!("Expected ListResourceTemplates response"),
1592 }
1593 }
1594
1595 #[tokio::test]
1596 async fn test_read_resource_via_template() {
1597 use crate::resource::ResourceTemplateBuilder;
1598 use std::collections::HashMap;
1599
1600 let template = ResourceTemplateBuilder::new("db://users/{id}")
1601 .name("User Records")
1602 .handler(|uri: String, vars: HashMap<String, String>| async move {
1603 let id = vars.get("id").unwrap().clone();
1604 Ok(ReadResourceResult {
1605 contents: vec![ResourceContent {
1606 uri,
1607 mime_type: Some("application/json".to_string()),
1608 text: Some(format!(r#"{{"id": "{}"}}"#, id)),
1609 blob: None,
1610 }],
1611 })
1612 });
1613
1614 let mut router = McpRouter::new().resource_template(template);
1615
1616 init_router(&mut router).await;
1618
1619 let req = RouterRequest {
1621 id: RequestId::Number(1),
1622 inner: McpRequest::ReadResource(ReadResourceParams {
1623 uri: "db://users/123".to_string(),
1624 }),
1625 extensions: Extensions::new(),
1626 };
1627
1628 let resp = router.ready().await.unwrap().call(req).await.unwrap();
1629
1630 match resp.inner {
1631 Ok(McpResponse::ReadResource(result)) => {
1632 assert_eq!(result.contents.len(), 1);
1633 assert_eq!(result.contents[0].uri, "db://users/123");
1634 assert!(result.contents[0].text.as_ref().unwrap().contains("123"));
1635 }
1636 _ => panic!("Expected ReadResource response"),
1637 }
1638 }
1639
1640 #[tokio::test]
1641 async fn test_static_resource_takes_precedence_over_template() {
1642 use crate::resource::{ResourceBuilder, ResourceTemplateBuilder};
1643 use std::collections::HashMap;
1644
1645 let template = ResourceTemplateBuilder::new("file:///{path}")
1647 .name("Files Template")
1648 .handler(|uri: String, _vars: HashMap<String, String>| async move {
1649 Ok(ReadResourceResult {
1650 contents: vec![ResourceContent {
1651 uri,
1652 mime_type: None,
1653 text: Some("from template".to_string()),
1654 blob: None,
1655 }],
1656 })
1657 });
1658
1659 let static_resource = ResourceBuilder::new("file:///README.md")
1661 .name("README")
1662 .text("from static resource");
1663
1664 let mut router = McpRouter::new()
1665 .resource_template(template)
1666 .resource(static_resource);
1667
1668 init_router(&mut router).await;
1670
1671 let req = RouterRequest {
1673 id: RequestId::Number(1),
1674 inner: McpRequest::ReadResource(ReadResourceParams {
1675 uri: "file:///README.md".to_string(),
1676 }),
1677 extensions: Extensions::new(),
1678 };
1679
1680 let resp = router.ready().await.unwrap().call(req).await.unwrap();
1681
1682 match resp.inner {
1683 Ok(McpResponse::ReadResource(result)) => {
1684 assert_eq!(
1686 result.contents[0].text.as_deref(),
1687 Some("from static resource")
1688 );
1689 }
1690 _ => panic!("Expected ReadResource response"),
1691 }
1692 }
1693
1694 #[tokio::test]
1695 async fn test_resource_not_found_when_no_match() {
1696 use crate::resource::ResourceTemplateBuilder;
1697 use std::collections::HashMap;
1698
1699 let template = ResourceTemplateBuilder::new("db://users/{id}")
1700 .name("Users")
1701 .handler(|uri: String, _vars: HashMap<String, String>| async move {
1702 Ok(ReadResourceResult {
1703 contents: vec![ResourceContent {
1704 uri,
1705 mime_type: None,
1706 text: None,
1707 blob: None,
1708 }],
1709 })
1710 });
1711
1712 let mut router = McpRouter::new().resource_template(template);
1713
1714 init_router(&mut router).await;
1716
1717 let req = RouterRequest {
1719 id: RequestId::Number(1),
1720 inner: McpRequest::ReadResource(ReadResourceParams {
1721 uri: "db://posts/123".to_string(),
1722 }),
1723 extensions: Extensions::new(),
1724 };
1725
1726 let resp = router.ready().await.unwrap().call(req).await.unwrap();
1727
1728 match resp.inner {
1729 Err(err) => {
1730 assert!(err.message.contains("not found"));
1731 }
1732 Ok(_) => panic!("Expected error for non-matching URI"),
1733 }
1734 }
1735
1736 #[tokio::test]
1737 async fn test_capabilities_include_resources_with_only_templates() {
1738 use crate::resource::ResourceTemplateBuilder;
1739 use std::collections::HashMap;
1740
1741 let template = ResourceTemplateBuilder::new("file:///{path}")
1742 .name("Files")
1743 .handler(|uri: String, _vars: HashMap<String, String>| async move {
1744 Ok(ReadResourceResult {
1745 contents: vec![ResourceContent {
1746 uri,
1747 mime_type: None,
1748 text: None,
1749 blob: None,
1750 }],
1751 })
1752 });
1753
1754 let mut router = McpRouter::new().resource_template(template);
1755
1756 let init_req = RouterRequest {
1758 id: RequestId::Number(0),
1759 inner: McpRequest::Initialize(InitializeParams {
1760 protocol_version: "2025-11-25".to_string(),
1761 capabilities: ClientCapabilities {
1762 roots: None,
1763 sampling: None,
1764 elicitation: None,
1765 },
1766 client_info: Implementation {
1767 name: "test".to_string(),
1768 version: "1.0".to_string(),
1769 ..Default::default()
1770 },
1771 }),
1772 extensions: Extensions::new(),
1773 };
1774 let resp = router.ready().await.unwrap().call(init_req).await.unwrap();
1775
1776 match resp.inner {
1777 Ok(McpResponse::Initialize(result)) => {
1778 assert!(result.capabilities.resources.is_some());
1780 }
1781 _ => panic!("Expected Initialize response"),
1782 }
1783 }
1784
1785 #[tokio::test]
1790 async fn test_log_sends_notification() {
1791 use crate::context::notification_channel;
1792
1793 let (tx, mut rx) = notification_channel(10);
1794 let router = McpRouter::new().with_notification_sender(tx);
1795
1796 let sent = router.log_info("Test message");
1798 assert!(sent);
1799
1800 let notification = rx.try_recv().unwrap();
1802 match notification {
1803 ServerNotification::LogMessage(params) => {
1804 assert_eq!(params.level, LogLevel::Info);
1805 let data = params.data.unwrap();
1806 assert_eq!(
1807 data.get("message").unwrap().as_str().unwrap(),
1808 "Test message"
1809 );
1810 }
1811 _ => panic!("Expected LogMessage notification"),
1812 }
1813 }
1814
1815 #[tokio::test]
1816 async fn test_log_with_custom_params() {
1817 use crate::context::notification_channel;
1818
1819 let (tx, mut rx) = notification_channel(10);
1820 let router = McpRouter::new().with_notification_sender(tx);
1821
1822 let params = LoggingMessageParams::new(LogLevel::Error)
1824 .with_logger("database")
1825 .with_data(serde_json::json!({
1826 "error": "Connection failed",
1827 "host": "localhost"
1828 }));
1829
1830 let sent = router.log(params);
1831 assert!(sent);
1832
1833 let notification = rx.try_recv().unwrap();
1834 match notification {
1835 ServerNotification::LogMessage(params) => {
1836 assert_eq!(params.level, LogLevel::Error);
1837 assert_eq!(params.logger.as_deref(), Some("database"));
1838 let data = params.data.unwrap();
1839 assert_eq!(
1840 data.get("error").unwrap().as_str().unwrap(),
1841 "Connection failed"
1842 );
1843 }
1844 _ => panic!("Expected LogMessage notification"),
1845 }
1846 }
1847
1848 #[tokio::test]
1849 async fn test_log_without_channel_returns_false() {
1850 let router = McpRouter::new();
1852
1853 assert!(!router.log_info("Test"));
1855 assert!(!router.log_warning("Test"));
1856 assert!(!router.log_error("Test"));
1857 assert!(!router.log_debug("Test"));
1858 }
1859
1860 #[tokio::test]
1861 async fn test_logging_capability_with_channel() {
1862 use crate::context::notification_channel;
1863
1864 let (tx, _rx) = notification_channel(10);
1865 let mut router = McpRouter::new().with_notification_sender(tx);
1866
1867 let init_req = RouterRequest {
1869 id: RequestId::Number(0),
1870 inner: McpRequest::Initialize(InitializeParams {
1871 protocol_version: "2025-11-25".to_string(),
1872 capabilities: ClientCapabilities {
1873 roots: None,
1874 sampling: None,
1875 elicitation: None,
1876 },
1877 client_info: Implementation {
1878 name: "test".to_string(),
1879 version: "1.0".to_string(),
1880 ..Default::default()
1881 },
1882 }),
1883 extensions: Extensions::new(),
1884 };
1885 let resp = router.ready().await.unwrap().call(init_req).await.unwrap();
1886
1887 match resp.inner {
1888 Ok(McpResponse::Initialize(result)) => {
1889 assert!(result.capabilities.logging.is_some());
1891 }
1892 _ => panic!("Expected Initialize response"),
1893 }
1894 }
1895
1896 #[tokio::test]
1897 async fn test_no_logging_capability_without_channel() {
1898 let mut router = McpRouter::new();
1899
1900 let init_req = RouterRequest {
1902 id: RequestId::Number(0),
1903 inner: McpRequest::Initialize(InitializeParams {
1904 protocol_version: "2025-11-25".to_string(),
1905 capabilities: ClientCapabilities {
1906 roots: None,
1907 sampling: None,
1908 elicitation: None,
1909 },
1910 client_info: Implementation {
1911 name: "test".to_string(),
1912 version: "1.0".to_string(),
1913 ..Default::default()
1914 },
1915 }),
1916 extensions: Extensions::new(),
1917 };
1918 let resp = router.ready().await.unwrap().call(init_req).await.unwrap();
1919
1920 match resp.inner {
1921 Ok(McpResponse::Initialize(result)) => {
1922 assert!(result.capabilities.logging.is_none());
1924 }
1925 _ => panic!("Expected Initialize response"),
1926 }
1927 }
1928
1929 #[tokio::test]
1934 async fn test_enqueue_task() {
1935 let add_tool = ToolBuilder::new("add")
1936 .description("Add two numbers")
1937 .handler(|input: AddInput| async move {
1938 Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1939 })
1940 .build()
1941 .expect("valid tool name");
1942
1943 let mut router = McpRouter::new().tool(add_tool);
1944 init_router(&mut router).await;
1945
1946 let req = RouterRequest {
1947 id: RequestId::Number(1),
1948 inner: McpRequest::EnqueueTask(EnqueueTaskParams {
1949 tool_name: "add".to_string(),
1950 arguments: serde_json::json!({"a": 5, "b": 10}),
1951 ttl: None,
1952 }),
1953 extensions: Extensions::new(),
1954 };
1955
1956 let resp = router.ready().await.unwrap().call(req).await.unwrap();
1957
1958 match resp.inner {
1959 Ok(McpResponse::EnqueueTask(result)) => {
1960 assert!(result.task_id.starts_with("task-"));
1961 assert_eq!(result.status, TaskStatus::Working);
1962 }
1963 _ => panic!("Expected EnqueueTask response"),
1964 }
1965 }
1966
1967 #[tokio::test]
1968 async fn test_list_tasks_empty() {
1969 let mut router = McpRouter::new();
1970 init_router(&mut router).await;
1971
1972 let req = RouterRequest {
1973 id: RequestId::Number(1),
1974 inner: McpRequest::ListTasks(ListTasksParams::default()),
1975 extensions: Extensions::new(),
1976 };
1977
1978 let resp = router.ready().await.unwrap().call(req).await.unwrap();
1979
1980 match resp.inner {
1981 Ok(McpResponse::ListTasks(result)) => {
1982 assert!(result.tasks.is_empty());
1983 }
1984 _ => panic!("Expected ListTasks response"),
1985 }
1986 }
1987
1988 #[tokio::test]
1989 async fn test_task_lifecycle_complete() {
1990 let add_tool = ToolBuilder::new("add")
1991 .description("Add two numbers")
1992 .handler(|input: AddInput| async move {
1993 Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1994 })
1995 .build()
1996 .expect("valid tool name");
1997
1998 let mut router = McpRouter::new().tool(add_tool);
1999 init_router(&mut router).await;
2000
2001 let req = RouterRequest {
2003 id: RequestId::Number(1),
2004 inner: McpRequest::EnqueueTask(EnqueueTaskParams {
2005 tool_name: "add".to_string(),
2006 arguments: serde_json::json!({"a": 7, "b": 8}),
2007 ttl: None,
2008 }),
2009 extensions: Extensions::new(),
2010 };
2011
2012 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2013 let task_id = match resp.inner {
2014 Ok(McpResponse::EnqueueTask(result)) => result.task_id,
2015 _ => panic!("Expected EnqueueTask response"),
2016 };
2017
2018 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
2020
2021 let req = RouterRequest {
2023 id: RequestId::Number(2),
2024 inner: McpRequest::GetTaskResult(GetTaskResultParams {
2025 task_id: task_id.clone(),
2026 }),
2027 extensions: Extensions::new(),
2028 };
2029
2030 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2031
2032 match resp.inner {
2033 Ok(McpResponse::GetTaskResult(result)) => {
2034 assert_eq!(result.task_id, task_id);
2035 assert_eq!(result.status, TaskStatus::Completed);
2036 assert!(result.result.is_some());
2037 assert!(result.error.is_none());
2038
2039 let tool_result = result.result.unwrap();
2041 match &tool_result.content[0] {
2042 Content::Text { text, .. } => assert_eq!(text, "15"),
2043 _ => panic!("Expected text content"),
2044 }
2045 }
2046 _ => panic!("Expected GetTaskResult response"),
2047 }
2048 }
2049
2050 #[tokio::test]
2051 async fn test_task_cancellation() {
2052 let slow_tool = ToolBuilder::new("slow")
2054 .description("Slow tool")
2055 .handler(|_input: serde_json::Value| async move {
2056 tokio::time::sleep(tokio::time::Duration::from_secs(60)).await;
2057 Ok(CallToolResult::text("done"))
2058 })
2059 .build()
2060 .expect("valid tool name");
2061
2062 let mut router = McpRouter::new().tool(slow_tool);
2063 init_router(&mut router).await;
2064
2065 let req = RouterRequest {
2067 id: RequestId::Number(1),
2068 inner: McpRequest::EnqueueTask(EnqueueTaskParams {
2069 tool_name: "slow".to_string(),
2070 arguments: serde_json::json!({}),
2071 ttl: None,
2072 }),
2073 extensions: Extensions::new(),
2074 };
2075
2076 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2077 let task_id = match resp.inner {
2078 Ok(McpResponse::EnqueueTask(result)) => result.task_id,
2079 _ => panic!("Expected EnqueueTask response"),
2080 };
2081
2082 let req = RouterRequest {
2084 id: RequestId::Number(2),
2085 inner: McpRequest::CancelTask(CancelTaskParams {
2086 task_id: task_id.clone(),
2087 reason: Some("Test cancellation".to_string()),
2088 }),
2089 extensions: Extensions::new(),
2090 };
2091
2092 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2093
2094 match resp.inner {
2095 Ok(McpResponse::CancelTask(result)) => {
2096 assert!(result.cancelled);
2097 assert_eq!(result.status, TaskStatus::Cancelled);
2098 }
2099 _ => panic!("Expected CancelTask response"),
2100 }
2101 }
2102
2103 #[tokio::test]
2104 async fn test_get_task_info() {
2105 let add_tool = ToolBuilder::new("add")
2106 .description("Add two numbers")
2107 .handler(|input: AddInput| async move {
2108 Ok(CallToolResult::text(format!("{}", input.a + input.b)))
2109 })
2110 .build()
2111 .expect("valid tool name");
2112
2113 let mut router = McpRouter::new().tool(add_tool);
2114 init_router(&mut router).await;
2115
2116 let req = RouterRequest {
2118 id: RequestId::Number(1),
2119 inner: McpRequest::EnqueueTask(EnqueueTaskParams {
2120 tool_name: "add".to_string(),
2121 arguments: serde_json::json!({"a": 1, "b": 2}),
2122 ttl: Some(600),
2123 }),
2124 extensions: Extensions::new(),
2125 };
2126
2127 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2128 let task_id = match resp.inner {
2129 Ok(McpResponse::EnqueueTask(result)) => result.task_id,
2130 _ => panic!("Expected EnqueueTask response"),
2131 };
2132
2133 let req = RouterRequest {
2135 id: RequestId::Number(2),
2136 inner: McpRequest::GetTaskInfo(GetTaskInfoParams {
2137 task_id: task_id.clone(),
2138 }),
2139 extensions: Extensions::new(),
2140 };
2141
2142 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2143
2144 match resp.inner {
2145 Ok(McpResponse::GetTaskInfo(info)) => {
2146 assert_eq!(info.task_id, task_id);
2147 assert!(info.created_at.contains('T')); assert_eq!(info.ttl, Some(600));
2149 }
2150 _ => panic!("Expected GetTaskInfo response"),
2151 }
2152 }
2153
2154 #[tokio::test]
2155 async fn test_enqueue_nonexistent_tool() {
2156 let mut router = McpRouter::new();
2157 init_router(&mut router).await;
2158
2159 let req = RouterRequest {
2160 id: RequestId::Number(1),
2161 inner: McpRequest::EnqueueTask(EnqueueTaskParams {
2162 tool_name: "nonexistent".to_string(),
2163 arguments: serde_json::json!({}),
2164 ttl: None,
2165 }),
2166 extensions: Extensions::new(),
2167 };
2168
2169 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2170
2171 match resp.inner {
2172 Err(e) => {
2173 assert!(e.message.contains("not found"));
2174 }
2175 _ => panic!("Expected error response"),
2176 }
2177 }
2178
2179 #[tokio::test]
2180 async fn test_get_nonexistent_task() {
2181 let mut router = McpRouter::new();
2182 init_router(&mut router).await;
2183
2184 let req = RouterRequest {
2185 id: RequestId::Number(1),
2186 inner: McpRequest::GetTaskInfo(GetTaskInfoParams {
2187 task_id: "task-999".to_string(),
2188 }),
2189 extensions: Extensions::new(),
2190 };
2191
2192 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2193
2194 match resp.inner {
2195 Err(e) => {
2196 assert!(e.message.contains("not found"));
2197 }
2198 _ => panic!("Expected error response"),
2199 }
2200 }
2201
2202 #[tokio::test]
2207 async fn test_subscribe_to_resource() {
2208 use crate::resource::ResourceBuilder;
2209
2210 let resource = ResourceBuilder::new("file:///test.txt")
2211 .name("Test File")
2212 .text("Hello");
2213
2214 let mut router = McpRouter::new().resource(resource);
2215 init_router(&mut router).await;
2216
2217 let req = RouterRequest {
2219 id: RequestId::Number(1),
2220 inner: McpRequest::SubscribeResource(SubscribeResourceParams {
2221 uri: "file:///test.txt".to_string(),
2222 }),
2223 extensions: Extensions::new(),
2224 };
2225
2226 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2227
2228 match resp.inner {
2229 Ok(McpResponse::SubscribeResource(_)) => {
2230 assert!(router.is_subscribed("file:///test.txt"));
2232 }
2233 _ => panic!("Expected SubscribeResource response"),
2234 }
2235 }
2236
2237 #[tokio::test]
2238 async fn test_unsubscribe_from_resource() {
2239 use crate::resource::ResourceBuilder;
2240
2241 let resource = ResourceBuilder::new("file:///test.txt")
2242 .name("Test File")
2243 .text("Hello");
2244
2245 let mut router = McpRouter::new().resource(resource);
2246 init_router(&mut router).await;
2247
2248 let req = RouterRequest {
2250 id: RequestId::Number(1),
2251 inner: McpRequest::SubscribeResource(SubscribeResourceParams {
2252 uri: "file:///test.txt".to_string(),
2253 }),
2254 extensions: Extensions::new(),
2255 };
2256 let _ = router.ready().await.unwrap().call(req).await.unwrap();
2257 assert!(router.is_subscribed("file:///test.txt"));
2258
2259 let req = RouterRequest {
2261 id: RequestId::Number(2),
2262 inner: McpRequest::UnsubscribeResource(UnsubscribeResourceParams {
2263 uri: "file:///test.txt".to_string(),
2264 }),
2265 extensions: Extensions::new(),
2266 };
2267
2268 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2269
2270 match resp.inner {
2271 Ok(McpResponse::UnsubscribeResource(_)) => {
2272 assert!(!router.is_subscribed("file:///test.txt"));
2274 }
2275 _ => panic!("Expected UnsubscribeResource response"),
2276 }
2277 }
2278
2279 #[tokio::test]
2280 async fn test_subscribe_nonexistent_resource() {
2281 let mut router = McpRouter::new();
2282 init_router(&mut router).await;
2283
2284 let req = RouterRequest {
2285 id: RequestId::Number(1),
2286 inner: McpRequest::SubscribeResource(SubscribeResourceParams {
2287 uri: "file:///nonexistent.txt".to_string(),
2288 }),
2289 extensions: Extensions::new(),
2290 };
2291
2292 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2293
2294 match resp.inner {
2295 Err(e) => {
2296 assert!(e.message.contains("not found"));
2297 }
2298 _ => panic!("Expected error response"),
2299 }
2300 }
2301
2302 #[tokio::test]
2303 async fn test_notify_resource_updated() {
2304 use crate::context::notification_channel;
2305 use crate::resource::ResourceBuilder;
2306
2307 let (tx, mut rx) = notification_channel(10);
2308
2309 let resource = ResourceBuilder::new("file:///test.txt")
2310 .name("Test File")
2311 .text("Hello");
2312
2313 let router = McpRouter::new()
2314 .resource(resource)
2315 .with_notification_sender(tx);
2316
2317 router.subscribe("file:///test.txt");
2319
2320 let sent = router.notify_resource_updated("file:///test.txt");
2322 assert!(sent);
2323
2324 let notification = rx.try_recv().unwrap();
2326 match notification {
2327 ServerNotification::ResourceUpdated { uri } => {
2328 assert_eq!(uri, "file:///test.txt");
2329 }
2330 _ => panic!("Expected ResourceUpdated notification"),
2331 }
2332 }
2333
2334 #[tokio::test]
2335 async fn test_notify_resource_updated_not_subscribed() {
2336 use crate::context::notification_channel;
2337 use crate::resource::ResourceBuilder;
2338
2339 let (tx, mut rx) = notification_channel(10);
2340
2341 let resource = ResourceBuilder::new("file:///test.txt")
2342 .name("Test File")
2343 .text("Hello");
2344
2345 let router = McpRouter::new()
2346 .resource(resource)
2347 .with_notification_sender(tx);
2348
2349 let sent = router.notify_resource_updated("file:///test.txt");
2351 assert!(!sent); assert!(rx.try_recv().is_err());
2355 }
2356
2357 #[tokio::test]
2358 async fn test_notify_resources_list_changed() {
2359 use crate::context::notification_channel;
2360
2361 let (tx, mut rx) = notification_channel(10);
2362 let router = McpRouter::new().with_notification_sender(tx);
2363
2364 let sent = router.notify_resources_list_changed();
2365 assert!(sent);
2366
2367 let notification = rx.try_recv().unwrap();
2368 match notification {
2369 ServerNotification::ResourcesListChanged => {}
2370 _ => panic!("Expected ResourcesListChanged notification"),
2371 }
2372 }
2373
2374 #[tokio::test]
2375 async fn test_subscribed_uris() {
2376 use crate::resource::ResourceBuilder;
2377
2378 let resource1 = ResourceBuilder::new("file:///a.txt").name("A").text("A");
2379
2380 let resource2 = ResourceBuilder::new("file:///b.txt").name("B").text("B");
2381
2382 let router = McpRouter::new().resource(resource1).resource(resource2);
2383
2384 router.subscribe("file:///a.txt");
2386 router.subscribe("file:///b.txt");
2387
2388 let uris = router.subscribed_uris();
2389 assert_eq!(uris.len(), 2);
2390 assert!(uris.contains(&"file:///a.txt".to_string()));
2391 assert!(uris.contains(&"file:///b.txt".to_string()));
2392 }
2393
2394 #[tokio::test]
2395 async fn test_subscription_capability_advertised() {
2396 use crate::resource::ResourceBuilder;
2397
2398 let resource = ResourceBuilder::new("file:///test.txt")
2399 .name("Test")
2400 .text("Hello");
2401
2402 let mut router = McpRouter::new().resource(resource);
2403
2404 let init_req = RouterRequest {
2406 id: RequestId::Number(0),
2407 inner: McpRequest::Initialize(InitializeParams {
2408 protocol_version: "2025-11-25".to_string(),
2409 capabilities: ClientCapabilities {
2410 roots: None,
2411 sampling: None,
2412 elicitation: None,
2413 },
2414 client_info: Implementation {
2415 name: "test".to_string(),
2416 version: "1.0".to_string(),
2417 ..Default::default()
2418 },
2419 }),
2420 extensions: Extensions::new(),
2421 };
2422 let resp = router.ready().await.unwrap().call(init_req).await.unwrap();
2423
2424 match resp.inner {
2425 Ok(McpResponse::Initialize(result)) => {
2426 let resources_cap = result.capabilities.resources.unwrap();
2428 assert!(resources_cap.subscribe);
2429 }
2430 _ => panic!("Expected Initialize response"),
2431 }
2432 }
2433
2434 #[tokio::test]
2435 async fn test_completion_handler() {
2436 let router = McpRouter::new()
2437 .server_info("test", "1.0")
2438 .completion_handler(|params: CompleteParams| async move {
2439 let prefix = ¶ms.argument.value;
2441 let suggestions: Vec<String> = vec!["alpha", "beta", "gamma"]
2442 .into_iter()
2443 .filter(|s| s.starts_with(prefix))
2444 .map(String::from)
2445 .collect();
2446 Ok(CompleteResult::new(suggestions))
2447 });
2448
2449 let init_req = RouterRequest {
2451 id: RequestId::Number(0),
2452 inner: McpRequest::Initialize(InitializeParams {
2453 protocol_version: "2025-11-25".to_string(),
2454 capabilities: ClientCapabilities::default(),
2455 client_info: Implementation {
2456 name: "test".to_string(),
2457 version: "1.0".to_string(),
2458 ..Default::default()
2459 },
2460 }),
2461 extensions: Extensions::new(),
2462 };
2463 let resp = router
2464 .clone()
2465 .ready()
2466 .await
2467 .unwrap()
2468 .call(init_req)
2469 .await
2470 .unwrap();
2471
2472 match resp.inner {
2474 Ok(McpResponse::Initialize(result)) => {
2475 assert!(result.capabilities.completions.is_some());
2476 }
2477 _ => panic!("Expected Initialize response"),
2478 }
2479
2480 router.handle_notification(McpNotification::Initialized);
2482
2483 let complete_req = RouterRequest {
2485 id: RequestId::Number(1),
2486 inner: McpRequest::Complete(CompleteParams {
2487 reference: CompletionReference::prompt("test-prompt"),
2488 argument: CompletionArgument::new("query", "al"),
2489 }),
2490 extensions: Extensions::new(),
2491 };
2492 let resp = router
2493 .clone()
2494 .ready()
2495 .await
2496 .unwrap()
2497 .call(complete_req)
2498 .await
2499 .unwrap();
2500
2501 match resp.inner {
2502 Ok(McpResponse::Complete(result)) => {
2503 assert_eq!(result.completion.values, vec!["alpha"]);
2504 }
2505 _ => panic!("Expected Complete response"),
2506 }
2507 }
2508
2509 #[tokio::test]
2510 async fn test_completion_without_handler_returns_empty() {
2511 let router = McpRouter::new().server_info("test", "1.0");
2512
2513 let init_req = RouterRequest {
2515 id: RequestId::Number(0),
2516 inner: McpRequest::Initialize(InitializeParams {
2517 protocol_version: "2025-11-25".to_string(),
2518 capabilities: ClientCapabilities::default(),
2519 client_info: Implementation {
2520 name: "test".to_string(),
2521 version: "1.0".to_string(),
2522 ..Default::default()
2523 },
2524 }),
2525 extensions: Extensions::new(),
2526 };
2527 let resp = router
2528 .clone()
2529 .ready()
2530 .await
2531 .unwrap()
2532 .call(init_req)
2533 .await
2534 .unwrap();
2535
2536 match resp.inner {
2538 Ok(McpResponse::Initialize(result)) => {
2539 assert!(result.capabilities.completions.is_none());
2540 }
2541 _ => panic!("Expected Initialize response"),
2542 }
2543
2544 router.handle_notification(McpNotification::Initialized);
2546
2547 let complete_req = RouterRequest {
2549 id: RequestId::Number(1),
2550 inner: McpRequest::Complete(CompleteParams {
2551 reference: CompletionReference::prompt("test-prompt"),
2552 argument: CompletionArgument::new("query", "al"),
2553 }),
2554 extensions: Extensions::new(),
2555 };
2556 let resp = router
2557 .clone()
2558 .ready()
2559 .await
2560 .unwrap()
2561 .call(complete_req)
2562 .await
2563 .unwrap();
2564
2565 match resp.inner {
2566 Ok(McpResponse::Complete(result)) => {
2567 assert!(result.completion.values.is_empty());
2568 }
2569 _ => panic!("Expected Complete response"),
2570 }
2571 }
2572}