use async_trait::async_trait;
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
#[cfg(unix)]
use tokio::net::UnixStream;
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::Mutex;
use super::protocol::McpMessage;
use anyhow::{anyhow, Result};
#[derive(Debug, Clone)]
pub enum McpTransportType {
Stdio {
command: String,
args: Vec<String>,
},
Tcp {
addr: String,
},
#[cfg(unix)]
Unix {
path: String,
},
}
#[async_trait]
pub trait McpTransport: Send + Sync {
async fn send(&self, message: &McpMessage) -> Result<()>;
async fn receive(&self) -> Result<Option<McpMessage>>;
async fn close(&self) -> Result<()>;
}
pub struct StdioTransport {
process: Arc<Mutex<Option<tokio::process::Child>>>,
stdin: Arc<Mutex<Option<tokio::process::ChildStdin>>>,
stdout: Arc<Mutex<Option<tokio::io::BufReader<tokio::process::ChildStdout>>>>,
}
impl StdioTransport {
pub fn new(_command: &str, _args: &[String]) -> Result<Self> {
Ok(Self {
process: Arc::new(Mutex::new(None)),
stdin: Arc::new(Mutex::new(None)),
stdout: Arc::new(Mutex::new(None)),
})
}
pub async fn start(&self, command: &str, args: &[String]) -> Result<()> {
use std::process::Stdio;
let mut cmd = tokio::process::Command::new(command);
cmd.args(args)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::null());
let mut child = cmd.spawn()?;
let stdin = child
.stdin
.take()
.ok_or_else(|| anyhow!("Failed to open stdin"))?;
let stdout = child
.stdout
.take()
.ok_or_else(|| anyhow!("Failed to open stdout"))?;
*self.stdin.lock().await = Some(stdin);
*self.stdout.lock().await = Some(BufReader::new(stdout));
*self.process.lock().await = Some(child);
Ok(())
}
}
#[async_trait]
impl McpTransport for StdioTransport {
async fn send(&self, message: &McpMessage) -> Result<()> {
let mut stdin_guard = self.stdin.lock().await;
let stdin = stdin_guard
.as_mut()
.ok_or_else(|| anyhow!("Transport not started"))?;
let json = serde_json::to_string(message)?;
let frame = format!("Content-Length: {}\r\n\r\n{}", json.len(), json);
stdin.write_all(frame.as_bytes()).await?;
stdin.flush().await?;
Ok(())
}
async fn receive(&self) -> Result<Option<McpMessage>> {
let mut stdout_guard = self.stdout.lock().await;
let stdout = stdout_guard
.as_mut()
.ok_or_else(|| anyhow!("Transport not started"))?;
let mut header_buf = vec![0u8; 1024];
let mut total_read = 0;
loop {
let n = stdout.read(&mut header_buf[total_read..]).await?;
if n == 0 {
return Ok(None); }
total_read += n;
if let Some(pos) = find_header_end(&header_buf[..total_read]) {
let header = String::from_utf8_lossy(&header_buf[..pos]);
let content_length = parse_content_length(&header)?;
let header_size = pos + 4; let body_size = content_length;
let mut body_buf = vec![0u8; body_size];
let already_read = total_read - header_size;
if already_read > 0 {
body_buf[..already_read].copy_from_slice(&header_buf[header_size..total_read]);
}
if already_read < body_size {
stdout.read_exact(&mut body_buf[already_read..]).await?;
}
let message: McpMessage = serde_json::from_slice(&body_buf)?;
return Ok(Some(message));
}
if total_read >= header_buf.len() {
return Err(anyhow!("Header too large"));
}
}
}
async fn close(&self) -> Result<()> {
let mut process_guard = self.process.lock().await;
if let Some(mut process) = process_guard.take() {
process.kill().await?;
}
Ok(())
}
}
pub struct TcpTransport {
stream: Arc<Mutex<Option<TcpStream>>>,
listener: Arc<Mutex<Option<TcpListener>>>,
}
impl TcpTransport {
pub async fn connect(addr: &str) -> Result<Self> {
let stream = TcpStream::connect(addr).await?;
Ok(Self {
stream: Arc::new(Mutex::new(Some(stream))),
listener: Arc::new(Mutex::new(None)),
})
}
pub async fn bind(addr: &str) -> Result<Self> {
let listener = TcpListener::bind(addr).await?;
Ok(Self {
stream: Arc::new(Mutex::new(None)),
listener: Arc::new(Mutex::new(Some(listener))),
})
}
pub async fn accept(&self) -> Result<()> {
let mut listener_guard = self.listener.lock().await;
let listener = listener_guard
.as_mut()
.ok_or_else(|| anyhow!("Not in server mode"))?;
let (stream, _) = listener.accept().await?;
*self.stream.lock().await = Some(stream);
Ok(())
}
}
#[async_trait]
impl McpTransport for TcpTransport {
async fn send(&self, message: &McpMessage) -> Result<()> {
let mut stream_guard = self.stream.lock().await;
let stream = stream_guard
.as_mut()
.ok_or_else(|| anyhow!("Not connected"))?;
let json = serde_json::to_string(message)?;
let frame = format!("Content-Length: {}\r\n\r\n{}", json.len(), json);
stream.write_all(frame.as_bytes()).await?;
stream.flush().await?;
Ok(())
}
async fn receive(&self) -> Result<Option<McpMessage>> {
let mut stream_guard = self.stream.lock().await;
let stream = stream_guard
.as_mut()
.ok_or_else(|| anyhow!("Not connected"))?;
let mut header_buf = vec![0u8; 1024];
let mut total_read = 0;
loop {
let n = stream.read(&mut header_buf[total_read..]).await?;
if n == 0 {
return Ok(None); }
total_read += n;
if let Some(pos) = find_header_end(&header_buf[..total_read]) {
let header = String::from_utf8_lossy(&header_buf[..pos]);
let content_length = parse_content_length(&header)?;
let header_size = pos + 4; let body_size = content_length;
let mut body_buf = vec![0u8; body_size];
let already_read = total_read - header_size;
if already_read > 0 {
body_buf[..already_read].copy_from_slice(&header_buf[header_size..total_read]);
}
if already_read < body_size {
stream.read_exact(&mut body_buf[already_read..]).await?;
}
let message: McpMessage = serde_json::from_slice(&body_buf)?;
return Ok(Some(message));
}
if total_read >= header_buf.len() {
return Err(anyhow!("Header too large"));
}
}
}
async fn close(&self) -> Result<()> {
let mut stream_guard = self.stream.lock().await;
stream_guard.take();
Ok(())
}
}
fn find_header_end(buf: &[u8]) -> Option<usize> {
for i in 0..buf.len().saturating_sub(3) {
if &buf[i..i + 4] == b"\r\n\r\n" {
return Some(i);
}
}
None
}
fn parse_content_length(header: &str) -> Result<usize> {
for line in header.lines() {
if let Some(value) = line.strip_prefix("Content-Length:") {
return Ok(value.trim().parse()?);
}
}
Err(anyhow!("Content-Length header not found"))
}
#[cfg(unix)]
pub struct UnixSocketTransport {
stream: Arc<Mutex<Option<UnixStream>>>,
}
#[cfg(unix)]
impl UnixSocketTransport {
pub async fn connect(path: &str) -> Result<Self> {
let stream = UnixStream::connect(path).await?;
Ok(Self {
stream: Arc::new(Mutex::new(Some(stream))),
})
}
}
#[cfg(unix)]
#[async_trait]
impl McpTransport for UnixSocketTransport {
async fn send(&self, message: &McpMessage) -> Result<()> {
let mut stream_guard = self.stream.lock().await;
let stream = stream_guard
.as_mut()
.ok_or_else(|| anyhow!("Not connected"))?;
let json = serde_json::to_string(message)?;
let frame = format!("Content-Length: {}\r\n\r\n{}", json.len(), json);
stream.write_all(frame.as_bytes()).await?;
stream.flush().await?;
Ok(())
}
async fn receive(&self) -> Result<Option<McpMessage>> {
let mut stream_guard = self.stream.lock().await;
let stream = stream_guard
.as_mut()
.ok_or_else(|| anyhow!("Not connected"))?;
let mut header_buf = vec![0u8; 1024];
let mut total_read = 0;
loop {
let n = stream.read(&mut header_buf[total_read..]).await?;
if n == 0 {
return Ok(None); }
total_read += n;
if let Some(pos) = find_header_end(&header_buf[..total_read]) {
let header = String::from_utf8_lossy(&header_buf[..pos]);
let content_length = parse_content_length(&header)?;
let header_size = pos + 4; let body_size = content_length;
let mut body_buf = vec![0u8; body_size];
let already_read = total_read - header_size;
if already_read > 0 {
body_buf[..already_read].copy_from_slice(&header_buf[header_size..total_read]);
}
if already_read < body_size {
stream.read_exact(&mut body_buf[already_read..]).await?;
}
let message: McpMessage = serde_json::from_slice(&body_buf)?;
return Ok(Some(message));
}
if total_read >= header_buf.len() {
return Err(anyhow!("Header too large"));
}
}
}
async fn close(&self) -> Result<()> {
let mut stream_guard = self.stream.lock().await;
stream_guard.take();
Ok(())
}
}
pub struct MemoryTransport {
messages: Arc<Mutex<Vec<McpMessage>>>,
position: Arc<Mutex<usize>>,
}
impl Default for MemoryTransport {
fn default() -> Self {
Self::new()
}
}
impl MemoryTransport {
pub fn new() -> Self {
Self {
messages: Arc::new(Mutex::new(Vec::new())),
position: Arc::new(Mutex::new(0)),
}
}
pub async fn push(&self, message: McpMessage) {
self.messages.lock().await.push(message);
}
}
#[async_trait]
impl McpTransport for MemoryTransport {
async fn send(&self, message: &McpMessage) -> Result<()> {
self.messages.lock().await.push(message.clone());
Ok(())
}
async fn receive(&self) -> Result<Option<McpMessage>> {
let messages = self.messages.lock().await;
let mut pos = self.position.lock().await;
if *pos < messages.len() {
let message = messages[*pos].clone();
*pos += 1;
Ok(Some(message))
} else {
Ok(None)
}
}
async fn close(&self) -> Result<()> {
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::super::protocol::{McpRequest, RequestId};
use super::*;
#[test]
fn test_parse_content_length() {
let header = "Content-Length: 42\r\n";
let len = parse_content_length(header).unwrap();
assert_eq!(len, 42);
}
#[test]
fn test_find_header_end() {
let buf = b"Content-Length: 10\r\n\r\n";
let pos = find_header_end(buf).unwrap();
assert_eq!(pos, 18); }
#[tokio::test]
async fn test_memory_transport() {
let transport = MemoryTransport::new();
let msg = McpMessage::Request(McpRequest {
id: RequestId::Number(1),
method: "test".to_string(),
params: None,
});
transport.send(&msg).await.unwrap();
let received = transport.receive().await.unwrap();
assert!(received.is_some());
}
}