use std::collections::HashMap;
use std::convert::Infallible;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use axum::Router;
use axum::extract::DefaultBodyLimit;
use axum::response::sse::{Event, KeepAlive, Sse};
use axum::routing::{get, post};
use futures::stream::Stream;
use tokio::sync::{RwLock, broadcast};
use turbomcp_core::error::{McpError, McpResult};
use turbomcp_core::handler::McpHandler;
use turbomcp_core::types::core::ProtocolVersion;
use uuid::Uuid;
use crate::config::{RateLimiter, ServerConfig};
use crate::context::RequestContext;
use crate::router::{self, JsonRpcIncoming, JsonRpcOutgoing};
const MAX_BODY_SIZE: usize = 10 * 1024 * 1024;
const SSE_KEEP_ALIVE_SECS: u64 = 30;
#[derive(Debug, Clone)]
struct SessionData {
tx: broadcast::Sender<String>,
protocol_version: Option<ProtocolVersion>,
}
#[derive(Clone, Debug)]
pub struct SessionManager {
sessions: Arc<RwLock<HashMap<String, SessionData>>>,
}
impl Default for SessionManager {
fn default() -> Self {
Self::new()
}
}
impl SessionManager {
pub fn new() -> Self {
Self {
sessions: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn create_session(&self) -> (String, broadcast::Receiver<String>) {
let session_id = Uuid::new_v4().to_string();
let (tx, rx) = broadcast::channel(100);
self.sessions.write().await.insert(
session_id.clone(),
SessionData {
tx,
protocol_version: None,
},
);
tracing::debug!("Created SSE session: {}", session_id);
(session_id, rx)
}
pub async fn remove_session(&self, session_id: &str) {
self.sessions.write().await.remove(session_id);
tracing::debug!("Removed SSE session: {}", session_id);
}
#[allow(dead_code)] pub(crate) async fn send_to_session(&self, session_id: &str, message: &str) -> bool {
if let Some(data) = self.sessions.read().await.get(session_id) {
data.tx.send(message.to_string()).is_ok()
} else {
false
}
}
#[allow(dead_code)] pub(crate) async fn broadcast(&self, message: &str) {
let sessions = self.sessions.read().await;
for (session_id, data) in sessions.iter() {
if data.tx.send(message.to_string()).is_err() {
tracing::warn!("Failed to send to session {}", session_id);
}
}
}
#[allow(dead_code)] pub(crate) async fn session_count(&self) -> usize {
self.sessions.read().await.len()
}
pub(crate) async fn set_protocol_version(&self, session_id: &str, version: ProtocolVersion) {
if let Some(data) = self.sessions.write().await.get_mut(session_id) {
data.protocol_version = Some(version);
}
}
pub(crate) async fn get_protocol_version(&self, session_id: &str) -> Option<ProtocolVersion> {
self.sessions
.read()
.await
.get(session_id)
.and_then(|data| data.protocol_version.clone())
}
}
pub async fn run<H: McpHandler>(handler: &H, addr: &str) -> McpResult<()> {
handler.on_initialize().await?;
let session_manager = SessionManager::new();
let state = SseState {
handler: handler.clone(),
session_manager: session_manager.clone(),
rate_limiter: None,
config: None,
};
let app = Router::new()
.route("/", post(handle_json_rpc::<H>))
.route("/mcp", post(handle_json_rpc::<H>))
.route("/sse", get(handle_sse::<H>))
.layer(DefaultBodyLimit::max(MAX_BODY_SIZE))
.with_state(state);
let socket_addr: SocketAddr = addr
.parse()
.map_err(|e| McpError::internal(format!("Invalid address '{}': {}", addr, e)))?;
let listener = tokio::net::TcpListener::bind(socket_addr)
.await
.map_err(|e| McpError::internal(format!("Failed to bind to {}: {}", addr, e)))?;
tracing::info!(
"MCP server listening on http://{} (POST /, /mcp; GET /sse)",
socket_addr
);
axum::serve(listener, app)
.await
.map_err(|e| McpError::internal(format!("Server error: {}", e)))?;
handler.on_shutdown().await?;
Ok(())
}
pub async fn run_with_config<H: McpHandler>(
handler: &H,
addr: &str,
config: &ServerConfig,
) -> McpResult<()> {
handler.on_initialize().await?;
let rate_limiter = config
.rate_limit
.as_ref()
.map(|cfg| Arc::new(RateLimiter::new(cfg.clone())));
let session_manager = SessionManager::new();
let state = SseState {
handler: handler.clone(),
session_manager: session_manager.clone(),
rate_limiter,
config: Some(config.clone()),
};
let app = Router::new()
.route("/", post(handle_json_rpc_with_rate_limit::<H>))
.route("/mcp", post(handle_json_rpc_with_rate_limit::<H>))
.route("/sse", get(handle_sse::<H>))
.layer(DefaultBodyLimit::max(MAX_BODY_SIZE))
.with_state(state);
let socket_addr: SocketAddr = addr
.parse()
.map_err(|e| McpError::internal(format!("Invalid address '{}': {}", addr, e)))?;
let listener = tokio::net::TcpListener::bind(socket_addr)
.await
.map_err(|e| McpError::internal(format!("Failed to bind to {}: {}", addr, e)))?;
let rate_limit_info = config
.rate_limit
.as_ref()
.map(|cfg| {
format!(
" (rate limit: {}/{}s)",
cfg.max_requests,
cfg.window.as_secs()
)
})
.unwrap_or_default();
tracing::info!(
"MCP server listening on http://{}{} (POST /, /mcp; GET /sse)",
socket_addr,
rate_limit_info
);
axum::serve(
listener,
app.into_make_service_with_connect_info::<SocketAddr>(),
)
.await
.map_err(|e| McpError::internal(format!("Server error: {}", e)))?;
handler.on_shutdown().await?;
Ok(())
}
#[derive(Clone)]
struct SseState<H: McpHandler> {
handler: H,
session_manager: SessionManager,
rate_limiter: Option<Arc<RateLimiter>>,
config: Option<ServerConfig>,
}
async fn route_with_version_tracking<H: McpHandler>(
handler: &H,
request: router::JsonRpcIncoming,
session_manager: &SessionManager,
config: Option<&ServerConfig>,
session_id: Option<&str>,
) -> router::JsonRpcOutgoing {
let ctx = RequestContext::http();
let core_ctx = ctx.to_core_context();
if request.method == "initialize" {
let response = router::route_request_with_config(handler, request, &core_ctx, config).await;
if let (Some(sid), Some(result)) = (session_id, response.result.as_ref())
&& let Some(version_str) = result.get("protocolVersion").and_then(|v| v.as_str())
{
let version = ProtocolVersion::from(version_str);
session_manager.set_protocol_version(sid, version).await;
tracing::debug!(
session_id = sid,
protocol_version = version_str,
"Stored negotiated protocol version for session"
);
}
return response;
}
if let Some(sid) = session_id
&& let Some(version) = session_manager.get_protocol_version(sid).await
{
return router::route_request_versioned(handler, request, &core_ctx, &version).await;
}
router::route_request_with_config(handler, request, &core_ctx, config).await
}
async fn handle_json_rpc<H: McpHandler>(
axum::extract::State(state): axum::extract::State<SseState<H>>,
headers: axum::http::HeaderMap,
axum::Json(request): axum::Json<JsonRpcIncoming>,
) -> axum::Json<JsonRpcOutgoing> {
let session_id = headers
.get("mcp-session-id")
.and_then(|v| v.to_str().ok())
.map(str::to_owned);
let response = route_with_version_tracking(
&state.handler,
request,
&state.session_manager,
state.config.as_ref(),
session_id.as_deref(),
)
.await;
axum::Json(response)
}
async fn handle_json_rpc_with_rate_limit<H: McpHandler>(
axum::extract::State(state): axum::extract::State<SseState<H>>,
axum::extract::ConnectInfo(addr): axum::extract::ConnectInfo<SocketAddr>,
headers: axum::http::HeaderMap,
axum::Json(request): axum::Json<JsonRpcIncoming>,
) -> Result<axum::Json<JsonRpcOutgoing>, axum::http::StatusCode> {
if let Some(ref limiter) = state.rate_limiter {
let client_id = addr.ip().to_string();
if !limiter.check(Some(&client_id)) {
tracing::warn!("Rate limit exceeded for client {}", client_id);
return Err(axum::http::StatusCode::TOO_MANY_REQUESTS);
}
}
let session_id = headers
.get("mcp-session-id")
.and_then(|v| v.to_str().ok())
.map(str::to_owned);
let response = route_with_version_tracking(
&state.handler,
request,
&state.session_manager,
state.config.as_ref(),
session_id.as_deref(),
)
.await;
Ok(axum::Json(response))
}
async fn handle_sse<H: McpHandler>(
axum::extract::State(state): axum::extract::State<SseState<H>>,
) -> impl axum::response::IntoResponse {
let (session_id, mut rx) = state.session_manager.create_session().await;
let session_manager = state.session_manager.clone();
let session_id_for_stream = session_id.clone();
let session_id_for_header = session_id.clone();
let stream = async_stream::stream! {
yield Ok::<_, Infallible>(Event::default()
.event("connected")
.data(format!(r#"{{"sessionId":"{}"}}"#, session_id_for_stream)));
loop {
match rx.recv().await {
Ok(message) => {
yield Ok(Event::default()
.event("message")
.data(message));
}
Err(broadcast::error::RecvError::Lagged(n)) => {
tracing::warn!("SSE client lagged, missed {} messages", n);
continue;
}
Err(broadcast::error::RecvError::Closed) => {
tracing::debug!("SSE broadcast channel closed");
break;
}
}
}
};
let cleanup_stream = CleanupStream {
inner: Box::pin(stream),
session_manager,
session_id,
};
let sse = Sse::new(cleanup_stream).keep_alive(
KeepAlive::new()
.interval(Duration::from_secs(SSE_KEEP_ALIVE_SECS))
.text("keep-alive"),
);
(
[(
axum::http::header::HeaderName::from_static("mcp-session-id"),
axum::http::header::HeaderValue::from_str(&session_id_for_header).unwrap_or_else(
|_| axum::http::header::HeaderValue::from_static("invalid-session"),
),
)],
sse,
)
}
struct CleanupStream<S> {
inner: std::pin::Pin<Box<S>>,
session_manager: SessionManager,
session_id: String,
}
impl<S: Stream<Item = Result<Event, Infallible>>> Stream for CleanupStream<S> {
type Item = Result<Event, Infallible>;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
self.inner.as_mut().poll_next(cx)
}
}
impl<S> Drop for CleanupStream<S> {
fn drop(&mut self) {
let session_manager = self.session_manager.clone();
let session_id = self.session_id.clone();
tokio::spawn(async move {
session_manager.remove_session(&session_id).await;
});
}
}
#[cfg(test)]
mod tests {
}