use std::any::{Any, TypeId};
use std::collections::HashMap;
use std::sync::Arc;
use atrg_db::DbPool;
use crate::config::Config;
use atrg_identity::IdentityResolver;
#[derive(Default)]
pub struct Extensions {
map: HashMap<TypeId, Box<dyn Any + Send + Sync>>,
}
impl Extensions {
pub fn new() -> Self {
Self {
map: HashMap::new(),
}
}
pub fn insert<T: Send + Sync + 'static>(&mut self, value: T) -> Option<T> {
self.map
.insert(TypeId::of::<T>(), Box::new(value))
.and_then(|boxed| boxed.downcast::<T>().ok().map(|b| *b))
}
pub fn get<T: Send + Sync + 'static>(&self) -> Option<&T> {
self.map
.get(&TypeId::of::<T>())
.and_then(|boxed| boxed.downcast_ref::<T>())
}
pub fn contains<T: Send + Sync + 'static>(&self) -> bool {
self.map.contains_key(&TypeId::of::<T>())
}
pub fn len(&self) -> usize {
self.map.len()
}
pub fn is_empty(&self) -> bool {
self.map.is_empty()
}
}
impl std::fmt::Debug for Extensions {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Extensions")
.field("len", &self.map.len())
.finish_non_exhaustive()
}
}
#[derive(Clone)]
pub struct AppState {
pub config: Arc<Config>,
pub db: DbPool,
pub http: reqwest::Client,
pub identity: Arc<IdentityResolver>,
pub extensions: Arc<Extensions>,
}
impl AppState {
pub fn extension<T: Send + Sync + 'static>(&self) -> &T {
self.extensions.get::<T>().unwrap_or_else(|| {
panic!(
"AppState::extension::<{}>() called but no value of that type was registered. \
Did you forget to call `AtrgApp::with_extension(value)` during app setup?",
std::any::type_name::<T>()
)
})
}
pub fn try_extension<T: Send + Sync + 'static>(&self) -> Option<&T> {
self.extensions.get::<T>()
}
pub fn has_extension<T: Send + Sync + 'static>(&self) -> bool {
self.extensions.contains::<T>()
}
}
impl axum::extract::FromRef<AppState> for DbPool {
fn from_ref(state: &AppState) -> Self {
state.db.clone()
}
}
impl axum::extract::FromRef<AppState> for Arc<Config> {
fn from_ref(state: &AppState) -> Self {
state.config.clone()
}
}
impl axum::extract::FromRef<AppState> for Arc<IdentityResolver> {
fn from_ref(state: &AppState) -> Self {
state.identity.clone()
}
}
impl axum::extract::FromRef<AppState> for Arc<Extensions> {
fn from_ref(state: &AppState) -> Self {
state.extensions.clone()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn _assert_send_sync_clone<T: Send + Sync + Clone>() {}
#[test]
fn app_state_is_send_sync_clone() {
_assert_send_sync_clone::<AppState>();
}
#[test]
fn extensions_insert_and_get() {
struct Foo(u32);
struct Bar(String);
let mut ext = Extensions::new();
ext.insert(Foo(42));
ext.insert(Bar("hello".into()));
assert_eq!(ext.get::<Foo>().unwrap().0, 42);
assert_eq!(ext.get::<Bar>().unwrap().0, "hello");
}
#[test]
fn extensions_get_missing_returns_none() {
let ext = Extensions::new();
assert!(ext.get::<u32>().is_none());
}
#[test]
fn extensions_insert_replaces_and_returns_old() {
struct Config(String);
let mut ext = Extensions::new();
let old = ext.insert(Config("v1".into()));
assert!(old.is_none());
let old = ext.insert(Config("v2".into()));
assert_eq!(old.unwrap().0, "v1");
assert_eq!(ext.get::<Config>().unwrap().0, "v2");
}
#[test]
fn extensions_contains() {
struct Present;
let mut ext = Extensions::new();
assert!(!ext.contains::<Present>());
ext.insert(Present);
assert!(ext.contains::<Present>());
}
#[test]
fn extensions_len_and_is_empty() {
struct A;
struct B;
let mut ext = Extensions::new();
assert!(ext.is_empty());
assert_eq!(ext.len(), 0);
ext.insert(A);
assert!(!ext.is_empty());
assert_eq!(ext.len(), 1);
ext.insert(B);
assert_eq!(ext.len(), 2);
}
#[test]
fn extensions_debug_shows_len() {
let mut ext = Extensions::new();
ext.insert(42u32);
let dbg = format!("{:?}", ext);
assert!(dbg.contains("Extensions"));
assert!(dbg.contains("len"));
}
#[tokio::test]
async fn app_state_extension_returns_value() {
struct MyService {
name: String,
}
let mut ext = Extensions::new();
ext.insert(MyService {
name: "test".into(),
});
let db = atrg_db::connect("sqlite::memory:").await.unwrap();
let state = AppState {
config: Arc::new(crate::config::Config {
app: crate::config::AppConfig {
name: "test".into(),
host: "127.0.0.1".into(),
port: 3000,
secret_key: "secret".into(),
cors_origins: vec![],
environment: "development".into(),
admin_dids: vec![],
},
auth: crate::config::AuthConfig {
client_id: "http://localhost/client-metadata.json".into(),
redirect_uri: "http://localhost/auth/callback".into(),
scope: "atproto transition:generic".into(),
post_login_redirect: "/".into(),
},
database: crate::config::DatabaseConfig {
url: "sqlite::memory:".into(),
},
jetstream: None,
firehose: None,
feed_generator: None,
labeler: None,
rate_limit: None,
}),
db,
http: reqwest::Client::new(),
identity: Arc::new(atrg_identity::IdentityResolver::with_defaults(
reqwest::Client::new(),
)),
extensions: Arc::new(ext),
};
assert_eq!(state.extension::<MyService>().name, "test");
}
#[tokio::test]
async fn app_state_try_extension_returns_none_when_missing() {
struct NotRegistered;
let db = atrg_db::connect("sqlite::memory:").await.unwrap();
let state = AppState {
config: Arc::new(crate::config::Config {
app: crate::config::AppConfig {
name: "test".into(),
host: "127.0.0.1".into(),
port: 3000,
secret_key: "secret".into(),
cors_origins: vec![],
environment: "development".into(),
admin_dids: vec![],
},
auth: crate::config::AuthConfig {
client_id: "http://localhost/client-metadata.json".into(),
redirect_uri: "http://localhost/auth/callback".into(),
scope: "atproto transition:generic".into(),
post_login_redirect: "/".into(),
},
database: crate::config::DatabaseConfig {
url: "sqlite::memory:".into(),
},
jetstream: None,
firehose: None,
feed_generator: None,
labeler: None,
rate_limit: None,
}),
db,
http: reqwest::Client::new(),
identity: Arc::new(atrg_identity::IdentityResolver::with_defaults(
reqwest::Client::new(),
)),
extensions: Arc::new(Extensions::new()),
};
assert!(state.try_extension::<NotRegistered>().is_none());
assert!(!state.has_extension::<NotRegistered>());
}
#[tokio::test]
#[should_panic(expected = "no value of that type was registered")]
async fn app_state_extension_panics_when_missing() {
struct NotRegistered;
let db = atrg_db::connect("sqlite::memory:").await.unwrap();
let state = AppState {
config: Arc::new(crate::config::Config {
app: crate::config::AppConfig {
name: "test".into(),
host: "127.0.0.1".into(),
port: 3000,
secret_key: "secret".into(),
cors_origins: vec![],
environment: "development".into(),
admin_dids: vec![],
},
auth: crate::config::AuthConfig {
client_id: "http://localhost/client-metadata.json".into(),
redirect_uri: "http://localhost/auth/callback".into(),
scope: "atproto transition:generic".into(),
post_login_redirect: "/".into(),
},
database: crate::config::DatabaseConfig {
url: "sqlite::memory:".into(),
},
jetstream: None,
firehose: None,
feed_generator: None,
labeler: None,
rate_limit: None,
}),
db,
http: reqwest::Client::new(),
identity: Arc::new(atrg_identity::IdentityResolver::with_defaults(
reqwest::Client::new(),
)),
extensions: Arc::new(Extensions::new()),
};
let _ = state.extension::<NotRegistered>();
}
}