#[cfg(feature = "test-auth-bypass")]
use crate::auth::extractors::{TEST_CLAIMS_HEADER, TEST_USER_ID_HEADER, encode_test_claims_header};
use axum::{
Router,
body::{Body, Bytes},
http::{Method, Request, StatusCode, header},
};
use serde::{Deserialize, Serialize};
use tower::ServiceExt;
pub struct Scenario {
app: Router,
request: Request<Body>,
}
impl Scenario {
pub fn new(app: Router) -> Self {
Self {
app,
request: Request::builder()
.method(Method::GET)
.uri("/")
.body(Body::empty())
.unwrap(),
}
}
pub fn method(mut self, method: Method) -> Self {
*self.request.method_mut() = method;
self
}
pub fn uri(mut self, uri: &str) -> Self {
*self.request.uri_mut() = uri.parse().unwrap();
self
}
pub fn header(mut self, key: &str, value: &str) -> Self {
use axum::http::HeaderName;
self.request.headers_mut().insert(
HeaderName::from_bytes(key.as_bytes()).unwrap(),
value.parse().unwrap(),
);
self
}
pub fn with_header(self, key: &str, value: &str) -> Self {
self.header(key, value)
}
pub fn with_request_header(self, key: &str, value: &str) -> Self {
self.header(key, value)
}
pub fn bearer_token(self, token: &str) -> Self {
self.header("Authorization", &format!("Bearer {}", token))
}
pub fn with_auth(self, token: &str) -> Self {
self.bearer_token(token)
}
#[cfg(feature = "test-auth-bypass")]
pub fn with_test_user(self, user_id: &str) -> Self {
self.header(TEST_USER_ID_HEADER, user_id)
}
#[cfg(feature = "test-auth-bypass")]
pub fn with_test_claims<T: Serialize>(self, claims: &T) -> Self {
let encoded = encode_test_claims_header(claims);
self.header(TEST_CLAIMS_HEADER, &encoded)
}
pub fn with_query(mut self, params: &[(&str, &str)]) -> Self {
let uri = self.request.uri().clone();
let mut query_parts = vec![];
if let Some(query) = uri.query() {
query_parts.push(query.to_string());
}
for (key, value) in params {
query_parts.push(format!(
"{}={}",
urlencoding::encode(key),
urlencoding::encode(value)
));
}
let path = uri.path();
let new_uri = if query_parts.is_empty() {
path.to_string()
} else {
format!("{}?{}", path, query_parts.join("&"))
};
*self.request.uri_mut() = new_uri.parse().unwrap();
self
}
pub fn json_body<T: Serialize>(mut self, body: &T) -> Self {
let json = serde_json::to_string(body).unwrap();
*self.request.body_mut() = Body::from(json);
self.request
.headers_mut()
.insert(header::CONTENT_TYPE, "application/json".parse().unwrap());
self.request
.headers_mut()
.insert(header::ACCEPT, "application/json".parse().unwrap());
self
}
pub fn json<T: Serialize>(self, body: &T) -> Self {
self.json_body(body)
}
pub fn with_json<T: Serialize>(self, body: &T) -> Self {
self.json_body(body)
}
pub fn form_body<T: Serialize>(mut self, body: &T) -> Self {
let encoded = serde_urlencoded::to_string(body).unwrap();
*self.request.body_mut() = Body::from(encoded);
self.request.headers_mut().insert(
header::CONTENT_TYPE,
"application/x-www-form-urlencoded".parse().unwrap(),
);
self
}
pub fn form<T: Serialize>(self, body: &T) -> Self {
self.form_body(body)
}
pub fn with_form<T: Serialize>(self, body: &T) -> Self {
self.form_body(body)
}
pub fn text_body(mut self, body: impl Into<String>) -> Self {
*self.request.body_mut() = Body::from(body.into());
self
}
pub async fn execute(self) -> ScenarioAssert {
let response = self.app.oneshot(self.request).await.unwrap();
ScenarioAssert { response }
}
pub async fn send(self) -> ScenarioAssert {
self.execute().await
}
}
pub struct ScenarioAssert {
response: axum::response::Response,
}
impl ScenarioAssert {
pub fn assert_status(self, expected: StatusCode) -> Self {
assert_eq!(
self.response.status(),
expected,
"Expected status {}, got {}",
expected,
self.response.status()
);
self
}
pub fn assert_ok(self) -> Self {
self.assert_status(StatusCode::OK)
}
pub fn assert_created(self) -> Self {
self.assert_status(StatusCode::CREATED)
}
pub fn assert_bad_request(self) -> Self {
self.assert_status(StatusCode::BAD_REQUEST)
}
pub fn assert_unauthorized(self) -> Self {
self.assert_status(StatusCode::UNAUTHORIZED)
}
pub fn assert_not_found(self) -> Self {
self.assert_status(StatusCode::NOT_FOUND)
}
pub fn assert_server_error(self) -> Self {
self.assert_status(StatusCode::INTERNAL_SERVER_ERROR)
}
pub fn assert_header(self, key: &str, expected: &str) -> Self {
let value = self
.response
.headers()
.get(key)
.unwrap_or_else(|| panic!("Header '{}' not found", key))
.to_str()
.unwrap();
assert_eq!(value, expected, "Header '{}' value mismatch", key);
self
}
pub fn assert_header_exists(self, key: &str) -> Self {
self.response
.headers()
.get(key)
.unwrap_or_else(|| panic!("Header '{}' not found", key));
self
}
pub fn assert_redirect(self) -> Self {
let status = self.response.status();
assert!(
status.is_redirection(),
"Expected redirect status, got {}",
status
);
self.assert_header_exists(header::LOCATION.as_str())
}
pub fn assert_redirect_to(self, expected: &str) -> Self {
self.assert_redirect()
.assert_header(header::LOCATION.as_str(), expected)
}
pub fn assert_status_any(self, expected: &[StatusCode]) -> Self {
let status = self.response.status();
assert!(
expected.contains(&status),
"Expected status in {:?}, got {}",
expected,
status
);
self
}
pub fn assert_json(self) -> Self {
let content_type = self
.response
.headers()
.get(header::CONTENT_TYPE)
.expect("Content-Type header not found")
.to_str()
.unwrap();
assert!(
content_type.contains("application/json"),
"Expected JSON content type, got: {}",
content_type
);
self
}
pub fn assert_json_ok(self) -> Self {
self.assert_ok().assert_json()
}
pub async fn body_bytes(self) -> Vec<u8> {
axum::body::to_bytes(self.response.into_body(), usize::MAX)
.await
.unwrap()
.to_vec()
}
pub async fn body_string(self) -> String {
String::from_utf8(self.body_bytes().await).unwrap()
}
pub async fn json<T: for<'de> Deserialize<'de>>(self) -> T {
let bytes = self.body_bytes().await;
serde_json::from_slice(&bytes).expect("Failed to parse JSON response")
}
pub async fn json_value(self) -> serde_json::Value {
self.json().await
}
pub async fn json_into<T: for<'de> Deserialize<'de>>(self) -> T {
self.json().await
}
pub async fn assert_json_field(self, path: &str, expected: serde_json::Value) -> Self {
let (parts, bytes) = self.into_parts_and_body().await;
let json: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
let actual = json_path_get(&json, path)
.unwrap_or_else(|| panic!("Path '{}' not found in JSON", path));
assert_eq!(actual, &expected, "JSON path '{}' value mismatch", path);
Self::from_parts_and_body(parts, bytes)
}
pub async fn assert_json_contains(self, expected: serde_json::Value) -> Self {
let (parts, bytes) = self.into_parts_and_body().await;
let json: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
assert!(
json_contains(&json, &expected),
"Expected JSON to contain {:?}, got {:?}",
expected,
json
);
Self::from_parts_and_body(parts, bytes)
}
pub async fn assert_json_path(self, path: &str, expected: serde_json::Value) -> Self {
self.assert_json_field(path, expected).await
}
pub async fn json_path_eq(self, path: &str, expected: serde_json::Value) -> Self {
self.assert_json_path(path, expected).await
}
pub async fn assert_contains(self, text: &str) -> Self {
let (parts, bytes) = self.into_parts_and_body().await;
let body = String::from_utf8(bytes.to_vec()).unwrap();
assert!(
body.contains(text),
"Response body does not contain '{}'. Body: {}",
text,
body
);
Self::from_parts_and_body(parts, Bytes::from(body))
}
pub async fn dump(self) -> Self {
let status = self.response.status();
let headers: Vec<(String, String)> = self
.response
.headers()
.iter()
.map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("<invalid>").to_string()))
.collect();
let (parts, bytes) = self.into_parts_and_body().await;
let body = String::from_utf8(bytes.to_vec()).unwrap();
eprintln!("=== Response Dump ===");
eprintln!("Status: {}", status);
eprintln!("Headers:");
for (key, value) in &headers {
eprintln!(" {}: {}", key, value);
}
eprintln!("Body: {}", body);
eprintln!("===================");
Self::from_parts_and_body(parts, Bytes::from(body))
}
pub fn response(self) -> axum::response::Response {
self.response
}
async fn into_parts_and_body(self) -> (axum::http::response::Parts, Bytes) {
let (parts, body) = self.response.into_parts();
let bytes = axum::body::to_bytes(body, usize::MAX).await.unwrap();
(parts, bytes)
}
fn from_parts_and_body(parts: axum::http::response::Parts, body: Bytes) -> Self {
Self {
response: axum::response::Response::from_parts(parts, Body::from(body)),
}
}
}
fn json_path_get<'a>(json: &'a serde_json::Value, path: &str) -> Option<&'a serde_json::Value> {
let parts: Vec<&str> = path.split('.').collect();
let mut current = json;
for part in parts {
if let Ok(index) = part.parse::<usize>() {
current = current.get(index)?;
} else {
current = current.get(part)?;
}
}
Some(current)
}
fn json_contains(actual: &serde_json::Value, expected: &serde_json::Value) -> bool {
match (actual, expected) {
(serde_json::Value::Object(actual_map), serde_json::Value::Object(expected_map)) => {
expected_map.iter().all(|(key, expected_value)| {
actual_map
.get(key)
.map(|actual_value| json_contains(actual_value, expected_value))
.unwrap_or(false)
})
}
(serde_json::Value::Array(actual_array), serde_json::Value::Array(expected_array)) => {
expected_array.iter().all(|expected_value| {
actual_array
.iter()
.any(|actual_value| json_contains(actual_value, expected_value))
})
}
_ => actual == expected,
}
}
pub fn get(app: Router, uri: &str) -> Scenario {
Scenario::new(app).method(Method::GET).uri(uri)
}
pub fn post(app: Router, uri: &str) -> Scenario {
Scenario::new(app).method(Method::POST).uri(uri)
}
pub fn put(app: Router, uri: &str) -> Scenario {
Scenario::new(app).method(Method::PUT).uri(uri)
}
pub fn delete(app: Router, uri: &str) -> Scenario {
Scenario::new(app).method(Method::DELETE).uri(uri)
}
pub fn patch(app: Router, uri: &str) -> Scenario {
Scenario::new(app).method(Method::PATCH).uri(uri)
}
#[cfg(test)]
mod tests {
use super::*;
use axum::{Json, Router, routing::get as axum_get};
use serde_json::json;
async fn hello_handler() -> Json<serde_json::Value> {
Json(json!({"message": "Hello, World!"}))
}
async fn echo_handler(
axum::extract::Query(params): axum::extract::Query<
std::collections::HashMap<String, String>,
>,
) -> Json<serde_json::Value> {
Json(json!({"params": params}))
}
#[tokio::test]
async fn test_basic_get() {
let app = Router::new().route("/hello", axum_get(hello_handler));
let response = get(app, "/hello").send().await.assert_json_ok();
let body = response.json_value().await;
assert_eq!(body["message"], "Hello, World!");
}
#[tokio::test]
async fn test_with_query_params() {
let app = Router::new().route("/echo", axum_get(echo_handler));
let response = get(app, "/echo")
.with_query(&[("key", "value"), ("foo", "bar")])
.execute()
.await
.assert_ok();
let body: serde_json::Value = response.json().await;
assert!(body["params"].is_object());
}
#[tokio::test]
async fn test_json_alias() {
async fn post_handler(
axum::Json(payload): axum::Json<serde_json::Value>,
) -> axum::Json<serde_json::Value> {
axum::Json(payload)
}
let app = Router::new().route("/echo", axum::routing::post(post_handler));
let response = post(app, "/echo")
.json(&json!({"key": "value"}))
.send()
.await
.assert_json_ok();
let body: serde_json::Value = response.json().await;
assert_eq!(body["key"], json!("value"));
}
#[tokio::test]
async fn test_with_auth() {
let app = Router::new().route("/hello", axum_get(hello_handler));
get(app, "/hello")
.with_auth("test-token-123")
.execute()
.await
.assert_ok();
}
#[tokio::test]
async fn test_assert_json_path() {
let app = Router::new().route("/hello", axum_get(hello_handler));
let response = get(app, "/hello").send().await.assert_ok();
response
.json_path_eq("message", json!("Hello, World!"))
.await;
}
#[tokio::test]
async fn test_assert_contains() {
let app = Router::new().route("/hello", axum_get(hello_handler));
let response = get(app, "/hello").execute().await.assert_ok();
response.assert_contains("Hello").await;
}
#[tokio::test]
async fn test_assert_contains_preserves_headers() {
async fn handler() -> axum::response::Response {
axum::response::Response::builder()
.status(StatusCode::OK)
.header("x-test", "1")
.body(Body::from("hello"))
.unwrap()
}
let app = Router::new().route("/hello", axum::routing::get(handler));
get(app, "/hello")
.send()
.await
.assert_contains("hello")
.await
.assert_header("x-test", "1");
}
#[tokio::test]
async fn test_assert_json_contains() {
let app = Router::new().route("/hello", axum_get(hello_handler));
get(app, "/hello")
.send()
.await
.assert_json_ok()
.assert_json_contains(json!({"message": "Hello, World!"}))
.await;
}
#[tokio::test]
async fn test_assert_json_field_preserves_headers() {
async fn handler() -> axum::response::Response {
axum::response::Response::builder()
.status(StatusCode::OK)
.header("x-test", "1")
.header(header::CONTENT_TYPE, "application/json")
.body(Body::from(r#"{"message":"hello"}"#))
.unwrap()
}
let app = Router::new().route("/hello", axum::routing::get(handler));
get(app, "/hello")
.send()
.await
.assert_json_field("message", json!("hello"))
.await
.assert_header("x-test", "1");
}
#[tokio::test]
async fn test_assert_status_any_and_header_exists() {
async fn handler() -> axum::response::Response {
let mut response = axum::response::Response::new(Body::from("ok"));
response
.headers_mut()
.insert("x-test", "1".parse().unwrap());
*response.status_mut() = StatusCode::OK;
response
}
let app = Router::new().route("/test", axum::routing::get(handler));
get(app, "/test")
.send()
.await
.assert_status_any(&[StatusCode::OK, StatusCode::CREATED])
.assert_header_exists("x-test");
}
#[tokio::test]
async fn test_form_alias() {
#[derive(Deserialize)]
struct LoginForm {
email: String,
}
async fn form_handler(axum::Form(form): axum::Form<LoginForm>) -> Json<serde_json::Value> {
Json(json!({ "email": form.email }))
}
let app = Router::new().route("/form", axum::routing::post(form_handler));
let response = post(app, "/form")
.with_form(&[("email", "test@example.com")])
.send()
.await
.assert_json_ok();
assert_eq!(
response.json_value().await["email"],
json!("test@example.com")
);
}
#[tokio::test]
async fn test_assert_redirect_to() {
async fn redirect_handler() -> axum::response::Response {
axum::response::Response::builder()
.status(StatusCode::FOUND)
.header(header::LOCATION, "/target")
.body(Body::empty())
.unwrap()
}
let app = Router::new().route("/redirect", axum::routing::get(redirect_handler));
get(app, "/redirect")
.send()
.await
.assert_redirect_to("/target");
}
}