use std::collections::BTreeMap;
use std::path::PathBuf;
use std::process::Stdio;
use std::time::Duration;
use async_trait::async_trait;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::process::{Child, ChildStdin, ChildStdout, Command};
use tokio::task::JoinHandle;
use tokio::time::timeout;
use super::transport::{JsonRpcMessage, McpError, Transport};
const CLOSE_GRACE: Duration = Duration::from_secs(5);
pub struct StdioTransport {
child: Option<Child>,
stdin: Option<ChildStdin>,
stdout: Option<BufReader<ChildStdout>>,
stderr_task: Option<JoinHandle<()>>,
line_buf: String,
}
impl StdioTransport {
pub fn builder() -> StdioTransportBuilder {
StdioTransportBuilder::default()
}
pub fn pid(&self) -> Option<u32> {
self.child.as_ref().and_then(|c| c.id())
}
}
#[derive(Debug, Clone)]
pub struct StdioTransportBuilder {
command: Option<String>,
args: Vec<String>,
env: BTreeMap<String, String>,
cwd: Option<PathBuf>,
inherit_env: bool,
}
impl StdioTransportBuilder {
pub fn command(mut self, command: impl Into<String>) -> Self {
self.command = Some(command.into());
self
}
pub fn arg(mut self, arg: impl Into<String>) -> Self {
self.args.push(arg.into());
self
}
pub fn args<I, S>(mut self, args: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.args = args.into_iter().map(Into::into).collect();
self
}
pub fn env(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.env.insert(key.into(), value.into());
self
}
pub fn cwd(mut self, cwd: impl Into<PathBuf>) -> Self {
self.cwd = Some(cwd.into());
self
}
pub fn inherit_env(mut self, inherit: bool) -> Self {
self.inherit_env = inherit;
self
}
pub fn spawn(mut self) -> Result<StdioTransport, McpError> {
let command = self
.command
.take()
.ok_or_else(|| McpError::Other("StdioTransport: command not set".into()))?;
let mut cmd = Command::new(&command);
cmd.args(&self.args)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.kill_on_drop(true);
if !self.inherit_env {
cmd.env_clear();
}
for (k, v) in &self.env {
cmd.env(k, v);
}
if let Some(cwd) = &self.cwd {
cmd.current_dir(cwd);
}
let mut child = cmd
.spawn()
.map_err(|e| McpError::Other(format!("spawn {command}: {e}")))?;
let stdin = child
.stdin
.take()
.ok_or_else(|| McpError::Other("no stdin pipe".into()))?;
let stdout = child
.stdout
.take()
.ok_or_else(|| McpError::Other("no stdout pipe".into()))?;
let stderr = child
.stderr
.take()
.ok_or_else(|| McpError::Other("no stderr pipe".into()))?;
let stderr_task = tokio::spawn(async move {
let mut reader = BufReader::new(stderr);
let mut line = String::new();
loop {
line.clear();
match reader.read_line(&mut line).await {
Ok(0) => break,
Ok(_) => {
let trimmed = line.trim_end_matches(['\r', '\n']);
if !trimmed.is_empty() {
eprintln!("[mcp-stdio] {trimmed}");
}
}
Err(_) => break,
}
}
});
Ok(StdioTransport {
child: Some(child),
stdin: Some(stdin),
stdout: Some(BufReader::new(stdout)),
stderr_task: Some(stderr_task),
line_buf: String::new(),
})
}
}
impl Default for StdioTransportBuilder {
fn default() -> Self {
Self {
command: None,
args: Vec::new(),
env: BTreeMap::new(),
cwd: None,
inherit_env: true,
}
}
}
#[async_trait]
impl Transport for StdioTransport {
async fn send(&mut self, msg: JsonRpcMessage) -> Result<(), McpError> {
let stdin = self.stdin.as_mut().ok_or(McpError::Closed)?;
let mut bytes = serde_json::to_vec(&msg)?;
bytes.push(b'\n');
stdin.write_all(&bytes).await?;
stdin.flush().await?;
Ok(())
}
async fn recv(&mut self) -> Result<JsonRpcMessage, McpError> {
let stdout = self.stdout.as_mut().ok_or(McpError::Closed)?;
self.line_buf.clear();
let n = stdout.read_line(&mut self.line_buf).await?;
if n == 0 {
return Err(McpError::Closed);
}
let trimmed = self.line_buf.trim_end_matches(['\r', '\n']);
let msg: JsonRpcMessage = serde_json::from_str(trimmed)?;
Ok(msg)
}
async fn close(&mut self) -> Result<(), McpError> {
drop(self.stdin.take());
if let Some(mut child) = self.child.take() {
match timeout(CLOSE_GRACE, child.wait()).await {
Ok(Ok(_status)) => {}
Ok(Err(e)) => return Err(McpError::Io(e)),
Err(_) => {
let _ = child.start_kill();
let _ = child.wait().await;
}
}
}
self.stdout.take();
if let Some(task) = self.stderr_task.take() {
let _ = task.await;
}
Ok(())
}
}