mcpkit_client/
client.rs

1//! MCP client implementation.
2//!
3//! The [`Client`] struct provides a high-level API for interacting with
4//! MCP servers. It handles:
5//!
6//! - Protocol initialization
7//! - Request/response correlation
8//! - Tool, resource, and prompt operations
9//! - Task tracking
10//! - Connection lifecycle
11//! - Server-initiated request handling via [`ClientHandler`]
12
13use futures::channel::oneshot;
14use mcpkit_core::capability::{
15    is_version_supported, ClientCapabilities, ClientInfo, InitializeRequest, InitializeResult,
16    ServerCapabilities, ServerInfo, PROTOCOL_VERSION, SUPPORTED_PROTOCOL_VERSIONS,
17};
18use mcpkit_core::error::{
19    HandshakeDetails, JsonRpcError, McpError, TransportContext, TransportDetails,
20    TransportErrorKind,
21};
22use mcpkit_core::protocol::{Message, Notification, Request, RequestId, Response};
23use mcpkit_core::protocol_version::ProtocolVersion;
24use mcpkit_core::types::{
25    CallToolRequest, CallToolResult, CancelTaskRequest, CompleteRequest, CompleteResult,
26    CompletionArgument, CompletionRef, CreateMessageRequest, ElicitRequest, GetPromptRequest,
27    GetPromptResult, GetTaskRequest, ListPromptsResult, ListResourceTemplatesResult,
28    ListResourcesResult, ListTasksRequest, ListTasksResult, ListToolsResult, Prompt,
29    ReadResourceRequest, ReadResourceResult, Resource, ResourceContents, ResourceTemplate, Task,
30    TaskStatus, TaskSummary, Tool,
31};
32use mcpkit_transport::Transport;
33use std::collections::HashMap;
34use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
35use std::sync::Arc;
36use tracing::{debug, error, info, trace, warn};
37
38// Runtime-agnostic sync primitives
39use async_lock::RwLock;
40
41// Use tokio channels when tokio-runtime is enabled, otherwise use the transport abstraction
42#[cfg(feature = "tokio-runtime")]
43use tokio::sync::mpsc;
44
45use crate::handler::ClientHandler;
46
47/// An MCP client connected to a server.
48///
49/// The client provides methods for interacting with MCP servers:
50///
51/// - Tools: `list_tools()`, `call_tool()`
52/// - Resources: `list_resources()`, `read_resource()`
53/// - Prompts: `list_prompts()`, `get_prompt()`
54/// - Tasks: `list_tasks()`, `get_task()`, `cancel_task()`
55///
56/// The client also handles server-initiated requests (sampling, elicitation)
57/// by delegating to a [`ClientHandler`] implementation.
58///
59/// # Example
60///
61/// ```no_run
62/// use mcpkit_client::ClientBuilder;
63/// use mcpkit_transport::SpawnedTransport;
64///
65/// # async fn example() -> Result<(), mcpkit_core::error::McpError> {
66/// let transport = SpawnedTransport::spawn("my-server", &[] as &[&str]).await?;
67/// let client = ClientBuilder::new()
68///     .name("my-client")
69///     .version("1.0.0")
70///     .build(transport)
71///     .await?;
72///
73/// let tools = client.list_tools().await?;
74/// # Ok(())
75/// # }
76/// ```
77pub struct Client<T: Transport, H: ClientHandler = crate::handler::NoOpHandler> {
78    /// The underlying transport (shared with background task).
79    transport: Arc<T>,
80    /// Server information received during initialization.
81    server_info: ServerInfo,
82    /// Server capabilities.
83    server_caps: ServerCapabilities,
84    /// Negotiated protocol version.
85    ///
86    /// Use this for feature detection via methods like `supports_tasks()`,
87    /// `supports_elicitation()`, etc.
88    protocol_version: ProtocolVersion,
89    /// Client information.
90    client_info: ClientInfo,
91    /// Client capabilities.
92    client_caps: ClientCapabilities,
93    /// Next request ID.
94    next_id: AtomicU64,
95    /// Pending requests awaiting responses.
96    pending: Arc<RwLock<HashMap<RequestId, oneshot::Sender<Response>>>>,
97    /// Instructions from the server.
98    instructions: Option<String>,
99    /// Handler for server-initiated requests.
100    handler: Arc<H>,
101    /// Sender for outgoing messages to the background task.
102    outgoing_tx: mpsc::Sender<Message>,
103    /// Flag indicating if the client is running.
104    running: Arc<AtomicBool>,
105    /// Handle to the background task.
106    _background_handle: Option<tokio::task::JoinHandle<()>>,
107}
108
109impl<T: Transport + 'static> Client<T, crate::handler::NoOpHandler> {
110    /// Create a new client without a handler (called by builder).
111    pub(crate) fn new(
112        transport: T,
113        init_result: InitializeResult,
114        client_info: ClientInfo,
115        client_caps: ClientCapabilities,
116    ) -> Self {
117        Self::with_handler(
118            transport,
119            init_result,
120            client_info,
121            client_caps,
122            crate::handler::NoOpHandler,
123        )
124    }
125}
126
127impl<T: Transport + 'static, H: ClientHandler + 'static> Client<T, H> {
128    /// Create a new client with a custom handler (called by builder).
129    pub(crate) fn with_handler(
130        transport: T,
131        init_result: InitializeResult,
132        client_info: ClientInfo,
133        client_caps: ClientCapabilities,
134        handler: H,
135    ) -> Self {
136        let transport = Arc::new(transport);
137        let pending = Arc::new(RwLock::new(HashMap::new()));
138        let handler = Arc::new(handler);
139        let running = Arc::new(AtomicBool::new(true));
140
141        // Parse the negotiated protocol version
142        let protocol_version =
143            if let Ok(v) = init_result.protocol_version.parse::<ProtocolVersion>() {
144                v
145            } else {
146                warn!(
147                    server_version = %init_result.protocol_version,
148                    fallback_version = %ProtocolVersion::LATEST,
149                    "Server returned unknown protocol version, falling back to latest supported"
150                );
151                ProtocolVersion::LATEST
152            };
153
154        // Create channel for outgoing messages
155        let (outgoing_tx, outgoing_rx) = mpsc::channel::<Message>(256);
156
157        // Start background message routing task
158        let background_handle = Self::spawn_message_router(
159            Arc::clone(&transport),
160            Arc::clone(&pending),
161            Arc::clone(&handler),
162            Arc::clone(&running),
163            outgoing_rx,
164        );
165
166        // Notify handler that connection is established
167        let handler_clone = Arc::clone(&handler);
168        tokio::spawn(async move {
169            handler_clone.on_connected().await;
170        });
171
172        Self {
173            transport,
174            server_info: init_result.server_info,
175            server_caps: init_result.capabilities,
176            protocol_version,
177            client_info,
178            client_caps,
179            next_id: AtomicU64::new(1),
180            pending,
181            instructions: init_result.instructions,
182            handler,
183            outgoing_tx,
184            running,
185            _background_handle: Some(background_handle),
186        }
187    }
188
189    /// Spawn the background message routing task.
190    ///
191    /// This task:
192    /// - Reads incoming messages from the transport
193    /// - Routes responses to pending request channels
194    /// - Delegates server-initiated requests to the handler
195    /// - Handles notifications
196    fn spawn_message_router(
197        transport: Arc<T>,
198        pending: Arc<RwLock<HashMap<RequestId, oneshot::Sender<Response>>>>,
199        handler: Arc<H>,
200        running: Arc<AtomicBool>,
201        mut outgoing_rx: mpsc::Receiver<Message>,
202    ) -> tokio::task::JoinHandle<()> {
203        tokio::spawn(async move {
204            debug!("Starting client message router");
205
206            loop {
207                if !running.load(Ordering::SeqCst) {
208                    debug!("Message router stopping (client closed)");
209                    break;
210                }
211
212                tokio::select! {
213                    // Handle outgoing messages
214                    Some(msg) = outgoing_rx.recv() => {
215                        if let Err(e) = transport.send(msg).await {
216                            error!(?e, "Failed to send message");
217                        }
218                    }
219
220                    // Handle incoming messages
221                    result = transport.recv() => {
222                        match result {
223                            Ok(Some(message)) => {
224                                Self::handle_incoming_message(
225                                    message,
226                                    &pending,
227                                    &handler,
228                                    &transport,
229                                ).await;
230                            }
231                            Ok(None) => {
232                                info!("Connection closed by server");
233                                running.store(false, Ordering::SeqCst);
234                                handler.on_disconnected().await;
235                                break;
236                            }
237                            Err(e) => {
238                                error!(?e, "Transport error in message router");
239                                running.store(false, Ordering::SeqCst);
240                                handler.on_disconnected().await;
241                                break;
242                            }
243                        }
244                    }
245                }
246            }
247
248            debug!("Message router stopped");
249        })
250    }
251
252    /// Handle an incoming message from the server.
253    async fn handle_incoming_message(
254        message: Message,
255        pending: &Arc<RwLock<HashMap<RequestId, oneshot::Sender<Response>>>>,
256        handler: &Arc<H>,
257        transport: &Arc<T>,
258    ) {
259        match message {
260            Message::Response(response) => {
261                Self::route_response(response, pending).await;
262            }
263            Message::Request(request) => {
264                Self::handle_server_request(request, handler, transport).await;
265            }
266            Message::Notification(notification) => {
267                Self::handle_notification(notification, handler).await;
268            }
269        }
270    }
271
272    /// Route a response to the appropriate pending request.
273    async fn route_response(
274        response: Response,
275        pending: &Arc<RwLock<HashMap<RequestId, oneshot::Sender<Response>>>>,
276    ) {
277        let sender = {
278            let mut pending_guard = pending.write().await;
279            pending_guard.remove(&response.id)
280        };
281
282        if let Some(sender) = sender {
283            trace!(?response.id, "Routing response to pending request");
284            if sender.send(response).is_err() {
285                warn!("Pending request receiver dropped");
286            }
287        } else {
288            warn!(?response.id, "Received response for unknown request");
289        }
290    }
291
292    /// Handle a server-initiated request.
293    async fn handle_server_request(request: Request, handler: &Arc<H>, transport: &Arc<T>) {
294        trace!(method = %request.method, "Handling server request");
295
296        let response = match request.method.as_ref() {
297            "sampling/createMessage" => Self::handle_sampling_request(&request, handler).await,
298            "elicitation/elicit" => Self::handle_elicitation_request(&request, handler).await,
299            "roots/list" => Self::handle_roots_request(&request, handler).await,
300            "ping" => {
301                // Respond to ping with empty result
302                Response::success(request.id.clone(), serde_json::json!({}))
303            }
304            _ => {
305                warn!(method = %request.method, "Unknown server request method");
306                Response::error(
307                    request.id.clone(),
308                    JsonRpcError::method_not_found(format!("Unknown method: {}", request.method)),
309                )
310            }
311        };
312
313        // Send the response
314        if let Err(e) = transport.send(Message::Response(response)).await {
315            error!(?e, "Failed to send response to server request");
316        }
317    }
318
319    /// Handle a sampling/createMessage request.
320    async fn handle_sampling_request(request: &Request, handler: &Arc<H>) -> Response {
321        let params = match &request.params {
322            Some(p) => match serde_json::from_value::<CreateMessageRequest>(p.clone()) {
323                Ok(req) => req,
324                Err(e) => {
325                    return Response::error(
326                        request.id.clone(),
327                        JsonRpcError::invalid_params(format!("Invalid params: {e}")),
328                    );
329                }
330            },
331            None => {
332                return Response::error(
333                    request.id.clone(),
334                    JsonRpcError::invalid_params("Missing params for sampling/createMessage"),
335                );
336            }
337        };
338
339        match handler.create_message(params).await {
340            Ok(result) => match serde_json::to_value(result) {
341                Ok(value) => Response::success(request.id.clone(), value),
342                Err(e) => Response::error(
343                    request.id.clone(),
344                    JsonRpcError::internal_error(format!("Serialization error: {e}")),
345                ),
346            },
347            Err(e) => Response::error(
348                request.id.clone(),
349                JsonRpcError::internal_error(e.to_string()),
350            ),
351        }
352    }
353
354    /// Handle an elicitation/elicit request.
355    async fn handle_elicitation_request(request: &Request, handler: &Arc<H>) -> Response {
356        let params = match &request.params {
357            Some(p) => match serde_json::from_value::<ElicitRequest>(p.clone()) {
358                Ok(req) => req,
359                Err(e) => {
360                    return Response::error(
361                        request.id.clone(),
362                        JsonRpcError::invalid_params(format!("Invalid params: {e}")),
363                    );
364                }
365            },
366            None => {
367                return Response::error(
368                    request.id.clone(),
369                    JsonRpcError::invalid_params("Missing params for elicitation/elicit"),
370                );
371            }
372        };
373
374        match handler.elicit(params).await {
375            Ok(result) => match serde_json::to_value(result) {
376                Ok(value) => Response::success(request.id.clone(), value),
377                Err(e) => Response::error(
378                    request.id.clone(),
379                    JsonRpcError::internal_error(format!("Serialization error: {e}")),
380                ),
381            },
382            Err(e) => Response::error(
383                request.id.clone(),
384                JsonRpcError::internal_error(e.to_string()),
385            ),
386        }
387    }
388
389    /// Handle a roots/list request.
390    async fn handle_roots_request(request: &Request, handler: &Arc<H>) -> Response {
391        match handler.list_roots().await {
392            Ok(roots) => {
393                let roots_json: Vec<serde_json::Value> = roots
394                    .into_iter()
395                    .map(|r| {
396                        serde_json::json!({
397                            "uri": r.uri,
398                            "name": r.name
399                        })
400                    })
401                    .collect();
402                Response::success(
403                    request.id.clone(),
404                    serde_json::json!({ "roots": roots_json }),
405                )
406            }
407            Err(e) => Response::error(
408                request.id.clone(),
409                JsonRpcError::internal_error(e.to_string()),
410            ),
411        }
412    }
413
414    /// Handle a notification from the server.
415    async fn handle_notification(notification: Notification, handler: &Arc<H>) {
416        trace!(method = %notification.method, "Received server notification");
417
418        match notification.method.as_ref() {
419            "notifications/cancelled" => {
420                // Handle cancellation notifications
421                if let Some(params) = &notification.params {
422                    if let Some(request_id) = params.get("requestId") {
423                        debug!(?request_id, "Server cancelled request");
424                    }
425                }
426            }
427            "notifications/progress" => {
428                // Handle progress notifications
429                if let Some(params) = notification.params {
430                    if let (Some(task_id), Some(progress)) = (
431                        params.get("progressToken").and_then(|v| v.as_str()),
432                        params.get("progress"),
433                    ) {
434                        if let Ok(progress) = serde_json::from_value::<
435                            mcpkit_core::types::TaskProgress,
436                        >(progress.clone())
437                        {
438                            debug!(task_id = %task_id, "Task progress update");
439                            handler.on_task_progress(task_id.into(), progress).await;
440                        }
441                    }
442                }
443            }
444            "notifications/resources/updated" => {
445                if let Some(params) = notification.params {
446                    if let Some(uri) = params.get("uri").and_then(|v| v.as_str()) {
447                        debug!(uri = %uri, "Resource updated");
448                        handler.on_resource_updated(uri.to_string()).await;
449                    }
450                }
451            }
452            "notifications/resources/list_changed" => {
453                debug!("Resources list changed");
454                handler.on_resources_list_changed().await;
455            }
456            "notifications/tools/list_changed" => {
457                debug!("Tools list changed");
458                handler.on_tools_list_changed().await;
459            }
460            "notifications/prompts/list_changed" => {
461                debug!("Prompts list changed");
462                handler.on_prompts_list_changed().await;
463            }
464            _ => {
465                trace!(method = %notification.method, "Unhandled notification");
466            }
467        }
468    }
469
470    /// Get the server information.
471    pub const fn server_info(&self) -> &ServerInfo {
472        &self.server_info
473    }
474
475    /// Get the server capabilities.
476    pub const fn server_capabilities(&self) -> &ServerCapabilities {
477        &self.server_caps
478    }
479
480    /// Get the negotiated protocol version.
481    ///
482    /// Use this for feature detection. For example:
483    /// ```rust,ignore
484    /// if client.protocol_version().supports_tasks() {
485    ///     // Use task-related features
486    /// }
487    /// ```
488    pub fn protocol_version(&self) -> ProtocolVersion {
489        self.protocol_version
490    }
491
492    /// Get the client information.
493    pub const fn client_info(&self) -> &ClientInfo {
494        &self.client_info
495    }
496
497    /// Get the client capabilities.
498    pub const fn client_capabilities(&self) -> &ClientCapabilities {
499        &self.client_caps
500    }
501
502    /// Get the server instructions, if provided.
503    pub fn instructions(&self) -> Option<&str> {
504        self.instructions.as_deref()
505    }
506
507    /// Check if the server supports tools.
508    pub const fn has_tools(&self) -> bool {
509        self.server_caps.has_tools()
510    }
511
512    /// Check if the server supports resources.
513    pub const fn has_resources(&self) -> bool {
514        self.server_caps.has_resources()
515    }
516
517    /// Check if the server supports prompts.
518    pub const fn has_prompts(&self) -> bool {
519        self.server_caps.has_prompts()
520    }
521
522    /// Check if the server supports tasks.
523    pub const fn has_tasks(&self) -> bool {
524        self.server_caps.has_tasks()
525    }
526
527    /// Check if the server supports completions.
528    pub const fn has_completions(&self) -> bool {
529        self.server_caps.has_completions()
530    }
531
532    /// Check if the client is still connected.
533    pub fn is_connected(&self) -> bool {
534        self.running.load(Ordering::SeqCst)
535    }
536
537    // ==========================================================================
538    // Tool Operations
539    // ==========================================================================
540
541    /// List all available tools.
542    ///
543    /// # Errors
544    ///
545    /// Returns an error if tools are not supported or the request fails.
546    pub async fn list_tools(&self) -> Result<Vec<Tool>, McpError> {
547        self.ensure_capability("tools", self.has_tools())?;
548
549        let result: ListToolsResult = self.request("tools/list", None).await?;
550        Ok(result.tools)
551    }
552
553    /// List tools with pagination.
554    ///
555    /// # Errors
556    ///
557    /// Returns an error if tools are not supported or the request fails.
558    pub async fn list_tools_paginated(
559        &self,
560        cursor: Option<&str>,
561    ) -> Result<ListToolsResult, McpError> {
562        self.ensure_capability("tools", self.has_tools())?;
563
564        let params = cursor.map(|c| serde_json::json!({ "cursor": c }));
565        self.request("tools/list", params).await
566    }
567
568    /// Call a tool by name.
569    ///
570    /// # Arguments
571    ///
572    /// * `name` - The name of the tool to call
573    /// * `arguments` - The arguments to pass to the tool (as JSON)
574    ///
575    /// # Errors
576    ///
577    /// Returns an error if tools are not supported or the call fails.
578    pub async fn call_tool(
579        &self,
580        name: impl Into<String>,
581        arguments: serde_json::Value,
582    ) -> Result<CallToolResult, McpError> {
583        self.ensure_capability("tools", self.has_tools())?;
584
585        let request = CallToolRequest {
586            name: name.into(),
587            arguments: Some(arguments),
588        };
589        self.request("tools/call", Some(serde_json::to_value(request)?))
590            .await
591    }
592
593    // ==========================================================================
594    // Resource Operations
595    // ==========================================================================
596
597    /// List all available resources.
598    ///
599    /// # Errors
600    ///
601    /// Returns an error if resources are not supported or the request fails.
602    pub async fn list_resources(&self) -> Result<Vec<Resource>, McpError> {
603        self.ensure_capability("resources", self.has_resources())?;
604
605        let result: ListResourcesResult = self.request("resources/list", None).await?;
606        Ok(result.resources)
607    }
608
609    /// List resources with pagination.
610    ///
611    /// # Errors
612    ///
613    /// Returns an error if resources are not supported or the request fails.
614    pub async fn list_resources_paginated(
615        &self,
616        cursor: Option<&str>,
617    ) -> Result<ListResourcesResult, McpError> {
618        self.ensure_capability("resources", self.has_resources())?;
619
620        let params = cursor.map(|c| serde_json::json!({ "cursor": c }));
621        self.request("resources/list", params).await
622    }
623
624    /// List resource templates.
625    ///
626    /// # Errors
627    ///
628    /// Returns an error if resources are not supported or the request fails.
629    pub async fn list_resource_templates(&self) -> Result<Vec<ResourceTemplate>, McpError> {
630        self.ensure_capability("resources", self.has_resources())?;
631
632        let result: ListResourceTemplatesResult =
633            self.request("resources/templates/list", None).await?;
634        Ok(result.resource_templates)
635    }
636
637    /// Read a resource by URI.
638    ///
639    /// # Errors
640    ///
641    /// Returns an error if resources are not supported or the read fails.
642    pub async fn read_resource(
643        &self,
644        uri: impl Into<String>,
645    ) -> Result<Vec<ResourceContents>, McpError> {
646        self.ensure_capability("resources", self.has_resources())?;
647
648        let request = ReadResourceRequest { uri: uri.into() };
649        let result: ReadResourceResult = self
650            .request("resources/read", Some(serde_json::to_value(request)?))
651            .await?;
652        Ok(result.contents)
653    }
654
655    // ==========================================================================
656    // Prompt Operations
657    // ==========================================================================
658
659    /// List all available prompts.
660    ///
661    /// # Errors
662    ///
663    /// Returns an error if prompts are not supported or the request fails.
664    pub async fn list_prompts(&self) -> Result<Vec<Prompt>, McpError> {
665        self.ensure_capability("prompts", self.has_prompts())?;
666
667        let result: ListPromptsResult = self.request("prompts/list", None).await?;
668        Ok(result.prompts)
669    }
670
671    /// List prompts with pagination.
672    ///
673    /// # Errors
674    ///
675    /// Returns an error if prompts are not supported or the request fails.
676    pub async fn list_prompts_paginated(
677        &self,
678        cursor: Option<&str>,
679    ) -> Result<ListPromptsResult, McpError> {
680        self.ensure_capability("prompts", self.has_prompts())?;
681
682        let params = cursor.map(|c| serde_json::json!({ "cursor": c }));
683        self.request("prompts/list", params).await
684    }
685
686    /// Get a prompt by name, optionally with arguments.
687    ///
688    /// # Errors
689    ///
690    /// Returns an error if prompts are not supported or the get fails.
691    pub async fn get_prompt(
692        &self,
693        name: impl Into<String>,
694        arguments: Option<serde_json::Map<String, serde_json::Value>>,
695    ) -> Result<GetPromptResult, McpError> {
696        self.ensure_capability("prompts", self.has_prompts())?;
697
698        let request = GetPromptRequest {
699            name: name.into(),
700            arguments,
701        };
702        self.request("prompts/get", Some(serde_json::to_value(request)?))
703            .await
704    }
705
706    // ==========================================================================
707    // Task Operations
708    // ==========================================================================
709
710    /// List all tasks.
711    ///
712    /// # Errors
713    ///
714    /// Returns an error if tasks are not supported or the request fails.
715    pub async fn list_tasks(&self) -> Result<Vec<TaskSummary>, McpError> {
716        self.ensure_capability("tasks", self.has_tasks())?;
717
718        let result: ListTasksResult = self.request("tasks/list", None).await?;
719        Ok(result.tasks)
720    }
721
722    /// List tasks with optional status filter and pagination.
723    ///
724    /// # Errors
725    ///
726    /// Returns an error if tasks are not supported or the request fails.
727    pub async fn list_tasks_filtered(
728        &self,
729        status: Option<TaskStatus>,
730        cursor: Option<&str>,
731    ) -> Result<ListTasksResult, McpError> {
732        self.ensure_capability("tasks", self.has_tasks())?;
733
734        let request = ListTasksRequest {
735            status,
736            cursor: cursor.map(String::from),
737        };
738        self.request("tasks/list", Some(serde_json::to_value(request)?))
739            .await
740    }
741
742    /// Get a task by ID.
743    ///
744    /// # Errors
745    ///
746    /// Returns an error if tasks are not supported or the task is not found.
747    pub async fn get_task(&self, id: impl Into<String>) -> Result<Task, McpError> {
748        self.ensure_capability("tasks", self.has_tasks())?;
749
750        let request = GetTaskRequest {
751            id: id.into().into(),
752        };
753        self.request("tasks/get", Some(serde_json::to_value(request)?))
754            .await
755    }
756
757    /// Cancel a running task.
758    ///
759    /// # Errors
760    ///
761    /// Returns an error if tasks are not supported, cancellation is not supported,
762    /// or the task is not found.
763    pub async fn cancel_task(&self, id: impl Into<String>) -> Result<(), McpError> {
764        self.ensure_capability("tasks", self.has_tasks())?;
765
766        let request = CancelTaskRequest {
767            id: id.into().into(),
768        };
769        let _: serde_json::Value = self
770            .request("tasks/cancel", Some(serde_json::to_value(request)?))
771            .await?;
772        Ok(())
773    }
774
775    // ==========================================================================
776    // Completion Operations
777    // ==========================================================================
778
779    /// Get completions for a prompt argument.
780    ///
781    /// # Arguments
782    ///
783    /// * `prompt_name` - The name of the prompt
784    /// * `argument_name` - The name of the argument to complete
785    /// * `current_value` - The current partial value being typed
786    ///
787    /// # Errors
788    ///
789    /// Returns an error if completions are not supported or the request fails.
790    pub async fn complete_prompt_argument(
791        &self,
792        prompt_name: impl Into<String>,
793        argument_name: impl Into<String>,
794        current_value: impl Into<String>,
795    ) -> Result<CompleteResult, McpError> {
796        self.ensure_capability("completions", self.has_completions())?;
797
798        let request = CompleteRequest {
799            ref_: CompletionRef::prompt(prompt_name),
800            argument: CompletionArgument {
801                name: argument_name.into(),
802                value: current_value.into(),
803            },
804        };
805        self.request("completion/complete", Some(serde_json::to_value(request)?))
806            .await
807    }
808
809    /// Get completions for a resource argument.
810    ///
811    /// # Arguments
812    ///
813    /// * `resource_uri` - The URI of the resource
814    /// * `argument_name` - The name of the argument to complete
815    /// * `current_value` - The current partial value being typed
816    ///
817    /// # Errors
818    ///
819    /// Returns an error if completions are not supported or the request fails.
820    pub async fn complete_resource_argument(
821        &self,
822        resource_uri: impl Into<String>,
823        argument_name: impl Into<String>,
824        current_value: impl Into<String>,
825    ) -> Result<CompleteResult, McpError> {
826        self.ensure_capability("completions", self.has_completions())?;
827
828        let request = CompleteRequest {
829            ref_: CompletionRef::resource(resource_uri),
830            argument: CompletionArgument {
831                name: argument_name.into(),
832                value: current_value.into(),
833            },
834        };
835        self.request("completion/complete", Some(serde_json::to_value(request)?))
836            .await
837    }
838
839    // ==========================================================================
840    // Resource Subscription Operations
841    // ==========================================================================
842
843    /// Subscribe to updates for a resource.
844    ///
845    /// When subscribed, the server will send `notifications/resources/updated`
846    /// when the resource changes.
847    ///
848    /// # Errors
849    ///
850    /// Returns an error if resource subscriptions are not supported or the request fails.
851    pub async fn subscribe_resource(&self, uri: impl Into<String>) -> Result<(), McpError> {
852        self.ensure_capability("resources", self.has_resources())?;
853
854        // Check if subscribe is supported
855        if !self.server_caps.has_resource_subscribe() {
856            return Err(McpError::CapabilityNotSupported {
857                capability: "resources.subscribe".to_string(),
858                available: self.available_capabilities().into_boxed_slice(),
859            });
860        }
861
862        let params = serde_json::json!({ "uri": uri.into() });
863        let _: serde_json::Value = self.request("resources/subscribe", Some(params)).await?;
864        Ok(())
865    }
866
867    /// Unsubscribe from updates for a resource.
868    ///
869    /// # Errors
870    ///
871    /// Returns an error if resource subscriptions are not supported or the request fails.
872    pub async fn unsubscribe_resource(&self, uri: impl Into<String>) -> Result<(), McpError> {
873        self.ensure_capability("resources", self.has_resources())?;
874
875        // Check if subscribe is supported
876        if !self.server_caps.has_resource_subscribe() {
877            return Err(McpError::CapabilityNotSupported {
878                capability: "resources.subscribe".to_string(),
879                available: self.available_capabilities().into_boxed_slice(),
880            });
881        }
882
883        let params = serde_json::json!({ "uri": uri.into() });
884        let _: serde_json::Value = self.request("resources/unsubscribe", Some(params)).await?;
885        Ok(())
886    }
887
888    // ==========================================================================
889    // Connection Operations
890    // ==========================================================================
891
892    /// Ping the server.
893    ///
894    /// # Errors
895    ///
896    /// Returns an error if the ping fails.
897    pub async fn ping(&self) -> Result<(), McpError> {
898        let _: serde_json::Value = self.request("ping", None).await?;
899        Ok(())
900    }
901
902    /// Close the connection gracefully.
903    ///
904    /// # Errors
905    ///
906    /// Returns an error if the close fails.
907    pub async fn close(self) -> Result<(), McpError> {
908        debug!("Closing client connection");
909
910        // Signal the background task to stop
911        self.running.store(false, Ordering::SeqCst);
912
913        // Notify handler
914        self.handler.on_disconnected().await;
915
916        // Close the transport
917        self.transport.close().await.map_err(|e| {
918            McpError::Transport(Box::new(TransportDetails {
919                kind: TransportErrorKind::ConnectionClosed,
920                message: e.to_string(),
921                context: TransportContext::default(),
922                source: None,
923            }))
924        })
925    }
926
927    // ==========================================================================
928    // Internal Methods
929    // ==========================================================================
930
931    /// Generate the next request ID.
932    fn next_request_id(&self) -> RequestId {
933        RequestId::Number(self.next_id.fetch_add(1, Ordering::SeqCst))
934    }
935
936    /// Send a request and wait for the response.
937    async fn request<R: serde::de::DeserializeOwned>(
938        &self,
939        method: &str,
940        params: Option<serde_json::Value>,
941    ) -> Result<R, McpError> {
942        if !self.is_connected() {
943            return Err(McpError::Transport(Box::new(TransportDetails {
944                kind: TransportErrorKind::ConnectionClosed,
945                message: "Client is not connected".to_string(),
946                context: TransportContext::default(),
947                source: None,
948            })));
949        }
950
951        let id = self.next_request_id();
952        let request = if let Some(params) = params {
953            Request::with_params(method.to_string(), id.clone(), params)
954        } else {
955            Request::new(method.to_string(), id.clone())
956        };
957
958        trace!(?id, method, "Sending request");
959
960        // Create a channel for the response
961        let (tx, rx) = oneshot::channel();
962        {
963            let mut pending = self.pending.write().await;
964            pending.insert(id.clone(), tx);
965        }
966
967        // Send the request through the outgoing channel
968        self.outgoing_tx
969            .send(Message::Request(request))
970            .await
971            .map_err(|_| {
972                McpError::Transport(Box::new(TransportDetails {
973                    kind: TransportErrorKind::WriteFailed,
974                    message: "Failed to send request (channel closed)".to_string(),
975                    context: TransportContext::default(),
976                    source: None,
977                }))
978            })?;
979
980        // Wait for the response with a timeout
981        let response = rx.await.map_err(|_| {
982            McpError::Transport(Box::new(TransportDetails {
983                kind: TransportErrorKind::ConnectionClosed,
984                message: "Response channel closed (server may have disconnected)".to_string(),
985                context: TransportContext::default(),
986                source: None,
987            }))
988        })?;
989
990        // Process the response
991        if let Some(error) = response.error {
992            return Err(McpError::Internal {
993                message: error.message,
994                source: None,
995            });
996        }
997
998        let result = response.result.ok_or_else(|| McpError::Internal {
999            message: "Response contained neither result nor error".to_string(),
1000            source: None,
1001        })?;
1002
1003        serde_json::from_value(result).map_err(McpError::from)
1004    }
1005
1006    /// Check that a capability is supported.
1007    fn ensure_capability(&self, name: &str, supported: bool) -> Result<(), McpError> {
1008        if supported {
1009            Ok(())
1010        } else {
1011            Err(McpError::CapabilityNotSupported {
1012                capability: name.to_string(),
1013                available: self.available_capabilities().into_boxed_slice(),
1014            })
1015        }
1016    }
1017
1018    /// Get list of available capabilities.
1019    fn available_capabilities(&self) -> Vec<String> {
1020        let mut caps = Vec::new();
1021        if self.has_tools() {
1022            caps.push("tools".to_string());
1023        }
1024        if self.has_resources() {
1025            caps.push("resources".to_string());
1026        }
1027        if self.has_prompts() {
1028            caps.push("prompts".to_string());
1029        }
1030        if self.has_tasks() {
1031            caps.push("tasks".to_string());
1032        }
1033        if self.has_completions() {
1034            caps.push("completions".to_string());
1035        }
1036        caps
1037    }
1038}
1039
1040/// Initialize a client connection.
1041///
1042/// This performs the MCP handshake with protocol version negotiation:
1043/// 1. Send initialize request with our preferred protocol version
1044/// 2. Wait for initialize result with server's negotiated version
1045/// 3. Validate we support the server's version (disconnect if not)
1046/// 4. Send initialized notification
1047///
1048/// # Protocol Version Negotiation
1049///
1050/// Per the MCP specification:
1051/// - Client sends its preferred (latest) protocol version
1052/// - Server responds with the same version if supported, or its own preferred version
1053/// - Client must support the server's version or the handshake fails
1054///
1055/// This SDK supports protocol versions: `2025-11-25`, `2024-11-05`.
1056pub(crate) async fn initialize<T: Transport>(
1057    transport: &T,
1058    client_info: &ClientInfo,
1059    capabilities: &ClientCapabilities,
1060) -> Result<InitializeResult, McpError> {
1061    debug!(
1062        protocol_version = %PROTOCOL_VERSION,
1063        supported_versions = ?SUPPORTED_PROTOCOL_VERSIONS,
1064        "Initializing MCP connection"
1065    );
1066
1067    // Build initialize request
1068    let request = InitializeRequest::new(client_info.clone(), capabilities.clone());
1069    let init_request = Request::with_params(
1070        "initialize".to_string(),
1071        RequestId::Number(0),
1072        serde_json::to_value(&request)?,
1073    );
1074
1075    // Send initialize request
1076    transport
1077        .send(Message::Request(init_request))
1078        .await
1079        .map_err(|e| {
1080            McpError::Transport(Box::new(TransportDetails {
1081                kind: TransportErrorKind::WriteFailed,
1082                message: format!("Failed to send initialize: {e}"),
1083                context: TransportContext::default(),
1084                source: None,
1085            }))
1086        })?;
1087
1088    // Wait for response
1089    let response = loop {
1090        match transport.recv().await {
1091            Ok(Some(Message::Response(r))) if r.id == RequestId::Number(0) => break r,
1092            Ok(Some(_)) => {} // Skip non-matching messages
1093            Ok(None) => {
1094                return Err(McpError::HandshakeFailed(Box::new(HandshakeDetails {
1095                    message: "Connection closed during initialization".to_string(),
1096                    client_version: Some(PROTOCOL_VERSION.to_string()),
1097                    server_version: None,
1098                    source: None,
1099                })));
1100            }
1101            Err(e) => {
1102                return Err(McpError::HandshakeFailed(Box::new(HandshakeDetails {
1103                    message: format!("Transport error during initialization: {e}"),
1104                    client_version: Some(PROTOCOL_VERSION.to_string()),
1105                    server_version: None,
1106                    source: None,
1107                })));
1108            }
1109        }
1110    };
1111
1112    // Parse the response
1113    if let Some(error) = response.error {
1114        return Err(McpError::HandshakeFailed(Box::new(HandshakeDetails {
1115            message: error.message,
1116            client_version: Some(PROTOCOL_VERSION.to_string()),
1117            server_version: None,
1118            source: None,
1119        })));
1120    }
1121
1122    let result: InitializeResult = response
1123        .result
1124        .map(serde_json::from_value)
1125        .transpose()?
1126        .ok_or_else(|| {
1127            McpError::HandshakeFailed(Box::new(HandshakeDetails {
1128                message: "Empty initialize result".to_string(),
1129                client_version: Some(PROTOCOL_VERSION.to_string()),
1130                server_version: None,
1131                source: None,
1132            }))
1133        })?;
1134
1135    // Validate protocol version
1136    let server_version = &result.protocol_version;
1137    if !is_version_supported(server_version) {
1138        warn!(
1139            server_version = %server_version,
1140            supported = ?SUPPORTED_PROTOCOL_VERSIONS,
1141            "Server returned unsupported protocol version"
1142        );
1143        return Err(McpError::HandshakeFailed(Box::new(HandshakeDetails {
1144            message: format!(
1145                "Unsupported protocol version: server returned '{server_version}', but client only supports {SUPPORTED_PROTOCOL_VERSIONS:?}"
1146            ),
1147            client_version: Some(PROTOCOL_VERSION.to_string()),
1148            server_version: Some(server_version.clone()),
1149            source: None,
1150        })));
1151    }
1152
1153    debug!(
1154        server = %result.server_info.name,
1155        server_version = %result.server_info.version,
1156        protocol_version = %result.protocol_version,
1157        "Received initialize result with compatible protocol version"
1158    );
1159
1160    // Send initialized notification
1161    let notification = Notification::new("notifications/initialized");
1162    transport
1163        .send(Message::Notification(notification))
1164        .await
1165        .map_err(|e| {
1166            McpError::Transport(Box::new(TransportDetails {
1167                kind: TransportErrorKind::WriteFailed,
1168                message: format!("Failed to send initialized: {e}"),
1169                context: TransportContext::default(),
1170                source: None,
1171            }))
1172        })?;
1173
1174    debug!("MCP initialization complete");
1175    Ok(result)
1176}
1177
1178#[cfg(test)]
1179mod tests {
1180    use super::*;
1181
1182    #[test]
1183    fn test_request_id_generation() {
1184        let next_id = AtomicU64::new(1);
1185        assert_eq!(next_id.fetch_add(1, Ordering::SeqCst), 1);
1186        assert_eq!(next_id.fetch_add(1, Ordering::SeqCst), 2);
1187    }
1188}