1use futures::channel::oneshot;
14use mcpkit_core::capability::{
15 is_version_supported, ClientCapabilities, ClientInfo, InitializeRequest, InitializeResult,
16 ServerCapabilities, ServerInfo, PROTOCOL_VERSION, SUPPORTED_PROTOCOL_VERSIONS,
17};
18use mcpkit_core::error::{
19 HandshakeDetails, JsonRpcError, McpError, TransportContext, TransportDetails,
20 TransportErrorKind,
21};
22use mcpkit_core::protocol::{Message, Notification, Request, RequestId, Response};
23use mcpkit_core::types::{
24 CallToolRequest, CallToolResult, CancelTaskRequest, CompleteRequest, CompleteResult,
25 CompletionArgument, CompletionRef, CreateMessageRequest, ElicitRequest, GetPromptRequest,
26 GetPromptResult, GetTaskRequest, ListPromptsResult, ListResourceTemplatesResult,
27 ListResourcesResult, ListTasksRequest, ListTasksResult, ListToolsResult, Prompt,
28 ReadResourceRequest, ReadResourceResult, Resource, ResourceContents, ResourceTemplate, Task,
29 TaskStatus, TaskSummary, Tool,
30};
31use mcpkit_transport::Transport;
32use std::collections::HashMap;
33use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
34use std::sync::Arc;
35use tracing::{debug, error, info, trace, warn};
36
37use async_lock::RwLock;
39
40#[cfg(feature = "tokio-runtime")]
42use tokio::sync::mpsc;
43
44use crate::handler::ClientHandler;
45
46pub struct Client<T: Transport, H: ClientHandler = crate::handler::NoOpHandler> {
77 transport: Arc<T>,
79 server_info: ServerInfo,
81 server_caps: ServerCapabilities,
83 client_info: ClientInfo,
85 client_caps: ClientCapabilities,
87 next_id: AtomicU64,
89 pending: Arc<RwLock<HashMap<RequestId, oneshot::Sender<Response>>>>,
91 instructions: Option<String>,
93 handler: Arc<H>,
95 outgoing_tx: mpsc::Sender<Message>,
97 running: Arc<AtomicBool>,
99 _background_handle: Option<tokio::task::JoinHandle<()>>,
101}
102
103impl<T: Transport + 'static> Client<T, crate::handler::NoOpHandler> {
104 pub(crate) fn new(
106 transport: T,
107 init_result: InitializeResult,
108 client_info: ClientInfo,
109 client_caps: ClientCapabilities,
110 ) -> Self {
111 Self::with_handler(
112 transport,
113 init_result,
114 client_info,
115 client_caps,
116 crate::handler::NoOpHandler,
117 )
118 }
119}
120
121impl<T: Transport + 'static, H: ClientHandler + 'static> Client<T, H> {
122 pub(crate) fn with_handler(
124 transport: T,
125 init_result: InitializeResult,
126 client_info: ClientInfo,
127 client_caps: ClientCapabilities,
128 handler: H,
129 ) -> Self {
130 let transport = Arc::new(transport);
131 let pending = Arc::new(RwLock::new(HashMap::new()));
132 let handler = Arc::new(handler);
133 let running = Arc::new(AtomicBool::new(true));
134
135 let (outgoing_tx, outgoing_rx) = mpsc::channel::<Message>(256);
137
138 let background_handle = Self::spawn_message_router(
140 Arc::clone(&transport),
141 Arc::clone(&pending),
142 Arc::clone(&handler),
143 Arc::clone(&running),
144 outgoing_rx,
145 );
146
147 let handler_clone = Arc::clone(&handler);
149 tokio::spawn(async move {
150 handler_clone.on_connected().await;
151 });
152
153 Self {
154 transport,
155 server_info: init_result.server_info,
156 server_caps: init_result.capabilities,
157 client_info,
158 client_caps,
159 next_id: AtomicU64::new(1),
160 pending,
161 instructions: init_result.instructions,
162 handler,
163 outgoing_tx,
164 running,
165 _background_handle: Some(background_handle),
166 }
167 }
168
169 fn spawn_message_router(
177 transport: Arc<T>,
178 pending: Arc<RwLock<HashMap<RequestId, oneshot::Sender<Response>>>>,
179 handler: Arc<H>,
180 running: Arc<AtomicBool>,
181 mut outgoing_rx: mpsc::Receiver<Message>,
182 ) -> tokio::task::JoinHandle<()> {
183 tokio::spawn(async move {
184 debug!("Starting client message router");
185
186 loop {
187 if !running.load(Ordering::SeqCst) {
188 debug!("Message router stopping (client closed)");
189 break;
190 }
191
192 tokio::select! {
193 Some(msg) = outgoing_rx.recv() => {
195 if let Err(e) = transport.send(msg).await {
196 error!(?e, "Failed to send message");
197 }
198 }
199
200 result = transport.recv() => {
202 match result {
203 Ok(Some(message)) => {
204 Self::handle_incoming_message(
205 message,
206 &pending,
207 &handler,
208 &transport,
209 ).await;
210 }
211 Ok(None) => {
212 info!("Connection closed by server");
213 running.store(false, Ordering::SeqCst);
214 handler.on_disconnected().await;
215 break;
216 }
217 Err(e) => {
218 error!(?e, "Transport error in message router");
219 running.store(false, Ordering::SeqCst);
220 handler.on_disconnected().await;
221 break;
222 }
223 }
224 }
225 }
226 }
227
228 debug!("Message router stopped");
229 })
230 }
231
232 async fn handle_incoming_message(
234 message: Message,
235 pending: &Arc<RwLock<HashMap<RequestId, oneshot::Sender<Response>>>>,
236 handler: &Arc<H>,
237 transport: &Arc<T>,
238 ) {
239 match message {
240 Message::Response(response) => {
241 Self::route_response(response, pending).await;
242 }
243 Message::Request(request) => {
244 Self::handle_server_request(request, handler, transport).await;
245 }
246 Message::Notification(notification) => {
247 Self::handle_notification(notification, handler).await;
248 }
249 }
250 }
251
252 async fn route_response(
254 response: Response,
255 pending: &Arc<RwLock<HashMap<RequestId, oneshot::Sender<Response>>>>,
256 ) {
257 let sender = {
258 let mut pending_guard = pending.write().await;
259 pending_guard.remove(&response.id)
260 };
261
262 if let Some(sender) = sender {
263 trace!(?response.id, "Routing response to pending request");
264 if sender.send(response).is_err() {
265 warn!("Pending request receiver dropped");
266 }
267 } else {
268 warn!(?response.id, "Received response for unknown request");
269 }
270 }
271
272 async fn handle_server_request(request: Request, handler: &Arc<H>, transport: &Arc<T>) {
274 trace!(method = %request.method, "Handling server request");
275
276 let response = match request.method.as_ref() {
277 "sampling/createMessage" => Self::handle_sampling_request(&request, handler).await,
278 "elicitation/elicit" => Self::handle_elicitation_request(&request, handler).await,
279 "roots/list" => Self::handle_roots_request(&request, handler).await,
280 "ping" => {
281 Response::success(request.id.clone(), serde_json::json!({}))
283 }
284 _ => {
285 warn!(method = %request.method, "Unknown server request method");
286 Response::error(
287 request.id.clone(),
288 JsonRpcError::method_not_found(format!("Unknown method: {}", request.method)),
289 )
290 }
291 };
292
293 if let Err(e) = transport.send(Message::Response(response)).await {
295 error!(?e, "Failed to send response to server request");
296 }
297 }
298
299 async fn handle_sampling_request(request: &Request, handler: &Arc<H>) -> Response {
301 let params = match &request.params {
302 Some(p) => match serde_json::from_value::<CreateMessageRequest>(p.clone()) {
303 Ok(req) => req,
304 Err(e) => {
305 return Response::error(
306 request.id.clone(),
307 JsonRpcError::invalid_params(format!("Invalid params: {e}")),
308 );
309 }
310 },
311 None => {
312 return Response::error(
313 request.id.clone(),
314 JsonRpcError::invalid_params("Missing params for sampling/createMessage"),
315 );
316 }
317 };
318
319 match handler.create_message(params).await {
320 Ok(result) => match serde_json::to_value(result) {
321 Ok(value) => Response::success(request.id.clone(), value),
322 Err(e) => Response::error(
323 request.id.clone(),
324 JsonRpcError::internal_error(format!("Serialization error: {e}")),
325 ),
326 },
327 Err(e) => Response::error(
328 request.id.clone(),
329 JsonRpcError::internal_error(e.to_string()),
330 ),
331 }
332 }
333
334 async fn handle_elicitation_request(request: &Request, handler: &Arc<H>) -> Response {
336 let params = match &request.params {
337 Some(p) => match serde_json::from_value::<ElicitRequest>(p.clone()) {
338 Ok(req) => req,
339 Err(e) => {
340 return Response::error(
341 request.id.clone(),
342 JsonRpcError::invalid_params(format!("Invalid params: {e}")),
343 );
344 }
345 },
346 None => {
347 return Response::error(
348 request.id.clone(),
349 JsonRpcError::invalid_params("Missing params for elicitation/elicit"),
350 );
351 }
352 };
353
354 match handler.elicit(params).await {
355 Ok(result) => match serde_json::to_value(result) {
356 Ok(value) => Response::success(request.id.clone(), value),
357 Err(e) => Response::error(
358 request.id.clone(),
359 JsonRpcError::internal_error(format!("Serialization error: {e}")),
360 ),
361 },
362 Err(e) => Response::error(
363 request.id.clone(),
364 JsonRpcError::internal_error(e.to_string()),
365 ),
366 }
367 }
368
369 async fn handle_roots_request(request: &Request, handler: &Arc<H>) -> Response {
371 match handler.list_roots().await {
372 Ok(roots) => {
373 let roots_json: Vec<serde_json::Value> = roots
374 .into_iter()
375 .map(|r| {
376 serde_json::json!({
377 "uri": r.uri,
378 "name": r.name
379 })
380 })
381 .collect();
382 Response::success(
383 request.id.clone(),
384 serde_json::json!({ "roots": roots_json }),
385 )
386 }
387 Err(e) => Response::error(
388 request.id.clone(),
389 JsonRpcError::internal_error(e.to_string()),
390 ),
391 }
392 }
393
394 async fn handle_notification(notification: Notification, handler: &Arc<H>) {
396 trace!(method = %notification.method, "Received server notification");
397
398 match notification.method.as_ref() {
399 "notifications/cancelled" => {
400 if let Some(params) = ¬ification.params {
402 if let Some(request_id) = params.get("requestId") {
403 debug!(?request_id, "Server cancelled request");
404 }
405 }
406 }
407 "notifications/progress" => {
408 if let Some(params) = notification.params {
410 if let (Some(task_id), Some(progress)) = (
411 params.get("progressToken").and_then(|v| v.as_str()),
412 params.get("progress"),
413 ) {
414 if let Ok(progress) = serde_json::from_value::<
415 mcpkit_core::types::TaskProgress,
416 >(progress.clone())
417 {
418 debug!(task_id = %task_id, "Task progress update");
419 handler.on_task_progress(task_id.into(), progress).await;
420 }
421 }
422 }
423 }
424 "notifications/resources/updated" => {
425 if let Some(params) = notification.params {
426 if let Some(uri) = params.get("uri").and_then(|v| v.as_str()) {
427 debug!(uri = %uri, "Resource updated");
428 handler.on_resource_updated(uri.to_string()).await;
429 }
430 }
431 }
432 "notifications/resources/list_changed" => {
433 debug!("Resources list changed");
434 handler.on_resources_list_changed().await;
435 }
436 "notifications/tools/list_changed" => {
437 debug!("Tools list changed");
438 handler.on_tools_list_changed().await;
439 }
440 "notifications/prompts/list_changed" => {
441 debug!("Prompts list changed");
442 handler.on_prompts_list_changed().await;
443 }
444 _ => {
445 trace!(method = %notification.method, "Unhandled notification");
446 }
447 }
448 }
449
450 pub const fn server_info(&self) -> &ServerInfo {
452 &self.server_info
453 }
454
455 pub const fn server_capabilities(&self) -> &ServerCapabilities {
457 &self.server_caps
458 }
459
460 pub const fn client_info(&self) -> &ClientInfo {
462 &self.client_info
463 }
464
465 pub const fn client_capabilities(&self) -> &ClientCapabilities {
467 &self.client_caps
468 }
469
470 pub fn instructions(&self) -> Option<&str> {
472 self.instructions.as_deref()
473 }
474
475 pub const fn has_tools(&self) -> bool {
477 self.server_caps.has_tools()
478 }
479
480 pub const fn has_resources(&self) -> bool {
482 self.server_caps.has_resources()
483 }
484
485 pub const fn has_prompts(&self) -> bool {
487 self.server_caps.has_prompts()
488 }
489
490 pub const fn has_tasks(&self) -> bool {
492 self.server_caps.has_tasks()
493 }
494
495 pub const fn has_completions(&self) -> bool {
497 self.server_caps.has_completions()
498 }
499
500 pub fn is_connected(&self) -> bool {
502 self.running.load(Ordering::SeqCst)
503 }
504
505 pub async fn list_tools(&self) -> Result<Vec<Tool>, McpError> {
515 self.ensure_capability("tools", self.has_tools())?;
516
517 let result: ListToolsResult = self.request("tools/list", None).await?;
518 Ok(result.tools)
519 }
520
521 pub async fn list_tools_paginated(
527 &self,
528 cursor: Option<&str>,
529 ) -> Result<ListToolsResult, McpError> {
530 self.ensure_capability("tools", self.has_tools())?;
531
532 let params = cursor.map(|c| serde_json::json!({ "cursor": c }));
533 self.request("tools/list", params).await
534 }
535
536 pub async fn call_tool(
547 &self,
548 name: impl Into<String>,
549 arguments: serde_json::Value,
550 ) -> Result<CallToolResult, McpError> {
551 self.ensure_capability("tools", self.has_tools())?;
552
553 let request = CallToolRequest {
554 name: name.into(),
555 arguments: Some(arguments),
556 };
557 self.request("tools/call", Some(serde_json::to_value(request)?))
558 .await
559 }
560
561 pub async fn list_resources(&self) -> Result<Vec<Resource>, McpError> {
571 self.ensure_capability("resources", self.has_resources())?;
572
573 let result: ListResourcesResult = self.request("resources/list", None).await?;
574 Ok(result.resources)
575 }
576
577 pub async fn list_resources_paginated(
583 &self,
584 cursor: Option<&str>,
585 ) -> Result<ListResourcesResult, McpError> {
586 self.ensure_capability("resources", self.has_resources())?;
587
588 let params = cursor.map(|c| serde_json::json!({ "cursor": c }));
589 self.request("resources/list", params).await
590 }
591
592 pub async fn list_resource_templates(&self) -> Result<Vec<ResourceTemplate>, McpError> {
598 self.ensure_capability("resources", self.has_resources())?;
599
600 let result: ListResourceTemplatesResult =
601 self.request("resources/templates/list", None).await?;
602 Ok(result.resource_templates)
603 }
604
605 pub async fn read_resource(
611 &self,
612 uri: impl Into<String>,
613 ) -> Result<Vec<ResourceContents>, McpError> {
614 self.ensure_capability("resources", self.has_resources())?;
615
616 let request = ReadResourceRequest { uri: uri.into() };
617 let result: ReadResourceResult = self
618 .request("resources/read", Some(serde_json::to_value(request)?))
619 .await?;
620 Ok(result.contents)
621 }
622
623 pub async fn list_prompts(&self) -> Result<Vec<Prompt>, McpError> {
633 self.ensure_capability("prompts", self.has_prompts())?;
634
635 let result: ListPromptsResult = self.request("prompts/list", None).await?;
636 Ok(result.prompts)
637 }
638
639 pub async fn list_prompts_paginated(
645 &self,
646 cursor: Option<&str>,
647 ) -> Result<ListPromptsResult, McpError> {
648 self.ensure_capability("prompts", self.has_prompts())?;
649
650 let params = cursor.map(|c| serde_json::json!({ "cursor": c }));
651 self.request("prompts/list", params).await
652 }
653
654 pub async fn get_prompt(
660 &self,
661 name: impl Into<String>,
662 arguments: Option<serde_json::Map<String, serde_json::Value>>,
663 ) -> Result<GetPromptResult, McpError> {
664 self.ensure_capability("prompts", self.has_prompts())?;
665
666 let request = GetPromptRequest {
667 name: name.into(),
668 arguments,
669 };
670 self.request("prompts/get", Some(serde_json::to_value(request)?))
671 .await
672 }
673
674 pub async fn list_tasks(&self) -> Result<Vec<TaskSummary>, McpError> {
684 self.ensure_capability("tasks", self.has_tasks())?;
685
686 let result: ListTasksResult = self.request("tasks/list", None).await?;
687 Ok(result.tasks)
688 }
689
690 pub async fn list_tasks_filtered(
696 &self,
697 status: Option<TaskStatus>,
698 cursor: Option<&str>,
699 ) -> Result<ListTasksResult, McpError> {
700 self.ensure_capability("tasks", self.has_tasks())?;
701
702 let request = ListTasksRequest {
703 status,
704 cursor: cursor.map(String::from),
705 };
706 self.request("tasks/list", Some(serde_json::to_value(request)?))
707 .await
708 }
709
710 pub async fn get_task(&self, id: impl Into<String>) -> Result<Task, McpError> {
716 self.ensure_capability("tasks", self.has_tasks())?;
717
718 let request = GetTaskRequest {
719 id: id.into().into(),
720 };
721 self.request("tasks/get", Some(serde_json::to_value(request)?))
722 .await
723 }
724
725 pub async fn cancel_task(&self, id: impl Into<String>) -> Result<(), McpError> {
732 self.ensure_capability("tasks", self.has_tasks())?;
733
734 let request = CancelTaskRequest {
735 id: id.into().into(),
736 };
737 let _: serde_json::Value = self
738 .request("tasks/cancel", Some(serde_json::to_value(request)?))
739 .await?;
740 Ok(())
741 }
742
743 pub async fn complete_prompt_argument(
759 &self,
760 prompt_name: impl Into<String>,
761 argument_name: impl Into<String>,
762 current_value: impl Into<String>,
763 ) -> Result<CompleteResult, McpError> {
764 self.ensure_capability("completions", self.has_completions())?;
765
766 let request = CompleteRequest {
767 ref_: CompletionRef::prompt(prompt_name),
768 argument: CompletionArgument {
769 name: argument_name.into(),
770 value: current_value.into(),
771 },
772 };
773 self.request("completion/complete", Some(serde_json::to_value(request)?))
774 .await
775 }
776
777 pub async fn complete_resource_argument(
789 &self,
790 resource_uri: impl Into<String>,
791 argument_name: impl Into<String>,
792 current_value: impl Into<String>,
793 ) -> Result<CompleteResult, McpError> {
794 self.ensure_capability("completions", self.has_completions())?;
795
796 let request = CompleteRequest {
797 ref_: CompletionRef::resource(resource_uri),
798 argument: CompletionArgument {
799 name: argument_name.into(),
800 value: current_value.into(),
801 },
802 };
803 self.request("completion/complete", Some(serde_json::to_value(request)?))
804 .await
805 }
806
807 pub async fn subscribe_resource(&self, uri: impl Into<String>) -> Result<(), McpError> {
820 self.ensure_capability("resources", self.has_resources())?;
821
822 if !self.server_caps.has_resource_subscribe() {
824 return Err(McpError::CapabilityNotSupported {
825 capability: "resources.subscribe".to_string(),
826 available: self.available_capabilities().into_boxed_slice(),
827 });
828 }
829
830 let params = serde_json::json!({ "uri": uri.into() });
831 let _: serde_json::Value = self.request("resources/subscribe", Some(params)).await?;
832 Ok(())
833 }
834
835 pub async fn unsubscribe_resource(&self, uri: impl Into<String>) -> Result<(), McpError> {
841 self.ensure_capability("resources", self.has_resources())?;
842
843 if !self.server_caps.has_resource_subscribe() {
845 return Err(McpError::CapabilityNotSupported {
846 capability: "resources.subscribe".to_string(),
847 available: self.available_capabilities().into_boxed_slice(),
848 });
849 }
850
851 let params = serde_json::json!({ "uri": uri.into() });
852 let _: serde_json::Value = self.request("resources/unsubscribe", Some(params)).await?;
853 Ok(())
854 }
855
856 pub async fn ping(&self) -> Result<(), McpError> {
866 let _: serde_json::Value = self.request("ping", None).await?;
867 Ok(())
868 }
869
870 pub async fn close(self) -> Result<(), McpError> {
876 debug!("Closing client connection");
877
878 self.running.store(false, Ordering::SeqCst);
880
881 self.handler.on_disconnected().await;
883
884 self.transport.close().await.map_err(|e| {
886 McpError::Transport(Box::new(TransportDetails {
887 kind: TransportErrorKind::ConnectionClosed,
888 message: e.to_string(),
889 context: TransportContext::default(),
890 source: None,
891 }))
892 })
893 }
894
895 fn next_request_id(&self) -> RequestId {
901 RequestId::Number(self.next_id.fetch_add(1, Ordering::SeqCst))
902 }
903
904 async fn request<R: serde::de::DeserializeOwned>(
906 &self,
907 method: &str,
908 params: Option<serde_json::Value>,
909 ) -> Result<R, McpError> {
910 if !self.is_connected() {
911 return Err(McpError::Transport(Box::new(TransportDetails {
912 kind: TransportErrorKind::ConnectionClosed,
913 message: "Client is not connected".to_string(),
914 context: TransportContext::default(),
915 source: None,
916 })));
917 }
918
919 let id = self.next_request_id();
920 let request = if let Some(params) = params {
921 Request::with_params(method.to_string(), id.clone(), params)
922 } else {
923 Request::new(method.to_string(), id.clone())
924 };
925
926 trace!(?id, method, "Sending request");
927
928 let (tx, rx) = oneshot::channel();
930 {
931 let mut pending = self.pending.write().await;
932 pending.insert(id.clone(), tx);
933 }
934
935 self.outgoing_tx
937 .send(Message::Request(request))
938 .await
939 .map_err(|_| {
940 McpError::Transport(Box::new(TransportDetails {
941 kind: TransportErrorKind::WriteFailed,
942 message: "Failed to send request (channel closed)".to_string(),
943 context: TransportContext::default(),
944 source: None,
945 }))
946 })?;
947
948 let response = rx.await.map_err(|_| {
950 McpError::Transport(Box::new(TransportDetails {
951 kind: TransportErrorKind::ConnectionClosed,
952 message: "Response channel closed (server may have disconnected)".to_string(),
953 context: TransportContext::default(),
954 source: None,
955 }))
956 })?;
957
958 if let Some(error) = response.error {
960 return Err(McpError::Internal {
961 message: error.message,
962 source: None,
963 });
964 }
965
966 let result = response.result.ok_or_else(|| McpError::Internal {
967 message: "Response contained neither result nor error".to_string(),
968 source: None,
969 })?;
970
971 serde_json::from_value(result).map_err(McpError::from)
972 }
973
974 fn ensure_capability(&self, name: &str, supported: bool) -> Result<(), McpError> {
976 if supported {
977 Ok(())
978 } else {
979 Err(McpError::CapabilityNotSupported {
980 capability: name.to_string(),
981 available: self.available_capabilities().into_boxed_slice(),
982 })
983 }
984 }
985
986 fn available_capabilities(&self) -> Vec<String> {
988 let mut caps = Vec::new();
989 if self.has_tools() {
990 caps.push("tools".to_string());
991 }
992 if self.has_resources() {
993 caps.push("resources".to_string());
994 }
995 if self.has_prompts() {
996 caps.push("prompts".to_string());
997 }
998 if self.has_tasks() {
999 caps.push("tasks".to_string());
1000 }
1001 if self.has_completions() {
1002 caps.push("completions".to_string());
1003 }
1004 caps
1005 }
1006}
1007
1008pub(crate) async fn initialize<T: Transport>(
1025 transport: &T,
1026 client_info: &ClientInfo,
1027 capabilities: &ClientCapabilities,
1028) -> Result<InitializeResult, McpError> {
1029 debug!(
1030 protocol_version = %PROTOCOL_VERSION,
1031 supported_versions = ?SUPPORTED_PROTOCOL_VERSIONS,
1032 "Initializing MCP connection"
1033 );
1034
1035 let request = InitializeRequest::new(client_info.clone(), capabilities.clone());
1037 let init_request = Request::with_params(
1038 "initialize".to_string(),
1039 RequestId::Number(0),
1040 serde_json::to_value(&request)?,
1041 );
1042
1043 transport
1045 .send(Message::Request(init_request))
1046 .await
1047 .map_err(|e| {
1048 McpError::Transport(Box::new(TransportDetails {
1049 kind: TransportErrorKind::WriteFailed,
1050 message: format!("Failed to send initialize: {e}"),
1051 context: TransportContext::default(),
1052 source: None,
1053 }))
1054 })?;
1055
1056 let response = loop {
1058 match transport.recv().await {
1059 Ok(Some(Message::Response(r))) if r.id == RequestId::Number(0) => break r,
1060 Ok(Some(_)) => {} Ok(None) => {
1062 return Err(McpError::HandshakeFailed(Box::new(HandshakeDetails {
1063 message: "Connection closed during initialization".to_string(),
1064 client_version: Some(PROTOCOL_VERSION.to_string()),
1065 server_version: None,
1066 source: None,
1067 })));
1068 }
1069 Err(e) => {
1070 return Err(McpError::HandshakeFailed(Box::new(HandshakeDetails {
1071 message: format!("Transport error during initialization: {e}"),
1072 client_version: Some(PROTOCOL_VERSION.to_string()),
1073 server_version: None,
1074 source: None,
1075 })));
1076 }
1077 }
1078 };
1079
1080 if let Some(error) = response.error {
1082 return Err(McpError::HandshakeFailed(Box::new(HandshakeDetails {
1083 message: error.message,
1084 client_version: Some(PROTOCOL_VERSION.to_string()),
1085 server_version: None,
1086 source: None,
1087 })));
1088 }
1089
1090 let result: InitializeResult = response
1091 .result
1092 .map(serde_json::from_value)
1093 .transpose()?
1094 .ok_or_else(|| {
1095 McpError::HandshakeFailed(Box::new(HandshakeDetails {
1096 message: "Empty initialize result".to_string(),
1097 client_version: Some(PROTOCOL_VERSION.to_string()),
1098 server_version: None,
1099 source: None,
1100 }))
1101 })?;
1102
1103 let server_version = &result.protocol_version;
1105 if !is_version_supported(server_version) {
1106 warn!(
1107 server_version = %server_version,
1108 supported = ?SUPPORTED_PROTOCOL_VERSIONS,
1109 "Server returned unsupported protocol version"
1110 );
1111 return Err(McpError::HandshakeFailed(Box::new(HandshakeDetails {
1112 message: format!(
1113 "Unsupported protocol version: server returned '{server_version}', but client only supports {SUPPORTED_PROTOCOL_VERSIONS:?}"
1114 ),
1115 client_version: Some(PROTOCOL_VERSION.to_string()),
1116 server_version: Some(server_version.clone()),
1117 source: None,
1118 })));
1119 }
1120
1121 debug!(
1122 server = %result.server_info.name,
1123 server_version = %result.server_info.version,
1124 protocol_version = %result.protocol_version,
1125 "Received initialize result with compatible protocol version"
1126 );
1127
1128 let notification = Notification::new("notifications/initialized");
1130 transport
1131 .send(Message::Notification(notification))
1132 .await
1133 .map_err(|e| {
1134 McpError::Transport(Box::new(TransportDetails {
1135 kind: TransportErrorKind::WriteFailed,
1136 message: format!("Failed to send initialized: {e}"),
1137 context: TransportContext::default(),
1138 source: None,
1139 }))
1140 })?;
1141
1142 debug!("MCP initialization complete");
1143 Ok(result)
1144}
1145
1146#[cfg(test)]
1147mod tests {
1148 use super::*;
1149
1150 #[test]
1151 fn test_request_id_generation() {
1152 let next_id = AtomicU64::new(1);
1153 assert_eq!(next_id.fetch_add(1, Ordering::SeqCst), 1);
1154 assert_eq!(next_id.fetch_add(1, Ordering::SeqCst), 2);
1155 }
1156}