use anyhow::{Context, Result, anyhow};
use std::sync::atomic::AtomicUsize;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::process::{Child, ChildStdin, ChildStdout, Command};
use tokio::sync::{Mutex, RwLock};
use tokio::task::JoinHandle;
use tokio::time::{Duration, timeout};
use crate::protocol::*;
pub struct McpClient {
server: McpServer,
child: RwLock<Option<Child>>,
stdin: Mutex<Option<tokio::io::BufWriter<ChildStdin>>>,
stdout: Mutex<Option<BufReader<ChildStdout>>>,
initialized: RwLock<bool>,
tool_cache: RwLock<Option<Vec<McpTool>>>,
server_info: RwLock<Option<ServerInfo>>,
request_timeout: Duration,
stderr_task: Mutex<Option<JoinHandle<()>>>,
next_id: AtomicUsize,
}
impl McpClient {
pub fn new(server: McpServer) -> Self {
Self {
server,
child: RwLock::new(None),
stdin: Mutex::new(None),
stdout: Mutex::new(None),
initialized: RwLock::new(false),
tool_cache: RwLock::new(None),
server_info: RwLock::new(None),
request_timeout: Duration::from_secs(30),
stderr_task: Mutex::new(None),
next_id: AtomicUsize::new(1),
}
}
#[must_use]
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.request_timeout = timeout;
self
}
pub async fn initialize(&self) -> Result<()> {
if *self.initialized.read().await {
return Ok(());
}
let mut child = Command::new(&self.server.command)
.args(&self.server.args)
.envs(&self.server.env)
.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped())
.kill_on_drop(true)
.spawn()
.with_context(|| format!("Failed to spawn MCP server '{}'", self.server.name))?;
let stdin = child
.stdin
.take()
.expect("stdin not captured — stdin was piped");
let stdout = child
.stdout
.take()
.expect("stdout not captured — stdout was piped");
let stderr = child
.stderr
.take()
.expect("stderr not captured — stderr was piped");
let stderr_server_name = self.server.name.clone();
let stderr_task = tokio::spawn(async move {
let mut reader = BufReader::new(stderr);
let mut line = String::new();
loop {
line.clear();
match reader.read_line(&mut line).await {
Ok(0) => break, Ok(_) => {
let trimmed = line.trim_end_matches(['\n', '\r']);
if !trimmed.is_empty() {
tracing::debug!(
server = %stderr_server_name,
stream = "stderr",
"{}",
trimmed
);
}
}
Err(e) => {
tracing::debug!(
server = %stderr_server_name,
stream = "stderr",
error = %e,
"stderr drain stopping"
);
break;
}
}
}
});
*self.stdin.lock().await = Some(tokio::io::BufWriter::new(stdin));
*self.stdout.lock().await = Some(BufReader::new(stdout));
*self.stderr_task.lock().await = Some(stderr_task);
*self.child.write().await = Some(child);
let params = InitializeParams::default();
let request = McpRequest::with_id(self.next_id(), "initialize")
.with_params(serde_json::to_value(¶ms)?);
let response = match self.do_request(request).await {
Ok(resp) => resp,
Err(e) => {
self.cleanup_child().await;
return Err(e);
}
};
let result_json = response.into_result()?;
let init_result: InitializeResult = serde_json::from_value(result_json)?;
*self.server_info.write().await = Some(init_result.server_info.clone());
*self.initialized.write().await = true;
let notification = McpRequest::notification("notifications/initialized");
self.send_notification(notification).await?;
tracing::debug!(
server = %self.server.name,
version = %init_result.server_info.version,
"MCP server initialized"
);
Ok(())
}
pub async fn is_initialized(&self) -> bool {
*self.initialized.read().await
}
pub async fn server_info(&self) -> Option<ServerInfo> {
self.server_info.read().await.clone()
}
async fn do_request(&self, request: McpRequest) -> Result<McpResponse> {
let request_id = request.id.clone();
let mut stdin_guard = self.stdin.lock().await;
let stdin = stdin_guard
.as_mut()
.ok_or_else(|| anyhow!("stdin not available on '{}'", self.server.name))?;
let json = request.to_jsonl()?;
timeout(self.request_timeout, async {
stdin.write_all(&json).await?;
stdin.flush().await?;
Ok::<(), tokio::io::Error>(())
})
.await
.map_err(|e| anyhow::anyhow!("MCP request timed out (write): {e}"))??;
let mut stdout_guard = self.stdout.lock().await;
let stdout = stdout_guard
.as_mut()
.ok_or_else(|| anyhow!("stdout not available on '{}'", self.server.name))?;
loop {
let line: std::io::Result<Option<String>> = timeout(self.request_timeout, async {
stdout.lines().next_line().await
})
.await
.map_err(|e| anyhow::anyhow!("MCP request timed out (read): {e}"))?;
let response_str: String = line
.context("Failed to read MCP response line from stdout")?
.with_context(|| format!("MCP server {} returned no response", self.server.name))?;
let value: serde_json::Value = serde_json::from_str(&response_str)
.with_context(|| format!("Failed to parse MCP message JSON: {response_str}"))?;
if value.get("method").is_some() {
tracing::debug!(
server = %self.server.name,
method = ?value.get("method"),
"MCP server sent a notification/server request; skipping"
);
continue;
}
let got_id = value.get("id");
if got_id != Some(&request_id) {
tracing::warn!(
server = %self.server.name,
expected_id = ?request_id,
got_id = ?got_id,
"MCP response ID mismatch, skipping"
);
continue;
}
let parsed: McpResponse = serde_json::from_value(value)
.with_context(|| format!("Failed to parse MCP response: {response_str}"))?;
return Ok(parsed);
}
}
async fn send_notification(&self, notification: McpRequest) -> Result<()> {
let mut stdin_guard = self.stdin.lock().await;
let stdin = stdin_guard
.as_mut()
.ok_or_else(|| anyhow!("stdin not available on '{}'", self.server.name))?;
let json = notification.to_jsonl()?;
stdin.write_all(&json).await?;
stdin.flush().await?;
Ok(())
}
pub(crate) async fn send_request(&self, request: McpRequest) -> Result<McpResponse> {
{
let child = self.child.read().await;
if child.is_none() {
tracing::warn!(
server = %self.server.name,
"MCP server not running, attempting auto-start"
);
drop(child);
self.restart().await?;
}
}
let request_for_retry = request.clone();
match self.do_request(request).await {
Ok(resp) => Ok(resp),
Err(e) => {
let err_str = e.to_string();
let is_comm_error = err_str.contains("not available")
|| err_str.contains("broken pipe")
|| err_str.contains("timed out")
|| err_str.contains("no response")
|| err_str.contains("reset by peer");
if is_comm_error {
tracing::warn!(
server = %self.server.name,
error = %err_str,
"MCP communication error, attempting auto-restart + retry"
);
self.restart().await?;
self.do_request(request_for_retry).await
} else {
Err(e)
}
}
}
}
fn next_id(&self) -> usize {
self.next_id
.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
}
async fn cleanup_child(&self) {
*self.stdin.lock().await = None;
*self.stdout.lock().await = None;
if let Some(handle) = self.stderr_task.lock().await.take() {
handle.abort();
}
if let Some(mut child) = self.child.write().await.take() {
let _ = child.kill().await;
let _ = child.wait().await;
}
*self.initialized.write().await = false;
}
pub async fn list_tools(&self) -> Result<Vec<McpTool>> {
if let Some(cached) = self.tool_cache.read().await.clone() {
return Ok(cached);
}
self.refresh_tools().await
}
pub async fn refresh_tools(&self) -> Result<Vec<McpTool>> {
let request = McpRequest::with_id(self.next_id(), "tools/list");
let response = self.send_request(request).await?;
let result_json = response.into_result()?;
let tools_result: McpToolsResult = serde_json::from_value(result_json)?;
let tools = tools_result.tools;
*self.tool_cache.write().await = Some(tools.clone());
tracing::debug!(
server = %self.server.name,
count = tools.len(),
"Refreshed tool cache"
);
Ok(tools)
}
pub async fn call_tool(
&self,
tool_name: &str,
arguments: serde_json::Value,
) -> Result<McpToolCallResult> {
let params = serde_json::json!({
"name": tool_name,
"arguments": arguments,
});
let request = McpRequest::with_id(self.next_id(), "tools/call").with_params(params);
let response = self.send_request(request).await?;
let result_json = response.into_result()?;
let call_result: McpToolCallResult = serde_json::from_value(result_json)?;
tracing::debug!(
server = %self.server.name,
tool = tool_name,
"Tool call completed"
);
Ok(call_result)
}
pub async fn call_tool_text(
&self,
tool_name: &str,
arguments: serde_json::Value,
) -> Result<String> {
let result = self.call_tool(tool_name, arguments).await?;
for block in result.content {
if let McpContentBlock::Text { text } = block {
return Ok(text);
}
}
Err(anyhow!("Tool '{tool_name}' returned no text content"))
}
pub async fn shutdown(&self) -> Result<()> {
*self.stdin.lock().await = None;
*self.stdout.lock().await = None;
if let Some(handle) = self.stderr_task.lock().await.take() {
handle.abort();
}
let mut child_guard = self.child.write().await;
if let Some(mut child) = child_guard.take() {
tracing::debug!(server = %self.server.name, "Shutting down MCP server");
let _ = child.try_wait();
child.kill().await?;
let _ = child.wait().await;
}
*self.initialized.write().await = false;
*self.tool_cache.write().await = None;
Ok(())
}
pub async fn restart(&self) -> Result<()> {
self.shutdown().await?;
self.initialize().await
}
pub fn server(&self) -> &McpServer {
&self.server
}
}
impl std::fmt::Debug for McpClient {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("McpClient")
.field("server", &self.server.name)
.field("initialized", &self.initialized)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::time::Duration;
#[test]
fn test_client_construction() {
let server = McpServer::new("test-server", "npx");
let client = McpClient::new(server);
assert_eq!(client.server.name, "test-server");
assert_eq!(client.server.command, "npx");
}
#[test]
fn test_client_with_timeout() {
let server = McpServer::new("test", "echo");
let client = McpClient::new(server).with_timeout(Duration::from_secs(60));
assert_eq!(client.server.name, "test");
}
#[test]
fn test_client_with_timeout_short() {
let server = McpServer::new("test", "sleep");
let client = McpClient::new(server).with_timeout(Duration::from_millis(50));
assert_eq!(client.server.name, "test");
}
#[test]
fn test_client_debug_format() {
let server = McpServer::new("debug-test", "echo");
let client = McpClient::new(server);
let debug_str = format!("{client:?}");
assert!(debug_str.contains("debug-test"));
assert!(debug_str.contains("McpClient"));
}
#[test]
fn test_client_debug_different_servers() {
let server1 = McpServer::new("server-a", "cmd1");
let server2 = McpServer::new("server-b", "cmd2");
let client1 = McpClient::new(server1);
let client2 = McpClient::new(server2);
let debug1 = format!("{client1:?}");
let debug2 = format!("{client2:?}");
assert!(debug1.contains("server-a"));
assert!(debug2.contains("server-b"));
assert_ne!(debug1, debug2);
}
#[tokio::test]
async fn test_is_initialized_false_on_new() {
let server = McpServer::new("test", "echo");
let client = McpClient::new(server);
assert!(!client.is_initialized().await);
}
#[tokio::test]
async fn test_is_initialized_after_failed_init() {
let server = McpServer::new("ghost", "nonexistent-binary-xyz-123");
let client = McpClient::new(server);
let result = client.initialize().await;
assert!(result.is_err());
assert!(!client.is_initialized().await);
}
#[tokio::test]
async fn test_shutdown_when_not_running() {
let server = McpServer::new("test-shutdown", "echo");
let client = McpClient::new(server);
let result = client.shutdown().await;
assert!(result.is_ok());
assert!(!client.is_initialized().await);
}
#[tokio::test]
async fn test_shutdown_idempotent() {
let server = McpServer::new("test-idempotent", "echo");
let client = McpClient::new(server);
let first = client.shutdown().await;
assert!(first.is_ok());
let second = client.shutdown().await;
assert!(second.is_ok());
}
#[test]
fn test_client_server_config_passed_through() {
let server = McpServer::new("config-test", "npx")
.with_args(vec!["-y".to_string(), "@some/mcp-server".to_string()])
.with_env("DEBUG", "true");
let client = McpClient::new(server);
assert_eq!(client.server.name, "config-test");
assert_eq!(client.server.command, "npx");
assert_eq!(client.server.args, vec!["-y", "@some/mcp-server"]);
assert_eq!(client.server.env.get("DEBUG"), Some(&"true".to_string()));
}
#[test]
fn test_client_server_method() {
let server = McpServer::new("method-test", "python");
let client = McpClient::new(server);
let retrieved_server = client.server();
assert_eq!(retrieved_server.name, "method-test");
}
#[tokio::test]
async fn test_server_info_none_on_new_client() {
let server = McpServer::new("test", "echo");
let client = McpClient::new(server);
assert!(client.server_info().await.is_none());
}
#[tokio::test]
async fn test_initialize_already_initialized_skipped() {
let server = McpServer::new("echo", "echo");
let client = McpClient::new(server);
let _ = client.initialize().await;
let result = client.initialize().await;
assert!(result.is_err() || result.is_ok());
}
#[test]
fn test_client_default_timeout_is_30_seconds() {
let server = McpServer::new("test", "echo");
let client = McpClient::new(server);
assert_eq!(client.server.name, "test");
}
#[tokio::test]
async fn test_shutdown_clears_initialized_flag() {
let server = McpServer::new("test-clear", "echo");
let client = McpClient::new(server);
assert!(!client.is_initialized().await);
client.shutdown().await.unwrap();
assert!(!client.is_initialized().await);
}
}