use std::ffi::OsStr;
use std::process::Stdio;
use tokio::process::{Command, Child, ChildStdout};
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader, BufWriter};
use tokio::sync::mpsc;
use crate::error::McpError;
use crate::McpResult;
use super::{McpMessage, McpTransport};
pub struct McpProcessTransport {
reader: BufReader<ChildStdout>,
writer_tx: mpsc::Sender<String>,
child: Child,
}
impl McpProcessTransport {
pub fn new<I, S>(cmd: S, args: I) -> McpResult<Self>
where
I: IntoIterator<Item = S>,
S: AsRef<OsStr>,
{
let cmd = cmd.as_ref();
let mut command = Command::new(cmd);
command.args(args);
command
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::inherit());
let mut child = command.spawn().map_err(|e| {
McpError::Transport(format!("Failed to spawn process '{}': {}", cmd.display(), e))
})?;
let child_stdin = child.stdin.take().ok_or_else(|| {
McpError::Transport("Failed to open child process stdin".to_string())
})?;
let child_stdout = child.stdout.take().ok_or_else(|| {
McpError::Transport("Failed to open child process stdout".to_string())
})?;
let (writer_tx, mut writer_rx) = mpsc::channel::<String>(32);
tokio::spawn(async move {
let mut writer = BufWriter::new(child_stdin);
while let Some(message) = writer_rx.recv().await {
if let Err(e) = writer.write_all(message.as_bytes()).await {
eprintln!("Error writing to server process stdin: {}", e);
break;
}
if !message.ends_with('\n') {
if let Err(_) = writer.write_all(b"\n").await {
break;
}
}
if let Err(e) = writer.flush().await {
eprintln!("Error flushing to server process: {}", e);
break;
}
}
});
Ok(Self {
reader: BufReader::new(child_stdout),
writer_tx,
child,
})
}
}
#[async_trait::async_trait]
impl McpTransport for McpProcessTransport {
async fn send(&mut self, message: McpMessage) -> McpResult<()> {
let json = serde_json::to_string(&message)
.map_err(|err| McpError::Serialization(err.to_string()))?;
self.writer_tx.send(json).await
.map_err(|err| McpError::Transport(format!("Failed to send message to process writer: {}", err)))
}
async fn receive(&mut self) -> McpResult<McpMessage> {
let mut line = String::new();
match self.reader.read_line(&mut line).await {
Ok(0) => {
Err(McpError::Transport("Connection closed by remote process (EOF)".to_string()))
}
Ok(_) => {
match serde_json::from_str(&line) {
Ok(parsed) => Ok(parsed),
Err(err) => Err(McpError::Serialization(format!("Invalid JSON from server: {}. Raw: {}", err, line)))
}
}
Err(err) => {
Err(McpError::Transport(format!("Failed to read from process: {}", err)))
}
}
}
async fn close(&mut self) -> McpResult<()> {
match self.child.kill().await {
Ok(_) => Ok(()),
Err(e) if e.kind() == std::io::ErrorKind::InvalidInput => Ok(()),
Err(e) => Err(McpError::Transport(format!("Failed to kill process: {}", e))),
}
}
}