use std::process::Stdio;
use async_trait::async_trait;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::process::{Child, Command};
use super::transport::ClientTransport;
use crate::error::{Error, Result};
pub struct StdioClientTransport {
child: Option<Child>,
stdin: Option<tokio::process::ChildStdin>,
stdout: BufReader<tokio::process::ChildStdout>,
}
impl StdioClientTransport {
pub async fn spawn(program: &str, args: &[&str]) -> Result<Self> {
let mut cmd = Command::new(program);
cmd.args(args);
Self::spawn_command(&mut cmd).await
}
pub async fn spawn_command(cmd: &mut Command) -> Result<Self> {
cmd.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::inherit());
let mut child = cmd
.spawn()
.map_err(|e| Error::Transport(format!("Failed to spawn process: {}", e)))?;
let stdin = child
.stdin
.take()
.ok_or_else(|| Error::Transport("Failed to get child stdin".to_string()))?;
let stdout = child
.stdout
.take()
.ok_or_else(|| Error::Transport("Failed to get child stdout".to_string()))?;
tracing::info!("Spawned MCP server process");
Ok(Self {
child: Some(child),
stdin: Some(stdin),
stdout: BufReader::new(stdout),
})
}
pub fn from_child(mut child: Child) -> Result<Self> {
let stdin = child
.stdin
.take()
.ok_or_else(|| Error::Transport("Failed to get child stdin".to_string()))?;
let stdout = child
.stdout
.take()
.ok_or_else(|| Error::Transport("Failed to get child stdout".to_string()))?;
Ok(Self {
child: Some(child),
stdin: Some(stdin),
stdout: BufReader::new(stdout),
})
}
}
#[async_trait]
impl ClientTransport for StdioClientTransport {
async fn send(&mut self, message: &str) -> Result<()> {
let stdin = self
.stdin
.as_mut()
.ok_or_else(|| Error::Transport("Transport closed".to_string()))?;
stdin
.write_all(message.as_bytes())
.await
.map_err(|e| Error::Transport(format!("Failed to write: {}", e)))?;
stdin
.write_all(b"\n")
.await
.map_err(|e| Error::Transport(format!("Failed to write newline: {}", e)))?;
stdin
.flush()
.await
.map_err(|e| Error::Transport(format!("Failed to flush: {}", e)))?;
Ok(())
}
async fn recv(&mut self) -> Result<Option<String>> {
let mut line = String::new();
let bytes = self
.stdout
.read_line(&mut line)
.await
.map_err(|e| Error::Transport(format!("Failed to read: {}", e)))?;
if bytes == 0 {
return Ok(None); }
Ok(Some(line.trim().to_string()))
}
fn is_connected(&self) -> bool {
self.child.is_some() && self.stdin.is_some()
}
async fn close(&mut self) -> Result<()> {
self.stdin.take();
if let Some(mut child) = self.child.take() {
let result =
tokio::time::timeout(std::time::Duration::from_secs(5), child.wait()).await;
match result {
Ok(Ok(status)) => {
tracing::info!(status = ?status, "Child process exited");
}
Ok(Err(e)) => {
tracing::error!(error = %e, "Error waiting for child");
}
Err(_) => {
tracing::warn!("Timeout waiting for child, killing");
let _ = child.kill().await;
}
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_spawn_nonexistent_program() {
let result = StdioClientTransport::spawn("nonexistent-program-xyz", &[]).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_send_and_recv_via_cat() {
let mut transport = StdioClientTransport::spawn("cat", &[]).await.unwrap();
assert!(transport.is_connected());
let msg = r#"{"jsonrpc":"2.0","id":1,"method":"test"}"#;
transport.send(msg).await.unwrap();
let received = transport.recv().await.unwrap();
assert_eq!(received.as_deref(), Some(msg));
}
#[tokio::test]
async fn test_close_signals_eof() {
let mut transport = StdioClientTransport::spawn("cat", &[]).await.unwrap();
assert!(transport.is_connected());
transport.close().await.unwrap();
assert!(!transport.is_connected());
}
#[tokio::test]
async fn test_recv_returns_none_on_eof() {
let mut transport = StdioClientTransport::spawn("true", &[]).await.unwrap();
let result = transport.recv().await.unwrap();
assert_eq!(result, None);
}
#[tokio::test]
async fn test_send_after_close_fails() {
let mut transport = StdioClientTransport::spawn("cat", &[]).await.unwrap();
transport.close().await.unwrap();
let result = transport.send("hello").await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_spawn_command_with_env() {
let mut cmd = Command::new("sh");
cmd.args(["-c", "echo $TEST_VAR"]);
cmd.env("TEST_VAR", "hello_from_test");
let mut transport = StdioClientTransport::spawn_command(&mut cmd).await.unwrap();
let received = transport.recv().await.unwrap();
assert_eq!(received.as_deref(), Some("hello_from_test"));
}
#[tokio::test]
async fn test_multiple_send_recv_roundtrips() {
let mut transport = StdioClientTransport::spawn("cat", &[]).await.unwrap();
for i in 0..5 {
let msg = format!(r#"{{"id":{i},"msg":"test"}}"#);
transport.send(&msg).await.unwrap();
let received = transport.recv().await.unwrap();
assert_eq!(received.as_deref(), Some(msg.as_str()));
}
transport.close().await.unwrap();
}
}