use anyhow::{Context, Result, anyhow};
use serde_json::Value;
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::process::{Child, Command};
use tokio::sync::{Mutex, oneshot};
use tokio::time::{Duration, timeout};
const REQUEST_TIMEOUT_SECS: u64 = 30;
pub struct StdioTransport {
stdin: Arc<Mutex<tokio::process::ChildStdin>>,
pending: Arc<Mutex<HashMap<u64, oneshot::Sender<Value>>>>,
next_id: AtomicU64,
child: Arc<Mutex<Child>>,
_reader_task: tokio::task::JoinHandle<()>,
}
impl StdioTransport {
pub async fn spawn(
command: &str,
args: &[String],
env: &HashMap<String, String>,
) -> Result<Self> {
let mut cmd = Command::new(command);
cmd.args(args)
.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped());
for (key, value) in env {
cmd.env(key, value);
}
let mut child = cmd.spawn().with_context(|| {
format!(
"Failed to spawn MCP server: {} {}",
command,
args.join(" ")
)
})?;
let stdin = child
.stdin
.take()
.ok_or_else(|| anyhow!("Failed to capture MCP server stdin"))?;
let stdout = child
.stdout
.take()
.ok_or_else(|| anyhow!("Failed to capture MCP server stdout"))?;
let stderr = child
.stderr
.take()
.ok_or_else(|| anyhow!("Failed to capture MCP server stderr"))?;
let pending: Arc<Mutex<HashMap<u64, oneshot::Sender<Value>>>> =
Arc::new(Mutex::new(HashMap::new()));
let pending_clone = Arc::clone(&pending);
let reader_task = tokio::spawn(async move {
let mut reader = BufReader::new(stdout).lines();
while let Ok(Some(line)) = reader.next_line().await {
if line.trim().is_empty() {
continue;
}
let Ok(msg) = serde_json::from_str::<Value>(&line) else {
tracing::warn!("MCP: unparseable stdout line: {}", &line[..line.len().min(200)]);
continue;
};
if let Some(id) = msg.get("id").and_then(|v| v.as_u64()) {
let mut pending = pending_clone.lock().await;
if let Some(sender) = pending.remove(&id) {
let _ = sender.send(msg);
}
}
}
});
tokio::spawn(async move {
let mut reader = BufReader::new(stderr).lines();
while let Ok(Some(line)) = reader.next_line().await {
tracing::debug!("MCP stderr: {}", line);
}
});
Ok(Self {
stdin: Arc::new(Mutex::new(stdin)),
pending,
next_id: AtomicU64::new(1),
child: Arc::new(Mutex::new(child)),
_reader_task: reader_task,
})
}
pub async fn send_request(&self, method: &str, params: Value) -> Result<Value> {
let id = self.next_id.fetch_add(1, Ordering::Relaxed);
let request = serde_json::json!({
"jsonrpc": "2.0",
"id": id,
"method": method,
"params": params,
});
let (tx, rx) = oneshot::channel();
{
let mut pending = self.pending.lock().await;
pending.insert(id, tx);
}
let msg = format!("{}\n", serde_json::to_string(&request)?);
{
let mut stdin = self.stdin.lock().await;
stdin.write_all(msg.as_bytes()).await.with_context(|| {
format!("Failed to write to MCP server stdin (method: {})", method)
})?;
stdin.flush().await?;
}
let response = timeout(Duration::from_secs(REQUEST_TIMEOUT_SECS), rx)
.await
.map_err(|_| anyhow!("MCP request timed out after {}s: {}", REQUEST_TIMEOUT_SECS, method))?
.map_err(|_| anyhow!("MCP response channel closed unexpectedly"))?;
if let Some(error) = response.get("error") {
let code = error.get("code").and_then(|c| c.as_i64()).unwrap_or(-1);
let message = error
.get("message")
.and_then(|m| m.as_str())
.unwrap_or("Unknown error");
return Err(anyhow!("MCP error (code {}): {}", code, message));
}
response
.get("result")
.cloned()
.ok_or_else(|| anyhow!("MCP response missing 'result' field"))
}
pub async fn send_notification(&self, method: &str, params: Value) -> Result<()> {
let notification = serde_json::json!({
"jsonrpc": "2.0",
"method": method,
"params": params,
});
let msg = format!("{}\n", serde_json::to_string(¬ification)?);
let mut stdin = self.stdin.lock().await;
stdin.write_all(msg.as_bytes()).await?;
stdin.flush().await?;
Ok(())
}
pub async fn shutdown(&self) {
let mut child = self.child.lock().await;
let _ = child.kill().await;
}
}