1use futures::channel::oneshot;
14use mcpkit_core::capability::{
15 ClientCapabilities, ClientInfo, InitializeRequest, InitializeResult, PROTOCOL_VERSION,
16 SUPPORTED_PROTOCOL_VERSIONS, ServerCapabilities, ServerInfo, is_version_supported,
17};
18use mcpkit_core::error::{
19 HandshakeDetails, JsonRpcError, McpError, TransportContext, TransportDetails,
20 TransportErrorKind,
21};
22use mcpkit_core::protocol::{Message, Notification, Request, RequestId, Response};
23use mcpkit_core::protocol_version::ProtocolVersion;
24use mcpkit_core::types::{
25 CallToolRequest, CallToolResult, CancelTaskRequest, CompleteRequest, CompleteResult,
26 CompletionArgument, CompletionRef, CreateMessageRequest, ElicitRequest, GetPromptRequest,
27 GetPromptResult, GetTaskRequest, ListPromptsResult, ListResourceTemplatesResult,
28 ListResourcesResult, ListTasksRequest, ListTasksResult, ListToolsResult, Prompt,
29 ReadResourceRequest, ReadResourceResult, Resource, ResourceContents, ResourceTemplate, Task,
30 TaskStatus, TaskSummary, Tool,
31};
32use mcpkit_transport::Transport;
33use std::collections::HashMap;
34use std::sync::Arc;
35use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
36use tracing::{debug, error, info, trace, warn};
37
38use async_lock::RwLock;
40
41#[cfg(feature = "tokio-runtime")]
43use tokio::sync::mpsc;
44
45use crate::handler::ClientHandler;
46
47pub struct Client<T: Transport, H: ClientHandler = crate::handler::NoOpHandler> {
78 transport: Arc<T>,
80 server_info: ServerInfo,
82 server_caps: ServerCapabilities,
84 protocol_version: ProtocolVersion,
89 client_info: ClientInfo,
91 client_caps: ClientCapabilities,
93 next_id: AtomicU64,
95 pending: Arc<RwLock<HashMap<RequestId, oneshot::Sender<Response>>>>,
97 instructions: Option<String>,
99 handler: Arc<H>,
101 outgoing_tx: mpsc::Sender<Message>,
103 running: Arc<AtomicBool>,
105 _background_handle: Option<tokio::task::JoinHandle<()>>,
107}
108
109impl<T: Transport + 'static> Client<T, crate::handler::NoOpHandler> {
110 pub(crate) fn new(
112 transport: T,
113 init_result: InitializeResult,
114 client_info: ClientInfo,
115 client_caps: ClientCapabilities,
116 ) -> Self {
117 Self::with_handler(
118 transport,
119 init_result,
120 client_info,
121 client_caps,
122 crate::handler::NoOpHandler,
123 )
124 }
125}
126
127impl<T: Transport + 'static, H: ClientHandler + 'static> Client<T, H> {
128 pub(crate) fn with_handler(
130 transport: T,
131 init_result: InitializeResult,
132 client_info: ClientInfo,
133 client_caps: ClientCapabilities,
134 handler: H,
135 ) -> Self {
136 let transport = Arc::new(transport);
137 let pending = Arc::new(RwLock::new(HashMap::new()));
138 let handler = Arc::new(handler);
139 let running = Arc::new(AtomicBool::new(true));
140
141 let protocol_version =
143 if let Ok(v) = init_result.protocol_version.parse::<ProtocolVersion>() {
144 v
145 } else {
146 warn!(
147 server_version = %init_result.protocol_version,
148 fallback_version = %ProtocolVersion::LATEST,
149 "Server returned unknown protocol version, falling back to latest supported"
150 );
151 ProtocolVersion::LATEST
152 };
153
154 let (outgoing_tx, outgoing_rx) = mpsc::channel::<Message>(256);
156
157 let background_handle = Self::spawn_message_router(
159 Arc::clone(&transport),
160 Arc::clone(&pending),
161 Arc::clone(&handler),
162 Arc::clone(&running),
163 outgoing_rx,
164 );
165
166 let handler_clone = Arc::clone(&handler);
168 tokio::spawn(async move {
169 handler_clone.on_connected().await;
170 });
171
172 Self {
173 transport,
174 server_info: init_result.server_info,
175 server_caps: init_result.capabilities,
176 protocol_version,
177 client_info,
178 client_caps,
179 next_id: AtomicU64::new(1),
180 pending,
181 instructions: init_result.instructions,
182 handler,
183 outgoing_tx,
184 running,
185 _background_handle: Some(background_handle),
186 }
187 }
188
189 fn spawn_message_router(
197 transport: Arc<T>,
198 pending: Arc<RwLock<HashMap<RequestId, oneshot::Sender<Response>>>>,
199 handler: Arc<H>,
200 running: Arc<AtomicBool>,
201 mut outgoing_rx: mpsc::Receiver<Message>,
202 ) -> tokio::task::JoinHandle<()> {
203 tokio::spawn(async move {
204 debug!("Starting client message router");
205
206 loop {
207 if !running.load(Ordering::SeqCst) {
208 debug!("Message router stopping (client closed)");
209 break;
210 }
211
212 tokio::select! {
213 Some(msg) = outgoing_rx.recv() => {
215 if let Err(e) = transport.send(msg).await {
216 error!(?e, "Failed to send message");
217 }
218 }
219
220 result = transport.recv() => {
222 match result {
223 Ok(Some(message)) => {
224 let msg_id = match &message {
226 Message::Request(r) => format!("Request({})", r.id),
227 Message::Response(r) => format!("Response({})", r.id),
228 Message::Notification(n) => format!("Notification({})", n.method),
229 };
230 debug!(msg = %msg_id, "Router received message from transport");
231 Self::handle_incoming_message(
232 message,
233 &pending,
234 &handler,
235 &transport,
236 ).await;
237 }
238 Ok(None) => {
239 info!("Connection closed by server");
240 running.store(false, Ordering::SeqCst);
241 handler.on_disconnected().await;
242 break;
243 }
244 Err(e) => {
245 error!(?e, "Transport error in message router");
246 running.store(false, Ordering::SeqCst);
247 handler.on_disconnected().await;
248 break;
249 }
250 }
251 }
252 }
253 }
254
255 debug!("Message router stopped");
256 })
257 }
258
259 async fn handle_incoming_message(
261 message: Message,
262 pending: &Arc<RwLock<HashMap<RequestId, oneshot::Sender<Response>>>>,
263 handler: &Arc<H>,
264 transport: &Arc<T>,
265 ) {
266 match message {
267 Message::Response(response) => {
268 Self::route_response(response, pending).await;
269 }
270 Message::Request(request) => {
271 Self::handle_server_request(request, handler, transport).await;
272 }
273 Message::Notification(notification) => {
274 Self::handle_notification(notification, handler).await;
275 }
276 }
277 }
278
279 async fn route_response(
281 response: Response,
282 pending: &Arc<RwLock<HashMap<RequestId, oneshot::Sender<Response>>>>,
283 ) {
284 let pending_count = pending.read().await.len();
285 let sender = {
286 let mut pending_guard = pending.write().await;
287 pending_guard.remove(&response.id)
288 };
289
290 if let Some(sender) = sender {
291 debug!(?response.id, pending_count, "Routing response to pending request (found in pending)");
292 if sender.send(response).is_err() {
293 warn!("Pending request receiver dropped");
294 }
295 } else {
296 debug!(?response.id, pending_count, "Response not found in pending (possible race or duplicate)");
302 }
303 }
304
305 async fn handle_server_request(request: Request, handler: &Arc<H>, transport: &Arc<T>) {
307 trace!(method = %request.method, "Handling server request");
308
309 let response = match request.method.as_ref() {
310 "sampling/createMessage" => Self::handle_sampling_request(&request, handler).await,
311 "elicitation/elicit" => Self::handle_elicitation_request(&request, handler).await,
312 "roots/list" => Self::handle_roots_request(&request, handler).await,
313 "ping" => {
314 Response::success(request.id.clone(), serde_json::json!({}))
316 }
317 _ => {
318 warn!(method = %request.method, "Unknown server request method");
319 Response::error(
320 request.id.clone(),
321 JsonRpcError::method_not_found(format!("Unknown method: {}", request.method)),
322 )
323 }
324 };
325
326 if let Err(e) = transport.send(Message::Response(response)).await {
328 error!(?e, "Failed to send response to server request");
329 }
330 }
331
332 async fn handle_sampling_request(request: &Request, handler: &Arc<H>) -> Response {
334 let params = match &request.params {
335 Some(p) => match serde_json::from_value::<CreateMessageRequest>(p.clone()) {
336 Ok(req) => req,
337 Err(e) => {
338 return Response::error(
339 request.id.clone(),
340 JsonRpcError::invalid_params(format!("Invalid params: {e}")),
341 );
342 }
343 },
344 None => {
345 return Response::error(
346 request.id.clone(),
347 JsonRpcError::invalid_params("Missing params for sampling/createMessage"),
348 );
349 }
350 };
351
352 match handler.create_message(params).await {
353 Ok(result) => match serde_json::to_value(result) {
354 Ok(value) => Response::success(request.id.clone(), value),
355 Err(e) => Response::error(
356 request.id.clone(),
357 JsonRpcError::internal_error(format!("Serialization error: {e}")),
358 ),
359 },
360 Err(e) => Response::error(
361 request.id.clone(),
362 JsonRpcError::internal_error(e.to_string()),
363 ),
364 }
365 }
366
367 async fn handle_elicitation_request(request: &Request, handler: &Arc<H>) -> Response {
369 let params = match &request.params {
370 Some(p) => match serde_json::from_value::<ElicitRequest>(p.clone()) {
371 Ok(req) => req,
372 Err(e) => {
373 return Response::error(
374 request.id.clone(),
375 JsonRpcError::invalid_params(format!("Invalid params: {e}")),
376 );
377 }
378 },
379 None => {
380 return Response::error(
381 request.id.clone(),
382 JsonRpcError::invalid_params("Missing params for elicitation/elicit"),
383 );
384 }
385 };
386
387 match handler.elicit(params).await {
388 Ok(result) => match serde_json::to_value(result) {
389 Ok(value) => Response::success(request.id.clone(), value),
390 Err(e) => Response::error(
391 request.id.clone(),
392 JsonRpcError::internal_error(format!("Serialization error: {e}")),
393 ),
394 },
395 Err(e) => Response::error(
396 request.id.clone(),
397 JsonRpcError::internal_error(e.to_string()),
398 ),
399 }
400 }
401
402 async fn handle_roots_request(request: &Request, handler: &Arc<H>) -> Response {
404 match handler.list_roots().await {
405 Ok(roots) => {
406 let roots_json: Vec<serde_json::Value> = roots
407 .into_iter()
408 .map(|r| {
409 serde_json::json!({
410 "uri": r.uri,
411 "name": r.name
412 })
413 })
414 .collect();
415 Response::success(
416 request.id.clone(),
417 serde_json::json!({ "roots": roots_json }),
418 )
419 }
420 Err(e) => Response::error(
421 request.id.clone(),
422 JsonRpcError::internal_error(e.to_string()),
423 ),
424 }
425 }
426
427 async fn handle_notification(notification: Notification, handler: &Arc<H>) {
429 trace!(method = %notification.method, "Received server notification");
430
431 match notification.method.as_ref() {
432 "notifications/cancelled" => {
433 if let Some(params) = ¬ification.params {
435 if let Some(request_id) = params.get("requestId") {
436 debug!(?request_id, "Server cancelled request");
437 }
438 }
439 }
440 "notifications/progress" => {
441 if let Some(params) = notification.params {
443 if let (Some(task_id), Some(progress)) = (
444 params.get("progressToken").and_then(|v| v.as_str()),
445 params.get("progress"),
446 ) {
447 if let Ok(progress) = serde_json::from_value::<
448 mcpkit_core::types::TaskProgress,
449 >(progress.clone())
450 {
451 debug!(task_id = %task_id, "Task progress update");
452 handler.on_task_progress(task_id.into(), progress).await;
453 }
454 }
455 }
456 }
457 "notifications/resources/updated" => {
458 if let Some(params) = notification.params {
459 if let Some(uri) = params.get("uri").and_then(|v| v.as_str()) {
460 debug!(uri = %uri, "Resource updated");
461 handler.on_resource_updated(uri.to_string()).await;
462 }
463 }
464 }
465 "notifications/resources/list_changed" => {
466 debug!("Resources list changed");
467 handler.on_resources_list_changed().await;
468 }
469 "notifications/tools/list_changed" => {
470 debug!("Tools list changed");
471 handler.on_tools_list_changed().await;
472 }
473 "notifications/prompts/list_changed" => {
474 debug!("Prompts list changed");
475 handler.on_prompts_list_changed().await;
476 }
477 _ => {
478 trace!(method = %notification.method, "Unhandled notification");
479 }
480 }
481 }
482
483 pub const fn server_info(&self) -> &ServerInfo {
485 &self.server_info
486 }
487
488 pub const fn server_capabilities(&self) -> &ServerCapabilities {
490 &self.server_caps
491 }
492
493 pub fn protocol_version(&self) -> ProtocolVersion {
502 self.protocol_version
503 }
504
505 pub const fn client_info(&self) -> &ClientInfo {
507 &self.client_info
508 }
509
510 pub const fn client_capabilities(&self) -> &ClientCapabilities {
512 &self.client_caps
513 }
514
515 pub fn instructions(&self) -> Option<&str> {
517 self.instructions.as_deref()
518 }
519
520 pub const fn has_tools(&self) -> bool {
522 self.server_caps.has_tools()
523 }
524
525 pub const fn has_resources(&self) -> bool {
527 self.server_caps.has_resources()
528 }
529
530 pub const fn has_prompts(&self) -> bool {
532 self.server_caps.has_prompts()
533 }
534
535 pub const fn has_tasks(&self) -> bool {
537 self.server_caps.has_tasks()
538 }
539
540 pub const fn has_completions(&self) -> bool {
542 self.server_caps.has_completions()
543 }
544
545 pub fn is_connected(&self) -> bool {
547 self.running.load(Ordering::SeqCst)
548 }
549
550 pub async fn list_tools(&self) -> Result<Vec<Tool>, McpError> {
560 self.ensure_capability("tools", self.has_tools())?;
561
562 let result: ListToolsResult = self.request("tools/list", None).await?;
563 Ok(result.tools)
564 }
565
566 pub async fn list_tools_paginated(
572 &self,
573 cursor: Option<&str>,
574 ) -> Result<ListToolsResult, McpError> {
575 self.ensure_capability("tools", self.has_tools())?;
576
577 let params = cursor.map(|c| serde_json::json!({ "cursor": c }));
578 self.request("tools/list", params).await
579 }
580
581 pub async fn call_tool(
592 &self,
593 name: impl Into<String>,
594 arguments: serde_json::Value,
595 ) -> Result<CallToolResult, McpError> {
596 self.ensure_capability("tools", self.has_tools())?;
597
598 let request = CallToolRequest {
599 name: name.into(),
600 arguments: Some(arguments),
601 };
602 self.request("tools/call", Some(serde_json::to_value(request)?))
603 .await
604 }
605
606 pub async fn list_resources(&self) -> Result<Vec<Resource>, McpError> {
616 self.ensure_capability("resources", self.has_resources())?;
617
618 let result: ListResourcesResult = self.request("resources/list", None).await?;
619 Ok(result.resources)
620 }
621
622 pub async fn list_resources_paginated(
628 &self,
629 cursor: Option<&str>,
630 ) -> Result<ListResourcesResult, McpError> {
631 self.ensure_capability("resources", self.has_resources())?;
632
633 let params = cursor.map(|c| serde_json::json!({ "cursor": c }));
634 self.request("resources/list", params).await
635 }
636
637 pub async fn list_resource_templates(&self) -> Result<Vec<ResourceTemplate>, McpError> {
643 self.ensure_capability("resources", self.has_resources())?;
644
645 let result: ListResourceTemplatesResult =
646 self.request("resources/templates/list", None).await?;
647 Ok(result.resource_templates)
648 }
649
650 pub async fn read_resource(
656 &self,
657 uri: impl Into<String>,
658 ) -> Result<Vec<ResourceContents>, McpError> {
659 self.ensure_capability("resources", self.has_resources())?;
660
661 let request = ReadResourceRequest { uri: uri.into() };
662 let result: ReadResourceResult = self
663 .request("resources/read", Some(serde_json::to_value(request)?))
664 .await?;
665 Ok(result.contents)
666 }
667
668 pub async fn list_prompts(&self) -> Result<Vec<Prompt>, McpError> {
678 self.ensure_capability("prompts", self.has_prompts())?;
679
680 let result: ListPromptsResult = self.request("prompts/list", None).await?;
681 Ok(result.prompts)
682 }
683
684 pub async fn list_prompts_paginated(
690 &self,
691 cursor: Option<&str>,
692 ) -> Result<ListPromptsResult, McpError> {
693 self.ensure_capability("prompts", self.has_prompts())?;
694
695 let params = cursor.map(|c| serde_json::json!({ "cursor": c }));
696 self.request("prompts/list", params).await
697 }
698
699 pub async fn get_prompt(
705 &self,
706 name: impl Into<String>,
707 arguments: Option<serde_json::Map<String, serde_json::Value>>,
708 ) -> Result<GetPromptResult, McpError> {
709 self.ensure_capability("prompts", self.has_prompts())?;
710
711 let request = GetPromptRequest {
712 name: name.into(),
713 arguments,
714 };
715 self.request("prompts/get", Some(serde_json::to_value(request)?))
716 .await
717 }
718
719 pub async fn list_tasks(&self) -> Result<Vec<TaskSummary>, McpError> {
729 self.ensure_capability("tasks", self.has_tasks())?;
730
731 let result: ListTasksResult = self.request("tasks/list", None).await?;
732 Ok(result.tasks)
733 }
734
735 pub async fn list_tasks_filtered(
741 &self,
742 status: Option<TaskStatus>,
743 cursor: Option<&str>,
744 ) -> Result<ListTasksResult, McpError> {
745 self.ensure_capability("tasks", self.has_tasks())?;
746
747 let request = ListTasksRequest {
748 status,
749 cursor: cursor.map(String::from),
750 };
751 self.request("tasks/list", Some(serde_json::to_value(request)?))
752 .await
753 }
754
755 pub async fn get_task(&self, id: impl Into<String>) -> Result<Task, McpError> {
761 self.ensure_capability("tasks", self.has_tasks())?;
762
763 let request = GetTaskRequest {
764 id: id.into().into(),
765 };
766 self.request("tasks/get", Some(serde_json::to_value(request)?))
767 .await
768 }
769
770 pub async fn cancel_task(&self, id: impl Into<String>) -> Result<(), McpError> {
777 self.ensure_capability("tasks", self.has_tasks())?;
778
779 let request = CancelTaskRequest {
780 id: id.into().into(),
781 };
782 let _: serde_json::Value = self
783 .request("tasks/cancel", Some(serde_json::to_value(request)?))
784 .await?;
785 Ok(())
786 }
787
788 pub async fn complete_prompt_argument(
804 &self,
805 prompt_name: impl Into<String>,
806 argument_name: impl Into<String>,
807 current_value: impl Into<String>,
808 ) -> Result<CompleteResult, McpError> {
809 self.ensure_capability("completions", self.has_completions())?;
810
811 let request = CompleteRequest {
812 ref_: CompletionRef::prompt(prompt_name),
813 argument: CompletionArgument {
814 name: argument_name.into(),
815 value: current_value.into(),
816 },
817 };
818 self.request("completion/complete", Some(serde_json::to_value(request)?))
819 .await
820 }
821
822 pub async fn complete_resource_argument(
834 &self,
835 resource_uri: impl Into<String>,
836 argument_name: impl Into<String>,
837 current_value: impl Into<String>,
838 ) -> Result<CompleteResult, McpError> {
839 self.ensure_capability("completions", self.has_completions())?;
840
841 let request = CompleteRequest {
842 ref_: CompletionRef::resource(resource_uri),
843 argument: CompletionArgument {
844 name: argument_name.into(),
845 value: current_value.into(),
846 },
847 };
848 self.request("completion/complete", Some(serde_json::to_value(request)?))
849 .await
850 }
851
852 pub async fn subscribe_resource(&self, uri: impl Into<String>) -> Result<(), McpError> {
865 self.ensure_capability("resources", self.has_resources())?;
866
867 if !self.server_caps.has_resource_subscribe() {
869 return Err(McpError::CapabilityNotSupported {
870 capability: "resources.subscribe".to_string(),
871 available: self.available_capabilities().into_boxed_slice(),
872 });
873 }
874
875 let params = serde_json::json!({ "uri": uri.into() });
876 let _: serde_json::Value = self.request("resources/subscribe", Some(params)).await?;
877 Ok(())
878 }
879
880 pub async fn unsubscribe_resource(&self, uri: impl Into<String>) -> Result<(), McpError> {
886 self.ensure_capability("resources", self.has_resources())?;
887
888 if !self.server_caps.has_resource_subscribe() {
890 return Err(McpError::CapabilityNotSupported {
891 capability: "resources.subscribe".to_string(),
892 available: self.available_capabilities().into_boxed_slice(),
893 });
894 }
895
896 let params = serde_json::json!({ "uri": uri.into() });
897 let _: serde_json::Value = self.request("resources/unsubscribe", Some(params)).await?;
898 Ok(())
899 }
900
901 pub async fn ping(&self) -> Result<(), McpError> {
911 let _: serde_json::Value = self.request("ping", None).await?;
912 Ok(())
913 }
914
915 pub async fn close(self) -> Result<(), McpError> {
921 debug!("Closing client connection");
922
923 self.running.store(false, Ordering::SeqCst);
925
926 self.handler.on_disconnected().await;
928
929 self.transport.close().await.map_err(|e| {
931 McpError::Transport(Box::new(TransportDetails {
932 kind: TransportErrorKind::ConnectionClosed,
933 message: e.to_string(),
934 context: TransportContext::default(),
935 source: None,
936 }))
937 })
938 }
939
940 fn next_request_id(&self) -> RequestId {
946 RequestId::Number(self.next_id.fetch_add(1, Ordering::SeqCst))
947 }
948
949 async fn request<R: serde::de::DeserializeOwned>(
951 &self,
952 method: &str,
953 params: Option<serde_json::Value>,
954 ) -> Result<R, McpError> {
955 if !self.is_connected() {
956 return Err(McpError::Transport(Box::new(TransportDetails {
957 kind: TransportErrorKind::ConnectionClosed,
958 message: "Client is not connected".to_string(),
959 context: TransportContext::default(),
960 source: None,
961 })));
962 }
963
964 let id = self.next_request_id();
965 let request = if let Some(params) = params {
966 Request::with_params(method.to_string(), id.clone(), params)
967 } else {
968 Request::new(method.to_string(), id.clone())
969 };
970
971 trace!(?id, method, "Sending request");
972
973 let (tx, rx) = oneshot::channel();
975 {
976 let mut pending = self.pending.write().await;
977 pending.insert(id.clone(), tx);
978 }
979
980 self.outgoing_tx
982 .send(Message::Request(request))
983 .await
984 .map_err(|_| {
985 McpError::Transport(Box::new(TransportDetails {
986 kind: TransportErrorKind::WriteFailed,
987 message: "Failed to send request (channel closed)".to_string(),
988 context: TransportContext::default(),
989 source: None,
990 }))
991 })?;
992
993 let response = rx.await.map_err(|_| {
995 McpError::Transport(Box::new(TransportDetails {
996 kind: TransportErrorKind::ConnectionClosed,
997 message: "Response channel closed (server may have disconnected)".to_string(),
998 context: TransportContext::default(),
999 source: None,
1000 }))
1001 })?;
1002
1003 if let Some(error) = response.error {
1005 return Err(McpError::Internal {
1006 message: error.message,
1007 source: None,
1008 });
1009 }
1010
1011 let result = response.result.ok_or_else(|| McpError::Internal {
1012 message: "Response contained neither result nor error".to_string(),
1013 source: None,
1014 })?;
1015
1016 serde_json::from_value(result).map_err(McpError::from)
1017 }
1018
1019 fn ensure_capability(&self, name: &str, supported: bool) -> Result<(), McpError> {
1021 if supported {
1022 Ok(())
1023 } else {
1024 Err(McpError::CapabilityNotSupported {
1025 capability: name.to_string(),
1026 available: self.available_capabilities().into_boxed_slice(),
1027 })
1028 }
1029 }
1030
1031 fn available_capabilities(&self) -> Vec<String> {
1033 let mut caps = Vec::new();
1034 if self.has_tools() {
1035 caps.push("tools".to_string());
1036 }
1037 if self.has_resources() {
1038 caps.push("resources".to_string());
1039 }
1040 if self.has_prompts() {
1041 caps.push("prompts".to_string());
1042 }
1043 if self.has_tasks() {
1044 caps.push("tasks".to_string());
1045 }
1046 if self.has_completions() {
1047 caps.push("completions".to_string());
1048 }
1049 caps
1050 }
1051}
1052
1053pub(crate) async fn initialize<T: Transport>(
1070 transport: &T,
1071 client_info: &ClientInfo,
1072 capabilities: &ClientCapabilities,
1073) -> Result<InitializeResult, McpError> {
1074 debug!(
1075 protocol_version = %PROTOCOL_VERSION,
1076 supported_versions = ?SUPPORTED_PROTOCOL_VERSIONS,
1077 "Initializing MCP connection"
1078 );
1079
1080 let request = InitializeRequest::new(client_info.clone(), capabilities.clone());
1082 let init_request = Request::with_params(
1083 "initialize".to_string(),
1084 RequestId::Number(0),
1085 serde_json::to_value(&request)?,
1086 );
1087
1088 transport
1090 .send(Message::Request(init_request))
1091 .await
1092 .map_err(|e| {
1093 McpError::Transport(Box::new(TransportDetails {
1094 kind: TransportErrorKind::WriteFailed,
1095 message: format!("Failed to send initialize: {e}"),
1096 context: TransportContext::default(),
1097 source: None,
1098 }))
1099 })?;
1100
1101 let response = loop {
1103 match transport.recv().await {
1104 Ok(Some(Message::Response(r))) if r.id == RequestId::Number(0) => break r,
1105 Ok(Some(_)) => {} Ok(None) => {
1107 return Err(McpError::HandshakeFailed(Box::new(HandshakeDetails {
1108 message: "Connection closed during initialization".to_string(),
1109 client_version: Some(PROTOCOL_VERSION.to_string()),
1110 server_version: None,
1111 source: None,
1112 })));
1113 }
1114 Err(e) => {
1115 return Err(McpError::HandshakeFailed(Box::new(HandshakeDetails {
1116 message: format!("Transport error during initialization: {e}"),
1117 client_version: Some(PROTOCOL_VERSION.to_string()),
1118 server_version: None,
1119 source: None,
1120 })));
1121 }
1122 }
1123 };
1124
1125 if let Some(error) = response.error {
1127 return Err(McpError::HandshakeFailed(Box::new(HandshakeDetails {
1128 message: error.message,
1129 client_version: Some(PROTOCOL_VERSION.to_string()),
1130 server_version: None,
1131 source: None,
1132 })));
1133 }
1134
1135 let result: InitializeResult = response
1136 .result
1137 .map(serde_json::from_value)
1138 .transpose()?
1139 .ok_or_else(|| {
1140 McpError::HandshakeFailed(Box::new(HandshakeDetails {
1141 message: "Empty initialize result".to_string(),
1142 client_version: Some(PROTOCOL_VERSION.to_string()),
1143 server_version: None,
1144 source: None,
1145 }))
1146 })?;
1147
1148 let server_version = &result.protocol_version;
1150 if !is_version_supported(server_version) {
1151 warn!(
1152 server_version = %server_version,
1153 supported = ?SUPPORTED_PROTOCOL_VERSIONS,
1154 "Server returned unsupported protocol version"
1155 );
1156 return Err(McpError::HandshakeFailed(Box::new(HandshakeDetails {
1157 message: format!(
1158 "Unsupported protocol version: server returned '{server_version}', but client only supports {SUPPORTED_PROTOCOL_VERSIONS:?}"
1159 ),
1160 client_version: Some(PROTOCOL_VERSION.to_string()),
1161 server_version: Some(server_version.clone()),
1162 source: None,
1163 })));
1164 }
1165
1166 debug!(
1167 server = %result.server_info.name,
1168 server_version = %result.server_info.version,
1169 protocol_version = %result.protocol_version,
1170 "Received initialize result with compatible protocol version"
1171 );
1172
1173 let notification = Notification::new("notifications/initialized");
1175 transport
1176 .send(Message::Notification(notification))
1177 .await
1178 .map_err(|e| {
1179 McpError::Transport(Box::new(TransportDetails {
1180 kind: TransportErrorKind::WriteFailed,
1181 message: format!("Failed to send initialized: {e}"),
1182 context: TransportContext::default(),
1183 source: None,
1184 }))
1185 })?;
1186
1187 debug!("MCP initialization complete");
1188 Ok(result)
1189}
1190
1191#[cfg(test)]
1192mod tests {
1193 use super::*;
1194
1195 #[test]
1196 fn test_request_id_generation() {
1197 let next_id = AtomicU64::new(1);
1198 assert_eq!(next_id.fetch_add(1, Ordering::SeqCst), 1);
1199 assert_eq!(next_id.fetch_add(1, Ordering::SeqCst), 2);
1200 }
1201}