mcp_protocol_sdk/client/
mcp_client.rs

1//! MCP client implementation
2//!
3//! This module provides the main MCP client that can connect to MCP servers,
4//! initialize connections, and perform operations like calling tools, reading resources,
5//! and executing prompts according to the Model Context Protocol specification.
6
7use serde_json::Value;
8use std::collections::HashMap;
9use std::sync::Arc;
10use tokio::sync::{Mutex, RwLock};
11
12use crate::core::error::{McpError, McpResult};
13use crate::protocol::{messages::*, methods, types::*, validation::*};
14use crate::transport::traits::Transport;
15
16/// Configuration for the MCP client
17#[derive(Debug, Clone)]
18pub struct ClientConfig {
19    /// Request timeout in milliseconds
20    pub request_timeout_ms: u64,
21    /// Maximum number of retry attempts
22    pub max_retries: u32,
23    /// Retry delay in milliseconds
24    pub retry_delay_ms: u64,
25    /// Whether to validate all outgoing requests
26    pub validate_requests: bool,
27    /// Whether to validate all incoming responses
28    pub validate_responses: bool,
29}
30
31impl Default for ClientConfig {
32    fn default() -> Self {
33        Self {
34            request_timeout_ms: 30000,
35            max_retries: 3,
36            retry_delay_ms: 1000,
37            validate_requests: true,
38            validate_responses: true,
39        }
40    }
41}
42
43/// Main MCP client implementation
44pub struct McpClient {
45    /// Client information
46    info: ClientInfo,
47    /// Client capabilities
48    capabilities: ClientCapabilities,
49    /// Client configuration
50    config: ClientConfig,
51    /// Active transport
52    transport: Arc<Mutex<Option<Box<dyn Transport>>>>,
53    /// Server capabilities (available after initialization)
54    server_capabilities: Arc<RwLock<Option<ServerCapabilities>>>,
55    /// Server information (available after initialization)
56    server_info: Arc<RwLock<Option<ServerInfo>>>,
57    /// Request ID counter
58    request_counter: Arc<Mutex<u64>>,
59    /// Connection state
60    connected: Arc<RwLock<bool>>,
61}
62
63impl McpClient {
64    /// Create a new MCP client with the given name and version
65    pub fn new(name: String, version: String) -> Self {
66        Self {
67            info: ClientInfo { name, version },
68            capabilities: ClientCapabilities::default(),
69            config: ClientConfig::default(),
70            transport: Arc::new(Mutex::new(None)),
71            server_capabilities: Arc::new(RwLock::new(None)),
72            server_info: Arc::new(RwLock::new(None)),
73            request_counter: Arc::new(Mutex::new(0)),
74            connected: Arc::new(RwLock::new(false)),
75        }
76    }
77
78    /// Create a new MCP client with custom configuration
79    pub fn with_config(name: String, version: String, config: ClientConfig) -> Self {
80        let mut client = Self::new(name, version);
81        client.config = config;
82        client
83    }
84
85    /// Set client capabilities
86    pub fn set_capabilities(&mut self, capabilities: ClientCapabilities) {
87        self.capabilities = capabilities;
88    }
89
90    /// Get client information
91    pub fn info(&self) -> &ClientInfo {
92        &self.info
93    }
94
95    /// Get client capabilities
96    pub fn capabilities(&self) -> &ClientCapabilities {
97        &self.capabilities
98    }
99
100    /// Get client configuration
101    pub fn config(&self) -> &ClientConfig {
102        &self.config
103    }
104
105    /// Get server capabilities (if connected)
106    pub async fn server_capabilities(&self) -> Option<ServerCapabilities> {
107        let capabilities = self.server_capabilities.read().await;
108        capabilities.clone()
109    }
110
111    /// Get server information (if connected)
112    pub async fn server_info(&self) -> Option<ServerInfo> {
113        let info = self.server_info.read().await;
114        info.clone()
115    }
116
117    /// Check if the client is connected
118    pub async fn is_connected(&self) -> bool {
119        let connected = self.connected.read().await;
120        *connected
121    }
122
123    // ========================================================================
124    // Connection Management
125    // ========================================================================
126
127    /// Connect to an MCP server using the provided transport
128    pub async fn connect<T>(&mut self, transport: T) -> McpResult<InitializeResult>
129    where
130        T: Transport + 'static,
131    {
132        // Set the transport
133        {
134            let mut transport_guard = self.transport.lock().await;
135            *transport_guard = Some(Box::new(transport));
136        }
137
138        // Initialize the connection
139        let init_result = self.initialize().await?;
140
141        // Mark as connected
142        {
143            let mut connected = self.connected.write().await;
144            *connected = true;
145        }
146
147        Ok(init_result)
148    }
149
150    /// Disconnect from the server
151    pub async fn disconnect(&self) -> McpResult<()> {
152        // Close the transport
153        {
154            let mut transport_guard = self.transport.lock().await;
155            if let Some(transport) = transport_guard.as_mut() {
156                transport.close().await?;
157            }
158            *transport_guard = None;
159        }
160
161        // Clear server information
162        {
163            let mut server_capabilities = self.server_capabilities.write().await;
164            *server_capabilities = None;
165        }
166        {
167            let mut server_info = self.server_info.write().await;
168            *server_info = None;
169        }
170
171        // Mark as disconnected
172        {
173            let mut connected = self.connected.write().await;
174            *connected = false;
175        }
176
177        Ok(())
178    }
179
180    /// Initialize the connection with the server
181    async fn initialize(&self) -> McpResult<InitializeResult> {
182        let params = InitializeParams::new(
183            crate::protocol::LATEST_PROTOCOL_VERSION.to_string(),
184            self.capabilities.clone(),
185            self.info.clone(),
186        );
187
188        let request = JsonRpcRequest::new(
189            Value::from(self.next_request_id().await),
190            methods::INITIALIZE.to_string(),
191            Some(params),
192        )?;
193
194        let response = self.send_request(request).await?;
195
196        // The send_request method will return an error if there was a JSON-RPC error
197        // so we can safely extract the result here
198
199        let result: InitializeResult = serde_json::from_value(
200            response
201                .result
202                .ok_or_else(|| McpError::Protocol("Missing initialize result".to_string()))?,
203        )?;
204
205        // Store server information
206        {
207            let mut server_capabilities = self.server_capabilities.write().await;
208            *server_capabilities = Some(result.capabilities.clone());
209        }
210        {
211            let mut server_info = self.server_info.write().await;
212            *server_info = Some(result.server_info.clone());
213        }
214
215        Ok(result)
216    }
217
218    // ========================================================================
219    // Tool Operations
220    // ========================================================================
221
222    /// List available tools from the server
223    pub async fn list_tools(&self, cursor: Option<String>) -> McpResult<ListToolsResult> {
224        self.ensure_connected().await?;
225
226        let params = ListToolsParams { cursor, meta: None };
227        let request = JsonRpcRequest::new(
228            Value::from(self.next_request_id().await),
229            methods::TOOLS_LIST.to_string(),
230            Some(params),
231        )?;
232
233        let response = self.send_request(request).await?;
234        self.handle_response(response)
235    }
236
237    /// Call a tool on the server
238    pub async fn call_tool(
239        &self,
240        name: String,
241        arguments: Option<HashMap<String, Value>>,
242    ) -> McpResult<CallToolResult> {
243        self.ensure_connected().await?;
244
245        let params = if let Some(args) = arguments {
246            CallToolParams::new_with_arguments(name, args)
247        } else {
248            CallToolParams::new(name)
249        };
250
251        if self.config.validate_requests {
252            validate_call_tool_params(&params)?;
253        }
254
255        let request = JsonRpcRequest::new(
256            Value::from(self.next_request_id().await),
257            methods::TOOLS_CALL.to_string(),
258            Some(params),
259        )?;
260
261        let response = self.send_request(request).await?;
262        self.handle_response(response)
263    }
264
265    // ========================================================================
266    // Resource Operations
267    // ========================================================================
268
269    /// List available resources from the server
270    pub async fn list_resources(&self, cursor: Option<String>) -> McpResult<ListResourcesResult> {
271        self.ensure_connected().await?;
272
273        let params = ListResourcesParams { cursor, meta: None };
274        let request = JsonRpcRequest::new(
275            Value::from(self.next_request_id().await),
276            methods::RESOURCES_LIST.to_string(),
277            Some(params),
278        )?;
279
280        let response = self.send_request(request).await?;
281        self.handle_response(response)
282    }
283
284    /// Read a resource from the server
285    pub async fn read_resource(&self, uri: String) -> McpResult<ReadResourceResult> {
286        self.ensure_connected().await?;
287
288        let params = ReadResourceParams::new(uri);
289
290        if self.config.validate_requests {
291            validate_read_resource_params(&params)?;
292        }
293
294        let request = JsonRpcRequest::new(
295            Value::from(self.next_request_id().await),
296            methods::RESOURCES_READ.to_string(),
297            Some(params),
298        )?;
299
300        let response = self.send_request(request).await?;
301        self.handle_response(response)
302    }
303
304    /// Subscribe to resource updates
305    pub async fn subscribe_resource(&self, uri: String) -> McpResult<SubscribeResourceResult> {
306        self.ensure_connected().await?;
307
308        let params = SubscribeResourceParams { uri, meta: None };
309        let request = JsonRpcRequest::new(
310            Value::from(self.next_request_id().await),
311            methods::RESOURCES_SUBSCRIBE.to_string(),
312            Some(params),
313        )?;
314
315        let response = self.send_request(request).await?;
316        self.handle_response(response)
317    }
318
319    /// Unsubscribe from resource updates
320    pub async fn unsubscribe_resource(&self, uri: String) -> McpResult<UnsubscribeResourceResult> {
321        self.ensure_connected().await?;
322
323        let params = UnsubscribeResourceParams { uri, meta: None };
324        let request = JsonRpcRequest::new(
325            Value::from(self.next_request_id().await),
326            methods::RESOURCES_UNSUBSCRIBE.to_string(),
327            Some(params),
328        )?;
329
330        let response = self.send_request(request).await?;
331        self.handle_response(response)
332    }
333
334    // ========================================================================
335    // Prompt Operations
336    // ========================================================================
337
338    /// List available prompts from the server
339    pub async fn list_prompts(&self, cursor: Option<String>) -> McpResult<ListPromptsResult> {
340        self.ensure_connected().await?;
341
342        let params = ListPromptsParams { cursor, meta: None };
343        let request = JsonRpcRequest::new(
344            Value::from(self.next_request_id().await),
345            methods::PROMPTS_LIST.to_string(),
346            Some(params),
347        )?;
348
349        let response = self.send_request(request).await?;
350        self.handle_response(response)
351    }
352
353    /// Get a prompt from the server
354    pub async fn get_prompt(
355        &self,
356        name: String,
357        arguments: Option<HashMap<String, String>>,
358    ) -> McpResult<GetPromptResult> {
359        self.ensure_connected().await?;
360
361        let params = if let Some(args) = arguments {
362            GetPromptParams::new_with_arguments(name, args)
363        } else {
364            GetPromptParams::new(name)
365        };
366
367        if self.config.validate_requests {
368            validate_get_prompt_params(&params)?;
369        }
370
371        let request = JsonRpcRequest::new(
372            Value::from(self.next_request_id().await),
373            methods::PROMPTS_GET.to_string(),
374            Some(params),
375        )?;
376
377        let response = self.send_request(request).await?;
378        self.handle_response(response)
379    }
380
381    // ========================================================================
382    // Sampling Operations (if supported by server)
383    // ========================================================================
384
385    /// Create a message using server-side sampling
386    pub async fn create_message(
387        &self,
388        params: CreateMessageParams,
389    ) -> McpResult<CreateMessageResult> {
390        self.ensure_connected().await?;
391
392        // Check if server supports sampling
393        {
394            let server_capabilities = self.server_capabilities.read().await;
395            if let Some(capabilities) = server_capabilities.as_ref() {
396                if capabilities.sampling.is_none() {
397                    return Err(McpError::Protocol(
398                        "Server does not support sampling".to_string(),
399                    ));
400                }
401            } else {
402                return Err(McpError::Protocol("Not connected to server".to_string()));
403            }
404        }
405
406        if self.config.validate_requests {
407            validate_create_message_params(&params)?;
408        }
409
410        let request = JsonRpcRequest::new(
411            Value::from(self.next_request_id().await),
412            methods::SAMPLING_CREATE_MESSAGE.to_string(),
413            Some(params),
414        )?;
415
416        let response = self.send_request(request).await?;
417        self.handle_response(response)
418    }
419
420    // ========================================================================
421    // Utility Operations
422    // ========================================================================
423
424    /// Send a ping to the server
425    pub async fn ping(&self) -> McpResult<PingResult> {
426        self.ensure_connected().await?;
427
428        let request = JsonRpcRequest::new(
429            Value::from(self.next_request_id().await),
430            methods::PING.to_string(),
431            Some(PingParams { meta: None }),
432        )?;
433
434        let response = self.send_request(request).await?;
435        self.handle_response(response)
436    }
437
438    /// Set the logging level on the server
439    pub async fn set_logging_level(&self, level: LoggingLevel) -> McpResult<SetLoggingLevelResult> {
440        self.ensure_connected().await?;
441
442        let params = SetLoggingLevelParams { level, meta: None };
443        let request = JsonRpcRequest::new(
444            Value::from(self.next_request_id().await),
445            methods::LOGGING_SET_LEVEL.to_string(),
446            Some(params),
447        )?;
448
449        let response = self.send_request(request).await?;
450        self.handle_response(response)
451    }
452
453    // ========================================================================
454    // Notification Handling
455    // ========================================================================
456
457    /// Receive notifications from the server
458    pub async fn receive_notification(&self) -> McpResult<Option<JsonRpcNotification>> {
459        let mut transport_guard = self.transport.lock().await;
460        if let Some(transport) = transport_guard.as_mut() {
461            transport.receive_notification().await
462        } else {
463            Err(McpError::Transport("Not connected".to_string()))
464        }
465    }
466
467    // ========================================================================
468    // Helper Methods
469    // ========================================================================
470
471    /// Send a request and get a response
472    async fn send_request(&self, request: JsonRpcRequest) -> McpResult<JsonRpcResponse> {
473        if self.config.validate_requests {
474            validate_jsonrpc_request(&request)?;
475            validate_mcp_request(&request.method, request.params.as_ref())?;
476        }
477
478        let mut transport_guard = self.transport.lock().await;
479        if let Some(transport) = transport_guard.as_mut() {
480            let response = transport.send_request(request).await?;
481
482            if self.config.validate_responses {
483                validate_jsonrpc_response(&response)?;
484            }
485
486            Ok(response)
487        } else {
488            Err(McpError::Transport("Not connected".to_string()))
489        }
490    }
491
492    /// Handle a JSON-RPC response and extract the result
493    fn handle_response<T>(&self, response: JsonRpcResponse) -> McpResult<T>
494    where
495        T: serde::de::DeserializeOwned,
496    {
497        // JsonRpcResponse only contains successful responses
498        // Errors are handled separately by the transport layer
499        let result = response
500            .result
501            .ok_or_else(|| McpError::Protocol("Missing result in response".to_string()))?;
502
503        serde_json::from_value(result).map_err(|e| McpError::Serialization(e.to_string()))
504    }
505
506    /// Ensure the client is connected
507    async fn ensure_connected(&self) -> McpResult<()> {
508        if !self.is_connected().await {
509            return Err(McpError::Connection("Not connected to server".to_string()));
510        }
511        Ok(())
512    }
513
514    /// Get the next request ID
515    async fn next_request_id(&self) -> u64 {
516        let mut counter = self.request_counter.lock().await;
517        *counter += 1;
518        *counter
519    }
520}
521
522/// Client builder for easier construction
523pub struct McpClientBuilder {
524    name: String,
525    version: String,
526    capabilities: ClientCapabilities,
527    config: ClientConfig,
528}
529
530impl McpClientBuilder {
531    /// Create a new client builder
532    pub fn new(name: String, version: String) -> Self {
533        Self {
534            name,
535            version,
536            capabilities: ClientCapabilities::default(),
537            config: ClientConfig::default(),
538        }
539    }
540
541    /// Set client capabilities
542    pub fn capabilities(mut self, capabilities: ClientCapabilities) -> Self {
543        self.capabilities = capabilities;
544        self
545    }
546
547    /// Set client configuration
548    pub fn config(mut self, config: ClientConfig) -> Self {
549        self.config = config;
550        self
551    }
552
553    /// Set request timeout
554    pub fn request_timeout(mut self, timeout_ms: u64) -> Self {
555        self.config.request_timeout_ms = timeout_ms;
556        self
557    }
558
559    /// Set maximum retries
560    pub fn max_retries(mut self, retries: u32) -> Self {
561        self.config.max_retries = retries;
562        self
563    }
564
565    /// Enable or disable request validation
566    pub fn validate_requests(mut self, validate: bool) -> Self {
567        self.config.validate_requests = validate;
568        self
569    }
570
571    /// Enable or disable response validation
572    pub fn validate_responses(mut self, validate: bool) -> Self {
573        self.config.validate_responses = validate;
574        self
575    }
576
577    /// Build the client
578    pub fn build(self) -> McpClient {
579        let mut client = McpClient::new(self.name, self.version);
580        client.set_capabilities(self.capabilities);
581        client.config = self.config;
582        client
583    }
584}
585
586#[cfg(test)]
587mod tests {
588    use super::*;
589    use async_trait::async_trait;
590
591    // Mock transport for testing
592    struct MockTransport {
593        responses: Vec<JsonRpcResponse>,
594        current: usize,
595    }
596
597    impl MockTransport {
598        fn new(responses: Vec<JsonRpcResponse>) -> Self {
599            Self {
600                responses,
601                current: 0,
602            }
603        }
604    }
605
606    #[async_trait]
607    impl Transport for MockTransport {
608        async fn send_request(&mut self, _request: JsonRpcRequest) -> McpResult<JsonRpcResponse> {
609            if self.current < self.responses.len() {
610                let response = self.responses[self.current].clone();
611                self.current += 1;
612                Ok(response)
613            } else {
614                Err(McpError::Transport("No more responses".to_string()))
615            }
616        }
617
618        async fn send_notification(&mut self, _notification: JsonRpcNotification) -> McpResult<()> {
619            Ok(())
620        }
621
622        async fn receive_notification(&mut self) -> McpResult<Option<JsonRpcNotification>> {
623            Ok(None)
624        }
625
626        async fn close(&mut self) -> McpResult<()> {
627            Ok(())
628        }
629    }
630
631    #[tokio::test]
632    async fn test_client_creation() {
633        let client = McpClient::new("test-client".to_string(), "1.0.0".to_string());
634        assert_eq!(client.info().name, "test-client");
635        assert_eq!(client.info().version, "1.0.0");
636        assert!(!client.is_connected().await);
637    }
638
639    #[tokio::test]
640    async fn test_client_builder() {
641        let client = McpClientBuilder::new("test-client".to_string(), "1.0.0".to_string())
642            .request_timeout(5000)
643            .max_retries(5)
644            .validate_requests(false)
645            .build();
646
647        assert_eq!(client.config().request_timeout_ms, 5000);
648        assert_eq!(client.config().max_retries, 5);
649        assert!(!client.config().validate_requests);
650    }
651
652    #[tokio::test]
653    async fn test_mock_connection() {
654        let init_result = InitializeResult::new(
655            crate::protocol::LATEST_PROTOCOL_VERSION.to_string(),
656            ServerCapabilities::default(),
657            ServerInfo {
658                name: "test-server".to_string(),
659                version: "1.0.0".to_string(),
660            },
661        );
662
663        let init_response = JsonRpcResponse::success(Value::from(1), init_result.clone()).unwrap();
664
665        let transport = MockTransport::new(vec![init_response]);
666
667        let mut client = McpClient::new("test-client".to_string(), "1.0.0".to_string());
668        let result = client.connect(transport).await.unwrap();
669
670        assert_eq!(result.server_info.name, "test-server");
671        assert!(client.is_connected().await);
672    }
673
674    #[tokio::test]
675    async fn test_disconnect() {
676        let init_result = InitializeResult::new(
677            crate::protocol::LATEST_PROTOCOL_VERSION.to_string(),
678            ServerCapabilities::default(),
679            ServerInfo {
680                name: "test-server".to_string(),
681                version: "1.0.0".to_string(),
682            },
683        );
684
685        let init_response = JsonRpcResponse::success(Value::from(1), init_result).unwrap();
686
687        let transport = MockTransport::new(vec![init_response]);
688
689        let mut client = McpClient::new("test-client".to_string(), "1.0.0".to_string());
690        client.connect(transport).await.unwrap();
691
692        assert!(client.is_connected().await);
693
694        client.disconnect().await.unwrap();
695        assert!(!client.is_connected().await);
696        assert!(client.server_info().await.is_none());
697        assert!(client.server_capabilities().await.is_none());
698    }
699}