use anyhow::{anyhow, Result};
use async_trait::async_trait;
use std::sync::Arc;
use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader};
use tokio::process::{Child, Command};
use tokio::sync::Mutex;
use tokio::time::{timeout, Duration};
#[async_trait]
pub trait Transport: Send + Sync {
async fn send(&self, message: &str) -> Result<String>;
async fn notify(&self, message: &str) -> Result<()>;
async fn receive(&self) -> Result<String>;
async fn close(&self) -> Result<()>;
}
pub struct StdioTransport {
process: Arc<Mutex<Option<Child>>>,
writer: Arc<Mutex<Option<Box<dyn AsyncWrite + Unpin + Send>>>>,
reader: Arc<Mutex<Option<BufReader<Box<dyn AsyncRead + Unpin + Send>>>>>,
server_name: String,
}
impl StdioTransport {
pub async fn spawn(
name: impl Into<String>,
command: &str,
args: &[String],
env: Option<Vec<(String, String)>>,
) -> Result<Self> {
let server_name = name.into();
let (actual_command, actual_args) = if cfg!(target_os = "windows")
&& (command == "npx" || command == "npm" || command == "node") {
let mut full_args = vec!["/c".to_string(), command.to_string()];
full_args.extend(args.iter().cloned());
("cmd.exe".to_string(), full_args)
} else {
(command.to_string(), args.to_vec())
};
let mut cmd = Command::new(&actual_command);
cmd.args(&actual_args)
.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped())
.kill_on_drop(true);
if let Some(env_vars) = env {
for (key, value) in env_vars {
cmd.env(key, value);
}
}
let mut child = cmd.spawn()
.map_err(|e| anyhow!("Failed to spawn MCP server '{}': {} (command: {} {:?})",
server_name, e, actual_command, actual_args))?;
let stdin: Box<dyn AsyncWrite + Unpin + Send> = Box::new(child.stdin.take()
.ok_or_else(|| anyhow!("Failed to get stdin for MCP server '{}'", server_name))?);
let stdout: Box<dyn AsyncRead + Unpin + Send> = Box::new(child.stdout.take()
.ok_or_else(|| anyhow!("Failed to get stdout for MCP server '{}'", server_name))?);
tracing::info!("MCP server '{}' started: {} {:?}", server_name, actual_command, actual_args);
Ok(Self {
process: Arc::new(Mutex::new(Some(child))),
writer: Arc::new(Mutex::new(Some(stdin))),
reader: Arc::new(Mutex::new(Some(BufReader::new(stdout)))),
server_name,
})
}
async fn read_line(&self) -> Result<String> {
let mut reader_lock = self.reader.lock().await;
let reader = reader_lock.as_mut()
.ok_or_else(|| anyhow!("Transport closed for server '{}'", self.server_name))?;
let mut line = String::new();
reader.read_line(&mut line).await?;
if line.is_empty() {
return Err(anyhow!("EOF reached for server '{}'", self.server_name));
}
let line = line.trim_end().to_string();
Ok(line)
}
}
#[async_trait]
impl Transport for StdioTransport {
async fn send(&self, message: &str) -> Result<String> {
let mut writer_lock = self.writer.lock().await;
let writer = writer_lock.as_mut()
.ok_or_else(|| anyhow!("Transport closed for server '{}'", self.server_name))?;
writer.write_all(format!("{}\n", message).as_bytes()).await?;
writer.flush().await?;
let response = self.read_line().await?;
Ok(response)
}
async fn notify(&self, message: &str) -> Result<()> {
let mut writer_lock = self.writer.lock().await;
let writer = writer_lock.as_mut()
.ok_or_else(|| anyhow!("Transport closed for server '{}'", self.server_name))?;
writer.write_all(format!("{}\n", message).as_bytes()).await?;
writer.flush().await?;
Ok(())
}
async fn receive(&self) -> Result<String> {
self.read_line().await
}
async fn close(&self) -> Result<()> {
let mut process_lock = self.process.lock().await;
if let Some(mut child) = process_lock.take() {
child.kill().await.map_err(|e| anyhow!("Failed to kill MCP server '{}': {}", self.server_name, e))?;
tracing::info!("MCP server '{}' stopped", self.server_name);
}
*self.writer.lock().await = None;
*self.reader.lock().await = None;
Ok(())
}
}
pub struct SseTransport {
base_url: String,
client: reqwest::Client,
server_name: String,
timeout_ms: u64,
}
impl SseTransport {
pub fn new(
name: impl Into<String>,
base_url: impl Into<String>,
timeout_ms: Option<u64>,
) -> Self {
Self {
base_url: base_url.into(),
client: reqwest::Client::new(),
server_name: name.into(),
timeout_ms: timeout_ms.unwrap_or(30000),
}
}
async fn send_http(&self, body: &str) -> Result<String> {
let url = format!("{}/mcp", self.base_url);
let response = timeout(
Duration::from_millis(self.timeout_ms),
self.client
.post(&url)
.header("Content-Type", "application/json")
.body(body.to_string())
.send()
).await
.map_err(|_| anyhow!("Request timeout for MCP server '{}'", self.server_name))?
.map_err(|e| anyhow!("HTTP error for MCP server '{}': {}", self.server_name, e))?;
let text = response.text().await?;
Ok(text)
}
}
#[async_trait]
impl Transport for SseTransport {
async fn send(&self, message: &str) -> Result<String> {
self.send_http(message).await
}
async fn notify(&self, message: &str) -> Result<()> {
self.send_http(message).await?;
Ok(())
}
async fn receive(&self) -> Result<String> {
Err(anyhow!("SSE receive not implemented - use send() for request/response"))
}
async fn close(&self) -> Result<()> {
Ok(())
}
}
#[derive(Debug, Clone)]
pub enum TransportConfig {
Stdio {
command: String,
args: Vec<String>,
env: Option<Vec<(String, String)>>,
},
Sse {
url: String,
timeout_ms: Option<u64>,
},
}
impl TransportConfig {
pub fn stdio(command: impl Into<String>, args: Vec<String>) -> Self {
Self::Stdio {
command: command.into(),
args,
env: None,
}
}
pub fn sse(url: impl Into<String>) -> Self {
Self::Sse {
url: url.into(),
timeout_ms: None,
}
}
}
pub async fn create_transport(
server_name: &str,
config: &TransportConfig,
) -> Result<Box<dyn Transport>> {
match config {
TransportConfig::Stdio { command, args, env } => {
Ok(Box::new(StdioTransport::spawn(
server_name,
command,
args,
env.clone(),
).await?))
}
TransportConfig::Sse { url, timeout_ms } => {
Ok(Box::new(SseTransport::new(
server_name,
url,
*timeout_ms,
)))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_transport_config_stdio() {
let config = TransportConfig::stdio("npx", vec!["-y".into(), "@playwright/mcp".into()]);
match config {
TransportConfig::Stdio { command, args, .. } => {
assert_eq!(command, "npx");
assert_eq!(args.len(), 2);
}
_ => panic!("Expected Stdio variant"),
}
}
#[test]
fn test_transport_config_sse() {
let config = TransportConfig::sse("http://localhost:3000");
match config {
TransportConfig::Sse { url, .. } => {
assert_eq!(url, "http://localhost:3000");
}
_ => panic!("Expected Sse variant"),
}
}
}