use crate::error::Result;
use crate::server::http_middleware::{
adapters::{from_axum_with_limit, into_axum},
ServerHttpContext, ServerHttpMiddlewareChain, ServerHttpResponse,
};
use crate::server::tower_layers::{AllowedOrigins, DnsRebindingLayer, SecurityHeadersLayer};
use crate::server::Server;
use crate::shared::http_constants::{
APPLICATION_JSON, LAST_EVENT_ID, MCP_PROTOCOL_VERSION, MCP_SESSION_ID, TEXT_EVENT_STREAM,
};
use crate::shared::TransportMessage;
use crate::types::{ClientRequest, Request};
use async_trait::async_trait;
use axum::{
body::Body,
extract::State,
http::{header, HeaderMap, HeaderValue, StatusCode},
response::{sse::Event, IntoResponse, Response, Sse},
routing::{delete, get, post},
Json, Router,
};
use futures_util::StreamExt;
use parking_lot::RwLock;
use serde_json::json;
use std::collections::HashMap;
use std::convert::Infallible;
use std::net::SocketAddr;
use std::sync::Arc;
#[cfg(not(target_arch = "wasm32"))]
use tokio::sync::mpsc;
use tokio_stream::wrappers::UnboundedReceiverStream;
use uuid::Uuid;
#[async_trait]
pub trait EventStore: Send + Sync {
async fn store_event(
&self,
stream_id: &str,
event_id: &str,
message: &TransportMessage,
) -> Result<()>;
async fn replay_events_after(
&self,
last_event_id: &str,
) -> Result<Vec<(String, TransportMessage)>>;
async fn get_stream_for_event(&self, event_id: &str) -> Result<Option<String>>;
}
type EventList = Vec<(String, TransportMessage)>;
type EventsMap = HashMap<String, EventList>;
#[derive(Debug, Default)]
pub struct InMemoryEventStore {
events: Arc<RwLock<EventsMap>>,
event_to_stream: Arc<RwLock<HashMap<String, String>>>,
event_order: Arc<RwLock<Vec<String>>>,
}
#[async_trait]
impl EventStore for InMemoryEventStore {
async fn store_event(
&self,
stream_id: &str,
event_id: &str,
message: &TransportMessage,
) -> Result<()> {
let mut events = self.events.write();
let stream_events = events.entry(stream_id.to_string()).or_default();
stream_events.push((event_id.to_string(), message.clone()));
self.event_to_stream
.write()
.insert(event_id.to_string(), stream_id.to_string());
self.event_order.write().push(event_id.to_string());
Ok(())
}
async fn replay_events_after(
&self,
last_event_id: &str,
) -> Result<Vec<(String, TransportMessage)>> {
let event_order = self.event_order.read();
let mut result = Vec::new();
let start_pos = event_order
.iter()
.position(|id| id == last_event_id)
.map_or(0, |pos| pos + 1);
let events = self.events.read();
let event_to_stream = self.event_to_stream.read();
for i in start_pos..event_order.len() {
let event_id = &event_order[i];
if let Some(stream_id) = event_to_stream.get(event_id) {
if let Some(stream_events) = events.get(stream_id) {
for (eid, msg) in stream_events {
if eid == event_id {
result.push((eid.clone(), msg.clone()));
break;
}
}
}
}
}
Ok(result)
}
async fn get_stream_for_event(&self, event_id: &str) -> Result<Option<String>> {
Ok(self.event_to_stream.read().get(event_id).cloned())
}
}
type SessionCallback = Box<dyn Fn(&str) + Send + Sync>;
pub struct StreamableHttpServerConfig {
pub session_id_generator: Option<Box<dyn Fn() -> String + Send + Sync>>,
pub enable_json_response: bool,
pub event_store: Option<Arc<InMemoryEventStore>>,
pub on_session_initialized: Option<SessionCallback>,
pub on_session_closed: Option<SessionCallback>,
pub http_middleware: Option<Arc<ServerHttpMiddlewareChain>>,
pub allowed_origins: Option<AllowedOrigins>,
pub max_request_bytes: usize,
}
impl std::fmt::Debug for StreamableHttpServerConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("StreamableHttpServerConfig")
.field("session_id_generator", &self.session_id_generator.is_some())
.field("enable_json_response", &self.enable_json_response)
.field("event_store", &self.event_store.is_some())
.field(
"on_session_initialized",
&self.on_session_initialized.is_some(),
)
.field("on_session_closed", &self.on_session_closed.is_some())
.field("http_middleware", &self.http_middleware.is_some())
.field("allowed_origins", &self.allowed_origins)
.field("max_request_bytes", &self.max_request_bytes)
.finish()
}
}
impl Default for StreamableHttpServerConfig {
fn default() -> Self {
Self {
session_id_generator: Some(Box::new(|| Uuid::new_v4().to_string())),
enable_json_response: false,
event_store: Some(Arc::new(InMemoryEventStore::default())),
on_session_initialized: None,
on_session_closed: None,
http_middleware: None,
allowed_origins: None,
max_request_bytes: crate::server::limits::DEFAULT_MAX_REQUEST_BYTES,
}
}
}
impl StreamableHttpServerConfig {
pub fn stateless() -> Self {
Self {
session_id_generator: None,
enable_json_response: true,
event_store: None,
on_session_initialized: None,
on_session_closed: None,
http_middleware: None,
allowed_origins: Some(AllowedOrigins::any()),
max_request_bytes: crate::server::limits::DEFAULT_MAX_REQUEST_BYTES,
}
}
}
#[derive(Debug, Clone)]
struct SessionInfo {
initialized: bool,
protocol_version: Option<String>,
}
#[derive(Clone)]
pub(crate) struct ServerState {
server: Arc<tokio::sync::Mutex<Server>>,
config: Arc<StreamableHttpServerConfig>,
allowed_origins: AllowedOrigins,
sse_streams: Arc<RwLock<HashMap<String, mpsc::UnboundedSender<TransportMessage>>>>,
sessions: Arc<RwLock<HashMap<String, SessionInfo>>>,
}
pub(crate) fn build_mcp_router(state: ServerState) -> Router<()> {
Router::new()
.route("/", post(handle_post_request))
.route("/", get(handle_get_sse))
.route("/", delete(handle_delete_session))
.with_state(state)
}
pub(crate) fn make_server_state(
server: Arc<tokio::sync::Mutex<Server>>,
config: StreamableHttpServerConfig,
) -> ServerState {
let allowed_origins = config
.allowed_origins
.clone()
.unwrap_or_else(AllowedOrigins::localhost);
ServerState {
server,
config: Arc::new(config),
allowed_origins,
sse_streams: Arc::new(RwLock::new(HashMap::new())),
sessions: Arc::new(RwLock::new(HashMap::new())),
}
}
pub struct StreamableHttpServer {
addr: SocketAddr,
state: ServerState,
}
impl std::fmt::Debug for StreamableHttpServer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("StreamableHttpServer")
.field("addr", &self.addr)
.field("state", &"ServerState { ... }")
.finish()
}
}
fn create_error_response(status: StatusCode, code: i32, message: &str) -> Response {
let error_body = json!({
"jsonrpc": "2.0",
"error": {
"code": code,
"message": message
},
"id": null
});
(status, Json(error_body)).into_response()
}
impl StreamableHttpServer {
pub fn new(addr: SocketAddr, server: Arc<tokio::sync::Mutex<Server>>) -> Self {
Self::with_config(addr, server, StreamableHttpServerConfig::default())
}
pub fn with_config(
addr: SocketAddr,
server: Arc<tokio::sync::Mutex<Server>>,
config: StreamableHttpServerConfig,
) -> Self {
let state = make_server_state(server, config);
Self { addr, state }
}
pub async fn start(self) -> Result<(SocketAddr, tokio::task::JoinHandle<()>)> {
let allowed = self.state.allowed_origins.clone();
let cors = crate::server::tower_layers::build_mcp_cors_layer(&allowed);
let app = build_mcp_router(self.state)
.layer(SecurityHeadersLayer::default())
.layer(DnsRebindingLayer::new(allowed))
.layer(cors);
let listener = tokio::net::TcpListener::bind(self.addr).await?;
let local_addr = listener.local_addr()?;
let server_task = tokio::spawn(async move {
axum::serve(listener, app).await.unwrap();
});
Ok((local_addr, server_task))
}
}
fn validate_headers(headers: &HeaderMap, method: &str) -> std::result::Result<(), Response> {
match method {
"POST" => {
if let Some(content_type) = headers.get(header::CONTENT_TYPE) {
let ct = content_type.to_str().unwrap_or("");
if !ct.contains(APPLICATION_JSON) {
return Err(create_error_response(
StatusCode::UNSUPPORTED_MEDIA_TYPE,
-32700,
"Content-Type must be application/json",
));
}
} else {
return Err(create_error_response(
StatusCode::UNSUPPORTED_MEDIA_TYPE,
-32700,
"Content-Type header is required",
));
}
if let Some(accept) = headers.get(header::ACCEPT) {
let accept_str = accept.to_str().unwrap_or("");
if !accept_str.contains(APPLICATION_JSON) && !accept_str.contains(TEXT_EVENT_STREAM)
{
return Err(create_error_response(
StatusCode::NOT_ACCEPTABLE,
-32700,
"Accept header must include application/json or text/event-stream",
));
}
} else {
return Err(create_error_response(
StatusCode::NOT_ACCEPTABLE,
-32700,
"Accept header is required",
));
}
},
"GET" => {
if let Some(accept) = headers.get(header::ACCEPT) {
let accept_str = accept.to_str().unwrap_or("");
if !accept_str.contains(TEXT_EVENT_STREAM) {
return Err(create_error_response(
StatusCode::NOT_ACCEPTABLE,
-32700,
"Accept header must be text/event-stream for SSE",
));
}
} else {
return Err(create_error_response(
StatusCode::NOT_ACCEPTABLE,
-32700,
"Accept header is required for SSE",
));
}
},
_ => {},
}
Ok(())
}
fn process_init_session(
state: &ServerState,
session_id: Option<String>,
protocol_version: Option<String>,
) -> std::result::Result<(Option<String>, bool), Response> {
if let Some(generator) = &state.config.session_id_generator {
if let Some(sid) = session_id {
if let Some(session_info) = state.sessions.read().get(&sid) {
if session_info.initialized {
return Err(create_error_response(
StatusCode::BAD_REQUEST,
-32600,
"Session already initialized",
));
}
}
Ok((Some(sid), false))
} else {
let new_id = generator();
state.sessions.write().insert(
new_id.clone(),
SessionInfo {
initialized: false,
protocol_version,
},
);
if let Some(callback) = &state.config.on_session_initialized {
callback(&new_id);
}
Ok((Some(new_id), true))
}
} else {
Ok((None, false))
}
}
fn validate_non_init_session(
state: &ServerState,
session_id: Option<String>,
) -> std::result::Result<Option<String>, Response> {
if state.config.session_id_generator.is_some() {
match session_id {
None => {
Err(create_error_response(
StatusCode::BAD_REQUEST,
-32600,
"Session ID required for non-initialization requests",
))
},
Some(sid) => {
if !state.sessions.read().contains_key(&sid) {
Err(create_error_response(
StatusCode::NOT_FOUND,
-32600,
"Unknown session ID",
))
} else {
Ok(Some(sid))
}
},
}
} else {
Ok(None)
}
}
fn extract_negotiated_version(response: &TransportMessage) -> Option<String> {
if let TransportMessage::Response(ref json_resp) = response {
if let crate::types::jsonrpc::ResponsePayload::Result(ref value) = json_resp.payload {
if let Ok(init_result) =
serde_json::from_value::<crate::types::InitializeResult>(value.clone())
{
return Some(init_result.protocol_version.0);
}
}
}
None
}
fn update_session_after_init(
state: &ServerState,
session_id: Option<&String>,
negotiated_version: Option<String>,
) {
if let Some(sid) = session_id {
if let Some(session_info) = state.sessions.write().get_mut(sid) {
session_info.initialized = true;
session_info.protocol_version =
negotiated_version.or_else(|| Some(crate::DEFAULT_PROTOCOL_VERSION.to_string()));
}
}
}
fn build_response(
state: &ServerState,
response: TransportMessage,
session_id: Option<&String>,
) -> Response {
if state.config.enable_json_response {
let json_bytes = match crate::shared::StdioTransport::serialize_message(&response) {
Ok(bytes) => bytes,
Err(e) => {
return create_error_response(
StatusCode::INTERNAL_SERVER_ERROR,
-32603,
&format!("Failed to serialize response: {}", e),
);
},
};
tracing::debug!(
target: "mcp.http",
response = %String::from_utf8_lossy(&json_bytes),
"HTTP response serialized bytes"
);
let json_value: serde_json::Value = match serde_json::from_slice(&json_bytes) {
Ok(val) => val,
Err(e) => {
return create_error_response(
StatusCode::INTERNAL_SERVER_ERROR,
-32603,
&format!("Failed to parse JSON response: {}", e),
);
},
};
tracing::debug!(
target: "mcp.http",
response = %serde_json::to_string(&json_value).unwrap_or_default(),
"HTTP response (JSON mode)"
);
(StatusCode::OK, Json(json_value)).into_response()
} else {
if let Some(sid) = session_id {
if let Some(sender) = state.sse_streams.read().get(sid) {
let _ = sender.send(response);
StatusCode::ACCEPTED.into_response()
} else {
let (tx, rx) = mpsc::unbounded_channel();
tx.send(response).unwrap();
let stream = UnboundedReceiverStream::new(rx);
let sse = Sse::new(stream.map(|msg| {
let event_id = Uuid::new_v4().to_string();
let json_bytes = crate::shared::StdioTransport::serialize_message(&msg)
.unwrap_or_else(|e| {
tracing::error!(target: "mcp.sse", error = %e, "Failed to serialize SSE message");
Vec::new()
});
let json_str =
String::from_utf8(json_bytes).unwrap_or_else(|_| "{}".to_string());
Ok::<_, Infallible>(
Event::default()
.id(event_id)
.event("message")
.data(json_str),
)
}));
sse.into_response()
}
} else {
let json_bytes = match crate::shared::StdioTransport::serialize_message(&response) {
Ok(bytes) => bytes,
Err(e) => {
return create_error_response(
StatusCode::INTERNAL_SERVER_ERROR,
-32603,
&format!("Failed to serialize response: {}", e),
);
},
};
let json_value: serde_json::Value = match serde_json::from_slice(&json_bytes) {
Ok(val) => val,
Err(e) => {
return create_error_response(
StatusCode::INTERNAL_SERVER_ERROR,
-32603,
&format!("Failed to parse JSON response: {}", e),
);
},
};
(StatusCode::OK, Json(json_value)).into_response()
}
}
}
fn validate_protocol_version(
state: &ServerState,
session_id: Option<&String>,
protocol_version: Option<&String>,
) -> std::result::Result<(), Response> {
if let Some(version) = protocol_version {
if !crate::SUPPORTED_PROTOCOL_VERSIONS.contains(&version.as_str()) {
return Err(create_error_response(
StatusCode::BAD_REQUEST,
-32600,
&format!("Unsupported protocol version: {}", version),
));
}
}
if state.config.session_id_generator.is_some() {
if let Some(sid) = session_id {
if let Some(session_info) = state.sessions.read().get(sid.as_str()) {
if let Some(ref negotiated_version) = session_info.protocol_version {
if let Some(provided_version) = protocol_version {
if provided_version != negotiated_version {
return Err(create_error_response(
StatusCode::BAD_REQUEST,
-32600,
&format!(
"Protocol version mismatch: expected {}, got {}",
negotiated_version, provided_version
),
));
}
}
}
}
}
}
Ok(())
}
async fn handle_post_request(
State(state): State<ServerState>,
request: axum::extract::Request<Body>,
) -> impl IntoResponse {
if state.config.http_middleware.is_none() {
return handle_post_fast_path(state, request).await;
}
handle_post_with_middleware(state, request).await
}
async fn extract_and_validate_auth(
state: &ServerState,
headers: &HeaderMap,
) -> std::result::Result<Option<crate::server::auth::AuthContext>, Response> {
let server = state.server.lock().await;
if let Some(auth_provider) = server.get_auth_provider() {
let auth_header = headers
.get(http::header::AUTHORIZATION)
.and_then(|v| v.to_str().ok());
match auth_provider.validate_request(auth_header).await {
Ok(ctx) => Ok(ctx),
Err(e) => {
Err(create_error_response(
StatusCode::UNAUTHORIZED,
-32003,
&format!("Authentication failed: {}", e),
))
},
}
} else {
Ok(extract_auth_from_proxy_headers(headers))
}
}
fn extract_auth_from_proxy_headers(
headers: &HeaderMap,
) -> Option<crate::server::auth::AuthContext> {
let user_id = headers
.get("x-pmcp-user-id")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string())?;
let email = headers
.get("x-pmcp-user-email")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
let name = headers
.get("x-pmcp-user-name")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
let groups = headers
.get("x-pmcp-user-groups")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
let tenant_id = headers
.get("x-pmcp-tenant-id")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
let mut claims = std::collections::HashMap::new();
if let Some(ref email) = email {
claims.insert(
"email".to_string(),
serde_json::Value::String(email.clone()),
);
}
if let Some(ref name) = name {
claims.insert("name".to_string(), serde_json::Value::String(name.clone()));
}
if let Some(ref groups) = groups {
let groups_array: Vec<serde_json::Value> = groups
.split(',')
.map(|g| serde_json::Value::String(g.trim().to_string()))
.filter(|v| v.as_str() != Some(""))
.collect();
claims.insert("groups".to_string(), serde_json::Value::Array(groups_array));
}
if let Some(ref tenant_id) = tenant_id {
claims.insert(
"tenant_id".to_string(),
serde_json::Value::String(tenant_id.clone()),
);
}
tracing::debug!(
user_id = %user_id,
email = ?email,
"Extracted auth context from proxy headers"
);
Some(crate::server::auth::AuthContext {
subject: user_id,
scopes: vec![],
claims,
token: None,
client_id: None,
expires_at: None,
authenticated: true,
})
}
async fn handle_post_fast_path(
state: ServerState,
request: axum::extract::Request<Body>,
) -> Response {
let (parts, body) = request.into_parts();
let headers = parts.headers;
let body_bytes = match axum::body::to_bytes(body, state.config.max_request_bytes).await {
Ok(b) => b,
Err(e) => {
return create_error_response(
StatusCode::PAYLOAD_TOO_LARGE,
-32600,
&format!("Request body exceeds limit: {}", e),
);
},
};
let body = String::from_utf8_lossy(&body_bytes).to_string();
if let Err(error_response) = validate_headers(&headers, "POST") {
return error_response;
}
let message: TransportMessage =
match crate::shared::StdioTransport::parse_message(body.as_bytes()) {
Ok(msg) => msg,
Err(e) => {
return create_error_response(
StatusCode::BAD_REQUEST,
-32700,
&format!("Invalid JSON: {}", e),
);
},
};
let session_id = headers
.get(MCP_SESSION_ID)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
let protocol_version = headers
.get(MCP_PROTOCOL_VERSION)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
let is_init_request = matches!(
&message,
TransportMessage::Request { request: Request::Client(boxed), .. }
if matches!(**boxed, ClientRequest::Initialize(_))
);
let (response_session_id, _is_new_session) = if is_init_request {
match process_init_session(&state, session_id.clone(), protocol_version.clone()) {
Ok(result) => result,
Err(error_response) => return error_response,
}
} else {
match validate_non_init_session(&state, session_id.clone()) {
Ok(sid) => (sid, false),
Err(error_response) => return error_response,
}
};
if !is_init_request {
if let Err(error_response) =
validate_protocol_version(&state, session_id.as_ref(), protocol_version.as_ref())
{
return error_response;
}
}
let auth_context = match extract_and_validate_auth(&state, &headers).await {
Ok(ctx) => ctx,
Err(response) => return response,
};
match message {
TransportMessage::Request { id, request } => {
let server = state.server.lock().await;
let json_response = server.handle_request(id, request, auth_context).await;
tracing::debug!(
target: "mcp.http",
response = %serde_json::to_string(&json_response).unwrap_or_default(),
"StreamableHttpServer response"
);
let response = TransportMessage::Response(json_response.clone());
let negotiated_version = if is_init_request {
let version = extract_negotiated_version(&response);
update_session_after_init(&state, response_session_id.as_ref(), version.clone());
version
} else {
None
};
if let Some(event_store) = &state.config.event_store {
if let Some(sid) = &response_session_id {
let event_id = Uuid::new_v4().to_string();
let _ = event_store.store_event(sid, &event_id, &response).await;
}
}
let mut response = build_response(&state, response, session_id.as_ref());
if let Some(sid) = &response_session_id {
response
.headers_mut()
.insert(MCP_SESSION_ID, sid.parse().unwrap());
}
let version_to_send = if is_init_request {
negotiated_version.unwrap_or_else(|| crate::DEFAULT_PROTOCOL_VERSION.to_string())
} else {
if let Some(ref sid) = response_session_id {
if let Some(session_info) = state.sessions.read().get(sid) {
session_info
.protocol_version
.clone()
.unwrap_or_else(|| crate::DEFAULT_PROTOCOL_VERSION.to_string())
} else {
crate::DEFAULT_PROTOCOL_VERSION.to_string()
}
} else {
crate::DEFAULT_PROTOCOL_VERSION.to_string()
}
};
response
.headers_mut()
.insert(MCP_PROTOCOL_VERSION, version_to_send.parse().unwrap());
response
},
TransportMessage::Notification { .. } => {
StatusCode::ACCEPTED.into_response()
},
TransportMessage::Response(_) => StatusCode::ACCEPTED.into_response(),
}
}
#[allow(clippy::cognitive_complexity)]
async fn handle_post_with_middleware(
state: ServerState,
request: axum::extract::Request<Body>,
) -> Response {
let http_middleware = state
.config
.http_middleware
.as_ref()
.expect("Middleware chain must exist");
let (parts, body) = request.into_parts();
let mut server_request =
match from_axum_with_limit(parts, body, state.config.max_request_bytes).await {
Ok(req) => req,
Err(e) => {
return create_error_response(
StatusCode::PAYLOAD_TOO_LARGE,
-32600,
&format!("Request body exceeds limit: {}", e),
);
},
};
let session_id = server_request
.get_header(MCP_SESSION_ID)
.map(|s| s.to_string());
let request_id = server_request
.get_header("x-request-id")
.map_or_else(|| Uuid::new_v4().to_string(), |s| s.to_string());
let http_context = ServerHttpContext {
request_id: request_id.clone(),
start_time: std::time::Instant::now(),
session_id: session_id.clone(),
};
if let Err(e) = http_middleware
.process_request(&mut server_request, &http_context)
.await
{
let _ = http_middleware.handle_error(&e, &http_context).await;
return create_error_response(
StatusCode::INTERNAL_SERVER_ERROR,
-32603,
&format!("Middleware rejected request: {}", e),
);
}
if let Err(error_response) = validate_headers(&server_request.headers, "POST") {
let validation_error = crate::Error::protocol_msg("Header validation failed");
let _ = http_middleware
.handle_error(&validation_error, &http_context)
.await;
return error_response;
}
let body_str = String::from_utf8_lossy(&server_request.body);
let message: TransportMessage =
match crate::shared::StdioTransport::parse_message(body_str.as_bytes()) {
Ok(msg) => msg,
Err(e) => {
let mut error_response = ServerHttpResponse::new(
StatusCode::BAD_REQUEST,
HeaderMap::new(),
format!("{{\"error\":\"Invalid JSON: {}\"}}", e).into_bytes(),
);
let _ = http_middleware
.process_response(&mut error_response, &http_context)
.await;
return into_axum(error_response);
},
};
let protocol_version = server_request
.get_header(MCP_PROTOCOL_VERSION)
.map(|s| s.to_string());
let is_init_request = matches!(
&message,
TransportMessage::Request { request: Request::Client(boxed), .. }
if matches!(**boxed, ClientRequest::Initialize(_))
);
let (response_session_id, _) = if is_init_request {
match process_init_session(&state, session_id.clone(), protocol_version.clone()) {
Ok(result) => result,
Err(error_response) => {
let session_error = crate::Error::protocol_msg("Session initialization failed");
let _ = http_middleware
.handle_error(&session_error, &http_context)
.await;
return error_response;
},
}
} else {
match validate_non_init_session(&state, session_id.clone()) {
Ok(sid) => (sid, false),
Err(error_response) => {
let session_error = crate::Error::protocol_msg("Session validation failed");
let _ = http_middleware
.handle_error(&session_error, &http_context)
.await;
return error_response;
},
}
};
if !is_init_request {
if let Err(error_response) =
validate_protocol_version(&state, session_id.as_ref(), protocol_version.as_ref())
{
let version_error = crate::Error::protocol_msg("Protocol version validation failed");
let _ = http_middleware
.handle_error(&version_error, &http_context)
.await;
return error_response;
}
}
let auth_context = {
let server = state.server.lock().await;
if let Some(auth_provider) = server.get_auth_provider() {
let auth_header = server_request.get_header("authorization");
match auth_provider.validate_request(auth_header).await {
Ok(ctx) => ctx,
Err(e) => {
let auth_error =
crate::Error::authentication(format!("Authentication failed: {}", e));
let _ = http_middleware
.handle_error(&auth_error, &http_context)
.await;
return create_error_response(
StatusCode::UNAUTHORIZED,
-32003,
&format!("Authentication failed: {}", e),
);
},
}
} else {
None
}
};
match message {
TransportMessage::Request { id, request } => {
let server = state.server.lock().await;
let json_response = server.handle_request(id, request, auth_context).await;
let response_msg = TransportMessage::Response(json_response.clone());
let negotiated_version = if is_init_request {
let version = extract_negotiated_version(&response_msg);
update_session_after_init(&state, response_session_id.as_ref(), version.clone());
version
} else {
None
};
if let Some(event_store) = &state.config.event_store {
if let Some(sid) = &response_session_id {
let event_id = Uuid::new_v4().to_string();
let _ = event_store.store_event(sid, &event_id, &response_msg).await;
}
}
let response_body = match serde_json::to_vec(&response_msg) {
Ok(b) => b,
Err(e) => {
let serialization_error =
crate::Error::internal(format!("Failed to serialize response: {}", e));
let _ = http_middleware
.handle_error(&serialization_error, &http_context)
.await;
return create_error_response(
StatusCode::INTERNAL_SERVER_ERROR,
-32603,
&format!("Failed to serialize response: {}", e),
);
},
};
let mut response_headers = HeaderMap::new();
response_headers.insert(header::CONTENT_TYPE, APPLICATION_JSON.parse().unwrap());
if let Some(sid) = &response_session_id {
response_headers.insert(MCP_SESSION_ID, sid.parse().unwrap());
}
let version_to_send = if is_init_request {
negotiated_version.unwrap_or_else(|| crate::DEFAULT_PROTOCOL_VERSION.to_string())
} else if let Some(ref sid) = response_session_id {
if let Some(session_info) = state.sessions.read().get(sid) {
session_info
.protocol_version
.clone()
.unwrap_or_else(|| crate::DEFAULT_PROTOCOL_VERSION.to_string())
} else {
crate::DEFAULT_PROTOCOL_VERSION.to_string()
}
} else {
crate::DEFAULT_PROTOCOL_VERSION.to_string()
};
response_headers.insert(MCP_PROTOCOL_VERSION, version_to_send.parse().unwrap());
let mut server_response =
ServerHttpResponse::new(StatusCode::OK, response_headers, response_body);
if let Err(e) = http_middleware
.process_response(&mut server_response, &http_context)
.await
{
tracing::warn!("Response middleware processing failed: {}", e);
}
into_axum(server_response)
},
TransportMessage::Notification { .. } => StatusCode::ACCEPTED.into_response(),
TransportMessage::Response(_) => StatusCode::ACCEPTED.into_response(),
}
}
async fn handle_get_sse(State(state): State<ServerState>, headers: HeaderMap) -> impl IntoResponse {
if let Err(error_response) = validate_headers(&headers, "GET") {
return error_response;
}
let session_id = headers
.get(MCP_SESSION_ID)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
let session_id = if let Some(sid) = session_id {
if state.config.session_id_generator.is_some() && !state.sessions.read().contains_key(&sid)
{
return create_error_response(StatusCode::NOT_FOUND, -32600, "Unknown session ID");
}
sid
} else if let Some(generator) = &state.config.session_id_generator {
let new_id = generator();
state.sessions.write().insert(
new_id.clone(),
SessionInfo {
initialized: true, protocol_version: None,
},
);
if let Some(callback) = &state.config.on_session_initialized {
callback(&new_id);
}
new_id
} else {
return create_error_response(
StatusCode::METHOD_NOT_ALLOWED,
-32601,
"SSE not supported in stateless mode",
);
};
if state.sse_streams.read().contains_key(&session_id) {
return create_error_response(
StatusCode::CONFLICT,
-32600,
"SSE stream already exists for this session",
);
}
let (tx, rx) = mpsc::unbounded_channel();
state
.sse_streams
.write()
.insert(session_id.clone(), tx.clone());
if let Some(last_event_id) = headers.get(LAST_EVENT_ID) {
if let Ok(last_id) = last_event_id.to_str() {
if let Some(event_store) = &state.config.event_store {
if let Ok(events) = event_store.replay_events_after(last_id).await {
for (_event_id, msg) in events {
let _ = tx.send(msg);
}
}
}
}
}
let stream = UnboundedReceiverStream::new(rx);
let session_id_header = session_id.clone();
let sse = Sse::new(stream.map(move |msg| {
let event_id = Uuid::new_v4().to_string();
if let Some(event_store) = &state.config.event_store {
let sid = session_id.clone();
let msg_clone = msg.clone();
let store = event_store.clone();
let event_id_clone = event_id.clone();
tokio::spawn(async move {
let _ = store.store_event(&sid, &event_id_clone, &msg_clone).await;
});
}
Ok::<_, Infallible>(
Event::default()
.id(event_id)
.event("message")
.data(serde_json::to_string(&msg).unwrap()),
)
}));
let mut response = sse.into_response();
response
.headers_mut()
.insert(MCP_SESSION_ID, session_id_header.parse().unwrap());
response.headers_mut().insert(
header::CACHE_CONTROL,
HeaderValue::from_static("no-cache, no-transform"),
);
response
.headers_mut()
.insert(header::CONNECTION, HeaderValue::from_static("keep-alive"));
response
}
async fn handle_delete_session(
State(state): State<ServerState>,
headers: HeaderMap,
) -> impl IntoResponse {
let session_id = headers
.get(MCP_SESSION_ID)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
if let Some(sid) = session_id {
let session_exists = state.sessions.read().contains_key(&sid);
if !session_exists && state.config.session_id_generator.is_some() {
return create_error_response(StatusCode::NOT_FOUND, -32600, "Unknown session ID");
}
state.sse_streams.write().remove(&sid);
state.sessions.write().remove(&sid);
if let Some(callback) = &state.config.on_session_closed {
callback(&sid);
}
(StatusCode::OK, Json(json!({"status": "ok"}))).into_response()
} else {
create_error_response(StatusCode::NOT_FOUND, -32600, "No session ID provided")
}
}