use axum::body::to_bytes;
use axum::http::header;
use axum::response::Response;
const MAX_BODY_BYTES: usize = 1024 * 1024;
pub fn assert_status(res: &Response, expected: u16) {
let actual = res.status().as_u16();
assert_eq!(
actual, expected,
"expected HTTP status {expected}, got {actual}"
);
}
pub fn assert_status_in(res: &Response, allowed: &[u16]) {
let actual = res.status().as_u16();
assert!(
allowed.contains(&actual),
"expected HTTP status to be one of {allowed:?}, got {actual}"
);
}
pub fn assert_status_2xx(res: &Response) {
let actual = res.status().as_u16();
assert!(
(200..300).contains(&actual),
"expected a 2xx status, got {actual}"
);
}
pub fn assert_status_4xx(res: &Response) {
let actual = res.status().as_u16();
assert!(
(400..500).contains(&actual),
"expected a 4xx status, got {actual}"
);
}
pub fn assert_status_5xx(res: &Response) {
let actual = res.status().as_u16();
assert!(
(500..600).contains(&actual),
"expected a 5xx status, got {actual}"
);
}
pub async fn assert_contains(res: Response, fragment: &str) {
let body = to_bytes(res.into_body(), MAX_BODY_BYTES)
.await
.unwrap_or_else(|e| panic!("assert_contains: failed to read body: {e}"));
let body_str = std::str::from_utf8(&body)
.unwrap_or_else(|e| panic!("assert_contains: body is not UTF-8: {e}"));
assert!(
body_str.contains(fragment),
"expected body to contain `{fragment}`, got:\n{}",
truncate(body_str, 500)
);
}
pub async fn assert_not_contains(res: Response, fragment: &str) {
let body = to_bytes(res.into_body(), MAX_BODY_BYTES)
.await
.unwrap_or_else(|e| panic!("assert_not_contains: failed to read body: {e}"));
let body_str = std::str::from_utf8(&body)
.unwrap_or_else(|e| panic!("assert_not_contains: body is not UTF-8: {e}"));
assert!(
!body_str.contains(fragment),
"expected body to NOT contain `{fragment}`, got:\n{}",
truncate(body_str, 500)
);
}
pub fn assert_redirects(res: &Response, target: &str) {
let status = res.status();
assert!(
status.is_redirection(),
"assert_redirects: status was {status}, expected 3xx"
);
let loc = res
.headers()
.get(header::LOCATION)
.and_then(|v| v.to_str().ok())
.unwrap_or_else(|| panic!("assert_redirects: no Location header on {status} response"));
assert_eq!(
loc, target,
"assert_redirects: expected Location `{target}`, got `{loc}`"
);
}
#[cfg(feature = "template_views")]
pub fn assert_messages(res: &Response, secret: &[u8], expected: &[(&str, &str)]) {
use crate::messages::{Level, MESSAGES_COOKIE};
use std::str::FromStr as _;
let mut cookie_value: Option<String> = None;
for v in res.headers().get_all(header::SET_COOKIE).iter() {
let Ok(s) = v.to_str() else {
continue;
};
let first = s.split(';').next().unwrap_or("");
if let Some(val) = first.trim().strip_prefix(&format!("{MESSAGES_COOKIE}=")) {
cookie_value = Some(val.to_owned());
}
}
let Some(raw) = cookie_value else {
if expected.is_empty() {
return;
}
panic!("assert_messages: no `{MESSAGES_COOKIE}` Set-Cookie header found");
};
let mut headers = axum::http::HeaderMap::new();
headers.insert(
header::COOKIE,
axum::http::HeaderValue::from_str(&format!("{MESSAGES_COOKIE}={raw}"))
.expect("cookie value is header-safe (just produced it)"),
);
let (msgs, _) = crate::messages::drain(secret, &headers);
if expected.is_empty() {
assert!(
msgs.is_empty(),
"assert_messages: expected no messages, got {msgs:?}"
);
return;
}
let actual: Vec<(String, String)> = msgs
.iter()
.map(|m| (m.level.as_str().to_owned(), m.body.clone()))
.collect();
let expected_owned: Vec<(String, String)> = expected
.iter()
.map(|(lvl, body)| ((*lvl).to_owned(), (*body).to_owned()))
.collect();
assert_eq!(
actual, expected_owned,
"assert_messages: messages don't match — left=actual right=expected"
);
for (lvl, _) in expected {
Level::from_str(lvl)
.unwrap_or_else(|_| panic!("assert_messages: `{lvl}` is not a valid Level"));
}
}
pub fn assert_header(res: &Response, name: &str, value: &str) {
let actual = res
.headers()
.get(name)
.map(|v| v.to_str().unwrap_or("<non-utf8>").to_owned());
match actual {
None => panic!("expected header `{name}: {value}`, but header was missing"),
Some(actual) if actual == value => {}
Some(actual) => panic!("expected header `{name}: {value}`, got `{name}: {actual}`",),
}
}
pub fn assert_content_type(res: &Response, expected: &str) {
assert_header(res, "content-type", expected);
}
pub async fn assert_json_eq(res: Response, expected: &serde_json::Value) {
let bytes = to_bytes(res.into_body(), MAX_BODY_BYTES)
.await
.expect("read response body");
let actual: serde_json::Value = match serde_json::from_slice(&bytes) {
Ok(v) => v,
Err(e) => panic!(
"assert_json_eq: body is not valid JSON ({e}). Raw body:\n{}",
truncate(&String::from_utf8_lossy(&bytes), 500),
),
};
if &actual != expected {
let actual_pp = serde_json::to_string_pretty(&actual).unwrap_or_default();
let expected_pp = serde_json::to_string_pretty(expected).unwrap_or_default();
panic!("assert_json_eq mismatch.\nexpected:\n{expected_pp}\nactual:\n{actual_pp}");
}
}
pub async fn assert_json_not_eq(res: Response, unexpected: &serde_json::Value) {
let bytes = to_bytes(res.into_body(), MAX_BODY_BYTES)
.await
.expect("read response body");
let actual: serde_json::Value = match serde_json::from_slice(&bytes) {
Ok(v) => v,
Err(e) => panic!(
"assert_json_not_eq: body is not valid JSON ({e}). Raw body:\n{}",
truncate(&String::from_utf8_lossy(&bytes), 500),
),
};
if &actual == unexpected {
let actual_pp = serde_json::to_string_pretty(&actual).unwrap_or_default();
panic!("assert_json_not_eq: body equals the unexpected value:\n{actual_pp}",);
}
}
pub fn assert_redirect_chain(chain: &[(u16, String)], final_path: &str, final_status: u16) {
let last = chain
.last()
.unwrap_or_else(|| panic!("assert_redirect_chain: chain is empty"));
if last.0 != final_status || last.1 != final_path {
let pretty = chain
.iter()
.enumerate()
.map(|(i, (s, p))| format!(" {i}: {s} {p}"))
.collect::<Vec<_>>()
.join("\n");
panic!(
"assert_redirect_chain: expected final hop to be `{final_status} {final_path}`, got `{} {}`.\nFull chain:\n{pretty}",
last.0, last.1,
);
}
}
pub async fn assert_contains_count(res: Response, fragment: &str, count: usize) {
let bytes = to_bytes(res.into_body(), MAX_BODY_BYTES)
.await
.expect("read response body");
let body = String::from_utf8_lossy(&bytes);
let actual = body.matches(fragment).count();
assert_eq!(
actual, count,
"expected `{fragment}` to appear {count} times, found {actual}.\nBody (first 500 chars):\n{}",
truncate(&body, 500),
);
}
pub fn assert_cookie_set(res: &Response, name: &str, expected_value: Option<&str>) {
let mut matches: Vec<String> = Vec::new();
for v in res.headers().get_all(axum::http::header::SET_COOKIE).iter() {
let Ok(s) = v.to_str() else { continue };
let first = s.split(';').next().unwrap_or("");
if let Some(val) = first.trim().strip_prefix(&format!("{name}=")) {
matches.push(val.to_owned());
}
}
if matches.is_empty() {
let all_cookies: Vec<String> = res
.headers()
.get_all(axum::http::header::SET_COOKIE)
.iter()
.filter_map(|v| v.to_str().ok().map(str::to_owned))
.collect();
panic!(
"assert_cookie_set: no `Set-Cookie` for `{name}` found. \
Found {} Set-Cookie header(s): {all_cookies:?}",
all_cookies.len()
);
}
if let Some(expected) = expected_value {
let any_match = matches.iter().any(|v| v == expected);
assert!(
any_match,
"assert_cookie_set: `{name}` was set, but its value didn't match. \
Expected `{expected}`, got: {matches:?}"
);
}
}
pub fn assert_cookie_not_set(res: &Response, name: &str) {
for v in res.headers().get_all(axum::http::header::SET_COOKIE).iter() {
let Ok(s) = v.to_str() else { continue };
let first = s.split(';').next().unwrap_or("");
if first.trim().starts_with(&format!("{name}=")) {
panic!(
"assert_cookie_not_set: `Set-Cookie: {name}=...` was unexpectedly emitted: `{s}`"
);
}
}
}
fn truncate(s: &str, max: usize) -> String {
if s.len() <= max {
return s.to_owned();
}
let mut idx = max;
while idx > 0 && !s.is_char_boundary(idx) {
idx -= 1;
}
let remaining = s.len() - idx;
format!("{}...(+{remaining} more chars)", &s[..idx])
}
#[cfg(test)]
mod tests {
use super::*;
use axum::body::Body;
use axum::http::StatusCode;
fn html_response(status: StatusCode, body: &str) -> Response {
Response::builder()
.status(status)
.body(Body::from(body.to_owned()))
.unwrap()
}
fn redirect_response(status: StatusCode, location: &str) -> Response {
Response::builder()
.status(status)
.header(header::LOCATION, location)
.body(Body::empty())
.unwrap()
}
#[test]
fn assert_status_passes_on_match() {
let res = html_response(StatusCode::OK, "");
assert_status(&res, 200);
}
#[test]
#[should_panic(expected = "expected HTTP status 404, got 200")]
fn assert_status_panics_on_mismatch() {
let res = html_response(StatusCode::OK, "");
assert_status(&res, 404);
}
#[tokio::test]
async fn assert_contains_passes_when_body_includes_fragment() {
let res = html_response(StatusCode::OK, "Hello, world!");
assert_contains(res, "world").await;
}
#[tokio::test]
#[should_panic(expected = "expected body to contain `nope`")]
async fn assert_contains_panics_when_fragment_missing() {
let res = html_response(StatusCode::OK, "Hello, world!");
assert_contains(res, "nope").await;
}
#[tokio::test]
async fn assert_contains_passes_on_error_status_when_fragment_present() {
let res = html_response(StatusCode::NOT_FOUND, "Not Found");
assert_contains(res, "Not Found").await;
}
#[test]
fn truncate_short_input_passes_through() {
assert_eq!(truncate("hello", 500), "hello");
}
#[test]
fn truncate_long_input_appends_more_chars_indicator() {
let long = "x".repeat(1000);
let out = truncate(&long, 500);
assert!(out.starts_with(&"x".repeat(500)));
assert!(out.contains("...(+500 more chars)"), "got: {out}");
}
#[test]
fn truncate_clips_at_utf8_boundary_no_mid_codepoint_slice() {
let s = "é";
let out = truncate(s, 1);
assert!(out.starts_with("..."), "got: {out}");
}
#[tokio::test]
async fn assert_not_contains_passes_when_fragment_missing() {
let res = html_response(StatusCode::OK, "Goodbye, sky.");
assert_not_contains(res, "world").await;
}
#[tokio::test]
#[should_panic(expected = "expected body to NOT contain `world`")]
async fn assert_not_contains_panics_when_fragment_present() {
let res = html_response(StatusCode::OK, "world peace");
assert_not_contains(res, "world").await;
}
#[test]
fn assert_redirects_passes_on_302_with_location() {
let res = redirect_response(StatusCode::FOUND, "/login?next=%2Fprofile");
assert_redirects(&res, "/login?next=%2Fprofile");
}
#[test]
fn assert_redirects_passes_on_301() {
let res = redirect_response(StatusCode::MOVED_PERMANENTLY, "/new-home");
assert_redirects(&res, "/new-home");
}
#[test]
#[should_panic(expected = "expected 3xx")]
fn assert_redirects_panics_on_non_redirect_status() {
let res = html_response(StatusCode::OK, "");
assert_redirects(&res, "/login");
}
#[test]
#[should_panic(expected = "expected Location `/wrong`")]
fn assert_redirects_panics_on_location_mismatch() {
let res = redirect_response(StatusCode::FOUND, "/login");
assert_redirects(&res, "/wrong");
}
#[cfg(feature = "template_views")]
#[test]
fn assert_messages_passes_on_staged_match() {
use crate::messages;
const SECRET: &[u8] = b"test-secret-32-bytes-aaaaaaaaaaaa";
let cookie = messages::success(SECRET, &axum::http::HeaderMap::new(), "Item created.");
let res = Response::builder()
.status(StatusCode::SEE_OTHER)
.header(header::SET_COOKIE, cookie)
.body(Body::empty())
.unwrap();
assert_messages(&res, SECRET, &[("success", "Item created.")]);
}
#[cfg(feature = "template_views")]
#[test]
fn assert_messages_passes_on_empty_when_no_cookie_set() {
let res = html_response(StatusCode::OK, "");
assert_messages(&res, b"any-secret", &[]);
}
#[cfg(feature = "template_views")]
#[test]
#[should_panic(expected = "messages don't match")]
fn assert_messages_panics_on_mismatch() {
use crate::messages;
const SECRET: &[u8] = b"test-secret-32-bytes-aaaaaaaaaaaa";
let cookie = messages::success(SECRET, &axum::http::HeaderMap::new(), "Item created.");
let res = Response::builder()
.status(StatusCode::SEE_OTHER)
.header(header::SET_COOKIE, cookie)
.body(Body::empty())
.unwrap();
assert_messages(&res, SECRET, &[("error", "Something broke.")]);
}
fn header_response(name: &'static str, value: &'static str) -> Response {
Response::builder()
.status(StatusCode::OK)
.header(name, value)
.body(Body::empty())
.unwrap()
}
#[test]
fn assert_header_passes_on_exact_match() {
let res = header_response("X-Request-Id", "abc-123");
assert_header(&res, "x-request-id", "abc-123");
assert_header(&res, "X-REQUEST-ID", "abc-123");
}
#[test]
#[should_panic(expected = "header was missing")]
fn assert_header_panics_when_missing() {
let res = html_response(StatusCode::OK, "");
assert_header(&res, "x-not-set", "anything");
}
#[test]
#[should_panic(expected = "expected header")]
fn assert_header_panics_on_value_mismatch() {
let res = header_response("x-tag", "actual");
assert_header(&res, "x-tag", "expected");
}
#[test]
fn assert_content_type_passes() {
let res = header_response("content-type", "application/json");
assert_content_type(&res, "application/json");
}
#[test]
#[should_panic(expected = "expected header `content-type:")]
fn assert_content_type_panics_on_mismatch() {
let res = header_response("content-type", "text/html; charset=utf-8");
assert_content_type(&res, "application/json");
}
#[tokio::test]
async fn assert_json_eq_passes_on_structural_match() {
let res = Response::builder()
.status(StatusCode::OK)
.header("content-type", "application/json")
.body(Body::from(r#"{"id": 1, "name": "Alice"}"#))
.unwrap();
assert_json_eq(res, &serde_json::json!({"name": "Alice", "id": 1})).await;
}
#[tokio::test]
#[should_panic(expected = "assert_json_eq mismatch")]
async fn assert_json_eq_panics_on_value_mismatch() {
let res = Response::builder()
.status(StatusCode::OK)
.body(Body::from(r#"{"id": 1}"#))
.unwrap();
assert_json_eq(res, &serde_json::json!({"id": 2})).await;
}
#[tokio::test]
#[should_panic(expected = "body is not valid JSON")]
async fn assert_json_eq_panics_on_malformed_body() {
let res = html_response(StatusCode::OK, "<html>not json</html>");
assert_json_eq(res, &serde_json::json!({})).await;
}
#[tokio::test]
async fn assert_contains_count_passes_on_exact_count() {
let body = "<li>a</li><li>b</li><li>c</li>";
let res = html_response(StatusCode::OK, body);
assert_contains_count(res, "<li>", 3).await;
}
#[tokio::test]
async fn assert_contains_count_zero_means_absent() {
let res = html_response(StatusCode::OK, "no nope");
assert_contains_count(res, "yes", 0).await;
}
#[tokio::test]
#[should_panic(expected = "expected `<li>` to appear 5 times, found 3")]
async fn assert_contains_count_panics_on_wrong_count() {
let res = html_response(StatusCode::OK, "<li>a</li><li>b</li><li>c</li>");
assert_contains_count(res, "<li>", 5).await;
}
#[tokio::test]
async fn assert_json_not_eq_passes_when_values_differ() {
let res = Response::builder()
.status(StatusCode::OK)
.body(Body::from(r#"{"id": 1}"#))
.unwrap();
assert_json_not_eq(res, &serde_json::json!({"id": 2})).await;
}
#[tokio::test]
#[should_panic(expected = "body equals the unexpected value")]
async fn assert_json_not_eq_panics_on_structural_match() {
let res = Response::builder()
.status(StatusCode::OK)
.body(Body::from(r#"{"id": 1, "name": "Alice"}"#))
.unwrap();
assert_json_not_eq(res, &serde_json::json!({"name": "Alice", "id": 1})).await;
}
#[tokio::test]
#[should_panic(expected = "body is not valid JSON")]
async fn assert_json_not_eq_panics_on_malformed_body() {
let res = html_response(StatusCode::OK, "<html>not json</html>");
assert_json_not_eq(res, &serde_json::json!({})).await;
}
#[test]
fn assert_redirect_chain_passes_on_matching_final_hop() {
let chain = vec![
(302u16, "/old".to_owned()),
(302, "/intermediate".to_owned()),
(200, "/canonical".to_owned()),
];
assert_redirect_chain(&chain, "/canonical", 200);
}
#[test]
#[should_panic(expected = "chain is empty")]
fn assert_redirect_chain_panics_on_empty_chain() {
assert_redirect_chain(&[], "/anywhere", 200);
}
#[test]
#[should_panic(expected = "expected final hop to be `200 /canonical`")]
fn assert_redirect_chain_panics_on_wrong_final_path() {
let chain = vec![(302u16, "/old".to_owned()), (200, "/elsewhere".to_owned())];
assert_redirect_chain(&chain, "/canonical", 200);
}
#[test]
#[should_panic(expected = "expected final hop to be `200 /canonical`")]
fn assert_redirect_chain_panics_on_wrong_final_status() {
let chain = vec![(302u16, "/old".to_owned()), (404, "/canonical".to_owned())];
assert_redirect_chain(&chain, "/canonical", 200);
}
fn cookie_response(set_cookies: &[&str]) -> Response {
let mut builder = Response::builder().status(StatusCode::OK);
for c in set_cookies {
builder = builder.header(axum::http::header::SET_COOKIE, *c);
}
builder.body(Body::empty()).unwrap()
}
#[test]
fn assert_cookie_set_passes_when_cookie_present() {
let res = cookie_response(&["session=abc123; Path=/; HttpOnly"]);
assert_cookie_set(&res, "session", None);
}
#[test]
fn assert_cookie_set_passes_with_exact_value_match() {
let res = cookie_response(&["session=abc123; Path=/; HttpOnly"]);
assert_cookie_set(&res, "session", Some("abc123"));
}
#[test]
#[should_panic(expected = "no `Set-Cookie` for `session` found")]
fn assert_cookie_set_panics_when_cookie_absent() {
let res = cookie_response(&["other=value"]);
assert_cookie_set(&res, "session", None);
}
#[test]
#[should_panic(expected = "value didn't match")]
fn assert_cookie_set_panics_on_value_mismatch() {
let res = cookie_response(&["session=abc; Path=/"]);
assert_cookie_set(&res, "session", Some("xyz"));
}
#[test]
fn assert_cookie_set_handles_multiple_set_cookie_headers() {
let res = cookie_response(&["csrftoken=tok; Path=/", "session=abc; Path=/; HttpOnly"]);
assert_cookie_set(&res, "csrftoken", Some("tok"));
assert_cookie_set(&res, "session", Some("abc"));
}
#[test]
fn assert_cookie_not_set_passes_when_cookie_absent() {
let res = cookie_response(&["other=value"]);
assert_cookie_not_set(&res, "session");
}
#[test]
fn assert_cookie_not_set_passes_when_no_cookies_at_all() {
let res = cookie_response(&[]);
assert_cookie_not_set(&res, "session");
}
#[test]
#[should_panic(expected = "unexpectedly emitted")]
fn assert_cookie_not_set_panics_when_cookie_present() {
let res = cookie_response(&["session=abc; Path=/"]);
assert_cookie_not_set(&res, "session");
}
#[test]
fn assert_status_in_passes_on_allowed_match() {
let res = html_response(StatusCode::CREATED, "");
assert_status_in(&res, &[200, 201, 202]);
}
#[test]
#[should_panic(expected = "expected HTTP status to be one of")]
fn assert_status_in_panics_on_mismatch() {
let res = html_response(StatusCode::OK, "");
assert_status_in(&res, &[201, 202]);
}
#[test]
fn assert_status_2xx_accepts_range() {
for code in [200, 201, 202, 204, 299] {
let res = html_response(StatusCode::from_u16(code).unwrap(), "");
assert_status_2xx(&res);
}
}
#[test]
#[should_panic(expected = "expected a 2xx status, got 301")]
fn assert_status_2xx_panics_on_redirect() {
let res = html_response(StatusCode::MOVED_PERMANENTLY, "");
assert_status_2xx(&res);
}
#[test]
fn assert_status_4xx_accepts_range() {
for code in [400, 401, 403, 404, 422, 499] {
let res = html_response(StatusCode::from_u16(code).unwrap(), "");
assert_status_4xx(&res);
}
}
#[test]
#[should_panic(expected = "expected a 4xx status, got 200")]
fn assert_status_4xx_panics_on_success() {
let res = html_response(StatusCode::OK, "");
assert_status_4xx(&res);
}
#[test]
fn assert_status_5xx_accepts_range() {
for code in [500, 502, 503, 504, 599] {
let res = html_response(StatusCode::from_u16(code).unwrap(), "");
assert_status_5xx(&res);
}
}
#[test]
#[should_panic(expected = "expected a 5xx status, got 400")]
fn assert_status_5xx_panics_on_4xx() {
let res = html_response(StatusCode::BAD_REQUEST, "");
assert_status_5xx(&res);
}
}