use anyhow::{Result, anyhow};
use async_trait::async_trait;
use std::sync::Arc;
use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader};
use tokio::process::{Child, Command};
use tokio::sync::Mutex;
use tokio::time::{Duration, timeout};
#[async_trait]
pub trait Transport: Send + Sync {
async fn send(&self, message: &str) -> Result<String>;
async fn notify(&self, message: &str) -> Result<()>;
async fn receive(&self) -> Result<String>;
async fn close(&self) -> Result<()>;
}
pub struct StdioTransport {
process: Arc<Mutex<Option<Child>>>,
writer: Arc<Mutex<Option<Box<dyn AsyncWrite + Unpin + Send>>>>,
reader: Arc<Mutex<Option<BufReader<Box<dyn AsyncRead + Unpin + Send>>>>>,
server_name: String,
}
impl StdioTransport {
pub async fn spawn(
name: impl Into<String>,
command: &str,
args: &[String],
env: Option<Vec<(String, String)>>,
) -> Result<Self> {
let server_name = name.into();
let (actual_command, actual_args) = if cfg!(target_os = "windows")
&& (command == "npx" || command == "npm" || command == "node")
{
let mut full_args = vec!["/c".to_string(), command.to_string()];
full_args.extend(args.iter().cloned());
("cmd.exe".to_string(), full_args)
} else {
(command.to_string(), args.to_vec())
};
let mut cmd = Command::new(&actual_command);
cmd.args(&actual_args)
.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped())
.kill_on_drop(true);
if let Some(env_vars) = env {
for (key, value) in env_vars {
cmd.env(key, value);
}
}
tracing::debug!(
"Spawning MCP server '{}' with command: {} {:?}",
server_name,
actual_command,
actual_args
);
let mut child = cmd.spawn().map_err(|e| {
anyhow!(
"Failed to spawn MCP server '{}': {} (command: {} {:?})",
server_name,
e,
actual_command,
actual_args
)
})?;
tracing::debug!("MCP server '{}' process spawned successfully", server_name);
let stdin: Box<dyn AsyncWrite + Unpin + Send> = Box::new(
child
.stdin
.take()
.ok_or_else(|| anyhow!("Failed to get stdin for MCP server '{}'", server_name))?,
);
let stdout: Box<dyn AsyncRead + Unpin + Send> = Box::new(
child
.stdout
.take()
.ok_or_else(|| anyhow!("Failed to get stdout for MCP server '{}'", server_name))?,
);
tracing::info!(
"MCP server '{}' started: {} {:?}",
server_name,
actual_command,
actual_args
);
Ok(Self {
process: Arc::new(Mutex::new(Some(child))),
writer: Arc::new(Mutex::new(Some(stdin))),
reader: Arc::new(Mutex::new(Some(BufReader::new(stdout)))),
server_name,
})
}
async fn read_line(&self) -> Result<String> {
let mut reader_lock = self.reader.lock().await;
let reader = reader_lock
.as_mut()
.ok_or_else(|| anyhow!("Transport closed for server '{}'", self.server_name))?;
let mut line = String::new();
tracing::debug!("Reading from '{}' (timeout: 30s)...", self.server_name);
let read_result =
tokio::time::timeout(Duration::from_secs(30), reader.read_line(&mut line)).await;
tracing::debug!(
"Read result from '{}': {:?}",
self.server_name,
read_result.is_ok()
);
match read_result {
Ok(Ok(_)) => {
if line.is_empty() {
return Err(anyhow!("EOF reached for server '{}'", self.server_name));
}
Ok(line.trim_end().to_string())
}
Ok(Err(e)) => Err(anyhow!(
"Read error for server '{}': {}",
self.server_name,
e
)),
Err(_) => Err(anyhow!(
"Read timeout for server '{}' after 30s",
self.server_name
)),
}
}
}
#[async_trait]
impl Transport for StdioTransport {
async fn send(&self, message: &str) -> Result<String> {
tracing::debug!(
"MCP send to '{}': {}",
self.server_name,
message.chars().take(200).collect::<String>()
);
let mut writer_lock = self.writer.lock().await;
let writer = writer_lock
.as_mut()
.ok_or_else(|| anyhow!("Transport closed for server '{}'", self.server_name))?;
writer
.write_all(format!("{}\n", message).as_bytes())
.await?;
writer.flush().await?;
tracing::debug!(
"MCP sent, waiting for response from '{}'...",
self.server_name
);
let response = self.read_line().await?;
tracing::debug!(
"MCP received from '{}': {}",
self.server_name,
response.chars().take(200).collect::<String>()
);
Ok(response)
}
async fn notify(&self, message: &str) -> Result<()> {
let mut writer_lock = self.writer.lock().await;
let writer = writer_lock
.as_mut()
.ok_or_else(|| anyhow!("Transport closed for server '{}'", self.server_name))?;
tracing::info!(
"MCP >> '{}' : {}",
self.server_name,
message.chars().take(100).collect::<String>()
);
writer
.write_all(format!("{}\n", message).as_bytes())
.await?;
writer.flush().await?;
Ok(())
}
async fn receive(&self) -> Result<String> {
let line = self.read_line().await?;
tracing::info!(
"MCP << '{}' : {}",
self.server_name,
line.chars().take(100).collect::<String>()
);
Ok(line)
}
async fn close(&self) -> Result<()> {
let mut process_lock = self.process.lock().await;
if let Some(mut child) = process_lock.take() {
child
.kill()
.await
.map_err(|e| anyhow!("Failed to kill MCP server '{}': {}", self.server_name, e))?;
tracing::info!("MCP server '{}' stopped", self.server_name);
}
*self.writer.lock().await = None;
*self.reader.lock().await = None;
Ok(())
}
}
pub struct SseTransport {
base_url: String,
client: reqwest::Client,
server_name: String,
timeout_ms: u64,
}
impl SseTransport {
pub fn new(
name: impl Into<String>,
base_url: impl Into<String>,
timeout_ms: Option<u64>,
) -> Self {
Self {
base_url: base_url.into(),
client: reqwest::Client::new(),
server_name: name.into(),
timeout_ms: timeout_ms.unwrap_or(30000),
}
}
async fn send_http(&self, body: &str) -> Result<String> {
let url = format!("{}/mcp", self.base_url);
let response = timeout(
Duration::from_millis(self.timeout_ms),
self.client
.post(&url)
.header("Content-Type", "application/json")
.body(body.to_string())
.send(),
)
.await
.map_err(|_| anyhow!("Request timeout for MCP server '{}'", self.server_name))?
.map_err(|e| anyhow!("HTTP error for MCP server '{}': {}", self.server_name, e))?;
let text = response.text().await?;
Ok(text)
}
}
#[async_trait]
impl Transport for SseTransport {
async fn send(&self, message: &str) -> Result<String> {
self.send_http(message).await
}
async fn notify(&self, message: &str) -> Result<()> {
self.send_http(message).await?;
Ok(())
}
async fn receive(&self) -> Result<String> {
Err(anyhow!(
"SSE receive not implemented - use send() for request/response"
))
}
async fn close(&self) -> Result<()> {
Ok(())
}
}
#[derive(Debug, Clone)]
pub enum TransportConfig {
Stdio {
command: String,
args: Vec<String>,
env: Option<Vec<(String, String)>>,
},
Sse {
url: String,
timeout_ms: Option<u64>,
},
}
impl TransportConfig {
pub fn stdio(command: impl Into<String>, args: Vec<String>) -> Self {
Self::Stdio {
command: command.into(),
args,
env: None,
}
}
pub fn sse(url: impl Into<String>) -> Self {
Self::Sse {
url: url.into(),
timeout_ms: None,
}
}
}
pub async fn create_transport(
server_name: &str,
config: &TransportConfig,
) -> Result<Box<dyn Transport>> {
match config {
TransportConfig::Stdio { command, args, env } => Ok(Box::new(
StdioTransport::spawn(server_name, command, args, env.clone()).await?,
)),
TransportConfig::Sse { url, timeout_ms } => {
Ok(Box::new(SseTransport::new(server_name, url, *timeout_ms)))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_transport_config_stdio() {
let config = TransportConfig::stdio("npx", vec!["-y".into(), "@playwright/mcp".into()]);
match config {
TransportConfig::Stdio { command, args, .. } => {
assert_eq!(command, "npx");
assert_eq!(args.len(), 2);
}
_ => panic!("Expected Stdio variant"),
}
}
#[test]
fn test_transport_config_sse() {
let config = TransportConfig::sse("http://localhost:3000");
match config {
TransportConfig::Sse { url, .. } => {
assert_eq!(url, "http://localhost:3000");
}
_ => panic!("Expected Sse variant"),
}
}
}