use std::any::{Any, TypeId};
use std::collections::{HashMap, HashSet};
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use futures::FutureExt as _;
use tracing::Instrument as _;
use crate::config::{AutumnConfig, ConfigLoader};
use crate::error_pages::{ErrorPageRenderer, SharedRenderer};
use crate::middleware::exception_filter::ExceptionFilter;
#[cfg(feature = "db")]
use crate::migrate;
use crate::route::Route;
use crate::state::AppState;
#[must_use]
pub fn app() -> AppBuilder {
AppBuilder {
routes: Vec::new(),
route_sources: Vec::new(),
current_plugin: None,
tasks: Vec::new(),
one_off_tasks: Vec::new(),
jobs: Vec::new(),
static_metas: Vec::new(),
exception_filters: Vec::new(),
scoped_groups: Vec::new(),
merge_routers: Vec::new(),
nest_routers: Vec::new(),
custom_layers: Vec::new(),
startup_hooks: Vec::new(),
shutdown_hooks: Vec::new(),
extensions: HashMap::new(),
registered_plugins: HashSet::new(),
error_page_renderer: None,
#[cfg(feature = "db")]
migrations: Vec::new(),
config_loader_factory: None,
#[cfg(feature = "db")]
pool_provider_factory: None,
telemetry_provider: None,
session_store: None,
#[cfg(feature = "ws")]
channels_backend: None,
#[cfg(feature = "storage")]
blob_store: None,
cache_backend: None,
#[cfg(feature = "openapi")]
openapi: None,
audit_logger: None,
#[cfg(feature = "i18n")]
i18n_bundle: None,
#[cfg(feature = "i18n")]
i18n_auto_load: false,
policy_registrations: Vec::new(),
#[cfg(feature = "mail")]
mail_delivery_queue_factory: None,
#[cfg(feature = "mail")]
mail_previews: Vec::new(),
declared_routes: Vec::new(),
}
}
type StartupHookFuture = Pin<Box<dyn Future<Output = crate::AutumnResult<()>> + Send>>;
type StartupHook = Box<dyn Fn(AppState) -> StartupHookFuture + Send + Sync>;
type ShutdownHookFuture = Pin<Box<dyn Future<Output = ()> + Send>>;
type ShutdownHook = Box<dyn Fn() -> ShutdownHookFuture + Send + Sync>;
type ConfigLoaderFactory = Box<
dyn FnOnce() -> Pin<
Box<dyn Future<Output = Result<AutumnConfig, crate::config::ConfigError>> + Send>,
> + Send,
>;
#[cfg(feature = "db")]
type PoolProviderFactory = Box<
dyn FnOnce(
crate::config::DatabaseConfig,
) -> Pin<
Box<
dyn Future<
Output = Result<Option<crate::db::DatabaseTopology>, crate::db::PoolError>,
> + Send,
>,
> + Send,
>;
type PolicyRegistration = Box<dyn FnOnce(&crate::authorization::PolicyRegistry) + Send>;
pub struct AppBuilder {
routes: Vec<Route>,
route_sources: Vec<crate::route_listing::RouteSource>,
current_plugin: Option<String>,
tasks: Vec<crate::task::TaskInfo>,
one_off_tasks: Vec<crate::task::OneOffTaskInfo>,
jobs: Vec<crate::job::JobInfo>,
pub(crate) static_metas: Vec<crate::static_gen::StaticRouteMeta>,
exception_filters: Vec<Arc<dyn ExceptionFilter>>,
scoped_groups: Vec<ScopedGroup>,
merge_routers: Vec<axum::Router<AppState>>,
nest_routers: Vec<(String, axum::Router<AppState>)>,
custom_layers: Vec<CustomLayerRegistration>,
startup_hooks: Vec<StartupHook>,
shutdown_hooks: Vec<ShutdownHook>,
extensions: HashMap<TypeId, Box<dyn Any + Send>>,
registered_plugins: HashSet<String>,
error_page_renderer: Option<SharedRenderer>,
#[cfg(feature = "db")]
migrations: Vec<migrate::EmbeddedMigrations>,
config_loader_factory: Option<ConfigLoaderFactory>,
#[cfg(feature = "db")]
pool_provider_factory: Option<PoolProviderFactory>,
telemetry_provider: Option<Box<dyn crate::telemetry::TelemetryProvider>>,
session_store: Option<Arc<dyn crate::session::BoxedSessionStore>>,
#[cfg(feature = "ws")]
channels_backend: Option<Arc<dyn crate::channels::ChannelsBackend>>,
#[cfg(feature = "storage")]
blob_store: Option<crate::storage::SharedBlobStore>,
cache_backend: Option<Arc<dyn crate::cache::Cache>>,
#[cfg(feature = "openapi")]
openapi: Option<crate::openapi::OpenApiConfig>,
audit_logger: Option<Arc<crate::audit::AuditLogger>>,
#[cfg(feature = "i18n")]
i18n_bundle: Option<Arc<crate::i18n::Bundle>>,
#[cfg(feature = "i18n")]
i18n_auto_load: bool,
policy_registrations: Vec<PolicyRegistration>,
#[cfg(feature = "mail")]
mail_delivery_queue_factory: Option<MailDeliveryQueueFactory>,
#[cfg(feature = "mail")]
mail_previews: Vec<crate::mail::MailPreview>,
declared_routes: Vec<crate::route_listing::RouteInfo>,
}
#[cfg(feature = "mail")]
pub(crate) type MailDeliveryQueueFactory = Box<
dyn FnOnce(&AppState) -> crate::AutumnResult<Arc<dyn crate::mail::MailDeliveryQueue>> + Send,
>;
pub(crate) struct ScopedGroup {
pub(crate) prefix: String,
pub(crate) routes: Vec<Route>,
pub(crate) source: crate::route_listing::RouteSource,
pub(crate) apply_layer:
Box<dyn FnOnce(axum::Router<AppState>) -> axum::Router<AppState> + Send>,
}
pub(crate) type CustomLayerApplier =
Box<dyn FnOnce(axum::Router<AppState>) -> axum::Router<AppState> + Send>;
pub(crate) struct CustomLayerRegistration {
pub(crate) type_id: TypeId,
pub(crate) apply: CustomLayerApplier,
}
mod sealed {
pub trait Sealed {}
}
#[diagnostic::on_unimplemented(
message = "`{Self}` is not a usable Autumn app-wide Tower layer",
label = "this type does not implement `tower::Layer<axum::routing::Route>` with the required service bounds",
note = "`AppBuilder::layer(..)` requires:\n L: tower::Layer<axum::routing::Route> + Clone + Send + Sync + 'static,\n L::Service: Service<axum::extract::Request, Response = axum::response::Response, Error = Infallible> + Clone + Send + Sync + 'static,\n <L::Service as Service<axum::extract::Request>>::Future: Send + 'static\nSee docs/guide/middleware.md for common patterns and how to wrap raw-error layers (e.g. TimeoutLayer) with HandleErrorLayer."
)]
pub trait IntoAppLayer: sealed::Sealed + Send + Sync + 'static {
#[doc(hidden)]
fn apply_to(self, router: axum::Router<AppState>) -> axum::Router<AppState>;
}
impl<L> sealed::Sealed for L
where
L: tower::Layer<axum::routing::Route> + Clone + Send + Sync + 'static,
L::Service: tower::Service<
axum::extract::Request,
Response = axum::response::Response,
Error = std::convert::Infallible,
> + Clone
+ Send
+ Sync
+ 'static,
<L::Service as tower::Service<axum::extract::Request>>::Future: Send + 'static,
{
}
impl<L> IntoAppLayer for L
where
L: tower::Layer<axum::routing::Route> + Clone + Send + Sync + 'static,
L::Service: tower::Service<
axum::extract::Request,
Response = axum::response::Response,
Error = std::convert::Infallible,
> + Clone
+ Send
+ Sync
+ 'static,
<L::Service as tower::Service<axum::extract::Request>>::Future: Send + 'static,
{
fn apply_to(self, router: axum::Router<AppState>) -> axum::Router<AppState> {
router.layer(self)
}
}
impl AppBuilder {
#[must_use]
pub fn routes(mut self, routes: Vec<Route>) -> Self {
let source = self
.current_plugin
.as_ref()
.map_or(crate::route_listing::RouteSource::User, |name| {
crate::route_listing::RouteSource::Plugin(name.clone())
});
for _ in &routes {
self.route_sources.push(source.clone());
}
self.routes.extend(routes);
self
}
#[must_use]
pub fn tasks(mut self, tasks: Vec<crate::task::TaskInfo>) -> Self {
self.tasks.extend(tasks);
self
}
#[must_use]
pub fn one_off_tasks(mut self, tasks: Vec<crate::task::OneOffTaskInfo>) -> Self {
self.one_off_tasks.extend(tasks);
self
}
#[must_use]
pub fn jobs(mut self, jobs: Vec<crate::job::JobInfo>) -> Self {
self.jobs.extend(jobs);
self
}
#[must_use]
pub fn static_routes(mut self, metas: Vec<crate::static_gen::StaticRouteMeta>) -> Self {
self.static_metas.extend(metas);
self
}
#[cfg(feature = "openapi")]
#[must_use]
pub fn openapi(mut self, config: crate::openapi::OpenApiConfig) -> Self {
self.openapi = Some(config);
self
}
#[must_use]
pub fn exception_filter(mut self, filter: impl ExceptionFilter) -> Self {
self.exception_filters.push(Arc::new(filter));
self
}
#[must_use]
pub fn error_pages(mut self, renderer: impl ErrorPageRenderer) -> Self {
self.error_page_renderer = Some(Arc::new(renderer));
self
}
#[must_use]
pub fn scoped<L>(mut self, prefix: &str, layer: L, routes: Vec<Route>) -> Self
where
L: tower::Layer<axum::routing::Route> + Clone + Send + Sync + 'static,
L::Service: tower::Service<
axum::http::Request<axum::body::Body>,
Response = axum::http::Response<axum::body::Body>,
Error = std::convert::Infallible,
> + Clone
+ Send
+ Sync
+ 'static,
<L::Service as tower::Service<axum::http::Request<axum::body::Body>>>::Future:
Send + 'static,
{
let source = self
.current_plugin
.as_ref()
.map_or(crate::route_listing::RouteSource::User, |name| {
crate::route_listing::RouteSource::Plugin(name.clone())
});
self.scoped_groups.push(ScopedGroup {
prefix: prefix.to_owned(),
routes,
source,
apply_layer: Box::new(move |router| router.layer(layer)),
});
self
}
#[must_use]
pub fn layer<L: IntoAppLayer>(mut self, layer: L) -> Self {
self.custom_layers.push(CustomLayerRegistration {
type_id: TypeId::of::<L>(),
apply: Box::new(move |router| layer.apply_to(router)),
});
self
}
#[must_use]
pub fn has_layer<L: 'static>(&self) -> bool {
let layer_type = TypeId::of::<L>();
self.custom_layers
.iter()
.any(|registered| registered.type_id == layer_type)
}
#[must_use]
pub fn get_layer_types(&self) -> Vec<TypeId> {
self.custom_layers
.iter()
.map(|registered| registered.type_id)
.collect()
}
#[must_use]
pub fn merge(mut self, router: axum::Router<AppState>) -> Self {
self.merge_routers.push(router);
self
}
#[must_use]
pub fn nest(mut self, path: &str, router: axum::Router<AppState>) -> Self {
self.nest_routers.push((path.to_owned(), router));
self
}
#[must_use]
pub fn declare_plugin_routes(
mut self,
routes: impl IntoIterator<Item = crate::route_listing::RouteInfo>,
) -> Self {
let source = self
.current_plugin
.as_deref()
.map_or(crate::route_listing::RouteSource::User, |name| {
crate::route_listing::RouteSource::Plugin(name.to_owned())
});
for mut route in routes {
route.source = source.clone();
self.declared_routes.push(route);
}
self
}
#[must_use]
pub fn on_startup<F, Fut>(mut self, hook: F) -> Self
where
F: Fn(AppState) -> Fut + Send + Sync + 'static,
Fut: Future<Output = crate::AutumnResult<()>> + Send + 'static,
{
self.startup_hooks
.push(Box::new(move |state| Box::pin(hook(state))));
self
}
#[must_use]
pub fn on_shutdown<F, Fut>(mut self, hook: F) -> Self
where
F: Fn() -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
self.shutdown_hooks.push(Box::new(move || Box::pin(hook())));
self
}
#[must_use]
pub fn with_extension<T>(mut self, value: T) -> Self
where
T: Any + Send + 'static,
{
self.extensions.insert(TypeId::of::<T>(), Box::new(value));
self
}
#[must_use]
pub fn update_extension<T, Init, Update>(mut self, init: Init, update: Update) -> Self
where
T: Any + Send + 'static,
Init: FnOnce() -> T,
Update: FnOnce(&mut T),
{
let type_id = TypeId::of::<T>();
let entry = self
.extensions
.entry(type_id)
.or_insert_with(|| Box::new(init()));
let typed = entry
.downcast_mut::<T>()
.expect("extension type map corrupted");
update(typed);
self
}
#[must_use]
pub fn extension<T>(&self) -> Option<&T>
where
T: Any + Send + 'static,
{
self.extensions.get(&TypeId::of::<T>())?.downcast_ref::<T>()
}
#[cfg(feature = "i18n")]
#[must_use]
pub fn i18n(mut self, bundle: crate::i18n::Bundle) -> Self {
self.i18n_bundle = Some(Arc::new(bundle));
self.i18n_auto_load = false;
self
}
#[cfg(feature = "i18n")]
#[must_use]
pub fn i18n_auto(mut self) -> Self {
self.i18n_bundle = None;
self.i18n_auto_load = true;
self
}
#[must_use]
pub fn with_config_loader<L>(mut self, loader: L) -> Self
where
L: crate::config::ConfigLoader,
{
if self.config_loader_factory.is_some() {
tracing::warn!(
"config loader replaced; the previously-installed loader was overwritten"
);
}
self.config_loader_factory = Some(Box::new(move || {
Box::pin(async move { loader.load().await })
}));
self
}
#[cfg(feature = "db")]
#[must_use]
pub fn with_pool_provider<P>(mut self, provider: P) -> Self
where
P: crate::db::DatabasePoolProvider,
{
if self.pool_provider_factory.is_some() {
tracing::warn!(
"database pool provider replaced; the previously-installed provider was overwritten"
);
}
self.pool_provider_factory =
Some(Box::new(move |config: crate::config::DatabaseConfig| {
Box::pin(async move { provider.create_topology(&config).await })
}));
self
}
#[must_use]
pub fn with_telemetry_provider<T>(mut self, provider: T) -> Self
where
T: crate::telemetry::TelemetryProvider,
{
if self.telemetry_provider.is_some() {
tracing::warn!(
"telemetry provider replaced; the previously-installed provider was overwritten"
);
}
self.telemetry_provider = Some(Box::new(provider));
self
}
#[must_use]
pub fn with_session_store<S>(mut self, store: S) -> Self
where
S: crate::session::SessionStore,
{
if self.session_store.is_some() {
tracing::warn!(
"session store replaced; the previously-installed store was overwritten"
);
}
self.session_store = Some(Arc::new(store));
self
}
#[cfg(feature = "ws")]
#[must_use]
pub fn with_channels_backend<B>(mut self, backend: B) -> Self
where
B: crate::channels::ChannelsBackend,
{
if self.channels_backend.is_some() {
tracing::warn!(
"channels backend replaced; the previously-installed backend was overwritten"
);
}
self.channels_backend = Some(Arc::new(backend));
self
}
#[cfg(feature = "storage")]
#[must_use]
pub fn with_blob_store<B>(mut self, store: B) -> Self
where
B: crate::storage::BlobStore,
{
if self.blob_store.is_some() {
tracing::warn!("blob store replaced; the previously-installed store was overwritten");
}
self.blob_store = Some(std::sync::Arc::new(store));
self
}
#[must_use]
pub fn with_cache_backend<C: crate::cache::Cache>(mut self, cache: C) -> Self {
if self.cache_backend.is_some() {
tracing::warn!(
"cache backend replaced; the previously-installed backend was overwritten"
);
}
self.cache_backend = Some(Arc::new(cache) as Arc<dyn crate::cache::Cache>);
self
}
#[cfg(feature = "mail")]
#[must_use]
pub fn with_mail_delivery_queue(
mut self,
queue: impl crate::mail::MailDeliveryQueue + 'static,
) -> Self {
let arc: Arc<dyn crate::mail::MailDeliveryQueue> = Arc::new(queue);
self.mail_delivery_queue_factory = Some(Box::new(move |_state| Ok(arc)));
self
}
#[cfg(feature = "mail")]
#[must_use]
pub fn with_mail_delivery_queue_factory<F, Q>(mut self, factory: F) -> Self
where
F: FnOnce(&AppState) -> crate::AutumnResult<Q> + Send + 'static,
Q: crate::mail::MailDeliveryQueue + 'static,
{
self.mail_delivery_queue_factory = Some(Box::new(move |state| {
factory(state).map(|q| Arc::new(q) as Arc<dyn crate::mail::MailDeliveryQueue>)
}));
self
}
#[cfg(feature = "mail")]
#[must_use]
pub fn mail_previews(
mut self,
previews: impl IntoIterator<Item = crate::mail::MailPreview>,
) -> Self {
self.mail_previews.extend(previews);
self
}
#[must_use]
pub fn with_audit_sink<S>(mut self, sink: S) -> Self
where
S: crate::audit::AuditSink,
{
let logger = self
.audit_logger
.take()
.map_or_else(crate::audit::AuditLogger::new, |logger| (*logger).clone())
.with_sink(Arc::new(sink));
self.audit_logger = Some(Arc::new(logger));
self
}
#[must_use]
pub fn policy<R, P>(mut self, policy: P) -> Self
where
R: Send + Sync + 'static,
P: crate::authorization::Policy<R>,
{
self.policy_registrations.push(Box::new(move |registry| {
registry.register_policy::<R, _>(policy);
}));
self
}
#[must_use]
pub fn scope<R, S>(mut self, scope: S) -> Self
where
R: Send + Sync + 'static,
S: crate::authorization::Scope<R>,
{
self.policy_registrations.push(Box::new(move |registry| {
registry.register_scope::<R, _>(scope);
}));
self
}
#[must_use]
#[track_caller]
pub fn plugin<P>(mut self, plugin: P) -> Self
where
P: crate::plugin::Plugin,
{
let name = plugin.name();
if self.registered_plugins.contains(name.as_ref()) {
tracing::warn!(
plugin = name.as_ref(),
"plugin already registered; skipping duplicate"
);
return self;
}
let name_str = name.into_owned();
self.registered_plugins.insert(name_str.clone());
let outer_plugin = self.current_plugin.replace(name_str);
let mut result = plugin.build(self);
result.current_plugin = outer_plugin;
result
}
#[must_use]
pub fn plugins<P>(self, plugins: P) -> Self
where
P: crate::plugin::Plugins,
{
plugins.apply(self)
}
#[must_use]
pub fn has_plugin(&self, name: &str) -> bool {
self.registered_plugins.contains(name)
}
#[cfg(feature = "db")]
#[must_use]
pub fn migrations(mut self, migrations: migrate::EmbeddedMigrations) -> Self {
self.migrations.push(migrations);
self
}
#[allow(clippy::too_many_lines)]
#[allow(clippy::cognitive_complexity)]
pub async fn run(self) {
if is_static_build_mode() {
self.run_build_mode().await;
return;
}
if is_dump_routes_mode() {
self.run_dump_routes_mode().await;
return;
}
if is_list_one_off_tasks_mode() {
self.run_list_one_off_tasks_mode();
return;
}
if let Some(task_name) = one_off_task_name_from_env() {
self.run_one_off_task_mode(task_name).await;
return;
}
let Self {
routes,
route_sources: _,
current_plugin: _,
tasks,
one_off_tasks: _,
jobs,
static_metas: _,
exception_filters,
scoped_groups,
merge_routers,
nest_routers,
custom_layers,
startup_hooks,
shutdown_hooks,
extensions: _,
registered_plugins: _,
error_page_renderer,
#[cfg(feature = "db")]
migrations,
config_loader_factory,
#[cfg(feature = "db")]
pool_provider_factory,
telemetry_provider,
session_store,
#[cfg(feature = "ws")]
channels_backend,
#[cfg(feature = "storage")]
blob_store,
cache_backend,
#[cfg(feature = "openapi")]
openapi,
audit_logger,
#[cfg(feature = "i18n")]
i18n_bundle,
#[cfg(feature = "i18n")]
i18n_auto_load,
policy_registrations,
#[cfg(feature = "mail")]
mail_delivery_queue_factory,
#[cfg(feature = "mail")]
mail_previews,
declared_routes: _,
} = self;
let all_routes = routes;
let (config, _telemetry_guard) =
load_config_and_telemetry(config_loader_factory, telemetry_provider).await;
#[cfg(feature = "i18n")]
let i18n_bundle =
resolve_i18n_bundle(i18n_bundle, i18n_auto_load, &config, &crate::config::OsEnv);
assert!(
!all_routes.is_empty(),
"No routes registered. Did you forget to call .routes()?"
);
let profile_display = config.profile.as_deref().unwrap_or("none");
tracing::info!(
version = env!("CARGO_PKG_VERSION"),
profile = profile_display,
"Autumn starting"
);
let show_config = std::env::var("AUTUMN_SHOW_CONFIG").as_deref() == Ok("1");
if show_config {
log_startup_transparency(&all_routes, &tasks, &scoped_groups, &config);
}
fail_fast_on_invalid_session_config(&config, session_store.is_some());
fail_fast_on_invalid_signing_secret(&config);
fail_fast_on_invalid_webhook_config(&config);
#[cfg(feature = "storage")]
let storage_bootstrap = blob_store.map_or_else(
|| preflight_storage(&config),
|store| {
Some(StorageBootstrap {
store,
serving: None,
})
},
);
#[cfg(feature = "db")]
let database = setup_database(&config, migrations, pool_provider_factory)
.await
.unwrap_or_else(|e| {
tracing::error!("{e}");
std::process::exit(1);
});
#[cfg(feature = "db")]
let pool = database.topology;
#[cfg(feature = "db")]
let replica_readiness = database.replica_readiness;
#[cfg(feature = "db")]
let replica_migration_check = database.replica_migration_check;
#[cfg(feature = "db")]
if pool.is_some() {
tracing::info!(
primary_max_connections = config.database.effective_primary_pool_size(),
replica_configured = config.database.replica_url.is_some(),
replica_max_connections = config.database.effective_replica_pool_size(),
"Database topology configured"
);
} else {
tracing::info!("Database not configured");
}
validate_repository_api_policies(&all_routes, &scoped_groups, &config);
let mut state = build_state(
&config,
#[cfg(feature = "db")]
pool.as_ref(),
#[cfg(feature = "ws")]
channels_backend,
);
#[cfg(feature = "db")]
configure_replica_migration_check(&state, replica_migration_check);
#[cfg(feature = "db")]
apply_replica_migration_readiness(&state, replica_readiness);
if let Some(cache) = cache_backend {
crate::cache::set_global_cache(cache.clone());
state.shared_cache = Some(cache);
} else {
crate::cache::clear_global_cache();
}
for register in policy_registrations {
register(state.policy_registry());
}
validate_repository_policies_registered(&all_routes, &scoped_groups, &state, &config);
#[cfg(feature = "mail")]
crate::mail::install_mailer_with_factory(
&state,
&config.mail,
mail_delivery_queue_factory,
true,
)
.unwrap_or_else(|error| {
tracing::error!(error = %error, "Failed to configure mailer");
std::process::exit(1);
});
#[cfg(feature = "mail")]
state.insert_extension(crate::mail::MailPreviewRegistry::new(mail_previews));
if let Some(logger) = audit_logger {
state.insert_extension::<crate::audit::AuditLogger>((*logger).clone());
}
#[cfg(feature = "i18n")]
let custom_layers = install_i18n_bundle_layer(custom_layers, &state, i18n_bundle);
#[cfg(feature = "storage")]
let storage_router = storage_bootstrap.and_then(|b| b.install(&state));
install_webhook_registry(&state, &config);
let env = crate::config::OsEnv;
let dist_dir = project_dir("dist", &env);
let dist_ref = if dist_dir.exists() {
Some(dist_dir.as_path())
} else {
None
};
#[cfg_attr(not(feature = "storage"), allow(unused_mut))]
let mut merge_routers = merge_routers;
#[cfg(feature = "storage")]
if let Some(router) = storage_router {
merge_routers.push(router);
}
let router = crate::router::try_build_router_with_static_inner(
all_routes,
&config,
state.clone(),
dist_ref,
crate::router::RouterContext {
exception_filters,
scoped_groups,
merge_routers,
nest_routers,
custom_layers,
error_page_renderer,
session_store,
#[cfg(feature = "openapi")]
openapi: if config.openapi_runtime.enabled {
openapi
} else {
None
},
},
)
.unwrap_or_else(|error| {
tracing::error!(error = %error, "Failed to build router");
std::process::exit(1);
});
let addr = format!("{}:{}", config.server.host, config.server.port);
let listener = tokio::net::TcpListener::bind(&addr)
.await
.unwrap_or_else(|e| {
tracing::error!(addr = %addr, "Failed to bind: {e}");
std::process::exit(1);
});
let shutdown_timeout = config.server.shutdown_timeout_secs;
let server_shutdown = tokio_util::sync::CancellationToken::new();
if let Err(error) = initialize_job_runtime(jobs, &state, &server_shutdown, &config.jobs) {
tracing::error!(error = %error, "job runtime initialization failed");
std::process::exit(1);
}
tracing::info!(addr = %addr, "Listening");
let server_shutdown_wait = server_shutdown.clone();
let server_task = tokio::spawn(async move {
axum::serve(
listener,
router.into_make_service_with_connect_info::<std::net::SocketAddr>(),
)
.with_graceful_shutdown(async move {
server_shutdown_wait.cancelled().await;
})
.await
});
let shutdown_state = state.clone();
let shutdown_signal_token = server_shutdown.clone();
#[cfg(feature = "ws")]
let websocket_shutdown = state.shutdown.clone();
let shutdown_task = tokio::spawn(async move {
shutdown_signal().await;
shutdown_state.begin_shutdown();
#[cfg(feature = "ws")]
websocket_shutdown.cancel();
if shutdown_timeout > 5 {
tokio::spawn(async move {
tokio::time::sleep(std::time::Duration::from_secs(
shutdown_timeout.saturating_sub(5),
))
.await;
tracing::warn!(
timeout_secs = shutdown_timeout,
"Shutdown draining near timeout, force-kill may be imminent"
);
});
}
run_shutdown_hooks(&shutdown_hooks).await;
shutdown_signal_token.cancel();
});
if let Err(error) = run_startup_hooks(&startup_hooks, state.clone()).await {
tracing::error!(error = %error, "startup hook failed");
server_shutdown.cancel();
server_task.abort();
std::process::exit(1);
}
if !state.probes().is_shutting_down() {
if !tasks.is_empty()
&& let Err(error) = start_task_scheduler_with_config(
tasks,
&state,
&server_shutdown,
&config.scheduler,
)
{
tracing::error!(error = %error, "scheduled task runtime initialization failed");
server_shutdown.cancel();
server_task.abort();
std::process::exit(1);
}
state.probes().mark_startup_complete();
}
let server_result = server_task.await.unwrap_or_else(|e| {
tracing::error!("Server task join error: {e}");
std::process::exit(1);
});
shutdown_task.abort();
server_result.unwrap_or_else(|e| {
tracing::error!("Server error: {e}");
std::process::exit(1);
});
tracing::info!("Server shut down cleanly");
}
#[allow(clippy::too_many_lines)]
async fn run_build_mode(self) {
let Self {
routes,
route_sources: _,
current_plugin: _,
tasks: _,
one_off_tasks: _,
jobs: _,
static_metas,
exception_filters: _,
#[cfg(feature = "openapi")]
scoped_groups,
#[cfg(not(feature = "openapi"))]
scoped_groups: _,
merge_routers: _,
nest_routers: _,
custom_layers,
startup_hooks: _,
shutdown_hooks: _,
extensions: _,
registered_plugins: _,
error_page_renderer: _,
#[cfg(feature = "db")]
migrations: _,
config_loader_factory,
#[cfg(feature = "db")]
pool_provider_factory,
telemetry_provider,
session_store,
#[cfg(feature = "ws")]
channels_backend,
#[cfg(feature = "storage")]
blob_store,
cache_backend,
#[cfg(feature = "openapi")]
openapi,
audit_logger: _,
#[cfg(feature = "i18n")]
i18n_bundle,
#[cfg(feature = "i18n")]
i18n_auto_load,
policy_registrations,
#[cfg(feature = "mail")]
mail_delivery_queue_factory,
#[cfg(feature = "mail")]
mail_previews,
declared_routes: _,
} = self;
let all_routes = routes;
let (config, _telemetry_guard) =
load_config_and_telemetry(config_loader_factory, telemetry_provider).await;
#[cfg(feature = "i18n")]
let i18n_bundle =
resolve_i18n_bundle(i18n_bundle, i18n_auto_load, &config, &crate::config::OsEnv);
#[cfg(feature = "openapi")]
let api_docs_snapshot: Vec<crate::openapi::ApiDoc> = {
let mut docs: Vec<crate::openapi::ApiDoc> =
all_routes.iter().map(|r| r.api_doc.clone()).collect();
for group in &scoped_groups {
let prefix_params = crate::router::extract_path_params(&group.prefix);
for route in &group.routes {
let mut doc = route.api_doc.clone();
let full = crate::router::join_nested_path(&group.prefix, route.api_doc.path);
doc.path = Box::leak(full.into_boxed_str());
if !prefix_params.is_empty() {
let mut merged: Vec<&'static str> = prefix_params
.iter()
.map(|p| &*Box::leak(p.clone().into_boxed_str()))
.collect();
merged.extend_from_slice(doc.path_params);
doc.path_params = Box::leak(merged.into_boxed_slice());
}
docs.push(doc);
}
}
docs
};
if static_metas.is_empty() {
eprintln!("No static routes registered. Nothing to build.");
eprintln!("Hint: use .static_routes(static_routes![...]) on your AppBuilder.");
std::process::exit(1);
}
fail_fast_on_invalid_session_config(&config, session_store.is_some());
fail_fast_on_invalid_signing_secret(&config);
#[cfg(feature = "storage")]
let storage_bootstrap = blob_store.map_or_else(
|| preflight_storage(&config),
|store| {
Some(StorageBootstrap {
store,
serving: None,
})
},
);
#[cfg(feature = "db")]
let database = setup_database(&config, vec![], pool_provider_factory)
.await
.unwrap_or_else(|e| {
eprintln!("{e}");
std::process::exit(1);
});
#[cfg(feature = "db")]
let pool = database.topology;
#[cfg(feature = "db")]
let replica_readiness = database.replica_readiness;
#[cfg(feature = "db")]
let replica_migration_check = database.replica_migration_check;
let mut state = build_state(
&config,
#[cfg(feature = "db")]
pool.as_ref(),
#[cfg(feature = "ws")]
channels_backend,
);
#[cfg(feature = "db")]
configure_replica_migration_check(&state, replica_migration_check);
#[cfg(feature = "db")]
apply_replica_migration_readiness(&state, replica_readiness);
if let Some(cache) = cache_backend {
crate::cache::set_global_cache(cache.clone());
state.shared_cache = Some(cache);
} else {
crate::cache::clear_global_cache();
}
#[cfg(feature = "mail")]
crate::mail::install_mailer_with_factory(
&state,
&config.mail,
mail_delivery_queue_factory,
false,
)
.unwrap_or_else(|error| {
eprintln!("Failed to configure mailer: {error}");
std::process::exit(1);
});
#[cfg(feature = "mail")]
state.insert_extension(crate::mail::MailPreviewRegistry::new(mail_previews));
state.probes = crate::probe::ProbeState::default();
for register in policy_registrations {
register(state.policy_registry());
}
#[cfg(feature = "i18n")]
let custom_layers = install_i18n_bundle_layer(custom_layers, &state, i18n_bundle);
#[cfg(feature = "storage")]
let storage_router = storage_bootstrap.and_then(|b| b.install(&state));
#[cfg_attr(not(feature = "storage"), allow(unused_mut))]
let mut merge_routers: Vec<axum::Router<AppState>> = Vec::new();
#[cfg(feature = "storage")]
if let Some(router) = storage_router {
merge_routers.push(router);
}
let router = crate::router::try_build_router_inner(
all_routes,
&config,
state,
crate::router::RouterContext {
exception_filters: Vec::new(),
scoped_groups: Vec::new(),
merge_routers,
nest_routers: Vec::new(),
custom_layers,
error_page_renderer: None,
session_store,
#[cfg(feature = "openapi")]
openapi: None,
},
)
.unwrap_or_else(|error| {
eprintln!("Failed to build router: {error}");
std::process::exit(1);
});
let env = crate::config::OsEnv;
let dist_dir = project_dir("dist", &env);
eprintln!("Building {} static route(s)...", static_metas.len());
match crate::static_gen::render_static_routes(router, &static_metas, &dist_dir).await {
Ok(()) => {
eprintln!(
"\n \u{2713} Static build complete \u{2192} {}",
dist_dir.display()
);
}
Err(e) => {
eprintln!("\n \u{2717} Static build failed: {e}");
std::process::exit(1);
}
}
#[cfg(feature = "openapi")]
if let Some(openapi_config) = openapi {
let openapi_config =
openapi_config.session_cookie_name(config.session.cookie_name.clone());
let docs: Vec<&crate::openapi::ApiDoc> = api_docs_snapshot.iter().collect();
let spec = crate::openapi::generate_spec(&openapi_config, &docs);
match crate::openapi::write_openapi_spec_to_dist(&spec, &dist_dir) {
Ok(()) => {
eprintln!(
" \u{2713} OpenAPI spec written \u{2192} {}/openapi.json",
dist_dir.display()
);
}
Err(e) => {
eprintln!(" \u{26A0} Failed to write OpenAPI spec: {e}");
}
}
}
}
async fn run_dump_routes_mode(self) {
let Self {
routes,
route_sources,
scoped_groups,
merge_routers,
nest_routers,
declared_routes,
config_loader_factory,
telemetry_provider,
#[cfg(feature = "openapi")]
openapi,
..
} = self;
let hidden = merge_routers.len() + nest_routers.len();
if hidden > 0 {
eprintln!(
"[autumn routes] warning: {hidden} raw router(s) added via \
.merge()/.nest() are not enumerable and are omitted from this listing"
);
}
let (config, _telemetry_guard) =
load_config_and_telemetry(config_loader_factory, telemetry_provider).await;
let mut infos =
crate::route_listing::collect_route_infos(&routes, &route_sources, &scoped_groups);
infos.extend(declared_routes);
crate::route_listing::append_framework_routes(&mut infos, &config);
#[cfg(feature = "openapi")]
if let Some(ref oa) = openapi {
crate::route_listing::append_openapi_routes(&mut infos, oa);
}
crate::route_listing::append_dev_reload_routes(&mut infos);
crate::route_listing::sort_route_infos(&mut infos);
let json = serde_json::to_string_pretty(&infos).unwrap_or_else(|e| {
eprintln!("Failed to serialize route listing: {e}");
std::process::exit(1);
});
println!("{json}");
std::process::exit(0);
}
fn run_list_one_off_tasks_mode(self) {
let Self { one_off_tasks, .. } = self;
if let Err(error) = crate::task::validate_unique_one_off_task_names(&one_off_tasks) {
eprintln!("Invalid task registration: {error}");
std::process::exit(1);
}
let listing = crate::task::list_one_off_tasks(&one_off_tasks);
let json = serde_json::to_string_pretty(&listing).unwrap_or_else(|error| {
eprintln!("Failed to serialize task listing: {error}");
std::process::exit(1);
});
println!("{json}");
std::process::exit(0);
}
#[allow(clippy::too_many_lines)]
#[allow(clippy::cognitive_complexity)]
async fn run_one_off_task_mode(self, requested_name: String) {
let Self {
one_off_tasks,
jobs,
#[cfg(feature = "i18n")]
custom_layers,
#[cfg(not(feature = "i18n"))]
custom_layers: _,
startup_hooks,
shutdown_hooks,
config_loader_factory,
#[cfg(feature = "db")]
migrations,
#[cfg(feature = "db")]
pool_provider_factory,
telemetry_provider,
session_store,
#[cfg(feature = "ws")]
channels_backend,
#[cfg(feature = "storage")]
blob_store,
audit_logger,
#[cfg(feature = "i18n")]
i18n_bundle,
#[cfg(feature = "i18n")]
i18n_auto_load,
policy_registrations,
cache_backend,
#[cfg(feature = "mail")]
mail_delivery_queue_factory,
..
} = self;
if let Err(error) = crate::task::validate_unique_one_off_task_names(&one_off_tasks) {
eprintln!("Invalid task registration: {error}");
std::process::exit(1);
}
let Some((task_name, task_handler)) = one_off_tasks
.iter()
.find(|task| task.name == requested_name)
.map(|task| (task.name.clone(), task.handler))
else {
eprintln!("No one-off task named '{requested_name}' is registered.");
print_available_one_off_tasks(&one_off_tasks);
std::process::exit(1);
};
let args = one_off_task_args_from_env().unwrap_or_else(|error| {
eprintln!("Invalid task args: {error}");
std::process::exit(1);
});
let (config, _telemetry_guard) =
load_config_and_telemetry(config_loader_factory, telemetry_provider).await;
#[cfg(feature = "i18n")]
let i18n_bundle =
resolve_i18n_bundle(i18n_bundle, i18n_auto_load, &config, &crate::config::OsEnv);
fail_fast_on_invalid_session_config(&config, session_store.is_some());
fail_fast_on_invalid_signing_secret(&config);
#[cfg(feature = "storage")]
let storage_bootstrap = blob_store.map_or_else(
|| preflight_storage(&config),
|store| {
Some(StorageBootstrap {
store,
serving: None,
})
},
);
#[cfg(feature = "db")]
let database = setup_database(&config, migrations, pool_provider_factory)
.await
.unwrap_or_else(|error| {
eprintln!("{error}");
std::process::exit(1);
});
#[cfg(feature = "db")]
let pool = database.topology;
#[cfg(feature = "db")]
let replica_readiness = database.replica_readiness;
#[cfg(feature = "db")]
let replica_migration_check = database.replica_migration_check;
let mut state = build_state(
&config,
#[cfg(feature = "db")]
pool.as_ref(),
#[cfg(feature = "ws")]
channels_backend,
);
#[cfg(feature = "db")]
configure_replica_migration_check(&state, replica_migration_check);
#[cfg(feature = "db")]
apply_replica_migration_readiness(&state, replica_readiness);
if let Some(cache) = cache_backend {
crate::cache::set_global_cache(cache.clone());
state.shared_cache = Some(cache);
} else {
crate::cache::clear_global_cache();
}
for register in policy_registrations {
register(state.policy_registry());
}
#[cfg(feature = "mail")]
crate::mail::install_mailer_with_factory(
&state,
&config.mail,
mail_delivery_queue_factory,
true,
)
.unwrap_or_else(|error| {
eprintln!("Failed to configure mailer: {error}");
std::process::exit(1);
});
if let Some(logger) = audit_logger {
state.insert_extension::<crate::audit::AuditLogger>((*logger).clone());
}
#[cfg(feature = "i18n")]
let _custom_layers = install_i18n_bundle_layer(custom_layers, &state, i18n_bundle);
#[cfg(feature = "storage")]
let _storage_router = storage_bootstrap.and_then(|bootstrap| bootstrap.install(&state));
let task_shutdown = tokio_util::sync::CancellationToken::new();
if let Err(error) = initialize_job_runtime(jobs, &state, &task_shutdown, &config.jobs) {
eprintln!("job runtime initialization failed: {error}");
std::process::exit(1);
}
if let Err(error) = run_startup_hooks(&startup_hooks, state.clone()).await {
eprintln!("startup hook failed: {error}");
task_shutdown.cancel();
std::process::exit(1);
}
state.probes().mark_startup_complete();
tracing::info!(task = %task_name, "Running one-off task");
let span = tracing::info_span!("one_off_task", task = %task_name);
let result = (task_handler)(state.clone(), args).instrument(span).await;
task_shutdown.cancel();
run_shutdown_hooks(&shutdown_hooks).await;
match result {
Ok(()) => {
tracing::info!(task = %task_name, "One-off task completed");
}
Err(error) => {
tracing::error!(task = %task_name, error = %error, "One-off task failed");
eprintln!("Task '{task_name}' failed: {error}");
for cause in error.source_chain() {
eprintln!("Caused by: {cause}");
}
std::process::exit(1);
}
}
}
}
pub(crate) fn is_static_build_mode() -> bool {
std::env::var("AUTUMN_BUILD_STATIC").as_deref() == Ok("1")
}
pub(crate) fn is_dump_routes_mode() -> bool {
std::env::var("AUTUMN_DUMP_ROUTES").as_deref() == Ok("1")
}
pub(crate) fn is_list_one_off_tasks_mode() -> bool {
std::env::var("AUTUMN_LIST_TASKS").as_deref() == Ok("1")
}
fn one_off_task_name_from_env() -> Option<String> {
std::env::var("AUTUMN_RUN_TASK")
.ok()
.map(|value| value.trim().to_owned())
.filter(|value| !value.is_empty())
}
fn one_off_task_args_from_env() -> Result<Vec<String>, String> {
match std::env::var("AUTUMN_TASK_ARGS_JSON") {
Ok(raw) if !raw.trim().is_empty() => serde_json::from_str(&raw)
.map_err(|error| format!("AUTUMN_TASK_ARGS_JSON must be a JSON string array: {error}")),
_ => Ok(Vec::new()),
}
}
fn print_available_one_off_tasks(tasks: &[crate::task::OneOffTaskInfo]) {
let listing = crate::task::list_one_off_tasks(tasks);
if listing.is_empty() {
eprintln!("No one-off tasks are registered. Add .one_off_tasks(one_off_tasks![...]).");
return;
}
eprintln!("Available tasks:");
for task in listing {
if task.description.is_empty() {
eprintln!(" {}", task.name);
} else {
eprintln!(" {:<24} {}", task.name, task.description);
}
}
}
#[allow(clippy::cast_possible_truncation)]
#[allow(clippy::cognitive_complexity)]
#[allow(dead_code)]
fn start_task_scheduler(
tasks: Vec<crate::task::TaskInfo>,
state: &AppState,
shutdown: &tokio_util::sync::CancellationToken,
) {
if let Err(error) = start_task_scheduler_with_config(
tasks,
state,
shutdown,
&crate::config::SchedulerConfig::default(),
) {
tracing::error!(error = %error, "scheduled task runtime initialization failed");
}
}
#[allow(clippy::cast_possible_truncation)]
#[allow(clippy::cognitive_complexity)]
fn start_task_scheduler_with_config(
tasks: Vec<crate::task::TaskInfo>,
state: &AppState,
shutdown: &tokio_util::sync::CancellationToken,
scheduler_config: &crate::config::SchedulerConfig,
) -> crate::AutumnResult<()> {
tracing::info!(count = tasks.len(), "Starting scheduled tasks");
let coordinator = crate::scheduler::coordinator_from_config(scheduler_config, state)?;
let lease_ttl = std::time::Duration::from_secs(scheduler_config.lease_ttl_secs);
for task_info in &tasks {
let schedule_desc = task_info.schedule.to_string();
tracing::info!(
name = %task_info.name,
schedule = %schedule_desc,
coordination = %task_info.coordination,
scheduler_backend = coordinator.backend(),
replica_id = coordinator.replica_id(),
lease_ttl_secs = scheduler_config.lease_ttl_secs,
"Registered task"
);
}
let mut cron_tasks: Vec<CronTaskSpec> = Vec::new();
for task_info in tasks {
let state = state.clone();
let name = task_info.name.clone();
let handler = task_info.handler;
let coordination = task_info.coordination;
let schedule_desc = task_info.schedule.to_string();
state.task_registry.register_scheduled(
&name,
&schedule_desc,
coordination,
coordinator.backend(),
coordinator.replica_id(),
);
match task_info.schedule {
crate::task::Schedule::FixedDelay(delay) => {
let coordinator = Arc::clone(&coordinator);
let shutdown = shutdown.child_token();
tokio::spawn(async move {
loop {
state
.task_registry
.record_next_run_at(&name, &format_next_task_run_after(delay));
tokio::select! {
() = shutdown.cancelled() => break,
() = tokio::time::sleep(delay) => {
execute_fixed_delay_task(
name.clone(),
state.clone(),
handler,
delay,
coordination,
Arc::clone(&coordinator),
lease_ttl,
)
.await;
}
}
}
});
}
crate::task::Schedule::Cron {
expression,
timezone,
} => {
cron_tasks.push(CronTaskSpec {
name,
expression,
timezone,
coordination,
handler,
});
}
}
}
run_cron_scheduler(cron_tasks, state, shutdown, &coordinator, lease_ttl);
Ok(())
}
#[allow(unused_variables, clippy::needless_pass_by_value)]
fn send_ws_sys_task_msg(
state: &AppState,
event: &str,
name: &str,
extra: Vec<(&str, serde_json::Value)>,
) {
#[cfg(feature = "ws")]
{
let mut msg = serde_json::json!({
"event": event,
"task": name,
"timestamp": chrono::Utc::now().to_rfc3339(),
});
if let Some(map) = msg.as_object_mut() {
for (k, v) in extra {
map.insert(k.to_string(), v);
}
}
let _ = state.channels().sender("sys:tasks").send(msg.to_string());
}
}
async fn execute_task_result(
state: &AppState,
handler: crate::task::TaskHandler,
start: std::time::Instant,
name: &str,
schedule: &'static str,
) -> Result<u64, (u64, String)> {
let task_span = tracing::info_span!(
parent: None,
"scheduled_task",
otel.kind = "internal",
task = %name,
schedule = schedule,
);
let future = match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
(handler)(state.clone()).instrument(task_span)
})) {
Ok(future) => future,
Err(panic) => {
let duration_ms = u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX);
return Err((duration_ms, format_scheduled_task_panic(panic.as_ref())));
}
};
let result = std::panic::AssertUnwindSafe(future).catch_unwind().await;
let duration_ms = u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX);
match result {
Ok(Ok(())) => Ok(duration_ms),
Ok(Err(e)) => Err((duration_ms, e.to_string())),
Err(panic) => Err((duration_ms, format_scheduled_task_panic(panic.as_ref()))),
}
}
fn format_scheduled_task_panic(panic: &(dyn Any + Send)) -> String {
let detail = panic
.downcast_ref::<String>()
.map(String::as_str)
.or_else(|| panic.downcast_ref::<&'static str>().copied())
.unwrap_or("non-string panic payload");
format!("scheduled task handler panicked: {detail}")
}
async fn execute_task_result_with_optional_lease_ttl(
state: &AppState,
handler: crate::task::TaskHandler,
start: std::time::Instant,
name: &str,
schedule: &'static str,
lease_ttl: Option<std::time::Duration>,
) -> Result<u64, (u64, String)> {
let Some(lease_ttl) = lease_ttl else {
return execute_task_result(state, handler, start, name, schedule).await;
};
tokio::time::timeout(
lease_ttl,
execute_task_result(state, handler, start, name, schedule),
)
.await
.map_or_else(
|_| {
let duration_ms = u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX);
Err((
duration_ms,
format!(
"scheduled task exceeded lease TTL of {}s",
lease_ttl.as_secs()
),
))
},
std::convert::identity,
)
}
#[allow(clippy::cognitive_complexity)]
async fn execute_fixed_delay_task(
name: String,
state: AppState,
handler: crate::task::TaskHandler,
delay: std::time::Duration,
coordination: crate::task::TaskCoordination,
coordinator: Arc<dyn crate::scheduler::SchedulerCoordinator>,
lease_ttl: std::time::Duration,
) {
let tick_key =
crate::scheduler::fixed_delay_tick_key(&name, delay, crate::scheduler::now_unix_duration());
let lease = match coordinator
.try_acquire(&name, &tick_key, coordination)
.await
{
Ok(Some(lease)) => lease,
Ok(None) => {
tracing::debug!(task = %name, tick = %tick_key, "Scheduled task tick already claimed");
return;
}
Err(error) => {
tracing::warn!(task = %name, tick = %tick_key, error = %error, "Failed to acquire scheduled task lease");
return;
}
};
state
.task_registry
.record_leader(&name, lease.leader_id(), &tick_key);
tracing::debug!(task = %name, "Running scheduled task");
state.task_registry.record_start(&name);
send_ws_sys_task_msg(&state, "started", &name, vec![]);
let start = std::time::Instant::now();
let lease_ttl = lease_ttl_for_run(&lease, coordination, lease_ttl);
match execute_task_result_with_optional_lease_ttl(
&state,
handler,
start,
&name,
"fixed_delay",
lease_ttl,
)
.await
{
Ok(duration_ms) => {
state.task_registry.record_success(&name, duration_ms);
tracing::debug!(task = %name, "Task completed");
send_ws_sys_task_msg(
&state,
"success",
&name,
vec![("duration_ms", serde_json::json!(duration_ms))],
);
}
Err((duration_ms, error_str)) => {
state
.task_registry
.record_failure(&name, duration_ms, &error_str);
tracing::warn!(task = %name, error = %error_str, "Task failed");
send_ws_sys_task_msg(
&state,
"failure",
&name,
vec![
("duration_ms", serde_json::json!(duration_ms)),
("error", serde_json::json!(error_str)),
],
);
}
}
if let Err(error) = lease.release().await {
tracing::warn!(task = %name, tick = %tick_key, error = %error, "Failed to release scheduled task lease");
}
}
#[allow(clippy::cognitive_complexity)]
async fn execute_cron_task(
name: String,
state: AppState,
handler: crate::task::TaskHandler,
coordination: crate::task::TaskCoordination,
coordinator: Arc<dyn crate::scheduler::SchedulerCoordinator>,
lease_ttl: std::time::Duration,
scheduled_unix_secs: u64,
) {
let tick_key = crate::scheduler::cron_tick_key(&name, scheduled_unix_secs);
let lease = match coordinator
.try_acquire(&name, &tick_key, coordination)
.await
{
Ok(Some(lease)) => lease,
Ok(None) => {
tracing::debug!(task = %name, tick = %tick_key, "Cron task tick already claimed");
return;
}
Err(error) => {
tracing::warn!(task = %name, tick = %tick_key, error = %error, "Failed to acquire cron task lease");
return;
}
};
state
.task_registry
.record_leader(&name, lease.leader_id(), &tick_key);
tracing::debug!(task = %name, "Running cron task");
state.task_registry.record_start(&name);
send_ws_sys_task_msg(&state, "started", &name, vec![]);
let start = std::time::Instant::now();
let lease_ttl = lease_ttl_for_run(&lease, coordination, lease_ttl);
match execute_task_result_with_optional_lease_ttl(
&state, handler, start, &name, "cron", lease_ttl,
)
.await
{
Ok(duration_ms) => {
state.task_registry.record_success(&name, duration_ms);
tracing::debug!(task = %name, "Cron task completed");
send_ws_sys_task_msg(
&state,
"success",
&name,
vec![("duration_ms", serde_json::json!(duration_ms))],
);
}
Err((duration_ms, error_str)) => {
state
.task_registry
.record_failure(&name, duration_ms, &error_str);
tracing::warn!(task = %name, error = %error_str, "Cron task failed");
send_ws_sys_task_msg(
&state,
"failure",
&name,
vec![
("duration_ms", serde_json::json!(duration_ms)),
("error", serde_json::json!(error_str)),
],
);
}
}
if let Err(error) = lease.release().await {
tracing::warn!(task = %name, tick = %tick_key, error = %error, "Failed to release cron task lease");
}
}
struct CronTaskSpec {
name: String,
expression: String,
timezone: Option<String>,
coordination: crate::task::TaskCoordination,
handler: crate::task::TaskHandler,
}
fn lease_ttl_for_run(
lease: &crate::scheduler::SchedulerLease,
coordination: crate::task::TaskCoordination,
lease_ttl: std::time::Duration,
) -> Option<std::time::Duration> {
(coordination == crate::task::TaskCoordination::Fleet && lease.backend() == "postgres")
.then_some(lease_ttl)
}
fn run_cron_scheduler(
tasks: Vec<CronTaskSpec>,
state: &AppState,
shutdown: &tokio_util::sync::CancellationToken,
coordinator: &Arc<dyn crate::scheduler::SchedulerCoordinator>,
lease_ttl: std::time::Duration,
) {
if tasks.is_empty() {
return;
}
tracing::info!(count = tasks.len(), "Cron scheduler started");
for task in tasks {
let state = state.clone();
let coordinator = Arc::clone(coordinator);
let shutdown = shutdown.child_token();
tokio::spawn(async move {
run_cron_task_loop(task, state, shutdown, coordinator, lease_ttl).await;
});
}
}
#[allow(clippy::cognitive_complexity)]
async fn run_cron_task_loop(
task: CronTaskSpec,
state: AppState,
shutdown: tokio_util::sync::CancellationToken,
coordinator: Arc<dyn crate::scheduler::SchedulerCoordinator>,
lease_ttl: std::time::Duration,
) {
let CronTaskSpec {
name,
expression,
timezone,
coordination,
handler,
} = task;
let cron = match expression.parse::<croner::Cron>() {
Ok(cron) => cron,
Err(error) => {
tracing::error!(task = %name, expression = %expression, error = %error, "Failed to create cron job");
return;
}
};
let timezone = timezone
.as_deref()
.and_then(|timezone| {
timezone.parse::<chrono_tz::Tz>().map_or_else(
|_| {
tracing::warn!(task = %name, timezone = %timezone, "Unrecognized timezone; falling back to UTC");
None
},
Some,
)
})
.unwrap_or(chrono_tz::UTC);
let mut cursor = chrono::Utc::now().with_timezone(&timezone);
loop {
let now = chrono::Utc::now().with_timezone(&timezone);
let scheduled_at = match next_cron_occurrence_after(&cron, &cursor, &now) {
Ok(scheduled_at) => scheduled_at,
Err(error) => {
tracing::error!(task = %name, expression = %expression, error = %error, "Failed to compute next cron tick");
return;
}
};
state.task_registry.record_next_run_at(
&name,
&scheduled_at.with_timezone(&chrono::Utc).to_rfc3339(),
);
let sleep_for = cron_sleep_duration_until(&scheduled_at);
tokio::select! {
() = shutdown.cancelled() => break,
() = tokio::time::sleep(sleep_for) => {
let woke_at = chrono::Utc::now().with_timezone(&timezone);
match cron_occurrence_is_overdue(&cron, &scheduled_at, &woke_at) {
Ok(true) => {
tracing::warn!(
task = %name,
scheduled_at = %scheduled_at,
woke_at = %woke_at,
"Skipping overdue cron task tick"
);
cursor = woke_at;
continue;
}
Ok(false) => {}
Err(error) => {
tracing::error!(task = %name, expression = %expression, error = %error, "Failed to evaluate cron tick lateness");
return;
}
}
let scheduled_unix_secs = u64::try_from(scheduled_at.timestamp()).unwrap_or_default();
tokio::spawn(execute_cron_task(
name.clone(),
state.clone(),
handler,
coordination,
Arc::clone(&coordinator),
lease_ttl,
scheduled_unix_secs,
));
cursor = scheduled_at;
}
}
}
}
fn format_next_task_run_after(delay: std::time::Duration) -> String {
let now = chrono::Utc::now();
let Ok(delay) = chrono::TimeDelta::from_std(delay) else {
return now.to_rfc3339();
};
(now + delay).to_rfc3339()
}
fn next_cron_occurrence_after<Tz: chrono::TimeZone>(
cron: &croner::Cron,
cursor: &chrono::DateTime<Tz>,
now: &chrono::DateTime<Tz>,
) -> Result<chrono::DateTime<Tz>, croner::errors::CronError> {
let anchor = if cursor < now { now } else { cursor };
cron.find_next_occurrence(anchor, false)
}
fn cron_occurrence_is_overdue<Tz: chrono::TimeZone>(
cron: &croner::Cron,
scheduled_at: &chrono::DateTime<Tz>,
now: &chrono::DateTime<Tz>,
) -> Result<bool, croner::errors::CronError> {
let next_after_scheduled = cron.find_next_occurrence(scheduled_at, false)?;
Ok(&next_after_scheduled <= now)
}
fn cron_sleep_duration_until<Tz: chrono::TimeZone>(
scheduled_at: &chrono::DateTime<Tz>,
) -> std::time::Duration {
scheduled_at
.with_timezone(&chrono::Utc)
.signed_duration_since(chrono::Utc::now())
.to_std()
.unwrap_or_default()
}
async fn run_startup_hooks(hooks: &[StartupHook], state: AppState) -> crate::AutumnResult<()> {
for hook in hooks {
hook(state.clone()).await?;
}
Ok(())
}
fn initialize_job_runtime(
jobs: Vec<crate::job::JobInfo>,
state: &AppState,
shutdown: &tokio_util::sync::CancellationToken,
config: &crate::config::JobConfig,
) -> crate::AutumnResult<()> {
crate::job::clear_global_job_client();
if jobs.is_empty() {
Ok(())
} else {
crate::job::start_runtime(jobs, state, shutdown, config)
}
}
async fn run_shutdown_hooks(hooks: &[ShutdownHook]) {
for hook in hooks.iter().rev() {
hook().await;
}
}
#[allow(clippy::cognitive_complexity)]
fn log_startup_transparency(
routes: &[Route],
tasks: &[crate::task::TaskInfo],
scoped_groups: &[ScopedGroup],
config: &AutumnConfig,
) {
tracing::info!(
"Registered routes:{}",
format_route_lines(routes, scoped_groups, config)
);
if let Some(task_lines) = format_task_lines(tasks) {
tracing::info!("Scheduled tasks:{task_lines}");
}
tracing::info!("Active middleware: {}", format_middleware_list(config));
tracing::info!("Configuration:{}", format_config_summary(config));
}
fn fail_fast_on_invalid_session_config(config: &AutumnConfig, has_custom_session_store: bool) {
if has_custom_session_store {
return;
}
if let Err(error) = config.session.backend_plan(config.profile.as_deref()) {
eprintln!("Invalid session backend config: {error}");
std::process::exit(1);
}
}
fn fail_fast_on_invalid_signing_secret(config: &AutumnConfig) {
use crate::security::config::validate_signing_secret;
let is_production = matches!(config.profile.as_deref(), Some("prod" | "production"));
let secret = config.security.signing_secret.secret.as_deref();
if let Err(error) = validate_signing_secret(secret, is_production) {
eprintln!("Invalid signing secret configuration: {error}");
eprintln!(
" hint: generate a secret with `openssl rand -hex 32` and set \
AUTUMN_SECURITY__SIGNING_SECRET"
);
std::process::exit(1);
}
if is_production {
for (i, prev) in config
.security
.signing_secret
.previous_secrets
.iter()
.enumerate()
{
if let Err(error) = validate_signing_secret(Some(prev.as_str()), true) {
eprintln!("Invalid signing secret configuration: previous_secrets[{i}]: {error}");
eprintln!(
" hint: every previous secret must meet the same entropy requirement \
as the current secret"
);
std::process::exit(1);
}
}
}
}
fn fail_fast_on_invalid_webhook_config(config: &AutumnConfig) {
let is_production = matches!(config.profile.as_deref(), Some("prod" | "production"));
if let Err(error) = config.security.webhooks.validate(is_production) {
eprintln!("Invalid signed webhook configuration: {error}");
std::process::exit(1);
}
}
pub(crate) fn install_webhook_registry(state: &AppState, config: &AutumnConfig) {
if let Err(error) =
crate::webhook::install_registry_from_config(state, &config.security.webhooks)
{
eprintln!("Invalid signed webhook configuration: {error}");
std::process::exit(1);
}
}
#[cfg(feature = "storage")]
struct StorageBootstrap {
store: crate::storage::SharedBlobStore,
serving: Option<axum::Router<AppState>>,
}
#[cfg(feature = "storage")]
impl StorageBootstrap {
fn install(self, state: &AppState) -> Option<axum::Router<AppState>> {
state.insert_extension::<crate::storage::BlobStoreState>(
crate::storage::BlobStoreState::new(self.store),
);
self.serving
}
}
#[cfg(feature = "storage")]
#[allow(clippy::too_many_lines)] fn preflight_storage(config: &AutumnConfig) -> Option<StorageBootstrap> {
use crate::storage::StorageBackendPlan;
let plan = config
.storage
.backend_plan(config.profile.as_deref())
.unwrap_or_else(|error| {
tracing::error!(%error, "invalid storage backend config; aborting startup");
std::process::exit(1);
});
match plan {
StorageBackendPlan::Disabled => None,
StorageBackendPlan::Local {
provider_id,
root,
mount_path,
default_url_expiry_secs,
warn_in_production,
} => Some(bootstrap_local_storage(
config,
&provider_id,
&root,
&mount_path,
default_url_expiry_secs,
warn_in_production,
)),
StorageBackendPlan::S3 { .. } => {
tracing::error!(
"storage.backend=s3 requires the `autumn-storage-s3` plugin. \
Add it to your Cargo.toml, build an S3BlobStore from your config, \
and call `.with_blob_store(store)` on your AppBuilder. \
Aborting startup."
);
std::process::exit(1);
}
}
}
#[cfg(feature = "storage")]
fn bootstrap_local_storage(
config: &AutumnConfig,
provider_id: &str,
root: &std::path::Path,
mount_path: &str,
default_url_expiry_secs: u64,
warn_in_production: bool,
) -> StorageBootstrap {
use crate::storage::{LocalBlobStore, SharedBlobStore, local::SigningKey};
if warn_in_production {
tracing::warn!(
"prod profile is using the local-disk blob store; \
bytes won't survive replica turnover. Set \
storage.backend=s3 or storage.allow_local_in_production=true \
to acknowledge"
);
}
let (signing_key, previous_signing_keys) = config
.security
.signing_secret
.secret
.as_deref()
.filter(|s| !s.is_empty())
.map_or_else(
|| {
config
.storage
.local
.signing_key
.as_deref()
.filter(|s| !s.is_empty())
.map_or_else(
|| {
if matches!(config.profile.as_deref(), Some("prod" | "production")) {
tracing::warn!(
"no signing secret configured in prod; blob URL signatures \
won't survive a process restart. Set \
AUTUMN_SECURITY__SIGNING_SECRET."
);
}
(SigningKey::random(), vec![])
},
|legacy| (SigningKey::new(legacy.as_bytes().to_vec()), vec![]),
)
},
|secret| {
let current = SigningKey::new(secret.as_bytes().to_vec());
let previous = config
.security
.signing_secret
.previous_secrets
.iter()
.map(|s| SigningKey::new(s.as_bytes().to_vec()))
.collect::<Vec<_>>();
(current, previous)
},
);
let store = match LocalBlobStore::new(
provider_id.to_string(),
root.to_path_buf(),
mount_path.to_string(),
std::time::Duration::from_secs(default_url_expiry_secs),
signing_key,
previous_signing_keys,
) {
Ok(store) => store,
Err(err) => {
tracing::error!(
error = %err,
root = %root.display(),
"failed to initialize local blob store; aborting startup"
);
std::process::exit(1);
}
};
let serving = crate::storage::local::serve_router(&store);
let arc: SharedBlobStore = std::sync::Arc::new(store);
tracing::info!(
provider = %provider_id,
root = %root.display(),
mount = %mount_path,
"Local blob store mounted"
);
StorageBootstrap {
store: arc,
serving: Some(serving),
}
}
async fn load_config_and_telemetry(
config_loader: Option<ConfigLoaderFactory>,
telemetry_provider: Option<Box<dyn crate::telemetry::TelemetryProvider>>,
) -> (AutumnConfig, crate::telemetry::TelemetryGuard) {
let config = match config_loader {
Some(factory) => factory().await,
None => crate::config::TomlEnvConfigLoader::new().load().await,
}
.unwrap_or_else(|e| {
eprintln!("Failed to load configuration: {e}");
std::process::exit(1);
});
let provider: Box<dyn crate::telemetry::TelemetryProvider> = telemetry_provider
.unwrap_or_else(|| Box::new(crate::telemetry::TracingOtlpTelemetryProvider::new()));
let telemetry_guard = provider
.init(&config.log, &config.telemetry, config.profile.as_deref())
.unwrap_or_else(|error| {
eprintln!("Failed to initialize telemetry: {error}");
std::process::exit(1);
});
(config, telemetry_guard)
}
#[cfg(feature = "i18n")]
fn resolve_i18n_bundle(
explicit_bundle: Option<Arc<crate::i18n::Bundle>>,
auto_load: bool,
config: &AutumnConfig,
env: &dyn crate::config::Env,
) -> Option<Arc<crate::i18n::Bundle>> {
if explicit_bundle.is_some() {
return explicit_bundle;
}
if !auto_load {
return None;
}
let dir = project_dir(&config.i18n.dir, env);
Some(Arc::new(
crate::i18n::Bundle::load_from_dir(&dir, &config.i18n)
.unwrap_or_else(|e| panic!("i18n_auto: {e}")),
))
}
#[cfg(feature = "i18n")]
fn install_i18n_bundle_layer(
mut custom_layers: Vec<CustomLayerRegistration>,
state: &AppState,
bundle: Option<Arc<crate::i18n::Bundle>>,
) -> Vec<CustomLayerRegistration> {
let Some(bundle) = bundle else {
return custom_layers;
};
tracing::info!(
locales = ?bundle.locales(),
default = bundle.default_locale(),
"i18n bundle loaded"
);
state.insert_extension::<Arc<crate::i18n::Bundle>>(bundle.clone());
let ext_layer = axum::Extension(bundle);
custom_layers.push(CustomLayerRegistration {
type_id: TypeId::of::<axum::Extension<Arc<crate::i18n::Bundle>>>(),
apply: Box::new(move |router| router.layer(ext_layer)),
});
custom_layers
}
#[cfg(feature = "db")]
struct DatabaseBootstrap {
topology: Option<crate::db::DatabaseTopology>,
replica_readiness: Option<crate::migrate::ReplicaMigrationReadiness>,
replica_migration_check: Option<(String, String)>,
}
#[cfg(feature = "db")]
async fn setup_database(
config: &AutumnConfig,
migrations: Vec<crate::migrate::EmbeddedMigrations>,
pool_provider: Option<PoolProviderFactory>,
) -> Result<DatabaseBootstrap, String> {
let check_replica_migrations = !migrations.is_empty();
let topology = match pool_provider {
Some(factory) => factory(config.database.clone()).await,
None => crate::db::create_topology(&config.database),
}
.map_err(|e| format!("Failed to create database pool: {e}"))?;
if topology.is_some()
&& let Some(url) = config.database.effective_primary_url()
{
for mig in migrations {
crate::migrate::auto_migrate(
url,
config.profile.as_deref(),
config.database.auto_migrate_in_production,
mig,
);
}
}
let (replica_readiness, replica_migration_check) = if topology
.as_ref()
.is_some_and(|topology| check_replica_migrations && topology.replica().is_some())
{
match (
config.database.effective_primary_url(),
config.database.replica_url.as_deref(),
) {
(Some(primary_url), Some(replica_url)) => {
let primary_url = primary_url.to_owned();
let replica_url = replica_url.to_owned();
let readiness = crate::migrate::check_replica_migration_readiness_blocking(
primary_url.clone(),
replica_url.clone(),
)
.await;
(Some(readiness), Some((primary_url, replica_url)))
}
_ => (None, None),
}
} else {
(None, None)
};
Ok(DatabaseBootstrap {
topology,
replica_readiness,
replica_migration_check,
})
}
#[cfg(feature = "db")]
fn apply_replica_migration_readiness(
state: &AppState,
readiness: Option<crate::migrate::ReplicaMigrationReadiness>,
) {
let Some(readiness) = readiness else {
return;
};
if readiness.is_ready() {
state.probes().mark_replica_migrations_ready();
} else if let Some(detail) = readiness.detail() {
state.probes().mark_replica_migrations_unready(detail);
}
}
#[cfg(feature = "db")]
fn configure_replica_migration_check(state: &AppState, check: Option<(String, String)>) {
let Some((primary_url, replica_url)) = check else {
return;
};
state
.probes()
.configure_replica_migration_check(primary_url, replica_url);
}
fn collect_unguarded_repository_writes(
routes: &[Route],
scoped_groups: &[ScopedGroup],
) -> Vec<(String, String)> {
let mut offenders: Vec<(String, String)> = Vec::new();
let mut seen: std::collections::HashSet<(&'static str, &'static str)> =
std::collections::HashSet::new();
let mut record_route = |route: &Route| {
if let Some(meta) = route.repository
&& !meta.has_policy
&& is_mutating_method(&route.method)
&& seen.insert((meta.resource_type_name, meta.api_path))
{
offenders.push((meta.resource_type_name.to_owned(), meta.api_path.to_owned()));
}
};
for route in routes {
record_route(route);
}
for group in scoped_groups {
for route in &group.routes {
record_route(route);
}
}
offenders
}
fn format_unguarded_repository_listing(offenders: &[(String, String)]) -> String {
use std::fmt::Write;
let mut s = String::new();
let mut first = true;
for (name, path) in offenders {
if !first {
s.push('\n');
}
first = false;
write!(s, " - #[repository({name}, api = \"{path}\")]").unwrap();
}
s
}
fn validate_repository_api_policies(
routes: &[Route],
scoped_groups: &[ScopedGroup],
config: &AutumnConfig,
) {
let profile = config.profile.as_deref().unwrap_or("default");
let strict =
is_production_profile(profile) && !config.security.allow_unauthorized_repository_api;
let offenders = collect_unguarded_repository_writes(routes, scoped_groups);
if offenders.is_empty() {
return;
}
let listing = format_unguarded_repository_listing(&offenders);
if strict {
tracing::error!(
"refusing to start: the following #[repository(api = ...)] mutating endpoints have no paired `policy = ...` argument:\n{listing}\n\
Add `policy = SomePolicy` to each, or set `[security] allow_unauthorized_repository_api = true` to opt out explicitly."
);
std::process::exit(1);
} else {
tracing::warn!(
"the following #[repository(api = ...)] mutating endpoints have no paired `policy = ...` argument; \
auto-generated POST/PUT/PATCH/DELETE handlers will accept writes from any authenticated user:\n{listing}\n\
This will become a startup-time error in `prod` profile builds."
);
}
}
type MissingRepositoryRegistration = (String, String);
fn collect_unregistered_repository_handlers(
routes: &[Route],
scoped_groups: &[ScopedGroup],
registry: &crate::authorization::PolicyRegistry,
) -> (
Vec<MissingRepositoryRegistration>,
Vec<MissingRepositoryRegistration>,
) {
let mut missing_policies: Vec<(String, String)> = Vec::new();
let mut missing_scopes: Vec<(String, String)> = Vec::new();
let mut seen_policies: std::collections::HashSet<(&'static str, &'static str)> =
std::collections::HashSet::new();
let mut seen_scopes: std::collections::HashSet<(&'static str, &'static str)> =
std::collections::HashSet::new();
let mut record_route = |route: &Route| {
if let Some(meta) = route.repository {
if let Some(check) = meta.policy_check
&& !check(registry)
&& seen_policies.insert((meta.resource_type_name, meta.api_path))
{
missing_policies
.push((meta.resource_type_name.to_owned(), meta.api_path.to_owned()));
}
if let Some(check) = meta.scope_check
&& !check(registry)
&& seen_scopes.insert((meta.resource_type_name, meta.api_path))
{
missing_scopes.push((meta.resource_type_name.to_owned(), meta.api_path.to_owned()));
}
}
};
for route in routes {
record_route(route);
}
for group in scoped_groups {
for route in &group.routes {
record_route(route);
}
}
(missing_policies, missing_scopes)
}
fn format_missing_policy_listing(missing: &[(String, String)]) -> String {
use std::fmt::Write;
let mut s = String::new();
let mut first = true;
for (name, path) in missing {
if !first {
s.push('\n');
}
first = false;
write!(s, " - #[repository({name}, api = \"{path}\", policy = ...)]: call `.policy::<{name}, _>(...)` on the app builder").unwrap();
}
s
}
fn format_missing_scope_listing(missing: &[(String, String)]) -> String {
use std::fmt::Write;
let mut s = String::new();
let mut first = true;
for (name, path) in missing {
if !first {
s.push('\n');
}
first = false;
write!(s, " - #[repository({name}, api = \"{path}\", scope = ...)]: call `.scope::<{name}, _>(...)` on the app builder").unwrap();
}
s
}
#[allow(clippy::cognitive_complexity)]
fn validate_repository_policies_registered(
routes: &[Route],
scoped_groups: &[ScopedGroup],
state: &AppState,
config: &AutumnConfig,
) {
let profile = config.profile.as_deref().unwrap_or("default");
let strict = is_production_profile(profile);
let (missing_policies, missing_scopes) =
collect_unregistered_repository_handlers(routes, scoped_groups, state.policy_registry());
if missing_policies.is_empty() && missing_scopes.is_empty() {
return;
}
if !missing_policies.is_empty() {
let listing = format_missing_policy_listing(&missing_policies);
if strict {
tracing::error!(
"refusing to start: the following #[repository] routes declare a `policy = ...` argument, but no policy is registered for the resource type. Without registration, every protected request would fail at runtime with `500 no policy registered`:\n{listing}"
);
} else {
tracing::warn!(
"the following #[repository] routes declare `policy = ...` but no matching `.policy::<R, _>(...)` registration is on the app builder. Protected requests will 500 at runtime:\n{listing}\n\
This will become a startup-time error in `prod` profile builds."
);
}
}
if !missing_scopes.is_empty() {
let listing = format_missing_scope_listing(&missing_scopes);
if strict {
tracing::error!(
"refusing to start: the following #[repository] routes declare a `scope = ...` argument, but no scope is registered for the resource type. Without registration, every list request would fail at runtime with `500 missing scope registration`:\n{listing}"
);
} else {
tracing::warn!(
"the following #[repository] routes declare `scope = ...` but no matching `.scope::<R, _>(...)` registration is on the app builder. List requests will 500 at runtime:\n{listing}\n\
This will become a startup-time error in `prod` profile builds."
);
}
}
if strict {
std::process::exit(1);
}
}
const fn is_mutating_method(method: &http::Method) -> bool {
matches!(
*method,
http::Method::POST | http::Method::PUT | http::Method::PATCH | http::Method::DELETE
)
}
fn is_production_profile(profile: &str) -> bool {
matches!(profile, "prod" | "production")
}
#[cfg(test)]
mod validate_repository_api_policies_tests {
use super::*;
use crate::RepositoryApiMeta;
fn build_route(
method: http::Method,
path: &'static str,
meta: Option<RepositoryApiMeta>,
) -> Route {
Route {
method,
path,
handler: axum::routing::any(|| async { "" }),
name: "test_route",
api_doc: crate::openapi::ApiDoc::default(),
repository: meta,
}
}
fn unguarded(path: &'static str, type_name: &'static str) -> RepositoryApiMeta {
RepositoryApiMeta {
resource_type_name: type_name,
api_path: path,
has_policy: false,
policy_check: None,
scope_check: None,
}
}
fn collect_offenders(routes: &[Route]) -> Vec<(String, String)> {
collect_unguarded_repository_writes(routes, &[])
}
#[test]
fn read_only_mount_without_policy_is_not_an_offender() {
let routes = vec![
build_route(
http::Method::GET,
"/api/posts",
Some(unguarded("/api/posts", "Post")),
),
build_route(
http::Method::GET,
"/api/posts/{id}",
Some(unguarded("/api/posts", "Post")),
),
];
let offenders = collect_offenders(&routes);
assert!(
offenders.is_empty(),
"read-only mounts should not trigger the unauthorized-repo guard"
);
}
#[test]
fn write_mount_without_policy_is_an_offender() {
let routes = vec![build_route(
http::Method::POST,
"/api/posts",
Some(unguarded("/api/posts", "Post")),
)];
let offenders = collect_offenders(&routes);
assert_eq!(offenders.len(), 1);
assert_eq!(offenders[0].0, "Post");
assert_eq!(offenders[0].1, "/api/posts");
}
#[test]
fn mixed_mount_only_dedups_one_offender_per_repository() {
let routes = vec![
build_route(
http::Method::GET,
"/api/posts",
Some(unguarded("/api/posts", "Post")),
),
build_route(
http::Method::POST,
"/api/posts",
Some(unguarded("/api/posts", "Post")),
),
build_route(
http::Method::PUT,
"/api/posts/{id}",
Some(unguarded("/api/posts", "Post")),
),
build_route(
http::Method::DELETE,
"/api/posts/{id}",
Some(unguarded("/api/posts", "Post")),
),
];
let offenders = collect_offenders(&routes);
assert_eq!(offenders.len(), 1);
}
#[test]
fn is_mutating_method_classifies_methods() {
assert!(is_mutating_method(&http::Method::POST));
assert!(is_mutating_method(&http::Method::PUT));
assert!(is_mutating_method(&http::Method::PATCH));
assert!(is_mutating_method(&http::Method::DELETE));
assert!(!is_mutating_method(&http::Method::GET));
assert!(!is_mutating_method(&http::Method::HEAD));
assert!(!is_mutating_method(&http::Method::OPTIONS));
}
use crate::authorization::{Policy, PolicyRegistry};
#[derive(Debug, Clone, PartialEq)]
struct TestPost;
#[derive(Default)]
struct TestPostPolicy;
impl Policy<TestPost> for TestPostPolicy {}
fn guarded_with_check(path: &'static str, type_name: &'static str) -> RepositoryApiMeta {
RepositoryApiMeta {
resource_type_name: type_name,
api_path: path,
has_policy: true,
policy_check: Some(|registry: &PolicyRegistry| registry.has_policy::<TestPost>()),
scope_check: None,
}
}
fn collect_missing(routes: &[Route], registry: &PolicyRegistry) -> Vec<(String, String)> {
let (missing_policies, _) = collect_unregistered_repository_handlers(routes, &[], registry);
missing_policies
}
#[test]
fn registry_check_flags_routes_missing_their_policy_registration() {
let registry = PolicyRegistry::default();
let routes = vec![build_route(
http::Method::POST,
"/api/posts",
Some(guarded_with_check("/api/posts", "TestPost")),
)];
let missing = collect_missing(&routes, ®istry);
assert_eq!(missing.len(), 1);
assert_eq!(missing[0].0, "TestPost");
assert_eq!(missing[0].1, "/api/posts");
}
#[test]
fn registry_check_passes_when_policy_is_registered() {
let registry = PolicyRegistry::default();
registry.register_policy::<TestPost, _>(TestPostPolicy);
let routes = vec![build_route(
http::Method::POST,
"/api/posts",
Some(guarded_with_check("/api/posts", "TestPost")),
)];
let missing = collect_missing(&routes, ®istry);
assert!(missing.is_empty(), "policy is registered, no offenders");
}
#[test]
fn registry_check_skips_routes_without_policy_check_fn() {
let registry = PolicyRegistry::default();
let routes = vec![build_route(
http::Method::POST,
"/api/posts",
Some(unguarded("/api/posts", "TestPost")),
)];
let missing = collect_missing(&routes, ®istry);
assert!(missing.is_empty());
}
#[test]
fn registry_check_dedups_one_offender_per_repository() {
let registry = PolicyRegistry::default();
let routes = vec![
build_route(
http::Method::GET,
"/api/posts",
Some(guarded_with_check("/api/posts", "TestPost")),
),
build_route(
http::Method::POST,
"/api/posts",
Some(guarded_with_check("/api/posts", "TestPost")),
),
build_route(
http::Method::DELETE,
"/api/posts/{id}",
Some(guarded_with_check("/api/posts", "TestPost")),
),
];
let missing = collect_missing(&routes, ®istry);
assert_eq!(missing.len(), 1);
}
use crate::authorization::{BoxFuture, PolicyContext, Scope};
#[derive(Default)]
struct TestPostScope;
impl Scope<TestPost> for TestPostScope {
fn list<'a>(
&'a self,
_ctx: &'a PolicyContext,
_conn: &'a mut diesel_async::AsyncPgConnection,
) -> BoxFuture<'a, crate::AutumnResult<Vec<TestPost>>> {
Box::pin(async { Ok(Vec::new()) })
}
}
fn scope_only_meta(path: &'static str, type_name: &'static str) -> RepositoryApiMeta {
RepositoryApiMeta {
resource_type_name: type_name,
api_path: path,
has_policy: false,
policy_check: None,
scope_check: Some(|registry: &PolicyRegistry| registry.scope::<TestPost>().is_some()),
}
}
fn collect_missing_scopes(
routes: &[Route],
registry: &PolicyRegistry,
) -> Vec<(String, String)> {
let (_, missing_scopes) = collect_unregistered_repository_handlers(routes, &[], registry);
missing_scopes
}
#[test]
fn scope_check_flags_unregistered_scope() {
let registry = PolicyRegistry::default();
let routes = vec![build_route(
http::Method::GET,
"/api/posts",
Some(scope_only_meta("/api/posts", "TestPost")),
)];
let missing = collect_missing_scopes(&routes, ®istry);
assert_eq!(missing.len(), 1);
assert_eq!(missing[0].0, "TestPost");
}
#[test]
fn scope_check_passes_when_scope_is_registered() {
let registry = PolicyRegistry::default();
registry.register_scope::<TestPost, _>(TestPostScope);
let routes = vec![build_route(
http::Method::GET,
"/api/posts",
Some(scope_only_meta("/api/posts", "TestPost")),
)];
let missing = collect_missing_scopes(&routes, ®istry);
assert!(missing.is_empty());
}
#[test]
fn scope_check_skips_routes_without_scope_check_fn() {
let registry = PolicyRegistry::default();
let routes = vec![build_route(
http::Method::POST,
"/api/posts",
Some(unguarded("/api/posts", "TestPost")),
)];
let missing = collect_missing_scopes(&routes, ®istry);
assert!(missing.is_empty());
}
#[test]
fn is_production_profile_matches_both_aliases() {
assert!(is_production_profile("prod"));
assert!(is_production_profile("production"));
assert!(!is_production_profile("dev"));
assert!(!is_production_profile("staging"));
assert!(!is_production_profile("test"));
assert!(!is_production_profile("default"));
assert!(!is_production_profile("Prod"));
assert!(!is_production_profile("Production"));
}
#[test]
fn format_unguarded_listing_renders_one_bullet_per_offender() {
let offenders = vec![
("Post".to_owned(), "/api/posts".to_owned()),
("Comment".to_owned(), "/api/comments".to_owned()),
];
let listing = format_unguarded_repository_listing(&offenders);
assert!(listing.contains("Post"));
assert!(listing.contains("/api/posts"));
assert!(listing.contains("Comment"));
assert!(listing.contains("/api/comments"));
assert_eq!(listing.matches("\n - ").count() + 1, 2);
}
#[test]
fn format_unguarded_listing_empty_input_yields_empty_string() {
let listing = format_unguarded_repository_listing(&[]);
assert!(listing.is_empty());
}
#[test]
fn format_missing_policy_listing_includes_policy_call_hint() {
let missing = vec![("Post".to_owned(), "/api/posts".to_owned())];
let listing = format_missing_policy_listing(&missing);
assert!(listing.contains("Post"));
assert!(listing.contains("/api/posts"));
assert!(listing.contains(".policy::<Post, _>"));
assert!(listing.contains("policy = ..."));
}
#[test]
fn format_missing_scope_listing_includes_scope_call_hint() {
let missing = vec![("Post".to_owned(), "/api/posts".to_owned())];
let listing = format_missing_scope_listing(&missing);
assert!(listing.contains("Post"));
assert!(listing.contains("/api/posts"));
assert!(listing.contains(".scope::<Post, _>"));
assert!(listing.contains("scope = ..."));
}
#[test]
fn collect_unguarded_walks_scoped_groups() {
let group_route = build_route(
http::Method::POST,
"/api/posts",
Some(unguarded("/api/posts", "Post")),
);
let group = ScopedGroup {
prefix: "/scoped".to_owned(),
routes: vec![group_route],
source: crate::route_listing::RouteSource::User,
apply_layer: Box::new(|r| r),
};
let offenders = collect_unguarded_repository_writes(&[], std::slice::from_ref(&group));
assert_eq!(offenders.len(), 1);
assert_eq!(offenders[0].0, "Post");
}
#[test]
fn collect_unregistered_walks_scoped_groups() {
let group_route = build_route(
http::Method::POST,
"/api/posts",
Some(guarded_with_check("/api/posts", "TestPost")),
);
let group = ScopedGroup {
prefix: "/scoped".to_owned(),
routes: vec![group_route],
source: crate::route_listing::RouteSource::User,
apply_layer: Box::new(|r| r),
};
let registry = PolicyRegistry::default();
let (missing, _) =
collect_unregistered_repository_handlers(&[], std::slice::from_ref(&group), ®istry);
assert_eq!(missing.len(), 1);
assert_eq!(missing[0].0, "TestPost");
}
}
fn build_state(
config: &AutumnConfig,
#[cfg(feature = "db")] database_topology: Option<&crate::db::DatabaseTopology>,
#[cfg(feature = "ws")] channels_backend: Option<Arc<dyn crate::channels::ChannelsBackend>>,
) -> AppState {
#[cfg(feature = "ws")]
let shutdown = tokio_util::sync::CancellationToken::new();
#[cfg(feature = "ws")]
let channels = channels_backend.map_or_else(
|| {
crate::channels::Channels::from_config(&config.channels, shutdown.child_token())
.unwrap_or_else(|error| {
tracing::error!(error = %error, "Failed to configure channels backend");
std::process::exit(1);
})
},
crate::channels::Channels::with_shared_backend,
);
let state = AppState {
extensions: std::sync::Arc::new(std::sync::RwLock::new(std::collections::HashMap::new())),
#[cfg(feature = "db")]
pool: database_topology.map(|topology| topology.primary().clone()),
#[cfg(feature = "db")]
replica_pool: database_topology.and_then(|topology| topology.replica().cloned()),
profile: config.profile.clone(),
started_at: std::time::Instant::now(),
health_detailed: config.health.detailed,
probes: crate::probe::ProbeState::pending_startup(),
metrics: crate::middleware::MetricsCollector::new(),
log_levels: crate::actuator::LogLevels::new(&config.log.level),
task_registry: crate::actuator::TaskRegistry::new(),
job_registry: crate::actuator::JobRegistry::new(),
config_props: crate::actuator::ConfigProperties::from_config(config),
#[cfg(feature = "ws")]
channels,
#[cfg(feature = "ws")]
shutdown,
policy_registry: crate::authorization::PolicyRegistry::default(),
forbidden_response: config.security.forbidden_response,
auth_session_key: config.auth.session_key.clone(),
shared_cache: None,
};
#[cfg(feature = "db")]
if state.replica_pool.is_some() {
state
.probes()
.configure_replica_dependency(config.database.replica_fallback);
}
state.insert_extension(config.clone());
state
}
fn format_route_lines(
routes: &[Route],
scoped_groups: &[ScopedGroup],
config: &AutumnConfig,
) -> String {
use std::fmt::Write as _;
let mut out = String::new();
for route in routes {
let _ = write!(
out,
"\n {} {:<8} -> {}",
route.path, route.method, route.name
);
}
for group in scoped_groups {
for route in &group.routes {
let _ = write!(
out,
"\n {}{} {:<8} -> {} (scoped)",
group.prefix, route.path, route.method, route.name
);
}
}
let mut probe_paths = std::collections::HashSet::new();
for (path, name) in [
(config.health.live_path.as_str(), "live"),
(config.health.ready_path.as_str(), "ready"),
(config.health.startup_path.as_str(), "startup"),
(config.health.path.as_str(), "health"),
] {
if probe_paths.insert(path) {
let _ = write!(out, "\n {} {:<8} -> {}", path, "GET", name);
}
}
let _ = write!(
out,
"\n {} {:<8} -> actuator",
crate::actuator::actuator_route_glob(&config.actuator.prefix),
"GET"
);
#[cfg(feature = "htmx")]
{
out.push_str("\n /static/js/htmx.min.js GET -> htmx");
out.push_str("\n /static/js/autumn-htmx-csrf.js GET -> htmx csrf");
}
out
}
fn format_task_lines(tasks: &[crate::task::TaskInfo]) -> Option<String> {
use std::fmt::Write as _;
if tasks.is_empty() {
return None;
}
let mut out = String::new();
for task in tasks {
let schedule = task.schedule.to_string();
let _ = write!(out, "\n {} ({schedule})", task.name);
}
Some(out)
}
fn format_middleware_list(config: &AutumnConfig) -> String {
let mut items = vec![
"RequestId",
"SecurityHeaders",
"Session (in-memory)",
"ErrorPages",
];
if !config.cors.allowed_origins.is_empty() {
items.push("CORS");
}
if config.security.csrf.enabled {
items.push("CSRF");
}
items.push("Metrics");
items.join(", ")
}
fn mask_database_url(url: &str, pool_size: usize) -> String {
if let Ok(mut parsed_url) = url::Url::parse(url) {
if parsed_url.password().is_some() {
let _ = parsed_url.set_password(Some("****"));
return format!("{parsed_url} (pool_size={pool_size})");
}
format!("{parsed_url} (pool_size={pool_size})")
} else {
format!("**** (pool_size={pool_size})")
}
}
fn format_config_summary(config: &AutumnConfig) -> String {
let profile = config.profile.as_deref().unwrap_or("none");
let db_status = config.database.effective_primary_url().map_or_else(
|| "not configured".to_owned(),
|url| {
let primary = mask_database_url(url, config.database.effective_primary_pool_size());
if config.database.replica_url.is_some() {
format!(
"primary={primary}, replica=configured (pool_size={})",
config.database.effective_replica_pool_size()
)
} else {
primary
}
},
);
let telemetry_status = if config.telemetry.enabled {
let endpoint = config
.telemetry
.otlp_endpoint
.as_deref()
.unwrap_or("<missing endpoint>");
format!("{:?} -> {endpoint}", config.telemetry.protocol)
} else {
"disabled".to_owned()
};
format!(
"\
\n profile: {profile}\
\n server: {}:{}\
\n database: {db_status}\
\n log_level: {}\
\n log_format: {:?}\
\n telemetry: {telemetry_status}\
\n health: {} (detailed={})\
\n actuator: sensitive={}\
\n shutdown: {}s",
config.server.host,
config.server.port,
config.log.level,
config.log.format,
config.health.path,
config.health.detailed,
config.actuator.sensitive,
config.server.shutdown_timeout_secs,
)
}
pub(crate) fn project_dir(subdir: &str, env: &dyn crate::config::Env) -> std::path::PathBuf {
env.var("AUTUMN_MANIFEST_DIR").map_or_else(
|_| std::path::PathBuf::from(subdir),
|d| std::path::PathBuf::from(d).join(subdir),
)
}
async fn shutdown_signal() {
let ctrl_c = async {
tokio::signal::ctrl_c()
.await
.expect("Failed to install Ctrl+C handler");
tracing::info!("Received Ctrl+C, starting graceful shutdown");
};
#[cfg(unix)]
let terminate = async {
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
.expect("Failed to install SIGTERM handler")
.recv()
.await;
tracing::info!("Received SIGTERM, starting graceful shutdown");
};
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
tokio::select! {
() = ctrl_c => {},
() = terminate => {},
}
}
#[cfg(test)]
mod tests {
use super::*;
use axum::body::Body;
use axum::http::{Request, StatusCode};
use std::sync::atomic::{AtomicUsize, Ordering};
use tower::ServiceExt;
#[cfg(feature = "mail")]
struct MailTestNoopQueue;
#[cfg(feature = "mail")]
impl crate::mail::MailDeliveryQueue for MailTestNoopQueue {
fn enqueue<'a>(
&'a self,
_mail: crate::mail::Mail,
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = Result<(), crate::mail::MailError>> + Send + 'a>,
> {
Box::pin(async { Ok(()) })
}
}
#[cfg(feature = "mail")]
fn test_mail() -> crate::mail::Mail {
crate::mail::Mail::builder()
.to("test@example.com")
.subject("hi")
.text("hello")
.build()
.expect("test mail should build")
}
pub fn test_router(routes: Vec<Route>) -> axum::Router {
let config = AutumnConfig::default();
let state = AppState {
extensions: std::sync::Arc::new(std::sync::RwLock::new(
std::collections::HashMap::new(),
)),
#[cfg(feature = "db")]
pool: None,
#[cfg(feature = "db")]
replica_pool: None,
profile: None,
started_at: std::time::Instant::now(),
health_detailed: true,
probes: crate::probe::ProbeState::ready_for_test(),
metrics: crate::middleware::MetricsCollector::new(),
log_levels: crate::actuator::LogLevels::new("info"),
task_registry: crate::actuator::TaskRegistry::new(),
job_registry: crate::actuator::JobRegistry::new(),
config_props: crate::actuator::ConfigProperties::default(),
#[cfg(feature = "ws")]
channels: crate::channels::Channels::new(32),
#[cfg(feature = "ws")]
shutdown: tokio_util::sync::CancellationToken::new(),
policy_registry: crate::authorization::PolicyRegistry::default(),
forbidden_response: crate::authorization::ForbiddenResponse::default(),
auth_session_key: "user_id".to_owned(),
shared_cache: None,
};
crate::router::build_router(routes, &config, state)
}
#[cfg(feature = "db")]
#[test]
fn build_state_applies_replica_fallback_policy_to_read_routing() {
let mut config = AutumnConfig::default();
config.database.primary_url = Some("postgres://localhost/primary".to_owned());
config.database.primary_pool_size = Some(5);
config.database.replica_url = Some("postgres://localhost/replica".to_owned());
config.database.replica_pool_size = Some(2);
config.database.replica_fallback = crate::config::ReplicaFallback::Primary;
let topology = crate::db::create_topology(&config.database)
.expect("topology should build")
.expect("database should be configured");
let state = build_state(
&config,
Some(&topology),
#[cfg(feature = "ws")]
None,
);
state
.probes()
.mark_replica_unready("replica migrations lag primary");
assert_eq!(state.read_pool().expect("read pool").status().max_size, 5);
}
#[cfg(feature = "db")]
#[tokio::test]
async fn custom_pool_provider_preserves_configured_replica_topology() {
struct PassthroughPoolProvider;
impl crate::db::DatabasePoolProvider for PassthroughPoolProvider {
async fn create_pool(
&self,
config: &crate::config::DatabaseConfig,
) -> Result<
Option<
diesel_async::pooled_connection::deadpool::Pool<
diesel_async::AsyncPgConnection,
>,
>,
crate::db::PoolError,
> {
crate::db::create_pool(config)
}
}
let mut config = AutumnConfig::default();
config.database.primary_url = Some("postgres://localhost/primary".to_owned());
config.database.primary_pool_size = Some(5);
config.database.replica_url = Some("postgres://localhost/replica".to_owned());
config.database.replica_pool_size = Some(2);
config.database.replica_fallback = crate::config::ReplicaFallback::FailReadiness;
let AppBuilder {
pool_provider_factory,
..
} = app().with_pool_provider(PassthroughPoolProvider);
let database = setup_database(&config, Vec::new(), pool_provider_factory)
.await
.expect("custom provider should build database topology");
let topology = database.topology.expect("database should be configured");
assert_eq!(topology.primary().status().max_size, 5);
assert_eq!(
topology
.replica()
.expect("custom provider should create replica pool")
.status()
.max_size,
2
);
let state = build_state(
&config,
Some(&topology),
#[cfg(feature = "ws")]
None,
);
state
.probes()
.mark_replica_connection_unready("replica connection failed");
assert!(state.read_pool().is_none());
let (status, _) = crate::probe::readiness_response(&state).await;
assert_eq!(status, StatusCode::SERVICE_UNAVAILABLE);
}
#[cfg(feature = "db")]
#[test]
fn configure_replica_migration_check_stores_recheck_urls() {
let mut config = AutumnConfig::default();
config.database.primary_url = Some("postgres://localhost/primary".to_owned());
config.database.replica_url = Some("postgres://localhost/replica".to_owned());
let topology = crate::db::create_topology(&config.database)
.expect("topology should build")
.expect("database should be configured");
let state = build_state(
&config,
Some(&topology),
#[cfg(feature = "ws")]
None,
);
assert!(
state.probes().replica_migration_check().is_none(),
"build_state should not enable migration checks without registered migrations"
);
configure_replica_migration_check(
&state,
Some((
"postgres://localhost/primary".to_owned(),
"postgres://localhost/replica".to_owned(),
)),
);
let check = state
.probes()
.replica_migration_check()
.expect("replica migration check should be configured");
assert_eq!(check.primary_url, "postgres://localhost/primary");
assert_eq!(check.replica_url, "postgres://localhost/replica");
}
#[cfg(feature = "db")]
#[tokio::test]
async fn replica_migration_readiness_marks_ready_endpoint_degraded() {
let mut config = AutumnConfig::default();
config.database.primary_url = Some("postgres://localhost/primary".to_owned());
config.database.primary_pool_size = Some(5);
config.database.replica_url = Some("postgres://localhost/replica".to_owned());
config.database.replica_pool_size = Some(2);
config.database.replica_fallback = crate::config::ReplicaFallback::FailReadiness;
let topology = crate::db::create_topology(&config.database)
.expect("topology should build")
.expect("database should be configured");
let state = build_state(
&config,
Some(&topology),
#[cfg(feature = "ws")]
None,
);
apply_replica_migration_readiness(
&state,
Some(crate::migrate::ReplicaMigrationReadiness::Stale {
primary_latest: Some("00000000000002".to_owned()),
replica_latest: Some("00000000000001".to_owned()),
}),
);
let (status, _) = crate::probe::readiness_response(&state).await;
assert_eq!(status, StatusCode::SERVICE_UNAVAILABLE);
}
#[cfg(feature = "db")]
#[tokio::test]
async fn blocking_replica_migration_readiness_reports_unknown_connection_errors() {
let readiness = crate::migrate::check_replica_migration_readiness_blocking(
"not-a-primary-url".to_owned(),
"not-a-replica-url".to_owned(),
)
.await;
assert!(matches!(
readiness,
crate::migrate::ReplicaMigrationReadiness::Unknown(_)
));
}
#[cfg(feature = "ws")]
#[test]
fn with_channels_backend_overrides_config_driven_backend_selection() {
let builder = app().with_channels_backend(crate::channels::LocalChannelsBackend::new(4));
let AppBuilder {
channels_backend, ..
} = builder;
assert!(channels_backend.is_some());
let mut config = AutumnConfig::default();
config.channels.backend = crate::config::ChannelBackend::Redis;
config.channels.redis.url = None;
let state = build_state(
&config,
#[cfg(feature = "db")]
None,
#[cfg(feature = "ws")]
channels_backend,
);
let mut rx = state.channels().subscribe("override");
state
.broadcast()
.publish("override", "ok")
.expect("custom local backend should publish");
assert_eq!(rx.try_recv().expect("message should arrive").as_str(), "ok");
}
pub fn test_get_route(path: &'static str, name: &'static str) -> Route {
Route {
method: http::Method::GET,
path,
handler: axum::routing::get(|| async { "ok" }),
name,
api_doc: crate::openapi::ApiDoc {
method: "GET",
path,
operation_id: name,
success_status: 200,
..Default::default()
},
repository: None,
}
}
#[cfg(feature = "i18n")]
fn test_i18n_bundle(key: &str, value: &str) -> Arc<crate::i18n::Bundle> {
let mut messages = std::collections::HashMap::new();
let mut en = std::collections::HashMap::new();
en.insert(key.to_owned(), value.to_owned());
messages.insert("en".to_owned(), en);
Arc::new(crate::i18n::Bundle::from_messages(
messages,
&crate::i18n::I18nConfig::default(),
))
}
#[cfg(feature = "i18n")]
#[test]
fn i18n_auto_defers_loading_until_runtime_config_is_available() {
let builder = app().i18n_auto();
assert!(builder.i18n_bundle.is_none());
assert!(builder.i18n_auto_load);
}
#[cfg(feature = "i18n")]
#[derive(Clone)]
struct StaticConfigLoader {
config: AutumnConfig,
}
#[cfg(feature = "i18n")]
impl crate::config::ConfigLoader for StaticConfigLoader {
async fn load(&self) -> Result<AutumnConfig, crate::config::ConfigError> {
Ok(self.config.clone())
}
}
#[cfg(feature = "i18n")]
struct NoopTelemetryProvider;
#[cfg(feature = "i18n")]
impl crate::telemetry::TelemetryProvider for NoopTelemetryProvider {
fn init(
&self,
_log: &crate::config::LogConfig,
_telemetry: &crate::config::TelemetryConfig,
_profile: Option<&str>,
) -> Result<crate::telemetry::TelemetryGuard, crate::telemetry::TelemetryInitError>
{
Ok(crate::telemetry::TelemetryGuard::disabled())
}
}
#[cfg(feature = "i18n")]
#[tokio::test]
async fn i18n_auto_uses_config_loader_output_for_bundle_dir() {
let project = tempfile::tempdir().expect("project dir");
let i18n_dir = project.path().join("custom-i18n");
std::fs::create_dir_all(&i18n_dir).expect("i18n dir");
std::fs::write(i18n_dir.join("en.ftl"), "nav.home = Loader Home\n").expect("bundle");
let mut config = AutumnConfig::default();
config.i18n.dir = "custom-i18n".to_owned();
let builder = app()
.with_config_loader(StaticConfigLoader { config })
.with_telemetry_provider(NoopTelemetryProvider)
.i18n_auto();
let AppBuilder {
config_loader_factory,
telemetry_provider,
i18n_bundle,
i18n_auto_load,
..
} = builder;
let (loaded_config, _guard) =
load_config_and_telemetry(config_loader_factory, telemetry_provider).await;
let env = crate::config::MockEnv::new().with(
"AUTUMN_MANIFEST_DIR",
project.path().to_str().expect("utf-8 path"),
);
let bundle = resolve_i18n_bundle(i18n_bundle, i18n_auto_load, &loaded_config, &env)
.expect("bundle loaded from configured dir");
assert_eq!(bundle.translate("en", "nav.home", &[]), "Loader Home");
}
#[cfg(feature = "i18n")]
#[tokio::test]
async fn i18n_bundle_layer_is_applied_to_static_route_rendering() {
async fn localized(locale: crate::i18n::Locale) -> String {
locale.t("nav.home")
}
let config = AutumnConfig::default();
let state = AppState::for_test();
let custom_layers = install_i18n_bundle_layer(
Vec::new(),
&state,
Some(test_i18n_bundle("nav.home", "Home")),
);
let router = crate::router::try_build_router_inner(
vec![Route {
method: http::Method::GET,
path: "/about",
handler: axum::routing::get(localized),
name: "localized",
api_doc: crate::openapi::ApiDoc {
method: "GET",
path: "/about",
operation_id: "localized",
success_status: 200,
..Default::default()
},
repository: None,
}],
&config,
state,
crate::router::RouterContext {
exception_filters: Vec::new(),
scoped_groups: Vec::new(),
merge_routers: Vec::new(),
nest_routers: Vec::new(),
custom_layers,
error_page_renderer: None,
session_store: None,
#[cfg(feature = "openapi")]
openapi: None,
},
)
.expect("router builds");
let tmp = tempfile::tempdir().expect("dist parent");
let dist = tmp.path().join("dist");
crate::static_gen::render_static_routes(
router,
&[crate::static_gen::StaticRouteMeta {
path: "/about",
name: "localized",
revalidate: None,
params_fn: None,
}],
&dist,
)
.await
.expect("static render succeeds");
let html = std::fs::read_to_string(dist.join("about/index.html")).expect("rendered html");
assert_eq!(html, "Home");
}
#[test]
fn app_builder_routes_adds_routes() {
let builder = app();
assert_eq!(builder.routes.len(), 0);
let builder = builder.routes(vec![test_get_route("/1", "route1")]);
assert_eq!(builder.routes.len(), 1);
let builder = builder.routes(vec![
test_get_route("/2", "route2"),
test_get_route("/3", "route3"),
]);
assert_eq!(builder.routes.len(), 3);
assert_eq!(builder.routes[0].path, "/1");
assert_eq!(builder.routes[1].path, "/2");
assert_eq!(builder.routes[2].path, "/3");
}
#[test]
fn app_builder_extensions_store_and_update_typed_values() {
let builder = app()
.with_extension::<String>("haunted".into())
.update_extension::<String, _, _>(String::new, |value| value.push_str(" harvest"));
let value = builder
.extension::<String>()
.expect("string extension should be present");
assert_eq!(value, "haunted harvest");
}
#[cfg(feature = "mail")]
#[tokio::test]
async fn app_builder_with_mail_delivery_queue_stores_queue_for_install() {
let builder = app().with_mail_delivery_queue(MailTestNoopQueue);
let factory = builder
.mail_delivery_queue_factory
.expect("with_mail_delivery_queue should store a factory on the builder");
let state = AppState::for_test();
let queue = factory(&state).expect("trivial factory should produce the queue");
assert!(Arc::strong_count(&queue) >= 1);
queue
.enqueue(test_mail())
.await
.expect("noop queue should always succeed");
}
#[cfg(feature = "mail")]
#[test]
fn app_builder_with_mail_delivery_queue_factory_runs_with_app_state() {
let observed_profile: Arc<std::sync::Mutex<Option<String>>> =
Arc::new(std::sync::Mutex::new(None));
let captured = Arc::clone(&observed_profile);
let builder = app().with_mail_delivery_queue_factory(move |state| {
*captured.lock().expect("lock") = Some(state.profile().to_owned());
Ok::<_, crate::AutumnError>(MailTestNoopQueue)
});
let factory = builder
.mail_delivery_queue_factory
.expect("factory should be stored on the builder");
let state = AppState::for_test().with_profile("dev");
let _queue = factory(&state).expect("factory should succeed");
assert_eq!(
observed_profile.lock().expect("lock").as_deref(),
Some("dev"),
"factory must run with the live AppState"
);
}
#[cfg(feature = "mail")]
#[test]
fn app_builder_with_mail_delivery_queue_factory_propagates_errors() {
let builder = app().with_mail_delivery_queue_factory(|_state| {
Err::<MailTestNoopQueue, _>(crate::AutumnError::service_unavailable_msg("factory boom"))
});
let factory = builder
.mail_delivery_queue_factory
.expect("factory present");
let state = AppState::for_test();
match factory(&state) {
Ok(_) => panic!("factory should have errored"),
Err(err) => assert!(err.to_string().contains("factory boom")),
}
}
#[tokio::test]
async fn startup_and_shutdown_hooks_run_in_expected_order() {
let events = Arc::new(std::sync::Mutex::new(Vec::<&'static str>::new()));
let startup_events = Arc::clone(&events);
let shutdown_a = Arc::clone(&events);
let shutdown_b = Arc::clone(&events);
let builder = app()
.on_startup(move |_state| {
let startup_events = Arc::clone(&startup_events);
async move {
startup_events
.lock()
.expect("events lock poisoned")
.push("start");
Ok(())
}
})
.on_shutdown(move || {
let shutdown_a = Arc::clone(&shutdown_a);
async move {
shutdown_a
.lock()
.expect("events lock poisoned")
.push("stop-a");
}
})
.on_shutdown(move || {
let shutdown_b = Arc::clone(&shutdown_b);
async move {
shutdown_b
.lock()
.expect("events lock poisoned")
.push("stop-b");
}
});
run_startup_hooks(&builder.startup_hooks, AppState::for_test())
.await
.expect("startup hooks should succeed");
run_shutdown_hooks(&builder.shutdown_hooks).await;
let recorded_events = events.lock().expect("events lock poisoned").clone();
assert_eq!(recorded_events, vec!["start", "stop-b", "stop-a"]);
}
fn startup_noop_job_handler(
_state: AppState,
_payload: serde_json::Value,
) -> Pin<Box<dyn Future<Output = crate::AutumnResult<()>> + Send + 'static>> {
Box::pin(async move { Ok(()) })
}
#[tokio::test]
async fn startup_hooks_can_enqueue_jobs_after_runtime_init() {
let _guard = crate::job::global_job_runtime_test_lock().lock().await;
crate::job::clear_global_job_client();
let builder = app()
.jobs(vec![crate::job::JobInfo {
name: "startup-seed".to_string(),
max_attempts: 1,
initial_backoff_ms: 1,
handler: startup_noop_job_handler,
}])
.on_startup(|_state| async {
crate::job::enqueue("startup-seed", serde_json::json!({ "kind": "warmup" })).await
});
let state = AppState::for_test().with_profile("dev");
let shutdown = tokio_util::sync::CancellationToken::new();
initialize_job_runtime(
builder.jobs.clone(),
&state,
&shutdown,
&crate::config::JobConfig::default(),
)
.expect("job runtime should initialize before startup hooks");
run_startup_hooks(&builder.startup_hooks, state.clone())
.await
.expect("startup hook should be able to enqueue jobs");
tokio::time::timeout(std::time::Duration::from_secs(1), async {
loop {
let snapshot = state.job_registry().snapshot();
let status = snapshot
.get("startup-seed")
.expect("job should be registered before startup hooks run");
if status.total_successes == 1 {
break;
}
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
}
})
.await
.expect("startup-enqueued job should complete");
shutdown.cancel();
crate::job::clear_global_job_client();
}
#[tokio::test]
async fn initialize_job_runtime_propagates_redis_init_errors() {
let _guard = crate::job::global_job_runtime_test_lock().lock().await;
crate::job::clear_global_job_client();
let state = AppState::for_test().with_profile("dev");
let shutdown = tokio_util::sync::CancellationToken::new();
let config = crate::config::JobConfig {
backend: "redis".to_string(),
..Default::default()
};
let error = initialize_job_runtime(
vec![crate::job::JobInfo {
name: "startup-seed".to_string(),
max_attempts: 1,
initial_backoff_ms: 1,
handler: startup_noop_job_handler,
}],
&state,
&shutdown,
&config,
)
.expect_err("redis init errors should abort startup");
#[cfg(feature = "redis")]
assert!(
error
.to_string()
.contains("jobs.backend=redis requires jobs.redis.url"),
"unexpected error: {error}"
);
#[cfg(not(feature = "redis"))]
assert!(
error
.to_string()
.contains("jobs.backend=redis requested but redis feature is disabled"),
"unexpected error: {error}"
);
}
#[tokio::test]
async fn startup_hook_errors_propagate() {
let builder = app().on_startup(|_state| async {
Err(crate::AutumnError::service_unavailable_msg(
"startup ritual failed",
))
});
let error = run_startup_hooks(&builder.startup_hooks, AppState::for_test())
.await
.expect_err("startup hook should fail");
assert!(error.to_string().contains("startup ritual failed"));
}
#[tokio::test]
async fn build_router_mounts_user_routes() {
let router = test_router(vec![test_get_route("/test", "test_handler")]);
let response = router
.oneshot(Request::builder().uri("/test").body(Body::empty()).unwrap())
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
assert_eq!(&body[..], b"ok");
}
#[tokio::test]
async fn build_router_mounts_health_check_at_default_path() {
let router = test_router(vec![test_get_route("/dummy", "dummy")]);
let response = router
.oneshot(
Request::builder()
.uri("/health")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(json["status"], "ok");
}
#[tokio::test]
async fn build_router_mounts_health_check_at_custom_path() {
let mut config = AutumnConfig::default();
config.health.path = "/healthz".to_owned();
let state = AppState {
extensions: std::sync::Arc::new(std::sync::RwLock::new(
std::collections::HashMap::new(),
)),
#[cfg(feature = "db")]
pool: None,
#[cfg(feature = "db")]
replica_pool: None,
profile: None,
started_at: std::time::Instant::now(),
health_detailed: true,
probes: crate::probe::ProbeState::ready_for_test(),
metrics: crate::middleware::MetricsCollector::new(),
log_levels: crate::actuator::LogLevels::new("info"),
task_registry: crate::actuator::TaskRegistry::new(),
job_registry: crate::actuator::JobRegistry::new(),
config_props: crate::actuator::ConfigProperties::default(),
#[cfg(feature = "ws")]
channels: crate::channels::Channels::new(32),
#[cfg(feature = "ws")]
shutdown: tokio_util::sync::CancellationToken::new(),
policy_registry: crate::authorization::PolicyRegistry::default(),
forbidden_response: crate::authorization::ForbiddenResponse::default(),
auth_session_key: "user_id".to_owned(),
shared_cache: None,
};
let router =
crate::router::build_router(vec![test_get_route("/dummy", "dummy")], &config, state);
let response = router
.oneshot(
Request::builder()
.uri("/healthz")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
}
#[tokio::test]
async fn build_router_adds_request_id_header() {
let router = test_router(vec![test_get_route("/test", "test")]);
let response = router
.oneshot(Request::builder().uri("/test").body(Body::empty()).unwrap())
.await
.unwrap();
assert!(response.headers().contains_key("x-request-id"));
}
#[tokio::test]
async fn build_router_unknown_route_returns_404() {
let router = test_router(vec![test_get_route("/exists", "exists")]);
let response = router
.oneshot(Request::builder().uri("/nope").body(Body::empty()).unwrap())
.await
.unwrap();
assert_eq!(response.status(), StatusCode::NOT_FOUND);
}
#[tokio::test]
async fn build_router_multiple_routes() {
let router = test_router(vec![test_get_route("/a", "a"), test_get_route("/b", "b")]);
let resp_a = router
.clone()
.oneshot(Request::builder().uri("/a").body(Body::empty()).unwrap())
.await
.unwrap();
assert_eq!(resp_a.status(), StatusCode::OK);
let resp_b = router
.oneshot(Request::builder().uri("/b").body(Body::empty()).unwrap())
.await
.unwrap();
assert_eq!(resp_b.status(), StatusCode::OK);
}
#[tokio::test]
async fn build_router_post_route() {
let post_routes = vec![Route {
method: http::Method::POST,
path: "/submit",
handler: axum::routing::post(|| async { "posted" }),
name: "submit",
api_doc: crate::openapi::ApiDoc {
method: "POST",
path: "/submit",
operation_id: "submit",
success_status: 200,
..Default::default()
},
repository: None,
}];
let config = AutumnConfig::default();
let state = AppState {
extensions: std::sync::Arc::new(std::sync::RwLock::new(
std::collections::HashMap::new(),
)),
#[cfg(feature = "db")]
pool: None,
#[cfg(feature = "db")]
replica_pool: None,
profile: None,
started_at: std::time::Instant::now(),
health_detailed: true,
probes: crate::probe::ProbeState::ready_for_test(),
metrics: crate::middleware::MetricsCollector::new(),
log_levels: crate::actuator::LogLevels::new("info"),
task_registry: crate::actuator::TaskRegistry::new(),
job_registry: crate::actuator::JobRegistry::new(),
config_props: crate::actuator::ConfigProperties::default(),
#[cfg(feature = "ws")]
channels: crate::channels::Channels::new(32),
#[cfg(feature = "ws")]
shutdown: tokio_util::sync::CancellationToken::new(),
policy_registry: crate::authorization::PolicyRegistry::default(),
forbidden_response: crate::authorization::ForbiddenResponse::default(),
auth_session_key: "user_id".to_owned(),
shared_cache: None,
};
let router = crate::router::build_router(post_routes, &config, state);
let response = router
.oneshot(
Request::builder()
.method("POST")
.uri("/submit")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
}
#[tokio::test]
async fn build_router_merges_methods_on_same_path() {
let route_list = vec![
Route {
method: http::Method::GET,
path: "/admin",
handler: axum::routing::get(|| async { "list" }),
name: "admin_list",
api_doc: crate::openapi::ApiDoc {
method: "GET",
path: "/admin",
operation_id: "admin_list",
success_status: 200,
..Default::default()
},
repository: None,
},
Route {
method: http::Method::POST,
path: "/admin",
handler: axum::routing::post(|| async { "created" }),
name: "create",
api_doc: crate::openapi::ApiDoc {
method: "POST",
path: "/admin",
operation_id: "create",
success_status: 200,
..Default::default()
},
repository: None,
},
];
let config = AutumnConfig::default();
let state = AppState {
extensions: std::sync::Arc::new(std::sync::RwLock::new(
std::collections::HashMap::new(),
)),
#[cfg(feature = "db")]
pool: None,
#[cfg(feature = "db")]
replica_pool: None,
profile: None,
started_at: std::time::Instant::now(),
health_detailed: true,
probes: crate::probe::ProbeState::ready_for_test(),
metrics: crate::middleware::MetricsCollector::new(),
log_levels: crate::actuator::LogLevels::new("info"),
task_registry: crate::actuator::TaskRegistry::new(),
job_registry: crate::actuator::JobRegistry::new(),
config_props: crate::actuator::ConfigProperties::default(),
#[cfg(feature = "ws")]
channels: crate::channels::Channels::new(32),
#[cfg(feature = "ws")]
shutdown: tokio_util::sync::CancellationToken::new(),
policy_registry: crate::authorization::PolicyRegistry::default(),
forbidden_response: crate::authorization::ForbiddenResponse::default(),
auth_session_key: "user_id".to_owned(),
shared_cache: None,
};
let router = crate::router::build_router(route_list, &config, state);
let resp = router
.clone()
.oneshot(
Request::builder()
.uri("/admin")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
.await
.unwrap();
assert_eq!(&body[..], b"list");
let resp = router
.oneshot(
Request::builder()
.method("POST")
.uri("/admin")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
.await
.unwrap();
assert_eq!(&body[..], b"created");
}
#[cfg(feature = "htmx")]
#[tokio::test]
async fn htmx_handler_returns_javascript_with_correct_headers() {
let app = axum::Router::new().route(
crate::htmx::HTMX_JS_PATH,
axum::routing::get(crate::router::htmx_handler),
);
let response = app
.oneshot(
Request::builder()
.uri(crate::htmx::HTMX_JS_PATH)
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let content_type = response
.headers()
.get("content-type")
.unwrap()
.to_str()
.unwrap();
assert!(
content_type.contains("application/javascript"),
"Expected application/javascript, got {content_type}"
);
let cache_control = response
.headers()
.get("cache-control")
.unwrap()
.to_str()
.unwrap();
assert!(
cache_control.contains("immutable"),
"Expected immutable cache, got {cache_control}"
);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
assert_eq!(body.len(), crate::htmx::HTMX_JS.len());
let start = std::str::from_utf8(&body[..50]).expect("htmx should be valid UTF-8");
assert!(
start.contains("htmx") || start.contains("function"),
"Response doesn't look like htmx JavaScript: {start}"
);
}
#[cfg(feature = "htmx")]
#[tokio::test]
async fn htmx_csrf_handler_returns_csp_compatible_javascript() {
let app = axum::Router::new().route(
crate::htmx::HTMX_CSRF_JS_PATH,
axum::routing::get(crate::router::htmx_csrf_handler),
);
let response = app
.oneshot(
Request::builder()
.uri(crate::htmx::HTMX_CSRF_JS_PATH)
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(
response
.headers()
.get("content-type")
.and_then(|value| value.to_str().ok()),
Some("application/javascript")
);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let js = std::str::from_utf8(&body).expect("csrf helper should be valid utf-8");
assert!(js.contains("htmx:configRequest"));
assert!(js.contains("X-CSRF-Token"));
assert!(!js.contains("<script"));
}
#[cfg(feature = "htmx")]
#[tokio::test]
async fn build_router_serves_htmx_js() {
let router = test_router(vec![test_get_route("/dummy", "dummy")]);
let response = router
.oneshot(
Request::builder()
.uri(crate::htmx::HTMX_JS_PATH)
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let ct = response
.headers()
.get("content-type")
.unwrap()
.to_str()
.unwrap();
assert!(ct.contains("javascript"));
}
#[cfg(feature = "htmx")]
#[tokio::test]
async fn build_router_serves_htmx_csrf_js() {
let router = test_router(vec![test_get_route("/dummy", "dummy")]);
let response = router
.oneshot(
Request::builder()
.uri(crate::htmx::HTMX_CSRF_JS_PATH)
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let csp = response
.headers()
.get("content-security-policy")
.expect("framework JS should still receive security headers")
.to_str()
.unwrap();
assert!(csp.contains("script-src 'self'"), "csp = {csp}");
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let js = std::str::from_utf8(&body).expect("csrf helper should be valid utf-8");
assert!(js.contains("htmx:configRequest"));
assert!(js.contains("X-CSRF-Token"));
}
#[tokio::test]
async fn build_router_serves_default_favicon_without_404() {
let router = test_router(vec![test_get_route("/dummy", "dummy")]);
let response = router
.oneshot(
Request::builder()
.uri(crate::router::DEFAULT_FAVICON_PATH)
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::NO_CONTENT);
assert!(
response.headers().contains_key("content-security-policy"),
"framework fallback responses should still receive security headers"
);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
assert!(body.is_empty());
}
#[tokio::test]
async fn build_router_does_not_override_user_favicon_route() {
let router = test_router(vec![test_get_route(
crate::router::DEFAULT_FAVICON_PATH,
"favicon",
)]);
let response = router
.oneshot(
Request::builder()
.uri(crate::router::DEFAULT_FAVICON_PATH)
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
assert_eq!(&body[..], b"ok");
}
#[tokio::test]
async fn build_router_serves_static_files_for_unmatched_paths() {
use std::collections::HashMap;
let tmp = tempfile::tempdir().expect("tempdir");
let dist = tmp.path().join("dist");
std::fs::create_dir_all(dist.join("docs")).expect("mkdir");
std::fs::write(dist.join("docs/index.html"), "<h1>Static Docs</h1>").expect("write");
let manifest = crate::static_gen::StaticManifest {
generated_at: "2026-03-27T00:00:00Z".to_owned(),
autumn_version: "0.2.0".to_owned(),
routes: HashMap::from([(
"/docs".to_owned(),
crate::static_gen::ManifestEntry {
file: "docs/index.html".to_owned(),
revalidate: None,
},
)]),
};
let json = serde_json::to_string(&manifest).expect("serialize");
std::fs::write(dist.join("manifest.json"), json).expect("write manifest");
let config = AutumnConfig::default();
let state = AppState {
extensions: std::sync::Arc::new(std::sync::RwLock::new(
std::collections::HashMap::new(),
)),
#[cfg(feature = "db")]
pool: None,
#[cfg(feature = "db")]
replica_pool: None,
profile: None,
started_at: std::time::Instant::now(),
health_detailed: true,
probes: crate::probe::ProbeState::ready_for_test(),
metrics: crate::middleware::MetricsCollector::new(),
log_levels: crate::actuator::LogLevels::new("info"),
task_registry: crate::actuator::TaskRegistry::new(),
job_registry: crate::actuator::JobRegistry::new(),
config_props: crate::actuator::ConfigProperties::default(),
#[cfg(feature = "ws")]
channels: crate::channels::Channels::new(32),
#[cfg(feature = "ws")]
shutdown: tokio_util::sync::CancellationToken::new(),
policy_registry: crate::authorization::PolicyRegistry::default(),
forbidden_response: crate::authorization::ForbiddenResponse::default(),
auth_session_key: "user_id".to_owned(),
shared_cache: None,
};
let router = crate::router::build_router_with_static(
vec![test_get_route("/other", "other_page")],
&config,
state,
Some(dist.as_path()),
);
let response = router
.oneshot(
Request::builder()
.uri("/docs/")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let csp = response
.headers()
.get("content-security-policy")
.expect("static-first HTML should still receive security headers")
.to_str()
.unwrap();
assert!(csp.contains("script-src 'self'"), "csp = {csp}");
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
assert_eq!(std::str::from_utf8(&body).unwrap(), "<h1>Static Docs</h1>");
}
#[tokio::test]
async fn build_mode_static_rendering_bypasses_startup_barrier() {
temp_env::async_with_vars([("AUTUMN_BUILD_STATIC", Some("1"))], async {
let config = AutumnConfig::default();
let state = AppState::for_test().with_startup_complete(false);
let router = crate::router::build_router(
vec![Route {
method: http::Method::GET,
path: "/about",
handler: axum::routing::get(|| async { "About Page Content" }),
name: "about",
api_doc: crate::openapi::ApiDoc {
method: "GET",
path: "/about",
operation_id: "about",
success_status: 200,
..Default::default()
},
repository: None,
}],
&config,
state,
);
let tmp = tempfile::tempdir().unwrap();
let dist = tmp.path().join("dist");
let result = crate::static_gen::render_static_routes(
router,
&[crate::static_gen::StaticRouteMeta {
path: "/about",
name: "about",
revalidate: None,
params_fn: None,
}],
&dist,
)
.await;
assert!(result.is_ok(), "build failed: {:?}", result.err());
let html = std::fs::read_to_string(dist.join("about/index.html")).unwrap();
assert_eq!(html, "About Page Content");
})
.await;
}
#[tokio::test]
async fn build_router_injects_live_reload_script_when_enabled() {
let reload_file = tempfile::NamedTempFile::new().expect("reload state file");
std::fs::write(reload_file.path(), r#"{"version":0,"kind":"full"}"#).expect("write");
temp_env::async_with_vars(
[
("AUTUMN_DEV_RELOAD", Some("1")),
(
"AUTUMN_DEV_RELOAD_STATE",
Some(reload_file.path().to_str().expect("utf-8 path")),
),
],
async {
let router = test_router(vec![Route {
method: http::Method::GET,
path: "/page",
handler: axum::routing::get(|| async {
axum::response::Html("<html><body><main>ok</main></body></html>")
}),
name: "page",
api_doc: crate::openapi::ApiDoc {
method: "GET",
path: "/page",
operation_id: "page",
success_status: 200,
..Default::default()
},
repository: None,
}]);
let response = router
.oneshot(Request::builder().uri("/page").body(Body::empty()).unwrap())
.await
.unwrap();
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let html = std::str::from_utf8(&body).expect("utf-8");
assert!(html.contains("/__autumn/live-reload"));
},
)
.await;
}
#[tokio::test]
async fn build_router_mounts_dev_reload_script_endpoint_when_enabled() {
let reload_file = tempfile::NamedTempFile::new().expect("reload state file");
std::fs::write(reload_file.path(), r#"{"version":0,"kind":"full"}"#).expect("write");
temp_env::async_with_vars(
[
("AUTUMN_DEV_RELOAD", Some("1")),
(
"AUTUMN_DEV_RELOAD_STATE",
Some(reload_file.path().to_str().expect("utf-8 path")),
),
],
async {
let router = test_router(vec![test_get_route("/dummy", "dummy")]);
let response = router
.oneshot(
Request::builder()
.uri("/__autumn/live-reload.js")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(
response
.headers()
.get("content-type")
.and_then(|v| v.to_str().ok()),
Some("application/javascript; charset=utf-8")
);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let js = std::str::from_utf8(&body).expect("utf-8");
assert!(js.contains("fetch("), "js body: {js}");
},
)
.await;
}
#[tokio::test]
async fn build_router_mounts_dev_reload_endpoint_when_enabled() {
let reload_file = tempfile::NamedTempFile::new().expect("reload state file");
std::fs::write(reload_file.path(), r#"{"version":7,"kind":"css"}"#).expect("write");
temp_env::async_with_vars(
[
("AUTUMN_DEV_RELOAD", Some("1")),
(
"AUTUMN_DEV_RELOAD_STATE",
Some(reload_file.path().to_str().expect("utf-8 path")),
),
],
async {
let router = test_router(vec![test_get_route("/dummy", "dummy")]);
let response = router
.oneshot(
Request::builder()
.uri("/__autumn/live-reload")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(
response.headers().get("cache-control").unwrap(),
"no-store, no-cache, must-revalidate"
);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
assert_eq!(&body[..], br#"{"version":7,"kind":"css"}"#);
},
)
.await;
}
#[tokio::test]
async fn build_router_disables_cache_for_static_assets_in_dev_reload_mode() {
let project = tempfile::tempdir().expect("project dir");
let static_dir = project.path().join("static");
std::fs::create_dir_all(&static_dir).expect("mkdir");
std::fs::write(static_dir.join("demo.txt"), "hello").expect("write static file");
let reload_file = tempfile::NamedTempFile::new().expect("reload state file");
std::fs::write(reload_file.path(), r#"{"version":0,"kind":"full"}"#).expect("write");
temp_env::async_with_vars(
[
(
"AUTUMN_MANIFEST_DIR",
Some(project.path().to_str().expect("utf-8 path")),
),
("AUTUMN_DEV_RELOAD", Some("1")),
(
"AUTUMN_DEV_RELOAD_STATE",
Some(reload_file.path().to_str().expect("utf-8 path")),
),
],
async {
let router = test_router(vec![test_get_route("/dummy", "dummy")]);
let response = router
.oneshot(
Request::builder()
.uri("/static/demo.txt")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(
response.headers().get("cache-control").unwrap(),
"no-store, no-cache, must-revalidate"
);
},
)
.await;
}
#[test]
fn app_builder_accepts_static_routes() {
use crate::static_gen::StaticRouteMeta;
let metas = vec![StaticRouteMeta {
path: "/about",
name: "about",
revalidate: None,
params_fn: None,
}];
let builder = app().static_routes(metas);
assert_eq!(builder.static_metas.len(), 1);
}
#[test]
fn project_dir_defaults_to_subdir() {
let env = crate::config::MockEnv::new();
let dir = super::project_dir("dist", &env);
assert_eq!(dir, std::path::PathBuf::from("dist"));
}
pub fn test_router_with_config(routes: Vec<Route>, config: &AutumnConfig) -> axum::Router {
let state = AppState {
extensions: std::sync::Arc::new(std::sync::RwLock::new(
std::collections::HashMap::new(),
)),
#[cfg(feature = "db")]
pool: None,
#[cfg(feature = "db")]
replica_pool: None,
profile: None,
started_at: std::time::Instant::now(),
health_detailed: true,
probes: crate::probe::ProbeState::ready_for_test(),
metrics: crate::middleware::MetricsCollector::new(),
log_levels: crate::actuator::LogLevels::new("info"),
task_registry: crate::actuator::TaskRegistry::new(),
job_registry: crate::actuator::JobRegistry::new(),
config_props: crate::actuator::ConfigProperties::default(),
#[cfg(feature = "ws")]
channels: crate::channels::Channels::new(32),
#[cfg(feature = "ws")]
shutdown: tokio_util::sync::CancellationToken::new(),
policy_registry: crate::authorization::PolicyRegistry::default(),
forbidden_response: crate::authorization::ForbiddenResponse::default(),
auth_session_key: "user_id".to_owned(),
shared_cache: None,
};
crate::router::build_router(routes, config, state)
}
#[tokio::test]
async fn cors_wildcard_allows_any_origin() {
let mut config = AutumnConfig::default();
config.cors.allowed_origins = vec!["*".to_owned()];
let router = test_router_with_config(vec![test_get_route("/test", "test")], &config);
let response = router
.oneshot(
Request::builder()
.uri("/test")
.header("Origin", "https://example.com")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(
response
.headers()
.get("access-control-allow-origin")
.unwrap(),
"*"
);
}
#[tokio::test]
async fn cors_specific_origin_reflected() {
let mut config = AutumnConfig::default();
config.cors.allowed_origins = vec!["https://example.com".to_owned()];
let router = test_router_with_config(vec![test_get_route("/test", "test")], &config);
let response = router
.oneshot(
Request::builder()
.uri("/test")
.header("Origin", "https://example.com")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(
response
.headers()
.get("access-control-allow-origin")
.unwrap(),
"https://example.com"
);
}
#[tokio::test]
async fn cors_disabled_when_no_origins() {
let config = AutumnConfig::default();
assert!(config.cors.allowed_origins.is_empty());
let router = test_router_with_config(vec![test_get_route("/test", "test")], &config);
let response = router
.oneshot(
Request::builder()
.uri("/test")
.header("Origin", "https://example.com")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert!(
response
.headers()
.get("access-control-allow-origin")
.is_none()
);
}
#[tokio::test]
async fn cors_preflight_returns_204() {
let mut config = AutumnConfig::default();
config.cors.allowed_origins = vec!["https://example.com".to_owned()];
let router = test_router_with_config(vec![test_get_route("/test", "test")], &config);
let response = router
.oneshot(
Request::builder()
.method("OPTIONS")
.uri("/test")
.header("Origin", "https://example.com")
.header("Access-Control-Request-Method", "GET")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert!(
response
.headers()
.contains_key("access-control-allow-methods")
);
}
#[tokio::test]
async fn build_router_with_static_skips_without_manifest() {
let tmp = tempfile::tempdir().expect("tempdir");
let dist = tmp.path().join("dist");
std::fs::create_dir_all(&dist).expect("mkdir");
let config = AutumnConfig::default();
let state = AppState {
extensions: std::sync::Arc::new(std::sync::RwLock::new(
std::collections::HashMap::new(),
)),
#[cfg(feature = "db")]
pool: None,
#[cfg(feature = "db")]
replica_pool: None,
profile: None,
started_at: std::time::Instant::now(),
health_detailed: true,
probes: crate::probe::ProbeState::ready_for_test(),
metrics: crate::middleware::MetricsCollector::new(),
log_levels: crate::actuator::LogLevels::new("info"),
task_registry: crate::actuator::TaskRegistry::new(),
job_registry: crate::actuator::JobRegistry::new(),
config_props: crate::actuator::ConfigProperties::default(),
#[cfg(feature = "ws")]
channels: crate::channels::Channels::new(32),
#[cfg(feature = "ws")]
shutdown: tokio_util::sync::CancellationToken::new(),
policy_registry: crate::authorization::PolicyRegistry::default(),
forbidden_response: crate::authorization::ForbiddenResponse::default(),
auth_session_key: "user_id".to_owned(),
shared_cache: None,
};
let router = crate::router::build_router_with_static(
vec![test_get_route("/test", "test")],
&config,
state,
Some(dist.as_path()),
);
let response = router
.oneshot(Request::builder().uri("/test").body(Body::empty()).unwrap())
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
}
#[tokio::test]
async fn build_router_with_static_none_dist() {
let config = AutumnConfig::default();
let state = AppState {
extensions: std::sync::Arc::new(std::sync::RwLock::new(
std::collections::HashMap::new(),
)),
#[cfg(feature = "db")]
pool: None,
#[cfg(feature = "db")]
replica_pool: None,
profile: None,
started_at: std::time::Instant::now(),
health_detailed: true,
probes: crate::probe::ProbeState::ready_for_test(),
metrics: crate::middleware::MetricsCollector::new(),
log_levels: crate::actuator::LogLevels::new("info"),
task_registry: crate::actuator::TaskRegistry::new(),
job_registry: crate::actuator::JobRegistry::new(),
config_props: crate::actuator::ConfigProperties::default(),
#[cfg(feature = "ws")]
channels: crate::channels::Channels::new(32),
#[cfg(feature = "ws")]
shutdown: tokio_util::sync::CancellationToken::new(),
policy_registry: crate::authorization::PolicyRegistry::default(),
forbidden_response: crate::authorization::ForbiddenResponse::default(),
auth_session_key: "user_id".to_owned(),
shared_cache: None,
};
let router = crate::router::build_router_with_static(
vec![test_get_route("/test", "test")],
&config,
state,
None,
);
let response = router
.oneshot(Request::builder().uri("/test").body(Body::empty()).unwrap())
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
}
#[test]
fn format_route_lines_lists_user_routes() {
let routes = vec![
test_get_route("/", "index"),
test_get_route("/users/{id}", "get_user"),
];
let config = AutumnConfig::default();
let output = format_route_lines(&routes, &[], &config);
assert!(output.contains("-> index"));
assert!(output.contains("/ GET"));
assert!(output.contains("/users/{id}"));
assert!(output.contains("-> get_user"));
}
#[test]
fn config_runtime_drift_format_route_lines_uses_actuator_prefix() {
let mut config = AutumnConfig::default();
config.actuator.prefix = "/ops".to_owned();
let output = format_route_lines(&[], &[], &config);
assert!(output.contains("-> health"));
assert!(output.contains("/ops/*"));
}
#[test]
fn format_task_lines_none_when_empty() {
assert!(format_task_lines(&[]).is_none());
}
#[test]
fn format_task_lines_fixed_delay() {
let tasks = vec![crate::task::TaskInfo {
name: "cleanup".into(),
schedule: crate::task::Schedule::FixedDelay(std::time::Duration::from_secs(300)),
coordination: crate::task::TaskCoordination::Fleet,
handler: |_| Box::pin(async { Ok(()) }),
}];
let output = format_task_lines(&tasks).unwrap();
assert!(output.contains("cleanup (every 300s)"));
}
#[test]
fn format_task_lines_cron() {
let tasks = vec![crate::task::TaskInfo {
name: "nightly".into(),
schedule: crate::task::Schedule::Cron {
expression: "0 0 * * *".into(),
timezone: None,
},
coordination: crate::task::TaskCoordination::Fleet,
handler: |_| Box::pin(async { Ok(()) }),
}];
let output = format_task_lines(&tasks).unwrap();
assert!(output.contains("nightly (cron 0 0 * * *)"));
}
#[test]
fn format_middleware_list_default() {
let config = AutumnConfig::default();
let output = format_middleware_list(&config);
assert!(output.contains("RequestId"));
assert!(output.contains("SecurityHeaders"));
assert!(output.contains("Session (in-memory)"));
assert!(output.contains("Metrics"));
assert!(!output.contains("CORS"));
assert!(!output.contains("CSRF"));
}
#[test]
fn format_middleware_list_with_cors_and_csrf() {
let config = AutumnConfig {
cors: crate::config::CorsConfig {
allowed_origins: vec!["https://example.com".into()],
..crate::config::CorsConfig::default()
},
security: crate::security::config::SecurityConfig {
csrf: crate::security::config::CsrfConfig {
enabled: true,
..crate::security::config::CsrfConfig::default()
},
..crate::security::config::SecurityConfig::default()
},
..AutumnConfig::default()
};
let output = format_middleware_list(&config);
assert!(output.contains("CORS"));
assert!(output.contains("CSRF"));
}
#[test]
fn mask_database_url_with_password() {
let masked = mask_database_url("postgres://user:secret@localhost:5432/mydb", 10);
assert!(masked.contains("****"));
assert!(!masked.contains("secret"));
assert!(masked.contains("postgres://user:****@localhost:5432/mydb"));
assert!(masked.contains("pool_size=10"));
}
#[test]
fn mask_database_url_without_password() {
let masked = mask_database_url("postgres://localhost/mydb", 5);
assert!(!masked.contains("****"));
assert!(masked.contains("postgres://localhost/mydb"));
assert!(masked.contains("pool_size=5"));
}
#[test]
fn mask_database_url_edge_cases() {
let masked2 = mask_database_url("postgres://user:p%40ssw%3Ard%21@localhost:5432/mydb", 10);
assert!(masked2.contains("****"));
assert!(!masked2.contains("p%40ssw%3Ard%21"));
assert!(masked2.contains("postgres://user:****@localhost:5432/mydb"));
let masked3 = mask_database_url("postgres://:secret@localhost:5432/mydb", 10);
assert!(masked3.contains("****"));
assert!(!masked3.contains("secret"));
assert!(masked3.contains("postgres://:****@localhost:5432/mydb"));
}
#[test]
fn mask_database_url_invalid_url_fallback() {
let masked = mask_database_url("this is completely invalid as a URL with supersecret", 10);
assert!(masked.contains("****"));
assert!(!masked.contains("supersecret"));
assert!(masked.contains("pool_size=10"));
}
#[test]
fn format_config_summary_defaults() {
let config = AutumnConfig::default();
let output = format_config_summary(&config);
assert!(output.contains("profile: none"));
assert!(output.contains("server: 127.0.0.1:3000"));
assert!(output.contains("database: not configured"));
assert!(output.contains("log_level:"));
assert!(output.contains("telemetry: disabled"));
assert!(output.contains("health: /health"));
}
#[test]
fn format_config_summary_with_db() {
let config = AutumnConfig {
database: crate::config::DatabaseConfig {
url: Some("postgres://user:pass@host/db".into()),
pool_size: 20,
..crate::config::DatabaseConfig::default()
},
..AutumnConfig::default()
};
let output = format_config_summary(&config);
assert!(output.contains("user:****@host/db"));
assert!(output.contains("pool_size=20"));
assert!(!output.contains("pass"));
}
#[test]
fn format_config_summary_with_profile() {
let config = AutumnConfig {
profile: Some("prod".into()),
..AutumnConfig::default()
};
let output = format_config_summary(&config);
assert!(output.contains("profile: prod"));
}
#[test]
fn format_config_summary_with_telemetry() {
let config = AutumnConfig {
telemetry: crate::config::TelemetryConfig {
enabled: true,
service_name: "orders-api".into(),
otlp_endpoint: Some("http://otel-collector:4317".into()),
..crate::config::TelemetryConfig::default()
},
..AutumnConfig::default()
};
let output = format_config_summary(&config);
assert!(output.contains("telemetry: Grpc -> http://otel-collector:4317"));
}
#[test]
fn log_startup_transparency_runs_without_panic() {
let routes = vec![test_get_route("/", "index")];
let tasks = vec![crate::task::TaskInfo {
name: "cleanup".into(),
schedule: crate::task::Schedule::FixedDelay(std::time::Duration::from_secs(60)),
coordination: crate::task::TaskCoordination::Fleet,
handler: |_| Box::pin(async { Ok(()) }),
}];
let config = AutumnConfig::default();
log_startup_transparency(&routes, &tasks, &[], &config);
}
#[test]
fn log_startup_transparency_no_tasks() {
let routes = vec![test_get_route("/health", "check")];
let config = AutumnConfig::default();
log_startup_transparency(&routes, &[], &[], &config);
}
#[cfg(feature = "ws")]
#[tokio::test]
async fn start_task_scheduler_broadcasts_events() {
let state = AppState {
extensions: std::sync::Arc::new(std::sync::RwLock::new(
std::collections::HashMap::new(),
)),
#[cfg(feature = "db")]
pool: None,
#[cfg(feature = "db")]
replica_pool: None,
profile: None,
started_at: std::time::Instant::now(),
health_detailed: true,
probes: crate::probe::ProbeState::ready_for_test(),
metrics: crate::middleware::MetricsCollector::new(),
log_levels: crate::actuator::LogLevels::new("info"),
task_registry: crate::actuator::TaskRegistry::new(),
job_registry: crate::actuator::JobRegistry::new(),
config_props: crate::actuator::ConfigProperties::default(),
channels: crate::channels::Channels::new(32),
shutdown: tokio_util::sync::CancellationToken::new(),
policy_registry: crate::authorization::PolicyRegistry::default(),
forbidden_response: crate::authorization::ForbiddenResponse::default(),
auth_session_key: "user_id".to_owned(),
shared_cache: None,
};
let mut rx = state.channels().subscribe("sys:tasks");
let task = crate::task::TaskInfo {
name: "test_broadcaster".into(),
schedule: crate::task::Schedule::FixedDelay(std::time::Duration::from_millis(1)),
coordination: crate::task::TaskCoordination::Fleet,
handler: |_| Box::pin(async { Ok(()) }),
};
let state_clone = state.clone();
tokio::spawn(async move {
super::start_task_scheduler(
vec![task],
&state_clone,
&tokio_util::sync::CancellationToken::new(),
);
});
let msg1 = tokio::time::timeout(std::time::Duration::from_secs(1), rx.recv())
.await
.expect("timeout waiting for start event")
.expect("channel closed");
let json1: serde_json::Value = serde_json::from_str(msg1.as_str()).unwrap();
assert_eq!(json1["event"], "started");
assert_eq!(json1["task"], "test_broadcaster");
let msg2 = tokio::time::timeout(std::time::Duration::from_secs(1), rx.recv())
.await
.expect("timeout waiting for success event")
.expect("channel closed");
let json2: serde_json::Value = serde_json::from_str(msg2.as_str()).unwrap();
assert_eq!(json2["event"], "success");
assert_eq!(json2["task"], "test_broadcaster");
assert!(json2.get("duration_ms").is_some());
}
#[cfg(feature = "ws")]
#[tokio::test]
async fn start_task_scheduler_broadcasts_failure_events() {
let state = AppState {
extensions: std::sync::Arc::new(std::sync::RwLock::new(
std::collections::HashMap::new(),
)),
#[cfg(feature = "db")]
pool: None,
#[cfg(feature = "db")]
replica_pool: None,
profile: None,
started_at: std::time::Instant::now(),
health_detailed: true,
probes: crate::probe::ProbeState::ready_for_test(),
metrics: crate::middleware::MetricsCollector::new(),
log_levels: crate::actuator::LogLevels::new("info"),
task_registry: crate::actuator::TaskRegistry::new(),
job_registry: crate::actuator::JobRegistry::new(),
config_props: crate::actuator::ConfigProperties::default(),
channels: crate::channels::Channels::new(32),
shutdown: tokio_util::sync::CancellationToken::new(),
policy_registry: crate::authorization::PolicyRegistry::default(),
forbidden_response: crate::authorization::ForbiddenResponse::default(),
auth_session_key: "user_id".to_owned(),
shared_cache: None,
};
let mut rx = state.channels().subscribe("sys:tasks");
let task = crate::task::TaskInfo {
name: "test_failing_task".into(),
schedule: crate::task::Schedule::FixedDelay(std::time::Duration::from_millis(1)),
coordination: crate::task::TaskCoordination::Fleet,
handler: |_| {
Box::pin(async { Err(crate::AutumnError::bad_request_msg("forced error")) })
},
};
let state_clone = state.clone();
tokio::spawn(async move {
super::start_task_scheduler(
vec![task],
&state_clone,
&tokio_util::sync::CancellationToken::new(),
);
});
let _ = rx.recv().await.unwrap();
let msg2 = tokio::time::timeout(std::time::Duration::from_secs(1), rx.recv())
.await
.expect("timeout waiting for failure event")
.expect("channel closed");
let json2: serde_json::Value = serde_json::from_str(msg2.as_str()).unwrap();
assert_eq!(json2["event"], "failure");
assert_eq!(json2["task"], "test_failing_task");
assert_eq!(json2["error"], "forced error");
}
#[tokio::test]
async fn execute_task_result_ok_returns_duration() {
let state = AppState::for_test();
let handler: crate::task::TaskHandler = |_| Box::pin(async { Ok(()) });
let start = std::time::Instant::now();
let result =
super::execute_task_result(&state, handler, start, "test_task", "fixed_delay").await;
assert!(result.is_ok(), "expected Ok from successful handler");
assert!(result.unwrap() < u64::MAX);
}
#[tokio::test]
async fn execute_task_result_err_returns_duration_and_message() {
let state = AppState::for_test();
let handler: crate::task::TaskHandler =
|_| Box::pin(async { Err(crate::AutumnError::bad_request_msg("test error")) });
let start = std::time::Instant::now();
let result =
super::execute_task_result(&state, handler, start, "test_task", "fixed_delay").await;
assert!(result.is_err(), "expected Err from failing handler");
let (duration_ms, msg) = result.unwrap_err();
assert!(duration_ms < u64::MAX);
assert!(msg.contains("test error"));
}
fn instantly_panicking_scheduled_handler(
_state: AppState,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = crate::AutumnResult<()>> + Send>> {
panic!("panic before scheduled future")
}
#[tokio::test]
async fn execute_task_result_reports_immediate_handler_panics() {
let state = AppState::for_test();
let start = std::time::Instant::now();
let result = super::execute_task_result(
&state,
instantly_panicking_scheduled_handler,
start,
"test_task",
"fixed_delay",
)
.await;
let (duration_ms, msg) = result.expect_err("expected Err from panicking handler");
assert!(duration_ms < u64::MAX);
assert!(msg.contains("scheduled task handler panicked: panic before scheduled future"));
}
#[tokio::test]
async fn execute_fixed_delay_task_does_not_timeout_in_process_runs() {
let state = AppState::for_test();
state.task_registry.register_scheduled(
"slow_task",
"every 1s",
crate::task::TaskCoordination::Fleet,
"in_process",
"replica-a",
);
let handler: crate::task::TaskHandler = |_| {
Box::pin(async {
tokio::time::sleep(std::time::Duration::from_millis(30)).await;
Ok(())
})
};
let coordinator = std::sync::Arc::new(
crate::scheduler::InProcessSchedulerCoordinator::new("replica-a"),
);
super::execute_fixed_delay_task(
"slow_task".to_owned(),
state.clone(),
handler,
std::time::Duration::from_secs(1),
crate::task::TaskCoordination::Fleet,
coordinator,
std::time::Duration::from_millis(10),
)
.await;
let snapshot = state.task_registry.snapshot();
let status = &snapshot["slow_task"];
assert_eq!(status.status, "idle");
assert_eq!(status.last_result.as_deref(), Some("ok"));
assert_eq!(status.total_runs, 1);
assert_eq!(status.total_failures, 0);
assert!(status.last_error.is_none());
}
static SKIPPED_LEASE_HANDLER_CALLS: AtomicUsize = AtomicUsize::new(0);
struct DenyingSchedulerCoordinator;
impl crate::scheduler::SchedulerCoordinator for DenyingSchedulerCoordinator {
fn backend(&self) -> &'static str {
"postgres"
}
fn replica_id(&self) -> &'static str {
"replica-a"
}
fn try_acquire<'a>(
&'a self,
_task_name: &'a str,
_tick_key: &'a str,
_coordination: crate::task::TaskCoordination,
) -> crate::scheduler::SchedulerFuture<
'a,
crate::AutumnResult<Option<crate::scheduler::SchedulerLease>>,
> {
Box::pin(async { Ok(None) })
}
}
struct GrantingSchedulerCoordinator {
backend: &'static str,
tick_keys: std::sync::Arc<std::sync::Mutex<Vec<String>>>,
release_count: Option<std::sync::Arc<AtomicUsize>>,
}
impl crate::scheduler::SchedulerCoordinator for GrantingSchedulerCoordinator {
fn backend(&self) -> &'static str {
self.backend
}
fn replica_id(&self) -> &'static str {
"replica-a"
}
fn try_acquire<'a>(
&'a self,
_task_name: &'a str,
tick_key: &'a str,
_coordination: crate::task::TaskCoordination,
) -> crate::scheduler::SchedulerFuture<
'a,
crate::AutumnResult<Option<crate::scheduler::SchedulerLease>>,
> {
Box::pin(async move {
self.tick_keys.lock().unwrap().push(tick_key.to_owned());
let lease = self.release_count.as_ref().map_or_else(
|| crate::scheduler::SchedulerLease::local(self.backend, "replica-a"),
|release_count| {
crate::scheduler::SchedulerLease::tracked(
self.backend,
"replica-a",
std::sync::Arc::clone(release_count),
)
},
);
Ok(Some(lease))
})
}
}
fn counted_scheduled_handler(
_state: AppState,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = crate::AutumnResult<()>> + Send>> {
Box::pin(async {
SKIPPED_LEASE_HANDLER_CALLS.fetch_add(1, Ordering::SeqCst);
Ok(())
})
}
#[tokio::test]
async fn execute_fixed_delay_task_skips_handler_when_lease_is_not_acquired() {
SKIPPED_LEASE_HANDLER_CALLS.store(0, Ordering::SeqCst);
let state = AppState::for_test();
state.task_registry.register_scheduled(
"claimed_elsewhere",
"every 1s",
crate::task::TaskCoordination::Fleet,
"postgres",
"replica-a",
);
let coordinator = std::sync::Arc::new(DenyingSchedulerCoordinator);
super::execute_fixed_delay_task(
"claimed_elsewhere".to_owned(),
state.clone(),
counted_scheduled_handler,
std::time::Duration::from_secs(1),
crate::task::TaskCoordination::Fleet,
coordinator,
std::time::Duration::from_secs(1),
)
.await;
let snapshot = state.task_registry.snapshot();
let status = &snapshot["claimed_elsewhere"];
assert_eq!(SKIPPED_LEASE_HANDLER_CALLS.load(Ordering::SeqCst), 0);
assert_eq!(status.total_runs, 0);
assert!(status.current_leader.is_none());
assert!(status.last_tick.is_none());
}
#[tokio::test]
async fn execute_fixed_delay_task_records_distributed_lease_ttl_timeout() {
let state = AppState::for_test();
state.task_registry.register_scheduled(
"slow_distributed_task",
"every 1s",
crate::task::TaskCoordination::Fleet,
"postgres",
"replica-a",
);
let handler: crate::task::TaskHandler = |_| {
Box::pin(async {
tokio::time::sleep(std::time::Duration::from_secs(5)).await;
Ok(())
})
};
let coordinator = std::sync::Arc::new(GrantingSchedulerCoordinator {
backend: "postgres",
tick_keys: std::sync::Arc::new(std::sync::Mutex::new(Vec::new())),
release_count: None,
});
super::execute_fixed_delay_task(
"slow_distributed_task".to_owned(),
state.clone(),
handler,
std::time::Duration::from_secs(1),
crate::task::TaskCoordination::Fleet,
coordinator,
std::time::Duration::from_millis(10),
)
.await;
let snapshot = state.task_registry.snapshot();
let status = &snapshot["slow_distributed_task"];
assert_eq!(status.status, "idle");
assert_eq!(status.last_result.as_deref(), Some("failed"));
assert_eq!(status.total_runs, 1);
assert_eq!(status.total_failures, 1);
assert!(
status
.last_error
.as_deref()
.is_some_and(|error| error.contains("lease TTL"))
);
}
#[tokio::test]
async fn execute_cron_task_uses_scheduled_occurrence_for_tick_key() {
let state = AppState::for_test();
state.task_registry.register_scheduled(
"cron_review_task",
"cron */10 * * * * *",
crate::task::TaskCoordination::Fleet,
"postgres",
"replica-a",
);
let tick_keys = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
let coordinator = std::sync::Arc::new(GrantingSchedulerCoordinator {
backend: "postgres",
tick_keys: std::sync::Arc::clone(&tick_keys),
release_count: None,
});
let handler: crate::task::TaskHandler = |_| Box::pin(async { Ok(()) });
let scheduled_unix_secs = 1_700_000_000;
super::execute_cron_task(
"cron_review_task".to_owned(),
state.clone(),
handler,
crate::task::TaskCoordination::Fleet,
coordinator,
std::time::Duration::from_secs(30),
scheduled_unix_secs,
)
.await;
assert_eq!(
tick_keys.lock().unwrap().as_slice(),
["cron_review_task:1700000000"]
);
}
#[tokio::test]
async fn execute_fixed_delay_task_releases_lease_when_handler_panics() {
let state = AppState::for_test();
state.task_registry.register_scheduled(
"panic_task",
"every 1s",
crate::task::TaskCoordination::Fleet,
"postgres",
"replica-a",
);
let release_count = std::sync::Arc::new(AtomicUsize::new(0));
let coordinator = std::sync::Arc::new(GrantingSchedulerCoordinator {
backend: "postgres",
tick_keys: std::sync::Arc::new(std::sync::Mutex::new(Vec::new())),
release_count: Some(std::sync::Arc::clone(&release_count)),
});
let handler: crate::task::TaskHandler = |_| {
Box::pin(async {
panic!("forced scheduled panic");
#[allow(unreachable_code)]
Ok(())
})
};
super::execute_fixed_delay_task(
"panic_task".to_owned(),
state.clone(),
handler,
std::time::Duration::from_secs(1),
crate::task::TaskCoordination::Fleet,
coordinator,
std::time::Duration::from_secs(30),
)
.await;
let snapshot = state.task_registry.snapshot();
let status = &snapshot["panic_task"];
assert_eq!(release_count.load(Ordering::SeqCst), 1);
assert_eq!(status.status, "idle");
assert_eq!(status.last_result.as_deref(), Some("failed"));
assert_eq!(status.total_runs, 1);
assert_eq!(status.total_failures, 1);
assert!(
status
.last_error
.as_deref()
.is_some_and(|error| error.contains("scheduled task handler panicked"))
);
}
#[test]
fn next_cron_occurrence_skips_overdue_slots() {
use chrono::TimeZone as _;
let cron = "0 * * * * *"
.parse::<croner::Cron>()
.expect("cron expression should parse");
let stale_cursor = chrono_tz::UTC
.with_ymd_and_hms(2026, 5, 5, 12, 0, 0)
.unwrap();
let now = chrono_tz::UTC
.with_ymd_and_hms(2026, 5, 5, 12, 30, 5)
.unwrap();
let next = super::next_cron_occurrence_after(&cron, &stale_cursor, &now)
.expect("next cron occurrence should resolve");
assert_eq!(
next,
chrono_tz::UTC
.with_ymd_and_hms(2026, 5, 5, 12, 31, 0)
.unwrap()
);
}
#[test]
fn cron_occurrence_is_overdue_after_later_slot_passed() {
use chrono::TimeZone as _;
let cron = "0 * * * * *"
.parse::<croner::Cron>()
.expect("cron expression should parse");
let scheduled_at = chrono_tz::UTC
.with_ymd_and_hms(2026, 5, 5, 12, 1, 0)
.unwrap();
let slightly_late = chrono_tz::UTC
.with_ymd_and_hms(2026, 5, 5, 12, 1, 5)
.unwrap();
let after_later_slot = chrono_tz::UTC
.with_ymd_and_hms(2026, 5, 5, 12, 30, 5)
.unwrap();
assert!(
!super::cron_occurrence_is_overdue(&cron, &scheduled_at, &slightly_late)
.expect("overdue check should resolve")
);
assert!(
super::cron_occurrence_is_overdue(&cron, &scheduled_at, &after_later_slot)
.expect("overdue check should resolve")
);
}
#[cfg(feature = "storage")]
mod storage_preflight {
use super::super::{StorageBootstrap, preflight_storage};
use crate::AppState;
use crate::config::AutumnConfig;
use crate::storage::{BlobStoreState, StorageBackend, StorageConfig, StorageLocalConfig};
fn config_with_storage(storage: StorageConfig) -> AutumnConfig {
AutumnConfig {
profile: Some("dev".into()),
storage,
..AutumnConfig::default()
}
}
#[test]
fn preflight_returns_none_when_disabled() {
let cfg = config_with_storage(StorageConfig {
backend: StorageBackend::Disabled,
..StorageConfig::default()
});
assert!(preflight_storage(&cfg).is_none());
}
#[test]
fn preflight_provisions_local_backend_against_tempdir() {
let dir = tempfile::tempdir().unwrap();
let cfg = config_with_storage(StorageConfig {
backend: StorageBackend::Local,
local: StorageLocalConfig {
root: dir.path().to_path_buf(),
..StorageLocalConfig::default()
},
..StorageConfig::default()
});
let bootstrap = preflight_storage(&cfg).expect("local backend should provision");
assert_eq!(bootstrap.store.provider_id(), "default");
assert!(bootstrap.serving.is_some(), "local backend mounts a route");
}
#[tokio::test]
async fn install_registers_blob_store_on_state() {
let dir = tempfile::tempdir().unwrap();
let cfg = config_with_storage(StorageConfig {
backend: StorageBackend::Local,
local: StorageLocalConfig {
root: dir.path().to_path_buf(),
..StorageLocalConfig::default()
},
..StorageConfig::default()
});
let bootstrap: StorageBootstrap = preflight_storage(&cfg).unwrap();
let state = AppState::for_test();
assert!(state.extension::<BlobStoreState>().is_none());
let serving = bootstrap.install(&state);
assert!(serving.is_some());
assert!(state.extension::<BlobStoreState>().is_some());
}
#[test]
fn with_blob_store_stores_custom_store() {
use crate::storage::{
Blob, BlobFuture, BlobMeta, BlobStore, BlobStoreError, ByteStream,
};
use bytes::Bytes;
use std::time::Duration;
struct FakeStore;
impl BlobStore for FakeStore {
fn provider_id(&self) -> &'static str {
"fake"
}
fn put<'a>(&'a self, _k: &'a str, _ct: &'a str, _b: Bytes) -> BlobFuture<'a, Blob> {
Box::pin(async { Err(BlobStoreError::Unsupported("fake".into())) })
}
fn put_stream<'a>(
&'a self,
_k: &'a str,
_ct: &'a str,
_d: ByteStream<'a>,
) -> BlobFuture<'a, Blob> {
Box::pin(async { Err(BlobStoreError::Unsupported("fake".into())) })
}
fn get<'a>(&'a self, _k: &'a str) -> BlobFuture<'a, Bytes> {
Box::pin(async { Err(BlobStoreError::Unsupported("fake".into())) })
}
fn delete<'a>(&'a self, _k: &'a str) -> BlobFuture<'a, ()> {
Box::pin(async { Err(BlobStoreError::Unsupported("fake".into())) })
}
fn head<'a>(&'a self, _k: &'a str) -> BlobFuture<'a, Option<BlobMeta>> {
Box::pin(async { Err(BlobStoreError::Unsupported("fake".into())) })
}
fn presigned_url<'a>(
&'a self,
_k: &'a str,
_e: Duration,
) -> BlobFuture<'a, String> {
Box::pin(async { Err(BlobStoreError::Unsupported("fake".into())) })
}
}
let builder = crate::app().with_blob_store(FakeStore);
assert!(builder.blob_store.is_some());
}
#[tokio::test]
async fn with_blob_store_is_installed_on_state() {
use crate::storage::{
Blob, BlobFuture, BlobMeta, BlobStore, BlobStoreError, ByteStream,
};
use bytes::Bytes;
use std::time::Duration;
struct FakeStore;
impl BlobStore for FakeStore {
fn provider_id(&self) -> &'static str {
"fake-installed"
}
fn put<'a>(&'a self, _k: &'a str, _ct: &'a str, _b: Bytes) -> BlobFuture<'a, Blob> {
Box::pin(async { Err(BlobStoreError::Unsupported("fake".into())) })
}
fn put_stream<'a>(
&'a self,
_k: &'a str,
_ct: &'a str,
_d: ByteStream<'a>,
) -> BlobFuture<'a, Blob> {
Box::pin(async { Err(BlobStoreError::Unsupported("fake".into())) })
}
fn get<'a>(&'a self, _k: &'a str) -> BlobFuture<'a, Bytes> {
Box::pin(async { Err(BlobStoreError::Unsupported("fake".into())) })
}
fn delete<'a>(&'a self, _k: &'a str) -> BlobFuture<'a, ()> {
Box::pin(async { Err(BlobStoreError::Unsupported("fake".into())) })
}
fn head<'a>(&'a self, _k: &'a str) -> BlobFuture<'a, Option<BlobMeta>> {
Box::pin(async { Err(BlobStoreError::Unsupported("fake".into())) })
}
fn presigned_url<'a>(
&'a self,
_k: &'a str,
_e: Duration,
) -> BlobFuture<'a, String> {
Box::pin(async { Err(BlobStoreError::Unsupported("fake".into())) })
}
}
let builder = crate::app().with_blob_store(FakeStore);
let bootstrap = builder.blob_store.map(|store| StorageBootstrap {
store,
serving: None,
});
let state = AppState::for_test();
assert!(state.extension::<BlobStoreState>().is_none());
if let Some(b) = bootstrap {
b.install(&state);
}
let installed = state
.extension::<BlobStoreState>()
.expect("store should be installed");
assert_eq!(installed.store().provider_id(), "fake-installed");
}
}
struct TestPlugin {
name: &'static str,
route: Route,
}
impl crate::plugin::Plugin for TestPlugin {
fn name(&self) -> std::borrow::Cow<'static, str> {
std::borrow::Cow::Borrowed(self.name)
}
fn build(self, app: AppBuilder) -> AppBuilder {
app.routes(vec![self.route])
}
}
#[test]
fn routes_registered_before_plugin_are_user_sourced() {
let user_route = test_get_route("/home", "home");
let builder = app().routes(vec![user_route]);
assert_eq!(builder.route_sources.len(), 1);
assert_eq!(
builder.route_sources[0],
crate::route_listing::RouteSource::User
);
}
#[test]
fn routes_registered_inside_plugin_are_plugin_sourced() {
let plugin_route = test_get_route("/plugin-page", "plugin_page");
let plugin = TestPlugin {
name: "my-plugin",
route: plugin_route,
};
let builder = app().plugin(plugin);
assert_eq!(builder.route_sources.len(), 1);
assert_eq!(
builder.route_sources[0],
crate::route_listing::RouteSource::Plugin("my-plugin".to_owned())
);
}
#[test]
fn routes_registered_after_plugin_revert_to_user_sourced() {
let plugin_route = test_get_route("/plugin-page", "plugin_page");
let user_route = test_get_route("/home", "home");
let plugin = TestPlugin {
name: "my-plugin",
route: plugin_route,
};
let builder = app().plugin(plugin).routes(vec![user_route]);
assert_eq!(builder.route_sources.len(), 2);
assert_eq!(
builder.route_sources[0],
crate::route_listing::RouteSource::Plugin("my-plugin".to_owned())
);
assert_eq!(
builder.route_sources[1],
crate::route_listing::RouteSource::User
);
}
struct OuterPlugin;
impl crate::plugin::Plugin for OuterPlugin {
fn name(&self) -> std::borrow::Cow<'static, str> {
"outer".into()
}
fn build(self, app: AppBuilder) -> AppBuilder {
let inner = TestPlugin {
name: "inner",
route: test_get_route("/inner", "inner"),
};
app.plugin(inner)
.routes(vec![test_get_route("/outer-after", "outer_after")])
}
}
#[test]
fn outer_plugin_source_restored_after_nested_plugin() {
let builder = app().plugin(OuterPlugin);
assert_eq!(builder.route_sources.len(), 2);
assert_eq!(
builder.route_sources[0],
crate::route_listing::RouteSource::Plugin("inner".to_owned()),
"first route should be attributed to inner plugin"
);
assert_eq!(
builder.route_sources[1],
crate::route_listing::RouteSource::Plugin("outer".to_owned()),
"second route should be re-attributed to outer plugin after nested build"
);
}
}