mcpkit_client/
client.rs

1//! MCP client implementation.
2//!
3//! The [`Client`] struct provides a high-level API for interacting with
4//! MCP servers. It handles:
5//!
6//! - Protocol initialization
7//! - Request/response correlation
8//! - Tool, resource, and prompt operations
9//! - Task tracking
10//! - Connection lifecycle
11//! - Server-initiated request handling via [`ClientHandler`]
12
13use futures::channel::oneshot;
14use mcpkit_core::capability::{
15    is_version_supported, ClientCapabilities, ClientInfo, InitializeRequest, InitializeResult,
16    ServerCapabilities, ServerInfo, PROTOCOL_VERSION, SUPPORTED_PROTOCOL_VERSIONS,
17};
18use mcpkit_core::error::{HandshakeDetails, JsonRpcError, McpError, TransportContext, TransportDetails, TransportErrorKind};
19use mcpkit_core::protocol::{Message, Notification, Request, RequestId, Response};
20use mcpkit_core::types::{
21    CallToolRequest, CallToolResult, CreateMessageRequest, ElicitRequest,
22    GetPromptRequest, GetPromptResult, ListPromptsResult, ListResourcesResult,
23    ListResourceTemplatesResult, ListToolsResult, Prompt, ReadResourceRequest,
24    ReadResourceResult, Resource, ResourceContents, ResourceTemplate, Tool,
25};
26use mcpkit_transport::Transport;
27use std::collections::HashMap;
28use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
29use std::sync::Arc;
30use tracing::{debug, error, info, trace, warn};
31
32// Runtime-agnostic sync primitives
33use async_lock::RwLock;
34
35// Use tokio channels when tokio-runtime is enabled, otherwise use the transport abstraction
36#[cfg(feature = "tokio-runtime")]
37use tokio::sync::mpsc;
38
39use crate::handler::ClientHandler;
40
41/// An MCP client connected to a server.
42///
43/// The client provides methods for interacting with MCP servers:
44///
45/// - Tools: `list_tools()`, `call_tool()`
46/// - Resources: `list_resources()`, `read_resource()`
47/// - Prompts: `list_prompts()`, `get_prompt()`
48/// - Tasks: `list_tasks()`, `get_task()`, `cancel_task()`
49///
50/// The client also handles server-initiated requests (sampling, elicitation)
51/// by delegating to a [`ClientHandler`] implementation.
52///
53/// # Example
54///
55/// ```no_run
56/// use mcpkit_client::ClientBuilder;
57/// use mcpkit_transport::SpawnedTransport;
58///
59/// # async fn example() -> Result<(), mcpkit_core::error::McpError> {
60/// let transport = SpawnedTransport::spawn("my-server", &[] as &[&str]).await?;
61/// let client = ClientBuilder::new()
62///     .name("my-client")
63///     .version("1.0.0")
64///     .build(transport)
65///     .await?;
66///
67/// let tools = client.list_tools().await?;
68/// # Ok(())
69/// # }
70/// ```
71pub struct Client<T: Transport, H: ClientHandler = crate::handler::NoOpHandler> {
72    /// The underlying transport (shared with background task).
73    transport: Arc<T>,
74    /// Server information received during initialization.
75    server_info: ServerInfo,
76    /// Server capabilities.
77    server_caps: ServerCapabilities,
78    /// Client information.
79    client_info: ClientInfo,
80    /// Client capabilities.
81    client_caps: ClientCapabilities,
82    /// Next request ID.
83    next_id: AtomicU64,
84    /// Pending requests awaiting responses.
85    pending: Arc<RwLock<HashMap<RequestId, oneshot::Sender<Response>>>>,
86    /// Instructions from the server.
87    instructions: Option<String>,
88    /// Handler for server-initiated requests.
89    handler: Arc<H>,
90    /// Sender for outgoing messages to the background task.
91    outgoing_tx: mpsc::Sender<Message>,
92    /// Flag indicating if the client is running.
93    running: Arc<AtomicBool>,
94    /// Handle to the background task.
95    _background_handle: Option<tokio::task::JoinHandle<()>>,
96}
97
98impl<T: Transport + 'static> Client<T, crate::handler::NoOpHandler> {
99    /// Create a new client without a handler (called by builder).
100    pub(crate) fn new(
101        transport: T,
102        init_result: InitializeResult,
103        client_info: ClientInfo,
104        client_caps: ClientCapabilities,
105    ) -> Self {
106        Self::with_handler(transport, init_result, client_info, client_caps, crate::handler::NoOpHandler)
107    }
108}
109
110impl<T: Transport + 'static, H: ClientHandler + 'static> Client<T, H> {
111    /// Create a new client with a custom handler (called by builder).
112    pub(crate) fn with_handler(
113        transport: T,
114        init_result: InitializeResult,
115        client_info: ClientInfo,
116        client_caps: ClientCapabilities,
117        handler: H,
118    ) -> Self {
119        let transport = Arc::new(transport);
120        let pending = Arc::new(RwLock::new(HashMap::new()));
121        let handler = Arc::new(handler);
122        let running = Arc::new(AtomicBool::new(true));
123
124        // Create channel for outgoing messages
125        let (outgoing_tx, outgoing_rx) = mpsc::channel::<Message>(256);
126
127        // Start background message routing task
128        let background_handle = Self::spawn_message_router(
129            Arc::clone(&transport),
130            Arc::clone(&pending),
131            Arc::clone(&handler),
132            Arc::clone(&running),
133            outgoing_rx,
134        );
135
136        // Notify handler that connection is established
137        let handler_clone = Arc::clone(&handler);
138        tokio::spawn(async move {
139            handler_clone.on_connected().await;
140        });
141
142        Self {
143            transport,
144            server_info: init_result.server_info,
145            server_caps: init_result.capabilities,
146            client_info,
147            client_caps,
148            next_id: AtomicU64::new(1),
149            pending,
150            instructions: init_result.instructions,
151            handler,
152            outgoing_tx,
153            running,
154            _background_handle: Some(background_handle),
155        }
156    }
157
158    /// Spawn the background message routing task.
159    ///
160    /// This task:
161    /// - Reads incoming messages from the transport
162    /// - Routes responses to pending request channels
163    /// - Delegates server-initiated requests to the handler
164    /// - Handles notifications
165    fn spawn_message_router(
166        transport: Arc<T>,
167        pending: Arc<RwLock<HashMap<RequestId, oneshot::Sender<Response>>>>,
168        handler: Arc<H>,
169        running: Arc<AtomicBool>,
170        mut outgoing_rx: mpsc::Receiver<Message>,
171    ) -> tokio::task::JoinHandle<()> {
172        tokio::spawn(async move {
173            debug!("Starting client message router");
174
175            loop {
176                if !running.load(Ordering::SeqCst) {
177                    debug!("Message router stopping (client closed)");
178                    break;
179                }
180
181                tokio::select! {
182                    // Handle outgoing messages
183                    Some(msg) = outgoing_rx.recv() => {
184                        if let Err(e) = transport.send(msg).await {
185                            error!(?e, "Failed to send message");
186                        }
187                    }
188
189                    // Handle incoming messages
190                    result = transport.recv() => {
191                        match result {
192                            Ok(Some(message)) => {
193                                Self::handle_incoming_message(
194                                    message,
195                                    &pending,
196                                    &handler,
197                                    &transport,
198                                ).await;
199                            }
200                            Ok(None) => {
201                                info!("Connection closed by server");
202                                running.store(false, Ordering::SeqCst);
203                                handler.on_disconnected().await;
204                                break;
205                            }
206                            Err(e) => {
207                                error!(?e, "Transport error in message router");
208                                running.store(false, Ordering::SeqCst);
209                                handler.on_disconnected().await;
210                                break;
211                            }
212                        }
213                    }
214                }
215            }
216
217            debug!("Message router stopped");
218        })
219    }
220
221    /// Handle an incoming message from the server.
222    async fn handle_incoming_message(
223        message: Message,
224        pending: &Arc<RwLock<HashMap<RequestId, oneshot::Sender<Response>>>>,
225        handler: &Arc<H>,
226        transport: &Arc<T>,
227    ) {
228        match message {
229            Message::Response(response) => {
230                Self::route_response(response, pending).await;
231            }
232            Message::Request(request) => {
233                Self::handle_server_request(request, handler, transport).await;
234            }
235            Message::Notification(notification) => {
236                Self::handle_notification(notification, handler).await;
237            }
238        }
239    }
240
241    /// Route a response to the appropriate pending request.
242    async fn route_response(
243        response: Response,
244        pending: &Arc<RwLock<HashMap<RequestId, oneshot::Sender<Response>>>>,
245    ) {
246        let sender = {
247            let mut pending_guard = pending.write().await;
248            pending_guard.remove(&response.id)
249        };
250
251        if let Some(sender) = sender {
252            trace!(?response.id, "Routing response to pending request");
253            if sender.send(response).is_err() {
254                warn!("Pending request receiver dropped");
255            }
256        } else {
257            warn!(?response.id, "Received response for unknown request");
258        }
259    }
260
261    /// Handle a server-initiated request.
262    async fn handle_server_request(
263        request: Request,
264        handler: &Arc<H>,
265        transport: &Arc<T>,
266    ) {
267        trace!(method = %request.method, "Handling server request");
268
269        let response = match request.method.as_ref() {
270            "sampling/createMessage" => {
271                Self::handle_sampling_request(&request, handler).await
272            }
273            "elicitation/elicit" => {
274                Self::handle_elicitation_request(&request, handler).await
275            }
276            "roots/list" => {
277                Self::handle_roots_request(&request, handler).await
278            }
279            "ping" => {
280                // Respond to ping with empty result
281                Response::success(request.id.clone(), serde_json::json!({}))
282            }
283            _ => {
284                warn!(method = %request.method, "Unknown server request method");
285                Response::error(
286                    request.id.clone(),
287                    JsonRpcError::method_not_found(&format!("Unknown method: {}", request.method)),
288                )
289            }
290        };
291
292        // Send the response
293        if let Err(e) = transport.send(Message::Response(response)).await {
294            error!(?e, "Failed to send response to server request");
295        }
296    }
297
298    /// Handle a sampling/createMessage request.
299    async fn handle_sampling_request(request: &Request, handler: &Arc<H>) -> Response {
300        let params = match &request.params {
301            Some(p) => match serde_json::from_value::<CreateMessageRequest>(p.clone()) {
302                Ok(req) => req,
303                Err(e) => {
304                    return Response::error(
305                        request.id.clone(),
306                        JsonRpcError::invalid_params(format!("Invalid params: {e}")),
307                    );
308                }
309            },
310            None => {
311                return Response::error(
312                    request.id.clone(),
313                    JsonRpcError::invalid_params("Missing params for sampling/createMessage"),
314                );
315            }
316        };
317
318        match handler.create_message(params).await {
319            Ok(result) => {
320                match serde_json::to_value(result) {
321                    Ok(value) => Response::success(request.id.clone(), value),
322                    Err(e) => Response::error(
323                        request.id.clone(),
324                        JsonRpcError::internal_error(format!("Serialization error: {e}")),
325                    ),
326                }
327            }
328            Err(e) => Response::error(
329                request.id.clone(),
330                JsonRpcError::internal_error(e.to_string()),
331            ),
332        }
333    }
334
335    /// Handle an elicitation/elicit request.
336    async fn handle_elicitation_request(request: &Request, handler: &Arc<H>) -> Response {
337        let params = match &request.params {
338            Some(p) => match serde_json::from_value::<ElicitRequest>(p.clone()) {
339                Ok(req) => req,
340                Err(e) => {
341                    return Response::error(
342                        request.id.clone(),
343                        JsonRpcError::invalid_params(format!("Invalid params: {e}")),
344                    );
345                }
346            },
347            None => {
348                return Response::error(
349                    request.id.clone(),
350                    JsonRpcError::invalid_params("Missing params for elicitation/elicit"),
351                );
352            }
353        };
354
355        match handler.elicit(params).await {
356            Ok(result) => {
357                match serde_json::to_value(result) {
358                    Ok(value) => Response::success(request.id.clone(), value),
359                    Err(e) => Response::error(
360                        request.id.clone(),
361                        JsonRpcError::internal_error(format!("Serialization error: {e}")),
362                    ),
363                }
364            }
365            Err(e) => Response::error(
366                request.id.clone(),
367                JsonRpcError::internal_error(e.to_string()),
368            ),
369        }
370    }
371
372    /// Handle a roots/list request.
373    async fn handle_roots_request(request: &Request, handler: &Arc<H>) -> Response {
374        match handler.list_roots().await {
375            Ok(roots) => {
376                let roots_json: Vec<serde_json::Value> = roots
377                    .into_iter()
378                    .map(|r| {
379                        serde_json::json!({
380                            "uri": r.uri,
381                            "name": r.name
382                        })
383                    })
384                    .collect();
385                Response::success(request.id.clone(), serde_json::json!({ "roots": roots_json }))
386            }
387            Err(e) => Response::error(
388                request.id.clone(),
389                JsonRpcError::internal_error(&e.to_string()),
390            ),
391        }
392    }
393
394    /// Handle a notification from the server.
395    async fn handle_notification(notification: Notification, _handler: &Arc<H>) {
396        trace!(method = %notification.method, "Received server notification");
397
398        match notification.method.as_ref() {
399            "notifications/cancelled" => {
400                // Handle cancellation notifications
401                if let Some(params) = notification.params {
402                    if let Some(request_id) = params.get("requestId") {
403                        debug!(?request_id, "Server cancelled request");
404                    }
405                }
406            }
407            "notifications/progress" => {
408                // Handle progress notifications
409                trace!("Received progress notification");
410            }
411            "notifications/resources/updated" => {
412                trace!("Resources updated notification");
413            }
414            "notifications/tools/list_changed" => {
415                trace!("Tools list changed notification");
416            }
417            "notifications/prompts/list_changed" => {
418                trace!("Prompts list changed notification");
419            }
420            _ => {
421                trace!(method = %notification.method, "Unhandled notification");
422            }
423        }
424    }
425
426    /// Get the server information.
427    pub fn server_info(&self) -> &ServerInfo {
428        &self.server_info
429    }
430
431    /// Get the server capabilities.
432    pub fn server_capabilities(&self) -> &ServerCapabilities {
433        &self.server_caps
434    }
435
436    /// Get the client information.
437    pub fn client_info(&self) -> &ClientInfo {
438        &self.client_info
439    }
440
441    /// Get the client capabilities.
442    pub fn client_capabilities(&self) -> &ClientCapabilities {
443        &self.client_caps
444    }
445
446    /// Get the server instructions, if provided.
447    pub fn instructions(&self) -> Option<&str> {
448        self.instructions.as_deref()
449    }
450
451    /// Check if the server supports tools.
452    pub fn has_tools(&self) -> bool {
453        self.server_caps.has_tools()
454    }
455
456    /// Check if the server supports resources.
457    pub fn has_resources(&self) -> bool {
458        self.server_caps.has_resources()
459    }
460
461    /// Check if the server supports prompts.
462    pub fn has_prompts(&self) -> bool {
463        self.server_caps.has_prompts()
464    }
465
466    /// Check if the server supports tasks.
467    pub fn has_tasks(&self) -> bool {
468        self.server_caps.has_tasks()
469    }
470
471    /// Check if the client is still connected.
472    pub fn is_connected(&self) -> bool {
473        self.running.load(Ordering::SeqCst)
474    }
475
476    // ==========================================================================
477    // Tool Operations
478    // ==========================================================================
479
480    /// List all available tools.
481    ///
482    /// # Errors
483    ///
484    /// Returns an error if tools are not supported or the request fails.
485    pub async fn list_tools(&self) -> Result<Vec<Tool>, McpError> {
486        self.ensure_capability("tools", self.has_tools())?;
487
488        let result: ListToolsResult = self.request("tools/list", None).await?;
489        Ok(result.tools)
490    }
491
492    /// List tools with pagination.
493    ///
494    /// # Errors
495    ///
496    /// Returns an error if tools are not supported or the request fails.
497    pub async fn list_tools_paginated(
498        &self,
499        cursor: Option<&str>,
500    ) -> Result<ListToolsResult, McpError> {
501        self.ensure_capability("tools", self.has_tools())?;
502
503        let params = cursor.map(|c| serde_json::json!({ "cursor": c }));
504        self.request("tools/list", params).await
505    }
506
507    /// Call a tool by name.
508    ///
509    /// # Arguments
510    ///
511    /// * `name` - The name of the tool to call
512    /// * `arguments` - The arguments to pass to the tool (as JSON)
513    ///
514    /// # Errors
515    ///
516    /// Returns an error if tools are not supported or the call fails.
517    pub async fn call_tool(
518        &self,
519        name: impl Into<String>,
520        arguments: serde_json::Value,
521    ) -> Result<CallToolResult, McpError> {
522        self.ensure_capability("tools", self.has_tools())?;
523
524        let request = CallToolRequest {
525            name: name.into(),
526            arguments: Some(arguments),
527        };
528        self.request("tools/call", Some(serde_json::to_value(request)?))
529            .await
530    }
531
532    // ==========================================================================
533    // Resource Operations
534    // ==========================================================================
535
536    /// List all available resources.
537    ///
538    /// # Errors
539    ///
540    /// Returns an error if resources are not supported or the request fails.
541    pub async fn list_resources(&self) -> Result<Vec<Resource>, McpError> {
542        self.ensure_capability("resources", self.has_resources())?;
543
544        let result: ListResourcesResult = self.request("resources/list", None).await?;
545        Ok(result.resources)
546    }
547
548    /// List resources with pagination.
549    ///
550    /// # Errors
551    ///
552    /// Returns an error if resources are not supported or the request fails.
553    pub async fn list_resources_paginated(
554        &self,
555        cursor: Option<&str>,
556    ) -> Result<ListResourcesResult, McpError> {
557        self.ensure_capability("resources", self.has_resources())?;
558
559        let params = cursor.map(|c| serde_json::json!({ "cursor": c }));
560        self.request("resources/list", params).await
561    }
562
563    /// List resource templates.
564    ///
565    /// # Errors
566    ///
567    /// Returns an error if resources are not supported or the request fails.
568    pub async fn list_resource_templates(&self) -> Result<Vec<ResourceTemplate>, McpError> {
569        self.ensure_capability("resources", self.has_resources())?;
570
571        let result: ListResourceTemplatesResult =
572            self.request("resources/templates/list", None).await?;
573        Ok(result.resource_templates)
574    }
575
576    /// Read a resource by URI.
577    ///
578    /// # Errors
579    ///
580    /// Returns an error if resources are not supported or the read fails.
581    pub async fn read_resource(&self, uri: impl Into<String>) -> Result<Vec<ResourceContents>, McpError> {
582        self.ensure_capability("resources", self.has_resources())?;
583
584        let request = ReadResourceRequest { uri: uri.into() };
585        let result: ReadResourceResult =
586            self.request("resources/read", Some(serde_json::to_value(request)?))
587                .await?;
588        Ok(result.contents)
589    }
590
591    // ==========================================================================
592    // Prompt Operations
593    // ==========================================================================
594
595    /// List all available prompts.
596    ///
597    /// # Errors
598    ///
599    /// Returns an error if prompts are not supported or the request fails.
600    pub async fn list_prompts(&self) -> Result<Vec<Prompt>, McpError> {
601        self.ensure_capability("prompts", self.has_prompts())?;
602
603        let result: ListPromptsResult = self.request("prompts/list", None).await?;
604        Ok(result.prompts)
605    }
606
607    /// List prompts with pagination.
608    ///
609    /// # Errors
610    ///
611    /// Returns an error if prompts are not supported or the request fails.
612    pub async fn list_prompts_paginated(
613        &self,
614        cursor: Option<&str>,
615    ) -> Result<ListPromptsResult, McpError> {
616        self.ensure_capability("prompts", self.has_prompts())?;
617
618        let params = cursor.map(|c| serde_json::json!({ "cursor": c }));
619        self.request("prompts/list", params).await
620    }
621
622    /// Get a prompt by name, optionally with arguments.
623    ///
624    /// # Errors
625    ///
626    /// Returns an error if prompts are not supported or the get fails.
627    pub async fn get_prompt(
628        &self,
629        name: impl Into<String>,
630        arguments: Option<serde_json::Map<String, serde_json::Value>>,
631    ) -> Result<GetPromptResult, McpError> {
632        self.ensure_capability("prompts", self.has_prompts())?;
633
634        let request = GetPromptRequest {
635            name: name.into(),
636            arguments,
637        };
638        self.request("prompts/get", Some(serde_json::to_value(request)?))
639            .await
640    }
641
642    // ==========================================================================
643    // Connection Operations
644    // ==========================================================================
645
646    /// Ping the server.
647    ///
648    /// # Errors
649    ///
650    /// Returns an error if the ping fails.
651    pub async fn ping(&self) -> Result<(), McpError> {
652        let _: serde_json::Value = self.request("ping", None).await?;
653        Ok(())
654    }
655
656    /// Close the connection gracefully.
657    ///
658    /// # Errors
659    ///
660    /// Returns an error if the close fails.
661    pub async fn close(self) -> Result<(), McpError> {
662        debug!("Closing client connection");
663
664        // Signal the background task to stop
665        self.running.store(false, Ordering::SeqCst);
666
667        // Notify handler
668        self.handler.on_disconnected().await;
669
670        // Close the transport
671        self.transport.close().await.map_err(|e| {
672            McpError::Transport(Box::new(TransportDetails {
673                kind: TransportErrorKind::ConnectionClosed,
674                message: e.to_string(),
675                context: TransportContext::default(),
676                source: None,
677            }))
678        })
679    }
680
681    // ==========================================================================
682    // Internal Methods
683    // ==========================================================================
684
685    /// Generate the next request ID.
686    fn next_request_id(&self) -> RequestId {
687        RequestId::Number(self.next_id.fetch_add(1, Ordering::SeqCst))
688    }
689
690    /// Send a request and wait for the response.
691    async fn request<R: serde::de::DeserializeOwned>(
692        &self,
693        method: &str,
694        params: Option<serde_json::Value>,
695    ) -> Result<R, McpError> {
696        if !self.is_connected() {
697            return Err(McpError::Transport(Box::new(TransportDetails {
698                kind: TransportErrorKind::ConnectionClosed,
699                message: "Client is not connected".to_string(),
700                context: TransportContext::default(),
701                source: None,
702            })));
703        }
704
705        let id = self.next_request_id();
706        let request = if let Some(params) = params {
707            Request::with_params(method.to_string(), id.clone(), params)
708        } else {
709            Request::new(method.to_string(), id.clone())
710        };
711
712        trace!(?id, method, "Sending request");
713
714        // Create a channel for the response
715        let (tx, rx) = oneshot::channel();
716        {
717            let mut pending = self.pending.write().await;
718            pending.insert(id.clone(), tx);
719        }
720
721        // Send the request through the outgoing channel
722        self.outgoing_tx
723            .send(Message::Request(request))
724            .await
725            .map_err(|_| McpError::Transport(Box::new(TransportDetails {
726                kind: TransportErrorKind::WriteFailed,
727                message: "Failed to send request (channel closed)".to_string(),
728                context: TransportContext::default(),
729                source: None,
730            })))?;
731
732        // Wait for the response with a timeout
733        let response = rx.await.map_err(|_| McpError::Transport(Box::new(TransportDetails {
734            kind: TransportErrorKind::ConnectionClosed,
735            message: "Response channel closed (server may have disconnected)".to_string(),
736            context: TransportContext::default(),
737            source: None,
738        })))?;
739
740        // Process the response
741        if let Some(error) = response.error {
742            return Err(McpError::Internal {
743                message: error.message,
744                source: None,
745            });
746        }
747
748        let result = response.result.ok_or_else(|| McpError::Internal {
749            message: "Response contained neither result nor error".to_string(),
750            source: None,
751        })?;
752
753        serde_json::from_value(result).map_err(McpError::from)
754    }
755
756    /// Check that a capability is supported.
757    fn ensure_capability(&self, name: &str, supported: bool) -> Result<(), McpError> {
758        if supported {
759            Ok(())
760        } else {
761            Err(McpError::CapabilityNotSupported {
762                capability: name.to_string(),
763                available: self.available_capabilities().into_boxed_slice(),
764            })
765        }
766    }
767
768    /// Get list of available capabilities.
769    fn available_capabilities(&self) -> Vec<String> {
770        let mut caps = Vec::new();
771        if self.has_tools() {
772            caps.push("tools".to_string());
773        }
774        if self.has_resources() {
775            caps.push("resources".to_string());
776        }
777        if self.has_prompts() {
778            caps.push("prompts".to_string());
779        }
780        if self.has_tasks() {
781            caps.push("tasks".to_string());
782        }
783        caps
784    }
785}
786
787/// Initialize a client connection.
788///
789/// This performs the MCP handshake with protocol version negotiation:
790/// 1. Send initialize request with our preferred protocol version
791/// 2. Wait for initialize result with server's negotiated version
792/// 3. Validate we support the server's version (disconnect if not)
793/// 4. Send initialized notification
794///
795/// # Protocol Version Negotiation
796///
797/// Per the MCP specification:
798/// - Client sends its preferred (latest) protocol version
799/// - Server responds with the same version if supported, or its own preferred version
800/// - Client must support the server's version or the handshake fails
801///
802/// This SDK supports protocol versions: `2025-11-25`, `2024-11-05`.
803pub(crate) async fn initialize<T: Transport>(
804    transport: &T,
805    client_info: &ClientInfo,
806    capabilities: &ClientCapabilities,
807) -> Result<InitializeResult, McpError> {
808    debug!(
809        protocol_version = %PROTOCOL_VERSION,
810        supported_versions = ?SUPPORTED_PROTOCOL_VERSIONS,
811        "Initializing MCP connection"
812    );
813
814    // Build initialize request
815    let request = InitializeRequest::new(client_info.clone(), capabilities.clone());
816    let init_request = Request::with_params(
817        "initialize".to_string(),
818        RequestId::Number(0),
819        serde_json::to_value(&request)?,
820    );
821
822    // Send initialize request
823    transport
824        .send(Message::Request(init_request))
825        .await
826        .map_err(|e| McpError::Transport(Box::new(TransportDetails {
827            kind: TransportErrorKind::WriteFailed,
828            message: format!("Failed to send initialize: {e}"),
829            context: TransportContext::default(),
830            source: None,
831        })))?;
832
833    // Wait for response
834    let response = loop {
835        match transport.recv().await {
836            Ok(Some(Message::Response(r))) if r.id == RequestId::Number(0) => break r,
837            Ok(Some(_)) => continue,
838            Ok(None) => {
839                return Err(McpError::HandshakeFailed(Box::new(HandshakeDetails {
840                    message: "Connection closed during initialization".to_string(),
841                    client_version: Some(PROTOCOL_VERSION.to_string()),
842                    server_version: None,
843                    source: None,
844                })));
845            }
846            Err(e) => {
847                return Err(McpError::HandshakeFailed(Box::new(HandshakeDetails {
848                    message: format!("Transport error during initialization: {e}"),
849                    client_version: Some(PROTOCOL_VERSION.to_string()),
850                    server_version: None,
851                    source: None,
852                })));
853            }
854        }
855    };
856
857    // Parse the response
858    if let Some(error) = response.error {
859        return Err(McpError::HandshakeFailed(Box::new(HandshakeDetails {
860            message: error.message,
861            client_version: Some(PROTOCOL_VERSION.to_string()),
862            server_version: None,
863            source: None,
864        })));
865    }
866
867    let result: InitializeResult = response
868        .result
869        .map(serde_json::from_value)
870        .transpose()?
871        .ok_or_else(|| McpError::HandshakeFailed(Box::new(HandshakeDetails {
872            message: "Empty initialize result".to_string(),
873            client_version: Some(PROTOCOL_VERSION.to_string()),
874            server_version: None,
875            source: None,
876        })))?;
877
878    // Validate protocol version
879    let server_version = &result.protocol_version;
880    if !is_version_supported(server_version) {
881        warn!(
882            server_version = %server_version,
883            supported = ?SUPPORTED_PROTOCOL_VERSIONS,
884            "Server returned unsupported protocol version"
885        );
886        return Err(McpError::HandshakeFailed(Box::new(HandshakeDetails {
887            message: format!(
888                "Unsupported protocol version: server returned '{}', but client only supports {:?}",
889                server_version, SUPPORTED_PROTOCOL_VERSIONS
890            ),
891            client_version: Some(PROTOCOL_VERSION.to_string()),
892            server_version: Some(server_version.clone()),
893            source: None,
894        })));
895    }
896
897    debug!(
898        server = %result.server_info.name,
899        server_version = %result.server_info.version,
900        protocol_version = %result.protocol_version,
901        "Received initialize result with compatible protocol version"
902    );
903
904    // Send initialized notification
905    let notification = Notification::new("notifications/initialized");
906    transport
907        .send(Message::Notification(notification))
908        .await
909        .map_err(|e| McpError::Transport(Box::new(TransportDetails {
910            kind: TransportErrorKind::WriteFailed,
911            message: format!("Failed to send initialized: {e}"),
912            context: TransportContext::default(),
913            source: None,
914        })))?;
915
916    debug!("MCP initialization complete");
917    Ok(result)
918}
919
920#[cfg(test)]
921mod tests {
922    use super::*;
923
924    #[test]
925    fn test_request_id_generation() {
926        let next_id = AtomicU64::new(1);
927        assert_eq!(next_id.fetch_add(1, Ordering::SeqCst), 1);
928        assert_eq!(next_id.fetch_add(1, Ordering::SeqCst), 2);
929    }
930}