use std::any::{Any, TypeId};
use std::collections::HashMap;
use std::sync::Arc;
use crate::cache::Cache;
use crate::time::{ClockSource, SystemClock};
pub struct GlobalCacheEntry(pub Arc<dyn Cache>);
use crate::actuator;
use crate::authorization::{ForbiddenResponse, Policy, PolicyRegistry, Scope};
#[cfg(feature = "ws")]
use crate::channels::Channels;
#[cfg(feature = "db")]
use crate::db::DbState;
use crate::middleware;
#[cfg(feature = "presence")]
use crate::presence::Presence;
use crate::probe;
#[cfg(feature = "ws")]
use tokio_util::sync::CancellationToken;
#[derive(Clone)]
#[non_exhaustive]
pub struct AppState {
pub(crate) extensions: Arc<std::sync::RwLock<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>>,
#[cfg(feature = "db")]
pub(crate) pool:
Option<diesel_async::pooled_connection::deadpool::Pool<diesel_async::AsyncPgConnection>>,
#[cfg(feature = "db")]
pub(crate) replica_pool:
Option<diesel_async::pooled_connection::deadpool::Pool<diesel_async::AsyncPgConnection>>,
pub(crate) profile: Option<String>,
pub(crate) started_at: std::time::Instant,
pub(crate) health_detailed: bool,
pub(crate) probes: probe::ProbeState,
pub(crate) metrics: middleware::MetricsCollector,
pub(crate) log_levels: actuator::LogLevels,
pub(crate) task_registry: actuator::TaskRegistry,
pub(crate) job_registry: actuator::JobRegistry,
pub(crate) config_props: actuator::ConfigProperties,
pub(crate) metrics_source_registry: actuator::MetricsSourceRegistry,
pub(crate) health_indicator_registry: actuator::HealthIndicatorRegistry,
#[cfg(feature = "ws")]
pub(crate) channels: Channels,
#[cfg(feature = "presence")]
pub(crate) presence: Presence,
#[cfg(feature = "ws")]
pub(crate) shutdown: CancellationToken,
pub(crate) policy_registry: PolicyRegistry,
pub(crate) forbidden_response: ForbiddenResponse,
pub(crate) auth_session_key: String,
pub(crate) shared_cache: Option<Arc<dyn Cache>>,
pub(crate) clock: Arc<dyn ClockSource>,
}
impl AppState {
pub fn insert_extension<T>(&self, value: T)
where
T: Any + Send + Sync + 'static,
{
self.extensions
.write()
.expect("app state extension lock poisoned")
.insert(TypeId::of::<T>(), Arc::new(value));
}
#[must_use]
pub fn extension<T>(&self) -> Option<Arc<T>>
where
T: Any + Send + Sync + 'static,
{
self.extensions
.read()
.expect("app state extension lock poisoned")
.get(&TypeId::of::<T>())
.cloned()
.and_then(|value| Arc::downcast::<T>(value).ok())
}
#[cfg(feature = "reporting")]
#[must_use]
pub(crate) fn error_reporters(
&self,
) -> Vec<std::sync::Arc<dyn crate::reporting::ErrorReporter>> {
self.extension::<crate::reporting::RegisteredReporters>()
.map(|reporters| reporters.0.clone())
.unwrap_or_default()
}
#[cfg(feature = "db")]
#[must_use]
pub const fn pool(
&self,
) -> Option<&diesel_async::pooled_connection::deadpool::Pool<diesel_async::AsyncPgConnection>>
{
self.pool.as_ref()
}
#[cfg(feature = "db")]
#[must_use]
pub const fn replica_pool(
&self,
) -> Option<&diesel_async::pooled_connection::deadpool::Pool<diesel_async::AsyncPgConnection>>
{
self.replica_pool.as_ref()
}
#[cfg(feature = "db")]
#[must_use]
pub fn read_pool(
&self,
) -> Option<&diesel_async::pooled_connection::deadpool::Pool<diesel_async::AsyncPgConnection>>
{
if self.replica_pool.is_some() && self.probes.should_route_reads_to_replica() {
self.replica_pool.as_ref()
} else if self.replica_pool.is_some() && self.probes.should_fallback_reads_to_primary() {
self.pool.as_ref()
} else if self.replica_pool.is_some() {
None
} else {
self.pool.as_ref()
}
}
#[must_use]
pub const fn metrics(&self) -> &middleware::MetricsCollector {
&self.metrics
}
#[must_use]
pub const fn log_levels(&self) -> &actuator::LogLevels {
&self.log_levels
}
#[must_use]
pub const fn task_registry(&self) -> &actuator::TaskRegistry {
&self.task_registry
}
#[must_use]
pub const fn job_registry(&self) -> &actuator::JobRegistry {
&self.job_registry
}
#[must_use]
pub const fn config_props(&self) -> &actuator::ConfigProperties {
&self.config_props
}
#[must_use]
pub const fn metrics_source_registry(&self) -> &actuator::MetricsSourceRegistry {
&self.metrics_source_registry
}
#[must_use]
pub const fn health_indicator_registry(&self) -> &actuator::HealthIndicatorRegistry {
&self.health_indicator_registry
}
#[must_use]
pub fn config(&self) -> crate::config::AutumnConfig {
self.extension::<crate::config::AutumnConfig>()
.map_or_else(crate::config::AutumnConfig::default, |arc| (*arc).clone())
}
#[must_use]
pub const fn probes(&self) -> &probe::ProbeState {
&self.probes
}
pub fn mark_startup_complete(&self) {
self.probes.mark_startup_complete();
}
pub fn begin_shutdown(&self) {
self.probes.begin_shutdown();
}
#[cfg(feature = "db")]
#[must_use]
pub fn with_pool(
mut self,
pool: diesel_async::pooled_connection::deadpool::Pool<diesel_async::AsyncPgConnection>,
) -> Self {
self.pool = Some(pool);
self
}
#[cfg(feature = "db")]
#[must_use]
pub fn with_replica_pool(
mut self,
pool: diesel_async::pooled_connection::deadpool::Pool<diesel_async::AsyncPgConnection>,
) -> Self {
self.replica_pool = Some(pool);
self
}
#[must_use]
pub fn with_extension<T>(self, value: T) -> Self
where
T: Any + Send + Sync + 'static,
{
self.insert_extension(value);
self
}
#[must_use]
pub fn cache(&self) -> Option<Arc<dyn Cache>> {
self.extension::<GlobalCacheEntry>()
.map(|e| e.0.clone())
.or_else(|| self.shared_cache.clone())
}
#[must_use]
pub fn with_cache(mut self, cache: Arc<dyn Cache>) -> Self {
self.shared_cache = Some(cache);
self
}
#[must_use]
pub fn clock(&self) -> &dyn ClockSource {
self.clock.as_ref()
}
#[must_use]
pub fn with_clock(mut self, clock: Arc<dyn ClockSource>) -> Self {
self.clock = clock;
self
}
pub fn set_cache(&self, cache: Arc<dyn Cache>) {
crate::cache::set_global_cache(cache.clone());
self.insert_extension(GlobalCacheEntry(cache));
}
#[must_use]
pub fn with_profile(mut self, profile: impl Into<String>) -> Self {
self.profile = Some(profile.into());
self
}
#[must_use]
pub const fn policy_registry(&self) -> &PolicyRegistry {
&self.policy_registry
}
#[must_use]
pub fn policy<R: Send + Sync + 'static>(&self) -> Option<std::sync::Arc<dyn Policy<R>>> {
self.policy_registry.policy::<R>()
}
#[must_use]
pub fn scope<R: Send + Sync + 'static>(&self) -> Option<std::sync::Arc<dyn Scope<R>>> {
self.policy_registry.scope::<R>()
}
#[must_use]
pub const fn forbidden_response(&self) -> ForbiddenResponse {
self.forbidden_response
}
#[must_use]
pub fn auth_session_key(&self) -> &str {
&self.auth_session_key
}
#[doc(hidden)]
#[must_use]
pub const fn with_forbidden_response(mut self, value: ForbiddenResponse) -> Self {
self.forbidden_response = value;
self
}
#[doc(hidden)]
#[must_use]
pub fn with_auth_session_key(mut self, value: impl Into<String>) -> Self {
self.auth_session_key = value.into();
self
}
#[doc(hidden)]
#[must_use]
pub fn with_startup_complete(self, startup_complete: bool) -> Self {
self.probes.set_startup_complete(startup_complete);
self
}
#[doc(hidden)]
#[must_use]
pub fn with_draining(self, draining: bool) -> Self {
self.probes.set_draining(draining);
self
}
#[must_use]
pub fn profile(&self) -> &str {
self.profile.as_deref().unwrap_or("default")
}
#[must_use]
pub fn uptime(&self) -> std::time::Duration {
self.started_at.elapsed()
}
#[must_use]
pub fn uptime_display(&self) -> String {
let secs = self.started_at.elapsed().as_secs();
if secs < 60 {
format!("{secs}s")
} else if secs < 3600 {
format!("{}m {}s", secs / 60, secs % 60)
} else {
let hours = secs / 3600;
let mins = (secs % 3600) / 60;
format!("{hours}h {mins}m")
}
}
#[cfg(feature = "ws")]
#[must_use]
pub const fn channels(&self) -> &Channels {
&self.channels
}
#[cfg(feature = "presence")]
#[must_use]
pub const fn presence(&self) -> &Presence {
&self.presence
}
#[cfg(feature = "ws")]
#[must_use]
pub fn broadcast(&self) -> crate::channels::Broadcast {
self.channels.broadcast()
}
#[cfg(feature = "ws")]
#[must_use]
pub fn shutdown_token(&self) -> CancellationToken {
self.shutdown.child_token()
}
#[cfg(feature = "ws")]
#[doc(hidden)]
pub fn trigger_shutdown_for_test(&self) {
self.begin_shutdown();
self.shutdown.cancel();
}
#[doc(hidden)]
pub fn set_startup_complete_for_test(&self, startup_complete: bool) {
self.probes.set_startup_complete(startup_complete);
}
#[doc(hidden)]
pub fn set_draining_for_test(&self, draining: bool) {
self.probes.set_draining(draining);
}
#[doc(hidden)]
pub fn begin_shutdown_for_test(&self) {
self.set_draining_for_test(true);
}
#[must_use]
pub fn detached() -> Self {
#[cfg(feature = "ws")]
let channels = Channels::new(32);
Self {
extensions: Arc::new(std::sync::RwLock::new(HashMap::new())),
#[cfg(feature = "db")]
pool: None,
#[cfg(feature = "db")]
replica_pool: None,
profile: None,
started_at: std::time::Instant::now(),
health_detailed: true,
probes: probe::ProbeState::ready_for_test(),
metrics: middleware::MetricsCollector::new(),
log_levels: actuator::LogLevels::new("info"),
task_registry: actuator::TaskRegistry::new(),
job_registry: actuator::JobRegistry::new(),
config_props: actuator::ConfigProperties::default(),
metrics_source_registry: actuator::MetricsSourceRegistry::new(),
health_indicator_registry: actuator::HealthIndicatorRegistry::new(),
#[cfg(feature = "presence")]
presence: Presence::new(channels.clone()),
#[cfg(feature = "ws")]
channels,
#[cfg(feature = "ws")]
shutdown: CancellationToken::new(),
policy_registry: PolicyRegistry::default(),
forbidden_response: ForbiddenResponse::default(),
auth_session_key: "user_id".to_owned(),
shared_cache: None,
clock: Arc::new(SystemClock),
}
}
#[allow(dead_code)]
#[must_use]
pub fn for_test() -> Self {
Self::detached()
}
}
#[cfg(feature = "db")]
impl DbState for AppState {
fn metrics(&self) -> Option<&crate::middleware::MetricsCollector> {
Some(&self.metrics)
}
fn pool(
&self,
) -> Option<&diesel_async::pooled_connection::deadpool::Pool<diesel_async::AsyncPgConnection>>
{
self.pool.as_ref()
}
fn replica_pool(
&self,
) -> Option<&diesel_async::pooled_connection::deadpool::Pool<diesel_async::AsyncPgConnection>>
{
self.replica_pool.as_ref()
}
fn read_pool(
&self,
) -> Option<&diesel_async::pooled_connection::deadpool::Pool<diesel_async::AsyncPgConnection>>
{
Self::read_pool(self)
}
fn db_interceptors(
&self,
) -> Vec<std::sync::Arc<dyn crate::interceptor::DbConnectionInterceptor>> {
self.extension::<Arc<dyn crate::interceptor::DbConnectionInterceptor>>()
.map(|arc| vec![(*arc).clone()])
.unwrap_or_default()
}
fn statement_timeout(&self) -> Option<std::time::Duration> {
self.extension::<crate::config::AutumnConfig>()
.and_then(|cfg| cfg.database.statement_timeout)
}
fn slow_query_threshold(&self) -> std::time::Duration {
self.extension::<crate::config::AutumnConfig>().map_or_else(
|| std::time::Duration::from_millis(500),
|cfg| cfg.database.slow_query_threshold,
)
}
}
impl crate::probe::ProvideProbeState for AppState {
fn probes(&self) -> &crate::probe::ProbeState {
&self.probes
}
fn health_detailed(&self) -> bool {
self.health_detailed
}
fn profile(&self) -> &str {
self.profile()
}
fn uptime_display(&self) -> String {
self.uptime_display()
}
#[cfg(feature = "db")]
fn pool(
&self,
) -> Option<&diesel_async::pooled_connection::deadpool::Pool<diesel_async::AsyncPgConnection>>
{
self.pool.as_ref()
}
#[cfg(feature = "db")]
fn replica_pool(
&self,
) -> Option<&diesel_async::pooled_connection::deadpool::Pool<diesel_async::AsyncPgConnection>>
{
self.replica_pool.as_ref()
}
fn health_indicator_registry(&self) -> Option<&crate::actuator::HealthIndicatorRegistry> {
Some(&self.health_indicator_registry)
}
}
impl crate::actuator::ProvideActuatorState for AppState {
fn metrics(&self) -> &crate::middleware::MetricsCollector {
&self.metrics
}
fn log_levels(&self) -> &crate::actuator::LogLevels {
&self.log_levels
}
fn task_registry(&self) -> &crate::actuator::TaskRegistry {
&self.task_registry
}
fn job_registry(&self) -> &crate::actuator::JobRegistry {
&self.job_registry
}
fn config_props(&self) -> &crate::actuator::ConfigProperties {
&self.config_props
}
fn profile(&self) -> &str {
self.profile()
}
fn uptime_display(&self) -> String {
self.uptime_display()
}
fn metrics_source_registry(&self) -> Option<&crate::actuator::MetricsSourceRegistry> {
Some(&self.metrics_source_registry)
}
fn health_indicator_registry(&self) -> Option<&crate::actuator::HealthIndicatorRegistry> {
Some(&self.health_indicator_registry)
}
fn health_detailed(&self) -> bool {
self.health_detailed
}
fn deploy_version(&self) -> String {
self.extension::<crate::canary::CanaryState>().map_or_else(
|| crate::canary::STABLE.to_owned(),
|c| c.version().to_owned(),
)
}
#[cfg(feature = "ws")]
fn channels(&self) -> &crate::channels::Channels {
&self.channels
}
#[cfg(feature = "ws")]
fn shutdown_token(&self) -> tokio_util::sync::CancellationToken {
self.shutdown_token()
}
#[cfg(feature = "db")]
fn pool(
&self,
) -> Option<&diesel_async::pooled_connection::deadpool::Pool<diesel_async::AsyncPgConnection>>
{
self.pool.as_ref()
}
#[cfg(feature = "http-client")]
fn webhook_outbound(&self) -> Option<crate::webhook_outbound::WebhookOutboundManager> {
self.extension::<crate::webhook_outbound::WebhookOutboundManager>()
.map(|x| (*x).clone())
}
fn log_buffer(&self) -> Option<crate::log::capture::LogBuffer> {
self.extension::<crate::log::capture::LogBuffer>()
.map(|x| (*x).clone())
}
}
impl std::fmt::Debug for AppState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut s = f.debug_struct("AppState");
#[cfg(feature = "db")]
s.field(
"pool",
&self
.pool
.as_ref()
.map(|p| format!("Pool(max={})", p.status().max_size)),
);
s.field(
"extensions",
&self
.extensions
.read()
.map_or(0, |extensions| extensions.len()),
);
s.field("profile", &self.profile)
.field("started_at", &self.started_at)
.field("health_detailed", &self.health_detailed)
.field("probes", &self.probes)
.field("metrics", &"MetricsCollector")
.field("log_levels", &"LogLevels")
.field("task_registry", &"TaskRegistry")
.finish_non_exhaustive()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg(feature = "db")]
use crate::config;
#[cfg(feature = "db")]
use crate::db;
#[test]
fn app_state_debug_without_pool() {
let state = AppState::for_test().with_profile("dev");
let debug = format!("{state:?}");
assert!(debug.contains("AppState"));
assert!(debug.contains("dev"));
}
#[cfg(feature = "db")]
#[test]
fn app_state_debug_with_pool() {
let config = config::DatabaseConfig {
url: Some("postgres://localhost/test".into()),
pool_size: 5,
..Default::default()
};
let pool = db::create_pool(&config).unwrap().unwrap();
let state = AppState::for_test().with_pool(pool);
let debug = format!("{state:?}");
assert!(debug.contains("Pool(max=5)"));
}
#[cfg(feature = "db")]
#[test]
fn database_topology_state_exposes_replica_as_read_pool() {
let primary_config = config::DatabaseConfig {
url: Some("postgres://localhost/primary".into()),
pool_size: 5,
..Default::default()
};
let replica_config = config::DatabaseConfig {
url: Some("postgres://localhost/replica".into()),
pool_size: 2,
..Default::default()
};
let primary = db::create_pool(&primary_config).unwrap().unwrap();
let replica = db::create_pool(&replica_config).unwrap().unwrap();
let state = AppState::for_test()
.with_pool(primary)
.with_replica_pool(replica);
assert_eq!(state.pool().expect("primary pool").status().max_size, 5);
assert_eq!(
state
.replica_pool()
.expect("replica pool")
.status()
.max_size,
2
);
assert_eq!(state.read_pool().expect("read pool").status().max_size, 2);
}
#[cfg(feature = "db")]
#[test]
fn read_pool_uses_primary_when_replica_is_unready_and_policy_allows_fallback() {
let primary_config = config::DatabaseConfig {
url: Some("postgres://localhost/primary".into()),
pool_size: 5,
..Default::default()
};
let replica_config = config::DatabaseConfig {
url: Some("postgres://localhost/replica".into()),
pool_size: 2,
..Default::default()
};
let primary = db::create_pool(&primary_config).unwrap().unwrap();
let replica = db::create_pool(&replica_config).unwrap().unwrap();
let state = AppState::for_test()
.with_pool(primary)
.with_replica_pool(replica);
state
.probes()
.configure_replica_dependency(config::ReplicaFallback::Primary);
state
.probes()
.mark_replica_unready("replica migrations lag primary");
assert_eq!(state.read_pool().expect("read pool").status().max_size, 5);
assert_eq!(
db::DbState::read_pool(&state)
.expect("trait read pool")
.status()
.max_size,
5
);
}
#[cfg(feature = "db")]
#[test]
fn read_pool_does_not_route_to_unready_replica_when_policy_fails_readiness() {
let primary_config = config::DatabaseConfig {
url: Some("postgres://localhost/primary".into()),
pool_size: 5,
..Default::default()
};
let replica_config = config::DatabaseConfig {
url: Some("postgres://localhost/replica".into()),
pool_size: 2,
..Default::default()
};
let primary = db::create_pool(&primary_config).unwrap().unwrap();
let replica = db::create_pool(&replica_config).unwrap().unwrap();
let state = AppState::for_test()
.with_pool(primary)
.with_replica_pool(replica);
state
.probes()
.configure_replica_dependency(config::ReplicaFallback::FailReadiness);
state
.probes()
.mark_replica_unready("replica connection failed");
assert!(state.read_pool().is_none());
}
#[cfg(feature = "db")]
#[tokio::test]
async fn readiness_fails_when_app_state_replica_is_unready_and_policy_is_fail_readiness() {
let primary_config = config::DatabaseConfig {
url: Some("postgres://localhost/primary".into()),
pool_size: 5,
..Default::default()
};
let replica_config = config::DatabaseConfig {
url: Some("postgres://localhost/replica".into()),
pool_size: 2,
..Default::default()
};
let primary = db::create_pool(&primary_config).unwrap().unwrap();
let replica = db::create_pool(&replica_config).unwrap().unwrap();
let state = AppState::for_test()
.with_pool(primary)
.with_replica_pool(replica);
state
.probes()
.configure_replica_dependency(config::ReplicaFallback::FailReadiness);
state
.probes()
.mark_replica_unready("replica migrations lag primary");
let (status, _) = crate::probe::readiness_response(&state).await;
assert_eq!(status, http::StatusCode::SERVICE_UNAVAILABLE);
}
#[test]
fn detached_state_starts_without_profile() {
let state = AppState::detached();
assert_eq!(state.profile(), "default");
}
fn require_clone<T: Clone>(t: &T) -> T {
t.clone()
}
#[test]
fn app_state_is_clone() {
let state = AppState::for_test();
let _cloned = require_clone(&state);
}
#[test]
fn app_state_profile_accessor() {
let state = AppState::for_test().with_profile("staging");
assert_eq!(state.profile(), "staging");
}
#[test]
fn app_state_deploy_version_defaults_to_stable() {
use crate::actuator::ProvideActuatorState;
let state = AppState::for_test();
assert_eq!(state.deploy_version(), crate::canary::STABLE);
}
#[test]
fn app_state_deploy_version_reads_canary_extension() {
use crate::actuator::ProvideActuatorState;
let state = AppState::for_test();
state.insert_extension(crate::canary::CanaryState::new(crate::canary::CANARY));
assert_eq!(state.deploy_version(), crate::canary::CANARY);
}
#[test]
fn app_state_profile_default() {
let state = AppState::for_test();
assert_eq!(state.profile(), "default");
}
#[test]
fn app_state_uptime_display() {
let state = AppState::for_test();
let display = state.uptime_display();
assert!(
display.contains('s'),
"uptime should contain 's': {display}"
);
}
#[test]
fn app_state_accessors() {
let state = AppState::for_test();
let _metrics = state.metrics();
let _log_levels = state.log_levels();
let _task_registry = state.task_registry();
let _config_props = state.config_props();
#[cfg(feature = "db")]
{
let _pool = state.pool();
}
let _missing = state.extension::<String>();
}
#[test]
fn app_state_runtime_extensions_round_trip() {
let state = AppState::for_test();
state.insert_extension(String::from("haunted"));
let stored = state
.extension::<String>()
.expect("runtime extension should be installed");
assert_eq!(stored.as_str(), "haunted");
}
}