pe-core 0.1.0

Core types for Potential Expectations — messages, channels, state, traits
Documentation
//! Phase state store -- TypeId-keyed storage for node phase enums.
//!
//! Nodes that use the interrupt system (via `node!` DSL or `#[node]` Mode 2)
//! store their current phase in this store. The store lives inside checkpoint
//! data so phases survive serialization across interrupt/resume boundaries.
//!
//! Phase enums are type-erased via bincode serialization: each entry is stored
//! as `(u64 type hash, Vec<u8> serialized bytes)` so the store itself is
//! `Serialize + Deserialize` without knowing concrete phase types.

use serde::{Deserialize, Serialize};
use std::any::TypeId;
use std::collections::HashMap;

/// Stores serialized phase state for nodes, keyed by phase enum TypeId.
///
/// Phase enums generated by macros implement `Serialize + Deserialize`.
/// The store serializes them to bytes on `set()` and deserializes on `get()`,
/// so the store itself is serializable without generic parameters.
///
/// # Example
///
/// ```ignore
/// let mut store = PhaseStateStore::new();
/// store.set::<MyPhase>(&MyPhase::GatherMore { position })?;
/// let phase: MyPhase = store.get::<MyPhase>()?.unwrap();
/// ```
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct PhaseStateStore {
    /// TypeId is not serializable, so we use a u64 hash as the key.
    /// This is stable within a single compilation (same types = same hash).
    entries: HashMap<u64, Vec<u8>>,
}

impl PhaseStateStore {
    /// Create an empty phase state store.
    pub fn new() -> Self {
        Self::default()
    }

    /// Store a phase value, serializing it to bytes.
    ///
    /// Overwrites any existing phase of the same type.
    #[must_use = "this returns a Result that must be checked"]
    pub fn set<T>(&mut self, value: &T) -> Result<(), PhaseStoreError>
    where
        T: Serialize + 'static,
    {
        let key = type_key::<T>();
        let bytes = bincode::serialize(value).map_err(|e| PhaseStoreError::Serialize {
            details: e.to_string(),
        })?;
        self.entries.insert(key, bytes);
        Ok(())
    }

    /// Retrieve and deserialize a phase value by type.
    ///
    /// Returns `None` if no phase of this type is stored.
    #[must_use = "this returns a Result that must be checked"]
    pub fn get<T>(&self) -> Result<Option<T>, PhaseStoreError>
    where
        T: for<'de> Deserialize<'de> + 'static,
    {
        let key = type_key::<T>();
        match self.entries.get(&key) {
            Some(bytes) => {
                let val =
                    bincode::deserialize(bytes).map_err(|e| PhaseStoreError::Deserialize {
                        details: e.to_string(),
                    })?;
                Ok(Some(val))
            }
            None => Ok(None),
        }
    }

    /// Remove the phase state for a given type (used on `complete`).
    pub fn clear<T: 'static>(&mut self) {
        let key = type_key::<T>();
        self.entries.remove(&key);
    }

    /// Whether the store contains a phase of the given type.
    pub fn contains<T: 'static>(&self) -> bool {
        self.entries.contains_key(&type_key::<T>())
    }

    /// Whether the store is empty (no phases stored).
    pub fn is_empty(&self) -> bool {
        self.entries.is_empty()
    }
}

/// Errors from phase state serialization/deserialization.
#[derive(Debug, Clone, thiserror::Error)]
#[non_exhaustive]
pub enum PhaseStoreError {
    /// Serialization of a phase value failed.
    #[error("Phase serialization failed: {details}")]
    Serialize {
        /// Details about the serialization failure.
        details: String,
    },

    /// Deserialization of a phase value failed.
    #[error("Phase deserialization failed: {details}")]
    Deserialize {
        /// Details about the deserialization failure.
        details: String,
    },
}

/// Compute a stable u64 key from a TypeId.
fn type_key<T: 'static>() -> u64 {
    use std::hash::{Hash, Hasher};
    let mut hasher = std::collections::hash_map::DefaultHasher::new();
    TypeId::of::<T>().hash(&mut hasher);
    hasher.finish()
}

#[cfg(test)]
mod tests {
    use super::*;

    #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
    enum TestPhase {
        Start,
        Middle { value: String },
        End,
    }

    #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
    enum OtherPhase {
        Only,
    }

    #[test]
    fn test_set_and_get_round_trip() {
        let mut store = PhaseStateStore::new();
        let phase = TestPhase::Middle {
            value: "hello".into(),
        };
        store.set(&phase).unwrap();
        let retrieved: TestPhase = store.get::<TestPhase>().unwrap().unwrap();
        assert_eq!(retrieved, phase);
    }

    #[test]
    fn test_get_missing_type_returns_none() {
        let store = PhaseStateStore::new();
        let result = store.get::<OtherPhase>().unwrap();
        assert!(result.is_none());
    }

    #[test]
    fn test_different_types_independent() {
        let mut store = PhaseStateStore::new();
        store.set(&TestPhase::Start).unwrap();
        store.set(&OtherPhase::Only).unwrap();

        assert_eq!(store.get::<TestPhase>().unwrap(), Some(TestPhase::Start));
        assert_eq!(store.get::<OtherPhase>().unwrap(), Some(OtherPhase::Only));
    }

    #[test]
    fn test_overwrite_same_type() {
        let mut store = PhaseStateStore::new();
        store.set(&TestPhase::Start).unwrap();
        store.set(&TestPhase::End).unwrap();
        assert_eq!(store.get::<TestPhase>().unwrap(), Some(TestPhase::End));
    }

    #[test]
    fn test_clear_removes_type() {
        let mut store = PhaseStateStore::new();
        store.set(&TestPhase::Start).unwrap();
        assert!(store.contains::<TestPhase>());
        store.clear::<TestPhase>();
        assert!(!store.contains::<TestPhase>());
        assert!(store.get::<TestPhase>().unwrap().is_none());
    }

    #[test]
    fn test_empty_store() {
        let store = PhaseStateStore::new();
        assert!(store.is_empty());
    }

    #[test]
    fn test_store_serialization_round_trip() {
        let mut store = PhaseStateStore::new();
        store
            .set(&TestPhase::Middle {
                value: "ser".into(),
            })
            .unwrap();

        let bytes = bincode::serialize(&store).unwrap();
        let restored: PhaseStateStore = bincode::deserialize(&bytes).unwrap();

        let phase: TestPhase = restored.get::<TestPhase>().unwrap().unwrap();
        assert_eq!(
            phase,
            TestPhase::Middle {
                value: "ser".into(),
            }
        );
    }
}