1use anyhow::{anyhow, Context, Result};
12use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
13use tokio::process::{Child, ChildStdin, ChildStdout, Command};
14use tokio::sync::RwLock;
15use tokio::time::{timeout, Duration};
16
17use super::protocol::*;
18
19pub struct McpClient {
38 server: McpServer,
40 child: RwLock<Option<Child>>,
42 stdin: RwLock<Option<tokio::io::BufWriter<ChildStdin>>>,
44 stdout: RwLock<Option<BufReader<ChildStdout>>>,
46 initialized: RwLock<bool>,
48 tool_cache: RwLock<Option<Vec<McpTool>>>,
50 server_info: RwLock<Option<ServerInfo>>,
52 request_timeout: Duration,
54}
55
56impl McpClient {
57 pub fn new(server: McpServer) -> Self {
61 Self {
62 server,
63 child: RwLock::new(None),
64 stdin: RwLock::new(None),
65 stdout: RwLock::new(None),
66 initialized: RwLock::new(false),
67 tool_cache: RwLock::new(None),
68 server_info: RwLock::new(None),
69 request_timeout: Duration::from_secs(30),
70 }
71 }
72
73 #[must_use]
75 pub fn with_timeout(mut self, timeout: Duration) -> Self {
76 self.request_timeout = timeout;
77 self
78 }
79
80 pub async fn initialize(&self) -> Result<()> {
82 if *self.initialized.read().await {
83 return Ok(());
84 }
85
86 let mut child = Command::new(&self.server.command)
88 .args(&self.server.args)
89 .envs(&self.server.env)
90 .stdin(std::process::Stdio::piped())
91 .stdout(std::process::Stdio::piped())
92 .stderr(std::process::Stdio::piped())
93 .spawn()
94 .with_context(|| format!("Failed to spawn MCP server '{}'", self.server.name))?;
95
96 let stdin = child
97 .stdin
98 .take()
99 .expect("stdin not captured — stdin was piped");
100 let stdout = child
101 .stdout
102 .take()
103 .expect("stdout not captured — stdout was piped");
104
105 *self.stdin.write().await = Some(tokio::io::BufWriter::new(stdin));
107 *self.stdout.write().await = Some(BufReader::new(stdout));
108
109 *self.child.write().await = Some(child);
111
112 let params = InitializeParams::default();
114 let request = McpRequest::new("initialize").with_params(serde_json::to_value(¶ms)?);
115
116 let response = self.do_request(request).await?;
119
120 let result_json = response.into_result()?;
122 let init_result: InitializeResult = serde_json::from_value(result_json)?;
123
124 *self.server_info.write().await = Some(init_result.server_info.clone());
125 *self.initialized.write().await = true;
126
127 let notification = McpRequest::new("notifications/initialized");
129 self.send_notification(notification).await?;
130
131 tracing::debug!(
132 server = %self.server.name,
133 version = %init_result.server_info.version,
134 "MCP server initialized"
135 );
136
137 Ok(())
138 }
139
140 pub async fn is_initialized(&self) -> bool {
142 *self.initialized.read().await
143 }
144
145 pub async fn server_info(&self) -> Option<ServerInfo> {
147 self.server_info.read().await.clone()
148 }
149
150 async fn do_request(&self, request: McpRequest) -> Result<McpResponse> {
155 let request_id = request.id.clone();
156
157 let mut stdin_guard = self.stdin.write().await;
159 let stdin = stdin_guard
160 .as_mut()
161 .ok_or_else(|| anyhow!("stdin not available on '{}'", self.server.name))?;
162
163 let json = request.to_jsonl()?;
165 timeout(self.request_timeout, async {
166 stdin.write_all(&json).await?;
167 stdin.flush().await?;
168 Ok::<(), tokio::io::Error>(())
169 })
170 .await
171 .map_err(|e| anyhow::anyhow!("MCP request timed out (write): {}", e))??;
172
173 let mut stdout_guard = self.stdout.write().await;
175 let stdout = stdout_guard
176 .as_mut()
177 .ok_or_else(|| anyhow!("stdout not available on '{}'", self.server.name))?;
178
179 let line: std::io::Result<Option<String>> = timeout(self.request_timeout, async {
181 stdout.lines().next_line().await
182 })
183 .await
184 .map_err(|e| anyhow::anyhow!("MCP request timed out (read): {}", e))?;
185
186 let response_str: String = line
187 .context("Failed to read MCP response line from stdout")?
188 .with_context(|| format!("MCP server {} returned no response", self.server.name))?;
189
190 let parsed: McpResponse = serde_json::from_str(&response_str)
191 .with_context(|| format!("Failed to parse MCP response JSON: {}", response_str))?;
192
193 if parsed.id != request_id {
195 tracing::warn!(
196 server = %self.server.name,
197 expected_id = ?request_id,
198 got_id = ?parsed.id,
199 "MCP response ID mismatch"
200 );
201 }
202
203 Ok(parsed)
204 }
205
206 async fn send_notification(&self, notification: McpRequest) -> Result<()> {
208 let mut stdin_guard = self.stdin.write().await;
209 let stdin = stdin_guard
210 .as_mut()
211 .ok_or_else(|| anyhow!("stdin not available on '{}'", self.server.name))?;
212
213 let json = notification.to_jsonl()?;
214 stdin.write_all(&json).await?;
215 stdin.flush().await?;
216
217 Ok(())
218 }
219
220 pub(crate) async fn send_request(&self, request: McpRequest) -> Result<McpResponse> {
226 {
228 let child = self.child.read().await;
229 if child.is_none() {
230 tracing::warn!(
231 server = %self.server.name,
232 "MCP server not running, attempting auto-start"
233 );
234 drop(child);
235 self.restart().await?;
237 }
238 }
239
240 match self.do_request(request).await {
241 Ok(resp) => Ok(resp),
242 Err(e) => {
243 let err_str = e.to_string();
245 let is_comm_error = err_str.contains("not available")
246 || err_str.contains("broken pipe")
247 || err_str.contains("timed out")
248 || err_str.contains("no response");
249
250 if is_comm_error {
251 tracing::warn!(
252 server = %self.server.name,
253 error = %err_str,
254 "MCP communication error, attempting auto-restart"
255 );
256 self.restart().await?;
257 anyhow::bail!(
258 "MCP server '{}' restarted after error. Please retry the request.",
259 self.server.name
260 );
261 } else {
262 Err(e)
263 }
264 }
265 }
266 }
267
268 pub async fn list_tools(&self) -> Result<Vec<McpTool>> {
272 if let Some(cached) = self.tool_cache.read().await.clone() {
274 return Ok(cached);
275 }
276
277 self.refresh_tools().await
278 }
279
280 pub async fn refresh_tools(&self) -> Result<Vec<McpTool>> {
282 let request = McpRequest::new("tools/list");
283 let response = self.send_request(request).await?;
284
285 let result_json = response.into_result()?;
286 let tools_result: McpToolsResult = serde_json::from_value(result_json)?;
287
288 let tools = tools_result.tools;
289 *self.tool_cache.write().await = Some(tools.clone());
290
291 tracing::debug!(
292 server = %self.server.name,
293 count = tools.len(),
294 "Refreshed tool cache"
295 );
296
297 Ok(tools)
298 }
299
300 pub async fn call_tool(
304 &self,
305 tool_name: &str,
306 arguments: serde_json::Value,
307 ) -> Result<McpToolCallResult> {
308 let params = serde_json::json!({
309 "name": tool_name,
310 "arguments": arguments,
311 });
312
313 let request = McpRequest::new("tools/call").with_params(params);
314 let response = self.send_request(request).await?;
315
316 let result_json = response.into_result()?;
317 let call_result: McpToolCallResult = serde_json::from_value(result_json)?;
318
319 tracing::debug!(
320 server = %self.server.name,
321 tool = tool_name,
322 "Tool call completed"
323 );
324
325 Ok(call_result)
326 }
327
328 pub async fn call_tool_text(
332 &self,
333 tool_name: &str,
334 arguments: serde_json::Value,
335 ) -> Result<String> {
336 let result = self.call_tool(tool_name, arguments).await?;
337
338 for block in result.content {
339 if let McpContentBlock::Text { text } = block {
340 return Ok(text);
341 }
342 }
343
344 Err(anyhow!("Tool '{}' returned no text content", tool_name))
345 }
346
347 pub async fn shutdown(&self) -> Result<()> {
351 *self.stdin.write().await = None;
353 *self.stdout.write().await = None;
354
355 let mut child_guard = self.child.write().await;
356
357 if let Some(mut child) = child_guard.take() {
358 tracing::debug!(server = %self.server.name, "Shutting down MCP server");
359
360 let _ = child.try_wait();
362
363 child.kill().await?;
365 let _ = child.wait().await;
366 }
367
368 *self.initialized.write().await = false;
369 *self.tool_cache.write().await = None;
370
371 Ok(())
372 }
373
374 pub async fn restart(&self) -> Result<()> {
376 self.shutdown().await?;
377 self.initialize().await
378 }
379
380 pub fn server(&self) -> &McpServer {
382 &self.server
383 }
384}
385
386impl std::fmt::Debug for McpClient {
387 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
388 f.debug_struct("McpClient")
389 .field("server", &self.server.name)
390 .field("initialized", &self.initialized)
391 .finish()
392 }
393}
394
395#[cfg(test)]
400mod tests {
401 use super::*;
402 use tokio::time::Duration;
403
404 #[test]
407 fn test_client_construction() {
408 let server = McpServer::new("test-server", "npx");
409 let client = McpClient::new(server);
410
411 assert_eq!(client.server.name, "test-server");
413 assert_eq!(client.server.command, "npx");
414 }
415
416 #[test]
417 fn test_client_with_timeout() {
418 let server = McpServer::new("test", "echo");
419 let client = McpClient::new(server).with_timeout(Duration::from_secs(60));
420
421 assert_eq!(client.server.name, "test");
425 }
426
427 #[test]
428 fn test_client_with_timeout_short() {
429 let server = McpServer::new("test", "sleep");
430 let client = McpClient::new(server).with_timeout(Duration::from_millis(50));
431
432 assert_eq!(client.server.name, "test");
433 }
435
436 #[test]
437 fn test_client_debug_format() {
438 let server = McpServer::new("debug-test", "echo");
439 let client = McpClient::new(server);
440
441 let debug_str = format!("{:?}", client);
442
443 assert!(debug_str.contains("debug-test"));
445 assert!(debug_str.contains("McpClient"));
446 }
447
448 #[test]
449 fn test_client_debug_different_servers() {
450 let server1 = McpServer::new("server-a", "cmd1");
451 let server2 = McpServer::new("server-b", "cmd2");
452
453 let client1 = McpClient::new(server1);
454 let client2 = McpClient::new(server2);
455
456 let debug1 = format!("{:?}", client1);
457 let debug2 = format!("{:?}", client2);
458
459 assert!(debug1.contains("server-a"));
460 assert!(debug2.contains("server-b"));
461 assert_ne!(debug1, debug2);
462 }
463
464 #[tokio::test]
465 async fn test_is_initialized_false_on_new() {
466 let server = McpServer::new("test", "echo");
467 let client = McpClient::new(server);
468
469 assert!(!client.is_initialized().await);
471 }
472
473 #[tokio::test]
474 async fn test_is_initialized_after_failed_init() {
475 let server = McpServer::new("ghost", "nonexistent-binary-xyz-123");
476 let client = McpClient::new(server);
477
478 let result = client.initialize().await;
480 assert!(result.is_err());
481 assert!(!client.is_initialized().await);
482 }
483
484 #[tokio::test]
485 async fn test_shutdown_when_not_running() {
486 let server = McpServer::new("test-shutdown", "echo");
487 let client = McpClient::new(server);
488
489 let result = client.shutdown().await;
491 assert!(result.is_ok());
492
493 assert!(!client.is_initialized().await);
495 }
496
497 #[tokio::test]
498 async fn test_shutdown_idempotent() {
499 let server = McpServer::new("test-idempotent", "echo");
500 let client = McpClient::new(server);
501
502 let first = client.shutdown().await;
504 assert!(first.is_ok());
505
506 let second = client.shutdown().await;
508 assert!(second.is_ok());
509 }
510
511 #[test]
512 fn test_client_server_config_passed_through() {
513 let server = McpServer::new("config-test", "npx")
514 .with_args(vec!["-y".to_string(), "@some/mcp-server".to_string()])
515 .with_env("DEBUG", "true");
516
517 let client = McpClient::new(server);
518
519 assert_eq!(client.server.name, "config-test");
520 assert_eq!(client.server.command, "npx");
521 assert_eq!(client.server.args, vec!["-y", "@some/mcp-server"]);
522 assert_eq!(client.server.env.get("DEBUG"), Some(&"true".to_string()));
523 }
524
525 #[test]
526 fn test_client_server_method() {
527 let server = McpServer::new("method-test", "python");
528 let client = McpClient::new(server);
529
530 let retrieved_server = client.server();
532 assert_eq!(retrieved_server.name, "method-test");
533 }
534
535 #[tokio::test]
536 async fn test_server_info_none_on_new_client() {
537 let server = McpServer::new("test", "echo");
538 let client = McpClient::new(server);
539
540 assert!(client.server_info().await.is_none());
542 }
543
544 #[tokio::test]
545 async fn test_initialize_already_initialized_skipped() {
546 let server = McpServer::new("echo", "echo");
547 let client = McpClient::new(server);
548
549 let _ = client.initialize().await;
551
552 let result = client.initialize().await;
554 assert!(result.is_err() || result.is_ok());
556 }
557
558 #[test]
559 fn test_client_default_timeout_is_30_seconds() {
560 let server = McpServer::new("test", "echo");
561 let client = McpClient::new(server);
562
563 assert_eq!(client.server.name, "test");
566 }
567
568 #[tokio::test]
569 async fn test_shutdown_clears_initialized_flag() {
570 let server = McpServer::new("test-clear", "echo");
571 let client = McpClient::new(server);
572
573 assert!(!client.is_initialized().await);
575
576 client.shutdown().await.unwrap();
578 assert!(!client.is_initialized().await);
579 }
580}