use super::{ResponseSnapshot, SnapshotError, snapshot_response};
use axum::http::{HeaderName, HeaderValue, Method};
use axum_test::TestServer;
use bytes::Bytes;
use serde_json::Value;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use tokio::time::timeout;
use urlencoding::encode;
type MultipartPayload = Option<(Vec<(String, String)>, Vec<super::MultipartFilePart>)>;
const GRAPHQL_WS_MESSAGE_TIMEOUT: Duration = Duration::from_secs(2);
const GRAPHQL_WS_MAX_CONTROL_MESSAGES: usize = 32;
#[derive(Debug, Clone, PartialEq)]
pub struct GraphQLSubscriptionSnapshot {
pub operation_id: String,
pub acknowledged: bool,
pub event: Option<Value>,
pub errors: Vec<Value>,
pub complete_received: bool,
}
pub struct TestClient {
mock_server: Arc<TestServer>,
router: axum::Router,
http_server: Mutex<Option<Arc<TestServer>>>,
}
impl TestClient {
pub fn from_router(router: axum::Router) -> Result<Self, String> {
let mock_server =
TestServer::try_new(router.clone()).map_err(|e| format!("Failed to create test server: {}", e))?;
Ok(Self {
mock_server: Arc::new(mock_server),
router,
http_server: Mutex::new(None),
})
}
pub fn http_server(&self) -> Result<Arc<TestServer>, SnapshotError> {
let mut guard = self
.http_server
.lock()
.map_err(|_| SnapshotError::Decompression("Failed to lock HTTP test server state".to_string()))?;
if let Some(server) = guard.as_ref() {
return Ok(Arc::clone(server));
}
if tokio::runtime::Handle::try_current().is_err() {
return Err(SnapshotError::Decompression(
"WebSocket test transport requires an active Tokio runtime".to_string(),
));
}
let server = Arc::new(
TestServer::builder()
.http_transport()
.try_build(self.router.clone())
.map_err(|e| SnapshotError::Decompression(format!("Failed to create test server: {}", e)))?,
);
*guard = Some(Arc::clone(&server));
Ok(server)
}
pub async fn get(
&self,
path: &str,
query_params: Option<Vec<(String, String)>>,
headers: Option<Vec<(String, String)>>,
) -> Result<ResponseSnapshot, SnapshotError> {
let full_path = build_full_path(path, query_params.as_deref());
let mut request = self.mock_server.get(&full_path);
if let Some(headers_vec) = headers {
request = self.add_headers(request, headers_vec)?;
}
let response = request.await;
snapshot_response(response).await
}
pub async fn post(
&self,
path: &str,
json: Option<Value>,
form_data: Option<Vec<(String, String)>>,
multipart: MultipartPayload,
query_params: Option<Vec<(String, String)>>,
headers: Option<Vec<(String, String)>>,
) -> Result<ResponseSnapshot, SnapshotError> {
let full_path = build_full_path(path, query_params.as_deref());
let mut request = self.mock_server.post(&full_path);
if let Some(headers_vec) = headers {
request = self.add_headers(request, headers_vec)?;
}
if let Some((form_fields, files)) = multipart {
let (body, boundary) = super::build_multipart_body(&form_fields, &files);
let content_type = format!("multipart/form-data; boundary={}", boundary);
request = request.add_header("content-type", &content_type);
request = request.bytes(Bytes::from(body));
} else if let Some(form_fields) = form_data {
let fields_value = serde_json::to_value(&form_fields)
.map_err(|e| SnapshotError::Decompression(format!("Failed to serialize form fields: {}", e)))?;
let encoded = super::encode_urlencoded_body(&fields_value)
.map_err(|e| SnapshotError::Decompression(format!("Form encoding failed: {}", e)))?;
request = request.add_header("content-type", "application/x-www-form-urlencoded");
request = request.bytes(Bytes::from(encoded));
} else if let Some(json_value) = json {
request = request.json(&json_value);
}
let response = request.await;
snapshot_response(response).await
}
pub async fn request_raw(
&self,
method: Method,
path: &str,
body: Bytes,
query_params: Option<Vec<(String, String)>>,
headers: Option<Vec<(String, String)>>,
) -> Result<ResponseSnapshot, SnapshotError> {
let full_path = build_full_path(path, query_params.as_deref());
let mut request = self.mock_server.method(method, &full_path);
if let Some(headers_vec) = headers {
request = self.add_headers(request, headers_vec)?;
}
request = request.bytes(body);
let response = request.await;
snapshot_response(response).await
}
pub async fn put(
&self,
path: &str,
json: Option<Value>,
query_params: Option<Vec<(String, String)>>,
headers: Option<Vec<(String, String)>>,
) -> Result<ResponseSnapshot, SnapshotError> {
let full_path = build_full_path(path, query_params.as_deref());
let mut request = self.mock_server.put(&full_path);
if let Some(headers_vec) = headers {
request = self.add_headers(request, headers_vec)?;
}
if let Some(json_value) = json {
request = request.json(&json_value);
}
let response = request.await;
snapshot_response(response).await
}
pub async fn patch(
&self,
path: &str,
json: Option<Value>,
query_params: Option<Vec<(String, String)>>,
headers: Option<Vec<(String, String)>>,
) -> Result<ResponseSnapshot, SnapshotError> {
let full_path = build_full_path(path, query_params.as_deref());
let mut request = self.mock_server.patch(&full_path);
if let Some(headers_vec) = headers {
request = self.add_headers(request, headers_vec)?;
}
if let Some(json_value) = json {
request = request.json(&json_value);
}
let response = request.await;
snapshot_response(response).await
}
pub async fn delete(
&self,
path: &str,
query_params: Option<Vec<(String, String)>>,
headers: Option<Vec<(String, String)>>,
) -> Result<ResponseSnapshot, SnapshotError> {
let full_path = build_full_path(path, query_params.as_deref());
let mut request = self.mock_server.delete(&full_path);
if let Some(headers_vec) = headers {
request = self.add_headers(request, headers_vec)?;
}
let response = request.await;
snapshot_response(response).await
}
pub async fn options(
&self,
path: &str,
query_params: Option<Vec<(String, String)>>,
headers: Option<Vec<(String, String)>>,
) -> Result<ResponseSnapshot, SnapshotError> {
let full_path = build_full_path(path, query_params.as_deref());
let mut request = self.mock_server.method(Method::OPTIONS, &full_path);
if let Some(headers_vec) = headers {
request = self.add_headers(request, headers_vec)?;
}
let response = request.await;
snapshot_response(response).await
}
pub async fn head(
&self,
path: &str,
query_params: Option<Vec<(String, String)>>,
headers: Option<Vec<(String, String)>>,
) -> Result<ResponseSnapshot, SnapshotError> {
let full_path = build_full_path(path, query_params.as_deref());
let mut request = self.mock_server.method(Method::HEAD, &full_path);
if let Some(headers_vec) = headers {
request = self.add_headers(request, headers_vec)?;
}
let response = request.await;
snapshot_response(response).await
}
pub async fn trace(
&self,
path: &str,
query_params: Option<Vec<(String, String)>>,
headers: Option<Vec<(String, String)>>,
) -> Result<ResponseSnapshot, SnapshotError> {
let full_path = build_full_path(path, query_params.as_deref());
let mut request = self.mock_server.method(Method::TRACE, &full_path);
if let Some(headers_vec) = headers {
request = self.add_headers(request, headers_vec)?;
}
let response = request.await;
snapshot_response(response).await
}
pub async fn graphql_at(
&self,
endpoint: &str,
query: &str,
variables: Option<Value>,
operation_name: Option<&str>,
) -> Result<ResponseSnapshot, SnapshotError> {
let body = build_graphql_body(query, variables, operation_name);
self.post(endpoint, Some(body), None, None, None, None).await
}
pub async fn graphql(
&self,
query: &str,
variables: Option<Value>,
operation_name: Option<&str>,
) -> Result<ResponseSnapshot, SnapshotError> {
self.graphql_at("/graphql", query, variables, operation_name).await
}
pub async fn graphql_with_status(
&self,
query: &str,
variables: Option<Value>,
operation_name: Option<&str>,
) -> Result<(u16, ResponseSnapshot), SnapshotError> {
let snapshot = self.graphql(query, variables, operation_name).await?;
let status = snapshot.status;
Ok((status, snapshot))
}
pub async fn graphql_subscription_at(
&self,
endpoint: &str,
query: &str,
variables: Option<Value>,
operation_name: Option<&str>,
) -> Result<GraphQLSubscriptionSnapshot, SnapshotError> {
let operation_id = "spikard-subscription-1".to_string();
let http_server = self.http_server()?;
let upgrade = http_server
.get_websocket(endpoint)
.add_header("sec-websocket-protocol", "graphql-transport-ws")
.await;
if upgrade.status_code().as_u16() != 101 {
return Err(SnapshotError::Decompression(format!(
"GraphQL subscription upgrade failed with status {}",
upgrade.status_code()
)));
}
let mut websocket = super::WebSocketConnection::new(upgrade.into_websocket().await);
websocket
.send_json(&serde_json::json!({"type": "connection_init"}))
.await;
wait_for_graphql_ack(&mut websocket).await?;
websocket
.send_json(&serde_json::json!({
"id": operation_id,
"type": "subscribe",
"payload": build_graphql_body(query, variables, operation_name),
}))
.await;
let mut event = None;
let mut errors = Vec::new();
let mut complete_received = false;
for _ in 0..GRAPHQL_WS_MAX_CONTROL_MESSAGES {
let message = timeout(
GRAPHQL_WS_MESSAGE_TIMEOUT,
receive_graphql_protocol_message(&mut websocket),
)
.await
.map_err(|_| {
SnapshotError::Decompression("Timed out waiting for GraphQL subscription message".to_string())
})??;
let message_type = message.get("type").and_then(Value::as_str).unwrap_or_default();
match message_type {
"next"
if message
.get("id")
.and_then(Value::as_str)
.is_none_or(|id| id == operation_id) =>
{
event = message.get("payload").cloned();
websocket
.send_json(&serde_json::json!({
"id": operation_id,
"type": "complete",
}))
.await;
if let Ok(next_message) = timeout(
GRAPHQL_WS_MESSAGE_TIMEOUT,
receive_graphql_protocol_message(&mut websocket),
)
.await
&& let Ok(next_message) = next_message
&& next_message.get("type").and_then(Value::as_str) == Some("complete")
&& next_message
.get("id")
.and_then(Value::as_str)
.is_none_or(|id| id == operation_id)
{
complete_received = true;
}
break;
}
"error" => {
errors.push(message.get("payload").cloned().unwrap_or(message));
break;
}
"complete"
if message
.get("id")
.and_then(Value::as_str)
.is_none_or(|id| id == operation_id) =>
{
complete_received = true;
break;
}
"ping" => {
let mut pong = serde_json::json!({"type": "pong"});
if let Some(payload) = message.get("payload") {
pong["payload"] = payload.clone();
}
websocket.send_json(&pong).await;
}
"pong" => {}
_ => {}
}
}
websocket.close().await;
if event.is_none() && errors.is_empty() && !complete_received {
return Err(SnapshotError::Decompression(
"No GraphQL subscription event received before timeout".to_string(),
));
}
Ok(GraphQLSubscriptionSnapshot {
operation_id,
acknowledged: true,
event,
errors,
complete_received,
})
}
pub async fn graphql_subscription(
&self,
query: &str,
variables: Option<Value>,
operation_name: Option<&str>,
) -> Result<GraphQLSubscriptionSnapshot, SnapshotError> {
self.graphql_subscription_at("/graphql", query, variables, operation_name)
.await
}
fn add_headers(
&self,
mut request: axum_test::TestRequest,
headers: Vec<(String, String)>,
) -> Result<axum_test::TestRequest, SnapshotError> {
for (key, value) in headers {
let header_name = HeaderName::from_bytes(key.as_bytes())
.map_err(|e| SnapshotError::InvalidHeader(format!("Invalid header name: {}", e)))?;
let header_value = HeaderValue::from_str(&value)
.map_err(|e| SnapshotError::InvalidHeader(format!("Invalid header value: {}", e)))?;
request = request.add_header(header_name, header_value);
}
Ok(request)
}
}
async fn wait_for_graphql_ack(websocket: &mut super::WebSocketConnection) -> Result<(), SnapshotError> {
for _ in 0..GRAPHQL_WS_MAX_CONTROL_MESSAGES {
let message = timeout(GRAPHQL_WS_MESSAGE_TIMEOUT, receive_graphql_protocol_message(websocket))
.await
.map_err(|_| SnapshotError::Decompression("Timed out waiting for GraphQL connection_ack".to_string()))??;
match message.get("type").and_then(Value::as_str).unwrap_or_default() {
"connection_ack" => return Ok(()),
"ping" => {
let mut pong = serde_json::json!({"type": "pong"});
if let Some(payload) = message.get("payload") {
pong["payload"] = payload.clone();
}
websocket.send_json(&pong).await;
}
"connection_error" | "error" => {
return Err(SnapshotError::Decompression(format!(
"GraphQL subscription rejected during init: {}",
message
)));
}
_ => {}
}
}
Err(SnapshotError::Decompression(
"No GraphQL connection_ack received".to_string(),
))
}
async fn receive_graphql_protocol_message(websocket: &mut super::WebSocketConnection) -> Result<Value, SnapshotError> {
loop {
match websocket.receive_message().await {
super::WebSocketMessage::Text(text) => {
return serde_json::from_str::<Value>(&text).map_err(|e| {
SnapshotError::Decompression(format!("Failed to parse GraphQL WebSocket message as JSON: {}", e))
});
}
super::WebSocketMessage::Binary(bytes) => {
return serde_json::from_slice::<Value>(&bytes).map_err(|e| {
SnapshotError::Decompression(format!(
"Failed to parse GraphQL binary WebSocket message as JSON: {}",
e
))
});
}
super::WebSocketMessage::Ping(_) | super::WebSocketMessage::Pong(_) => continue,
super::WebSocketMessage::Close(reason) => {
return Err(SnapshotError::Decompression(format!(
"GraphQL WebSocket connection closed before response: {:?}",
reason
)));
}
}
}
}
pub fn build_graphql_body(query: &str, variables: Option<Value>, operation_name: Option<&str>) -> Value {
let mut body = serde_json::json!({ "query": query });
if let Some(vars) = variables {
body["variables"] = vars;
}
if let Some(op_name) = operation_name {
body["operationName"] = Value::String(op_name.to_string());
}
body
}
fn build_full_path(path: &str, query_params: Option<&[(String, String)]>) -> String {
match query_params {
None | Some(&[]) => path.to_string(),
Some(params) => {
let query_string: Vec<String> = params
.iter()
.map(|(k, v)| format!("{}={}", encode(k), encode(v)))
.collect();
if path.contains('?') {
format!("{}&{}", path, query_string.join("&"))
} else {
format!("{}?{}", path, query_string.join("&"))
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use axum::{
Router,
extract::ws::{Message, WebSocketUpgrade},
routing::get,
};
#[test]
fn build_full_path_no_params() {
let path = "/users";
assert_eq!(build_full_path(path, None), "/users");
assert_eq!(build_full_path(path, Some(&[])), "/users");
}
#[test]
fn build_full_path_with_params() {
let path = "/users";
let params = vec![
("id".to_string(), "123".to_string()),
("name".to_string(), "test user".to_string()),
];
let result = build_full_path(path, Some(¶ms));
assert!(result.starts_with("/users?"));
assert!(result.contains("id=123"));
assert!(result.contains("name=test%20user"));
}
#[test]
fn build_full_path_existing_query() {
let path = "/users?active=true";
let params = vec![("id".to_string(), "123".to_string())];
let result = build_full_path(path, Some(¶ms));
assert_eq!(result, "/users?active=true&id=123");
}
#[test]
fn test_graphql_query_builder() {
let query = "{ users { id name } }";
let variables = Some(serde_json::json!({ "limit": 10 }));
let op_name = Some("GetUsers");
let mut body = serde_json::json!({ "query": query });
if let Some(vars) = variables {
body["variables"] = vars;
}
if let Some(op_name) = op_name {
body["operationName"] = Value::String(op_name.to_string());
}
assert_eq!(body["query"], query);
assert_eq!(body["variables"]["limit"], 10);
assert_eq!(body["operationName"], "GetUsers");
}
#[test]
fn test_graphql_with_status_method() {
let query = "query { hello }";
let body = serde_json::json!({
"query": query,
"variables": null,
"operationName": null
});
let expected_fields = vec!["query", "variables", "operationName"];
for field in expected_fields {
assert!(body.get(field).is_some(), "Missing field: {}", field);
}
}
#[test]
fn test_build_graphql_body_basic() {
let query = "{ users { id name } }";
let body = build_graphql_body(query, None, None);
assert_eq!(body["query"], query);
assert!(body.get("variables").is_none() || body["variables"].is_null());
assert!(body.get("operationName").is_none() || body["operationName"].is_null());
}
#[test]
fn test_build_graphql_body_with_variables() {
let query = "query GetUser($id: ID!) { user(id: $id) { name } }";
let variables = Some(serde_json::json!({ "id": "123" }));
let body = build_graphql_body(query, variables, None);
assert_eq!(body["query"], query);
assert_eq!(body["variables"]["id"], "123");
}
#[test]
fn test_build_graphql_body_with_operation_name() {
let query = "query GetUsers { users { id } }";
let op_name = Some("GetUsers");
let body = build_graphql_body(query, None, op_name);
assert_eq!(body["query"], query);
assert_eq!(body["operationName"], "GetUsers");
}
#[test]
fn test_build_graphql_body_all_fields() {
let query = "mutation CreateUser($name: String!) { createUser(name: $name) { id } }";
let variables = Some(serde_json::json!({ "name": "Alice" }));
let op_name = Some("CreateUser");
let body = build_graphql_body(query, variables, op_name);
assert_eq!(body["query"], query);
assert_eq!(body["variables"]["name"], "Alice");
assert_eq!(body["operationName"], "CreateUser");
}
#[tokio::test]
async fn graphql_subscription_returns_first_event_and_completes() {
let app = Router::new().route(
"/graphql",
get(|ws: WebSocketUpgrade| async move {
ws.on_upgrade(|mut socket| async move {
while let Some(result) = socket.recv().await {
let Ok(Message::Text(text)) = result else {
continue;
};
let Ok(message): Result<Value, _> = serde_json::from_str(&text) else {
continue;
};
match message.get("type").and_then(Value::as_str) {
Some("connection_init") => {
let _ = socket
.send(Message::Text(
serde_json::json!({"type":"connection_ack"}).to_string().into(),
))
.await;
}
Some("subscribe") => {
let id = message.get("id").and_then(Value::as_str).unwrap_or("1");
let _ = socket
.send(Message::Text(
serde_json::json!({
"id": id,
"type": "next",
"payload": {"data": {"ticker": "AAPL"}},
})
.to_string()
.into(),
))
.await;
if let Some(Ok(Message::Text(complete_text))) = socket.recv().await {
let Ok(complete_message): Result<Value, _> = serde_json::from_str(&complete_text)
else {
break;
};
if complete_message.get("type").and_then(Value::as_str) == Some("complete") {
let _ = socket
.send(Message::Text(
serde_json::json!({"id": id, "type":"complete"}).to_string().into(),
))
.await;
}
}
break;
}
_ => {}
}
}
})
}),
);
let client = TestClient::from_router(app).expect("client");
assert!(client.http_server.lock().expect("lock").is_none());
let snapshot = client
.graphql_subscription("subscription { ticker }", None, None)
.await
.expect("subscription snapshot");
assert!(snapshot.acknowledged);
assert_eq!(snapshot.errors, Vec::<Value>::new());
assert_eq!(snapshot.event, Some(serde_json::json!({"data": {"ticker": "AAPL"}})));
assert!(snapshot.complete_received);
assert!(client.http_server.lock().expect("lock").is_some());
}
#[tokio::test]
async fn graphql_subscription_surfaces_connection_error() {
let app = Router::new().route(
"/graphql",
get(|ws: WebSocketUpgrade| async move {
ws.on_upgrade(|mut socket| async move {
while let Some(result) = socket.recv().await {
let Ok(Message::Text(text)) = result else {
continue;
};
let Ok(message): Result<Value, _> = serde_json::from_str(&text) else {
continue;
};
if message.get("type").and_then(Value::as_str) == Some("connection_init") {
let _ = socket
.send(Message::Text(
serde_json::json!({
"type": "connection_error",
"payload": {"message": "not authorized"},
})
.to_string()
.into(),
))
.await;
break;
}
}
})
}),
);
let client = TestClient::from_router(app).expect("client");
assert!(client.http_server.lock().expect("lock").is_none());
let error = client
.graphql_subscription("subscription { privateFeed }", None, None)
.await
.expect_err("expected connection error");
assert!(error.to_string().contains("connection_error"));
assert!(client.http_server.lock().expect("lock").is_some());
}
#[tokio::test]
async fn http_requests_do_not_initialize_socket_transport() {
let app = Router::new().route("/health", get(|| async { "ok" }));
let client = TestClient::from_router(app).expect("client");
assert!(client.http_server.lock().expect("lock").is_none());
let snapshot = client.get("/health", None, None).await.expect("response snapshot");
assert_eq!(snapshot.status, 200);
assert_eq!(snapshot.text().expect("body"), "ok");
assert!(client.http_server.lock().expect("lock").is_none());
}
}