use serde_json::{json, Value};
use std::time::Duration;
use tokio_tungstenite::tungstenite::Message;
use futures_util::{SinkExt, StreamExt};
use anyhow::Result;
mod mcp_test_helpers;
mod prompt_test_helpers;
use mcp_test_helpers::*;
use prompt_test_helpers::*;
#[tokio::test]
async fn test_prompt_list_websocket() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
init_test_tracing();
let port = find_available_port().await?;
let server = create_test_server_with_prompts().await?;
let server_handle = tokio::spawn(async move {
let mut server = server;
if let Err(e) = server.start(port).await {
eprintln!("Test server error: {e}");
}
});
tokio::time::sleep(Duration::from_millis(100)).await;
let ws_url = format!("ws://127.0.0.1:{}/mcp", port);
let (ws_stream, _) = tokio_tungstenite::connect_async(&ws_url).await?;
let (mut write, mut read) = ws_stream.split();
let init_message = json!({
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": "2025-06-18",
"capabilities": {},
"clientInfo": {"name": "test-client", "version": "1.0.0"}
}
});
write.send(Message::Text(serde_json::to_string(&init_message)?.into())).await?;
let _init_response = receive_ws_message(&mut read, Duration::from_secs(5)).await?;
let list_message = json!({
"jsonrpc": "2.0",
"id": 2,
"method": "prompts/list",
"params": {}
});
write.send(Message::Text(serde_json::to_string(&list_message)?.into())).await?;
let response_text = receive_ws_message(&mut read, Duration::from_secs(5)).await?;
let response: Value = serde_json::from_str(&response_text)?;
assert_eq!(response["jsonrpc"], "2.0");
assert_eq!(response["id"], 2);
assert!(response.get("result").is_some());
let result = &response["result"];
assert!(result.get("prompts").is_some());
let prompts = result["prompts"].as_array().unwrap();
assert_eq!(prompts.len(), 2);
let hello_prompt = &prompts[0];
assert_eq!(hello_prompt["name"], "hello_world");
assert!(hello_prompt["arguments"].is_array());
server_handle.abort();
Ok(())
}
#[tokio::test]
async fn test_prompt_list_http() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
init_test_tracing();
let port = find_available_port().await?;
let server = create_test_server_with_prompts().await?;
let server_handle = tokio::spawn(async move {
let mut server = server;
if let Err(e) = server.start(port).await {
eprintln!("Test server error: {e}");
}
});
tokio::time::sleep(Duration::from_millis(100)).await;
let http_url = format!("http://127.0.0.1:{}/mcp", port);
let client = reqwest::Client::new();
let init_response = client
.post(&http_url)
.header("Content-Type", "application/json")
.header("Accept", "application/json")
.json(&json!({
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": "2025-06-18",
"capabilities": {},
"clientInfo": {"name": "test-client", "version": "1.0.0"}
}
}))
.send()
.await?;
let cookies = init_response.cookies().collect::<Vec<_>>();
let mut request_builder = client
.post(&http_url)
.header("Content-Type", "application/json")
.header("Accept", "application/json")
.json(&json!({
"jsonrpc": "2.0",
"id": 2,
"method": "prompts/list",
"params": {}
}));
for cookie in cookies {
request_builder = request_builder.header("Cookie", format!("{}={}", cookie.name(), cookie.value()));
}
let list_response = request_builder.send().await?;
let response_json: Value = list_response.json().await?;
assert_eq!(response_json["jsonrpc"], "2.0");
assert_eq!(response_json["id"], 2);
assert!(response_json.get("result").is_some());
let result = &response_json["result"];
assert!(result.get("prompts").is_some());
let prompts = result["prompts"].as_array().unwrap();
assert_eq!(prompts.len(), 2);
server_handle.abort();
Ok(())
}
#[tokio::test]
async fn test_prompt_get_not_found() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
init_test_tracing();
let port = find_available_port().await?;
let server = create_test_server_with_prompts().await?;
let server_handle = tokio::spawn(async move {
let mut server = server;
if let Err(e) = server.start(port).await {
eprintln!("Test server error: {e}");
}
});
tokio::time::sleep(Duration::from_millis(100)).await;
let ws_url = format!("ws://127.0.0.1:{}/mcp", port);
let (ws_stream, _) = tokio_tungstenite::connect_async(&ws_url).await?;
let (mut write, mut read) = ws_stream.split();
let init_message = json!({
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": "2025-06-18",
"capabilities": {},
"clientInfo": {"name": "test-client", "version": "1.0.0"}
}
});
write.send(Message::Text(serde_json::to_string(&init_message)?.into())).await?;
let _init_response = receive_ws_message(&mut read, Duration::from_secs(5)).await?;
let get_message = json!({
"jsonrpc": "2.0",
"id": 2,
"method": "prompts/get",
"params": {
"name": "nonexistent_prompt",
"arguments": {}
}
});
write.send(Message::Text(serde_json::to_string(&get_message)?.into())).await?;
let response_text = receive_ws_message(&mut read, Duration::from_secs(5)).await?;
let response: Value = serde_json::from_str(&response_text)?;
assert_eq!(response["jsonrpc"], "2.0");
assert_eq!(response["id"], 2);
assert!(response.get("error").is_some());
let error = &response["error"];
assert!(error["message"].as_str().unwrap().contains("Prompt not found"));
server_handle.abort();
Ok(())
}
#[tokio::test]
async fn test_prompt_missing_params() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
init_test_tracing();
let port = find_available_port().await?;
let server = create_test_server_with_prompts().await?;
let server_handle = tokio::spawn(async move {
let mut server = server;
if let Err(e) = server.start(port).await {
eprintln!("Test server error: {e}");
}
});
tokio::time::sleep(Duration::from_millis(100)).await;
let ws_url = format!("ws://127.0.0.1:{}/mcp", port);
let (ws_stream, _) = tokio_tungstenite::connect_async(&ws_url).await?;
let (mut write, mut read) = ws_stream.split();
let init_message = json!({
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": "2025-06-18",
"capabilities": {},
"clientInfo": {"name": "test-client", "version": "1.0.0"}
}
});
write.send(Message::Text(serde_json::to_string(&init_message)?.into())).await?;
let _init_response = receive_ws_message(&mut read, Duration::from_secs(5)).await?;
let missing_params_message = json!({
"jsonrpc": "2.0",
"id": 2,
"method": "prompts/get",
"params": {
"name": "hello_world",
"arguments": {}
}
});
write.send(Message::Text(serde_json::to_string(&missing_params_message)?.into())).await?;
let response_text = receive_ws_message(&mut read, Duration::from_secs(5)).await?;
let response: Value = serde_json::from_str(&response_text)?;
assert_eq!(response["jsonrpc"], "2.0");
assert_eq!(response["id"], 2);
assert!(response.get("error").is_some());
let error = &response["error"];
let error_message = error["message"].as_str().unwrap();
assert!(error_message.contains("Missing required argument: name") || error_message.contains("Prompt not found"));
server_handle.abort();
Ok(())
}
#[tokio::test]
async fn test_prompt_invalid_method() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
init_test_tracing();
with_mcp_connection("prompt_invalid_method", |_server, mut write, mut read| async move {
let invalid_request = json!({
"jsonrpc": "2.0",
"id": 1,
"method": "prompts/invalid",
"params": {}
});
write.send(Message::Text(invalid_request.to_string().into())).await?;
let response_text = receive_ws_message(&mut read, Duration::from_secs(5)).await?;
let response: Value = serde_json::from_str(&response_text)?;
assert_eq!(response["jsonrpc"], "2.0");
assert_eq!(response["id"], 1);
assert!(response.get("error").is_some());
let error = &response["error"];
assert_eq!(error["code"], -32601);
Ok(())
}).await
}
#[tokio::test]
async fn test_prompt_malformed_request() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
init_test_tracing();
with_mcp_connection("prompt_malformed", |_server, mut write, mut read| async move {
let malformed_json = r#"{"jsonrpc": "2.0", "id": 1, "method": "prompts/list", "params": {INVALID}"#;
write.send(Message::Text(malformed_json.to_string().into())).await?;
let response_text = receive_ws_message(&mut read, Duration::from_secs(5)).await?;
let response: Value = serde_json::from_str(&response_text)?;
assert_eq!(response["jsonrpc"], "2.0");
assert!(response.get("error").is_some());
let error = &response["error"];
assert_eq!(error["code"], -32700);
Ok(())
}).await
}