use anyhow::{Result, anyhow};
use async_trait::async_trait;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::sync::mpsc;
#[cfg(feature = "mcp")]
use futures_util::{SinkExt, StreamExt};
#[cfg(feature = "mcp")]
use tokio_tungstenite::{connect_async, tungstenite::protocol::Message};
use super::jsonrpc::JsonRpcMessage;
#[async_trait]
pub trait Transport: Send + Sync {
async fn send(&mut self, message: JsonRpcMessage) -> Result<()>;
async fn receive(&mut self) -> Result<Option<JsonRpcMessage>>;
async fn close(&mut self) -> Result<()>;
}
pub struct StdioTransport {
stdin_reader: BufReader<tokio::io::Stdin>,
stdout: tokio::io::Stdout,
shutdown_rx: mpsc::Receiver<()>,
}
impl StdioTransport {
pub fn new(shutdown_rx: mpsc::Receiver<()>) -> Self {
Self {
stdin_reader: BufReader::new(tokio::io::stdin()),
stdout: tokio::io::stdout(),
shutdown_rx,
}
}
}
#[async_trait]
impl Transport for StdioTransport {
async fn send(&mut self, message: JsonRpcMessage) -> Result<()> {
let json = serde_json::to_string(&message)?;
self.stdout.write_all(json.as_bytes()).await?;
self.stdout.write_all(b"\n").await?;
self.stdout.flush().await?;
Ok(())
}
async fn receive(&mut self) -> Result<Option<JsonRpcMessage>> {
let mut line = String::new();
tokio::select! {
_ = self.shutdown_rx.recv() => {
return Ok(None);
}
result = self.stdin_reader.read_line(&mut line) => {
match result {
Ok(0) => Ok(None), Ok(_) => {
let trimmed = line.trim();
if trimmed.is_empty() {
return Ok(None);
}
let message: JsonRpcMessage = serde_json::from_str(trimmed)?;
Ok(Some(message))
}
Err(e) => Err(e.into()),
}
}
}
}
async fn close(&mut self) -> Result<()> {
Ok(())
}
}
#[cfg(feature = "mcp")]
type WebSocketPair = (
futures_util::stream::SplitSink<
tokio_tungstenite::WebSocketStream<
tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
>,
Message,
>,
futures_util::stream::SplitStream<
tokio_tungstenite::WebSocketStream<
tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
>,
>,
);
#[cfg(feature = "mcp")]
pub struct HttpTransport {
base_url: String,
client: reqwest::Client,
websocket: Option<WebSocketPair>,
is_connected: bool,
}
#[cfg(feature = "mcp")]
impl HttpTransport {
pub fn new(base_url: String) -> Self {
Self {
base_url,
client: reqwest::Client::new(),
websocket: None,
is_connected: false,
}
}
async fn connect(&mut self) -> Result<()> {
if self.is_connected {
return Ok(());
}
let ws_url = self
.base_url
.replace("http://", "ws://")
.replace("https://", "wss://")
+ "/mcp";
let (ws_stream, _) = connect_async(&ws_url)
.await
.map_err(|e| anyhow!("Failed to connect to WebSocket: {}", e))?;
let (sink, stream) = ws_stream.split();
self.websocket = Some((sink, stream));
self.is_connected = true;
Ok(())
}
async fn send_http(&self, message: &JsonRpcMessage) -> Result<JsonRpcMessage> {
let response = self
.client
.post(format!("{}/mcp", self.base_url))
.header("Content-Type", "application/json")
.json(message)
.send()
.await
.map_err(|e| anyhow!("HTTP request failed: {}", e))?;
if !response.status().is_success() {
return Err(anyhow!(
"HTTP request failed with status: {}",
response.status()
));
}
let response_json: JsonRpcMessage = response
.json()
.await
.map_err(|e| anyhow!("Failed to parse JSON response: {}", e))?;
Ok(response_json)
}
}
#[cfg(feature = "mcp")]
#[async_trait]
impl Transport for HttpTransport {
async fn send(&mut self, message: JsonRpcMessage) -> Result<()> {
if let Some((sink, _)) = &mut self.websocket {
let json = serde_json::to_string(&message)?;
sink.send(Message::Text(json))
.await
.map_err(|e| anyhow!("WebSocket send failed: {}", e))?;
Ok(())
} else {
match &message {
JsonRpcMessage::Request { .. } => {
let _response = self.send_http(&message).await?;
Ok(())
}
JsonRpcMessage::Notification { .. } => {
let _response = self
.client
.post(format!("{}/mcp/notify", self.base_url))
.header("Content-Type", "application/json")
.json(&message)
.send()
.await
.map_err(|e| anyhow!("HTTP notification failed: {}", e))?;
Ok(())
}
JsonRpcMessage::Response { .. } => {
Err(anyhow!("Cannot send response message via HTTP transport"))
}
}
}
}
async fn receive(&mut self) -> Result<Option<JsonRpcMessage>> {
if !self.is_connected {
self.connect().await?;
}
if let Some((_, stream)) = &mut self.websocket {
match stream.next().await {
Some(Ok(Message::Text(text))) => {
let message: JsonRpcMessage = serde_json::from_str(&text)
.map_err(|e| anyhow!("Failed to parse WebSocket message: {}", e))?;
Ok(Some(message))
}
Some(Ok(Message::Close(_))) => {
self.is_connected = false;
self.websocket = None;
Ok(None)
}
Some(Ok(_)) => {
self.receive().await }
Some(Err(e)) => Err(anyhow!("WebSocket error: {}", e)),
None => {
self.is_connected = false;
self.websocket = None;
Ok(None)
}
}
} else {
Ok(None)
}
}
async fn close(&mut self) -> Result<()> {
if let Some((mut sink, _)) = self.websocket.take() {
sink.send(Message::Close(None))
.await
.map_err(|e| anyhow!("Failed to close WebSocket: {}", e))?;
}
self.is_connected = false;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_stdio_transport_creation() {
let (_tx, rx) = mpsc::channel(1);
let _transport = StdioTransport::new(rx);
}
#[cfg(feature = "mcp")]
#[test]
fn test_http_transport_creation() {
let _transport = HttpTransport::new("http://localhost:8080".to_string());
}
}