use axum::body::Body;
use axum::http::{Method, Request, StatusCode};
use tower::ServiceExt;
use crate::config::AutumnConfig;
use crate::route::Route;
use crate::state::AppState;
#[cfg(feature = "db")]
use diesel_async::AsyncPgConnection;
#[cfg(feature = "db")]
use diesel_async::pooled_connection::deadpool::Pool;
pub struct TestApp {
routes: Vec<Route>,
merge_routers: Vec<axum::Router<crate::state::AppState>>,
nest_routers: Vec<(String, axum::Router<crate::state::AppState>)>,
custom_layers: Vec<crate::app::CustomLayerRegistration>,
config: AutumnConfig,
#[cfg(feature = "openapi")]
openapi: Option<crate::openapi::OpenApiConfig>,
#[cfg(feature = "db")]
pool: Option<Pool<AsyncPgConnection>>,
}
impl TestApp {
#[must_use]
pub fn new() -> Self {
let mut config = AutumnConfig::default();
config.profile = Some("test".into());
config.security.csrf.enabled = false;
Self {
routes: Vec::new(),
merge_routers: Vec::new(),
nest_routers: Vec::new(),
custom_layers: Vec::new(),
config,
#[cfg(feature = "openapi")]
openapi: None,
#[cfg(feature = "db")]
pool: None,
}
}
#[cfg(feature = "openapi")]
#[must_use]
pub fn openapi(mut self, config: crate::openapi::OpenApiConfig) -> Self {
self.openapi = Some(config);
self
}
#[must_use]
pub fn merge(mut self, router: axum::Router<crate::state::AppState>) -> Self {
self.merge_routers.push(router);
self
}
#[must_use]
pub fn nest(mut self, path: &str, router: axum::Router<crate::state::AppState>) -> Self {
self.nest_routers.push((path.to_owned(), router));
self
}
#[must_use]
pub fn layer<L: crate::app::IntoAppLayer>(mut self, layer: L) -> Self {
self.custom_layers
.push(crate::app::CustomLayerRegistration {
type_id: std::any::TypeId::of::<L>(),
apply: Box::new(move |router| layer.apply_to(router)),
});
self
}
#[must_use]
pub const fn from_router(router: axum::Router) -> TestClient {
TestClient { router }
}
#[must_use]
pub fn routes(mut self, routes: Vec<Route>) -> Self {
self.routes.extend(routes);
self
}
#[must_use]
pub fn config(mut self, config: AutumnConfig) -> Self {
self.config = config;
self
}
#[must_use]
pub fn profile(mut self, profile: &str) -> Self {
self.config.profile = Some(profile.to_owned());
self
}
#[cfg(feature = "db")]
#[must_use]
pub fn with_db(mut self, pool: Pool<AsyncPgConnection>) -> Self {
self.pool = Some(pool);
self
}
#[must_use]
pub fn build(self) -> TestClient {
let state = AppState {
extensions: std::sync::Arc::new(std::sync::RwLock::new(
std::collections::HashMap::new(),
)),
#[cfg(feature = "db")]
pool: self.pool,
profile: self.config.profile.clone(),
started_at: std::time::Instant::now(),
health_detailed: self.config.health.detailed,
probes: crate::probe::ProbeState::ready_for_test(),
metrics: crate::middleware::MetricsCollector::new(),
log_levels: crate::actuator::LogLevels::new(&self.config.log.level),
task_registry: crate::actuator::TaskRegistry::new(),
config_props: crate::actuator::ConfigProperties::default(),
#[cfg(feature = "ws")]
channels: crate::channels::Channels::new(32),
#[cfg(feature = "ws")]
shutdown: tokio_util::sync::CancellationToken::new(),
};
let router = crate::router::try_build_router_inner(
self.routes,
&self.config,
state,
crate::router::RouterContext {
exception_filters: Vec::new(),
scoped_groups: Vec::new(),
merge_routers: self.merge_routers,
nest_routers: self.nest_routers,
custom_layers: self.custom_layers,
error_page_renderer: None,
session_store: None,
#[cfg(feature = "openapi")]
openapi: self.openapi,
},
)
.expect("failed to build test router");
TestClient { router }
}
}
impl Default for TestApp {
fn default() -> Self {
Self::new()
}
}
pub struct TestClient {
router: axum::Router,
}
impl TestClient {
pub fn into_router(self) -> axum::Router {
self.router
}
#[must_use]
pub fn get(&self, uri: &str) -> RequestBuilder {
RequestBuilder::new(self.router.clone(), Method::GET, uri)
}
#[must_use]
pub fn post(&self, uri: &str) -> RequestBuilder {
RequestBuilder::new(self.router.clone(), Method::POST, uri)
}
#[must_use]
pub fn put(&self, uri: &str) -> RequestBuilder {
RequestBuilder::new(self.router.clone(), Method::PUT, uri)
}
#[must_use]
pub fn delete(&self, uri: &str) -> RequestBuilder {
RequestBuilder::new(self.router.clone(), Method::DELETE, uri)
}
#[must_use]
pub fn patch(&self, uri: &str) -> RequestBuilder {
RequestBuilder::new(self.router.clone(), Method::PATCH, uri)
}
}
pub struct RequestBuilder {
router: axum::Router,
method: Method,
uri: String,
headers: Vec<(String, String)>,
body: Body,
}
impl RequestBuilder {
fn new(router: axum::Router, method: Method, uri: &str) -> Self {
Self {
router,
method,
uri: uri.to_owned(),
headers: Vec::new(),
body: Body::empty(),
}
}
#[must_use]
pub fn header(mut self, name: &str, value: &str) -> Self {
self.headers.push((name.to_owned(), value.to_owned()));
self
}
#[must_use]
pub fn json(mut self, value: &serde_json::Value) -> Self {
self.headers
.push(("content-type".to_owned(), "application/json".to_owned()));
self.body = Body::from(serde_json::to_vec(value).expect("failed to serialize JSON body"));
self
}
#[must_use]
pub fn form(mut self, body: &str) -> Self {
self.headers.push((
"content-type".to_owned(),
"application/x-www-form-urlencoded".to_owned(),
));
self.body = Body::from(body.to_owned());
self
}
#[must_use]
pub fn body(mut self, body: impl Into<Body>) -> Self {
self.body = body.into();
self
}
pub async fn send(self) -> TestResponse {
let mut builder = Request::builder().method(self.method).uri(&self.uri);
for (name, value) in &self.headers {
builder = builder.header(name.as_str(), value.as_str());
}
let request = builder.body(self.body).expect("failed to build request");
let response = self.router.oneshot(request).await.expect("request failed");
let status = response.status();
let headers: Vec<(String, String)> = response
.headers()
.iter()
.map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_owned()))
.collect();
let body_bytes = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.expect("failed to read response body");
TestResponse {
status,
headers,
body: body_bytes.to_vec(),
}
}
}
pub struct TestResponse {
pub status: StatusCode,
pub headers: Vec<(String, String)>,
pub body: Vec<u8>,
}
impl TestResponse {
#[must_use]
pub fn text(&self) -> String {
String::from_utf8(self.body.clone()).expect("response body is not valid UTF-8")
}
#[must_use]
pub fn json<T: serde::de::DeserializeOwned>(&self) -> T {
serde_json::from_slice(&self.body).expect("failed to parse response body as JSON")
}
#[must_use]
pub fn header(&self, name: &str) -> Option<&str> {
let name_lower = name.to_lowercase();
self.headers
.iter()
.find(|(k, _)| k.to_lowercase() == name_lower)
.map(|(_, v)| v.as_str())
}
#[track_caller]
pub fn assert_ok(&self) -> &Self {
assert_eq!(
self.status,
StatusCode::OK,
"expected 200 OK, got {}.\nBody: {}",
self.status,
String::from_utf8_lossy(&self.body)
);
self
}
#[track_caller]
pub fn assert_status(&self, expected: u16) -> &Self {
assert_eq!(
self.status.as_u16(),
expected,
"expected status {expected}, got {}.\nBody: {}",
self.status,
String::from_utf8_lossy(&self.body)
);
self
}
#[track_caller]
pub fn assert_success(&self) -> &Self {
assert!(
self.status.is_success(),
"expected 2xx success, got {}.\nBody: {}",
self.status,
String::from_utf8_lossy(&self.body)
);
self
}
#[track_caller]
pub fn assert_header(&self, name: &str, expected: &str) -> &Self {
let value = self
.header(name)
.unwrap_or_else(|| panic!("expected header `{name}` to be present"));
assert_eq!(
value, expected,
"header `{name}`: expected `{expected}`, got `{value}`"
);
self
}
#[track_caller]
pub fn assert_header_contains(&self, name: &str, substring: &str) -> &Self {
let value = self
.header(name)
.unwrap_or_else(|| panic!("expected header `{name}` to be present"));
assert!(
value.contains(substring),
"header `{name}`: expected `{value}` to contain `{substring}`"
);
self
}
#[track_caller]
pub fn assert_body_contains(&self, substring: &str) -> &Self {
let body = self.text();
assert!(
body.contains(substring),
"expected body to contain `{substring}`.\nBody: {body}"
);
self
}
#[track_caller]
pub fn assert_body_eq(&self, expected: &str) -> &Self {
let body = self.text();
assert_eq!(body, expected, "body mismatch");
self
}
#[track_caller]
pub fn assert_json<T, F>(&self, predicate: F) -> &Self
where
T: serde::de::DeserializeOwned,
F: FnOnce(&T),
{
let value: T = self.json();
predicate(&value);
self
}
#[track_caller]
pub fn assert_body_empty(&self) -> &Self {
assert!(
self.body.is_empty(),
"expected empty body, got {} bytes: {}",
self.body.len(),
String::from_utf8_lossy(&self.body)
);
self
}
}
#[cfg(all(feature = "db", feature = "test-support"))]
pub struct TestDb {
_container: testcontainers::ContainerAsync<testcontainers_modules::postgres::Postgres>,
pool: Pool<AsyncPgConnection>,
url: String,
}
#[cfg(all(feature = "db", feature = "test-support"))]
impl TestDb {
pub async fn new() -> Self {
use diesel_async::pooled_connection::AsyncDieselConnectionManager;
use testcontainers::runners::AsyncRunner;
use testcontainers_modules::postgres::Postgres;
let container = Postgres::default()
.start()
.await
.expect("failed to start Postgres testcontainer (is Docker running?)");
let host = container
.get_host()
.await
.expect("failed to build test router");
let port = container
.get_host_port_ipv4(5432)
.await
.expect("failed to build test router");
let url = format!("postgres://postgres:postgres@{host}:{port}/postgres");
let manager = AsyncDieselConnectionManager::<AsyncPgConnection>::new(&url);
let pool = Pool::builder(manager)
.max_size(5)
.build()
.expect("failed to build connection pool");
Self {
_container: container,
pool,
url,
}
}
pub async fn shared() -> &'static Self {
use std::sync::OnceLock;
use tokio::sync::OnceCell;
static CELL: OnceLock<OnceCell<TestDb>> = OnceLock::new();
let once = CELL.get_or_init(OnceCell::new);
once.get_or_init(Self::new).await
}
#[must_use]
pub fn pool(&self) -> Pool<AsyncPgConnection> {
self.pool.clone()
}
#[must_use]
pub fn url(&self) -> &str {
&self.url
}
pub async fn execute_sql(&self, sql: &str) {
use diesel_async::RunQueryDsl;
let mut conn = self.pool.get().await.expect("failed to get connection");
diesel::sql_query(sql)
.execute(&mut *conn)
.await
.unwrap_or_else(|e| panic!("SQL execution failed: {e}\nSQL: {sql}"));
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_routes() -> Vec<Route> {
use axum::routing;
async fn hello() -> &'static str {
"hello"
}
async fn echo_json(
axum::Json(value): axum::Json<serde_json::Value>,
) -> axum::Json<serde_json::Value> {
axum::Json(value)
}
async fn status_201() -> (StatusCode, &'static str) {
(StatusCode::CREATED, "created")
}
vec![
Route {
method: Method::GET,
path: "/hello",
handler: routing::get(hello),
name: "hello",
api_doc: crate::openapi::ApiDoc {
method: "GET",
path: "/hello",
operation_id: "hello",
success_status: 200,
..Default::default()
},
},
Route {
method: Method::POST,
path: "/echo",
handler: routing::post(echo_json),
name: "echo",
api_doc: crate::openapi::ApiDoc {
method: "POST",
path: "/echo",
operation_id: "echo",
success_status: 200,
..Default::default()
},
},
Route {
method: Method::POST,
path: "/create",
handler: routing::post(status_201),
name: "create",
api_doc: crate::openapi::ApiDoc {
method: "POST",
path: "/create",
operation_id: "create",
success_status: 201,
..Default::default()
},
},
]
}
#[tokio::test]
async fn test_app_get_request() {
let client = TestApp::new().routes(test_routes()).build();
client.get("/hello").send().await.assert_ok();
}
#[tokio::test]
async fn test_app_post_json() {
let client = TestApp::new().routes(test_routes()).build();
client
.post("/echo")
.json(&serde_json::json!({"key": "value"}))
.send()
.await
.assert_ok()
.assert_body_contains("key");
}
#[tokio::test]
async fn test_response_assert_status() {
let client = TestApp::new().routes(test_routes()).build();
client
.post("/create")
.send()
.await
.assert_status(201)
.assert_body_eq("created");
}
#[tokio::test]
async fn test_response_assert_success() {
let client = TestApp::new().routes(test_routes()).build();
client.get("/hello").send().await.assert_success();
}
#[tokio::test]
async fn test_not_found() {
let client = TestApp::new().routes(test_routes()).build();
client.get("/nonexistent").send().await.assert_status(404);
}
#[tokio::test]
async fn test_response_json_deserialization() {
let client = TestApp::new().routes(test_routes()).build();
let resp = client
.post("/echo")
.json(&serde_json::json!({"count": 42}))
.send()
.await;
resp.assert_ok().assert_json::<serde_json::Value, _>(|v| {
assert_eq!(v["count"], 42);
});
}
#[tokio::test]
async fn test_custom_header() {
let client = TestApp::new().routes(test_routes()).build();
let resp = client
.get("/hello")
.header("x-custom", "test-value")
.send()
.await;
resp.assert_ok();
}
#[tokio::test]
async fn test_client_default() {
let _app = TestApp::default();
}
}