1use std::collections::{HashMap, HashSet};
7use std::future::Future;
8use std::pin::Pin;
9use std::sync::{Arc, RwLock};
10use std::task::{Context, Poll};
11
12use tower_service::Service;
13
14use crate::async_task::TaskStore;
15use crate::context::{
16 CancellationToken, ClientRequesterHandle, NotificationSender, RequestContext,
17 ServerNotification,
18};
19use crate::error::{Error, JsonRpcError, Result};
20use crate::prompt::Prompt;
21use crate::protocol::*;
22use crate::resource::{Resource, ResourceTemplate};
23use crate::session::SessionState;
24use crate::tool::Tool;
25
26pub type CompletionHandler = Arc<
28 dyn Fn(CompleteParams) -> Pin<Box<dyn Future<Output = Result<CompleteResult>> + Send>>
29 + Send
30 + Sync,
31>;
32
33#[derive(Clone)]
58pub struct McpRouter {
59 inner: Arc<McpRouterInner>,
60 session: SessionState,
61}
62
63impl std::fmt::Debug for McpRouter {
64 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
65 f.debug_struct("McpRouter")
66 .field("server_name", &self.inner.server_name)
67 .field("server_version", &self.inner.server_version)
68 .field("tools_count", &self.inner.tools.len())
69 .field("resources_count", &self.inner.resources.len())
70 .field("prompts_count", &self.inner.prompts.len())
71 .field("session_phase", &self.session.phase())
72 .finish()
73 }
74}
75
76#[derive(Clone)]
78struct McpRouterInner {
79 server_name: String,
80 server_version: String,
81 instructions: Option<String>,
82 tools: HashMap<String, Arc<Tool>>,
83 resources: HashMap<String, Arc<Resource>>,
84 resource_templates: Vec<Arc<ResourceTemplate>>,
86 prompts: HashMap<String, Arc<Prompt>>,
87 in_flight: Arc<RwLock<HashMap<RequestId, CancellationToken>>>,
89 notification_tx: Option<NotificationSender>,
91 client_requester: Option<ClientRequesterHandle>,
93 task_store: TaskStore,
95 subscriptions: Arc<RwLock<HashSet<String>>>,
97 completion_handler: Option<CompletionHandler>,
99}
100
101impl McpRouter {
102 pub fn new() -> Self {
104 Self {
105 inner: Arc::new(McpRouterInner {
106 server_name: "tower-mcp".to_string(),
107 server_version: env!("CARGO_PKG_VERSION").to_string(),
108 instructions: None,
109 tools: HashMap::new(),
110 resources: HashMap::new(),
111 resource_templates: Vec::new(),
112 prompts: HashMap::new(),
113 in_flight: Arc::new(RwLock::new(HashMap::new())),
114 notification_tx: None,
115 client_requester: None,
116 task_store: TaskStore::new(),
117 subscriptions: Arc::new(RwLock::new(HashSet::new())),
118 completion_handler: None,
119 }),
120 session: SessionState::new(),
121 }
122 }
123
124 pub fn task_store(&self) -> &TaskStore {
126 &self.inner.task_store
127 }
128
129 pub fn with_notification_sender(mut self, tx: NotificationSender) -> Self {
133 Arc::make_mut(&mut self.inner).notification_tx = Some(tx);
134 self
135 }
136
137 pub fn notification_sender(&self) -> Option<&NotificationSender> {
139 self.inner.notification_tx.as_ref()
140 }
141
142 pub fn with_client_requester(mut self, requester: ClientRequesterHandle) -> Self {
147 Arc::make_mut(&mut self.inner).client_requester = Some(requester);
148 self
149 }
150
151 pub fn client_requester(&self) -> Option<&ClientRequesterHandle> {
153 self.inner.client_requester.as_ref()
154 }
155
156 pub fn create_context(
161 &self,
162 request_id: RequestId,
163 progress_token: Option<ProgressToken>,
164 ) -> RequestContext {
165 let ctx = RequestContext::new(request_id.clone());
166
167 let ctx = if let Some(token) = progress_token {
169 ctx.with_progress_token(token)
170 } else {
171 ctx
172 };
173
174 let ctx = if let Some(tx) = &self.inner.notification_tx {
176 ctx.with_notification_sender(tx.clone())
177 } else {
178 ctx
179 };
180
181 let ctx = if let Some(requester) = &self.inner.client_requester {
183 ctx.with_client_requester(requester.clone())
184 } else {
185 ctx
186 };
187
188 let token = ctx.cancellation_token();
190 if let Ok(mut in_flight) = self.inner.in_flight.write() {
191 in_flight.insert(request_id, token);
192 }
193
194 ctx
195 }
196
197 pub fn complete_request(&self, request_id: &RequestId) {
199 if let Ok(mut in_flight) = self.inner.in_flight.write() {
200 in_flight.remove(request_id);
201 }
202 }
203
204 fn cancel_request(&self, request_id: &RequestId) -> bool {
206 if let Ok(in_flight) = self.inner.in_flight.read()
207 && let Some(token) = in_flight.get(request_id)
208 {
209 token.cancel();
210 return true;
211 }
212 false
213 }
214
215 pub fn server_info(mut self, name: impl Into<String>, version: impl Into<String>) -> Self {
217 let inner = Arc::make_mut(&mut self.inner);
218 inner.server_name = name.into();
219 inner.server_version = version.into();
220 self
221 }
222
223 pub fn instructions(mut self, instructions: impl Into<String>) -> Self {
225 Arc::make_mut(&mut self.inner).instructions = Some(instructions.into());
226 self
227 }
228
229 pub fn tool(mut self, tool: Tool) -> Self {
231 Arc::make_mut(&mut self.inner)
232 .tools
233 .insert(tool.name.clone(), Arc::new(tool));
234 self
235 }
236
237 pub fn resource(mut self, resource: Resource) -> Self {
239 Arc::make_mut(&mut self.inner)
240 .resources
241 .insert(resource.uri.clone(), Arc::new(resource));
242 self
243 }
244
245 pub fn resource_template(mut self, template: ResourceTemplate) -> Self {
276 Arc::make_mut(&mut self.inner)
277 .resource_templates
278 .push(Arc::new(template));
279 self
280 }
281
282 pub fn prompt(mut self, prompt: Prompt) -> Self {
284 Arc::make_mut(&mut self.inner)
285 .prompts
286 .insert(prompt.name.clone(), Arc::new(prompt));
287 self
288 }
289
290 pub fn completion_handler<F, Fut>(mut self, handler: F) -> Self
317 where
318 F: Fn(CompleteParams) -> Fut + Send + Sync + 'static,
319 Fut: Future<Output = Result<CompleteResult>> + Send + 'static,
320 {
321 Arc::make_mut(&mut self.inner).completion_handler =
322 Some(Arc::new(move |params| Box::pin(handler(params))));
323 self
324 }
325
326 pub fn session(&self) -> &SessionState {
328 &self.session
329 }
330
331 pub fn log(&self, params: LoggingMessageParams) -> bool {
353 if let Some(tx) = &self.inner.notification_tx
354 && tx.try_send(ServerNotification::LogMessage(params)).is_ok()
355 {
356 return true;
357 }
358 false
359 }
360
361 pub fn log_info(&self, message: &str) -> bool {
365 self.log(
366 LoggingMessageParams::new(LogLevel::Info)
367 .with_data(serde_json::json!({ "message": message })),
368 )
369 }
370
371 pub fn log_warning(&self, message: &str) -> bool {
373 self.log(
374 LoggingMessageParams::new(LogLevel::Warning)
375 .with_data(serde_json::json!({ "message": message })),
376 )
377 }
378
379 pub fn log_error(&self, message: &str) -> bool {
381 self.log(
382 LoggingMessageParams::new(LogLevel::Error)
383 .with_data(serde_json::json!({ "message": message })),
384 )
385 }
386
387 pub fn log_debug(&self, message: &str) -> bool {
389 self.log(
390 LoggingMessageParams::new(LogLevel::Debug)
391 .with_data(serde_json::json!({ "message": message })),
392 )
393 }
394
395 pub fn is_subscribed(&self, uri: &str) -> bool {
397 if let Ok(subs) = self.inner.subscriptions.read() {
398 return subs.contains(uri);
399 }
400 false
401 }
402
403 pub fn subscribed_uris(&self) -> Vec<String> {
405 if let Ok(subs) = self.inner.subscriptions.read() {
406 return subs.iter().cloned().collect();
407 }
408 Vec::new()
409 }
410
411 fn subscribe(&self, uri: &str) -> bool {
413 if let Ok(mut subs) = self.inner.subscriptions.write() {
414 return subs.insert(uri.to_string());
415 }
416 false
417 }
418
419 fn unsubscribe(&self, uri: &str) -> bool {
421 if let Ok(mut subs) = self.inner.subscriptions.write() {
422 return subs.remove(uri);
423 }
424 false
425 }
426
427 pub fn notify_resource_updated(&self, uri: &str) -> bool {
432 if !self.is_subscribed(uri) {
434 return false;
435 }
436
437 if let Some(tx) = &self.inner.notification_tx
438 && tx
439 .try_send(ServerNotification::ResourceUpdated {
440 uri: uri.to_string(),
441 })
442 .is_ok()
443 {
444 return true;
445 }
446 false
447 }
448
449 pub fn notify_resources_list_changed(&self) -> bool {
453 if let Some(tx) = &self.inner.notification_tx
454 && tx
455 .try_send(ServerNotification::ResourcesListChanged)
456 .is_ok()
457 {
458 return true;
459 }
460 false
461 }
462
463 fn capabilities(&self) -> ServerCapabilities {
465 let has_resources =
466 !self.inner.resources.is_empty() || !self.inner.resource_templates.is_empty();
467
468 ServerCapabilities {
469 tools: if self.inner.tools.is_empty() {
470 None
471 } else {
472 Some(ToolsCapability::default())
473 },
474 resources: if has_resources {
475 Some(ResourcesCapability {
476 subscribe: true,
477 ..Default::default()
478 })
479 } else {
480 None
481 },
482 prompts: if self.inner.prompts.is_empty() {
483 None
484 } else {
485 Some(PromptsCapability::default())
486 },
487 logging: if self.inner.notification_tx.is_some() {
489 Some(LoggingCapability::default())
490 } else {
491 None
492 },
493 tasks: Some(TasksCapability::default()),
495 completions: if self.inner.completion_handler.is_some() {
497 Some(CompletionsCapability::default())
498 } else {
499 None
500 },
501 }
502 }
503
504 async fn handle(&self, request_id: RequestId, request: McpRequest) -> Result<McpResponse> {
506 let method = request.method_name();
508 if !self.session.is_request_allowed(method) {
509 tracing::warn!(
510 method = %method,
511 phase = ?self.session.phase(),
512 "Request rejected: session not initialized"
513 );
514 return Err(Error::JsonRpc(JsonRpcError::invalid_request(format!(
515 "Session not initialized. Only 'initialize' and 'ping' are allowed before initialization. Got: {}",
516 method
517 ))));
518 }
519
520 match request {
521 McpRequest::Initialize(params) => {
522 tracing::info!(
523 client = %params.client_info.name,
524 version = %params.client_info.version,
525 "Client initializing"
526 );
527
528 let protocol_version = if crate::protocol::SUPPORTED_PROTOCOL_VERSIONS
531 .contains(¶ms.protocol_version.as_str())
532 {
533 params.protocol_version
534 } else {
535 crate::protocol::LATEST_PROTOCOL_VERSION.to_string()
536 };
537
538 self.session.mark_initializing();
540
541 Ok(McpResponse::Initialize(InitializeResult {
542 protocol_version,
543 capabilities: self.capabilities(),
544 server_info: Implementation {
545 name: self.inner.server_name.clone(),
546 version: self.inner.server_version.clone(),
547 },
548 instructions: self.inner.instructions.clone(),
549 }))
550 }
551
552 McpRequest::ListTools(_params) => {
553 let tools: Vec<ToolDefinition> =
554 self.inner.tools.values().map(|t| t.definition()).collect();
555
556 Ok(McpResponse::ListTools(ListToolsResult {
557 tools,
558 next_cursor: None,
559 }))
560 }
561
562 McpRequest::CallTool(params) => {
563 let tool =
564 self.inner.tools.get(¶ms.name).ok_or_else(|| {
565 Error::JsonRpc(JsonRpcError::method_not_found(¶ms.name))
566 })?;
567
568 let progress_token = params.meta.and_then(|m| m.progress_token);
570 let ctx = self.create_context(request_id, progress_token);
571
572 tracing::debug!(tool = %params.name, "Calling tool");
573 let result = tool.call_with_context(ctx, params.arguments).await?;
574
575 Ok(McpResponse::CallTool(result))
576 }
577
578 McpRequest::ListResources(_params) => {
579 let resources: Vec<ResourceDefinition> = self
580 .inner
581 .resources
582 .values()
583 .map(|r| r.definition())
584 .collect();
585
586 Ok(McpResponse::ListResources(ListResourcesResult {
587 resources,
588 next_cursor: None,
589 }))
590 }
591
592 McpRequest::ListResourceTemplates(_params) => {
593 let resource_templates: Vec<ResourceTemplateDefinition> = self
594 .inner
595 .resource_templates
596 .iter()
597 .map(|t| t.definition())
598 .collect();
599
600 Ok(McpResponse::ListResourceTemplates(
601 ListResourceTemplatesResult {
602 resource_templates,
603 next_cursor: None,
604 },
605 ))
606 }
607
608 McpRequest::ReadResource(params) => {
609 if let Some(resource) = self.inner.resources.get(¶ms.uri) {
611 tracing::debug!(uri = %params.uri, "Reading static resource");
612 let result = resource.read().await?;
613 return Ok(McpResponse::ReadResource(result));
614 }
615
616 for template in &self.inner.resource_templates {
618 if let Some(variables) = template.match_uri(¶ms.uri) {
619 tracing::debug!(
620 uri = %params.uri,
621 template = %template.uri_template,
622 "Reading resource via template"
623 );
624 let result = template.read(¶ms.uri, variables).await?;
625 return Ok(McpResponse::ReadResource(result));
626 }
627 }
628
629 Err(Error::JsonRpc(JsonRpcError::resource_not_found(
631 ¶ms.uri,
632 )))
633 }
634
635 McpRequest::SubscribeResource(params) => {
636 if !self.inner.resources.contains_key(¶ms.uri) {
638 return Err(Error::JsonRpc(JsonRpcError::resource_not_found(
639 ¶ms.uri,
640 )));
641 }
642
643 tracing::debug!(uri = %params.uri, "Subscribing to resource");
644 self.subscribe(¶ms.uri);
645
646 Ok(McpResponse::SubscribeResource(EmptyResult {}))
647 }
648
649 McpRequest::UnsubscribeResource(params) => {
650 if !self.inner.resources.contains_key(¶ms.uri) {
652 return Err(Error::JsonRpc(JsonRpcError::resource_not_found(
653 ¶ms.uri,
654 )));
655 }
656
657 tracing::debug!(uri = %params.uri, "Unsubscribing from resource");
658 self.unsubscribe(¶ms.uri);
659
660 Ok(McpResponse::UnsubscribeResource(EmptyResult {}))
661 }
662
663 McpRequest::ListPrompts(_params) => {
664 let prompts: Vec<PromptDefinition> = self
665 .inner
666 .prompts
667 .values()
668 .map(|p| p.definition())
669 .collect();
670
671 Ok(McpResponse::ListPrompts(ListPromptsResult {
672 prompts,
673 next_cursor: None,
674 }))
675 }
676
677 McpRequest::GetPrompt(params) => {
678 let prompt = self.inner.prompts.get(¶ms.name).ok_or_else(|| {
679 Error::JsonRpc(JsonRpcError::method_not_found(&format!(
680 "Prompt not found: {}",
681 params.name
682 )))
683 })?;
684
685 tracing::debug!(name = %params.name, "Getting prompt");
686 let result = prompt.get(params.arguments).await?;
687
688 Ok(McpResponse::GetPrompt(result))
689 }
690
691 McpRequest::Ping => Ok(McpResponse::Pong(EmptyResult {})),
692
693 McpRequest::EnqueueTask(params) => {
694 let tool = self.inner.tools.get(¶ms.tool_name).ok_or_else(|| {
696 Error::JsonRpc(JsonRpcError::method_not_found(&format!(
697 "Tool not found: {}",
698 params.tool_name
699 )))
700 })?;
701
702 let (task_id, cancellation_token) = self.inner.task_store.create_task(
704 ¶ms.tool_name,
705 params.arguments.clone(),
706 params.ttl,
707 );
708
709 tracing::info!(task_id = %task_id, tool = %params.tool_name, "Enqueued async task");
710
711 let ctx = self.create_context(request_id, None);
713
714 let task_store = self.inner.task_store.clone();
716 let tool = tool.clone();
717 let arguments = params.arguments;
718 let task_id_clone = task_id.clone();
719
720 tokio::spawn(async move {
721 if cancellation_token.is_cancelled() {
723 tracing::debug!(task_id = %task_id_clone, "Task cancelled before execution");
724 return;
725 }
726
727 match tool.call_with_context(ctx, arguments).await {
729 Ok(result) => {
730 if cancellation_token.is_cancelled() {
731 tracing::debug!(task_id = %task_id_clone, "Task cancelled during execution");
732 } else {
733 task_store.complete_task(&task_id_clone, result);
734 tracing::debug!(task_id = %task_id_clone, "Task completed successfully");
735 }
736 }
737 Err(e) => {
738 task_store.fail_task(&task_id_clone, &e.to_string());
739 tracing::warn!(task_id = %task_id_clone, error = %e, "Task failed");
740 }
741 }
742 });
743
744 Ok(McpResponse::EnqueueTask(EnqueueTaskResult {
745 task_id,
746 status: TaskStatus::Working,
747 poll_interval: Some(2),
748 }))
749 }
750
751 McpRequest::ListTasks(params) => {
752 let tasks = self.inner.task_store.list_tasks(params.status);
753
754 Ok(McpResponse::ListTasks(ListTasksResult {
755 tasks,
756 next_cursor: None,
757 }))
758 }
759
760 McpRequest::GetTaskInfo(params) => {
761 let task = self
762 .inner
763 .task_store
764 .get_task(¶ms.task_id)
765 .ok_or_else(|| {
766 Error::JsonRpc(JsonRpcError::invalid_params(format!(
767 "Task not found: {}",
768 params.task_id
769 )))
770 })?;
771
772 Ok(McpResponse::GetTaskInfo(task))
773 }
774
775 McpRequest::GetTaskResult(params) => {
776 let (status, result, error) = self
777 .inner
778 .task_store
779 .get_task_full(¶ms.task_id)
780 .ok_or_else(|| {
781 Error::JsonRpc(JsonRpcError::invalid_params(format!(
782 "Task not found: {}",
783 params.task_id
784 )))
785 })?;
786
787 Ok(McpResponse::GetTaskResult(GetTaskResultResult {
788 task_id: params.task_id,
789 status,
790 result,
791 error,
792 }))
793 }
794
795 McpRequest::CancelTask(params) => {
796 let status = self
797 .inner
798 .task_store
799 .cancel_task(¶ms.task_id, params.reason.as_deref())
800 .ok_or_else(|| {
801 Error::JsonRpc(JsonRpcError::invalid_params(format!(
802 "Task not found: {}",
803 params.task_id
804 )))
805 })?;
806
807 let cancelled = status == TaskStatus::Cancelled;
808
809 Ok(McpResponse::CancelTask(CancelTaskResult {
810 cancelled,
811 status,
812 }))
813 }
814
815 McpRequest::SetLoggingLevel(params) => {
816 tracing::debug!(level = ?params.level, "Client set logging level");
820 Ok(McpResponse::SetLoggingLevel(EmptyResult {}))
821 }
822
823 McpRequest::Complete(params) => {
824 tracing::debug!(
825 reference = ?params.reference,
826 argument = %params.argument.name,
827 "Completion request"
828 );
829
830 if let Some(ref handler) = self.inner.completion_handler {
832 let result = handler(params).await?;
833 Ok(McpResponse::Complete(result))
834 } else {
835 Ok(McpResponse::Complete(CompleteResult::new(vec![])))
837 }
838 }
839
840 McpRequest::Unknown { method, .. } => {
841 Err(Error::JsonRpc(JsonRpcError::method_not_found(&method)))
842 }
843 }
844 }
845
846 pub fn handle_notification(&self, notification: McpNotification) {
848 match notification {
849 McpNotification::Initialized => {
850 if self.session.mark_initialized() {
851 tracing::info!("Session initialized, entering operation phase");
852 } else {
853 tracing::warn!(
854 "Received initialized notification in unexpected state: {:?}",
855 self.session.phase()
856 );
857 }
858 }
859 McpNotification::Cancelled(params) => {
860 if self.cancel_request(¶ms.request_id) {
861 tracing::info!(
862 request_id = ?params.request_id,
863 reason = ?params.reason,
864 "Request cancelled"
865 );
866 } else {
867 tracing::debug!(
868 request_id = ?params.request_id,
869 reason = ?params.reason,
870 "Cancellation requested for unknown request"
871 );
872 }
873 }
874 McpNotification::Progress(params) => {
875 tracing::trace!(
876 token = ?params.progress_token,
877 progress = params.progress,
878 total = ?params.total,
879 "Progress notification"
880 );
881 }
883 McpNotification::RootsListChanged => {
884 tracing::info!("Client roots list changed");
885 }
888 McpNotification::Unknown { method, .. } => {
889 tracing::debug!(method = %method, "Unknown notification received");
890 }
891 }
892 }
893}
894
895impl Default for McpRouter {
896 fn default() -> Self {
897 Self::new()
898 }
899}
900
901#[derive(Debug)]
907pub struct RouterRequest {
908 pub id: RequestId,
909 pub inner: McpRequest,
910}
911
912#[derive(Debug)]
914pub struct RouterResponse {
915 pub id: RequestId,
916 pub inner: std::result::Result<McpResponse, JsonRpcError>,
917}
918
919impl RouterResponse {
920 pub fn into_jsonrpc(self) -> JsonRpcResponse {
922 match self.inner {
923 Ok(response) => match serde_json::to_value(response) {
924 Ok(result) => JsonRpcResponse::result(self.id, result),
925 Err(e) => {
926 tracing::error!(error = %e, "Failed to serialize response");
927 JsonRpcResponse::error(
928 Some(self.id),
929 JsonRpcError::internal_error(format!("Serialization error: {}", e)),
930 )
931 }
932 },
933 Err(error) => JsonRpcResponse::error(Some(self.id), error),
934 }
935 }
936}
937
938impl Service<RouterRequest> for McpRouter {
939 type Response = RouterResponse;
940 type Error = std::convert::Infallible; type Future =
942 Pin<Box<dyn Future<Output = std::result::Result<Self::Response, Self::Error>> + Send>>;
943
944 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
945 Poll::Ready(Ok(()))
946 }
947
948 fn call(&mut self, req: RouterRequest) -> Self::Future {
949 let router = self.clone();
950 let request_id = req.id.clone();
951 Box::pin(async move {
952 let result = router.handle(req.id, req.inner).await;
953 router.complete_request(&request_id);
955 Ok(RouterResponse {
956 id: request_id,
957 inner: result.map_err(|e| match e {
958 Error::JsonRpc(err) => err,
959 Error::Tool(err) => JsonRpcError::internal_error(err.to_string()),
960 e => JsonRpcError::internal_error(e.to_string()),
961 }),
962 })
963 })
964 }
965}
966
967#[cfg(test)]
968mod tests {
969 use super::*;
970 use crate::jsonrpc::JsonRpcService;
971 use crate::tool::ToolBuilder;
972 use schemars::JsonSchema;
973 use serde::Deserialize;
974 use tower::ServiceExt;
975
976 #[derive(Debug, Deserialize, JsonSchema)]
977 struct AddInput {
978 a: i64,
979 b: i64,
980 }
981
982 async fn init_router(router: &mut McpRouter) {
984 let init_req = RouterRequest {
986 id: RequestId::Number(0),
987 inner: McpRequest::Initialize(InitializeParams {
988 protocol_version: "2025-03-26".to_string(),
989 capabilities: ClientCapabilities {
990 roots: None,
991 sampling: None,
992 elicitation: None,
993 },
994 client_info: Implementation {
995 name: "test".to_string(),
996 version: "1.0".to_string(),
997 },
998 }),
999 };
1000 let _ = router.ready().await.unwrap().call(init_req).await.unwrap();
1001 router.handle_notification(McpNotification::Initialized);
1003 }
1004
1005 #[tokio::test]
1006 async fn test_router_list_tools() {
1007 let add_tool = ToolBuilder::new("add")
1008 .description("Add two numbers")
1009 .handler(|input: AddInput| async move {
1010 Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1011 })
1012 .build()
1013 .expect("valid tool name");
1014
1015 let mut router = McpRouter::new().tool(add_tool);
1016
1017 init_router(&mut router).await;
1019
1020 let req = RouterRequest {
1021 id: RequestId::Number(1),
1022 inner: McpRequest::ListTools(ListToolsParams::default()),
1023 };
1024
1025 let resp = router.ready().await.unwrap().call(req).await.unwrap();
1026
1027 match resp.inner {
1028 Ok(McpResponse::ListTools(result)) => {
1029 assert_eq!(result.tools.len(), 1);
1030 assert_eq!(result.tools[0].name, "add");
1031 }
1032 _ => panic!("Expected ListTools response"),
1033 }
1034 }
1035
1036 #[tokio::test]
1037 async fn test_router_call_tool() {
1038 let add_tool = ToolBuilder::new("add")
1039 .description("Add two numbers")
1040 .handler(|input: AddInput| async move {
1041 Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1042 })
1043 .build()
1044 .expect("valid tool name");
1045
1046 let mut router = McpRouter::new().tool(add_tool);
1047
1048 init_router(&mut router).await;
1050
1051 let req = RouterRequest {
1052 id: RequestId::Number(1),
1053 inner: McpRequest::CallTool(CallToolParams {
1054 name: "add".to_string(),
1055 arguments: serde_json::json!({"a": 2, "b": 3}),
1056 meta: None,
1057 }),
1058 };
1059
1060 let resp = router.ready().await.unwrap().call(req).await.unwrap();
1061
1062 match resp.inner {
1063 Ok(McpResponse::CallTool(result)) => {
1064 assert!(!result.is_error);
1065 match &result.content[0] {
1067 Content::Text { text, .. } => assert_eq!(text, "5"),
1068 _ => panic!("Expected text content"),
1069 }
1070 }
1071 _ => panic!("Expected CallTool response"),
1072 }
1073 }
1074
1075 async fn init_jsonrpc_service(service: &mut JsonRpcService<McpRouter>, router: &McpRouter) {
1077 let init_req = JsonRpcRequest::new(0, "initialize").with_params(serde_json::json!({
1078 "protocolVersion": "2025-03-26",
1079 "capabilities": {},
1080 "clientInfo": { "name": "test", "version": "1.0" }
1081 }));
1082 let _ = service.call_single(init_req).await.unwrap();
1083 router.handle_notification(McpNotification::Initialized);
1084 }
1085
1086 #[tokio::test]
1087 async fn test_jsonrpc_service() {
1088 let add_tool = ToolBuilder::new("add")
1089 .description("Add two numbers")
1090 .handler(|input: AddInput| async move {
1091 Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1092 })
1093 .build()
1094 .expect("valid tool name");
1095
1096 let router = McpRouter::new().tool(add_tool);
1097 let mut service = JsonRpcService::new(router.clone());
1098
1099 init_jsonrpc_service(&mut service, &router).await;
1101
1102 let req = JsonRpcRequest::new(1, "tools/list");
1103
1104 let resp = service.call_single(req).await.unwrap();
1105
1106 match resp {
1107 JsonRpcResponse::Result(r) => {
1108 assert_eq!(r.id, RequestId::Number(1));
1109 let tools = r.result.get("tools").unwrap().as_array().unwrap();
1110 assert_eq!(tools.len(), 1);
1111 }
1112 JsonRpcResponse::Error(_) => panic!("Expected success response"),
1113 }
1114 }
1115
1116 #[tokio::test]
1117 async fn test_batch_request() {
1118 let add_tool = ToolBuilder::new("add")
1119 .description("Add two numbers")
1120 .handler(|input: AddInput| async move {
1121 Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1122 })
1123 .build()
1124 .expect("valid tool name");
1125
1126 let router = McpRouter::new().tool(add_tool);
1127 let mut service = JsonRpcService::new(router.clone());
1128
1129 init_jsonrpc_service(&mut service, &router).await;
1131
1132 let requests = vec![
1134 JsonRpcRequest::new(1, "tools/list"),
1135 JsonRpcRequest::new(2, "tools/call").with_params(serde_json::json!({
1136 "name": "add",
1137 "arguments": {"a": 10, "b": 20}
1138 })),
1139 JsonRpcRequest::new(3, "ping"),
1140 ];
1141
1142 let responses = service.call_batch(requests).await.unwrap();
1143
1144 assert_eq!(responses.len(), 3);
1145
1146 match &responses[0] {
1148 JsonRpcResponse::Result(r) => {
1149 assert_eq!(r.id, RequestId::Number(1));
1150 let tools = r.result.get("tools").unwrap().as_array().unwrap();
1151 assert_eq!(tools.len(), 1);
1152 }
1153 JsonRpcResponse::Error(_) => panic!("Expected success for tools/list"),
1154 }
1155
1156 match &responses[1] {
1158 JsonRpcResponse::Result(r) => {
1159 assert_eq!(r.id, RequestId::Number(2));
1160 let content = r.result.get("content").unwrap().as_array().unwrap();
1161 let text = content[0].get("text").unwrap().as_str().unwrap();
1162 assert_eq!(text, "30");
1163 }
1164 JsonRpcResponse::Error(_) => panic!("Expected success for tools/call"),
1165 }
1166
1167 match &responses[2] {
1169 JsonRpcResponse::Result(r) => {
1170 assert_eq!(r.id, RequestId::Number(3));
1171 }
1172 JsonRpcResponse::Error(_) => panic!("Expected success for ping"),
1173 }
1174 }
1175
1176 #[tokio::test]
1177 async fn test_empty_batch_error() {
1178 let router = McpRouter::new();
1179 let mut service = JsonRpcService::new(router);
1180
1181 let result = service.call_batch(vec![]).await;
1182 assert!(result.is_err());
1183 }
1184
1185 #[tokio::test]
1190 async fn test_progress_token_extraction() {
1191 use crate::context::{RequestContext, ServerNotification, notification_channel};
1192 use crate::protocol::ProgressToken;
1193 use std::sync::Arc;
1194 use std::sync::atomic::{AtomicBool, Ordering};
1195
1196 let progress_reported = Arc::new(AtomicBool::new(false));
1198 let progress_ref = progress_reported.clone();
1199
1200 let tool = ToolBuilder::new("progress_tool")
1202 .description("Tool that reports progress")
1203 .handler_with_context(move |ctx: RequestContext, _input: AddInput| {
1204 let reported = progress_ref.clone();
1205 async move {
1206 ctx.report_progress(50.0, Some(100.0), Some("Halfway"))
1208 .await;
1209 reported.store(true, Ordering::SeqCst);
1210 Ok(CallToolResult::text("done"))
1211 }
1212 })
1213 .build()
1214 .expect("valid tool name");
1215
1216 let (tx, mut rx) = notification_channel(10);
1218 let router = McpRouter::new().with_notification_sender(tx).tool(tool);
1219 let mut service = JsonRpcService::new(router.clone());
1220
1221 init_jsonrpc_service(&mut service, &router).await;
1223
1224 let req = JsonRpcRequest::new(1, "tools/call").with_params(serde_json::json!({
1226 "name": "progress_tool",
1227 "arguments": {"a": 1, "b": 2},
1228 "_meta": {
1229 "progressToken": "test-token-123"
1230 }
1231 }));
1232
1233 let resp = service.call_single(req).await.unwrap();
1234
1235 match resp {
1237 JsonRpcResponse::Result(_) => {}
1238 JsonRpcResponse::Error(e) => panic!("Expected success, got error: {:?}", e),
1239 }
1240
1241 assert!(progress_reported.load(Ordering::SeqCst));
1243
1244 let notification = rx.try_recv().expect("Expected progress notification");
1246 match notification {
1247 ServerNotification::Progress(params) => {
1248 assert_eq!(
1249 params.progress_token,
1250 ProgressToken::String("test-token-123".to_string())
1251 );
1252 assert_eq!(params.progress, 50.0);
1253 assert_eq!(params.total, Some(100.0));
1254 assert_eq!(params.message.as_deref(), Some("Halfway"));
1255 }
1256 _ => panic!("Expected Progress notification"),
1257 }
1258 }
1259
1260 #[tokio::test]
1261 async fn test_tool_call_without_progress_token() {
1262 use crate::context::{RequestContext, notification_channel};
1263 use std::sync::Arc;
1264 use std::sync::atomic::{AtomicBool, Ordering};
1265
1266 let progress_attempted = Arc::new(AtomicBool::new(false));
1267 let progress_ref = progress_attempted.clone();
1268
1269 let tool = ToolBuilder::new("no_token_tool")
1270 .description("Tool that tries to report progress without token")
1271 .handler_with_context(move |ctx: RequestContext, _input: AddInput| {
1272 let attempted = progress_ref.clone();
1273 async move {
1274 ctx.report_progress(50.0, Some(100.0), None).await;
1276 attempted.store(true, Ordering::SeqCst);
1277 Ok(CallToolResult::text("done"))
1278 }
1279 })
1280 .build()
1281 .expect("valid tool name");
1282
1283 let (tx, mut rx) = notification_channel(10);
1284 let router = McpRouter::new().with_notification_sender(tx).tool(tool);
1285 let mut service = JsonRpcService::new(router.clone());
1286
1287 init_jsonrpc_service(&mut service, &router).await;
1288
1289 let req = JsonRpcRequest::new(1, "tools/call").with_params(serde_json::json!({
1291 "name": "no_token_tool",
1292 "arguments": {"a": 1, "b": 2}
1293 }));
1294
1295 let resp = service.call_single(req).await.unwrap();
1296 assert!(matches!(resp, JsonRpcResponse::Result(_)));
1297
1298 assert!(progress_attempted.load(Ordering::SeqCst));
1300
1301 assert!(rx.try_recv().is_err());
1303 }
1304
1305 #[tokio::test]
1306 async fn test_batch_errors_returned_not_dropped() {
1307 let add_tool = ToolBuilder::new("add")
1308 .description("Add two numbers")
1309 .handler(|input: AddInput| async move {
1310 Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1311 })
1312 .build()
1313 .expect("valid tool name");
1314
1315 let router = McpRouter::new().tool(add_tool);
1316 let mut service = JsonRpcService::new(router.clone());
1317
1318 init_jsonrpc_service(&mut service, &router).await;
1319
1320 let requests = vec![
1322 JsonRpcRequest::new(1, "tools/call").with_params(serde_json::json!({
1324 "name": "add",
1325 "arguments": {"a": 10, "b": 20}
1326 })),
1327 JsonRpcRequest::new(2, "tools/call").with_params(serde_json::json!({
1329 "name": "nonexistent_tool",
1330 "arguments": {}
1331 })),
1332 JsonRpcRequest::new(3, "ping"),
1334 ];
1335
1336 let responses = service.call_batch(requests).await.unwrap();
1337
1338 assert_eq!(responses.len(), 3);
1340
1341 match &responses[0] {
1343 JsonRpcResponse::Result(r) => {
1344 assert_eq!(r.id, RequestId::Number(1));
1345 }
1346 JsonRpcResponse::Error(_) => panic!("Expected success for first request"),
1347 }
1348
1349 match &responses[1] {
1351 JsonRpcResponse::Error(e) => {
1352 assert_eq!(e.id, Some(RequestId::Number(2)));
1353 assert!(e.error.message.contains("not found") || e.error.code == -32601);
1355 }
1356 JsonRpcResponse::Result(_) => panic!("Expected error for second request"),
1357 }
1358
1359 match &responses[2] {
1361 JsonRpcResponse::Result(r) => {
1362 assert_eq!(r.id, RequestId::Number(3));
1363 }
1364 JsonRpcResponse::Error(_) => panic!("Expected success for third request"),
1365 }
1366 }
1367
1368 #[tokio::test]
1373 async fn test_list_resource_templates() {
1374 use crate::resource::ResourceTemplateBuilder;
1375 use std::collections::HashMap;
1376
1377 let template = ResourceTemplateBuilder::new("file:///{path}")
1378 .name("Project Files")
1379 .description("Access project files")
1380 .handler(|uri: String, _vars: HashMap<String, String>| async move {
1381 Ok(ReadResourceResult {
1382 contents: vec![ResourceContent {
1383 uri,
1384 mime_type: None,
1385 text: None,
1386 blob: None,
1387 }],
1388 })
1389 });
1390
1391 let mut router = McpRouter::new().resource_template(template);
1392
1393 init_router(&mut router).await;
1395
1396 let req = RouterRequest {
1397 id: RequestId::Number(1),
1398 inner: McpRequest::ListResourceTemplates(ListResourceTemplatesParams::default()),
1399 };
1400
1401 let resp = router.ready().await.unwrap().call(req).await.unwrap();
1402
1403 match resp.inner {
1404 Ok(McpResponse::ListResourceTemplates(result)) => {
1405 assert_eq!(result.resource_templates.len(), 1);
1406 assert_eq!(result.resource_templates[0].uri_template, "file:///{path}");
1407 assert_eq!(result.resource_templates[0].name, "Project Files");
1408 }
1409 _ => panic!("Expected ListResourceTemplates response"),
1410 }
1411 }
1412
1413 #[tokio::test]
1414 async fn test_read_resource_via_template() {
1415 use crate::resource::ResourceTemplateBuilder;
1416 use std::collections::HashMap;
1417
1418 let template = ResourceTemplateBuilder::new("db://users/{id}")
1419 .name("User Records")
1420 .handler(|uri: String, vars: HashMap<String, String>| async move {
1421 let id = vars.get("id").unwrap().clone();
1422 Ok(ReadResourceResult {
1423 contents: vec![ResourceContent {
1424 uri,
1425 mime_type: Some("application/json".to_string()),
1426 text: Some(format!(r#"{{"id": "{}"}}"#, id)),
1427 blob: None,
1428 }],
1429 })
1430 });
1431
1432 let mut router = McpRouter::new().resource_template(template);
1433
1434 init_router(&mut router).await;
1436
1437 let req = RouterRequest {
1439 id: RequestId::Number(1),
1440 inner: McpRequest::ReadResource(ReadResourceParams {
1441 uri: "db://users/123".to_string(),
1442 }),
1443 };
1444
1445 let resp = router.ready().await.unwrap().call(req).await.unwrap();
1446
1447 match resp.inner {
1448 Ok(McpResponse::ReadResource(result)) => {
1449 assert_eq!(result.contents.len(), 1);
1450 assert_eq!(result.contents[0].uri, "db://users/123");
1451 assert!(result.contents[0].text.as_ref().unwrap().contains("123"));
1452 }
1453 _ => panic!("Expected ReadResource response"),
1454 }
1455 }
1456
1457 #[tokio::test]
1458 async fn test_static_resource_takes_precedence_over_template() {
1459 use crate::resource::{ResourceBuilder, ResourceTemplateBuilder};
1460 use std::collections::HashMap;
1461
1462 let template = ResourceTemplateBuilder::new("file:///{path}")
1464 .name("Files Template")
1465 .handler(|uri: String, _vars: HashMap<String, String>| async move {
1466 Ok(ReadResourceResult {
1467 contents: vec![ResourceContent {
1468 uri,
1469 mime_type: None,
1470 text: Some("from template".to_string()),
1471 blob: None,
1472 }],
1473 })
1474 });
1475
1476 let static_resource = ResourceBuilder::new("file:///README.md")
1478 .name("README")
1479 .text("from static resource");
1480
1481 let mut router = McpRouter::new()
1482 .resource_template(template)
1483 .resource(static_resource);
1484
1485 init_router(&mut router).await;
1487
1488 let req = RouterRequest {
1490 id: RequestId::Number(1),
1491 inner: McpRequest::ReadResource(ReadResourceParams {
1492 uri: "file:///README.md".to_string(),
1493 }),
1494 };
1495
1496 let resp = router.ready().await.unwrap().call(req).await.unwrap();
1497
1498 match resp.inner {
1499 Ok(McpResponse::ReadResource(result)) => {
1500 assert_eq!(
1502 result.contents[0].text.as_deref(),
1503 Some("from static resource")
1504 );
1505 }
1506 _ => panic!("Expected ReadResource response"),
1507 }
1508 }
1509
1510 #[tokio::test]
1511 async fn test_resource_not_found_when_no_match() {
1512 use crate::resource::ResourceTemplateBuilder;
1513 use std::collections::HashMap;
1514
1515 let template = ResourceTemplateBuilder::new("db://users/{id}")
1516 .name("Users")
1517 .handler(|uri: String, _vars: HashMap<String, String>| async move {
1518 Ok(ReadResourceResult {
1519 contents: vec![ResourceContent {
1520 uri,
1521 mime_type: None,
1522 text: None,
1523 blob: None,
1524 }],
1525 })
1526 });
1527
1528 let mut router = McpRouter::new().resource_template(template);
1529
1530 init_router(&mut router).await;
1532
1533 let req = RouterRequest {
1535 id: RequestId::Number(1),
1536 inner: McpRequest::ReadResource(ReadResourceParams {
1537 uri: "db://posts/123".to_string(),
1538 }),
1539 };
1540
1541 let resp = router.ready().await.unwrap().call(req).await.unwrap();
1542
1543 match resp.inner {
1544 Err(err) => {
1545 assert!(err.message.contains("not found"));
1546 }
1547 Ok(_) => panic!("Expected error for non-matching URI"),
1548 }
1549 }
1550
1551 #[tokio::test]
1552 async fn test_capabilities_include_resources_with_only_templates() {
1553 use crate::resource::ResourceTemplateBuilder;
1554 use std::collections::HashMap;
1555
1556 let template = ResourceTemplateBuilder::new("file:///{path}")
1557 .name("Files")
1558 .handler(|uri: String, _vars: HashMap<String, String>| async move {
1559 Ok(ReadResourceResult {
1560 contents: vec![ResourceContent {
1561 uri,
1562 mime_type: None,
1563 text: None,
1564 blob: None,
1565 }],
1566 })
1567 });
1568
1569 let mut router = McpRouter::new().resource_template(template);
1570
1571 let init_req = RouterRequest {
1573 id: RequestId::Number(0),
1574 inner: McpRequest::Initialize(InitializeParams {
1575 protocol_version: "2025-03-26".to_string(),
1576 capabilities: ClientCapabilities {
1577 roots: None,
1578 sampling: None,
1579 elicitation: None,
1580 },
1581 client_info: Implementation {
1582 name: "test".to_string(),
1583 version: "1.0".to_string(),
1584 },
1585 }),
1586 };
1587 let resp = router.ready().await.unwrap().call(init_req).await.unwrap();
1588
1589 match resp.inner {
1590 Ok(McpResponse::Initialize(result)) => {
1591 assert!(result.capabilities.resources.is_some());
1593 }
1594 _ => panic!("Expected Initialize response"),
1595 }
1596 }
1597
1598 #[tokio::test]
1603 async fn test_log_sends_notification() {
1604 use crate::context::notification_channel;
1605
1606 let (tx, mut rx) = notification_channel(10);
1607 let router = McpRouter::new().with_notification_sender(tx);
1608
1609 let sent = router.log_info("Test message");
1611 assert!(sent);
1612
1613 let notification = rx.try_recv().unwrap();
1615 match notification {
1616 ServerNotification::LogMessage(params) => {
1617 assert_eq!(params.level, LogLevel::Info);
1618 let data = params.data.unwrap();
1619 assert_eq!(
1620 data.get("message").unwrap().as_str().unwrap(),
1621 "Test message"
1622 );
1623 }
1624 _ => panic!("Expected LogMessage notification"),
1625 }
1626 }
1627
1628 #[tokio::test]
1629 async fn test_log_with_custom_params() {
1630 use crate::context::notification_channel;
1631
1632 let (tx, mut rx) = notification_channel(10);
1633 let router = McpRouter::new().with_notification_sender(tx);
1634
1635 let params = LoggingMessageParams::new(LogLevel::Error)
1637 .with_logger("database")
1638 .with_data(serde_json::json!({
1639 "error": "Connection failed",
1640 "host": "localhost"
1641 }));
1642
1643 let sent = router.log(params);
1644 assert!(sent);
1645
1646 let notification = rx.try_recv().unwrap();
1647 match notification {
1648 ServerNotification::LogMessage(params) => {
1649 assert_eq!(params.level, LogLevel::Error);
1650 assert_eq!(params.logger.as_deref(), Some("database"));
1651 let data = params.data.unwrap();
1652 assert_eq!(
1653 data.get("error").unwrap().as_str().unwrap(),
1654 "Connection failed"
1655 );
1656 }
1657 _ => panic!("Expected LogMessage notification"),
1658 }
1659 }
1660
1661 #[tokio::test]
1662 async fn test_log_without_channel_returns_false() {
1663 let router = McpRouter::new();
1665
1666 assert!(!router.log_info("Test"));
1668 assert!(!router.log_warning("Test"));
1669 assert!(!router.log_error("Test"));
1670 assert!(!router.log_debug("Test"));
1671 }
1672
1673 #[tokio::test]
1674 async fn test_logging_capability_with_channel() {
1675 use crate::context::notification_channel;
1676
1677 let (tx, _rx) = notification_channel(10);
1678 let mut router = McpRouter::new().with_notification_sender(tx);
1679
1680 let init_req = RouterRequest {
1682 id: RequestId::Number(0),
1683 inner: McpRequest::Initialize(InitializeParams {
1684 protocol_version: "2025-03-26".to_string(),
1685 capabilities: ClientCapabilities {
1686 roots: None,
1687 sampling: None,
1688 elicitation: None,
1689 },
1690 client_info: Implementation {
1691 name: "test".to_string(),
1692 version: "1.0".to_string(),
1693 },
1694 }),
1695 };
1696 let resp = router.ready().await.unwrap().call(init_req).await.unwrap();
1697
1698 match resp.inner {
1699 Ok(McpResponse::Initialize(result)) => {
1700 assert!(result.capabilities.logging.is_some());
1702 }
1703 _ => panic!("Expected Initialize response"),
1704 }
1705 }
1706
1707 #[tokio::test]
1708 async fn test_no_logging_capability_without_channel() {
1709 let mut router = McpRouter::new();
1710
1711 let init_req = RouterRequest {
1713 id: RequestId::Number(0),
1714 inner: McpRequest::Initialize(InitializeParams {
1715 protocol_version: "2025-03-26".to_string(),
1716 capabilities: ClientCapabilities {
1717 roots: None,
1718 sampling: None,
1719 elicitation: None,
1720 },
1721 client_info: Implementation {
1722 name: "test".to_string(),
1723 version: "1.0".to_string(),
1724 },
1725 }),
1726 };
1727 let resp = router.ready().await.unwrap().call(init_req).await.unwrap();
1728
1729 match resp.inner {
1730 Ok(McpResponse::Initialize(result)) => {
1731 assert!(result.capabilities.logging.is_none());
1733 }
1734 _ => panic!("Expected Initialize response"),
1735 }
1736 }
1737
1738 #[tokio::test]
1743 async fn test_enqueue_task() {
1744 let add_tool = ToolBuilder::new("add")
1745 .description("Add two numbers")
1746 .handler(|input: AddInput| async move {
1747 Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1748 })
1749 .build()
1750 .expect("valid tool name");
1751
1752 let mut router = McpRouter::new().tool(add_tool);
1753 init_router(&mut router).await;
1754
1755 let req = RouterRequest {
1756 id: RequestId::Number(1),
1757 inner: McpRequest::EnqueueTask(EnqueueTaskParams {
1758 tool_name: "add".to_string(),
1759 arguments: serde_json::json!({"a": 5, "b": 10}),
1760 ttl: None,
1761 }),
1762 };
1763
1764 let resp = router.ready().await.unwrap().call(req).await.unwrap();
1765
1766 match resp.inner {
1767 Ok(McpResponse::EnqueueTask(result)) => {
1768 assert!(result.task_id.starts_with("task-"));
1769 assert_eq!(result.status, TaskStatus::Working);
1770 }
1771 _ => panic!("Expected EnqueueTask response"),
1772 }
1773 }
1774
1775 #[tokio::test]
1776 async fn test_list_tasks_empty() {
1777 let mut router = McpRouter::new();
1778 init_router(&mut router).await;
1779
1780 let req = RouterRequest {
1781 id: RequestId::Number(1),
1782 inner: McpRequest::ListTasks(ListTasksParams::default()),
1783 };
1784
1785 let resp = router.ready().await.unwrap().call(req).await.unwrap();
1786
1787 match resp.inner {
1788 Ok(McpResponse::ListTasks(result)) => {
1789 assert!(result.tasks.is_empty());
1790 }
1791 _ => panic!("Expected ListTasks response"),
1792 }
1793 }
1794
1795 #[tokio::test]
1796 async fn test_task_lifecycle_complete() {
1797 let add_tool = ToolBuilder::new("add")
1798 .description("Add two numbers")
1799 .handler(|input: AddInput| async move {
1800 Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1801 })
1802 .build()
1803 .expect("valid tool name");
1804
1805 let mut router = McpRouter::new().tool(add_tool);
1806 init_router(&mut router).await;
1807
1808 let req = RouterRequest {
1810 id: RequestId::Number(1),
1811 inner: McpRequest::EnqueueTask(EnqueueTaskParams {
1812 tool_name: "add".to_string(),
1813 arguments: serde_json::json!({"a": 7, "b": 8}),
1814 ttl: None,
1815 }),
1816 };
1817
1818 let resp = router.ready().await.unwrap().call(req).await.unwrap();
1819 let task_id = match resp.inner {
1820 Ok(McpResponse::EnqueueTask(result)) => result.task_id,
1821 _ => panic!("Expected EnqueueTask response"),
1822 };
1823
1824 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
1826
1827 let req = RouterRequest {
1829 id: RequestId::Number(2),
1830 inner: McpRequest::GetTaskResult(GetTaskResultParams {
1831 task_id: task_id.clone(),
1832 }),
1833 };
1834
1835 let resp = router.ready().await.unwrap().call(req).await.unwrap();
1836
1837 match resp.inner {
1838 Ok(McpResponse::GetTaskResult(result)) => {
1839 assert_eq!(result.task_id, task_id);
1840 assert_eq!(result.status, TaskStatus::Completed);
1841 assert!(result.result.is_some());
1842 assert!(result.error.is_none());
1843
1844 let tool_result = result.result.unwrap();
1846 match &tool_result.content[0] {
1847 Content::Text { text, .. } => assert_eq!(text, "15"),
1848 _ => panic!("Expected text content"),
1849 }
1850 }
1851 _ => panic!("Expected GetTaskResult response"),
1852 }
1853 }
1854
1855 #[tokio::test]
1856 async fn test_task_cancellation() {
1857 let slow_tool = ToolBuilder::new("slow")
1859 .description("Slow tool")
1860 .handler(|_input: serde_json::Value| async move {
1861 tokio::time::sleep(tokio::time::Duration::from_secs(60)).await;
1862 Ok(CallToolResult::text("done"))
1863 })
1864 .build()
1865 .expect("valid tool name");
1866
1867 let mut router = McpRouter::new().tool(slow_tool);
1868 init_router(&mut router).await;
1869
1870 let req = RouterRequest {
1872 id: RequestId::Number(1),
1873 inner: McpRequest::EnqueueTask(EnqueueTaskParams {
1874 tool_name: "slow".to_string(),
1875 arguments: serde_json::json!({}),
1876 ttl: None,
1877 }),
1878 };
1879
1880 let resp = router.ready().await.unwrap().call(req).await.unwrap();
1881 let task_id = match resp.inner {
1882 Ok(McpResponse::EnqueueTask(result)) => result.task_id,
1883 _ => panic!("Expected EnqueueTask response"),
1884 };
1885
1886 let req = RouterRequest {
1888 id: RequestId::Number(2),
1889 inner: McpRequest::CancelTask(CancelTaskParams {
1890 task_id: task_id.clone(),
1891 reason: Some("Test cancellation".to_string()),
1892 }),
1893 };
1894
1895 let resp = router.ready().await.unwrap().call(req).await.unwrap();
1896
1897 match resp.inner {
1898 Ok(McpResponse::CancelTask(result)) => {
1899 assert!(result.cancelled);
1900 assert_eq!(result.status, TaskStatus::Cancelled);
1901 }
1902 _ => panic!("Expected CancelTask response"),
1903 }
1904 }
1905
1906 #[tokio::test]
1907 async fn test_get_task_info() {
1908 let add_tool = ToolBuilder::new("add")
1909 .description("Add two numbers")
1910 .handler(|input: AddInput| async move {
1911 Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1912 })
1913 .build()
1914 .expect("valid tool name");
1915
1916 let mut router = McpRouter::new().tool(add_tool);
1917 init_router(&mut router).await;
1918
1919 let req = RouterRequest {
1921 id: RequestId::Number(1),
1922 inner: McpRequest::EnqueueTask(EnqueueTaskParams {
1923 tool_name: "add".to_string(),
1924 arguments: serde_json::json!({"a": 1, "b": 2}),
1925 ttl: Some(600),
1926 }),
1927 };
1928
1929 let resp = router.ready().await.unwrap().call(req).await.unwrap();
1930 let task_id = match resp.inner {
1931 Ok(McpResponse::EnqueueTask(result)) => result.task_id,
1932 _ => panic!("Expected EnqueueTask response"),
1933 };
1934
1935 let req = RouterRequest {
1937 id: RequestId::Number(2),
1938 inner: McpRequest::GetTaskInfo(GetTaskInfoParams {
1939 task_id: task_id.clone(),
1940 }),
1941 };
1942
1943 let resp = router.ready().await.unwrap().call(req).await.unwrap();
1944
1945 match resp.inner {
1946 Ok(McpResponse::GetTaskInfo(info)) => {
1947 assert_eq!(info.task_id, task_id);
1948 assert!(info.created_at.contains('T')); assert_eq!(info.ttl, Some(600));
1950 }
1951 _ => panic!("Expected GetTaskInfo response"),
1952 }
1953 }
1954
1955 #[tokio::test]
1956 async fn test_enqueue_nonexistent_tool() {
1957 let mut router = McpRouter::new();
1958 init_router(&mut router).await;
1959
1960 let req = RouterRequest {
1961 id: RequestId::Number(1),
1962 inner: McpRequest::EnqueueTask(EnqueueTaskParams {
1963 tool_name: "nonexistent".to_string(),
1964 arguments: serde_json::json!({}),
1965 ttl: None,
1966 }),
1967 };
1968
1969 let resp = router.ready().await.unwrap().call(req).await.unwrap();
1970
1971 match resp.inner {
1972 Err(e) => {
1973 assert!(e.message.contains("not found"));
1974 }
1975 _ => panic!("Expected error response"),
1976 }
1977 }
1978
1979 #[tokio::test]
1980 async fn test_get_nonexistent_task() {
1981 let mut router = McpRouter::new();
1982 init_router(&mut router).await;
1983
1984 let req = RouterRequest {
1985 id: RequestId::Number(1),
1986 inner: McpRequest::GetTaskInfo(GetTaskInfoParams {
1987 task_id: "task-999".to_string(),
1988 }),
1989 };
1990
1991 let resp = router.ready().await.unwrap().call(req).await.unwrap();
1992
1993 match resp.inner {
1994 Err(e) => {
1995 assert!(e.message.contains("not found"));
1996 }
1997 _ => panic!("Expected error response"),
1998 }
1999 }
2000
2001 #[tokio::test]
2006 async fn test_subscribe_to_resource() {
2007 use crate::resource::ResourceBuilder;
2008
2009 let resource = ResourceBuilder::new("file:///test.txt")
2010 .name("Test File")
2011 .text("Hello");
2012
2013 let mut router = McpRouter::new().resource(resource);
2014 init_router(&mut router).await;
2015
2016 let req = RouterRequest {
2018 id: RequestId::Number(1),
2019 inner: McpRequest::SubscribeResource(SubscribeResourceParams {
2020 uri: "file:///test.txt".to_string(),
2021 }),
2022 };
2023
2024 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2025
2026 match resp.inner {
2027 Ok(McpResponse::SubscribeResource(_)) => {
2028 assert!(router.is_subscribed("file:///test.txt"));
2030 }
2031 _ => panic!("Expected SubscribeResource response"),
2032 }
2033 }
2034
2035 #[tokio::test]
2036 async fn test_unsubscribe_from_resource() {
2037 use crate::resource::ResourceBuilder;
2038
2039 let resource = ResourceBuilder::new("file:///test.txt")
2040 .name("Test File")
2041 .text("Hello");
2042
2043 let mut router = McpRouter::new().resource(resource);
2044 init_router(&mut router).await;
2045
2046 let req = RouterRequest {
2048 id: RequestId::Number(1),
2049 inner: McpRequest::SubscribeResource(SubscribeResourceParams {
2050 uri: "file:///test.txt".to_string(),
2051 }),
2052 };
2053 let _ = router.ready().await.unwrap().call(req).await.unwrap();
2054 assert!(router.is_subscribed("file:///test.txt"));
2055
2056 let req = RouterRequest {
2058 id: RequestId::Number(2),
2059 inner: McpRequest::UnsubscribeResource(UnsubscribeResourceParams {
2060 uri: "file:///test.txt".to_string(),
2061 }),
2062 };
2063
2064 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2065
2066 match resp.inner {
2067 Ok(McpResponse::UnsubscribeResource(_)) => {
2068 assert!(!router.is_subscribed("file:///test.txt"));
2070 }
2071 _ => panic!("Expected UnsubscribeResource response"),
2072 }
2073 }
2074
2075 #[tokio::test]
2076 async fn test_subscribe_nonexistent_resource() {
2077 let mut router = McpRouter::new();
2078 init_router(&mut router).await;
2079
2080 let req = RouterRequest {
2081 id: RequestId::Number(1),
2082 inner: McpRequest::SubscribeResource(SubscribeResourceParams {
2083 uri: "file:///nonexistent.txt".to_string(),
2084 }),
2085 };
2086
2087 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2088
2089 match resp.inner {
2090 Err(e) => {
2091 assert!(e.message.contains("not found"));
2092 }
2093 _ => panic!("Expected error response"),
2094 }
2095 }
2096
2097 #[tokio::test]
2098 async fn test_notify_resource_updated() {
2099 use crate::context::notification_channel;
2100 use crate::resource::ResourceBuilder;
2101
2102 let (tx, mut rx) = notification_channel(10);
2103
2104 let resource = ResourceBuilder::new("file:///test.txt")
2105 .name("Test File")
2106 .text("Hello");
2107
2108 let router = McpRouter::new()
2109 .resource(resource)
2110 .with_notification_sender(tx);
2111
2112 router.subscribe("file:///test.txt");
2114
2115 let sent = router.notify_resource_updated("file:///test.txt");
2117 assert!(sent);
2118
2119 let notification = rx.try_recv().unwrap();
2121 match notification {
2122 ServerNotification::ResourceUpdated { uri } => {
2123 assert_eq!(uri, "file:///test.txt");
2124 }
2125 _ => panic!("Expected ResourceUpdated notification"),
2126 }
2127 }
2128
2129 #[tokio::test]
2130 async fn test_notify_resource_updated_not_subscribed() {
2131 use crate::context::notification_channel;
2132 use crate::resource::ResourceBuilder;
2133
2134 let (tx, mut rx) = notification_channel(10);
2135
2136 let resource = ResourceBuilder::new("file:///test.txt")
2137 .name("Test File")
2138 .text("Hello");
2139
2140 let router = McpRouter::new()
2141 .resource(resource)
2142 .with_notification_sender(tx);
2143
2144 let sent = router.notify_resource_updated("file:///test.txt");
2146 assert!(!sent); assert!(rx.try_recv().is_err());
2150 }
2151
2152 #[tokio::test]
2153 async fn test_notify_resources_list_changed() {
2154 use crate::context::notification_channel;
2155
2156 let (tx, mut rx) = notification_channel(10);
2157 let router = McpRouter::new().with_notification_sender(tx);
2158
2159 let sent = router.notify_resources_list_changed();
2160 assert!(sent);
2161
2162 let notification = rx.try_recv().unwrap();
2163 match notification {
2164 ServerNotification::ResourcesListChanged => {}
2165 _ => panic!("Expected ResourcesListChanged notification"),
2166 }
2167 }
2168
2169 #[tokio::test]
2170 async fn test_subscribed_uris() {
2171 use crate::resource::ResourceBuilder;
2172
2173 let resource1 = ResourceBuilder::new("file:///a.txt").name("A").text("A");
2174
2175 let resource2 = ResourceBuilder::new("file:///b.txt").name("B").text("B");
2176
2177 let router = McpRouter::new().resource(resource1).resource(resource2);
2178
2179 router.subscribe("file:///a.txt");
2181 router.subscribe("file:///b.txt");
2182
2183 let uris = router.subscribed_uris();
2184 assert_eq!(uris.len(), 2);
2185 assert!(uris.contains(&"file:///a.txt".to_string()));
2186 assert!(uris.contains(&"file:///b.txt".to_string()));
2187 }
2188
2189 #[tokio::test]
2190 async fn test_subscription_capability_advertised() {
2191 use crate::resource::ResourceBuilder;
2192
2193 let resource = ResourceBuilder::new("file:///test.txt")
2194 .name("Test")
2195 .text("Hello");
2196
2197 let mut router = McpRouter::new().resource(resource);
2198
2199 let init_req = RouterRequest {
2201 id: RequestId::Number(0),
2202 inner: McpRequest::Initialize(InitializeParams {
2203 protocol_version: "2025-03-26".to_string(),
2204 capabilities: ClientCapabilities {
2205 roots: None,
2206 sampling: None,
2207 elicitation: None,
2208 },
2209 client_info: Implementation {
2210 name: "test".to_string(),
2211 version: "1.0".to_string(),
2212 },
2213 }),
2214 };
2215 let resp = router.ready().await.unwrap().call(init_req).await.unwrap();
2216
2217 match resp.inner {
2218 Ok(McpResponse::Initialize(result)) => {
2219 let resources_cap = result.capabilities.resources.unwrap();
2221 assert!(resources_cap.subscribe);
2222 }
2223 _ => panic!("Expected Initialize response"),
2224 }
2225 }
2226
2227 #[tokio::test]
2228 async fn test_completion_handler() {
2229 let router = McpRouter::new()
2230 .server_info("test", "1.0")
2231 .completion_handler(|params: CompleteParams| async move {
2232 let prefix = ¶ms.argument.value;
2234 let suggestions: Vec<String> = vec!["alpha", "beta", "gamma"]
2235 .into_iter()
2236 .filter(|s| s.starts_with(prefix))
2237 .map(String::from)
2238 .collect();
2239 Ok(CompleteResult::new(suggestions))
2240 });
2241
2242 let init_req = RouterRequest {
2244 id: RequestId::Number(0),
2245 inner: McpRequest::Initialize(InitializeParams {
2246 protocol_version: "2025-03-26".to_string(),
2247 capabilities: ClientCapabilities::default(),
2248 client_info: Implementation {
2249 name: "test".to_string(),
2250 version: "1.0".to_string(),
2251 },
2252 }),
2253 };
2254 let resp = router
2255 .clone()
2256 .ready()
2257 .await
2258 .unwrap()
2259 .call(init_req)
2260 .await
2261 .unwrap();
2262
2263 match resp.inner {
2265 Ok(McpResponse::Initialize(result)) => {
2266 assert!(result.capabilities.completions.is_some());
2267 }
2268 _ => panic!("Expected Initialize response"),
2269 }
2270
2271 router.handle_notification(McpNotification::Initialized);
2273
2274 let complete_req = RouterRequest {
2276 id: RequestId::Number(1),
2277 inner: McpRequest::Complete(CompleteParams {
2278 reference: CompletionReference::prompt("test-prompt"),
2279 argument: CompletionArgument::new("query", "al"),
2280 }),
2281 };
2282 let resp = router
2283 .clone()
2284 .ready()
2285 .await
2286 .unwrap()
2287 .call(complete_req)
2288 .await
2289 .unwrap();
2290
2291 match resp.inner {
2292 Ok(McpResponse::Complete(result)) => {
2293 assert_eq!(result.completion.values, vec!["alpha"]);
2294 }
2295 _ => panic!("Expected Complete response"),
2296 }
2297 }
2298
2299 #[tokio::test]
2300 async fn test_completion_without_handler_returns_empty() {
2301 let router = McpRouter::new().server_info("test", "1.0");
2302
2303 let init_req = RouterRequest {
2305 id: RequestId::Number(0),
2306 inner: McpRequest::Initialize(InitializeParams {
2307 protocol_version: "2025-03-26".to_string(),
2308 capabilities: ClientCapabilities::default(),
2309 client_info: Implementation {
2310 name: "test".to_string(),
2311 version: "1.0".to_string(),
2312 },
2313 }),
2314 };
2315 let resp = router
2316 .clone()
2317 .ready()
2318 .await
2319 .unwrap()
2320 .call(init_req)
2321 .await
2322 .unwrap();
2323
2324 match resp.inner {
2326 Ok(McpResponse::Initialize(result)) => {
2327 assert!(result.capabilities.completions.is_none());
2328 }
2329 _ => panic!("Expected Initialize response"),
2330 }
2331
2332 router.handle_notification(McpNotification::Initialized);
2334
2335 let complete_req = RouterRequest {
2337 id: RequestId::Number(1),
2338 inner: McpRequest::Complete(CompleteParams {
2339 reference: CompletionReference::prompt("test-prompt"),
2340 argument: CompletionArgument::new("query", "al"),
2341 }),
2342 };
2343 let resp = router
2344 .clone()
2345 .ready()
2346 .await
2347 .unwrap()
2348 .call(complete_req)
2349 .await
2350 .unwrap();
2351
2352 match resp.inner {
2353 Ok(McpResponse::Complete(result)) => {
2354 assert!(result.completion.values.is_empty());
2355 }
2356 _ => panic!("Expected Complete response"),
2357 }
2358 }
2359}