use serde::{Deserialize, Serialize};
use std::any::TypeId;
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct PhaseStateStore {
entries: HashMap<u64, Vec<u8>>,
}
impl PhaseStateStore {
pub fn new() -> Self {
Self::default()
}
#[must_use = "this returns a Result that must be checked"]
pub fn set<T>(&mut self, value: &T) -> Result<(), PhaseStoreError>
where
T: Serialize + 'static,
{
let key = type_key::<T>();
let bytes = bincode::serialize(value).map_err(|e| PhaseStoreError::Serialize {
details: e.to_string(),
})?;
self.entries.insert(key, bytes);
Ok(())
}
#[must_use = "this returns a Result that must be checked"]
pub fn get<T>(&self) -> Result<Option<T>, PhaseStoreError>
where
T: for<'de> Deserialize<'de> + 'static,
{
let key = type_key::<T>();
match self.entries.get(&key) {
Some(bytes) => {
let val =
bincode::deserialize(bytes).map_err(|e| PhaseStoreError::Deserialize {
details: e.to_string(),
})?;
Ok(Some(val))
}
None => Ok(None),
}
}
pub fn clear<T: 'static>(&mut self) {
let key = type_key::<T>();
self.entries.remove(&key);
}
pub fn contains<T: 'static>(&self) -> bool {
self.entries.contains_key(&type_key::<T>())
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
}
#[derive(Debug, Clone, thiserror::Error)]
#[non_exhaustive]
pub enum PhaseStoreError {
#[error("Phase serialization failed: {details}")]
Serialize {
details: String,
},
#[error("Phase deserialization failed: {details}")]
Deserialize {
details: String,
},
}
fn type_key<T: 'static>() -> u64 {
use std::hash::{Hash, Hasher};
let mut hasher = std::collections::hash_map::DefaultHasher::new();
TypeId::of::<T>().hash(&mut hasher);
hasher.finish()
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
enum TestPhase {
Start,
Middle { value: String },
End,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
enum OtherPhase {
Only,
}
#[test]
fn test_set_and_get_round_trip() {
let mut store = PhaseStateStore::new();
let phase = TestPhase::Middle {
value: "hello".into(),
};
store.set(&phase).unwrap();
let retrieved: TestPhase = store.get::<TestPhase>().unwrap().unwrap();
assert_eq!(retrieved, phase);
}
#[test]
fn test_get_missing_type_returns_none() {
let store = PhaseStateStore::new();
let result = store.get::<OtherPhase>().unwrap();
assert!(result.is_none());
}
#[test]
fn test_different_types_independent() {
let mut store = PhaseStateStore::new();
store.set(&TestPhase::Start).unwrap();
store.set(&OtherPhase::Only).unwrap();
assert_eq!(store.get::<TestPhase>().unwrap(), Some(TestPhase::Start));
assert_eq!(store.get::<OtherPhase>().unwrap(), Some(OtherPhase::Only));
}
#[test]
fn test_overwrite_same_type() {
let mut store = PhaseStateStore::new();
store.set(&TestPhase::Start).unwrap();
store.set(&TestPhase::End).unwrap();
assert_eq!(store.get::<TestPhase>().unwrap(), Some(TestPhase::End));
}
#[test]
fn test_clear_removes_type() {
let mut store = PhaseStateStore::new();
store.set(&TestPhase::Start).unwrap();
assert!(store.contains::<TestPhase>());
store.clear::<TestPhase>();
assert!(!store.contains::<TestPhase>());
assert!(store.get::<TestPhase>().unwrap().is_none());
}
#[test]
fn test_empty_store() {
let store = PhaseStateStore::new();
assert!(store.is_empty());
}
#[test]
fn test_store_serialization_round_trip() {
let mut store = PhaseStateStore::new();
store
.set(&TestPhase::Middle {
value: "ser".into(),
})
.unwrap();
let bytes = bincode::serialize(&store).unwrap();
let restored: PhaseStateStore = bincode::deserialize(&bytes).unwrap();
let phase: TestPhase = restored.get::<TestPhase>().unwrap().unwrap();
assert_eq!(
phase,
TestPhase::Middle {
value: "ser".into(),
}
);
}
}