use crate::error::Result;
use crate::server::http_middleware::{
adapters::{from_axum_with_limit, into_axum},
ServerHttpContext, ServerHttpMiddlewareChain, ServerHttpResponse,
};
use crate::server::tower_layers::{AllowedOrigins, DnsRebindingLayer, SecurityHeadersLayer};
use crate::server::Server;
use crate::shared::http_constants::{
APPLICATION_JSON, LAST_EVENT_ID, MCP_PROTOCOL_VERSION, MCP_SESSION_ID, TEXT_EVENT_STREAM,
};
use crate::shared::TransportMessage;
use crate::types::{ClientRequest, Request};
use async_trait::async_trait;
use axum::{
body::Body,
extract::State,
http::{header, HeaderMap, HeaderValue, StatusCode},
response::{sse::Event, IntoResponse, Response, Sse},
routing::{delete, get, post},
Json, Router,
};
use futures_util::StreamExt;
use parking_lot::RwLock;
use serde_json::json;
use std::collections::HashMap;
use std::convert::Infallible;
use std::net::SocketAddr;
use std::sync::Arc;
#[cfg(not(target_arch = "wasm32"))]
use tokio::sync::mpsc;
use tokio_stream::wrappers::UnboundedReceiverStream;
use uuid::Uuid;
#[async_trait]
pub trait EventStore: Send + Sync {
async fn store_event(
&self,
stream_id: &str,
event_id: &str,
message: &TransportMessage,
) -> Result<()>;
async fn replay_events_after(
&self,
last_event_id: &str,
) -> Result<Vec<(String, TransportMessage)>>;
async fn get_stream_for_event(&self, event_id: &str) -> Result<Option<String>>;
}
type EventList = Vec<(String, TransportMessage)>;
type EventsMap = HashMap<String, EventList>;
#[derive(Debug, Default)]
pub struct InMemoryEventStore {
events: Arc<RwLock<EventsMap>>,
event_to_stream: Arc<RwLock<HashMap<String, String>>>,
event_order: Arc<RwLock<Vec<String>>>,
}
#[async_trait]
impl EventStore for InMemoryEventStore {
async fn store_event(
&self,
stream_id: &str,
event_id: &str,
message: &TransportMessage,
) -> Result<()> {
let mut events = self.events.write();
let stream_events = events.entry(stream_id.to_string()).or_default();
stream_events.push((event_id.to_string(), message.clone()));
self.event_to_stream
.write()
.insert(event_id.to_string(), stream_id.to_string());
self.event_order.write().push(event_id.to_string());
Ok(())
}
async fn replay_events_after(
&self,
last_event_id: &str,
) -> Result<Vec<(String, TransportMessage)>> {
let event_order = self.event_order.read();
let mut result = Vec::new();
let start_pos = event_order
.iter()
.position(|id| id == last_event_id)
.map_or(0, |pos| pos + 1);
let events = self.events.read();
let event_to_stream = self.event_to_stream.read();
for i in start_pos..event_order.len() {
let event_id = &event_order[i];
if let Some(stream_id) = event_to_stream.get(event_id) {
if let Some(stream_events) = events.get(stream_id) {
for (eid, msg) in stream_events {
if eid == event_id {
result.push((eid.clone(), msg.clone()));
break;
}
}
}
}
}
Ok(result)
}
async fn get_stream_for_event(&self, event_id: &str) -> Result<Option<String>> {
Ok(self.event_to_stream.read().get(event_id).cloned())
}
}
type SessionCallback = Box<dyn Fn(&str) + Send + Sync>;
pub struct StreamableHttpServerConfig {
pub session_id_generator: Option<Box<dyn Fn() -> String + Send + Sync>>,
pub enable_json_response: bool,
pub event_store: Option<Arc<InMemoryEventStore>>,
pub on_session_initialized: Option<SessionCallback>,
pub on_session_closed: Option<SessionCallback>,
pub http_middleware: Option<Arc<ServerHttpMiddlewareChain>>,
pub allowed_origins: Option<AllowedOrigins>,
pub max_request_bytes: usize,
}
impl std::fmt::Debug for StreamableHttpServerConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("StreamableHttpServerConfig")
.field("session_id_generator", &self.session_id_generator.is_some())
.field("enable_json_response", &self.enable_json_response)
.field("event_store", &self.event_store.is_some())
.field(
"on_session_initialized",
&self.on_session_initialized.is_some(),
)
.field("on_session_closed", &self.on_session_closed.is_some())
.field("http_middleware", &self.http_middleware.is_some())
.field("allowed_origins", &self.allowed_origins)
.field("max_request_bytes", &self.max_request_bytes)
.finish()
}
}
impl Default for StreamableHttpServerConfig {
fn default() -> Self {
Self {
session_id_generator: Some(Box::new(|| Uuid::new_v4().to_string())),
enable_json_response: false,
event_store: Some(Arc::new(InMemoryEventStore::default())),
on_session_initialized: None,
on_session_closed: None,
http_middleware: None,
allowed_origins: None,
max_request_bytes: crate::server::limits::DEFAULT_MAX_REQUEST_BYTES,
}
}
}
impl StreamableHttpServerConfig {
pub fn stateless() -> Self {
Self {
session_id_generator: None,
enable_json_response: true,
event_store: None,
on_session_initialized: None,
on_session_closed: None,
http_middleware: None,
allowed_origins: Some(AllowedOrigins::any()),
max_request_bytes: crate::server::limits::DEFAULT_MAX_REQUEST_BYTES,
}
}
}
#[derive(Debug, Clone)]
struct SessionInfo {
initialized: bool,
protocol_version: Option<String>,
}
#[derive(Clone)]
pub(crate) struct ServerState {
server: Arc<tokio::sync::Mutex<Server>>,
config: Arc<StreamableHttpServerConfig>,
allowed_origins: AllowedOrigins,
sse_streams: Arc<RwLock<HashMap<String, mpsc::UnboundedSender<TransportMessage>>>>,
sessions: Arc<RwLock<HashMap<String, SessionInfo>>>,
}
pub(crate) fn build_mcp_router(state: ServerState) -> Router<()> {
Router::new()
.route("/", post(handle_post_request))
.route("/", get(handle_get_sse))
.route("/", delete(handle_delete_session))
.with_state(state)
}
pub(crate) fn make_server_state(
server: Arc<tokio::sync::Mutex<Server>>,
config: StreamableHttpServerConfig,
) -> ServerState {
let allowed_origins = config
.allowed_origins
.clone()
.unwrap_or_else(AllowedOrigins::localhost);
ServerState {
server,
config: Arc::new(config),
allowed_origins,
sse_streams: Arc::new(RwLock::new(HashMap::new())),
sessions: Arc::new(RwLock::new(HashMap::new())),
}
}
pub struct StreamableHttpServer {
addr: SocketAddr,
state: ServerState,
}
impl std::fmt::Debug for StreamableHttpServer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("StreamableHttpServer")
.field("addr", &self.addr)
.field("state", &"ServerState { ... }")
.finish()
}
}
fn create_error_response(status: StatusCode, code: i32, message: &str) -> Response {
let error_body = json!({
"jsonrpc": "2.0",
"error": {
"code": code,
"message": message
},
"id": null
});
(status, Json(error_body)).into_response()
}
impl StreamableHttpServer {
pub fn new(addr: SocketAddr, server: Arc<tokio::sync::Mutex<Server>>) -> Self {
Self::with_config(addr, server, StreamableHttpServerConfig::default())
}
pub fn with_config(
addr: SocketAddr,
server: Arc<tokio::sync::Mutex<Server>>,
config: StreamableHttpServerConfig,
) -> Self {
let state = make_server_state(server, config);
Self { addr, state }
}
pub async fn start(self) -> Result<(SocketAddr, tokio::task::JoinHandle<()>)> {
let allowed = self.state.allowed_origins.clone();
let cors = crate::server::tower_layers::build_mcp_cors_layer(&allowed);
let app = build_mcp_router(self.state)
.layer(SecurityHeadersLayer::default())
.layer(DnsRebindingLayer::new(allowed))
.layer(cors);
let listener = tokio::net::TcpListener::bind(self.addr).await?;
let local_addr = listener.local_addr()?;
let server_task = tokio::spawn(async move {
axum::serve(listener, app).await.unwrap();
});
Ok((local_addr, server_task))
}
}
fn validate_content_type_json(headers: &HeaderMap) -> std::result::Result<(), Response> {
let Some(content_type) = headers.get(header::CONTENT_TYPE) else {
return Err(create_error_response(
StatusCode::UNSUPPORTED_MEDIA_TYPE,
-32700,
"Content-Type header is required",
));
};
let ct = content_type.to_str().unwrap_or("");
if !ct.contains(APPLICATION_JSON) {
return Err(create_error_response(
StatusCode::UNSUPPORTED_MEDIA_TYPE,
-32700,
"Content-Type must be application/json",
));
}
Ok(())
}
fn validate_accept_post(headers: &HeaderMap) -> std::result::Result<(), Response> {
let Some(accept) = headers.get(header::ACCEPT) else {
return Err(create_error_response(
StatusCode::NOT_ACCEPTABLE,
-32700,
"Accept header is required",
));
};
let accept_str = accept.to_str().unwrap_or("");
if !accept_str.contains(APPLICATION_JSON) && !accept_str.contains(TEXT_EVENT_STREAM) {
return Err(create_error_response(
StatusCode::NOT_ACCEPTABLE,
-32700,
"Accept header must include application/json or text/event-stream",
));
}
Ok(())
}
fn validate_accept_sse(headers: &HeaderMap) -> std::result::Result<(), Response> {
let Some(accept) = headers.get(header::ACCEPT) else {
return Err(create_error_response(
StatusCode::NOT_ACCEPTABLE,
-32700,
"Accept header is required for SSE",
));
};
let accept_str = accept.to_str().unwrap_or("");
if !accept_str.contains(TEXT_EVENT_STREAM) {
return Err(create_error_response(
StatusCode::NOT_ACCEPTABLE,
-32700,
"Accept header must be text/event-stream for SSE",
));
}
Ok(())
}
fn validate_headers(headers: &HeaderMap, method: &str) -> std::result::Result<(), Response> {
match method {
"POST" => {
validate_content_type_json(headers)?;
validate_accept_post(headers)?;
},
"GET" => validate_accept_sse(headers)?,
_ => {},
}
Ok(())
}
fn process_init_session(
state: &ServerState,
session_id: Option<String>,
protocol_version: Option<String>,
) -> std::result::Result<(Option<String>, bool), Response> {
if let Some(generator) = &state.config.session_id_generator {
if let Some(sid) = session_id {
if let Some(session_info) = state.sessions.read().get(&sid) {
if session_info.initialized {
return Err(create_error_response(
StatusCode::BAD_REQUEST,
-32600,
"Session already initialized",
));
}
}
Ok((Some(sid), false))
} else {
let new_id = generator();
state.sessions.write().insert(
new_id.clone(),
SessionInfo {
initialized: false,
protocol_version,
},
);
if let Some(callback) = &state.config.on_session_initialized {
callback(&new_id);
}
Ok((Some(new_id), true))
}
} else {
Ok((None, false))
}
}
fn validate_non_init_session(
state: &ServerState,
session_id: Option<String>,
) -> std::result::Result<Option<String>, Response> {
if state.config.session_id_generator.is_some() {
match session_id {
None => {
Err(create_error_response(
StatusCode::BAD_REQUEST,
-32600,
"Session ID required for non-initialization requests",
))
},
Some(sid) => {
if !state.sessions.read().contains_key(&sid) {
Err(create_error_response(
StatusCode::NOT_FOUND,
-32600,
"Unknown session ID",
))
} else {
Ok(Some(sid))
}
},
}
} else {
Ok(None)
}
}
fn extract_negotiated_version(response: &TransportMessage) -> Option<String> {
if let TransportMessage::Response(ref json_resp) = response {
if let crate::types::jsonrpc::ResponsePayload::Result(ref value) = json_resp.payload {
if let Ok(init_result) =
serde_json::from_value::<crate::types::InitializeResult>(value.clone())
{
return Some(init_result.protocol_version.0);
}
}
}
None
}
fn update_session_after_init(
state: &ServerState,
session_id: Option<&String>,
negotiated_version: Option<String>,
) {
if let Some(sid) = session_id {
if let Some(session_info) = state.sessions.write().get_mut(sid) {
session_info.initialized = true;
session_info.protocol_version =
negotiated_version.or_else(|| Some(crate::DEFAULT_PROTOCOL_VERSION.to_string()));
}
}
}
fn serialize_response_as_json_value(
response: &TransportMessage,
) -> std::result::Result<serde_json::Value, Response> {
let json_bytes = crate::shared::StdioTransport::serialize_message(response).map_err(|e| {
create_error_response(
StatusCode::INTERNAL_SERVER_ERROR,
-32603,
&format!("Failed to serialize response: {}", e),
)
})?;
tracing::debug!(
target: "mcp.http",
response = %String::from_utf8_lossy(&json_bytes),
"HTTP response serialized bytes"
);
let json_value: serde_json::Value = serde_json::from_slice(&json_bytes).map_err(|e| {
create_error_response(
StatusCode::INTERNAL_SERVER_ERROR,
-32603,
&format!("Failed to parse JSON response: {}", e),
)
})?;
Ok(json_value)
}
fn build_json_response(response: &TransportMessage, trace_source: &'static str) -> Response {
let json_value = match serialize_response_as_json_value(response) {
Ok(v) => v,
Err(error_response) => return error_response,
};
tracing::debug!(
target: "mcp.http",
source = trace_source,
response = %serde_json::to_string(&json_value).unwrap_or_default(),
"HTTP response (JSON mode)"
);
(StatusCode::OK, Json(json_value)).into_response()
}
fn build_sse_response_from_single_message(response: TransportMessage) -> Response {
let (tx, rx) = mpsc::unbounded_channel();
tx.send(response).unwrap();
let stream = UnboundedReceiverStream::new(rx);
let sse = Sse::new(stream.map(|msg| {
let event_id = Uuid::new_v4().to_string();
let json_bytes =
crate::shared::StdioTransport::serialize_message(&msg).unwrap_or_else(|e| {
tracing::error!(target: "mcp.sse", error = %e, "Failed to serialize SSE message");
Vec::new()
});
let json_str = String::from_utf8(json_bytes).unwrap_or_else(|_| "{}".to_string());
Ok::<_, Infallible>(
Event::default()
.id(event_id)
.event("message")
.data(json_str),
)
}));
sse.into_response()
}
fn build_response(
state: &ServerState,
response: TransportMessage,
session_id: Option<&String>,
) -> Response {
if state.config.enable_json_response {
return build_json_response(&response, "JSON mode");
}
let Some(sid) = session_id else {
return build_json_response(&response, "SSE no-session fallback");
};
if let Some(sender) = state.sse_streams.read().get(sid) {
let _ = sender.send(response);
return StatusCode::ACCEPTED.into_response();
}
build_sse_response_from_single_message(response)
}
fn validate_protocol_version_supported(
protocol_version: Option<&String>,
) -> std::result::Result<(), Response> {
let Some(version) = protocol_version else {
return Ok(());
};
if crate::SUPPORTED_PROTOCOL_VERSIONS.contains(&version.as_str()) {
return Ok(());
}
Err(create_error_response(
StatusCode::BAD_REQUEST,
-32600,
&format!("Unsupported protocol version: {}", version),
))
}
fn validate_protocol_version_matches_session(
state: &ServerState,
session_id: Option<&String>,
protocol_version: Option<&String>,
) -> std::result::Result<(), Response> {
if state.config.session_id_generator.is_none() {
return Ok(());
}
let Some(sid) = session_id else {
return Ok(());
};
let sessions = state.sessions.read();
let Some(session_info) = sessions.get(sid.as_str()) else {
return Ok(());
};
let Some(negotiated_version) = session_info.protocol_version.as_ref() else {
return Ok(());
};
let Some(provided_version) = protocol_version else {
return Ok(());
};
if provided_version == negotiated_version {
return Ok(());
}
Err(create_error_response(
StatusCode::BAD_REQUEST,
-32600,
&format!(
"Protocol version mismatch: expected {}, got {}",
negotiated_version, provided_version
),
))
}
fn validate_protocol_version(
state: &ServerState,
session_id: Option<&String>,
protocol_version: Option<&String>,
) -> std::result::Result<(), Response> {
validate_protocol_version_supported(protocol_version)?;
validate_protocol_version_matches_session(state, session_id, protocol_version)
}
async fn handle_post_request(
State(state): State<ServerState>,
request: axum::extract::Request<Body>,
) -> impl IntoResponse {
if state.config.http_middleware.is_none() {
return handle_post_fast_path(state, request).await;
}
handle_post_with_middleware(state, request).await
}
async fn extract_and_validate_auth(
state: &ServerState,
headers: &HeaderMap,
) -> std::result::Result<Option<crate::server::auth::AuthContext>, Response> {
let server = state.server.lock().await;
if let Some(auth_provider) = server.get_auth_provider() {
let auth_header = headers
.get(http::header::AUTHORIZATION)
.and_then(|v| v.to_str().ok());
match auth_provider.validate_request(auth_header).await {
Ok(ctx) => Ok(ctx),
Err(e) => {
Err(create_error_response(
StatusCode::UNAUTHORIZED,
-32003,
&format!("Authentication failed: {}", e),
))
},
}
} else {
Ok(extract_auth_from_proxy_headers(headers))
}
}
fn extract_auth_from_proxy_headers(
headers: &HeaderMap,
) -> Option<crate::server::auth::AuthContext> {
let user_id = headers
.get("x-pmcp-user-id")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string())?;
let email = headers
.get("x-pmcp-user-email")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
let name = headers
.get("x-pmcp-user-name")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
let groups = headers
.get("x-pmcp-user-groups")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
let tenant_id = headers
.get("x-pmcp-tenant-id")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
let mut claims = std::collections::HashMap::new();
if let Some(ref email) = email {
claims.insert(
"email".to_string(),
serde_json::Value::String(email.clone()),
);
}
if let Some(ref name) = name {
claims.insert("name".to_string(), serde_json::Value::String(name.clone()));
}
if let Some(ref groups) = groups {
let groups_array: Vec<serde_json::Value> = groups
.split(',')
.map(|g| serde_json::Value::String(g.trim().to_string()))
.filter(|v| v.as_str() != Some(""))
.collect();
claims.insert("groups".to_string(), serde_json::Value::Array(groups_array));
}
if let Some(ref tenant_id) = tenant_id {
claims.insert(
"tenant_id".to_string(),
serde_json::Value::String(tenant_id.clone()),
);
}
for (name, value) in headers {
let Some(suffix) = name.as_str().strip_prefix("x-pmcp-claim-custom-") else {
continue;
};
let Ok(val_str) = value.to_str() else {
continue;
};
if suffix.is_empty() || val_str.is_empty() {
continue;
}
let snake: String = suffix
.chars()
.map(|c| if c == '-' { '_' } else { c })
.collect();
claims.insert(
format!("custom:{}", snake),
serde_json::Value::String(val_str.to_string()),
);
}
tracing::debug!(
user_id = %user_id,
email = ?email,
"Extracted auth context from proxy headers"
);
Some(crate::server::auth::AuthContext {
subject: user_id,
scopes: vec![],
claims,
token: None,
client_id: None,
expires_at: None,
authenticated: true,
})
}
fn extract_session_and_protocol_headers(headers: &HeaderMap) -> (Option<String>, Option<String>) {
let session_id = headers
.get(MCP_SESSION_ID)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
let protocol_version = headers
.get(MCP_PROTOCOL_VERSION)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
(session_id, protocol_version)
}
fn is_initialize_request(message: &TransportMessage) -> bool {
matches!(
message,
TransportMessage::Request { request: Request::Client(boxed), .. }
if matches!(**boxed, ClientRequest::Initialize(_))
)
}
fn resolve_session_for_request(
state: &ServerState,
is_init_request: bool,
session_id: Option<String>,
protocol_version: Option<String>,
) -> std::result::Result<Option<String>, Response> {
if is_init_request {
let (sid, _is_new) = process_init_session(state, session_id, protocol_version)?;
Ok(sid)
} else {
validate_non_init_session(state, session_id)
}
}
fn compute_outbound_protocol_version(
state: &ServerState,
response_session_id: Option<&String>,
is_init_request: bool,
negotiated_version: Option<&str>,
) -> String {
if is_init_request {
return negotiated_version.map_or_else(
|| crate::DEFAULT_PROTOCOL_VERSION.to_string(),
std::string::ToString::to_string,
);
}
if let Some(sid) = response_session_id {
if let Some(session_info) = state.sessions.read().get(sid) {
return session_info
.protocol_version
.clone()
.unwrap_or_else(|| crate::DEFAULT_PROTOCOL_VERSION.to_string());
}
}
crate::DEFAULT_PROTOCOL_VERSION.to_string()
}
async fn report_middleware_error(
http_middleware: &ServerHttpMiddlewareChain,
context: &ServerHttpContext,
error_kind: &str,
) {
let err = crate::Error::protocol_msg(error_kind);
let _ = http_middleware.handle_error(&err, context).await;
}
async fn run_request_middleware(
http_middleware: &ServerHttpMiddlewareChain,
server_request: &mut crate::server::http_middleware::ServerHttpRequest,
context: &ServerHttpContext,
) -> std::result::Result<(), Response> {
if let Err(e) = http_middleware
.process_request(server_request, context)
.await
{
let _ = http_middleware.handle_error(&e, context).await;
return Err(create_error_response(
StatusCode::INTERNAL_SERVER_ERROR,
-32603,
&format!("Middleware rejected request: {}", e),
));
}
Ok(())
}
async fn parse_transport_message_with_middleware(
body: &[u8],
http_middleware: &ServerHttpMiddlewareChain,
context: &ServerHttpContext,
) -> std::result::Result<TransportMessage, Response> {
match crate::shared::StdioTransport::parse_message(body) {
Ok(msg) => Ok(msg),
Err(e) => {
let mut error_response = ServerHttpResponse::new(
StatusCode::BAD_REQUEST,
HeaderMap::new(),
format!("{{\"error\":\"Invalid JSON: {}\"}}", e).into_bytes(),
);
let _ = http_middleware
.process_response(&mut error_response, context)
.await;
Err(into_axum(error_response))
},
}
}
async fn extract_auth_with_middleware(
state: &ServerState,
server_request: &crate::server::http_middleware::ServerHttpRequest,
http_middleware: &ServerHttpMiddlewareChain,
context: &ServerHttpContext,
) -> std::result::Result<Option<crate::server::auth::AuthContext>, Response> {
let server = state.server.lock().await;
let Some(auth_provider) = server.get_auth_provider() else {
return Ok(None);
};
let auth_header = server_request.get_header("authorization");
match auth_provider.validate_request(auth_header).await {
Ok(ctx) => Ok(ctx),
Err(e) => {
let auth_error = crate::Error::authentication(format!("Authentication failed: {}", e));
let _ = http_middleware.handle_error(&auth_error, context).await;
Err(create_error_response(
StatusCode::UNAUTHORIZED,
-32003,
&format!("Authentication failed: {}", e),
))
},
}
}
async fn build_success_response_with_middleware(
response_msg: &TransportMessage,
response_session_id: Option<&String>,
version_to_send: &str,
http_middleware: &ServerHttpMiddlewareChain,
context: &ServerHttpContext,
) -> Response {
let response_body = match serde_json::to_vec(response_msg) {
Ok(b) => b,
Err(e) => {
let serialization_error =
crate::Error::internal(format!("Failed to serialize response: {}", e));
let _ = http_middleware
.handle_error(&serialization_error, context)
.await;
return create_error_response(
StatusCode::INTERNAL_SERVER_ERROR,
-32603,
&format!("Failed to serialize response: {}", e),
);
},
};
let mut response_headers = HeaderMap::new();
response_headers.insert(header::CONTENT_TYPE, APPLICATION_JSON.parse().unwrap());
if let Some(sid) = response_session_id {
response_headers.insert(MCP_SESSION_ID, sid.parse().unwrap());
}
response_headers.insert(MCP_PROTOCOL_VERSION, version_to_send.parse().unwrap());
let mut server_response =
ServerHttpResponse::new(StatusCode::OK, response_headers, response_body);
if let Err(e) = http_middleware
.process_response(&mut server_response, context)
.await
{
tracing::warn!("Response middleware processing failed: {}", e);
}
into_axum(server_response)
}
async fn store_response_event(
state: &ServerState,
response_session_id: Option<&String>,
response_msg: &TransportMessage,
) {
if let Some(event_store) = &state.config.event_store {
if let Some(sid) = response_session_id {
let event_id = Uuid::new_v4().to_string();
let _ = event_store.store_event(sid, &event_id, response_msg).await;
}
}
}
async fn read_body_with_limit(
body: Body,
max_bytes: usize,
) -> std::result::Result<String, Response> {
let body_bytes = axum::body::to_bytes(body, max_bytes).await.map_err(|e| {
create_error_response(
StatusCode::PAYLOAD_TOO_LARGE,
-32600,
&format!("Request body exceeds limit: {}", e),
)
})?;
Ok(String::from_utf8_lossy(&body_bytes).to_string())
}
fn parse_transport_message_fast(body: &[u8]) -> std::result::Result<TransportMessage, Response> {
crate::shared::StdioTransport::parse_message(body).map_err(|e| {
create_error_response(
StatusCode::BAD_REQUEST,
-32700,
&format!("Invalid JSON: {}", e),
)
})
}
async fn handle_fast_path_request(
state: &ServerState,
id: crate::types::RequestId,
request: Request,
auth_context: Option<crate::server::auth::AuthContext>,
is_init_request: bool,
response_session_id: Option<String>,
session_id: Option<&String>,
) -> Response {
let json_response = {
let server = state.server.lock().await;
server.handle_request(id, request, auth_context).await
};
tracing::debug!(
target: "mcp.http",
response = %serde_json::to_string(&json_response).unwrap_or_default(),
"StreamableHttpServer response"
);
let response_msg = TransportMessage::Response(json_response);
let negotiated_version = if is_init_request {
let version = extract_negotiated_version(&response_msg);
update_session_after_init(state, response_session_id.as_ref(), version.clone());
version
} else {
None
};
store_response_event(state, response_session_id.as_ref(), &response_msg).await;
let mut response = build_response(state, response_msg, session_id);
if let Some(sid) = &response_session_id {
response
.headers_mut()
.insert(MCP_SESSION_ID, sid.parse().unwrap());
}
let version_to_send = compute_outbound_protocol_version(
state,
response_session_id.as_ref(),
is_init_request,
negotiated_version.as_deref(),
);
response
.headers_mut()
.insert(MCP_PROTOCOL_VERSION, version_to_send.parse().unwrap());
response
}
async fn handle_post_fast_path(
state: ServerState,
request: axum::extract::Request<Body>,
) -> Response {
let (parts, body) = request.into_parts();
let headers = parts.headers;
let body = match read_body_with_limit(body, state.config.max_request_bytes).await {
Ok(b) => b,
Err(response) => return response,
};
if let Err(error_response) = validate_headers(&headers, "POST") {
return error_response;
}
let message = match parse_transport_message_fast(body.as_bytes()) {
Ok(msg) => msg,
Err(response) => return response,
};
let (session_id, protocol_version) = extract_session_and_protocol_headers(&headers);
let is_init_request = is_initialize_request(&message);
let response_session_id = match resolve_session_for_request(
&state,
is_init_request,
session_id.clone(),
protocol_version.clone(),
) {
Ok(sid) => sid,
Err(error_response) => return error_response,
};
if !is_init_request {
if let Err(error_response) =
validate_protocol_version(&state, session_id.as_ref(), protocol_version.as_ref())
{
return error_response;
}
}
let auth_context = match extract_and_validate_auth(&state, &headers).await {
Ok(ctx) => ctx,
Err(response) => return response,
};
match message {
TransportMessage::Request { id, request } => {
handle_fast_path_request(
&state,
id,
request,
auth_context,
is_init_request,
response_session_id,
session_id.as_ref(),
)
.await
},
TransportMessage::Notification { .. } | TransportMessage::Response(_) => {
StatusCode::ACCEPTED.into_response()
},
}
}
fn build_middleware_context(
server_request: &crate::server::http_middleware::ServerHttpRequest,
) -> ServerHttpContext {
let session_id = server_request
.get_header(MCP_SESSION_ID)
.map(str::to_string);
let request_id = server_request
.get_header("x-request-id")
.map_or_else(|| Uuid::new_v4().to_string(), str::to_string);
ServerHttpContext {
request_id,
start_time: std::time::Instant::now(),
session_id,
}
}
async fn convert_axum_to_middleware_request(
request: axum::extract::Request<Body>,
max_request_bytes: usize,
) -> std::result::Result<crate::server::http_middleware::ServerHttpRequest, Response> {
let (parts, body) = request.into_parts();
from_axum_with_limit(parts, body, max_request_bytes)
.await
.map_err(|e| {
create_error_response(
StatusCode::PAYLOAD_TOO_LARGE,
-32600,
&format!("Request body exceeds limit: {}", e),
)
})
}
async fn resolve_session_with_error_hook(
state: &ServerState,
is_init_request: bool,
session_id: Option<String>,
protocol_version: Option<String>,
http_middleware: &ServerHttpMiddlewareChain,
http_context: &ServerHttpContext,
) -> std::result::Result<Option<String>, Response> {
match resolve_session_for_request(state, is_init_request, session_id, protocol_version) {
Ok(sid) => Ok(sid),
Err(error_response) => {
let kind = if is_init_request {
"Session initialization failed"
} else {
"Session validation failed"
};
report_middleware_error(http_middleware, http_context, kind).await;
Err(error_response)
},
}
}
async fn validate_protocol_version_with_error_hook(
state: &ServerState,
is_init_request: bool,
session_id: Option<&String>,
protocol_version: Option<&String>,
http_middleware: &ServerHttpMiddlewareChain,
http_context: &ServerHttpContext,
) -> std::result::Result<(), Response> {
if is_init_request {
return Ok(());
}
if let Err(error_response) = validate_protocol_version(state, session_id, protocol_version) {
report_middleware_error(
http_middleware,
http_context,
"Protocol version validation failed",
)
.await;
return Err(error_response);
}
Ok(())
}
async fn dispatch_message_with_middleware(
state: &ServerState,
message: TransportMessage,
is_init_request: bool,
response_session_id: Option<String>,
auth_context: Option<crate::server::auth::AuthContext>,
http_middleware: &ServerHttpMiddlewareChain,
http_context: &ServerHttpContext,
) -> Response {
match message {
TransportMessage::Request { id, request } => {
let json_response = {
let server = state.server.lock().await;
server.handle_request(id, request, auth_context).await
};
let response_msg = TransportMessage::Response(json_response);
let negotiated_version = if is_init_request {
let version = extract_negotiated_version(&response_msg);
update_session_after_init(state, response_session_id.as_ref(), version.clone());
version
} else {
None
};
store_response_event(state, response_session_id.as_ref(), &response_msg).await;
let version_to_send = compute_outbound_protocol_version(
state,
response_session_id.as_ref(),
is_init_request,
negotiated_version.as_deref(),
);
build_success_response_with_middleware(
&response_msg,
response_session_id.as_ref(),
&version_to_send,
http_middleware,
http_context,
)
.await
},
TransportMessage::Notification { .. } | TransportMessage::Response(_) => {
StatusCode::ACCEPTED.into_response()
},
}
}
async fn handle_post_with_middleware(
state: ServerState,
request: axum::extract::Request<Body>,
) -> Response {
let http_middleware = state
.config
.http_middleware
.as_ref()
.expect("Middleware chain must exist");
let mut server_request =
match convert_axum_to_middleware_request(request, state.config.max_request_bytes).await {
Ok(req) => req,
Err(response) => return response,
};
let http_context = build_middleware_context(&server_request);
if let Err(response) =
run_request_middleware(http_middleware, &mut server_request, &http_context).await
{
return response;
}
if let Err(error_response) = validate_headers(&server_request.headers, "POST") {
report_middleware_error(http_middleware, &http_context, "Header validation failed").await;
return error_response;
}
let message = match parse_transport_message_with_middleware(
&server_request.body,
http_middleware,
&http_context,
)
.await
{
Ok(msg) => msg,
Err(response) => return response,
};
let (session_id, protocol_version) =
extract_session_and_protocol_headers(&server_request.headers);
let is_init_request = is_initialize_request(&message);
let response_session_id = match resolve_session_with_error_hook(
&state,
is_init_request,
session_id.clone(),
protocol_version.clone(),
http_middleware,
&http_context,
)
.await
{
Ok(sid) => sid,
Err(response) => return response,
};
if let Err(response) = validate_protocol_version_with_error_hook(
&state,
is_init_request,
session_id.as_ref(),
protocol_version.as_ref(),
http_middleware,
&http_context,
)
.await
{
return response;
}
let auth_context =
match extract_auth_with_middleware(&state, &server_request, http_middleware, &http_context)
.await
{
Ok(ctx) => ctx,
Err(response) => return response,
};
dispatch_message_with_middleware(
&state,
message,
is_init_request,
response_session_id,
auth_context,
http_middleware,
&http_context,
)
.await
}
fn resolve_sse_session(
state: &ServerState,
incoming_session_id: Option<String>,
) -> std::result::Result<String, Response> {
if let Some(sid) = incoming_session_id {
if state.config.session_id_generator.is_some() && !state.sessions.read().contains_key(&sid)
{
return Err(create_error_response(
StatusCode::NOT_FOUND,
-32600,
"Unknown session ID",
));
}
return Ok(sid);
}
let Some(generator) = &state.config.session_id_generator else {
return Err(create_error_response(
StatusCode::METHOD_NOT_ALLOWED,
-32601,
"SSE not supported in stateless mode",
));
};
let new_id = generator();
state.sessions.write().insert(
new_id.clone(),
SessionInfo {
initialized: true, protocol_version: None,
},
);
if let Some(callback) = &state.config.on_session_initialized {
callback(&new_id);
}
Ok(new_id)
}
async fn replay_sse_events_from_header(
headers: &HeaderMap,
tx: &mpsc::UnboundedSender<TransportMessage>,
event_store: Option<&Arc<InMemoryEventStore>>,
) {
let Some(last_event_id) = headers.get(LAST_EVENT_ID) else {
return;
};
let Ok(last_id) = last_event_id.to_str() else {
return;
};
let Some(store) = event_store else {
return;
};
if let Ok(events) = store.replay_events_after(last_id).await {
for (_event_id, msg) in events {
let _ = tx.send(msg);
}
}
}
fn sse_event_for_message(
msg: &TransportMessage,
session_id: &str,
event_store: Option<&Arc<InMemoryEventStore>>,
) -> Event {
let event_id = Uuid::new_v4().to_string();
if let Some(store) = event_store {
let sid = session_id.to_string();
let msg_clone = msg.clone();
let store = store.clone();
let event_id_clone = event_id.clone();
tokio::spawn(async move {
let _ = store.store_event(&sid, &event_id_clone, &msg_clone).await;
});
}
Event::default()
.id(event_id)
.event("message")
.data(serde_json::to_string(msg).unwrap())
}
fn attach_sse_response_headers(response: &mut Response, session_id: &str) {
response
.headers_mut()
.insert(MCP_SESSION_ID, session_id.parse().unwrap());
response.headers_mut().insert(
header::CACHE_CONTROL,
HeaderValue::from_static("no-cache, no-transform"),
);
response
.headers_mut()
.insert(header::CONNECTION, HeaderValue::from_static("keep-alive"));
}
async fn handle_get_sse(State(state): State<ServerState>, headers: HeaderMap) -> impl IntoResponse {
if let Err(error_response) = validate_headers(&headers, "GET") {
return error_response;
}
let incoming_session_id = headers
.get(MCP_SESSION_ID)
.and_then(|v| v.to_str().ok())
.map(str::to_string);
let session_id = match resolve_sse_session(&state, incoming_session_id) {
Ok(sid) => sid,
Err(response) => return response,
};
if state.sse_streams.read().contains_key(&session_id) {
return create_error_response(
StatusCode::CONFLICT,
-32600,
"SSE stream already exists for this session",
);
}
let (tx, rx) = mpsc::unbounded_channel();
state
.sse_streams
.write()
.insert(session_id.clone(), tx.clone());
replay_sse_events_from_header(&headers, &tx, state.config.event_store.as_ref()).await;
let stream = UnboundedReceiverStream::new(rx);
let session_id_for_header = session_id.clone();
let session_id_for_stream = session_id.clone();
let event_store = state.config.event_store.clone();
let sse = Sse::new(stream.map(move |msg| {
Ok::<_, Infallible>(sse_event_for_message(
&msg,
&session_id_for_stream,
event_store.as_ref(),
))
}));
let mut response = sse.into_response();
attach_sse_response_headers(&mut response, &session_id_for_header);
response
}
async fn handle_delete_session(
State(state): State<ServerState>,
headers: HeaderMap,
) -> impl IntoResponse {
let session_id = headers
.get(MCP_SESSION_ID)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
if let Some(sid) = session_id {
let session_exists = state.sessions.read().contains_key(&sid);
if !session_exists && state.config.session_id_generator.is_some() {
return create_error_response(StatusCode::NOT_FOUND, -32600, "Unknown session ID");
}
state.sse_streams.write().remove(&sid);
state.sessions.write().remove(&sid);
if let Some(callback) = &state.config.on_session_closed {
callback(&sid);
}
(StatusCode::OK, Json(json!({"status": "ok"}))).into_response()
} else {
create_error_response(StatusCode::NOT_FOUND, -32600, "No session ID provided")
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn extract_custom_claim_header_inserted_under_cognito_key() {
let mut h = HeaderMap::new();
h.insert("x-pmcp-user-id", "user-123".parse().unwrap());
h.insert(
"x-pmcp-claim-custom-primary-creator",
"rosen".parse().unwrap(),
);
let ctx = extract_auth_from_proxy_headers(&h).expect("auth ctx");
assert_eq!(
ctx.claims.get("custom:primary_creator"),
Some(&serde_json::Value::String("rosen".into())),
);
}
#[test]
#[allow(clippy::unnecessary_get_then_check)]
fn extract_custom_claim_empty_value_dropped() {
let mut h = HeaderMap::new();
h.insert("x-pmcp-user-id", "user-123".parse().unwrap());
h.insert("x-pmcp-claim-custom-empty", "".parse().unwrap());
let ctx = extract_auth_from_proxy_headers(&h).expect("auth ctx");
assert!(ctx.claims.get("custom:empty").is_none());
}
#[test]
fn extract_custom_claim_kebab_to_snake() {
let mut h = HeaderMap::new();
h.insert("x-pmcp-user-id", "u".parse().unwrap());
h.insert(
"x-pmcp-claim-custom-promo-code",
"SUMMER25".parse().unwrap(),
);
let ctx = extract_auth_from_proxy_headers(&h).expect("auth ctx");
assert_eq!(
ctx.claims.get("custom:promo_code"),
Some(&serde_json::Value::String("SUMMER25".into())),
);
}
#[test]
fn extract_custom_claim_coexists_with_standard_headers() {
let mut h = HeaderMap::new();
h.insert("x-pmcp-user-id", "u".parse().unwrap());
h.insert("x-pmcp-user-email", "u@example.com".parse().unwrap());
h.insert("x-pmcp-user-groups", "g1,g2".parse().unwrap());
h.insert("x-pmcp-claim-custom-tier", "gold".parse().unwrap());
let ctx = extract_auth_from_proxy_headers(&h).expect("auth ctx");
assert_eq!(ctx.subject, "u");
assert_eq!(ctx.claims["email"], "u@example.com");
assert_eq!(ctx.claims["custom:tier"], "gold");
}
}