1use std::any::{Any, TypeId};
13use std::collections::HashMap;
14use std::sync::Arc;
15
16use crate::cache::Cache;
17
18pub struct GlobalCacheEntry(pub Arc<dyn Cache>);
21
22use crate::actuator;
23use crate::authorization::{ForbiddenResponse, Policy, PolicyRegistry, Scope};
24#[cfg(feature = "ws")]
25use crate::channels::Channels;
26#[cfg(feature = "db")]
27use crate::db::DbState;
28use crate::middleware;
29use crate::probe;
30#[cfg(feature = "ws")]
31use tokio_util::sync::CancellationToken;
32
33#[derive(Clone)]
54#[non_exhaustive]
55pub struct AppState {
56 pub(crate) extensions: Arc<std::sync::RwLock<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>>,
59
60 #[cfg(feature = "db")]
63 pub(crate) pool:
64 Option<diesel_async::pooled_connection::deadpool::Pool<diesel_async::AsyncPgConnection>>,
65
66 #[cfg(feature = "db")]
68 pub(crate) replica_pool:
69 Option<diesel_async::pooled_connection::deadpool::Pool<diesel_async::AsyncPgConnection>>,
70
71 pub(crate) profile: Option<String>,
73
74 pub(crate) started_at: std::time::Instant,
76
77 pub(crate) health_detailed: bool,
79
80 pub(crate) probes: probe::ProbeState,
82
83 pub(crate) metrics: middleware::MetricsCollector,
85
86 pub(crate) log_levels: actuator::LogLevels,
88
89 pub(crate) task_registry: actuator::TaskRegistry,
91 pub(crate) job_registry: actuator::JobRegistry,
93
94 pub(crate) config_props: actuator::ConfigProperties,
96
97 #[cfg(feature = "ws")]
102 pub(crate) channels: Channels,
103
104 #[cfg(feature = "ws")]
109 pub(crate) shutdown: CancellationToken,
110
111 pub(crate) policy_registry: PolicyRegistry,
114
115 pub(crate) forbidden_response: ForbiddenResponse,
119
120 pub(crate) auth_session_key: String,
125
126 pub(crate) shared_cache: Option<Arc<dyn Cache>>,
129}
130
131impl AppState {
132 pub fn insert_extension<T>(&self, value: T)
141 where
142 T: Any + Send + Sync + 'static,
143 {
144 self.extensions
145 .write()
146 .expect("app state extension lock poisoned")
147 .insert(TypeId::of::<T>(), Arc::new(value));
148 }
149
150 #[must_use]
159 pub fn extension<T>(&self) -> Option<Arc<T>>
160 where
161 T: Any + Send + Sync + 'static,
162 {
163 self.extensions
164 .read()
165 .expect("app state extension lock poisoned")
166 .get(&TypeId::of::<T>())
167 .cloned()
168 .and_then(|value| Arc::downcast::<T>(value).ok())
169 }
170
171 #[cfg(feature = "db")]
173 #[must_use]
174 pub const fn pool(
175 &self,
176 ) -> Option<&diesel_async::pooled_connection::deadpool::Pool<diesel_async::AsyncPgConnection>>
177 {
178 self.pool.as_ref()
179 }
180
181 #[cfg(feature = "db")]
183 #[must_use]
184 pub const fn replica_pool(
185 &self,
186 ) -> Option<&diesel_async::pooled_connection::deadpool::Pool<diesel_async::AsyncPgConnection>>
187 {
188 self.replica_pool.as_ref()
189 }
190
191 #[cfg(feature = "db")]
193 #[must_use]
194 pub fn read_pool(
195 &self,
196 ) -> Option<&diesel_async::pooled_connection::deadpool::Pool<diesel_async::AsyncPgConnection>>
197 {
198 if self.replica_pool.is_some() && self.probes.should_route_reads_to_replica() {
199 self.replica_pool.as_ref()
200 } else if self.replica_pool.is_some() && self.probes.should_fallback_reads_to_primary() {
201 self.pool.as_ref()
202 } else if self.replica_pool.is_some() {
203 None
204 } else {
205 self.pool.as_ref()
206 }
207 }
208
209 #[must_use]
211 pub const fn metrics(&self) -> &middleware::MetricsCollector {
212 &self.metrics
213 }
214
215 #[must_use]
217 pub const fn log_levels(&self) -> &actuator::LogLevels {
218 &self.log_levels
219 }
220
221 #[must_use]
223 pub const fn task_registry(&self) -> &actuator::TaskRegistry {
224 &self.task_registry
225 }
226
227 #[must_use]
229 pub const fn job_registry(&self) -> &actuator::JobRegistry {
230 &self.job_registry
231 }
232
233 #[must_use]
235 pub const fn config_props(&self) -> &actuator::ConfigProperties {
236 &self.config_props
237 }
238
239 #[must_use]
241 pub const fn probes(&self) -> &probe::ProbeState {
242 &self.probes
243 }
244
245 pub fn mark_startup_complete(&self) {
247 self.probes.mark_startup_complete();
248 }
249
250 pub fn begin_shutdown(&self) {
252 self.probes.begin_shutdown();
253 }
254
255 #[cfg(feature = "db")]
257 #[must_use]
258 pub fn with_pool(
259 mut self,
260 pool: diesel_async::pooled_connection::deadpool::Pool<diesel_async::AsyncPgConnection>,
261 ) -> Self {
262 self.pool = Some(pool);
263 self
264 }
265
266 #[cfg(feature = "db")]
268 #[must_use]
269 pub fn with_replica_pool(
270 mut self,
271 pool: diesel_async::pooled_connection::deadpool::Pool<diesel_async::AsyncPgConnection>,
272 ) -> Self {
273 self.replica_pool = Some(pool);
274 self
275 }
276
277 #[must_use]
279 pub fn with_extension<T>(self, value: T) -> Self
280 where
281 T: Any + Send + Sync + 'static,
282 {
283 self.insert_extension(value);
284 self
285 }
286
287 #[must_use]
294 pub fn cache(&self) -> Option<Arc<dyn Cache>> {
295 self.extension::<GlobalCacheEntry>()
296 .map(|e| e.0.clone())
297 .or_else(|| self.shared_cache.clone())
298 }
299
300 #[must_use]
302 pub fn with_cache(mut self, cache: Arc<dyn Cache>) -> Self {
303 self.shared_cache = Some(cache);
304 self
305 }
306
307 pub fn set_cache(&self, cache: Arc<dyn Cache>) {
312 crate::cache::set_global_cache(cache.clone());
313 self.insert_extension(GlobalCacheEntry(cache));
314 }
315
316 #[must_use]
318 pub fn with_profile(mut self, profile: impl Into<String>) -> Self {
319 self.profile = Some(profile.into());
320 self
321 }
322
323 #[must_use]
325 pub const fn policy_registry(&self) -> &PolicyRegistry {
326 &self.policy_registry
327 }
328
329 #[must_use]
331 pub fn policy<R: Send + Sync + 'static>(&self) -> Option<std::sync::Arc<dyn Policy<R>>> {
332 self.policy_registry.policy::<R>()
333 }
334
335 #[must_use]
337 pub fn scope<R: Send + Sync + 'static>(&self) -> Option<std::sync::Arc<dyn Scope<R>>> {
338 self.policy_registry.scope::<R>()
339 }
340
341 #[must_use]
345 pub const fn forbidden_response(&self) -> ForbiddenResponse {
346 self.forbidden_response
347 }
348
349 #[must_use]
352 pub fn auth_session_key(&self) -> &str {
353 &self.auth_session_key
354 }
355
356 #[doc(hidden)]
358 #[must_use]
359 pub const fn with_forbidden_response(mut self, value: ForbiddenResponse) -> Self {
360 self.forbidden_response = value;
361 self
362 }
363
364 #[doc(hidden)]
366 #[must_use]
367 pub fn with_auth_session_key(mut self, value: impl Into<String>) -> Self {
368 self.auth_session_key = value.into();
369 self
370 }
371
372 #[doc(hidden)]
374 #[must_use]
375 pub fn with_startup_complete(self, startup_complete: bool) -> Self {
376 self.probes.set_startup_complete(startup_complete);
377 self
378 }
379
380 #[doc(hidden)]
382 #[must_use]
383 pub fn with_draining(self, draining: bool) -> Self {
384 self.probes.set_draining(draining);
385 self
386 }
387
388 #[must_use]
390 pub fn profile(&self) -> &str {
391 self.profile.as_deref().unwrap_or("default")
392 }
393
394 #[must_use]
396 pub fn uptime(&self) -> std::time::Duration {
397 self.started_at.elapsed()
398 }
399
400 #[must_use]
402 pub fn uptime_display(&self) -> String {
403 let secs = self.started_at.elapsed().as_secs();
404 if secs < 60 {
405 format!("{secs}s")
406 } else if secs < 3600 {
407 format!("{}m {}s", secs / 60, secs % 60)
408 } else {
409 let hours = secs / 3600;
410 let mins = (secs % 3600) / 60;
411 format!("{hours}h {mins}m")
412 }
413 }
414
415 #[cfg(feature = "ws")]
419 #[must_use]
420 pub const fn channels(&self) -> &Channels {
421 &self.channels
422 }
423
424 #[cfg(feature = "ws")]
426 #[must_use]
427 pub fn broadcast(&self) -> crate::channels::Broadcast {
428 self.channels.broadcast()
429 }
430
431 #[cfg(feature = "ws")]
436 #[must_use]
437 pub fn shutdown_token(&self) -> CancellationToken {
438 self.shutdown.child_token()
439 }
440
441 #[cfg(feature = "ws")]
443 #[doc(hidden)]
444 pub fn trigger_shutdown_for_test(&self) {
445 self.begin_shutdown();
446 self.shutdown.cancel();
447 }
448
449 #[doc(hidden)]
451 pub fn set_startup_complete_for_test(&self, startup_complete: bool) {
452 self.probes.set_startup_complete(startup_complete);
453 }
454
455 #[doc(hidden)]
457 pub fn set_draining_for_test(&self, draining: bool) {
458 self.probes.set_draining(draining);
459 }
460
461 #[doc(hidden)]
463 pub fn begin_shutdown_for_test(&self) {
464 self.set_draining_for_test(true);
465 }
466
467 #[must_use]
473 pub fn detached() -> Self {
474 Self {
475 extensions: Arc::new(std::sync::RwLock::new(HashMap::new())),
476 #[cfg(feature = "db")]
477 pool: None,
478 #[cfg(feature = "db")]
479 replica_pool: None,
480 profile: None,
481 started_at: std::time::Instant::now(),
482 health_detailed: true,
483 probes: probe::ProbeState::ready_for_test(),
484 metrics: middleware::MetricsCollector::new(),
485 log_levels: actuator::LogLevels::new("info"),
486 task_registry: actuator::TaskRegistry::new(),
487 job_registry: actuator::JobRegistry::new(),
488 config_props: actuator::ConfigProperties::default(),
489 #[cfg(feature = "ws")]
490 channels: Channels::new(32),
491 #[cfg(feature = "ws")]
492 shutdown: CancellationToken::new(),
493 policy_registry: PolicyRegistry::default(),
494 forbidden_response: ForbiddenResponse::default(),
495 auth_session_key: "user_id".to_owned(),
496 shared_cache: None,
497 }
498 }
499
500 #[allow(dead_code)]
503 #[must_use]
504 pub fn for_test() -> Self {
505 Self::detached()
506 }
507}
508
509#[cfg(feature = "db")]
510impl DbState for AppState {
511 fn pool(
512 &self,
513 ) -> Option<&diesel_async::pooled_connection::deadpool::Pool<diesel_async::AsyncPgConnection>>
514 {
515 self.pool.as_ref()
516 }
517
518 fn replica_pool(
519 &self,
520 ) -> Option<&diesel_async::pooled_connection::deadpool::Pool<diesel_async::AsyncPgConnection>>
521 {
522 self.replica_pool.as_ref()
523 }
524
525 fn read_pool(
526 &self,
527 ) -> Option<&diesel_async::pooled_connection::deadpool::Pool<diesel_async::AsyncPgConnection>>
528 {
529 Self::read_pool(self)
530 }
531}
532
533impl crate::probe::ProvideProbeState for AppState {
534 fn probes(&self) -> &crate::probe::ProbeState {
535 &self.probes
536 }
537
538 fn health_detailed(&self) -> bool {
539 self.health_detailed
540 }
541
542 fn profile(&self) -> &str {
543 self.profile()
544 }
545
546 fn uptime_display(&self) -> String {
547 self.uptime_display()
548 }
549
550 #[cfg(feature = "db")]
551 fn pool(
552 &self,
553 ) -> Option<&diesel_async::pooled_connection::deadpool::Pool<diesel_async::AsyncPgConnection>>
554 {
555 self.pool.as_ref()
556 }
557
558 #[cfg(feature = "db")]
559 fn replica_pool(
560 &self,
561 ) -> Option<&diesel_async::pooled_connection::deadpool::Pool<diesel_async::AsyncPgConnection>>
562 {
563 self.replica_pool.as_ref()
564 }
565}
566
567impl crate::actuator::ProvideActuatorState for AppState {
568 fn metrics(&self) -> &crate::middleware::MetricsCollector {
569 &self.metrics
570 }
571
572 fn log_levels(&self) -> &crate::actuator::LogLevels {
573 &self.log_levels
574 }
575
576 fn task_registry(&self) -> &crate::actuator::TaskRegistry {
577 &self.task_registry
578 }
579
580 fn job_registry(&self) -> &crate::actuator::JobRegistry {
581 &self.job_registry
582 }
583
584 fn config_props(&self) -> &crate::actuator::ConfigProperties {
585 &self.config_props
586 }
587
588 fn profile(&self) -> &str {
589 self.profile()
590 }
591
592 fn uptime_display(&self) -> String {
593 self.uptime_display()
594 }
595
596 #[cfg(feature = "ws")]
597 fn channels(&self) -> &crate::channels::Channels {
598 &self.channels
599 }
600
601 #[cfg(feature = "ws")]
602 fn shutdown_token(&self) -> tokio_util::sync::CancellationToken {
603 self.shutdown_token()
604 }
605
606 #[cfg(feature = "db")]
607 fn pool(
608 &self,
609 ) -> Option<&diesel_async::pooled_connection::deadpool::Pool<diesel_async::AsyncPgConnection>>
610 {
611 self.pool.as_ref()
612 }
613 }
619
620impl std::fmt::Debug for AppState {
621 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
622 let mut s = f.debug_struct("AppState");
623 #[cfg(feature = "db")]
624 s.field(
625 "pool",
626 &self
627 .pool
628 .as_ref()
629 .map(|p| format!("Pool(max={})", p.status().max_size)),
630 );
631 s.field(
632 "extensions",
633 &self
634 .extensions
635 .read()
636 .map_or(0, |extensions| extensions.len()),
637 );
638 s.field("profile", &self.profile)
639 .field("started_at", &self.started_at)
640 .field("health_detailed", &self.health_detailed)
641 .field("probes", &self.probes)
642 .field("metrics", &"MetricsCollector")
643 .field("log_levels", &"LogLevels")
644 .field("task_registry", &"TaskRegistry")
645 .finish_non_exhaustive()
646 }
647}
648
649#[cfg(test)]
650mod tests {
651 use super::*;
652 #[cfg(feature = "db")]
653 use crate::config;
654 #[cfg(feature = "db")]
655 use crate::db;
656
657 #[test]
658 fn app_state_debug_without_pool() {
659 let state = AppState::for_test().with_profile("dev");
660 let debug = format!("{state:?}");
661 assert!(debug.contains("AppState"));
662 assert!(debug.contains("dev"));
663 }
664
665 #[cfg(feature = "db")]
666 #[test]
667 fn app_state_debug_with_pool() {
668 let config = config::DatabaseConfig {
669 url: Some("postgres://localhost/test".into()),
670 pool_size: 5,
671 ..Default::default()
672 };
673 let pool = db::create_pool(&config).unwrap().unwrap();
674 let state = AppState::for_test().with_pool(pool);
675 let debug = format!("{state:?}");
676 assert!(debug.contains("Pool(max=5)"));
677 }
678
679 #[cfg(feature = "db")]
680 #[test]
681 fn database_topology_state_exposes_replica_as_read_pool() {
682 let primary_config = config::DatabaseConfig {
683 url: Some("postgres://localhost/primary".into()),
684 pool_size: 5,
685 ..Default::default()
686 };
687 let replica_config = config::DatabaseConfig {
688 url: Some("postgres://localhost/replica".into()),
689 pool_size: 2,
690 ..Default::default()
691 };
692 let primary = db::create_pool(&primary_config).unwrap().unwrap();
693 let replica = db::create_pool(&replica_config).unwrap().unwrap();
694
695 let state = AppState::for_test()
696 .with_pool(primary)
697 .with_replica_pool(replica);
698
699 assert_eq!(state.pool().expect("primary pool").status().max_size, 5);
700 assert_eq!(
701 state
702 .replica_pool()
703 .expect("replica pool")
704 .status()
705 .max_size,
706 2
707 );
708 assert_eq!(state.read_pool().expect("read pool").status().max_size, 2);
709 }
710
711 #[cfg(feature = "db")]
712 #[test]
713 fn read_pool_uses_primary_when_replica_is_unready_and_policy_allows_fallback() {
714 let primary_config = config::DatabaseConfig {
715 url: Some("postgres://localhost/primary".into()),
716 pool_size: 5,
717 ..Default::default()
718 };
719 let replica_config = config::DatabaseConfig {
720 url: Some("postgres://localhost/replica".into()),
721 pool_size: 2,
722 ..Default::default()
723 };
724 let primary = db::create_pool(&primary_config).unwrap().unwrap();
725 let replica = db::create_pool(&replica_config).unwrap().unwrap();
726
727 let state = AppState::for_test()
728 .with_pool(primary)
729 .with_replica_pool(replica);
730 state
731 .probes()
732 .configure_replica_dependency(config::ReplicaFallback::Primary);
733 state
734 .probes()
735 .mark_replica_unready("replica migrations lag primary");
736
737 assert_eq!(state.read_pool().expect("read pool").status().max_size, 5);
738 assert_eq!(
739 db::DbState::read_pool(&state)
740 .expect("trait read pool")
741 .status()
742 .max_size,
743 5
744 );
745 }
746
747 #[cfg(feature = "db")]
748 #[test]
749 fn read_pool_does_not_route_to_unready_replica_when_policy_fails_readiness() {
750 let primary_config = config::DatabaseConfig {
751 url: Some("postgres://localhost/primary".into()),
752 pool_size: 5,
753 ..Default::default()
754 };
755 let replica_config = config::DatabaseConfig {
756 url: Some("postgres://localhost/replica".into()),
757 pool_size: 2,
758 ..Default::default()
759 };
760 let primary = db::create_pool(&primary_config).unwrap().unwrap();
761 let replica = db::create_pool(&replica_config).unwrap().unwrap();
762
763 let state = AppState::for_test()
764 .with_pool(primary)
765 .with_replica_pool(replica);
766 state
767 .probes()
768 .configure_replica_dependency(config::ReplicaFallback::FailReadiness);
769 state
770 .probes()
771 .mark_replica_unready("replica connection failed");
772
773 assert!(state.read_pool().is_none());
774 }
775
776 #[cfg(feature = "db")]
777 #[tokio::test]
778 async fn readiness_fails_when_app_state_replica_is_unready_and_policy_is_fail_readiness() {
779 let primary_config = config::DatabaseConfig {
780 url: Some("postgres://localhost/primary".into()),
781 pool_size: 5,
782 ..Default::default()
783 };
784 let replica_config = config::DatabaseConfig {
785 url: Some("postgres://localhost/replica".into()),
786 pool_size: 2,
787 ..Default::default()
788 };
789 let primary = db::create_pool(&primary_config).unwrap().unwrap();
790 let replica = db::create_pool(&replica_config).unwrap().unwrap();
791
792 let state = AppState::for_test()
793 .with_pool(primary)
794 .with_replica_pool(replica);
795 state
796 .probes()
797 .configure_replica_dependency(config::ReplicaFallback::FailReadiness);
798 state
799 .probes()
800 .mark_replica_unready("replica migrations lag primary");
801
802 let (status, _) = crate::probe::readiness_response(&state).await;
803
804 assert_eq!(status, http::StatusCode::SERVICE_UNAVAILABLE);
805 }
806
807 #[test]
808 fn detached_state_starts_without_profile() {
809 let state = AppState::detached();
810
811 assert_eq!(state.profile(), "default");
812 }
813
814 fn require_clone<T: Clone>(t: &T) -> T {
815 t.clone()
816 }
817
818 #[test]
819 fn app_state_is_clone() {
820 let state = AppState::for_test();
821 let _cloned = require_clone(&state);
822 }
823
824 #[test]
825 fn app_state_profile_accessor() {
826 let state = AppState::for_test().with_profile("staging");
827 assert_eq!(state.profile(), "staging");
828 }
829
830 #[test]
831 fn app_state_profile_default() {
832 let state = AppState::for_test();
833 assert_eq!(state.profile(), "default");
834 }
835
836 #[test]
837 fn app_state_uptime_display() {
838 let state = AppState::for_test();
839 let display = state.uptime_display();
840 assert!(
841 display.contains('s'),
842 "uptime should contain 's': {display}"
843 );
844 }
845
846 #[test]
847 fn app_state_accessors() {
848 let state = AppState::for_test();
849
850 let _metrics = state.metrics();
852 let _log_levels = state.log_levels();
853 let _task_registry = state.task_registry();
854 let _config_props = state.config_props();
855
856 #[cfg(feature = "db")]
857 {
858 let _pool = state.pool();
859 }
860 let _missing = state.extension::<String>();
861 }
862
863 #[test]
864 fn app_state_runtime_extensions_round_trip() {
865 let state = AppState::for_test();
866 state.insert_extension(String::from("haunted"));
867
868 let stored = state
869 .extension::<String>()
870 .expect("runtime extension should be installed");
871
872 assert_eq!(stored.as_str(), "haunted");
873 }
874}