use std::any::{Any, TypeId};
use std::collections::HashMap;
pub trait State: Send + Sync + 'static {
fn state_type(&self) -> &'static str {
std::any::type_name::<Self>()
}
}
pub struct States {
data: HashMap<TypeId, Box<dyn Any + Send + Sync>>,
}
impl States {
pub(crate) fn new() -> Self {
Self {
data: HashMap::new(),
}
}
pub fn register<T: State>(&mut self, state: T) {
self.data.insert(TypeId::of::<T>(), Box::new(state));
}
pub fn get<T: State>(&self) -> Option<&T> {
self.data
.get(&TypeId::of::<T>())
.and_then(|boxed| boxed.downcast_ref::<T>())
}
pub fn get_mut<T: State>(&mut self) -> Option<&mut T> {
self.data
.get_mut(&TypeId::of::<T>())
.and_then(|boxed| boxed.downcast_mut::<T>())
}
pub fn contains<T: State>(&self) -> bool {
self.data.contains_key(&TypeId::of::<T>())
}
pub fn len(&self) -> usize {
self.data.len()
}
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
#[allow(dead_code)]
pub(crate) fn remove<T: State>(&mut self) -> Option<T> {
self.data
.remove(&TypeId::of::<T>())
.and_then(|boxed| boxed.downcast::<T>().ok())
.map(|boxed| *boxed)
}
#[allow(dead_code)]
pub(crate) fn clear(&mut self) {
self.data.clear();
}
}
impl Default for States {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use serde::{Deserialize, Serialize};
use super::*;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
struct TestState {
value: i32,
}
impl State for TestState {}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
struct AnotherState {
name: String,
}
impl State for AnotherState {}
#[test]
fn test_states_new() {
let states = States::new();
assert_eq!(states.len(), 0);
assert!(states.is_empty());
}
#[test]
fn test_register_and_get() {
let mut states = States::new();
states.register(TestState { value: 42 });
let state = states.get::<TestState>();
assert!(state.is_some());
assert_eq!(state.unwrap().value, 42);
}
#[test]
fn test_get_nonexistent() {
let states = States::new();
let state = states.get::<TestState>();
assert!(state.is_none());
}
#[test]
fn test_get_mut() {
let mut states = States::new();
states.register(TestState { value: 10 });
if let Some(state) = states.get_mut::<TestState>() {
state.value = 20;
}
assert_eq!(states.get::<TestState>().unwrap().value, 20);
}
#[test]
fn test_multiple_types() {
let mut states = States::new();
states.register(TestState { value: 10 });
states.register(AnotherState {
name: "Test".to_string(),
});
assert!(states.get::<TestState>().is_some());
assert!(states.get::<AnotherState>().is_some());
assert_eq!(states.get::<TestState>().unwrap().value, 10);
assert_eq!(states.get::<AnotherState>().unwrap().name, "Test");
}
#[test]
fn test_replace() {
let mut states = States::new();
states.register(TestState { value: 1 });
states.register(TestState { value: 2 });
let state = states.get::<TestState>();
assert_eq!(state.unwrap().value, 2);
}
#[test]
fn test_contains() {
let mut states = States::new();
states.register(TestState { value: 5 });
assert!(states.contains::<TestState>());
assert!(!states.contains::<AnotherState>());
}
#[test]
fn test_remove() {
let mut states = States::new();
states.register(TestState { value: 99 });
let removed = states.remove::<TestState>();
assert!(removed.is_some());
assert_eq!(removed.unwrap().value, 99);
assert!(!states.contains::<TestState>());
}
#[test]
fn test_len_and_clear() {
let mut states = States::new();
assert_eq!(states.len(), 0);
assert!(states.is_empty());
states.register(TestState { value: 1 });
states.register(AnotherState {
name: "DB".to_string(),
});
assert_eq!(states.len(), 2);
assert!(!states.is_empty());
states.clear();
assert_eq!(states.len(), 0);
assert!(states.is_empty());
}
#[test]
fn test_state_type_name() {
let state = TestState { value: 42 };
let type_name = state.state_type();
assert!(type_name.contains("TestState"));
}
}