1use std::any::{Any, TypeId};
4use std::collections::HashMap;
5use std::sync::Arc;
6
7use atrg_db::DbPool;
8
9use crate::config::Config;
10use atrg_identity::IdentityResolver;
11
12#[derive(Default)]
39pub struct Extensions {
40 map: HashMap<TypeId, Box<dyn Any + Send + Sync>>,
41}
42
43impl Extensions {
44 pub fn new() -> Self {
46 Self {
47 map: HashMap::new(),
48 }
49 }
50
51 pub fn insert<T: Send + Sync + 'static>(&mut self, value: T) -> Option<T> {
54 self.map
55 .insert(TypeId::of::<T>(), Box::new(value))
56 .and_then(|boxed| boxed.downcast::<T>().ok().map(|b| *b))
57 }
58
59 pub fn get<T: Send + Sync + 'static>(&self) -> Option<&T> {
62 self.map
63 .get(&TypeId::of::<T>())
64 .and_then(|boxed| boxed.downcast_ref::<T>())
65 }
66
67 pub fn contains<T: Send + Sync + 'static>(&self) -> bool {
69 self.map.contains_key(&TypeId::of::<T>())
70 }
71
72 pub fn len(&self) -> usize {
74 self.map.len()
75 }
76
77 pub fn is_empty(&self) -> bool {
79 self.map.is_empty()
80 }
81}
82
83impl std::fmt::Debug for Extensions {
85 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
86 f.debug_struct("Extensions")
87 .field("len", &self.map.len())
88 .finish_non_exhaustive()
89 }
90}
91
92#[derive(Clone)]
105pub struct AppState {
106 pub config: Arc<Config>,
108 pub db: DbPool,
112 pub http: reqwest::Client,
114 pub identity: Arc<IdentityResolver>,
116 pub extensions: Arc<Extensions>,
120}
121
122impl AppState {
123 pub fn extension<T: Send + Sync + 'static>(&self) -> &T {
142 self.extensions.get::<T>().unwrap_or_else(|| {
143 panic!(
144 "AppState::extension::<{}>() called but no value of that type was registered. \
145 Did you forget to call `AtrgApp::with_extension(value)` during app setup?",
146 std::any::type_name::<T>()
147 )
148 })
149 }
150
151 pub fn try_extension<T: Send + Sync + 'static>(&self) -> Option<&T> {
162 self.extensions.get::<T>()
163 }
164
165 pub fn has_extension<T: Send + Sync + 'static>(&self) -> bool {
167 self.extensions.contains::<T>()
168 }
169}
170
171impl axum::extract::FromRef<AppState> for DbPool {
177 fn from_ref(state: &AppState) -> Self {
178 state.db.clone()
179 }
180}
181
182impl axum::extract::FromRef<AppState> for Arc<Config> {
183 fn from_ref(state: &AppState) -> Self {
184 state.config.clone()
185 }
186}
187
188impl axum::extract::FromRef<AppState> for Arc<IdentityResolver> {
189 fn from_ref(state: &AppState) -> Self {
190 state.identity.clone()
191 }
192}
193
194impl axum::extract::FromRef<AppState> for Arc<Extensions> {
195 fn from_ref(state: &AppState) -> Self {
196 state.extensions.clone()
197 }
198}
199
200#[cfg(test)]
201mod tests {
202 use super::*;
203
204 fn _assert_send_sync_clone<T: Send + Sync + Clone>() {}
206
207 #[test]
208 fn app_state_is_send_sync_clone() {
209 _assert_send_sync_clone::<AppState>();
210 }
211
212 #[test]
215 fn extensions_insert_and_get() {
216 struct Foo(u32);
217 struct Bar(String);
218
219 let mut ext = Extensions::new();
220 ext.insert(Foo(42));
221 ext.insert(Bar("hello".into()));
222
223 assert_eq!(ext.get::<Foo>().unwrap().0, 42);
224 assert_eq!(ext.get::<Bar>().unwrap().0, "hello");
225 }
226
227 #[test]
228 fn extensions_get_missing_returns_none() {
229 let ext = Extensions::new();
230 assert!(ext.get::<u32>().is_none());
231 }
232
233 #[test]
234 fn extensions_insert_replaces_and_returns_old() {
235 struct Config(String);
236
237 let mut ext = Extensions::new();
238 let old = ext.insert(Config("v1".into()));
239 assert!(old.is_none());
240
241 let old = ext.insert(Config("v2".into()));
242 assert_eq!(old.unwrap().0, "v1");
243 assert_eq!(ext.get::<Config>().unwrap().0, "v2");
244 }
245
246 #[test]
247 fn extensions_contains() {
248 struct Present;
249
250 let mut ext = Extensions::new();
251 assert!(!ext.contains::<Present>());
252 ext.insert(Present);
253 assert!(ext.contains::<Present>());
254 }
255
256 #[test]
257 fn extensions_len_and_is_empty() {
258 struct A;
259 struct B;
260
261 let mut ext = Extensions::new();
262 assert!(ext.is_empty());
263 assert_eq!(ext.len(), 0);
264
265 ext.insert(A);
266 assert!(!ext.is_empty());
267 assert_eq!(ext.len(), 1);
268
269 ext.insert(B);
270 assert_eq!(ext.len(), 2);
271 }
272
273 #[test]
274 fn extensions_debug_shows_len() {
275 let mut ext = Extensions::new();
276 ext.insert(42u32);
277 let dbg = format!("{:?}", ext);
278 assert!(dbg.contains("Extensions"));
279 assert!(dbg.contains("len"));
280 }
281
282 #[tokio::test]
283 async fn app_state_extension_returns_value() {
284 struct MyService {
285 name: String,
286 }
287
288 let mut ext = Extensions::new();
289 ext.insert(MyService {
290 name: "test".into(),
291 });
292
293 let db = atrg_db::connect("sqlite::memory:").await.unwrap();
294 let state = AppState {
295 config: Arc::new(crate::config::Config {
296 app: crate::config::AppConfig {
297 name: "test".into(),
298 host: "127.0.0.1".into(),
299 port: 3000,
300 secret_key: "secret".into(),
301 cors_origins: vec![],
302 environment: "development".into(),
303 admin_dids: vec![],
304 },
305 auth: crate::config::AuthConfig {
306 client_id: "http://localhost/client-metadata.json".into(),
307 redirect_uri: "http://localhost/auth/callback".into(),
308 scope: "atproto transition:generic".into(),
309 post_login_redirect: "/".into(),
310 },
311 database: crate::config::DatabaseConfig {
312 url: "sqlite::memory:".into(),
313 },
314 jetstream: None,
315 firehose: None,
316 feed_generator: None,
317 labeler: None,
318 rate_limit: None,
319 }),
320 db,
321 http: reqwest::Client::new(),
322 identity: Arc::new(atrg_identity::IdentityResolver::with_defaults(
323 reqwest::Client::new(),
324 )),
325 extensions: Arc::new(ext),
326 };
327
328 assert_eq!(state.extension::<MyService>().name, "test");
329 }
330
331 #[tokio::test]
332 async fn app_state_try_extension_returns_none_when_missing() {
333 struct NotRegistered;
334
335 let db = atrg_db::connect("sqlite::memory:").await.unwrap();
336 let state = AppState {
337 config: Arc::new(crate::config::Config {
338 app: crate::config::AppConfig {
339 name: "test".into(),
340 host: "127.0.0.1".into(),
341 port: 3000,
342 secret_key: "secret".into(),
343 cors_origins: vec![],
344 environment: "development".into(),
345 admin_dids: vec![],
346 },
347 auth: crate::config::AuthConfig {
348 client_id: "http://localhost/client-metadata.json".into(),
349 redirect_uri: "http://localhost/auth/callback".into(),
350 scope: "atproto transition:generic".into(),
351 post_login_redirect: "/".into(),
352 },
353 database: crate::config::DatabaseConfig {
354 url: "sqlite::memory:".into(),
355 },
356 jetstream: None,
357 firehose: None,
358 feed_generator: None,
359 labeler: None,
360 rate_limit: None,
361 }),
362 db,
363 http: reqwest::Client::new(),
364 identity: Arc::new(atrg_identity::IdentityResolver::with_defaults(
365 reqwest::Client::new(),
366 )),
367 extensions: Arc::new(Extensions::new()),
368 };
369
370 assert!(state.try_extension::<NotRegistered>().is_none());
371 assert!(!state.has_extension::<NotRegistered>());
372 }
373
374 #[tokio::test]
375 #[should_panic(expected = "no value of that type was registered")]
376 async fn app_state_extension_panics_when_missing() {
377 struct NotRegistered;
378
379 let db = atrg_db::connect("sqlite::memory:").await.unwrap();
380 let state = AppState {
381 config: Arc::new(crate::config::Config {
382 app: crate::config::AppConfig {
383 name: "test".into(),
384 host: "127.0.0.1".into(),
385 port: 3000,
386 secret_key: "secret".into(),
387 cors_origins: vec![],
388 environment: "development".into(),
389 admin_dids: vec![],
390 },
391 auth: crate::config::AuthConfig {
392 client_id: "http://localhost/client-metadata.json".into(),
393 redirect_uri: "http://localhost/auth/callback".into(),
394 scope: "atproto transition:generic".into(),
395 post_login_redirect: "/".into(),
396 },
397 database: crate::config::DatabaseConfig {
398 url: "sqlite::memory:".into(),
399 },
400 jetstream: None,
401 firehose: None,
402 feed_generator: None,
403 labeler: None,
404 rate_limit: None,
405 }),
406 db,
407 http: reqwest::Client::new(),
408 identity: Arc::new(atrg_identity::IdentityResolver::with_defaults(
409 reqwest::Client::new(),
410 )),
411 extensions: Arc::new(Extensions::new()),
412 };
413
414 let _ = state.extension::<NotRegistered>();
415 }
416}