mcpkit_client/
client.rs

1//! MCP client implementation.
2//!
3//! The [`Client`] struct provides a high-level API for interacting with
4//! MCP servers. It handles:
5//!
6//! - Protocol initialization
7//! - Request/response correlation
8//! - Tool, resource, and prompt operations
9//! - Task tracking
10//! - Connection lifecycle
11//! - Server-initiated request handling via [`ClientHandler`]
12
13use futures::channel::oneshot;
14use mcpkit_core::capability::{
15    is_version_supported, ClientCapabilities, ClientInfo, InitializeRequest, InitializeResult,
16    ServerCapabilities, ServerInfo, PROTOCOL_VERSION, SUPPORTED_PROTOCOL_VERSIONS,
17};
18use mcpkit_core::error::{
19    HandshakeDetails, JsonRpcError, McpError, TransportContext, TransportDetails,
20    TransportErrorKind,
21};
22use mcpkit_core::protocol::{Message, Notification, Request, RequestId, Response};
23use mcpkit_core::types::{
24    CallToolRequest, CallToolResult, CancelTaskRequest, CompleteRequest, CompleteResult,
25    CompletionArgument, CompletionRef, CreateMessageRequest, ElicitRequest, GetPromptRequest,
26    GetPromptResult, GetTaskRequest, ListPromptsResult, ListResourceTemplatesResult,
27    ListResourcesResult, ListTasksRequest, ListTasksResult, ListToolsResult, Prompt,
28    ReadResourceRequest, ReadResourceResult, Resource, ResourceContents, ResourceTemplate, Task,
29    TaskStatus, TaskSummary, Tool,
30};
31use mcpkit_transport::Transport;
32use std::collections::HashMap;
33use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
34use std::sync::Arc;
35use tracing::{debug, error, info, trace, warn};
36
37// Runtime-agnostic sync primitives
38use async_lock::RwLock;
39
40// Use tokio channels when tokio-runtime is enabled, otherwise use the transport abstraction
41#[cfg(feature = "tokio-runtime")]
42use tokio::sync::mpsc;
43
44use crate::handler::ClientHandler;
45
46/// An MCP client connected to a server.
47///
48/// The client provides methods for interacting with MCP servers:
49///
50/// - Tools: `list_tools()`, `call_tool()`
51/// - Resources: `list_resources()`, `read_resource()`
52/// - Prompts: `list_prompts()`, `get_prompt()`
53/// - Tasks: `list_tasks()`, `get_task()`, `cancel_task()`
54///
55/// The client also handles server-initiated requests (sampling, elicitation)
56/// by delegating to a [`ClientHandler`] implementation.
57///
58/// # Example
59///
60/// ```no_run
61/// use mcpkit_client::ClientBuilder;
62/// use mcpkit_transport::SpawnedTransport;
63///
64/// # async fn example() -> Result<(), mcpkit_core::error::McpError> {
65/// let transport = SpawnedTransport::spawn("my-server", &[] as &[&str]).await?;
66/// let client = ClientBuilder::new()
67///     .name("my-client")
68///     .version("1.0.0")
69///     .build(transport)
70///     .await?;
71///
72/// let tools = client.list_tools().await?;
73/// # Ok(())
74/// # }
75/// ```
76pub struct Client<T: Transport, H: ClientHandler = crate::handler::NoOpHandler> {
77    /// The underlying transport (shared with background task).
78    transport: Arc<T>,
79    /// Server information received during initialization.
80    server_info: ServerInfo,
81    /// Server capabilities.
82    server_caps: ServerCapabilities,
83    /// Client information.
84    client_info: ClientInfo,
85    /// Client capabilities.
86    client_caps: ClientCapabilities,
87    /// Next request ID.
88    next_id: AtomicU64,
89    /// Pending requests awaiting responses.
90    pending: Arc<RwLock<HashMap<RequestId, oneshot::Sender<Response>>>>,
91    /// Instructions from the server.
92    instructions: Option<String>,
93    /// Handler for server-initiated requests.
94    handler: Arc<H>,
95    /// Sender for outgoing messages to the background task.
96    outgoing_tx: mpsc::Sender<Message>,
97    /// Flag indicating if the client is running.
98    running: Arc<AtomicBool>,
99    /// Handle to the background task.
100    _background_handle: Option<tokio::task::JoinHandle<()>>,
101}
102
103impl<T: Transport + 'static> Client<T, crate::handler::NoOpHandler> {
104    /// Create a new client without a handler (called by builder).
105    pub(crate) fn new(
106        transport: T,
107        init_result: InitializeResult,
108        client_info: ClientInfo,
109        client_caps: ClientCapabilities,
110    ) -> Self {
111        Self::with_handler(
112            transport,
113            init_result,
114            client_info,
115            client_caps,
116            crate::handler::NoOpHandler,
117        )
118    }
119}
120
121impl<T: Transport + 'static, H: ClientHandler + 'static> Client<T, H> {
122    /// Create a new client with a custom handler (called by builder).
123    pub(crate) fn with_handler(
124        transport: T,
125        init_result: InitializeResult,
126        client_info: ClientInfo,
127        client_caps: ClientCapabilities,
128        handler: H,
129    ) -> Self {
130        let transport = Arc::new(transport);
131        let pending = Arc::new(RwLock::new(HashMap::new()));
132        let handler = Arc::new(handler);
133        let running = Arc::new(AtomicBool::new(true));
134
135        // Create channel for outgoing messages
136        let (outgoing_tx, outgoing_rx) = mpsc::channel::<Message>(256);
137
138        // Start background message routing task
139        let background_handle = Self::spawn_message_router(
140            Arc::clone(&transport),
141            Arc::clone(&pending),
142            Arc::clone(&handler),
143            Arc::clone(&running),
144            outgoing_rx,
145        );
146
147        // Notify handler that connection is established
148        let handler_clone = Arc::clone(&handler);
149        tokio::spawn(async move {
150            handler_clone.on_connected().await;
151        });
152
153        Self {
154            transport,
155            server_info: init_result.server_info,
156            server_caps: init_result.capabilities,
157            client_info,
158            client_caps,
159            next_id: AtomicU64::new(1),
160            pending,
161            instructions: init_result.instructions,
162            handler,
163            outgoing_tx,
164            running,
165            _background_handle: Some(background_handle),
166        }
167    }
168
169    /// Spawn the background message routing task.
170    ///
171    /// This task:
172    /// - Reads incoming messages from the transport
173    /// - Routes responses to pending request channels
174    /// - Delegates server-initiated requests to the handler
175    /// - Handles notifications
176    fn spawn_message_router(
177        transport: Arc<T>,
178        pending: Arc<RwLock<HashMap<RequestId, oneshot::Sender<Response>>>>,
179        handler: Arc<H>,
180        running: Arc<AtomicBool>,
181        mut outgoing_rx: mpsc::Receiver<Message>,
182    ) -> tokio::task::JoinHandle<()> {
183        tokio::spawn(async move {
184            debug!("Starting client message router");
185
186            loop {
187                if !running.load(Ordering::SeqCst) {
188                    debug!("Message router stopping (client closed)");
189                    break;
190                }
191
192                tokio::select! {
193                    // Handle outgoing messages
194                    Some(msg) = outgoing_rx.recv() => {
195                        if let Err(e) = transport.send(msg).await {
196                            error!(?e, "Failed to send message");
197                        }
198                    }
199
200                    // Handle incoming messages
201                    result = transport.recv() => {
202                        match result {
203                            Ok(Some(message)) => {
204                                Self::handle_incoming_message(
205                                    message,
206                                    &pending,
207                                    &handler,
208                                    &transport,
209                                ).await;
210                            }
211                            Ok(None) => {
212                                info!("Connection closed by server");
213                                running.store(false, Ordering::SeqCst);
214                                handler.on_disconnected().await;
215                                break;
216                            }
217                            Err(e) => {
218                                error!(?e, "Transport error in message router");
219                                running.store(false, Ordering::SeqCst);
220                                handler.on_disconnected().await;
221                                break;
222                            }
223                        }
224                    }
225                }
226            }
227
228            debug!("Message router stopped");
229        })
230    }
231
232    /// Handle an incoming message from the server.
233    async fn handle_incoming_message(
234        message: Message,
235        pending: &Arc<RwLock<HashMap<RequestId, oneshot::Sender<Response>>>>,
236        handler: &Arc<H>,
237        transport: &Arc<T>,
238    ) {
239        match message {
240            Message::Response(response) => {
241                Self::route_response(response, pending).await;
242            }
243            Message::Request(request) => {
244                Self::handle_server_request(request, handler, transport).await;
245            }
246            Message::Notification(notification) => {
247                Self::handle_notification(notification, handler).await;
248            }
249        }
250    }
251
252    /// Route a response to the appropriate pending request.
253    async fn route_response(
254        response: Response,
255        pending: &Arc<RwLock<HashMap<RequestId, oneshot::Sender<Response>>>>,
256    ) {
257        let sender = {
258            let mut pending_guard = pending.write().await;
259            pending_guard.remove(&response.id)
260        };
261
262        if let Some(sender) = sender {
263            trace!(?response.id, "Routing response to pending request");
264            if sender.send(response).is_err() {
265                warn!("Pending request receiver dropped");
266            }
267        } else {
268            warn!(?response.id, "Received response for unknown request");
269        }
270    }
271
272    /// Handle a server-initiated request.
273    async fn handle_server_request(request: Request, handler: &Arc<H>, transport: &Arc<T>) {
274        trace!(method = %request.method, "Handling server request");
275
276        let response = match request.method.as_ref() {
277            "sampling/createMessage" => Self::handle_sampling_request(&request, handler).await,
278            "elicitation/elicit" => Self::handle_elicitation_request(&request, handler).await,
279            "roots/list" => Self::handle_roots_request(&request, handler).await,
280            "ping" => {
281                // Respond to ping with empty result
282                Response::success(request.id.clone(), serde_json::json!({}))
283            }
284            _ => {
285                warn!(method = %request.method, "Unknown server request method");
286                Response::error(
287                    request.id.clone(),
288                    JsonRpcError::method_not_found(format!("Unknown method: {}", request.method)),
289                )
290            }
291        };
292
293        // Send the response
294        if let Err(e) = transport.send(Message::Response(response)).await {
295            error!(?e, "Failed to send response to server request");
296        }
297    }
298
299    /// Handle a sampling/createMessage request.
300    async fn handle_sampling_request(request: &Request, handler: &Arc<H>) -> Response {
301        let params = match &request.params {
302            Some(p) => match serde_json::from_value::<CreateMessageRequest>(p.clone()) {
303                Ok(req) => req,
304                Err(e) => {
305                    return Response::error(
306                        request.id.clone(),
307                        JsonRpcError::invalid_params(format!("Invalid params: {e}")),
308                    );
309                }
310            },
311            None => {
312                return Response::error(
313                    request.id.clone(),
314                    JsonRpcError::invalid_params("Missing params for sampling/createMessage"),
315                );
316            }
317        };
318
319        match handler.create_message(params).await {
320            Ok(result) => match serde_json::to_value(result) {
321                Ok(value) => Response::success(request.id.clone(), value),
322                Err(e) => Response::error(
323                    request.id.clone(),
324                    JsonRpcError::internal_error(format!("Serialization error: {e}")),
325                ),
326            },
327            Err(e) => Response::error(
328                request.id.clone(),
329                JsonRpcError::internal_error(e.to_string()),
330            ),
331        }
332    }
333
334    /// Handle an elicitation/elicit request.
335    async fn handle_elicitation_request(request: &Request, handler: &Arc<H>) -> Response {
336        let params = match &request.params {
337            Some(p) => match serde_json::from_value::<ElicitRequest>(p.clone()) {
338                Ok(req) => req,
339                Err(e) => {
340                    return Response::error(
341                        request.id.clone(),
342                        JsonRpcError::invalid_params(format!("Invalid params: {e}")),
343                    );
344                }
345            },
346            None => {
347                return Response::error(
348                    request.id.clone(),
349                    JsonRpcError::invalid_params("Missing params for elicitation/elicit"),
350                );
351            }
352        };
353
354        match handler.elicit(params).await {
355            Ok(result) => match serde_json::to_value(result) {
356                Ok(value) => Response::success(request.id.clone(), value),
357                Err(e) => Response::error(
358                    request.id.clone(),
359                    JsonRpcError::internal_error(format!("Serialization error: {e}")),
360                ),
361            },
362            Err(e) => Response::error(
363                request.id.clone(),
364                JsonRpcError::internal_error(e.to_string()),
365            ),
366        }
367    }
368
369    /// Handle a roots/list request.
370    async fn handle_roots_request(request: &Request, handler: &Arc<H>) -> Response {
371        match handler.list_roots().await {
372            Ok(roots) => {
373                let roots_json: Vec<serde_json::Value> = roots
374                    .into_iter()
375                    .map(|r| {
376                        serde_json::json!({
377                            "uri": r.uri,
378                            "name": r.name
379                        })
380                    })
381                    .collect();
382                Response::success(
383                    request.id.clone(),
384                    serde_json::json!({ "roots": roots_json }),
385                )
386            }
387            Err(e) => Response::error(
388                request.id.clone(),
389                JsonRpcError::internal_error(e.to_string()),
390            ),
391        }
392    }
393
394    /// Handle a notification from the server.
395    async fn handle_notification(notification: Notification, handler: &Arc<H>) {
396        trace!(method = %notification.method, "Received server notification");
397
398        match notification.method.as_ref() {
399            "notifications/cancelled" => {
400                // Handle cancellation notifications
401                if let Some(params) = &notification.params {
402                    if let Some(request_id) = params.get("requestId") {
403                        debug!(?request_id, "Server cancelled request");
404                    }
405                }
406            }
407            "notifications/progress" => {
408                // Handle progress notifications
409                if let Some(params) = notification.params {
410                    if let (Some(task_id), Some(progress)) = (
411                        params.get("progressToken").and_then(|v| v.as_str()),
412                        params.get("progress"),
413                    ) {
414                        if let Ok(progress) = serde_json::from_value::<
415                            mcpkit_core::types::TaskProgress,
416                        >(progress.clone())
417                        {
418                            debug!(task_id = %task_id, "Task progress update");
419                            handler.on_task_progress(task_id.into(), progress).await;
420                        }
421                    }
422                }
423            }
424            "notifications/resources/updated" => {
425                if let Some(params) = notification.params {
426                    if let Some(uri) = params.get("uri").and_then(|v| v.as_str()) {
427                        debug!(uri = %uri, "Resource updated");
428                        handler.on_resource_updated(uri.to_string()).await;
429                    }
430                }
431            }
432            "notifications/resources/list_changed" => {
433                debug!("Resources list changed");
434                handler.on_resources_list_changed().await;
435            }
436            "notifications/tools/list_changed" => {
437                debug!("Tools list changed");
438                handler.on_tools_list_changed().await;
439            }
440            "notifications/prompts/list_changed" => {
441                debug!("Prompts list changed");
442                handler.on_prompts_list_changed().await;
443            }
444            _ => {
445                trace!(method = %notification.method, "Unhandled notification");
446            }
447        }
448    }
449
450    /// Get the server information.
451    pub const fn server_info(&self) -> &ServerInfo {
452        &self.server_info
453    }
454
455    /// Get the server capabilities.
456    pub const fn server_capabilities(&self) -> &ServerCapabilities {
457        &self.server_caps
458    }
459
460    /// Get the client information.
461    pub const fn client_info(&self) -> &ClientInfo {
462        &self.client_info
463    }
464
465    /// Get the client capabilities.
466    pub const fn client_capabilities(&self) -> &ClientCapabilities {
467        &self.client_caps
468    }
469
470    /// Get the server instructions, if provided.
471    pub fn instructions(&self) -> Option<&str> {
472        self.instructions.as_deref()
473    }
474
475    /// Check if the server supports tools.
476    pub const fn has_tools(&self) -> bool {
477        self.server_caps.has_tools()
478    }
479
480    /// Check if the server supports resources.
481    pub const fn has_resources(&self) -> bool {
482        self.server_caps.has_resources()
483    }
484
485    /// Check if the server supports prompts.
486    pub const fn has_prompts(&self) -> bool {
487        self.server_caps.has_prompts()
488    }
489
490    /// Check if the server supports tasks.
491    pub const fn has_tasks(&self) -> bool {
492        self.server_caps.has_tasks()
493    }
494
495    /// Check if the server supports completions.
496    pub const fn has_completions(&self) -> bool {
497        self.server_caps.has_completions()
498    }
499
500    /// Check if the client is still connected.
501    pub fn is_connected(&self) -> bool {
502        self.running.load(Ordering::SeqCst)
503    }
504
505    // ==========================================================================
506    // Tool Operations
507    // ==========================================================================
508
509    /// List all available tools.
510    ///
511    /// # Errors
512    ///
513    /// Returns an error if tools are not supported or the request fails.
514    pub async fn list_tools(&self) -> Result<Vec<Tool>, McpError> {
515        self.ensure_capability("tools", self.has_tools())?;
516
517        let result: ListToolsResult = self.request("tools/list", None).await?;
518        Ok(result.tools)
519    }
520
521    /// List tools with pagination.
522    ///
523    /// # Errors
524    ///
525    /// Returns an error if tools are not supported or the request fails.
526    pub async fn list_tools_paginated(
527        &self,
528        cursor: Option<&str>,
529    ) -> Result<ListToolsResult, McpError> {
530        self.ensure_capability("tools", self.has_tools())?;
531
532        let params = cursor.map(|c| serde_json::json!({ "cursor": c }));
533        self.request("tools/list", params).await
534    }
535
536    /// Call a tool by name.
537    ///
538    /// # Arguments
539    ///
540    /// * `name` - The name of the tool to call
541    /// * `arguments` - The arguments to pass to the tool (as JSON)
542    ///
543    /// # Errors
544    ///
545    /// Returns an error if tools are not supported or the call fails.
546    pub async fn call_tool(
547        &self,
548        name: impl Into<String>,
549        arguments: serde_json::Value,
550    ) -> Result<CallToolResult, McpError> {
551        self.ensure_capability("tools", self.has_tools())?;
552
553        let request = CallToolRequest {
554            name: name.into(),
555            arguments: Some(arguments),
556        };
557        self.request("tools/call", Some(serde_json::to_value(request)?))
558            .await
559    }
560
561    // ==========================================================================
562    // Resource Operations
563    // ==========================================================================
564
565    /// List all available resources.
566    ///
567    /// # Errors
568    ///
569    /// Returns an error if resources are not supported or the request fails.
570    pub async fn list_resources(&self) -> Result<Vec<Resource>, McpError> {
571        self.ensure_capability("resources", self.has_resources())?;
572
573        let result: ListResourcesResult = self.request("resources/list", None).await?;
574        Ok(result.resources)
575    }
576
577    /// List resources with pagination.
578    ///
579    /// # Errors
580    ///
581    /// Returns an error if resources are not supported or the request fails.
582    pub async fn list_resources_paginated(
583        &self,
584        cursor: Option<&str>,
585    ) -> Result<ListResourcesResult, McpError> {
586        self.ensure_capability("resources", self.has_resources())?;
587
588        let params = cursor.map(|c| serde_json::json!({ "cursor": c }));
589        self.request("resources/list", params).await
590    }
591
592    /// List resource templates.
593    ///
594    /// # Errors
595    ///
596    /// Returns an error if resources are not supported or the request fails.
597    pub async fn list_resource_templates(&self) -> Result<Vec<ResourceTemplate>, McpError> {
598        self.ensure_capability("resources", self.has_resources())?;
599
600        let result: ListResourceTemplatesResult =
601            self.request("resources/templates/list", None).await?;
602        Ok(result.resource_templates)
603    }
604
605    /// Read a resource by URI.
606    ///
607    /// # Errors
608    ///
609    /// Returns an error if resources are not supported or the read fails.
610    pub async fn read_resource(
611        &self,
612        uri: impl Into<String>,
613    ) -> Result<Vec<ResourceContents>, McpError> {
614        self.ensure_capability("resources", self.has_resources())?;
615
616        let request = ReadResourceRequest { uri: uri.into() };
617        let result: ReadResourceResult = self
618            .request("resources/read", Some(serde_json::to_value(request)?))
619            .await?;
620        Ok(result.contents)
621    }
622
623    // ==========================================================================
624    // Prompt Operations
625    // ==========================================================================
626
627    /// List all available prompts.
628    ///
629    /// # Errors
630    ///
631    /// Returns an error if prompts are not supported or the request fails.
632    pub async fn list_prompts(&self) -> Result<Vec<Prompt>, McpError> {
633        self.ensure_capability("prompts", self.has_prompts())?;
634
635        let result: ListPromptsResult = self.request("prompts/list", None).await?;
636        Ok(result.prompts)
637    }
638
639    /// List prompts with pagination.
640    ///
641    /// # Errors
642    ///
643    /// Returns an error if prompts are not supported or the request fails.
644    pub async fn list_prompts_paginated(
645        &self,
646        cursor: Option<&str>,
647    ) -> Result<ListPromptsResult, McpError> {
648        self.ensure_capability("prompts", self.has_prompts())?;
649
650        let params = cursor.map(|c| serde_json::json!({ "cursor": c }));
651        self.request("prompts/list", params).await
652    }
653
654    /// Get a prompt by name, optionally with arguments.
655    ///
656    /// # Errors
657    ///
658    /// Returns an error if prompts are not supported or the get fails.
659    pub async fn get_prompt(
660        &self,
661        name: impl Into<String>,
662        arguments: Option<serde_json::Map<String, serde_json::Value>>,
663    ) -> Result<GetPromptResult, McpError> {
664        self.ensure_capability("prompts", self.has_prompts())?;
665
666        let request = GetPromptRequest {
667            name: name.into(),
668            arguments,
669        };
670        self.request("prompts/get", Some(serde_json::to_value(request)?))
671            .await
672    }
673
674    // ==========================================================================
675    // Task Operations
676    // ==========================================================================
677
678    /// List all tasks.
679    ///
680    /// # Errors
681    ///
682    /// Returns an error if tasks are not supported or the request fails.
683    pub async fn list_tasks(&self) -> Result<Vec<TaskSummary>, McpError> {
684        self.ensure_capability("tasks", self.has_tasks())?;
685
686        let result: ListTasksResult = self.request("tasks/list", None).await?;
687        Ok(result.tasks)
688    }
689
690    /// List tasks with optional status filter and pagination.
691    ///
692    /// # Errors
693    ///
694    /// Returns an error if tasks are not supported or the request fails.
695    pub async fn list_tasks_filtered(
696        &self,
697        status: Option<TaskStatus>,
698        cursor: Option<&str>,
699    ) -> Result<ListTasksResult, McpError> {
700        self.ensure_capability("tasks", self.has_tasks())?;
701
702        let request = ListTasksRequest {
703            status,
704            cursor: cursor.map(String::from),
705        };
706        self.request("tasks/list", Some(serde_json::to_value(request)?))
707            .await
708    }
709
710    /// Get a task by ID.
711    ///
712    /// # Errors
713    ///
714    /// Returns an error if tasks are not supported or the task is not found.
715    pub async fn get_task(&self, id: impl Into<String>) -> Result<Task, McpError> {
716        self.ensure_capability("tasks", self.has_tasks())?;
717
718        let request = GetTaskRequest {
719            id: id.into().into(),
720        };
721        self.request("tasks/get", Some(serde_json::to_value(request)?))
722            .await
723    }
724
725    /// Cancel a running task.
726    ///
727    /// # Errors
728    ///
729    /// Returns an error if tasks are not supported, cancellation is not supported,
730    /// or the task is not found.
731    pub async fn cancel_task(&self, id: impl Into<String>) -> Result<(), McpError> {
732        self.ensure_capability("tasks", self.has_tasks())?;
733
734        let request = CancelTaskRequest {
735            id: id.into().into(),
736        };
737        let _: serde_json::Value = self
738            .request("tasks/cancel", Some(serde_json::to_value(request)?))
739            .await?;
740        Ok(())
741    }
742
743    // ==========================================================================
744    // Completion Operations
745    // ==========================================================================
746
747    /// Get completions for a prompt argument.
748    ///
749    /// # Arguments
750    ///
751    /// * `prompt_name` - The name of the prompt
752    /// * `argument_name` - The name of the argument to complete
753    /// * `current_value` - The current partial value being typed
754    ///
755    /// # Errors
756    ///
757    /// Returns an error if completions are not supported or the request fails.
758    pub async fn complete_prompt_argument(
759        &self,
760        prompt_name: impl Into<String>,
761        argument_name: impl Into<String>,
762        current_value: impl Into<String>,
763    ) -> Result<CompleteResult, McpError> {
764        self.ensure_capability("completions", self.has_completions())?;
765
766        let request = CompleteRequest {
767            ref_: CompletionRef::prompt(prompt_name),
768            argument: CompletionArgument {
769                name: argument_name.into(),
770                value: current_value.into(),
771            },
772        };
773        self.request("completion/complete", Some(serde_json::to_value(request)?))
774            .await
775    }
776
777    /// Get completions for a resource argument.
778    ///
779    /// # Arguments
780    ///
781    /// * `resource_uri` - The URI of the resource
782    /// * `argument_name` - The name of the argument to complete
783    /// * `current_value` - The current partial value being typed
784    ///
785    /// # Errors
786    ///
787    /// Returns an error if completions are not supported or the request fails.
788    pub async fn complete_resource_argument(
789        &self,
790        resource_uri: impl Into<String>,
791        argument_name: impl Into<String>,
792        current_value: impl Into<String>,
793    ) -> Result<CompleteResult, McpError> {
794        self.ensure_capability("completions", self.has_completions())?;
795
796        let request = CompleteRequest {
797            ref_: CompletionRef::resource(resource_uri),
798            argument: CompletionArgument {
799                name: argument_name.into(),
800                value: current_value.into(),
801            },
802        };
803        self.request("completion/complete", Some(serde_json::to_value(request)?))
804            .await
805    }
806
807    // ==========================================================================
808    // Resource Subscription Operations
809    // ==========================================================================
810
811    /// Subscribe to updates for a resource.
812    ///
813    /// When subscribed, the server will send `notifications/resources/updated`
814    /// when the resource changes.
815    ///
816    /// # Errors
817    ///
818    /// Returns an error if resource subscriptions are not supported or the request fails.
819    pub async fn subscribe_resource(&self, uri: impl Into<String>) -> Result<(), McpError> {
820        self.ensure_capability("resources", self.has_resources())?;
821
822        // Check if subscribe is supported
823        if !self.server_caps.has_resource_subscribe() {
824            return Err(McpError::CapabilityNotSupported {
825                capability: "resources.subscribe".to_string(),
826                available: self.available_capabilities().into_boxed_slice(),
827            });
828        }
829
830        let params = serde_json::json!({ "uri": uri.into() });
831        let _: serde_json::Value = self.request("resources/subscribe", Some(params)).await?;
832        Ok(())
833    }
834
835    /// Unsubscribe from updates for a resource.
836    ///
837    /// # Errors
838    ///
839    /// Returns an error if resource subscriptions are not supported or the request fails.
840    pub async fn unsubscribe_resource(&self, uri: impl Into<String>) -> Result<(), McpError> {
841        self.ensure_capability("resources", self.has_resources())?;
842
843        // Check if subscribe is supported
844        if !self.server_caps.has_resource_subscribe() {
845            return Err(McpError::CapabilityNotSupported {
846                capability: "resources.subscribe".to_string(),
847                available: self.available_capabilities().into_boxed_slice(),
848            });
849        }
850
851        let params = serde_json::json!({ "uri": uri.into() });
852        let _: serde_json::Value = self.request("resources/unsubscribe", Some(params)).await?;
853        Ok(())
854    }
855
856    // ==========================================================================
857    // Connection Operations
858    // ==========================================================================
859
860    /// Ping the server.
861    ///
862    /// # Errors
863    ///
864    /// Returns an error if the ping fails.
865    pub async fn ping(&self) -> Result<(), McpError> {
866        let _: serde_json::Value = self.request("ping", None).await?;
867        Ok(())
868    }
869
870    /// Close the connection gracefully.
871    ///
872    /// # Errors
873    ///
874    /// Returns an error if the close fails.
875    pub async fn close(self) -> Result<(), McpError> {
876        debug!("Closing client connection");
877
878        // Signal the background task to stop
879        self.running.store(false, Ordering::SeqCst);
880
881        // Notify handler
882        self.handler.on_disconnected().await;
883
884        // Close the transport
885        self.transport.close().await.map_err(|e| {
886            McpError::Transport(Box::new(TransportDetails {
887                kind: TransportErrorKind::ConnectionClosed,
888                message: e.to_string(),
889                context: TransportContext::default(),
890                source: None,
891            }))
892        })
893    }
894
895    // ==========================================================================
896    // Internal Methods
897    // ==========================================================================
898
899    /// Generate the next request ID.
900    fn next_request_id(&self) -> RequestId {
901        RequestId::Number(self.next_id.fetch_add(1, Ordering::SeqCst))
902    }
903
904    /// Send a request and wait for the response.
905    async fn request<R: serde::de::DeserializeOwned>(
906        &self,
907        method: &str,
908        params: Option<serde_json::Value>,
909    ) -> Result<R, McpError> {
910        if !self.is_connected() {
911            return Err(McpError::Transport(Box::new(TransportDetails {
912                kind: TransportErrorKind::ConnectionClosed,
913                message: "Client is not connected".to_string(),
914                context: TransportContext::default(),
915                source: None,
916            })));
917        }
918
919        let id = self.next_request_id();
920        let request = if let Some(params) = params {
921            Request::with_params(method.to_string(), id.clone(), params)
922        } else {
923            Request::new(method.to_string(), id.clone())
924        };
925
926        trace!(?id, method, "Sending request");
927
928        // Create a channel for the response
929        let (tx, rx) = oneshot::channel();
930        {
931            let mut pending = self.pending.write().await;
932            pending.insert(id.clone(), tx);
933        }
934
935        // Send the request through the outgoing channel
936        self.outgoing_tx
937            .send(Message::Request(request))
938            .await
939            .map_err(|_| {
940                McpError::Transport(Box::new(TransportDetails {
941                    kind: TransportErrorKind::WriteFailed,
942                    message: "Failed to send request (channel closed)".to_string(),
943                    context: TransportContext::default(),
944                    source: None,
945                }))
946            })?;
947
948        // Wait for the response with a timeout
949        let response = rx.await.map_err(|_| {
950            McpError::Transport(Box::new(TransportDetails {
951                kind: TransportErrorKind::ConnectionClosed,
952                message: "Response channel closed (server may have disconnected)".to_string(),
953                context: TransportContext::default(),
954                source: None,
955            }))
956        })?;
957
958        // Process the response
959        if let Some(error) = response.error {
960            return Err(McpError::Internal {
961                message: error.message,
962                source: None,
963            });
964        }
965
966        let result = response.result.ok_or_else(|| McpError::Internal {
967            message: "Response contained neither result nor error".to_string(),
968            source: None,
969        })?;
970
971        serde_json::from_value(result).map_err(McpError::from)
972    }
973
974    /// Check that a capability is supported.
975    fn ensure_capability(&self, name: &str, supported: bool) -> Result<(), McpError> {
976        if supported {
977            Ok(())
978        } else {
979            Err(McpError::CapabilityNotSupported {
980                capability: name.to_string(),
981                available: self.available_capabilities().into_boxed_slice(),
982            })
983        }
984    }
985
986    /// Get list of available capabilities.
987    fn available_capabilities(&self) -> Vec<String> {
988        let mut caps = Vec::new();
989        if self.has_tools() {
990            caps.push("tools".to_string());
991        }
992        if self.has_resources() {
993            caps.push("resources".to_string());
994        }
995        if self.has_prompts() {
996            caps.push("prompts".to_string());
997        }
998        if self.has_tasks() {
999            caps.push("tasks".to_string());
1000        }
1001        if self.has_completions() {
1002            caps.push("completions".to_string());
1003        }
1004        caps
1005    }
1006}
1007
1008/// Initialize a client connection.
1009///
1010/// This performs the MCP handshake with protocol version negotiation:
1011/// 1. Send initialize request with our preferred protocol version
1012/// 2. Wait for initialize result with server's negotiated version
1013/// 3. Validate we support the server's version (disconnect if not)
1014/// 4. Send initialized notification
1015///
1016/// # Protocol Version Negotiation
1017///
1018/// Per the MCP specification:
1019/// - Client sends its preferred (latest) protocol version
1020/// - Server responds with the same version if supported, or its own preferred version
1021/// - Client must support the server's version or the handshake fails
1022///
1023/// This SDK supports protocol versions: `2025-11-25`, `2024-11-05`.
1024pub(crate) async fn initialize<T: Transport>(
1025    transport: &T,
1026    client_info: &ClientInfo,
1027    capabilities: &ClientCapabilities,
1028) -> Result<InitializeResult, McpError> {
1029    debug!(
1030        protocol_version = %PROTOCOL_VERSION,
1031        supported_versions = ?SUPPORTED_PROTOCOL_VERSIONS,
1032        "Initializing MCP connection"
1033    );
1034
1035    // Build initialize request
1036    let request = InitializeRequest::new(client_info.clone(), capabilities.clone());
1037    let init_request = Request::with_params(
1038        "initialize".to_string(),
1039        RequestId::Number(0),
1040        serde_json::to_value(&request)?,
1041    );
1042
1043    // Send initialize request
1044    transport
1045        .send(Message::Request(init_request))
1046        .await
1047        .map_err(|e| {
1048            McpError::Transport(Box::new(TransportDetails {
1049                kind: TransportErrorKind::WriteFailed,
1050                message: format!("Failed to send initialize: {e}"),
1051                context: TransportContext::default(),
1052                source: None,
1053            }))
1054        })?;
1055
1056    // Wait for response
1057    let response = loop {
1058        match transport.recv().await {
1059            Ok(Some(Message::Response(r))) if r.id == RequestId::Number(0) => break r,
1060            Ok(Some(_)) => {} // Skip non-matching messages
1061            Ok(None) => {
1062                return Err(McpError::HandshakeFailed(Box::new(HandshakeDetails {
1063                    message: "Connection closed during initialization".to_string(),
1064                    client_version: Some(PROTOCOL_VERSION.to_string()),
1065                    server_version: None,
1066                    source: None,
1067                })));
1068            }
1069            Err(e) => {
1070                return Err(McpError::HandshakeFailed(Box::new(HandshakeDetails {
1071                    message: format!("Transport error during initialization: {e}"),
1072                    client_version: Some(PROTOCOL_VERSION.to_string()),
1073                    server_version: None,
1074                    source: None,
1075                })));
1076            }
1077        }
1078    };
1079
1080    // Parse the response
1081    if let Some(error) = response.error {
1082        return Err(McpError::HandshakeFailed(Box::new(HandshakeDetails {
1083            message: error.message,
1084            client_version: Some(PROTOCOL_VERSION.to_string()),
1085            server_version: None,
1086            source: None,
1087        })));
1088    }
1089
1090    let result: InitializeResult = response
1091        .result
1092        .map(serde_json::from_value)
1093        .transpose()?
1094        .ok_or_else(|| {
1095            McpError::HandshakeFailed(Box::new(HandshakeDetails {
1096                message: "Empty initialize result".to_string(),
1097                client_version: Some(PROTOCOL_VERSION.to_string()),
1098                server_version: None,
1099                source: None,
1100            }))
1101        })?;
1102
1103    // Validate protocol version
1104    let server_version = &result.protocol_version;
1105    if !is_version_supported(server_version) {
1106        warn!(
1107            server_version = %server_version,
1108            supported = ?SUPPORTED_PROTOCOL_VERSIONS,
1109            "Server returned unsupported protocol version"
1110        );
1111        return Err(McpError::HandshakeFailed(Box::new(HandshakeDetails {
1112            message: format!(
1113                "Unsupported protocol version: server returned '{server_version}', but client only supports {SUPPORTED_PROTOCOL_VERSIONS:?}"
1114            ),
1115            client_version: Some(PROTOCOL_VERSION.to_string()),
1116            server_version: Some(server_version.clone()),
1117            source: None,
1118        })));
1119    }
1120
1121    debug!(
1122        server = %result.server_info.name,
1123        server_version = %result.server_info.version,
1124        protocol_version = %result.protocol_version,
1125        "Received initialize result with compatible protocol version"
1126    );
1127
1128    // Send initialized notification
1129    let notification = Notification::new("notifications/initialized");
1130    transport
1131        .send(Message::Notification(notification))
1132        .await
1133        .map_err(|e| {
1134            McpError::Transport(Box::new(TransportDetails {
1135                kind: TransportErrorKind::WriteFailed,
1136                message: format!("Failed to send initialized: {e}"),
1137                context: TransportContext::default(),
1138                source: None,
1139            }))
1140        })?;
1141
1142    debug!("MCP initialization complete");
1143    Ok(result)
1144}
1145
1146#[cfg(test)]
1147mod tests {
1148    use super::*;
1149
1150    #[test]
1151    fn test_request_id_generation() {
1152        let next_id = AtomicU64::new(1);
1153        assert_eq!(next_id.fetch_add(1, Ordering::SeqCst), 1);
1154        assert_eq!(next_id.fetch_add(1, Ordering::SeqCst), 2);
1155    }
1156}