use futures_util::{SinkExt, StreamExt};
use serde_json::{json, Value};
use std::time::Duration;
use tokio_tungstenite::{
tungstenite::protocol::Message,
};
mod mcp_test_helpers;
use mcp_test_helpers::{with_mcp_test_server, init_test_tracing};
#[tokio::test]
async fn test_websocket_upgrade_handshake() {
init_test_tracing();
with_mcp_test_server("websocket_upgrade_handshake", |server| async move {
let ws_url = server.ws_url();
let (ws_stream, response) = tokio_tungstenite::connect_async(&ws_url).await?;
assert_eq!(response.status(), 101);
let upgrade_header = response.headers().get("upgrade").unwrap().to_str().unwrap();
assert_eq!(upgrade_header.to_lowercase(), "websocket");
let connection_header = response.headers().get("connection").unwrap().to_str().unwrap();
assert!(connection_header.to_lowercase().contains("upgrade"));
assert!(response.headers().get("sec-websocket-accept").is_some());
let (mut write, mut read) = ws_stream.split();
let init_message = json!({
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": "2025-06-18",
"capabilities": {},
"clientInfo": {"name": "upgrade-test", "version": "1.0.0"}
}
});
write.send(Message::Text(init_message.to_string().into())).await?;
if let Some(Ok(Message::Text(response_text))) = read.next().await {
let response: Value = serde_json::from_str(&response_text)?;
assert_eq!(response["jsonrpc"], "2.0");
assert_eq!(response["id"], 1);
assert!(response["result"]["capabilities"].is_object());
} else {
return Err("Did not receive initialization response".into());
}
Ok(())
})
.await
.unwrap();
}
#[tokio::test]
async fn test_websocket_upgrade_invalid_headers() {
init_test_tracing();
with_mcp_test_server("websocket_upgrade_invalid_headers", |server| async move {
let client = reqwest::Client::new();
let url = format!("http://127.0.0.1:{}/mcp", server.port);
let response = client
.get(&url)
.header("Connection", "Upgrade")
.header("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
.header("Sec-WebSocket-Version", "13")
.send()
.await?;
assert_ne!(response.status(), 101);
let response = client
.get(&url)
.header("Upgrade", "websocket")
.header("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
.header("Sec-WebSocket-Version", "13")
.send()
.await?;
assert_ne!(response.status(), 101);
let response = client
.get(&url)
.header("Upgrade", "websocket")
.header("Connection", "Upgrade")
.header("Sec-WebSocket-Version", "13")
.send()
.await?;
assert_ne!(response.status(), 101);
let response = client
.get(&url)
.header("Upgrade", "websocket")
.header("Connection", "Upgrade")
.header("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
.header("Sec-WebSocket-Version", "12") .send()
.await?;
assert_ne!(response.status(), 101);
Ok(())
})
.await
.unwrap();
}
#[tokio::test]
async fn test_websocket_upgrade_with_transport_discovery() {
init_test_tracing();
with_mcp_test_server("websocket_upgrade_transport_discovery", |server| async move {
let client = reqwest::Client::new();
let url = format!("http://127.0.0.1:{}/mcp", server.port);
let response = client
.get(&url)
.header("Accept", "application/json")
.send()
.await?;
assert_eq!(response.status(), 200);
let transport_info: Value = response.json().await?;
let available_transports = transport_info["mcp_server"]["available_transports"].as_object()
.ok_or("No available_transports object found")?;
let websocket_transport = available_transports.get("websocket")
.ok_or("WebSocket transport not found")?;
assert!(websocket_transport["uri"].as_str().unwrap().starts_with("ws://"));
let ws_url = websocket_transport["uri"].as_str().unwrap();
let (ws_stream, _) = tokio_tungstenite::connect_async(ws_url).await?;
let (mut write, mut read) = ws_stream.split();
let ping_message = Message::Ping(vec![1, 2, 3].into());
write.send(ping_message).await?;
tokio::time::timeout(Duration::from_secs(5), async {
while let Some(Ok(msg)) = read.next().await {
if let Message::Pong(data) = msg {
assert_eq!(data.as_ref(), &[1, 2, 3]);
return Ok(());
}
}
Err("Did not receive pong")
}).await??;
Ok(())
})
.await
.unwrap();
}
#[tokio::test]
async fn test_http_to_websocket_upgrade_context() {
init_test_tracing();
with_mcp_test_server("http_to_websocket_upgrade", |server| async move {
let client = reqwest::Client::new();
let http_url = format!("http://127.0.0.1:{}/mcp", server.port);
let http_request = json!({
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": "2025-06-18",
"capabilities": {},
"clientInfo": {"name": "http-client", "version": "1.0.0"}
}
});
let http_response = client
.post(&http_url)
.json(&http_request)
.send()
.await?;
assert_eq!(http_response.status(), 200);
let http_result: Value = http_response.json().await?;
assert_eq!(http_result["jsonrpc"], "2.0");
assert_eq!(http_result["id"], 1);
let ws_url = format!("ws://127.0.0.1:{}/mcp", server.port);
let (ws_stream, upgrade_response) = tokio_tungstenite::connect_async(&ws_url).await?;
assert_eq!(upgrade_response.status(), 101);
let (mut write, mut read) = ws_stream.split();
let init_request = json!({
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": "2025-06-18",
"capabilities": {},
"clientInfo": {"name": "ws-client", "version": "1.0.0"}
}
});
write.send(Message::Text(init_request.to_string().into())).await?;
if let Some(Ok(Message::Text(response_text))) = read.next().await {
let init_result: Value = serde_json::from_str(&response_text)?;
assert_eq!(init_result["jsonrpc"], "2.0");
assert_eq!(init_result["id"], 1);
} else {
return Err("Did not receive initialization response".into());
}
let ws_request = json!({
"jsonrpc": "2.0",
"id": 2,
"method": "tools/list",
"params": {}
});
write.send(Message::Text(ws_request.to_string().into())).await?;
if let Some(Ok(Message::Text(response_text))) = read.next().await {
let ws_result: Value = serde_json::from_str(&response_text)?;
assert_eq!(ws_result["jsonrpc"], "2.0");
assert_eq!(ws_result["id"], 2);
assert!(ws_result["result"]["tools"].is_array());
} else {
return Err("Did not receive WebSocket response".into());
}
Ok(())
})
.await
.unwrap();
}
#[tokio::test]
async fn test_websocket_subprotocol_negotiation() {
init_test_tracing();
with_mcp_test_server("websocket_subprotocol_negotiation", |server| async move {
let ws_url = server.ws_url();
let (ws_stream, response) = tokio_tungstenite::connect_async(&ws_url).await?;
assert_eq!(response.status(), 101);
let (mut write, mut read) = ws_stream.split();
write.send(Message::Pong(vec![42].into())).await?;
write.send(Message::Ping(vec![1, 2, 3].into())).await?;
tokio::time::timeout(Duration::from_secs(2), async {
while let Some(Ok(msg)) = read.next().await {
if matches!(msg, Message::Pong(_)) {
return Ok(());
}
}
Err("Did not receive pong response")
}).await??;
Ok(())
})
.await
.unwrap();
}
#[tokio::test]
async fn test_concurrent_http_websocket_same_server() {
init_test_tracing();
with_mcp_test_server("concurrent_http_websocket", |server| async move {
let http_url = format!("http://127.0.0.1:{}/mcp", server.port);
let ws_url = format!("ws://127.0.0.1:{}/mcp", server.port);
let http_task = tokio::spawn({
let http_url = http_url.clone();
async move {
let client = reqwest::Client::new();
for i in 0..10 {
let request = json!({
"jsonrpc": "2.0",
"id": i + 100,
"method": "initialize",
"params": {
"protocolVersion": "2025-06-18",
"capabilities": {},
"clientInfo": {"name": format!("http-client-{}", i), "version": "1.0.0"}
}
});
let response = client.post(&http_url).json(&request).send().await?;
assert_eq!(response.status(), 200);
let result: Value = response.json().await?;
assert_eq!(result["jsonrpc"], "2.0");
assert_eq!(result["id"], i + 100);
}
Ok::<(), Box<dyn std::error::Error + Send + Sync>>(())
}
});
let ws_task = tokio::spawn({
let ws_url = ws_url.clone();
async move {
let (ws_stream, _) = tokio_tungstenite::connect_async(&ws_url).await?;
let (mut write, mut read) = ws_stream.split();
let init_request = json!({
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": "2025-06-18",
"capabilities": {},
"clientInfo": {"name": "ws-client", "version": "1.0.0"}
}
});
write.send(Message::Text(init_request.to_string().into())).await?;
let _ = read.next().await;
for i in 0..10 {
let request = json!({
"jsonrpc": "2.0",
"id": i + 200,
"method": "tools/list",
"params": {}
});
write.send(Message::Text(request.to_string().into())).await?;
if let Some(Ok(Message::Text(response_text))) = read.next().await {
let response: Value = serde_json::from_str(&response_text)?;
assert_eq!(response["jsonrpc"], "2.0");
assert_eq!(response["id"], i + 200);
}
}
Ok::<(), Box<dyn std::error::Error + Send + Sync>>(())
}
});
let (http_result, ws_result) = tokio::try_join!(http_task, ws_task)?;
http_result?;
ws_result?;
Ok(())
})
.await
.unwrap();
}
#[tokio::test]
async fn test_websocket_upgrade_error_conditions() {
init_test_tracing();
with_mcp_test_server("websocket_upgrade_errors", |server| async move {
let client = reqwest::Client::new();
let url = format!("http://127.0.0.1:{}/mcp", server.port);
let response = client
.post(&url)
.header("Upgrade", "websocket")
.header("Connection", "Upgrade")
.header("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
.header("Sec-WebSocket-Version", "13")
.send()
.await?;
assert_ne!(response.status(), 101);
let response = client
.get(&url)
.header("Upgrade", "websocket")
.header("Connection", "Upgrade")
.header("Sec-WebSocket-Key", "invalid-key-format")
.header("Sec-WebSocket-Version", "13")
.send()
.await?;
assert_ne!(response.status(), 101);
let response = client
.get(&url)
.header("upgrade", "websocket") .header("connection", "upgrade") .header("sec-websocket-key", "dGhlIHNhbXBsZSBub25jZQ==")
.header("sec-websocket-version", "13")
.send()
.await?;
if response.status() != 101 {
assert!(response.status().is_client_error() || response.status().is_server_error());
}
Ok(())
})
.await
.unwrap();
}