mcp_protocol_sdk/client/
mcp_client.rs

1//! MCP client implementation
2//!
3//! This module provides the main MCP client that can connect to MCP servers,
4//! initialize connections, and perform operations like calling tools, reading resources,
5//! and executing prompts according to the Model Context Protocol specification.
6
7use serde_json::Value;
8use std::collections::HashMap;
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::sync::{Mutex, RwLock};
12
13use crate::core::error::{McpError, McpResult};
14use crate::protocol::{messages::*, types::*, validation::*};
15use crate::transport::traits::Transport;
16
17/// Configuration for the MCP client
18#[derive(Debug, Clone)]
19pub struct ClientConfig {
20    /// Request timeout in milliseconds
21    pub request_timeout_ms: u64,
22    /// Maximum number of retry attempts
23    pub max_retries: u32,
24    /// Retry delay in milliseconds
25    pub retry_delay_ms: u64,
26    /// Whether to validate all outgoing requests
27    pub validate_requests: bool,
28    /// Whether to validate all incoming responses
29    pub validate_responses: bool,
30}
31
32impl Default for ClientConfig {
33    fn default() -> Self {
34        Self {
35            request_timeout_ms: 30000,
36            max_retries: 3,
37            retry_delay_ms: 1000,
38            validate_requests: true,
39            validate_responses: true,
40        }
41    }
42}
43
44/// Main MCP client implementation
45pub struct McpClient {
46    /// Client information
47    info: ClientInfo,
48    /// Client capabilities
49    capabilities: ClientCapabilities,
50    /// Client configuration
51    config: ClientConfig,
52    /// Active transport
53    transport: Arc<Mutex<Option<Box<dyn Transport>>>>,
54    /// Server capabilities (available after initialization)
55    server_capabilities: Arc<RwLock<Option<ServerCapabilities>>>,
56    /// Server information (available after initialization)
57    server_info: Arc<RwLock<Option<ServerInfo>>>,
58    /// Request ID counter
59    request_counter: Arc<Mutex<u64>>,
60    /// Connection state
61    connected: Arc<RwLock<bool>>,
62}
63
64impl McpClient {
65    /// Create a new MCP client with the given name and version
66    pub fn new(name: String, version: String) -> Self {
67        Self {
68            info: ClientInfo { name, version },
69            capabilities: ClientCapabilities::default(),
70            config: ClientConfig::default(),
71            transport: Arc::new(Mutex::new(None)),
72            server_capabilities: Arc::new(RwLock::new(None)),
73            server_info: Arc::new(RwLock::new(None)),
74            request_counter: Arc::new(Mutex::new(0)),
75            connected: Arc::new(RwLock::new(false)),
76        }
77    }
78
79    /// Create a new MCP client with custom configuration
80    pub fn with_config(name: String, version: String, config: ClientConfig) -> Self {
81        let mut client = Self::new(name, version);
82        client.config = config;
83        client
84    }
85
86    /// Set client capabilities
87    pub fn set_capabilities(&mut self, capabilities: ClientCapabilities) {
88        self.capabilities = capabilities;
89    }
90
91    /// Get client information
92    pub fn info(&self) -> &ClientInfo {
93        &self.info
94    }
95
96    /// Get client capabilities
97    pub fn capabilities(&self) -> &ClientCapabilities {
98        &self.capabilities
99    }
100
101    /// Get client configuration
102    pub fn config(&self) -> &ClientConfig {
103        &self.config
104    }
105
106    /// Get server capabilities (if connected)
107    pub async fn server_capabilities(&self) -> Option<ServerCapabilities> {
108        let capabilities = self.server_capabilities.read().await;
109        capabilities.clone()
110    }
111
112    /// Get server information (if connected)
113    pub async fn server_info(&self) -> Option<ServerInfo> {
114        let info = self.server_info.read().await;
115        info.clone()
116    }
117
118    /// Check if the client is connected
119    pub async fn is_connected(&self) -> bool {
120        let connected = self.connected.read().await;
121        *connected
122    }
123
124    // ========================================================================
125    // Connection Management
126    // ========================================================================
127
128    /// Connect to an MCP server using the provided transport
129    pub async fn connect<T>(&mut self, transport: T) -> McpResult<InitializeResult>
130    where
131        T: Transport + 'static,
132    {
133        // Set the transport
134        {
135            let mut transport_guard = self.transport.lock().await;
136            *transport_guard = Some(Box::new(transport));
137        }
138
139        // Initialize the connection
140        let init_result = self.initialize().await?;
141
142        // Mark as connected
143        {
144            let mut connected = self.connected.write().await;
145            *connected = true;
146        }
147
148        Ok(init_result)
149    }
150
151    /// Disconnect from the server
152    pub async fn disconnect(&self) -> McpResult<()> {
153        // Close the transport
154        {
155            let mut transport_guard = self.transport.lock().await;
156            if let Some(transport) = transport_guard.as_mut() {
157                transport.close().await?;
158            }
159            *transport_guard = None;
160        }
161
162        // Clear server information
163        {
164            let mut server_capabilities = self.server_capabilities.write().await;
165            *server_capabilities = None;
166        }
167        {
168            let mut server_info = self.server_info.write().await;
169            *server_info = None;
170        }
171
172        // Mark as disconnected
173        {
174            let mut connected = self.connected.write().await;
175            *connected = false;
176        }
177
178        Ok(())
179    }
180
181    /// Initialize the connection with the server
182    async fn initialize(&self) -> McpResult<InitializeResult> {
183        let params = InitializeParams::new(
184            self.info.clone(),
185            self.capabilities.clone(),
186            MCP_PROTOCOL_VERSION.to_string(),
187        );
188
189        let request = JsonRpcRequest::new(
190            Value::from(self.next_request_id().await),
191            methods::INITIALIZE.to_string(),
192            Some(params),
193        )?;
194
195        let response = self.send_request(request).await?;
196
197        if let Some(error) = response.error {
198            return Err(McpError::Protocol(format!(
199                "Initialize failed: {}",
200                error.message
201            )));
202        }
203
204        let result: InitializeResult = serde_json::from_value(
205            response
206                .result
207                .ok_or_else(|| McpError::Protocol("Missing initialize result".to_string()))?,
208        )?;
209
210        // Store server information
211        {
212            let mut server_capabilities = self.server_capabilities.write().await;
213            *server_capabilities = Some(result.capabilities.clone());
214        }
215        {
216            let mut server_info = self.server_info.write().await;
217            *server_info = Some(result.server_info.clone());
218        }
219
220        Ok(result)
221    }
222
223    // ========================================================================
224    // Tool Operations
225    // ========================================================================
226
227    /// List available tools from the server
228    pub async fn list_tools(&self, cursor: Option<String>) -> McpResult<ListToolsResult> {
229        self.ensure_connected().await?;
230
231        let params = ListToolsParams { cursor };
232        let request = JsonRpcRequest::new(
233            Value::from(self.next_request_id().await),
234            methods::TOOLS_LIST.to_string(),
235            Some(params),
236        )?;
237
238        let response = self.send_request(request).await?;
239        self.handle_response(response)
240    }
241
242    /// Call a tool on the server
243    pub async fn call_tool(
244        &self,
245        name: String,
246        arguments: Option<HashMap<String, Value>>,
247    ) -> McpResult<CallToolResult> {
248        self.ensure_connected().await?;
249
250        let params = CallToolParams::new(name, arguments);
251
252        if self.config.validate_requests {
253            validate_call_tool_params(&params)?;
254        }
255
256        let request = JsonRpcRequest::new(
257            Value::from(self.next_request_id().await),
258            methods::TOOLS_CALL.to_string(),
259            Some(params),
260        )?;
261
262        let response = self.send_request(request).await?;
263        self.handle_response(response)
264    }
265
266    // ========================================================================
267    // Resource Operations
268    // ========================================================================
269
270    /// List available resources from the server
271    pub async fn list_resources(&self, cursor: Option<String>) -> McpResult<ListResourcesResult> {
272        self.ensure_connected().await?;
273
274        let params = ListResourcesParams { cursor };
275        let request = JsonRpcRequest::new(
276            Value::from(self.next_request_id().await),
277            methods::RESOURCES_LIST.to_string(),
278            Some(params),
279        )?;
280
281        let response = self.send_request(request).await?;
282        self.handle_response(response)
283    }
284
285    /// Read a resource from the server
286    pub async fn read_resource(&self, uri: String) -> McpResult<ReadResourceResult> {
287        self.ensure_connected().await?;
288
289        let params = ReadResourceParams::new(uri);
290
291        if self.config.validate_requests {
292            validate_read_resource_params(&params)?;
293        }
294
295        let request = JsonRpcRequest::new(
296            Value::from(self.next_request_id().await),
297            methods::RESOURCES_READ.to_string(),
298            Some(params),
299        )?;
300
301        let response = self.send_request(request).await?;
302        self.handle_response(response)
303    }
304
305    /// Subscribe to resource updates
306    pub async fn subscribe_resource(&self, uri: String) -> McpResult<SubscribeResourceResult> {
307        self.ensure_connected().await?;
308
309        let params = SubscribeResourceParams { uri };
310        let request = JsonRpcRequest::new(
311            Value::from(self.next_request_id().await),
312            methods::RESOURCES_SUBSCRIBE.to_string(),
313            Some(params),
314        )?;
315
316        let response = self.send_request(request).await?;
317        self.handle_response(response)
318    }
319
320    /// Unsubscribe from resource updates
321    pub async fn unsubscribe_resource(&self, uri: String) -> McpResult<UnsubscribeResourceResult> {
322        self.ensure_connected().await?;
323
324        let params = UnsubscribeResourceParams { uri };
325        let request = JsonRpcRequest::new(
326            Value::from(self.next_request_id().await),
327            methods::RESOURCES_UNSUBSCRIBE.to_string(),
328            Some(params),
329        )?;
330
331        let response = self.send_request(request).await?;
332        self.handle_response(response)
333    }
334
335    // ========================================================================
336    // Prompt Operations
337    // ========================================================================
338
339    /// List available prompts from the server
340    pub async fn list_prompts(&self, cursor: Option<String>) -> McpResult<ListPromptsResult> {
341        self.ensure_connected().await?;
342
343        let params = ListPromptsParams { cursor };
344        let request = JsonRpcRequest::new(
345            Value::from(self.next_request_id().await),
346            methods::PROMPTS_LIST.to_string(),
347            Some(params),
348        )?;
349
350        let response = self.send_request(request).await?;
351        self.handle_response(response)
352    }
353
354    /// Get a prompt from the server
355    pub async fn get_prompt(
356        &self,
357        name: String,
358        arguments: Option<HashMap<String, Value>>,
359    ) -> McpResult<GetPromptResult> {
360        self.ensure_connected().await?;
361
362        let params = GetPromptParams::new(name, arguments);
363
364        if self.config.validate_requests {
365            validate_get_prompt_params(&params)?;
366        }
367
368        let request = JsonRpcRequest::new(
369            Value::from(self.next_request_id().await),
370            methods::PROMPTS_GET.to_string(),
371            Some(params),
372        )?;
373
374        let response = self.send_request(request).await?;
375        self.handle_response(response)
376    }
377
378    // ========================================================================
379    // Sampling Operations (if supported by server)
380    // ========================================================================
381
382    /// Create a message using server-side sampling
383    pub async fn create_message(
384        &self,
385        params: CreateMessageParams,
386    ) -> McpResult<CreateMessageResult> {
387        self.ensure_connected().await?;
388
389        // Check if server supports sampling
390        {
391            let server_capabilities = self.server_capabilities.read().await;
392            if let Some(capabilities) = server_capabilities.as_ref() {
393                if capabilities.sampling.is_none() {
394                    return Err(McpError::Protocol(
395                        "Server does not support sampling".to_string(),
396                    ));
397                }
398            } else {
399                return Err(McpError::Protocol("Not connected to server".to_string()));
400            }
401        }
402
403        if self.config.validate_requests {
404            validate_create_message_params(&params)?;
405        }
406
407        let request = JsonRpcRequest::new(
408            Value::from(self.next_request_id().await),
409            methods::SAMPLING_CREATE_MESSAGE.to_string(),
410            Some(params),
411        )?;
412
413        let response = self.send_request(request).await?;
414        self.handle_response(response)
415    }
416
417    // ========================================================================
418    // Utility Operations
419    // ========================================================================
420
421    /// Send a ping to the server
422    pub async fn ping(&self) -> McpResult<PingResult> {
423        self.ensure_connected().await?;
424
425        let request = JsonRpcRequest::new(
426            Value::from(self.next_request_id().await),
427            methods::PING.to_string(),
428            Some(PingParams {}),
429        )?;
430
431        let response = self.send_request(request).await?;
432        self.handle_response(response)
433    }
434
435    /// Set the logging level on the server
436    pub async fn set_logging_level(&self, level: LoggingLevel) -> McpResult<SetLoggingLevelResult> {
437        self.ensure_connected().await?;
438
439        let params = SetLoggingLevelParams { level };
440        let request = JsonRpcRequest::new(
441            Value::from(self.next_request_id().await),
442            methods::LOGGING_SET_LEVEL.to_string(),
443            Some(params),
444        )?;
445
446        let response = self.send_request(request).await?;
447        self.handle_response(response)
448    }
449
450    // ========================================================================
451    // Notification Handling
452    // ========================================================================
453
454    /// Receive notifications from the server
455    pub async fn receive_notification(&self) -> McpResult<Option<JsonRpcNotification>> {
456        let mut transport_guard = self.transport.lock().await;
457        if let Some(transport) = transport_guard.as_mut() {
458            transport.receive_notification().await
459        } else {
460            Err(McpError::Transport("Not connected".to_string()))
461        }
462    }
463
464    // ========================================================================
465    // Helper Methods
466    // ========================================================================
467
468    /// Send a request and get a response
469    async fn send_request(&self, request: JsonRpcRequest) -> McpResult<JsonRpcResponse> {
470        if self.config.validate_requests {
471            validate_jsonrpc_request(&request)?;
472            validate_mcp_request(&request.method, request.params.as_ref())?;
473        }
474
475        let mut transport_guard = self.transport.lock().await;
476        if let Some(transport) = transport_guard.as_mut() {
477            let response = transport.send_request(request).await?;
478
479            if self.config.validate_responses {
480                validate_jsonrpc_response(&response)?;
481            }
482
483            Ok(response)
484        } else {
485            Err(McpError::Transport("Not connected".to_string()))
486        }
487    }
488
489    /// Handle a JSON-RPC response and extract the result
490    fn handle_response<T>(&self, response: JsonRpcResponse) -> McpResult<T>
491    where
492        T: serde::de::DeserializeOwned,
493    {
494        if let Some(error) = response.error {
495            return Err(McpError::Protocol(format!(
496                "Server error: {}",
497                error.message
498            )));
499        }
500
501        let result = response
502            .result
503            .ok_or_else(|| McpError::Protocol("Missing result in response".to_string()))?;
504
505        serde_json::from_value(result).map_err(|e| McpError::Serialization(e))
506    }
507
508    /// Ensure the client is connected
509    async fn ensure_connected(&self) -> McpResult<()> {
510        if !self.is_connected().await {
511            return Err(McpError::Connection("Not connected to server".to_string()));
512        }
513        Ok(())
514    }
515
516    /// Get the next request ID
517    async fn next_request_id(&self) -> u64 {
518        let mut counter = self.request_counter.lock().await;
519        *counter += 1;
520        *counter
521    }
522}
523
524/// Client builder for easier construction
525pub struct McpClientBuilder {
526    name: String,
527    version: String,
528    capabilities: ClientCapabilities,
529    config: ClientConfig,
530}
531
532impl McpClientBuilder {
533    /// Create a new client builder
534    pub fn new(name: String, version: String) -> Self {
535        Self {
536            name,
537            version,
538            capabilities: ClientCapabilities::default(),
539            config: ClientConfig::default(),
540        }
541    }
542
543    /// Set client capabilities
544    pub fn capabilities(mut self, capabilities: ClientCapabilities) -> Self {
545        self.capabilities = capabilities;
546        self
547    }
548
549    /// Set client configuration
550    pub fn config(mut self, config: ClientConfig) -> Self {
551        self.config = config;
552        self
553    }
554
555    /// Set request timeout
556    pub fn request_timeout(mut self, timeout_ms: u64) -> Self {
557        self.config.request_timeout_ms = timeout_ms;
558        self
559    }
560
561    /// Set maximum retries
562    pub fn max_retries(mut self, retries: u32) -> Self {
563        self.config.max_retries = retries;
564        self
565    }
566
567    /// Enable or disable request validation
568    pub fn validate_requests(mut self, validate: bool) -> Self {
569        self.config.validate_requests = validate;
570        self
571    }
572
573    /// Enable or disable response validation
574    pub fn validate_responses(mut self, validate: bool) -> Self {
575        self.config.validate_responses = validate;
576        self
577    }
578
579    /// Build the client
580    pub fn build(self) -> McpClient {
581        let mut client = McpClient::new(self.name, self.version);
582        client.set_capabilities(self.capabilities);
583        client.config = self.config;
584        client
585    }
586}
587
588#[cfg(test)]
589mod tests {
590    use super::*;
591    use async_trait::async_trait;
592
593    // Mock transport for testing
594    struct MockTransport {
595        responses: Vec<JsonRpcResponse>,
596        current: usize,
597    }
598
599    impl MockTransport {
600        fn new(responses: Vec<JsonRpcResponse>) -> Self {
601            Self {
602                responses,
603                current: 0,
604            }
605        }
606    }
607
608    #[async_trait]
609    impl Transport for MockTransport {
610        async fn send_request(&mut self, _request: JsonRpcRequest) -> McpResult<JsonRpcResponse> {
611            if self.current < self.responses.len() {
612                let response = self.responses[self.current].clone();
613                self.current += 1;
614                Ok(response)
615            } else {
616                Err(McpError::Transport("No more responses".to_string()))
617            }
618        }
619
620        async fn send_notification(&mut self, _notification: JsonRpcNotification) -> McpResult<()> {
621            Ok(())
622        }
623
624        async fn receive_notification(&mut self) -> McpResult<Option<JsonRpcNotification>> {
625            Ok(None)
626        }
627
628        async fn close(&mut self) -> McpResult<()> {
629            Ok(())
630        }
631    }
632
633    #[tokio::test]
634    async fn test_client_creation() {
635        let client = McpClient::new("test-client".to_string(), "1.0.0".to_string());
636        assert_eq!(client.info().name, "test-client");
637        assert_eq!(client.info().version, "1.0.0");
638        assert!(!client.is_connected().await);
639    }
640
641    #[tokio::test]
642    async fn test_client_builder() {
643        let client = McpClientBuilder::new("test-client".to_string(), "1.0.0".to_string())
644            .request_timeout(5000)
645            .max_retries(5)
646            .validate_requests(false)
647            .build();
648
649        assert_eq!(client.config().request_timeout_ms, 5000);
650        assert_eq!(client.config().max_retries, 5);
651        assert!(!client.config().validate_requests);
652    }
653
654    #[tokio::test]
655    async fn test_mock_connection() {
656        let init_result = InitializeResult::new(
657            ServerInfo {
658                name: "test-server".to_string(),
659                version: "1.0.0".to_string(),
660            },
661            ServerCapabilities::default(),
662            MCP_PROTOCOL_VERSION.to_string(),
663        );
664
665        let init_response = JsonRpcResponse::success(Value::from(1), init_result.clone()).unwrap();
666
667        let transport = MockTransport::new(vec![init_response]);
668
669        let mut client = McpClient::new("test-client".to_string(), "1.0.0".to_string());
670        let result = client.connect(transport).await.unwrap();
671
672        assert_eq!(result.server_info.name, "test-server");
673        assert!(client.is_connected().await);
674    }
675
676    #[tokio::test]
677    async fn test_disconnect() {
678        let init_result = InitializeResult::new(
679            ServerInfo {
680                name: "test-server".to_string(),
681                version: "1.0.0".to_string(),
682            },
683            ServerCapabilities::default(),
684            MCP_PROTOCOL_VERSION.to_string(),
685        );
686
687        let init_response = JsonRpcResponse::success(Value::from(1), init_result).unwrap();
688
689        let transport = MockTransport::new(vec![init_response]);
690
691        let mut client = McpClient::new("test-client".to_string(), "1.0.0".to_string());
692        client.connect(transport).await.unwrap();
693
694        assert!(client.is_connected().await);
695
696        client.disconnect().await.unwrap();
697        assert!(!client.is_connected().await);
698        assert!(client.server_info().await.is_none());
699        assert!(client.server_capabilities().await.is_none());
700    }
701}