use std::collections::HashMap;
use std::convert::Infallible;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant};
use axum::{
Router,
extract::State,
http::{HeaderMap, HeaderValue, StatusCode, header},
response::{IntoResponse, Response, Sse, sse::Event},
routing::{delete, get, post},
};
use tokio::sync::{Mutex, RwLock, broadcast, oneshot};
use tokio_stream::StreamExt;
use tokio_stream::wrappers::BroadcastStream;
use crate::context::{
ChannelClientRequester, ClientRequesterHandle, NotificationReceiver, OutgoingRequestReceiver,
notification_channel, outgoing_request_channel,
};
use crate::error::{Error, JsonRpcError, Result};
use crate::jsonrpc::JsonRpcService;
use crate::protocol::{
JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, LATEST_PROTOCOL_VERSION, McpNotification,
RequestId, SUPPORTED_PROTOCOL_VERSIONS,
};
use crate::router::{McpRouter, RouterRequest, RouterResponse};
use crate::transport::service::{
CatchError, InjectAnnotations, McpBoxService, ServiceFactory, identity_factory,
};
use tower::util::BoxCloneService;
pub const MCP_SESSION_ID_HEADER: &str = "mcp-session-id";
pub const MCP_PROTOCOL_VERSION_HEADER: &str = "mcp-protocol-version";
const SSE_MESSAGE_EVENT: &str = "message";
const LAST_EVENT_ID_HEADER: &str = "last-event-id";
struct PendingRequest {
response_tx: oneshot::Sender<Result<serde_json::Value>>,
}
enum SessionServiceSource {
Router {
router: McpRouter,
factory: ServiceFactory,
},
Boxed(std::sync::Mutex<McpBoxService>),
}
struct Session {
id: String,
service_source: SessionServiceSource,
notifications_tx: broadcast::Sender<String>,
created_at: Instant,
last_accessed: RwLock<Instant>,
pending_requests: Mutex<HashMap<RequestId, PendingRequest>>,
request_rx: Mutex<Option<OutgoingRequestReceiver>>,
protocol_version: RwLock<String>,
event_counter: AtomicU64,
event_store: Arc<dyn crate::event_store::EventStore>,
}
impl Session {
fn new(
router: McpRouter,
sampling_enabled: bool,
service_factory: ServiceFactory,
event_store: Arc<dyn crate::event_store::EventStore>,
) -> Self {
let (notifications_tx, _) = broadcast::channel(100);
let (notif_sender, mut notif_receiver) = notification_channel(256);
let router = router.with_notification_sender(notif_sender);
let broadcast_tx = notifications_tx.clone();
tokio::spawn(async move {
while let Some(notification) = notif_receiver.recv().await {
if let Some(json) = crate::transport::stdio::serialize_notification(¬ification) {
let _ = broadcast_tx.send(json);
}
}
});
let (router, request_rx) = if sampling_enabled {
let (request_tx, request_rx) = outgoing_request_channel(32);
let client_requester: ClientRequesterHandle =
Arc::new(ChannelClientRequester::new(request_tx));
let router = router.with_client_requester(client_requester);
(router, Some(request_rx))
} else {
(router, None)
};
let now = Instant::now();
Self {
id: uuid::Uuid::new_v4().to_string(),
service_source: SessionServiceSource::Router {
router,
factory: service_factory,
},
notifications_tx,
created_at: now,
last_accessed: RwLock::new(now),
pending_requests: Mutex::new(HashMap::new()),
request_rx: Mutex::new(request_rx),
protocol_version: RwLock::new(LATEST_PROTOCOL_VERSION.to_string()),
event_counter: AtomicU64::new(0),
event_store,
}
}
fn from_service(
service: McpBoxService,
event_store: Arc<dyn crate::event_store::EventStore>,
) -> Self {
let (notifications_tx, _) = broadcast::channel(100);
let now = Instant::now();
Self {
id: uuid::Uuid::new_v4().to_string(),
service_source: SessionServiceSource::Boxed(std::sync::Mutex::new(service)),
notifications_tx,
created_at: now,
last_accessed: RwLock::new(now),
pending_requests: Mutex::new(HashMap::new()),
request_rx: Mutex::new(None),
protocol_version: RwLock::new(LATEST_PROTOCOL_VERSION.to_string()),
event_counter: AtomicU64::new(0),
event_store,
}
}
fn restored(
record: &crate::session_store::SessionRecord,
router: McpRouter,
sampling_enabled: bool,
service_factory: ServiceFactory,
event_store: Arc<dyn crate::event_store::EventStore>,
) -> Self {
router.session().mark_initialized();
let (notifications_tx, _) = broadcast::channel(100);
let (notif_sender, mut notif_receiver) = notification_channel(256);
let router = router.with_notification_sender(notif_sender);
let broadcast_tx = notifications_tx.clone();
tokio::spawn(async move {
while let Some(notification) = notif_receiver.recv().await {
if let Some(json) = crate::transport::stdio::serialize_notification(¬ification) {
let _ = broadcast_tx.send(json);
}
}
});
let (router, request_rx) = if sampling_enabled {
let (request_tx, request_rx) = outgoing_request_channel(32);
let client_requester: ClientRequesterHandle =
Arc::new(ChannelClientRequester::new(request_tx));
let router = router.with_client_requester(client_requester);
(router, Some(request_rx))
} else {
(router, None)
};
let now = Instant::now();
Self {
id: record.id.clone(),
service_source: SessionServiceSource::Router {
router,
factory: service_factory,
},
notifications_tx,
created_at: now,
last_accessed: RwLock::new(now),
pending_requests: Mutex::new(HashMap::new()),
request_rx: Mutex::new(request_rx),
protocol_version: RwLock::new(record.protocol_version.clone()),
event_counter: AtomicU64::new(0),
event_store,
}
}
fn from_service_restored(
service: McpBoxService,
record: &crate::session_store::SessionRecord,
event_store: Arc<dyn crate::event_store::EventStore>,
) -> Self {
let (notifications_tx, _) = broadcast::channel(100);
let now = Instant::now();
Self {
id: record.id.clone(),
service_source: SessionServiceSource::Boxed(std::sync::Mutex::new(service)),
notifications_tx,
created_at: now,
last_accessed: RwLock::new(now),
pending_requests: Mutex::new(HashMap::new()),
request_rx: Mutex::new(None),
protocol_version: RwLock::new(record.protocol_version.clone()),
event_counter: AtomicU64::new(0),
event_store,
}
}
fn make_service(&self) -> McpBoxService {
match &self.service_source {
SessionServiceSource::Router { router, factory } => (factory)(router.clone()),
SessionServiceSource::Boxed(mutex) => mutex.lock().unwrap().clone(),
}
}
fn handle_notification(&self, notification: McpNotification) {
match &self.service_source {
SessionServiceSource::Router { router, .. } => {
router.handle_notification(notification);
}
SessionServiceSource::Boxed(_) => {
tracing::debug!(
notification = ?notification,
"Notification received on service-based session (not forwarded)"
);
}
}
}
fn next_event_id(&self) -> u64 {
self.event_counter.fetch_add(1, Ordering::SeqCst)
}
async fn buffer_event(&self, id: u64, data: String) {
let record = crate::event_store::EventRecord::new(id, data);
if let Err(e) = self.event_store.append(&self.id, record).await {
tracing::warn!(session_id = %self.id, event_id = id, error = %e, "Failed to append event to event store");
}
}
async fn get_events_after(&self, after_id: u64) -> Vec<crate::event_store::EventRecord> {
match self.event_store.replay_after(&self.id, after_id).await {
Ok(events) => events,
Err(e) => {
tracing::warn!(session_id = %self.id, error = %e, "Failed to replay events from event store");
Vec::new()
}
}
}
async fn touch(&self) {
*self.last_accessed.write().await = Instant::now();
}
async fn is_expired(&self, ttl: Duration) -> bool {
self.last_accessed.read().await.elapsed() > ttl
}
async fn add_pending_request(
&self,
id: RequestId,
response_tx: oneshot::Sender<Result<serde_json::Value>>,
) {
let mut pending = self.pending_requests.lock().await;
pending.insert(id, PendingRequest { response_tx });
}
async fn complete_pending_request(
&self,
id: &RequestId,
result: Result<serde_json::Value>,
) -> bool {
let pending = {
let mut pending_requests = self.pending_requests.lock().await;
pending_requests.remove(id)
};
match pending {
Some(pending) => {
let _ = pending.response_tx.send(result);
true
}
None => false,
}
}
}
pub const DEFAULT_SESSION_TTL: Duration = Duration::from_secs(30 * 60);
const DEFAULT_CLEANUP_INTERVAL: Duration = Duration::from_secs(60);
#[derive(Debug, Clone)]
pub struct SessionConfig {
pub ttl: Duration,
pub max_sessions: Option<usize>,
pub cleanup_interval: Duration,
}
impl Default for SessionConfig {
fn default() -> Self {
Self {
ttl: DEFAULT_SESSION_TTL,
max_sessions: None,
cleanup_interval: DEFAULT_CLEANUP_INTERVAL,
}
}
}
impl SessionConfig {
pub fn with_ttl(ttl: Duration) -> Self {
Self {
ttl,
..Default::default()
}
}
pub fn max_sessions(mut self, max: usize) -> Self {
self.max_sessions = Some(max);
self
}
pub fn cleanup_interval(mut self, interval: Duration) -> Self {
self.cleanup_interval = interval;
self
}
}
struct SessionRegistry {
sessions: RwLock<HashMap<String, Arc<Session>>>,
config: SessionConfig,
sampling_enabled: bool,
persistent: Arc<dyn crate::session_store::SessionStore>,
events: Arc<dyn crate::event_store::EventStore>,
service_source: ServiceSource,
auto_reinit: bool,
}
impl SessionRegistry {
fn new(
config: SessionConfig,
sampling_enabled: bool,
persistent: Arc<dyn crate::session_store::SessionStore>,
events: Arc<dyn crate::event_store::EventStore>,
service_source: ServiceSource,
auto_reinit: bool,
) -> Self {
Self {
sessions: RwLock::new(HashMap::new()),
config,
sampling_enabled,
persistent,
events,
service_source,
auto_reinit,
}
}
async fn record_for(&self, session: &Session) -> crate::session_store::SessionRecord {
let protocol_version = session.protocol_version.read().await.clone();
let last_accessed = session.last_accessed.read().await;
let mut record = crate::session_store::SessionRecord::new(
session.id.clone(),
protocol_version,
self.config.ttl,
);
let now = std::time::SystemTime::now();
let created_ago = session.created_at.elapsed();
let last_accessed_ago = last_accessed.elapsed();
record.created_at = now.checked_sub(created_ago).unwrap_or(now);
record.last_accessed = now.checked_sub(last_accessed_ago).unwrap_or(now);
record.expires_at = record.last_accessed + self.config.ttl;
record
}
async fn persist_new(&self, session: &Session) {
let record = self.record_for(session).await;
if let Err(e) = self.persistent.create(&mut record.clone()).await {
tracing::warn!(session_id = %session.id, error = %e, "Failed to persist session record");
}
}
async fn create(
&self,
router: McpRouter,
service_factory: ServiceFactory,
) -> Option<Arc<Session>> {
let session = {
let mut sessions = self.sessions.write().await;
if let Some(max) = self.config.max_sessions
&& sessions.len() >= max
{
tracing::warn!(
max_sessions = max,
current = sessions.len(),
"Session limit reached, rejecting new session"
);
return None;
}
let session = Arc::new(Session::new(
router,
self.sampling_enabled,
service_factory,
self.events.clone(),
));
sessions.insert(session.id.clone(), session.clone());
tracing::debug!(session_id = %session.id, sampling = self.sampling_enabled, "Created new session");
session
};
self.persist_new(&session).await;
Some(session)
}
async fn create_from_service(&self, service: McpBoxService) -> Option<Arc<Session>> {
let session = {
let mut sessions = self.sessions.write().await;
if let Some(max) = self.config.max_sessions
&& sessions.len() >= max
{
tracing::warn!(
max_sessions = max,
current = sessions.len(),
"Session limit reached, rejecting new session"
);
return None;
}
let session = Arc::new(Session::from_service(service, self.events.clone()));
sessions.insert(session.id.clone(), session.clone());
tracing::debug!(session_id = %session.id, "Created new session from service");
session
};
self.persist_new(&session).await;
Some(session)
}
async fn create_initialized(
&self,
router: McpRouter,
service_factory: ServiceFactory,
) -> Option<Arc<Session>> {
router.session().mark_initialized();
let session = {
let mut sessions = self.sessions.write().await;
if let Some(max) = self.config.max_sessions
&& sessions.len() >= max
{
return None;
}
let session = Arc::new(Session::new(
router,
self.sampling_enabled,
service_factory,
self.events.clone(),
));
sessions.insert(session.id.clone(), session.clone());
tracing::debug!(session_id = %session.id, "Created pre-initialized session (optional_sessions)");
session
};
self.persist_new(&session).await;
Some(session)
}
async fn create_initialized_from_service(
&self,
service: McpBoxService,
) -> Option<Arc<Session>> {
let session = {
let mut sessions = self.sessions.write().await;
if let Some(max) = self.config.max_sessions
&& sessions.len() >= max
{
return None;
}
let session = Arc::new(Session::from_service(service, self.events.clone()));
sessions.insert(session.id.clone(), session.clone());
tracing::debug!(session_id = %session.id, "Created pre-initialized session from service (optional_sessions)");
session
};
self.persist_new(&session).await;
Some(session)
}
async fn get(&self, id: &str) -> Option<Arc<Session>> {
{
let sessions = self.sessions.read().await;
if let Some(s) = sessions.get(id).cloned() {
s.touch().await;
return Some(s);
}
}
match self.persistent.load(id).await {
Ok(Some(record)) => {
tracing::info!(session_id = %id, "Restoring session from persistent store");
if let Some(session) = self.restore_from_record(record).await {
return Some(session);
}
}
Ok(None) => {}
Err(e) => {
tracing::warn!(session_id = %id, error = %e, "Failed to load session record");
}
}
if self.auto_reinit {
tracing::info!(session_id = %id, "Auto-reinitializing unknown session");
return self.auto_reinitialize(id).await;
}
None
}
async fn restore_from_record(
&self,
record: crate::session_store::SessionRecord,
) -> Option<Arc<Session>> {
let session = {
let mut sessions = self.sessions.write().await;
if let Some(max) = self.config.max_sessions
&& sessions.len() >= max
{
tracing::warn!(
max_sessions = max,
"Session limit reached, cannot restore session"
);
return None;
}
if let Some(existing) = sessions.get(&record.id).cloned() {
existing.touch().await;
return Some(existing);
}
let session: Arc<Session> = match &self.service_source {
ServiceSource::Router { router, factory } => Arc::new(Session::restored(
&record,
router.with_fresh_session(),
self.sampling_enabled,
factory.clone(),
self.events.clone(),
)),
ServiceSource::Service(svc) => {
let service = svc.lock().unwrap().clone();
Arc::new(Session::from_service_restored(
service,
&record,
self.events.clone(),
))
}
};
sessions.insert(record.id.clone(), session.clone());
tracing::debug!(session_id = %session.id, "Restored session into local registry");
session
};
if let Ok(events) = self.events.replay_after(&record.id, 0).await
&& let Some(max_id) = events.iter().map(|e| e.id).max()
{
session
.event_counter
.store(max_id + 1, std::sync::atomic::Ordering::SeqCst);
}
let mut refreshed = record;
refreshed.touch(self.config.ttl);
if let Err(e) = self.persistent.save(&refreshed).await {
tracing::warn!(session_id = %refreshed.id, error = %e, "Failed to refresh restored session record");
}
Some(session)
}
async fn auto_reinitialize(&self, id: &str) -> Option<Arc<Session>> {
let mut record = crate::session_store::SessionRecord::new(
id.to_string(),
LATEST_PROTOCOL_VERSION.to_string(),
self.config.ttl,
);
record.client_info = Some(crate::protocol::Implementation {
name: "auto-recovered".into(),
version: "unknown".into(),
title: None,
description: None,
icons: None,
website_url: None,
meta: None,
});
record.client_capabilities = Some(crate::protocol::ClientCapabilities::default());
if let Err(e) = self.persistent.create(&mut record).await {
tracing::warn!(session_id = %id, error = %e, "Failed to persist auto-reinitialized session");
}
self.restore_from_record(record).await
}
async fn remove(&self, id: &str) -> bool {
let removed = {
let mut sessions = self.sessions.write().await;
sessions.remove(id).is_some()
};
if removed {
tracing::debug!(session_id = %id, "Removed session");
if let Err(e) = self.persistent.delete(id).await {
tracing::warn!(session_id = %id, error = %e, "Failed to delete session record");
}
if let Err(e) = self.events.purge_session(id).await {
tracing::warn!(session_id = %id, error = %e, "Failed to purge session events");
}
}
removed
}
async fn broadcast_to_all(&self, json: &str) {
let sessions = self.sessions.read().await;
for session in sessions.values() {
let _ = session.notifications_tx.send(json.to_string());
}
}
async fn cleanup_expired(&self) -> usize {
let expired = {
let mut sessions = self.sessions.write().await;
let ttl = self.config.ttl;
let mut expired = Vec::new();
for (id, session) in sessions.iter() {
if session.is_expired(ttl).await {
expired.push(id.clone());
}
}
for id in &expired {
sessions.remove(id);
tracing::debug!(session_id = %id, "Expired session removed");
}
if !expired.is_empty() {
tracing::info!(
expired_count = expired.len(),
remaining = sessions.len(),
"Session cleanup completed"
);
}
expired
};
for id in &expired {
if let Err(e) = self.persistent.delete(id).await {
tracing::warn!(session_id = %id, error = %e, "Failed to delete expired session record");
}
if let Err(e) = self.events.purge_session(id).await {
tracing::warn!(session_id = %id, error = %e, "Failed to purge expired session events");
}
}
expired.len()
}
}
#[derive(Debug, Clone)]
pub struct SessionInfo {
pub id: String,
pub created_at: Duration,
pub last_activity: Duration,
}
#[derive(Clone)]
pub struct SessionHandle {
store: Arc<SessionRegistry>,
}
impl SessionHandle {
pub async fn session_count(&self) -> usize {
self.store.sessions.read().await.len()
}
pub async fn list_sessions(&self) -> Vec<SessionInfo> {
let sessions = self.store.sessions.read().await;
let mut infos = Vec::with_capacity(sessions.len());
for session in sessions.values() {
let last_accessed = session.last_accessed.read().await;
infos.push(SessionInfo {
id: session.id.clone(),
created_at: session.created_at.elapsed(),
last_activity: last_accessed.elapsed(),
});
}
infos
}
pub async fn terminate_session(&self, id: &str) -> bool {
self.store.remove(id).await
}
}
#[derive(Clone)]
enum ServiceSource {
Router {
router: McpRouter,
factory: ServiceFactory,
},
Service(Arc<std::sync::Mutex<McpBoxService>>),
}
struct AppState {
service_source: ServiceSource,
sessions: Arc<SessionRegistry>,
validate_origin: bool,
allowed_origins: Vec<String>,
validate_host: bool,
allowed_hosts: Vec<String>,
sampling_enabled: bool,
optional_sessions: bool,
#[cfg(feature = "stateless")]
stateless_config: Option<crate::stateless::StatelessConfig>,
}
#[cfg(feature = "oauth")]
#[derive(Clone)]
pub(crate) struct OAuthConfig {
pub(crate) metadata: crate::oauth::ProtectedResourceMetadata,
}
pub struct HttpTransport {
service_source: ServiceSource,
validate_origin: bool,
allowed_origins: Vec<String>,
validate_host: bool,
allowed_hosts: Vec<String>,
session_config: SessionConfig,
sampling_enabled: bool,
optional_sessions: bool,
session_store: Arc<dyn crate::session_store::SessionStore>,
event_store: Arc<dyn crate::event_store::EventStore>,
auto_reinit_sessions: bool,
external_notifications: Option<NotificationReceiver>,
#[cfg(feature = "stateless")]
stateless_config: Option<crate::stateless::StatelessConfig>,
#[cfg(feature = "oauth")]
oauth_config: Option<OAuthConfig>,
}
impl HttpTransport {
pub fn new(router: McpRouter) -> Self {
Self {
service_source: ServiceSource::Router {
router,
factory: identity_factory(),
},
validate_origin: true,
allowed_origins: vec![],
validate_host: true,
allowed_hosts: vec![],
session_config: SessionConfig::default(),
sampling_enabled: false,
optional_sessions: true,
session_store: Arc::new(crate::session_store::MemorySessionStore::new()),
event_store: Arc::new(crate::event_store::MemoryEventStore::new()),
auto_reinit_sessions: false,
external_notifications: None,
#[cfg(feature = "stateless")]
stateless_config: None,
#[cfg(feature = "oauth")]
oauth_config: None,
}
}
pub fn from_service<S>(service: S) -> Self
where
S: tower::Service<
RouterRequest,
Response = RouterResponse,
Error = std::convert::Infallible,
> + Clone
+ Send
+ 'static,
S::Future: Send,
{
Self {
service_source: ServiceSource::Service(Arc::new(std::sync::Mutex::new(
BoxCloneService::new(service),
))),
validate_origin: true,
allowed_origins: vec![],
validate_host: true,
allowed_hosts: vec![],
session_config: SessionConfig::default(),
sampling_enabled: false,
optional_sessions: true,
session_store: Arc::new(crate::session_store::MemorySessionStore::new()),
event_store: Arc::new(crate::event_store::MemoryEventStore::new()),
auto_reinit_sessions: false,
external_notifications: None,
#[cfg(feature = "stateless")]
stateless_config: None,
#[cfg(feature = "oauth")]
oauth_config: None,
}
}
pub fn with_notifications(router: McpRouter, notification_rx: NotificationReceiver) -> Self {
Self {
external_notifications: Some(notification_rx),
..Self::new(router)
}
}
pub fn external_notifications(mut self, notification_rx: NotificationReceiver) -> Self {
self.external_notifications = Some(notification_rx);
self
}
pub fn with_sampling(mut self) -> Self {
self.sampling_enabled = true;
self
}
pub fn require_sessions(mut self) -> Self {
self.optional_sessions = false;
self
}
#[cfg(feature = "stateless")]
pub fn stateless(mut self, config: crate::stateless::StatelessConfig) -> Self {
self.stateless_config = Some(config);
self
}
pub fn disable_origin_validation(mut self) -> Self {
self.validate_origin = false;
self
}
pub fn allowed_origins(mut self, origins: Vec<String>) -> Self {
self.allowed_origins = origins;
self
}
pub fn disable_host_validation(mut self) -> Self {
self.validate_host = false;
self
}
pub fn allowed_hosts(mut self, hosts: Vec<String>) -> Self {
self.allowed_hosts = hosts;
self
}
pub fn session_config(mut self, config: SessionConfig) -> Self {
self.session_config = config;
self
}
pub fn session_ttl(mut self, ttl: Duration) -> Self {
self.session_config.ttl = ttl;
self
}
pub fn max_sessions(mut self, max: usize) -> Self {
self.session_config.max_sessions = Some(max);
self
}
pub fn session_store(mut self, store: Arc<dyn crate::session_store::SessionStore>) -> Self {
self.session_store = store;
self
}
pub fn event_store(mut self, store: Arc<dyn crate::event_store::EventStore>) -> Self {
self.event_store = store;
self
}
pub fn auto_reinitialize_sessions(mut self, enabled: bool) -> Self {
self.auto_reinit_sessions = enabled;
self
}
#[cfg(feature = "oauth")]
pub fn oauth(mut self, metadata: crate::oauth::ProtectedResourceMetadata) -> Self {
self.oauth_config = Some(OAuthConfig { metadata });
self
}
pub fn layer<L>(mut self, layer: L) -> Self
where
L: tower::Layer<McpRouter> + Send + Sync + 'static,
L::Service:
tower::Service<RouterRequest, Response = RouterResponse> + Clone + Send + 'static,
<L::Service as tower::Service<RouterRequest>>::Error: std::fmt::Display + Send,
<L::Service as tower::Service<RouterRequest>>::Future: Send,
{
match &mut self.service_source {
ServiceSource::Router { factory, .. } => {
*factory = Arc::new(move |router: McpRouter| {
let annotations = router.tool_annotations_map();
let wrapped = layer.layer(router);
tower::util::BoxCloneService::new(InjectAnnotations::new(
CatchError::new(wrapped),
annotations,
))
});
}
ServiceSource::Service(_) => {
panic!(
"layer() cannot be used with from_service() — \
wrap the service with middleware before passing it in"
);
}
}
self
}
fn build_state(&self) -> Arc<AppState> {
let sessions = Arc::new(SessionRegistry::new(
self.session_config.clone(),
self.sampling_enabled,
self.session_store.clone(),
self.event_store.clone(),
self.service_source.clone(),
self.auto_reinit_sessions,
));
let cleanup_sessions = sessions.clone();
let cleanup_interval = self.session_config.cleanup_interval;
tokio::spawn(async move {
loop {
tokio::time::sleep(cleanup_interval).await;
cleanup_sessions.cleanup_expired().await;
}
});
Arc::new(AppState {
service_source: self.service_source.clone(),
sessions,
validate_origin: self.validate_origin,
allowed_origins: self.allowed_origins.clone(),
validate_host: self.validate_host,
allowed_hosts: self.allowed_hosts.clone(),
sampling_enabled: self.sampling_enabled,
optional_sessions: self.optional_sessions,
#[cfg(feature = "stateless")]
stateless_config: self.stateless_config.clone(),
})
}
pub fn into_router(self) -> Router {
let (router, _handle) = self.into_router_with_handle();
router
}
pub fn into_router_with_handle(mut self) -> (Router, SessionHandle) {
let external_rx = self.external_notifications.take();
let state = self.build_state();
let handle = SessionHandle {
store: state.sessions.clone(),
};
spawn_external_notification_fanout(external_rx, state.sessions.clone());
let router = Router::new()
.route("/", post(handle_post))
.route("/", get(handle_get))
.route("/", delete(handle_delete))
.route("/health", get(handle_health))
.with_state(state);
#[cfg(feature = "oauth")]
let router = self.add_oauth_route(router, "");
(router, handle)
}
pub fn into_router_at(self, path: &str) -> Router {
let (router, _handle) = self.into_router_at_with_handle(path);
router
}
pub fn into_router_at_with_handle(mut self, path: &str) -> (Router, SessionHandle) {
let external_rx = self.external_notifications.take();
let state = self.build_state();
let handle = SessionHandle {
store: state.sessions.clone(),
};
spawn_external_notification_fanout(external_rx, state.sessions.clone());
let mcp_router = Router::new()
.route("/", post(handle_post))
.route("/", get(handle_get))
.route("/", delete(handle_delete))
.route("/health", get(handle_health))
.with_state(state);
let router = Router::new().nest(path, mcp_router);
#[cfg(feature = "oauth")]
let router = self.add_oauth_route(router, path);
(router, handle)
}
pub async fn serve(self, addr: &str) -> Result<()> {
let listener = tokio::net::TcpListener::bind(addr)
.await
.map_err(|e| Error::Transport(format!("Failed to bind to {}: {}", addr, e)))?;
tracing::info!("MCP HTTP transport listening on {}", addr);
let router = self.into_router();
axum::serve(listener, router)
.await
.map_err(|e| Error::Transport(format!("Server error: {}", e)))?;
Ok(())
}
#[cfg(feature = "oauth")]
fn add_oauth_route(&self, router: Router, base_path: &str) -> Router {
if let Some(ref config) = self.oauth_config {
let metadata = config.metadata.clone();
let well_known_path = if base_path.is_empty() {
crate::oauth::ProtectedResourceMetadata::well_known_path().to_string()
} else {
format!(
"{}{}",
base_path.trim_end_matches('/'),
crate::oauth::ProtectedResourceMetadata::well_known_path()
)
};
router.route(
&well_known_path,
get(move || {
let m = metadata.clone();
async move { axum::Json(m) }
}),
)
} else {
router
}
}
}
fn spawn_external_notification_fanout(
rx: Option<NotificationReceiver>,
sessions: Arc<SessionRegistry>,
) {
let Some(mut rx) = rx else {
return;
};
tokio::spawn(async move {
while let Some(notification) = rx.recv().await {
if let Some(json) = crate::transport::stdio::serialize_notification(¬ification) {
sessions.broadcast_to_all(&json).await;
}
}
tracing::debug!("External notification channel closed; fan-out task exiting");
});
}
fn is_localhost_origin(origin: &str) -> bool {
if let Some(rest) = origin
.strip_prefix("http://")
.or_else(|| origin.strip_prefix("https://"))
{
is_localhost_host(rest)
} else {
false
}
}
fn is_localhost_host(host: &str) -> bool {
let host_only = if host.starts_with('[') {
host.split(']')
.next()
.unwrap_or(host)
.trim_start_matches('[')
} else {
host.split(':').next().unwrap_or(host)
};
matches!(host_only, "localhost" | "127.0.0.1" | "::1")
}
fn effective_host<'a>(headers: &'a HeaderMap, uri: &'a axum::http::Uri) -> Option<&'a str> {
if let Some(value) = headers.get(header::HOST)
&& let Ok(s) = value.to_str()
{
return Some(s);
}
uri.authority().map(|a| a.as_str())
}
fn validate_host(headers: &HeaderMap, uri: &axum::http::Uri, state: &AppState) -> Option<Response> {
if !state.validate_host {
return None;
}
let Some(host) = effective_host(headers, uri) else {
if state.allowed_hosts.is_empty() {
return None;
}
tracing::warn!("Rejecting request: missing Host header and no :authority fallback");
return Some((StatusCode::BAD_REQUEST, "Missing Host header").into_response());
};
if is_localhost_host(host) {
return None;
}
if state.allowed_hosts.is_empty() {
return None;
}
if state.allowed_hosts.iter().any(|h| h == host) {
return None;
}
tracing::warn!(host = %host, "Rejecting request: Host not in allowlist");
Some((StatusCode::BAD_REQUEST, "Host not allowed").into_response())
}
fn validate_origin(headers: &HeaderMap, state: &AppState) -> Option<Response> {
if !state.validate_origin {
return None;
}
if let Some(origin) = headers.get(header::ORIGIN) {
let origin_str = origin.to_str().unwrap_or("");
if is_localhost_origin(origin_str) {
return None;
}
if state.allowed_origins.is_empty() {
tracing::warn!(
origin = %origin_str,
"Rejecting request: cross-origin not allowed (no allowlist configured)"
);
return Some(
(StatusCode::FORBIDDEN, "Cross-origin requests not allowed").into_response(),
);
}
if !state
.allowed_origins
.iter()
.any(|o| o == origin_str || o == "*")
{
tracing::warn!(origin = %origin_str, "Rejecting request: Origin not in allowlist");
return Some((StatusCode::FORBIDDEN, "Origin not allowed").into_response());
}
}
None
}
fn get_session_id(headers: &HeaderMap) -> Option<String> {
headers
.get(MCP_SESSION_ID_HEADER)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string())
}
fn get_protocol_version(headers: &HeaderMap) -> Option<String> {
headers
.get(MCP_PROTOCOL_VERSION_HEADER)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string())
}
fn get_last_event_id(headers: &HeaderMap) -> Option<u64> {
headers
.get(LAST_EVENT_ID_HEADER)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u64>().ok())
}
fn is_initialize_request(body: &serde_json::Value) -> bool {
body.get("method")
.and_then(|m| m.as_str())
.map(|m| m == "initialize")
.unwrap_or(false)
}
fn is_response(parsed: &serde_json::Value) -> bool {
parsed.get("method").is_none()
&& (parsed.get("result").is_some() || parsed.get("error").is_some())
}
fn extract_request_id(parsed: &serde_json::Value) -> Option<RequestId> {
parsed.get("id").and_then(|id| {
if let Some(n) = id.as_i64() {
Some(RequestId::Number(n))
} else {
id.as_str().map(|s| RequestId::String(s.to_string()))
}
})
}
async fn handle_post(
State(state): State<Arc<AppState>>,
request: axum::extract::Request,
) -> Response {
let (parts, body_bytes) = request.into_parts();
let headers = parts.headers;
let uri = parts.uri.clone();
if let Some(resp) = validate_host(&headers, &uri, &state) {
return resp;
}
if let Some(resp) = validate_origin(&headers, &state) {
return resp;
}
let body = match axum::body::to_bytes(body_bytes, usize::MAX).await {
Ok(bytes) => match String::from_utf8(bytes.to_vec()) {
Ok(s) => s,
Err(e) => {
return json_rpc_error_response(
None,
JsonRpcError::parse_error(format!("Invalid UTF-8: {}", e)),
);
}
},
Err(e) => {
return json_rpc_error_response(
None,
JsonRpcError::parse_error(format!("Failed to read body: {}", e)),
);
}
};
#[cfg(feature = "oauth")]
let http_extensions = parts.extensions;
#[cfg(not(feature = "oauth"))]
let _ = parts.extensions;
let parsed: serde_json::Value = match serde_json::from_str(&body) {
Ok(v) => v,
Err(e) => {
return json_rpc_error_response(
None,
JsonRpcError::parse_error(format!("Invalid JSON: {}", e)),
);
}
};
let is_init = is_initialize_request(&parsed);
#[cfg(feature = "stateless")]
if !is_init && state.stateless_config.is_some() && get_session_id(&headers).is_none() {
let version_from_header = get_protocol_version(&headers);
let params = parsed.get("params").unwrap_or(&parsed);
let version_from_meta = crate::stateless::StatelessRequestMeta::from_params(params)
.and_then(|m| m.protocol_version);
if let Some(version) = version_from_header.or(version_from_meta) {
if let Err(err) = crate::stateless::validate_protocol_version(&version) {
return json_rpc_error_response(None, err);
}
if parsed.get("id").is_none() || is_response(&parsed) {
return StatusCode::ACCEPTED.into_response();
}
let request: JsonRpcRequest = match serde_json::from_value(parsed) {
Ok(r) => r,
Err(e) => {
return json_rpc_error_response(
None,
JsonRpcError::parse_error(format!("Invalid request: {}", e)),
);
}
};
let mut service = match &state.service_source {
ServiceSource::Router { router, factory } => {
let ephemeral = router.with_fresh_session();
ephemeral.session().mark_initialized();
JsonRpcService::new(factory(ephemeral))
}
ServiceSource::Service(mutex) => JsonRpcService::new(mutex.lock().unwrap().clone()),
};
#[cfg(feature = "oauth")]
{
if let Some(claims) = http_extensions.get::<crate::oauth::token::TokenClaims>() {
let mut ext = crate::router::Extensions::new();
ext.insert(claims.clone());
service = service.with_extensions(ext);
}
}
let response = match service.call_single(request).await {
Ok(resp) => resp,
Err(e) => {
return json_rpc_error_response(
None,
JsonRpcError::internal_error(e.to_string()),
);
}
};
let mut resp = axum::Json(response).into_response();
resp.headers_mut().insert(
MCP_PROTOCOL_VERSION_HEADER,
HeaderValue::from_str(&version).unwrap(),
);
return resp;
}
}
let session = if is_init {
let create_result = match &state.service_source {
ServiceSource::Router { router, factory } => {
state
.sessions
.create(router.with_fresh_session(), factory.clone())
.await
}
ServiceSource::Service(mutex) => {
let service = mutex.lock().unwrap().clone();
state.sessions.create_from_service(service).await
}
};
match create_result {
Some(s) => s,
None => {
return (
StatusCode::SERVICE_UNAVAILABLE,
"Maximum session limit reached",
)
.into_response();
}
}
} else if let Some(session_id) = get_session_id(&headers) {
match state.sessions.get(&session_id).await {
Some(s) => s,
None => {
return json_rpc_error_response(
None,
JsonRpcError::session_not_found_with_id(&session_id),
);
}
}
} else if state.optional_sessions {
let create_result = match &state.service_source {
ServiceSource::Router { router, factory } => {
state
.sessions
.create_initialized(router.with_fresh_session(), factory.clone())
.await
}
ServiceSource::Service(mutex) => {
let service = mutex.lock().unwrap().clone();
state
.sessions
.create_initialized_from_service(service)
.await
}
};
match create_result {
Some(s) => s,
None => {
return (
StatusCode::SERVICE_UNAVAILABLE,
"Maximum session limit reached",
)
.into_response();
}
}
} else {
return json_rpc_error_response(None, JsonRpcError::session_required());
};
if !is_init
&& let Some(version) = get_protocol_version(&headers)
&& !SUPPORTED_PROTOCOL_VERSIONS.contains(&version.as_str())
{
return (
StatusCode::BAD_REQUEST,
format!("Unsupported protocol version: {}", version),
)
.into_response();
}
if is_response(&parsed) {
if let Some(id) = extract_request_id(&parsed) {
let result = if let Some(error) = parsed.get("error") {
let code = error.get("code").and_then(|c| c.as_i64()).unwrap_or(-1);
let message = error
.get("message")
.and_then(|m| m.as_str())
.unwrap_or("Unknown error");
Err(Error::Internal(format!(
"Client error ({}): {}",
code, message
)))
} else if let Some(result) = parsed.get("result") {
Ok(result.clone())
} else {
Err(Error::Internal(
"Response has neither result nor error".to_string(),
))
};
if session.complete_pending_request(&id, result).await {
tracing::debug!(request_id = ?id, "Completed pending request");
} else {
tracing::warn!(request_id = ?id, "Received response for unknown request");
}
}
return StatusCode::ACCEPTED.into_response();
}
if parsed.get("id").is_none() {
if let Ok(notification) = serde_json::from_value::<JsonRpcNotification>(parsed)
&& let Ok(mcp_notification) = McpNotification::from_jsonrpc(¬ification)
{
session.handle_notification(mcp_notification);
}
return StatusCode::ACCEPTED.into_response();
}
let request: JsonRpcRequest = match serde_json::from_value(parsed) {
Ok(r) => r,
Err(e) => {
return json_rpc_error_response(
None,
JsonRpcError::parse_error(format!("Invalid request: {}", e)),
);
}
};
let mut service = JsonRpcService::new(session.make_service());
#[cfg(feature = "oauth")]
{
if let Some(claims) = http_extensions.get::<crate::oauth::token::TokenClaims>() {
let mut ext = crate::router::Extensions::new();
ext.insert(claims.clone());
service = service.with_extensions(ext);
}
}
let response = match service.call_single(request).await {
Ok(resp) => resp,
Err(e) => {
return json_rpc_error_response(None, JsonRpcError::internal_error(e.to_string()));
}
};
if is_init
&& let JsonRpcResponse::Result(ref result) = response
&& let Some(version) = result
.result
.get("protocolVersion")
.and_then(|v| v.as_str())
{
*session.protocol_version.write().await = version.to_string();
}
let mut resp = axum::Json(response).into_response();
if is_init {
resp.headers_mut().insert(
MCP_SESSION_ID_HEADER,
HeaderValue::from_str(&session.id).unwrap(),
);
}
let version = session.protocol_version.read().await;
resp.headers_mut().insert(
MCP_PROTOCOL_VERSION_HEADER,
HeaderValue::from_str(&version).unwrap(),
);
resp
}
async fn handle_get(
State(state): State<Arc<AppState>>,
request: axum::extract::Request,
) -> Response {
let (parts, _body) = request.into_parts();
let headers = parts.headers;
let uri = parts.uri.clone();
if let Some(resp) = validate_host(&headers, &uri, &state) {
return resp;
}
if let Some(resp) = validate_origin(&headers, &state) {
return resp;
}
let accept = headers
.get(header::ACCEPT)
.and_then(|v| v.to_str().ok())
.unwrap_or("");
if !accept.contains("text/event-stream") {
return (
StatusCode::NOT_ACCEPTABLE,
"Accept header must include text/event-stream",
)
.into_response();
}
let session_id = match get_session_id(&headers) {
Some(id) => id,
None => {
return json_rpc_error_response(None, JsonRpcError::session_required());
}
};
let session = match state.sessions.get(&session_id).await {
Some(s) => s,
None => {
return json_rpc_error_response(
None,
JsonRpcError::session_not_found_with_id(&session_id),
);
}
};
let last_event_id = get_last_event_id(&headers);
if state.sampling_enabled {
return handle_get_bidirectional(session, last_event_id).await;
}
let rx = session.notifications_tx.subscribe();
let session_clone = session.clone();
let replay_events: Vec<_> = if let Some(after_id) = last_event_id {
let events = session.get_events_after(after_id).await;
tracing::debug!(
after_id = after_id,
replay_count = events.len(),
"Replaying buffered events for stream resumption"
);
events
.into_iter()
.map(|e| {
Ok::<_, Infallible>(
Event::default()
.id(e.id.to_string())
.event(SSE_MESSAGE_EVENT)
.data(e.data),
)
})
.collect()
} else {
Vec::new()
};
let replay_stream = tokio_stream::iter(replay_events);
let live_stream = BroadcastStream::new(rx)
.then(move |result: std::result::Result<String, _>| {
let session = session_clone.clone();
async move {
match result {
Ok(msg) => {
let event_id = session.next_event_id();
session.buffer_event(event_id, msg.clone()).await;
Some(Ok::<_, Infallible>(
Event::default()
.id(event_id.to_string())
.event(SSE_MESSAGE_EVENT)
.data(msg),
))
}
Err(_) => None,
}
}
})
.filter_map(|x| x);
let stream = replay_stream.chain(live_stream);
Sse::new(stream)
.keep_alive(
axum::response::sse::KeepAlive::new()
.interval(Duration::from_secs(30))
.text("ping"),
)
.into_response()
}
async fn handle_get_bidirectional(session: Arc<Session>, last_event_id: Option<u64>) -> Response {
let request_rx = {
let mut rx_guard = session.request_rx.lock().await;
rx_guard.take()
};
let (tx, rx) = tokio::sync::mpsc::channel::<std::result::Result<Event, Infallible>>(100);
if let Some(after_id) = last_event_id {
let events = session.get_events_after(after_id).await;
tracing::debug!(
after_id = after_id,
replay_count = events.len(),
"Replaying buffered events for bidirectional stream resumption"
);
for event in events {
let sse_event = Event::default()
.id(event.id.to_string())
.event(SSE_MESSAGE_EVENT)
.data(event.data);
if tx.send(Ok(sse_event)).await.is_err() {
return Sse::new(tokio_stream::wrappers::ReceiverStream::new(rx))
.keep_alive(
axum::response::sse::KeepAlive::new()
.interval(Duration::from_secs(30))
.text("ping"),
)
.into_response();
}
}
}
let session_clone = session.clone();
tokio::spawn(async move {
let mut notification_rx = session_clone.notifications_tx.subscribe();
if let Some(mut req_rx) = request_rx {
loop {
tokio::select! {
result = notification_rx.recv() => {
match result {
Ok(msg) => {
let event_id = session_clone.next_event_id();
session_clone.buffer_event(event_id, msg.clone()).await;
let event = Event::default()
.id(event_id.to_string())
.event(SSE_MESSAGE_EVENT)
.data(msg);
if tx.send(Ok(event)).await.is_err() {
break; }
}
Err(broadcast::error::RecvError::Closed) => break,
Err(broadcast::error::RecvError::Lagged(_)) => continue,
}
}
Some(outgoing) = req_rx.recv() => {
let request = JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id: outgoing.id.clone(),
method: outgoing.method,
params: Some(outgoing.params),
};
match serde_json::to_string(&request) {
Ok(request_json) => {
tracing::debug!(output = %request_json, "Sending request to client via SSE");
session_clone.add_pending_request(
outgoing.id,
outgoing.response_tx,
).await;
let event_id = session_clone.next_event_id();
session_clone.buffer_event(event_id, request_json.clone()).await;
let event = Event::default()
.id(event_id.to_string())
.event(SSE_MESSAGE_EVENT)
.data(request_json);
if tx.send(Ok(event)).await.is_err() {
break; }
}
Err(e) => {
tracing::error!(error = %e, "Failed to serialize outgoing request");
let _ = outgoing.response_tx.send(Err(Error::Internal(
format!("Failed to serialize request: {}", e),
)));
}
}
}
}
}
} else {
loop {
match notification_rx.recv().await {
Ok(msg) => {
let event_id = session_clone.next_event_id();
session_clone.buffer_event(event_id, msg.clone()).await;
let event = Event::default()
.id(event_id.to_string())
.event(SSE_MESSAGE_EVENT)
.data(msg);
if tx.send(Ok(event)).await.is_err() {
break;
}
}
Err(broadcast::error::RecvError::Closed) => break,
Err(broadcast::error::RecvError::Lagged(_)) => continue,
}
}
}
});
let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
Sse::new(stream)
.keep_alive(
axum::response::sse::KeepAlive::new()
.interval(Duration::from_secs(30))
.text("ping"),
)
.into_response()
}
async fn handle_delete(
State(state): State<Arc<AppState>>,
request: axum::extract::Request,
) -> Response {
let (parts, _body) = request.into_parts();
let headers = parts.headers;
let uri = parts.uri.clone();
if let Some(resp) = validate_host(&headers, &uri, &state) {
return resp;
}
if let Some(resp) = validate_origin(&headers, &state) {
return resp;
}
let session_id = match get_session_id(&headers) {
Some(id) => id,
None => {
return json_rpc_error_response(None, JsonRpcError::session_required());
}
};
if state.sessions.remove(&session_id).await {
tracing::info!(session_id = %session_id, "Session terminated");
StatusCode::OK.into_response()
} else {
tracing::debug!(session_id = %session_id, "Session already removed or never existed");
StatusCode::OK.into_response()
}
}
async fn handle_health() -> Response {
StatusCode::OK.into_response()
}
fn json_rpc_error_response(
id: Option<crate::protocol::RequestId>,
error: JsonRpcError,
) -> Response {
let response = JsonRpcResponse::error(id, error);
axum::Json(response).into_response()
}
#[cfg(test)]
mod tests {
use super::*;
use axum::body::Body;
use axum::http::Request;
use tower::ServiceExt;
fn create_test_router() -> McpRouter {
McpRouter::new().server_info("test-server", "1.0.0")
}
#[tokio::test]
async fn test_initialize_creates_session() {
let transport = HttpTransport::new(create_test_router()).disable_origin_validation();
let app = transport.into_router();
let request = Request::builder()
.method("POST")
.uri("/")
.header("Content-Type", "application/json")
.header("Accept", "application/json, text/event-stream")
.body(Body::from(
serde_json::json!({
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": "2025-11-25",
"capabilities": {},
"clientInfo": {
"name": "test-client",
"version": "1.0.0"
}
}
})
.to_string(),
))
.unwrap();
let response = app.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert!(response.headers().contains_key(MCP_SESSION_ID_HEADER));
assert_eq!(
response
.headers()
.get(MCP_PROTOCOL_VERSION_HEADER)
.and_then(|v| v.to_str().ok()),
Some("2025-11-25")
);
}
#[tokio::test]
async fn test_protocol_version_header_on_subsequent_requests() {
let transport = HttpTransport::new(create_test_router()).disable_origin_validation();
let app = transport.into_router();
let init_request = Request::builder()
.method("POST")
.uri("/")
.header("Content-Type", "application/json")
.header("Accept", "application/json, text/event-stream")
.body(Body::from(
serde_json::json!({
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": "2025-03-26",
"capabilities": {},
"clientInfo": {
"name": "test-client",
"version": "1.0.0"
}
}
})
.to_string(),
))
.unwrap();
let init_response = app.clone().oneshot(init_request).await.unwrap();
let session_id = init_response
.headers()
.get(MCP_SESSION_ID_HEADER)
.unwrap()
.to_str()
.unwrap()
.to_string();
assert_eq!(
init_response
.headers()
.get(MCP_PROTOCOL_VERSION_HEADER)
.and_then(|v| v.to_str().ok()),
Some("2025-03-26")
);
let initialized_request = Request::builder()
.method("POST")
.uri("/")
.header("Content-Type", "application/json")
.header("Accept", "application/json, text/event-stream")
.header(MCP_SESSION_ID_HEADER, &session_id)
.header(MCP_PROTOCOL_VERSION_HEADER, "2025-03-26")
.body(Body::from(
serde_json::json!({
"jsonrpc": "2.0",
"method": "notifications/initialized"
})
.to_string(),
))
.unwrap();
app.clone().oneshot(initialized_request).await.unwrap();
let list_request = Request::builder()
.method("POST")
.uri("/")
.header("Content-Type", "application/json")
.header("Accept", "application/json, text/event-stream")
.header(MCP_SESSION_ID_HEADER, &session_id)
.header(MCP_PROTOCOL_VERSION_HEADER, "2025-03-26")
.body(Body::from(
serde_json::json!({
"jsonrpc": "2.0",
"id": 2,
"method": "tools/list"
})
.to_string(),
))
.unwrap();
let response = app.oneshot(list_request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(
response
.headers()
.get(MCP_PROTOCOL_VERSION_HEADER)
.and_then(|v| v.to_str().ok()),
Some("2025-03-26")
);
}
#[tokio::test]
async fn test_request_without_session_fails() {
let transport = HttpTransport::new(create_test_router())
.disable_origin_validation()
.require_sessions();
let app = transport.into_router();
let request = Request::builder()
.method("POST")
.uri("/")
.header("Content-Type", "application/json")
.body(Body::from(
serde_json::json!({
"jsonrpc": "2.0",
"id": 1,
"method": "tools/list"
})
.to_string(),
))
.unwrap();
let response = app.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert!(json.get("error").is_some());
assert_eq!(json["error"]["code"], -32006); }
#[tokio::test]
async fn test_delete_session() {
let transport = HttpTransport::new(create_test_router()).disable_origin_validation();
let app = transport.into_router();
let init_request = Request::builder()
.method("POST")
.uri("/")
.header("Content-Type", "application/json")
.header("Accept", "application/json, text/event-stream")
.body(Body::from(
serde_json::json!({
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": "2025-11-25",
"capabilities": {},
"clientInfo": {
"name": "test-client",
"version": "1.0.0"
}
}
})
.to_string(),
))
.unwrap();
let response = app.clone().oneshot(init_request).await.unwrap();
let session_id = response
.headers()
.get(MCP_SESSION_ID_HEADER)
.unwrap()
.to_str()
.unwrap()
.to_string();
let delete_request = Request::builder()
.method("DELETE")
.uri("/")
.header(MCP_SESSION_ID_HEADER, &session_id)
.body(Body::empty())
.unwrap();
let response = app.clone().oneshot(delete_request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let list_request = Request::builder()
.method("POST")
.uri("/")
.header("Content-Type", "application/json")
.header(MCP_SESSION_ID_HEADER, &session_id)
.body(Body::from(
serde_json::json!({
"jsonrpc": "2.0",
"id": 2,
"method": "tools/list"
})
.to_string(),
))
.unwrap();
let response = app.oneshot(list_request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert!(json.get("error").is_some());
assert_eq!(json["error"]["code"], -32005); }
#[tokio::test]
async fn test_custom_session_store_receives_create_and_delete() {
use crate::session_store::{MemorySessionStore, SessionStore as PublicSessionStore};
let store = Arc::new(MemorySessionStore::new());
let store_dyn: Arc<dyn PublicSessionStore> = store.clone();
let transport = HttpTransport::new(create_test_router())
.disable_origin_validation()
.session_store(store_dyn);
let (app, handle) = transport.into_router_with_handle();
let init_request = Request::builder()
.method("POST")
.uri("/")
.header("Content-Type", "application/json")
.header("Accept", "application/json, text/event-stream")
.body(Body::from(
serde_json::json!({
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": "2025-11-25",
"capabilities": {},
"clientInfo": { "name": "test-client", "version": "1.0.0" }
}
})
.to_string(),
))
.unwrap();
let response = app.clone().oneshot(init_request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let session_id = response
.headers()
.get(MCP_SESSION_ID_HEADER)
.unwrap()
.to_str()
.unwrap()
.to_string();
assert_eq!(store.len().await, 1);
let record = store.load(&session_id).await.unwrap();
assert!(record.is_some(), "expected session to be persisted");
assert_eq!(record.unwrap().id, session_id);
assert!(handle.terminate_session(&session_id).await);
assert_eq!(store.len().await, 0);
assert!(store.load(&session_id).await.unwrap().is_none());
}
#[tokio::test]
async fn test_custom_event_store_buffers_and_purges() {
use crate::event_store::{EventStore as PublicEventStore, MemoryEventStore};
let events = Arc::new(MemoryEventStore::new());
let events_dyn: Arc<dyn PublicEventStore> = events.clone();
let session = Arc::new(Session::new(
create_test_router(),
false,
identity_factory(),
events_dyn,
));
session.buffer_event(0, "first".to_string()).await;
session.buffer_event(1, "second".to_string()).await;
assert_eq!(events.total_events().await, 2);
let replayed = events.replay_after(&session.id, 0).await.unwrap();
assert_eq!(replayed.len(), 1);
assert_eq!(replayed[0].id, 1);
assert_eq!(replayed[0].data, "second");
events.purge_session(&session.id).await.unwrap();
assert_eq!(events.total_events().await, 0);
}
#[tokio::test]
async fn test_restore_from_store_serves_unknown_session_id() {
use crate::session_store::{MemorySessionStore, SessionRecord, SessionStore};
let store = Arc::new(MemorySessionStore::new());
let store_dyn: Arc<dyn SessionStore> = store.clone();
let mut seeded = SessionRecord::new(
"shared-session".to_string(),
"2025-11-25".to_string(),
Duration::from_secs(60),
);
store.create(&mut seeded).await.unwrap();
let seeded_id = seeded.id;
let transport = HttpTransport::new(create_test_router())
.disable_origin_validation()
.session_store(store_dyn);
let app = transport.into_router();
let list_request = Request::builder()
.method("POST")
.uri("/")
.header("Content-Type", "application/json")
.header(MCP_SESSION_ID_HEADER, &seeded_id)
.body(Body::from(
serde_json::json!({
"jsonrpc": "2.0",
"id": 1,
"method": "tools/list"
})
.to_string(),
))
.unwrap();
let response = app.oneshot(list_request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert!(
json.get("result").is_some(),
"expected tools/list result, got {json}"
);
}
#[tokio::test]
async fn test_auto_reinitialize_serves_unknown_session_without_store_record() {
let transport = HttpTransport::new(create_test_router())
.disable_origin_validation()
.auto_reinitialize_sessions(true);
let app = transport.into_router();
let list_request = Request::builder()
.method("POST")
.uri("/")
.header("Content-Type", "application/json")
.header(MCP_SESSION_ID_HEADER, "client-made-up-id")
.body(Body::from(
serde_json::json!({
"jsonrpc": "2.0",
"id": 1,
"method": "tools/list"
})
.to_string(),
))
.unwrap();
let response = app.oneshot(list_request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert!(
json.get("result").is_some(),
"expected tools/list result, got {json}"
);
}
#[tokio::test]
async fn test_unknown_session_without_restore_or_auto_reinit_returns_error() {
let transport = HttpTransport::new(create_test_router()).disable_origin_validation();
let app = transport.into_router();
let list_request = Request::builder()
.method("POST")
.uri("/")
.header("Content-Type", "application/json")
.header(MCP_SESSION_ID_HEADER, "never-seen-before")
.body(Body::from(
serde_json::json!({
"jsonrpc": "2.0",
"id": 1,
"method": "tools/list"
})
.to_string(),
))
.unwrap();
let response = app.oneshot(list_request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert!(json.get("error").is_some(), "expected error, got {json}");
assert_eq!(json["error"]["code"], -32005); }
#[tokio::test]
async fn test_session_expiration() {
let config = SessionConfig::with_ttl(Duration::from_millis(50))
.cleanup_interval(Duration::from_millis(10));
let transport = HttpTransport::new(create_test_router())
.disable_origin_validation()
.session_config(config);
let app = transport.into_router();
let init_request = Request::builder()
.method("POST")
.uri("/")
.header("Content-Type", "application/json")
.header("Accept", "application/json, text/event-stream")
.body(Body::from(
serde_json::json!({
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": "2025-11-25",
"capabilities": {},
"clientInfo": {
"name": "test-client",
"version": "1.0.0"
}
}
})
.to_string(),
))
.unwrap();
let response = app.clone().oneshot(init_request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let session_id = response
.headers()
.get(MCP_SESSION_ID_HEADER)
.unwrap()
.to_str()
.unwrap()
.to_string();
tokio::time::sleep(Duration::from_millis(100)).await;
let list_request = Request::builder()
.method("POST")
.uri("/")
.header("Content-Type", "application/json")
.header(MCP_SESSION_ID_HEADER, &session_id)
.body(Body::from(
serde_json::json!({
"jsonrpc": "2.0",
"id": 2,
"method": "tools/list"
})
.to_string(),
))
.unwrap();
let response = app.oneshot(list_request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert!(json.get("error").is_some());
assert_eq!(json["error"]["code"], -32005); }
#[tokio::test]
async fn test_layer_with_identity() {
let transport = HttpTransport::new(create_test_router())
.disable_origin_validation()
.layer(tower::layer::util::Identity::new());
let app = transport.into_router();
let request = Request::builder()
.method("POST")
.uri("/")
.header("Content-Type", "application/json")
.header("Accept", "application/json, text/event-stream")
.body(Body::from(
serde_json::json!({
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": "2025-11-25",
"capabilities": {},
"clientInfo": {
"name": "test-client",
"version": "1.0.0"
}
}
})
.to_string(),
))
.unwrap();
let response = app.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert!(response.headers().contains_key(MCP_SESSION_ID_HEADER));
}
#[tokio::test]
async fn test_layer_with_timeout() {
use std::time::Duration;
use tower::timeout::TimeoutLayer;
let transport = HttpTransport::new(create_test_router())
.disable_origin_validation()
.layer(TimeoutLayer::new(Duration::from_secs(30)));
let app = transport.into_router();
let request = Request::builder()
.method("POST")
.uri("/")
.header("Content-Type", "application/json")
.header("Accept", "application/json, text/event-stream")
.body(Body::from(
serde_json::json!({
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": "2025-11-25",
"capabilities": {},
"clientInfo": {
"name": "test-client",
"version": "1.0.0"
}
}
})
.to_string(),
))
.unwrap();
let response = app.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert!(response.headers().contains_key(MCP_SESSION_ID_HEADER));
}
#[tokio::test]
async fn test_layer_middleware_error_produces_jsonrpc_error() {
use std::time::Duration;
use tower::timeout::TimeoutLayer;
let slow_tool = crate::tool::ToolBuilder::new("slow")
.description("A slow tool")
.handler(|_: serde_json::Value| async move {
tokio::time::sleep(Duration::from_secs(10)).await;
Ok(crate::CallToolResult::text("done"))
})
.build();
let router = McpRouter::new()
.server_info("test-server", "1.0.0")
.tool(slow_tool);
let transport = HttpTransport::new(router)
.disable_origin_validation()
.layer(TimeoutLayer::new(Duration::from_millis(1)));
let app = transport.into_router();
let init_request = Request::builder()
.method("POST")
.uri("/")
.header("Content-Type", "application/json")
.header("Accept", "application/json, text/event-stream")
.body(Body::from(
serde_json::json!({
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": "2025-11-25",
"capabilities": {},
"clientInfo": {
"name": "test-client",
"version": "1.0.0"
}
}
})
.to_string(),
))
.unwrap();
let response = app.clone().oneshot(init_request).await.unwrap();
let session_id = response
.headers()
.get(MCP_SESSION_ID_HEADER)
.unwrap()
.to_str()
.unwrap()
.to_string();
let tool_request = Request::builder()
.method("POST")
.uri("/")
.header("Content-Type", "application/json")
.header(MCP_SESSION_ID_HEADER, &session_id)
.body(Body::from(
serde_json::json!({
"jsonrpc": "2.0",
"id": 2,
"method": "tools/call",
"params": {
"name": "slow",
"arguments": {}
}
})
.to_string(),
))
.unwrap();
let response = app.oneshot(tool_request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert!(
json.get("error").is_some(),
"Expected JSON-RPC error response, got: {}",
json
);
}
#[tokio::test]
async fn test_max_sessions_limit() {
let config = SessionConfig::default().max_sessions(1);
let transport = HttpTransport::new(create_test_router())
.disable_origin_validation()
.session_config(config);
let app = transport.into_router();
let init_request1 = Request::builder()
.method("POST")
.uri("/")
.header("Content-Type", "application/json")
.header("Accept", "application/json, text/event-stream")
.body(Body::from(
serde_json::json!({
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": "2025-11-25",
"capabilities": {},
"clientInfo": {
"name": "test-client",
"version": "1.0.0"
}
}
})
.to_string(),
))
.unwrap();
let response = app.clone().oneshot(init_request1).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let init_request2 = Request::builder()
.method("POST")
.uri("/")
.header("Content-Type", "application/json")
.header("Accept", "application/json, text/event-stream")
.body(Body::from(
serde_json::json!({
"jsonrpc": "2.0",
"id": 2,
"method": "initialize",
"params": {
"protocolVersion": "2025-11-25",
"capabilities": {},
"clientInfo": {
"name": "test-client-2",
"version": "1.0.0"
}
}
})
.to_string(),
))
.unwrap();
let response = app.oneshot(init_request2).await.unwrap();
assert_eq!(response.status(), StatusCode::SERVICE_UNAVAILABLE);
}
#[tokio::test]
async fn test_session_event_buffering() {
let session = Session::new(
create_test_router(),
false,
identity_factory(),
Arc::new(crate::event_store::MemoryEventStore::new()),
);
session.buffer_event(0, "event0".to_string()).await;
session.buffer_event(1, "event1".to_string()).await;
session.buffer_event(2, "event2".to_string()).await;
let events = session.get_events_after(0).await;
assert_eq!(events.len(), 2);
assert_eq!(events[0].id, 1);
assert_eq!(events[0].data, "event1");
assert_eq!(events[1].id, 2);
assert_eq!(events[1].data, "event2");
let events = session.get_events_after(1).await;
assert_eq!(events.len(), 1);
assert_eq!(events[0].id, 2);
let events = session.get_events_after(2).await;
assert!(events.is_empty());
}
#[tokio::test]
async fn test_session_event_counter_increments() {
let session = Session::new(
create_test_router(),
false,
identity_factory(),
Arc::new(crate::event_store::MemoryEventStore::new()),
);
assert_eq!(session.next_event_id(), 0);
assert_eq!(session.next_event_id(), 1);
assert_eq!(session.next_event_id(), 2);
}
#[tokio::test]
async fn test_session_event_buffer_limit() {
let session = Session::new(
create_test_router(),
false,
identity_factory(),
Arc::new(crate::event_store::MemoryEventStore::new()),
);
for i in 0..10 {
session.buffer_event(i, format!("event{}", i)).await;
}
let events = session.get_events_after(0).await;
assert_eq!(events.len(), 9);
}
#[tokio::test]
async fn test_session_handle_count() {
let transport = HttpTransport::new(create_test_router()).disable_origin_validation();
let (app, handle) = transport.into_router_with_handle();
assert_eq!(handle.session_count().await, 0);
let request = Request::builder()
.method("POST")
.uri("/")
.header("Content-Type", "application/json")
.header("Accept", "application/json, text/event-stream")
.body(Body::from(
serde_json::json!({
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": "2025-11-25",
"capabilities": {},
"clientInfo": {
"name": "test-client",
"version": "1.0.0"
}
}
})
.to_string(),
))
.unwrap();
let response = app.oneshot(request).await.unwrap();
assert_eq!(response.status(), 200);
assert_eq!(handle.session_count().await, 1);
}
#[tokio::test]
async fn test_session_handle_list_and_terminate() {
let transport = HttpTransport::new(create_test_router()).disable_origin_validation();
let (app, handle) = transport.into_router_with_handle();
assert!(handle.list_sessions().await.is_empty());
let request = Request::builder()
.method("POST")
.uri("/")
.header("Content-Type", "application/json")
.header("Accept", "application/json, text/event-stream")
.body(Body::from(
serde_json::json!({
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": "2025-11-25",
"capabilities": {},
"clientInfo": {
"name": "test-client",
"version": "1.0.0"
}
}
})
.to_string(),
))
.unwrap();
let response = app.oneshot(request).await.unwrap();
assert_eq!(response.status(), 200);
let sessions = handle.list_sessions().await;
assert_eq!(sessions.len(), 1);
assert!(!sessions[0].id.is_empty());
let session_id = sessions[0].id.clone();
assert!(handle.terminate_session(&session_id).await);
assert_eq!(handle.session_count().await, 0);
assert!(!handle.terminate_session(&session_id).await);
}
#[tokio::test]
async fn test_request_without_session_id_rejected() {
let transport = HttpTransport::new(create_test_router())
.disable_origin_validation()
.require_sessions();
let app = transport.into_router();
let request = Request::builder()
.method("POST")
.uri("/")
.header("Content-Type", "application/json")
.header("Accept", "application/json")
.body(Body::from(
serde_json::json!({
"jsonrpc": "2.0",
"id": 1,
"method": "tools/list",
"params": {}
})
.to_string(),
))
.unwrap();
let response = app.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK); let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert!(json["error"].is_object());
}
#[tokio::test]
async fn test_invalid_session_id_returns_error() {
let transport = HttpTransport::new(create_test_router()).disable_origin_validation();
let app = transport.into_router();
let request = Request::builder()
.method("POST")
.uri("/")
.header("Content-Type", "application/json")
.header("Accept", "application/json")
.header("mcp-session-id", "nonexistent-session-id")
.body(Body::from(
serde_json::json!({
"jsonrpc": "2.0",
"id": 1,
"method": "tools/list",
"params": {}
})
.to_string(),
))
.unwrap();
let response = app.oneshot(request).await.unwrap();
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(json["error"]["code"].as_i64().unwrap(), -32005); }
#[tokio::test]
async fn test_notification_returns_accepted() {
let transport = HttpTransport::new(create_test_router()).disable_origin_validation();
let app = transport.into_router();
let init_req = Request::builder()
.method("POST")
.uri("/")
.header("Content-Type", "application/json")
.header("Accept", "application/json, text/event-stream")
.body(Body::from(
serde_json::json!({
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": "2025-11-25",
"capabilities": {},
"clientInfo": { "name": "test", "version": "1.0" }
}
})
.to_string(),
))
.unwrap();
let resp = app.clone().oneshot(init_req).await.unwrap();
let session_id = resp
.headers()
.get(MCP_SESSION_ID_HEADER)
.unwrap()
.to_str()
.unwrap()
.to_string();
let notif = Request::builder()
.method("POST")
.uri("/")
.header("Content-Type", "application/json")
.header("mcp-session-id", &session_id)
.body(Body::from(
serde_json::json!({
"jsonrpc": "2.0",
"method": "notifications/initialized"
})
.to_string(),
))
.unwrap();
let response = app.oneshot(notif).await.unwrap();
assert_eq!(response.status(), StatusCode::ACCEPTED);
}
#[tokio::test]
async fn test_invalid_json_returns_parse_error() {
let transport = HttpTransport::new(create_test_router()).disable_origin_validation();
let app = transport.into_router();
let request = Request::builder()
.method("POST")
.uri("/")
.header("Content-Type", "application/json")
.header("Accept", "application/json")
.body(Body::from("not valid json{{{"))
.unwrap();
let response = app.oneshot(request).await.unwrap();
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(json["error"]["code"].as_i64().unwrap(), -32700); }
#[tokio::test]
async fn test_session_config_max_sessions() {
let transport = HttpTransport::new(create_test_router())
.disable_origin_validation()
.session_config(SessionConfig::default().max_sessions(1));
let app = transport.into_router();
let init1 = Request::builder()
.method("POST")
.uri("/")
.header("Content-Type", "application/json")
.header("Accept", "application/json, text/event-stream")
.body(Body::from(
serde_json::json!({
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": "2025-11-25",
"capabilities": {},
"clientInfo": { "name": "test1", "version": "1.0" }
}
})
.to_string(),
))
.unwrap();
let resp1 = app.clone().oneshot(init1).await.unwrap();
assert_eq!(resp1.status(), StatusCode::OK);
let init2 = Request::builder()
.method("POST")
.uri("/")
.header("Content-Type", "application/json")
.header("Accept", "application/json, text/event-stream")
.body(Body::from(
serde_json::json!({
"jsonrpc": "2.0",
"id": 2,
"method": "initialize",
"params": {
"protocolVersion": "2025-11-25",
"capabilities": {},
"clientInfo": { "name": "test2", "version": "1.0" }
}
})
.to_string(),
))
.unwrap();
let resp2 = app.oneshot(init2).await.unwrap();
assert_eq!(resp2.status(), StatusCode::SERVICE_UNAVAILABLE);
}
#[tokio::test]
async fn test_delete_terminates_session() {
let transport = HttpTransport::new(create_test_router()).disable_origin_validation();
let app = transport.into_router();
let init_req = Request::builder()
.method("POST")
.uri("/")
.header("Content-Type", "application/json")
.header("Accept", "application/json, text/event-stream")
.body(Body::from(
serde_json::json!({
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": "2025-11-25",
"capabilities": {},
"clientInfo": { "name": "test", "version": "1.0" }
}
})
.to_string(),
))
.unwrap();
let resp = app.clone().oneshot(init_req).await.unwrap();
let session_id = resp
.headers()
.get(MCP_SESSION_ID_HEADER)
.unwrap()
.to_str()
.unwrap()
.to_string();
let delete_req = Request::builder()
.method("DELETE")
.uri("/")
.header("mcp-session-id", &session_id)
.body(Body::empty())
.unwrap();
let resp = app.clone().oneshot(delete_req).await.unwrap();
assert!(resp.status().is_success());
let list_req = Request::builder()
.method("POST")
.uri("/")
.header("Content-Type", "application/json")
.header("Accept", "application/json")
.header("mcp-session-id", &session_id)
.body(Body::from(
serde_json::json!({
"jsonrpc": "2.0",
"id": 2,
"method": "tools/list",
"params": {}
})
.to_string(),
))
.unwrap();
let resp = app.oneshot(list_req).await.unwrap();
let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
.await
.unwrap();
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(json["error"]["code"].as_i64().unwrap(), -32005);
}
#[test]
fn test_is_localhost_origin_http() {
assert!(is_localhost_origin("http://localhost"));
assert!(is_localhost_origin("http://localhost:3000"));
assert!(is_localhost_origin("http://127.0.0.1"));
assert!(is_localhost_origin("http://127.0.0.1:8080"));
assert!(is_localhost_origin("http://[::1]"));
assert!(is_localhost_origin("http://[::1]:3000"));
}
#[test]
fn test_is_localhost_origin_https() {
assert!(is_localhost_origin("https://localhost"));
assert!(is_localhost_origin("https://127.0.0.1:443"));
}
#[test]
fn test_is_not_localhost_origin() {
assert!(!is_localhost_origin("http://example.com"));
assert!(!is_localhost_origin("http://evil-localhost.com"));
assert!(!is_localhost_origin("http://localhost.evil.com"));
assert!(!is_localhost_origin("ftp://localhost"));
assert!(!is_localhost_origin("localhost"));
assert!(!is_localhost_origin(""));
}
#[tokio::test]
async fn test_origin_validation_rejects_cross_origin() {
let transport = HttpTransport::new(create_test_router());
let app = transport.into_router();
let req = Request::builder()
.method("POST")
.uri("/")
.header("Content-Type", "application/json")
.header("Accept", "application/json, text/event-stream")
.header("Origin", "http://evil.com")
.body(Body::from(
serde_json::json!({
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": "2025-11-25",
"capabilities": {},
"clientInfo": { "name": "test", "version": "1.0" }
}
})
.to_string(),
))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::FORBIDDEN);
}
#[tokio::test]
async fn test_origin_validation_allows_localhost() {
let transport = HttpTransport::new(create_test_router());
let app = transport.into_router();
let req = Request::builder()
.method("POST")
.uri("/")
.header("Content-Type", "application/json")
.header("Accept", "application/json, text/event-stream")
.header("Origin", "http://localhost:3000")
.body(Body::from(
serde_json::json!({
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": "2025-11-25",
"capabilities": {},
"clientInfo": { "name": "test", "version": "1.0" }
}
})
.to_string(),
))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_origin_validation_allows_configured_origin() {
let transport = HttpTransport::new(create_test_router())
.allowed_origins(vec!["https://my-app.example.com".to_string()]);
let app = transport.into_router();
let req = Request::builder()
.method("POST")
.uri("/")
.header("Content-Type", "application/json")
.header("Accept", "application/json, text/event-stream")
.header("Origin", "https://my-app.example.com")
.body(Body::from(
serde_json::json!({
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": "2025-11-25",
"capabilities": {},
"clientInfo": { "name": "test", "version": "1.0" }
}
})
.to_string(),
))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_origin_validation_rejects_unconfigured_origin() {
let transport = HttpTransport::new(create_test_router())
.allowed_origins(vec!["https://my-app.example.com".to_string()]);
let app = transport.into_router();
let req = Request::builder()
.method("POST")
.uri("/")
.header("Content-Type", "application/json")
.header("Accept", "application/json, text/event-stream")
.header("Origin", "https://other-app.example.com")
.body(Body::from(
serde_json::json!({
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": "2025-11-25",
"capabilities": {},
"clientInfo": { "name": "test", "version": "1.0" }
}
})
.to_string(),
))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::FORBIDDEN);
}
#[tokio::test]
async fn test_origin_validation_no_header_allowed() {
let transport = HttpTransport::new(create_test_router());
let app = transport.into_router();
let req = Request::builder()
.method("POST")
.uri("/")
.header("Content-Type", "application/json")
.header("Accept", "application/json, text/event-stream")
.body(Body::from(
serde_json::json!({
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": "2025-11-25",
"capabilities": {},
"clientInfo": { "name": "test", "version": "1.0" }
}
})
.to_string(),
))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_disabled_origin_validation_allows_any() {
let transport = HttpTransport::new(create_test_router()).disable_origin_validation();
let app = transport.into_router();
let req = Request::builder()
.method("POST")
.uri("/")
.header("Content-Type", "application/json")
.header("Accept", "application/json, text/event-stream")
.header("Origin", "http://evil.com")
.body(Body::from(
serde_json::json!({
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": "2025-11-25",
"capabilities": {},
"clientInfo": { "name": "test", "version": "1.0" }
}
})
.to_string(),
))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
fn initialize_body() -> Body {
Body::from(
serde_json::json!({
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": "2025-11-25",
"capabilities": {},
"clientInfo": { "name": "test", "version": "1.0" }
}
})
.to_string(),
)
}
#[test]
fn test_is_localhost_host_variants() {
assert!(is_localhost_host("localhost"));
assert!(is_localhost_host("localhost:3000"));
assert!(is_localhost_host("127.0.0.1"));
assert!(is_localhost_host("127.0.0.1:8080"));
assert!(is_localhost_host("[::1]"));
assert!(is_localhost_host("[::1]:3000"));
assert!(!is_localhost_host("evil.com"));
assert!(!is_localhost_host("api.example.com:8443"));
assert!(!is_localhost_host("10.0.0.1"));
}
#[tokio::test]
async fn test_host_validation_allows_localhost() {
let transport = HttpTransport::new(create_test_router())
.allowed_hosts(vec!["api.example.com".to_string()]);
let app = transport.into_router();
let req = Request::builder()
.method("POST")
.uri("/")
.header("Content-Type", "application/json")
.header("Accept", "application/json, text/event-stream")
.header("Host", "127.0.0.1:3000")
.body(initialize_body())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_host_validation_allows_configured_host() {
let transport = HttpTransport::new(create_test_router())
.allowed_hosts(vec!["api.example.com".to_string()]);
let app = transport.into_router();
let req = Request::builder()
.method("POST")
.uri("/")
.header("Content-Type", "application/json")
.header("Accept", "application/json, text/event-stream")
.header("Host", "api.example.com")
.body(initialize_body())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_host_validation_rejects_unconfigured_host() {
let transport = HttpTransport::new(create_test_router())
.allowed_hosts(vec!["api.example.com".to_string()]);
let app = transport.into_router();
let req = Request::builder()
.method("POST")
.uri("/")
.header("Content-Type", "application/json")
.header("Accept", "application/json, text/event-stream")
.header("Host", "evil.com")
.body(initialize_body())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
}
#[tokio::test]
async fn test_host_validation_no_allowlist_accepts_any_host() {
let transport = HttpTransport::new(create_test_router());
let app = transport.into_router();
let req = Request::builder()
.method("POST")
.uri("/")
.header("Content-Type", "application/json")
.header("Accept", "application/json, text/event-stream")
.header("Host", "any.example.com")
.body(initialize_body())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_disabled_host_validation_allows_any_with_allowlist() {
let transport = HttpTransport::new(create_test_router())
.disable_host_validation()
.allowed_hosts(vec!["api.example.com".to_string()]);
let app = transport.into_router();
let req = Request::builder()
.method("POST")
.uri("/")
.header("Content-Type", "application/json")
.header("Accept", "application/json, text/event-stream")
.header("Host", "evil.com")
.body(initialize_body())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[test]
fn test_effective_host_prefers_header() {
let mut headers = HeaderMap::new();
headers.insert(header::HOST, HeaderValue::from_static("api.example.com"));
let uri: axum::http::Uri = "http://other.example.com/path".parse().unwrap();
assert_eq!(effective_host(&headers, &uri), Some("api.example.com"));
}
#[test]
fn test_effective_host_falls_back_to_authority() {
let headers = HeaderMap::new();
let uri: axum::http::Uri = "http://api.example.com/path".parse().unwrap();
assert_eq!(effective_host(&headers, &uri), Some("api.example.com"));
}
#[test]
fn test_effective_host_returns_none_when_both_missing() {
let headers = HeaderMap::new();
let uri: axum::http::Uri = "/path".parse().unwrap();
assert_eq!(effective_host(&headers, &uri), None);
}
async fn init_session(app: &Router) -> String {
let req = Request::builder()
.method("POST")
.uri("/")
.header("Content-Type", "application/json")
.header("Accept", "application/json, text/event-stream")
.body(Body::from(
serde_json::json!({
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": "2025-11-25",
"capabilities": {},
"clientInfo": { "name": "test", "version": "1.0" }
}
})
.to_string(),
))
.unwrap();
let resp = app.clone().oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
resp.headers()
.get(MCP_SESSION_ID_HEADER)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string())
.expect("initialize must return a session id")
}
#[tokio::test]
async fn test_external_notification_reaches_single_session() {
let (notif_tx, notif_rx) = notification_channel(8);
let transport = HttpTransport::with_notifications(create_test_router(), notif_rx);
let (app, session_handle) = transport.into_router_with_handle();
let session_id = init_session(&app).await;
let mut rx = {
let sessions = session_handle.store.sessions.read().await;
let session = sessions
.get(&session_id)
.expect("session should be registered");
session.notifications_tx.subscribe()
};
notif_tx
.send(crate::context::ServerNotification::ResourceUpdated {
uri: "claude://chats/abc".to_string(),
})
.await
.unwrap();
let json = tokio::time::timeout(Duration::from_secs(1), rx.recv())
.await
.expect("notification should arrive within timeout")
.expect("broadcast channel closed");
assert!(json.contains("notifications/resources/updated"));
assert!(json.contains("claude://chats/abc"));
}
#[tokio::test]
async fn test_external_notification_fans_out_to_all_sessions() {
let (notif_tx, notif_rx) = notification_channel(8);
let transport = HttpTransport::with_notifications(create_test_router(), notif_rx);
let (app, session_handle) = transport.into_router_with_handle();
let session_a = init_session(&app).await;
let session_b = init_session(&app).await;
assert_ne!(session_a, session_b);
let (mut rx_a, mut rx_b) = {
let sessions = session_handle.store.sessions.read().await;
let a = sessions.get(&session_a).unwrap();
let b = sessions.get(&session_b).unwrap();
(
a.notifications_tx.subscribe(),
b.notifications_tx.subscribe(),
)
};
notif_tx
.send(crate::context::ServerNotification::ResourcesListChanged)
.await
.unwrap();
let json_a = tokio::time::timeout(Duration::from_secs(1), rx_a.recv())
.await
.unwrap()
.unwrap();
let json_b = tokio::time::timeout(Duration::from_secs(1), rx_b.recv())
.await
.unwrap()
.unwrap();
assert!(json_a.contains("notifications/resources/list_changed"));
assert!(json_b.contains("notifications/resources/list_changed"));
}
#[tokio::test]
async fn test_external_notifications_builder_method() {
let (notif_tx, notif_rx) = notification_channel(8);
let transport = HttpTransport::new(create_test_router()).external_notifications(notif_rx);
let (app, session_handle) = transport.into_router_with_handle();
let session_id = init_session(&app).await;
let mut rx = {
let sessions = session_handle.store.sessions.read().await;
sessions
.get(&session_id)
.unwrap()
.notifications_tx
.subscribe()
};
notif_tx
.send(crate::context::ServerNotification::ToolsListChanged)
.await
.unwrap();
let json = tokio::time::timeout(Duration::from_secs(1), rx.recv())
.await
.unwrap()
.unwrap();
assert!(json.contains("notifications/tools/list_changed"));
}
#[tokio::test]
async fn test_default_transport_has_no_external_fanout_task() {
let transport = HttpTransport::new(create_test_router());
let (app, _handle) = transport.into_router_with_handle();
let _session_id = init_session(&app).await;
}
}