use std::{
collections::HashMap,
future::Future,
sync::{Arc, RwLock},
time::{Duration, Instant},
};
use axum::{
Router,
body::Bytes,
extract::{FromRequestParts, State},
http::{HeaderValue, StatusCode, header, header::ToStrError, request::Parts},
response::{IntoResponse, Response},
routing::{delete, get, post},
};
use rand::Rng;
use thiserror::Error;
pub use super::{
ToolHandler,
session_id::{HTTP_SESSION_ID_HEADER, McpSessionId, ParseSessionIdError},
};
use crate::{McpServer, McpServerBuilder, Output, ToolRegistry, parse_line};
#[derive(Debug, Error)]
pub enum SessionIdRejection {
#[error("missing session ID header `{HTTP_SESSION_ID_HEADER}`")]
Missing,
#[error("session ID header not valid UTF-8")]
InvalidUtf8(#[source] ToStrError),
#[error("invalid session ID")]
InvalidFormat(#[source] ParseSessionIdError),
}
impl IntoResponse for SessionIdRejection {
fn into_response(self) -> Response {
(StatusCode::BAD_REQUEST, self.to_string()).into_response()
}
}
impl<S> FromRequestParts<S> for McpSessionId
where
S: Send + Sync,
{
type Rejection = SessionIdRejection;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let value = parts
.headers
.get(HTTP_SESSION_ID_HEADER)
.ok_or(SessionIdRejection::Missing)?;
let s = value.to_str().map_err(SessionIdRejection::InvalidUtf8)?;
s.parse().map_err(SessionIdRejection::InvalidFormat)
}
}
#[derive(Clone, Copy, Debug)]
pub struct OptionalSessionId(pub Option<McpSessionId>);
impl<S> FromRequestParts<S> for OptionalSessionId
where
S: Send + Sync,
{
type Rejection = SessionIdRejection;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
if !parts.headers.contains_key(HTTP_SESSION_ID_HEADER) {
return Ok(Self(None));
}
let id = McpSessionId::from_request_parts(parts, state).await?;
Ok(Self(Some(id)))
}
}
pub trait SessionStorage<R: ToolRegistry>: Send + Sync + 'static {
type Error: std::fmt::Display + Send;
fn create(
&self,
server: McpServer<R>,
) -> impl Future<Output = Result<McpSessionId, Self::Error>> + Send;
fn with_session<T: Send>(
&self,
id: McpSessionId,
f: impl FnOnce(&mut McpServer<R>) -> T + Send,
) -> impl Future<Output = Result<Option<T>, Self::Error>> + Send;
fn remove(&self, id: McpSessionId) -> impl Future<Output = bool> + Send;
}
#[derive(Clone, Copy, Debug, Error)]
pub enum InMemoryStorageError {
#[error("session storage at capacity")]
AtCapacity,
}
struct SessionEntry<R: ToolRegistry> {
server: McpServer<R>,
last_accessed: Instant,
}
pub struct InMemoryStorage<R: ToolRegistry> {
sessions: RwLock<HashMap<McpSessionId, SessionEntry<R>>>,
capacity: usize,
min_eviction_age: Duration,
}
impl<R: ToolRegistry> InMemoryStorage<R> {
pub fn new(capacity: usize, min_eviction_age: Duration) -> Self {
Self {
sessions: RwLock::new(HashMap::new()),
capacity,
min_eviction_age,
}
}
}
impl<R: ToolRegistry + Send + Sync + 'static> SessionStorage<R> for InMemoryStorage<R> {
type Error = InMemoryStorageError;
async fn create(&self, server: McpServer<R>) -> Result<McpSessionId, Self::Error> {
let id: McpSessionId = rand::rng().random();
let now = Instant::now();
let mut sessions = self.sessions.write().expect("lock poisoned");
if sessions.len() >= self.capacity {
let eviction_threshold = now - self.min_eviction_age;
let oldest = sessions
.iter()
.filter(|(_, entry)| entry.last_accessed < eviction_threshold)
.min_by_key(|(_, entry)| entry.last_accessed)
.map(|(id, _)| *id);
match oldest {
Some(oldest_id) => {
sessions.remove(&oldest_id);
}
None => return Err(InMemoryStorageError::AtCapacity),
}
}
sessions.insert(
id,
SessionEntry {
server,
last_accessed: now,
},
);
Ok(id)
}
async fn with_session<T: Send>(
&self,
id: McpSessionId,
f: impl FnOnce(&mut McpServer<R>) -> T + Send,
) -> Result<Option<T>, Self::Error> {
let mut sessions = self.sessions.write().expect("lock poisoned");
let Some(entry) = sessions.get_mut(&id) else {
return Ok(None);
};
entry.last_accessed = Instant::now();
Ok(Some(f(&mut entry.server)))
}
async fn remove(&self, id: McpSessionId) -> bool {
self.sessions
.write()
.expect("lock poisoned")
.remove(&id)
.is_some()
}
}
pub const DEFAULT_CAPACITY: usize = 10_000;
pub const DEFAULT_MIN_EVICTION_AGE: Duration = Duration::from_secs(120);
struct AppState<R: ToolRegistry, H: ToolHandler<R>, S: SessionStorage<R>> {
builder: Arc<McpServerBuilder<R>>,
storage: Arc<S>,
handler: H,
}
impl<R: ToolRegistry, H: ToolHandler<R> + Clone, S: SessionStorage<R>> Clone for AppState<R, H, S> {
fn clone(&self) -> Self {
Self {
builder: Arc::clone(&self.builder),
storage: Arc::clone(&self.storage),
handler: self.handler.clone(),
}
}
}
pub struct McpRouter;
impl McpRouter {
pub fn builder<R, H>(
builder: McpServerBuilder<R>,
handler: H,
) -> McpRouterBuilder<R, H, InMemoryStorage<R>>
where
R: ToolRegistry + Send + Sync + 'static,
H: ToolHandler<R> + Clone + Send + Sync + 'static,
{
McpRouterBuilder {
builder,
handler,
storage: InMemoryStorage::new(DEFAULT_CAPACITY, DEFAULT_MIN_EVICTION_AGE),
}
}
}
pub struct McpRouterBuilder<R: ToolRegistry, H, S> {
builder: McpServerBuilder<R>,
handler: H,
storage: S,
}
impl<R, H, S> McpRouterBuilder<R, H, S>
where
R: ToolRegistry + Send + Sync + 'static,
H: ToolHandler<R> + Clone + Send + Sync + 'static,
S: SessionStorage<R>,
{
pub fn storage<S2: SessionStorage<R>>(self, storage: S2) -> McpRouterBuilder<R, H, S2> {
McpRouterBuilder {
builder: self.builder,
handler: self.handler,
storage,
}
}
pub fn build(self) -> Router {
let state = AppState {
builder: Arc::new(self.builder),
storage: Arc::new(self.storage),
handler: self.handler,
};
Router::new()
.route("/", post(handle_post::<R, H, S>))
.route("/", get(handle_get))
.route("/", delete(handle_delete::<R, H, S>))
.with_state(state)
}
}
pub fn mcp_router<R, H>(builder: McpServerBuilder<R>, handler: H) -> Router
where
R: ToolRegistry + Send + Sync + 'static,
H: ToolHandler<R> + Clone + Send + Sync + 'static,
{
McpRouter::builder(builder, handler).build()
}
async fn handle_post<R, H, S>(
State(state): State<AppState<R, H, S>>,
OptionalSessionId(session_id): OptionalSessionId,
body: Bytes,
) -> Response
where
R: ToolRegistry + Send + Sync + 'static,
H: ToolHandler<R> + Clone + Send + Sync + 'static,
S: SessionStorage<R>,
{
let body_str = match std::str::from_utf8(&body) {
Ok(s) => s,
Err(_) => return (StatusCode::BAD_REQUEST, "invalid UTF-8").into_response(),
};
let msg = match parse_line(body_str) {
Ok(m) => m,
Err(e) => {
return (StatusCode::BAD_REQUEST, format!("invalid JSON-RPC: {e}")).into_response();
}
};
let session_id = match session_id {
Some(id) => id,
None => match state.storage.create(state.builder.build()).await {
Ok(id) => id,
Err(e) => {
return (
StatusCode::SERVICE_UNAVAILABLE,
format!("storage error: {e}"),
)
.into_response();
}
},
};
handle_session(&state, session_id, msg).await
}
async fn handle_get() -> Response {
StatusCode::METHOD_NOT_ALLOWED.into_response()
}
async fn handle_session<R, H, S>(
state: &AppState<R, H, S>,
session_id: McpSessionId,
msg: rust_mcp_schema::JsonrpcMessage,
) -> Response
where
R: ToolRegistry + Send + Sync + 'static,
H: ToolHandler<R> + Clone + Send + Sync + 'static,
S: SessionStorage<R>,
{
let output = match state
.storage
.with_session(session_id, |server| server.handle(msg))
.await
{
Ok(Some(output)) => output,
Ok(None) => return StatusCode::NOT_FOUND.into_response(),
Err(e) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
format!("storage error: {e}"),
)
.into_response();
}
};
if matches!(&output, Output::ProtocolError(_)) {
state.storage.remove(session_id).await;
}
match output {
Output::Send(msg) => json_response(&msg, session_id),
Output::ToolCall { tool, responder } => {
let result = state.handler.handle(Some(session_id), tool).await;
json_response(&responder.respond(result), session_id)
}
Output::None => StatusCode::ACCEPTED.into_response(),
Output::ProtocolError(e) => {
(StatusCode::BAD_REQUEST, format!("protocol error: {e}")).into_response()
}
}
}
fn json_response(msg: &crate::OutgoingMessage, session_id: McpSessionId) -> Response {
let json = match serde_json::to_vec(msg.as_inner()) {
Ok(j) => j,
Err(e) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
format!("serialization error: {e}"),
)
.into_response();
}
};
let mut response = (
StatusCode::OK,
[(header::CONTENT_TYPE, "application/json")],
json,
)
.into_response();
let value = HeaderValue::from_str(&session_id.to_string()).expect("hex is valid header");
response.headers_mut().insert(HTTP_SESSION_ID_HEADER, value);
response
}
async fn handle_delete<R, H, S>(
State(state): State<AppState<R, H, S>>,
session_id: McpSessionId,
) -> Response
where
R: ToolRegistry + Send + Sync + 'static,
H: ToolHandler<R> + Clone + Send + Sync + 'static,
S: SessionStorage<R>,
{
if state.storage.remove(session_id).await {
StatusCode::NO_CONTENT.into_response()
} else {
StatusCode::NOT_FOUND.into_response()
}
}
#[cfg(test)]
mod tests {
use axum::{
body::Body,
http::{Request, StatusCode},
};
use tower::util::ServiceExt;
use super::{HTTP_SESSION_ID_HEADER, mcp_router};
use crate::{McpServer, McpServerBuilder, NoTools};
fn test_builder() -> McpServerBuilder<NoTools> {
let mut builder = McpServer::builder();
builder.name("test").version("1.0");
builder
}
fn test_handler(_: NoTools) -> Result<String, std::convert::Infallible> {
unreachable!("no tools")
}
#[tokio::test]
async fn initialize_creates_session() {
let router = mcp_router(test_builder(), |_, t| async { test_handler(t) });
let body = r#"{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2025-06-18","capabilities":{},"clientInfo":{"name":"test","version":"1.0"}}}"#;
let response = router
.oneshot(
Request::builder()
.method("POST")
.uri("/")
.header("content-type", "application/json")
.body(Body::from(body))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert!(response.headers().contains_key(HTTP_SESSION_ID_HEADER));
let session_id = response
.headers()
.get(HTTP_SESSION_ID_HEADER)
.unwrap()
.to_str()
.unwrap();
assert!(!session_id.is_empty());
}
#[tokio::test]
async fn subsequent_request_requires_session() {
let router = mcp_router(test_builder(), |_, t| async { test_handler(t) });
let init_body = r#"{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2025-06-18","capabilities":{},"clientInfo":{"name":"test","version":"1.0"}}}"#;
let init_response = router
.clone()
.oneshot(
Request::builder()
.method("POST")
.uri("/")
.header("content-type", "application/json")
.body(Body::from(init_body))
.unwrap(),
)
.await
.unwrap();
let session_id = init_response
.headers()
.get(HTTP_SESSION_ID_HEADER)
.unwrap()
.to_str()
.unwrap()
.to_string();
let initialized_body = r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#;
let response = router
.clone()
.oneshot(
Request::builder()
.method("POST")
.uri("/")
.header("content-type", "application/json")
.header(HTTP_SESSION_ID_HEADER, &session_id)
.body(Body::from(initialized_body))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::ACCEPTED);
let ping_body = r#"{"jsonrpc":"2.0","id":2,"method":"ping"}"#;
let response = router
.oneshot(
Request::builder()
.method("POST")
.uri("/")
.header("content-type", "application/json")
.header(HTTP_SESSION_ID_HEADER, &session_id)
.body(Body::from(ping_body))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
}
#[tokio::test]
async fn invalid_session_returns_404() {
let router = mcp_router(test_builder(), |_, t| async { test_handler(t) });
let body = r#"{"jsonrpc":"2.0","id":1,"method":"ping"}"#;
let response = router
.oneshot(
Request::builder()
.method("POST")
.uri("/")
.header("content-type", "application/json")
.header(HTTP_SESSION_ID_HEADER, "00000000000000000000000000000000")
.body(Body::from(body))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::NOT_FOUND);
}
#[tokio::test]
async fn delete_removes_session() {
let router = mcp_router(test_builder(), |_, t| async { test_handler(t) });
let init_body = r#"{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2025-06-18","capabilities":{},"clientInfo":{"name":"test","version":"1.0"}}}"#;
let init_response = router
.clone()
.oneshot(
Request::builder()
.method("POST")
.uri("/")
.header("content-type", "application/json")
.body(Body::from(init_body))
.unwrap(),
)
.await
.unwrap();
let session_id = init_response
.headers()
.get(HTTP_SESSION_ID_HEADER)
.unwrap()
.to_str()
.unwrap()
.to_string();
let delete_response = router
.clone()
.oneshot(
Request::builder()
.method("DELETE")
.uri("/")
.header(HTTP_SESSION_ID_HEADER, &session_id)
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(delete_response.status(), StatusCode::NO_CONTENT);
let ping_body = r#"{"jsonrpc":"2.0","id":2,"method":"ping"}"#;
let response = router
.oneshot(
Request::builder()
.method("POST")
.uri("/")
.header("content-type", "application/json")
.header(HTTP_SESSION_ID_HEADER, &session_id)
.body(Body::from(ping_body))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::NOT_FOUND);
}
#[tokio::test]
async fn get_returns_405() {
let router = mcp_router(test_builder(), |_, t| async { test_handler(t) });
let response = router
.oneshot(
Request::builder()
.method("GET")
.uri("/")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::METHOD_NOT_ALLOWED);
}
mod storage {
use std::time::Duration;
use super::*;
use crate::io::{
McpSessionId,
axum::{InMemoryStorage, InMemoryStorageError, SessionStorage},
};
#[tokio::test]
async fn create_and_access_session() {
let storage = InMemoryStorage::new(10, Duration::from_secs(0));
let server = test_builder().build();
let id = storage.create(server).await.unwrap();
let result = storage
.with_session(id, |_server| "accessed")
.await
.unwrap();
assert_eq!(result, Some("accessed"));
}
#[tokio::test]
async fn missing_session_returns_none() {
let storage: InMemoryStorage<NoTools> =
InMemoryStorage::new(10, Duration::from_secs(0));
let fake_id = McpSessionId::from_raw(12345);
let result = storage
.with_session(fake_id, |_server| "accessed")
.await
.unwrap();
assert_eq!(result, None);
}
#[tokio::test]
async fn remove_session() {
let storage = InMemoryStorage::new(10, Duration::from_secs(0));
let server = test_builder().build();
let id = storage.create(server).await.unwrap();
assert!(storage.remove(id).await);
assert!(!storage.remove(id).await);
}
#[tokio::test]
async fn evicts_oldest_when_at_capacity() {
let storage = InMemoryStorage::new(2, Duration::from_secs(0));
let id1 = storage.create(test_builder().build()).await.unwrap();
let id2 = storage.create(test_builder().build()).await.unwrap();
let id3 = storage.create(test_builder().build()).await.unwrap();
let r1 = storage.with_session(id1, |_| ()).await.unwrap();
let r2 = storage.with_session(id2, |_| ()).await.unwrap();
let r3 = storage.with_session(id3, |_| ()).await.unwrap();
assert!(r1.is_none(), "oldest session should be evicted");
assert!(r2.is_some());
assert!(r3.is_some());
}
#[tokio::test]
async fn at_capacity_when_sessions_too_young() {
let storage = InMemoryStorage::new(2, Duration::from_secs(60));
storage.create(test_builder().build()).await.unwrap();
storage.create(test_builder().build()).await.unwrap();
let result = storage.create(test_builder().build()).await;
assert!(matches!(result, Err(InMemoryStorageError::AtCapacity)));
}
}
}