use std::sync::Arc;
use axum::Router;
use axum::body::Body;
use axum::http::{Method, Request};
use axum::middleware;
use http_body_util::BodyExt;
use tower::ServiceExt;
use rusty_gasket::auth::{AuthChain, AuthMiddlewareState, UnauthenticatedPolicy, auth_middleware};
use rusty_gasket::observability;
use rusty_gasket::testing::mock_auth::MockAuthBackend;
use rusty_gasket::testing::test_response::TestResponse;
pub struct TestApp {
router: Router,
}
impl std::fmt::Debug for TestApp {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TestApp").finish_non_exhaustive()
}
}
impl TestApp {
pub const fn builder() -> TestAppBuilder {
TestAppBuilder {
router: None,
mock_auth: None,
auth_state: None,
add_logging: false,
}
}
pub async fn get(&self, path: &str) -> TestResponse {
self.request(Method::GET, path, Body::empty()).await
}
pub async fn post_json(&self, path: &str, body: &impl serde::Serialize) -> TestResponse {
let json = serde_json::to_vec(body).expect("serialize JSON body");
let request = Request::builder()
.method(Method::POST)
.uri(path)
.header("content-type", "application/json")
.body(Body::from(json))
.expect("build request");
self.send(request).await
}
pub async fn put_json(&self, path: &str, body: &impl serde::Serialize) -> TestResponse {
let json = serde_json::to_vec(body).expect("serialize JSON body");
let request = Request::builder()
.method(Method::PUT)
.uri(path)
.header("content-type", "application/json")
.body(Body::from(json))
.expect("build request");
self.send(request).await
}
pub async fn delete(&self, path: &str) -> TestResponse {
self.request(Method::DELETE, path, Body::empty()).await
}
pub async fn patch_json(&self, path: &str, body: &impl serde::Serialize) -> TestResponse {
let json = serde_json::to_vec(body).expect("serialize JSON body");
let request = Request::builder()
.method(Method::PATCH)
.uri(path)
.header("content-type", "application/json")
.body(Body::from(json))
.expect("build request");
self.send(request).await
}
pub async fn request(&self, method: Method, path: &str, body: Body) -> TestResponse {
let request = Request::builder()
.method(method)
.uri(path)
.body(body)
.expect("build request");
self.send(request).await
}
pub async fn send(&self, request: Request<Body>) -> TestResponse {
let response = self
.router
.clone()
.oneshot(request)
.await
.expect("router should not fail");
let status = response.status();
let headers = response.headers().clone();
let body = response
.into_body()
.collect()
.await
.expect("collect response body")
.to_bytes();
TestResponse::new(status, headers, body)
}
}
#[must_use = "TestAppBuilder must be consumed by .build() to produce a TestApp"]
pub struct TestAppBuilder {
router: Option<Router>,
mock_auth: Option<MockAuthBackend>,
auth_state: Option<Arc<AuthMiddlewareState>>,
add_logging: bool,
}
impl std::fmt::Debug for TestAppBuilder {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TestAppBuilder")
.field("has_router", &self.router.is_some())
.field("has_mock_auth", &self.mock_auth.is_some())
.field("add_logging", &self.add_logging)
.finish()
}
}
impl TestAppBuilder {
pub fn router(mut self, router: Router) -> Self {
self.router = Some(router);
self
}
pub fn mock_auth(mut self, subject: &str) -> Self {
self.mock_auth = Some(MockAuthBackend::authenticated(subject));
self
}
pub fn mock_auth_identity(mut self, identity: rusty_gasket::auth::Identity) -> Self {
self.mock_auth = Some(MockAuthBackend::with_identity(identity));
self
}
pub fn anonymous_auth(mut self) -> Self {
self.mock_auth = Some(MockAuthBackend::anonymous());
self
}
pub fn auth_state(mut self, state: Arc<AuthMiddlewareState>) -> Self {
self.auth_state = Some(state);
self
}
pub const fn logging(mut self, enabled: bool) -> Self {
self.add_logging = enabled;
self
}
pub fn build(self) -> TestApp {
let router = self.router.expect("TestApp requires a router");
assert!(
!(self.auth_state.is_some() && self.mock_auth.is_some()),
"TestAppBuilder cannot combine with_auth_state with with_mock_auth*; \
pick one auth source per TestApp",
);
let router = if let Some(state) = self.auth_state {
router.layer(middleware::from_fn_with_state(state, auth_middleware))
} else if let Some(mock) = self.mock_auth {
let fallback = if mock.is_anonymous() {
UnauthenticatedPolicy::AllowAnonymous
} else {
UnauthenticatedPolicy::Reject
};
let state = Arc::new(AuthMiddlewareState::new(
AuthChain::new().backend(mock).with_fallback(fallback),
));
router.layer(middleware::from_fn_with_state(state, auth_middleware))
} else {
router
};
let router = if self.add_logging {
router.layer(middleware::from_fn(observability::logging_middleware))
} else {
router
};
TestApp { router }
}
}