use anyhow::Result;
use futures_util::{SinkExt, StreamExt};
use serde_json::{json, Value};
use solidmcp::McpServer;
use tokio_tungstenite::{connect_async, tungstenite::Message};
async fn find_available_port() -> u16 {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
.await
.expect("Failed to bind to port 0");
let port = listener.local_addr().unwrap().port();
drop(listener);
port
}
#[tokio::test]
async fn test_websocket_connection_and_init() -> Result<()> {
let port = find_available_port().await;
let mut server = McpServer::new().await?;
let server_handle = tokio::spawn(async move { server.start(port).await });
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
let url = format!("ws://127.0.0.1:{port}/mcp");
let (ws_stream, _) = connect_async(&url).await?;
let (mut write, mut read) = ws_stream.split();
let init_request = json!({
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": "2025-06-18",
"capabilities": {},
"clientInfo": {
"name": "ws-test-client",
"version": "1.0.0"
}
}
});
write
.send(Message::Text(init_request.to_string().into()))
.await?;
if let Some(Ok(Message::Text(response_text))) = read.next().await {
let response: Value = serde_json::from_str(&response_text.to_string())?;
assert_eq!(response["jsonrpc"], "2.0");
assert_eq!(response["id"], 1);
assert!(response["result"]["capabilities"].is_object());
assert_eq!(response["result"]["protocolVersion"], "2025-06-18");
} else {
panic!("Did not receive expected response");
}
server_handle.abort();
Ok(())
}
#[tokio::test]
async fn test_websocket_message_ordering() -> Result<()> {
let port = find_available_port().await;
let mut server = McpServer::new().await?;
let server_handle = tokio::spawn(async move { server.start(port).await });
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
let url = format!("ws://127.0.0.1:{port}/mcp");
let (ws_stream, _) = connect_async(&url).await?;
let (mut write, mut read) = ws_stream.split();
let init = json!({
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": "2025-06-18"
}
});
write.send(Message::Text(init.to_string().into())).await?;
let _ = read.next().await;
let request_ids = vec![100, 200, 300, 400, 500];
for id in &request_ids {
let request = json!({
"jsonrpc": "2.0",
"id": id,
"method": "tools/list",
"params": {}
});
write
.send(Message::Text(request.to_string().into()))
.await?;
}
let mut received_ids = Vec::new();
for _ in 0..request_ids.len() {
if let Some(Ok(Message::Text(response_text))) = read.next().await {
let response: Value = serde_json::from_str(&response_text.to_string())?;
if let Some(id) = response["id"].as_i64() {
received_ids.push(id);
}
}
}
assert_eq!(received_ids.len(), request_ids.len());
for id in request_ids {
assert!(received_ids.contains(&(id as i64)));
}
server_handle.abort();
Ok(())
}
#[tokio::test]
async fn test_websocket_ping_pong() -> Result<()> {
let port = find_available_port().await;
let mut server = McpServer::new().await?;
let server_handle = tokio::spawn(async move { server.start(port).await });
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
let url = format!("ws://127.0.0.1:{port}/mcp");
let (ws_stream, _) = connect_async(&url).await?;
let (mut write, mut read) = ws_stream.split();
write.send(Message::Ping(vec![1, 2, 3].into())).await?;
tokio::time::timeout(tokio::time::Duration::from_secs(5), async {
while let Some(Ok(msg)) = read.next().await {
if let Message::Pong(data) = msg {
assert_eq!(data, vec![1, 2, 3]);
return;
}
}
panic!("Did not receive pong");
})
.await?;
server_handle.abort();
Ok(())
}
#[tokio::test]
async fn test_websocket_close_handling() -> Result<()> {
let port = find_available_port().await;
let mut server = McpServer::new().await?;
let server_handle = tokio::spawn(async move { server.start(port).await });
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
let url = format!("ws://127.0.0.1:{port}/mcp");
let (ws_stream, _) = connect_async(&url).await?;
let (mut write, mut read) = ws_stream.split();
let init = json!({
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": "2025-06-18"
}
});
write.send(Message::Text(init.to_string().into())).await?;
let _ = read.next().await;
write.send(Message::Close(None)).await?;
tokio::time::timeout(tokio::time::Duration::from_secs(5), async {
while let Some(Ok(msg)) = read.next().await {
if matches!(msg, Message::Close(_)) {
return;
}
}
})
.await?;
server_handle.abort();
Ok(())
}
#[tokio::test]
async fn test_websocket_large_messages() -> Result<()> {
let port = find_available_port().await;
let mut server = McpServer::new().await?;
let server_handle = tokio::spawn(async move { server.start(port).await });
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
let url = format!("ws://127.0.0.1:{port}/mcp");
let (ws_stream, _) = connect_async(&url).await?;
let (mut write, mut read) = ws_stream.split();
let mut large_array = Vec::new();
for i in 0..10000 {
large_array.push(format!("item_{i}"));
}
let large_request = json!({
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": "2025-06-18",
"capabilities": {},
"clientInfo": {
"name": "test-client",
"version": "1.0.0",
"metadata": large_array
}
}
});
write
.send(Message::Text(large_request.to_string().into()))
.await?;
if let Some(Ok(Message::Text(response_text))) = read.next().await {
let response: Value = serde_json::from_str(&response_text.to_string())?;
assert_eq!(response["jsonrpc"], "2.0");
assert!(response["result"].is_object());
} else {
panic!("Did not receive response for large message");
}
server_handle.abort();
Ok(())
}
#[tokio::test]
async fn test_concurrent_websocket_connections() -> Result<()> {
let port = find_available_port().await;
let mut server = McpServer::new().await?;
let server_handle = tokio::spawn(async move { server.start(port).await });
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
let mut handles = vec![];
for client_id in 0..5 {
let url = format!("ws://127.0.0.1:{port}/mcp");
let handle = tokio::spawn(async move {
let (ws_stream, _) = connect_async(&url).await?;
let (mut write, mut read) = ws_stream.split();
let init = json!({
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": "2025-06-18",
"clientInfo": {
"name": format!("client-{}", client_id),
"version": "1.0.0"
}
}
});
write.send(Message::Text(init.to_string().into())).await?;
if let Some(Ok(Message::Text(response_text))) = read.next().await {
let response: Value = serde_json::from_str(&response_text.to_string())?;
assert_eq!(response["result"]["protocolVersion"], "2025-06-18");
}
for i in 0..10 {
let request = json!({
"jsonrpc": "2.0",
"id": i + 2,
"method": "tools/list",
"params": {}
});
write
.send(Message::Text(request.to_string().into()))
.await?;
if let Some(Ok(Message::Text(response_text))) = read.next().await {
let response: Value = serde_json::from_str(&response_text.to_string())?;
assert!(response["result"]["tools"].is_array());
}
}
Ok::<(), anyhow::Error>(())
});
handles.push(handle);
}
for handle in handles {
handle.await??;
}
server_handle.abort();
Ok(())
}
#[tokio::test]
async fn test_websocket_reconnection() -> Result<()> {
let port = find_available_port().await;
let mut server = McpServer::new().await?;
let server_handle = tokio::spawn(async move { server.start(port).await });
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
let url = format!("ws://127.0.0.1:{port}/mcp");
{
let (ws_stream, _) = connect_async(&url).await?;
let (mut write, mut read) = ws_stream.split();
let init = json!({
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": "2025-06-18"
}
});
write.send(Message::Text(init.to_string().into())).await?;
let _ = read.next().await;
write.send(Message::Close(None)).await?;
}
{
let (ws_stream, _) = connect_async(&url).await?;
let (mut write, mut read) = ws_stream.split();
let init = json!({
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": "2025-06-18"
}
});
write.send(Message::Text(init.to_string().into())).await?;
if let Some(Ok(Message::Text(response_text))) = read.next().await {
let response: Value = serde_json::from_str(&response_text.to_string())?;
assert_eq!(response["result"]["protocolVersion"], "2025-06-18");
}
}
server_handle.abort();
Ok(())
}