use std::any::{Any, TypeId};
use std::collections::HashMap;
use std::sync::Arc;
use tracing::warn;
use crate::error::{Error, Result};
use crate::extract::{FromRequest, RequestContext};
#[derive(Default)]
pub struct StateMap {
entries: HashMap<TypeId, Arc<dyn Any + Send + Sync>>,
}
impl StateMap {
pub fn new() -> Self {
Self::default()
}
pub fn insert<S: Send + Sync + 'static>(&mut self, value: S) {
if self.entries.contains_key(&TypeId::of::<S>()) {
warn!(
target: "tork",
"state value of type `{}` is being silently replaced",
std::any::type_name::<S>(),
);
}
self.entries.insert(TypeId::of::<S>(), Arc::new(value));
}
pub fn get<S: Send + Sync + 'static>(&self) -> Option<Arc<S>> {
self.entries
.get(&TypeId::of::<S>())
.and_then(|entry| entry.clone().downcast::<S>().ok())
}
pub fn contains<S: Send + Sync + 'static>(&self) -> bool {
self.entries.contains_key(&TypeId::of::<S>())
}
pub fn remove<S: Send + Sync + 'static>(&mut self) {
self.entries.remove(&TypeId::of::<S>());
}
}
pub type AppStateRef = Arc<StateMap>;
pub struct State<S>(pub S);
impl<S> FromRequest for State<S>
where
S: Clone + Send + Sync + 'static,
{
fn from_request(
ctx: &RequestContext,
) -> impl std::future::Future<Output = Result<Self>> + Send {
let resolved = match ctx.state().get::<S>() {
Some(value) => Ok(State((*value).clone())),
None => Err(Error::internal(format!(
"application state `{}` was not configured",
std::any::type_name::<S>()
))),
};
async move { resolved }
}
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Clone)]
struct Config {
name: String,
}
#[test]
fn insert_and_get_by_type() {
let mut map = StateMap::new();
map.insert(Config {
name: "tork".to_owned(),
});
let config = map.get::<Config>().expect("config should be present");
assert_eq!(config.name, "tork");
assert!(map.get::<u32>().is_none());
assert!(map.contains::<Config>());
}
}