mcp_protocol_sdk/client/
mcp_client.rs

1//! MCP client implementation
2//!
3//! This module provides the main MCP client that can connect to MCP servers,
4//! initialize connections, and perform operations like calling tools, reading resources,
5//! and executing prompts according to the Model Context Protocol specification.
6
7use serde_json::Value;
8use std::collections::HashMap;
9use std::sync::Arc;
10use tokio::sync::{Mutex, RwLock};
11
12use crate::core::error::{McpError, McpResult};
13use crate::protocol::{messages::*, types::*, validation::*};
14use crate::transport::traits::Transport;
15
16/// Configuration for the MCP client
17#[derive(Debug, Clone)]
18pub struct ClientConfig {
19    /// Request timeout in milliseconds
20    pub request_timeout_ms: u64,
21    /// Maximum number of retry attempts
22    pub max_retries: u32,
23    /// Retry delay in milliseconds
24    pub retry_delay_ms: u64,
25    /// Whether to validate all outgoing requests
26    pub validate_requests: bool,
27    /// Whether to validate all incoming responses
28    pub validate_responses: bool,
29}
30
31impl Default for ClientConfig {
32    fn default() -> Self {
33        Self {
34            request_timeout_ms: 30000,
35            max_retries: 3,
36            retry_delay_ms: 1000,
37            validate_requests: true,
38            validate_responses: true,
39        }
40    }
41}
42
43/// Main MCP client implementation
44pub struct McpClient {
45    /// Client information
46    info: ClientInfo,
47    /// Client capabilities
48    capabilities: ClientCapabilities,
49    /// Client configuration
50    config: ClientConfig,
51    /// Active transport
52    transport: Arc<Mutex<Option<Box<dyn Transport>>>>,
53    /// Server capabilities (available after initialization)
54    server_capabilities: Arc<RwLock<Option<ServerCapabilities>>>,
55    /// Server information (available after initialization)
56    server_info: Arc<RwLock<Option<ServerInfo>>>,
57    /// Request ID counter
58    request_counter: Arc<Mutex<u64>>,
59    /// Connection state
60    connected: Arc<RwLock<bool>>,
61}
62
63impl McpClient {
64    /// Create a new MCP client with the given name and version
65    pub fn new(name: String, version: String) -> Self {
66        Self {
67            info: ClientInfo { name, version },
68            capabilities: ClientCapabilities::default(),
69            config: ClientConfig::default(),
70            transport: Arc::new(Mutex::new(None)),
71            server_capabilities: Arc::new(RwLock::new(None)),
72            server_info: Arc::new(RwLock::new(None)),
73            request_counter: Arc::new(Mutex::new(0)),
74            connected: Arc::new(RwLock::new(false)),
75        }
76    }
77
78    /// Create a new MCP client with custom configuration
79    pub fn with_config(name: String, version: String, config: ClientConfig) -> Self {
80        let mut client = Self::new(name, version);
81        client.config = config;
82        client
83    }
84
85    /// Set client capabilities
86    pub fn set_capabilities(&mut self, capabilities: ClientCapabilities) {
87        self.capabilities = capabilities;
88    }
89
90    /// Get client information
91    pub fn info(&self) -> &ClientInfo {
92        &self.info
93    }
94
95    /// Get client capabilities
96    pub fn capabilities(&self) -> &ClientCapabilities {
97        &self.capabilities
98    }
99
100    /// Get client configuration
101    pub fn config(&self) -> &ClientConfig {
102        &self.config
103    }
104
105    /// Get server capabilities (if connected)
106    pub async fn server_capabilities(&self) -> Option<ServerCapabilities> {
107        let capabilities = self.server_capabilities.read().await;
108        capabilities.clone()
109    }
110
111    /// Get server information (if connected)
112    pub async fn server_info(&self) -> Option<ServerInfo> {
113        let info = self.server_info.read().await;
114        info.clone()
115    }
116
117    /// Check if the client is connected
118    pub async fn is_connected(&self) -> bool {
119        let connected = self.connected.read().await;
120        *connected
121    }
122
123    // ========================================================================
124    // Connection Management
125    // ========================================================================
126
127    /// Connect to an MCP server using the provided transport
128    pub async fn connect<T>(&mut self, transport: T) -> McpResult<InitializeResult>
129    where
130        T: Transport + 'static,
131    {
132        // Set the transport
133        {
134            let mut transport_guard = self.transport.lock().await;
135            *transport_guard = Some(Box::new(transport));
136        }
137
138        // Initialize the connection
139        let init_result = self.initialize().await?;
140
141        // Mark as connected
142        {
143            let mut connected = self.connected.write().await;
144            *connected = true;
145        }
146
147        Ok(init_result)
148    }
149
150    /// Disconnect from the server
151    pub async fn disconnect(&self) -> McpResult<()> {
152        // Close the transport
153        {
154            let mut transport_guard = self.transport.lock().await;
155            if let Some(transport) = transport_guard.as_mut() {
156                transport.close().await?;
157            }
158            *transport_guard = None;
159        }
160
161        // Clear server information
162        {
163            let mut server_capabilities = self.server_capabilities.write().await;
164            *server_capabilities = None;
165        }
166        {
167            let mut server_info = self.server_info.write().await;
168            *server_info = None;
169        }
170
171        // Mark as disconnected
172        {
173            let mut connected = self.connected.write().await;
174            *connected = false;
175        }
176
177        Ok(())
178    }
179
180    /// Initialize the connection with the server
181    async fn initialize(&self) -> McpResult<InitializeResult> {
182        let params = InitializeParams::new(
183            self.info.clone(),
184            self.capabilities.clone(),
185            MCP_PROTOCOL_VERSION.to_string(),
186        );
187
188        let request = JsonRpcRequest::new(
189            Value::from(self.next_request_id().await),
190            methods::INITIALIZE.to_string(),
191            Some(params),
192        )?;
193
194        let response = self.send_request(request).await?;
195
196        if let Some(error) = response.error {
197            return Err(McpError::Protocol(format!(
198                "Initialize failed: {}",
199                error.message
200            )));
201        }
202
203        let result: InitializeResult = serde_json::from_value(
204            response
205                .result
206                .ok_or_else(|| McpError::Protocol("Missing initialize result".to_string()))?,
207        )?;
208
209        // Store server information
210        {
211            let mut server_capabilities = self.server_capabilities.write().await;
212            *server_capabilities = Some(result.capabilities.clone());
213        }
214        {
215            let mut server_info = self.server_info.write().await;
216            *server_info = Some(result.server_info.clone());
217        }
218
219        Ok(result)
220    }
221
222    // ========================================================================
223    // Tool Operations
224    // ========================================================================
225
226    /// List available tools from the server
227    pub async fn list_tools(&self, cursor: Option<String>) -> McpResult<ListToolsResult> {
228        self.ensure_connected().await?;
229
230        let params = ListToolsParams { cursor };
231        let request = JsonRpcRequest::new(
232            Value::from(self.next_request_id().await),
233            methods::TOOLS_LIST.to_string(),
234            Some(params),
235        )?;
236
237        let response = self.send_request(request).await?;
238        self.handle_response(response)
239    }
240
241    /// Call a tool on the server
242    pub async fn call_tool(
243        &self,
244        name: String,
245        arguments: Option<HashMap<String, Value>>,
246    ) -> McpResult<CallToolResult> {
247        self.ensure_connected().await?;
248
249        let params = CallToolParams::new(name, arguments);
250
251        if self.config.validate_requests {
252            validate_call_tool_params(&params)?;
253        }
254
255        let request = JsonRpcRequest::new(
256            Value::from(self.next_request_id().await),
257            methods::TOOLS_CALL.to_string(),
258            Some(params),
259        )?;
260
261        let response = self.send_request(request).await?;
262        self.handle_response(response)
263    }
264
265    // ========================================================================
266    // Resource Operations
267    // ========================================================================
268
269    /// List available resources from the server
270    pub async fn list_resources(&self, cursor: Option<String>) -> McpResult<ListResourcesResult> {
271        self.ensure_connected().await?;
272
273        let params = ListResourcesParams { cursor };
274        let request = JsonRpcRequest::new(
275            Value::from(self.next_request_id().await),
276            methods::RESOURCES_LIST.to_string(),
277            Some(params),
278        )?;
279
280        let response = self.send_request(request).await?;
281        self.handle_response(response)
282    }
283
284    /// Read a resource from the server
285    pub async fn read_resource(&self, uri: String) -> McpResult<ReadResourceResult> {
286        self.ensure_connected().await?;
287
288        let params = ReadResourceParams::new(uri);
289
290        if self.config.validate_requests {
291            validate_read_resource_params(&params)?;
292        }
293
294        let request = JsonRpcRequest::new(
295            Value::from(self.next_request_id().await),
296            methods::RESOURCES_READ.to_string(),
297            Some(params),
298        )?;
299
300        let response = self.send_request(request).await?;
301        self.handle_response(response)
302    }
303
304    /// Subscribe to resource updates
305    pub async fn subscribe_resource(&self, uri: String) -> McpResult<SubscribeResourceResult> {
306        self.ensure_connected().await?;
307
308        let params = SubscribeResourceParams { uri };
309        let request = JsonRpcRequest::new(
310            Value::from(self.next_request_id().await),
311            methods::RESOURCES_SUBSCRIBE.to_string(),
312            Some(params),
313        )?;
314
315        let response = self.send_request(request).await?;
316        self.handle_response(response)
317    }
318
319    /// Unsubscribe from resource updates
320    pub async fn unsubscribe_resource(&self, uri: String) -> McpResult<UnsubscribeResourceResult> {
321        self.ensure_connected().await?;
322
323        let params = UnsubscribeResourceParams { uri };
324        let request = JsonRpcRequest::new(
325            Value::from(self.next_request_id().await),
326            methods::RESOURCES_UNSUBSCRIBE.to_string(),
327            Some(params),
328        )?;
329
330        let response = self.send_request(request).await?;
331        self.handle_response(response)
332    }
333
334    // ========================================================================
335    // Prompt Operations
336    // ========================================================================
337
338    /// List available prompts from the server
339    pub async fn list_prompts(&self, cursor: Option<String>) -> McpResult<ListPromptsResult> {
340        self.ensure_connected().await?;
341
342        let params = ListPromptsParams { cursor };
343        let request = JsonRpcRequest::new(
344            Value::from(self.next_request_id().await),
345            methods::PROMPTS_LIST.to_string(),
346            Some(params),
347        )?;
348
349        let response = self.send_request(request).await?;
350        self.handle_response(response)
351    }
352
353    /// Get a prompt from the server
354    pub async fn get_prompt(
355        &self,
356        name: String,
357        arguments: Option<HashMap<String, Value>>,
358    ) -> McpResult<GetPromptResult> {
359        self.ensure_connected().await?;
360
361        let params = GetPromptParams::new(name, arguments);
362
363        if self.config.validate_requests {
364            validate_get_prompt_params(&params)?;
365        }
366
367        let request = JsonRpcRequest::new(
368            Value::from(self.next_request_id().await),
369            methods::PROMPTS_GET.to_string(),
370            Some(params),
371        )?;
372
373        let response = self.send_request(request).await?;
374        self.handle_response(response)
375    }
376
377    // ========================================================================
378    // Sampling Operations (if supported by server)
379    // ========================================================================
380
381    /// Create a message using server-side sampling
382    pub async fn create_message(
383        &self,
384        params: CreateMessageParams,
385    ) -> McpResult<CreateMessageResult> {
386        self.ensure_connected().await?;
387
388        // Check if server supports sampling
389        {
390            let server_capabilities = self.server_capabilities.read().await;
391            if let Some(capabilities) = server_capabilities.as_ref() {
392                if capabilities.sampling.is_none() {
393                    return Err(McpError::Protocol(
394                        "Server does not support sampling".to_string(),
395                    ));
396                }
397            } else {
398                return Err(McpError::Protocol("Not connected to server".to_string()));
399            }
400        }
401
402        if self.config.validate_requests {
403            validate_create_message_params(&params)?;
404        }
405
406        let request = JsonRpcRequest::new(
407            Value::from(self.next_request_id().await),
408            methods::SAMPLING_CREATE_MESSAGE.to_string(),
409            Some(params),
410        )?;
411
412        let response = self.send_request(request).await?;
413        self.handle_response(response)
414    }
415
416    // ========================================================================
417    // Utility Operations
418    // ========================================================================
419
420    /// Send a ping to the server
421    pub async fn ping(&self) -> McpResult<PingResult> {
422        self.ensure_connected().await?;
423
424        let request = JsonRpcRequest::new(
425            Value::from(self.next_request_id().await),
426            methods::PING.to_string(),
427            Some(PingParams {}),
428        )?;
429
430        let response = self.send_request(request).await?;
431        self.handle_response(response)
432    }
433
434    /// Set the logging level on the server
435    pub async fn set_logging_level(&self, level: LoggingLevel) -> McpResult<SetLoggingLevelResult> {
436        self.ensure_connected().await?;
437
438        let params = SetLoggingLevelParams { level };
439        let request = JsonRpcRequest::new(
440            Value::from(self.next_request_id().await),
441            methods::LOGGING_SET_LEVEL.to_string(),
442            Some(params),
443        )?;
444
445        let response = self.send_request(request).await?;
446        self.handle_response(response)
447    }
448
449    // ========================================================================
450    // Notification Handling
451    // ========================================================================
452
453    /// Receive notifications from the server
454    pub async fn receive_notification(&self) -> McpResult<Option<JsonRpcNotification>> {
455        let mut transport_guard = self.transport.lock().await;
456        if let Some(transport) = transport_guard.as_mut() {
457            transport.receive_notification().await
458        } else {
459            Err(McpError::Transport("Not connected".to_string()))
460        }
461    }
462
463    // ========================================================================
464    // Helper Methods
465    // ========================================================================
466
467    /// Send a request and get a response
468    async fn send_request(&self, request: JsonRpcRequest) -> McpResult<JsonRpcResponse> {
469        if self.config.validate_requests {
470            validate_jsonrpc_request(&request)?;
471            validate_mcp_request(&request.method, request.params.as_ref())?;
472        }
473
474        let mut transport_guard = self.transport.lock().await;
475        if let Some(transport) = transport_guard.as_mut() {
476            let response = transport.send_request(request).await?;
477
478            if self.config.validate_responses {
479                validate_jsonrpc_response(&response)?;
480            }
481
482            Ok(response)
483        } else {
484            Err(McpError::Transport("Not connected".to_string()))
485        }
486    }
487
488    /// Handle a JSON-RPC response and extract the result
489    fn handle_response<T>(&self, response: JsonRpcResponse) -> McpResult<T>
490    where
491        T: serde::de::DeserializeOwned,
492    {
493        if let Some(error) = response.error {
494            return Err(McpError::Protocol(format!(
495                "Server error: {}",
496                error.message
497            )));
498        }
499
500        let result = response
501            .result
502            .ok_or_else(|| McpError::Protocol("Missing result in response".to_string()))?;
503
504        serde_json::from_value(result).map_err(McpError::Serialization)
505    }
506
507    /// Ensure the client is connected
508    async fn ensure_connected(&self) -> McpResult<()> {
509        if !self.is_connected().await {
510            return Err(McpError::Connection("Not connected to server".to_string()));
511        }
512        Ok(())
513    }
514
515    /// Get the next request ID
516    async fn next_request_id(&self) -> u64 {
517        let mut counter = self.request_counter.lock().await;
518        *counter += 1;
519        *counter
520    }
521}
522
523/// Client builder for easier construction
524pub struct McpClientBuilder {
525    name: String,
526    version: String,
527    capabilities: ClientCapabilities,
528    config: ClientConfig,
529}
530
531impl McpClientBuilder {
532    /// Create a new client builder
533    pub fn new(name: String, version: String) -> Self {
534        Self {
535            name,
536            version,
537            capabilities: ClientCapabilities::default(),
538            config: ClientConfig::default(),
539        }
540    }
541
542    /// Set client capabilities
543    pub fn capabilities(mut self, capabilities: ClientCapabilities) -> Self {
544        self.capabilities = capabilities;
545        self
546    }
547
548    /// Set client configuration
549    pub fn config(mut self, config: ClientConfig) -> Self {
550        self.config = config;
551        self
552    }
553
554    /// Set request timeout
555    pub fn request_timeout(mut self, timeout_ms: u64) -> Self {
556        self.config.request_timeout_ms = timeout_ms;
557        self
558    }
559
560    /// Set maximum retries
561    pub fn max_retries(mut self, retries: u32) -> Self {
562        self.config.max_retries = retries;
563        self
564    }
565
566    /// Enable or disable request validation
567    pub fn validate_requests(mut self, validate: bool) -> Self {
568        self.config.validate_requests = validate;
569        self
570    }
571
572    /// Enable or disable response validation
573    pub fn validate_responses(mut self, validate: bool) -> Self {
574        self.config.validate_responses = validate;
575        self
576    }
577
578    /// Build the client
579    pub fn build(self) -> McpClient {
580        let mut client = McpClient::new(self.name, self.version);
581        client.set_capabilities(self.capabilities);
582        client.config = self.config;
583        client
584    }
585}
586
587#[cfg(test)]
588mod tests {
589    use super::*;
590    use async_trait::async_trait;
591
592    // Mock transport for testing
593    struct MockTransport {
594        responses: Vec<JsonRpcResponse>,
595        current: usize,
596    }
597
598    impl MockTransport {
599        fn new(responses: Vec<JsonRpcResponse>) -> Self {
600            Self {
601                responses,
602                current: 0,
603            }
604        }
605    }
606
607    #[async_trait]
608    impl Transport for MockTransport {
609        async fn send_request(&mut self, _request: JsonRpcRequest) -> McpResult<JsonRpcResponse> {
610            if self.current < self.responses.len() {
611                let response = self.responses[self.current].clone();
612                self.current += 1;
613                Ok(response)
614            } else {
615                Err(McpError::Transport("No more responses".to_string()))
616            }
617        }
618
619        async fn send_notification(&mut self, _notification: JsonRpcNotification) -> McpResult<()> {
620            Ok(())
621        }
622
623        async fn receive_notification(&mut self) -> McpResult<Option<JsonRpcNotification>> {
624            Ok(None)
625        }
626
627        async fn close(&mut self) -> McpResult<()> {
628            Ok(())
629        }
630    }
631
632    #[tokio::test]
633    async fn test_client_creation() {
634        let client = McpClient::new("test-client".to_string(), "1.0.0".to_string());
635        assert_eq!(client.info().name, "test-client");
636        assert_eq!(client.info().version, "1.0.0");
637        assert!(!client.is_connected().await);
638    }
639
640    #[tokio::test]
641    async fn test_client_builder() {
642        let client = McpClientBuilder::new("test-client".to_string(), "1.0.0".to_string())
643            .request_timeout(5000)
644            .max_retries(5)
645            .validate_requests(false)
646            .build();
647
648        assert_eq!(client.config().request_timeout_ms, 5000);
649        assert_eq!(client.config().max_retries, 5);
650        assert!(!client.config().validate_requests);
651    }
652
653    #[tokio::test]
654    async fn test_mock_connection() {
655        let init_result = InitializeResult::new(
656            ServerInfo {
657                name: "test-server".to_string(),
658                version: "1.0.0".to_string(),
659            },
660            ServerCapabilities::default(),
661            MCP_PROTOCOL_VERSION.to_string(),
662        );
663
664        let init_response = JsonRpcResponse::success(Value::from(1), init_result.clone()).unwrap();
665
666        let transport = MockTransport::new(vec![init_response]);
667
668        let mut client = McpClient::new("test-client".to_string(), "1.0.0".to_string());
669        let result = client.connect(transport).await.unwrap();
670
671        assert_eq!(result.server_info.name, "test-server");
672        assert!(client.is_connected().await);
673    }
674
675    #[tokio::test]
676    async fn test_disconnect() {
677        let init_result = InitializeResult::new(
678            ServerInfo {
679                name: "test-server".to_string(),
680                version: "1.0.0".to_string(),
681            },
682            ServerCapabilities::default(),
683            MCP_PROTOCOL_VERSION.to_string(),
684        );
685
686        let init_response = JsonRpcResponse::success(Value::from(1), init_result).unwrap();
687
688        let transport = MockTransport::new(vec![init_response]);
689
690        let mut client = McpClient::new("test-client".to_string(), "1.0.0".to_string());
691        client.connect(transport).await.unwrap();
692
693        assert!(client.is_connected().await);
694
695        client.disconnect().await.unwrap();
696        assert!(!client.is_connected().await);
697        assert!(client.server_info().await.is_none());
698        assert!(client.server_capabilities().await.is_none());
699    }
700}