use super::types::*;
use async_trait::async_trait;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::process::{Child, Command};
use tokio::sync::Mutex;
pub const DEFAULT_REQUEST_TIMEOUT: Duration = Duration::from_secs(30);
#[async_trait]
pub trait McpTransport: Send + Sync {
async fn send(&self, request: JsonRpcRequest) -> Result<JsonRpcResponse, McpError>;
async fn close(&self) -> Result<(), McpError>;
}
pub struct StdioTransport {
stdin: Arc<Mutex<tokio::process::ChildStdin>>, stdout: Arc<Mutex<BufReader<tokio::process::ChildStdout>>>, child: Arc<Mutex<Child>>, request_timeout: Duration, }
impl StdioTransport {
pub async fn new(
command: &str, args: &[&str], env: Option<HashMap<String, String>>, ) -> Result<Self, McpError> {
let mut cmd = Command::new(command);
cmd.args(args)
.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped());
if let Some(env_vars) = env {
for (k, v) in env_vars {
cmd.env(k, v);
}
}
let mut child = cmd
.spawn()
.map_err(|e| McpError::Transport(format!("Failed to spawn '{}': {}", command, e)))?;
let stdin = child
.stdin
.take()
.ok_or_else(|| McpError::Transport("Failed to capture stdin".into()))?;
let stdout = child
.stdout
.take()
.ok_or_else(|| McpError::Transport("Failed to capture stdout".into()))?;
Ok(Self {
stdin: Arc::new(Mutex::new(stdin)),
stdout: Arc::new(Mutex::new(BufReader::new(stdout))), child: Arc::new(Mutex::new(child)),
request_timeout: DEFAULT_REQUEST_TIMEOUT,
})
}
pub fn with_timeout(mut self, request_timeout: Duration) -> Self {
self.request_timeout = request_timeout;
self
}
}
#[async_trait]
impl McpTransport for StdioTransport {
async fn send(
&self,
request: JsonRpcRequest, ) -> Result<JsonRpcResponse, McpError> {
let timeout = self.request_timeout;
let work = async {
let mut line = serde_json::to_string(&request)?;
line.push('\n');
{
let mut stdin = self.stdin.lock().await;
stdin
.write_all(line.as_bytes())
.await
.map_err(|e| McpError::Transport(format!("Write error: {}", e)))?;
stdin
.flush()
.await
.map_err(|e| McpError::Transport(format!("Flush error: {}", e)))?;
}
let mut response_line = String::new();
{
let mut stdout = self.stdout.lock().await;
let bytes_read = stdout
.read_line(&mut response_line)
.await
.map_err(|e| McpError::Transport(format!("Read error: {}", e)))?;
if bytes_read == 0 {
return Err(McpError::ConnectionClosed);
}
}
let response: JsonRpcResponse = serde_json::from_str(response_line.trim())?;
Ok::<_, McpError>(response)
};
match tokio::time::timeout(timeout, work).await {
Ok(result) => result,
Err(_) => Err(McpError::Timeout { duration: timeout }),
}
}
async fn close(&self) -> Result<(), McpError> {
let mut child = self.child.lock().await;
let _ = child.kill().await;
Ok(())
}
}
pub struct HttpTransport {
client: reqwest::Client,
base_url: String,
request_timeout: Duration,
}
impl HttpTransport {
pub fn new(url: &str) -> Result<Self, McpError> {
Self::new_with_timeout(url, DEFAULT_REQUEST_TIMEOUT)
}
pub fn new_with_timeout(url: &str, request_timeout: Duration) -> Result<Self, McpError> {
Ok(Self {
client: reqwest::Client::new(),
base_url: url.trim_end_matches('/').to_string(),
request_timeout,
})
}
}
#[async_trait]
impl McpTransport for HttpTransport {
async fn send(
&self,
request: JsonRpcRequest, ) -> Result<JsonRpcResponse, McpError> {
let timeout = self.request_timeout;
let work = async {
let resp = self
.client
.post(&self.base_url)
.json(&request)
.send()
.await
.map_err(|e| McpError::Transport(format!("HTTP error: {}", e)))?;
if !resp.status().is_success() {
return Err(McpError::Transport(format!(
"HTTP {} from server",
resp.status()
)));
}
let response: JsonRpcResponse = resp
.json()
.await
.map_err(|e| McpError::Transport(format!("Response parse error: {}", e)))?;
Ok::<_, McpError>(response)
};
match tokio::time::timeout(timeout, work).await {
Ok(result) => result,
Err(_) => Err(McpError::Timeout { duration: timeout }),
}
}
async fn close(&self) -> Result<(), McpError> {
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_stdio_transport_with_cat() {
let transport = StdioTransport::new("cat", &[], None).await.unwrap();
let request = JsonRpcRequest::new("test/echo", Some(serde_json::json!({"hello": "world"})));
let request_id = request.id;
let mut line = serde_json::to_string(&request).unwrap();
line.push('\n');
{
let mut stdin = transport.stdin.lock().await;
stdin.write_all(line.as_bytes()).await.unwrap();
stdin.flush().await.unwrap();
}
let mut response_line = String::new();
{
let mut stdout = transport.stdout.lock().await;
stdout.read_line(&mut response_line).await.unwrap();
}
let echoed: JsonRpcRequest = serde_json::from_str(response_line.trim()).unwrap();
assert_eq!(echoed.id, request_id);
assert_eq!(echoed.method, "test/echo");
transport.close().await.unwrap();
}
#[test]
fn test_http_transport_creation() {
let transport = HttpTransport::new("http://localhost:8080/mcp").unwrap();
assert_eq!(transport.base_url, "http://localhost:8080/mcp");
let transport = HttpTransport::new("http://localhost:8080/mcp/").unwrap();
assert_eq!(transport.base_url, "http://localhost:8080/mcp");
}
#[tokio::test]
async fn stdio_send_times_out_on_silent_child() {
let transport = StdioTransport::new("sleep", &["60"], None)
.await
.unwrap()
.with_timeout(Duration::from_millis(150));
let request = JsonRpcRequest::new("test/timeout", None);
let start = std::time::Instant::now();
let result = transport.send(request).await;
let elapsed = start.elapsed();
match result {
Err(McpError::Timeout { duration }) => {
assert_eq!(duration, Duration::from_millis(150));
}
other => panic!("expected McpError::Timeout, got {:?}", other),
}
assert!(
elapsed < Duration::from_secs(2),
"send() should have returned promptly after timeout, took {:?}",
elapsed
);
transport.close().await.unwrap();
}
#[tokio::test]
async fn stdio_send_succeeds_within_timeout() {
let script = r#"while IFS= read -r _line; do printf '{"jsonrpc":"2.0","id":1,"result":{"ok":true}}\n'; done"#;
let transport = StdioTransport::new("bash", &["-c", script], None)
.await
.unwrap()
.with_timeout(Duration::from_secs(5));
let request = JsonRpcRequest::new("test/ok", None);
let response = transport.send(request).await.expect("send should succeed");
assert!(response.result.is_some());
assert!(response.error.is_none());
transport.close().await.unwrap();
}
#[tokio::test]
async fn http_send_times_out_on_silent_server() {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
loop {
if let Ok((stream, _)) = listener.accept().await {
tokio::spawn(async move {
let _stream = stream;
tokio::time::sleep(Duration::from_secs(60)).await;
});
}
}
});
let url = format!("http://{}/", addr);
let transport = HttpTransport::new_with_timeout(&url, Duration::from_millis(200)).unwrap();
let request = JsonRpcRequest::new("test/timeout", None);
let start = std::time::Instant::now();
let result = transport.send(request).await;
let elapsed = start.elapsed();
match result {
Err(McpError::Timeout { duration }) => {
assert_eq!(duration, Duration::from_millis(200));
}
other => panic!("expected McpError::Timeout, got {:?}", other),
}
assert!(
elapsed < Duration::from_secs(2),
"send() should have returned promptly after timeout, took {:?}",
elapsed
);
}
#[test]
fn stdio_default_timeout_is_30s() {
assert_eq!(DEFAULT_REQUEST_TIMEOUT, Duration::from_secs(30));
}
}