use super::{handlers::McpHandler, protocol::*};
use anyhow::Result;
use axum::{
extract::State,
http::{header, StatusCode},
response::{sse::Event, IntoResponse, Sse},
routing::{get, post},
Json, Router,
};
use futures::stream::Stream;
use serde_json;
use std::sync::Arc;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tower_http::cors::CorsLayer;
pub struct StdioTransport {
handler: Arc<McpHandler>,
}
impl StdioTransport {
pub fn new(handler: Arc<McpHandler>) -> Self {
Self { handler }
}
pub async fn run(&self) -> Result<()> {
let stdin = tokio::io::stdin();
let mut stdout = tokio::io::stdout();
let mut reader = BufReader::new(stdin);
let mut line = String::new();
tracing::info!("MCP STDIO transport started");
loop {
line.clear();
let n = reader.read_line(&mut line).await?;
if n == 0 {
break;
}
let trimmed = line.trim();
if trimmed.is_empty() {
continue;
}
let request: McpRequest = match serde_json::from_str(trimmed) {
Ok(req) => req,
Err(e) => {
let error_response = McpResponse::error(
None,
McpError::new(error_codes::PARSE_ERROR, e.to_string()),
);
let response_json = serde_json::to_string(&error_response)?;
stdout.write_all(response_json.as_bytes()).await?;
stdout.write_all(b"\n").await?;
stdout.flush().await?;
continue;
}
};
let response = self.handler.handle_request(request).await;
let response_json = serde_json::to_string(&response)?;
stdout.write_all(response_json.as_bytes()).await?;
stdout.write_all(b"\n").await?;
stdout.flush().await?;
}
tracing::info!("MCP STDIO transport stopped");
Ok(())
}
}
pub struct SseTransport {
handler: Arc<McpHandler>,
host: String,
port: u16,
}
impl SseTransport {
pub fn new(handler: Arc<McpHandler>, host: String, port: u16) -> Self {
Self {
handler,
host,
port,
}
}
pub async fn run(&self) -> Result<()> {
let app = Router::new()
.route("/", get(root))
.route("/mcp", post(mcp_handler))
.route("/mcp/sse", get(mcp_sse_handler))
.layer(CorsLayer::permissive())
.with_state(self.handler.clone());
let addr = format!("{}:{}", self.host, self.port);
let listener = tokio::net::TcpListener::bind(&addr).await?;
tracing::info!("MCP SSE transport listening on http://{}", addr);
axum::serve(listener, app).await?;
Ok(())
}
}
async fn root() -> &'static str {
"Ruvector MCP Server"
}
async fn mcp_handler(
State(handler): State<Arc<McpHandler>>,
Json(request): Json<McpRequest>,
) -> Json<McpResponse> {
let response = handler.handle_request(request).await;
Json(response)
}
async fn mcp_sse_handler(
State(handler): State<Arc<McpHandler>>,
) -> Sse<impl Stream<Item = Result<Event, std::convert::Infallible>>> {
let stream = async_stream::stream! {
yield Ok(Event::default().data("connected"));
let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(30));
loop {
interval.tick().await;
yield Ok(Event::default().event("ping").data("keep-alive"));
}
};
Sse::new(stream).keep_alive(
axum::response::sse::KeepAlive::new()
.interval(tokio::time::Duration::from_secs(30))
.text("keep-alive"),
)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::Config;
#[tokio::test]
async fn test_stdio_transport_creation() {
let config = Config::default();
let handler = Arc::new(McpHandler::new(config));
let _transport = StdioTransport::new(handler);
}
#[tokio::test]
async fn test_sse_transport_creation() {
let config = Config::default();
let handler = Arc::new(McpHandler::new(config));
let _transport = SseTransport::new(handler, "127.0.0.1".to_string(), 3000);
}
}