#![cfg(all(
feature = "transport-streamable-http-client",
feature = "transport-streamable-http-client-reqwest",
feature = "transport-streamable-http-server",
not(feature = "local")
))]
use std::{collections::HashMap, sync::Arc};
use rmcp::{
ServiceError, ServiceExt,
model::{ClientJsonRpcMessage, ClientRequest, PingRequest, RequestId},
transport::{
StreamableHttpClientTransport,
streamable_http_client::{
StreamableHttpClient, StreamableHttpClientTransportConfig, StreamableHttpError,
},
streamable_http_server::{
StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager,
},
},
};
use tokio_util::sync::CancellationToken;
mod common;
use common::calculator::Calculator;
#[tokio::test]
async fn test_stale_session_id_returns_status_aware_error() -> anyhow::Result<()> {
let ct = CancellationToken::new();
let service: StreamableHttpService<Calculator, LocalSessionManager> =
StreamableHttpService::new(
|| Ok(Calculator::new()),
Default::default(),
StreamableHttpServerConfig::default()
.with_sse_keep_alive(None)
.with_cancellation_token(ct.child_token()),
);
let router = axum::Router::new().nest_service("/mcp", service);
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await?;
let addr = listener.local_addr()?;
let handle = tokio::spawn({
let ct = ct.clone();
async move {
let _ = axum::serve(listener, router)
.with_graceful_shutdown(async move { ct.cancelled_owned().await })
.await;
}
});
let uri = Arc::<str>::from(format!("http://{addr}/mcp"));
let message = ClientJsonRpcMessage::request(
ClientRequest::PingRequest(PingRequest::default()),
RequestId::Number(1),
);
let client = reqwest::Client::new();
let result = client
.post_message(
uri.clone(),
message,
Some(Arc::from("stale-session-id")),
None,
HashMap::new(),
)
.await;
let raw_response = reqwest::Client::new()
.post(uri.as_ref())
.header("accept", "application/json, text/event-stream")
.header("content-type", "application/json")
.header("mcp-session-id", "stale-session-id")
.body(r#"{"jsonrpc":"2.0","id":1,"method":"ping","params":{}}"#)
.send()
.await?;
assert_eq!(raw_response.status(), reqwest::StatusCode::NOT_FOUND);
match result {
Err(StreamableHttpError::SessionExpired) => {
}
other => panic!("expected SessionExpired, got: {other:?}"),
}
ct.cancel();
handle.await?;
Ok(())
}
#[tokio::test]
async fn test_transparent_reinitialization_on_session_expiry() -> anyhow::Result<()> {
let ct = CancellationToken::new();
let session_manager = Arc::new(LocalSessionManager::default());
let service = StreamableHttpService::new(
|| Ok(Calculator::new()),
session_manager.clone(),
StreamableHttpServerConfig::default()
.with_sse_keep_alive(None)
.with_cancellation_token(ct.child_token()),
);
let router = axum::Router::new().nest_service("/mcp", service);
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await?;
let addr = listener.local_addr()?;
let server_handle = tokio::spawn({
let ct = ct.clone();
async move {
let _ = axum::serve(listener, router)
.with_graceful_shutdown(async move { ct.cancelled_owned().await })
.await;
}
});
let transport = StreamableHttpClientTransport::from_config(
StreamableHttpClientTransportConfig::with_uri(format!("http://{addr}/mcp"))
.reinit_on_expired_session(true),
);
let client = ().serve(transport).await?;
let _resources = client.list_all_resources().await?;
let original_session_id = {
let sessions = session_manager.sessions.read().await;
sessions
.keys()
.next()
.cloned()
.expect("session should exist")
};
{
let mut sessions = session_manager.sessions.write().await;
sessions.clear();
}
let _resources_after = client.list_all_resources().await?;
{
let sessions = session_manager.sessions.read().await;
let new_session_id = sessions
.keys()
.next()
.expect("new session should exist after re-initialization");
assert_ne!(
new_session_id, &original_session_id,
"new session ID should differ from the original"
);
}
let _ = client.cancel().await;
ct.cancel();
server_handle.await?;
Ok(())
}
#[tokio::test]
async fn test_session_expired_error_when_reinit_disabled() -> anyhow::Result<()> {
let ct = CancellationToken::new();
let session_manager = Arc::new(LocalSessionManager::default());
let service = StreamableHttpService::new(
|| Ok(Calculator::new()),
session_manager.clone(),
StreamableHttpServerConfig::default()
.with_sse_keep_alive(None)
.with_cancellation_token(ct.child_token()),
);
let router = axum::Router::new().nest_service("/mcp", service);
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await?;
let addr = listener.local_addr()?;
let server_handle = tokio::spawn({
let ct = ct.clone();
async move {
let _ = axum::serve(listener, router)
.with_graceful_shutdown(async move { ct.cancelled_owned().await })
.await;
}
});
let transport = StreamableHttpClientTransport::from_config(
StreamableHttpClientTransportConfig::with_uri(format!("http://{addr}/mcp"))
.reinit_on_expired_session(false),
);
let client = ().serve(transport).await?;
let _resources = client.list_all_resources().await?;
{
let mut sessions = session_manager.sessions.write().await;
sessions.clear();
}
let result = client.list_all_resources().await;
match result {
Err(ServiceError::TransportSend(ref dyn_err)) => {
let err_msg = format!("{dyn_err}");
assert!(
err_msg.contains("Session expired"),
"expected 'Session expired' in error message, got: {err_msg}"
);
}
other => panic!("expected TransportSend(SessionExpired), got: {other:?}"),
}
let _ = client.cancel().await;
ct.cancel();
server_handle.await?;
Ok(())
}