Skip to main content

pe_core/
phase_store.rs

1//! Phase state store -- TypeId-keyed storage for node phase enums.
2//!
3//! Nodes that use the interrupt system (via `node!` DSL or `#[node]` Mode 2)
4//! store their current phase in this store. The store lives inside checkpoint
5//! data so phases survive serialization across interrupt/resume boundaries.
6//!
7//! Phase enums are type-erased via bincode serialization: each entry is stored
8//! as `(u64 type hash, Vec<u8> serialized bytes)` so the store itself is
9//! `Serialize + Deserialize` without knowing concrete phase types.
10
11use serde::{Deserialize, Serialize};
12use std::any::TypeId;
13use std::collections::HashMap;
14
15/// Stores serialized phase state for nodes, keyed by phase enum TypeId.
16///
17/// Phase enums generated by macros implement `Serialize + Deserialize`.
18/// The store serializes them to bytes on `set()` and deserializes on `get()`,
19/// so the store itself is serializable without generic parameters.
20///
21/// # Example
22///
23/// ```ignore
24/// let mut store = PhaseStateStore::new();
25/// store.set::<MyPhase>(&MyPhase::GatherMore { position })?;
26/// let phase: MyPhase = store.get::<MyPhase>()?.unwrap();
27/// ```
28#[derive(Debug, Clone, Serialize, Deserialize, Default)]
29pub struct PhaseStateStore {
30    /// TypeId is not serializable, so we use a u64 hash as the key.
31    /// This is stable within a single compilation (same types = same hash).
32    entries: HashMap<u64, Vec<u8>>,
33}
34
35impl PhaseStateStore {
36    /// Create an empty phase state store.
37    pub fn new() -> Self {
38        Self::default()
39    }
40
41    /// Store a phase value, serializing it to bytes.
42    ///
43    /// Overwrites any existing phase of the same type.
44    #[must_use = "this returns a Result that must be checked"]
45    pub fn set<T>(&mut self, value: &T) -> Result<(), PhaseStoreError>
46    where
47        T: Serialize + 'static,
48    {
49        let key = type_key::<T>();
50        let bytes = bincode::serialize(value).map_err(|e| PhaseStoreError::Serialize {
51            details: e.to_string(),
52        })?;
53        self.entries.insert(key, bytes);
54        Ok(())
55    }
56
57    /// Retrieve and deserialize a phase value by type.
58    ///
59    /// Returns `None` if no phase of this type is stored.
60    #[must_use = "this returns a Result that must be checked"]
61    pub fn get<T>(&self) -> Result<Option<T>, PhaseStoreError>
62    where
63        T: for<'de> Deserialize<'de> + 'static,
64    {
65        let key = type_key::<T>();
66        match self.entries.get(&key) {
67            Some(bytes) => {
68                let val =
69                    bincode::deserialize(bytes).map_err(|e| PhaseStoreError::Deserialize {
70                        details: e.to_string(),
71                    })?;
72                Ok(Some(val))
73            }
74            None => Ok(None),
75        }
76    }
77
78    /// Remove the phase state for a given type (used on `complete`).
79    pub fn clear<T: 'static>(&mut self) {
80        let key = type_key::<T>();
81        self.entries.remove(&key);
82    }
83
84    /// Whether the store contains a phase of the given type.
85    pub fn contains<T: 'static>(&self) -> bool {
86        self.entries.contains_key(&type_key::<T>())
87    }
88
89    /// Whether the store is empty (no phases stored).
90    pub fn is_empty(&self) -> bool {
91        self.entries.is_empty()
92    }
93}
94
95/// Errors from phase state serialization/deserialization.
96#[derive(Debug, Clone, thiserror::Error)]
97#[non_exhaustive]
98pub enum PhaseStoreError {
99    /// Serialization of a phase value failed.
100    #[error("Phase serialization failed: {details}")]
101    Serialize {
102        /// Details about the serialization failure.
103        details: String,
104    },
105
106    /// Deserialization of a phase value failed.
107    #[error("Phase deserialization failed: {details}")]
108    Deserialize {
109        /// Details about the deserialization failure.
110        details: String,
111    },
112}
113
114/// Compute a stable u64 key from a TypeId.
115fn type_key<T: 'static>() -> u64 {
116    use std::hash::{Hash, Hasher};
117    let mut hasher = std::collections::hash_map::DefaultHasher::new();
118    TypeId::of::<T>().hash(&mut hasher);
119    hasher.finish()
120}
121
122#[cfg(test)]
123mod tests {
124    use super::*;
125
126    #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
127    enum TestPhase {
128        Start,
129        Middle { value: String },
130        End,
131    }
132
133    #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
134    enum OtherPhase {
135        Only,
136    }
137
138    #[test]
139    fn test_set_and_get_round_trip() {
140        let mut store = PhaseStateStore::new();
141        let phase = TestPhase::Middle {
142            value: "hello".into(),
143        };
144        store.set(&phase).unwrap();
145        let retrieved: TestPhase = store.get::<TestPhase>().unwrap().unwrap();
146        assert_eq!(retrieved, phase);
147    }
148
149    #[test]
150    fn test_get_missing_type_returns_none() {
151        let store = PhaseStateStore::new();
152        let result = store.get::<OtherPhase>().unwrap();
153        assert!(result.is_none());
154    }
155
156    #[test]
157    fn test_different_types_independent() {
158        let mut store = PhaseStateStore::new();
159        store.set(&TestPhase::Start).unwrap();
160        store.set(&OtherPhase::Only).unwrap();
161
162        assert_eq!(store.get::<TestPhase>().unwrap(), Some(TestPhase::Start));
163        assert_eq!(store.get::<OtherPhase>().unwrap(), Some(OtherPhase::Only));
164    }
165
166    #[test]
167    fn test_overwrite_same_type() {
168        let mut store = PhaseStateStore::new();
169        store.set(&TestPhase::Start).unwrap();
170        store.set(&TestPhase::End).unwrap();
171        assert_eq!(store.get::<TestPhase>().unwrap(), Some(TestPhase::End));
172    }
173
174    #[test]
175    fn test_clear_removes_type() {
176        let mut store = PhaseStateStore::new();
177        store.set(&TestPhase::Start).unwrap();
178        assert!(store.contains::<TestPhase>());
179        store.clear::<TestPhase>();
180        assert!(!store.contains::<TestPhase>());
181        assert!(store.get::<TestPhase>().unwrap().is_none());
182    }
183
184    #[test]
185    fn test_empty_store() {
186        let store = PhaseStateStore::new();
187        assert!(store.is_empty());
188    }
189
190    #[test]
191    fn test_store_serialization_round_trip() {
192        let mut store = PhaseStateStore::new();
193        store
194            .set(&TestPhase::Middle {
195                value: "ser".into(),
196            })
197            .unwrap();
198
199        let bytes = bincode::serialize(&store).unwrap();
200        let restored: PhaseStateStore = bincode::deserialize(&bytes).unwrap();
201
202        let phase: TestPhase = restored.get::<TestPhase>().unwrap().unwrap();
203        assert_eq!(
204            phase,
205            TestPhase::Middle {
206                value: "ser".into(),
207            }
208        );
209    }
210}