mcpkit_client/
client.rs

1//! MCP client implementation.
2//!
3//! The [`Client`] struct provides a high-level API for interacting with
4//! MCP servers. It handles:
5//!
6//! - Protocol initialization
7//! - Request/response correlation
8//! - Tool, resource, and prompt operations
9//! - Task tracking
10//! - Connection lifecycle
11//! - Server-initiated request handling via [`ClientHandler`]
12
13use futures::channel::oneshot;
14use mcpkit_core::capability::{
15    ClientCapabilities, ClientInfo, InitializeRequest, InitializeResult, PROTOCOL_VERSION,
16    SUPPORTED_PROTOCOL_VERSIONS, ServerCapabilities, ServerInfo, is_version_supported,
17};
18use mcpkit_core::error::{
19    HandshakeDetails, JsonRpcError, McpError, TransportContext, TransportDetails,
20    TransportErrorKind,
21};
22use mcpkit_core::protocol::{Message, Notification, Request, RequestId, Response};
23use mcpkit_core::protocol_version::ProtocolVersion;
24use mcpkit_core::types::{
25    CallToolRequest, CallToolResult, CancelTaskRequest, CompleteRequest, CompleteResult,
26    CompletionArgument, CompletionRef, CreateMessageRequest, ElicitRequest, GetPromptRequest,
27    GetPromptResult, GetTaskRequest, ListPromptsResult, ListResourceTemplatesResult,
28    ListResourcesResult, ListTasksRequest, ListTasksResult, ListToolsResult, Prompt,
29    ReadResourceRequest, ReadResourceResult, Resource, ResourceContents, ResourceTemplate, Task,
30    TaskStatus, TaskSummary, Tool,
31};
32use mcpkit_transport::Transport;
33use std::collections::HashMap;
34use std::sync::Arc;
35use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
36use tracing::{debug, error, info, trace, warn};
37
38// Runtime-agnostic sync primitives
39use async_lock::RwLock;
40
41// Use tokio channels when tokio-runtime is enabled, otherwise use the transport abstraction
42#[cfg(feature = "tokio-runtime")]
43use tokio::sync::mpsc;
44
45use crate::handler::ClientHandler;
46
47/// An MCP client connected to a server.
48///
49/// The client provides methods for interacting with MCP servers:
50///
51/// - Tools: `list_tools()`, `call_tool()`
52/// - Resources: `list_resources()`, `read_resource()`
53/// - Prompts: `list_prompts()`, `get_prompt()`
54/// - Tasks: `list_tasks()`, `get_task()`, `cancel_task()`
55///
56/// The client also handles server-initiated requests (sampling, elicitation)
57/// by delegating to a [`ClientHandler`] implementation.
58///
59/// # Example
60///
61/// ```no_run
62/// use mcpkit_client::ClientBuilder;
63/// use mcpkit_transport::SpawnedTransport;
64///
65/// # async fn example() -> Result<(), mcpkit_core::error::McpError> {
66/// let transport = SpawnedTransport::spawn("my-server", &[] as &[&str]).await?;
67/// let client = ClientBuilder::new()
68///     .name("my-client")
69///     .version("1.0.0")
70///     .build(transport)
71///     .await?;
72///
73/// let tools = client.list_tools().await?;
74/// # Ok(())
75/// # }
76/// ```
77pub struct Client<T: Transport, H: ClientHandler = crate::handler::NoOpHandler> {
78    /// The underlying transport (shared with background task).
79    transport: Arc<T>,
80    /// Server information received during initialization.
81    server_info: ServerInfo,
82    /// Server capabilities.
83    server_caps: ServerCapabilities,
84    /// Negotiated protocol version.
85    ///
86    /// Use this for feature detection via methods like `supports_tasks()`,
87    /// `supports_elicitation()`, etc.
88    protocol_version: ProtocolVersion,
89    /// Client information.
90    client_info: ClientInfo,
91    /// Client capabilities.
92    client_caps: ClientCapabilities,
93    /// Next request ID.
94    next_id: AtomicU64,
95    /// Pending requests awaiting responses.
96    pending: Arc<RwLock<HashMap<RequestId, oneshot::Sender<Response>>>>,
97    /// Instructions from the server.
98    instructions: Option<String>,
99    /// Handler for server-initiated requests.
100    handler: Arc<H>,
101    /// Sender for outgoing messages to the background task.
102    outgoing_tx: mpsc::Sender<Message>,
103    /// Flag indicating if the client is running.
104    running: Arc<AtomicBool>,
105    /// Handle to the background task.
106    _background_handle: Option<tokio::task::JoinHandle<()>>,
107}
108
109impl<T: Transport + 'static> Client<T, crate::handler::NoOpHandler> {
110    /// Create a new client without a handler (called by builder).
111    pub(crate) fn new(
112        transport: T,
113        init_result: InitializeResult,
114        client_info: ClientInfo,
115        client_caps: ClientCapabilities,
116    ) -> Self {
117        Self::with_handler(
118            transport,
119            init_result,
120            client_info,
121            client_caps,
122            crate::handler::NoOpHandler,
123        )
124    }
125}
126
127impl<T: Transport + 'static, H: ClientHandler + 'static> Client<T, H> {
128    /// Create a new client with a custom handler (called by builder).
129    pub(crate) fn with_handler(
130        transport: T,
131        init_result: InitializeResult,
132        client_info: ClientInfo,
133        client_caps: ClientCapabilities,
134        handler: H,
135    ) -> Self {
136        let transport = Arc::new(transport);
137        let pending = Arc::new(RwLock::new(HashMap::new()));
138        let handler = Arc::new(handler);
139        let running = Arc::new(AtomicBool::new(true));
140
141        // Parse the negotiated protocol version
142        let protocol_version =
143            if let Ok(v) = init_result.protocol_version.parse::<ProtocolVersion>() {
144                v
145            } else {
146                warn!(
147                    server_version = %init_result.protocol_version,
148                    fallback_version = %ProtocolVersion::LATEST,
149                    "Server returned unknown protocol version, falling back to latest supported"
150                );
151                ProtocolVersion::LATEST
152            };
153
154        // Create channel for outgoing messages
155        let (outgoing_tx, outgoing_rx) = mpsc::channel::<Message>(256);
156
157        // Start background message routing task
158        let background_handle = Self::spawn_message_router(
159            Arc::clone(&transport),
160            Arc::clone(&pending),
161            Arc::clone(&handler),
162            Arc::clone(&running),
163            outgoing_rx,
164        );
165
166        // Notify handler that connection is established
167        let handler_clone = Arc::clone(&handler);
168        tokio::spawn(async move {
169            handler_clone.on_connected().await;
170        });
171
172        Self {
173            transport,
174            server_info: init_result.server_info,
175            server_caps: init_result.capabilities,
176            protocol_version,
177            client_info,
178            client_caps,
179            next_id: AtomicU64::new(1),
180            pending,
181            instructions: init_result.instructions,
182            handler,
183            outgoing_tx,
184            running,
185            _background_handle: Some(background_handle),
186        }
187    }
188
189    /// Spawn the background message routing task.
190    ///
191    /// This task:
192    /// - Reads incoming messages from the transport
193    /// - Routes responses to pending request channels
194    /// - Delegates server-initiated requests to the handler
195    /// - Handles notifications
196    fn spawn_message_router(
197        transport: Arc<T>,
198        pending: Arc<RwLock<HashMap<RequestId, oneshot::Sender<Response>>>>,
199        handler: Arc<H>,
200        running: Arc<AtomicBool>,
201        mut outgoing_rx: mpsc::Receiver<Message>,
202    ) -> tokio::task::JoinHandle<()> {
203        tokio::spawn(async move {
204            debug!("Starting client message router");
205
206            loop {
207                if !running.load(Ordering::SeqCst) {
208                    debug!("Message router stopping (client closed)");
209                    break;
210                }
211
212                tokio::select! {
213                    // Handle outgoing messages
214                    Some(msg) = outgoing_rx.recv() => {
215                        if let Err(e) = transport.send(msg).await {
216                            error!(?e, "Failed to send message");
217                        }
218                    }
219
220                    // Handle incoming messages
221                    result = transport.recv() => {
222                        match result {
223                            Ok(Some(message)) => {
224                                // Debug: log what was received
225                                let msg_id = match &message {
226                                    Message::Request(r) => format!("Request({})", r.id),
227                                    Message::Response(r) => format!("Response({})", r.id),
228                                    Message::Notification(n) => format!("Notification({})", n.method),
229                                };
230                                debug!(msg = %msg_id, "Router received message from transport");
231                                Self::handle_incoming_message(
232                                    message,
233                                    &pending,
234                                    &handler,
235                                    &transport,
236                                ).await;
237                            }
238                            Ok(None) => {
239                                info!("Connection closed by server");
240                                running.store(false, Ordering::SeqCst);
241                                handler.on_disconnected().await;
242                                break;
243                            }
244                            Err(e) => {
245                                error!(?e, "Transport error in message router");
246                                running.store(false, Ordering::SeqCst);
247                                handler.on_disconnected().await;
248                                break;
249                            }
250                        }
251                    }
252                }
253            }
254
255            debug!("Message router stopped");
256        })
257    }
258
259    /// Handle an incoming message from the server.
260    async fn handle_incoming_message(
261        message: Message,
262        pending: &Arc<RwLock<HashMap<RequestId, oneshot::Sender<Response>>>>,
263        handler: &Arc<H>,
264        transport: &Arc<T>,
265    ) {
266        match message {
267            Message::Response(response) => {
268                Self::route_response(response, pending).await;
269            }
270            Message::Request(request) => {
271                Self::handle_server_request(request, handler, transport).await;
272            }
273            Message::Notification(notification) => {
274                Self::handle_notification(notification, handler).await;
275            }
276        }
277    }
278
279    /// Route a response to the appropriate pending request.
280    async fn route_response(
281        response: Response,
282        pending: &Arc<RwLock<HashMap<RequestId, oneshot::Sender<Response>>>>,
283    ) {
284        let pending_count = pending.read().await.len();
285        let sender = {
286            let mut pending_guard = pending.write().await;
287            pending_guard.remove(&response.id)
288        };
289
290        if let Some(sender) = sender {
291            debug!(?response.id, pending_count, "Routing response to pending request (found in pending)");
292            if sender.send(response).is_err() {
293                warn!("Pending request receiver dropped");
294            }
295        } else {
296            // This can happen benignly when:
297            // 1. A response arrives that was already handled (e.g., after timeout)
298            // 2. The server sends an unsolicited response
299            // 3. A previous response is re-delivered due to transport buffering
300            // Log at debug level to help diagnose correlation issues.
301            debug!(?response.id, pending_count, "Response not found in pending (possible race or duplicate)");
302        }
303    }
304
305    /// Handle a server-initiated request.
306    async fn handle_server_request(request: Request, handler: &Arc<H>, transport: &Arc<T>) {
307        trace!(method = %request.method, "Handling server request");
308
309        let response = match request.method.as_ref() {
310            "sampling/createMessage" => Self::handle_sampling_request(&request, handler).await,
311            "elicitation/elicit" => Self::handle_elicitation_request(&request, handler).await,
312            "roots/list" => Self::handle_roots_request(&request, handler).await,
313            "ping" => {
314                // Respond to ping with empty result
315                Response::success(request.id.clone(), serde_json::json!({}))
316            }
317            _ => {
318                warn!(method = %request.method, "Unknown server request method");
319                Response::error(
320                    request.id.clone(),
321                    JsonRpcError::method_not_found(format!("Unknown method: {}", request.method)),
322                )
323            }
324        };
325
326        // Send the response
327        if let Err(e) = transport.send(Message::Response(response)).await {
328            error!(?e, "Failed to send response to server request");
329        }
330    }
331
332    /// Handle a sampling/createMessage request.
333    async fn handle_sampling_request(request: &Request, handler: &Arc<H>) -> Response {
334        let params = match &request.params {
335            Some(p) => match serde_json::from_value::<CreateMessageRequest>(p.clone()) {
336                Ok(req) => req,
337                Err(e) => {
338                    return Response::error(
339                        request.id.clone(),
340                        JsonRpcError::invalid_params(format!("Invalid params: {e}")),
341                    );
342                }
343            },
344            None => {
345                return Response::error(
346                    request.id.clone(),
347                    JsonRpcError::invalid_params("Missing params for sampling/createMessage"),
348                );
349            }
350        };
351
352        match handler.create_message(params).await {
353            Ok(result) => match serde_json::to_value(result) {
354                Ok(value) => Response::success(request.id.clone(), value),
355                Err(e) => Response::error(
356                    request.id.clone(),
357                    JsonRpcError::internal_error(format!("Serialization error: {e}")),
358                ),
359            },
360            Err(e) => Response::error(
361                request.id.clone(),
362                JsonRpcError::internal_error(e.to_string()),
363            ),
364        }
365    }
366
367    /// Handle an elicitation/elicit request.
368    async fn handle_elicitation_request(request: &Request, handler: &Arc<H>) -> Response {
369        let params = match &request.params {
370            Some(p) => match serde_json::from_value::<ElicitRequest>(p.clone()) {
371                Ok(req) => req,
372                Err(e) => {
373                    return Response::error(
374                        request.id.clone(),
375                        JsonRpcError::invalid_params(format!("Invalid params: {e}")),
376                    );
377                }
378            },
379            None => {
380                return Response::error(
381                    request.id.clone(),
382                    JsonRpcError::invalid_params("Missing params for elicitation/elicit"),
383                );
384            }
385        };
386
387        match handler.elicit(params).await {
388            Ok(result) => match serde_json::to_value(result) {
389                Ok(value) => Response::success(request.id.clone(), value),
390                Err(e) => Response::error(
391                    request.id.clone(),
392                    JsonRpcError::internal_error(format!("Serialization error: {e}")),
393                ),
394            },
395            Err(e) => Response::error(
396                request.id.clone(),
397                JsonRpcError::internal_error(e.to_string()),
398            ),
399        }
400    }
401
402    /// Handle a roots/list request.
403    async fn handle_roots_request(request: &Request, handler: &Arc<H>) -> Response {
404        match handler.list_roots().await {
405            Ok(roots) => {
406                let roots_json: Vec<serde_json::Value> = roots
407                    .into_iter()
408                    .map(|r| {
409                        serde_json::json!({
410                            "uri": r.uri,
411                            "name": r.name
412                        })
413                    })
414                    .collect();
415                Response::success(
416                    request.id.clone(),
417                    serde_json::json!({ "roots": roots_json }),
418                )
419            }
420            Err(e) => Response::error(
421                request.id.clone(),
422                JsonRpcError::internal_error(e.to_string()),
423            ),
424        }
425    }
426
427    /// Handle a notification from the server.
428    async fn handle_notification(notification: Notification, handler: &Arc<H>) {
429        trace!(method = %notification.method, "Received server notification");
430
431        match notification.method.as_ref() {
432            "notifications/cancelled" => {
433                // Handle cancellation notifications
434                if let Some(params) = &notification.params {
435                    if let Some(request_id) = params.get("requestId") {
436                        debug!(?request_id, "Server cancelled request");
437                    }
438                }
439            }
440            "notifications/progress" => {
441                // Handle progress notifications
442                if let Some(params) = notification.params {
443                    if let (Some(task_id), Some(progress)) = (
444                        params.get("progressToken").and_then(|v| v.as_str()),
445                        params.get("progress"),
446                    ) {
447                        if let Ok(progress) = serde_json::from_value::<
448                            mcpkit_core::types::TaskProgress,
449                        >(progress.clone())
450                        {
451                            debug!(task_id = %task_id, "Task progress update");
452                            handler.on_task_progress(task_id.into(), progress).await;
453                        }
454                    }
455                }
456            }
457            "notifications/resources/updated" => {
458                if let Some(params) = notification.params {
459                    if let Some(uri) = params.get("uri").and_then(|v| v.as_str()) {
460                        debug!(uri = %uri, "Resource updated");
461                        handler.on_resource_updated(uri.to_string()).await;
462                    }
463                }
464            }
465            "notifications/resources/list_changed" => {
466                debug!("Resources list changed");
467                handler.on_resources_list_changed().await;
468            }
469            "notifications/tools/list_changed" => {
470                debug!("Tools list changed");
471                handler.on_tools_list_changed().await;
472            }
473            "notifications/prompts/list_changed" => {
474                debug!("Prompts list changed");
475                handler.on_prompts_list_changed().await;
476            }
477            _ => {
478                trace!(method = %notification.method, "Unhandled notification");
479            }
480        }
481    }
482
483    /// Get the server information.
484    pub const fn server_info(&self) -> &ServerInfo {
485        &self.server_info
486    }
487
488    /// Get the server capabilities.
489    pub const fn server_capabilities(&self) -> &ServerCapabilities {
490        &self.server_caps
491    }
492
493    /// Get the negotiated protocol version.
494    ///
495    /// Use this for feature detection. For example:
496    /// ```rust,ignore
497    /// if client.protocol_version().supports_tasks() {
498    ///     // Use task-related features
499    /// }
500    /// ```
501    pub fn protocol_version(&self) -> ProtocolVersion {
502        self.protocol_version
503    }
504
505    /// Get the client information.
506    pub const fn client_info(&self) -> &ClientInfo {
507        &self.client_info
508    }
509
510    /// Get the client capabilities.
511    pub const fn client_capabilities(&self) -> &ClientCapabilities {
512        &self.client_caps
513    }
514
515    /// Get the server instructions, if provided.
516    pub fn instructions(&self) -> Option<&str> {
517        self.instructions.as_deref()
518    }
519
520    /// Check if the server supports tools.
521    pub const fn has_tools(&self) -> bool {
522        self.server_caps.has_tools()
523    }
524
525    /// Check if the server supports resources.
526    pub const fn has_resources(&self) -> bool {
527        self.server_caps.has_resources()
528    }
529
530    /// Check if the server supports prompts.
531    pub const fn has_prompts(&self) -> bool {
532        self.server_caps.has_prompts()
533    }
534
535    /// Check if the server supports tasks.
536    pub const fn has_tasks(&self) -> bool {
537        self.server_caps.has_tasks()
538    }
539
540    /// Check if the server supports completions.
541    pub const fn has_completions(&self) -> bool {
542        self.server_caps.has_completions()
543    }
544
545    /// Check if the client is still connected.
546    pub fn is_connected(&self) -> bool {
547        self.running.load(Ordering::SeqCst)
548    }
549
550    // ==========================================================================
551    // Tool Operations
552    // ==========================================================================
553
554    /// List all available tools.
555    ///
556    /// # Errors
557    ///
558    /// Returns an error if tools are not supported or the request fails.
559    pub async fn list_tools(&self) -> Result<Vec<Tool>, McpError> {
560        self.ensure_capability("tools", self.has_tools())?;
561
562        let result: ListToolsResult = self.request("tools/list", None).await?;
563        Ok(result.tools)
564    }
565
566    /// List tools with pagination.
567    ///
568    /// # Errors
569    ///
570    /// Returns an error if tools are not supported or the request fails.
571    pub async fn list_tools_paginated(
572        &self,
573        cursor: Option<&str>,
574    ) -> Result<ListToolsResult, McpError> {
575        self.ensure_capability("tools", self.has_tools())?;
576
577        let params = cursor.map(|c| serde_json::json!({ "cursor": c }));
578        self.request("tools/list", params).await
579    }
580
581    /// Call a tool by name.
582    ///
583    /// # Arguments
584    ///
585    /// * `name` - The name of the tool to call
586    /// * `arguments` - The arguments to pass to the tool (as JSON)
587    ///
588    /// # Errors
589    ///
590    /// Returns an error if tools are not supported or the call fails.
591    pub async fn call_tool(
592        &self,
593        name: impl Into<String>,
594        arguments: serde_json::Value,
595    ) -> Result<CallToolResult, McpError> {
596        self.ensure_capability("tools", self.has_tools())?;
597
598        let request = CallToolRequest {
599            name: name.into(),
600            arguments: Some(arguments),
601        };
602        self.request("tools/call", Some(serde_json::to_value(request)?))
603            .await
604    }
605
606    // ==========================================================================
607    // Resource Operations
608    // ==========================================================================
609
610    /// List all available resources.
611    ///
612    /// # Errors
613    ///
614    /// Returns an error if resources are not supported or the request fails.
615    pub async fn list_resources(&self) -> Result<Vec<Resource>, McpError> {
616        self.ensure_capability("resources", self.has_resources())?;
617
618        let result: ListResourcesResult = self.request("resources/list", None).await?;
619        Ok(result.resources)
620    }
621
622    /// List resources with pagination.
623    ///
624    /// # Errors
625    ///
626    /// Returns an error if resources are not supported or the request fails.
627    pub async fn list_resources_paginated(
628        &self,
629        cursor: Option<&str>,
630    ) -> Result<ListResourcesResult, McpError> {
631        self.ensure_capability("resources", self.has_resources())?;
632
633        let params = cursor.map(|c| serde_json::json!({ "cursor": c }));
634        self.request("resources/list", params).await
635    }
636
637    /// List resource templates.
638    ///
639    /// # Errors
640    ///
641    /// Returns an error if resources are not supported or the request fails.
642    pub async fn list_resource_templates(&self) -> Result<Vec<ResourceTemplate>, McpError> {
643        self.ensure_capability("resources", self.has_resources())?;
644
645        let result: ListResourceTemplatesResult =
646            self.request("resources/templates/list", None).await?;
647        Ok(result.resource_templates)
648    }
649
650    /// Read a resource by URI.
651    ///
652    /// # Errors
653    ///
654    /// Returns an error if resources are not supported or the read fails.
655    pub async fn read_resource(
656        &self,
657        uri: impl Into<String>,
658    ) -> Result<Vec<ResourceContents>, McpError> {
659        self.ensure_capability("resources", self.has_resources())?;
660
661        let request = ReadResourceRequest { uri: uri.into() };
662        let result: ReadResourceResult = self
663            .request("resources/read", Some(serde_json::to_value(request)?))
664            .await?;
665        Ok(result.contents)
666    }
667
668    // ==========================================================================
669    // Prompt Operations
670    // ==========================================================================
671
672    /// List all available prompts.
673    ///
674    /// # Errors
675    ///
676    /// Returns an error if prompts are not supported or the request fails.
677    pub async fn list_prompts(&self) -> Result<Vec<Prompt>, McpError> {
678        self.ensure_capability("prompts", self.has_prompts())?;
679
680        let result: ListPromptsResult = self.request("prompts/list", None).await?;
681        Ok(result.prompts)
682    }
683
684    /// List prompts with pagination.
685    ///
686    /// # Errors
687    ///
688    /// Returns an error if prompts are not supported or the request fails.
689    pub async fn list_prompts_paginated(
690        &self,
691        cursor: Option<&str>,
692    ) -> Result<ListPromptsResult, McpError> {
693        self.ensure_capability("prompts", self.has_prompts())?;
694
695        let params = cursor.map(|c| serde_json::json!({ "cursor": c }));
696        self.request("prompts/list", params).await
697    }
698
699    /// Get a prompt by name, optionally with arguments.
700    ///
701    /// # Errors
702    ///
703    /// Returns an error if prompts are not supported or the get fails.
704    pub async fn get_prompt(
705        &self,
706        name: impl Into<String>,
707        arguments: Option<serde_json::Map<String, serde_json::Value>>,
708    ) -> Result<GetPromptResult, McpError> {
709        self.ensure_capability("prompts", self.has_prompts())?;
710
711        let request = GetPromptRequest {
712            name: name.into(),
713            arguments,
714        };
715        self.request("prompts/get", Some(serde_json::to_value(request)?))
716            .await
717    }
718
719    // ==========================================================================
720    // Task Operations
721    // ==========================================================================
722
723    /// List all tasks.
724    ///
725    /// # Errors
726    ///
727    /// Returns an error if tasks are not supported or the request fails.
728    pub async fn list_tasks(&self) -> Result<Vec<TaskSummary>, McpError> {
729        self.ensure_capability("tasks", self.has_tasks())?;
730
731        let result: ListTasksResult = self.request("tasks/list", None).await?;
732        Ok(result.tasks)
733    }
734
735    /// List tasks with optional status filter and pagination.
736    ///
737    /// # Errors
738    ///
739    /// Returns an error if tasks are not supported or the request fails.
740    pub async fn list_tasks_filtered(
741        &self,
742        status: Option<TaskStatus>,
743        cursor: Option<&str>,
744    ) -> Result<ListTasksResult, McpError> {
745        self.ensure_capability("tasks", self.has_tasks())?;
746
747        let request = ListTasksRequest {
748            status,
749            cursor: cursor.map(String::from),
750        };
751        self.request("tasks/list", Some(serde_json::to_value(request)?))
752            .await
753    }
754
755    /// Get a task by ID.
756    ///
757    /// # Errors
758    ///
759    /// Returns an error if tasks are not supported or the task is not found.
760    pub async fn get_task(&self, id: impl Into<String>) -> Result<Task, McpError> {
761        self.ensure_capability("tasks", self.has_tasks())?;
762
763        let request = GetTaskRequest {
764            id: id.into().into(),
765        };
766        self.request("tasks/get", Some(serde_json::to_value(request)?))
767            .await
768    }
769
770    /// Cancel a running task.
771    ///
772    /// # Errors
773    ///
774    /// Returns an error if tasks are not supported, cancellation is not supported,
775    /// or the task is not found.
776    pub async fn cancel_task(&self, id: impl Into<String>) -> Result<(), McpError> {
777        self.ensure_capability("tasks", self.has_tasks())?;
778
779        let request = CancelTaskRequest {
780            id: id.into().into(),
781        };
782        let _: serde_json::Value = self
783            .request("tasks/cancel", Some(serde_json::to_value(request)?))
784            .await?;
785        Ok(())
786    }
787
788    // ==========================================================================
789    // Completion Operations
790    // ==========================================================================
791
792    /// Get completions for a prompt argument.
793    ///
794    /// # Arguments
795    ///
796    /// * `prompt_name` - The name of the prompt
797    /// * `argument_name` - The name of the argument to complete
798    /// * `current_value` - The current partial value being typed
799    ///
800    /// # Errors
801    ///
802    /// Returns an error if completions are not supported or the request fails.
803    pub async fn complete_prompt_argument(
804        &self,
805        prompt_name: impl Into<String>,
806        argument_name: impl Into<String>,
807        current_value: impl Into<String>,
808    ) -> Result<CompleteResult, McpError> {
809        self.ensure_capability("completions", self.has_completions())?;
810
811        let request = CompleteRequest {
812            ref_: CompletionRef::prompt(prompt_name),
813            argument: CompletionArgument {
814                name: argument_name.into(),
815                value: current_value.into(),
816            },
817        };
818        self.request("completion/complete", Some(serde_json::to_value(request)?))
819            .await
820    }
821
822    /// Get completions for a resource argument.
823    ///
824    /// # Arguments
825    ///
826    /// * `resource_uri` - The URI of the resource
827    /// * `argument_name` - The name of the argument to complete
828    /// * `current_value` - The current partial value being typed
829    ///
830    /// # Errors
831    ///
832    /// Returns an error if completions are not supported or the request fails.
833    pub async fn complete_resource_argument(
834        &self,
835        resource_uri: impl Into<String>,
836        argument_name: impl Into<String>,
837        current_value: impl Into<String>,
838    ) -> Result<CompleteResult, McpError> {
839        self.ensure_capability("completions", self.has_completions())?;
840
841        let request = CompleteRequest {
842            ref_: CompletionRef::resource(resource_uri),
843            argument: CompletionArgument {
844                name: argument_name.into(),
845                value: current_value.into(),
846            },
847        };
848        self.request("completion/complete", Some(serde_json::to_value(request)?))
849            .await
850    }
851
852    // ==========================================================================
853    // Resource Subscription Operations
854    // ==========================================================================
855
856    /// Subscribe to updates for a resource.
857    ///
858    /// When subscribed, the server will send `notifications/resources/updated`
859    /// when the resource changes.
860    ///
861    /// # Errors
862    ///
863    /// Returns an error if resource subscriptions are not supported or the request fails.
864    pub async fn subscribe_resource(&self, uri: impl Into<String>) -> Result<(), McpError> {
865        self.ensure_capability("resources", self.has_resources())?;
866
867        // Check if subscribe is supported
868        if !self.server_caps.has_resource_subscribe() {
869            return Err(McpError::CapabilityNotSupported {
870                capability: "resources.subscribe".to_string(),
871                available: self.available_capabilities().into_boxed_slice(),
872            });
873        }
874
875        let params = serde_json::json!({ "uri": uri.into() });
876        let _: serde_json::Value = self.request("resources/subscribe", Some(params)).await?;
877        Ok(())
878    }
879
880    /// Unsubscribe from updates for a resource.
881    ///
882    /// # Errors
883    ///
884    /// Returns an error if resource subscriptions are not supported or the request fails.
885    pub async fn unsubscribe_resource(&self, uri: impl Into<String>) -> Result<(), McpError> {
886        self.ensure_capability("resources", self.has_resources())?;
887
888        // Check if subscribe is supported
889        if !self.server_caps.has_resource_subscribe() {
890            return Err(McpError::CapabilityNotSupported {
891                capability: "resources.subscribe".to_string(),
892                available: self.available_capabilities().into_boxed_slice(),
893            });
894        }
895
896        let params = serde_json::json!({ "uri": uri.into() });
897        let _: serde_json::Value = self.request("resources/unsubscribe", Some(params)).await?;
898        Ok(())
899    }
900
901    // ==========================================================================
902    // Connection Operations
903    // ==========================================================================
904
905    /// Ping the server.
906    ///
907    /// # Errors
908    ///
909    /// Returns an error if the ping fails.
910    pub async fn ping(&self) -> Result<(), McpError> {
911        let _: serde_json::Value = self.request("ping", None).await?;
912        Ok(())
913    }
914
915    /// Close the connection gracefully.
916    ///
917    /// # Errors
918    ///
919    /// Returns an error if the close fails.
920    pub async fn close(self) -> Result<(), McpError> {
921        debug!("Closing client connection");
922
923        // Signal the background task to stop
924        self.running.store(false, Ordering::SeqCst);
925
926        // Notify handler
927        self.handler.on_disconnected().await;
928
929        // Close the transport
930        self.transport.close().await.map_err(|e| {
931            McpError::Transport(Box::new(TransportDetails {
932                kind: TransportErrorKind::ConnectionClosed,
933                message: e.to_string(),
934                context: TransportContext::default(),
935                source: None,
936            }))
937        })
938    }
939
940    // ==========================================================================
941    // Internal Methods
942    // ==========================================================================
943
944    /// Generate the next request ID.
945    fn next_request_id(&self) -> RequestId {
946        RequestId::Number(self.next_id.fetch_add(1, Ordering::SeqCst))
947    }
948
949    /// Send a request and wait for the response.
950    async fn request<R: serde::de::DeserializeOwned>(
951        &self,
952        method: &str,
953        params: Option<serde_json::Value>,
954    ) -> Result<R, McpError> {
955        if !self.is_connected() {
956            return Err(McpError::Transport(Box::new(TransportDetails {
957                kind: TransportErrorKind::ConnectionClosed,
958                message: "Client is not connected".to_string(),
959                context: TransportContext::default(),
960                source: None,
961            })));
962        }
963
964        let id = self.next_request_id();
965        let request = if let Some(params) = params {
966            Request::with_params(method.to_string(), id.clone(), params)
967        } else {
968            Request::new(method.to_string(), id.clone())
969        };
970
971        trace!(?id, method, "Sending request");
972
973        // Create a channel for the response
974        let (tx, rx) = oneshot::channel();
975        {
976            let mut pending = self.pending.write().await;
977            pending.insert(id.clone(), tx);
978        }
979
980        // Send the request through the outgoing channel
981        self.outgoing_tx
982            .send(Message::Request(request))
983            .await
984            .map_err(|_| {
985                McpError::Transport(Box::new(TransportDetails {
986                    kind: TransportErrorKind::WriteFailed,
987                    message: "Failed to send request (channel closed)".to_string(),
988                    context: TransportContext::default(),
989                    source: None,
990                }))
991            })?;
992
993        // Wait for the response with a timeout
994        let response = rx.await.map_err(|_| {
995            McpError::Transport(Box::new(TransportDetails {
996                kind: TransportErrorKind::ConnectionClosed,
997                message: "Response channel closed (server may have disconnected)".to_string(),
998                context: TransportContext::default(),
999                source: None,
1000            }))
1001        })?;
1002
1003        // Process the response
1004        if let Some(error) = response.error {
1005            return Err(McpError::Internal {
1006                message: error.message,
1007                source: None,
1008            });
1009        }
1010
1011        let result = response.result.ok_or_else(|| McpError::Internal {
1012            message: "Response contained neither result nor error".to_string(),
1013            source: None,
1014        })?;
1015
1016        serde_json::from_value(result).map_err(McpError::from)
1017    }
1018
1019    /// Check that a capability is supported.
1020    fn ensure_capability(&self, name: &str, supported: bool) -> Result<(), McpError> {
1021        if supported {
1022            Ok(())
1023        } else {
1024            Err(McpError::CapabilityNotSupported {
1025                capability: name.to_string(),
1026                available: self.available_capabilities().into_boxed_slice(),
1027            })
1028        }
1029    }
1030
1031    /// Get list of available capabilities.
1032    fn available_capabilities(&self) -> Vec<String> {
1033        let mut caps = Vec::new();
1034        if self.has_tools() {
1035            caps.push("tools".to_string());
1036        }
1037        if self.has_resources() {
1038            caps.push("resources".to_string());
1039        }
1040        if self.has_prompts() {
1041            caps.push("prompts".to_string());
1042        }
1043        if self.has_tasks() {
1044            caps.push("tasks".to_string());
1045        }
1046        if self.has_completions() {
1047            caps.push("completions".to_string());
1048        }
1049        caps
1050    }
1051}
1052
1053/// Initialize a client connection.
1054///
1055/// This performs the MCP handshake with protocol version negotiation:
1056/// 1. Send initialize request with our preferred protocol version
1057/// 2. Wait for initialize result with server's negotiated version
1058/// 3. Validate we support the server's version (disconnect if not)
1059/// 4. Send initialized notification
1060///
1061/// # Protocol Version Negotiation
1062///
1063/// Per the MCP specification:
1064/// - Client sends its preferred (latest) protocol version
1065/// - Server responds with the same version if supported, or its own preferred version
1066/// - Client must support the server's version or the handshake fails
1067///
1068/// This SDK supports protocol versions: `2025-11-25`, `2024-11-05`.
1069pub(crate) async fn initialize<T: Transport>(
1070    transport: &T,
1071    client_info: &ClientInfo,
1072    capabilities: &ClientCapabilities,
1073) -> Result<InitializeResult, McpError> {
1074    debug!(
1075        protocol_version = %PROTOCOL_VERSION,
1076        supported_versions = ?SUPPORTED_PROTOCOL_VERSIONS,
1077        "Initializing MCP connection"
1078    );
1079
1080    // Build initialize request
1081    let request = InitializeRequest::new(client_info.clone(), capabilities.clone());
1082    let init_request = Request::with_params(
1083        "initialize".to_string(),
1084        RequestId::Number(0),
1085        serde_json::to_value(&request)?,
1086    );
1087
1088    // Send initialize request
1089    transport
1090        .send(Message::Request(init_request))
1091        .await
1092        .map_err(|e| {
1093            McpError::Transport(Box::new(TransportDetails {
1094                kind: TransportErrorKind::WriteFailed,
1095                message: format!("Failed to send initialize: {e}"),
1096                context: TransportContext::default(),
1097                source: None,
1098            }))
1099        })?;
1100
1101    // Wait for response
1102    let response = loop {
1103        match transport.recv().await {
1104            Ok(Some(Message::Response(r))) if r.id == RequestId::Number(0) => break r,
1105            Ok(Some(_)) => {} // Skip non-matching messages
1106            Ok(None) => {
1107                return Err(McpError::HandshakeFailed(Box::new(HandshakeDetails {
1108                    message: "Connection closed during initialization".to_string(),
1109                    client_version: Some(PROTOCOL_VERSION.to_string()),
1110                    server_version: None,
1111                    source: None,
1112                })));
1113            }
1114            Err(e) => {
1115                return Err(McpError::HandshakeFailed(Box::new(HandshakeDetails {
1116                    message: format!("Transport error during initialization: {e}"),
1117                    client_version: Some(PROTOCOL_VERSION.to_string()),
1118                    server_version: None,
1119                    source: None,
1120                })));
1121            }
1122        }
1123    };
1124
1125    // Parse the response
1126    if let Some(error) = response.error {
1127        return Err(McpError::HandshakeFailed(Box::new(HandshakeDetails {
1128            message: error.message,
1129            client_version: Some(PROTOCOL_VERSION.to_string()),
1130            server_version: None,
1131            source: None,
1132        })));
1133    }
1134
1135    let result: InitializeResult = response
1136        .result
1137        .map(serde_json::from_value)
1138        .transpose()?
1139        .ok_or_else(|| {
1140            McpError::HandshakeFailed(Box::new(HandshakeDetails {
1141                message: "Empty initialize result".to_string(),
1142                client_version: Some(PROTOCOL_VERSION.to_string()),
1143                server_version: None,
1144                source: None,
1145            }))
1146        })?;
1147
1148    // Validate protocol version
1149    let server_version = &result.protocol_version;
1150    if !is_version_supported(server_version) {
1151        warn!(
1152            server_version = %server_version,
1153            supported = ?SUPPORTED_PROTOCOL_VERSIONS,
1154            "Server returned unsupported protocol version"
1155        );
1156        return Err(McpError::HandshakeFailed(Box::new(HandshakeDetails {
1157            message: format!(
1158                "Unsupported protocol version: server returned '{server_version}', but client only supports {SUPPORTED_PROTOCOL_VERSIONS:?}"
1159            ),
1160            client_version: Some(PROTOCOL_VERSION.to_string()),
1161            server_version: Some(server_version.clone()),
1162            source: None,
1163        })));
1164    }
1165
1166    debug!(
1167        server = %result.server_info.name,
1168        server_version = %result.server_info.version,
1169        protocol_version = %result.protocol_version,
1170        "Received initialize result with compatible protocol version"
1171    );
1172
1173    // Send initialized notification
1174    let notification = Notification::new("notifications/initialized");
1175    transport
1176        .send(Message::Notification(notification))
1177        .await
1178        .map_err(|e| {
1179            McpError::Transport(Box::new(TransportDetails {
1180                kind: TransportErrorKind::WriteFailed,
1181                message: format!("Failed to send initialized: {e}"),
1182                context: TransportContext::default(),
1183                source: None,
1184            }))
1185        })?;
1186
1187    debug!("MCP initialization complete");
1188    Ok(result)
1189}
1190
1191#[cfg(test)]
1192mod tests {
1193    use super::*;
1194
1195    #[test]
1196    fn test_request_id_generation() {
1197        let next_id = AtomicU64::new(1);
1198        assert_eq!(next_id.fetch_add(1, Ordering::SeqCst), 1);
1199        assert_eq!(next_id.fetch_add(1, Ordering::SeqCst), 2);
1200    }
1201}