1use anyhow::{Context, Result, anyhow};
12use std::sync::atomic::AtomicUsize;
13use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
14use tokio::process::{Child, ChildStdin, ChildStdout, Command};
15use tokio::sync::{Mutex, RwLock};
16use tokio::task::JoinHandle;
17use tokio::time::{Duration, timeout};
18
19use crate::protocol::*;
20
21pub struct McpClient {
40 server: McpServer,
42 child: RwLock<Option<Child>>,
47 stdin: Mutex<Option<tokio::io::BufWriter<ChildStdin>>>,
51 stdout: Mutex<Option<BufReader<ChildStdout>>>,
55 initialized: RwLock<bool>,
57 tool_cache: RwLock<Option<Vec<McpTool>>>,
59 server_info: RwLock<Option<ServerInfo>>,
61 request_timeout: Duration,
63 stderr_task: Mutex<Option<JoinHandle<()>>>,
66 next_id: AtomicUsize,
69}
70
71impl McpClient {
72 pub fn new(server: McpServer) -> Self {
76 Self {
77 server,
78 child: RwLock::new(None),
79 stdin: Mutex::new(None),
80 stdout: Mutex::new(None),
81 initialized: RwLock::new(false),
82 tool_cache: RwLock::new(None),
83 server_info: RwLock::new(None),
84 request_timeout: Duration::from_secs(30),
85 stderr_task: Mutex::new(None),
86 next_id: AtomicUsize::new(1),
87 }
88 }
89
90 #[must_use]
92 pub fn with_timeout(mut self, timeout: Duration) -> Self {
93 self.request_timeout = timeout;
94 self
95 }
96
97 pub async fn initialize(&self) -> Result<()> {
103 if *self.initialized.read().await {
104 return Ok(());
105 }
106
107 let mut child = Command::new(&self.server.command)
112 .args(&self.server.args)
113 .envs(&self.server.env)
114 .stdin(std::process::Stdio::piped())
115 .stdout(std::process::Stdio::piped())
116 .stderr(std::process::Stdio::piped())
117 .kill_on_drop(true)
118 .spawn()
119 .with_context(|| format!("Failed to spawn MCP server '{}'", self.server.name))?;
120
121 let stdin = child
122 .stdin
123 .take()
124 .expect("stdin not captured — stdin was piped");
125 let stdout = child
126 .stdout
127 .take()
128 .expect("stdout not captured — stdout was piped");
129 let stderr = child
130 .stderr
131 .take()
132 .expect("stderr not captured — stderr was piped");
133
134 let stderr_server_name = self.server.name.clone();
138 let stderr_task = tokio::spawn(async move {
139 let mut reader = BufReader::new(stderr);
140 let mut line = String::new();
141 loop {
142 line.clear();
143 match reader.read_line(&mut line).await {
144 Ok(0) => break, Ok(_) => {
146 let trimmed = line.trim_end_matches(['\n', '\r']);
147 if !trimmed.is_empty() {
148 tracing::debug!(
149 server = %stderr_server_name,
150 stream = "stderr",
151 "{}",
152 trimmed
153 );
154 }
155 }
156 Err(e) => {
157 tracing::debug!(
158 server = %stderr_server_name,
159 stream = "stderr",
160 error = %e,
161 "stderr drain stopping"
162 );
163 break;
164 }
165 }
166 }
167 });
168
169 *self.stdin.lock().await = Some(tokio::io::BufWriter::new(stdin));
171 *self.stdout.lock().await = Some(BufReader::new(stdout));
172 *self.stderr_task.lock().await = Some(stderr_task);
173
174 *self.child.write().await = Some(child);
176
177 let params = InitializeParams::default();
181 let request = McpRequest::with_id(self.next_id(), "initialize")
182 .with_params(serde_json::to_value(¶ms)?);
183
184 let response = match self.do_request(request).await {
188 Ok(resp) => resp,
189 Err(e) => {
190 self.cleanup_child().await;
191 return Err(e);
192 }
193 };
194
195 let result_json = response.into_result()?;
197 let init_result: InitializeResult = serde_json::from_value(result_json)?;
198
199 *self.server_info.write().await = Some(init_result.server_info.clone());
200 *self.initialized.write().await = true;
201
202 let notification = McpRequest::notification("notifications/initialized");
205 self.send_notification(notification).await?;
206
207 tracing::debug!(
208 server = %self.server.name,
209 version = %init_result.server_info.version,
210 "MCP server initialized"
211 );
212
213 Ok(())
214 }
215
216 pub async fn is_initialized(&self) -> bool {
218 *self.initialized.read().await
219 }
220
221 pub async fn server_info(&self) -> Option<ServerInfo> {
223 self.server_info.read().await.clone()
224 }
225
226 async fn do_request(&self, request: McpRequest) -> Result<McpResponse> {
233 let request_id = request.id.clone();
234
235 let mut stdin_guard = self.stdin.lock().await;
237 let stdin = stdin_guard
238 .as_mut()
239 .ok_or_else(|| anyhow!("stdin not available on '{}'", self.server.name))?;
240
241 let json = request.to_jsonl()?;
243 timeout(self.request_timeout, async {
244 stdin.write_all(&json).await?;
245 stdin.flush().await?;
246 Ok::<(), tokio::io::Error>(())
247 })
248 .await
249 .map_err(|e| anyhow::anyhow!("MCP request timed out (write): {e}"))??;
250
251 let mut stdout_guard = self.stdout.lock().await;
253 let stdout = stdout_guard
254 .as_mut()
255 .ok_or_else(|| anyhow!("stdout not available on '{}'", self.server.name))?;
256
257 loop {
263 let line: std::io::Result<Option<String>> = timeout(self.request_timeout, async {
264 stdout.lines().next_line().await
265 })
266 .await
267 .map_err(|e| anyhow::anyhow!("MCP request timed out (read): {e}"))?;
268
269 let response_str: String = line
270 .context("Failed to read MCP response line from stdout")?
271 .with_context(|| format!("MCP server {} returned no response", self.server.name))?;
272
273 let value: serde_json::Value = serde_json::from_str(&response_str)
276 .with_context(|| format!("Failed to parse MCP message JSON: {response_str}"))?;
277
278 if value.get("method").is_some() {
281 tracing::debug!(
282 server = %self.server.name,
283 method = ?value.get("method"),
284 "MCP server sent a notification/server request; skipping"
285 );
286 continue;
287 }
288
289 let got_id = value.get("id");
291 if got_id != Some(&request_id) {
292 tracing::warn!(
297 server = %self.server.name,
298 expected_id = ?request_id,
299 got_id = ?got_id,
300 "MCP response ID mismatch, skipping"
301 );
302 continue;
303 }
304
305 let parsed: McpResponse = serde_json::from_value(value)
306 .with_context(|| format!("Failed to parse MCP response: {response_str}"))?;
307 return Ok(parsed);
308 }
309 }
310
311 async fn send_notification(&self, notification: McpRequest) -> Result<()> {
313 let mut stdin_guard = self.stdin.lock().await;
314 let stdin = stdin_guard
315 .as_mut()
316 .ok_or_else(|| anyhow!("stdin not available on '{}'", self.server.name))?;
317
318 let json = notification.to_jsonl()?;
319 stdin.write_all(&json).await?;
320 stdin.flush().await?;
321
322 Ok(())
323 }
324
325 pub(crate) async fn send_request(&self, request: McpRequest) -> Result<McpResponse> {
332 {
334 let child = self.child.read().await;
335 if child.is_none() {
336 tracing::warn!(
337 server = %self.server.name,
338 "MCP server not running, attempting auto-start"
339 );
340 drop(child);
341 self.restart().await?;
343 }
344 }
345
346 let request_for_retry = request.clone();
349 match self.do_request(request).await {
350 Ok(resp) => Ok(resp),
351 Err(e) => {
352 let err_str = e.to_string();
354 let is_comm_error = err_str.contains("not available")
355 || err_str.contains("broken pipe")
356 || err_str.contains("timed out")
357 || err_str.contains("no response")
358 || err_str.contains("reset by peer");
359
360 if is_comm_error {
361 tracing::warn!(
362 server = %self.server.name,
363 error = %err_str,
364 "MCP communication error, attempting auto-restart + retry"
365 );
366 self.restart().await?;
367 self.do_request(request_for_retry).await
370 } else {
371 Err(e)
372 }
373 }
374 }
375 }
376
377 fn next_id(&self) -> usize {
379 self.next_id
380 .fetch_add(1, std::sync::atomic::Ordering::Relaxed)
381 }
382
383 async fn cleanup_child(&self) {
387 *self.stdin.lock().await = None;
389 *self.stdout.lock().await = None;
390
391 if let Some(handle) = self.stderr_task.lock().await.take() {
393 handle.abort();
394 }
395
396 if let Some(mut child) = self.child.write().await.take() {
400 let _ = child.kill().await;
401 let _ = child.wait().await;
402 }
403
404 *self.initialized.write().await = false;
405 }
406
407 pub async fn list_tools(&self) -> Result<Vec<McpTool>> {
411 if let Some(cached) = self.tool_cache.read().await.clone() {
413 return Ok(cached);
414 }
415
416 self.refresh_tools().await
417 }
418
419 pub async fn refresh_tools(&self) -> Result<Vec<McpTool>> {
421 let request = McpRequest::with_id(self.next_id(), "tools/list");
422 let response = self.send_request(request).await?;
423
424 let result_json = response.into_result()?;
425 let tools_result: McpToolsResult = serde_json::from_value(result_json)?;
426
427 let tools = tools_result.tools;
428 *self.tool_cache.write().await = Some(tools.clone());
429
430 tracing::debug!(
431 server = %self.server.name,
432 count = tools.len(),
433 "Refreshed tool cache"
434 );
435
436 Ok(tools)
437 }
438
439 pub async fn call_tool(
443 &self,
444 tool_name: &str,
445 arguments: serde_json::Value,
446 ) -> Result<McpToolCallResult> {
447 let params = serde_json::json!({
448 "name": tool_name,
449 "arguments": arguments,
450 });
451
452 let request = McpRequest::with_id(self.next_id(), "tools/call").with_params(params);
453 let response = self.send_request(request).await?;
454
455 let result_json = response.into_result()?;
456 let call_result: McpToolCallResult = serde_json::from_value(result_json)?;
457
458 tracing::debug!(
459 server = %self.server.name,
460 tool = tool_name,
461 "Tool call completed"
462 );
463
464 Ok(call_result)
465 }
466
467 pub async fn call_tool_text(
471 &self,
472 tool_name: &str,
473 arguments: serde_json::Value,
474 ) -> Result<String> {
475 let result = self.call_tool(tool_name, arguments).await?;
476
477 for block in result.content {
478 if let McpContentBlock::Text { text } = block {
479 return Ok(text);
480 }
481 }
482
483 Err(anyhow!("Tool '{tool_name}' returned no text content"))
484 }
485
486 pub async fn shutdown(&self) -> Result<()> {
491 *self.stdin.lock().await = None;
493 *self.stdout.lock().await = None;
494
495 if let Some(handle) = self.stderr_task.lock().await.take() {
497 handle.abort();
498 }
499
500 let mut child_guard = self.child.write().await;
501
502 if let Some(mut child) = child_guard.take() {
503 tracing::debug!(server = %self.server.name, "Shutting down MCP server");
504
505 let _ = child.try_wait();
507
508 child.kill().await?;
510 let _ = child.wait().await;
511 }
512
513 *self.initialized.write().await = false;
514 *self.tool_cache.write().await = None;
515
516 Ok(())
517 }
518
519 pub async fn restart(&self) -> Result<()> {
521 self.shutdown().await?;
522 self.initialize().await
523 }
524
525 pub fn server(&self) -> &McpServer {
527 &self.server
528 }
529}
530
531impl std::fmt::Debug for McpClient {
532 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
533 f.debug_struct("McpClient")
534 .field("server", &self.server.name)
535 .field("initialized", &self.initialized)
536 .finish()
537 }
538}
539
540#[cfg(test)]
545mod tests {
546 use super::*;
547 use tokio::time::Duration;
548
549 #[test]
552 fn test_client_construction() {
553 let server = McpServer::new("test-server", "npx");
554 let client = McpClient::new(server);
555
556 assert_eq!(client.server.name, "test-server");
558 assert_eq!(client.server.command, "npx");
559 }
560
561 #[test]
562 fn test_client_with_timeout() {
563 let server = McpServer::new("test", "echo");
564 let client = McpClient::new(server).with_timeout(Duration::from_secs(60));
565
566 assert_eq!(client.server.name, "test");
570 }
571
572 #[test]
573 fn test_client_with_timeout_short() {
574 let server = McpServer::new("test", "sleep");
575 let client = McpClient::new(server).with_timeout(Duration::from_millis(50));
576
577 assert_eq!(client.server.name, "test");
578 }
580
581 #[test]
582 fn test_client_debug_format() {
583 let server = McpServer::new("debug-test", "echo");
584 let client = McpClient::new(server);
585
586 let debug_str = format!("{client:?}");
587
588 assert!(debug_str.contains("debug-test"));
590 assert!(debug_str.contains("McpClient"));
591 }
592
593 #[test]
594 fn test_client_debug_different_servers() {
595 let server1 = McpServer::new("server-a", "cmd1");
596 let server2 = McpServer::new("server-b", "cmd2");
597
598 let client1 = McpClient::new(server1);
599 let client2 = McpClient::new(server2);
600
601 let debug1 = format!("{client1:?}");
602 let debug2 = format!("{client2:?}");
603
604 assert!(debug1.contains("server-a"));
605 assert!(debug2.contains("server-b"));
606 assert_ne!(debug1, debug2);
607 }
608
609 #[tokio::test]
610 async fn test_is_initialized_false_on_new() {
611 let server = McpServer::new("test", "echo");
612 let client = McpClient::new(server);
613
614 assert!(!client.is_initialized().await);
616 }
617
618 #[tokio::test]
619 async fn test_is_initialized_after_failed_init() {
620 let server = McpServer::new("ghost", "nonexistent-binary-xyz-123");
621 let client = McpClient::new(server);
622
623 let result = client.initialize().await;
625 assert!(result.is_err());
626 assert!(!client.is_initialized().await);
627 }
628
629 #[tokio::test]
630 async fn test_shutdown_when_not_running() {
631 let server = McpServer::new("test-shutdown", "echo");
632 let client = McpClient::new(server);
633
634 let result = client.shutdown().await;
636 assert!(result.is_ok());
637
638 assert!(!client.is_initialized().await);
640 }
641
642 #[tokio::test]
643 async fn test_shutdown_idempotent() {
644 let server = McpServer::new("test-idempotent", "echo");
645 let client = McpClient::new(server);
646
647 let first = client.shutdown().await;
649 assert!(first.is_ok());
650
651 let second = client.shutdown().await;
653 assert!(second.is_ok());
654 }
655
656 #[test]
657 fn test_client_server_config_passed_through() {
658 let server = McpServer::new("config-test", "npx")
659 .with_args(vec!["-y".to_string(), "@some/mcp-server".to_string()])
660 .with_env("DEBUG", "true");
661
662 let client = McpClient::new(server);
663
664 assert_eq!(client.server.name, "config-test");
665 assert_eq!(client.server.command, "npx");
666 assert_eq!(client.server.args, vec!["-y", "@some/mcp-server"]);
667 assert_eq!(client.server.env.get("DEBUG"), Some(&"true".to_string()));
668 }
669
670 #[test]
671 fn test_client_server_method() {
672 let server = McpServer::new("method-test", "python");
673 let client = McpClient::new(server);
674
675 let retrieved_server = client.server();
677 assert_eq!(retrieved_server.name, "method-test");
678 }
679
680 #[tokio::test]
681 async fn test_server_info_none_on_new_client() {
682 let server = McpServer::new("test", "echo");
683 let client = McpClient::new(server);
684
685 assert!(client.server_info().await.is_none());
687 }
688
689 #[tokio::test]
690 async fn test_initialize_already_initialized_skipped() {
691 let server = McpServer::new("echo", "echo");
692 let client = McpClient::new(server);
693
694 let _ = client.initialize().await;
696
697 let result = client.initialize().await;
699 assert!(result.is_err() || result.is_ok());
701 }
702
703 #[test]
704 fn test_client_default_timeout_is_30_seconds() {
705 let server = McpServer::new("test", "echo");
706 let client = McpClient::new(server);
707
708 assert_eq!(client.server.name, "test");
711 }
712
713 #[tokio::test]
714 async fn test_shutdown_clears_initialized_flag() {
715 let server = McpServer::new("test-clear", "echo");
716 let client = McpClient::new(server);
717
718 assert!(!client.is_initialized().await);
720
721 client.shutdown().await.unwrap();
723 assert!(!client.is_initialized().await);
724 }
725}