use anyhow::{Context, Result, bail};
use async_trait::async_trait;
use std::collections::HashMap;
use std::process::Stdio;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::process::{Child, Command};
use tokio::sync::{Mutex, oneshot};
use super::protocol::{JsonRpcRequest, JsonRpcResponse, RequestId};
#[async_trait]
pub trait McpTransport: Send + Sync {
async fn send(&self, request: JsonRpcRequest) -> Result<JsonRpcResponse>;
async fn send_notification(&self, request: JsonRpcRequest) -> Result<()>;
async fn set_protocol_version(&self, _version: &str) {}
async fn close(&self) -> Result<()>;
}
const DEFAULT_RESPONSE_TIMEOUT: std::time::Duration = std::time::Duration::from_mins(1);
type PendingMap = std::sync::Mutex<HashMap<RequestId, oneshot::Sender<JsonRpcResponse>>>;
struct PendingGuard<'a> {
pending: &'a PendingMap,
request_id: RequestId,
}
impl Drop for PendingGuard<'_> {
fn drop(&mut self) {
let mut pending = self
.pending
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
pending.remove(&self.request_id);
}
}
pub struct StdioTransport {
next_id: AtomicU64,
pending: PendingMap,
writer: Mutex<tokio::io::BufWriter<tokio::process::ChildStdin>>,
_child: Arc<Mutex<Child>>,
response_timeout: std::time::Duration,
}
impl StdioTransport {
pub fn spawn(command: &str, args: &[&str]) -> Result<Arc<Self>> {
let mut child = Command::new(command)
.args(args)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::null())
.kill_on_drop(true)
.spawn()
.with_context(|| format!("Failed to spawn MCP server: {command}"))?;
let stdin = child.stdin.take().context("Failed to get stdin")?;
let stdout = child.stdout.take().context("Failed to get stdout")?;
let transport = Arc::new(Self {
next_id: AtomicU64::new(1),
pending: std::sync::Mutex::new(HashMap::new()),
writer: Mutex::new(tokio::io::BufWriter::new(stdin)),
_child: Arc::new(Mutex::new(child)),
response_timeout: DEFAULT_RESPONSE_TIMEOUT,
});
let transport_clone = Arc::clone(&transport);
tokio::spawn(async move {
let mut reader = BufReader::new(stdout);
let mut line = String::new();
loop {
line.clear();
match reader.read_line(&mut line).await {
Ok(0) | Err(_) => break, Ok(_) => {
const MAX_LINE_LEN: usize = 10 * 1024 * 1024; if line.len() > MAX_LINE_LEN {
log::warn!(
"MCP stdout line exceeds {} bytes (got {}), skipping",
MAX_LINE_LEN,
line.len()
);
continue;
}
if let Ok(response) = serde_json::from_str::<JsonRpcResponse>(&line) {
let sender = {
let mut pending = match transport_clone.pending.lock() {
Ok(pending) => pending,
Err(poisoned) => poisoned.into_inner(),
};
pending.remove(&response.id)
};
if let Some(sender) = sender {
let _ = sender.send(response);
}
}
}
}
}
});
Ok(transport)
}
pub fn spawn_with_env(command: &str, args: &[&str], env: &[(&str, &str)]) -> Result<Arc<Self>> {
let mut cmd = Command::new(command);
cmd.args(args)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::null())
.kill_on_drop(true);
for (key, value) in env {
cmd.env(key, value);
}
let mut child = cmd
.spawn()
.with_context(|| format!("Failed to spawn MCP server: {command}"))?;
let stdin = child.stdin.take().context("Failed to get stdin")?;
let stdout = child.stdout.take().context("Failed to get stdout")?;
let transport = Arc::new(Self {
next_id: AtomicU64::new(1),
pending: std::sync::Mutex::new(HashMap::new()),
writer: Mutex::new(tokio::io::BufWriter::new(stdin)),
_child: Arc::new(Mutex::new(child)),
response_timeout: DEFAULT_RESPONSE_TIMEOUT,
});
let transport_clone = Arc::clone(&transport);
tokio::spawn(async move {
let mut reader = BufReader::new(stdout);
let mut line = String::new();
loop {
line.clear();
match reader.read_line(&mut line).await {
Ok(0) | Err(_) => break, Ok(_) => {
const MAX_LINE_LEN: usize = 10 * 1024 * 1024; if line.len() > MAX_LINE_LEN {
log::warn!(
"MCP stdout line exceeds {} bytes (got {}), skipping",
MAX_LINE_LEN,
line.len()
);
continue;
}
if let Ok(response) = serde_json::from_str::<JsonRpcResponse>(&line) {
let sender = {
let mut pending = match transport_clone.pending.lock() {
Ok(pending) => pending,
Err(poisoned) => poisoned.into_inner(),
};
pending.remove(&response.id)
};
if let Some(sender) = sender {
let _ = sender.send(response);
}
}
}
}
}
});
Ok(transport)
}
fn next_request_id(&self) -> u64 {
self.next_id.fetch_add(1, Ordering::SeqCst)
}
#[cfg(test)]
fn pending_len(&self) -> usize {
self.pending
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.len()
}
#[cfg(test)]
fn spawn_with_timeout(
command: &str,
args: &[&str],
response_timeout: std::time::Duration,
) -> Result<Arc<Self>> {
let mut child = Command::new(command)
.args(args)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::null())
.kill_on_drop(true)
.spawn()
.with_context(|| format!("Failed to spawn MCP server: {command}"))?;
let stdin = child.stdin.take().context("Failed to get stdin")?;
let stdout = child.stdout.take().context("Failed to get stdout")?;
let transport = Arc::new(Self {
next_id: AtomicU64::new(1),
pending: std::sync::Mutex::new(HashMap::new()),
writer: Mutex::new(tokio::io::BufWriter::new(stdin)),
_child: Arc::new(Mutex::new(child)),
response_timeout,
});
let transport_clone = Arc::clone(&transport);
tokio::spawn(async move {
let mut reader = BufReader::new(stdout);
let mut line = String::new();
loop {
line.clear();
match reader.read_line(&mut line).await {
Ok(0) | Err(_) => break,
Ok(_) => {
if let Ok(response) = serde_json::from_str::<JsonRpcResponse>(&line) {
let sender = {
let mut pending = match transport_clone.pending.lock() {
Ok(pending) => pending,
Err(poisoned) => poisoned.into_inner(),
};
pending.remove(&response.id)
};
if let Some(sender) = sender {
let _ = sender.send(response);
}
}
}
}
}
});
Ok(transport)
}
}
#[async_trait]
impl McpTransport for StdioTransport {
async fn send(&self, mut request: JsonRpcRequest) -> Result<JsonRpcResponse> {
let id = self.next_request_id();
request.id = RequestId::Number(id);
let (tx, rx) = oneshot::channel();
{
let mut pending = self
.pending
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
pending.insert(request.id.clone(), tx);
}
let _pending_guard = PendingGuard {
pending: &self.pending,
request_id: request.id.clone(),
};
let json = serde_json::to_string(&request)?;
let mut writer = self.writer.lock().await;
writer.write_all(json.as_bytes()).await?;
writer.write_all(b"\n").await?;
writer.flush().await?;
drop(writer);
let response = tokio::time::timeout(self.response_timeout, rx)
.await
.context("MCP response timed out")?
.context("Response channel closed")?;
if let Some(ref error) = response.error {
bail!("JSON-RPC error {}: {}", error.code, error.message);
}
Ok(response)
}
async fn send_notification(&self, mut request: JsonRpcRequest) -> Result<()> {
let id = self.next_request_id();
request.id = RequestId::Number(id);
let json = serde_json::to_string(&request)?;
let mut writer = self.writer.lock().await;
writer.write_all(json.as_bytes()).await?;
writer.write_all(b"\n").await?;
writer.flush().await?;
drop(writer);
Ok(())
}
async fn close(&self) -> Result<()> {
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use anyhow::ensure;
#[test]
fn test_request_id_generation() {
let next_id = AtomicU64::new(1);
assert_eq!(next_id.fetch_add(1, Ordering::SeqCst), 1);
assert_eq!(next_id.fetch_add(1, Ordering::SeqCst), 2);
assert_eq!(next_id.fetch_add(1, Ordering::SeqCst), 3);
}
#[tokio::test]
async fn timed_out_requests_do_not_leak_pending_entries() -> Result<()> {
const N: usize = 8;
let transport = StdioTransport::spawn_with_timeout(
"sh",
&["-c", "cat > /dev/null"],
std::time::Duration::from_millis(50),
)?;
ensure!(
transport.pending_len() == 0,
"pending map should start empty"
);
for _ in 0..N {
let request = JsonRpcRequest::new("tools/list", None, 0);
let result = transport.send(request).await;
ensure!(
result.is_err(),
"request should time out when the server never replies"
);
}
ensure!(
transport.pending_len() == 0,
"pending map must be empty after {N} timeouts, found {} stale entries",
transport.pending_len(),
);
Ok(())
}
}