use anyhow::{anyhow, Context, Result};
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::process::{Child, ChildStdin, ChildStdout, Command};
use tokio::sync::RwLock;
use tokio::time::{timeout, Duration};
use crate::protocol::*;
pub struct McpClient {
server: McpServer,
child: RwLock<Option<Child>>,
stdin: RwLock<Option<tokio::io::BufWriter<ChildStdin>>>,
stdout: RwLock<Option<BufReader<ChildStdout>>>,
initialized: RwLock<bool>,
tool_cache: RwLock<Option<Vec<McpTool>>>,
server_info: RwLock<Option<ServerInfo>>,
request_timeout: Duration,
}
impl McpClient {
pub fn new(server: McpServer) -> Self {
Self {
server,
child: RwLock::new(None),
stdin: RwLock::new(None),
stdout: RwLock::new(None),
initialized: RwLock::new(false),
tool_cache: RwLock::new(None),
server_info: RwLock::new(None),
request_timeout: Duration::from_secs(30),
}
}
#[must_use]
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.request_timeout = timeout;
self
}
pub async fn initialize(&self) -> Result<()> {
if *self.initialized.read().await {
return Ok(());
}
let mut child = Command::new(&self.server.command)
.args(&self.server.args)
.envs(&self.server.env)
.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped())
.spawn()
.with_context(|| format!("Failed to spawn MCP server '{}'", self.server.name))?;
let stdin = child
.stdin
.take()
.expect("stdin not captured — stdin was piped");
let stdout = child
.stdout
.take()
.expect("stdout not captured — stdout was piped");
*self.stdin.write().await = Some(tokio::io::BufWriter::new(stdin));
*self.stdout.write().await = Some(BufReader::new(stdout));
*self.child.write().await = Some(child);
let params = InitializeParams::default();
let request = McpRequest::new("initialize").with_params(serde_json::to_value(¶ms)?);
let response = self.do_request(request).await?;
let result_json = response.into_result()?;
let init_result: InitializeResult = serde_json::from_value(result_json)?;
*self.server_info.write().await = Some(init_result.server_info.clone());
*self.initialized.write().await = true;
let notification = McpRequest::new("notifications/initialized");
self.send_notification(notification).await?;
tracing::debug!(
server = %self.server.name,
version = %init_result.server_info.version,
"MCP server initialized"
);
Ok(())
}
pub async fn is_initialized(&self) -> bool {
*self.initialized.read().await
}
pub async fn server_info(&self) -> Option<ServerInfo> {
self.server_info.read().await.clone()
}
async fn do_request(&self, request: McpRequest) -> Result<McpResponse> {
let request_id = request.id.clone();
let mut stdin_guard = self.stdin.write().await;
let stdin = stdin_guard
.as_mut()
.ok_or_else(|| anyhow!("stdin not available on '{}'", self.server.name))?;
let json = request.to_jsonl()?;
timeout(self.request_timeout, async {
stdin.write_all(&json).await?;
stdin.flush().await?;
Ok::<(), tokio::io::Error>(())
})
.await
.map_err(|e| anyhow::anyhow!("MCP request timed out (write): {e}"))??;
let mut stdout_guard = self.stdout.write().await;
let stdout = stdout_guard
.as_mut()
.ok_or_else(|| anyhow!("stdout not available on '{}'", self.server.name))?;
let line: std::io::Result<Option<String>> = timeout(self.request_timeout, async {
stdout.lines().next_line().await
})
.await
.map_err(|e| anyhow::anyhow!("MCP request timed out (read): {e}"))?;
let response_str: String = line
.context("Failed to read MCP response line from stdout")?
.with_context(|| format!("MCP server {} returned no response", self.server.name))?;
let parsed: McpResponse = serde_json::from_str(&response_str)
.with_context(|| format!("Failed to parse MCP response JSON: {response_str}"))?;
if parsed.id != request_id {
tracing::warn!(
server = %self.server.name,
expected_id = ?request_id,
got_id = ?parsed.id,
"MCP response ID mismatch"
);
}
Ok(parsed)
}
async fn send_notification(&self, notification: McpRequest) -> Result<()> {
let mut stdin_guard = self.stdin.write().await;
let stdin = stdin_guard
.as_mut()
.ok_or_else(|| anyhow!("stdin not available on '{}'", self.server.name))?;
let json = notification.to_jsonl()?;
stdin.write_all(&json).await?;
stdin.flush().await?;
Ok(())
}
pub(crate) async fn send_request(&self, request: McpRequest) -> Result<McpResponse> {
{
let child = self.child.read().await;
if child.is_none() {
tracing::warn!(
server = %self.server.name,
"MCP server not running, attempting auto-start"
);
drop(child);
self.restart().await?;
}
}
match self.do_request(request).await {
Ok(resp) => Ok(resp),
Err(e) => {
let err_str = e.to_string();
let is_comm_error = err_str.contains("not available")
|| err_str.contains("broken pipe")
|| err_str.contains("timed out")
|| err_str.contains("no response");
if is_comm_error {
tracing::warn!(
server = %self.server.name,
error = %err_str,
"MCP communication error, attempting auto-restart"
);
self.restart().await?;
anyhow::bail!(
"MCP server '{}' restarted after error. Please retry the request.",
self.server.name
);
} else {
Err(e)
}
}
}
}
pub async fn list_tools(&self) -> Result<Vec<McpTool>> {
if let Some(cached) = self.tool_cache.read().await.clone() {
return Ok(cached);
}
self.refresh_tools().await
}
pub async fn refresh_tools(&self) -> Result<Vec<McpTool>> {
let request = McpRequest::new("tools/list");
let response = self.send_request(request).await?;
let result_json = response.into_result()?;
let tools_result: McpToolsResult = serde_json::from_value(result_json)?;
let tools = tools_result.tools;
*self.tool_cache.write().await = Some(tools.clone());
tracing::debug!(
server = %self.server.name,
count = tools.len(),
"Refreshed tool cache"
);
Ok(tools)
}
pub async fn call_tool(
&self,
tool_name: &str,
arguments: serde_json::Value,
) -> Result<McpToolCallResult> {
let params = serde_json::json!({
"name": tool_name,
"arguments": arguments,
});
let request = McpRequest::new("tools/call").with_params(params);
let response = self.send_request(request).await?;
let result_json = response.into_result()?;
let call_result: McpToolCallResult = serde_json::from_value(result_json)?;
tracing::debug!(
server = %self.server.name,
tool = tool_name,
"Tool call completed"
);
Ok(call_result)
}
pub async fn call_tool_text(
&self,
tool_name: &str,
arguments: serde_json::Value,
) -> Result<String> {
let result = self.call_tool(tool_name, arguments).await?;
for block in result.content {
if let McpContentBlock::Text { text } = block {
return Ok(text);
}
}
Err(anyhow!("Tool '{tool_name}' returned no text content"))
}
pub async fn shutdown(&self) -> Result<()> {
*self.stdin.write().await = None;
*self.stdout.write().await = None;
let mut child_guard = self.child.write().await;
if let Some(mut child) = child_guard.take() {
tracing::debug!(server = %self.server.name, "Shutting down MCP server");
let _ = child.try_wait();
child.kill().await?;
let _ = child.wait().await;
}
*self.initialized.write().await = false;
*self.tool_cache.write().await = None;
Ok(())
}
pub async fn restart(&self) -> Result<()> {
self.shutdown().await?;
self.initialize().await
}
pub fn server(&self) -> &McpServer {
&self.server
}
}
impl std::fmt::Debug for McpClient {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("McpClient")
.field("server", &self.server.name)
.field("initialized", &self.initialized)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::time::Duration;
#[test]
fn test_client_construction() {
let server = McpServer::new("test-server", "npx");
let client = McpClient::new(server);
assert_eq!(client.server.name, "test-server");
assert_eq!(client.server.command, "npx");
}
#[test]
fn test_client_with_timeout() {
let server = McpServer::new("test", "echo");
let client = McpClient::new(server).with_timeout(Duration::from_secs(60));
assert_eq!(client.server.name, "test");
}
#[test]
fn test_client_with_timeout_short() {
let server = McpServer::new("test", "sleep");
let client = McpClient::new(server).with_timeout(Duration::from_millis(50));
assert_eq!(client.server.name, "test");
}
#[test]
fn test_client_debug_format() {
let server = McpServer::new("debug-test", "echo");
let client = McpClient::new(server);
let debug_str = format!("{client:?}");
assert!(debug_str.contains("debug-test"));
assert!(debug_str.contains("McpClient"));
}
#[test]
fn test_client_debug_different_servers() {
let server1 = McpServer::new("server-a", "cmd1");
let server2 = McpServer::new("server-b", "cmd2");
let client1 = McpClient::new(server1);
let client2 = McpClient::new(server2);
let debug1 = format!("{client1:?}");
let debug2 = format!("{client2:?}");
assert!(debug1.contains("server-a"));
assert!(debug2.contains("server-b"));
assert_ne!(debug1, debug2);
}
#[tokio::test]
async fn test_is_initialized_false_on_new() {
let server = McpServer::new("test", "echo");
let client = McpClient::new(server);
assert!(!client.is_initialized().await);
}
#[tokio::test]
async fn test_is_initialized_after_failed_init() {
let server = McpServer::new("ghost", "nonexistent-binary-xyz-123");
let client = McpClient::new(server);
let result = client.initialize().await;
assert!(result.is_err());
assert!(!client.is_initialized().await);
}
#[tokio::test]
async fn test_shutdown_when_not_running() {
let server = McpServer::new("test-shutdown", "echo");
let client = McpClient::new(server);
let result = client.shutdown().await;
assert!(result.is_ok());
assert!(!client.is_initialized().await);
}
#[tokio::test]
async fn test_shutdown_idempotent() {
let server = McpServer::new("test-idempotent", "echo");
let client = McpClient::new(server);
let first = client.shutdown().await;
assert!(first.is_ok());
let second = client.shutdown().await;
assert!(second.is_ok());
}
#[test]
fn test_client_server_config_passed_through() {
let server = McpServer::new("config-test", "npx")
.with_args(vec!["-y".to_string(), "@some/mcp-server".to_string()])
.with_env("DEBUG", "true");
let client = McpClient::new(server);
assert_eq!(client.server.name, "config-test");
assert_eq!(client.server.command, "npx");
assert_eq!(client.server.args, vec!["-y", "@some/mcp-server"]);
assert_eq!(client.server.env.get("DEBUG"), Some(&"true".to_string()));
}
#[test]
fn test_client_server_method() {
let server = McpServer::new("method-test", "python");
let client = McpClient::new(server);
let retrieved_server = client.server();
assert_eq!(retrieved_server.name, "method-test");
}
#[tokio::test]
async fn test_server_info_none_on_new_client() {
let server = McpServer::new("test", "echo");
let client = McpClient::new(server);
assert!(client.server_info().await.is_none());
}
#[tokio::test]
async fn test_initialize_already_initialized_skipped() {
let server = McpServer::new("echo", "echo");
let client = McpClient::new(server);
let _ = client.initialize().await;
let result = client.initialize().await;
assert!(result.is_err() || result.is_ok());
}
#[test]
fn test_client_default_timeout_is_30_seconds() {
let server = McpServer::new("test", "echo");
let client = McpClient::new(server);
assert_eq!(client.server.name, "test");
}
#[tokio::test]
async fn test_shutdown_clears_initialized_flag() {
let server = McpServer::new("test-clear", "echo");
let client = McpClient::new(server);
assert!(!client.is_initialized().await);
client.shutdown().await.unwrap();
assert!(!client.is_initialized().await);
}
}