use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use tokio::io::BufReader;
use tokio::process::{Child, Command};
use tokio::sync::{Mutex, oneshot};
use tokio::task::JoinHandle;
use crate::tools::mcp::protocol::{McpRequest, McpResponse};
use crate::tools::mcp::transport::{McpTransport, spawn_jsonrpc_reader, stream_transport_send};
use crate::tools::tool::ToolError;
pub struct StdioMcpTransport {
server_name: String,
stdin: Arc<Mutex<tokio::process::ChildStdin>>,
pending: Arc<Mutex<HashMap<u64, oneshot::Sender<McpResponse>>>>,
reader_handle: Mutex<Option<JoinHandle<()>>>,
stderr_handle: Mutex<Option<JoinHandle<()>>>,
child: Arc<Mutex<Child>>,
}
impl StdioMcpTransport {
pub async fn spawn(
name: impl Into<String>,
command: &str,
args: impl IntoIterator<Item = impl AsRef<std::ffi::OsStr>>,
env: impl IntoIterator<Item = (impl AsRef<std::ffi::OsStr>, impl AsRef<std::ffi::OsStr>)>,
) -> Result<Self, ToolError> {
let server_name = name.into();
let mut cmd = Command::new(command);
cmd.args(args)
.envs(env)
.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped());
let mut child = cmd.spawn().map_err(|e| {
ToolError::ExternalService(format!(
"[{}] Failed to spawn MCP server '{}': {}",
server_name, command, e
))
})?;
let stdin = child.stdin.take().ok_or_else(|| {
ToolError::ExternalService(format!(
"[{}] Failed to capture stdin of MCP server",
server_name
))
})?;
let stdout = child.stdout.take().ok_or_else(|| {
ToolError::ExternalService(format!(
"[{}] Failed to capture stdout of MCP server",
server_name
))
})?;
let stderr = child.stderr.take().ok_or_else(|| {
ToolError::ExternalService(format!(
"[{}] Failed to capture stderr of MCP server",
server_name
))
})?;
let pending: Arc<Mutex<HashMap<u64, oneshot::Sender<McpResponse>>>> =
Arc::new(Mutex::new(HashMap::new()));
let reader = BufReader::new(stdout);
let reader_handle = spawn_jsonrpc_reader(reader, pending.clone(), server_name.clone());
let stderr_name = server_name.clone();
let stderr_handle = tokio::spawn(async move {
use tokio::io::{AsyncBufReadExt, BufReader as TokioBufReader};
let reader = TokioBufReader::new(stderr);
let mut lines = reader.lines();
while let Ok(Some(line)) = lines.next_line().await {
tracing::debug!("[{}] stderr: {}", stderr_name, line);
}
});
Ok(Self {
server_name,
stdin: Arc::new(Mutex::new(stdin)),
pending,
reader_handle: Mutex::new(Some(reader_handle)),
stderr_handle: Mutex::new(Some(stderr_handle)),
child: Arc::new(Mutex::new(child)),
})
}
}
#[async_trait]
impl McpTransport for StdioMcpTransport {
async fn send(
&self,
request: &McpRequest,
_headers: &HashMap<String, String>,
) -> Result<McpResponse, ToolError> {
stream_transport_send(
&self.stdin,
&self.pending,
request,
&self.server_name,
Duration::from_secs(30),
)
.await
}
async fn shutdown(&self) -> Result<(), ToolError> {
{
let mut child = self.child.lock().await;
let _ = child.kill().await;
}
if let Some(handle) = self.reader_handle.lock().await.take() {
handle.abort();
}
if let Some(handle) = self.stderr_handle.lock().await.take() {
handle.abort();
}
{
let mut pending = self.pending.lock().await;
pending.clear(); }
tracing::debug!("[{}] Stdio transport shut down", self.server_name);
Ok(())
}
fn supports_http_features(&self) -> bool {
false
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_spawn_nonexistent_command_fails() {
let env: HashMap<String, String> = HashMap::new();
let result = StdioMcpTransport::spawn(
"test",
"this-command-does-not-exist-ironclaw-test",
std::iter::empty::<&str>(),
&env,
)
.await;
let err = result.err().expect("should be an error").to_string();
assert!(
err.contains("Failed to spawn"),
"Error should mention spawn failure: {}",
err
);
}
#[tokio::test]
async fn test_spawn_and_shutdown() {
let env: HashMap<String, String> = HashMap::new();
let transport =
StdioMcpTransport::spawn("test-cat", "cat", std::iter::empty::<&str>(), &env)
.await
.expect("cat should be available");
transport.shutdown().await.expect("shutdown should succeed");
}
#[tokio::test]
async fn test_send_timeout_on_non_jsonrpc_server() {
let env: HashMap<String, String> = HashMap::new();
let transport =
StdioMcpTransport::spawn("test-echo", "cat", std::iter::empty::<&str>(), &env)
.await
.expect("cat should be available");
let request = McpRequest::list_tools(999);
let headers = HashMap::new();
transport.shutdown().await.expect("shutdown should succeed");
let pending = transport.pending.lock().await;
assert!(pending.is_empty());
drop(pending);
let result = transport.send(&request, &headers).await;
assert!(result.is_err());
}
}