use anyhow::{Context, Result, bail};
use async_trait::async_trait;
use std::collections::HashMap;
use std::process::Stdio;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::process::{Child, Command};
use tokio::sync::{Mutex, oneshot};
use super::protocol::{JsonRpcRequest, JsonRpcResponse, RequestId};
#[async_trait]
pub trait McpTransport: Send + Sync {
async fn send(&self, request: JsonRpcRequest) -> Result<JsonRpcResponse>;
async fn send_notification(&self, request: JsonRpcRequest) -> Result<()>;
async fn close(&self) -> Result<()>;
}
const DEFAULT_RESPONSE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(60);
pub struct StdioTransport {
next_id: AtomicU64,
pending: Mutex<HashMap<RequestId, oneshot::Sender<JsonRpcResponse>>>,
writer: Mutex<tokio::io::BufWriter<tokio::process::ChildStdin>>,
_child: Arc<Mutex<Child>>,
response_timeout: std::time::Duration,
}
impl StdioTransport {
pub fn spawn(command: &str, args: &[&str]) -> Result<Arc<Self>> {
let mut child = Command::new(command)
.args(args)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::null())
.kill_on_drop(true)
.spawn()
.with_context(|| format!("Failed to spawn MCP server: {command}"))?;
let stdin = child.stdin.take().context("Failed to get stdin")?;
let stdout = child.stdout.take().context("Failed to get stdout")?;
let transport = Arc::new(Self {
next_id: AtomicU64::new(1),
pending: Mutex::new(HashMap::new()),
writer: Mutex::new(tokio::io::BufWriter::new(stdin)),
_child: Arc::new(Mutex::new(child)),
response_timeout: DEFAULT_RESPONSE_TIMEOUT,
});
let transport_clone = Arc::clone(&transport);
tokio::spawn(async move {
let mut reader = BufReader::new(stdout);
let mut line = String::new();
loop {
line.clear();
match reader.read_line(&mut line).await {
Ok(0) | Err(_) => break, Ok(_) => {
const MAX_LINE_LEN: usize = 10 * 1024 * 1024; if line.len() > MAX_LINE_LEN {
log::warn!(
"MCP stdout line exceeds {} bytes (got {}), skipping",
MAX_LINE_LEN,
line.len()
);
continue;
}
if let Ok(response) = serde_json::from_str::<JsonRpcResponse>(&line) {
let mut pending = transport_clone.pending.lock().await;
if let Some(sender) = pending.remove(&response.id) {
let _ = sender.send(response);
}
}
}
}
}
});
Ok(transport)
}
pub fn spawn_with_env(command: &str, args: &[&str], env: &[(&str, &str)]) -> Result<Arc<Self>> {
let mut cmd = Command::new(command);
cmd.args(args)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::null())
.kill_on_drop(true);
for (key, value) in env {
cmd.env(key, value);
}
let mut child = cmd
.spawn()
.with_context(|| format!("Failed to spawn MCP server: {command}"))?;
let stdin = child.stdin.take().context("Failed to get stdin")?;
let stdout = child.stdout.take().context("Failed to get stdout")?;
let transport = Arc::new(Self {
next_id: AtomicU64::new(1),
pending: Mutex::new(HashMap::new()),
writer: Mutex::new(tokio::io::BufWriter::new(stdin)),
_child: Arc::new(Mutex::new(child)),
response_timeout: DEFAULT_RESPONSE_TIMEOUT,
});
let transport_clone = Arc::clone(&transport);
tokio::spawn(async move {
let mut reader = BufReader::new(stdout);
let mut line = String::new();
loop {
line.clear();
match reader.read_line(&mut line).await {
Ok(0) | Err(_) => break, Ok(_) => {
const MAX_LINE_LEN: usize = 10 * 1024 * 1024; if line.len() > MAX_LINE_LEN {
log::warn!(
"MCP stdout line exceeds {} bytes (got {}), skipping",
MAX_LINE_LEN,
line.len()
);
continue;
}
if let Ok(response) = serde_json::from_str::<JsonRpcResponse>(&line) {
let mut pending = transport_clone.pending.lock().await;
if let Some(sender) = pending.remove(&response.id) {
let _ = sender.send(response);
}
}
}
}
}
});
Ok(transport)
}
fn next_request_id(&self) -> u64 {
self.next_id.fetch_add(1, Ordering::SeqCst)
}
}
#[async_trait]
impl McpTransport for StdioTransport {
async fn send(&self, mut request: JsonRpcRequest) -> Result<JsonRpcResponse> {
let id = self.next_request_id();
request.id = RequestId::Number(id);
let (tx, rx) = oneshot::channel();
{
let mut pending = self.pending.lock().await;
pending.insert(request.id.clone(), tx);
}
let json = serde_json::to_string(&request)?;
let mut writer = self.writer.lock().await;
writer.write_all(json.as_bytes()).await?;
writer.write_all(b"\n").await?;
writer.flush().await?;
drop(writer);
let response = tokio::time::timeout(self.response_timeout, rx)
.await
.context("MCP response timed out")?
.context("Response channel closed")?;
if let Some(ref error) = response.error {
bail!("JSON-RPC error {}: {}", error.code, error.message);
}
Ok(response)
}
async fn send_notification(&self, mut request: JsonRpcRequest) -> Result<()> {
let id = self.next_request_id();
request.id = RequestId::Number(id);
let json = serde_json::to_string(&request)?;
let mut writer = self.writer.lock().await;
writer.write_all(json.as_bytes()).await?;
writer.write_all(b"\n").await?;
writer.flush().await?;
drop(writer);
Ok(())
}
async fn close(&self) -> Result<()> {
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_request_id_generation() {
let next_id = AtomicU64::new(1);
assert_eq!(next_id.fetch_add(1, Ordering::SeqCst), 1);
assert_eq!(next_id.fetch_add(1, Ordering::SeqCst), 2);
assert_eq!(next_id.fetch_add(1, Ordering::SeqCst), 3);
}
}