#![allow(clippy::type_complexity, clippy::too_many_lines)]
use axum::body::Body;
use axum::http::{Method, Request, StatusCode};
use tower::ServiceExt;
use crate::config::AutumnConfig;
use crate::route::Route;
use crate::state::AppState;
#[cfg(feature = "db")]
use diesel_async::AsyncPgConnection;
#[cfg(feature = "db")]
use diesel_async::RunQueryDsl;
#[cfg(feature = "db")]
use diesel_async::pooled_connection::deadpool::Pool;
pub struct TestApp {
routes: Vec<Route>,
scoped_groups: Vec<crate::app::ScopedGroup>,
merge_routers: Vec<axum::Router<crate::state::AppState>>,
nest_routers: Vec<(String, axum::Router<crate::state::AppState>)>,
custom_layers: Vec<crate::app::CustomLayerRegistration>,
config: AutumnConfig,
#[cfg(feature = "openapi")]
openapi: Option<crate::openapi::OpenApiConfig>,
#[cfg(feature = "mcp")]
mcp: Option<crate::mcp::McpRuntime>,
#[cfg(feature = "db")]
pool: Option<Pool<AsyncPgConnection>>,
#[cfg(feature = "db")]
replica_pool: Option<Pool<AsyncPgConnection>>,
#[cfg(feature = "db")]
transactional: bool,
#[cfg(feature = "db")]
transactional_url: Option<String>,
policy_registrations: Vec<TestPolicyRegistration>,
forbidden_response_override: Option<crate::authorization::ForbiddenResponse>,
#[cfg(feature = "mail")]
mail_interceptor: Option<std::sync::Arc<dyn crate::interceptor::MailInterceptor>>,
job_interceptor: Option<std::sync::Arc<dyn crate::interceptor::JobInterceptor>>,
#[cfg(feature = "db")]
db_interceptor: Option<std::sync::Arc<dyn crate::interceptor::DbConnectionInterceptor>>,
#[cfg(feature = "ws")]
channels_interceptor: Option<std::sync::Arc<dyn crate::interceptor::ChannelsInterceptor>>,
#[cfg(feature = "oauth2")]
http_interceptor: Option<std::sync::Arc<dyn crate::interceptor::HttpInterceptor>>,
#[cfg(feature = "http-client")]
http_mock_registry: Option<std::sync::Arc<crate::http_client::MockRegistry>>,
state_initializers: Vec<Box<dyn FnOnce(&AppState) + Send>>,
jobs: Vec<crate::job::JobInfo>,
exception_filters: Vec<std::sync::Arc<dyn crate::middleware::ExceptionFilter>>,
registered_plugins: std::collections::HashSet<String>,
extensions: std::collections::HashMap<std::any::TypeId, Box<dyn std::any::Any + Send>>,
clock: Option<std::sync::Arc<dyn crate::time::ClockSource>>,
clock_as_any: Option<std::sync::Arc<dyn std::any::Any + Send + Sync>>,
api_versions: Vec<crate::app::ApiVersion>,
metrics_sources: Vec<(String, std::sync::Arc<dyn crate::actuator::MetricsSource>)>,
health_indicators: Vec<(
String,
crate::actuator::IndicatorGroup,
std::sync::Arc<dyn crate::actuator::HealthIndicator>,
)>,
#[cfg(feature = "inbound-mail")]
inbound_mail_router: Option<std::sync::Arc<crate::inbound_mail::InboundMailRouter>>,
}
type TestPolicyRegistration = Box<dyn FnOnce(&crate::authorization::PolicyRegistry) + Send>;
impl TestApp {
#[must_use]
pub fn new() -> Self {
let mut config = AutumnConfig::default();
config.profile = Some("test".into());
config.security.csrf.enabled = false;
Self {
routes: Vec::new(),
scoped_groups: Vec::new(),
merge_routers: Vec::new(),
nest_routers: Vec::new(),
custom_layers: Vec::new(),
config,
#[cfg(feature = "openapi")]
openapi: None,
#[cfg(feature = "mcp")]
mcp: None,
#[cfg(feature = "db")]
pool: None,
#[cfg(feature = "db")]
replica_pool: None,
#[cfg(feature = "db")]
transactional: false,
#[cfg(feature = "db")]
transactional_url: None,
policy_registrations: Vec::new(),
forbidden_response_override: None,
#[cfg(feature = "mail")]
mail_interceptor: None,
job_interceptor: None,
#[cfg(feature = "db")]
db_interceptor: None,
#[cfg(feature = "ws")]
channels_interceptor: None,
#[cfg(feature = "oauth2")]
http_interceptor: None,
#[cfg(feature = "http-client")]
http_mock_registry: None,
state_initializers: Vec::new(),
jobs: Vec::new(),
exception_filters: Vec::new(),
registered_plugins: std::collections::HashSet::new(),
extensions: std::collections::HashMap::new(),
clock: None,
clock_as_any: None,
api_versions: Vec::new(),
metrics_sources: Vec::new(),
health_indicators: Vec::new(),
#[cfg(feature = "inbound-mail")]
inbound_mail_router: None,
}
}
#[must_use]
pub fn policy<R, P>(mut self, policy: P) -> Self
where
R: Send + Sync + 'static,
P: crate::authorization::Policy<R>,
{
self.policy_registrations.push(Box::new(move |registry| {
registry.register_policy::<R, _>(policy);
}));
self
}
#[must_use]
pub fn scope<R, S>(mut self, scope: S) -> Self
where
R: Send + Sync + 'static,
S: crate::authorization::Scope<R>,
{
self.policy_registrations.push(Box::new(move |registry| {
registry.register_scope::<R, _>(scope);
}));
self
}
#[cfg(feature = "inbound-mail")]
#[must_use]
pub fn inbound_mail_router(mut self, router: crate::inbound_mail::InboundMailRouter) -> Self {
self.inbound_mail_router = Some(std::sync::Arc::new(router));
self
}
#[must_use]
pub const fn forbidden_response(
mut self,
value: crate::authorization::ForbiddenResponse,
) -> Self {
self.forbidden_response_override = Some(value);
self
}
#[cfg(feature = "openapi")]
#[must_use]
pub fn openapi(mut self, config: crate::openapi::OpenApiConfig) -> Self {
self.openapi = Some(config);
self
}
#[cfg(feature = "mcp")]
#[must_use]
pub fn mount_mcp(mut self, path: impl Into<String>) -> Self {
let path = path.into();
if let Some(rt) = self.mcp.as_mut() {
rt.mount_path = path;
} else {
self.mcp = Some(crate::mcp::McpRuntime::new(path));
}
self
}
#[cfg(feature = "mcp")]
#[must_use]
pub fn expose_all_as_mcp(mut self) -> Self {
if let Some(rt) = self.mcp.as_mut() {
rt.expose_all = true;
} else {
let mut rt = crate::mcp::McpRuntime::new("/mcp");
rt.expose_all = true;
self.mcp = Some(rt);
}
self
}
#[cfg(feature = "mcp")]
#[must_use]
pub fn secure_mcp<L>(mut self, layer: L) -> Self
where
L: tower::Layer<axum::routing::Route> + Clone + Send + Sync + 'static,
L::Service: tower::Service<
axum::http::Request<axum::body::Body>,
Response = axum::http::Response<axum::body::Body>,
Error = std::convert::Infallible,
> + Clone
+ Send
+ Sync
+ 'static,
<L::Service as tower::Service<axum::http::Request<axum::body::Body>>>::Future:
Send + 'static,
{
let applier: crate::mcp::McpEndpointLayer = Box::new(move |router| router.layer(layer));
if let Some(rt) = self.mcp.as_mut() {
rt.endpoint_layer = Some(applier);
} else {
let mut rt = crate::mcp::McpRuntime::new("/mcp");
rt.endpoint_layer = Some(applier);
self.mcp = Some(rt);
}
self
}
#[must_use]
pub fn merge(mut self, router: axum::Router<crate::state::AppState>) -> Self {
self.merge_routers.push(router);
self
}
#[must_use]
pub fn scoped<L>(mut self, prefix: &str, layer: L, routes: Vec<Route>) -> Self
where
L: tower::Layer<axum::routing::Route> + Clone + Send + Sync + 'static,
L::Service: tower::Service<
axum::http::Request<axum::body::Body>,
Response = axum::http::Response<axum::body::Body>,
Error = std::convert::Infallible,
> + Clone
+ Send
+ Sync
+ 'static,
<L::Service as tower::Service<axum::http::Request<axum::body::Body>>>::Future:
Send + 'static,
{
self.scoped_groups.push(crate::app::ScopedGroup {
prefix: prefix.to_owned(),
routes,
source: crate::route_listing::RouteSource::User,
apply_layer: Box::new(move |router| router.layer(layer)),
});
self
}
#[must_use]
pub fn nest(mut self, path: &str, router: axum::Router<crate::state::AppState>) -> Self {
self.nest_routers.push((path.to_owned(), router));
self
}
#[must_use]
pub fn layer<L: crate::app::IntoAppLayer>(mut self, layer: L) -> Self {
self.custom_layers
.push(crate::app::CustomLayerRegistration {
type_id: std::any::TypeId::of::<L>(),
type_name: std::any::type_name::<L>(),
apply: Box::new(move |router| layer.apply_to(router)),
});
self
}
#[cfg(feature = "reporting")]
#[must_use]
pub fn with_error_reporter<R: crate::reporting::ErrorReporter>(mut self, reporter: R) -> Self {
let reporter =
std::sync::Arc::new(reporter) as std::sync::Arc<dyn crate::reporting::ErrorReporter>;
self.state_initializers.push(Box::new(move |state| {
let mut reporters = state
.extension::<crate::reporting::RegisteredReporters>()
.map(|registered| registered.0.clone())
.unwrap_or_default();
reporters.push(reporter.clone());
state.insert_extension(crate::reporting::RegisteredReporters(reporters));
}));
self
}
#[must_use]
pub const fn idempotent(mut self) -> Self {
self.config.idempotency.enabled = Some(true);
self
}
#[must_use]
pub fn from_router(router: axum::Router, state: AppState) -> TestClient {
TestClient {
router,
probes: crate::probe::ProbeState::ready_for_test(),
state,
_job_runtime: None,
clock_as_any: None,
}
}
#[must_use]
pub fn routes(mut self, routes: Vec<Route>) -> Self {
self.routes.extend(routes);
self
}
#[must_use]
pub fn state_initializer<F>(mut self, f: F) -> Self
where
F: FnOnce(&AppState) + Send + 'static,
{
self.state_initializers.push(Box::new(f));
self
}
#[must_use]
pub fn with_flag_store<S>(mut self, store: S) -> Self
where
S: crate::feature_flags::FlagStore,
{
use std::sync::Arc;
let service = crate::feature_flags::FeatureFlagService::new(Arc::new(store) as Arc<_>);
self.state_initializers.push(Box::new(move |state| {
state.insert_extension(service);
}));
self
}
#[must_use]
pub fn plugin<P: crate::plugin::Plugin>(mut self, plugin: P) -> Self {
let name = plugin.name().into_owned();
if self.registered_plugins.contains(&name) {
tracing::warn!(plugin = %name, "Duplicate plugin registration in TestApp; skipping");
return self;
}
let mut app_builder = crate::app();
app_builder
.registered_plugins
.clone_from(&self.registered_plugins);
app_builder.extensions = self.extensions;
app_builder.state_initializers = std::mem::take(&mut self.state_initializers);
app_builder = app_builder.plugin(plugin);
self.registered_plugins = app_builder.registered_plugins;
self.extensions = app_builder.extensions;
self.state_initializers = app_builder.state_initializers;
self.routes.extend(app_builder.routes);
self.scoped_groups.extend(app_builder.scoped_groups);
self.merge_routers.extend(app_builder.merge_routers);
self.nest_routers.extend(app_builder.nest_routers);
self.custom_layers.extend(app_builder.custom_layers);
self.jobs.extend(app_builder.jobs);
self.exception_filters.extend(app_builder.exception_filters);
self.metrics_sources.extend(app_builder.metrics_sources);
self.health_indicators.extend(app_builder.health_indicators);
#[cfg(feature = "inbound-mail")]
if let Some(router) = app_builder.inbound_mail_router {
self.inbound_mail_router = Some(router);
}
#[cfg(feature = "reporting")]
{
let reporters = std::mem::take(&mut app_builder.error_reporters);
if !reporters.is_empty() {
self.state_initializers.push(Box::new(move |state| {
let mut existing = state
.extension::<crate::reporting::RegisteredReporters>()
.map(|registered| registered.0.clone())
.unwrap_or_default();
existing.extend(reporters.iter().cloned());
state.insert_extension(crate::reporting::RegisteredReporters(existing));
}));
}
}
for hook in app_builder.startup_hooks {
self.state_initializers.push(Box::new(move |state| {
let state_owned = state.clone();
if let Ok(handle) = tokio::runtime::Handle::try_current() {
let thread_handle =
std::thread::spawn(move || handle.block_on(hook(state_owned)));
thread_handle
.join()
.expect("Plugin startup hook thread panicked")
.expect("Plugin startup hook failed");
} else {
let thread_handle = std::thread::spawn(move || {
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.expect("failed to build tokio runtime for test plugin startup hook");
rt.block_on(hook(state_owned))
});
thread_handle
.join()
.expect("Plugin startup hook thread panicked")
.expect("Plugin startup hook failed");
}
}));
}
self
}
#[cfg(feature = "mail")]
#[must_use]
pub fn with_mail_interceptor(
mut self,
interceptor: impl crate::interceptor::MailInterceptor,
) -> Self {
self.mail_interceptor = Some(std::sync::Arc::new(interceptor));
self
}
#[must_use]
pub fn with_job_interceptor(
mut self,
interceptor: impl crate::interceptor::JobInterceptor,
) -> Self {
self.job_interceptor = Some(std::sync::Arc::new(interceptor));
self
}
#[cfg(feature = "db")]
#[must_use]
pub fn with_db_interceptor(
mut self,
interceptor: impl crate::interceptor::DbConnectionInterceptor,
) -> Self {
self.db_interceptor = Some(std::sync::Arc::new(interceptor));
self
}
#[cfg(feature = "ws")]
#[must_use]
pub fn with_channels_interceptor(
mut self,
interceptor: impl crate::interceptor::ChannelsInterceptor,
) -> Self {
self.channels_interceptor = Some(std::sync::Arc::new(interceptor));
self
}
#[cfg(feature = "oauth2")]
#[must_use]
pub fn with_http_interceptor(
mut self,
interceptor: impl crate::interceptor::HttpInterceptor,
) -> Self {
self.http_interceptor = Some(std::sync::Arc::new(interceptor));
self
}
#[must_use]
pub fn config(mut self, config: AutumnConfig) -> Self {
self.config = config;
self
}
#[must_use]
pub fn profile(mut self, profile: &str) -> Self {
self.config.profile = Some(profile.to_owned());
self
}
#[must_use]
pub fn with_clock<C>(mut self, clock: C) -> Self
where
C: crate::time::ClockSource + 'static,
{
let arc: std::sync::Arc<C> = std::sync::Arc::new(clock);
self.clock_as_any = Some(arc.clone() as std::sync::Arc<dyn std::any::Any + Send + Sync>);
self.clock = Some(arc as std::sync::Arc<dyn crate::time::ClockSource>);
self
}
#[must_use]
pub fn api_version(mut self, version: crate::app::ApiVersion) -> Self {
self.api_versions.push(version);
self
}
#[must_use]
pub fn api_versions(
mut self,
versions: impl IntoIterator<Item = crate::app::ApiVersion>,
) -> Self {
self.api_versions.extend(versions);
self
}
#[cfg(feature = "db")]
#[must_use]
pub fn with_db(mut self, pool: Pool<AsyncPgConnection>) -> Self {
self.pool = Some(pool);
self
}
#[cfg(feature = "db")]
#[must_use]
pub const fn transactional(mut self) -> Self {
self.transactional = true;
self
}
#[cfg(feature = "db")]
#[must_use]
pub fn with_transactional_db(mut self, url: impl Into<String>) -> Self {
self.transactional = true;
self.transactional_url = Some(url.into());
self
}
#[cfg(feature = "http-client")]
pub fn http_mock(&mut self, alias: &str) -> crate::http_client::MockSetupBuilder {
let registry = self
.http_mock_registry
.get_or_insert_with(|| std::sync::Arc::new(crate::http_client::MockRegistry::new()))
.clone();
crate::http_client::MockSetupBuilder {
registry,
alias: alias.to_owned(),
method: None,
path: None,
}
}
#[must_use]
#[cfg_attr(not(feature = "inbound-mail"), allow(unused_mut))]
pub fn build(mut self) -> TestClient {
crate::cache::clear_global_cache();
#[cfg(feature = "db")]
let (pool, replica_pool, db_interceptor) = if self.transactional {
let url = self.transactional_url.as_deref()
.or_else(|| self.config.database.effective_primary_url())
.expect("Transactional isolation enabled but database URL is not configured. Use `with_transactional_db(url)` or configure database.primary_url/database.url");
let connect_timeout_secs = self.config.database.connect_timeout_secs;
let timeout = std::time::Duration::from_secs(connect_timeout_secs);
let manager = diesel_async::pooled_connection::AsyncDieselConnectionManager::<
diesel_async::AsyncPgConnection,
>::new(url);
let pool = Pool::builder(manager)
.max_size(1)
.wait_timeout(Some(timeout))
.create_timeout(Some(timeout))
.runtime(deadpool::Runtime::Tokio1)
.post_create(deadpool::managed::Hook::async_fn(
|conn: &mut diesel_async::AsyncPgConnection, _metrics| {
Box::pin(async move {
use diesel_async::AsyncConnection;
use diesel_async::RunQueryDsl;
conn.begin_test_transaction().await.map_err(|e| {
deadpool::managed::HookError::Backend(
diesel_async::pooled_connection::PoolError::QueryError(e),
)
})?;
diesel::sql_query("SET autumn.test_transaction_started = 'true'")
.execute(conn)
.await
.map_err(|e| {
deadpool::managed::HookError::Backend(
diesel_async::pooled_connection::PoolError::QueryError(e),
)
})?;
Ok(())
})
},
))
.build()
.expect("failed to build transactional pool of size 1");
let trans_interceptor = std::sync::Arc::new(TransactionalDbInterceptor);
let interceptor = if let Some(user_interceptor) = self.db_interceptor {
std::sync::Arc::new(ComposedDbInterceptor {
first: user_interceptor,
second: trans_interceptor,
})
as std::sync::Arc<dyn crate::interceptor::DbConnectionInterceptor>
} else {
trans_interceptor as std::sync::Arc<dyn crate::interceptor::DbConnectionInterceptor>
};
(Some(pool), None, Some(interceptor))
} else {
(self.pool, self.replica_pool, self.db_interceptor)
};
let probes = crate::probe::ProbeState::ready_for_test();
#[cfg(feature = "ws")]
let test_channels = crate::channels::Channels::new(32);
#[cfg_attr(not(feature = "ws"), allow(unused_mut))]
let mut state = AppState {
extensions: std::sync::Arc::new(std::sync::RwLock::new(
std::collections::HashMap::new(),
)),
#[cfg(feature = "db")]
pool,
#[cfg(feature = "db")]
replica_pool,
profile: self.config.profile.clone(),
started_at: std::time::Instant::now(),
health_detailed: self.config.health.detailed,
probes: probes.clone(),
metrics: crate::middleware::MetricsCollector::new(),
log_levels: crate::actuator::LogLevels::new(&self.config.log.level),
task_registry: crate::actuator::TaskRegistry::new(),
job_registry: crate::actuator::JobRegistry::new(),
config_props: crate::actuator::ConfigProperties::default(),
metrics_source_registry: crate::actuator::MetricsSourceRegistry::new(),
health_indicator_registry: crate::actuator::HealthIndicatorRegistry::new(),
#[cfg(feature = "presence")]
presence: crate::presence::Presence::new(test_channels.clone()),
#[cfg(feature = "ws")]
channels: test_channels,
#[cfg(feature = "ws")]
shutdown: tokio_util::sync::CancellationToken::new(),
policy_registry: crate::authorization::PolicyRegistry::default(),
forbidden_response: self
.forbidden_response_override
.unwrap_or(self.config.security.forbidden_response),
auth_session_key: self.config.auth.session_key.clone(),
shared_cache: None,
clock: self
.clock
.unwrap_or_else(|| std::sync::Arc::new(crate::time::SystemClock)),
};
for register in self.policy_registrations {
register(state.policy_registry());
}
state.insert_extension(crate::app::RegisteredApiVersions(self.api_versions));
crate::app::install_webhook_registry(&state, &self.config);
state.insert_extension(self.config.clone());
#[cfg(feature = "mail")]
if let Some(interceptor) = self.mail_interceptor {
state.insert_extension(interceptor);
}
if let Some(interceptor) = self.job_interceptor {
state.insert_extension(interceptor);
}
#[cfg(feature = "db")]
if let Some(interceptor) = db_interceptor {
state.insert_extension(interceptor);
}
#[cfg(feature = "ws")]
if let Some(interceptor) = self.channels_interceptor {
state.insert_extension(interceptor.clone());
state.channels = crate::channels::Channels::with_shared_backend(std::sync::Arc::new(
crate::channels::InterceptedChannelsBackend::new(
state.channels.backend().clone(),
vec![interceptor],
),
));
#[cfg(feature = "presence")]
{
state.presence = crate::presence::Presence::new(state.channels.clone());
}
}
#[cfg(feature = "oauth2")]
if let Some(interceptor) = self.http_interceptor {
state.insert_extension(interceptor);
}
#[cfg(feature = "mail")]
{
crate::mail::install_mailer(&state, &self.config.mail, false)
.expect("Failed to configure test mailer");
}
#[cfg(feature = "http-client")]
state.insert_extension(self.config.http.clone());
#[cfg(feature = "http-client")]
if let Some(registry) = self.http_mock_registry {
state.insert_extension(crate::http_client::HttpMockRegistryExt(registry));
}
for (name, source) in self.metrics_sources {
if let Err(e) = state.metrics_source_registry.register(name, source) {
tracing::warn!("{e}");
}
}
for (name, group, indicator) in self.health_indicators {
if let Err(e) = state
.health_indicator_registry
.register(name, group, indicator)
{
tracing::warn!("{e}");
}
}
for initializer in self.state_initializers {
initializer(&state);
}
for job in &self.jobs {
state.job_registry.register(&job.name);
}
let job_runtime = if self.jobs.is_empty() {
None
} else {
let shutdown = tokio_util::sync::CancellationToken::new();
crate::job::start_runtime(self.jobs.clone(), &state, &shutdown, &self.config.jobs)
.expect("Failed to start job runtime in test");
Some(TestJobRuntime { shutdown })
};
#[cfg_attr(not(feature = "inbound-mail"), allow(unused_mut))]
let mut merge_routers = self.merge_routers;
#[cfg(feature = "inbound-mail")]
if let Some(ref im_router) = self.inbound_mail_router {
let mut registered_inbound: std::collections::HashSet<String> =
std::collections::HashSet::new();
for (path, axum_router) in crate::inbound_mail::build_routes(im_router) {
if self
.routes
.iter()
.any(|r| r.method == Method::POST && r.path == path)
|| self.scoped_groups.iter().any(|g| {
g.routes.iter().any(|r| {
r.method == Method::POST
&& crate::router::join_nested_path(&g.prefix, r.path)
== path.as_str()
})
})
|| self.nest_routers.iter().any(|(nest_path, _)| {
let p = nest_path.as_str();
path.as_str() == p
|| path.starts_with(p)
&& (p.ends_with('/') || path.as_bytes().get(p.len()) == Some(&b'/'))
})
{
tracing::warn!(
path = %path,
"inbound_mail: skipping webhook route — a POST handler is \
already registered at this path by the application"
);
continue;
}
if !registered_inbound.insert(path.clone()) {
tracing::warn!(
path = %path,
"inbound_mail: skipping duplicate inbound webhook path"
);
continue;
}
self.config.security.csrf.exempt_paths.push(path.clone());
self.config.security.captcha_exempt_paths.push(path);
merge_routers.push(axum_router);
}
}
let router = crate::router::try_build_router_inner(
self.routes,
&self.config,
state.clone(),
crate::router::RouterContext {
exception_filters: self.exception_filters,
scoped_groups: self.scoped_groups,
merge_routers,
nest_routers: self.nest_routers,
custom_layers: self.custom_layers,
error_page_renderer: None,
session_store: None,
#[cfg(feature = "openapi")]
openapi: self.openapi,
#[cfg(feature = "mcp")]
mcp: self.mcp,
},
)
.expect("failed to build test router");
let router = if self.config.log.access_log {
router.layer(crate::middleware::AccessLogLayer::fallback(
self.config.log.access_log_exclude.clone(),
))
} else {
router
};
TestClient {
router,
probes,
state,
_job_runtime: job_runtime,
clock_as_any: self.clock_as_any,
}
}
}
impl Default for TestApp {
fn default() -> Self {
Self::new()
}
}
pub struct TestClient {
router: axum::Router,
probes: crate::probe::ProbeState,
pub(crate) state: AppState,
_job_runtime: Option<TestJobRuntime>,
clock_as_any: Option<std::sync::Arc<dyn std::any::Any + Send + Sync>>,
}
struct TestJobRuntime {
shutdown: tokio_util::sync::CancellationToken,
}
impl Drop for TestJobRuntime {
fn drop(&mut self) {
self.shutdown.cancel();
crate::job::clear_global_job_client();
}
}
impl TestClient {
#[must_use]
pub const fn state(&self) -> &AppState {
&self.state
}
pub fn advance_clock(&self, duration: std::time::Duration) {
if let Some(any) = &self.clock_as_any {
let cloned = std::sync::Arc::clone(any);
if let Ok(ticking) = cloned.downcast::<crate::time::TickingClock>() {
ticking.advance(duration);
}
}
}
pub fn into_router(self) -> axum::Router {
self.router
}
pub const fn probes(&self) -> &crate::probe::ProbeState {
&self.probes
}
#[must_use]
pub fn get(&self, uri: &str) -> RequestBuilder {
RequestBuilder::new(self.router.clone(), Method::GET, uri)
}
#[must_use]
pub fn post(&self, uri: &str) -> RequestBuilder {
RequestBuilder::new(self.router.clone(), Method::POST, uri)
}
#[must_use]
pub fn put(&self, uri: &str) -> RequestBuilder {
RequestBuilder::new(self.router.clone(), Method::PUT, uri)
}
#[must_use]
pub fn delete(&self, uri: &str) -> RequestBuilder {
RequestBuilder::new(self.router.clone(), Method::DELETE, uri)
}
#[must_use]
pub fn patch(&self, uri: &str) -> RequestBuilder {
RequestBuilder::new(self.router.clone(), Method::PATCH, uri)
}
#[must_use]
pub fn options(&self, uri: &str) -> RequestBuilder {
RequestBuilder::new(self.router.clone(), Method::OPTIONS, uri)
}
}
pub struct RequestBuilder {
router: axum::Router,
method: Method,
uri: String,
headers: Vec<(String, String)>,
body: Body,
}
impl RequestBuilder {
fn new(router: axum::Router, method: Method, uri: &str) -> Self {
Self {
router,
method,
uri: uri.to_owned(),
headers: Vec::new(),
body: Body::empty(),
}
}
#[must_use]
pub fn header(mut self, name: &str, value: &str) -> Self {
self.headers.push((name.to_owned(), value.to_owned()));
self
}
#[must_use]
pub fn json(mut self, value: &serde_json::Value) -> Self {
self.headers
.push(("content-type".to_owned(), "application/json".to_owned()));
self.body = Body::from(serde_json::to_vec(value).expect("failed to serialize JSON body"));
self
}
#[must_use]
pub fn form(mut self, body: &str) -> Self {
self.headers.push((
"content-type".to_owned(),
"application/x-www-form-urlencoded".to_owned(),
));
self.headers
.push(("sec-fetch-site".to_owned(), "same-origin".to_owned()));
self.body = Body::from(body.to_owned());
self
}
#[must_use]
pub fn body(mut self, body: impl Into<Body>) -> Self {
self.body = body.into();
self
}
pub async fn send(self) -> TestResponse {
let mut builder = Request::builder().method(self.method).uri(&self.uri);
for (name, value) in &self.headers {
builder = builder.header(name.as_str(), value.as_str());
}
let request = builder.body(self.body).expect("failed to build request");
let service =
tower::Layer::layer(&crate::middleware::MethodOverrideLayer::new(), self.router);
let response = service.oneshot(request).await.expect("request failed");
let status = response.status();
let headers: Vec<(String, String)> = response
.headers()
.iter()
.map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_owned()))
.collect();
let body_bytes = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.expect("failed to read response body");
TestResponse {
status,
headers,
body: body_bytes.to_vec(),
}
}
}
pub struct TestResponse {
pub status: StatusCode,
pub headers: Vec<(String, String)>,
pub body: Vec<u8>,
}
impl TestResponse {
#[must_use]
pub fn text(&self) -> String {
String::from_utf8(self.body.clone()).unwrap_or_else(|e| {
panic!(
"response body is not valid UTF-8: {e}\nRaw bytes: {:?}",
self.body
)
})
}
#[must_use]
pub fn json<T: serde::de::DeserializeOwned>(&self) -> T {
serde_json::from_slice(&self.body).unwrap_or_else(|e| {
panic!(
"failed to parse response body as JSON: {e}\nBody: {}",
String::from_utf8_lossy(&self.body)
)
})
}
#[must_use]
pub fn header(&self, name: &str) -> Option<&str> {
let name_lower = name.to_lowercase();
self.headers
.iter()
.find(|(k, _)| k.to_lowercase() == name_lower)
.map(|(_, v)| v.as_str())
}
#[track_caller]
pub fn assert_ok(&self) -> &Self {
assert_eq!(
self.status,
StatusCode::OK,
"expected 200 OK, got {}.\nBody: {}",
self.status,
String::from_utf8_lossy(&self.body)
);
self
}
#[track_caller]
pub fn assert_status(&self, expected: u16) -> &Self {
assert_eq!(
self.status.as_u16(),
expected,
"expected status {expected}, got {}.\nBody: {}",
self.status,
String::from_utf8_lossy(&self.body)
);
self
}
#[track_caller]
pub fn assert_success(&self) -> &Self {
assert!(
self.status.is_success(),
"expected 2xx success, got {}.\nBody: {}",
self.status,
String::from_utf8_lossy(&self.body)
);
self
}
#[track_caller]
pub fn assert_header(&self, name: &str, expected: &str) -> &Self {
let value = self.header(name).unwrap_or_else(|| {
panic!(
"expected header `{name}` to be present.\nAvailable headers: {:?}",
self.headers
)
});
assert_eq!(
value, expected,
"header `{name}`: expected `{expected}`, got `{value}`"
);
self
}
#[track_caller]
pub fn assert_header_contains(&self, name: &str, substring: &str) -> &Self {
let value = self.header(name).unwrap_or_else(|| {
panic!(
"expected header `{name}` to be present.\nAvailable headers: {:?}",
self.headers
)
});
assert!(
value.contains(substring),
"header `{name}`: expected `{value}` to contain `{substring}`"
);
self
}
#[track_caller]
pub fn assert_body_contains(&self, substring: &str) -> &Self {
let body = self.text();
assert!(
body.contains(substring),
"expected body to contain `{substring}`.\nBody: {body}"
);
self
}
#[track_caller]
pub fn assert_body_eq(&self, expected: &str) -> &Self {
let body = self.text();
assert_eq!(body, expected, "body mismatch.\nActual Body: {body}");
self
}
#[track_caller]
pub fn assert_json<T, F>(&self, predicate: F) -> &Self
where
T: serde::de::DeserializeOwned,
F: FnOnce(&T),
{
let value: T = self.json();
predicate(&value);
self
}
#[track_caller]
pub fn assert_body_empty(&self) -> &Self {
assert!(
self.body.is_empty(),
"expected empty body, got {} bytes: {}",
self.body.len(),
String::from_utf8_lossy(&self.body)
);
self
}
fn parse_html(&self) -> Vec<crate::test_html::Node> {
crate::test_html::parse(&self.text())
}
#[track_caller]
fn compile_selector(css: &str) -> crate::test_html::SelectorList {
crate::test_html::SelectorList::parse(css)
.unwrap_or_else(|e| panic!("invalid CSS selector `{css}`: {e}"))
}
fn html_outline(nodes: &[crate::test_html::Node]) -> String {
crate::test_html::outline(nodes, 1200)
}
#[must_use]
#[track_caller]
pub fn selector_text(&self, css: &str) -> Vec<String> {
let selector = Self::compile_selector(css);
let nodes = self.parse_html();
selector
.matches(&nodes)
.iter()
.map(|el| crate::test_html::normalize_ws(&el.text()))
.collect()
}
#[must_use]
#[track_caller]
pub fn selector_attr(&self, css: &str, attr: &str) -> Vec<Option<String>> {
let selector = Self::compile_selector(css);
let nodes = self.parse_html();
selector
.matches(&nodes)
.iter()
.map(|el| el.attr(attr).map(str::to_string))
.collect()
}
#[must_use]
#[track_caller]
pub fn selector_count(&self, css: &str) -> usize {
let selector = Self::compile_selector(css);
let nodes = self.parse_html();
selector.matches(&nodes).len()
}
#[track_caller]
pub fn assert_selector(&self, css: &str) -> &Self {
let selector = Self::compile_selector(css);
let nodes = self.parse_html();
let count = selector.matches(&nodes).len();
assert!(
count > 0,
"no elements matched selector `{css}`.\nParsed HTML:\n{}",
Self::html_outline(&nodes)
);
self
}
#[track_caller]
pub fn assert_no_selector(&self, css: &str) -> &Self {
let selector = Self::compile_selector(css);
let nodes = self.parse_html();
let count = selector.matches(&nodes).len();
assert!(
count == 0,
"expected no elements matching selector `{css}`, but found {count}.\nParsed HTML:\n{}",
Self::html_outline(&nodes)
);
self
}
#[track_caller]
pub fn assert_selector_count(&self, css: &str, expected: usize) -> &Self {
let selector = Self::compile_selector(css);
let nodes = self.parse_html();
let actual = selector.matches(&nodes).len();
assert!(
actual == expected,
"expected {expected} element(s) matching selector `{css}`, found {actual}.\n\
Parsed HTML:\n{}",
Self::html_outline(&nodes)
);
self
}
#[track_caller]
pub fn assert_text(&self, css: &str, expected: &str) -> &Self {
let selector = Self::compile_selector(css);
let nodes = self.parse_html();
let matched = selector.matches(&nodes);
let Some(first) = matched.into_iter().next() else {
panic!(
"no elements matched selector `{css}`.\nParsed HTML:\n{}",
Self::html_outline(&nodes)
);
};
let actual = crate::test_html::normalize_ws(&first.text());
let expected_norm = crate::test_html::normalize_ws(expected);
assert!(
actual == expected_norm,
"text mismatch for selector `{css}`:\n expected: {expected_norm:?}\n \
actual: {actual:?}\nParsed HTML:\n{}",
Self::html_outline(&nodes)
);
self
}
#[track_caller]
pub fn assert_text_contains(&self, css: &str, substring: &str) -> &Self {
let selector = Self::compile_selector(css);
let nodes = self.parse_html();
let matched = selector.matches(&nodes);
let Some(first) = matched.into_iter().next() else {
panic!(
"no elements matched selector `{css}`.\nParsed HTML:\n{}",
Self::html_outline(&nodes)
);
};
let actual = crate::test_html::normalize_ws(&first.text());
let needle = crate::test_html::normalize_ws(substring);
assert!(
actual.contains(&needle),
"text for selector `{css}` did not contain {needle:?}.\n actual: {actual:?}\n\
Parsed HTML:\n{}",
Self::html_outline(&nodes)
);
self
}
#[track_caller]
pub fn assert_attr(&self, css: &str, attr: &str, expected: &str) -> &Self {
let selector = Self::compile_selector(css);
let nodes = self.parse_html();
let matched = selector.matches(&nodes);
let Some(first) = matched.into_iter().next() else {
panic!(
"no elements matched selector `{css}`.\nParsed HTML:\n{}",
Self::html_outline(&nodes)
);
};
match first.attr(attr) {
Some(actual) => assert!(
actual == expected,
"attribute `{attr}` mismatch for selector `{css}`:\n expected: {expected:?}\n \
actual: {actual:?}\nParsed HTML:\n{}",
Self::html_outline(&nodes)
),
None => panic!(
"element matching selector `{css}` has no `{attr}` attribute.\n\
Parsed HTML:\n{}",
Self::html_outline(&nodes)
),
}
self
}
}
#[cfg(feature = "db")]
struct TransactionalDbInterceptor;
#[cfg(feature = "db")]
impl crate::interceptor::DbConnectionInterceptor for TransactionalDbInterceptor {
fn intercept_checkout<'a>(
&'a self,
_ctx: crate::interceptor::DbCheckoutContext,
next: std::pin::Pin<
Box<
dyn std::future::Future<
Output = Result<crate::db::PooledConnection, crate::AutumnError>,
> + Send
+ 'a,
>,
>,
) -> std::pin::Pin<
Box<
dyn std::future::Future<
Output = Result<crate::db::PooledConnection, crate::AutumnError>,
> + Send
+ 'a,
>,
> {
Box::pin(async move {
let mut conn = next.await?;
let guc_result = diesel::select(diesel::dsl::sql::<
diesel::sql_types::Nullable<diesel::sql_types::Text>,
>(
"current_setting('autumn.test_transaction_started', true)",
))
.get_result::<Option<String>>(&mut *conn)
.await;
match guc_result {
Ok(Some(ref s)) if s == "true" => {
}
Ok(_) => {
use diesel_async::AsyncConnection;
use diesel_async::RunQueryDsl;
conn.begin_test_transaction().await.map_err(|e| {
crate::AutumnError::internal_server_error_msg(format!(
"failed to start test transaction: {e}"
))
})?;
diesel::sql_query("SET autumn.test_transaction_started = 'true'")
.execute(&mut *conn)
.await
.map_err(|e| {
crate::AutumnError::internal_server_error_msg(format!(
"failed to set transaction session GUC: {e}"
))
})?;
}
Err(_) => {
}
}
Ok(conn)
})
}
fn is_transactional_test(&self) -> bool {
true
}
}
#[cfg(feature = "db")]
struct ComposedDbInterceptor {
first: std::sync::Arc<dyn crate::interceptor::DbConnectionInterceptor>,
second: std::sync::Arc<dyn crate::interceptor::DbConnectionInterceptor>,
}
#[cfg(feature = "db")]
impl crate::interceptor::DbConnectionInterceptor for ComposedDbInterceptor {
fn intercept_checkout<'a>(
&'a self,
ctx: crate::interceptor::DbCheckoutContext,
next: std::pin::Pin<
Box<
dyn std::future::Future<
Output = Result<crate::db::PooledConnection, crate::AutumnError>,
> + Send
+ 'a,
>,
>,
) -> std::pin::Pin<
Box<
dyn std::future::Future<
Output = Result<crate::db::PooledConnection, crate::AutumnError>,
> + Send
+ 'a,
>,
> {
let next_wrapped = self.second.intercept_checkout(ctx.clone(), next);
self.first.intercept_checkout(ctx, next_wrapped)
}
fn is_transactional_test(&self) -> bool {
self.first.is_transactional_test() || self.second.is_transactional_test()
}
}
#[cfg(all(feature = "db", feature = "test-support"))]
pub struct TestDb {
_container: testcontainers::ContainerAsync<testcontainers_modules::postgres::Postgres>,
pool: Pool<AsyncPgConnection>,
url: String,
}
#[cfg(all(feature = "db", feature = "test-support"))]
impl TestDb {
pub async fn new() -> Self {
use diesel_async::pooled_connection::AsyncDieselConnectionManager;
use testcontainers::runners::AsyncRunner;
use testcontainers_modules::postgres::Postgres;
let container = Postgres::default()
.start()
.await
.expect("failed to start Postgres testcontainer (is Docker running?)");
let host = container
.get_host()
.await
.expect("failed to build test router");
let port = container
.get_host_port_ipv4(5432)
.await
.expect("failed to build test router");
let url = format!("postgres://postgres:postgres@{host}:{port}/postgres");
let manager = AsyncDieselConnectionManager::<AsyncPgConnection>::new(&url);
let pool = Pool::builder(manager)
.max_size(5)
.build()
.expect("failed to build connection pool");
Self {
_container: container,
pool,
url,
}
}
pub async fn shared() -> &'static Self {
use std::sync::OnceLock;
use tokio::sync::OnceCell;
static CELL: OnceLock<OnceCell<TestDb>> = OnceLock::new();
let once = CELL.get_or_init(OnceCell::new);
once.get_or_init(Self::new).await
}
#[must_use]
pub fn pool(&self) -> Pool<AsyncPgConnection> {
self.pool.clone()
}
#[must_use]
pub fn url(&self) -> &str {
&self.url
}
pub async fn execute_sql(&self, sql: &str) {
use diesel_async::RunQueryDsl;
let mut conn = self.pool.get().await.expect("failed to get connection");
diesel::sql_query(sql)
.execute(&mut *conn)
.await
.unwrap_or_else(|e| panic!("SQL execution failed: {e}\nSQL: {sql}"));
}
}
#[cfg(test)]
mod tests {
use super::*;
fn cleanup_probe_job(
_state: crate::state::AppState,
_payload: serde_json::Value,
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = crate::AutumnResult<()>> + Send + 'static>,
> {
Box::pin(async move { Ok(()) })
}
struct CleanupJobPlugin;
impl crate::plugin::Plugin for CleanupJobPlugin {
fn build(self, app: crate::app::AppBuilder) -> crate::app::AppBuilder {
app.jobs(vec![crate::job::JobInfo {
name: "cleanup_probe".to_string(),
max_attempts: 1,
initial_backoff_ms: 1,
uniqueness: None,
concurrency: None,
handler: cleanup_probe_job,
}])
}
}
fn test_routes() -> Vec<Route> {
use axum::routing;
async fn hello() -> &'static str {
"hello"
}
async fn echo_json(
axum::Json(value): axum::Json<serde_json::Value>,
) -> axum::Json<serde_json::Value> {
axum::Json(value)
}
async fn status_201() -> (StatusCode, &'static str) {
(StatusCode::CREATED, "created")
}
vec![
Route {
method: Method::GET,
path: "/hello",
handler: routing::get(hello),
name: "hello",
api_doc: crate::openapi::ApiDoc {
method: "GET",
path: "/hello",
operation_id: "hello",
success_status: 200,
..Default::default()
},
repository: None,
idempotency: crate::route::RouteIdempotency::Direct,
api_version: None,
sunset_opt_out: false,
},
Route {
method: Method::POST,
path: "/echo",
handler: routing::post(echo_json),
name: "echo",
api_doc: crate::openapi::ApiDoc {
method: "POST",
path: "/echo",
operation_id: "echo",
success_status: 200,
..Default::default()
},
repository: None,
idempotency: crate::route::RouteIdempotency::Direct,
api_version: None,
sunset_opt_out: false,
},
Route {
method: Method::POST,
path: "/create",
handler: routing::post(status_201),
name: "create",
api_doc: crate::openapi::ApiDoc {
method: "POST",
path: "/create",
operation_id: "create",
success_status: 201,
..Default::default()
},
repository: None,
idempotency: crate::route::RouteIdempotency::Direct,
api_version: None,
sunset_opt_out: false,
},
]
}
#[tokio::test]
async fn test_app_get_request() {
let client = TestApp::new().routes(test_routes()).build();
client.get("/hello").send().await.assert_ok();
}
#[tokio::test]
async fn test_app_post_json() {
let client = TestApp::new().routes(test_routes()).build();
client
.post("/echo")
.json(&serde_json::json!({"key": "value"}))
.send()
.await
.assert_ok()
.assert_body_contains("key");
}
#[tokio::test]
async fn test_response_assert_status() {
let client = TestApp::new().routes(test_routes()).build();
client
.post("/create")
.send()
.await
.assert_status(201)
.assert_body_eq("created");
}
#[tokio::test]
async fn test_response_assert_success() {
let client = TestApp::new().routes(test_routes()).build();
client.get("/hello").send().await.assert_success();
}
#[tokio::test]
async fn test_not_found() {
let client = TestApp::new().routes(test_routes()).build();
client.get("/nonexistent").send().await.assert_status(404);
}
#[tokio::test]
async fn test_response_json_deserialization() {
let client = TestApp::new().routes(test_routes()).build();
let resp = client
.post("/echo")
.json(&serde_json::json!({"count": 42}))
.send()
.await;
resp.assert_ok().assert_json::<serde_json::Value, _>(|v| {
assert_eq!(v["count"], 42);
});
}
#[tokio::test]
async fn test_custom_header() {
let client = TestApp::new().routes(test_routes()).build();
let resp = client
.get("/hello")
.header("x-custom", "test-value")
.send()
.await;
resp.assert_ok();
}
#[tokio::test]
async fn test_client_default() {
let _app = TestApp::default();
}
#[tokio::test]
async fn dropping_test_client_stops_test_started_job_runtime() {
let _guard = crate::job::global_job_runtime_test_lock().lock().await;
crate::job::clear_global_job_client();
let client = TestApp::new().plugin(CleanupJobPlugin).build();
let leaked_client = crate::job::global_job_client().expect("test job runtime should start");
drop(client);
assert!(
crate::job::global_job_client().is_none(),
"dropping a TestClient with jobs must clear its global job client"
);
let mut last_enqueue_error = None;
for _ in 0..25 {
match leaked_client
.enqueue("cleanup_probe", serde_json::json!({}))
.await
{
Ok(()) => tokio::time::sleep(std::time::Duration::from_millis(10)).await,
Err(error) => {
last_enqueue_error = Some(error.to_string());
break;
}
}
}
assert!(
last_enqueue_error
.as_deref()
.is_some_and(|message| message.contains("failed to enqueue job")),
"captured pre-drop job client must stop accepting jobs after TestClient drop; \
last error: {last_enqueue_error:?}"
);
crate::job::clear_global_job_client();
}
#[tokio::test]
async fn test_app_routes_html_method_override_to_delete() {
use axum::routing;
async fn deleted() -> &'static str {
"deleted"
}
let routes = vec![Route {
method: Method::DELETE,
path: "/items/{id}",
handler: routing::delete(deleted),
name: "items_delete",
api_doc: crate::openapi::ApiDoc {
method: "DELETE",
path: "/items/{id}",
operation_id: "items_delete",
success_status: 200,
..Default::default()
},
repository: None,
idempotency: crate::route::RouteIdempotency::Direct,
api_version: None,
sunset_opt_out: false,
}];
let client = TestApp::new().routes(routes).build();
client
.post("/items/1")
.form("_method=DELETE")
.send()
.await
.assert_ok()
.assert_body_eq("deleted");
}
#[cfg(feature = "maud")]
mod html_assertions {
use super::*;
use axum::routing::get;
async fn notes_index_v1() -> maud::Markup {
maud::html! {
table.notes {
tbody {
@for id in 1..=3u32 {
tr.note-row {
td.title { a href=(format!("/notes/{id}")) { "Note " (id) } }
}
}
}
}
}
}
async fn notes_index_v2() -> maud::Markup {
maud::html! {
div.card {
table.notes.striped {
thead { tr { th { "Title" } } }
tbody.rows {
@for id in 1..=3u32 {
tr.note-row.is-clickable data-id=(id) {
td.title {
span.wrap {
a.link href=(format!("/notes/{id}")) data-turbo="true" {
"Note " (id)
}
}
}
}
}
}
}
}
}
}
async fn note_row_fragment() -> maud::Markup {
maud::html! {
tr.note-row #note-7 {
td.title { a.link href="/notes/7" { "Note 7" } }
}
}
}
fn client(
path: &str,
handler: axum::routing::MethodRouter<crate::state::AppState>,
) -> TestClient {
let router = axum::Router::<crate::state::AppState>::new().route(path, handler);
TestApp::new().merge(router).build()
}
#[tokio::test]
async fn counts_rows_by_tag_and_class() {
let resp = client("/notes", get(notes_index_v1))
.get("/notes")
.send()
.await;
resp.assert_ok()
.assert_selector("table.notes")
.assert_selector_count("tbody tr", 3)
.assert_selector_count("tr.note-row", 3)
.assert_no_selector("form");
}
#[tokio::test]
async fn reads_text_and_attributes() {
let resp = client("/notes", get(notes_index_v1))
.get("/notes")
.send()
.await;
resp.assert_text("tr.note-row td.title a", "Note 1")
.assert_text_contains("tr.note-row", "Note 1")
.assert_attr("tr.note-row td a", "href", "/notes/1");
let links = resp.selector_text("tr.note-row a");
assert_eq!(links, vec!["Note 1", "Note 2", "Note 3"]);
let hrefs = resp.selector_attr("tr.note-row a", "href");
assert_eq!(
hrefs,
vec![
Some("/notes/1".to_string()),
Some("/notes/2".to_string()),
Some("/notes/3".to_string()),
]
);
assert_eq!(resp.selector_count("tr.note-row"), 3);
}
#[tokio::test]
async fn survives_cosmetic_refactor() {
for handler in [get(notes_index_v1), get(notes_index_v2)] {
let resp = client("/notes", handler).get("/notes").send().await;
resp.assert_ok()
.assert_selector_count("tbody tr.note-row", 3);
let hrefs = resp.selector_attr("tbody tr.note-row a", "href");
assert_eq!(
hrefs,
vec![
Some("/notes/1".to_string()),
Some("/notes/2".to_string()),
Some("/notes/3".to_string()),
],
"row links must survive the refactor"
);
}
}
#[tokio::test]
async fn works_for_htmx_fragment() {
let resp = client("/rows/7", get(note_row_fragment))
.get("/rows/7")
.send()
.await;
resp.assert_selector("tr.note-row")
.assert_selector("tr#note-7")
.assert_attr("tr#note-7 a", "href", "/notes/7")
.assert_text("tr#note-7 a.link", "Note 7");
}
#[tokio::test]
async fn id_and_attribute_selectors() {
let resp = client("/rows/7", get(note_row_fragment))
.get("/rows/7")
.send()
.await;
resp.assert_selector("#note-7")
.assert_selector("a[href=\"/notes/7\"]")
.assert_selector("a[href^=\"/notes/\"]")
.assert_no_selector("a[href=\"/other\"]");
}
#[tokio::test]
#[should_panic(expected = "expected 5 element(s) matching selector")]
async fn count_mismatch_panics_with_actionable_message() {
let resp = client("/notes", get(notes_index_v1))
.get("/notes")
.send()
.await;
resp.assert_selector_count("tr.note-row", 5);
}
#[tokio::test]
#[should_panic(expected = "no elements matched selector `table.missing`")]
async fn missing_selector_panics() {
let resp = client("/notes", get(notes_index_v1))
.get("/notes")
.send()
.await;
resp.assert_selector("table.missing");
}
}
#[tokio::test]
async fn test_app_routes_invalid_method_override_rejected() {
let client = TestApp::new().routes(test_routes()).build();
client
.post("/create")
.form("_method=BREW")
.send()
.await
.assert_status(400);
}
#[tokio::test]
async fn invalid_method_override_response_carries_framework_middleware() {
let client = TestApp::new().routes(test_routes()).build();
let response = client.post("/create").form("_method=BREW").send().await;
response.assert_status(400);
assert!(
response.header("x-request-id").is_some(),
"framework request-id header must wrap method-override rejections; \
observed headers: {:?}",
response.headers
);
assert!(
response.header("x-content-type-options").is_some(),
"framework security headers must wrap method-override rejections; \
observed headers: {:?}",
response.headers
);
}
}