use std::{sync::Arc, time::Duration};
use actix_web::{
HttpRequest, HttpResponse, Result, Scope,
error::InternalError,
http::{
StatusCode,
header::{self, CACHE_CONTROL},
},
middleware,
web::{self, Bytes, Data},
};
use futures::{Stream, StreamExt};
use tokio_stream::wrappers::ReceiverStream;
pub type OnRequestHook = dyn Fn(&HttpRequest, &mut rmcp::model::Extensions) + Send + Sync + 'static;
use rmcp::{
RoleServer,
model::{ClientJsonRpcMessage, ClientRequest},
serve_server,
service::serve_directly,
transport::{
OneshotTransport, TransportAdapterIdentity,
common::http_header::{HEADER_LAST_EVENT_ID, HEADER_SESSION_ID},
streamable_http_server::session::SessionManager,
},
};
use rmcp::model::GetExtensions;
#[cfg(feature = "authorization-token-passthrough")]
use super::AuthorizationHeader;
const HEADER_X_ACCEL_BUFFERING: &str = "X-Accel-Buffering";
const EVENT_STREAM_MIME_TYPE: &str = "text/event-stream";
const JSON_MIME_TYPE: &str = "application/json";
#[derive(Debug, Clone)]
pub struct StreamableHttpServerConfig {
pub stateful_mode: bool,
pub sse_keep_alive: Option<Duration>,
}
impl Default for StreamableHttpServerConfig {
fn default() -> Self {
Self {
stateful_mode: true,
sse_keep_alive: None,
}
}
}
#[derive(bon::Builder)]
pub struct StreamableHttpService<
S,
M = rmcp::transport::streamable_http_server::session::local::LocalSessionManager,
> {
service_factory: Arc<dyn Fn() -> Result<S, std::io::Error> + Send + Sync>,
session_manager: Arc<M>,
#[builder(default = true)]
stateful_mode: bool,
sse_keep_alive: Option<Duration>,
on_request: Option<Arc<OnRequestHook>>,
}
impl<S, M> Clone for StreamableHttpService<S, M> {
fn clone(&self) -> Self {
Self {
service_factory: self.service_factory.clone(),
session_manager: self.session_manager.clone(),
stateful_mode: self.stateful_mode,
sse_keep_alive: self.sse_keep_alive,
on_request: self.on_request.clone(),
}
}
}
impl<S, M, State: streamable_http_service_builder::State> StreamableHttpServiceBuilder<S, M, State>
where
State::OnRequest: streamable_http_service_builder::IsUnset,
{
pub fn on_request_fn(
self,
hook: impl Fn(&HttpRequest, &mut rmcp::model::Extensions) + Send + Sync + 'static,
) -> StreamableHttpServiceBuilder<S, M, streamable_http_service_builder::SetOnRequest<State>>
{
self.on_request(Arc::new(hook))
}
}
#[derive(Clone)]
struct AppData<S, M> {
service_factory: Arc<dyn Fn() -> Result<S, std::io::Error> + Send + Sync>,
session_manager: Arc<M>,
stateful_mode: bool,
sse_keep_alive: Option<Duration>,
on_request: Option<Arc<OnRequestHook>>,
}
impl<S, M> AppData<S, M> {
fn get_service(&self) -> Result<S, std::io::Error> {
(self.service_factory)()
}
}
fn wrap_with_sse_keepalive<S>(
stream: S,
keep_alive: Option<Duration>,
) -> impl Stream<Item = Result<Bytes, actix_web::Error>>
where
S: Stream<Item = Result<Bytes, actix_web::Error>> + Send + 'static,
{
async_stream::stream! {
let mut stream = Box::pin(stream);
let mut keep_alive_timer = keep_alive.map(|duration| tokio::time::interval(duration));
if let Some(ref mut timer) = keep_alive_timer {
timer.tick().await;
}
loop {
tokio::select! {
result = stream.next() => {
match result {
Some(msg) => yield msg,
None => break, }
}
_ = async {
match keep_alive_timer.as_mut() {
Some(timer) => {
timer.tick().await;
}
None => {
std::future::pending::<()>().await;
}
}
} => {
yield Ok(Bytes::from(":ping\n\n"));
}
}
}
}
}
impl<S, M> StreamableHttpService<S, M>
where
S: Clone + rmcp::ServerHandler + Send + 'static,
M: SessionManager + 'static,
{
pub fn scope(
self,
) -> Scope<
impl actix_web::dev::ServiceFactory<
actix_web::dev::ServiceRequest,
Config = (),
Response = actix_web::dev::ServiceResponse,
Error = actix_web::Error,
InitError = (),
>,
> {
self.scope_with_path("")
}
pub fn scope_with_path(
self,
path: &str,
) -> Scope<
impl actix_web::dev::ServiceFactory<
actix_web::dev::ServiceRequest,
Config = (),
Response = actix_web::dev::ServiceResponse,
Error = actix_web::Error,
InitError = (),
>,
> {
let app_data = AppData {
service_factory: self.service_factory,
session_manager: self.session_manager,
stateful_mode: self.stateful_mode,
sse_keep_alive: self.sse_keep_alive,
on_request: self.on_request,
};
web::scope(path)
.app_data(Data::new(app_data))
.wrap(middleware::NormalizePath::trim())
.route("", web::get().to(Self::handle_get))
.route("", web::post().to(Self::handle_post))
.route("", web::delete().to(Self::handle_delete))
}
async fn handle_get(req: HttpRequest, service: Data<AppData<S, M>>) -> Result<HttpResponse> {
let accept = req
.headers()
.get(header::ACCEPT)
.and_then(|h| h.to_str().ok());
if !accept.is_some_and(|header| header.contains(EVENT_STREAM_MIME_TYPE)) {
return Ok(HttpResponse::NotAcceptable()
.body("Not Acceptable: Client must accept text/event-stream"));
}
let session_id = req
.headers()
.get(HEADER_SESSION_ID)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_owned().into());
let Some(session_id) = session_id else {
return Ok(HttpResponse::Unauthorized().body("Unauthorized: Session ID is required"));
};
tracing::debug!(%session_id, "GET request for SSE stream");
let has_session = service
.session_manager
.has_session(&session_id)
.await
.map_err(|e| InternalError::new(e, StatusCode::INTERNAL_SERVER_ERROR))?;
if !has_session {
return Ok(HttpResponse::Unauthorized().body("Unauthorized: Session not found"));
}
let last_event_id = req
.headers()
.get(HEADER_LAST_EVENT_ID)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_owned());
let sse_stream: std::pin::Pin<Box<dyn Stream<Item = _> + Send>> =
if let Some(last_event_id) = last_event_id {
tracing::debug!(%session_id, %last_event_id, "Resuming stream from last event");
Box::pin(
service
.session_manager
.resume(&session_id, last_event_id)
.await
.map_err(|e| InternalError::new(e, StatusCode::INTERNAL_SERVER_ERROR))?,
)
} else {
tracing::debug!(%session_id, "Creating standalone stream");
Box::pin(
service
.session_manager
.create_standalone_stream(&session_id)
.await
.map_err(|e| InternalError::new(e, StatusCode::INTERNAL_SERVER_ERROR))?,
)
};
let formatted_stream = sse_stream.map(|msg| {
let data = serde_json::to_string(&msg.message).unwrap_or_else(|_| "{}".to_string());
let mut output = String::new();
if let Some(id) = msg.event_id {
output.push_str(&format!("id: {id}\n"));
}
output.push_str(&format!("data: {data}\n\n"));
Ok::<_, actix_web::Error>(Bytes::from(output))
});
let sse_stream = wrap_with_sse_keepalive(formatted_stream, service.sse_keep_alive);
Ok(HttpResponse::Ok()
.content_type(EVENT_STREAM_MIME_TYPE)
.append_header((CACHE_CONTROL, "no-cache"))
.append_header((HEADER_X_ACCEL_BUFFERING, "no"))
.streaming(sse_stream))
}
async fn handle_post(
req: HttpRequest,
body: Bytes,
service: Data<AppData<S, M>>,
) -> Result<HttpResponse> {
let accept = req
.headers()
.get(header::ACCEPT)
.and_then(|h| h.to_str().ok());
if !accept.is_some_and(|header| {
header.contains(JSON_MIME_TYPE) && header.contains(EVENT_STREAM_MIME_TYPE)
}) {
return Ok(HttpResponse::NotAcceptable().body(
"Not Acceptable: Client must accept both application/json and text/event-stream",
));
}
let content_type = req
.headers()
.get(header::CONTENT_TYPE)
.and_then(|h| h.to_str().ok());
if !content_type.is_some_and(|header| header.starts_with(JSON_MIME_TYPE)) {
return Ok(HttpResponse::UnsupportedMediaType()
.body("Unsupported Media Type: Content-Type must be application/json"));
}
let mut message: ClientJsonRpcMessage = serde_json::from_slice(&body)
.map_err(|e| InternalError::new(e, StatusCode::BAD_REQUEST))?;
tracing::debug!(?message, "POST request with message");
if service.stateful_mode {
let session_id = req
.headers()
.get(HEADER_SESSION_ID)
.and_then(|v| v.to_str().ok());
if let Some(session_id) = session_id {
let session_id = session_id.to_owned().into();
tracing::debug!(%session_id, "POST request with existing session");
let has_session = service
.session_manager
.has_session(&session_id)
.await
.map_err(|e| InternalError::new(e, StatusCode::INTERNAL_SERVER_ERROR))?;
if !has_session {
tracing::warn!(%session_id, "Session not found");
return Ok(HttpResponse::Unauthorized().body("Unauthorized: Session not found"));
}
match message {
#[allow(unused_mut)]
ClientJsonRpcMessage::Request(mut request_msg) => {
if let Some(ref hook) = service.on_request {
hook(&req, request_msg.request.extensions_mut());
}
#[cfg(feature = "authorization-token-passthrough")]
if let Some(auth_value) = req.headers().get(header::AUTHORIZATION) {
match auth_value.to_str() {
Ok(auth_str)
if auth_str.starts_with("Bearer ") && auth_str.len() > 7 =>
{
tracing::debug!(
"Forwarding Authorization header to MCP service for existing session. \
Note: MCP services must not pass this token to upstream APIs per MCP spec. \
See SECURITY.md for details."
);
request_msg
.request
.extensions_mut()
.insert(AuthorizationHeader(auth_str.to_string()));
}
Ok(auth_str) if auth_str == "Bearer" || auth_str == "Bearer " => {
tracing::debug!(
"Malformed Bearer token in existing session: missing token value"
);
}
Ok(auth_str) if !auth_str.starts_with("Bearer ") => {
let auth_type =
auth_str.split_whitespace().next().unwrap_or("unknown");
tracing::warn!(
"Non-Bearer authorization header ignored for existing session: {}",
auth_type
);
}
Err(e) => {
tracing::debug!(
"Invalid Authorization header encoding in existing session: {}",
e
);
}
_ => {}
}
}
#[cfg(not(feature = "authorization-token-passthrough"))]
if req.headers().get(header::AUTHORIZATION).is_some() {
tracing::warn!(
"Authorization header present but not forwarded. \
Enable 'authorization-token-passthrough' feature to forward tokens to MCP services. \
Note: Token passthrough violates MCP specifications. See SECURITY.md for details."
);
}
let stream = service
.session_manager
.create_stream(&session_id, ClientJsonRpcMessage::Request(request_msg))
.await
.map_err(|e| {
InternalError::new(e, StatusCode::INTERNAL_SERVER_ERROR)
})?;
let formatted_stream = stream.map(|msg| {
let data = serde_json::to_string(&msg.message)
.unwrap_or_else(|_| "{}".to_string());
let mut output = String::new();
if let Some(id) = msg.event_id {
output.push_str(&format!("id: {id}\n"));
}
output.push_str(&format!("data: {data}\n\n"));
Ok::<_, actix_web::Error>(Bytes::from(output))
});
let sse_stream =
wrap_with_sse_keepalive(formatted_stream, service.sse_keep_alive);
Ok(HttpResponse::Ok()
.content_type(EVENT_STREAM_MIME_TYPE)
.append_header((CACHE_CONTROL, "no-cache"))
.append_header((HEADER_X_ACCEL_BUFFERING, "no"))
.streaming(sse_stream))
}
ClientJsonRpcMessage::Notification(_)
| ClientJsonRpcMessage::Response(_)
| ClientJsonRpcMessage::Error(_) => {
service
.session_manager
.accept_message(&session_id, message)
.await
.map_err(|e| {
InternalError::new(e, StatusCode::INTERNAL_SERVER_ERROR)
})?;
Ok(HttpResponse::Accepted().finish())
}
}
} else {
tracing::debug!("POST request without session, creating new session");
let (session_id, transport) = service
.session_manager
.create_session()
.await
.map_err(|e| InternalError::new(e, StatusCode::INTERNAL_SERVER_ERROR))?;
tracing::info!(%session_id, "Created new session");
if let ClientJsonRpcMessage::Request(request_msg) = &mut message {
if !matches!(request_msg.request, ClientRequest::InitializeRequest(_)) {
return Ok(
HttpResponse::UnprocessableEntity().body("Expected initialize request")
);
}
if let Some(ref hook) = service.on_request {
hook(&req, request_msg.request.extensions_mut());
}
#[cfg(feature = "authorization-token-passthrough")]
if let Some(auth_value) = req.headers().get(header::AUTHORIZATION) {
match auth_value.to_str() {
Ok(auth_str)
if auth_str.starts_with("Bearer ") && auth_str.len() > 7 =>
{
tracing::debug!(
"Forwarding Authorization header to MCP service for new session. \
Note: MCP services must not pass this token to upstream APIs per MCP spec. \
See SECURITY.md for details."
);
request_msg
.request
.extensions_mut()
.insert(AuthorizationHeader(auth_str.to_string()));
}
Ok(auth_str) if auth_str == "Bearer" || auth_str == "Bearer " => {
tracing::debug!(
"Malformed Bearer token in new session: missing token value"
);
}
Ok(auth_str) if !auth_str.starts_with("Bearer ") => {
let auth_type =
auth_str.split_whitespace().next().unwrap_or("unknown");
tracing::warn!(
"Non-Bearer authorization header ignored for new session: {}",
auth_type
);
}
Err(e) => {
tracing::debug!(
"Invalid Authorization header encoding in new session: {}",
e
);
}
_ => {}
}
}
#[cfg(not(feature = "authorization-token-passthrough"))]
if req.headers().get(header::AUTHORIZATION).is_some() {
tracing::warn!(
"Authorization header present but not forwarded for new session. \
Enable 'authorization-token-passthrough' feature to forward tokens to MCP services. \
Note: Token passthrough violates MCP specifications. See SECURITY.md for details."
);
}
} else {
return Ok(
HttpResponse::UnprocessableEntity().body("Expected initialize request")
);
}
let service_instance = service
.get_service()
.map_err(|e| InternalError::new(e, StatusCode::INTERNAL_SERVER_ERROR))?;
tokio::spawn({
let session_manager = service.session_manager.clone();
let session_id = session_id.clone();
async move {
let service = serve_server::<S, M::Transport, _, TransportAdapterIdentity>(
service_instance,
transport,
)
.await;
match service {
Ok(service) => {
let _ = service.waiting().await;
}
Err(e) => {
tracing::error!("Failed to create service: {e}");
}
}
let _ = session_manager
.close_session(&session_id)
.await
.inspect_err(|e| {
tracing::error!("Failed to close session {session_id}: {e}");
});
}
});
let response = service
.session_manager
.initialize_session(&session_id, message)
.await
.map_err(|e| InternalError::new(e, StatusCode::INTERNAL_SERVER_ERROR))?;
tracing::debug!(?response, "Initialization complete, creating SSE stream");
let sse_stream = async_stream::stream! {
yield Ok::<_, actix_web::Error>(Bytes::from(format!(
"data: {}\n\n",
serde_json::to_string(&response).unwrap_or_else(|_| "{}".to_string())
)));
};
tracing::debug!("Created initialization response stream (closes after response)");
tracing::info!(
?session_id,
"Returning SSE streaming response for initialization"
);
Ok(HttpResponse::Ok()
.content_type(EVENT_STREAM_MIME_TYPE)
.append_header((CACHE_CONTROL, "no-cache"))
.append_header((HEADER_X_ACCEL_BUFFERING, "no"))
.append_header((HEADER_SESSION_ID, session_id.as_ref()))
.streaming(sse_stream))
}
} else {
tracing::debug!("POST request in stateless mode");
match message {
#[allow(unused_mut)]
ClientJsonRpcMessage::Request(mut request) => {
tracing::debug!(?request, "Processing request in stateless mode");
if let Some(ref hook) = service.on_request {
hook(&req, request.request.extensions_mut());
}
#[cfg(feature = "authorization-token-passthrough")]
if let Some(auth_value) = req.headers().get(header::AUTHORIZATION) {
match auth_value.to_str() {
Ok(auth_str)
if auth_str.starts_with("Bearer ") && auth_str.len() > 7 =>
{
tracing::debug!(
"Forwarding Authorization header to MCP service in stateless mode. \
Note: MCP services must not pass this token to upstream APIs per MCP spec. \
See SECURITY.md for details."
);
request
.request
.extensions_mut()
.insert(AuthorizationHeader(auth_str.to_string()));
}
Ok(auth_str) if auth_str == "Bearer" || auth_str == "Bearer " => {
tracing::debug!(
"Malformed Bearer token in stateless mode: missing token value"
);
}
Ok(auth_str) if !auth_str.starts_with("Bearer ") => {
let auth_type =
auth_str.split_whitespace().next().unwrap_or("unknown");
tracing::warn!(
"Non-Bearer authorization header ignored in stateless mode: {}",
auth_type
);
}
Err(e) => {
tracing::debug!(
"Invalid Authorization header encoding in stateless mode: {}",
e
);
}
_ => {}
}
}
#[cfg(not(feature = "authorization-token-passthrough"))]
if req.headers().get(header::AUTHORIZATION).is_some() {
tracing::warn!(
"Authorization header present but not forwarded in stateless mode. \
Enable 'authorization-token-passthrough' feature to forward tokens to MCP services. \
Note: Token passthrough violates MCP specifications. See SECURITY.md for details."
);
}
let service_instance = service
.get_service()
.map_err(|e| InternalError::new(e, StatusCode::INTERNAL_SERVER_ERROR))?;
let (transport, receiver) =
OneshotTransport::<RoleServer>::new(ClientJsonRpcMessage::Request(request));
let service_handle = serve_directly(service_instance, transport, None);
tokio::spawn(async move {
let _ = service_handle.waiting().await;
});
let formatted_stream = ReceiverStream::new(receiver).map(|message| {
tracing::info!(?message);
let data =
serde_json::to_string(&message).unwrap_or_else(|_| "{}".to_string());
Ok::<_, actix_web::Error>(Bytes::from(format!("data: {data}\n\n")))
});
let sse_stream =
wrap_with_sse_keepalive(formatted_stream, service.sse_keep_alive);
Ok(HttpResponse::Ok()
.content_type(EVENT_STREAM_MIME_TYPE)
.append_header((CACHE_CONTROL, "no-cache"))
.append_header((HEADER_X_ACCEL_BUFFERING, "no"))
.streaming(sse_stream))
}
_ => Ok(HttpResponse::UnprocessableEntity().body("Unexpected message type")),
}
}
}
async fn handle_delete(req: HttpRequest, service: Data<AppData<S, M>>) -> Result<HttpResponse> {
let session_id = req
.headers()
.get(HEADER_SESSION_ID)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_owned().into());
let Some(session_id) = session_id else {
return Ok(HttpResponse::Unauthorized().body("Unauthorized: Session ID is required"));
};
tracing::debug!(%session_id, "DELETE request to close session");
service
.session_manager
.close_session(&session_id)
.await
.map_err(|e| InternalError::new(e, StatusCode::INTERNAL_SERVER_ERROR))?;
tracing::info!(%session_id, "Session closed");
Ok(HttpResponse::NoContent().finish())
}
}