#![warn(clippy::all, clippy::pedantic, clippy::nursery, clippy::cargo)]
#![allow(clippy::multiple_crate_versions)]
use crate::{
Error, HttpRequest,
from_request::{FromRequest, IntoHandlerError},
};
use std::{collections::BTreeMap, fmt, sync::Arc};
#[derive(Debug)]
pub enum StateError {
NotFound {
type_name: &'static str,
},
NotInitialized {
backend: String,
},
TypeMismatch {
requested_type: &'static str,
found_type: Option<String>,
},
}
impl fmt::Display for StateError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::NotFound { type_name } => {
write!(f, "State of type '{type_name}' not found in application")
}
Self::NotInitialized { backend } => {
write!(f, "State container not initialized for backend '{backend}'")
}
Self::TypeMismatch {
requested_type,
found_type,
} => {
if let Some(found) = found_type {
write!(
f,
"State type mismatch: requested '{requested_type}', found '{found}'"
)
} else {
write!(
f,
"State type mismatch: requested '{requested_type}', no matching type found"
)
}
}
}
}
}
impl std::error::Error for StateError {}
impl IntoHandlerError for StateError {
fn into_handler_error(self) -> Error {
Error::internal_server_error(self.to_string())
}
}
#[derive(Debug)]
pub struct State<T>(pub Arc<T>);
impl<T> State<T> {
#[must_use]
pub const fn new(value: Arc<T>) -> Self {
Self(value)
}
#[must_use]
pub fn into_inner(self) -> Arc<T> {
self.0
}
#[must_use]
pub fn get(&self) -> &T {
&self.0
}
}
impl<T> Clone for State<T> {
fn clone(&self) -> Self {
Self(Arc::clone(&self.0))
}
}
impl<T> std::ops::Deref for State<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.0
}
}
#[derive(Debug, Default)]
pub struct StateContainer {
states: BTreeMap<std::any::TypeId, crate::request::ErasedState>,
type_names: BTreeMap<std::any::TypeId, &'static str>,
}
impl StateContainer {
#[must_use]
pub fn new() -> Self {
Self {
states: BTreeMap::new(),
type_names: BTreeMap::new(),
}
}
pub fn insert<T: Send + Sync + 'static>(&mut self, state: T) {
let type_id = std::any::TypeId::of::<T>();
let type_name = std::any::type_name::<T>();
self.states.insert(type_id, Arc::new(state));
self.type_names.insert(type_id, type_name);
}
#[must_use]
pub fn get<T: Send + Sync + 'static>(&self) -> Option<Arc<T>> {
let type_id = std::any::TypeId::of::<T>();
self.states
.get(&type_id)
.and_then(|arc| Arc::clone(arc).downcast::<T>().ok())
}
#[must_use]
pub fn get_any(&self, type_id: std::any::TypeId) -> Option<crate::request::ErasedState> {
self.states.get(&type_id).cloned()
}
#[must_use]
pub fn contains<T: Send + Sync + 'static>(&self) -> bool {
let type_id = std::any::TypeId::of::<T>();
self.states.contains_key(&type_id)
}
pub fn remove<T: Send + Sync + 'static>(&mut self) -> Option<Arc<T>> {
let type_id = std::any::TypeId::of::<T>();
self.type_names.remove(&type_id);
self.states
.remove(&type_id)
.and_then(|arc| arc.downcast::<T>().ok())
}
#[must_use]
pub fn type_names(&self) -> Vec<&'static str> {
self.type_names.values().copied().collect()
}
pub fn clear(&mut self) {
self.states.clear();
self.type_names.clear();
}
}
impl<T: Send + Sync + 'static> FromRequest for State<T> {
type Error = StateError;
type Future = std::future::Ready<Result<Self, Self::Error>>;
fn from_request_sync(req: &HttpRequest) -> Result<Self, Self::Error> {
req.app_state::<T>()
.map(Self::new)
.ok_or_else(|| StateError::NotFound {
type_name: std::any::type_name::<T>(),
})
}
fn from_request_async(req: HttpRequest) -> Self::Future {
std::future::ready(Self::from_request_sync(&req))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Debug, Clone, PartialEq)]
struct TestConfig {
name: String,
value: u32,
}
#[derive(Debug, Clone, PartialEq)]
struct DatabaseConfig {
url: String,
max_connections: usize,
}
#[test]
fn test_state_container_insert_and_get() {
let mut container = StateContainer::new();
let config = TestConfig {
name: "test".to_string(),
value: 123,
};
container.insert(config.clone());
let retrieved = container.get::<TestConfig>();
assert!(retrieved.is_some());
assert_eq!(*retrieved.unwrap(), config);
}
#[test]
fn test_state_container_multiple_types() {
let mut container = StateContainer::new();
let test_config = TestConfig {
name: "test".to_string(),
value: 123,
};
let db_config = DatabaseConfig {
url: "postgresql://localhost/test".to_string(),
max_connections: 5,
};
container.insert(test_config.clone());
container.insert(db_config.clone());
let retrieved_test = container.get::<TestConfig>();
let retrieved_db = container.get::<DatabaseConfig>();
assert!(retrieved_test.is_some());
assert!(retrieved_db.is_some());
assert_eq!(*retrieved_test.unwrap(), test_config);
assert_eq!(*retrieved_db.unwrap(), db_config);
}
#[test]
fn test_state_container_not_found() {
let container = StateContainer::new();
let result = container.get::<TestConfig>();
assert!(result.is_none());
}
#[test]
fn test_state_container_contains() {
let mut container = StateContainer::new();
assert!(!container.contains::<TestConfig>());
container.insert(TestConfig {
name: "test".to_string(),
value: 123,
});
assert!(container.contains::<TestConfig>());
assert!(!container.contains::<DatabaseConfig>());
}
#[test]
fn test_state_container_remove() {
let mut container = StateContainer::new();
let config = TestConfig {
name: "test".to_string(),
value: 123,
};
container.insert(config.clone());
assert!(container.contains::<TestConfig>());
let removed = container.remove::<TestConfig>();
assert!(removed.is_some());
assert_eq!(*removed.unwrap(), config);
assert!(!container.contains::<TestConfig>());
}
#[test]
fn test_state_container_clear() {
let mut container = StateContainer::new();
container.insert(TestConfig {
name: "test".to_string(),
value: 123,
});
container.insert(DatabaseConfig {
url: "postgresql://localhost/test".to_string(),
max_connections: 5,
});
assert!(container.contains::<TestConfig>());
assert!(container.contains::<DatabaseConfig>());
container.clear();
assert!(!container.contains::<TestConfig>());
assert!(!container.contains::<DatabaseConfig>());
}
#[test]
fn test_state_container_type_names() {
let mut container = StateContainer::new();
container.insert(TestConfig {
name: "test".to_string(),
value: 123,
});
container.insert(DatabaseConfig {
url: "postgresql://localhost/test".to_string(),
max_connections: 5,
});
let type_names = container.type_names();
assert_eq!(type_names.len(), 2);
assert!(type_names.contains(&std::any::type_name::<TestConfig>()));
assert!(type_names.contains(&std::any::type_name::<DatabaseConfig>()));
}
#[test]
fn test_state_error_display() {
let error = StateError::NotFound {
type_name: "TestConfig",
};
assert_eq!(
error.to_string(),
"State of type 'TestConfig' not found in application"
);
let error = StateError::NotInitialized {
backend: "simulator".to_string(),
};
assert_eq!(
error.to_string(),
"State container not initialized for backend 'simulator'"
);
let error = StateError::TypeMismatch {
requested_type: "TestConfig",
found_type: Some("DatabaseConfig".to_string()),
};
assert!(error.to_string().contains("State type mismatch"));
}
#[test]
fn test_state_new_and_into_inner() {
let config = TestConfig {
name: "test".to_string(),
value: 123,
};
let arc_config = Arc::new(config.clone());
let state = State::new(Arc::clone(&arc_config));
assert_eq!(state.get(), &config);
assert_eq!(*state.into_inner(), config);
}
#[test]
fn test_state_clone() {
let config = TestConfig {
name: "test".to_string(),
value: 123,
};
let state1 = State::new(Arc::new(config.clone()));
let state2 = state1.clone();
assert_eq!(state1.get(), state2.get());
assert_eq!(*state1.get(), config);
assert_eq!(*state2.get(), config);
}
#[test]
fn test_state_deref() {
let config = TestConfig {
name: "test".to_string(),
value: 123,
};
let state = State::new(Arc::new(config.clone()));
assert_eq!(state.name, config.name);
assert_eq!(state.value, config.value);
}
}