#![cfg(feature = "websocket")]
use futures::{SinkExt, StreamExt};
use mcpkit::protocol::{Message, Request, RequestId, Response};
use serde_json::json;
use std::net::SocketAddr;
use std::time::Duration;
use tokio::net::TcpListener;
use tokio::time::timeout;
use tokio_tungstenite::{accept_async, connect_async, tungstenite::Message as WsMessage};
async fn get_available_addr() -> Result<SocketAddr, Box<dyn std::error::Error>> {
let listener = TcpListener::bind("127.0.0.1:0").await?;
let addr = listener.local_addr()?;
Ok(addr)
}
const fn ws_text(s: String) -> WsMessage {
WsMessage::Text(s)
}
fn extract_text(msg: WsMessage) -> Option<String> {
match msg {
WsMessage::Text(text) => Some(text),
_ => None,
}
}
async fn spawn_test_server(listener: TcpListener) -> tokio::task::JoinHandle<()> {
tokio::spawn(async move {
if let Ok((stream, _addr)) = listener.accept().await {
if let Ok(ws_stream) = accept_async(stream).await {
let (mut tx, mut rx) = ws_stream.split();
while let Some(Ok(msg)) = rx.next().await {
if let Some(text) = extract_text(msg.clone()) {
if let Ok(mcp_msg) = serde_json::from_str::<Message>(&text) {
let response = match mcp_msg {
Message::Request(req) => {
let resp = match req.method.as_ref() {
"initialize" => Response::success(
req.id,
json!({
"protocolVersion": "2025-11-25",
"serverInfo": {
"name": "test-ws-server",
"version": "1.0.0"
},
"capabilities": {}
}),
),
"tools/list" => {
Response::success(req.id, json!({ "tools": [] }))
}
"ping" => Response::success(req.id, json!({})),
_ => Response::error(
req.id,
mcpkit::error::JsonRpcError::method_not_found(
req.method.to_string(),
),
),
};
Some(Message::Response(resp))
}
Message::Notification(_) => None,
Message::Response(_) => None,
};
if let Some(resp) = response {
if let Ok(json) = serde_json::to_string(&resp) {
if tx.send(ws_text(json)).await.is_err() {
break;
}
}
}
}
} else if let WsMessage::Close(_) = msg {
break;
}
}
}
}
})
}
#[tokio::test]
async fn test_websocket_connect() -> Result<(), Box<dyn std::error::Error>> {
let addr = get_available_addr().await?;
let listener = TcpListener::bind(addr).await?;
let _server = spawn_test_server(listener).await;
let url = format!("ws://{addr}");
let result = timeout(Duration::from_secs(5), connect_async(&url)).await??;
let (ws_stream, _) = result;
let _ = ws_stream;
Ok(())
}
#[tokio::test]
async fn test_websocket_initialize_handshake() -> Result<(), Box<dyn std::error::Error>> {
let addr = get_available_addr().await?;
let listener = TcpListener::bind(addr).await?;
let _server = spawn_test_server(listener).await;
let url = format!("ws://{addr}");
let (ws_stream, _) = connect_async(&url).await?;
let (mut tx, mut rx) = ws_stream.split();
let init_request = Request::with_params(
"initialize",
1u64,
json!({
"protocolVersion": "2025-11-25",
"capabilities": {},
"clientInfo": {
"name": "test-client",
"version": "1.0.0"
}
}),
);
let json = serde_json::to_string(&Message::Request(init_request))?;
tx.send(ws_text(json)).await?;
let response = timeout(Duration::from_secs(5), rx.next())
.await?
.ok_or("Stream ended")?
.map_err(|e| format!("WebSocket error: {e}"))?;
let text = extract_text(response).ok_or("Expected text message")?;
let msg: Message = serde_json::from_str(&text)?;
assert!(msg.is_response());
let resp = msg.as_response().ok_or("Expected response")?;
assert!(resp.is_success());
assert_eq!(resp.id, RequestId::Number(1));
assert_eq!(
resp.result.as_ref().ok_or("Expected result")?["protocolVersion"],
"2025-11-25"
);
Ok(())
}
#[tokio::test]
async fn test_websocket_request_response_cycle() -> Result<(), Box<dyn std::error::Error>> {
let addr = get_available_addr().await?;
let listener = TcpListener::bind(addr).await?;
let _server = spawn_test_server(listener).await;
let url = format!("ws://{addr}");
let (ws_stream, _) = connect_async(&url).await?;
let (mut tx, mut rx) = ws_stream.split();
for i in 1..=3 {
let request = Request::new("ping", i as u64);
let json = serde_json::to_string(&Message::Request(request))?;
tx.send(ws_text(json)).await?;
}
for i in 1..=3 {
let response = timeout(Duration::from_secs(5), rx.next())
.await?
.ok_or("Stream ended")?
.map_err(|e| format!("WebSocket error: {e}"))?;
let text = extract_text(response).ok_or("Expected text")?;
let msg: Message = serde_json::from_str(&text)?;
let resp = msg.as_response().ok_or("Expected response")?;
assert!(resp.is_success());
assert_eq!(resp.id, RequestId::Number(i));
}
Ok(())
}
#[tokio::test]
async fn test_websocket_tools_list() -> Result<(), Box<dyn std::error::Error>> {
let addr = get_available_addr().await?;
let listener = TcpListener::bind(addr).await?;
let _server = spawn_test_server(listener).await;
let url = format!("ws://{addr}");
let (ws_stream, _) = connect_async(&url).await?;
let (mut tx, mut rx) = ws_stream.split();
let request = Request::new("tools/list", 1u64);
let json = serde_json::to_string(&Message::Request(request))?;
tx.send(ws_text(json)).await?;
let response = timeout(Duration::from_secs(5), rx.next())
.await?
.ok_or("Stream ended")?
.map_err(|e| format!("WebSocket error: {e}"))?;
let text = extract_text(response).ok_or("Expected text")?;
let msg: Message = serde_json::from_str(&text)?;
let resp = msg.as_response().ok_or("Expected response")?;
assert!(resp.is_success());
let tools = resp.result.as_ref().ok_or("Expected result")?["tools"]
.as_array()
.ok_or("Expected array")?;
assert!(tools.is_empty()); Ok(())
}
#[tokio::test]
async fn test_websocket_method_not_found() -> Result<(), Box<dyn std::error::Error>> {
let addr = get_available_addr().await?;
let listener = TcpListener::bind(addr).await?;
let _server = spawn_test_server(listener).await;
let url = format!("ws://{addr}");
let (ws_stream, _) = connect_async(&url).await?;
let (mut tx, mut rx) = ws_stream.split();
let request = Request::new("unknown/method", 1u64);
let json = serde_json::to_string(&Message::Request(request))?;
tx.send(ws_text(json)).await?;
let response = timeout(Duration::from_secs(5), rx.next())
.await?
.ok_or("Stream ended")?
.map_err(|e| format!("WebSocket error: {e}"))?;
let text = extract_text(response).ok_or("Expected text")?;
let msg: Message = serde_json::from_str(&text)?;
let resp = msg.as_response().ok_or("Expected response")?;
assert!(resp.is_error());
assert_eq!(resp.error.as_ref().ok_or("Expected error")?.code, -32601);
Ok(())
}
#[tokio::test]
async fn test_websocket_bidirectional() -> Result<(), Box<dyn std::error::Error>> {
let addr = get_available_addr().await?;
let listener = TcpListener::bind(addr).await?;
let _server = spawn_test_server(listener).await;
let url = format!("ws://{addr}");
let (ws_stream, _) = connect_async(&url).await?;
let (mut tx, mut rx) = ws_stream.split();
let request1 = Request::new("ping", 1u64);
tx.send(ws_text(serde_json::to_string(&Message::Request(request1))?))
.await?;
let resp1 = timeout(Duration::from_secs(5), rx.next()).await;
assert!(resp1.is_ok());
let request2 = Request::new("tools/list", 2u64);
tx.send(ws_text(serde_json::to_string(&Message::Request(request2))?))
.await?;
let resp2 = timeout(Duration::from_secs(5), rx.next()).await;
assert!(resp2.is_ok());
Ok(())
}
#[tokio::test]
async fn test_websocket_graceful_close() -> Result<(), Box<dyn std::error::Error>> {
let addr = get_available_addr().await?;
let listener = TcpListener::bind(addr).await?;
let _server = spawn_test_server(listener).await;
let url = format!("ws://{addr}");
let (ws_stream, _) = connect_async(&url).await?;
let (mut tx, _rx) = ws_stream.split();
let result = tx.send(WsMessage::Close(None)).await;
assert!(result.is_ok());
Ok(())
}