use std::any::{Any, TypeId};
use std::collections::{HashMap, HashSet};
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use tracing::Instrument as _;
use crate::config::{AutumnConfig, ConfigLoader};
#[cfg(feature = "db")]
use crate::db::DatabasePoolProvider;
use crate::error_pages::{ErrorPageRenderer, SharedRenderer};
use crate::middleware::exception_filter::ExceptionFilter;
#[cfg(feature = "db")]
use crate::migrate;
use crate::route::Route;
use crate::state::AppState;
#[must_use]
pub fn app() -> AppBuilder {
AppBuilder {
routes: Vec::new(),
tasks: Vec::new(),
static_metas: Vec::new(),
exception_filters: Vec::new(),
scoped_groups: Vec::new(),
merge_routers: Vec::new(),
nest_routers: Vec::new(),
custom_layers: Vec::new(),
startup_hooks: Vec::new(),
shutdown_hooks: Vec::new(),
extensions: HashMap::new(),
registered_plugins: HashSet::new(),
error_page_renderer: None,
#[cfg(feature = "db")]
migrations: Vec::new(),
config_loader_factory: None,
#[cfg(feature = "db")]
pool_provider_factory: None,
telemetry_provider: None,
session_store: None,
#[cfg(feature = "openapi")]
openapi: None,
audit_logger: None,
}
}
type StartupHookFuture = Pin<Box<dyn Future<Output = crate::AutumnResult<()>> + Send>>;
type StartupHook = Box<dyn Fn(AppState) -> StartupHookFuture + Send + Sync>;
type ShutdownHookFuture = Pin<Box<dyn Future<Output = ()> + Send>>;
type ShutdownHook = Box<dyn Fn() -> ShutdownHookFuture + Send + Sync>;
type ConfigLoaderFactory = Box<
dyn FnOnce() -> Pin<
Box<dyn Future<Output = Result<AutumnConfig, crate::config::ConfigError>> + Send>,
> + Send,
>;
#[cfg(feature = "db")]
type PoolProviderFactory = Box<
dyn FnOnce(
crate::config::DatabaseConfig,
) -> Pin<
Box<
dyn Future<
Output = Result<
Option<
diesel_async::pooled_connection::deadpool::Pool<
diesel_async::AsyncPgConnection,
>,
>,
crate::db::PoolError,
>,
> + Send,
>,
> + Send,
>;
pub struct AppBuilder {
routes: Vec<Route>,
tasks: Vec<crate::task::TaskInfo>,
pub(crate) static_metas: Vec<crate::static_gen::StaticRouteMeta>,
exception_filters: Vec<Arc<dyn ExceptionFilter>>,
scoped_groups: Vec<ScopedGroup>,
merge_routers: Vec<axum::Router<AppState>>,
nest_routers: Vec<(String, axum::Router<AppState>)>,
custom_layers: Vec<CustomLayerRegistration>,
startup_hooks: Vec<StartupHook>,
shutdown_hooks: Vec<ShutdownHook>,
extensions: HashMap<TypeId, Box<dyn Any + Send>>,
registered_plugins: HashSet<String>,
error_page_renderer: Option<SharedRenderer>,
#[cfg(feature = "db")]
migrations: Vec<migrate::EmbeddedMigrations>,
config_loader_factory: Option<ConfigLoaderFactory>,
#[cfg(feature = "db")]
pool_provider_factory: Option<PoolProviderFactory>,
telemetry_provider: Option<Box<dyn crate::telemetry::TelemetryProvider>>,
session_store: Option<Arc<dyn crate::session::BoxedSessionStore>>,
#[cfg(feature = "openapi")]
openapi: Option<crate::openapi::OpenApiConfig>,
audit_logger: Option<Arc<crate::audit::AuditLogger>>,
}
pub(crate) struct ScopedGroup {
pub(crate) prefix: String,
pub(crate) routes: Vec<Route>,
pub(crate) apply_layer:
Box<dyn FnOnce(axum::Router<AppState>) -> axum::Router<AppState> + Send>,
}
pub(crate) type CustomLayerApplier =
Box<dyn FnOnce(axum::Router<AppState>) -> axum::Router<AppState> + Send>;
pub(crate) struct CustomLayerRegistration {
pub(crate) type_id: TypeId,
pub(crate) apply: CustomLayerApplier,
}
mod sealed {
pub trait Sealed {}
}
#[diagnostic::on_unimplemented(
message = "`{Self}` is not a usable Autumn app-wide Tower layer",
label = "this type does not implement `tower::Layer<axum::routing::Route>` with the required service bounds",
note = "`AppBuilder::layer(..)` requires:\n L: tower::Layer<axum::routing::Route> + Clone + Send + Sync + 'static,\n L::Service: Service<axum::extract::Request, Response = axum::response::Response, Error = Infallible> + Clone + Send + Sync + 'static,\n <L::Service as Service<axum::extract::Request>>::Future: Send + 'static\nSee docs/guide/middleware.md for common patterns and how to wrap raw-error layers (e.g. TimeoutLayer) with HandleErrorLayer."
)]
pub trait IntoAppLayer: sealed::Sealed + Send + Sync + 'static {
#[doc(hidden)]
fn apply_to(self, router: axum::Router<AppState>) -> axum::Router<AppState>;
}
impl<L> sealed::Sealed for L
where
L: tower::Layer<axum::routing::Route> + Clone + Send + Sync + 'static,
L::Service: tower::Service<
axum::extract::Request,
Response = axum::response::Response,
Error = std::convert::Infallible,
> + Clone
+ Send
+ Sync
+ 'static,
<L::Service as tower::Service<axum::extract::Request>>::Future: Send + 'static,
{
}
impl<L> IntoAppLayer for L
where
L: tower::Layer<axum::routing::Route> + Clone + Send + Sync + 'static,
L::Service: tower::Service<
axum::extract::Request,
Response = axum::response::Response,
Error = std::convert::Infallible,
> + Clone
+ Send
+ Sync
+ 'static,
<L::Service as tower::Service<axum::extract::Request>>::Future: Send + 'static,
{
fn apply_to(self, router: axum::Router<AppState>) -> axum::Router<AppState> {
router.layer(self)
}
}
impl AppBuilder {
#[must_use]
pub fn routes(mut self, routes: Vec<Route>) -> Self {
self.routes.extend(routes);
self
}
#[must_use]
pub fn tasks(mut self, tasks: Vec<crate::task::TaskInfo>) -> Self {
self.tasks.extend(tasks);
self
}
#[must_use]
pub fn static_routes(mut self, metas: Vec<crate::static_gen::StaticRouteMeta>) -> Self {
self.static_metas.extend(metas);
self
}
#[cfg(feature = "openapi")]
#[must_use]
pub fn openapi(mut self, config: crate::openapi::OpenApiConfig) -> Self {
self.openapi = Some(config);
self
}
#[must_use]
pub fn exception_filter(mut self, filter: impl ExceptionFilter) -> Self {
self.exception_filters.push(Arc::new(filter));
self
}
#[must_use]
pub fn error_pages(mut self, renderer: impl ErrorPageRenderer) -> Self {
self.error_page_renderer = Some(Arc::new(renderer));
self
}
#[must_use]
pub fn scoped<L>(mut self, prefix: &str, layer: L, routes: Vec<Route>) -> Self
where
L: tower::Layer<axum::routing::Route> + Clone + Send + Sync + 'static,
L::Service: tower::Service<
axum::http::Request<axum::body::Body>,
Response = axum::http::Response<axum::body::Body>,
Error = std::convert::Infallible,
> + Clone
+ Send
+ Sync
+ 'static,
<L::Service as tower::Service<axum::http::Request<axum::body::Body>>>::Future:
Send + 'static,
{
self.scoped_groups.push(ScopedGroup {
prefix: prefix.to_owned(),
routes,
apply_layer: Box::new(move |router| router.layer(layer)),
});
self
}
#[must_use]
pub fn layer<L: IntoAppLayer>(mut self, layer: L) -> Self {
self.custom_layers.push(CustomLayerRegistration {
type_id: TypeId::of::<L>(),
apply: Box::new(move |router| layer.apply_to(router)),
});
self
}
#[must_use]
pub fn has_layer<L: 'static>(&self) -> bool {
let layer_type = TypeId::of::<L>();
self.custom_layers
.iter()
.any(|registered| registered.type_id == layer_type)
}
#[must_use]
pub fn get_layer_types(&self) -> Vec<TypeId> {
self.custom_layers
.iter()
.map(|registered| registered.type_id)
.collect()
}
#[must_use]
pub fn merge(mut self, router: axum::Router<AppState>) -> Self {
self.merge_routers.push(router);
self
}
#[must_use]
pub fn nest(mut self, path: &str, router: axum::Router<AppState>) -> Self {
self.nest_routers.push((path.to_owned(), router));
self
}
#[must_use]
pub fn on_startup<F, Fut>(mut self, hook: F) -> Self
where
F: Fn(AppState) -> Fut + Send + Sync + 'static,
Fut: Future<Output = crate::AutumnResult<()>> + Send + 'static,
{
self.startup_hooks
.push(Box::new(move |state| Box::pin(hook(state))));
self
}
#[must_use]
pub fn on_shutdown<F, Fut>(mut self, hook: F) -> Self
where
F: Fn() -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
self.shutdown_hooks.push(Box::new(move || Box::pin(hook())));
self
}
#[must_use]
pub fn with_extension<T>(mut self, value: T) -> Self
where
T: Any + Send + 'static,
{
self.extensions.insert(TypeId::of::<T>(), Box::new(value));
self
}
#[must_use]
pub fn update_extension<T, Init, Update>(mut self, init: Init, update: Update) -> Self
where
T: Any + Send + 'static,
Init: FnOnce() -> T,
Update: FnOnce(&mut T),
{
let type_id = TypeId::of::<T>();
let entry = self
.extensions
.entry(type_id)
.or_insert_with(|| Box::new(init()));
let typed = entry
.downcast_mut::<T>()
.expect("extension type map corrupted");
update(typed);
self
}
#[must_use]
pub fn extension<T>(&self) -> Option<&T>
where
T: Any + Send + 'static,
{
self.extensions.get(&TypeId::of::<T>())?.downcast_ref::<T>()
}
#[must_use]
pub fn with_config_loader<L>(mut self, loader: L) -> Self
where
L: crate::config::ConfigLoader,
{
if self.config_loader_factory.is_some() {
tracing::warn!(
"config loader replaced; the previously-installed loader was overwritten"
);
}
self.config_loader_factory = Some(Box::new(move || {
Box::pin(async move { loader.load().await })
}));
self
}
#[cfg(feature = "db")]
#[must_use]
pub fn with_pool_provider<P>(mut self, provider: P) -> Self
where
P: crate::db::DatabasePoolProvider,
{
if self.pool_provider_factory.is_some() {
tracing::warn!(
"database pool provider replaced; the previously-installed provider was overwritten"
);
}
self.pool_provider_factory =
Some(Box::new(move |config: crate::config::DatabaseConfig| {
Box::pin(async move { provider.create_pool(&config).await })
}));
self
}
#[must_use]
pub fn with_telemetry_provider<T>(mut self, provider: T) -> Self
where
T: crate::telemetry::TelemetryProvider,
{
if self.telemetry_provider.is_some() {
tracing::warn!(
"telemetry provider replaced; the previously-installed provider was overwritten"
);
}
self.telemetry_provider = Some(Box::new(provider));
self
}
#[must_use]
pub fn with_session_store<S>(mut self, store: S) -> Self
where
S: crate::session::SessionStore,
{
if self.session_store.is_some() {
tracing::warn!(
"session store replaced; the previously-installed store was overwritten"
);
}
self.session_store = Some(Arc::new(store));
self
}
#[must_use]
pub fn with_audit_sink<S>(mut self, sink: S) -> Self
where
S: crate::audit::AuditSink,
{
let logger = self
.audit_logger
.take()
.map_or_else(crate::audit::AuditLogger::new, |logger| (*logger).clone())
.with_sink(Arc::new(sink));
self.audit_logger = Some(Arc::new(logger));
self
}
#[must_use]
#[track_caller]
pub fn plugin<P>(mut self, plugin: P) -> Self
where
P: crate::plugin::Plugin,
{
let name = plugin.name();
if self.registered_plugins.contains(name.as_ref()) {
tracing::warn!(
plugin = name.as_ref(),
"plugin already registered; skipping duplicate"
);
return self;
}
self.registered_plugins.insert(name.into_owned());
plugin.build(self)
}
#[must_use]
pub fn plugins<P>(self, plugins: P) -> Self
where
P: crate::plugin::Plugins,
{
plugins.apply(self)
}
#[must_use]
pub fn has_plugin(&self, name: &str) -> bool {
self.registered_plugins.contains(name)
}
#[cfg(feature = "db")]
#[must_use]
pub fn migrations(mut self, migrations: migrate::EmbeddedMigrations) -> Self {
self.migrations.push(migrations);
self
}
#[allow(clippy::too_many_lines)]
#[allow(clippy::cognitive_complexity)]
pub async fn run(self) {
if is_static_build_mode() {
self.run_build_mode().await;
return;
}
let Self {
routes,
tasks,
static_metas: _,
exception_filters,
scoped_groups,
merge_routers,
nest_routers,
custom_layers,
startup_hooks,
shutdown_hooks,
extensions: _,
registered_plugins: _,
error_page_renderer,
#[cfg(feature = "db")]
migrations,
config_loader_factory,
#[cfg(feature = "db")]
pool_provider_factory,
telemetry_provider,
session_store,
#[cfg(feature = "openapi")]
openapi,
audit_logger,
} = self;
let all_routes = routes;
let (config, _telemetry_guard) =
load_config_and_telemetry(config_loader_factory, telemetry_provider).await;
assert!(
!all_routes.is_empty(),
"No routes registered. Did you forget to call .routes()?"
);
let profile_display = config.profile.as_deref().unwrap_or("none");
tracing::info!(
version = env!("CARGO_PKG_VERSION"),
profile = profile_display,
"Autumn starting"
);
let show_config = std::env::var("AUTUMN_SHOW_CONFIG").as_deref() == Ok("1");
if show_config {
log_startup_transparency(&all_routes, &tasks, &scoped_groups, &config);
}
fail_fast_on_invalid_session_config(&config, session_store.is_some());
#[cfg(feature = "db")]
let pool = setup_database(&config, migrations, pool_provider_factory)
.await
.unwrap_or_else(|e| {
tracing::error!("{e}");
std::process::exit(1);
});
#[cfg(feature = "db")]
if pool.is_some() {
tracing::info!(
max_connections = config.database.pool_size,
"Database pool configured"
);
} else {
tracing::info!("Database not configured");
}
let state = build_state(
&config,
#[cfg(feature = "db")]
pool,
);
if let Some(logger) = audit_logger {
state.insert_extension::<crate::audit::AuditLogger>((*logger).clone());
}
let env = crate::config::OsEnv;
let dist_dir = project_dir("dist", &env);
let dist_ref = if dist_dir.exists() {
Some(dist_dir.as_path())
} else {
None
};
let router = crate::router::try_build_router_with_static_inner(
all_routes,
&config,
state.clone(),
dist_ref,
crate::router::RouterContext {
exception_filters,
scoped_groups,
merge_routers,
nest_routers,
custom_layers,
error_page_renderer,
session_store,
#[cfg(feature = "openapi")]
openapi,
},
)
.unwrap_or_else(|error| {
tracing::error!(error = %error, "Failed to build router");
std::process::exit(1);
});
let addr = format!("{}:{}", config.server.host, config.server.port);
let listener = tokio::net::TcpListener::bind(&addr)
.await
.unwrap_or_else(|e| {
tracing::error!(addr = %addr, "Failed to bind: {e}");
std::process::exit(1);
});
tracing::info!(addr = %addr, "Listening");
let shutdown_timeout = config.server.shutdown_timeout_secs;
let server_shutdown = tokio_util::sync::CancellationToken::new();
let server_shutdown_wait = server_shutdown.clone();
let server_task = tokio::spawn(async move {
axum::serve(
listener,
router.into_make_service_with_connect_info::<std::net::SocketAddr>(),
)
.with_graceful_shutdown(async move {
server_shutdown_wait.cancelled().await;
})
.await
});
let shutdown_state = state.clone();
let shutdown_signal_token = server_shutdown.clone();
#[cfg(feature = "ws")]
let websocket_shutdown = state.shutdown.clone();
let shutdown_task = tokio::spawn(async move {
shutdown_signal().await;
shutdown_state.begin_shutdown();
#[cfg(feature = "ws")]
websocket_shutdown.cancel();
if shutdown_timeout > 5 {
tokio::spawn(async move {
tokio::time::sleep(std::time::Duration::from_secs(
shutdown_timeout.saturating_sub(5),
))
.await;
tracing::warn!(
timeout_secs = shutdown_timeout,
"Shutdown draining near timeout, force-kill may be imminent"
);
});
}
run_shutdown_hooks(&shutdown_hooks).await;
shutdown_signal_token.cancel();
});
if let Err(error) = run_startup_hooks(&startup_hooks, state.clone()).await {
tracing::error!(error = %error, "startup hook failed");
server_shutdown.cancel();
server_task.abort();
std::process::exit(1);
}
if !state.probes().is_shutting_down() {
if !tasks.is_empty() {
start_task_scheduler(tasks, &state, server_shutdown.clone());
}
state.probes().mark_startup_complete();
}
let server_result = server_task.await.unwrap_or_else(|e| {
tracing::error!("Server task join error: {e}");
std::process::exit(1);
});
shutdown_task.abort();
server_result.unwrap_or_else(|e| {
tracing::error!("Server error: {e}");
std::process::exit(1);
});
tracing::info!("Server shut down cleanly");
}
async fn run_build_mode(self) {
let Self {
routes,
tasks: _,
static_metas,
exception_filters: _,
scoped_groups: _,
merge_routers: _,
nest_routers: _,
custom_layers,
startup_hooks: _,
shutdown_hooks: _,
extensions: _,
registered_plugins: _,
error_page_renderer: _,
#[cfg(feature = "db")]
migrations: _,
config_loader_factory,
#[cfg(feature = "db")]
pool_provider_factory,
telemetry_provider,
session_store,
#[cfg(feature = "openapi")]
openapi: _,
audit_logger: _,
} = self;
let all_routes = routes;
let (config, _telemetry_guard) =
load_config_and_telemetry(config_loader_factory, telemetry_provider).await;
if static_metas.is_empty() {
eprintln!("No static routes registered. Nothing to build.");
eprintln!("Hint: use .static_routes(static_routes![...]) on your AppBuilder.");
std::process::exit(1);
}
fail_fast_on_invalid_session_config(&config, session_store.is_some());
#[cfg(feature = "db")]
let pool = setup_database(&config, vec![], pool_provider_factory)
.await
.unwrap_or_else(|e| {
eprintln!("{e}");
std::process::exit(1);
});
let mut state = build_state(
&config,
#[cfg(feature = "db")]
pool,
);
state.probes = crate::probe::ProbeState::default();
let router = crate::router::try_build_router_inner(
all_routes,
&config,
state,
crate::router::RouterContext {
exception_filters: Vec::new(),
scoped_groups: Vec::new(),
merge_routers: Vec::new(),
nest_routers: Vec::new(),
custom_layers,
error_page_renderer: None,
session_store,
#[cfg(feature = "openapi")]
openapi: None,
},
)
.unwrap_or_else(|error| {
eprintln!("Failed to build router: {error}");
std::process::exit(1);
});
let env = crate::config::OsEnv;
let dist_dir = project_dir("dist", &env);
eprintln!("Building {} static route(s)...", static_metas.len());
match crate::static_gen::render_static_routes(router, &static_metas, &dist_dir).await {
Ok(()) => {
eprintln!(
"\n \u{2713} Static build complete \u{2192} {}",
dist_dir.display()
);
}
Err(e) => {
eprintln!("\n \u{2717} Static build failed: {e}");
std::process::exit(1);
}
}
}
}
pub(crate) fn is_static_build_mode() -> bool {
std::env::var("AUTUMN_BUILD_STATIC").as_deref() == Ok("1")
}
#[allow(clippy::cast_possible_truncation)]
#[allow(clippy::cognitive_complexity)]
fn start_task_scheduler(
tasks: Vec<crate::task::TaskInfo>,
state: &AppState,
shutdown: tokio_util::sync::CancellationToken,
) {
tracing::info!(count = tasks.len(), "Starting scheduled tasks");
for task_info in &tasks {
let schedule_desc = task_info.schedule.to_string();
tracing::info!(name = %task_info.name, schedule = %schedule_desc, "Registered task");
}
let mut cron_tasks: Vec<(String, String, Option<String>, crate::task::TaskHandler)> =
Vec::new();
for task_info in tasks {
let state = state.clone();
let name = task_info.name.clone();
let handler = task_info.handler;
let schedule_desc = task_info.schedule.to_string();
match task_info.schedule {
crate::task::Schedule::FixedDelay(delay) => {
state.task_registry.register(&name, &schedule_desc);
tokio::spawn(async move {
loop {
tokio::time::sleep(delay).await;
execute_fixed_delay_task(name.clone(), state.clone(), handler).await;
}
});
}
crate::task::Schedule::Cron {
expression,
timezone,
} => {
state.task_registry.register(&name, &schedule_desc);
cron_tasks.push((name, expression, timezone, handler));
}
}
}
if !cron_tasks.is_empty() {
let state = state.clone();
tokio::spawn(async move {
run_cron_scheduler(cron_tasks, state, shutdown).await;
});
}
}
#[allow(unused_variables, clippy::needless_pass_by_value)]
fn send_ws_sys_task_msg(
state: &AppState,
event: &str,
name: &str,
extra: Vec<(&str, serde_json::Value)>,
) {
#[cfg(feature = "ws")]
{
let mut msg = serde_json::json!({
"event": event,
"task": name,
"timestamp": chrono::Utc::now().to_rfc3339(),
});
if let Some(map) = msg.as_object_mut() {
for (k, v) in extra {
map.insert(k.to_string(), v);
}
}
let _ = state.channels().sender("sys:tasks").send(msg.to_string());
}
}
async fn execute_task_result(
state: &AppState,
handler: crate::task::TaskHandler,
start: std::time::Instant,
name: &str,
schedule: &'static str,
) -> Result<u64, (u64, String)> {
let task_span = tracing::info_span!(
parent: None,
"scheduled_task",
otel.kind = "internal",
task = %name,
schedule = schedule,
);
let result = (handler)(state.clone()).instrument(task_span).await;
let duration_ms = u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX);
match result {
Ok(()) => Ok(duration_ms),
Err(e) => Err((duration_ms, e.to_string())),
}
}
async fn execute_fixed_delay_task(
name: String,
state: AppState,
handler: crate::task::TaskHandler,
) {
tracing::debug!(task = %name, "Running scheduled task");
state.task_registry.record_start(&name);
send_ws_sys_task_msg(&state, "started", &name, vec![]);
let start = std::time::Instant::now();
match execute_task_result(&state, handler, start, &name, "fixed_delay").await {
Ok(duration_ms) => {
state.task_registry.record_success(&name, duration_ms);
tracing::debug!(task = %name, "Task completed");
send_ws_sys_task_msg(
&state,
"success",
&name,
vec![("duration_ms", serde_json::json!(duration_ms))],
);
}
Err((duration_ms, error_str)) => {
state
.task_registry
.record_failure(&name, duration_ms, &error_str);
tracing::warn!(task = %name, error = %error_str, "Task failed");
send_ws_sys_task_msg(
&state,
"failure",
&name,
vec![
("duration_ms", serde_json::json!(duration_ms)),
("error", serde_json::json!(error_str)),
],
);
}
}
}
async fn execute_cron_task(name: String, state: AppState, handler: crate::task::TaskHandler) {
tracing::debug!(task = %name, "Running cron task");
state.task_registry.record_start(&name);
send_ws_sys_task_msg(&state, "started", &name, vec![]);
let start = std::time::Instant::now();
match execute_task_result(&state, handler, start, &name, "cron").await {
Ok(duration_ms) => {
state.task_registry.record_success(&name, duration_ms);
tracing::debug!(task = %name, "Cron task completed");
send_ws_sys_task_msg(
&state,
"success",
&name,
vec![("duration_ms", serde_json::json!(duration_ms))],
);
}
Err((duration_ms, error_str)) => {
state
.task_registry
.record_failure(&name, duration_ms, &error_str);
tracing::warn!(task = %name, error = %error_str, "Cron task failed");
send_ws_sys_task_msg(
&state,
"failure",
&name,
vec![
("duration_ms", serde_json::json!(duration_ms)),
("error", serde_json::json!(error_str)),
],
);
}
}
}
async fn register_cron_task(
sched: &tokio_cron_scheduler::JobScheduler,
name: String,
expression: String,
timezone: Option<String>,
handler: crate::task::TaskHandler,
state: AppState,
) {
let state_clone = state.clone();
let name_clone = name.clone();
let job_result = build_cron_job(&expression, timezone.as_deref(), move |_uuid, _lock| {
let state = state_clone.clone();
let name = name_clone.clone();
Box::pin(async move {
execute_cron_task(name, state, handler).await;
}) as std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>
});
match job_result {
Ok(job) => {
if let Err(e) = sched.add(job).await {
tracing::error!(task = %name, error = %e, "Failed to add cron task to scheduler");
}
}
Err(e) => {
tracing::error!(task = %name, error = %e, "Failed to create cron job");
}
}
}
async fn setup_cron_scheduler(
tasks: Vec<(String, String, Option<String>, crate::task::TaskHandler)>,
state: AppState,
) -> Option<tokio_cron_scheduler::JobScheduler> {
use tokio_cron_scheduler::JobScheduler;
let sched = match JobScheduler::new().await {
Ok(s) => s,
Err(e) => {
tracing::error!(error = %e, "Failed to create cron job scheduler");
return None;
}
};
for (name, expression, timezone, handler) in tasks {
register_cron_task(&sched, name, expression, timezone, handler, state.clone()).await;
}
if let Err(e) = sched.start().await {
tracing::error!(error = %e, "Failed to start cron scheduler");
return None;
}
Some(sched)
}
#[allow(clippy::cognitive_complexity)]
async fn run_cron_scheduler(
tasks: Vec<(String, String, Option<String>, crate::task::TaskHandler)>,
state: AppState,
shutdown: tokio_util::sync::CancellationToken,
) {
let Some(mut sched) = setup_cron_scheduler(tasks, state).await else {
return;
};
tracing::info!("Cron scheduler started");
shutdown.cancelled().await;
tracing::info!("Shutting down cron scheduler");
if let Err(e) = sched.shutdown().await {
tracing::error!(error = %e, "Failed to shut down cron scheduler");
}
}
fn build_cron_job<F>(
expression: &str,
timezone: Option<&str>,
run: F,
) -> Result<tokio_cron_scheduler::Job, tokio_cron_scheduler::JobSchedulerError>
where
F: 'static
+ FnMut(
uuid::Uuid,
tokio_cron_scheduler::JobScheduler,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>
+ Send
+ Sync,
{
use tokio_cron_scheduler::Job;
if let Some(tz_str) = timezone {
match tz_str.parse::<chrono_tz::Tz>() {
Ok(tz) => return Job::new_async_tz(expression, tz, run),
Err(_) => {
tracing::warn!(
timezone = %tz_str,
"Unrecognized timezone; falling back to UTC"
);
}
}
}
Job::new_async(expression, run)
}
async fn run_startup_hooks(hooks: &[StartupHook], state: AppState) -> crate::AutumnResult<()> {
for hook in hooks {
hook(state.clone()).await?;
}
Ok(())
}
async fn run_shutdown_hooks(hooks: &[ShutdownHook]) {
for hook in hooks.iter().rev() {
hook().await;
}
}
#[allow(clippy::cognitive_complexity)]
fn log_startup_transparency(
routes: &[Route],
tasks: &[crate::task::TaskInfo],
scoped_groups: &[ScopedGroup],
config: &AutumnConfig,
) {
tracing::info!(
"Registered routes:{}",
format_route_lines(routes, scoped_groups, config)
);
if let Some(task_lines) = format_task_lines(tasks) {
tracing::info!("Scheduled tasks:{task_lines}");
}
tracing::info!("Active middleware: {}", format_middleware_list(config));
tracing::info!("Configuration:{}", format_config_summary(config));
}
fn fail_fast_on_invalid_session_config(config: &AutumnConfig, has_custom_session_store: bool) {
if has_custom_session_store {
return;
}
if let Err(error) = config.session.backend_plan(config.profile.as_deref()) {
eprintln!("Invalid session backend config: {error}");
std::process::exit(1);
}
}
async fn load_config_and_telemetry(
config_loader: Option<ConfigLoaderFactory>,
telemetry_provider: Option<Box<dyn crate::telemetry::TelemetryProvider>>,
) -> (AutumnConfig, crate::telemetry::TelemetryGuard) {
let config = match config_loader {
Some(factory) => factory().await,
None => crate::config::TomlEnvConfigLoader::new().load().await,
}
.unwrap_or_else(|e| {
eprintln!("Failed to load configuration: {e}");
std::process::exit(1);
});
let provider: Box<dyn crate::telemetry::TelemetryProvider> = telemetry_provider
.unwrap_or_else(|| Box::new(crate::telemetry::TracingOtlpTelemetryProvider::new()));
let telemetry_guard = provider
.init(&config.log, &config.telemetry, config.profile.as_deref())
.unwrap_or_else(|error| {
eprintln!("Failed to initialize telemetry: {error}");
std::process::exit(1);
});
(config, telemetry_guard)
}
#[cfg(feature = "db")]
async fn setup_database(
config: &AutumnConfig,
migrations: Vec<crate::migrate::EmbeddedMigrations>,
pool_provider: Option<PoolProviderFactory>,
) -> Result<
Option<diesel_async::pooled_connection::deadpool::Pool<diesel_async::AsyncPgConnection>>,
String,
> {
let pool = match pool_provider {
Some(factory) => factory(config.database.clone()).await,
None => {
crate::db::DieselDeadpoolPoolProvider::new()
.create_pool(&config.database)
.await
}
}
.map_err(|e| format!("Failed to create database pool: {e}"))?;
if pool.is_some() {
if let Some(url) = &config.database.url {
for mig in migrations {
crate::migrate::auto_migrate(
url,
config.profile.as_deref(),
config.database.auto_migrate_in_production,
mig,
);
}
}
}
Ok(pool)
}
fn build_state(
config: &AutumnConfig,
#[cfg(feature = "db")] pool: Option<
diesel_async::pooled_connection::deadpool::Pool<diesel_async::AsyncPgConnection>,
>,
) -> AppState {
AppState {
extensions: std::sync::Arc::new(std::sync::RwLock::new(std::collections::HashMap::new())),
#[cfg(feature = "db")]
pool,
profile: config.profile.clone(),
started_at: std::time::Instant::now(),
health_detailed: config.health.detailed,
probes: crate::probe::ProbeState::pending_startup(),
metrics: crate::middleware::MetricsCollector::new(),
log_levels: crate::actuator::LogLevels::new(&config.log.level),
task_registry: crate::actuator::TaskRegistry::new(),
config_props: crate::actuator::ConfigProperties::from_config(config),
#[cfg(feature = "ws")]
channels: crate::channels::Channels::new(32),
#[cfg(feature = "ws")]
shutdown: tokio_util::sync::CancellationToken::new(),
}
}
fn format_route_lines(
routes: &[Route],
scoped_groups: &[ScopedGroup],
config: &AutumnConfig,
) -> String {
use std::fmt::Write as _;
let mut out = String::new();
for route in routes {
let _ = write!(
out,
"\n {} {:<8} -> {}",
route.path, route.method, route.name
);
}
for group in scoped_groups {
for route in &group.routes {
let _ = write!(
out,
"\n {}{} {:<8} -> {} (scoped)",
group.prefix, route.path, route.method, route.name
);
}
}
let mut probe_paths = std::collections::HashSet::new();
for (path, name) in [
(config.health.live_path.as_str(), "live"),
(config.health.ready_path.as_str(), "ready"),
(config.health.startup_path.as_str(), "startup"),
(config.health.path.as_str(), "health"),
] {
if probe_paths.insert(path) {
let _ = write!(out, "\n {} {:<8} -> {}", path, "GET", name);
}
}
let _ = write!(
out,
"\n {} {:<8} -> actuator",
crate::actuator::actuator_route_glob(&config.actuator.prefix),
"GET"
);
#[cfg(feature = "htmx")]
{
out.push_str("\n /static/js/htmx.min.js GET -> htmx");
out.push_str("\n /static/js/autumn-htmx-csrf.js GET -> htmx csrf");
}
out
}
fn format_task_lines(tasks: &[crate::task::TaskInfo]) -> Option<String> {
use std::fmt::Write as _;
if tasks.is_empty() {
return None;
}
let mut out = String::new();
for task in tasks {
let schedule = task.schedule.to_string();
let _ = write!(out, "\n {} ({schedule})", task.name);
}
Some(out)
}
fn format_middleware_list(config: &AutumnConfig) -> String {
let mut items = vec![
"RequestId",
"SecurityHeaders",
"Session (in-memory)",
"ErrorPages",
];
if !config.cors.allowed_origins.is_empty() {
items.push("CORS");
}
if config.security.csrf.enabled {
items.push("CSRF");
}
items.push("Metrics");
items.join(", ")
}
fn mask_database_url(url: &str, pool_size: usize) -> String {
if let Ok(mut parsed_url) = url::Url::parse(url) {
if parsed_url.password().is_some() {
let _ = parsed_url.set_password(Some("****"));
return format!("{parsed_url} (pool_size={pool_size})");
}
format!("{parsed_url} (pool_size={pool_size})")
} else {
format!("**** (pool_size={pool_size})")
}
}
fn format_config_summary(config: &AutumnConfig) -> String {
let profile = config.profile.as_deref().unwrap_or("none");
let db_status = config.database.url.as_deref().map_or_else(
|| "not configured".to_owned(),
|url| mask_database_url(url, config.database.pool_size),
);
let telemetry_status = if config.telemetry.enabled {
let endpoint = config
.telemetry
.otlp_endpoint
.as_deref()
.unwrap_or("<missing endpoint>");
format!("{:?} -> {endpoint}", config.telemetry.protocol)
} else {
"disabled".to_owned()
};
format!(
"\
\n profile: {profile}\
\n server: {}:{}\
\n database: {db_status}\
\n log_level: {}\
\n log_format: {:?}\
\n telemetry: {telemetry_status}\
\n health: {} (detailed={})\
\n actuator: sensitive={}\
\n shutdown: {}s",
config.server.host,
config.server.port,
config.log.level,
config.log.format,
config.health.path,
config.health.detailed,
config.actuator.sensitive,
config.server.shutdown_timeout_secs,
)
}
pub(crate) fn project_dir(subdir: &str, env: &dyn crate::config::Env) -> std::path::PathBuf {
env.var("AUTUMN_MANIFEST_DIR").map_or_else(
|_| std::path::PathBuf::from(subdir),
|d| std::path::PathBuf::from(d).join(subdir),
)
}
async fn shutdown_signal() {
let ctrl_c = async {
tokio::signal::ctrl_c()
.await
.expect("Failed to install Ctrl+C handler");
tracing::info!("Received Ctrl+C, starting graceful shutdown");
};
#[cfg(unix)]
let terminate = async {
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
.expect("Failed to install SIGTERM handler")
.recv()
.await;
tracing::info!("Received SIGTERM, starting graceful shutdown");
};
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
tokio::select! {
() = ctrl_c => {},
() = terminate => {},
}
}
#[cfg(test)]
mod tests {
use super::*;
use axum::body::Body;
use axum::http::{Request, StatusCode};
use tower::ServiceExt;
pub fn test_router(routes: Vec<Route>) -> axum::Router {
let config = AutumnConfig::default();
let state = AppState {
extensions: std::sync::Arc::new(std::sync::RwLock::new(
std::collections::HashMap::new(),
)),
#[cfg(feature = "db")]
pool: None,
profile: None,
started_at: std::time::Instant::now(),
health_detailed: true,
probes: crate::probe::ProbeState::ready_for_test(),
metrics: crate::middleware::MetricsCollector::new(),
log_levels: crate::actuator::LogLevels::new("info"),
task_registry: crate::actuator::TaskRegistry::new(),
config_props: crate::actuator::ConfigProperties::default(),
#[cfg(feature = "ws")]
channels: crate::channels::Channels::new(32),
#[cfg(feature = "ws")]
shutdown: tokio_util::sync::CancellationToken::new(),
};
crate::router::build_router(routes, &config, state)
}
pub fn test_get_route(path: &'static str, name: &'static str) -> Route {
Route {
method: http::Method::GET,
path,
handler: axum::routing::get(|| async { "ok" }),
name,
api_doc: crate::openapi::ApiDoc {
method: "GET",
path,
operation_id: name,
success_status: 200,
..Default::default()
},
}
}
#[test]
fn app_builder_routes_adds_routes() {
let builder = app();
assert_eq!(builder.routes.len(), 0);
let builder = builder.routes(vec![test_get_route("/1", "route1")]);
assert_eq!(builder.routes.len(), 1);
let builder = builder.routes(vec![
test_get_route("/2", "route2"),
test_get_route("/3", "route3"),
]);
assert_eq!(builder.routes.len(), 3);
assert_eq!(builder.routes[0].path, "/1");
assert_eq!(builder.routes[1].path, "/2");
assert_eq!(builder.routes[2].path, "/3");
}
#[test]
fn app_builder_extensions_store_and_update_typed_values() {
let builder = app()
.with_extension::<String>("haunted".into())
.update_extension::<String, _, _>(String::new, |value| value.push_str(" harvest"));
let value = builder
.extension::<String>()
.expect("string extension should be present");
assert_eq!(value, "haunted harvest");
}
#[tokio::test]
async fn startup_and_shutdown_hooks_run_in_expected_order() {
let events = Arc::new(std::sync::Mutex::new(Vec::<&'static str>::new()));
let startup_events = Arc::clone(&events);
let shutdown_a = Arc::clone(&events);
let shutdown_b = Arc::clone(&events);
let builder = app()
.on_startup(move |_state| {
let startup_events = Arc::clone(&startup_events);
async move {
startup_events
.lock()
.expect("events lock poisoned")
.push("start");
Ok(())
}
})
.on_shutdown(move || {
let shutdown_a = Arc::clone(&shutdown_a);
async move {
shutdown_a
.lock()
.expect("events lock poisoned")
.push("stop-a");
}
})
.on_shutdown(move || {
let shutdown_b = Arc::clone(&shutdown_b);
async move {
shutdown_b
.lock()
.expect("events lock poisoned")
.push("stop-b");
}
});
run_startup_hooks(&builder.startup_hooks, AppState::for_test())
.await
.expect("startup hooks should succeed");
run_shutdown_hooks(&builder.shutdown_hooks).await;
let recorded_events = events.lock().expect("events lock poisoned").clone();
assert_eq!(recorded_events, vec!["start", "stop-b", "stop-a"]);
}
#[tokio::test]
async fn startup_hook_errors_propagate() {
let builder = app().on_startup(|_state| async {
Err(crate::AutumnError::service_unavailable_msg(
"startup ritual failed",
))
});
let error = run_startup_hooks(&builder.startup_hooks, AppState::for_test())
.await
.expect_err("startup hook should fail");
assert!(error.to_string().contains("startup ritual failed"));
}
#[tokio::test]
async fn build_router_mounts_user_routes() {
let router = test_router(vec![test_get_route("/test", "test_handler")]);
let response = router
.oneshot(Request::builder().uri("/test").body(Body::empty()).unwrap())
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
assert_eq!(&body[..], b"ok");
}
#[tokio::test]
async fn build_router_mounts_health_check_at_default_path() {
let router = test_router(vec![test_get_route("/dummy", "dummy")]);
let response = router
.oneshot(
Request::builder()
.uri("/health")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(json["status"], "ok");
}
#[tokio::test]
async fn build_router_mounts_health_check_at_custom_path() {
let mut config = AutumnConfig::default();
config.health.path = "/healthz".to_owned();
let state = AppState {
extensions: std::sync::Arc::new(std::sync::RwLock::new(
std::collections::HashMap::new(),
)),
#[cfg(feature = "db")]
pool: None,
profile: None,
started_at: std::time::Instant::now(),
health_detailed: true,
probes: crate::probe::ProbeState::ready_for_test(),
metrics: crate::middleware::MetricsCollector::new(),
log_levels: crate::actuator::LogLevels::new("info"),
task_registry: crate::actuator::TaskRegistry::new(),
config_props: crate::actuator::ConfigProperties::default(),
#[cfg(feature = "ws")]
channels: crate::channels::Channels::new(32),
#[cfg(feature = "ws")]
shutdown: tokio_util::sync::CancellationToken::new(),
};
let router =
crate::router::build_router(vec![test_get_route("/dummy", "dummy")], &config, state);
let response = router
.oneshot(
Request::builder()
.uri("/healthz")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
}
#[tokio::test]
async fn build_router_adds_request_id_header() {
let router = test_router(vec![test_get_route("/test", "test")]);
let response = router
.oneshot(Request::builder().uri("/test").body(Body::empty()).unwrap())
.await
.unwrap();
assert!(response.headers().contains_key("x-request-id"));
}
#[tokio::test]
async fn build_router_unknown_route_returns_404() {
let router = test_router(vec![test_get_route("/exists", "exists")]);
let response = router
.oneshot(Request::builder().uri("/nope").body(Body::empty()).unwrap())
.await
.unwrap();
assert_eq!(response.status(), StatusCode::NOT_FOUND);
}
#[tokio::test]
async fn build_router_multiple_routes() {
let router = test_router(vec![test_get_route("/a", "a"), test_get_route("/b", "b")]);
let resp_a = router
.clone()
.oneshot(Request::builder().uri("/a").body(Body::empty()).unwrap())
.await
.unwrap();
assert_eq!(resp_a.status(), StatusCode::OK);
let resp_b = router
.oneshot(Request::builder().uri("/b").body(Body::empty()).unwrap())
.await
.unwrap();
assert_eq!(resp_b.status(), StatusCode::OK);
}
#[tokio::test]
async fn build_router_post_route() {
let post_routes = vec![Route {
method: http::Method::POST,
path: "/submit",
handler: axum::routing::post(|| async { "posted" }),
name: "submit",
api_doc: crate::openapi::ApiDoc {
method: "POST",
path: "/submit",
operation_id: "submit",
success_status: 200,
..Default::default()
},
}];
let config = AutumnConfig::default();
let state = AppState {
extensions: std::sync::Arc::new(std::sync::RwLock::new(
std::collections::HashMap::new(),
)),
#[cfg(feature = "db")]
pool: None,
profile: None,
started_at: std::time::Instant::now(),
health_detailed: true,
probes: crate::probe::ProbeState::ready_for_test(),
metrics: crate::middleware::MetricsCollector::new(),
log_levels: crate::actuator::LogLevels::new("info"),
task_registry: crate::actuator::TaskRegistry::new(),
config_props: crate::actuator::ConfigProperties::default(),
#[cfg(feature = "ws")]
channels: crate::channels::Channels::new(32),
#[cfg(feature = "ws")]
shutdown: tokio_util::sync::CancellationToken::new(),
};
let router = crate::router::build_router(post_routes, &config, state);
let response = router
.oneshot(
Request::builder()
.method("POST")
.uri("/submit")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
}
#[tokio::test]
async fn build_router_merges_methods_on_same_path() {
let route_list = vec![
Route {
method: http::Method::GET,
path: "/admin",
handler: axum::routing::get(|| async { "list" }),
name: "admin_list",
api_doc: crate::openapi::ApiDoc {
method: "GET",
path: "/admin",
operation_id: "admin_list",
success_status: 200,
..Default::default()
},
},
Route {
method: http::Method::POST,
path: "/admin",
handler: axum::routing::post(|| async { "created" }),
name: "create",
api_doc: crate::openapi::ApiDoc {
method: "POST",
path: "/admin",
operation_id: "create",
success_status: 200,
..Default::default()
},
},
];
let config = AutumnConfig::default();
let state = AppState {
extensions: std::sync::Arc::new(std::sync::RwLock::new(
std::collections::HashMap::new(),
)),
#[cfg(feature = "db")]
pool: None,
profile: None,
started_at: std::time::Instant::now(),
health_detailed: true,
probes: crate::probe::ProbeState::ready_for_test(),
metrics: crate::middleware::MetricsCollector::new(),
log_levels: crate::actuator::LogLevels::new("info"),
task_registry: crate::actuator::TaskRegistry::new(),
config_props: crate::actuator::ConfigProperties::default(),
#[cfg(feature = "ws")]
channels: crate::channels::Channels::new(32),
#[cfg(feature = "ws")]
shutdown: tokio_util::sync::CancellationToken::new(),
};
let router = crate::router::build_router(route_list, &config, state);
let resp = router
.clone()
.oneshot(
Request::builder()
.uri("/admin")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
.await
.unwrap();
assert_eq!(&body[..], b"list");
let resp = router
.oneshot(
Request::builder()
.method("POST")
.uri("/admin")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
.await
.unwrap();
assert_eq!(&body[..], b"created");
}
#[cfg(feature = "htmx")]
#[tokio::test]
async fn htmx_handler_returns_javascript_with_correct_headers() {
let app = axum::Router::new().route(
crate::htmx::HTMX_JS_PATH,
axum::routing::get(crate::router::htmx_handler),
);
let response = app
.oneshot(
Request::builder()
.uri(crate::htmx::HTMX_JS_PATH)
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let content_type = response
.headers()
.get("content-type")
.unwrap()
.to_str()
.unwrap();
assert!(
content_type.contains("application/javascript"),
"Expected application/javascript, got {content_type}"
);
let cache_control = response
.headers()
.get("cache-control")
.unwrap()
.to_str()
.unwrap();
assert!(
cache_control.contains("immutable"),
"Expected immutable cache, got {cache_control}"
);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
assert_eq!(body.len(), crate::htmx::HTMX_JS.len());
let start = std::str::from_utf8(&body[..50]).expect("htmx should be valid UTF-8");
assert!(
start.contains("htmx") || start.contains("function"),
"Response doesn't look like htmx JavaScript: {start}"
);
}
#[cfg(feature = "htmx")]
#[tokio::test]
async fn htmx_csrf_handler_returns_csp_compatible_javascript() {
let app = axum::Router::new().route(
crate::htmx::HTMX_CSRF_JS_PATH,
axum::routing::get(crate::router::htmx_csrf_handler),
);
let response = app
.oneshot(
Request::builder()
.uri(crate::htmx::HTMX_CSRF_JS_PATH)
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(
response
.headers()
.get("content-type")
.and_then(|value| value.to_str().ok()),
Some("application/javascript")
);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let js = std::str::from_utf8(&body).expect("csrf helper should be valid utf-8");
assert!(js.contains("htmx:configRequest"));
assert!(js.contains("X-CSRF-Token"));
assert!(!js.contains("<script"));
}
#[cfg(feature = "htmx")]
#[tokio::test]
async fn build_router_serves_htmx_js() {
let router = test_router(vec![test_get_route("/dummy", "dummy")]);
let response = router
.oneshot(
Request::builder()
.uri(crate::htmx::HTMX_JS_PATH)
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let ct = response
.headers()
.get("content-type")
.unwrap()
.to_str()
.unwrap();
assert!(ct.contains("javascript"));
}
#[cfg(feature = "htmx")]
#[tokio::test]
async fn build_router_serves_htmx_csrf_js() {
let router = test_router(vec![test_get_route("/dummy", "dummy")]);
let response = router
.oneshot(
Request::builder()
.uri(crate::htmx::HTMX_CSRF_JS_PATH)
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let csp = response
.headers()
.get("content-security-policy")
.expect("framework JS should still receive security headers")
.to_str()
.unwrap();
assert!(csp.contains("script-src 'self'"), "csp = {csp}");
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let js = std::str::from_utf8(&body).expect("csrf helper should be valid utf-8");
assert!(js.contains("htmx:configRequest"));
assert!(js.contains("X-CSRF-Token"));
}
#[tokio::test]
async fn build_router_serves_default_favicon_without_404() {
let router = test_router(vec![test_get_route("/dummy", "dummy")]);
let response = router
.oneshot(
Request::builder()
.uri(crate::router::DEFAULT_FAVICON_PATH)
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::NO_CONTENT);
assert!(
response.headers().contains_key("content-security-policy"),
"framework fallback responses should still receive security headers"
);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
assert!(body.is_empty());
}
#[tokio::test]
async fn build_router_does_not_override_user_favicon_route() {
let router = test_router(vec![test_get_route(
crate::router::DEFAULT_FAVICON_PATH,
"favicon",
)]);
let response = router
.oneshot(
Request::builder()
.uri(crate::router::DEFAULT_FAVICON_PATH)
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
assert_eq!(&body[..], b"ok");
}
#[tokio::test]
async fn build_router_serves_static_files_for_unmatched_paths() {
use std::collections::HashMap;
let tmp = tempfile::tempdir().expect("tempdir");
let dist = tmp.path().join("dist");
std::fs::create_dir_all(dist.join("docs")).expect("mkdir");
std::fs::write(dist.join("docs/index.html"), "<h1>Static Docs</h1>").expect("write");
let manifest = crate::static_gen::StaticManifest {
generated_at: "2026-03-27T00:00:00Z".to_owned(),
autumn_version: "0.2.0".to_owned(),
routes: HashMap::from([(
"/docs".to_owned(),
crate::static_gen::ManifestEntry {
file: "docs/index.html".to_owned(),
revalidate: None,
},
)]),
};
let json = serde_json::to_string(&manifest).expect("serialize");
std::fs::write(dist.join("manifest.json"), json).expect("write manifest");
let config = AutumnConfig::default();
let state = AppState {
extensions: std::sync::Arc::new(std::sync::RwLock::new(
std::collections::HashMap::new(),
)),
#[cfg(feature = "db")]
pool: None,
profile: None,
started_at: std::time::Instant::now(),
health_detailed: true,
probes: crate::probe::ProbeState::ready_for_test(),
metrics: crate::middleware::MetricsCollector::new(),
log_levels: crate::actuator::LogLevels::new("info"),
task_registry: crate::actuator::TaskRegistry::new(),
config_props: crate::actuator::ConfigProperties::default(),
#[cfg(feature = "ws")]
channels: crate::channels::Channels::new(32),
#[cfg(feature = "ws")]
shutdown: tokio_util::sync::CancellationToken::new(),
};
let router = crate::router::build_router_with_static(
vec![test_get_route("/other", "other_page")],
&config,
state,
Some(dist.as_path()),
);
let response = router
.oneshot(
Request::builder()
.uri("/docs/")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let csp = response
.headers()
.get("content-security-policy")
.expect("static-first HTML should still receive security headers")
.to_str()
.unwrap();
assert!(csp.contains("script-src 'self'"), "csp = {csp}");
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
assert_eq!(std::str::from_utf8(&body).unwrap(), "<h1>Static Docs</h1>");
}
#[tokio::test]
async fn build_mode_static_rendering_bypasses_startup_barrier() {
temp_env::async_with_vars([("AUTUMN_BUILD_STATIC", Some("1"))], async {
let config = AutumnConfig::default();
let state = AppState::for_test().with_startup_complete(false);
let router = crate::router::build_router(
vec![Route {
method: http::Method::GET,
path: "/about",
handler: axum::routing::get(|| async { "About Page Content" }),
name: "about",
api_doc: crate::openapi::ApiDoc {
method: "GET",
path: "/about",
operation_id: "about",
success_status: 200,
..Default::default()
},
}],
&config,
state,
);
let tmp = tempfile::tempdir().unwrap();
let dist = tmp.path().join("dist");
let result = crate::static_gen::render_static_routes(
router,
&[crate::static_gen::StaticRouteMeta {
path: "/about",
name: "about",
revalidate: None,
params_fn: None,
}],
&dist,
)
.await;
assert!(result.is_ok(), "build failed: {:?}", result.err());
let html = std::fs::read_to_string(dist.join("about/index.html")).unwrap();
assert_eq!(html, "About Page Content");
})
.await;
}
#[tokio::test]
async fn build_router_injects_live_reload_script_when_enabled() {
let reload_file = tempfile::NamedTempFile::new().expect("reload state file");
std::fs::write(reload_file.path(), r#"{"version":0,"kind":"full"}"#).expect("write");
temp_env::async_with_vars(
[
("AUTUMN_DEV_RELOAD", Some("1")),
(
"AUTUMN_DEV_RELOAD_STATE",
Some(reload_file.path().to_str().expect("utf-8 path")),
),
],
async {
let router = test_router(vec![Route {
method: http::Method::GET,
path: "/page",
handler: axum::routing::get(|| async {
axum::response::Html("<html><body><main>ok</main></body></html>")
}),
name: "page",
api_doc: crate::openapi::ApiDoc {
method: "GET",
path: "/page",
operation_id: "page",
success_status: 200,
..Default::default()
},
}]);
let response = router
.oneshot(Request::builder().uri("/page").body(Body::empty()).unwrap())
.await
.unwrap();
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let html = std::str::from_utf8(&body).expect("utf-8");
assert!(html.contains("/__autumn/live-reload"));
},
)
.await;
}
#[tokio::test]
async fn build_router_mounts_dev_reload_script_endpoint_when_enabled() {
let reload_file = tempfile::NamedTempFile::new().expect("reload state file");
std::fs::write(reload_file.path(), r#"{"version":0,"kind":"full"}"#).expect("write");
temp_env::async_with_vars(
[
("AUTUMN_DEV_RELOAD", Some("1")),
(
"AUTUMN_DEV_RELOAD_STATE",
Some(reload_file.path().to_str().expect("utf-8 path")),
),
],
async {
let router = test_router(vec![test_get_route("/dummy", "dummy")]);
let response = router
.oneshot(
Request::builder()
.uri("/__autumn/live-reload.js")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(
response
.headers()
.get("content-type")
.and_then(|v| v.to_str().ok()),
Some("application/javascript; charset=utf-8")
);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let js = std::str::from_utf8(&body).expect("utf-8");
assert!(js.contains("fetch("), "js body: {js}");
},
)
.await;
}
#[tokio::test]
async fn build_router_mounts_dev_reload_endpoint_when_enabled() {
let reload_file = tempfile::NamedTempFile::new().expect("reload state file");
std::fs::write(reload_file.path(), r#"{"version":7,"kind":"css"}"#).expect("write");
temp_env::async_with_vars(
[
("AUTUMN_DEV_RELOAD", Some("1")),
(
"AUTUMN_DEV_RELOAD_STATE",
Some(reload_file.path().to_str().expect("utf-8 path")),
),
],
async {
let router = test_router(vec![test_get_route("/dummy", "dummy")]);
let response = router
.oneshot(
Request::builder()
.uri("/__autumn/live-reload")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(
response.headers().get("cache-control").unwrap(),
"no-store, no-cache, must-revalidate"
);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
assert_eq!(&body[..], br#"{"version":7,"kind":"css"}"#);
},
)
.await;
}
#[tokio::test]
async fn build_router_disables_cache_for_static_assets_in_dev_reload_mode() {
let project = tempfile::tempdir().expect("project dir");
let static_dir = project.path().join("static");
std::fs::create_dir_all(&static_dir).expect("mkdir");
std::fs::write(static_dir.join("demo.txt"), "hello").expect("write static file");
let reload_file = tempfile::NamedTempFile::new().expect("reload state file");
std::fs::write(reload_file.path(), r#"{"version":0,"kind":"full"}"#).expect("write");
temp_env::async_with_vars(
[
(
"AUTUMN_MANIFEST_DIR",
Some(project.path().to_str().expect("utf-8 path")),
),
("AUTUMN_DEV_RELOAD", Some("1")),
(
"AUTUMN_DEV_RELOAD_STATE",
Some(reload_file.path().to_str().expect("utf-8 path")),
),
],
async {
let router = test_router(vec![test_get_route("/dummy", "dummy")]);
let response = router
.oneshot(
Request::builder()
.uri("/static/demo.txt")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(
response.headers().get("cache-control").unwrap(),
"no-store, no-cache, must-revalidate"
);
},
)
.await;
}
#[test]
fn app_builder_accepts_static_routes() {
use crate::static_gen::StaticRouteMeta;
let metas = vec![StaticRouteMeta {
path: "/about",
name: "about",
revalidate: None,
params_fn: None,
}];
let builder = app().static_routes(metas);
assert_eq!(builder.static_metas.len(), 1);
}
#[test]
fn project_dir_defaults_to_subdir() {
let env = crate::config::MockEnv::new();
let dir = super::project_dir("dist", &env);
assert_eq!(dir, std::path::PathBuf::from("dist"));
}
pub fn test_router_with_config(routes: Vec<Route>, config: &AutumnConfig) -> axum::Router {
let state = AppState {
extensions: std::sync::Arc::new(std::sync::RwLock::new(
std::collections::HashMap::new(),
)),
#[cfg(feature = "db")]
pool: None,
profile: None,
started_at: std::time::Instant::now(),
health_detailed: true,
probes: crate::probe::ProbeState::ready_for_test(),
metrics: crate::middleware::MetricsCollector::new(),
log_levels: crate::actuator::LogLevels::new("info"),
task_registry: crate::actuator::TaskRegistry::new(),
config_props: crate::actuator::ConfigProperties::default(),
#[cfg(feature = "ws")]
channels: crate::channels::Channels::new(32),
#[cfg(feature = "ws")]
shutdown: tokio_util::sync::CancellationToken::new(),
};
crate::router::build_router(routes, config, state)
}
#[tokio::test]
async fn cors_wildcard_allows_any_origin() {
let mut config = AutumnConfig::default();
config.cors.allowed_origins = vec!["*".to_owned()];
let router = test_router_with_config(vec![test_get_route("/test", "test")], &config);
let response = router
.oneshot(
Request::builder()
.uri("/test")
.header("Origin", "https://example.com")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(
response
.headers()
.get("access-control-allow-origin")
.unwrap(),
"*"
);
}
#[tokio::test]
async fn cors_specific_origin_reflected() {
let mut config = AutumnConfig::default();
config.cors.allowed_origins = vec!["https://example.com".to_owned()];
let router = test_router_with_config(vec![test_get_route("/test", "test")], &config);
let response = router
.oneshot(
Request::builder()
.uri("/test")
.header("Origin", "https://example.com")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(
response
.headers()
.get("access-control-allow-origin")
.unwrap(),
"https://example.com"
);
}
#[tokio::test]
async fn cors_disabled_when_no_origins() {
let config = AutumnConfig::default();
assert!(config.cors.allowed_origins.is_empty());
let router = test_router_with_config(vec![test_get_route("/test", "test")], &config);
let response = router
.oneshot(
Request::builder()
.uri("/test")
.header("Origin", "https://example.com")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert!(
response
.headers()
.get("access-control-allow-origin")
.is_none()
);
}
#[tokio::test]
async fn cors_preflight_returns_204() {
let mut config = AutumnConfig::default();
config.cors.allowed_origins = vec!["https://example.com".to_owned()];
let router = test_router_with_config(vec![test_get_route("/test", "test")], &config);
let response = router
.oneshot(
Request::builder()
.method("OPTIONS")
.uri("/test")
.header("Origin", "https://example.com")
.header("Access-Control-Request-Method", "GET")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert!(
response
.headers()
.contains_key("access-control-allow-methods")
);
}
#[tokio::test]
async fn build_router_with_static_skips_without_manifest() {
let tmp = tempfile::tempdir().expect("tempdir");
let dist = tmp.path().join("dist");
std::fs::create_dir_all(&dist).expect("mkdir");
let config = AutumnConfig::default();
let state = AppState {
extensions: std::sync::Arc::new(std::sync::RwLock::new(
std::collections::HashMap::new(),
)),
#[cfg(feature = "db")]
pool: None,
profile: None,
started_at: std::time::Instant::now(),
health_detailed: true,
probes: crate::probe::ProbeState::ready_for_test(),
metrics: crate::middleware::MetricsCollector::new(),
log_levels: crate::actuator::LogLevels::new("info"),
task_registry: crate::actuator::TaskRegistry::new(),
config_props: crate::actuator::ConfigProperties::default(),
#[cfg(feature = "ws")]
channels: crate::channels::Channels::new(32),
#[cfg(feature = "ws")]
shutdown: tokio_util::sync::CancellationToken::new(),
};
let router = crate::router::build_router_with_static(
vec![test_get_route("/test", "test")],
&config,
state,
Some(dist.as_path()),
);
let response = router
.oneshot(Request::builder().uri("/test").body(Body::empty()).unwrap())
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
}
#[tokio::test]
async fn build_router_with_static_none_dist() {
let config = AutumnConfig::default();
let state = AppState {
extensions: std::sync::Arc::new(std::sync::RwLock::new(
std::collections::HashMap::new(),
)),
#[cfg(feature = "db")]
pool: None,
profile: None,
started_at: std::time::Instant::now(),
health_detailed: true,
probes: crate::probe::ProbeState::ready_for_test(),
metrics: crate::middleware::MetricsCollector::new(),
log_levels: crate::actuator::LogLevels::new("info"),
task_registry: crate::actuator::TaskRegistry::new(),
config_props: crate::actuator::ConfigProperties::default(),
#[cfg(feature = "ws")]
channels: crate::channels::Channels::new(32),
#[cfg(feature = "ws")]
shutdown: tokio_util::sync::CancellationToken::new(),
};
let router = crate::router::build_router_with_static(
vec![test_get_route("/test", "test")],
&config,
state,
None,
);
let response = router
.oneshot(Request::builder().uri("/test").body(Body::empty()).unwrap())
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
}
#[test]
fn format_route_lines_lists_user_routes() {
let routes = vec![
test_get_route("/", "index"),
test_get_route("/users/{id}", "get_user"),
];
let config = AutumnConfig::default();
let output = format_route_lines(&routes, &[], &config);
assert!(output.contains("-> index"));
assert!(output.contains("/ GET"));
assert!(output.contains("/users/{id}"));
assert!(output.contains("-> get_user"));
}
#[test]
fn config_runtime_drift_format_route_lines_uses_actuator_prefix() {
let mut config = AutumnConfig::default();
config.actuator.prefix = "/ops".to_owned();
let output = format_route_lines(&[], &[], &config);
assert!(output.contains("-> health"));
assert!(output.contains("/ops/*"));
}
#[test]
fn format_task_lines_none_when_empty() {
assert!(format_task_lines(&[]).is_none());
}
#[test]
fn format_task_lines_fixed_delay() {
let tasks = vec![crate::task::TaskInfo {
name: "cleanup".into(),
schedule: crate::task::Schedule::FixedDelay(std::time::Duration::from_secs(300)),
handler: |_| Box::pin(async { Ok(()) }),
}];
let output = format_task_lines(&tasks).unwrap();
assert!(output.contains("cleanup (every 300s)"));
}
#[test]
fn format_task_lines_cron() {
let tasks = vec![crate::task::TaskInfo {
name: "nightly".into(),
schedule: crate::task::Schedule::Cron {
expression: "0 0 * * *".into(),
timezone: None,
},
handler: |_| Box::pin(async { Ok(()) }),
}];
let output = format_task_lines(&tasks).unwrap();
assert!(output.contains("nightly (cron 0 0 * * *)"));
}
#[test]
fn format_middleware_list_default() {
let config = AutumnConfig::default();
let output = format_middleware_list(&config);
assert!(output.contains("RequestId"));
assert!(output.contains("SecurityHeaders"));
assert!(output.contains("Session (in-memory)"));
assert!(output.contains("Metrics"));
assert!(!output.contains("CORS"));
assert!(!output.contains("CSRF"));
}
#[test]
fn format_middleware_list_with_cors_and_csrf() {
let config = AutumnConfig {
cors: crate::config::CorsConfig {
allowed_origins: vec!["https://example.com".into()],
..crate::config::CorsConfig::default()
},
security: crate::security::config::SecurityConfig {
csrf: crate::security::config::CsrfConfig {
enabled: true,
..crate::security::config::CsrfConfig::default()
},
..crate::security::config::SecurityConfig::default()
},
..AutumnConfig::default()
};
let output = format_middleware_list(&config);
assert!(output.contains("CORS"));
assert!(output.contains("CSRF"));
}
#[test]
fn mask_database_url_with_password() {
let masked = mask_database_url("postgres://user:secret@localhost:5432/mydb", 10);
assert!(masked.contains("****"));
assert!(!masked.contains("secret"));
assert!(masked.contains("postgres://user:****@localhost:5432/mydb"));
assert!(masked.contains("pool_size=10"));
}
#[test]
fn mask_database_url_without_password() {
let masked = mask_database_url("postgres://localhost/mydb", 5);
assert!(!masked.contains("****"));
assert!(masked.contains("postgres://localhost/mydb"));
assert!(masked.contains("pool_size=5"));
}
#[test]
fn mask_database_url_edge_cases() {
let masked2 = mask_database_url("postgres://user:p%40ssw%3Ard%21@localhost:5432/mydb", 10);
assert!(masked2.contains("****"));
assert!(!masked2.contains("p%40ssw%3Ard%21"));
assert!(masked2.contains("postgres://user:****@localhost:5432/mydb"));
let masked3 = mask_database_url("postgres://:secret@localhost:5432/mydb", 10);
assert!(masked3.contains("****"));
assert!(!masked3.contains("secret"));
assert!(masked3.contains("postgres://:****@localhost:5432/mydb"));
}
#[test]
fn mask_database_url_invalid_url_fallback() {
let masked = mask_database_url("this is completely invalid as a URL with supersecret", 10);
assert!(masked.contains("****"));
assert!(!masked.contains("supersecret"));
assert!(masked.contains("pool_size=10"));
}
#[test]
fn format_config_summary_defaults() {
let config = AutumnConfig::default();
let output = format_config_summary(&config);
assert!(output.contains("profile: none"));
assert!(output.contains("server: 127.0.0.1:3000"));
assert!(output.contains("database: not configured"));
assert!(output.contains("log_level:"));
assert!(output.contains("telemetry: disabled"));
assert!(output.contains("health: /health"));
}
#[test]
fn format_config_summary_with_db() {
let config = AutumnConfig {
database: crate::config::DatabaseConfig {
url: Some("postgres://user:pass@host/db".into()),
pool_size: 20,
..crate::config::DatabaseConfig::default()
},
..AutumnConfig::default()
};
let output = format_config_summary(&config);
assert!(output.contains("user:****@host/db"));
assert!(output.contains("pool_size=20"));
assert!(!output.contains("pass"));
}
#[test]
fn format_config_summary_with_profile() {
let config = AutumnConfig {
profile: Some("prod".into()),
..AutumnConfig::default()
};
let output = format_config_summary(&config);
assert!(output.contains("profile: prod"));
}
#[test]
fn format_config_summary_with_telemetry() {
let config = AutumnConfig {
telemetry: crate::config::TelemetryConfig {
enabled: true,
service_name: "orders-api".into(),
otlp_endpoint: Some("http://otel-collector:4317".into()),
..crate::config::TelemetryConfig::default()
},
..AutumnConfig::default()
};
let output = format_config_summary(&config);
assert!(output.contains("telemetry: Grpc -> http://otel-collector:4317"));
}
#[test]
fn log_startup_transparency_runs_without_panic() {
let routes = vec![test_get_route("/", "index")];
let tasks = vec![crate::task::TaskInfo {
name: "cleanup".into(),
schedule: crate::task::Schedule::FixedDelay(std::time::Duration::from_secs(60)),
handler: |_| Box::pin(async { Ok(()) }),
}];
let config = AutumnConfig::default();
log_startup_transparency(&routes, &tasks, &[], &config);
}
#[test]
fn log_startup_transparency_no_tasks() {
let routes = vec![test_get_route("/health", "check")];
let config = AutumnConfig::default();
log_startup_transparency(&routes, &[], &[], &config);
}
#[cfg(feature = "ws")]
#[tokio::test]
async fn start_task_scheduler_broadcasts_events() {
let state = AppState {
extensions: std::sync::Arc::new(std::sync::RwLock::new(
std::collections::HashMap::new(),
)),
#[cfg(feature = "db")]
pool: None,
profile: None,
started_at: std::time::Instant::now(),
health_detailed: true,
probes: crate::probe::ProbeState::ready_for_test(),
metrics: crate::middleware::MetricsCollector::new(),
log_levels: crate::actuator::LogLevels::new("info"),
task_registry: crate::actuator::TaskRegistry::new(),
config_props: crate::actuator::ConfigProperties::default(),
channels: crate::channels::Channels::new(32),
shutdown: tokio_util::sync::CancellationToken::new(),
};
let mut rx = state.channels().subscribe("sys:tasks");
let task = crate::task::TaskInfo {
name: "test_broadcaster".into(),
schedule: crate::task::Schedule::FixedDelay(std::time::Duration::from_millis(1)),
handler: |_| Box::pin(async { Ok(()) }),
};
let state_clone = state.clone();
tokio::spawn(async move {
super::start_task_scheduler(
vec![task],
&state_clone,
tokio_util::sync::CancellationToken::new(),
);
});
let msg1 = tokio::time::timeout(std::time::Duration::from_secs(1), rx.recv())
.await
.expect("timeout waiting for start event")
.expect("channel closed");
let json1: serde_json::Value = serde_json::from_str(msg1.as_str()).unwrap();
assert_eq!(json1["event"], "started");
assert_eq!(json1["task"], "test_broadcaster");
let msg2 = tokio::time::timeout(std::time::Duration::from_secs(1), rx.recv())
.await
.expect("timeout waiting for success event")
.expect("channel closed");
let json2: serde_json::Value = serde_json::from_str(msg2.as_str()).unwrap();
assert_eq!(json2["event"], "success");
assert_eq!(json2["task"], "test_broadcaster");
assert!(json2.get("duration_ms").is_some());
}
#[cfg(feature = "ws")]
#[tokio::test]
async fn start_task_scheduler_broadcasts_failure_events() {
let state = AppState {
extensions: std::sync::Arc::new(std::sync::RwLock::new(
std::collections::HashMap::new(),
)),
#[cfg(feature = "db")]
pool: None,
profile: None,
started_at: std::time::Instant::now(),
health_detailed: true,
probes: crate::probe::ProbeState::ready_for_test(),
metrics: crate::middleware::MetricsCollector::new(),
log_levels: crate::actuator::LogLevels::new("info"),
task_registry: crate::actuator::TaskRegistry::new(),
config_props: crate::actuator::ConfigProperties::default(),
channels: crate::channels::Channels::new(32),
shutdown: tokio_util::sync::CancellationToken::new(),
};
let mut rx = state.channels().subscribe("sys:tasks");
let task = crate::task::TaskInfo {
name: "test_failing_task".into(),
schedule: crate::task::Schedule::FixedDelay(std::time::Duration::from_millis(1)),
handler: |_| {
Box::pin(async { Err(crate::AutumnError::bad_request_msg("forced error")) })
},
};
let state_clone = state.clone();
tokio::spawn(async move {
super::start_task_scheduler(
vec![task],
&state_clone,
tokio_util::sync::CancellationToken::new(),
);
});
let _ = rx.recv().await.unwrap();
let msg2 = tokio::time::timeout(std::time::Duration::from_secs(1), rx.recv())
.await
.expect("timeout waiting for failure event")
.expect("channel closed");
let json2: serde_json::Value = serde_json::from_str(msg2.as_str()).unwrap();
assert_eq!(json2["event"], "failure");
assert_eq!(json2["task"], "test_failing_task");
assert_eq!(json2["error"], "forced error");
}
#[tokio::test]
async fn execute_task_result_ok_returns_duration() {
let state = AppState::for_test();
let handler: crate::task::TaskHandler = |_| Box::pin(async { Ok(()) });
let start = std::time::Instant::now();
let result =
super::execute_task_result(&state, handler, start, "test_task", "fixed_delay").await;
assert!(result.is_ok(), "expected Ok from successful handler");
assert!(result.unwrap() < u64::MAX);
}
#[tokio::test]
async fn execute_task_result_err_returns_duration_and_message() {
let state = AppState::for_test();
let handler: crate::task::TaskHandler =
|_| Box::pin(async { Err(crate::AutumnError::bad_request_msg("test error")) });
let start = std::time::Instant::now();
let result =
super::execute_task_result(&state, handler, start, "test_task", "fixed_delay").await;
assert!(result.is_err(), "expected Err from failing handler");
let (duration_ms, msg) = result.unwrap_err();
assert!(duration_ms < u64::MAX);
assert!(msg.contains("test error"));
}
}