1use serde::{Deserialize, Serialize};
12use std::any::TypeId;
13use std::collections::HashMap;
14
15#[derive(Debug, Clone, Serialize, Deserialize, Default)]
29pub struct PhaseStateStore {
30 entries: HashMap<u64, Vec<u8>>,
33}
34
35impl PhaseStateStore {
36 pub fn new() -> Self {
38 Self::default()
39 }
40
41 #[must_use = "this returns a Result that must be checked"]
45 pub fn set<T>(&mut self, value: &T) -> Result<(), PhaseStoreError>
46 where
47 T: Serialize + 'static,
48 {
49 let key = type_key::<T>();
50 let bytes = bincode::serialize(value).map_err(|e| PhaseStoreError::Serialize {
51 details: e.to_string(),
52 })?;
53 self.entries.insert(key, bytes);
54 Ok(())
55 }
56
57 #[must_use = "this returns a Result that must be checked"]
61 pub fn get<T>(&self) -> Result<Option<T>, PhaseStoreError>
62 where
63 T: for<'de> Deserialize<'de> + 'static,
64 {
65 let key = type_key::<T>();
66 match self.entries.get(&key) {
67 Some(bytes) => {
68 let val =
69 bincode::deserialize(bytes).map_err(|e| PhaseStoreError::Deserialize {
70 details: e.to_string(),
71 })?;
72 Ok(Some(val))
73 }
74 None => Ok(None),
75 }
76 }
77
78 pub fn clear<T: 'static>(&mut self) {
80 let key = type_key::<T>();
81 self.entries.remove(&key);
82 }
83
84 pub fn contains<T: 'static>(&self) -> bool {
86 self.entries.contains_key(&type_key::<T>())
87 }
88
89 pub fn is_empty(&self) -> bool {
91 self.entries.is_empty()
92 }
93}
94
95#[derive(Debug, Clone, thiserror::Error)]
97#[non_exhaustive]
98pub enum PhaseStoreError {
99 #[error("Phase serialization failed: {details}")]
101 Serialize {
102 details: String,
104 },
105
106 #[error("Phase deserialization failed: {details}")]
108 Deserialize {
109 details: String,
111 },
112}
113
114fn type_key<T: 'static>() -> u64 {
116 use std::hash::{Hash, Hasher};
117 let mut hasher = std::collections::hash_map::DefaultHasher::new();
118 TypeId::of::<T>().hash(&mut hasher);
119 hasher.finish()
120}
121
122#[cfg(test)]
123mod tests {
124 use super::*;
125
126 #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
127 enum TestPhase {
128 Start,
129 Middle { value: String },
130 End,
131 }
132
133 #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
134 enum OtherPhase {
135 Only,
136 }
137
138 #[test]
139 fn test_set_and_get_round_trip() {
140 let mut store = PhaseStateStore::new();
141 let phase = TestPhase::Middle {
142 value: "hello".into(),
143 };
144 store.set(&phase).unwrap();
145 let retrieved: TestPhase = store.get::<TestPhase>().unwrap().unwrap();
146 assert_eq!(retrieved, phase);
147 }
148
149 #[test]
150 fn test_get_missing_type_returns_none() {
151 let store = PhaseStateStore::new();
152 let result = store.get::<OtherPhase>().unwrap();
153 assert!(result.is_none());
154 }
155
156 #[test]
157 fn test_different_types_independent() {
158 let mut store = PhaseStateStore::new();
159 store.set(&TestPhase::Start).unwrap();
160 store.set(&OtherPhase::Only).unwrap();
161
162 assert_eq!(store.get::<TestPhase>().unwrap(), Some(TestPhase::Start));
163 assert_eq!(store.get::<OtherPhase>().unwrap(), Some(OtherPhase::Only));
164 }
165
166 #[test]
167 fn test_overwrite_same_type() {
168 let mut store = PhaseStateStore::new();
169 store.set(&TestPhase::Start).unwrap();
170 store.set(&TestPhase::End).unwrap();
171 assert_eq!(store.get::<TestPhase>().unwrap(), Some(TestPhase::End));
172 }
173
174 #[test]
175 fn test_clear_removes_type() {
176 let mut store = PhaseStateStore::new();
177 store.set(&TestPhase::Start).unwrap();
178 assert!(store.contains::<TestPhase>());
179 store.clear::<TestPhase>();
180 assert!(!store.contains::<TestPhase>());
181 assert!(store.get::<TestPhase>().unwrap().is_none());
182 }
183
184 #[test]
185 fn test_empty_store() {
186 let store = PhaseStateStore::new();
187 assert!(store.is_empty());
188 }
189
190 #[test]
191 fn test_store_serialization_round_trip() {
192 let mut store = PhaseStateStore::new();
193 store
194 .set(&TestPhase::Middle {
195 value: "ser".into(),
196 })
197 .unwrap();
198
199 let bytes = bincode::serialize(&store).unwrap();
200 let restored: PhaseStateStore = bincode::deserialize(&bytes).unwrap();
201
202 let phase: TestPhase = restored.get::<TestPhase>().unwrap().unwrap();
203 assert_eq!(
204 phase,
205 TestPhase::Middle {
206 value: "ser".into(),
207 }
208 );
209 }
210}