use anyhow::{anyhow, Result};
use serde_json::{json, Value};
use tokio::io::{AsyncBufRead, AsyncBufReadExt, AsyncWrite, AsyncWriteExt};
use tokio::time::{timeout, Duration};
use crate::rpc::{RpcRequest, RpcResponse};
pub struct McpClient {
reader: Box<dyn AsyncBufRead + Unpin + Send>,
writer: Box<dyn AsyncWrite + Unpin + Send>,
timeout: Duration,
next_id: u64,
}
impl McpClient {
pub fn new<R, W>(reader: R, writer: W) -> Self
where
R: AsyncBufRead + Unpin + Send + 'static,
W: AsyncWrite + Unpin + Send + 'static,
{
Self {
reader: Box::new(reader),
writer: Box::new(writer),
timeout: Duration::from_secs(30),
next_id: 1,
}
}
pub async fn connect_stdio(_cmd: &str, _args: &[&str]) -> Result<Self> {
Err(anyhow!(
"McpClient::connect_stdio is not implemented; use McpClient::new(reader, writer) with a custom IO transport instead"
))
}
pub async fn list_tools(&mut self) -> Result<Vec<Value>> {
let result = self.send_request("list_tools", json!({})).await?;
let tools = result
.get("tools")
.and_then(|v| v.as_array())
.cloned()
.unwrap_or_default();
Ok(tools)
}
pub async fn call_tool(&mut self, name: &str, args: Value) -> Result<Value> {
let params = json!({ "name": name, "args": args });
let result = self.send_request("call_tool", params).await?;
Ok(result)
}
async fn send_request(&mut self, method: &str, params: Value) -> Result<Value> {
let id = self.next_id.to_string();
self.next_id += 1;
let req = RpcRequest {
id: id.clone(),
method: method.to_string(),
params,
};
let text = serde_json::to_string(&req)?;
self.writer.write_all(text.as_bytes()).await?;
self.writer.write_all(b"\n").await?;
self.writer.flush().await?;
let mut line = String::new();
let read_future = self.reader.read_line(&mut line);
let result = timeout(self.timeout, read_future)
.await
.map_err(|_| anyhow!("MCP response timed out"))?;
let bytes_read = result?;
if bytes_read == 0 {
return Err(anyhow!("MCP server closed the connection"));
}
let trimmed = line.trim();
let resp: RpcResponse = serde_json::from_str(trimmed)
.map_err(|e| anyhow!("failed to parse MCP response: {e}"))?;
if resp.id != id {
return Err(anyhow!(
"mismatched response id: expected {}, got {}",
id, resp.id
));
}
if let Some(err) = resp.error {
return Err(anyhow!(err.message));
}
Ok(resp.result.unwrap_or(Value::Null))
}
}