1use futures::channel::oneshot;
14use mcpkit_core::capability::{
15 is_version_supported, ClientCapabilities, ClientInfo, InitializeRequest, InitializeResult,
16 ServerCapabilities, ServerInfo, PROTOCOL_VERSION, SUPPORTED_PROTOCOL_VERSIONS,
17};
18use mcpkit_core::error::{
19 HandshakeDetails, JsonRpcError, McpError, TransportContext, TransportDetails,
20 TransportErrorKind,
21};
22use mcpkit_core::protocol::{Message, Notification, Request, RequestId, Response};
23use mcpkit_core::protocol_version::ProtocolVersion;
24use mcpkit_core::types::{
25 CallToolRequest, CallToolResult, CancelTaskRequest, CompleteRequest, CompleteResult,
26 CompletionArgument, CompletionRef, CreateMessageRequest, ElicitRequest, GetPromptRequest,
27 GetPromptResult, GetTaskRequest, ListPromptsResult, ListResourceTemplatesResult,
28 ListResourcesResult, ListTasksRequest, ListTasksResult, ListToolsResult, Prompt,
29 ReadResourceRequest, ReadResourceResult, Resource, ResourceContents, ResourceTemplate, Task,
30 TaskStatus, TaskSummary, Tool,
31};
32use mcpkit_transport::Transport;
33use std::collections::HashMap;
34use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
35use std::sync::Arc;
36use tracing::{debug, error, info, trace, warn};
37
38use async_lock::RwLock;
40
41#[cfg(feature = "tokio-runtime")]
43use tokio::sync::mpsc;
44
45use crate::handler::ClientHandler;
46
47pub struct Client<T: Transport, H: ClientHandler = crate::handler::NoOpHandler> {
78 transport: Arc<T>,
80 server_info: ServerInfo,
82 server_caps: ServerCapabilities,
84 protocol_version: ProtocolVersion,
89 client_info: ClientInfo,
91 client_caps: ClientCapabilities,
93 next_id: AtomicU64,
95 pending: Arc<RwLock<HashMap<RequestId, oneshot::Sender<Response>>>>,
97 instructions: Option<String>,
99 handler: Arc<H>,
101 outgoing_tx: mpsc::Sender<Message>,
103 running: Arc<AtomicBool>,
105 _background_handle: Option<tokio::task::JoinHandle<()>>,
107}
108
109impl<T: Transport + 'static> Client<T, crate::handler::NoOpHandler> {
110 pub(crate) fn new(
112 transport: T,
113 init_result: InitializeResult,
114 client_info: ClientInfo,
115 client_caps: ClientCapabilities,
116 ) -> Self {
117 Self::with_handler(
118 transport,
119 init_result,
120 client_info,
121 client_caps,
122 crate::handler::NoOpHandler,
123 )
124 }
125}
126
127impl<T: Transport + 'static, H: ClientHandler + 'static> Client<T, H> {
128 pub(crate) fn with_handler(
130 transport: T,
131 init_result: InitializeResult,
132 client_info: ClientInfo,
133 client_caps: ClientCapabilities,
134 handler: H,
135 ) -> Self {
136 let transport = Arc::new(transport);
137 let pending = Arc::new(RwLock::new(HashMap::new()));
138 let handler = Arc::new(handler);
139 let running = Arc::new(AtomicBool::new(true));
140
141 let protocol_version =
143 if let Ok(v) = init_result.protocol_version.parse::<ProtocolVersion>() {
144 v
145 } else {
146 warn!(
147 server_version = %init_result.protocol_version,
148 fallback_version = %ProtocolVersion::LATEST,
149 "Server returned unknown protocol version, falling back to latest supported"
150 );
151 ProtocolVersion::LATEST
152 };
153
154 let (outgoing_tx, outgoing_rx) = mpsc::channel::<Message>(256);
156
157 let background_handle = Self::spawn_message_router(
159 Arc::clone(&transport),
160 Arc::clone(&pending),
161 Arc::clone(&handler),
162 Arc::clone(&running),
163 outgoing_rx,
164 );
165
166 let handler_clone = Arc::clone(&handler);
168 tokio::spawn(async move {
169 handler_clone.on_connected().await;
170 });
171
172 Self {
173 transport,
174 server_info: init_result.server_info,
175 server_caps: init_result.capabilities,
176 protocol_version,
177 client_info,
178 client_caps,
179 next_id: AtomicU64::new(1),
180 pending,
181 instructions: init_result.instructions,
182 handler,
183 outgoing_tx,
184 running,
185 _background_handle: Some(background_handle),
186 }
187 }
188
189 fn spawn_message_router(
197 transport: Arc<T>,
198 pending: Arc<RwLock<HashMap<RequestId, oneshot::Sender<Response>>>>,
199 handler: Arc<H>,
200 running: Arc<AtomicBool>,
201 mut outgoing_rx: mpsc::Receiver<Message>,
202 ) -> tokio::task::JoinHandle<()> {
203 tokio::spawn(async move {
204 debug!("Starting client message router");
205
206 loop {
207 if !running.load(Ordering::SeqCst) {
208 debug!("Message router stopping (client closed)");
209 break;
210 }
211
212 tokio::select! {
213 Some(msg) = outgoing_rx.recv() => {
215 if let Err(e) = transport.send(msg).await {
216 error!(?e, "Failed to send message");
217 }
218 }
219
220 result = transport.recv() => {
222 match result {
223 Ok(Some(message)) => {
224 Self::handle_incoming_message(
225 message,
226 &pending,
227 &handler,
228 &transport,
229 ).await;
230 }
231 Ok(None) => {
232 info!("Connection closed by server");
233 running.store(false, Ordering::SeqCst);
234 handler.on_disconnected().await;
235 break;
236 }
237 Err(e) => {
238 error!(?e, "Transport error in message router");
239 running.store(false, Ordering::SeqCst);
240 handler.on_disconnected().await;
241 break;
242 }
243 }
244 }
245 }
246 }
247
248 debug!("Message router stopped");
249 })
250 }
251
252 async fn handle_incoming_message(
254 message: Message,
255 pending: &Arc<RwLock<HashMap<RequestId, oneshot::Sender<Response>>>>,
256 handler: &Arc<H>,
257 transport: &Arc<T>,
258 ) {
259 match message {
260 Message::Response(response) => {
261 Self::route_response(response, pending).await;
262 }
263 Message::Request(request) => {
264 Self::handle_server_request(request, handler, transport).await;
265 }
266 Message::Notification(notification) => {
267 Self::handle_notification(notification, handler).await;
268 }
269 }
270 }
271
272 async fn route_response(
274 response: Response,
275 pending: &Arc<RwLock<HashMap<RequestId, oneshot::Sender<Response>>>>,
276 ) {
277 let sender = {
278 let mut pending_guard = pending.write().await;
279 pending_guard.remove(&response.id)
280 };
281
282 if let Some(sender) = sender {
283 trace!(?response.id, "Routing response to pending request");
284 if sender.send(response).is_err() {
285 warn!("Pending request receiver dropped");
286 }
287 } else {
288 warn!(?response.id, "Received response for unknown request");
289 }
290 }
291
292 async fn handle_server_request(request: Request, handler: &Arc<H>, transport: &Arc<T>) {
294 trace!(method = %request.method, "Handling server request");
295
296 let response = match request.method.as_ref() {
297 "sampling/createMessage" => Self::handle_sampling_request(&request, handler).await,
298 "elicitation/elicit" => Self::handle_elicitation_request(&request, handler).await,
299 "roots/list" => Self::handle_roots_request(&request, handler).await,
300 "ping" => {
301 Response::success(request.id.clone(), serde_json::json!({}))
303 }
304 _ => {
305 warn!(method = %request.method, "Unknown server request method");
306 Response::error(
307 request.id.clone(),
308 JsonRpcError::method_not_found(format!("Unknown method: {}", request.method)),
309 )
310 }
311 };
312
313 if let Err(e) = transport.send(Message::Response(response)).await {
315 error!(?e, "Failed to send response to server request");
316 }
317 }
318
319 async fn handle_sampling_request(request: &Request, handler: &Arc<H>) -> Response {
321 let params = match &request.params {
322 Some(p) => match serde_json::from_value::<CreateMessageRequest>(p.clone()) {
323 Ok(req) => req,
324 Err(e) => {
325 return Response::error(
326 request.id.clone(),
327 JsonRpcError::invalid_params(format!("Invalid params: {e}")),
328 );
329 }
330 },
331 None => {
332 return Response::error(
333 request.id.clone(),
334 JsonRpcError::invalid_params("Missing params for sampling/createMessage"),
335 );
336 }
337 };
338
339 match handler.create_message(params).await {
340 Ok(result) => match serde_json::to_value(result) {
341 Ok(value) => Response::success(request.id.clone(), value),
342 Err(e) => Response::error(
343 request.id.clone(),
344 JsonRpcError::internal_error(format!("Serialization error: {e}")),
345 ),
346 },
347 Err(e) => Response::error(
348 request.id.clone(),
349 JsonRpcError::internal_error(e.to_string()),
350 ),
351 }
352 }
353
354 async fn handle_elicitation_request(request: &Request, handler: &Arc<H>) -> Response {
356 let params = match &request.params {
357 Some(p) => match serde_json::from_value::<ElicitRequest>(p.clone()) {
358 Ok(req) => req,
359 Err(e) => {
360 return Response::error(
361 request.id.clone(),
362 JsonRpcError::invalid_params(format!("Invalid params: {e}")),
363 );
364 }
365 },
366 None => {
367 return Response::error(
368 request.id.clone(),
369 JsonRpcError::invalid_params("Missing params for elicitation/elicit"),
370 );
371 }
372 };
373
374 match handler.elicit(params).await {
375 Ok(result) => match serde_json::to_value(result) {
376 Ok(value) => Response::success(request.id.clone(), value),
377 Err(e) => Response::error(
378 request.id.clone(),
379 JsonRpcError::internal_error(format!("Serialization error: {e}")),
380 ),
381 },
382 Err(e) => Response::error(
383 request.id.clone(),
384 JsonRpcError::internal_error(e.to_string()),
385 ),
386 }
387 }
388
389 async fn handle_roots_request(request: &Request, handler: &Arc<H>) -> Response {
391 match handler.list_roots().await {
392 Ok(roots) => {
393 let roots_json: Vec<serde_json::Value> = roots
394 .into_iter()
395 .map(|r| {
396 serde_json::json!({
397 "uri": r.uri,
398 "name": r.name
399 })
400 })
401 .collect();
402 Response::success(
403 request.id.clone(),
404 serde_json::json!({ "roots": roots_json }),
405 )
406 }
407 Err(e) => Response::error(
408 request.id.clone(),
409 JsonRpcError::internal_error(e.to_string()),
410 ),
411 }
412 }
413
414 async fn handle_notification(notification: Notification, handler: &Arc<H>) {
416 trace!(method = %notification.method, "Received server notification");
417
418 match notification.method.as_ref() {
419 "notifications/cancelled" => {
420 if let Some(params) = ¬ification.params {
422 if let Some(request_id) = params.get("requestId") {
423 debug!(?request_id, "Server cancelled request");
424 }
425 }
426 }
427 "notifications/progress" => {
428 if let Some(params) = notification.params {
430 if let (Some(task_id), Some(progress)) = (
431 params.get("progressToken").and_then(|v| v.as_str()),
432 params.get("progress"),
433 ) {
434 if let Ok(progress) = serde_json::from_value::<
435 mcpkit_core::types::TaskProgress,
436 >(progress.clone())
437 {
438 debug!(task_id = %task_id, "Task progress update");
439 handler.on_task_progress(task_id.into(), progress).await;
440 }
441 }
442 }
443 }
444 "notifications/resources/updated" => {
445 if let Some(params) = notification.params {
446 if let Some(uri) = params.get("uri").and_then(|v| v.as_str()) {
447 debug!(uri = %uri, "Resource updated");
448 handler.on_resource_updated(uri.to_string()).await;
449 }
450 }
451 }
452 "notifications/resources/list_changed" => {
453 debug!("Resources list changed");
454 handler.on_resources_list_changed().await;
455 }
456 "notifications/tools/list_changed" => {
457 debug!("Tools list changed");
458 handler.on_tools_list_changed().await;
459 }
460 "notifications/prompts/list_changed" => {
461 debug!("Prompts list changed");
462 handler.on_prompts_list_changed().await;
463 }
464 _ => {
465 trace!(method = %notification.method, "Unhandled notification");
466 }
467 }
468 }
469
470 pub const fn server_info(&self) -> &ServerInfo {
472 &self.server_info
473 }
474
475 pub const fn server_capabilities(&self) -> &ServerCapabilities {
477 &self.server_caps
478 }
479
480 pub fn protocol_version(&self) -> ProtocolVersion {
489 self.protocol_version
490 }
491
492 pub const fn client_info(&self) -> &ClientInfo {
494 &self.client_info
495 }
496
497 pub const fn client_capabilities(&self) -> &ClientCapabilities {
499 &self.client_caps
500 }
501
502 pub fn instructions(&self) -> Option<&str> {
504 self.instructions.as_deref()
505 }
506
507 pub const fn has_tools(&self) -> bool {
509 self.server_caps.has_tools()
510 }
511
512 pub const fn has_resources(&self) -> bool {
514 self.server_caps.has_resources()
515 }
516
517 pub const fn has_prompts(&self) -> bool {
519 self.server_caps.has_prompts()
520 }
521
522 pub const fn has_tasks(&self) -> bool {
524 self.server_caps.has_tasks()
525 }
526
527 pub const fn has_completions(&self) -> bool {
529 self.server_caps.has_completions()
530 }
531
532 pub fn is_connected(&self) -> bool {
534 self.running.load(Ordering::SeqCst)
535 }
536
537 pub async fn list_tools(&self) -> Result<Vec<Tool>, McpError> {
547 self.ensure_capability("tools", self.has_tools())?;
548
549 let result: ListToolsResult = self.request("tools/list", None).await?;
550 Ok(result.tools)
551 }
552
553 pub async fn list_tools_paginated(
559 &self,
560 cursor: Option<&str>,
561 ) -> Result<ListToolsResult, McpError> {
562 self.ensure_capability("tools", self.has_tools())?;
563
564 let params = cursor.map(|c| serde_json::json!({ "cursor": c }));
565 self.request("tools/list", params).await
566 }
567
568 pub async fn call_tool(
579 &self,
580 name: impl Into<String>,
581 arguments: serde_json::Value,
582 ) -> Result<CallToolResult, McpError> {
583 self.ensure_capability("tools", self.has_tools())?;
584
585 let request = CallToolRequest {
586 name: name.into(),
587 arguments: Some(arguments),
588 };
589 self.request("tools/call", Some(serde_json::to_value(request)?))
590 .await
591 }
592
593 pub async fn list_resources(&self) -> Result<Vec<Resource>, McpError> {
603 self.ensure_capability("resources", self.has_resources())?;
604
605 let result: ListResourcesResult = self.request("resources/list", None).await?;
606 Ok(result.resources)
607 }
608
609 pub async fn list_resources_paginated(
615 &self,
616 cursor: Option<&str>,
617 ) -> Result<ListResourcesResult, McpError> {
618 self.ensure_capability("resources", self.has_resources())?;
619
620 let params = cursor.map(|c| serde_json::json!({ "cursor": c }));
621 self.request("resources/list", params).await
622 }
623
624 pub async fn list_resource_templates(&self) -> Result<Vec<ResourceTemplate>, McpError> {
630 self.ensure_capability("resources", self.has_resources())?;
631
632 let result: ListResourceTemplatesResult =
633 self.request("resources/templates/list", None).await?;
634 Ok(result.resource_templates)
635 }
636
637 pub async fn read_resource(
643 &self,
644 uri: impl Into<String>,
645 ) -> Result<Vec<ResourceContents>, McpError> {
646 self.ensure_capability("resources", self.has_resources())?;
647
648 let request = ReadResourceRequest { uri: uri.into() };
649 let result: ReadResourceResult = self
650 .request("resources/read", Some(serde_json::to_value(request)?))
651 .await?;
652 Ok(result.contents)
653 }
654
655 pub async fn list_prompts(&self) -> Result<Vec<Prompt>, McpError> {
665 self.ensure_capability("prompts", self.has_prompts())?;
666
667 let result: ListPromptsResult = self.request("prompts/list", None).await?;
668 Ok(result.prompts)
669 }
670
671 pub async fn list_prompts_paginated(
677 &self,
678 cursor: Option<&str>,
679 ) -> Result<ListPromptsResult, McpError> {
680 self.ensure_capability("prompts", self.has_prompts())?;
681
682 let params = cursor.map(|c| serde_json::json!({ "cursor": c }));
683 self.request("prompts/list", params).await
684 }
685
686 pub async fn get_prompt(
692 &self,
693 name: impl Into<String>,
694 arguments: Option<serde_json::Map<String, serde_json::Value>>,
695 ) -> Result<GetPromptResult, McpError> {
696 self.ensure_capability("prompts", self.has_prompts())?;
697
698 let request = GetPromptRequest {
699 name: name.into(),
700 arguments,
701 };
702 self.request("prompts/get", Some(serde_json::to_value(request)?))
703 .await
704 }
705
706 pub async fn list_tasks(&self) -> Result<Vec<TaskSummary>, McpError> {
716 self.ensure_capability("tasks", self.has_tasks())?;
717
718 let result: ListTasksResult = self.request("tasks/list", None).await?;
719 Ok(result.tasks)
720 }
721
722 pub async fn list_tasks_filtered(
728 &self,
729 status: Option<TaskStatus>,
730 cursor: Option<&str>,
731 ) -> Result<ListTasksResult, McpError> {
732 self.ensure_capability("tasks", self.has_tasks())?;
733
734 let request = ListTasksRequest {
735 status,
736 cursor: cursor.map(String::from),
737 };
738 self.request("tasks/list", Some(serde_json::to_value(request)?))
739 .await
740 }
741
742 pub async fn get_task(&self, id: impl Into<String>) -> Result<Task, McpError> {
748 self.ensure_capability("tasks", self.has_tasks())?;
749
750 let request = GetTaskRequest {
751 id: id.into().into(),
752 };
753 self.request("tasks/get", Some(serde_json::to_value(request)?))
754 .await
755 }
756
757 pub async fn cancel_task(&self, id: impl Into<String>) -> Result<(), McpError> {
764 self.ensure_capability("tasks", self.has_tasks())?;
765
766 let request = CancelTaskRequest {
767 id: id.into().into(),
768 };
769 let _: serde_json::Value = self
770 .request("tasks/cancel", Some(serde_json::to_value(request)?))
771 .await?;
772 Ok(())
773 }
774
775 pub async fn complete_prompt_argument(
791 &self,
792 prompt_name: impl Into<String>,
793 argument_name: impl Into<String>,
794 current_value: impl Into<String>,
795 ) -> Result<CompleteResult, McpError> {
796 self.ensure_capability("completions", self.has_completions())?;
797
798 let request = CompleteRequest {
799 ref_: CompletionRef::prompt(prompt_name),
800 argument: CompletionArgument {
801 name: argument_name.into(),
802 value: current_value.into(),
803 },
804 };
805 self.request("completion/complete", Some(serde_json::to_value(request)?))
806 .await
807 }
808
809 pub async fn complete_resource_argument(
821 &self,
822 resource_uri: impl Into<String>,
823 argument_name: impl Into<String>,
824 current_value: impl Into<String>,
825 ) -> Result<CompleteResult, McpError> {
826 self.ensure_capability("completions", self.has_completions())?;
827
828 let request = CompleteRequest {
829 ref_: CompletionRef::resource(resource_uri),
830 argument: CompletionArgument {
831 name: argument_name.into(),
832 value: current_value.into(),
833 },
834 };
835 self.request("completion/complete", Some(serde_json::to_value(request)?))
836 .await
837 }
838
839 pub async fn subscribe_resource(&self, uri: impl Into<String>) -> Result<(), McpError> {
852 self.ensure_capability("resources", self.has_resources())?;
853
854 if !self.server_caps.has_resource_subscribe() {
856 return Err(McpError::CapabilityNotSupported {
857 capability: "resources.subscribe".to_string(),
858 available: self.available_capabilities().into_boxed_slice(),
859 });
860 }
861
862 let params = serde_json::json!({ "uri": uri.into() });
863 let _: serde_json::Value = self.request("resources/subscribe", Some(params)).await?;
864 Ok(())
865 }
866
867 pub async fn unsubscribe_resource(&self, uri: impl Into<String>) -> Result<(), McpError> {
873 self.ensure_capability("resources", self.has_resources())?;
874
875 if !self.server_caps.has_resource_subscribe() {
877 return Err(McpError::CapabilityNotSupported {
878 capability: "resources.subscribe".to_string(),
879 available: self.available_capabilities().into_boxed_slice(),
880 });
881 }
882
883 let params = serde_json::json!({ "uri": uri.into() });
884 let _: serde_json::Value = self.request("resources/unsubscribe", Some(params)).await?;
885 Ok(())
886 }
887
888 pub async fn ping(&self) -> Result<(), McpError> {
898 let _: serde_json::Value = self.request("ping", None).await?;
899 Ok(())
900 }
901
902 pub async fn close(self) -> Result<(), McpError> {
908 debug!("Closing client connection");
909
910 self.running.store(false, Ordering::SeqCst);
912
913 self.handler.on_disconnected().await;
915
916 self.transport.close().await.map_err(|e| {
918 McpError::Transport(Box::new(TransportDetails {
919 kind: TransportErrorKind::ConnectionClosed,
920 message: e.to_string(),
921 context: TransportContext::default(),
922 source: None,
923 }))
924 })
925 }
926
927 fn next_request_id(&self) -> RequestId {
933 RequestId::Number(self.next_id.fetch_add(1, Ordering::SeqCst))
934 }
935
936 async fn request<R: serde::de::DeserializeOwned>(
938 &self,
939 method: &str,
940 params: Option<serde_json::Value>,
941 ) -> Result<R, McpError> {
942 if !self.is_connected() {
943 return Err(McpError::Transport(Box::new(TransportDetails {
944 kind: TransportErrorKind::ConnectionClosed,
945 message: "Client is not connected".to_string(),
946 context: TransportContext::default(),
947 source: None,
948 })));
949 }
950
951 let id = self.next_request_id();
952 let request = if let Some(params) = params {
953 Request::with_params(method.to_string(), id.clone(), params)
954 } else {
955 Request::new(method.to_string(), id.clone())
956 };
957
958 trace!(?id, method, "Sending request");
959
960 let (tx, rx) = oneshot::channel();
962 {
963 let mut pending = self.pending.write().await;
964 pending.insert(id.clone(), tx);
965 }
966
967 self.outgoing_tx
969 .send(Message::Request(request))
970 .await
971 .map_err(|_| {
972 McpError::Transport(Box::new(TransportDetails {
973 kind: TransportErrorKind::WriteFailed,
974 message: "Failed to send request (channel closed)".to_string(),
975 context: TransportContext::default(),
976 source: None,
977 }))
978 })?;
979
980 let response = rx.await.map_err(|_| {
982 McpError::Transport(Box::new(TransportDetails {
983 kind: TransportErrorKind::ConnectionClosed,
984 message: "Response channel closed (server may have disconnected)".to_string(),
985 context: TransportContext::default(),
986 source: None,
987 }))
988 })?;
989
990 if let Some(error) = response.error {
992 return Err(McpError::Internal {
993 message: error.message,
994 source: None,
995 });
996 }
997
998 let result = response.result.ok_or_else(|| McpError::Internal {
999 message: "Response contained neither result nor error".to_string(),
1000 source: None,
1001 })?;
1002
1003 serde_json::from_value(result).map_err(McpError::from)
1004 }
1005
1006 fn ensure_capability(&self, name: &str, supported: bool) -> Result<(), McpError> {
1008 if supported {
1009 Ok(())
1010 } else {
1011 Err(McpError::CapabilityNotSupported {
1012 capability: name.to_string(),
1013 available: self.available_capabilities().into_boxed_slice(),
1014 })
1015 }
1016 }
1017
1018 fn available_capabilities(&self) -> Vec<String> {
1020 let mut caps = Vec::new();
1021 if self.has_tools() {
1022 caps.push("tools".to_string());
1023 }
1024 if self.has_resources() {
1025 caps.push("resources".to_string());
1026 }
1027 if self.has_prompts() {
1028 caps.push("prompts".to_string());
1029 }
1030 if self.has_tasks() {
1031 caps.push("tasks".to_string());
1032 }
1033 if self.has_completions() {
1034 caps.push("completions".to_string());
1035 }
1036 caps
1037 }
1038}
1039
1040pub(crate) async fn initialize<T: Transport>(
1057 transport: &T,
1058 client_info: &ClientInfo,
1059 capabilities: &ClientCapabilities,
1060) -> Result<InitializeResult, McpError> {
1061 debug!(
1062 protocol_version = %PROTOCOL_VERSION,
1063 supported_versions = ?SUPPORTED_PROTOCOL_VERSIONS,
1064 "Initializing MCP connection"
1065 );
1066
1067 let request = InitializeRequest::new(client_info.clone(), capabilities.clone());
1069 let init_request = Request::with_params(
1070 "initialize".to_string(),
1071 RequestId::Number(0),
1072 serde_json::to_value(&request)?,
1073 );
1074
1075 transport
1077 .send(Message::Request(init_request))
1078 .await
1079 .map_err(|e| {
1080 McpError::Transport(Box::new(TransportDetails {
1081 kind: TransportErrorKind::WriteFailed,
1082 message: format!("Failed to send initialize: {e}"),
1083 context: TransportContext::default(),
1084 source: None,
1085 }))
1086 })?;
1087
1088 let response = loop {
1090 match transport.recv().await {
1091 Ok(Some(Message::Response(r))) if r.id == RequestId::Number(0) => break r,
1092 Ok(Some(_)) => {} Ok(None) => {
1094 return Err(McpError::HandshakeFailed(Box::new(HandshakeDetails {
1095 message: "Connection closed during initialization".to_string(),
1096 client_version: Some(PROTOCOL_VERSION.to_string()),
1097 server_version: None,
1098 source: None,
1099 })));
1100 }
1101 Err(e) => {
1102 return Err(McpError::HandshakeFailed(Box::new(HandshakeDetails {
1103 message: format!("Transport error during initialization: {e}"),
1104 client_version: Some(PROTOCOL_VERSION.to_string()),
1105 server_version: None,
1106 source: None,
1107 })));
1108 }
1109 }
1110 };
1111
1112 if let Some(error) = response.error {
1114 return Err(McpError::HandshakeFailed(Box::new(HandshakeDetails {
1115 message: error.message,
1116 client_version: Some(PROTOCOL_VERSION.to_string()),
1117 server_version: None,
1118 source: None,
1119 })));
1120 }
1121
1122 let result: InitializeResult = response
1123 .result
1124 .map(serde_json::from_value)
1125 .transpose()?
1126 .ok_or_else(|| {
1127 McpError::HandshakeFailed(Box::new(HandshakeDetails {
1128 message: "Empty initialize result".to_string(),
1129 client_version: Some(PROTOCOL_VERSION.to_string()),
1130 server_version: None,
1131 source: None,
1132 }))
1133 })?;
1134
1135 let server_version = &result.protocol_version;
1137 if !is_version_supported(server_version) {
1138 warn!(
1139 server_version = %server_version,
1140 supported = ?SUPPORTED_PROTOCOL_VERSIONS,
1141 "Server returned unsupported protocol version"
1142 );
1143 return Err(McpError::HandshakeFailed(Box::new(HandshakeDetails {
1144 message: format!(
1145 "Unsupported protocol version: server returned '{server_version}', but client only supports {SUPPORTED_PROTOCOL_VERSIONS:?}"
1146 ),
1147 client_version: Some(PROTOCOL_VERSION.to_string()),
1148 server_version: Some(server_version.clone()),
1149 source: None,
1150 })));
1151 }
1152
1153 debug!(
1154 server = %result.server_info.name,
1155 server_version = %result.server_info.version,
1156 protocol_version = %result.protocol_version,
1157 "Received initialize result with compatible protocol version"
1158 );
1159
1160 let notification = Notification::new("notifications/initialized");
1162 transport
1163 .send(Message::Notification(notification))
1164 .await
1165 .map_err(|e| {
1166 McpError::Transport(Box::new(TransportDetails {
1167 kind: TransportErrorKind::WriteFailed,
1168 message: format!("Failed to send initialized: {e}"),
1169 context: TransportContext::default(),
1170 source: None,
1171 }))
1172 })?;
1173
1174 debug!("MCP initialization complete");
1175 Ok(result)
1176}
1177
1178#[cfg(test)]
1179mod tests {
1180 use super::*;
1181
1182 #[test]
1183 fn test_request_id_generation() {
1184 let next_id = AtomicU64::new(1);
1185 assert_eq!(next_id.fetch_add(1, Ordering::SeqCst), 1);
1186 assert_eq!(next_id.fetch_add(1, Ordering::SeqCst), 2);
1187 }
1188}