use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Duration;
use serde::de::DeserializeOwned;
use serde_json::Value as JsonValue;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::process::{Child, ChildStdin, Command};
use tokio::sync::{Mutex, oneshot};
use super::protocol::*;
const INITIALIZE_TIMEOUT: Duration = Duration::from_secs(10);
const LIST_TOOLS_TIMEOUT: Duration = Duration::from_secs(30);
const CALL_TOOL_TIMEOUT: Duration = Duration::from_secs(120);
#[derive(Debug, thiserror::Error)]
pub enum McpClientError {
#[error("failed to spawn MCP server process: {0}")]
SpawnFailed(#[from] std::io::Error),
#[error("MCP server process has no stdin")]
NoStdin,
#[error("MCP server process has no stdout")]
NoStdout,
#[error("MCP server returned JSON-RPC error: {0}")]
JsonRpc(JsonRpcError),
#[error("timeout waiting for MCP response ({0:?})")]
Timeout(Duration),
#[error("MCP server process exited unexpectedly")]
ProcessExited,
#[error("failed to parse MCP response: {0}")]
ParseError(String),
#[error("MCP client is already shut down")]
Shutdown,
}
type PendingMap = HashMap<u64, oneshot::Sender<Result<JsonValue, McpClientError>>>;
pub struct McpStdioClient {
stdin: Mutex<ChildStdin>,
_child: Mutex<Child>,
next_id: AtomicU64,
pending: Arc<Mutex<PendingMap>>,
server_info: Option<McpServerInfo>,
tools: Vec<McpToolDefinition>,
server_name: String,
}
impl McpStdioClient {
pub async fn connect(config: &McpServerConfig) -> Result<Self, McpClientError> {
let mut cmd = Command::new(&config.command);
cmd.args(&config.args)
.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::null());
for (key, value) in &config.env {
cmd.env(key, value);
}
if let Some(cwd) = &config.cwd {
cmd.current_dir(cwd);
}
let mut child = cmd.spawn()?;
let stdin = child.stdin.take().ok_or(McpClientError::NoStdin)?;
let stdout = child.stdout.take().ok_or(McpClientError::NoStdout)?;
let pending: Arc<Mutex<PendingMap>> = Arc::new(Mutex::new(HashMap::new()));
let pending_clone = pending.clone();
tokio::spawn(async move {
let mut reader = BufReader::new(stdout);
let mut line = String::new();
loop {
line.clear();
match reader.read_line(&mut line).await {
Ok(0) | Err(_) => break,
Ok(_) => {}
}
let trimmed = line.trim();
if trimmed.is_empty() {
continue;
}
if let Ok(resp) = serde_json::from_str::<JsonRpcResponse>(trimmed) {
let id = match &resp.id {
JsonRpcId::Number(n) => *n,
_ => continue,
};
let mut pending = pending_clone.lock().await;
if let Some(tx) = pending.remove(&id) {
let result = if let Some(err) = resp.error {
Err(McpClientError::JsonRpc(err))
} else {
Ok(resp.result.unwrap_or(JsonValue::Null))
};
let _ = tx.send(result);
}
}
}
let mut pending = pending_clone.lock().await;
for (_, tx) in pending.drain() {
let _ = tx.send(Err(McpClientError::ProcessExited));
}
});
let mut client = Self {
stdin: Mutex::new(stdin),
_child: Mutex::new(child),
next_id: AtomicU64::new(1),
pending,
server_info: None,
tools: Vec::new(),
server_name: config.name.clone(),
};
client.initialize().await?;
client.discover_tools().await?;
Ok(client)
}
pub fn server_name(&self) -> &str {
&self.server_name
}
pub fn server_info(&self) -> Option<&McpServerInfo> {
self.server_info.as_ref()
}
pub fn tools(&self) -> &[McpToolDefinition] {
&self.tools
}
async fn call<P: serde::Serialize, R: DeserializeOwned>(
&self,
method: &str,
params: Option<P>,
timeout_duration: Duration,
) -> Result<R, McpClientError> {
let id = self.next_id.fetch_add(1, Ordering::Relaxed);
let params_value = params
.map(|p| serde_json::to_value(p).expect("serialize params"))
.filter(|v| !v.is_null());
let request = JsonRpcRequest::new(id, method, params_value);
let mut line = serde_json::to_string(&request).expect("serialize request");
line.push('\n');
let (tx, rx) = oneshot::channel();
{
let mut pending = self.pending.lock().await;
pending.insert(id, tx);
}
{
let mut stdin = self.stdin.lock().await;
stdin
.write_all(line.as_bytes())
.await
.map_err(|_| McpClientError::ProcessExited)?;
stdin
.flush()
.await
.map_err(|_| McpClientError::ProcessExited)?;
}
let result = tokio::time::timeout(timeout_duration, rx)
.await
.map_err(|_| McpClientError::Timeout(timeout_duration))?
.map_err(|_| McpClientError::ProcessExited)??;
serde_json::from_value(result)
.map_err(|e| McpClientError::ParseError(format!("deserialize response: {e}")))
}
async fn notify<P: serde::Serialize>(
&self,
method: &str,
params: Option<P>,
) -> Result<(), McpClientError> {
let mut obj = serde_json::json!({
"jsonrpc": "2.0",
"method": method,
});
if let Some(p) = params {
obj["params"] = serde_json::to_value(p).expect("serialize params");
}
let mut line = serde_json::to_string(&obj).expect("serialize notification");
line.push('\n');
let mut stdin = self.stdin.lock().await;
stdin
.write_all(line.as_bytes())
.await
.map_err(|_| McpClientError::ProcessExited)?;
stdin
.flush()
.await
.map_err(|_| McpClientError::ProcessExited)?;
Ok(())
}
async fn initialize(&mut self) -> Result<(), McpClientError> {
let params = McpInitializeParams {
protocol_version: "2024-11-05".to_string(),
capabilities: serde_json::json!({}),
client_info: McpClientInfo {
name: "mentra".to_string(),
version: env!("CARGO_PKG_VERSION").to_string(),
},
};
let result: McpInitializeResult = self
.call("initialize", Some(params), INITIALIZE_TIMEOUT)
.await?;
self.server_info = Some(result.server_info);
self.notify::<JsonValue>("notifications/initialized", None)
.await?;
Ok(())
}
async fn discover_tools(&mut self) -> Result<(), McpClientError> {
let mut all_tools = Vec::new();
let mut cursor: Option<String> = None;
loop {
let params = McpListToolsParams {
cursor: cursor.clone(),
};
let result: McpListToolsResult = self
.call("tools/list", Some(params), LIST_TOOLS_TIMEOUT)
.await?;
all_tools.extend(result.tools);
match result.next_cursor {
Some(next) if !next.is_empty() => cursor = Some(next),
_ => break,
}
}
self.tools = all_tools;
Ok(())
}
pub async fn call_tool(
&self,
tool_name: &str,
arguments: Option<JsonValue>,
) -> Result<McpToolCallResult, McpClientError> {
let params = McpToolCallParams {
name: tool_name.to_string(),
arguments,
};
self.call("tools/call", Some(params), CALL_TOOL_TIMEOUT)
.await
}
pub async fn shutdown(&self) {
let mut stdin = self.stdin.lock().await;
drop(stdin.shutdown().await);
}
}
impl Drop for McpStdioClient {
fn drop(&mut self) {
}
}