use std::{collections::HashMap, convert::Infallible, fmt::Display, sync::Arc, time::Duration};
use bytes::Bytes;
use futures::{StreamExt, future::BoxFuture};
use http::{HeaderMap, Method, Request, Response, header::ALLOW};
use http_body::Body;
use http_body_util::{BodyExt, Full, combinators::BoxBody};
use tokio_stream::wrappers::ReceiverStream;
use tokio_util::sync::CancellationToken;
use super::session::{
RestoreOutcome, SessionId, SessionManager, SessionRestoreMarker, SessionState, SessionStore,
};
use crate::{
RoleServer,
model::{
ClientJsonRpcMessage, ClientNotification, ClientRequest, GetExtensions, InitializeRequest,
InitializedNotification, ProtocolVersion,
},
serve_server,
service::serve_directly,
transport::{
OneshotTransport, TransportAdapterIdentity,
common::{
http_header::{
EVENT_STREAM_MIME_TYPE, HEADER_LAST_EVENT_ID, HEADER_MCP_PROTOCOL_VERSION,
HEADER_SESSION_ID, JSON_MIME_TYPE,
},
server_side_http::{
BoxResponse, ServerSseMessage, accepted_response, expect_json,
internal_error_response, sse_stream_response, unexpected_message_response,
},
},
},
};
#[non_exhaustive]
#[derive(Debug, Clone)]
pub struct StreamableHttpServerConfig {
pub sse_keep_alive: Option<Duration>,
pub sse_retry: Option<Duration>,
pub stateful_mode: bool,
pub json_response: bool,
pub cancellation_token: CancellationToken,
pub allowed_hosts: Vec<String>,
pub allowed_origins: Vec<String>,
pub session_store: Option<Arc<dyn SessionStore>>,
}
impl std::fmt::Debug for dyn SessionStore {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("<SessionStore>")
}
}
impl Default for StreamableHttpServerConfig {
fn default() -> Self {
Self {
sse_keep_alive: Some(Duration::from_secs(15)),
sse_retry: Some(Duration::from_secs(3)),
stateful_mode: true,
json_response: false,
cancellation_token: CancellationToken::new(),
allowed_hosts: vec!["localhost".into(), "127.0.0.1".into(), "::1".into()],
allowed_origins: vec![],
session_store: None,
}
}
}
impl StreamableHttpServerConfig {
pub fn with_allowed_hosts(
mut self,
allowed_hosts: impl IntoIterator<Item = impl Into<String>>,
) -> Self {
self.allowed_hosts = allowed_hosts.into_iter().map(Into::into).collect();
self
}
pub fn disable_allowed_hosts(mut self) -> Self {
self.allowed_hosts.clear();
self
}
pub fn with_allowed_origins(
mut self,
allowed_origins: impl IntoIterator<Item = impl Into<String>>,
) -> Self {
self.allowed_origins = allowed_origins.into_iter().map(Into::into).collect();
self
}
pub fn disable_allowed_origins(mut self) -> Self {
self.allowed_origins.clear();
self
}
pub fn with_sse_keep_alive(mut self, duration: Option<Duration>) -> Self {
self.sse_keep_alive = duration;
self
}
pub fn with_sse_retry(mut self, duration: Option<Duration>) -> Self {
self.sse_retry = duration;
self
}
pub fn with_stateful_mode(mut self, stateful: bool) -> Self {
self.stateful_mode = stateful;
self
}
pub fn with_json_response(mut self, json_response: bool) -> Self {
self.json_response = json_response;
self
}
pub fn with_cancellation_token(mut self, token: CancellationToken) -> Self {
self.cancellation_token = token;
self
}
}
#[expect(
clippy::result_large_err,
reason = "BoxResponse is intentionally large; matches other handlers in this file"
)]
fn validate_protocol_version_header(headers: &http::HeaderMap) -> Result<(), BoxResponse> {
if let Some(value) = headers.get(HEADER_MCP_PROTOCOL_VERSION) {
let version_str = value.to_str().map_err(|_| {
Response::builder()
.status(http::StatusCode::BAD_REQUEST)
.body(
Full::new(Bytes::from(
"Bad Request: Invalid MCP-Protocol-Version header encoding",
))
.boxed(),
)
.expect("valid response")
})?;
let is_known = ProtocolVersion::KNOWN_VERSIONS
.iter()
.any(|v| v.as_str() == version_str);
if !is_known {
return Err(Response::builder()
.status(http::StatusCode::BAD_REQUEST)
.body(
Full::new(Bytes::from(format!(
"Bad Request: Unsupported MCP-Protocol-Version: {version_str}"
)))
.boxed(),
)
.expect("valid response"));
}
}
Ok(())
}
fn forbidden_response(message: impl Into<String>) -> BoxResponse {
Response::builder()
.status(http::StatusCode::FORBIDDEN)
.body(Full::new(Bytes::from(message.into())).boxed())
.expect("valid response")
}
fn normalize_host(host: &str) -> String {
host.trim_matches('[')
.trim_matches(']')
.to_ascii_lowercase()
}
#[derive(Debug, Clone, PartialEq, Eq)]
struct NormalizedAuthority {
host: String,
port: Option<u16>,
}
fn normalize_authority(host: &str, port: Option<u16>) -> NormalizedAuthority {
NormalizedAuthority {
host: normalize_host(host),
port,
}
}
fn parse_allowed_authority(allowed: &str) -> Option<NormalizedAuthority> {
let allowed = allowed.trim();
if allowed.is_empty() {
return None;
}
if let Ok(authority) = http::uri::Authority::try_from(allowed) {
return Some(normalize_authority(authority.host(), authority.port_u16()));
}
Some(normalize_authority(allowed, None))
}
fn host_is_allowed(host: &NormalizedAuthority, allowed_hosts: &[String]) -> bool {
if allowed_hosts.is_empty() {
return true;
}
allowed_hosts
.iter()
.filter_map(|allowed| parse_allowed_authority(allowed))
.any(|allowed| {
allowed.host == host.host
&& match allowed.port {
Some(port) => host.port == Some(port),
None => true,
}
})
}
#[derive(Debug, Clone, PartialEq, Eq)]
enum NormalizedOrigin {
Null,
Tuple {
scheme: String,
host: String,
port: Option<u16>,
},
}
fn parse_origin_value(value: &str) -> Option<NormalizedOrigin> {
let value = value.trim();
if value.is_empty() {
return None;
}
if value.eq_ignore_ascii_case("null") {
return Some(NormalizedOrigin::Null);
}
let uri = http::Uri::try_from(value).ok()?;
let scheme = uri.scheme_str()?.to_ascii_lowercase();
let authority = uri.authority()?;
Some(NormalizedOrigin::Tuple {
scheme,
host: normalize_host(authority.host()),
port: authority.port_u16(),
})
}
fn origin_is_allowed(origin: &NormalizedOrigin, allowed_origins: &[String]) -> bool {
if allowed_origins.is_empty() {
return true;
}
allowed_origins
.iter()
.filter_map(|raw| parse_origin_value(raw))
.any(|allowed| match (&allowed, origin) {
(NormalizedOrigin::Null, NormalizedOrigin::Null) => true,
(
NormalizedOrigin::Tuple {
scheme: a_scheme,
host: a_host,
port: a_port,
},
NormalizedOrigin::Tuple {
scheme: o_scheme,
host: o_host,
port: o_port,
},
) => a_scheme == o_scheme && a_host == o_host && (a_port.is_none() || a_port == o_port),
_ => false,
})
}
fn bad_request_response(message: &str) -> BoxResponse {
let body = Full::from(message.to_string()).boxed();
http::Response::builder()
.status(http::StatusCode::BAD_REQUEST)
.header(http::header::CONTENT_TYPE, "text/plain; charset=utf-8")
.body(body)
.expect("failed to build bad request response")
}
fn parse_host_header(
uri: &http::Uri,
headers: &HeaderMap,
) -> Result<NormalizedAuthority, BoxResponse> {
if let Some(host) = headers.get(http::header::HOST) {
let host_str = host
.to_str()
.inspect_err(|_| {
tracing::warn!(host = ?host, "rejected request with non-UTF-8 Host header");
})
.map_err(|_| bad_request_response("Bad Request: Invalid Host header encoding"))?;
let authority = http::uri::Authority::try_from(host_str)
.inspect_err(|_| {
tracing::warn!(
host = host_str,
"rejected request with malformed Host header"
);
})
.map_err(|_| bad_request_response("Bad Request: Invalid Host header"))?;
return Ok(normalize_authority(authority.host(), authority.port_u16()));
}
let authority = uri.authority().ok_or_else(|| {
tracing::warn!("rejected request with missing Host header and no :authority");
bad_request_response("Bad Request: missing Host header")
})?;
Ok(normalize_authority(authority.host(), authority.port_u16()))
}
fn validate_dns_rebinding_headers(
uri: &http::Uri,
headers: &HeaderMap,
config: &StreamableHttpServerConfig,
) -> Result<(), BoxResponse> {
let host = parse_host_header(uri, headers)?;
if !host_is_allowed(&host, &config.allowed_hosts) {
tracing::warn!(
host = ?host,
"rejected request with disallowed Host header (possible DNS rebinding attempt)",
);
return Err(forbidden_response("Forbidden: Host header is not allowed"));
}
validate_origin_header(headers, &config.allowed_origins)?;
Ok(())
}
fn validate_origin_header(
headers: &HeaderMap,
allowed_origins: &[String],
) -> Result<(), BoxResponse> {
if allowed_origins.is_empty() {
return Ok(());
}
let Some(origin_header) = headers.get(http::header::ORIGIN) else {
return Ok(());
};
let origin_str = origin_header
.to_str()
.inspect_err(|_| {
tracing::warn!(origin = ?origin_header, "rejected request with non-UTF-8 Origin header");
})
.map_err(|_| bad_request_response("Bad Request: Invalid Origin header encoding"))?;
let origin = parse_origin_value(origin_str).ok_or_else(|| {
tracing::warn!(
origin = origin_str,
"rejected request with malformed Origin header",
);
bad_request_response("Bad Request: Invalid Origin header")
})?;
if !origin_is_allowed(&origin, allowed_origins) {
tracing::warn!(
origin = ?origin,
"rejected request with disallowed Origin header (possible cross-origin attack)",
);
return Err(forbidden_response(
"Forbidden: Origin header is not allowed",
));
}
Ok(())
}
pub struct StreamableHttpService<S, M> {
pub config: StreamableHttpServerConfig,
session_manager: Arc<M>,
service_factory: Arc<dyn Fn() -> Result<S, std::io::Error> + Send + Sync>,
pending_restores: Option<
Arc<tokio::sync::RwLock<HashMap<SessionId, tokio::sync::watch::Sender<Option<bool>>>>>,
>,
}
impl<S, M> Clone for StreamableHttpService<S, M> {
fn clone(&self) -> Self {
Self {
config: self.config.clone(),
session_manager: self.session_manager.clone(),
service_factory: self.service_factory.clone(),
pending_restores: self.pending_restores.clone(),
}
}
}
impl<RequestBody, S, M> tower_service::Service<Request<RequestBody>> for StreamableHttpService<S, M>
where
RequestBody: Body + Send + 'static,
S: crate::Service<RoleServer> + Send + 'static,
M: SessionManager,
RequestBody::Error: Display,
RequestBody::Data: Send + 'static,
{
type Response = BoxResponse;
type Error = Infallible;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn call(&mut self, req: http::Request<RequestBody>) -> Self::Future {
let service = self.clone();
Box::pin(async move {
let response = service.handle(req).await;
Ok(response)
})
}
fn poll_ready(
&mut self,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
std::task::Poll::Ready(Ok(()))
}
}
struct PendingRestoreGuard {
pending_restores:
Arc<tokio::sync::RwLock<HashMap<SessionId, tokio::sync::watch::Sender<Option<bool>>>>>,
session_id: SessionId,
watch_tx: tokio::sync::watch::Sender<Option<bool>>,
result: bool,
}
impl Drop for PendingRestoreGuard {
fn drop(&mut self) {
let _ = self.watch_tx.send(Some(self.result));
let pending_restores = self.pending_restores.clone();
let session_id = self.session_id.clone();
tokio::spawn(async move {
pending_restores.write().await.remove(&session_id);
});
}
}
impl<S, M> StreamableHttpService<S, M>
where
S: crate::Service<RoleServer> + Send + 'static,
M: SessionManager,
{
pub fn new(
service_factory: impl Fn() -> Result<S, std::io::Error> + Send + Sync + 'static,
session_manager: Arc<M>,
config: StreamableHttpServerConfig,
) -> Self {
let pending_restores = config.session_store.is_some().then(|| {
Arc::new(tokio::sync::RwLock::new(HashMap::<
SessionId,
tokio::sync::watch::Sender<Option<bool>>,
>::new()))
});
Self {
config,
session_manager,
service_factory: Arc::new(service_factory),
pending_restores,
}
}
fn get_service(&self) -> Result<S, std::io::Error> {
(self.service_factory)()
}
fn spawn_session_worker(
session_manager: Arc<M>,
session_id: SessionId,
service: S,
transport: M::Transport,
init_done_tx: Option<tokio::sync::oneshot::Sender<()>>,
) where
S: crate::Service<RoleServer> + Send + 'static,
M: SessionManager,
{
tokio::spawn(async move {
let svc =
serve_server::<S, M::Transport, _, TransportAdapterIdentity>(service, transport)
.await;
match svc {
Ok(svc) => {
if let Some(tx) = init_done_tx {
let _ = tx.send(());
}
let _ = svc.waiting().await;
}
Err(e) => {
tracing::error!("Failed to serve session: {e}");
}
}
let _ = session_manager
.close_session(&session_id)
.await
.inspect_err(|e| {
tracing::error!("Failed to close session {session_id}: {e}");
});
});
}
async fn try_restore_from_store(
&self,
session_id: &SessionId,
parts: &http::request::Parts,
) -> Result<bool, std::io::Error>
where
S: crate::Service<RoleServer> + Send + 'static,
M: SessionManager,
{
let (Some(pending_restores), Some(store)) =
(&self.pending_restores, &self.config.session_store)
else {
return Ok(false);
};
let (watch_tx, _watch_rx) = tokio::sync::watch::channel(None::<bool>);
{
let mut pending = pending_restores.write().await;
if let Some(tx) = pending.get(session_id) {
let mut rx = tx.subscribe();
drop(pending);
let result = rx
.wait_for(|r| r.is_some())
.await
.map(|r| r.unwrap_or(false))
.unwrap_or(false);
return Ok(result);
}
pending.insert(session_id.clone(), watch_tx.clone());
}
let mut guard = PendingRestoreGuard {
pending_restores: pending_restores.clone(),
session_id: session_id.clone(),
watch_tx: watch_tx.clone(),
result: false,
};
let state = match store.load(session_id.as_ref()).await {
Ok(Some(s)) => s,
Ok(None) => {
return Ok(false);
}
Err(e) => {
tracing::error!(
session_id = session_id.as_ref(),
error = %e,
"session store load failed during restore"
);
return Err(std::io::Error::other(e));
}
};
let transport = match self
.session_manager
.restore_session(session_id.clone())
.await
.map_err(|e| std::io::Error::other(e.to_string()))
{
Ok(RestoreOutcome::Restored(t)) => t,
Ok(RestoreOutcome::AlreadyPresent) => {
return Err(std::io::Error::other(
"restore_session returned AlreadyPresent unexpectedly; session manager might have modified the session store outside of the restore_session API",
));
}
Ok(RestoreOutcome::NotSupported) => {
return Ok(false);
}
Err(e) => {
return Err(e);
}
};
let service = match self.get_service() {
Ok(s) => s,
Err(e) => {
return Err(e);
}
};
let mut restore_init = ClientJsonRpcMessage::request(
ClientRequest::InitializeRequest(InitializeRequest {
params: state.initialize_params,
..Default::default()
}),
crate::model::NumberOrString::Number(0),
);
restore_init.insert_extension(parts.clone());
restore_init.insert_extension(SessionRestoreMarker {
id: session_id.clone(),
});
let mut restore_initialized = ClientJsonRpcMessage::notification(
ClientNotification::InitializedNotification(InitializedNotification {
..Default::default()
}),
);
restore_initialized.insert_extension(parts.clone());
restore_initialized.insert_extension(SessionRestoreMarker {
id: session_id.clone(),
});
let (init_done_tx, init_done_rx) = tokio::sync::oneshot::channel::<()>();
Self::spawn_session_worker(
self.session_manager.clone(),
session_id.clone(),
service,
transport,
Some(init_done_tx),
);
if let Err(e) = self
.session_manager
.initialize_session(session_id, restore_init)
.await
.map_err(|e| std::io::Error::other(e.to_string()))
{
return Err(e);
}
if let Err(e) = self
.session_manager
.accept_message(session_id, restore_initialized)
.await
.map_err(|e| std::io::Error::other(e.to_string()))
{
return Err(e);
}
if init_done_rx.await.is_err() {
return Err(std::io::Error::other(
"serve_server initialization failed during restore",
));
}
guard.result = true;
tracing::debug!(
session_id = session_id.as_ref(),
"session restored from external store"
);
Ok(true)
}
pub async fn handle<B>(&self, request: Request<B>) -> Response<BoxBody<Bytes, Infallible>>
where
B: Body + Send + 'static,
B::Error: Display,
{
if let Err(response) =
validate_dns_rebinding_headers(request.uri(), request.headers(), &self.config)
{
return response;
}
let method = request.method().clone();
let allowed_methods = match self.config.stateful_mode {
true => "GET, POST, DELETE",
false => "POST",
};
let result = match (method, self.config.stateful_mode) {
(Method::POST, _) => self.handle_post(request).await,
(Method::GET, true) => self.handle_get(request).await,
(Method::DELETE, true) => self.handle_delete(request).await,
_ => {
let response = Response::builder()
.status(http::StatusCode::METHOD_NOT_ALLOWED)
.header(ALLOW, allowed_methods)
.body(Full::new(Bytes::from("Method Not Allowed")).boxed())
.expect("valid response");
return response;
}
};
match result {
Ok(response) => response,
Err(response) => response,
}
}
async fn handle_get<B>(&self, request: Request<B>) -> Result<BoxResponse, BoxResponse>
where
B: Body + Send + 'static,
B::Error: Display,
{
if !request
.headers()
.get(http::header::ACCEPT)
.and_then(|header| header.to_str().ok())
.is_some_and(|header| header.contains(EVENT_STREAM_MIME_TYPE))
{
return Ok(Response::builder()
.status(http::StatusCode::NOT_ACCEPTABLE)
.body(
Full::new(Bytes::from(
"Not Acceptable: Client must accept text/event-stream",
))
.boxed(),
)
.expect("valid response"));
}
let session_id = request
.headers()
.get(HEADER_SESSION_ID)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_owned().into());
let Some(session_id) = session_id else {
return Ok(Response::builder()
.status(http::StatusCode::BAD_REQUEST)
.body(Full::new(Bytes::from("Bad Request: Session ID is required")).boxed())
.expect("valid response"));
};
let has_session = self
.session_manager
.has_session(&session_id)
.await
.map_err(internal_error_response("check session"))?;
let (parts, _) = request.into_parts();
if !has_session {
let restored = self
.try_restore_from_store(&session_id, &parts)
.await
.map_err(internal_error_response("restore session"))?;
if !restored {
return Ok(Response::builder()
.status(http::StatusCode::NOT_FOUND)
.body(Full::new(Bytes::from("Not Found: Session not found")).boxed())
.expect("valid response"));
}
}
validate_protocol_version_header(&parts.headers)?;
let last_event_id = parts
.headers
.get(HEADER_LAST_EVENT_ID)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_owned());
if let Some(last_event_id) = last_event_id {
match self
.session_manager
.resume(&session_id, last_event_id)
.await
{
Ok(stream) => {
return Ok(sse_stream_response(
stream,
self.config.sse_keep_alive,
self.config.cancellation_token.child_token(),
));
}
Err(e) => {
tracing::warn!("Resume failed ({e}), returning empty stream");
return Ok(sse_stream_response(
futures::stream::empty(),
None,
self.config.cancellation_token.child_token(),
));
}
}
}
let stream = self
.session_manager
.create_standalone_stream(&session_id)
.await
.map_err(internal_error_response("create standalone stream"))?;
let stream = if let Some(retry) = self.config.sse_retry {
let priming = ServerSseMessage::priming("0", retry);
futures::stream::once(async move { priming })
.chain(stream)
.left_stream()
} else {
stream.right_stream()
};
Ok(sse_stream_response(
stream,
self.config.sse_keep_alive,
self.config.cancellation_token.child_token(),
))
}
async fn handle_post<B>(&self, request: Request<B>) -> Result<BoxResponse, BoxResponse>
where
B: Body + Send + 'static,
B::Error: Display,
{
if !request
.headers()
.get(http::header::ACCEPT)
.and_then(|header| header.to_str().ok())
.is_some_and(|header| {
header.contains(JSON_MIME_TYPE) && header.contains(EVENT_STREAM_MIME_TYPE)
})
{
return Ok(Response::builder()
.status(http::StatusCode::NOT_ACCEPTABLE)
.body(Full::new(Bytes::from("Not Acceptable: Client must accept both application/json and text/event-stream")).boxed())
.expect("valid response"));
}
if !request
.headers()
.get(http::header::CONTENT_TYPE)
.and_then(|header| header.to_str().ok())
.is_some_and(|header| header.starts_with(JSON_MIME_TYPE))
{
return Ok(Response::builder()
.status(http::StatusCode::UNSUPPORTED_MEDIA_TYPE)
.body(
Full::new(Bytes::from(
"Unsupported Media Type: Content-Type must be application/json",
))
.boxed(),
)
.expect("valid response"));
}
let (part, body) = request.into_parts();
let mut message = match expect_json(body).await {
Ok(message) => message,
Err(response) => return Ok(response),
};
if self.config.stateful_mode {
let session_id = part
.headers
.get(HEADER_SESSION_ID)
.and_then(|v| v.to_str().ok());
if let Some(session_id) = session_id {
let session_id = session_id.to_owned().into();
let has_session = self
.session_manager
.has_session(&session_id)
.await
.map_err(internal_error_response("check session"))?;
if !has_session {
let restored = self
.try_restore_from_store(&session_id, &part)
.await
.map_err(internal_error_response("restore session"))?;
if !restored {
return Ok(Response::builder()
.status(http::StatusCode::NOT_FOUND)
.body(Full::new(Bytes::from("Not Found: Session not found")).boxed())
.expect("valid response"));
}
}
validate_protocol_version_header(&part.headers)?;
match &mut message {
ClientJsonRpcMessage::Request(req) => {
req.request.extensions_mut().insert(part);
}
ClientJsonRpcMessage::Notification(not) => {
not.notification.extensions_mut().insert(part);
}
_ => {
}
}
match message {
ClientJsonRpcMessage::Request(_) => {
let stream = self
.session_manager
.create_stream(&session_id, message)
.await
.map_err(internal_error_response("get session"))?;
Ok(sse_stream_response(
stream,
self.config.sse_keep_alive,
self.config.cancellation_token.child_token(),
))
}
ClientJsonRpcMessage::Notification(_)
| ClientJsonRpcMessage::Response(_)
| ClientJsonRpcMessage::Error(_) => {
self.session_manager
.accept_message(&session_id, message)
.await
.map_err(internal_error_response("accept message"))?;
Ok(accepted_response())
}
}
} else {
let (session_id, transport) = self
.session_manager
.create_session()
.await
.map_err(internal_error_response("create session"))?;
let stored_init_params = if self.config.session_store.is_some() {
if let ClientJsonRpcMessage::Request(req) = &message {
if let ClientRequest::InitializeRequest(init_req) = &req.request {
Some(init_req.params.clone())
} else {
None
}
} else {
None
}
} else {
None
};
if let ClientJsonRpcMessage::Request(req) = &mut message {
if !matches!(req.request, ClientRequest::InitializeRequest(_)) {
return Err(unexpected_message_response("initialize request"));
}
req.request.extensions_mut().insert(part);
} else {
return Err(unexpected_message_response("initialize request"));
}
let service = self
.get_service()
.map_err(internal_error_response("get service"))?;
Self::spawn_session_worker(
self.session_manager.clone(),
session_id.clone(),
service,
transport,
None,
);
let response = self
.session_manager
.initialize_session(&session_id, message)
.await
.map_err(internal_error_response("create stream"))?;
if let (Some(store), Some(params)) =
(&self.config.session_store, stored_init_params)
{
let state = SessionState {
initialize_params: params,
};
let _ = store
.store(session_id.as_ref(), &state)
.await
.inspect_err(|e| {
tracing::warn!(
"Failed to persist session {} to store: {e}",
session_id
);
});
}
let stream =
futures::stream::once(async move { ServerSseMessage::from_message(response) });
let stream = if let Some(retry) = self.config.sse_retry {
let priming = ServerSseMessage::priming("0", retry);
futures::stream::once(async move { priming })
.chain(stream)
.left_stream()
} else {
stream.right_stream()
};
let mut response = sse_stream_response(
stream,
self.config.sse_keep_alive,
self.config.cancellation_token.child_token(),
);
response.headers_mut().insert(
HEADER_SESSION_ID,
session_id
.parse()
.map_err(internal_error_response("create session id header"))?,
);
Ok(response)
}
} else {
let is_init = matches!(
&message,
ClientJsonRpcMessage::Request(req) if matches!(req.request, ClientRequest::InitializeRequest(_))
);
if !is_init {
validate_protocol_version_header(&part.headers)?;
}
let service = self
.get_service()
.map_err(internal_error_response("get service"))?;
match message {
ClientJsonRpcMessage::Request(mut request) => {
request.request.extensions_mut().insert(part);
let (transport, mut receiver) =
OneshotTransport::<RoleServer>::new(ClientJsonRpcMessage::Request(request));
let service = serve_directly(service, transport, None);
tokio::spawn(async move {
let _ = service.waiting().await;
});
if self.config.json_response {
let cancel = self.config.cancellation_token.child_token();
match tokio::select! {
res = receiver.recv() => res,
_ = cancel.cancelled() => None,
} {
Some(message) => {
tracing::trace!(?message);
let body = serde_json::to_vec(&message).map_err(|e| {
internal_error_response("serialize json response")(e)
})?;
Ok(Response::builder()
.status(http::StatusCode::OK)
.header(http::header::CONTENT_TYPE, JSON_MIME_TYPE)
.body(Full::new(Bytes::from(body)).boxed())
.expect("valid response"))
}
None => Err(internal_error_response("empty response")(
std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"no response message received from handler",
),
)),
}
} else {
let stream = ReceiverStream::new(receiver).map(|message| {
tracing::trace!(?message);
ServerSseMessage::from_message(message)
});
Ok(sse_stream_response(
stream,
self.config.sse_keep_alive,
self.config.cancellation_token.child_token(),
))
}
}
ClientJsonRpcMessage::Notification(_notification) => {
Ok(accepted_response())
}
ClientJsonRpcMessage::Response(_json_rpc_response) => Ok(accepted_response()),
ClientJsonRpcMessage::Error(_json_rpc_error) => Ok(accepted_response()),
}
}
}
async fn handle_delete<B>(&self, request: Request<B>) -> Result<BoxResponse, BoxResponse>
where
B: Body + Send + 'static,
B::Error: Display,
{
let session_id = request
.headers()
.get(HEADER_SESSION_ID)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_owned().into());
let Some(session_id) = session_id else {
return Ok(Response::builder()
.status(http::StatusCode::BAD_REQUEST)
.body(Full::new(Bytes::from("Bad Request: Session ID is required")).boxed())
.expect("valid response"));
};
validate_protocol_version_header(request.headers())?;
self.session_manager
.close_session(&session_id)
.await
.map_err(internal_error_response("close session"))?;
if let Some(store) = &self.config.session_store {
let _ = store.delete(session_id.as_ref()).await.inspect_err(|e| {
tracing::warn!("Failed to delete session {} from store: {e}", session_id);
});
}
Ok(accepted_response())
}
}