1use axum::body::Body;
65use axum::http::{Method, Request, StatusCode};
66use tower::ServiceExt;
67
68use crate::config::AutumnConfig;
69use crate::route::Route;
70
71use crate::state::AppState;
72
73#[cfg(feature = "db")]
74use diesel_async::AsyncPgConnection;
75#[cfg(feature = "db")]
76use diesel_async::pooled_connection::deadpool::Pool;
77
78pub struct TestApp {
105 routes: Vec<Route>,
106 merge_routers: Vec<axum::Router<crate::state::AppState>>,
107 nest_routers: Vec<(String, axum::Router<crate::state::AppState>)>,
108 custom_layers: Vec<crate::app::CustomLayerRegistration>,
109 config: AutumnConfig,
110 #[cfg(feature = "openapi")]
111 openapi: Option<crate::openapi::OpenApiConfig>,
112 #[cfg(feature = "db")]
113 pool: Option<Pool<AsyncPgConnection>>,
114}
115
116impl TestApp {
117 #[must_use]
119 pub fn new() -> Self {
120 let mut config = AutumnConfig::default();
121 config.profile = Some("test".into());
122 config.security.csrf.enabled = false;
124
125 Self {
126 routes: Vec::new(),
127 merge_routers: Vec::new(),
128 nest_routers: Vec::new(),
129 custom_layers: Vec::new(),
130 config,
131 #[cfg(feature = "openapi")]
132 openapi: None,
133 #[cfg(feature = "db")]
134 pool: None,
135 }
136 }
137
138 #[cfg(feature = "openapi")]
145 #[must_use]
146 pub fn openapi(mut self, config: crate::openapi::OpenApiConfig) -> Self {
147 self.openapi = Some(config);
148 self
149 }
150
151 #[must_use]
156 pub fn merge(mut self, router: axum::Router<crate::state::AppState>) -> Self {
157 self.merge_routers.push(router);
158 self
159 }
160
161 #[must_use]
165 pub fn nest(mut self, path: &str, router: axum::Router<crate::state::AppState>) -> Self {
166 self.nest_routers.push((path.to_owned(), router));
167 self
168 }
169
170 #[must_use]
175 pub fn layer<L: crate::app::IntoAppLayer>(mut self, layer: L) -> Self {
176 self.custom_layers
177 .push(crate::app::CustomLayerRegistration {
178 type_id: std::any::TypeId::of::<L>(),
179 apply: Box::new(move |router| layer.apply_to(router)),
180 });
181 self
182 }
183
184 #[must_use]
189 pub const fn from_router(router: axum::Router) -> TestClient {
190 TestClient { router }
191 }
192
193 #[must_use]
195 pub fn routes(mut self, routes: Vec<Route>) -> Self {
196 self.routes.extend(routes);
197 self
198 }
199
200 #[must_use]
202 pub fn config(mut self, config: AutumnConfig) -> Self {
203 self.config = config;
204 self
205 }
206
207 #[must_use]
209 pub fn profile(mut self, profile: &str) -> Self {
210 self.config.profile = Some(profile.to_owned());
211 self
212 }
213
214 #[cfg(feature = "db")]
216 #[must_use]
217 pub fn with_db(mut self, pool: Pool<AsyncPgConnection>) -> Self {
218 self.pool = Some(pool);
219 self
220 }
221
222 #[must_use]
228 pub fn build(self) -> TestClient {
229 let state = AppState {
230 extensions: std::sync::Arc::new(std::sync::RwLock::new(
231 std::collections::HashMap::new(),
232 )),
233 #[cfg(feature = "db")]
234 pool: self.pool,
235 profile: self.config.profile.clone(),
236 started_at: std::time::Instant::now(),
237 health_detailed: self.config.health.detailed,
238 probes: crate::probe::ProbeState::ready_for_test(),
239 metrics: crate::middleware::MetricsCollector::new(),
240 log_levels: crate::actuator::LogLevels::new(&self.config.log.level),
241 task_registry: crate::actuator::TaskRegistry::new(),
242 config_props: crate::actuator::ConfigProperties::default(),
243 #[cfg(feature = "ws")]
244 channels: crate::channels::Channels::new(32),
245 #[cfg(feature = "ws")]
246 shutdown: tokio_util::sync::CancellationToken::new(),
247 };
248
249 let router = crate::router::try_build_router_inner(
250 self.routes,
251 &self.config,
252 state,
253 crate::router::RouterContext {
254 exception_filters: Vec::new(),
255 scoped_groups: Vec::new(),
256 merge_routers: self.merge_routers,
257 nest_routers: self.nest_routers,
258 custom_layers: self.custom_layers,
259 error_page_renderer: None,
260 session_store: None,
261 #[cfg(feature = "openapi")]
262 openapi: self.openapi,
263 },
264 )
265 .expect("failed to build test router");
266 TestClient { router }
267 }
268}
269
270impl Default for TestApp {
271 fn default() -> Self {
272 Self::new()
273 }
274}
275
276pub struct TestClient {
308 router: axum::Router,
309}
310
311impl TestClient {
312 pub fn into_router(self) -> axum::Router {
314 self.router
315 }
316
317 #[must_use]
319 pub fn get(&self, uri: &str) -> RequestBuilder {
320 RequestBuilder::new(self.router.clone(), Method::GET, uri)
321 }
322
323 #[must_use]
325 pub fn post(&self, uri: &str) -> RequestBuilder {
326 RequestBuilder::new(self.router.clone(), Method::POST, uri)
327 }
328
329 #[must_use]
331 pub fn put(&self, uri: &str) -> RequestBuilder {
332 RequestBuilder::new(self.router.clone(), Method::PUT, uri)
333 }
334
335 #[must_use]
337 pub fn delete(&self, uri: &str) -> RequestBuilder {
338 RequestBuilder::new(self.router.clone(), Method::DELETE, uri)
339 }
340
341 #[must_use]
343 pub fn patch(&self, uri: &str) -> RequestBuilder {
344 RequestBuilder::new(self.router.clone(), Method::PATCH, uri)
345 }
346}
347
348pub struct RequestBuilder {
356 router: axum::Router,
357 method: Method,
358 uri: String,
359 headers: Vec<(String, String)>,
360 body: Body,
361}
362
363impl RequestBuilder {
364 fn new(router: axum::Router, method: Method, uri: &str) -> Self {
365 Self {
366 router,
367 method,
368 uri: uri.to_owned(),
369 headers: Vec::new(),
370 body: Body::empty(),
371 }
372 }
373
374 #[must_use]
376 pub fn header(mut self, name: &str, value: &str) -> Self {
377 self.headers.push((name.to_owned(), value.to_owned()));
378 self
379 }
380
381 #[must_use]
385 pub fn json(mut self, value: &serde_json::Value) -> Self {
386 self.headers
387 .push(("content-type".to_owned(), "application/json".to_owned()));
388 self.body = Body::from(serde_json::to_vec(value).expect("failed to serialize JSON body"));
389 self
390 }
391
392 #[must_use]
396 pub fn form(mut self, body: &str) -> Self {
397 self.headers.push((
398 "content-type".to_owned(),
399 "application/x-www-form-urlencoded".to_owned(),
400 ));
401 self.body = Body::from(body.to_owned());
402 self
403 }
404
405 #[must_use]
407 pub fn body(mut self, body: impl Into<Body>) -> Self {
408 self.body = body.into();
409 self
410 }
411
412 pub async fn send(self) -> TestResponse {
415 let mut builder = Request::builder().method(self.method).uri(&self.uri);
416
417 for (name, value) in &self.headers {
418 builder = builder.header(name.as_str(), value.as_str());
419 }
420
421 let request = builder.body(self.body).expect("failed to build request");
422
423 let response = self.router.oneshot(request).await.expect("request failed");
424
425 let status = response.status();
426 let headers: Vec<(String, String)> = response
427 .headers()
428 .iter()
429 .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_owned()))
430 .collect();
431 let body_bytes = axum::body::to_bytes(response.into_body(), usize::MAX)
432 .await
433 .expect("failed to read response body");
434
435 TestResponse {
436 status,
437 headers,
438 body: body_bytes.to_vec(),
439 }
440 }
441}
442
443pub struct TestResponse {
456 pub status: StatusCode,
458 pub headers: Vec<(String, String)>,
460 pub body: Vec<u8>,
462}
463
464impl TestResponse {
465 #[must_use]
471 pub fn text(&self) -> String {
472 String::from_utf8(self.body.clone()).expect("response body is not valid UTF-8")
473 }
474
475 #[must_use]
482 pub fn json<T: serde::de::DeserializeOwned>(&self) -> T {
483 serde_json::from_slice(&self.body).expect("failed to parse response body as JSON")
484 }
485
486 #[must_use]
488 pub fn header(&self, name: &str) -> Option<&str> {
489 let name_lower = name.to_lowercase();
490 self.headers
491 .iter()
492 .find(|(k, _)| k.to_lowercase() == name_lower)
493 .map(|(_, v)| v.as_str())
494 }
495
496 #[track_caller]
500 pub fn assert_ok(&self) -> &Self {
501 assert_eq!(
502 self.status,
503 StatusCode::OK,
504 "expected 200 OK, got {}.\nBody: {}",
505 self.status,
506 String::from_utf8_lossy(&self.body)
507 );
508 self
509 }
510
511 #[track_caller]
513 pub fn assert_status(&self, expected: u16) -> &Self {
514 assert_eq!(
515 self.status.as_u16(),
516 expected,
517 "expected status {expected}, got {}.\nBody: {}",
518 self.status,
519 String::from_utf8_lossy(&self.body)
520 );
521 self
522 }
523
524 #[track_caller]
526 pub fn assert_success(&self) -> &Self {
527 assert!(
528 self.status.is_success(),
529 "expected 2xx success, got {}.\nBody: {}",
530 self.status,
531 String::from_utf8_lossy(&self.body)
532 );
533 self
534 }
535
536 #[track_caller]
538 pub fn assert_header(&self, name: &str, expected: &str) -> &Self {
539 let value = self
540 .header(name)
541 .unwrap_or_else(|| panic!("expected header `{name}` to be present"));
542 assert_eq!(
543 value, expected,
544 "header `{name}`: expected `{expected}`, got `{value}`"
545 );
546 self
547 }
548
549 #[track_caller]
551 pub fn assert_header_contains(&self, name: &str, substring: &str) -> &Self {
552 let value = self
553 .header(name)
554 .unwrap_or_else(|| panic!("expected header `{name}` to be present"));
555 assert!(
556 value.contains(substring),
557 "header `{name}`: expected `{value}` to contain `{substring}`"
558 );
559 self
560 }
561
562 #[track_caller]
564 pub fn assert_body_contains(&self, substring: &str) -> &Self {
565 let body = self.text();
566 assert!(
567 body.contains(substring),
568 "expected body to contain `{substring}`.\nBody: {body}"
569 );
570 self
571 }
572
573 #[track_caller]
575 pub fn assert_body_eq(&self, expected: &str) -> &Self {
576 let body = self.text();
577 assert_eq!(body, expected, "body mismatch");
578 self
579 }
580
581 #[track_caller]
583 pub fn assert_json<T, F>(&self, predicate: F) -> &Self
584 where
585 T: serde::de::DeserializeOwned,
586 F: FnOnce(&T),
587 {
588 let value: T = self.json();
589 predicate(&value);
590 self
591 }
592
593 #[track_caller]
595 pub fn assert_body_empty(&self) -> &Self {
596 assert!(
597 self.body.is_empty(),
598 "expected empty body, got {} bytes: {}",
599 self.body.len(),
600 String::from_utf8_lossy(&self.body)
601 );
602 self
603 }
604}
605
606#[cfg(all(feature = "db", feature = "test-support"))]
641pub struct TestDb {
642 _container: testcontainers::ContainerAsync<testcontainers_modules::postgres::Postgres>,
643 pool: Pool<AsyncPgConnection>,
644 url: String,
645}
646
647#[cfg(all(feature = "db", feature = "test-support"))]
648impl TestDb {
649 pub async fn new() -> Self {
654 use diesel_async::pooled_connection::AsyncDieselConnectionManager;
655 use testcontainers::runners::AsyncRunner;
656 use testcontainers_modules::postgres::Postgres;
657
658 let container = Postgres::default()
659 .start()
660 .await
661 .expect("failed to start Postgres testcontainer (is Docker running?)");
662
663 let host = container
664 .get_host()
665 .await
666 .expect("failed to build test router");
667 let port = container
668 .get_host_port_ipv4(5432)
669 .await
670 .expect("failed to build test router");
671 let url = format!("postgres://postgres:postgres@{host}:{port}/postgres");
672
673 let manager = AsyncDieselConnectionManager::<AsyncPgConnection>::new(&url);
674 let pool = Pool::builder(manager)
675 .max_size(5)
676 .build()
677 .expect("failed to build connection pool");
678
679 Self {
680 _container: container,
681 pool,
682 url,
683 }
684 }
685
686 pub async fn shared() -> &'static Self {
694 use std::sync::OnceLock;
695 use tokio::sync::OnceCell;
696
697 static CELL: OnceLock<OnceCell<TestDb>> = OnceLock::new();
699 let once = CELL.get_or_init(OnceCell::new);
700 once.get_or_init(Self::new).await
701 }
702
703 #[must_use]
705 pub fn pool(&self) -> Pool<AsyncPgConnection> {
706 self.pool.clone()
707 }
708
709 #[must_use]
711 pub fn url(&self) -> &str {
712 &self.url
713 }
714
715 pub async fn execute_sql(&self, sql: &str) {
728 use diesel_async::RunQueryDsl;
729 let mut conn = self.pool.get().await.expect("failed to get connection");
730 diesel::sql_query(sql)
731 .execute(&mut *conn)
732 .await
733 .unwrap_or_else(|e| panic!("SQL execution failed: {e}\nSQL: {sql}"));
734 }
735}
736
737#[cfg(test)]
738mod tests {
739 use super::*;
740
741 fn test_routes() -> Vec<Route> {
742 use axum::routing;
743
744 async fn hello() -> &'static str {
745 "hello"
746 }
747
748 async fn echo_json(
749 axum::Json(value): axum::Json<serde_json::Value>,
750 ) -> axum::Json<serde_json::Value> {
751 axum::Json(value)
752 }
753
754 async fn status_201() -> (StatusCode, &'static str) {
755 (StatusCode::CREATED, "created")
756 }
757
758 vec![
759 Route {
760 method: Method::GET,
761 path: "/hello",
762 handler: routing::get(hello),
763 name: "hello",
764 api_doc: crate::openapi::ApiDoc {
765 method: "GET",
766 path: "/hello",
767 operation_id: "hello",
768 success_status: 200,
769 ..Default::default()
770 },
771 },
772 Route {
773 method: Method::POST,
774 path: "/echo",
775 handler: routing::post(echo_json),
776 name: "echo",
777 api_doc: crate::openapi::ApiDoc {
778 method: "POST",
779 path: "/echo",
780 operation_id: "echo",
781 success_status: 200,
782 ..Default::default()
783 },
784 },
785 Route {
786 method: Method::POST,
787 path: "/create",
788 handler: routing::post(status_201),
789 name: "create",
790 api_doc: crate::openapi::ApiDoc {
791 method: "POST",
792 path: "/create",
793 operation_id: "create",
794 success_status: 201,
795 ..Default::default()
796 },
797 },
798 ]
799 }
800
801 #[tokio::test]
802 async fn test_app_get_request() {
803 let client = TestApp::new().routes(test_routes()).build();
804 client.get("/hello").send().await.assert_ok();
805 }
806
807 #[tokio::test]
808 async fn test_app_post_json() {
809 let client = TestApp::new().routes(test_routes()).build();
810
811 client
812 .post("/echo")
813 .json(&serde_json::json!({"key": "value"}))
814 .send()
815 .await
816 .assert_ok()
817 .assert_body_contains("key");
818 }
819
820 #[tokio::test]
821 async fn test_response_assert_status() {
822 let client = TestApp::new().routes(test_routes()).build();
823
824 client
825 .post("/create")
826 .send()
827 .await
828 .assert_status(201)
829 .assert_body_eq("created");
830 }
831
832 #[tokio::test]
833 async fn test_response_assert_success() {
834 let client = TestApp::new().routes(test_routes()).build();
835 client.get("/hello").send().await.assert_success();
836 }
837
838 #[tokio::test]
839 async fn test_not_found() {
840 let client = TestApp::new().routes(test_routes()).build();
841 client.get("/nonexistent").send().await.assert_status(404);
842 }
843
844 #[tokio::test]
845 async fn test_response_json_deserialization() {
846 let client = TestApp::new().routes(test_routes()).build();
847
848 let resp = client
849 .post("/echo")
850 .json(&serde_json::json!({"count": 42}))
851 .send()
852 .await;
853
854 resp.assert_ok().assert_json::<serde_json::Value, _>(|v| {
855 assert_eq!(v["count"], 42);
856 });
857 }
858
859 #[tokio::test]
860 async fn test_custom_header() {
861 let client = TestApp::new().routes(test_routes()).build();
862
863 let resp = client
864 .get("/hello")
865 .header("x-custom", "test-value")
866 .send()
867 .await;
868 resp.assert_ok();
869 }
870
871 #[tokio::test]
872 async fn test_client_default() {
873 let _app = TestApp::default();
874 }
875}