use std::process::Stdio;
use std::sync::Arc;
use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader};
use tokio::process::{Child, ChildStdin, Command};
use tokio::sync::{mpsc, Mutex};
use tokio::task::JoinHandle;
use tracing::{debug, warn};
use crate::error::{Error, Result};
const INBOUND_CAPACITY: usize = 64;
pub struct StdioTransport {
stdin: Arc<Mutex<ChildStdin>>,
pub(crate) inbound: Mutex<mpsc::Receiver<String>>,
pub(crate) reader: JoinHandle<()>,
pub(crate) stderr_logger: Option<JoinHandle<()>>,
child: Mutex<Option<Child>>,
}
impl StdioTransport {
pub async fn spawn(command: &str, args: &[String]) -> Result<Self> {
let mut child = Command::new(command)
.args(args)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.kill_on_drop(true)
.spawn()
.map_err(|e| Error::other(format!("mcp spawn '{command}': {e}")))?;
let stdin = child
.stdin
.take()
.ok_or_else(|| Error::other("mcp child has no stdin"))?;
let stdout = child
.stdout
.take()
.ok_or_else(|| Error::other("mcp child has no stdout"))?;
let (tx, rx) = mpsc::channel::<String>(INBOUND_CAPACITY);
let reader = tokio::spawn(async move {
let mut reader = BufReader::new(stdout);
let mut line = String::new();
loop {
line.clear();
match reader.read_line(&mut line).await {
Ok(0) => {
debug!("mcp transport: stdout EOF");
break;
}
Ok(_) => {
let trimmed = line.trim_end_matches(['\r', '\n']);
if trimmed.is_empty() {
continue;
}
if tx.send(trimmed.to_string()).await.is_err() {
debug!("mcp transport: receiver dropped");
break;
}
}
Err(e) => {
warn!(?e, "mcp transport: stdout read failed");
break;
}
}
}
});
let stderr_logger = child.stderr.take().map(spawn_stderr_logger);
Ok(Self {
stdin: Arc::new(Mutex::new(stdin)),
inbound: Mutex::new(rx),
reader,
stderr_logger,
child: Mutex::new(Some(child)),
})
}
pub async fn send(&self, payload: &str) -> Result<()> {
let mut stdin = self.stdin.lock().await;
stdin
.write_all(payload.as_bytes())
.await
.map_err(|e| Error::other(format!("mcp write: {e}")))?;
stdin
.write_all(b"\n")
.await
.map_err(|e| Error::other(format!("mcp write nl: {e}")))?;
stdin
.flush()
.await
.map_err(|e| Error::other(format!("mcp flush: {e}")))?;
Ok(())
}
pub async fn shutdown(&self) {
{
let mut stdin = self.stdin.lock().await;
let _ = stdin.shutdown().await;
}
self.reader.abort();
if let Some(h) = &self.stderr_logger {
h.abort();
}
let mut guard = self.child.lock().await;
if let Some(mut child) = guard.take() {
let wait = tokio::time::timeout(
std::time::Duration::from_secs(2),
child.wait(),
)
.await;
if wait.is_err() {
let _ = child.start_kill();
let _ = child.wait().await;
}
}
}
}
fn spawn_stderr_logger(mut stderr: tokio::process::ChildStderr) -> JoinHandle<()> {
tokio::spawn(async move {
let mut buf = Vec::with_capacity(4096);
let mut scratch = [0u8; 4096];
loop {
match stderr.read(&mut scratch).await {
Ok(0) => break,
Ok(n) => {
buf.extend_from_slice(&scratch[..n]);
drain_complete_lines(&mut buf, emit_stderr_line);
}
Err(_) => break,
}
}
flush_remainder(&mut buf, emit_stderr_line);
})
}
fn drain_complete_lines(buf: &mut Vec<u8>, mut emit: impl FnMut(&[u8])) {
while let Some(pos) = buf.iter().position(|b| *b == b'\n') {
let line: Vec<u8> = buf.drain(..=pos).collect();
emit(&line[..line.len() - 1]);
}
}
fn flush_remainder(buf: &mut Vec<u8>, mut emit: impl FnMut(&[u8])) {
if !buf.is_empty() {
emit(buf);
buf.clear();
}
}
fn emit_stderr_line(line: &[u8]) {
let line = String::from_utf8_lossy(line);
debug!(target: "localharness::mcp", "{}", line.trim_end_matches('\r'));
}
#[cfg(test)]
mod tests {
use super::*;
fn collect(buf: &mut Vec<u8>) -> (Vec<String>, Vec<String>) {
let mut complete = Vec::new();
drain_complete_lines(buf, |l| {
complete.push(String::from_utf8_lossy(l).trim_end_matches('\r').to_string())
});
let mut remainder = Vec::new();
flush_remainder(buf, |l| {
remainder.push(String::from_utf8_lossy(l).trim_end_matches('\r').to_string())
});
(complete, remainder)
}
#[test]
fn drains_complete_lines_and_keeps_partial() {
let mut buf = b"one\ntwo\npartial".to_vec();
let mut complete = Vec::new();
drain_complete_lines(&mut buf, |l| {
complete.push(String::from_utf8_lossy(l).into_owned())
});
assert_eq!(complete, vec!["one".to_string(), "two".to_string()]);
assert_eq!(buf, b"partial");
}
#[test]
fn flush_emits_final_unterminated_line() {
let mut buf = b"boom: panic at the end".to_vec();
let (complete, remainder) = collect(&mut buf);
assert!(complete.is_empty());
assert_eq!(remainder, vec!["boom: panic at the end".to_string()]);
assert!(buf.is_empty());
}
#[test]
fn flush_is_noop_on_clean_eof() {
let mut buf = b"only\n".to_vec();
let (complete, remainder) = collect(&mut buf);
assert_eq!(complete, vec!["only".to_string()]);
assert!(remainder.is_empty());
}
#[test]
fn trims_trailing_cr_on_crlf_lines() {
let mut buf = b"crlf\r\nlast\r".to_vec();
let (complete, remainder) = collect(&mut buf);
assert_eq!(complete, vec!["crlf".to_string()]);
assert_eq!(remainder, vec!["last".to_string()]);
}
}