#[cfg(feature = "streamable-http")]
mod spec_compliance_tests {
use pmcp::server::streamable_http_server::{StreamableHttpServer, StreamableHttpServerConfig};
use pmcp::server::Server;
use pmcp::shared::streamable_http::{StreamableHttpTransport, StreamableHttpTransportConfig};
use pmcp::shared::{Transport, TransportMessage};
use pmcp::types::{
ClientCapabilities, ClientRequest, Implementation, InitializeRequest, Request,
};
use std::net::{Ipv4Addr, SocketAddr};
use std::sync::Arc;
use tokio::sync::Mutex;
use url::Url;
type Result<T> = std::result::Result<T, Box<dyn std::error::Error + Send + Sync>>;
fn box_err(e: pmcp::Error) -> Box<dyn std::error::Error + Send + Sync> {
Box::new(e)
}
#[tokio::test]
async fn test_baseline_accept_header_validation() -> Result<()> {
let server = Arc::new(Mutex::new(
Server::builder()
.name("test-server")
.version("1.0.0")
.build()
.map_err(box_err)?,
));
let addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0);
let http_server = StreamableHttpServer::new(addr, server);
let (server_addr, server_task) = http_server.start().await.map_err(box_err)?;
let client = reqwest::Client::new();
let url = format!("http://{}", server_addr);
let response = client
.post(&url)
.header("content-type", "application/json")
.body("{}")
.send()
.await
.unwrap();
assert_eq!(
response.status().as_u16(),
406,
"Missing Accept should return 406"
);
let error_body: serde_json::Value = response.json().await.unwrap();
assert_eq!(error_body["jsonrpc"], "2.0");
assert!(error_body["error"]["message"]
.as_str()
.unwrap()
.contains("Accept"));
let response = client
.post(&url)
.header("content-type", "application/json")
.header("accept", "text/html")
.body("{}")
.send()
.await
.unwrap();
assert_eq!(
response.status().as_u16(),
406,
"Wrong Accept should return 406"
);
server_task.abort();
Ok(())
}
#[tokio::test]
async fn test_baseline_content_type_validation() -> Result<()> {
let server = Arc::new(Mutex::new(
Server::builder()
.name("test-server")
.version("1.0.0")
.build()
.map_err(box_err)?,
));
let addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0);
let http_server = StreamableHttpServer::new(addr, server);
let (server_addr, server_task) = http_server.start().await.map_err(box_err)?;
let client = reqwest::Client::new();
let url = format!("http://{}", server_addr);
let response = client
.post(&url)
.header("accept", "application/json, text/event-stream")
.header("content-type", "text/plain")
.body("{}")
.send()
.await
.unwrap();
assert_eq!(
response.status().as_u16(),
415,
"Wrong Content-Type should return 415"
);
let error_body: serde_json::Value = response.json().await.unwrap();
assert_eq!(error_body["jsonrpc"], "2.0");
assert!(error_body["error"]["message"]
.as_str()
.unwrap()
.contains("Content-Type"));
let response = client
.post(&url)
.header("accept", "application/json, text/event-stream")
.body("{}")
.send()
.await
.unwrap();
assert_eq!(
response.status().as_u16(),
415,
"Missing Content-Type should return 415"
);
server_task.abort();
Ok(())
}
#[tokio::test]
async fn test_baseline_protocol_version_required_non_init() -> Result<()> {
let server = Arc::new(Mutex::new(
Server::builder()
.name("test-server")
.version("1.0.0")
.build()
.map_err(box_err)?,
));
let addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0);
let http_server = StreamableHttpServer::new(addr, server);
let (server_addr, server_task) = http_server.start().await.map_err(box_err)?;
let client_config = StreamableHttpTransportConfig {
url: Url::parse(&format!("http://{}", server_addr))
.map_err(|e| pmcp::Error::Internal(e.to_string()))
.map_err(box_err)?,
extra_headers: vec![],
auth_provider: None,
session_id: None,
enable_json_response: false,
on_resumption_token: None,
http_middleware_chain: None,
};
let mut client = StreamableHttpTransport::new(client_config);
let init_message = TransportMessage::Request {
id: 1i64.into(),
request: Request::Client(Box::new(ClientRequest::Initialize(InitializeRequest::new(
Implementation::new("test-client", "1.0.0"),
ClientCapabilities::default(),
)))),
};
client.send(init_message).await.map_err(box_err)?;
let _response = client.receive().await.map_err(box_err)?;
let session_id = client.session_id();
let reqwest_client = reqwest::Client::new();
let url = format!("http://{}", server_addr);
let ping_body = serde_json::json!({
"id": 2,
"request": {
"method": "ping"
}
});
let mut request = reqwest_client
.post(&url)
.header("accept", "application/json, text/event-stream")
.header("content-type", "application/json");
if let Some(sid) = &session_id {
request = request.header("mcp-session-id", sid);
}
let response = request.body(ping_body.to_string()).send().await.unwrap();
let status = response.status().as_u16();
println!(
"Response status without protocol version header: {}",
status
);
server_task.abort();
Ok(())
}
#[tokio::test]
async fn test_baseline_protocol_version_requirement() -> Result<()> {
let server = Arc::new(Mutex::new(
Server::builder()
.name("test-server")
.version("1.0.0")
.build()
.map_err(box_err)?,
));
let addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0);
let http_server = StreamableHttpServer::new(addr, server);
let (server_addr, server_task) = http_server.start().await.map_err(box_err)?;
let client_config = StreamableHttpTransportConfig {
url: Url::parse(&format!("http://{}", server_addr))
.map_err(|e| pmcp::Error::Internal(e.to_string()))
.map_err(box_err)?,
extra_headers: vec![],
auth_provider: None,
session_id: None,
enable_json_response: false,
on_resumption_token: None,
http_middleware_chain: None,
};
let mut client = StreamableHttpTransport::new(client_config);
let init_message = TransportMessage::Request {
id: 1i64.into(),
request: Request::Client(Box::new(ClientRequest::Initialize(InitializeRequest::new(
Implementation::new("test-client", "1.0.0"),
ClientCapabilities::default(),
)))),
};
client.send(init_message).await.map_err(box_err)?;
let _response = client.receive().await.map_err(box_err)?;
let session_id = client.session_id();
let reqwest_client = reqwest::Client::new();
let url = format!("http://{}", server_addr);
let ping_body = serde_json::json!({
"id": 2,
"request": {
"method": "ping"
}
});
let mut request = reqwest_client
.post(&url)
.header("accept", "application/json, text/event-stream")
.header("content-type", "application/json");
if let Some(sid) = &session_id {
request = request.header("mcp-session-id", sid);
}
let response = request.body(ping_body.to_string()).send().await.unwrap();
println!(
"Response status without protocol version: {}",
response.status()
);
server_task.abort();
Ok(())
}
#[tokio::test]
async fn test_baseline_notifications_only_returns_202() -> Result<()> {
let server = Arc::new(Mutex::new(
Server::builder()
.name("test-server")
.version("1.0.0")
.build()
.map_err(box_err)?,
));
let addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0);
let http_server = StreamableHttpServer::new(addr, server);
let (server_addr, server_task) = http_server.start().await.map_err(box_err)?;
let client_config = StreamableHttpTransportConfig {
url: Url::parse(&format!("http://{}", server_addr))
.map_err(|e| pmcp::Error::Internal(e.to_string()))
.map_err(box_err)?,
extra_headers: vec![],
auth_provider: None,
session_id: None,
enable_json_response: false,
on_resumption_token: None,
http_middleware_chain: None,
};
let mut client = StreamableHttpTransport::new(client_config);
let init_message = TransportMessage::Request {
id: 1i64.into(),
request: Request::Client(Box::new(ClientRequest::Initialize(InitializeRequest::new(
Implementation::new("test-client", "1.0.0"),
ClientCapabilities::default(),
)))),
};
client.send(init_message).await.map_err(box_err)?;
let _response = client.receive().await.map_err(box_err)?;
let protocol_version = client
.protocol_version()
.unwrap_or_else(|| pmcp::LATEST_PROTOCOL_VERSION.to_string());
let _notification_message = TransportMessage::Notification(
pmcp::types::Notification::Client(pmcp::types::ClientNotification::Initialized),
);
let reqwest_client = reqwest::Client::new();
let url = format!("http://{}", server_addr);
let notification_body = serde_json::json!({
"jsonrpc": "2.0",
"method": "notifications/initialized"
});
let mut request = reqwest_client
.post(&url)
.header("accept", "application/json, text/event-stream")
.header("content-type", "application/json")
.header("mcp-protocol-version", &protocol_version);
if let Some(sid) = client.session_id() {
request = request.header("mcp-session-id", sid);
}
let response = request
.body(notification_body.to_string())
.send()
.await
.unwrap();
let status = response.status().as_u16();
if status != 202 {
let body = response.text().await.unwrap();
println!("Response status: {}, body: {}", status, body);
panic!("Notification should return 202 Accepted, got {}", status);
}
server_task.abort();
Ok(())
}
#[tokio::test]
async fn test_stateful_initialize_creates_session() -> Result<()> {
let server = Arc::new(Mutex::new(
Server::builder()
.name("test-server")
.version("1.0.0")
.build()
.map_err(box_err)?,
));
let addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0);
let http_server = StreamableHttpServer::new(addr, server);
let (server_addr, server_task) = http_server.start().await.map_err(box_err)?;
let client_config = StreamableHttpTransportConfig {
url: Url::parse(&format!("http://{}", server_addr))
.map_err(|e| pmcp::Error::Internal(e.to_string()))
.map_err(box_err)?,
extra_headers: vec![],
auth_provider: None,
session_id: None, enable_json_response: false,
on_resumption_token: None,
http_middleware_chain: None,
};
let mut client = StreamableHttpTransport::new(client_config);
let init_message = TransportMessage::Request {
id: 1i64.into(),
request: Request::Client(Box::new(ClientRequest::Initialize(InitializeRequest::new(
Implementation::new("test-client", "1.0.0"),
ClientCapabilities::default(),
)))),
};
client.send(init_message).await.map_err(box_err)?;
let _response = client.receive().await.map_err(box_err)?;
assert!(
client.session_id().is_some(),
"Session ID should be created after initialization"
);
server_task.abort();
Ok(())
}
#[tokio::test]
async fn test_stateful_concurrent_sse_conflict() -> Result<()> {
let server = Arc::new(Mutex::new(
Server::builder()
.name("test-server")
.version("1.0.0")
.build()
.map_err(box_err)?,
));
let addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0);
let http_server = StreamableHttpServer::new(addr, server);
let (server_addr, server_task) = http_server.start().await.map_err(box_err)?;
let client_config = StreamableHttpTransportConfig {
url: Url::parse(&format!("http://{}", server_addr))
.map_err(|e| pmcp::Error::Internal(e.to_string()))
.map_err(box_err)?,
extra_headers: vec![],
auth_provider: None,
session_id: None,
enable_json_response: false,
on_resumption_token: None,
http_middleware_chain: None,
};
let mut client = StreamableHttpTransport::new(client_config);
let init_message = TransportMessage::Request {
id: 1i64.into(),
request: Request::Client(Box::new(ClientRequest::Initialize(InitializeRequest::new(
Implementation::new("test-client", "1.0.0"),
ClientCapabilities::default(),
)))),
};
client.send(init_message).await.map_err(box_err)?;
let _response = client.receive().await.map_err(box_err)?;
let session_id = client.session_id().expect("Should have session ID");
let reqwest_client = reqwest::Client::new();
let url = format!("http://{}", server_addr);
let response1 = reqwest_client
.get(&url)
.header("accept", "text/event-stream")
.header("mcp-session-id", &session_id)
.send()
.await
.unwrap();
assert_eq!(response1.status().as_u16(), 200, "First SSE should succeed");
let response2 = reqwest_client
.get(&url)
.header("accept", "text/event-stream")
.header("mcp-session-id", &session_id)
.send()
.await
.unwrap();
assert_eq!(
response2.status().as_u16(),
409,
"Second concurrent SSE should return 409 Conflict"
);
server_task.abort();
Ok(())
}
async fn create_stateless_server() -> Result<(SocketAddr, tokio::task::JoinHandle<()>)> {
let server = Arc::new(Mutex::new(
Server::builder()
.name("test-server")
.version("1.0.0")
.build()
.map_err(box_err)?,
));
let addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0);
let config = StreamableHttpServerConfig {
session_id_generator: None, enable_json_response: false,
event_store: None,
on_session_initialized: None,
on_session_closed: None,
http_middleware: None,
allowed_origins: None,
max_request_bytes: pmcp::server::limits::DEFAULT_MAX_REQUEST_BYTES,
};
let http_server = StreamableHttpServer::with_config(addr, server, config);
http_server.start().await.map_err(box_err)
}
#[tokio::test]
async fn test_stateless_no_session_id_in_response() -> Result<()> {
let (server_addr, server_task) = create_stateless_server().await?;
let client_config = StreamableHttpTransportConfig {
url: Url::parse(&format!("http://{}", server_addr))
.map_err(|e| pmcp::Error::Internal(e.to_string()))
.map_err(box_err)?,
extra_headers: vec![],
auth_provider: None,
session_id: None,
enable_json_response: false,
on_resumption_token: None,
http_middleware_chain: None,
};
let mut client = StreamableHttpTransport::new(client_config);
let init_message = TransportMessage::Request {
id: 1i64.into(),
request: Request::Client(Box::new(ClientRequest::Initialize(InitializeRequest::new(
Implementation::new("test-client", "1.0.0"),
ClientCapabilities::default(),
)))),
};
client.send(init_message).await.map_err(box_err)?;
let _response = client.receive().await.map_err(box_err)?;
assert!(
client.session_id().is_none(),
"No session ID should be created in stateless mode"
);
server_task.abort();
Ok(())
}
#[tokio::test]
async fn test_stateless_reinitialize_allowed() -> Result<()> {
let (server_addr, server_task) = create_stateless_server().await?;
let client_config = StreamableHttpTransportConfig {
url: Url::parse(&format!("http://{}", server_addr))
.map_err(|e| pmcp::Error::Internal(e.to_string()))
.map_err(box_err)?,
extra_headers: vec![],
auth_provider: None,
session_id: None,
enable_json_response: false,
on_resumption_token: None,
http_middleware_chain: None,
};
let mut client = StreamableHttpTransport::new(client_config);
let init_message = TransportMessage::Request {
id: 1i64.into(),
request: Request::Client(Box::new(ClientRequest::Initialize(InitializeRequest::new(
Implementation::new("test-client", "1.0.0"),
ClientCapabilities::default(),
)))),
};
client.send(init_message.clone()).await.map_err(box_err)?;
let _response1 = client.receive().await.map_err(box_err)?;
let init_message2 = TransportMessage::Request {
id: 2i64.into(),
request: Request::Client(Box::new(ClientRequest::Initialize(InitializeRequest::new(
Implementation::new("test-client", "1.0.0"),
ClientCapabilities::default(),
)))),
};
client.send(init_message2).await.map_err(box_err)?;
let _response2 = client.receive().await.map_err(box_err)?;
server_task.abort();
Ok(())
}
#[tokio::test]
async fn test_stateless_no_session_required() -> Result<()> {
let (server_addr, server_task) = create_stateless_server().await?;
let client_config = StreamableHttpTransportConfig {
url: Url::parse(&format!("http://{}", server_addr))
.map_err(|e| pmcp::Error::Internal(e.to_string()))
.map_err(box_err)?,
extra_headers: vec![],
auth_provider: None,
session_id: None, enable_json_response: false,
on_resumption_token: None,
http_middleware_chain: None,
};
let mut client = StreamableHttpTransport::new(client_config);
let ping_message = TransportMessage::Request {
id: 1i64.into(),
request: Request::Client(Box::new(ClientRequest::Ping)),
};
client.send(ping_message).await.map_err(box_err)?;
let _response = client.receive().await.map_err(box_err)?;
server_task.abort();
Ok(())
}
#[tokio::test]
async fn test_stateless_ignores_arbitrary_session_id() -> Result<()> {
let (server_addr, server_task) = create_stateless_server().await?;
let client_config = StreamableHttpTransportConfig {
url: Url::parse(&format!("http://{}", server_addr))
.map_err(|e| pmcp::Error::Internal(e.to_string()))
.map_err(box_err)?,
extra_headers: vec![],
auth_provider: None,
session_id: Some("arbitrary-session-id".to_string()), enable_json_response: false,
on_resumption_token: None,
http_middleware_chain: None,
};
let mut client = StreamableHttpTransport::new(client_config);
let ping_message = TransportMessage::Request {
id: 1i64.into(),
request: Request::Client(Box::new(ClientRequest::Ping)),
};
client.send(ping_message).await.map_err(box_err)?;
let _response = client.receive().await.map_err(box_err)?;
server_task.abort();
Ok(())
}
#[tokio::test]
async fn test_stateless_sse_returns_405() -> Result<()> {
let (server_addr, server_task) = create_stateless_server().await?;
let client = reqwest::Client::new();
let url = format!("http://{}", server_addr);
let response = client
.get(&url)
.header("accept", "text/event-stream")
.send()
.await
.unwrap();
assert_eq!(
response.status().as_u16(),
405,
"GET SSE should return 405 in stateless mode"
);
server_task.abort();
Ok(())
}
#[tokio::test]
async fn test_stateless_delete_returns_404_or_405() -> Result<()> {
let (server_addr, server_task) = create_stateless_server().await?;
let client = reqwest::Client::new();
let url = format!("http://{}", server_addr);
let response = client.delete(&url).send().await.unwrap();
let status = response.status().as_u16();
assert!(
status == 404 || status == 405,
"DELETE without session should return 404 or 405 in stateless mode, got {}",
status
);
server_task.abort();
Ok(())
}
}