use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use axum::body::{to_bytes, Body};
use axum::http::{HeaderName, HeaderValue, Method, Request, StatusCode};
use axum::Router;
use tower::ServiceExt;
#[derive(Clone)]
pub struct TestClient {
router: Router,
cookies: Arc<Mutex<HashMap<String, String>>>,
}
impl TestClient {
#[must_use]
pub fn new(router: Router) -> Self {
Self {
router,
cookies: Arc::new(Mutex::new(HashMap::new())),
}
}
#[must_use]
pub fn cookies(&self) -> HashMap<String, String> {
self.cookies.lock().expect("cookie jar poisoned").clone()
}
#[must_use]
pub fn cookie(&self, name: &str) -> Option<String> {
self.cookies
.lock()
.expect("cookie jar poisoned")
.get(name)
.cloned()
}
pub fn set_cookie(&self, name: impl Into<String>, value: impl Into<String>) {
self.cookies
.lock()
.expect("cookie jar poisoned")
.insert(name.into(), value.into());
}
pub fn clear_cookies(&self) {
self.cookies.lock().expect("cookie jar poisoned").clear();
}
pub async fn login(&self, path: impl Into<String>, fields: &[(&str, &str)]) -> TestResponse {
self.post(path).form(fields).send().await
}
#[cfg(feature = "tenancy")]
pub fn force_login_tenant_user(
&self,
secret: &crate::tenancy::session::SessionSecret,
slug: impl Into<String>,
user_id: i64,
ttl_secs: i64,
) -> &Self {
use crate::tenancy::tenant_console;
let payload = tenant_console::TenantSessionPayload::new(user_id, slug, ttl_secs);
let cookie = tenant_console::encode(secret, &payload);
self.set_cookie(tenant_console::COOKIE_NAME, cookie);
self
}
#[cfg(feature = "tenancy")]
pub fn force_login_operator(
&self,
secret: &crate::tenancy::session::SessionSecret,
operator_id: i64,
ttl_secs: i64,
) -> &Self {
use crate::tenancy::session;
let payload = session::SessionPayload::new(operator_id, ttl_secs);
let cookie = session::encode(secret, &payload);
self.set_cookie(session::COOKIE_NAME, cookie);
self
}
pub async fn logout(&self, path: Option<&str>) -> Option<TestResponse> {
self.clear_cookies();
match path {
Some(p) => Some(self.post(p.to_owned()).send().await),
None => None,
}
}
#[must_use]
pub fn get(&self, path: impl Into<String>) -> RequestBuilder<'_> {
self.request(Method::GET, path)
}
pub async fn get_following_redirects(
&self,
path: impl Into<String>,
max_hops: usize,
) -> (TestResponse, Vec<(u16, String)>) {
let mut current_path: String = path.into();
let mut chain: Vec<(u16, String)> = Vec::new();
let mut last: TestResponse = self.get(current_path.clone()).send().await;
for _ in 0..max_hops {
let status = last.status;
if !(300..400).contains(&status) {
break;
}
let location = match last.header("location") {
Some(loc) => loc.to_owned(),
None => break,
};
chain.push((status, location.clone()));
current_path = location;
last = self.get(current_path.clone()).send().await;
}
chain.push((last.status, current_path));
(last, chain)
}
#[must_use]
pub fn post(&self, path: impl Into<String>) -> RequestBuilder<'_> {
self.request(Method::POST, path)
}
#[must_use]
pub fn put(&self, path: impl Into<String>) -> RequestBuilder<'_> {
self.request(Method::PUT, path)
}
#[must_use]
pub fn patch(&self, path: impl Into<String>) -> RequestBuilder<'_> {
self.request(Method::PATCH, path)
}
#[must_use]
pub fn delete(&self, path: impl Into<String>) -> RequestBuilder<'_> {
self.request(Method::DELETE, path)
}
#[must_use]
pub fn head(&self, path: impl Into<String>) -> RequestBuilder<'_> {
self.request(Method::HEAD, path)
}
#[must_use]
pub fn request(&self, method: Method, path: impl Into<String>) -> RequestBuilder<'_> {
RequestBuilder {
client: self,
method,
path: path.into(),
headers: Vec::new(),
body: Body::empty(),
content_type: None,
}
}
}
pub struct RequestBuilder<'a> {
client: &'a TestClient,
method: Method,
path: String,
headers: Vec<(HeaderName, HeaderValue)>,
body: Body,
content_type: Option<&'static str>,
}
impl<'a> RequestBuilder<'a> {
#[must_use]
pub fn header(mut self, name: &str, value: &str) -> Self {
if let (Ok(n), Ok(v)) = (HeaderName::try_from(name), HeaderValue::try_from(value)) {
self.headers.push((n, v));
}
self
}
#[must_use]
pub fn json<T: serde::Serialize>(mut self, value: &T) -> Self {
let bytes = serde_json::to_vec(value).unwrap_or_default();
self.body = Body::from(bytes);
self.content_type = Some("application/json");
self
}
#[must_use]
pub fn form(mut self, fields: &[(&str, &str)]) -> Self {
let body = fields
.iter()
.map(|(k, v)| format!("{}={}", url_encode(k), url_encode(v)))
.collect::<Vec<_>>()
.join("&");
self.body = Body::from(body);
self.content_type = Some("application/x-www-form-urlencoded");
self
}
#[must_use]
pub fn body(mut self, body: impl Into<Body>) -> Self {
self.body = body.into();
self
}
pub async fn send(self) -> TestResponse {
let mut req = Request::builder().method(&self.method).uri(&self.path);
if let Some(ct) = self.content_type {
req = req.header("content-type", ct);
}
{
let jar = self.client.cookies.lock().expect("cookie jar poisoned");
if !jar.is_empty() {
let cookie_header = jar
.iter()
.map(|(k, v)| format!("{k}={v}"))
.collect::<Vec<_>>()
.join("; ");
req = req.header("cookie", cookie_header);
}
}
for (k, v) in self.headers {
req = req.header(k, v);
}
let req = req.body(self.body).unwrap();
let response = self
.client
.router
.clone()
.oneshot(req)
.await
.expect("test request panicked");
let set_cookies: Vec<String> = response
.headers()
.get_all("set-cookie")
.iter()
.filter_map(|v| v.to_str().ok().map(str::to_owned))
.collect();
{
let mut jar = self.client.cookies.lock().expect("cookie jar poisoned");
for raw in &set_cookies {
if let Some((name, value)) = parse_set_cookie(raw) {
if value.is_empty() {
jar.remove(&name);
} else {
jar.insert(name, value);
}
}
}
}
TestResponse::from_axum(response).await
}
}
fn parse_set_cookie(raw: &str) -> Option<(String, String)> {
let head = raw.split(';').next()?.trim();
let (name, value) = head.split_once('=')?;
Some((name.trim().to_owned(), value.trim().to_owned()))
}
pub struct TestResponse {
pub status: u16,
pub headers: HashMap<String, String>,
pub body: Vec<u8>,
}
impl TestResponse {
async fn from_axum(response: axum::http::Response<Body>) -> Self {
let (parts, body) = response.into_parts();
let status = parts.status.as_u16();
let headers: HashMap<String, String> = parts
.headers
.iter()
.map(|(k, v)| (k.as_str().to_owned(), v.to_str().unwrap_or("").to_owned()))
.collect();
let body = to_bytes(body, 16 * 1024 * 1024)
.await
.unwrap_or_default()
.to_vec();
Self {
status,
headers,
body,
}
}
#[must_use]
pub fn is_success(&self) -> bool {
StatusCode::from_u16(self.status).map_or(false, |s| s.is_success())
}
#[must_use]
pub fn text(&self) -> String {
String::from_utf8(self.body.clone()).unwrap_or_default()
}
#[must_use]
pub fn json<T: serde::de::DeserializeOwned>(&self) -> T {
serde_json::from_slice(&self.body).unwrap_or_else(|e| {
panic!(
"response body is not valid JSON: {e}\nbody: {}",
self.text()
)
})
}
#[must_use]
pub fn json_value(&self) -> serde_json::Value {
serde_json::from_slice(&self.body).unwrap_or(serde_json::Value::Null)
}
#[must_use]
pub fn header(&self, name: &str) -> Option<&str> {
let lower = name.to_ascii_lowercase();
self.headers.iter().find_map(|(k, v)| {
if k.eq_ignore_ascii_case(&lower) {
Some(v.as_str())
} else {
None
}
})
}
}
fn url_encode(s: &str) -> String {
s.bytes()
.map(|b| {
if b.is_ascii_alphanumeric() || matches!(b, b'-' | b'_' | b'.' | b'~') {
(b as char).to_string()
} else {
format!("%{b:02X}")
}
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
use axum::routing::{get, post};
use serde_json::json;
fn app() -> Router {
Router::new()
.route("/hello", get(|| async { "hi" }))
.route("/echo", post(|body: String| async move { body }))
.route(
"/json",
post(|body: axum::Json<serde_json::Value>| async move {
axum::Json(json!({"received": body.0}))
}),
)
.route(
"/status/{code}",
get(
|axum::extract::Path(code): axum::extract::Path<u16>| async move {
axum::http::StatusCode::from_u16(code).unwrap_or(axum::http::StatusCode::OK)
},
),
)
.route(
"/header_check",
get(|h: axum::http::HeaderMap| async move {
h.get("x-custom")
.map_or("missing".to_owned(), |v| v.to_str().unwrap().to_owned())
}),
)
}
#[tokio::test]
async fn get_returns_text() {
let c = TestClient::new(app());
let r = c.get("/hello").send().await;
assert_eq!(r.status, 200);
assert_eq!(r.text(), "hi");
assert!(r.is_success());
}
#[tokio::test]
async fn post_with_text_body_echos() {
let c = TestClient::new(app());
let r = c.post("/echo").body("hello world").send().await;
assert_eq!(r.status, 200);
assert_eq!(r.text(), "hello world");
}
#[tokio::test]
async fn post_json_body_returns_json() {
let c = TestClient::new(app());
let r = c.post("/json").json(&json!({"a": 1})).send().await;
assert_eq!(r.status, 200);
let v = r.json_value();
assert_eq!(v["received"]["a"], 1);
}
#[tokio::test]
async fn header_round_trip() {
let c = TestClient::new(app());
let r = c
.get("/header_check")
.header("x-custom", "value42")
.send()
.await;
assert_eq!(r.text(), "value42");
}
#[tokio::test]
async fn status_path_param() {
let c = TestClient::new(app());
assert_eq!(c.get("/status/200").send().await.status, 200);
assert_eq!(c.get("/status/404").send().await.status, 404);
assert_eq!(c.get("/status/500").send().await.status, 500);
}
#[tokio::test]
async fn test_client_is_reusable() {
let c = TestClient::new(app());
for _ in 0..3 {
assert_eq!(c.get("/hello").send().await.status, 200);
}
}
#[tokio::test]
async fn header_lookup_case_insensitive() {
let c = TestClient::new(app());
let r = c.get("/hello").send().await;
assert!(r.header("Content-Type").is_some() || r.header("content-type").is_some());
}
#[tokio::test]
async fn form_body_encodes_correctly() {
let c = TestClient::new(app());
let r = c
.post("/echo")
.form(&[("name", "alice & bob"), ("age", "30")])
.send()
.await;
let text = r.text();
assert!(text.contains("name=alice%20%26%20bob"));
assert!(text.contains("age=30"));
}
fn cookie_app() -> Router {
use axum::http::{header, HeaderMap, HeaderValue};
use axum::response::IntoResponse;
async fn login() -> impl IntoResponse {
let mut h = HeaderMap::new();
h.append(
header::SET_COOKIE,
HeaderValue::from_static("session=abc123; Path=/; HttpOnly"),
);
h.append(
header::SET_COOKIE,
HeaderValue::from_static("csrftoken=xyz; Path=/"),
);
(h, "ok")
}
async fn whoami(h: HeaderMap) -> String {
h.get("cookie")
.and_then(|v| v.to_str().ok())
.unwrap_or("(no cookies)")
.to_owned()
}
async fn logout() -> impl IntoResponse {
let mut h = HeaderMap::new();
h.insert(
header::SET_COOKIE,
HeaderValue::from_static("session=; Path=/; Max-Age=0"),
);
(h, "bye")
}
Router::new()
.route("/login", post(login))
.route("/me", get(whoami))
.route("/logout", post(logout))
}
#[tokio::test]
async fn set_cookie_persists_into_jar() {
let c = TestClient::new(cookie_app());
c.post("/login").send().await;
let jar = c.cookies();
assert_eq!(jar.get("session").map(String::as_str), Some("abc123"));
assert_eq!(jar.get("csrftoken").map(String::as_str), Some("xyz"));
}
#[tokio::test]
async fn jar_replays_on_subsequent_request() {
let c = TestClient::new(cookie_app());
c.post("/login").send().await;
let echoed = c.get("/me").send().await.text();
assert!(echoed.contains("session=abc123"), "echoed: {echoed}");
assert!(echoed.contains("csrftoken=xyz"), "echoed: {echoed}");
}
#[tokio::test]
async fn clear_cookies_drops_jar() {
let c = TestClient::new(cookie_app());
c.post("/login").send().await;
assert!(!c.cookies().is_empty());
c.clear_cookies();
assert!(c.cookies().is_empty());
let echoed = c.get("/me").send().await.text();
assert!(echoed.contains("(no cookies)"), "echoed: {echoed}");
}
#[tokio::test]
async fn empty_cookie_value_is_treated_as_deletion() {
let c = TestClient::new(cookie_app());
c.post("/login").send().await;
assert!(c.cookie("session").is_some());
c.post("/logout").send().await;
assert!(
c.cookie("session").is_none(),
"logout's Set-Cookie: session=; Max-Age=0 should delete the cookie"
);
assert_eq!(c.cookie("csrftoken").as_deref(), Some("xyz"));
}
#[tokio::test]
async fn set_cookie_manual_injection() {
let c = TestClient::new(cookie_app());
c.set_cookie("session", "manual-value");
let echoed = c.get("/me").send().await.text();
assert!(echoed.contains("session=manual-value"), "echoed: {echoed}");
}
#[tokio::test]
async fn login_helper_returns_response_and_persists_cookies() {
let c = TestClient::new(cookie_app());
let r = c.login("/login", &[]).await;
assert_eq!(r.status, 200);
assert!(c.cookie("session").is_some());
}
#[tokio::test]
async fn logout_with_path_clears_jar_and_hits_endpoint() {
let c = TestClient::new(cookie_app());
c.login("/login", &[]).await;
assert!(c.cookie("session").is_some());
let r = c.logout(Some("/logout")).await;
assert_eq!(r.expect("response").status, 200);
assert!(c.cookie("session").is_none());
}
#[tokio::test]
async fn logout_without_path_only_clears_locally() {
let c = TestClient::new(cookie_app());
c.login("/login", &[]).await;
let r = c.logout(None).await;
assert!(r.is_none());
assert!(c.cookies().is_empty());
}
#[tokio::test]
async fn no_cookie_header_emitted_when_jar_empty() {
let c = TestClient::new(cookie_app());
let echoed = c.get("/me").send().await.text();
assert_eq!(echoed, "(no cookies)");
}
#[test]
fn parse_set_cookie_handles_attributes() {
assert_eq!(
parse_set_cookie("session=abc; Path=/; HttpOnly").unwrap(),
("session".to_owned(), "abc".to_owned())
);
assert_eq!(
parse_set_cookie("foo=bar").unwrap(),
("foo".to_owned(), "bar".to_owned())
);
assert_eq!(
parse_set_cookie("expired=; Max-Age=0").unwrap(),
("expired".to_owned(), String::new())
);
assert!(parse_set_cookie("no-equals-sign").is_none());
}
fn redirect_app() -> Router {
use axum::http::{header, HeaderMap, HeaderValue, StatusCode};
use axum::response::IntoResponse;
async fn old() -> impl IntoResponse {
let mut h = HeaderMap::new();
h.insert(header::LOCATION, HeaderValue::from_static("/middle"));
(StatusCode::FOUND, h, "")
}
async fn middle() -> impl IntoResponse {
let mut h = HeaderMap::new();
h.insert(header::LOCATION, HeaderValue::from_static("/new"));
(StatusCode::MOVED_PERMANENTLY, h, "")
}
async fn new_handler() -> impl IntoResponse {
(StatusCode::OK, "final")
}
async fn loops() -> impl IntoResponse {
let mut h = HeaderMap::new();
h.insert(header::LOCATION, HeaderValue::from_static("/loop"));
(StatusCode::FOUND, h, "")
}
async fn redirect_no_location() -> impl IntoResponse {
(StatusCode::FOUND, "")
}
Router::new()
.route("/old", get(old))
.route("/middle", get(middle))
.route("/new", get(new_handler))
.route("/loop", get(loops))
.route("/dangling", get(redirect_no_location))
.route("/direct", get(|| async { "hi" }))
}
#[tokio::test]
async fn follows_two_hop_chain_to_final_200() {
let c = TestClient::new(redirect_app());
let (final_res, chain) = c.get_following_redirects("/old", 5).await;
assert_eq!(final_res.status, 200);
assert_eq!(final_res.text(), "final");
assert_eq!(chain.len(), 3);
assert_eq!(chain[0], (302, "/middle".to_owned()));
assert_eq!(chain[1], (301, "/new".to_owned()));
assert_eq!(chain[2].0, 200);
assert_eq!(chain[2].1, "/new");
}
#[tokio::test]
async fn follow_no_op_when_first_response_is_200() {
let c = TestClient::new(redirect_app());
let (res, chain) = c.get_following_redirects("/direct", 5).await;
assert_eq!(res.status, 200);
assert_eq!(chain.len(), 1);
assert_eq!(chain[0].0, 200);
}
#[tokio::test]
async fn follow_stops_at_max_hops() {
let c = TestClient::new(redirect_app());
let (res, chain) = c.get_following_redirects("/loop", 3).await;
assert_eq!(res.status, 302);
assert_eq!(chain.len(), 4);
for hop in &chain[..3] {
assert_eq!(hop.0, 302);
assert_eq!(hop.1, "/loop");
}
}
#[tokio::test]
async fn follow_stops_when_3xx_has_no_location() {
let c = TestClient::new(redirect_app());
let (res, chain) = c.get_following_redirects("/dangling", 5).await;
assert_eq!(res.status, 302);
assert_eq!(chain.len(), 1);
}
#[tokio::test]
async fn follow_max_hops_zero_returns_first_response() {
let c = TestClient::new(redirect_app());
let (res, chain) = c.get_following_redirects("/old", 0).await;
assert_eq!(res.status, 302);
assert_eq!(chain.len(), 1);
assert_eq!(chain[0].1, "/old");
}
#[cfg(feature = "tenancy")]
#[tokio::test]
async fn force_login_tenant_user_writes_decodable_cookie() {
use crate::tenancy::session::SessionSecret;
use crate::tenancy::tenant_console::{decode, COOKIE_NAME};
let secret = SessionSecret::from_bytes(b"a-test-secret-thirty-two-bytes-x".to_vec());
let c = TestClient::new(Router::new());
c.force_login_tenant_user(&secret, "acme", 42, 3600);
let cookie = c.cookie(COOKIE_NAME).expect("session cookie present");
let payload = decode(&secret, "acme", &cookie).expect("cookie decodes");
assert_eq!(payload.uid, 42);
assert_eq!(payload.slug, "acme");
assert!(!payload.is_impersonation());
}
#[cfg(feature = "tenancy")]
#[tokio::test]
async fn force_login_tenant_user_rejects_wrong_slug() {
use crate::tenancy::session::SessionSecret;
use crate::tenancy::tenant_console::decode;
let secret = SessionSecret::from_bytes(b"a-test-secret-thirty-two-bytes-x".to_vec());
let c = TestClient::new(Router::new());
c.force_login_tenant_user(&secret, "acme", 42, 3600);
let cookie = c
.cookie(crate::tenancy::tenant_console::COOKIE_NAME)
.unwrap();
assert!(decode(&secret, "globex", &cookie).is_err());
}
#[cfg(feature = "tenancy")]
#[tokio::test]
async fn force_login_operator_writes_decodable_cookie() {
use crate::tenancy::session::{decode, SessionSecret, COOKIE_NAME};
let secret = SessionSecret::from_bytes(b"a-test-secret-thirty-two-bytes-x".to_vec());
let c = TestClient::new(Router::new());
c.force_login_operator(&secret, 7, 3600);
let cookie = c.cookie(COOKIE_NAME).expect("operator cookie present");
let payload = decode(&secret, &cookie).expect("cookie decodes");
assert_eq!(payload.oid, 7);
}
#[cfg(feature = "tenancy")]
#[tokio::test]
async fn force_login_bad_secret_does_not_validate() {
use crate::tenancy::session::SessionSecret;
use crate::tenancy::tenant_console::decode;
let mint_secret = SessionSecret::from_bytes(b"mint-secret-thirty-two-bytes-xxx".to_vec());
let wrong_secret = SessionSecret::from_bytes(b"wrong-secret-thirty-two-bytes-xx".to_vec());
let c = TestClient::new(Router::new());
c.force_login_tenant_user(&mint_secret, "acme", 1, 3600);
let cookie = c
.cookie(crate::tenancy::tenant_console::COOKIE_NAME)
.unwrap();
assert!(decode(&wrong_secret, "acme", &cookie).is_err());
}
}