use std::fmt;
use std::future::Future;
use std::pin::Pin;
use std::sync::{Arc, LazyLock, Mutex};
use axum::{
Router,
body::{Body, to_bytes},
http::{HeaderMap, Method, Request, StatusCode, Uri, header},
};
use serde::Serialize;
use serde::de::DeserializeOwned;
use tower::ServiceExt;
#[cfg(feature = "test-auth-bypass")]
use crate::auth::extractors::{TEST_CLAIMS_HEADER, TEST_USER_ID_HEADER, encode_test_claims_header};
use crate::{App, Config, ConfigBuilder};
type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
type BeforeEachHook = Arc<dyn for<'a> Fn(&'a mut Request<Body>) -> BoxFuture<'a, ()> + Send + Sync>;
type AfterEachHook = Arc<dyn for<'a> Fn(&'a ScenarioOutcome) -> BoxFuture<'a, ()> + Send + Sync>;
type ScenarioAssertion = Box<dyn Fn(&ScenarioOutcome) -> Result<(), String> + Send + Sync>;
type AppTransform = Box<dyn FnOnce(App) -> App + Send>;
type ConfigTransform = Box<dyn FnOnce(Config) -> Config + Send>;
static TEST_HOST_ENV_LOCK: LazyLock<Mutex<()>> = LazyLock::new(|| Mutex::new(()));
enum ConfigSource {
Built(Config),
Builder(ConfigBuilder),
}
struct EnvOverride {
key: String,
value: Option<String>,
}
struct ScopedEnvOverrides {
previous: Vec<(String, Option<String>)>,
}
impl ScopedEnvOverrides {
fn apply(overrides: &[EnvOverride]) -> Self {
let previous = overrides
.iter()
.map(|override_| (override_.key.clone(), std::env::var(&override_.key).ok()))
.collect::<Vec<_>>();
for override_ in overrides {
unsafe {
if let Some(value) = &override_.value {
std::env::set_var(&override_.key, value);
} else {
std::env::remove_var(&override_.key);
}
}
}
Self { previous }
}
}
impl Drop for ScopedEnvOverrides {
fn drop(&mut self) {
for (key, value) in self.previous.iter().rev() {
unsafe {
if let Some(value) = value {
std::env::set_var(key, value);
} else {
std::env::remove_var(key);
}
}
}
}
}
fn build_host(
app: App,
with_middleware: bool,
before_each: Option<BeforeEachHook>,
after_each: Option<AfterEachHook>,
) -> TestHost {
let mut host = if with_middleware {
TestHost::new(app)
} else {
TestHost::without_middleware(app)
};
host.before_each = before_each;
host.after_each = after_each;
host
}
pub struct TestHostBootstrap {
config_source: ConfigSource,
config_transforms: Vec<ConfigTransform>,
app_transforms: Vec<AppTransform>,
env_overrides: Vec<EnvOverride>,
load_from_env: bool,
with_middleware: bool,
before_each: Option<BeforeEachHook>,
after_each: Option<AfterEachHook>,
}
pub struct TestHostBuilder {
app: App,
with_middleware: bool,
before_each: Option<BeforeEachHook>,
after_each: Option<AfterEachHook>,
}
pub struct TestHost {
router: Router,
before_each: Option<BeforeEachHook>,
after_each: Option<AfterEachHook>,
}
impl TestHost {
pub fn bootstrap() -> TestHostBootstrap {
TestHostBootstrap::new()
}
pub fn from_config(config: Config) -> TestHostBootstrap {
TestHostBootstrap::from_config(config)
}
pub fn from_config_builder(builder: ConfigBuilder) -> TestHostBootstrap {
TestHostBootstrap::from_config_builder(builder)
}
pub fn builder(app: App) -> TestHostBuilder {
TestHostBuilder::new(app)
}
pub fn new(app: App) -> Self {
Self::from_router(app.into_router_with_middleware())
}
pub fn without_middleware(app: App) -> Self {
Self::from_router(app.into_router())
}
pub fn from_router(router: Router) -> Self {
Self {
router,
before_each: None,
after_each: None,
}
}
pub fn before_each<F>(mut self, hook: F) -> Self
where
F: Fn(&mut Request<Body>) + Send + Sync + 'static,
{
self.before_each = Some(Arc::new(move |request| {
hook(request);
Box::pin(async {})
}));
self
}
pub fn before_each_async<F>(mut self, hook: F) -> Self
where
F: for<'a> Fn(&'a mut Request<Body>) -> BoxFuture<'a, ()> + Send + Sync + 'static,
{
self.before_each = Some(Arc::new(hook));
self
}
pub fn after_each<F>(mut self, hook: F) -> Self
where
F: Fn(&ScenarioOutcome) + Send + Sync + 'static,
{
self.after_each = Some(Arc::new(move |outcome| {
hook(outcome);
Box::pin(async {})
}));
self
}
pub fn after_each_async<F>(mut self, hook: F) -> Self
where
F: for<'a> Fn(&'a ScenarioOutcome) -> BoxFuture<'a, ()> + Send + Sync + 'static,
{
self.after_each = Some(Arc::new(hook));
self
}
pub async fn scenario<F>(&self, configure: F) -> ScenarioOutcome
where
F: FnOnce(&mut HostScenario),
{
self.try_scenario(configure)
.await
.unwrap_or_else(|error| panic!("{error}"))
}
pub async fn try_scenario<F>(&self, configure: F) -> Result<ScenarioOutcome, ScenarioFailure>
where
F: FnOnce(&mut HostScenario),
{
let mut scenario = HostScenario::new();
configure(&mut scenario);
let HostScenario {
request,
expected_status,
ignore_status_code,
assertions,
} = scenario;
let mut request = request;
if let Some(hook) = &self.before_each {
hook(&mut request).await;
}
let request_summary = RequestSummary::from_request(&request);
let response = self
.router
.clone()
.oneshot(request)
.await
.expect("test host request should succeed");
let outcome = ScenarioOutcome::from_response(request_summary.clone(), response).await;
if let Some(hook) = &self.after_each {
hook(&outcome).await;
}
let mut failures = Vec::new();
if !ignore_status_code {
let expected_status = expected_status.unwrap_or(StatusCode::OK);
if outcome.status() != expected_status {
failures.push(format!(
"Expected status {}, got {}",
expected_status,
outcome.status()
));
}
}
for assertion in assertions {
if let Err(message) = assertion(&outcome) {
failures.push(message);
}
}
if failures.is_empty() {
Ok(outcome)
} else {
Err(ScenarioFailure {
request: request_summary,
failures,
})
}
}
}
impl TestHostBootstrap {
pub fn new() -> Self {
Self {
config_source: ConfigSource::Builder(ConfigBuilder::new()),
config_transforms: Vec::new(),
app_transforms: Vec::new(),
env_overrides: Vec::new(),
load_from_env: false,
with_middleware: true,
before_each: None,
after_each: None,
}
}
pub fn from_config(config: Config) -> Self {
Self {
config_source: ConfigSource::Built(config),
..Self::new()
}
}
pub fn from_config_builder(builder: ConfigBuilder) -> Self {
Self {
config_source: ConfigSource::Builder(builder),
..Self::new()
}
}
pub fn configure_config<F>(mut self, configure: F) -> Self
where
F: FnOnce(Config) -> Config + Send + 'static,
{
self.config_transforms.push(Box::new(configure));
self
}
pub fn configure_config_builder<F>(mut self, configure: F) -> Self
where
F: FnOnce(ConfigBuilder) -> ConfigBuilder,
{
let builder = match self.config_source {
ConfigSource::Built(config) => ConfigBuilder::from_config(config),
ConfigSource::Builder(builder) => builder,
};
self.config_source = ConfigSource::Builder(configure(builder));
self
}
pub fn from_env(mut self) -> Self {
self.load_from_env = true;
self
}
pub fn configure_app<F>(mut self, configure: F) -> Self
where
F: FnOnce(App) -> App + Send + 'static,
{
self.app_transforms.push(Box::new(configure));
self
}
pub fn configure_context<F>(self, configure: F) -> Self
where
F: FnOnce(crate::app::AppContextBuilder) -> crate::app::AppContextBuilder + Send + 'static,
{
self.configure_app(move |app| app.map_context(configure))
}
pub fn with_env_var(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.set_env_override(key.into(), Some(value.into()));
self
}
pub fn without_env_var(mut self, key: impl Into<String>) -> Self {
self.set_env_override(key.into(), None);
self
}
pub fn without_middleware(mut self) -> Self {
self.with_middleware = false;
self
}
pub fn with_middleware(mut self) -> Self {
self.with_middleware = true;
self
}
pub fn before_each<F>(mut self, hook: F) -> Self
where
F: Fn(&mut Request<Body>) + Send + Sync + 'static,
{
self.before_each = Some(Arc::new(move |request| {
hook(request);
Box::pin(async {})
}));
self
}
pub fn before_each_async<F>(mut self, hook: F) -> Self
where
F: for<'a> Fn(&'a mut Request<Body>) -> BoxFuture<'a, ()> + Send + Sync + 'static,
{
self.before_each = Some(Arc::new(hook));
self
}
pub fn after_each<F>(mut self, hook: F) -> Self
where
F: Fn(&ScenarioOutcome) + Send + Sync + 'static,
{
self.after_each = Some(Arc::new(move |outcome| {
hook(outcome);
Box::pin(async {})
}));
self
}
pub fn after_each_async<F>(mut self, hook: F) -> Self
where
F: for<'a> Fn(&'a ScenarioOutcome) -> BoxFuture<'a, ()> + Send + Sync + 'static,
{
self.after_each = Some(Arc::new(hook));
self
}
pub fn build(self) -> TestHost {
self.try_build()
.expect("test host bootstrap should build successfully")
}
pub fn try_build(self) -> crate::Result<TestHost> {
let _lock = TEST_HOST_ENV_LOCK.lock().unwrap();
let _env = ScopedEnvOverrides::apply(&self.env_overrides);
let builder = match self.config_source {
ConfigSource::Built(config) => ConfigBuilder::from_config(config),
ConfigSource::Builder(builder) => builder,
};
let mut config = if self.load_from_env {
builder.from_env().build()?
} else {
builder.build()?
};
for transform in self.config_transforms {
config = transform(config);
}
let config = ConfigBuilder::from_config(config).build()?;
let mut app = App::with_config(config);
for transform in self.app_transforms {
app = transform(app);
}
Ok(build_host(
app,
self.with_middleware,
self.before_each,
self.after_each,
))
}
fn set_env_override(&mut self, key: String, value: Option<String>) {
if let Some(existing) = self
.env_overrides
.iter_mut()
.find(|existing| existing.key == key)
{
existing.value = value;
} else {
self.env_overrides.push(EnvOverride { key, value });
}
}
}
impl TestHostBuilder {
pub fn new(app: App) -> Self {
Self {
app,
with_middleware: true,
before_each: None,
after_each: None,
}
}
pub fn configure_app<F>(mut self, configure: F) -> Self
where
F: FnOnce(App) -> App,
{
self.app = configure(self.app);
self
}
pub fn configure_context<F>(mut self, configure: F) -> Self
where
F: FnOnce(crate::app::AppContextBuilder) -> crate::app::AppContextBuilder,
{
self.app = self.app.map_context(configure);
self
}
pub fn without_middleware(mut self) -> Self {
self.with_middleware = false;
self
}
pub fn with_middleware(mut self) -> Self {
self.with_middleware = true;
self
}
pub fn before_each<F>(mut self, hook: F) -> Self
where
F: Fn(&mut Request<Body>) + Send + Sync + 'static,
{
self.before_each = Some(Arc::new(move |request| {
hook(request);
Box::pin(async {})
}));
self
}
pub fn before_each_async<F>(mut self, hook: F) -> Self
where
F: for<'a> Fn(&'a mut Request<Body>) -> BoxFuture<'a, ()> + Send + Sync + 'static,
{
self.before_each = Some(Arc::new(hook));
self
}
pub fn after_each<F>(mut self, hook: F) -> Self
where
F: Fn(&ScenarioOutcome) + Send + Sync + 'static,
{
self.after_each = Some(Arc::new(move |outcome| {
hook(outcome);
Box::pin(async {})
}));
self
}
pub fn after_each_async<F>(mut self, hook: F) -> Self
where
F: for<'a> Fn(&'a ScenarioOutcome) -> BoxFuture<'a, ()> + Send + Sync + 'static,
{
self.after_each = Some(Arc::new(hook));
self
}
pub fn build(self) -> TestHost {
build_host(
self.app,
self.with_middleware,
self.before_each,
self.after_each,
)
}
}
pub struct HostScenario {
request: Request<Body>,
expected_status: Option<StatusCode>,
ignore_status_code: bool,
assertions: Vec<ScenarioAssertion>,
}
impl HostScenario {
fn new() -> Self {
Self {
request: Request::builder()
.method(Method::GET)
.uri("/")
.body(Body::empty())
.unwrap(),
expected_status: None,
ignore_status_code: false,
assertions: Vec::new(),
}
}
pub fn method(&mut self, method: Method) -> &mut Self {
*self.request.method_mut() = method;
self
}
pub fn uri(&mut self, uri: &str) -> &mut Self {
*self.request.uri_mut() = uri.parse().unwrap();
self
}
pub fn get(&mut self, uri: &str) -> &mut Self {
self.method(Method::GET).uri(uri)
}
pub fn post(&mut self, uri: &str) -> &mut Self {
self.method(Method::POST).uri(uri)
}
pub fn put(&mut self, uri: &str) -> &mut Self {
self.method(Method::PUT).uri(uri)
}
pub fn delete(&mut self, uri: &str) -> &mut Self {
self.method(Method::DELETE).uri(uri)
}
pub fn patch(&mut self, uri: &str) -> &mut Self {
self.method(Method::PATCH).uri(uri)
}
pub fn header(&mut self, key: &str, value: &str) -> &mut Self {
use axum::http::HeaderName;
self.request.headers_mut().insert(
HeaderName::from_bytes(key.as_bytes()).unwrap(),
value.parse().unwrap(),
);
self
}
pub fn with_request_header(&mut self, key: &str, value: &str) -> &mut Self {
self.header(key, value)
}
pub fn with_header(&mut self, key: &str, value: &str) -> &mut Self {
self.header(key, value)
}
pub fn bearer_token(&mut self, token: &str) -> &mut Self {
self.header("Authorization", &format!("Bearer {}", token))
}
pub fn with_auth(&mut self, token: &str) -> &mut Self {
self.bearer_token(token)
}
#[cfg(feature = "test-auth-bypass")]
pub fn with_test_user(&mut self, user_id: &str) -> &mut Self {
self.header(TEST_USER_ID_HEADER, user_id)
}
#[cfg(feature = "test-auth-bypass")]
pub fn with_test_claims<T: Serialize>(&mut self, claims: &T) -> &mut Self {
let encoded = encode_test_claims_header(claims);
self.header(TEST_CLAIMS_HEADER, &encoded)
}
pub fn with_query(&mut self, params: &[(&str, &str)]) -> &mut Self {
let uri = self.request.uri().clone();
let mut query_parts = vec![];
if let Some(query) = uri.query() {
query_parts.push(query.to_string());
}
for (key, value) in params {
query_parts.push(format!(
"{}={}",
urlencoding::encode(key),
urlencoding::encode(value)
));
}
let path = uri.path();
let new_uri = if query_parts.is_empty() {
path.to_string()
} else {
format!("{}?{}", path, query_parts.join("&"))
};
*self.request.uri_mut() = new_uri.parse().unwrap();
self
}
pub fn json<T: Serialize>(&mut self, body: &T) -> &mut Self {
let json = serde_json::to_string(body).unwrap();
*self.request.body_mut() = Body::from(json);
self.request
.headers_mut()
.insert(header::CONTENT_TYPE, "application/json".parse().unwrap());
self.request
.headers_mut()
.insert(header::ACCEPT, "application/json".parse().unwrap());
self
}
pub fn with_json<T: Serialize>(&mut self, body: &T) -> &mut Self {
self.json(body)
}
pub fn form<T: Serialize>(&mut self, body: &T) -> &mut Self {
let encoded = serde_urlencoded::to_string(body).unwrap();
*self.request.body_mut() = Body::from(encoded);
self.request.headers_mut().insert(
header::CONTENT_TYPE,
"application/x-www-form-urlencoded".parse().unwrap(),
);
self
}
pub fn with_form<T: Serialize>(&mut self, body: &T) -> &mut Self {
self.form(body)
}
pub fn text_body(&mut self, body: impl Into<String>) -> &mut Self {
*self.request.body_mut() = Body::from(body.into());
self
}
pub fn status_code_should_be(&mut self, status: u16) -> &mut Self {
self.expected_status = Some(
StatusCode::from_u16(status)
.unwrap_or_else(|_| panic!("Invalid HTTP status code: {status}")),
);
self.ignore_status_code = false;
self
}
pub fn status_code_should_be_ok(&mut self) -> &mut Self {
self.status_code_should_be(StatusCode::OK.as_u16())
}
pub fn ignore_status_code(&mut self) -> &mut Self {
self.ignore_status_code = true;
self
}
pub fn header_should_be(&mut self, key: &str, expected: &str) -> &mut Self {
let key = key.to_string();
let expected = expected.to_string();
self.assert_with(move |outcome| {
let Some(actual) = outcome.header(&key) else {
return Err(format!("Expected header '{key}' to exist"));
};
if actual == expected {
Ok(())
} else {
Err(format!(
"Expected header '{key}' to be '{expected}', got '{actual}'"
))
}
})
}
pub fn header_should_exist(&mut self, key: &str) -> &mut Self {
let key = key.to_string();
self.assert_with(move |outcome| {
if outcome.header(&key).is_some() {
Ok(())
} else {
Err(format!("Expected header '{key}' to exist"))
}
})
}
pub fn redirect_to_should_be(&mut self, expected: &str) -> &mut Self {
let expected = expected.to_string();
self.assert_with(move |outcome| {
if !outcome.status().is_redirection() {
return Err(format!(
"Expected redirect status, got {}",
outcome.status()
));
}
let Some(location) = outcome.header(header::LOCATION.as_str()) else {
return Err("Expected Location header to exist".to_string());
};
if location == expected {
Ok(())
} else {
Err(format!(
"Expected redirect location '{expected}', got '{location}'"
))
}
})
}
pub fn content_should_contain(&mut self, text: &str) -> &mut Self {
let text = text.to_string();
self.assert_with(move |outcome| {
let body = outcome.body_string();
if body.contains(&text) {
Ok(())
} else {
Err(format!("Expected body to contain '{text}', got '{body}'"))
}
})
}
pub fn json_path_should_be(&mut self, path: &str, expected: serde_json::Value) -> &mut Self {
let path = path.to_string();
self.assert_with(move |outcome| {
let json = parse_json_body(outcome)?;
let Some(actual) = json_path_get(&json, &path) else {
return Err(format!("Path '{path}' not found in JSON response"));
};
if actual == &expected {
Ok(())
} else {
Err(format!(
"Expected JSON path '{path}' to equal {expected}, got {actual}"
))
}
})
}
pub fn json_should_contain(&mut self, expected: serde_json::Value) -> &mut Self {
self.assert_with(move |outcome| {
let json = parse_json_body(outcome)?;
if json_contains(&json, &expected) {
Ok(())
} else {
Err(format!("Expected JSON to contain {expected}, got {json}"))
}
})
}
pub fn assert_with<F>(&mut self, assertion: F) -> &mut Self
where
F: Fn(&ScenarioOutcome) -> Result<(), String> + Send + Sync + 'static,
{
self.assertions.push(Box::new(assertion));
self
}
}
#[derive(Clone, Debug)]
pub struct RequestSummary {
method: Method,
uri: Uri,
headers: HeaderMap,
}
impl RequestSummary {
fn from_request(request: &Request<Body>) -> Self {
Self {
method: request.method().clone(),
uri: request.uri().clone(),
headers: request.headers().clone(),
}
}
pub fn method(&self) -> &Method {
&self.method
}
pub fn uri(&self) -> &Uri {
&self.uri
}
pub fn headers(&self) -> &HeaderMap {
&self.headers
}
pub fn header(&self, key: &str) -> Option<&str> {
self.headers.get(key).and_then(|value| value.to_str().ok())
}
}
#[derive(Clone, Debug)]
pub struct ScenarioOutcome {
request: RequestSummary,
status: StatusCode,
headers: HeaderMap,
body: Vec<u8>,
}
impl ScenarioOutcome {
async fn from_response(request: RequestSummary, response: axum::response::Response) -> Self {
let status = response.status();
let headers = response.headers().clone();
let body = to_bytes(response.into_body(), usize::MAX)
.await
.expect("test host response body should be readable")
.to_vec();
Self {
request,
status,
headers,
body,
}
}
pub fn request(&self) -> &RequestSummary {
&self.request
}
pub fn status(&self) -> StatusCode {
self.status
}
pub fn headers(&self) -> &HeaderMap {
&self.headers
}
pub fn header(&self, key: &str) -> Option<&str> {
self.headers.get(key).and_then(|value| value.to_str().ok())
}
pub fn body_bytes(&self) -> &[u8] {
&self.body
}
pub fn body_string(&self) -> String {
String::from_utf8_lossy(&self.body).into_owned()
}
pub fn json<T: DeserializeOwned>(&self) -> T {
serde_json::from_slice(&self.body).expect("failed to parse JSON response")
}
pub fn json_value(&self) -> serde_json::Value {
self.json()
}
}
#[derive(Clone, Debug)]
pub struct ScenarioFailure {
request: RequestSummary,
failures: Vec<String>,
}
impl ScenarioFailure {
pub fn request(&self) -> &RequestSummary {
&self.request
}
pub fn failures(&self) -> &[String] {
&self.failures
}
}
impl fmt::Display for ScenarioFailure {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(
f,
"Scenario failed for {} {}:",
self.request.method(),
self.request.uri()
)?;
for failure in &self.failures {
writeln!(f, "- {failure}")?;
}
Ok(())
}
}
impl std::error::Error for ScenarioFailure {}
fn parse_json_body(outcome: &ScenarioOutcome) -> Result<serde_json::Value, String> {
serde_json::from_slice(outcome.body_bytes())
.map_err(|error| format!("Expected JSON response body, got parse error: {error}"))
}
fn json_path_get<'a>(json: &'a serde_json::Value, path: &str) -> Option<&'a serde_json::Value> {
let parts: Vec<&str> = path.split('.').collect();
let mut current = json;
for part in parts {
if let Ok(index) = part.parse::<usize>() {
current = current.get(index)?;
} else {
current = current.get(part)?;
}
}
Some(current)
}
fn json_contains(actual: &serde_json::Value, expected: &serde_json::Value) -> bool {
match (actual, expected) {
(serde_json::Value::Object(actual_map), serde_json::Value::Object(expected_map)) => {
expected_map.iter().all(|(key, expected_value)| {
actual_map
.get(key)
.map(|actual_value| json_contains(actual_value, expected_value))
.unwrap_or(false)
})
}
(serde_json::Value::Array(actual_array), serde_json::Value::Array(expected_array)) => {
expected_array.iter().all(|expected_value| {
actual_array
.iter()
.any(|actual_value| json_contains(actual_value, expected_value))
})
}
_ => actual == expected,
}
}