reasonkit/mcp/
client.rs

1//! MCP Client Implementation
2//!
3//! Client for connecting to and communicating with MCP servers.
4//!
5//! This module provides:
6//! - Connection management to external MCP servers
7//! - Tool execution via RPC
8//! - Resource access
9//! - Automatic reconnection and error handling
10
11use super::transport::{StdioTransport, Transport};
12use super::types::*;
13use crate::error::{Error, Result};
14use async_trait::async_trait;
15use chrono::{DateTime, Utc};
16use serde::{Deserialize, Serialize};
17use std::collections::HashMap;
18use std::sync::Arc;
19use tokio::sync::RwLock;
20use uuid::Uuid;
21
22/// MCP client configuration
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct McpClientConfig {
25    /// Server name
26    pub name: String,
27    /// Server command (e.g., "npx", "node", "python")
28    pub command: String,
29    /// Command arguments
30    pub args: Vec<String>,
31    /// Environment variables
32    #[serde(default)]
33    pub env: HashMap<String, String>,
34    /// Connection timeout in seconds
35    #[serde(default = "default_timeout")]
36    pub timeout_secs: u64,
37    /// Auto-reconnect on failure
38    #[serde(default = "default_reconnect")]
39    pub auto_reconnect: bool,
40    /// Maximum reconnection attempts
41    #[serde(default = "default_max_retries")]
42    pub max_retries: u32,
43}
44
45fn default_timeout() -> u64 {
46    30
47}
48
49fn default_reconnect() -> bool {
50    true
51}
52
53fn default_max_retries() -> u32 {
54    3
55}
56
57/// MCP client connection state
58#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
59#[serde(rename_all = "snake_case")]
60pub enum ConnectionState {
61    /// Not connected
62    Disconnected,
63    /// Connecting
64    Connecting,
65    /// Connected and initialized
66    Connected,
67    /// Connection failed
68    Failed,
69    /// Reconnecting
70    Reconnecting,
71}
72
73/// MCP client statistics
74#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct ClientStats {
76    /// Total requests sent
77    pub requests_sent: u64,
78    /// Total responses received
79    pub responses_received: u64,
80    /// Total errors encountered
81    pub errors_total: u64,
82    /// Average response time (ms)
83    pub avg_response_time_ms: f64,
84    /// Connection uptime (seconds)
85    pub uptime_secs: u64,
86    /// Reconnection attempts
87    pub reconnect_attempts: u32,
88    /// Last successful request
89    pub last_request_at: Option<DateTime<Utc>>,
90}
91
92impl Default for ClientStats {
93    fn default() -> Self {
94        Self {
95            requests_sent: 0,
96            responses_received: 0,
97            errors_total: 0,
98            avg_response_time_ms: 0.0,
99            uptime_secs: 0,
100            reconnect_attempts: 0,
101            last_request_at: None,
102        }
103    }
104}
105
106/// MCP client trait
107#[async_trait]
108pub trait McpClientTrait: Send + Sync {
109    /// Connect to the server
110    async fn connect(&mut self) -> Result<()>;
111
112    /// Disconnect from the server
113    async fn disconnect(&mut self) -> Result<()>;
114
115    /// Get connection state
116    async fn state(&self) -> ConnectionState;
117
118    /// List available tools
119    async fn list_tools(&self) -> Result<Vec<super::tools::Tool>>;
120
121    /// Call a tool
122    async fn call_tool(
123        &self,
124        name: &str,
125        arguments: serde_json::Value,
126    ) -> Result<super::tools::ToolResult>;
127
128    /// List available resources
129    async fn list_resources(&self) -> Result<Vec<super::tools::ResourceTemplate>>;
130
131    /// Read a resource
132    async fn read_resource(&self, uri: &str) -> Result<serde_json::Value>;
133
134    /// Get client statistics
135    async fn stats(&self) -> ClientStats;
136
137    /// Perform a health check
138    async fn ping(&self) -> Result<bool>;
139}
140
141/// Concrete MCP client implementation
142pub struct McpClient {
143    /// Client ID
144    pub id: Uuid,
145    /// Client configuration
146    pub config: McpClientConfig,
147    /// Transport layer
148    transport: Arc<RwLock<Option<Arc<dyn Transport>>>>,
149    /// Connection state
150    state: Arc<RwLock<ConnectionState>>,
151    /// Server information (after initialization)
152    server_info: Arc<RwLock<Option<ServerInfo>>>,
153    /// Server capabilities (after initialization)
154    server_capabilities: Arc<RwLock<Option<ServerCapabilities>>>,
155    /// Client statistics
156    stats: Arc<RwLock<ClientStats>>,
157    /// Connected at timestamp
158    connected_at: Arc<RwLock<Option<DateTime<Utc>>>>,
159}
160
161impl McpClient {
162    /// Create a new MCP client
163    pub fn new(config: McpClientConfig) -> Self {
164        Self {
165            id: Uuid::new_v4(),
166            config,
167            transport: Arc::new(RwLock::new(None)),
168            state: Arc::new(RwLock::new(ConnectionState::Disconnected)),
169            server_info: Arc::new(RwLock::new(None)),
170            server_capabilities: Arc::new(RwLock::new(None)),
171            stats: Arc::new(RwLock::new(ClientStats::default())),
172            connected_at: Arc::new(RwLock::new(None)),
173        }
174    }
175
176    /// Get server information (must be connected)
177    pub async fn server_info(&self) -> Option<ServerInfo> {
178        self.server_info.read().await.clone()
179    }
180
181    /// Get server capabilities (must be connected)
182    pub async fn capabilities(&self) -> Option<ServerCapabilities> {
183        self.server_capabilities.read().await.clone()
184    }
185
186    /// Update statistics for a successful request
187    async fn record_success(&self, response_time_ms: f64) {
188        let mut s = self.stats.write().await;
189        s.responses_received += 1;
190        s.last_request_at = Some(Utc::now());
191
192        // Update average response time (exponential moving average)
193        if s.responses_received == 1 {
194            s.avg_response_time_ms = response_time_ms;
195        } else {
196            s.avg_response_time_ms = (s.avg_response_time_ms * 0.9) + (response_time_ms * 0.1);
197        }
198    }
199
200    /// Update statistics for an error
201    async fn record_error(&self) {
202        let mut s = self.stats.write().await;
203        s.errors_total += 1;
204    }
205
206    /// Send a request with automatic retries
207    async fn send_request_with_retry(&self, request: McpRequest) -> Result<McpResponse> {
208        let mut attempts = 0;
209        let max_retries = self.config.max_retries;
210
211        loop {
212            let transport_guard = self.transport.read().await;
213            let transport = transport_guard
214                .as_ref()
215                .ok_or_else(|| Error::network("Not connected to server"))?;
216
217            let start = std::time::Instant::now();
218            let result = transport.send_request(request.clone()).await;
219            let elapsed_ms = start.elapsed().as_millis() as f64;
220
221            match result {
222                Ok(response) => {
223                    if response.error.is_some() {
224                        self.record_error().await;
225                    } else {
226                        self.record_success(elapsed_ms).await;
227                    }
228                    return Ok(response);
229                }
230                Err(e) => {
231                    self.record_error().await;
232                    attempts += 1;
233
234                    if attempts >= max_retries {
235                        return Err(Error::network(format!(
236                            "Request failed after {} attempts: {}",
237                            attempts, e
238                        )));
239                    }
240
241                    // Exponential backoff
242                    let backoff_ms = 100 * (2_u64.pow(attempts - 1));
243                    tokio::time::sleep(tokio::time::Duration::from_millis(backoff_ms)).await;
244                }
245            }
246        }
247    }
248}
249
250#[async_trait]
251impl McpClientTrait for McpClient {
252    async fn connect(&mut self) -> Result<()> {
253        // Set state to connecting
254        *self.state.write().await = ConnectionState::Connecting;
255
256        // Convert HashMap to Vec<(String, String)> for env
257        let env_vec: Vec<(String, String)> = self.config.env.clone().into_iter().collect();
258
259        // Create stdio transport using spawn method
260        let transport =
261            StdioTransport::spawn(&self.config.command, self.config.args.clone(), env_vec)
262                .await
263                .map_err(|e| Error::network(format!("Failed to create transport: {}", e)))?;
264
265        *self.transport.write().await = Some(Arc::new(transport));
266
267        // Send initialize request
268        let init_params = serde_json::json!({
269            "protocolVersion": crate::mcp::MCP_VERSION,
270            "capabilities": {},
271            "clientInfo": {
272                "name": "reasonkit-core",
273                "version": env!("CARGO_PKG_VERSION")
274            }
275        });
276
277        let request = McpRequest::new(
278            RequestId::String(Uuid::new_v4().to_string()),
279            "initialize",
280            Some(init_params),
281        );
282
283        let response = self.send_request_with_retry(request).await?;
284
285        if let Some(error) = response.error {
286            *self.state.write().await = ConnectionState::Failed;
287            return Err(Error::network(format!(
288                "Initialize failed: {}",
289                error.message
290            )));
291        }
292
293        // Parse initialization result
294        if let Some(result) = response.result {
295            if let Ok(init_result) =
296                serde_json::from_value::<super::lifecycle::InitializeResult>(result)
297            {
298                *self.server_info.write().await = Some(init_result.server_info);
299                *self.server_capabilities.write().await = Some(init_result.capabilities);
300            }
301        }
302
303        // Send initialized notification
304        let notification = McpNotification {
305            jsonrpc: JsonRpcVersion::default(),
306            method: "notifications/initialized".to_string(),
307            params: None,
308        };
309
310        let transport_guard = self.transport.read().await;
311        if let Some(transport) = transport_guard.as_ref() {
312            transport.send_notification(notification).await.ok();
313        }
314
315        *self.state.write().await = ConnectionState::Connected;
316        *self.connected_at.write().await = Some(Utc::now());
317
318        Ok(())
319    }
320
321    async fn disconnect(&mut self) -> Result<()> {
322        // Send shutdown request
323        let request = McpRequest::new(
324            RequestId::String(Uuid::new_v4().to_string()),
325            "shutdown",
326            None,
327        );
328
329        // Best effort - ignore errors
330        let _ = self.send_request_with_retry(request).await;
331
332        // Clear transport
333        *self.transport.write().await = None;
334        *self.state.write().await = ConnectionState::Disconnected;
335        *self.connected_at.write().await = None;
336
337        Ok(())
338    }
339
340    async fn state(&self) -> ConnectionState {
341        *self.state.read().await
342    }
343
344    async fn list_tools(&self) -> Result<Vec<super::tools::Tool>> {
345        let request = McpRequest::new(
346            RequestId::String(Uuid::new_v4().to_string()),
347            "tools/list",
348            None,
349        );
350
351        let response = self.send_request_with_retry(request).await?;
352
353        if let Some(error) = response.error {
354            return Err(Error::network(format!(
355                "tools/list failed: {}",
356                error.message
357            )));
358        }
359
360        let result = response
361            .result
362            .ok_or_else(|| Error::network("tools/list response missing result"))?;
363
364        #[derive(Deserialize)]
365        struct ToolsListResponse {
366            tools: Vec<super::tools::Tool>,
367        }
368
369        let tools_response = serde_json::from_value::<ToolsListResponse>(result)
370            .map_err(|e| Error::network(format!("Failed to parse tools list: {}", e)))?;
371
372        Ok(tools_response.tools)
373    }
374
375    async fn call_tool(
376        &self,
377        name: &str,
378        arguments: serde_json::Value,
379    ) -> Result<super::tools::ToolResult> {
380        let mut stats = self.stats.write().await;
381        stats.requests_sent += 1;
382        drop(stats);
383
384        let params = serde_json::json!({
385            "name": name,
386            "arguments": arguments
387        });
388
389        let request = McpRequest::new(
390            RequestId::String(Uuid::new_v4().to_string()),
391            "tools/call",
392            Some(params),
393        );
394
395        let response = self.send_request_with_retry(request).await?;
396
397        if let Some(error) = response.error {
398            return Err(Error::network(format!(
399                "tools/call failed: {}",
400                error.message
401            )));
402        }
403
404        let result = response
405            .result
406            .ok_or_else(|| Error::network("tools/call response missing result"))?;
407
408        serde_json::from_value::<super::tools::ToolResult>(result)
409            .map_err(|e| Error::network(format!("Failed to parse tool result: {}", e)))
410    }
411
412    async fn list_resources(&self) -> Result<Vec<super::tools::ResourceTemplate>> {
413        let request = McpRequest::new(
414            RequestId::String(Uuid::new_v4().to_string()),
415            "resources/list",
416            None,
417        );
418
419        let response = self.send_request_with_retry(request).await?;
420
421        if let Some(error) = response.error {
422            return Err(Error::network(format!(
423                "resources/list failed: {}",
424                error.message
425            )));
426        }
427
428        let result = response
429            .result
430            .ok_or_else(|| Error::network("resources/list response missing result"))?;
431
432        #[derive(Deserialize)]
433        struct ResourcesListResponse {
434            resources: Vec<super::tools::ResourceTemplate>,
435        }
436
437        let resources_response = serde_json::from_value::<ResourcesListResponse>(result)
438            .map_err(|e| Error::network(format!("Failed to parse resources list: {}", e)))?;
439
440        Ok(resources_response.resources)
441    }
442
443    async fn read_resource(&self, uri: &str) -> Result<serde_json::Value> {
444        let params = serde_json::json!({
445            "uri": uri
446        });
447
448        let request = McpRequest::new(
449            RequestId::String(Uuid::new_v4().to_string()),
450            "resources/read",
451            Some(params),
452        );
453
454        let response = self.send_request_with_retry(request).await?;
455
456        if let Some(error) = response.error {
457            return Err(Error::network(format!(
458                "resources/read failed: {}",
459                error.message
460            )));
461        }
462
463        response
464            .result
465            .ok_or_else(|| Error::network("resources/read response missing result"))
466    }
467
468    async fn stats(&self) -> ClientStats {
469        let mut s = self.stats.read().await.clone();
470
471        // Calculate uptime
472        if let Some(connected_at) = *self.connected_at.read().await {
473            s.uptime_secs = (Utc::now() - connected_at).num_seconds() as u64;
474        }
475
476        s
477    }
478
479    async fn ping(&self) -> Result<bool> {
480        let request = McpRequest::new(RequestId::String(Uuid::new_v4().to_string()), "ping", None);
481
482        match tokio::time::timeout(
483            std::time::Duration::from_secs(5),
484            self.send_request_with_retry(request),
485        )
486        .await
487        {
488            Ok(Ok(response)) => Ok(response.error.is_none()),
489            Ok(Err(_)) | Err(_) => Ok(false),
490        }
491    }
492}
493
494#[cfg(test)]
495mod tests {
496    use super::*;
497
498    #[test]
499    fn test_client_config_default_values() {
500        let config = McpClientConfig {
501            name: "test-server".to_string(),
502            command: "test".to_string(),
503            args: vec![],
504            env: HashMap::new(),
505            timeout_secs: default_timeout(),
506            auto_reconnect: default_reconnect(),
507            max_retries: default_max_retries(),
508        };
509
510        assert_eq!(config.timeout_secs, 30);
511        assert!(config.auto_reconnect);
512        assert_eq!(config.max_retries, 3);
513    }
514
515    #[test]
516    fn test_connection_state_serialization() {
517        let state = ConnectionState::Connected;
518        let json = serde_json::to_string(&state).unwrap();
519        assert_eq!(json, "\"connected\"");
520    }
521
522    #[test]
523    fn test_client_stats_default() {
524        let stats = ClientStats::default();
525        assert_eq!(stats.requests_sent, 0);
526        assert_eq!(stats.responses_received, 0);
527        assert_eq!(stats.errors_total, 0);
528    }
529
530    #[test]
531    fn test_client_creation() {
532        let config = McpClientConfig {
533            name: "test-server".to_string(),
534            command: "echo".to_string(),
535            args: vec!["hello".to_string()],
536            env: HashMap::new(),
537            timeout_secs: 30,
538            auto_reconnect: true,
539            max_retries: 3,
540        };
541
542        let client = McpClient::new(config.clone());
543        assert_eq!(client.config.name, "test-server");
544        assert_eq!(client.config.command, "echo");
545    }
546}