use anyhow::{Context, Result, anyhow};
use serde_json::Value;
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::process::{Child, Command};
use tokio::sync::{Mutex, oneshot};
use tokio::time::{Duration, timeout};
const REQUEST_TIMEOUT_SECS: u64 = 30;
pub struct StdioTransport {
stdin: Arc<Mutex<tokio::process::ChildStdin>>,
pending: Arc<Mutex<HashMap<u64, oneshot::Sender<Value>>>>,
next_id: AtomicU64,
child: Arc<Mutex<Child>>,
_reader_task: tokio::task::JoinHandle<()>,
}
impl StdioTransport {
pub async fn spawn(
command: &str,
args: &[String],
env: &HashMap<String, String>,
) -> Result<Self> {
let mut cmd = Command::new(command);
cmd.args(args)
.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped())
.kill_on_drop(true);
for (key, value) in env {
cmd.env(key, value);
}
let mut child = cmd.spawn().with_context(|| {
format!("Failed to spawn MCP server: {} {}", command, args.join(" "))
})?;
let stdin = child
.stdin
.take()
.ok_or_else(|| anyhow!("Failed to capture MCP server stdin"))?;
let stdout = child
.stdout
.take()
.ok_or_else(|| anyhow!("Failed to capture MCP server stdout"))?;
let stderr = child
.stderr
.take()
.ok_or_else(|| anyhow!("Failed to capture MCP server stderr"))?;
let pending: Arc<Mutex<HashMap<u64, oneshot::Sender<Value>>>> =
Arc::new(Mutex::new(HashMap::new()));
let pending_clone = Arc::clone(&pending);
let reader_task = tokio::spawn(async move {
let mut reader = BufReader::new(stdout).lines();
while let Ok(Some(line)) = reader.next_line().await {
if line.trim().is_empty() {
continue;
}
let Ok(msg) = serde_json::from_str::<Value>(&line) else {
let end = line.floor_char_boundary(200);
tracing::warn!("MCP: unparseable stdout line: {}", &line[..end]);
continue;
};
if let Some(id) = msg.get("id").and_then(parse_response_id) {
let mut pending = pending_clone.lock().await;
if let Some(sender) = pending.remove(&id) {
let _ = sender.send(msg);
}
}
}
});
tokio::spawn(async move {
let mut reader = BufReader::new(stderr).lines();
while let Ok(Some(line)) = reader.next_line().await {
tracing::debug!("MCP stderr: {}", line);
}
});
Ok(Self {
stdin: Arc::new(Mutex::new(stdin)),
pending,
next_id: AtomicU64::new(1),
child: Arc::new(Mutex::new(child)),
_reader_task: reader_task,
})
}
pub async fn send_request(&self, method: &str, params: Value) -> Result<Value> {
let id = self.next_id.fetch_add(1, Ordering::Relaxed);
let request = serde_json::json!({
"jsonrpc": "2.0",
"id": id,
"method": method,
"params": params,
});
let (tx, rx) = oneshot::channel();
{
let mut pending = self.pending.lock().await;
pending.insert(id, tx);
}
let msg = format!("{}\n", serde_json::to_string(&request)?);
{
let mut stdin = self.stdin.lock().await;
stdin.write_all(msg.as_bytes()).await.with_context(|| {
format!("Failed to write to MCP server stdin (method: {})", method)
})?;
stdin.flush().await?;
}
let response = timeout(Duration::from_secs(REQUEST_TIMEOUT_SECS), rx)
.await
.map_err(|_| {
anyhow!(
"MCP request timed out after {}s: {}",
REQUEST_TIMEOUT_SECS,
method
)
})?
.map_err(|_| anyhow!("MCP response channel closed unexpectedly"))?;
if let Some(error) = response.get("error") {
let code = error.get("code").and_then(|c| c.as_i64()).unwrap_or(-1);
let message = error
.get("message")
.and_then(|m| m.as_str())
.unwrap_or("Unknown error");
return Err(anyhow!("MCP error (code {}): {}", code, message));
}
response
.get("result")
.cloned()
.ok_or_else(|| anyhow!("MCP response missing 'result' field"))
}
pub async fn send_notification(&self, method: &str, params: Value) -> Result<()> {
let notification = serde_json::json!({
"jsonrpc": "2.0",
"method": method,
"params": params,
});
let msg = format!("{}\n", serde_json::to_string(¬ification)?);
let mut stdin = self.stdin.lock().await;
stdin.write_all(msg.as_bytes()).await?;
stdin.flush().await?;
Ok(())
}
pub async fn shutdown(&self) {
{
let mut stdin = self.stdin.lock().await;
let _ = stdin.shutdown().await;
}
let mut child = self.child.lock().await;
if tokio::time::timeout(Duration::from_secs(2), child.wait())
.await
.is_ok()
{
return;
}
if let Err(e) = child.start_kill() {
tracing::debug!("MCP: start_kill failed: {}", e);
}
if tokio::time::timeout(Duration::from_secs(1), child.wait())
.await
.is_ok()
{
return;
}
let _ = child.kill().await;
let _ = child.wait().await;
}
}
fn parse_response_id(v: &Value) -> Option<u64> {
v.as_u64()
.or_else(|| v.as_str().and_then(|s| s.parse::<u64>().ok()))
}
#[cfg(test)]
mod tests {
use super::parse_response_id;
use serde_json::json;
#[test]
fn truncation_respects_char_boundary() {
let line = format!("{}你好", "a".repeat(199));
let end = line.floor_char_boundary(200);
let truncated = &line[..end]; assert!(end <= 200);
assert!(truncated.is_char_boundary(end));
}
#[test]
fn parse_response_id_accepts_integer() {
assert_eq!(parse_response_id(&json!(5)), Some(5));
assert_eq!(parse_response_id(&json!(0)), Some(0));
}
#[test]
fn parse_response_id_accepts_string_integer() {
assert_eq!(parse_response_id(&json!("5")), Some(5));
assert_eq!(parse_response_id(&json!("0")), Some(0));
}
#[test]
fn parse_response_id_rejects_non_numeric() {
assert_eq!(parse_response_id(&json!("abc")), None);
assert_eq!(parse_response_id(&json!(null)), None);
assert_eq!(parse_response_id(&json!({})), None);
assert_eq!(parse_response_id(&json!(-1)), None);
}
}