1#![allow(clippy::type_complexity, clippy::too_many_lines)]
2use axum::body::Body;
173use axum::http::{Method, Request, StatusCode};
174use tower::ServiceExt;
175
176use crate::config::AutumnConfig;
177use crate::route::Route;
178
179use crate::state::AppState;
180
181#[cfg(feature = "db")]
182use diesel_async::AsyncPgConnection;
183#[cfg(feature = "db")]
184use diesel_async::RunQueryDsl;
185#[cfg(feature = "db")]
186use diesel_async::pooled_connection::deadpool::Pool;
187
188pub struct TestApp {
215 routes: Vec<Route>,
216 scoped_groups: Vec<crate::app::ScopedGroup>,
217 merge_routers: Vec<axum::Router<crate::state::AppState>>,
218 nest_routers: Vec<(String, axum::Router<crate::state::AppState>)>,
219 custom_layers: Vec<crate::app::CustomLayerRegistration>,
220 config: AutumnConfig,
221 #[cfg(feature = "openapi")]
222 openapi: Option<crate::openapi::OpenApiConfig>,
223 #[cfg(feature = "mcp")]
224 mcp: Option<crate::mcp::McpRuntime>,
225 #[cfg(feature = "db")]
226 pool: Option<Pool<AsyncPgConnection>>,
227 #[cfg(feature = "db")]
228 replica_pool: Option<Pool<AsyncPgConnection>>,
229 #[cfg(feature = "db")]
230 transactional: bool,
231 #[cfg(feature = "db")]
232 transactional_url: Option<String>,
233 policy_registrations: Vec<TestPolicyRegistration>,
236 forbidden_response_override: Option<crate::authorization::ForbiddenResponse>,
240 #[cfg(feature = "mail")]
241 mail_interceptor: Option<std::sync::Arc<dyn crate::interceptor::MailInterceptor>>,
242 job_interceptor: Option<std::sync::Arc<dyn crate::interceptor::JobInterceptor>>,
243 #[cfg(feature = "db")]
244 db_interceptor: Option<std::sync::Arc<dyn crate::interceptor::DbConnectionInterceptor>>,
245 #[cfg(feature = "ws")]
246 channels_interceptor: Option<std::sync::Arc<dyn crate::interceptor::ChannelsInterceptor>>,
247 #[cfg(feature = "oauth2")]
248 http_interceptor: Option<std::sync::Arc<dyn crate::interceptor::HttpInterceptor>>,
249 #[cfg(feature = "http-client")]
253 http_mock_registry: Option<std::sync::Arc<crate::http_client::MockRegistry>>,
254 state_initializers: Vec<Box<dyn FnOnce(&AppState) + Send>>,
255 jobs: Vec<crate::job::JobInfo>,
256 exception_filters: Vec<std::sync::Arc<dyn crate::middleware::ExceptionFilter>>,
257 registered_plugins: std::collections::HashSet<String>,
258 extensions: std::collections::HashMap<std::any::TypeId, Box<dyn std::any::Any + Send>>,
259 clock: Option<std::sync::Arc<dyn crate::time::ClockSource>>,
261 clock_as_any: Option<std::sync::Arc<dyn std::any::Any + Send + Sync>>,
264 api_versions: Vec<crate::app::ApiVersion>,
265 metrics_sources: Vec<(String, std::sync::Arc<dyn crate::actuator::MetricsSource>)>,
267 health_indicators: Vec<(
269 String,
270 crate::actuator::IndicatorGroup,
271 std::sync::Arc<dyn crate::actuator::HealthIndicator>,
272 )>,
273 #[cfg(feature = "inbound-mail")]
275 inbound_mail_router: Option<std::sync::Arc<crate::inbound_mail::InboundMailRouter>>,
276}
277
278type TestPolicyRegistration = Box<dyn FnOnce(&crate::authorization::PolicyRegistry) + Send>;
279
280impl TestApp {
281 #[must_use]
283 pub fn new() -> Self {
284 let mut config = AutumnConfig::default();
285 config.profile = Some("test".into());
286 config.security.csrf.enabled = false;
288
289 Self {
290 routes: Vec::new(),
291 scoped_groups: Vec::new(),
292 merge_routers: Vec::new(),
293 nest_routers: Vec::new(),
294 custom_layers: Vec::new(),
295 config,
296 #[cfg(feature = "openapi")]
297 openapi: None,
298 #[cfg(feature = "mcp")]
299 mcp: None,
300 #[cfg(feature = "db")]
301 pool: None,
302 #[cfg(feature = "db")]
303 replica_pool: None,
304 #[cfg(feature = "db")]
305 transactional: false,
306 #[cfg(feature = "db")]
307 transactional_url: None,
308 policy_registrations: Vec::new(),
309 forbidden_response_override: None,
310 #[cfg(feature = "mail")]
311 mail_interceptor: None,
312 job_interceptor: None,
313 #[cfg(feature = "db")]
314 db_interceptor: None,
315 #[cfg(feature = "ws")]
316 channels_interceptor: None,
317 #[cfg(feature = "oauth2")]
318 http_interceptor: None,
319 #[cfg(feature = "http-client")]
320 http_mock_registry: None,
321 state_initializers: Vec::new(),
322 jobs: Vec::new(),
323 exception_filters: Vec::new(),
324 registered_plugins: std::collections::HashSet::new(),
325 extensions: std::collections::HashMap::new(),
326 clock: None,
327 clock_as_any: None,
328 api_versions: Vec::new(),
329 metrics_sources: Vec::new(),
330 health_indicators: Vec::new(),
331 #[cfg(feature = "inbound-mail")]
332 inbound_mail_router: None,
333 }
334 }
335
336 #[must_use]
340 pub fn policy<R, P>(mut self, policy: P) -> Self
341 where
342 R: Send + Sync + 'static,
343 P: crate::authorization::Policy<R>,
344 {
345 self.policy_registrations.push(Box::new(move |registry| {
346 registry.register_policy::<R, _>(policy);
347 }));
348 self
349 }
350
351 #[must_use]
355 pub fn scope<R, S>(mut self, scope: S) -> Self
356 where
357 R: Send + Sync + 'static,
358 S: crate::authorization::Scope<R>,
359 {
360 self.policy_registrations.push(Box::new(move |registry| {
361 registry.register_scope::<R, _>(scope);
362 }));
363 self
364 }
365
366 #[cfg(feature = "inbound-mail")]
370 #[must_use]
371 pub fn inbound_mail_router(mut self, router: crate::inbound_mail::InboundMailRouter) -> Self {
372 self.inbound_mail_router = Some(std::sync::Arc::new(router));
373 self
374 }
375
376 #[must_use]
380 pub const fn forbidden_response(
381 mut self,
382 value: crate::authorization::ForbiddenResponse,
383 ) -> Self {
384 self.forbidden_response_override = Some(value);
385 self
386 }
387
388 #[cfg(feature = "openapi")]
395 #[must_use]
396 pub fn openapi(mut self, config: crate::openapi::OpenApiConfig) -> Self {
397 self.openapi = Some(config);
398 self
399 }
400
401 #[cfg(feature = "mcp")]
408 #[must_use]
409 pub fn mount_mcp(mut self, path: impl Into<String>) -> Self {
410 let path = path.into();
411 if let Some(rt) = self.mcp.as_mut() {
412 rt.mount_path = path;
413 } else {
414 self.mcp = Some(crate::mcp::McpRuntime::new(path));
415 }
416 self
417 }
418
419 #[cfg(feature = "mcp")]
424 #[must_use]
425 pub fn expose_all_as_mcp(mut self) -> Self {
426 if let Some(rt) = self.mcp.as_mut() {
427 rt.expose_all = true;
428 } else {
429 let mut rt = crate::mcp::McpRuntime::new("/mcp");
430 rt.expose_all = true;
431 self.mcp = Some(rt);
432 }
433 self
434 }
435
436 #[cfg(feature = "mcp")]
441 #[must_use]
442 pub fn secure_mcp<L>(mut self, layer: L) -> Self
443 where
444 L: tower::Layer<axum::routing::Route> + Clone + Send + Sync + 'static,
445 L::Service: tower::Service<
446 axum::http::Request<axum::body::Body>,
447 Response = axum::http::Response<axum::body::Body>,
448 Error = std::convert::Infallible,
449 > + Clone
450 + Send
451 + Sync
452 + 'static,
453 <L::Service as tower::Service<axum::http::Request<axum::body::Body>>>::Future:
454 Send + 'static,
455 {
456 let applier: crate::mcp::McpEndpointLayer = Box::new(move |router| router.layer(layer));
457 if let Some(rt) = self.mcp.as_mut() {
458 rt.endpoint_layer = Some(applier);
459 } else {
460 let mut rt = crate::mcp::McpRuntime::new("/mcp");
461 rt.endpoint_layer = Some(applier);
462 self.mcp = Some(rt);
463 }
464 self
465 }
466
467 #[must_use]
472 pub fn merge(mut self, router: axum::Router<crate::state::AppState>) -> Self {
473 self.merge_routers.push(router);
474 self
475 }
476
477 #[must_use]
479 pub fn scoped<L>(mut self, prefix: &str, layer: L, routes: Vec<Route>) -> Self
480 where
481 L: tower::Layer<axum::routing::Route> + Clone + Send + Sync + 'static,
482 L::Service: tower::Service<
483 axum::http::Request<axum::body::Body>,
484 Response = axum::http::Response<axum::body::Body>,
485 Error = std::convert::Infallible,
486 > + Clone
487 + Send
488 + Sync
489 + 'static,
490 <L::Service as tower::Service<axum::http::Request<axum::body::Body>>>::Future:
491 Send + 'static,
492 {
493 self.scoped_groups.push(crate::app::ScopedGroup {
494 prefix: prefix.to_owned(),
495 routes,
496 source: crate::route_listing::RouteSource::User,
497 apply_layer: Box::new(move |router| router.layer(layer)),
498 });
499 self
500 }
501
502 #[must_use]
506 pub fn nest(mut self, path: &str, router: axum::Router<crate::state::AppState>) -> Self {
507 self.nest_routers.push((path.to_owned(), router));
508 self
509 }
510
511 #[must_use]
516 pub fn layer<L: crate::app::IntoAppLayer>(mut self, layer: L) -> Self {
517 self.custom_layers
518 .push(crate::app::CustomLayerRegistration {
519 type_id: std::any::TypeId::of::<L>(),
520 type_name: std::any::type_name::<L>(),
521 apply: Box::new(move |router| layer.apply_to(router)),
522 });
523 self
524 }
525
526 #[cfg(feature = "reporting")]
532 #[must_use]
533 pub fn with_error_reporter<R: crate::reporting::ErrorReporter>(mut self, reporter: R) -> Self {
534 let reporter =
535 std::sync::Arc::new(reporter) as std::sync::Arc<dyn crate::reporting::ErrorReporter>;
536 self.state_initializers.push(Box::new(move |state| {
537 let mut reporters = state
538 .extension::<crate::reporting::RegisteredReporters>()
539 .map(|registered| registered.0.clone())
540 .unwrap_or_default();
541 reporters.push(reporter.clone());
542 state.insert_extension(crate::reporting::RegisteredReporters(reporters));
543 }));
544 self
545 }
546
547 #[must_use]
554 pub const fn idempotent(mut self) -> Self {
555 self.config.idempotency.enabled = Some(true);
556 self
557 }
558
559 #[must_use]
566 pub fn from_router(router: axum::Router, state: AppState) -> TestClient {
567 TestClient {
568 router,
569 probes: crate::probe::ProbeState::ready_for_test(),
570 state,
571 _job_runtime: None,
572 clock_as_any: None,
573 }
574 }
575
576 #[must_use]
578 pub fn routes(mut self, routes: Vec<Route>) -> Self {
579 self.routes.extend(routes);
580 self
581 }
582
583 #[must_use]
585 pub fn state_initializer<F>(mut self, f: F) -> Self
586 where
587 F: FnOnce(&AppState) + Send + 'static,
588 {
589 self.state_initializers.push(Box::new(f));
590 self
591 }
592
593 #[must_use]
598 pub fn with_flag_store<S>(mut self, store: S) -> Self
599 where
600 S: crate::feature_flags::FlagStore,
601 {
602 use std::sync::Arc;
603 let service = crate::feature_flags::FeatureFlagService::new(Arc::new(store) as Arc<_>);
604 self.state_initializers.push(Box::new(move |state| {
605 state.insert_extension(service);
606 }));
607 self
608 }
609
610 #[must_use]
612 pub fn plugin<P: crate::plugin::Plugin>(mut self, plugin: P) -> Self {
613 let name = plugin.name().into_owned();
614 if self.registered_plugins.contains(&name) {
615 tracing::warn!(plugin = %name, "Duplicate plugin registration in TestApp; skipping");
616 return self;
617 }
618
619 let mut app_builder = crate::app();
620 app_builder
621 .registered_plugins
622 .clone_from(&self.registered_plugins);
623 app_builder.extensions = self.extensions;
624 app_builder.state_initializers = std::mem::take(&mut self.state_initializers);
625
626 app_builder = app_builder.plugin(plugin);
627
628 self.registered_plugins = app_builder.registered_plugins;
629 self.extensions = app_builder.extensions;
630 self.state_initializers = app_builder.state_initializers;
631
632 self.routes.extend(app_builder.routes);
634 self.scoped_groups.extend(app_builder.scoped_groups);
635 self.merge_routers.extend(app_builder.merge_routers);
636 self.nest_routers.extend(app_builder.nest_routers);
637 self.custom_layers.extend(app_builder.custom_layers);
638 self.jobs.extend(app_builder.jobs);
639 self.exception_filters.extend(app_builder.exception_filters);
640 self.metrics_sources.extend(app_builder.metrics_sources);
641 self.health_indicators.extend(app_builder.health_indicators);
642 #[cfg(feature = "inbound-mail")]
645 if let Some(router) = app_builder.inbound_mail_router {
646 self.inbound_mail_router = Some(router);
647 }
648
649 #[cfg(feature = "reporting")]
653 {
654 let reporters = std::mem::take(&mut app_builder.error_reporters);
655 if !reporters.is_empty() {
656 self.state_initializers.push(Box::new(move |state| {
657 let mut existing = state
658 .extension::<crate::reporting::RegisteredReporters>()
659 .map(|registered| registered.0.clone())
660 .unwrap_or_default();
661 existing.extend(reporters.iter().cloned());
662 state.insert_extension(crate::reporting::RegisteredReporters(existing));
663 }));
664 }
665 }
666
667 for hook in app_builder.startup_hooks {
668 self.state_initializers.push(Box::new(move |state| {
669 let state_owned = state.clone();
670 if let Ok(handle) = tokio::runtime::Handle::try_current() {
671 let thread_handle =
672 std::thread::spawn(move || handle.block_on(hook(state_owned)));
673 thread_handle
674 .join()
675 .expect("Plugin startup hook thread panicked")
676 .expect("Plugin startup hook failed");
677 } else {
678 let thread_handle = std::thread::spawn(move || {
679 let rt = tokio::runtime::Builder::new_multi_thread()
680 .enable_all()
681 .build()
682 .expect("failed to build tokio runtime for test plugin startup hook");
683 rt.block_on(hook(state_owned))
684 });
685 thread_handle
686 .join()
687 .expect("Plugin startup hook thread panicked")
688 .expect("Plugin startup hook failed");
689 }
690 }));
691 }
692 self
693 }
694
695 #[cfg(feature = "mail")]
696 #[must_use]
697 pub fn with_mail_interceptor(
698 mut self,
699 interceptor: impl crate::interceptor::MailInterceptor,
700 ) -> Self {
701 self.mail_interceptor = Some(std::sync::Arc::new(interceptor));
702 self
703 }
704
705 #[must_use]
706 pub fn with_job_interceptor(
707 mut self,
708 interceptor: impl crate::interceptor::JobInterceptor,
709 ) -> Self {
710 self.job_interceptor = Some(std::sync::Arc::new(interceptor));
711 self
712 }
713
714 #[cfg(feature = "db")]
715 #[must_use]
716 pub fn with_db_interceptor(
717 mut self,
718 interceptor: impl crate::interceptor::DbConnectionInterceptor,
719 ) -> Self {
720 self.db_interceptor = Some(std::sync::Arc::new(interceptor));
721 self
722 }
723
724 #[cfg(feature = "ws")]
725 #[must_use]
726 pub fn with_channels_interceptor(
727 mut self,
728 interceptor: impl crate::interceptor::ChannelsInterceptor,
729 ) -> Self {
730 self.channels_interceptor = Some(std::sync::Arc::new(interceptor));
731 self
732 }
733
734 #[cfg(feature = "oauth2")]
735 #[must_use]
736 pub fn with_http_interceptor(
737 mut self,
738 interceptor: impl crate::interceptor::HttpInterceptor,
739 ) -> Self {
740 self.http_interceptor = Some(std::sync::Arc::new(interceptor));
741 self
742 }
743
744 #[must_use]
746 pub fn config(mut self, config: AutumnConfig) -> Self {
747 self.config = config;
748 self
749 }
750
751 #[must_use]
753 pub fn profile(mut self, profile: &str) -> Self {
754 self.config.profile = Some(profile.to_owned());
755 self
756 }
757
758 #[must_use]
784 pub fn with_clock<C>(mut self, clock: C) -> Self
785 where
786 C: crate::time::ClockSource + 'static,
787 {
788 let arc: std::sync::Arc<C> = std::sync::Arc::new(clock);
789 self.clock_as_any = Some(arc.clone() as std::sync::Arc<dyn std::any::Any + Send + Sync>);
791 self.clock = Some(arc as std::sync::Arc<dyn crate::time::ClockSource>);
792 self
793 }
794
795 #[must_use]
797 pub fn api_version(mut self, version: crate::app::ApiVersion) -> Self {
798 self.api_versions.push(version);
799 self
800 }
801
802 #[must_use]
804 pub fn api_versions(
805 mut self,
806 versions: impl IntoIterator<Item = crate::app::ApiVersion>,
807 ) -> Self {
808 self.api_versions.extend(versions);
809 self
810 }
811
812 #[cfg(feature = "db")]
814 #[must_use]
815 pub fn with_db(mut self, pool: Pool<AsyncPgConnection>) -> Self {
816 self.pool = Some(pool);
817 self
818 }
819
820 #[cfg(feature = "db")]
823 #[must_use]
824 pub const fn transactional(mut self) -> Self {
825 self.transactional = true;
826 self
827 }
828
829 #[cfg(feature = "db")]
831 #[must_use]
832 pub fn with_transactional_db(mut self, url: impl Into<String>) -> Self {
833 self.transactional = true;
834 self.transactional_url = Some(url.into());
835 self
836 }
837
838 #[cfg(feature = "http-client")]
870 pub fn http_mock(&mut self, alias: &str) -> crate::http_client::MockSetupBuilder {
871 let registry = self
872 .http_mock_registry
873 .get_or_insert_with(|| std::sync::Arc::new(crate::http_client::MockRegistry::new()))
874 .clone();
875
876 crate::http_client::MockSetupBuilder {
877 registry,
878 alias: alias.to_owned(),
879 method: None,
880 path: None,
881 }
882 }
883
884 #[must_use]
895 #[cfg_attr(not(feature = "inbound-mail"), allow(unused_mut))]
896 pub fn build(mut self) -> TestClient {
897 crate::cache::clear_global_cache();
899
900 #[cfg(feature = "db")]
901 let (pool, replica_pool, db_interceptor) = if self.transactional {
902 let url = self.transactional_url.as_deref()
903 .or_else(|| self.config.database.effective_primary_url())
904 .expect("Transactional isolation enabled but database URL is not configured. Use `with_transactional_db(url)` or configure database.primary_url/database.url");
905
906 let connect_timeout_secs = self.config.database.connect_timeout_secs;
907 let timeout = std::time::Duration::from_secs(connect_timeout_secs);
908
909 let manager = diesel_async::pooled_connection::AsyncDieselConnectionManager::<
910 diesel_async::AsyncPgConnection,
911 >::new(url);
912 let pool = Pool::builder(manager)
913 .max_size(1)
914 .wait_timeout(Some(timeout))
915 .create_timeout(Some(timeout))
916 .runtime(deadpool::Runtime::Tokio1)
917 .post_create(deadpool::managed::Hook::async_fn(
918 |conn: &mut diesel_async::AsyncPgConnection, _metrics| {
919 Box::pin(async move {
920 use diesel_async::AsyncConnection;
921 use diesel_async::RunQueryDsl;
922
923 conn.begin_test_transaction().await.map_err(|e| {
924 deadpool::managed::HookError::Backend(
925 diesel_async::pooled_connection::PoolError::QueryError(e),
926 )
927 })?;
928
929 diesel::sql_query("SET autumn.test_transaction_started = 'true'")
930 .execute(conn)
931 .await
932 .map_err(|e| {
933 deadpool::managed::HookError::Backend(
934 diesel_async::pooled_connection::PoolError::QueryError(e),
935 )
936 })?;
937
938 Ok(())
939 })
940 },
941 ))
942 .build()
943 .expect("failed to build transactional pool of size 1");
944
945 let trans_interceptor = std::sync::Arc::new(TransactionalDbInterceptor);
946 let interceptor = if let Some(user_interceptor) = self.db_interceptor {
947 std::sync::Arc::new(ComposedDbInterceptor {
948 first: user_interceptor,
949 second: trans_interceptor,
950 })
951 as std::sync::Arc<dyn crate::interceptor::DbConnectionInterceptor>
952 } else {
953 trans_interceptor as std::sync::Arc<dyn crate::interceptor::DbConnectionInterceptor>
954 };
955
956 (Some(pool), None, Some(interceptor))
957 } else {
958 (self.pool, self.replica_pool, self.db_interceptor)
959 };
960
961 let probes = crate::probe::ProbeState::ready_for_test();
962 #[cfg(feature = "ws")]
963 let test_channels = crate::channels::Channels::new(32);
964 #[cfg_attr(not(feature = "ws"), allow(unused_mut))]
965 let mut state = AppState {
966 extensions: std::sync::Arc::new(std::sync::RwLock::new(
967 std::collections::HashMap::new(),
968 )),
969 #[cfg(feature = "db")]
970 pool,
971 #[cfg(feature = "db")]
972 replica_pool,
973 profile: self.config.profile.clone(),
974 started_at: std::time::Instant::now(),
975 health_detailed: self.config.health.detailed,
976 probes: probes.clone(),
977 metrics: crate::middleware::MetricsCollector::new(),
978 log_levels: crate::actuator::LogLevels::new(&self.config.log.level),
979 task_registry: crate::actuator::TaskRegistry::new(),
980 job_registry: crate::actuator::JobRegistry::new(),
981 config_props: crate::actuator::ConfigProperties::default(),
982 metrics_source_registry: crate::actuator::MetricsSourceRegistry::new(),
983 health_indicator_registry: crate::actuator::HealthIndicatorRegistry::new(),
984 #[cfg(feature = "presence")]
985 presence: crate::presence::Presence::new(test_channels.clone()),
986 #[cfg(feature = "ws")]
987 channels: test_channels,
988
989 #[cfg(feature = "ws")]
990 shutdown: tokio_util::sync::CancellationToken::new(),
991 policy_registry: crate::authorization::PolicyRegistry::default(),
992 forbidden_response: self
993 .forbidden_response_override
994 .unwrap_or(self.config.security.forbidden_response),
995 auth_session_key: self.config.auth.session_key.clone(),
996 shared_cache: None,
997 clock: self
998 .clock
999 .unwrap_or_else(|| std::sync::Arc::new(crate::time::SystemClock)),
1000 };
1001
1002 for register in self.policy_registrations {
1003 register(state.policy_registry());
1004 }
1005 state.insert_extension(crate::app::RegisteredApiVersions(self.api_versions));
1006 crate::app::install_webhook_registry(&state, &self.config);
1007
1008 state.insert_extension(self.config.clone());
1011
1012 #[cfg(feature = "mail")]
1013 if let Some(interceptor) = self.mail_interceptor {
1014 state.insert_extension(interceptor);
1015 }
1016 if let Some(interceptor) = self.job_interceptor {
1017 state.insert_extension(interceptor);
1018 }
1019 #[cfg(feature = "db")]
1020 if let Some(interceptor) = db_interceptor {
1021 state.insert_extension(interceptor);
1022 }
1023 #[cfg(feature = "ws")]
1024 if let Some(interceptor) = self.channels_interceptor {
1025 state.insert_extension(interceptor.clone());
1026 state.channels = crate::channels::Channels::with_shared_backend(std::sync::Arc::new(
1027 crate::channels::InterceptedChannelsBackend::new(
1028 state.channels.backend().clone(),
1029 vec![interceptor],
1030 ),
1031 ));
1032 #[cfg(feature = "presence")]
1033 {
1034 state.presence = crate::presence::Presence::new(state.channels.clone());
1035 }
1036 }
1037 #[cfg(feature = "oauth2")]
1038 if let Some(interceptor) = self.http_interceptor {
1039 state.insert_extension(interceptor);
1040 }
1041
1042 #[cfg(feature = "mail")]
1043 {
1044 crate::mail::install_mailer(&state, &self.config.mail, false)
1045 .expect("Failed to configure test mailer");
1046 }
1047
1048 #[cfg(feature = "http-client")]
1050 state.insert_extension(self.config.http.clone());
1051
1052 #[cfg(feature = "http-client")]
1054 if let Some(registry) = self.http_mock_registry {
1055 state.insert_extension(crate::http_client::HttpMockRegistryExt(registry));
1056 }
1057
1058 for (name, source) in self.metrics_sources {
1061 if let Err(e) = state.metrics_source_registry.register(name, source) {
1062 tracing::warn!("{e}");
1063 }
1064 }
1065 for (name, group, indicator) in self.health_indicators {
1066 if let Err(e) = state
1067 .health_indicator_registry
1068 .register(name, group, indicator)
1069 {
1070 tracing::warn!("{e}");
1071 }
1072 }
1073
1074 for initializer in self.state_initializers {
1075 initializer(&state);
1076 }
1077
1078 for job in &self.jobs {
1079 state.job_registry.register(&job.name);
1080 }
1081
1082 let job_runtime = if self.jobs.is_empty() {
1083 None
1084 } else {
1085 let shutdown = tokio_util::sync::CancellationToken::new();
1086 crate::job::start_runtime(self.jobs.clone(), &state, &shutdown, &self.config.jobs)
1087 .expect("Failed to start job runtime in test");
1088 Some(TestJobRuntime { shutdown })
1089 };
1090
1091 #[cfg_attr(not(feature = "inbound-mail"), allow(unused_mut))]
1092 let mut merge_routers = self.merge_routers;
1093 #[cfg(feature = "inbound-mail")]
1094 if let Some(ref im_router) = self.inbound_mail_router {
1095 let mut registered_inbound: std::collections::HashSet<String> =
1096 std::collections::HashSet::new();
1097 for (path, axum_router) in crate::inbound_mail::build_routes(im_router) {
1098 if self
1099 .routes
1100 .iter()
1101 .any(|r| r.method == Method::POST && r.path == path)
1102 || self.scoped_groups.iter().any(|g| {
1103 g.routes.iter().any(|r| {
1104 r.method == Method::POST
1105 && crate::router::join_nested_path(&g.prefix, r.path)
1106 == path.as_str()
1107 })
1108 })
1109 || self.nest_routers.iter().any(|(nest_path, _)| {
1110 let p = nest_path.as_str();
1111 path.as_str() == p
1112 || path.starts_with(p)
1113 && (p.ends_with('/') || path.as_bytes().get(p.len()) == Some(&b'/'))
1114 })
1115 {
1116 tracing::warn!(
1117 path = %path,
1118 "inbound_mail: skipping webhook route — a POST handler is \
1119 already registered at this path by the application"
1120 );
1121 continue;
1122 }
1123 if !registered_inbound.insert(path.clone()) {
1124 tracing::warn!(
1125 path = %path,
1126 "inbound_mail: skipping duplicate inbound webhook path"
1127 );
1128 continue;
1129 }
1130 self.config.security.csrf.exempt_paths.push(path.clone());
1131 self.config.security.captcha_exempt_paths.push(path);
1132 merge_routers.push(axum_router);
1133 }
1134 }
1135
1136 let router = crate::router::try_build_router_inner(
1137 self.routes,
1138 &self.config,
1139 state.clone(),
1140 crate::router::RouterContext {
1141 exception_filters: self.exception_filters,
1142 scoped_groups: self.scoped_groups,
1143 merge_routers,
1144 nest_routers: self.nest_routers,
1145 custom_layers: self.custom_layers,
1146 error_page_renderer: None,
1147 session_store: None,
1148 #[cfg(feature = "openapi")]
1149 openapi: self.openapi,
1150 #[cfg(feature = "mcp")]
1151 mcp: self.mcp,
1152 },
1153 )
1154 .expect("failed to build test router");
1155 let router = if self.config.log.access_log {
1162 router.layer(crate::middleware::AccessLogLayer::fallback(
1163 self.config.log.access_log_exclude.clone(),
1164 ))
1165 } else {
1166 router
1167 };
1168 TestClient {
1169 router,
1170 probes,
1171 state,
1172 _job_runtime: job_runtime,
1173 clock_as_any: self.clock_as_any,
1174 }
1175 }
1176}
1177
1178impl Default for TestApp {
1179 fn default() -> Self {
1180 Self::new()
1181 }
1182}
1183
1184pub struct TestClient {
1216 router: axum::Router,
1217 probes: crate::probe::ProbeState,
1218 pub(crate) state: AppState,
1219 _job_runtime: Option<TestJobRuntime>,
1220 clock_as_any: Option<std::sync::Arc<dyn std::any::Any + Send + Sync>>,
1222}
1223
1224struct TestJobRuntime {
1225 shutdown: tokio_util::sync::CancellationToken,
1226}
1227
1228impl Drop for TestJobRuntime {
1229 fn drop(&mut self) {
1230 self.shutdown.cancel();
1231 crate::job::clear_global_job_client();
1232 }
1233}
1234
1235impl TestClient {
1236 #[must_use]
1238 pub const fn state(&self) -> &AppState {
1239 &self.state
1240 }
1241
1242 pub fn advance_clock(&self, duration: std::time::Duration) {
1268 if let Some(any) = &self.clock_as_any {
1269 let cloned = std::sync::Arc::clone(any);
1270 if let Ok(ticking) = cloned.downcast::<crate::time::TickingClock>() {
1271 ticking.advance(duration);
1272 }
1273 }
1275 }
1277
1278 pub fn into_router(self) -> axum::Router {
1280 self.router
1281 }
1282
1283 pub const fn probes(&self) -> &crate::probe::ProbeState {
1288 &self.probes
1289 }
1290
1291 #[must_use]
1293 pub fn get(&self, uri: &str) -> RequestBuilder {
1294 RequestBuilder::new(self.router.clone(), Method::GET, uri)
1295 }
1296
1297 #[must_use]
1299 pub fn post(&self, uri: &str) -> RequestBuilder {
1300 RequestBuilder::new(self.router.clone(), Method::POST, uri)
1301 }
1302
1303 #[must_use]
1305 pub fn put(&self, uri: &str) -> RequestBuilder {
1306 RequestBuilder::new(self.router.clone(), Method::PUT, uri)
1307 }
1308
1309 #[must_use]
1311 pub fn delete(&self, uri: &str) -> RequestBuilder {
1312 RequestBuilder::new(self.router.clone(), Method::DELETE, uri)
1313 }
1314
1315 #[must_use]
1317 pub fn patch(&self, uri: &str) -> RequestBuilder {
1318 RequestBuilder::new(self.router.clone(), Method::PATCH, uri)
1319 }
1320
1321 #[must_use]
1323 pub fn options(&self, uri: &str) -> RequestBuilder {
1324 RequestBuilder::new(self.router.clone(), Method::OPTIONS, uri)
1325 }
1326}
1327
1328pub struct RequestBuilder {
1336 router: axum::Router,
1337 method: Method,
1338 uri: String,
1339 headers: Vec<(String, String)>,
1340 body: Body,
1341}
1342
1343impl RequestBuilder {
1344 fn new(router: axum::Router, method: Method, uri: &str) -> Self {
1345 Self {
1346 router,
1347 method,
1348 uri: uri.to_owned(),
1349 headers: Vec::new(),
1350 body: Body::empty(),
1351 }
1352 }
1353
1354 #[must_use]
1356 pub fn header(mut self, name: &str, value: &str) -> Self {
1357 self.headers.push((name.to_owned(), value.to_owned()));
1358 self
1359 }
1360
1361 #[must_use]
1365 pub fn json(mut self, value: &serde_json::Value) -> Self {
1366 self.headers
1367 .push(("content-type".to_owned(), "application/json".to_owned()));
1368 self.body = Body::from(serde_json::to_vec(value).expect("failed to serialize JSON body"));
1369 self
1370 }
1371
1372 #[must_use]
1380 pub fn form(mut self, body: &str) -> Self {
1381 self.headers.push((
1382 "content-type".to_owned(),
1383 "application/x-www-form-urlencoded".to_owned(),
1384 ));
1385 self.headers
1386 .push(("sec-fetch-site".to_owned(), "same-origin".to_owned()));
1387 self.body = Body::from(body.to_owned());
1388 self
1389 }
1390
1391 #[must_use]
1393 pub fn body(mut self, body: impl Into<Body>) -> Self {
1394 self.body = body.into();
1395 self
1396 }
1397
1398 pub async fn send(self) -> TestResponse {
1401 let mut builder = Request::builder().method(self.method).uri(&self.uri);
1402
1403 for (name, value) in &self.headers {
1404 builder = builder.header(name.as_str(), value.as_str());
1405 }
1406
1407 let request = builder.body(self.body).expect("failed to build request");
1408
1409 let service =
1415 tower::Layer::layer(&crate::middleware::MethodOverrideLayer::new(), self.router);
1416 let response = service.oneshot(request).await.expect("request failed");
1417
1418 let status = response.status();
1419 let headers: Vec<(String, String)> = response
1420 .headers()
1421 .iter()
1422 .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_owned()))
1423 .collect();
1424 let body_bytes = axum::body::to_bytes(response.into_body(), usize::MAX)
1425 .await
1426 .expect("failed to read response body");
1427
1428 TestResponse {
1429 status,
1430 headers,
1431 body: body_bytes.to_vec(),
1432 }
1433 }
1434}
1435
1436pub struct TestResponse {
1472 pub status: StatusCode,
1474 pub headers: Vec<(String, String)>,
1476 pub body: Vec<u8>,
1478}
1479
1480impl TestResponse {
1481 #[must_use]
1487 pub fn text(&self) -> String {
1488 String::from_utf8(self.body.clone()).unwrap_or_else(|e| {
1489 panic!(
1490 "response body is not valid UTF-8: {e}\nRaw bytes: {:?}",
1491 self.body
1492 )
1493 })
1494 }
1495
1496 #[must_use]
1503 pub fn json<T: serde::de::DeserializeOwned>(&self) -> T {
1504 serde_json::from_slice(&self.body).unwrap_or_else(|e| {
1505 panic!(
1506 "failed to parse response body as JSON: {e}\nBody: {}",
1507 String::from_utf8_lossy(&self.body)
1508 )
1509 })
1510 }
1511
1512 #[must_use]
1514 pub fn header(&self, name: &str) -> Option<&str> {
1515 let name_lower = name.to_lowercase();
1516 self.headers
1517 .iter()
1518 .find(|(k, _)| k.to_lowercase() == name_lower)
1519 .map(|(_, v)| v.as_str())
1520 }
1521
1522 #[track_caller]
1526 pub fn assert_ok(&self) -> &Self {
1527 assert_eq!(
1528 self.status,
1529 StatusCode::OK,
1530 "expected 200 OK, got {}.\nBody: {}",
1531 self.status,
1532 String::from_utf8_lossy(&self.body)
1533 );
1534 self
1535 }
1536
1537 #[track_caller]
1539 pub fn assert_status(&self, expected: u16) -> &Self {
1540 assert_eq!(
1541 self.status.as_u16(),
1542 expected,
1543 "expected status {expected}, got {}.\nBody: {}",
1544 self.status,
1545 String::from_utf8_lossy(&self.body)
1546 );
1547 self
1548 }
1549
1550 #[track_caller]
1552 pub fn assert_success(&self) -> &Self {
1553 assert!(
1554 self.status.is_success(),
1555 "expected 2xx success, got {}.\nBody: {}",
1556 self.status,
1557 String::from_utf8_lossy(&self.body)
1558 );
1559 self
1560 }
1561
1562 #[track_caller]
1564 pub fn assert_header(&self, name: &str, expected: &str) -> &Self {
1565 let value = self.header(name).unwrap_or_else(|| {
1566 panic!(
1567 "expected header `{name}` to be present.\nAvailable headers: {:?}",
1568 self.headers
1569 )
1570 });
1571 assert_eq!(
1572 value, expected,
1573 "header `{name}`: expected `{expected}`, got `{value}`"
1574 );
1575 self
1576 }
1577
1578 #[track_caller]
1580 pub fn assert_header_contains(&self, name: &str, substring: &str) -> &Self {
1581 let value = self.header(name).unwrap_or_else(|| {
1582 panic!(
1583 "expected header `{name}` to be present.\nAvailable headers: {:?}",
1584 self.headers
1585 )
1586 });
1587 assert!(
1588 value.contains(substring),
1589 "header `{name}`: expected `{value}` to contain `{substring}`"
1590 );
1591 self
1592 }
1593
1594 #[track_caller]
1596 pub fn assert_body_contains(&self, substring: &str) -> &Self {
1597 let body = self.text();
1598 assert!(
1599 body.contains(substring),
1600 "expected body to contain `{substring}`.\nBody: {body}"
1601 );
1602 self
1603 }
1604
1605 #[track_caller]
1607 pub fn assert_body_eq(&self, expected: &str) -> &Self {
1608 let body = self.text();
1609 assert_eq!(body, expected, "body mismatch.\nActual Body: {body}");
1610 self
1611 }
1612
1613 #[track_caller]
1615 pub fn assert_json<T, F>(&self, predicate: F) -> &Self
1616 where
1617 T: serde::de::DeserializeOwned,
1618 F: FnOnce(&T),
1619 {
1620 let value: T = self.json();
1621 predicate(&value);
1622 self
1623 }
1624
1625 #[track_caller]
1627 pub fn assert_body_empty(&self) -> &Self {
1628 assert!(
1629 self.body.is_empty(),
1630 "expected empty body, got {} bytes: {}",
1631 self.body.len(),
1632 String::from_utf8_lossy(&self.body)
1633 );
1634 self
1635 }
1636
1637 fn parse_html(&self) -> Vec<crate::test_html::Node> {
1661 crate::test_html::parse(&self.text())
1662 }
1663
1664 #[track_caller]
1667 fn compile_selector(css: &str) -> crate::test_html::SelectorList {
1668 crate::test_html::SelectorList::parse(css)
1669 .unwrap_or_else(|e| panic!("invalid CSS selector `{css}`: {e}"))
1670 }
1671
1672 fn html_outline(nodes: &[crate::test_html::Node]) -> String {
1674 crate::test_html::outline(nodes, 1200)
1675 }
1676
1677 #[must_use]
1683 #[track_caller]
1684 pub fn selector_text(&self, css: &str) -> Vec<String> {
1685 let selector = Self::compile_selector(css);
1686 let nodes = self.parse_html();
1687 selector
1688 .matches(&nodes)
1689 .iter()
1690 .map(|el| crate::test_html::normalize_ws(&el.text()))
1691 .collect()
1692 }
1693
1694 #[must_use]
1698 #[track_caller]
1699 pub fn selector_attr(&self, css: &str, attr: &str) -> Vec<Option<String>> {
1700 let selector = Self::compile_selector(css);
1701 let nodes = self.parse_html();
1702 selector
1703 .matches(&nodes)
1704 .iter()
1705 .map(|el| el.attr(attr).map(str::to_string))
1706 .collect()
1707 }
1708
1709 #[must_use]
1711 #[track_caller]
1712 pub fn selector_count(&self, css: &str) -> usize {
1713 let selector = Self::compile_selector(css);
1714 let nodes = self.parse_html();
1715 selector.matches(&nodes).len()
1716 }
1717
1718 #[track_caller]
1720 pub fn assert_selector(&self, css: &str) -> &Self {
1721 let selector = Self::compile_selector(css);
1722 let nodes = self.parse_html();
1723 let count = selector.matches(&nodes).len();
1724 assert!(
1725 count > 0,
1726 "no elements matched selector `{css}`.\nParsed HTML:\n{}",
1727 Self::html_outline(&nodes)
1728 );
1729 self
1730 }
1731
1732 #[track_caller]
1734 pub fn assert_no_selector(&self, css: &str) -> &Self {
1735 let selector = Self::compile_selector(css);
1736 let nodes = self.parse_html();
1737 let count = selector.matches(&nodes).len();
1738 assert!(
1739 count == 0,
1740 "expected no elements matching selector `{css}`, but found {count}.\nParsed HTML:\n{}",
1741 Self::html_outline(&nodes)
1742 );
1743 self
1744 }
1745
1746 #[track_caller]
1748 pub fn assert_selector_count(&self, css: &str, expected: usize) -> &Self {
1749 let selector = Self::compile_selector(css);
1750 let nodes = self.parse_html();
1751 let actual = selector.matches(&nodes).len();
1752 assert!(
1753 actual == expected,
1754 "expected {expected} element(s) matching selector `{css}`, found {actual}.\n\
1755 Parsed HTML:\n{}",
1756 Self::html_outline(&nodes)
1757 );
1758 self
1759 }
1760
1761 #[track_caller]
1764 pub fn assert_text(&self, css: &str, expected: &str) -> &Self {
1765 let selector = Self::compile_selector(css);
1766 let nodes = self.parse_html();
1767 let matched = selector.matches(&nodes);
1768 let Some(first) = matched.into_iter().next() else {
1769 panic!(
1770 "no elements matched selector `{css}`.\nParsed HTML:\n{}",
1771 Self::html_outline(&nodes)
1772 );
1773 };
1774 let actual = crate::test_html::normalize_ws(&first.text());
1775 let expected_norm = crate::test_html::normalize_ws(expected);
1776 assert!(
1777 actual == expected_norm,
1778 "text mismatch for selector `{css}`:\n expected: {expected_norm:?}\n \
1779 actual: {actual:?}\nParsed HTML:\n{}",
1780 Self::html_outline(&nodes)
1781 );
1782 self
1783 }
1784
1785 #[track_caller]
1788 pub fn assert_text_contains(&self, css: &str, substring: &str) -> &Self {
1789 let selector = Self::compile_selector(css);
1790 let nodes = self.parse_html();
1791 let matched = selector.matches(&nodes);
1792 let Some(first) = matched.into_iter().next() else {
1793 panic!(
1794 "no elements matched selector `{css}`.\nParsed HTML:\n{}",
1795 Self::html_outline(&nodes)
1796 );
1797 };
1798 let actual = crate::test_html::normalize_ws(&first.text());
1799 let needle = crate::test_html::normalize_ws(substring);
1800 assert!(
1801 actual.contains(&needle),
1802 "text for selector `{css}` did not contain {needle:?}.\n actual: {actual:?}\n\
1803 Parsed HTML:\n{}",
1804 Self::html_outline(&nodes)
1805 );
1806 self
1807 }
1808
1809 #[track_caller]
1812 pub fn assert_attr(&self, css: &str, attr: &str, expected: &str) -> &Self {
1813 let selector = Self::compile_selector(css);
1814 let nodes = self.parse_html();
1815 let matched = selector.matches(&nodes);
1816 let Some(first) = matched.into_iter().next() else {
1817 panic!(
1818 "no elements matched selector `{css}`.\nParsed HTML:\n{}",
1819 Self::html_outline(&nodes)
1820 );
1821 };
1822 match first.attr(attr) {
1823 Some(actual) => assert!(
1824 actual == expected,
1825 "attribute `{attr}` mismatch for selector `{css}`:\n expected: {expected:?}\n \
1826 actual: {actual:?}\nParsed HTML:\n{}",
1827 Self::html_outline(&nodes)
1828 ),
1829 None => panic!(
1830 "element matching selector `{css}` has no `{attr}` attribute.\n\
1831 Parsed HTML:\n{}",
1832 Self::html_outline(&nodes)
1833 ),
1834 }
1835 self
1836 }
1837}
1838
1839#[cfg(feature = "db")]
1840struct TransactionalDbInterceptor;
1841
1842#[cfg(feature = "db")]
1843impl crate::interceptor::DbConnectionInterceptor for TransactionalDbInterceptor {
1844 fn intercept_checkout<'a>(
1845 &'a self,
1846 _ctx: crate::interceptor::DbCheckoutContext,
1847 next: std::pin::Pin<
1848 Box<
1849 dyn std::future::Future<
1850 Output = Result<crate::db::PooledConnection, crate::AutumnError>,
1851 > + Send
1852 + 'a,
1853 >,
1854 >,
1855 ) -> std::pin::Pin<
1856 Box<
1857 dyn std::future::Future<
1858 Output = Result<crate::db::PooledConnection, crate::AutumnError>,
1859 > + Send
1860 + 'a,
1861 >,
1862 > {
1863 Box::pin(async move {
1864 let mut conn = next.await?;
1865
1866 let guc_result = diesel::select(diesel::dsl::sql::<
1868 diesel::sql_types::Nullable<diesel::sql_types::Text>,
1869 >(
1870 "current_setting('autumn.test_transaction_started', true)",
1871 ))
1872 .get_result::<Option<String>>(&mut *conn)
1873 .await;
1874
1875 match guc_result {
1876 Ok(Some(ref s)) if s == "true" => {
1877 }
1879 Ok(_) => {
1880 use diesel_async::AsyncConnection;
1881 use diesel_async::RunQueryDsl;
1882
1883 conn.begin_test_transaction().await.map_err(|e| {
1884 crate::AutumnError::internal_server_error_msg(format!(
1885 "failed to start test transaction: {e}"
1886 ))
1887 })?;
1888
1889 diesel::sql_query("SET autumn.test_transaction_started = 'true'")
1890 .execute(&mut *conn)
1891 .await
1892 .map_err(|e| {
1893 crate::AutumnError::internal_server_error_msg(format!(
1894 "failed to set transaction session GUC: {e}"
1895 ))
1896 })?;
1897 }
1898 Err(_) => {
1899 }
1902 }
1903 Ok(conn)
1904 })
1905 }
1906
1907 fn is_transactional_test(&self) -> bool {
1908 true
1909 }
1910}
1911
1912#[cfg(feature = "db")]
1913struct ComposedDbInterceptor {
1914 first: std::sync::Arc<dyn crate::interceptor::DbConnectionInterceptor>,
1915 second: std::sync::Arc<dyn crate::interceptor::DbConnectionInterceptor>,
1916}
1917
1918#[cfg(feature = "db")]
1919impl crate::interceptor::DbConnectionInterceptor for ComposedDbInterceptor {
1920 fn intercept_checkout<'a>(
1921 &'a self,
1922 ctx: crate::interceptor::DbCheckoutContext,
1923 next: std::pin::Pin<
1924 Box<
1925 dyn std::future::Future<
1926 Output = Result<crate::db::PooledConnection, crate::AutumnError>,
1927 > + Send
1928 + 'a,
1929 >,
1930 >,
1931 ) -> std::pin::Pin<
1932 Box<
1933 dyn std::future::Future<
1934 Output = Result<crate::db::PooledConnection, crate::AutumnError>,
1935 > + Send
1936 + 'a,
1937 >,
1938 > {
1939 let next_wrapped = self.second.intercept_checkout(ctx.clone(), next);
1940 self.first.intercept_checkout(ctx, next_wrapped)
1941 }
1942
1943 fn is_transactional_test(&self) -> bool {
1944 self.first.is_transactional_test() || self.second.is_transactional_test()
1945 }
1946}
1947
1948#[cfg(all(feature = "db", feature = "test-support"))]
1983pub struct TestDb {
1984 _container: testcontainers::ContainerAsync<testcontainers_modules::postgres::Postgres>,
1985 pool: Pool<AsyncPgConnection>,
1986 url: String,
1987}
1988
1989#[cfg(all(feature = "db", feature = "test-support"))]
1990impl TestDb {
1991 pub async fn new() -> Self {
1996 use diesel_async::pooled_connection::AsyncDieselConnectionManager;
1997 use testcontainers::runners::AsyncRunner;
1998 use testcontainers_modules::postgres::Postgres;
1999
2000 let container = Postgres::default()
2001 .start()
2002 .await
2003 .expect("failed to start Postgres testcontainer (is Docker running?)");
2004
2005 let host = container
2006 .get_host()
2007 .await
2008 .expect("failed to build test router");
2009 let port = container
2010 .get_host_port_ipv4(5432)
2011 .await
2012 .expect("failed to build test router");
2013 let url = format!("postgres://postgres:postgres@{host}:{port}/postgres");
2014
2015 let manager = AsyncDieselConnectionManager::<AsyncPgConnection>::new(&url);
2016 let pool = Pool::builder(manager)
2017 .max_size(5)
2018 .build()
2019 .expect("failed to build connection pool");
2020
2021 Self {
2022 _container: container,
2023 pool,
2024 url,
2025 }
2026 }
2027
2028 pub async fn shared() -> &'static Self {
2036 use std::sync::OnceLock;
2037 use tokio::sync::OnceCell;
2038
2039 static CELL: OnceLock<OnceCell<TestDb>> = OnceLock::new();
2041 let once = CELL.get_or_init(OnceCell::new);
2042 once.get_or_init(Self::new).await
2043 }
2044
2045 #[must_use]
2047 pub fn pool(&self) -> Pool<AsyncPgConnection> {
2048 self.pool.clone()
2049 }
2050
2051 #[must_use]
2053 pub fn url(&self) -> &str {
2054 &self.url
2055 }
2056
2057 pub async fn execute_sql(&self, sql: &str) {
2070 use diesel_async::RunQueryDsl;
2071 let mut conn = self.pool.get().await.expect("failed to get connection");
2072 diesel::sql_query(sql)
2073 .execute(&mut *conn)
2074 .await
2075 .unwrap_or_else(|e| panic!("SQL execution failed: {e}\nSQL: {sql}"));
2076 }
2077}
2078
2079#[cfg(test)]
2080mod tests {
2081 use super::*;
2082
2083 fn cleanup_probe_job(
2084 _state: crate::state::AppState,
2085 _payload: serde_json::Value,
2086 ) -> std::pin::Pin<
2087 Box<dyn std::future::Future<Output = crate::AutumnResult<()>> + Send + 'static>,
2088 > {
2089 Box::pin(async move { Ok(()) })
2090 }
2091
2092 struct CleanupJobPlugin;
2093
2094 impl crate::plugin::Plugin for CleanupJobPlugin {
2095 fn build(self, app: crate::app::AppBuilder) -> crate::app::AppBuilder {
2096 app.jobs(vec![crate::job::JobInfo {
2097 name: "cleanup_probe".to_string(),
2098 max_attempts: 1,
2099 initial_backoff_ms: 1,
2100 uniqueness: None,
2101 concurrency: None,
2102 handler: cleanup_probe_job,
2103 }])
2104 }
2105 }
2106
2107 fn test_routes() -> Vec<Route> {
2108 use axum::routing;
2109
2110 async fn hello() -> &'static str {
2111 "hello"
2112 }
2113
2114 async fn echo_json(
2115 axum::Json(value): axum::Json<serde_json::Value>,
2116 ) -> axum::Json<serde_json::Value> {
2117 axum::Json(value)
2118 }
2119
2120 async fn status_201() -> (StatusCode, &'static str) {
2121 (StatusCode::CREATED, "created")
2122 }
2123
2124 vec![
2125 Route {
2126 method: Method::GET,
2127 path: "/hello",
2128 handler: routing::get(hello),
2129 name: "hello",
2130 api_doc: crate::openapi::ApiDoc {
2131 method: "GET",
2132 path: "/hello",
2133 operation_id: "hello",
2134 success_status: 200,
2135 ..Default::default()
2136 },
2137 repository: None,
2138 idempotency: crate::route::RouteIdempotency::Direct,
2139 api_version: None,
2140 sunset_opt_out: false,
2141 },
2142 Route {
2143 method: Method::POST,
2144 path: "/echo",
2145 handler: routing::post(echo_json),
2146 name: "echo",
2147 api_doc: crate::openapi::ApiDoc {
2148 method: "POST",
2149 path: "/echo",
2150 operation_id: "echo",
2151 success_status: 200,
2152 ..Default::default()
2153 },
2154 repository: None,
2155 idempotency: crate::route::RouteIdempotency::Direct,
2156 api_version: None,
2157 sunset_opt_out: false,
2158 },
2159 Route {
2160 method: Method::POST,
2161 path: "/create",
2162 handler: routing::post(status_201),
2163 name: "create",
2164 api_doc: crate::openapi::ApiDoc {
2165 method: "POST",
2166 path: "/create",
2167 operation_id: "create",
2168 success_status: 201,
2169 ..Default::default()
2170 },
2171 repository: None,
2172 idempotency: crate::route::RouteIdempotency::Direct,
2173 api_version: None,
2174 sunset_opt_out: false,
2175 },
2176 ]
2177 }
2178
2179 #[tokio::test]
2180 async fn test_app_get_request() {
2181 let client = TestApp::new().routes(test_routes()).build();
2182 client.get("/hello").send().await.assert_ok();
2183 }
2184
2185 #[tokio::test]
2186 async fn test_app_post_json() {
2187 let client = TestApp::new().routes(test_routes()).build();
2188
2189 client
2190 .post("/echo")
2191 .json(&serde_json::json!({"key": "value"}))
2192 .send()
2193 .await
2194 .assert_ok()
2195 .assert_body_contains("key");
2196 }
2197
2198 #[tokio::test]
2199 async fn test_response_assert_status() {
2200 let client = TestApp::new().routes(test_routes()).build();
2201
2202 client
2203 .post("/create")
2204 .send()
2205 .await
2206 .assert_status(201)
2207 .assert_body_eq("created");
2208 }
2209
2210 #[tokio::test]
2211 async fn test_response_assert_success() {
2212 let client = TestApp::new().routes(test_routes()).build();
2213 client.get("/hello").send().await.assert_success();
2214 }
2215
2216 #[tokio::test]
2217 async fn test_not_found() {
2218 let client = TestApp::new().routes(test_routes()).build();
2219 client.get("/nonexistent").send().await.assert_status(404);
2220 }
2221
2222 #[tokio::test]
2223 async fn test_response_json_deserialization() {
2224 let client = TestApp::new().routes(test_routes()).build();
2225
2226 let resp = client
2227 .post("/echo")
2228 .json(&serde_json::json!({"count": 42}))
2229 .send()
2230 .await;
2231
2232 resp.assert_ok().assert_json::<serde_json::Value, _>(|v| {
2233 assert_eq!(v["count"], 42);
2234 });
2235 }
2236
2237 #[tokio::test]
2238 async fn test_custom_header() {
2239 let client = TestApp::new().routes(test_routes()).build();
2240
2241 let resp = client
2242 .get("/hello")
2243 .header("x-custom", "test-value")
2244 .send()
2245 .await;
2246 resp.assert_ok();
2247 }
2248
2249 #[tokio::test]
2250 async fn test_client_default() {
2251 let _app = TestApp::default();
2252 }
2253
2254 #[tokio::test]
2255 async fn dropping_test_client_stops_test_started_job_runtime() {
2256 let _guard = crate::job::global_job_runtime_test_lock().lock().await;
2257 crate::job::clear_global_job_client();
2258
2259 let client = TestApp::new().plugin(CleanupJobPlugin).build();
2260 let leaked_client = crate::job::global_job_client().expect("test job runtime should start");
2261
2262 drop(client);
2263
2264 assert!(
2265 crate::job::global_job_client().is_none(),
2266 "dropping a TestClient with jobs must clear its global job client"
2267 );
2268
2269 let mut last_enqueue_error = None;
2270 for _ in 0..25 {
2271 match leaked_client
2272 .enqueue("cleanup_probe", serde_json::json!({}))
2273 .await
2274 {
2275 Ok(()) => tokio::time::sleep(std::time::Duration::from_millis(10)).await,
2276 Err(error) => {
2277 last_enqueue_error = Some(error.to_string());
2278 break;
2279 }
2280 }
2281 }
2282
2283 assert!(
2284 last_enqueue_error
2285 .as_deref()
2286 .is_some_and(|message| message.contains("failed to enqueue job")),
2287 "captured pre-drop job client must stop accepting jobs after TestClient drop; \
2288 last error: {last_enqueue_error:?}"
2289 );
2290
2291 crate::job::clear_global_job_client();
2292 }
2293
2294 #[tokio::test]
2299 async fn test_app_routes_html_method_override_to_delete() {
2300 use axum::routing;
2301 async fn deleted() -> &'static str {
2302 "deleted"
2303 }
2304 let routes = vec![Route {
2305 method: Method::DELETE,
2306 path: "/items/{id}",
2307 handler: routing::delete(deleted),
2308 name: "items_delete",
2309 api_doc: crate::openapi::ApiDoc {
2310 method: "DELETE",
2311 path: "/items/{id}",
2312 operation_id: "items_delete",
2313 success_status: 200,
2314 ..Default::default()
2315 },
2316 repository: None,
2317 idempotency: crate::route::RouteIdempotency::Direct,
2318 api_version: None,
2319 sunset_opt_out: false,
2320 }];
2321 let client = TestApp::new().routes(routes).build();
2322
2323 client
2324 .post("/items/1")
2325 .form("_method=DELETE")
2326 .send()
2327 .await
2328 .assert_ok()
2329 .assert_body_eq("deleted");
2330 }
2331
2332 #[cfg(feature = "maud")]
2340 mod html_assertions {
2341 use super::*;
2342 use axum::routing::get;
2343
2344 async fn notes_index_v1() -> maud::Markup {
2347 maud::html! {
2348 table.notes {
2349 tbody {
2350 @for id in 1..=3u32 {
2351 tr.note-row {
2352 td.title { a href=(format!("/notes/{id}")) { "Note " (id) } }
2353 }
2354 }
2355 }
2356 }
2357 }
2358 }
2359
2360 async fn notes_index_v2() -> maud::Markup {
2364 maud::html! {
2365 div.card {
2366 table.notes.striped {
2367 thead { tr { th { "Title" } } }
2368 tbody.rows {
2369 @for id in 1..=3u32 {
2370 tr.note-row.is-clickable data-id=(id) {
2371 td.title {
2372 span.wrap {
2373 a.link href=(format!("/notes/{id}")) data-turbo="true" {
2374 "Note " (id)
2375 }
2376 }
2377 }
2378 }
2379 }
2380 }
2381 }
2382 }
2383 }
2384 }
2385
2386 async fn note_row_fragment() -> maud::Markup {
2388 maud::html! {
2389 tr.note-row #note-7 {
2390 td.title { a.link href="/notes/7" { "Note 7" } }
2391 }
2392 }
2393 }
2394
2395 fn client(
2396 path: &str,
2397 handler: axum::routing::MethodRouter<crate::state::AppState>,
2398 ) -> TestClient {
2399 let router = axum::Router::<crate::state::AppState>::new().route(path, handler);
2400 TestApp::new().merge(router).build()
2401 }
2402
2403 #[tokio::test]
2404 async fn counts_rows_by_tag_and_class() {
2405 let resp = client("/notes", get(notes_index_v1))
2406 .get("/notes")
2407 .send()
2408 .await;
2409 resp.assert_ok()
2410 .assert_selector("table.notes")
2411 .assert_selector_count("tbody tr", 3)
2412 .assert_selector_count("tr.note-row", 3)
2413 .assert_no_selector("form");
2414 }
2415
2416 #[tokio::test]
2417 async fn reads_text_and_attributes() {
2418 let resp = client("/notes", get(notes_index_v1))
2419 .get("/notes")
2420 .send()
2421 .await;
2422 resp.assert_text("tr.note-row td.title a", "Note 1")
2423 .assert_text_contains("tr.note-row", "Note 1")
2424 .assert_attr("tr.note-row td a", "href", "/notes/1");
2425
2426 let links = resp.selector_text("tr.note-row a");
2428 assert_eq!(links, vec!["Note 1", "Note 2", "Note 3"]);
2429 let hrefs = resp.selector_attr("tr.note-row a", "href");
2430 assert_eq!(
2431 hrefs,
2432 vec![
2433 Some("/notes/1".to_string()),
2434 Some("/notes/2".to_string()),
2435 Some("/notes/3".to_string()),
2436 ]
2437 );
2438 assert_eq!(resp.selector_count("tr.note-row"), 3);
2439 }
2440
2441 #[tokio::test]
2444 async fn survives_cosmetic_refactor() {
2445 for handler in [get(notes_index_v1), get(notes_index_v2)] {
2446 let resp = client("/notes", handler).get("/notes").send().await;
2447 resp.assert_ok()
2448 .assert_selector_count("tbody tr.note-row", 3);
2450 let hrefs = resp.selector_attr("tbody tr.note-row a", "href");
2451 assert_eq!(
2452 hrefs,
2453 vec![
2454 Some("/notes/1".to_string()),
2455 Some("/notes/2".to_string()),
2456 Some("/notes/3".to_string()),
2457 ],
2458 "row links must survive the refactor"
2459 );
2460 }
2461 }
2462
2463 #[tokio::test]
2466 async fn works_for_htmx_fragment() {
2467 let resp = client("/rows/7", get(note_row_fragment))
2468 .get("/rows/7")
2469 .send()
2470 .await;
2471 resp.assert_selector("tr.note-row")
2472 .assert_selector("tr#note-7")
2473 .assert_attr("tr#note-7 a", "href", "/notes/7")
2474 .assert_text("tr#note-7 a.link", "Note 7");
2475 }
2476
2477 #[tokio::test]
2478 async fn id_and_attribute_selectors() {
2479 let resp = client("/rows/7", get(note_row_fragment))
2480 .get("/rows/7")
2481 .send()
2482 .await;
2483 resp.assert_selector("#note-7")
2484 .assert_selector("a[href=\"/notes/7\"]")
2485 .assert_selector("a[href^=\"/notes/\"]")
2486 .assert_no_selector("a[href=\"/other\"]");
2487 }
2488
2489 #[tokio::test]
2490 #[should_panic(expected = "expected 5 element(s) matching selector")]
2491 async fn count_mismatch_panics_with_actionable_message() {
2492 let resp = client("/notes", get(notes_index_v1))
2493 .get("/notes")
2494 .send()
2495 .await;
2496 resp.assert_selector_count("tr.note-row", 5);
2497 }
2498
2499 #[tokio::test]
2500 #[should_panic(expected = "no elements matched selector `table.missing`")]
2501 async fn missing_selector_panics() {
2502 let resp = client("/notes", get(notes_index_v1))
2503 .get("/notes")
2504 .send()
2505 .await;
2506 resp.assert_selector("table.missing");
2507 }
2508 }
2509
2510 #[tokio::test]
2513 async fn test_app_routes_invalid_method_override_rejected() {
2514 let client = TestApp::new().routes(test_routes()).build();
2515
2516 client
2517 .post("/create")
2518 .form("_method=BREW")
2519 .send()
2520 .await
2521 .assert_status(400);
2522 }
2523
2524 #[tokio::test]
2532 async fn invalid_method_override_response_carries_framework_middleware() {
2533 let client = TestApp::new().routes(test_routes()).build();
2534
2535 let response = client.post("/create").form("_method=BREW").send().await;
2536 response.assert_status(400);
2537
2538 assert!(
2544 response.header("x-request-id").is_some(),
2545 "framework request-id header must wrap method-override rejections; \
2546 observed headers: {:?}",
2547 response.headers
2548 );
2549 assert!(
2552 response.header("x-content-type-options").is_some(),
2553 "framework security headers must wrap method-override rejections; \
2554 observed headers: {:?}",
2555 response.headers
2556 );
2557 }
2558}