#[cfg(feature = "streamable-http")]
mod session_validation_tests {
use pmcp::server::streamable_http_server::StreamableHttpServer;
use pmcp::server::Server;
use pmcp::shared::streamable_http::{StreamableHttpTransport, StreamableHttpTransportConfig};
use pmcp::shared::{Transport, TransportMessage};
use pmcp::types::{
ClientCapabilities, ClientRequest, Implementation, InitializeRequest, Request,
};
type Result<T> = std::result::Result<T, Box<dyn std::error::Error + Send + Sync>>;
use std::net::{Ipv4Addr, SocketAddr};
use std::sync::Arc;
use tokio::sync::Mutex;
use url::Url;
async fn create_test_server() -> Result<(SocketAddr, tokio::task::JoinHandle<()>)> {
let server = Arc::new(Mutex::new(
Server::builder()
.name("test-server")
.version("1.0.0")
.build()
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?,
));
let addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0);
let http_server = StreamableHttpServer::new(addr, server);
http_server
.start()
.await
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
}
#[tokio::test]
async fn test_double_initialization_rejected() -> Result<()> {
let (server_addr, server_task) = create_test_server().await?;
let client_config = StreamableHttpTransportConfig {
url: Url::parse(&format!("http://{}", server_addr)).map_err(|e| {
Box::new(pmcp::Error::Internal(e.to_string()))
as Box<dyn std::error::Error + Send + Sync>
})?,
extra_headers: vec![],
auth_provider: None,
session_id: None,
enable_json_response: false,
on_resumption_token: None,
http_middleware_chain: None,
};
let mut client1 = StreamableHttpTransport::new(client_config);
let init_message = TransportMessage::Request {
id: 1i64.into(),
request: Request::Client(Box::new(ClientRequest::Initialize(InitializeRequest::new(
Implementation::new("test-client", "1.0.0"),
ClientCapabilities::default(),
)))),
};
client1
.send(init_message.clone())
.await
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
let _response1 = client1
.receive()
.await
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
let session_id = client1
.session_id()
.expect("Session should be set after first init");
let client_config2 = StreamableHttpTransportConfig {
url: Url::parse(&format!("http://{}", server_addr)).map_err(|e| {
Box::new(pmcp::Error::Internal(e.to_string()))
as Box<dyn std::error::Error + Send + Sync>
})?,
extra_headers: vec![],
auth_provider: None,
session_id: Some(session_id),
enable_json_response: false,
on_resumption_token: None,
http_middleware_chain: None,
};
let mut client2 = StreamableHttpTransport::new(client_config2);
let result = client2.send(init_message).await;
assert!(result.is_err(), "Double initialization should fail");
server_task.abort();
Ok(())
}
#[tokio::test]
async fn test_unknown_session_id_returns_404() -> Result<()> {
let (server_addr, server_task) = create_test_server().await?;
let client_config = StreamableHttpTransportConfig {
url: Url::parse(&format!("http://{}", server_addr)).map_err(|e| {
Box::new(pmcp::Error::Internal(e.to_string()))
as Box<dyn std::error::Error + Send + Sync>
})?,
extra_headers: vec![],
auth_provider: None,
session_id: Some("invalid-session-id".to_string()),
enable_json_response: false,
on_resumption_token: None,
http_middleware_chain: None,
};
let mut client = StreamableHttpTransport::new(client_config);
let ping_message = TransportMessage::Request {
id: 1i64.into(),
request: Request::Client(Box::new(ClientRequest::Ping)),
};
let result = client.send(ping_message).await;
assert!(result.is_err(), "Request with unknown session should fail");
server_task.abort();
Ok(())
}
#[tokio::test]
async fn test_non_init_without_session_returns_400() -> Result<()> {
let (server_addr, server_task) = create_test_server().await?;
let client_config = StreamableHttpTransportConfig {
url: Url::parse(&format!("http://{}", server_addr)).map_err(|e| {
Box::new(pmcp::Error::Internal(e.to_string()))
as Box<dyn std::error::Error + Send + Sync>
})?,
extra_headers: vec![],
auth_provider: None,
session_id: None,
enable_json_response: false,
on_resumption_token: None,
http_middleware_chain: None,
};
let mut client = StreamableHttpTransport::new(client_config);
let ping_message = TransportMessage::Request {
id: 1i64.into(),
request: Request::Client(Box::new(ClientRequest::Ping)),
};
let result = client.send(ping_message).await;
assert!(result.is_err(), "Request without session should fail");
server_task.abort();
Ok(())
}
#[tokio::test]
async fn test_session_id_in_all_responses() -> Result<()> {
let (server_addr, server_task) = create_test_server().await?;
let client_config = StreamableHttpTransportConfig {
url: Url::parse(&format!("http://{}", server_addr)).map_err(|e| {
Box::new(pmcp::Error::Internal(e.to_string()))
as Box<dyn std::error::Error + Send + Sync>
})?,
extra_headers: vec![],
auth_provider: None,
session_id: None,
enable_json_response: false,
on_resumption_token: None,
http_middleware_chain: None,
};
let mut client = StreamableHttpTransport::new(client_config);
let init_message = TransportMessage::Request {
id: 1i64.into(),
request: Request::Client(Box::new(ClientRequest::Initialize(InitializeRequest::new(
Implementation::new("test-client", "1.0.0"),
ClientCapabilities::default(),
)))),
};
client
.send(init_message)
.await
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
let _init_response = client
.receive()
.await
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
let session_id = client
.session_id()
.expect("Session ID should be set after init");
let ping_message = TransportMessage::Request {
id: 2i64.into(),
request: Request::Client(Box::new(ClientRequest::Ping)),
};
client
.send(ping_message)
.await
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
let _ping_response = client
.receive()
.await
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
assert_eq!(client.session_id().unwrap(), session_id);
server_task.abort();
Ok(())
}
#[tokio::test]
async fn test_sse_endpoint_with_unknown_session() -> Result<()> {
let (server_addr, server_task) = create_test_server().await?;
let client = reqwest::Client::new();
let url = format!("http://{}", server_addr);
let response = client
.get(&url)
.header("accept", "text/event-stream")
.header("mcp-session-id", "unknown-session")
.send()
.await
.unwrap();
assert_eq!(response.status().as_u16(), 404);
server_task.abort();
Ok(())
}
#[tokio::test]
async fn test_delete_unknown_session() -> Result<()> {
let (server_addr, server_task) = create_test_server().await?;
let client = reqwest::Client::new();
let url = format!("http://{}", server_addr);
let response = client
.delete(&url)
.header("mcp-session-id", "unknown-session")
.send()
.await
.unwrap();
assert_eq!(response.status().as_u16(), 404);
server_task.abort();
Ok(())
}
#[tokio::test]
async fn test_protocol_version_validation() -> Result<()> {
let (server_addr, server_task) = create_test_server().await?;
let client_config = StreamableHttpTransportConfig {
url: Url::parse(&format!("http://{}", server_addr)).map_err(|e| {
Box::new(pmcp::Error::Internal(e.to_string()))
as Box<dyn std::error::Error + Send + Sync>
})?,
extra_headers: vec![],
auth_provider: None,
session_id: None,
enable_json_response: false,
on_resumption_token: None,
http_middleware_chain: None,
};
let mut client = StreamableHttpTransport::new(client_config);
let init_message = TransportMessage::Request {
id: 1i64.into(),
request: Request::Client(Box::new(ClientRequest::Initialize(InitializeRequest::new(
Implementation::new("test-client", "1.0.0"),
ClientCapabilities::default(),
)))),
};
client
.send(init_message)
.await
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
let _init_response = client
.receive()
.await
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
assert!(
client.protocol_version().is_some(),
"Protocol version should be negotiated"
);
let negotiated_version = client.protocol_version().unwrap();
assert!(
pmcp::SUPPORTED_PROTOCOL_VERSIONS.contains(&negotiated_version.as_str()),
"Negotiated version should be supported"
);
let ping_message = TransportMessage::Request {
id: 2i64.into(),
request: Request::Client(Box::new(ClientRequest::Ping)),
};
client
.send(ping_message)
.await
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
let _ping_response = client
.receive()
.await
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
assert_eq!(
client.protocol_version().unwrap(),
negotiated_version,
"Protocol version should remain consistent"
);
server_task.abort();
Ok(())
}
#[tokio::test]
async fn test_protocol_version_in_responses() -> Result<()> {
let (server_addr, server_task) = create_test_server().await?;
let client_config = StreamableHttpTransportConfig {
url: Url::parse(&format!("http://{}", server_addr)).map_err(|e| {
Box::new(pmcp::Error::Internal(e.to_string()))
as Box<dyn std::error::Error + Send + Sync>
})?,
extra_headers: vec![],
auth_provider: None,
session_id: None,
enable_json_response: false,
on_resumption_token: None,
http_middleware_chain: None,
};
let mut client = StreamableHttpTransport::new(client_config);
let init_message = TransportMessage::Request {
id: 1i64.into(),
request: Request::Client(Box::new(ClientRequest::Initialize(InitializeRequest::new(
Implementation::new("test-client", "1.0.0"),
ClientCapabilities::default(),
)))),
};
client
.send(init_message)
.await
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
let _init_response = client
.receive()
.await
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
let negotiated_version = client
.protocol_version()
.expect("Protocol version should be set");
let ping_message = TransportMessage::Request {
id: 2i64.into(),
request: Request::Client(Box::new(ClientRequest::Ping)),
};
client
.send(ping_message)
.await
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
let _ping_response = client
.receive()
.await
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
assert_eq!(
client.protocol_version().unwrap(),
negotiated_version,
"Protocol version should remain consistent"
);
server_task.abort();
Ok(())
}
#[tokio::test]
async fn test_error_response_bodies() -> Result<()> {
let (server_addr, server_task) = create_test_server().await?;
let client = reqwest::Client::new();
let url = format!("http://{}", server_addr);
let response = client
.post(&url)
.header("accept", "application/json, text/event-stream")
.body("{}")
.send()
.await
.unwrap();
assert_eq!(response.status().as_u16(), 415);
let error_body: serde_json::Value = response.json().await.unwrap();
assert_eq!(error_body["jsonrpc"], "2.0");
assert_eq!(error_body["error"]["code"], -32700);
assert!(error_body["error"]["message"]
.as_str()
.unwrap()
.contains("Content-Type"));
assert_eq!(error_body["id"], serde_json::Value::Null);
let response = client
.post(&url)
.header("content-type", "application/json")
.body("{}")
.send()
.await
.unwrap();
assert_eq!(response.status().as_u16(), 406);
let error_body: serde_json::Value = response.json().await.unwrap();
assert_eq!(error_body["jsonrpc"], "2.0");
assert_eq!(error_body["error"]["code"], -32700);
assert!(error_body["error"]["message"]
.as_str()
.unwrap()
.contains("Accept"));
let init_config = StreamableHttpTransportConfig {
url: Url::parse(&format!("http://{}", server_addr)).map_err(|e| {
Box::new(pmcp::Error::Internal(e.to_string()))
as Box<dyn std::error::Error + Send + Sync>
})?,
extra_headers: vec![],
auth_provider: None,
session_id: None,
enable_json_response: false,
on_resumption_token: None,
http_middleware_chain: None,
};
let mut init_client = StreamableHttpTransport::new(init_config);
let init_message = TransportMessage::Request {
id: 1i64.into(),
request: Request::Client(Box::new(ClientRequest::Initialize(InitializeRequest::new(
Implementation::new("test-client", "1.0.0"),
ClientCapabilities::default(),
)))),
};
init_client
.send(init_message)
.await
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
let _init_response = init_client
.receive()
.await
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
let client_config = StreamableHttpTransportConfig {
url: Url::parse(&format!("http://{}", server_addr)).map_err(|e| {
Box::new(pmcp::Error::Internal(e.to_string()))
as Box<dyn std::error::Error + Send + Sync>
})?,
extra_headers: vec![],
auth_provider: None,
session_id: None, enable_json_response: false,
on_resumption_token: None,
http_middleware_chain: None,
};
let mut client = StreamableHttpTransport::new(client_config);
let ping_message = TransportMessage::Request {
id: 2i64.into(),
request: Request::Client(Box::new(ClientRequest::Ping)),
};
let result = client.send(ping_message).await;
assert!(result.is_err());
let client_config2 = StreamableHttpTransportConfig {
url: Url::parse(&format!("http://{}", server_addr)).map_err(|e| {
Box::new(pmcp::Error::Internal(e.to_string()))
as Box<dyn std::error::Error + Send + Sync>
})?,
extra_headers: vec![],
auth_provider: None,
session_id: Some("non-existent-session".to_string()),
enable_json_response: false,
on_resumption_token: None,
http_middleware_chain: None,
};
let mut client2 = StreamableHttpTransport::new(client_config2);
let ping_message2 = TransportMessage::Request {
id: 3i64.into(),
request: Request::Client(Box::new(ClientRequest::Ping)),
};
let result2 = client2.send(ping_message2).await;
assert!(result2.is_err());
server_task.abort();
Ok(())
}
#[tokio::test]
async fn test_successful_session_lifecycle() -> Result<()> {
let (server_addr, server_task) = create_test_server().await?;
let client_config = StreamableHttpTransportConfig {
url: Url::parse(&format!("http://{}", server_addr)).map_err(|e| {
Box::new(pmcp::Error::Internal(e.to_string()))
as Box<dyn std::error::Error + Send + Sync>
})?,
extra_headers: vec![],
auth_provider: None,
session_id: None,
enable_json_response: false,
on_resumption_token: None,
http_middleware_chain: None,
};
let mut client = StreamableHttpTransport::new(client_config);
let init_message = TransportMessage::Request {
id: 1i64.into(),
request: Request::Client(Box::new(ClientRequest::Initialize(InitializeRequest::new(
Implementation::new("test-client", "1.0.0"),
ClientCapabilities::default(),
)))),
};
client
.send(init_message)
.await
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
let _init_response = client
.receive()
.await
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
let session_id = client
.session_id()
.expect("Session should be set after init");
let ping_message = TransportMessage::Request {
id: 2i64.into(),
request: Request::Client(Box::new(ClientRequest::Ping)),
};
client
.send(ping_message)
.await
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
let _ping_response = client
.receive()
.await
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
assert_eq!(client.session_id().unwrap(), session_id);
let reqwest_client = reqwest::Client::new();
let url = format!("http://{}", server_addr);
let delete_response = reqwest_client
.delete(&url)
.header("mcp-session-id", &session_id)
.send()
.await
.unwrap();
assert_eq!(delete_response.status().as_u16(), 200);
let client_config2 = StreamableHttpTransportConfig {
url: Url::parse(&format!("http://{}", server_addr)).map_err(|e| {
Box::new(pmcp::Error::Internal(e.to_string()))
as Box<dyn std::error::Error + Send + Sync>
})?,
extra_headers: vec![],
auth_provider: None,
session_id: Some(session_id),
enable_json_response: false,
on_resumption_token: None,
http_middleware_chain: None,
};
let mut client2 = StreamableHttpTransport::new(client_config2);
let ping_message2 = TransportMessage::Request {
id: 3i64.into(),
request: Request::Client(Box::new(ClientRequest::Ping)),
};
let result = client2.send(ping_message2).await;
assert!(result.is_err(), "Request with deleted session should fail");
server_task.abort();
Ok(())
}
}