use std::any::{Any, TypeId};
use std::collections::HashMap;
use std::sync::Arc;
use crate::actuator;
#[cfg(feature = "ws")]
use crate::channels::Channels;
#[cfg(feature = "db")]
use crate::db::DbState;
use crate::middleware;
use crate::probe;
#[cfg(feature = "ws")]
use tokio_util::sync::CancellationToken;
#[derive(Clone)]
#[non_exhaustive]
pub struct AppState {
pub(crate) extensions: Arc<std::sync::RwLock<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>>,
#[cfg(feature = "db")]
pub(crate) pool:
Option<diesel_async::pooled_connection::deadpool::Pool<diesel_async::AsyncPgConnection>>,
pub(crate) profile: Option<String>,
pub(crate) started_at: std::time::Instant,
pub(crate) health_detailed: bool,
pub(crate) probes: probe::ProbeState,
pub(crate) metrics: middleware::MetricsCollector,
pub(crate) log_levels: actuator::LogLevels,
pub(crate) task_registry: actuator::TaskRegistry,
pub(crate) config_props: actuator::ConfigProperties,
#[cfg(feature = "ws")]
pub(crate) channels: Channels,
#[cfg(feature = "ws")]
pub(crate) shutdown: CancellationToken,
}
impl AppState {
pub fn insert_extension<T>(&self, value: T)
where
T: Any + Send + Sync + 'static,
{
self.extensions
.write()
.expect("app state extension lock poisoned")
.insert(TypeId::of::<T>(), Arc::new(value));
}
#[must_use]
pub fn extension<T>(&self) -> Option<Arc<T>>
where
T: Any + Send + Sync + 'static,
{
self.extensions
.read()
.expect("app state extension lock poisoned")
.get(&TypeId::of::<T>())
.cloned()
.and_then(|value| Arc::downcast::<T>(value).ok())
}
#[cfg(feature = "db")]
#[must_use]
pub const fn pool(
&self,
) -> Option<&diesel_async::pooled_connection::deadpool::Pool<diesel_async::AsyncPgConnection>>
{
self.pool.as_ref()
}
#[must_use]
pub const fn metrics(&self) -> &middleware::MetricsCollector {
&self.metrics
}
#[must_use]
pub const fn log_levels(&self) -> &actuator::LogLevels {
&self.log_levels
}
#[must_use]
pub const fn task_registry(&self) -> &actuator::TaskRegistry {
&self.task_registry
}
#[must_use]
pub const fn config_props(&self) -> &actuator::ConfigProperties {
&self.config_props
}
#[must_use]
pub const fn probes(&self) -> &probe::ProbeState {
&self.probes
}
pub fn mark_startup_complete(&self) {
self.probes.mark_startup_complete();
}
pub fn begin_shutdown(&self) {
self.probes.begin_shutdown();
}
#[cfg(feature = "db")]
#[must_use]
pub fn with_pool(
mut self,
pool: diesel_async::pooled_connection::deadpool::Pool<diesel_async::AsyncPgConnection>,
) -> Self {
self.pool = Some(pool);
self
}
#[must_use]
pub fn with_extension<T>(self, value: T) -> Self
where
T: Any + Send + Sync + 'static,
{
self.insert_extension(value);
self
}
#[must_use]
pub fn with_profile(mut self, profile: impl Into<String>) -> Self {
self.profile = Some(profile.into());
self
}
#[doc(hidden)]
#[must_use]
pub fn with_startup_complete(self, startup_complete: bool) -> Self {
self.probes.set_startup_complete(startup_complete);
self
}
#[doc(hidden)]
#[must_use]
pub fn with_draining(self, draining: bool) -> Self {
self.probes.set_draining(draining);
self
}
#[must_use]
pub fn profile(&self) -> &str {
self.profile.as_deref().unwrap_or("default")
}
#[must_use]
pub fn uptime(&self) -> std::time::Duration {
self.started_at.elapsed()
}
#[must_use]
pub fn uptime_display(&self) -> String {
let secs = self.started_at.elapsed().as_secs();
if secs < 60 {
format!("{secs}s")
} else if secs < 3600 {
format!("{}m {}s", secs / 60, secs % 60)
} else {
let hours = secs / 3600;
let mins = (secs % 3600) / 60;
format!("{hours}h {mins}m")
}
}
#[cfg(feature = "ws")]
#[must_use]
pub const fn channels(&self) -> &Channels {
&self.channels
}
#[cfg(feature = "ws")]
#[must_use]
pub fn shutdown_token(&self) -> CancellationToken {
self.shutdown.child_token()
}
#[cfg(feature = "ws")]
#[doc(hidden)]
pub fn trigger_shutdown_for_test(&self) {
self.begin_shutdown();
self.shutdown.cancel();
}
#[doc(hidden)]
pub fn set_startup_complete_for_test(&self, startup_complete: bool) {
self.probes.set_startup_complete(startup_complete);
}
#[doc(hidden)]
pub fn set_draining_for_test(&self, draining: bool) {
self.probes.set_draining(draining);
}
#[doc(hidden)]
pub fn begin_shutdown_for_test(&self) {
self.set_draining_for_test(true);
}
#[must_use]
pub fn detached() -> Self {
Self {
extensions: Arc::new(std::sync::RwLock::new(HashMap::new())),
#[cfg(feature = "db")]
pool: None,
profile: None,
started_at: std::time::Instant::now(),
health_detailed: true,
probes: probe::ProbeState::ready_for_test(),
metrics: middleware::MetricsCollector::new(),
log_levels: actuator::LogLevels::new("info"),
task_registry: actuator::TaskRegistry::new(),
config_props: actuator::ConfigProperties::default(),
#[cfg(feature = "ws")]
channels: Channels::new(32),
#[cfg(feature = "ws")]
shutdown: CancellationToken::new(),
}
}
#[allow(dead_code)]
#[must_use]
pub fn for_test() -> Self {
Self::detached()
}
}
#[cfg(feature = "db")]
impl DbState for AppState {
fn pool(
&self,
) -> Option<&diesel_async::pooled_connection::deadpool::Pool<diesel_async::AsyncPgConnection>>
{
self.pool.as_ref()
}
}
impl crate::probe::ProvideProbeState for AppState {
fn probes(&self) -> &crate::probe::ProbeState {
&self.probes
}
fn health_detailed(&self) -> bool {
self.health_detailed
}
fn profile(&self) -> &str {
self.profile()
}
fn uptime_display(&self) -> String {
self.uptime_display()
}
#[cfg(feature = "db")]
fn pool(
&self,
) -> Option<&diesel_async::pooled_connection::deadpool::Pool<diesel_async::AsyncPgConnection>>
{
self.pool.as_ref()
}
}
impl crate::actuator::ProvideActuatorState for AppState {
fn metrics(&self) -> &crate::middleware::MetricsCollector {
&self.metrics
}
fn log_levels(&self) -> &crate::actuator::LogLevels {
&self.log_levels
}
fn task_registry(&self) -> &crate::actuator::TaskRegistry {
&self.task_registry
}
fn config_props(&self) -> &crate::actuator::ConfigProperties {
&self.config_props
}
fn profile(&self) -> &str {
self.profile()
}
fn uptime_display(&self) -> String {
self.uptime_display()
}
#[cfg(feature = "ws")]
fn channels(&self) -> &crate::channels::Channels {
&self.channels
}
#[cfg(feature = "ws")]
fn shutdown_token(&self) -> tokio_util::sync::CancellationToken {
self.shutdown_token()
}
#[cfg(feature = "db")]
fn pool(
&self,
) -> Option<&diesel_async::pooled_connection::deadpool::Pool<diesel_async::AsyncPgConnection>>
{
self.pool.as_ref()
}
}
impl std::fmt::Debug for AppState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut s = f.debug_struct("AppState");
#[cfg(feature = "db")]
s.field(
"pool",
&self
.pool
.as_ref()
.map(|p| format!("Pool(max={})", p.status().max_size)),
);
s.field(
"extensions",
&self
.extensions
.read()
.map_or(0, |extensions| extensions.len()),
);
s.field("profile", &self.profile)
.field("started_at", &self.started_at)
.field("health_detailed", &self.health_detailed)
.field("probes", &self.probes)
.field("metrics", &"MetricsCollector")
.field("log_levels", &"LogLevels")
.field("task_registry", &"TaskRegistry")
.finish_non_exhaustive()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg(feature = "db")]
use crate::config;
#[cfg(feature = "db")]
use crate::db;
#[test]
fn app_state_debug_without_pool() {
let state = AppState::for_test().with_profile("dev");
let debug = format!("{state:?}");
assert!(debug.contains("AppState"));
assert!(debug.contains("dev"));
}
#[cfg(feature = "db")]
#[test]
fn app_state_debug_with_pool() {
let config = config::DatabaseConfig {
url: Some("postgres://localhost/test".into()),
pool_size: 5,
..Default::default()
};
let pool = db::create_pool(&config).unwrap().unwrap();
let state = AppState::for_test().with_pool(pool);
let debug = format!("{state:?}");
assert!(debug.contains("Pool(max=5)"));
}
#[test]
fn detached_state_starts_without_profile() {
let state = AppState::detached();
assert_eq!(state.profile(), "default");
}
fn require_clone<T: Clone>(t: &T) -> T {
t.clone()
}
#[test]
fn app_state_is_clone() {
let state = AppState::for_test();
let _cloned = require_clone(&state);
}
#[test]
fn app_state_profile_accessor() {
let state = AppState::for_test().with_profile("staging");
assert_eq!(state.profile(), "staging");
}
#[test]
fn app_state_profile_default() {
let state = AppState::for_test();
assert_eq!(state.profile(), "default");
}
#[test]
fn app_state_uptime_display() {
let state = AppState::for_test();
let display = state.uptime_display();
assert!(
display.contains('s'),
"uptime should contain 's': {display}"
);
}
#[test]
fn app_state_accessors() {
let state = AppState::for_test();
let _metrics = state.metrics();
let _log_levels = state.log_levels();
let _task_registry = state.task_registry();
let _config_props = state.config_props();
#[cfg(feature = "db")]
{
let _pool = state.pool();
}
let _missing = state.extension::<String>();
}
#[test]
fn app_state_runtime_extensions_round_trip() {
let state = AppState::for_test();
state.insert_extension(String::from("haunted"));
let stored = state
.extension::<String>()
.expect("runtime extension should be installed");
assert_eq!(stored.as_str(), "haunted");
}
}