use std::collections::HashMap;
use axum::{
Router,
body::{Body, to_bytes},
http::{
HeaderMap, Method, Request, StatusCode,
header::{CONTENT_LENGTH, CONTENT_TYPE, COOKIE, HeaderValue, LOCATION, SET_COOKIE},
},
};
use tower::ServiceExt;
use url::Url;
use crate::conf::global_settings;
const MAX_REDIRECTS: usize = 10;
#[derive(Debug, Clone)]
pub struct TestResponse {
pub status_code: StatusCode,
pub content: String,
pub headers: HeaderMap,
pub url: String,
}
pub struct TestClient {
router: Router,
pub cookie_jar: HashMap<String, String>,
pub session_user: Option<String>,
pub follow_redirects: bool,
}
impl TestClient {
#[must_use]
pub fn new(router: Router) -> Self {
Self {
router,
cookie_jar: HashMap::new(),
session_user: None,
follow_redirects: false,
}
}
#[must_use]
pub fn with_follow_redirects(router: Router, follow_redirects: bool) -> Self {
let mut client = Self::new(router);
client.follow_redirects = follow_redirects;
client
}
#[must_use]
pub fn login(&mut self, username: &str, password: &str) -> bool {
if username.is_empty() || password.is_empty() {
return false;
}
self.cookie_jar.insert(
global_settings::SESSION_COOKIE_NAME.to_string(),
format!("session-{username}"),
);
self.session_user = Some(username.to_string());
true
}
pub fn logout(&mut self) {
self.cookie_jar.remove(global_settings::SESSION_COOKIE_NAME);
self.session_user = None;
}
pub fn get(&mut self, url: &str) -> TestResponse {
crate::runtime::block_on(self.request(Method::GET, url, None))
}
pub fn post(&mut self, url: &str, data: impl AsRef<str>) -> TestResponse {
crate::runtime::block_on(self.request(Method::POST, url, Some(data.as_ref().to_string())))
}
pub fn put(&mut self, url: &str, data: impl AsRef<str>) -> TestResponse {
crate::runtime::block_on(self.request(Method::PUT, url, Some(data.as_ref().to_string())))
}
pub fn patch(&mut self, url: &str, data: impl AsRef<str>) -> TestResponse {
crate::runtime::block_on(self.request(Method::PATCH, url, Some(data.as_ref().to_string())))
}
pub fn delete(&mut self, url: &str) -> TestResponse {
crate::runtime::block_on(self.request(Method::DELETE, url, None))
}
pub fn head(&mut self, url: &str) -> TestResponse {
crate::runtime::block_on(self.request(Method::HEAD, url, None))
}
pub fn options(&mut self, url: &str) -> TestResponse {
crate::runtime::block_on(self.request(Method::OPTIONS, url, None))
}
async fn request(&mut self, method: Method, url: &str, body: Option<String>) -> TestResponse {
let mut current_method = method;
let mut current_url = normalize_url(url);
let mut current_body = body.unwrap_or_default();
for redirect_count in 0..=MAX_REDIRECTS {
let request = self.build_request(¤t_method, ¤t_url, ¤t_body);
let response = self
.router
.clone()
.oneshot(request)
.await
.expect("test client request should be accepted by router");
let status_code = response.status();
let headers = response.headers().clone();
let content = String::from_utf8_lossy(
&to_bytes(response.into_body(), usize::MAX)
.await
.expect("test client response body should be readable"),
)
.into_owned();
self.update_cookies(&headers);
let test_response = TestResponse {
status_code,
content,
headers,
url: current_url.clone(),
};
if !self.follow_redirects || !is_redirect(test_response.status_code) {
return test_response;
}
let Some(location) = redirect_location(&test_response.headers) else {
return test_response;
};
assert!(
redirect_count < MAX_REDIRECTS,
"maximum redirect limit ({MAX_REDIRECTS}) exceeded while requesting {}",
test_response.url
);
current_method = redirected_method(test_response.status_code, ¤t_method);
current_url = resolve_redirect_url(¤t_url, location);
current_body.clear();
}
unreachable!("redirect loop should return or panic before exhausting the loop")
}
#[must_use]
fn build_request(&self, method: &Method, url: &str, body: &str) -> Request<Body> {
let mut request = Request::builder()
.method(method.clone())
.uri(url)
.body(if body.is_empty() {
Body::empty()
} else {
Body::from(body.to_string())
})
.expect("test client request should build");
let headers = request.headers_mut();
if !self.cookie_jar.is_empty() {
headers.insert(
COOKIE,
HeaderValue::from_str(&cookie_header_value(&self.cookie_jar))
.expect("cookie header should be valid"),
);
}
if method_has_body(method) {
headers.insert(
CONTENT_TYPE,
HeaderValue::from_static("application/x-www-form-urlencoded"),
);
headers.insert(
CONTENT_LENGTH,
HeaderValue::from_str(&body.len().to_string())
.expect("content length should be valid"),
);
if let Some(csrf_token) = self.cookie_jar.get(global_settings::CSRF_COOKIE_NAME) {
headers.insert(
"x-csrftoken",
HeaderValue::from_str(csrf_token).expect("csrf header should be valid"),
);
}
}
request
}
fn update_cookies(&mut self, headers: &HeaderMap) {
for value in headers.get_all(SET_COOKIE) {
let Ok(cookie) = value.to_str() else {
continue;
};
let Some((name, cookie_value, should_remove)) = parse_set_cookie(cookie) else {
continue;
};
if should_remove {
self.cookie_jar.remove(name);
continue;
}
self.cookie_jar
.insert(name.to_string(), cookie_value.to_string());
}
}
}
#[must_use]
fn method_has_body(method: &Method) -> bool {
matches!(method, &Method::POST | &Method::PUT | &Method::PATCH)
}
#[must_use]
fn is_redirect(status_code: StatusCode) -> bool {
matches!(
status_code,
StatusCode::MOVED_PERMANENTLY
| StatusCode::FOUND
| StatusCode::SEE_OTHER
| StatusCode::TEMPORARY_REDIRECT
| StatusCode::PERMANENT_REDIRECT
)
}
fn redirect_location(headers: &HeaderMap) -> Option<&str> {
headers.get(LOCATION).and_then(|value| value.to_str().ok())
}
#[must_use]
fn redirected_method(status_code: StatusCode, method: &Method) -> Method {
match status_code {
StatusCode::TEMPORARY_REDIRECT | StatusCode::PERMANENT_REDIRECT => method.clone(),
_ => Method::GET,
}
}
#[must_use]
fn normalize_url(url: &str) -> String {
if url.starts_with("http://") || url.starts_with("https://") {
let parsed = Url::parse(url).expect("absolute test URL should parse");
let path = parsed.path();
match parsed.query() {
Some(query) => format!("{path}?{query}"),
None => path.to_string(),
}
} else {
url.to_string()
}
}
#[must_use]
fn resolve_redirect_url(current_url: &str, location: &str) -> String {
if location.starts_with('/') {
return location.to_string();
}
let current = Url::parse(&format!("http://testserver{}", normalize_url(current_url)))
.expect("current redirect URL should parse");
current
.join(location)
.expect("redirect location should resolve")
.path()
.to_string()
}
#[must_use]
fn cookie_header_value(cookie_jar: &HashMap<String, String>) -> String {
let mut cookies = cookie_jar.iter().collect::<Vec<_>>();
cookies.sort_by(|left, right| left.0.cmp(right.0));
cookies
.into_iter()
.map(|(name, value)| format!("{name}={value}"))
.collect::<Vec<_>>()
.join("; ")
}
fn parse_set_cookie(cookie: &str) -> Option<(&str, &str, bool)> {
let mut parts = cookie.split(';');
let first = parts.next()?.trim();
let (name, value) = first.split_once('=')?;
let should_remove = value.is_empty()
|| parts
.clone()
.any(|part| part.trim().eq_ignore_ascii_case("max-age=0"));
Some((name.trim(), value.trim(), should_remove))
}
#[cfg(test)]
mod tests {
use super::*;
use axum::{
http::header::COOKIE,
response::{IntoResponse, Response},
routing::{delete, get, head, options, patch, post, put},
};
async fn ok() -> &'static str {
"ok"
}
async fn echo_body(body: String) -> String {
body
}
async fn set_cookie() -> Response {
(
[(SET_COOKIE, HeaderValue::from_static("theme=dark; Path=/"))],
"cookie set",
)
.into_response()
}
async fn echo_cookie(headers: HeaderMap) -> String {
headers
.get(COOKIE)
.and_then(|value| value.to_str().ok())
.unwrap_or_default()
.to_string()
}
async fn issue_csrf_cookie() -> Response {
(
[(
SET_COOKIE,
HeaderValue::from_static("csrftoken=securetoken; Path=/"),
)],
"form",
)
.into_response()
}
async fn require_csrf(headers: HeaderMap) -> Response {
let cookie = headers
.get(COOKIE)
.and_then(|value| value.to_str().ok())
.unwrap_or_default();
let header = headers
.get("x-csrftoken")
.and_then(|value| value.to_str().ok())
.unwrap_or_default();
if cookie.contains("csrftoken=securetoken") && header == "securetoken" {
StatusCode::OK.into_response()
} else {
StatusCode::FORBIDDEN.into_response()
}
}
async fn found_redirect() -> Response {
(
StatusCode::FOUND,
[(LOCATION, HeaderValue::from_static("/target/"))],
)
.into_response()
}
fn method_router() -> Router {
Router::new()
.route("/", get(ok))
.route("/submit/", post(echo_body))
.route("/put/", put(echo_body))
.route("/patch/", patch(echo_body))
.route("/delete/", delete(ok))
.route("/head/", head(ok))
.route("/options/", options(ok))
.route("/set-cookie/", get(set_cookie))
.route("/echo-cookie/", get(echo_cookie))
.route("/form/", get(issue_csrf_cookie))
.route("/csrf-submit/", post(require_csrf))
.route("/redirect/", get(found_redirect))
.route("/target/", get(|| async { "target" }))
}
#[test]
fn test_client_get_returns_200() {
let router = Router::new().route("/", get(ok));
let mut client = TestClient::new(router);
let response = client.get("/");
assert_eq!(response.status_code, StatusCode::OK);
assert_eq!(response.content, "ok");
assert_eq!(response.url, "/");
}
#[test]
fn test_client_post_sends_body() {
let router = Router::new().route("/submit/", post(echo_body));
let mut client = TestClient::new(router);
let response = client.post("/submit/", "name=alice");
assert_eq!(response.status_code, StatusCode::OK);
assert_eq!(response.content, "name=alice");
}
#[test]
fn test_client_default_no_follow_redirects() {
let router = method_router();
let mut client = TestClient::new(router);
let response = client.get("/redirect/");
assert_eq!(response.status_code, StatusCode::FOUND);
assert_eq!(response.url, "/redirect/");
}
#[test]
fn test_client_cookie_persistence() {
let router = method_router();
let mut client = TestClient::new(router);
let initial = client.get("/set-cookie/");
assert_eq!(initial.status_code, StatusCode::OK);
let response = client.get("/echo-cookie/");
assert!(response.content.contains("theme=dark"));
}
#[test]
fn client_follow_redirects_returns_final_response() {
let router = method_router();
let mut client = TestClient::with_follow_redirects(router, true);
let response = client.get("/redirect/");
assert_eq!(response.status_code, StatusCode::OK);
assert_eq!(response.content, "target");
assert_eq!(response.url, "/target/");
}
#[test]
fn client_get_handles_empty_body_gracefully() {
let router = Router::new().route("/", get(|| async { StatusCode::NO_CONTENT }));
let mut client = TestClient::new(router);
let response = client.get("/");
assert_eq!(response.status_code, StatusCode::NO_CONTENT);
assert!(response.content.is_empty());
}
#[test]
fn client_post_injects_csrf_cookie_into_header() {
let router = method_router();
let mut client = TestClient::new(router);
let form = client.get("/form/");
assert_eq!(form.status_code, StatusCode::OK);
let response = client.post("/csrf-submit/", "name=alice");
assert_eq!(response.status_code, StatusCode::OK);
}
#[test]
fn client_login_sets_session_cookie_and_user() {
let router = method_router();
let mut client = TestClient::new(router);
assert!(client.login("alice", "secret"));
assert_eq!(client.session_user.as_deref(), Some("alice"));
assert_eq!(
client.cookie_jar.get(global_settings::SESSION_COOKIE_NAME),
Some(&"session-alice".to_string())
);
}
#[test]
fn client_logout_clears_session_state() {
let router = method_router();
let mut client = TestClient::new(router);
let _ = client.login("alice", "secret");
client.logout();
assert!(client.session_user.is_none());
assert!(
!client
.cookie_jar
.contains_key(global_settings::SESSION_COOKIE_NAME)
);
}
#[test]
fn client_supports_put_patch_delete_head_and_options_requests() {
let router = method_router();
let mut client = TestClient::new(router);
let put_response = client.put("/put/", "value=1");
let patch_response = client.patch("/patch/", "value=2");
let delete_response = client.delete("/delete/");
let head_response = client.head("/head/");
let options_response = client.options("/options/");
assert_eq!(put_response.content, "value=1");
assert_eq!(patch_response.content, "value=2");
assert_eq!(delete_response.status_code, StatusCode::OK);
assert_eq!(head_response.status_code, StatusCode::OK);
assert_eq!(options_response.status_code, StatusCode::OK);
}
#[test]
fn test_response_clones_cleanly() {
let response = TestResponse {
status_code: StatusCode::ACCEPTED,
content: "payload".to_string(),
headers: HeaderMap::new(),
url: "/items/".to_string(),
};
let cloned = response.clone();
assert_eq!(cloned.status_code, StatusCode::ACCEPTED);
assert_eq!(cloned.content, "payload");
assert_eq!(cloned.url, "/items/");
}
#[test]
fn parse_set_cookie_marks_empty_values_for_removal() {
assert_eq!(
parse_set_cookie("sessionid=; Max-Age=0; Path=/"),
Some(("sessionid", "", true))
);
}
#[test]
fn normalize_url_accepts_absolute_urls() {
assert_eq!(
normalize_url("https://example.com/articles/?page=2"),
"/articles/?page=2"
);
}
}