1use std::{
22 any::{Any, TypeId},
23 collections::HashMap,
24};
25
26#[derive(Default)]
31pub struct AppState {
32 map: HashMap<TypeId, Box<dyn Any + Send + Sync>>,
33}
34
35impl std::fmt::Debug for AppState {
36 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37 let type_names: Vec<_> = self.map.keys().map(|id| format!("{id:?}")).collect();
38 f.debug_struct("AppState").field("types", &type_names).finish()
39 }
40}
41
42impl AppState {
43 #[must_use]
45 pub fn new() -> Self {
46 Self { map: HashMap::new() }
47 }
48
49 pub fn insert<T: Send + Sync + 'static>(&mut self, val: T) {
51 self.map.insert(TypeId::of::<T>(), Box::new(val));
52 }
53
54 #[must_use]
56 pub fn get<T: 'static>(&self) -> Option<&T> {
57 self.map.get(&TypeId::of::<T>()).and_then(|boxed| boxed.downcast_ref())
58 }
59
60 pub fn get_mut<T: 'static>(&mut self) -> Option<&mut T> {
62 self.map.get_mut(&TypeId::of::<T>()).and_then(|boxed| boxed.downcast_mut())
63 }
64
65 pub fn remove<T: 'static>(&mut self) -> Option<T> {
67 self.map
68 .remove(&TypeId::of::<T>())
69 .and_then(|boxed| boxed.downcast().ok())
70 .map(|boxed| *boxed)
71 }
72
73 #[must_use]
75 pub fn contains<T: 'static>(&self) -> bool {
76 self.map.contains_key(&TypeId::of::<T>())
77 }
78
79 #[must_use]
81 pub fn len(&self) -> usize {
82 self.map.len()
83 }
84
85 #[must_use]
87 pub fn is_empty(&self) -> bool {
88 self.map.is_empty()
89 }
90
91 pub fn clear(&mut self) {
93 self.map.clear();
94 }
95}
96
97#[cfg(test)]
100mod tests {
101 use super::*;
102
103 #[derive(Debug, Clone, PartialEq, Eq)]
104 struct DbPool(String);
105
106 #[derive(Debug, Clone, PartialEq, Eq)]
107 struct Config {
108 max_connections: u32,
109 }
110
111 #[test]
112 fn insert_and_get() {
113 let mut state = AppState::new();
114 state.insert(DbPool("postgres://localhost".into()));
115
116 let pool = state.get::<DbPool>().unwrap();
117 assert_eq!(pool.0, "postgres://localhost");
118 }
119
120 #[test]
121 fn different_types_coexist() {
122 let mut state = AppState::new();
123 state.insert(DbPool("pg".into()));
124 state.insert(Config { max_connections: 10 });
125
126 assert_eq!(state.get::<DbPool>().unwrap().0, "pg");
127 assert_eq!(state.get::<Config>().unwrap().max_connections, 10);
128 }
129
130 #[test]
131 fn insert_replaces_previous() {
132 let mut state = AppState::new();
133 state.insert(DbPool("old".into()));
134 state.insert(DbPool("new".into()));
135
136 assert_eq!(state.get::<DbPool>().unwrap().0, "new");
137 }
138
139 #[test]
140 fn get_missing_returns_none() {
141 let state = AppState::new();
142 assert!(state.get::<DbPool>().is_none());
143 }
144
145 #[test]
146 fn remove_returns_value() {
147 let mut state = AppState::new();
148 state.insert(DbPool("pg".into()));
149
150 let removed = state.remove::<DbPool>().unwrap();
151 assert_eq!(removed.0, "pg");
152 assert!(state.get::<DbPool>().is_none());
153 }
154
155 #[test]
156 fn contains_check() {
157 let mut state = AppState::new();
158 assert!(!state.contains::<DbPool>());
159
160 state.insert(DbPool("pg".into()));
161 assert!(state.contains::<DbPool>());
162 }
163
164 #[test]
165 fn len_and_is_empty() {
166 let mut state = AppState::new();
167 assert!(state.is_empty());
168 assert_eq!(state.len(), 0);
169
170 state.insert(DbPool("pg".into()));
171 state.insert(Config { max_connections: 5 });
172 assert_eq!(state.len(), 2);
173 assert!(!state.is_empty());
174 }
175
176 #[test]
177 fn clear_removes_all() {
178 let mut state = AppState::new();
179 state.insert(DbPool("pg".into()));
180 state.insert(Config { max_connections: 5 });
181
182 state.clear();
183 assert!(state.is_empty());
184 }
185
186 #[test]
187 fn get_mut_allows_mutation() {
188 let mut state = AppState::new();
189 state.insert(Config { max_connections: 5 });
190
191 state.get_mut::<Config>().unwrap().max_connections = 20;
192 assert_eq!(state.get::<Config>().unwrap().max_connections, 20);
193 }
194
195 const _: () = {
197 const fn assert_send_sync<T: Send + Sync>() {}
198 assert_send_sync::<AppState>();
199 };
200}