use std::collections::HashMap;
use std::sync::Arc;
use async_trait::async_trait;
use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader};
use tokio::process::{Child, ChildStdin, ChildStdout};
use tokio::sync::Mutex;
use super::protocol::{McpRequest, McpResponse};
#[async_trait]
pub trait McpTransport: Send + Sync {
async fn send(&self, request: &McpRequest) -> Result<McpResponse, String>;
async fn shutdown(&self) -> Result<(), String>;
fn transport_type(&self) -> &str;
}
pub struct HttpTransport {
url: String,
http: reqwest::Client,
}
impl HttpTransport {
pub fn new(url: &str, timeout_secs: u64) -> Self {
let http = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(timeout_secs))
.build()
.unwrap_or_default();
Self {
url: url.to_string(),
http,
}
}
pub fn url(&self) -> &str {
&self.url
}
}
#[async_trait]
impl McpTransport for HttpTransport {
async fn send(&self, request: &McpRequest) -> Result<McpResponse, String> {
let resp = self
.http
.post(&self.url)
.json(request)
.send()
.await
.map_err(|e| format!("HTTP request failed: {}", e))?;
let status = resp.status();
if !status.is_success() {
let body = resp.text().await.unwrap_or_default();
return Err(format!("HTTP {} from MCP server: {}", status, body));
}
resp.json::<McpResponse>()
.await
.map_err(|e| format!("Failed to parse MCP response: {}", e))
}
async fn shutdown(&self) -> Result<(), String> {
Ok(())
}
fn transport_type(&self) -> &str {
"http"
}
}
pub struct StdioTransport {
io: Arc<Mutex<StdioIo>>,
child: Arc<Mutex<Child>>,
timeout_secs: u64,
}
struct StdioIo {
stdin: ChildStdin,
stdout: BufReader<ChildStdout>,
}
impl StdioTransport {
pub async fn spawn(
command: &str,
args: &[String],
env: &HashMap<String, String>,
timeout_secs: u64,
) -> Result<Self, String> {
let mut cmd = tokio::process::Command::new(command);
cmd.args(args)
.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::null());
for (key, value) in env {
cmd.env(key, value);
}
let mut child = cmd
.spawn()
.map_err(|e| format!("Failed to spawn MCP server '{}': {}", command, e))?;
let stdin = child
.stdin
.take()
.ok_or_else(|| "Failed to capture child stdin".to_string())?;
let stdout = child
.stdout
.take()
.ok_or_else(|| "Failed to capture child stdout".to_string())?;
Ok(Self {
io: Arc::new(Mutex::new(StdioIo {
stdin,
stdout: BufReader::new(stdout),
})),
child: Arc::new(Mutex::new(child)),
timeout_secs,
})
}
}
#[async_trait]
impl McpTransport for StdioTransport {
async fn send(&self, request: &McpRequest) -> Result<McpResponse, String> {
let body = serde_json::to_string(request)
.map_err(|e| format!("Failed to serialize request: {}", e))?;
let frame = format!("Content-Length: {}\r\n\r\n{}", body.len(), body);
let timeout = std::time::Duration::from_secs(self.timeout_secs);
let mut io = self.io.lock().await;
tokio::time::timeout(timeout, io.stdin.write_all(frame.as_bytes()))
.await
.map_err(|_| "Timeout writing to MCP server stdin".to_string())?
.map_err(|e| format!("Failed to write to MCP server stdin: {}", e))?;
tokio::time::timeout(timeout, io.stdin.flush())
.await
.map_err(|_| "Timeout flushing MCP server stdin".to_string())?
.map_err(|e| format!("Failed to flush MCP server stdin: {}", e))?;
let content_length = tokio::time::timeout(timeout, read_content_length(&mut io.stdout))
.await
.map_err(|_| "Timeout reading Content-Length from MCP server".to_string())?
.map_err(|e| format!("Failed to read Content-Length: {}", e))?;
let mut buf = vec![0u8; content_length];
tokio::time::timeout(timeout, io.stdout.read_exact(&mut buf))
.await
.map_err(|_| "Timeout reading response body from MCP server".to_string())?
.map_err(|e| format!("Failed to read response body: {}", e))?;
serde_json::from_slice::<McpResponse>(&buf)
.map_err(|e| format!("Failed to parse MCP stdio response: {}", e))
}
async fn shutdown(&self) -> Result<(), String> {
let mut child = self.child.lock().await;
match tokio::time::timeout(std::time::Duration::from_secs(3), child.wait()).await {
Ok(_) => Ok(()),
Err(_) => {
child
.kill()
.await
.map_err(|e| format!("Failed to kill MCP server: {}", e))?;
Ok(())
}
}
}
fn transport_type(&self) -> &str {
"stdio"
}
}
async fn read_content_length<R: tokio::io::AsyncBufRead + Unpin>(
reader: &mut R,
) -> Result<usize, String> {
let mut content_length: Option<usize> = None;
loop {
let mut header_line = String::new();
let n = reader
.read_line(&mut header_line)
.await
.map_err(|e| format!("Failed to read header line: {}", e))?;
if n == 0 {
return Err("MCP server closed stdout while reading headers".to_string());
}
let trimmed = header_line.trim();
if trimmed.is_empty() {
break;
}
if let Some(val) = trimmed.strip_prefix("Content-Length:") {
content_length = Some(
val.trim()
.parse::<usize>()
.map_err(|e| format!("Invalid Content-Length value: {}", e))?,
);
}
}
content_length.ok_or_else(|| "MCP server response missing Content-Length header".to_string())
}
impl Drop for StdioTransport {
fn drop(&mut self) {
if let Ok(mut child) = self.child.try_lock() {
let _ = child.start_kill();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tools::mcp::protocol::McpRequest;
#[test]
fn test_http_transport_type() {
let t = HttpTransport::new("http://localhost:8080", 30);
assert_eq!(t.transport_type(), "http");
}
#[test]
fn test_http_transport_url() {
let t = HttpTransport::new("http://localhost:8080", 30);
assert_eq!(t.url(), "http://localhost:8080");
}
#[tokio::test]
async fn test_http_transport_send_no_server() {
let t = HttpTransport::new("http://127.0.0.1:1", 5);
let req = McpRequest::new(1, "tools/list", None);
let result = t.send(&req).await;
assert!(result.is_err());
assert!(result.unwrap_err().contains("HTTP request failed"));
}
#[tokio::test]
async fn test_http_transport_shutdown_is_noop() {
let t = HttpTransport::new("http://localhost:8080", 30);
let result = t.shutdown().await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_read_content_length_valid() {
let input = b"Content-Length: 42\r\n\r\n";
let mut reader = BufReader::new(&input[..]);
let len = read_content_length(&mut reader).await.unwrap();
assert_eq!(len, 42);
}
#[tokio::test]
async fn test_read_content_length_missing() {
let input = b"X-Custom: foo\r\n\r\n";
let mut reader = BufReader::new(&input[..]);
let result = read_content_length(&mut reader).await;
assert!(result.is_err());
assert!(result.unwrap_err().contains("missing Content-Length"));
}
#[tokio::test]
async fn test_stdio_transport_content_length_framing() {
let script = r#"
while IFS= read -r line; do
line="${line%%$'\r'}"
if [[ "$line" == Content-Length:* ]]; then
cl="${line#Content-Length: }"
fi
if [[ -z "$line" ]]; then
body=$(dd bs=1 count="$cl" 2>/dev/null)
printf "Content-Length: %d\r\n\r\n%s" "${#body}" "$body"
fi
done
"#;
let transport = StdioTransport::spawn(
"bash",
&["-c".to_string(), script.to_string()],
&HashMap::new(),
10,
)
.await;
assert!(
transport.is_ok(),
"bash echo should spawn: {:?}",
transport.err()
);
let t = transport.unwrap();
assert_eq!(t.transport_type(), "stdio");
let req = McpRequest::new(1, "initialize", None);
let resp = t.send(&req).await;
assert!(
resp.is_ok(),
"Content-Length framing roundtrip should succeed: {:?}",
resp.err()
);
let _ = t.shutdown().await;
}
#[tokio::test]
async fn test_stdio_transport_spawn_nonexistent_command() {
let result = StdioTransport::spawn(
"/nonexistent/binary/that/does/not/exist",
&[],
&HashMap::new(),
5,
)
.await;
assert!(result.is_err(), "Spawning nonexistent binary should fail");
}
#[tokio::test]
async fn test_stdio_transport_shutdown_kills_process() {
let transport = StdioTransport::spawn("cat", &[], &HashMap::new(), 10)
.await
.unwrap();
let result = transport.shutdown().await;
assert!(result.is_ok(), "Shutdown should succeed");
}
}