use super::McpTransport;
use crate::mcp::protocol::{JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, McpNotification};
use anyhow::{anyhow, Context, Result};
use async_trait::async_trait;
use std::collections::HashMap;
use std::process::Stdio;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::process::{Child, Command};
use tokio::sync::{mpsc, oneshot, RwLock};
const DEFAULT_REQUEST_TIMEOUT_SECS: u64 = 60;
pub struct StdioTransport {
child: RwLock<Option<Child>>,
stdin_tx: mpsc::Sender<String>,
pending: Arc<RwLock<HashMap<u64, oneshot::Sender<JsonRpcResponse>>>>,
notification_rx: RwLock<Option<mpsc::Receiver<McpNotification>>>,
connected: AtomicBool,
request_timeout_secs: u64,
}
impl StdioTransport {
pub async fn spawn(
command: &str,
args: &[String],
env: &HashMap<String, String>,
) -> Result<Self> {
Self::spawn_with_timeout(command, args, env, DEFAULT_REQUEST_TIMEOUT_SECS).await
}
pub async fn spawn_with_timeout(
command: &str,
args: &[String],
env: &HashMap<String, String>,
request_timeout_secs: u64,
) -> Result<Self> {
let mut cmd = Command::new(command);
cmd.args(args)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.kill_on_drop(true);
for (key, value) in env {
cmd.env(key, value);
}
let mut child = cmd
.spawn()
.with_context(|| format!("Failed to spawn MCP server: {} {:?}", command, args))?;
let stdin = child.stdin.take().ok_or_else(|| anyhow!("No stdin"))?;
let stdout = child.stdout.take().ok_or_else(|| anyhow!("No stdout"))?;
let (stdin_tx, mut stdin_rx) = mpsc::channel::<String>(100);
let (notification_tx, notification_rx) = mpsc::channel::<McpNotification>(100);
let pending: Arc<RwLock<HashMap<u64, oneshot::Sender<JsonRpcResponse>>>> =
Arc::new(RwLock::new(HashMap::new()));
let mut stdin_writer = stdin;
tokio::spawn(async move {
while let Some(msg) = stdin_rx.recv().await {
if let Err(e) = stdin_writer.write_all(msg.as_bytes()).await {
tracing::error!("Failed to write to MCP stdin: {}", e);
break;
}
if let Err(e) = stdin_writer.flush().await {
tracing::error!("Failed to flush MCP stdin: {}", e);
break;
}
}
});
let pending_clone = pending.clone();
tokio::spawn(async move {
let mut reader = BufReader::new(stdout);
let mut line = String::new();
loop {
line.clear();
match reader.read_line(&mut line).await {
Ok(0) => {
tracing::debug!("MCP stdout closed");
break;
}
Ok(_) => {
let trimmed = line.trim();
if trimmed.is_empty() {
continue;
}
if let Ok(response) = serde_json::from_str::<JsonRpcResponse>(trimmed) {
if let Some(id) = response.id {
let mut pending = pending_clone.write().await;
if let Some(tx) = pending.remove(&id) {
let _ = tx.send(response);
}
}
continue;
}
if let Ok(notification) =
serde_json::from_str::<JsonRpcNotification>(trimmed)
{
let mcp_notif = McpNotification::from_json_rpc(¬ification);
let _ = notification_tx.send(mcp_notif).await;
continue;
}
tracing::warn!("Unknown MCP message: {}", trimmed);
}
Err(e) => {
tracing::error!("Failed to read MCP stdout: {}", e);
break;
}
}
}
});
Ok(Self {
child: RwLock::new(Some(child)),
stdin_tx,
pending,
notification_rx: RwLock::new(Some(notification_rx)),
connected: AtomicBool::new(true),
request_timeout_secs,
})
}
}
#[async_trait]
impl McpTransport for StdioTransport {
async fn request(&self, request: JsonRpcRequest) -> Result<JsonRpcResponse> {
if !self.connected.load(Ordering::SeqCst) {
return Err(anyhow!("Transport not connected"));
}
let (tx, rx) = oneshot::channel();
let request_id = request.id;
{
let mut pending = self.pending.write().await;
pending.insert(request_id, tx);
}
let msg = serde_json::to_string(&request)? + "\n";
self.stdin_tx
.send(msg)
.await
.map_err(|_| anyhow!("Failed to send request"))?;
let response = match tokio::time::timeout(
std::time::Duration::from_secs(self.request_timeout_secs),
rx,
)
.await
{
Ok(Ok(resp)) => resp,
Ok(Err(_)) => {
self.pending.write().await.remove(&request_id);
return Err(anyhow!("Response channel closed"));
}
Err(_) => {
self.pending.write().await.remove(&request_id);
return Err(anyhow!(
"MCP request timed out after {}s",
self.request_timeout_secs
));
}
};
Ok(response)
}
async fn notify(&self, notification: JsonRpcNotification) -> Result<()> {
if !self.connected.load(Ordering::SeqCst) {
return Err(anyhow!("Transport not connected"));
}
let msg = serde_json::to_string(¬ification)? + "\n";
self.stdin_tx
.send(msg)
.await
.map_err(|_| anyhow!("Failed to send notification"))?;
Ok(())
}
fn notifications(&self) -> mpsc::Receiver<McpNotification> {
let mut rx_guard = self.notification_rx.blocking_write();
rx_guard.take().unwrap_or_else(|| {
let (_, rx) = mpsc::channel(1);
rx
})
}
async fn close(&self) -> Result<()> {
self.connected.store(false, Ordering::SeqCst);
let mut child_guard = self.child.write().await;
if let Some(mut child) = child_guard.take() {
let _ = child.kill().await;
}
Ok(())
}
fn is_connected(&self) -> bool {
self.connected.load(Ordering::SeqCst)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_stdio_transport_spawn_invalid_command() {
let result = StdioTransport::spawn("nonexistent_command_12345", &[], &HashMap::new()).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_stdio_transport_spawn_echo() {
let result = StdioTransport::spawn("cat", &[], &HashMap::new()).await;
if let Ok(transport) = result {
assert!(transport.is_connected());
transport.close().await.unwrap();
assert!(!transport.is_connected());
}
}
#[tokio::test]
async fn test_stdio_transport_is_connected_initial() {
let result = StdioTransport::spawn("cat", &[], &HashMap::new()).await;
if let Ok(transport) = result {
assert!(transport.is_connected());
let _ = transport.close().await;
}
}
#[tokio::test]
async fn test_stdio_transport_close_disconnects() {
let result = StdioTransport::spawn("cat", &[], &HashMap::new()).await;
if let Ok(transport) = result {
assert!(transport.is_connected());
transport.close().await.unwrap();
assert!(!transport.is_connected());
}
}
#[tokio::test]
async fn test_stdio_transport_spawn_with_args() {
let args = vec!["--version".to_string()];
let result = StdioTransport::spawn("cat", &args, &HashMap::new()).await;
let _ = result;
}
#[tokio::test]
async fn test_stdio_transport_spawn_with_env() {
let mut env = HashMap::new();
env.insert("TEST_VAR".to_string(), "test_value".to_string());
let result = StdioTransport::spawn("cat", &[], &env).await;
if let Ok(transport) = result {
let _ = transport.close().await;
}
}
#[tokio::test]
async fn test_stdio_transport_double_close() {
let result = StdioTransport::spawn("cat", &[], &HashMap::new()).await;
if let Ok(transport) = result {
transport.close().await.unwrap();
let result = transport.close().await;
assert!(result.is_ok());
}
}
#[tokio::test]
async fn test_stdio_transport_request_after_close() {
let result = StdioTransport::spawn("cat", &[], &HashMap::new()).await;
if let Ok(transport) = result {
transport.close().await.unwrap();
let request = JsonRpcRequest::new(1, "test", None);
let result = transport.request(request).await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("not connected"));
}
}
#[tokio::test]
async fn test_stdio_transport_notify_after_close() {
let result = StdioTransport::spawn("cat", &[], &HashMap::new()).await;
if let Ok(transport) = result {
transport.close().await.unwrap();
let notification = JsonRpcNotification::new("test", None);
let result = transport.notify(notification).await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("not connected"));
}
}
#[test]
fn test_json_rpc_request_creation() {
let request =
JsonRpcRequest::new(1, "test_method", Some(serde_json::json!({"key": "value"})));
assert_eq!(request.id, 1);
assert_eq!(request.method, "test_method");
assert!(request.params.is_some());
}
#[test]
fn test_json_rpc_notification_creation() {
let notification = JsonRpcNotification::new("test_notification", None);
assert_eq!(notification.method, "test_notification");
assert!(notification.params.is_none());
}
#[tokio::test]
async fn test_stdio_transport_custom_timeout() {
let result = StdioTransport::spawn_with_timeout("cat", &[], &HashMap::new(), 1).await;
if let Ok(transport) = result {
assert_eq!(transport.request_timeout_secs, 1);
let _ = transport.close().await;
}
}
#[tokio::test]
async fn test_stdio_transport_default_timeout() {
let result = StdioTransport::spawn("cat", &[], &HashMap::new()).await;
if let Ok(transport) = result {
assert_eq!(transport.request_timeout_secs, DEFAULT_REQUEST_TIMEOUT_SECS);
let _ = transport.close().await;
}
}
}