use anyhow::Result;
use async_trait::async_trait;
use mcp_protocol::messages::JsonRpcMessage;
use std::process::Stdio;
use tokio::process::{Child, Command};
use std::sync::Arc;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::sync::{mpsc, Mutex};
pub struct StdioTransport {
child_process: Arc<Mutex<Option<Child>>>,
tx: mpsc::Sender<JsonRpcMessage>,
command: String,
args: Vec<String>,
stdin: Arc<Mutex<Option<tokio::process::ChildStdin>>>,
}
impl StdioTransport {
pub fn new(command: &str, args: Vec<String>) -> (Self, mpsc::Receiver<JsonRpcMessage>) {
let (tx, rx) = mpsc::channel(100);
let transport = Self {
child_process: Arc::new(Mutex::new(None)),
tx,
command: command.to_string(),
args,
stdin: Arc::new(Mutex::new(None)),
};
(transport, rx)
}
}
#[async_trait]
impl super::Transport for StdioTransport {
async fn start(&self) -> Result<()> {
let mut child = Command::new(&self.command)
.args(&self.args)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::inherit())
.spawn()?;
let stdout = child.stdout.take().expect("Failed to get stdout");
let stdin = child.stdin.take().expect("Failed to get stdin");
{
let mut guard = self.child_process.lock().await;
*guard = Some(child);
}
{
let mut stdin_guard = self.stdin.lock().await;
*stdin_guard = Some(stdin);
}
let tx = self.tx.clone();
tokio::spawn(async move {
let mut reader = BufReader::new(stdout);
let mut line = String::new();
while reader.read_line(&mut line).await.unwrap_or(0) > 0 {
match serde_json::from_str::<JsonRpcMessage>(&line) {
Ok(message) => {
if tx.send(message).await.is_err() {
break;
}
}
Err(err) => {
tracing::error!("Failed to parse JSON-RPC message: {}", err);
}
}
line.clear();
}
});
Ok(())
}
async fn send(&self, message: JsonRpcMessage) -> Result<()> {
let mut stdin_guard = self.stdin.lock().await;
let stdin = stdin_guard
.as_mut()
.ok_or_else(|| anyhow::anyhow!("Child process not started"))?;
let serialized = serde_json::to_string(&message)?;
stdin.write_all(serialized.as_bytes()).await?;
stdin.write_all(b"\n").await?;
stdin.flush().await?;
Ok(())
}
async fn close(&self) -> Result<()> {
{
let mut stdin_guard = self.stdin.lock().await;
*stdin_guard = None;
}
let mut guard = self.child_process.lock().await;
if let Some(mut child) = guard.take() {
let wait_future = child.wait();
match tokio::time::timeout(std::time::Duration::from_secs(1), wait_future).await {
Ok(Ok(_)) => return Ok(()),
_ => {
child.kill().await?;
child.wait().await?;
}
}
}
Ok(())
}
fn box_clone(&self) -> Box<dyn super::Transport> {
Box::new(self.clone())
}
}
impl Clone for StdioTransport {
fn clone(&self) -> Self {
Self {
child_process: self.child_process.clone(),
tx: self.tx.clone(),
command: self.command.clone(),
args: self.args.clone(),
stdin: self.stdin.clone(),
}
}
}