use crate::{
auth::Claims,
error::{Error, ErrorCode},
types::{Message, RequestId, Response},
};
use bytes::Bytes;
use futures_util::{Stream, StreamExt, future::Either, stream};
use http::{HeaderMap, HeaderValue};
use std::pin::Pin;
use std::sync::Arc;
use tokio_stream::wrappers::ReceiverStream;
use super::{
context::HttpContext,
engine::HttpEngine,
types::{HttpRequest, HttpResponse, SseResponse},
};
pub(crate) const MCP_SESSION_ID: &str = "Mcp-Session-Id";
pub async fn dispatch_post<E: HttpEngine>(
req: E::Request,
ctx: &HttpContext,
) -> Result<E::Response, Error> {
let neutral = E::adapt_request(req).await?;
let resp = handle_post(neutral, ctx).await;
Ok(E::adapt_response(resp))
}
pub async fn dispatch_delete<E: HttpEngine>(
req: E::Request,
ctx: &HttpContext,
) -> Result<E::Response, Error> {
let neutral = E::adapt_request(req).await?;
let resp = handle_delete(neutral, ctx).await;
Ok(E::adapt_response(resp))
}
pub async fn dispatch_get_sse<E: HttpEngine>(
req: E::Request,
ctx: &HttpContext,
) -> Result<SseResponse<impl Stream<Item = E::SseEvent> + Send + 'static>, Error> {
let neutral = E::adapt_request(req).await?;
Ok(handle_get_sse::<E>(neutral, ctx).await)
}
pub async fn handle_post(req: HttpRequest, ctx: &HttpContext) -> HttpResponse {
let mut headers = req.headers().clone();
let id = get_or_create_mcp_session(&headers);
let claims = req.extensions().get::<Arc<dyn Claims>>().cloned();
let body = req.into_body();
let msg = match parse_message(&body) {
Ok(msg) => msg,
Err(code) => {
let resp = Response::error(RequestId::Null, Error::from(code));
return build_json_response(http::StatusCode::OK, id, &Message::Response(resp));
}
};
if let Message::Request(ref r) = msg
&& r.method == crate::commands::INIT
{
ctx.sse_registry.pre_register(id);
}
if matches!(msg, Message::Notification(_)) {
let msg = msg.set_session_id(id);
let _ = ctx.inbound_tx.send(Ok(msg)).await;
return status_response(http::StatusCode::ACCEPTED, id);
}
if let Message::Batch(ref batch) = msg
&& !batch.has_requests()
&& !batch.has_error_responses()
{
let msg = msg.set_session_id(id);
if ctx.inbound_tx.send(Ok(msg)).await.is_err() {
return status_response(http::StatusCode::INTERNAL_SERVER_ERROR, id);
}
return status_response(http::StatusCode::ACCEPTED, id);
}
headers.remove(http::header::AUTHORIZATION);
let mut msg = msg.set_session_id(id).set_headers(headers);
if let Some(c) = claims {
msg = msg.set_claims(c);
}
let (resp_tx, resp_rx) = tokio::sync::oneshot::channel::<Message>();
ctx.pending.insert(msg.full_id(), resp_tx);
if ctx.inbound_tx.send(Ok(msg)).await.is_err() {
return status_response(http::StatusCode::INTERNAL_SERVER_ERROR, id);
}
match resp_rx.await {
Ok(resp) => build_json_response(http::StatusCode::OK, id, &resp),
Err(_) => status_response(http::StatusCode::INTERNAL_SERVER_ERROR, id),
}
}
fn parse_message(body: &Bytes) -> Result<Message, ErrorCode> {
serde_json::from_slice::<Message>(body).map_err(|e| match e.classify() {
serde_json::error::Category::Syntax | serde_json::error::Category::Eof => {
ErrorCode::ParseError
}
_ => ErrorCode::InvalidRequest,
})
}
fn get_or_create_mcp_session(headers: &HeaderMap) -> uuid::Uuid {
headers
.get(MCP_SESSION_ID)
.and_then(|v| v.to_str().ok())
.and_then(|s| uuid::Uuid::parse_str(s).ok())
.unwrap_or_else(uuid::Uuid::new_v4)
}
fn build_json_response(
status: http::StatusCode,
session: uuid::Uuid,
body: &Message,
) -> HttpResponse {
let json = serde_json::to_vec(body).unwrap_or_default();
let mut resp = http::Response::builder()
.status(status)
.header(http::header::CONTENT_TYPE, "application/json")
.body(Bytes::from(json))
.unwrap_or_default();
if let Ok(v) = HeaderValue::from_str(&session.to_string()) {
resp.headers_mut().insert(MCP_SESSION_ID, v);
}
resp
}
fn status_response(status: http::StatusCode, session: uuid::Uuid) -> HttpResponse {
let mut resp = http::Response::builder()
.status(status)
.body(Bytes::new())
.unwrap_or_default();
if let Ok(v) = HeaderValue::from_str(&session.to_string()) {
resp.headers_mut().insert(MCP_SESSION_ID, v);
}
resp
}
pub async fn handle_delete(req: HttpRequest, ctx: &HttpContext) -> HttpResponse {
let Some(id) = parse_session_id(req.headers()) else {
return http::Response::builder()
.status(http::StatusCode::BAD_REQUEST)
.body(Bytes::new())
.unwrap_or_default();
};
#[cfg(feature = "tracing")]
crate::types::notification::fmt::LOG_REGISTRY.unregister(&id);
ctx.sse_registry.terminate(&id);
status_response(http::StatusCode::OK, id)
}
fn parse_session_id(headers: &HeaderMap) -> Option<uuid::Uuid> {
headers
.get(MCP_SESSION_ID)
.and_then(|v| v.to_str().ok())
.and_then(|s| uuid::Uuid::parse_str(s).ok())
}
enum SseItem {
Tracked(u64, Arc<Message>),
Ephemeral(Box<Message>),
}
struct SseConnectionCleanup {
id: uuid::Uuid,
generation: u64,
registry: Arc<crate::shared::SseSessionRegistry>,
}
impl Drop for SseConnectionCleanup {
fn drop(&mut self) {
#[cfg(feature = "tracing")]
crate::types::notification::fmt::LOG_REGISTRY
.unregister_if_generation(&self.id, self.generation);
self.registry.unregister(&self.id, self.generation);
}
}
pub async fn handle_get_sse<E: HttpEngine>(
req: HttpRequest,
ctx: &HttpContext,
) -> SseResponse<impl Stream<Item = E::SseEvent> + Send + 'static> {
let Some(id) = parse_session_id(req.headers()) else {
return SseResponse::Status(
http::Response::builder()
.status(http::StatusCode::BAD_REQUEST)
.body(Bytes::new())
.unwrap_or_default(),
);
};
let (msg_tx, msg_rx) =
tokio::sync::mpsc::channel::<(u64, Arc<Message>)>(ctx.sse_live_queue_capacity);
let (_log_tx, log_rx) = tokio::sync::mpsc::channel::<Message>(ctx.sse_log_queue_capacity);
let generation = ctx.sse_registry.register(id, msg_tx);
#[cfg(feature = "tracing")]
crate::types::notification::fmt::LOG_REGISTRY.register(id, generation, _log_tx);
let last_seq: Option<u64> = req
.headers()
.get("last-event-id")
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse().ok());
let replay = match last_seq {
Some(seq) => ctx.sse_registry.replay_since(&id, seq),
None => ctx.sse_registry.replay_all(&id),
};
let msg_stream = if replay.is_empty() {
Either::Left(ReceiverStream::new(msg_rx).map(|(seq, arc)| SseItem::Tracked(seq, arc)))
} else {
let replay_end_seq = replay.last().map(|(s, _)| *s).unwrap_or(0);
let replay_stream = stream::iter(replay).map(|(seq, arc)| SseItem::Tracked(seq, arc));
let live = ReceiverStream::new(msg_rx)
.filter(move |&(seq, _)| {
let keep = seq > replay_end_seq;
async move { keep }
})
.map(|(seq, arc)| SseItem::Tracked(seq, arc));
Either::Right(replay_stream.chain(live))
};
let log_stream = ReceiverStream::new(log_rx).map(|m| SseItem::Ephemeral(Box::new(m)));
let merged = stream::select(log_stream, msg_stream);
let cleanup = SseConnectionCleanup {
id,
generation,
registry: ctx.sse_registry.clone(),
};
let mut merged = Box::pin(merged);
let guarded = stream::poll_fn(move |cx| {
let _cleanup = &cleanup;
Pin::new(&mut merged).poll_next(cx)
})
.map(|item| match item {
SseItem::Tracked(seq, msg) => E::tracked_event(seq, &msg),
SseItem::Ephemeral(msg) => E::ephemeral_event(&msg),
});
let mut headers = HeaderMap::new();
if let Ok(v) = HeaderValue::from_str(&id.to_string()) {
headers.insert(MCP_SESSION_ID, v);
}
SseResponse::Stream {
headers,
stream: guarded,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::shared::SseSessionRegistry;
use bytes::Bytes;
use dashmap::DashMap;
use std::sync::Arc;
use tokio::sync::mpsc;
fn make_ctx() -> (
HttpContext,
mpsc::Receiver<Result<crate::types::Message, crate::error::Error>>,
) {
let (inbound_tx, inbound_rx) =
mpsc::channel::<Result<crate::types::Message, crate::error::Error>>(8);
let ctx = HttpContext {
addr: "127.0.0.1:0".into(),
endpoint: "/mcp".into(),
pending: Arc::new(DashMap::new()),
sse_registry: Arc::new(SseSessionRegistry::new(8)),
inbound_tx,
sse_live_queue_capacity: 64,
sse_log_queue_capacity: 64,
};
(ctx, inbound_rx)
}
fn make_request_body(method: &str) -> Bytes {
let body = serde_json::json!({
"jsonrpc": "2.0",
"method": method,
"id": 1
});
Bytes::from(serde_json::to_vec(&body).unwrap())
}
fn make_notification_body(method: &str) -> Bytes {
let body = serde_json::json!({
"jsonrpc": "2.0",
"method": method
});
Bytes::from(serde_json::to_vec(&body).unwrap())
}
#[tokio::test]
async fn notification_returns_202_without_pending_entry() {
let (ctx, mut _rx) = make_ctx();
let req = http::Request::builder()
.method("POST")
.uri("/mcp")
.body(make_notification_body("notifications/cancelled"))
.unwrap();
let resp = handle_post(req, &ctx).await;
assert_eq!(resp.status(), http::StatusCode::ACCEPTED);
assert!(
ctx.pending.is_empty(),
"no pending oneshot for notifications"
);
}
#[tokio::test]
async fn malformed_json_returns_parse_error_response() {
let (ctx, _rx) = make_ctx();
let req = http::Request::builder()
.method("POST")
.uri("/mcp")
.body(Bytes::from_static(b"not json"))
.unwrap();
let resp = handle_post(req, &ctx).await;
assert_eq!(resp.status(), http::StatusCode::OK);
let body: serde_json::Value = serde_json::from_slice(resp.body()).unwrap();
assert_eq!(body["error"]["code"], -32700);
}
#[tokio::test]
async fn invalid_message_shape_returns_invalid_request() {
let (ctx, _rx) = make_ctx();
let req = http::Request::builder()
.method("POST")
.uri("/mcp")
.body(Bytes::from_static(b"{\"valid_json\": true}"))
.unwrap();
let resp = handle_post(req, &ctx).await;
assert_eq!(resp.status(), http::StatusCode::OK);
let body: serde_json::Value = serde_json::from_slice(resp.body()).unwrap();
assert_eq!(body["error"]["code"], -32600);
}
#[tokio::test]
async fn init_request_pre_registers_session() {
let (ctx, _rx) = make_ctx();
let req = http::Request::builder()
.method("POST")
.uri("/mcp")
.body(make_request_body(crate::commands::INIT))
.unwrap();
let ctx_arc = std::sync::Arc::new(ctx);
let ctx_clone = ctx_arc.clone();
let _h = tokio::spawn(async move {
handle_post(req, &ctx_clone).await;
});
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
assert_eq!(ctx_arc.pending.len(), 1);
}
#[tokio::test]
async fn delete_without_session_id_returns_400() {
let (ctx, _rx) = make_ctx();
let req = http::Request::builder()
.method("DELETE")
.uri("/mcp")
.body(Bytes::new())
.unwrap();
let resp = handle_delete(req, &ctx).await;
assert_eq!(resp.status(), http::StatusCode::BAD_REQUEST);
}
#[tokio::test]
async fn delete_with_session_id_echoes_it_back() {
let (ctx, _rx) = make_ctx();
let id = uuid::Uuid::new_v4();
let req = http::Request::builder()
.method("DELETE")
.uri("/mcp")
.header(MCP_SESSION_ID, id.to_string())
.body(Bytes::new())
.unwrap();
let resp = handle_delete(req, &ctx).await;
assert_eq!(resp.status(), http::StatusCode::OK);
assert_eq!(
resp.headers()
.get(MCP_SESSION_ID)
.and_then(|v| v.to_str().ok()),
Some(id.to_string().as_str())
);
}
struct TestEngine;
impl super::HttpEngine for TestEngine {
type Request = HttpRequest;
type Response = HttpResponse;
type SseEvent = (Option<u64>, String);
async fn adapt_request(_req: Self::Request) -> Result<HttpRequest, crate::error::Error> {
unreachable!()
}
fn adapt_response(_resp: HttpResponse) -> Self::Response {
unreachable!()
}
fn tracked_event(seq: u64, msg: &Message) -> Self::SseEvent {
(Some(seq), serde_json::to_string(msg).unwrap())
}
fn ephemeral_event(msg: &Message) -> Self::SseEvent {
(None, serde_json::to_string(msg).unwrap())
}
async fn run(
self,
_ctx: HttpContext,
_token: tokio_util::sync::CancellationToken,
) -> Result<(), crate::error::Error> {
unreachable!()
}
}
#[tokio::test]
async fn get_without_session_id_returns_400() {
let (ctx, _rx) = make_ctx();
let req = http::Request::builder()
.method("GET")
.uri("/mcp")
.body(Bytes::new())
.unwrap();
let resp = handle_get_sse::<TestEngine>(req, &ctx).await;
match resp {
SseResponse::Status(r) => assert_eq!(r.status(), http::StatusCode::BAD_REQUEST),
SseResponse::Stream { .. } => panic!("expected Status, got Stream"),
}
}
#[tokio::test]
async fn get_with_session_returns_stream_with_session_header() {
let (ctx, _rx) = make_ctx();
let id = uuid::Uuid::new_v4();
ctx.sse_registry.pre_register(id);
let req = http::Request::builder()
.method("GET")
.uri("/mcp")
.header(MCP_SESSION_ID, id.to_string())
.body(Bytes::new())
.unwrap();
let resp = handle_get_sse::<TestEngine>(req, &ctx).await;
match resp {
SseResponse::Stream { headers, stream: _ } => {
assert_eq!(
headers.get(MCP_SESSION_ID).and_then(|v| v.to_str().ok()),
Some(id.to_string().as_str())
);
}
SseResponse::Status(_) => panic!("expected Stream, got Status"),
}
}
}