1use super::transport::{StdioTransport, Transport};
12use super::types::*;
13use crate::error::{Error, Result};
14use async_trait::async_trait;
15use chrono::{DateTime, Utc};
16use serde::{Deserialize, Serialize};
17use std::collections::HashMap;
18use std::sync::Arc;
19use tokio::sync::RwLock;
20use uuid::Uuid;
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct McpClientConfig {
25 pub name: String,
27 pub command: String,
29 pub args: Vec<String>,
31 #[serde(default)]
33 pub env: HashMap<String, String>,
34 #[serde(default = "default_timeout")]
36 pub timeout_secs: u64,
37 #[serde(default = "default_reconnect")]
39 pub auto_reconnect: bool,
40 #[serde(default = "default_max_retries")]
42 pub max_retries: u32,
43}
44
45fn default_timeout() -> u64 {
46 30
47}
48
49fn default_reconnect() -> bool {
50 true
51}
52
53fn default_max_retries() -> u32 {
54 3
55}
56
57#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
59#[serde(rename_all = "snake_case")]
60pub enum ConnectionState {
61 Disconnected,
63 Connecting,
65 Connected,
67 Failed,
69 Reconnecting,
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct ClientStats {
76 pub requests_sent: u64,
78 pub responses_received: u64,
80 pub errors_total: u64,
82 pub avg_response_time_ms: f64,
84 pub uptime_secs: u64,
86 pub reconnect_attempts: u32,
88 pub last_request_at: Option<DateTime<Utc>>,
90}
91
92impl Default for ClientStats {
93 fn default() -> Self {
94 Self {
95 requests_sent: 0,
96 responses_received: 0,
97 errors_total: 0,
98 avg_response_time_ms: 0.0,
99 uptime_secs: 0,
100 reconnect_attempts: 0,
101 last_request_at: None,
102 }
103 }
104}
105
106#[async_trait]
108pub trait McpClientTrait: Send + Sync {
109 async fn connect(&mut self) -> Result<()>;
111
112 async fn disconnect(&mut self) -> Result<()>;
114
115 async fn state(&self) -> ConnectionState;
117
118 async fn list_tools(&self) -> Result<Vec<super::tools::Tool>>;
120
121 async fn call_tool(
123 &self,
124 name: &str,
125 arguments: serde_json::Value,
126 ) -> Result<super::tools::ToolResult>;
127
128 async fn list_resources(&self) -> Result<Vec<super::tools::ResourceTemplate>>;
130
131 async fn read_resource(&self, uri: &str) -> Result<serde_json::Value>;
133
134 async fn stats(&self) -> ClientStats;
136
137 async fn ping(&self) -> Result<bool>;
139}
140
141pub struct McpClient {
143 pub id: Uuid,
145 pub config: McpClientConfig,
147 transport: Arc<RwLock<Option<Arc<dyn Transport>>>>,
149 state: Arc<RwLock<ConnectionState>>,
151 server_info: Arc<RwLock<Option<ServerInfo>>>,
153 server_capabilities: Arc<RwLock<Option<ServerCapabilities>>>,
155 stats: Arc<RwLock<ClientStats>>,
157 connected_at: Arc<RwLock<Option<DateTime<Utc>>>>,
159}
160
161impl McpClient {
162 pub fn new(config: McpClientConfig) -> Self {
164 Self {
165 id: Uuid::new_v4(),
166 config,
167 transport: Arc::new(RwLock::new(None)),
168 state: Arc::new(RwLock::new(ConnectionState::Disconnected)),
169 server_info: Arc::new(RwLock::new(None)),
170 server_capabilities: Arc::new(RwLock::new(None)),
171 stats: Arc::new(RwLock::new(ClientStats::default())),
172 connected_at: Arc::new(RwLock::new(None)),
173 }
174 }
175
176 pub async fn server_info(&self) -> Option<ServerInfo> {
178 self.server_info.read().await.clone()
179 }
180
181 pub async fn capabilities(&self) -> Option<ServerCapabilities> {
183 self.server_capabilities.read().await.clone()
184 }
185
186 async fn record_success(&self, response_time_ms: f64) {
188 let mut s = self.stats.write().await;
189 s.responses_received += 1;
190 s.last_request_at = Some(Utc::now());
191
192 if s.responses_received == 1 {
194 s.avg_response_time_ms = response_time_ms;
195 } else {
196 s.avg_response_time_ms = (s.avg_response_time_ms * 0.9) + (response_time_ms * 0.1);
197 }
198 }
199
200 async fn record_error(&self) {
202 let mut s = self.stats.write().await;
203 s.errors_total += 1;
204 }
205
206 async fn send_request_with_retry(&self, request: McpRequest) -> Result<McpResponse> {
208 let mut attempts = 0;
209 let max_retries = self.config.max_retries;
210
211 loop {
212 let transport_guard = self.transport.read().await;
213 let transport = transport_guard
214 .as_ref()
215 .ok_or_else(|| Error::network("Not connected to server"))?;
216
217 let start = std::time::Instant::now();
218 let result = transport.send_request(request.clone()).await;
219 let elapsed_ms = start.elapsed().as_millis() as f64;
220
221 match result {
222 Ok(response) => {
223 if response.error.is_some() {
224 self.record_error().await;
225 } else {
226 self.record_success(elapsed_ms).await;
227 }
228 return Ok(response);
229 }
230 Err(e) => {
231 self.record_error().await;
232 attempts += 1;
233
234 if attempts >= max_retries {
235 return Err(Error::network(format!(
236 "Request failed after {} attempts: {}",
237 attempts, e
238 )));
239 }
240
241 let backoff_ms = 100 * (2_u64.pow(attempts - 1));
243 tokio::time::sleep(tokio::time::Duration::from_millis(backoff_ms)).await;
244 }
245 }
246 }
247 }
248}
249
250#[async_trait]
251impl McpClientTrait for McpClient {
252 async fn connect(&mut self) -> Result<()> {
253 *self.state.write().await = ConnectionState::Connecting;
255
256 let env_vec: Vec<(String, String)> = self.config.env.clone().into_iter().collect();
258
259 let transport =
261 StdioTransport::spawn(&self.config.command, self.config.args.clone(), env_vec)
262 .await
263 .map_err(|e| Error::network(format!("Failed to create transport: {}", e)))?;
264
265 *self.transport.write().await = Some(Arc::new(transport));
266
267 let init_params = serde_json::json!({
269 "protocolVersion": crate::mcp::MCP_VERSION,
270 "capabilities": {},
271 "clientInfo": {
272 "name": "reasonkit-core",
273 "version": env!("CARGO_PKG_VERSION")
274 }
275 });
276
277 let request = McpRequest::new(
278 RequestId::String(Uuid::new_v4().to_string()),
279 "initialize",
280 Some(init_params),
281 );
282
283 let response = self.send_request_with_retry(request).await?;
284
285 if let Some(error) = response.error {
286 *self.state.write().await = ConnectionState::Failed;
287 return Err(Error::network(format!(
288 "Initialize failed: {}",
289 error.message
290 )));
291 }
292
293 if let Some(result) = response.result {
295 if let Ok(init_result) =
296 serde_json::from_value::<super::lifecycle::InitializeResult>(result)
297 {
298 *self.server_info.write().await = Some(init_result.server_info);
299 *self.server_capabilities.write().await = Some(init_result.capabilities);
300 }
301 }
302
303 let notification = McpNotification {
305 jsonrpc: JsonRpcVersion::default(),
306 method: "notifications/initialized".to_string(),
307 params: None,
308 };
309
310 let transport_guard = self.transport.read().await;
311 if let Some(transport) = transport_guard.as_ref() {
312 transport.send_notification(notification).await.ok();
313 }
314
315 *self.state.write().await = ConnectionState::Connected;
316 *self.connected_at.write().await = Some(Utc::now());
317
318 Ok(())
319 }
320
321 async fn disconnect(&mut self) -> Result<()> {
322 let request = McpRequest::new(
324 RequestId::String(Uuid::new_v4().to_string()),
325 "shutdown",
326 None,
327 );
328
329 let _ = self.send_request_with_retry(request).await;
331
332 *self.transport.write().await = None;
334 *self.state.write().await = ConnectionState::Disconnected;
335 *self.connected_at.write().await = None;
336
337 Ok(())
338 }
339
340 async fn state(&self) -> ConnectionState {
341 *self.state.read().await
342 }
343
344 async fn list_tools(&self) -> Result<Vec<super::tools::Tool>> {
345 let request = McpRequest::new(
346 RequestId::String(Uuid::new_v4().to_string()),
347 "tools/list",
348 None,
349 );
350
351 let response = self.send_request_with_retry(request).await?;
352
353 if let Some(error) = response.error {
354 return Err(Error::network(format!(
355 "tools/list failed: {}",
356 error.message
357 )));
358 }
359
360 let result = response
361 .result
362 .ok_or_else(|| Error::network("tools/list response missing result"))?;
363
364 #[derive(Deserialize)]
365 struct ToolsListResponse {
366 tools: Vec<super::tools::Tool>,
367 }
368
369 let tools_response = serde_json::from_value::<ToolsListResponse>(result)
370 .map_err(|e| Error::network(format!("Failed to parse tools list: {}", e)))?;
371
372 Ok(tools_response.tools)
373 }
374
375 async fn call_tool(
376 &self,
377 name: &str,
378 arguments: serde_json::Value,
379 ) -> Result<super::tools::ToolResult> {
380 let mut stats = self.stats.write().await;
381 stats.requests_sent += 1;
382 drop(stats);
383
384 let params = serde_json::json!({
385 "name": name,
386 "arguments": arguments
387 });
388
389 let request = McpRequest::new(
390 RequestId::String(Uuid::new_v4().to_string()),
391 "tools/call",
392 Some(params),
393 );
394
395 let response = self.send_request_with_retry(request).await?;
396
397 if let Some(error) = response.error {
398 return Err(Error::network(format!(
399 "tools/call failed: {}",
400 error.message
401 )));
402 }
403
404 let result = response
405 .result
406 .ok_or_else(|| Error::network("tools/call response missing result"))?;
407
408 serde_json::from_value::<super::tools::ToolResult>(result)
409 .map_err(|e| Error::network(format!("Failed to parse tool result: {}", e)))
410 }
411
412 async fn list_resources(&self) -> Result<Vec<super::tools::ResourceTemplate>> {
413 let request = McpRequest::new(
414 RequestId::String(Uuid::new_v4().to_string()),
415 "resources/list",
416 None,
417 );
418
419 let response = self.send_request_with_retry(request).await?;
420
421 if let Some(error) = response.error {
422 return Err(Error::network(format!(
423 "resources/list failed: {}",
424 error.message
425 )));
426 }
427
428 let result = response
429 .result
430 .ok_or_else(|| Error::network("resources/list response missing result"))?;
431
432 #[derive(Deserialize)]
433 struct ResourcesListResponse {
434 resources: Vec<super::tools::ResourceTemplate>,
435 }
436
437 let resources_response = serde_json::from_value::<ResourcesListResponse>(result)
438 .map_err(|e| Error::network(format!("Failed to parse resources list: {}", e)))?;
439
440 Ok(resources_response.resources)
441 }
442
443 async fn read_resource(&self, uri: &str) -> Result<serde_json::Value> {
444 let params = serde_json::json!({
445 "uri": uri
446 });
447
448 let request = McpRequest::new(
449 RequestId::String(Uuid::new_v4().to_string()),
450 "resources/read",
451 Some(params),
452 );
453
454 let response = self.send_request_with_retry(request).await?;
455
456 if let Some(error) = response.error {
457 return Err(Error::network(format!(
458 "resources/read failed: {}",
459 error.message
460 )));
461 }
462
463 response
464 .result
465 .ok_or_else(|| Error::network("resources/read response missing result"))
466 }
467
468 async fn stats(&self) -> ClientStats {
469 let mut s = self.stats.read().await.clone();
470
471 if let Some(connected_at) = *self.connected_at.read().await {
473 s.uptime_secs = (Utc::now() - connected_at).num_seconds() as u64;
474 }
475
476 s
477 }
478
479 async fn ping(&self) -> Result<bool> {
480 let request = McpRequest::new(RequestId::String(Uuid::new_v4().to_string()), "ping", None);
481
482 match tokio::time::timeout(
483 std::time::Duration::from_secs(5),
484 self.send_request_with_retry(request),
485 )
486 .await
487 {
488 Ok(Ok(response)) => Ok(response.error.is_none()),
489 Ok(Err(_)) | Err(_) => Ok(false),
490 }
491 }
492}
493
494#[cfg(test)]
495mod tests {
496 use super::*;
497
498 #[test]
499 fn test_client_config_default_values() {
500 let config = McpClientConfig {
501 name: "test-server".to_string(),
502 command: "test".to_string(),
503 args: vec![],
504 env: HashMap::new(),
505 timeout_secs: default_timeout(),
506 auto_reconnect: default_reconnect(),
507 max_retries: default_max_retries(),
508 };
509
510 assert_eq!(config.timeout_secs, 30);
511 assert!(config.auto_reconnect);
512 assert_eq!(config.max_retries, 3);
513 }
514
515 #[test]
516 fn test_connection_state_serialization() {
517 let state = ConnectionState::Connected;
518 let json = serde_json::to_string(&state).unwrap();
519 assert_eq!(json, "\"connected\"");
520 }
521
522 #[test]
523 fn test_client_stats_default() {
524 let stats = ClientStats::default();
525 assert_eq!(stats.requests_sent, 0);
526 assert_eq!(stats.responses_received, 0);
527 assert_eq!(stats.errors_total, 0);
528 }
529
530 #[test]
531 fn test_client_creation() {
532 let config = McpClientConfig {
533 name: "test-server".to_string(),
534 command: "echo".to_string(),
535 args: vec!["hello".to_string()],
536 env: HashMap::new(),
537 timeout_secs: 30,
538 auto_reconnect: true,
539 max_retries: 3,
540 };
541
542 let client = McpClient::new(config.clone());
543 assert_eq!(client.config.name, "test-server");
544 assert_eq!(client.config.command, "echo");
545 }
546}