#![cfg(all(
feature = "client",
feature = "server",
feature = "transport-streamable-http-client-reqwest",
feature = "transport-streamable-http-server",
not(feature = "local")
))]
use std::{collections::HashMap, sync::Arc};
use rmcp::{
ServiceExt,
transport::{
StreamableHttpClientTransport,
streamable_http_client::StreamableHttpClientTransportConfig,
streamable_http_server::{
StreamableHttpServerConfig, StreamableHttpService,
session::{SessionState, SessionStore, SessionStoreError, local::LocalSessionManager},
},
},
};
use tokio::sync::RwLock;
use tokio_util::sync::CancellationToken;
mod common;
use common::calculator::Calculator;
#[derive(Default, Clone)]
struct InMemorySessionStore(Arc<RwLock<HashMap<String, SessionState>>>);
impl InMemorySessionStore {
fn new() -> Self {
Self::default()
}
async fn len(&self) -> usize {
self.0.read().await.len()
}
}
#[async_trait::async_trait]
impl SessionStore for InMemorySessionStore {
async fn load(&self, session_id: &str) -> Result<Option<SessionState>, SessionStoreError> {
Ok(self.0.read().await.get(session_id).cloned())
}
async fn store(&self, session_id: &str, state: &SessionState) -> Result<(), SessionStoreError> {
self.0
.write()
.await
.insert(session_id.to_owned(), state.clone());
Ok(())
}
async fn delete(&self, session_id: &str) -> Result<(), SessionStoreError> {
self.0.write().await.remove(session_id);
Ok(())
}
}
fn make_service(
session_store: Arc<dyn SessionStore>,
ct: &CancellationToken,
) -> StreamableHttpService<Calculator, LocalSessionManager> {
StreamableHttpService::new(|| Ok(Calculator::new()), Default::default(), {
let mut cfg = StreamableHttpServerConfig::default();
cfg.stateful_mode = true;
cfg.sse_keep_alive = None;
cfg.cancellation_token = ct.child_token();
cfg.session_store = Some(session_store);
cfg
})
}
#[tokio::test]
async fn test_session_state_persisted_to_store() -> anyhow::Result<()> {
let store = Arc::new(InMemorySessionStore::new());
let ct = CancellationToken::new();
let service = make_service(store.clone(), &ct);
let router = axum::Router::new().nest_service("/mcp", service);
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await?;
let addr = listener.local_addr()?;
let handle = tokio::spawn({
let ct = ct.clone();
async move {
let _ = axum::serve(listener, router)
.with_graceful_shutdown(async move { ct.cancelled_owned().await })
.await;
}
});
let transport = StreamableHttpClientTransport::from_config(
StreamableHttpClientTransportConfig::with_uri(format!("http://{addr}/mcp")),
);
let client = ().serve(transport).await?;
let _resources = client.list_all_resources().await?;
assert_eq!(
store.len().await,
1,
"session state should be persisted to the store after initialization"
);
let entries = store.0.read().await;
let state = entries.values().next().expect("store entry should exist");
assert_eq!(
state.initialize_params.client_info.name, "rmcp",
"stored client_info.name should match the rmcp client"
);
let _ = client.cancel().await;
ct.cancel();
handle.await?;
Ok(())
}
#[tokio::test]
async fn test_session_state_deleted_from_store_on_delete() -> anyhow::Result<()> {
let store = Arc::new(InMemorySessionStore::new());
let session_manager = Arc::new(LocalSessionManager::default());
let ct = CancellationToken::new();
let service = StreamableHttpService::new(|| Ok(Calculator::new()), session_manager.clone(), {
let mut cfg = StreamableHttpServerConfig::default();
cfg.stateful_mode = true;
cfg.sse_keep_alive = None;
cfg.cancellation_token = ct.child_token();
cfg.session_store = Some(store.clone());
cfg
});
let router = axum::Router::new().nest_service("/mcp", service);
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await?;
let addr = listener.local_addr()?;
let handle = tokio::spawn({
let ct = ct.clone();
async move {
let _ = axum::serve(listener, router)
.with_graceful_shutdown(async move { ct.cancelled_owned().await })
.await;
}
});
let transport = StreamableHttpClientTransport::from_config(
StreamableHttpClientTransportConfig::with_uri(format!("http://{addr}/mcp")),
);
let client = ().serve(transport).await?;
let _resources = client.list_all_resources().await?;
assert_eq!(store.len().await, 1, "store should have one entry");
let session_id = {
let sessions = session_manager.sessions.read().await;
sessions
.keys()
.next()
.cloned()
.expect("session should exist")
};
let http_client = reqwest::Client::new();
let response = http_client
.delete(format!("http://{addr}/mcp"))
.header("mcp-session-id", session_id.as_ref())
.send()
.await?;
assert_eq!(response.status(), 202);
assert_eq!(
store.len().await,
0,
"store entry should be removed after explicit DELETE"
);
let _ = client.cancel().await;
ct.cancel();
handle.await?;
Ok(())
}
fn spawn_server(
session_store: Option<Arc<dyn SessionStore>>,
session_manager: Arc<LocalSessionManager>,
ct: &CancellationToken,
) -> (std::net::SocketAddr, tokio::task::JoinHandle<()>) {
let svc = StreamableHttpService::new(|| Ok(Calculator::new()), session_manager, {
let mut cfg = StreamableHttpServerConfig::default();
cfg.stateful_mode = true;
cfg.sse_keep_alive = None;
cfg.cancellation_token = ct.child_token();
cfg.session_store = session_store;
cfg
});
let std_listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
std_listener.set_nonblocking(true).unwrap();
let addr = std_listener.local_addr().unwrap();
let listener = tokio::net::TcpListener::from_std(std_listener).unwrap();
let router = axum::Router::new().nest_service("/mcp", svc);
let handle = tokio::spawn({
let ct = ct.clone();
async move {
let _ = axum::serve(listener, router)
.with_graceful_shutdown(async move { ct.cancelled_owned().await })
.await;
}
});
(addr, handle)
}
#[tokio::test]
async fn test_cross_instance_session_restore() -> anyhow::Result<()> {
let http = reqwest::Client::new();
{
let ct_a = CancellationToken::new();
let (addr_a, srv_a) = spawn_server(None, Arc::new(LocalSessionManager::default()), &ct_a);
let init_resp = http
.post(format!("http://{addr_a}/mcp"))
.header("accept", "application/json, text/event-stream")
.header("content-type", "application/json")
.body(r#"{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2025-11-25","capabilities":{},"clientInfo":{"name":"test","version":"0"}}}"#)
.send()
.await?;
assert_eq!(
init_resp.status(),
200,
"instance A: initialize should succeed"
);
let session_id = init_resp
.headers()
.get("mcp-session-id")
.expect("session ID header must be present")
.to_str()?
.to_owned();
ct_a.cancel();
srv_a.await?;
let ct_b = CancellationToken::new();
let (addr_b, srv_b) = spawn_server(None, Arc::new(LocalSessionManager::default()), &ct_b);
let resp = http
.post(format!("http://{addr_b}/mcp"))
.header("accept", "application/json, text/event-stream")
.header("content-type", "application/json")
.header("mcp-session-id", &session_id)
.body(r#"{"jsonrpc":"2.0","id":2,"method":"ping","params":{}}"#)
.send()
.await?;
assert_eq!(
resp.status(),
reqwest::StatusCode::NOT_FOUND,
"without a session store, instance B must return 404 for an unknown session ID"
);
ct_b.cancel();
srv_b.await?;
}
{
let store: Arc<dyn SessionStore> = Arc::new(InMemorySessionStore::new());
let ct_a = CancellationToken::new();
let sm_a = Arc::new(LocalSessionManager::default());
let (addr_a, srv_a) = spawn_server(Some(store.clone()), sm_a.clone(), &ct_a);
let init_resp = http
.post(format!("http://{addr_a}/mcp"))
.header("accept", "application/json, text/event-stream")
.header("content-type", "application/json")
.body(r#"{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2025-11-25","capabilities":{},"clientInfo":{"name":"test","version":"0"}}}"#)
.send()
.await?;
assert_eq!(
init_resp.status(),
200,
"instance A: initialize should succeed"
);
let original_session_id = init_resp
.headers()
.get("mcp-session-id")
.expect("session ID header must be present")
.to_str()?
.to_owned();
let store_ref = store
.load(&original_session_id)
.await
.expect("store load should not error");
assert!(
store_ref.is_some(),
"store should hold the session after initialization"
);
ct_a.cancel();
srv_a.await?;
let ct_b = CancellationToken::new();
let sm_b = Arc::new(LocalSessionManager::default());
let (addr_b, srv_b) = spawn_server(Some(store.clone()), sm_b.clone(), &ct_b);
let resp = http
.post(format!("http://{addr_b}/mcp"))
.header("accept", "application/json, text/event-stream")
.header("content-type", "application/json")
.header("mcp-session-id", &original_session_id)
.body(r#"{"jsonrpc":"2.0","id":2,"method":"ping","params":{}}"#)
.send()
.await?;
assert_eq!(
resp.status(),
200,
"instance B: request must succeed after transparent restore"
);
{
let sessions = sm_b.sessions.read().await;
let restored_id = sessions
.keys()
.next()
.expect("session should exist in instance B after restore");
assert_eq!(
restored_id.as_ref(),
original_session_id.as_str(),
"restored session must keep the original session ID"
);
}
ct_b.cancel();
srv_b.await?;
}
Ok(())
}