Skip to main content

atomr_cluster_sharding/
remember_entities.rs

1//! Remember-entities — persist active entity ids so they restart on
2//! shard re-allocation.
3//!
4//! [`RememberedEntities`] is the in-memory book-keeping layer
5//! (per-shard entity-id sets). [`RememberEntitiesStore`] is the
6//! pluggable trait the shard region calls to persist / load the
7//! set across restarts. The default [`InMemoryRememberStore`] is
8//! suitable for tests; production shard regions wire a journal- or
9//! ddata-backed implementation.
10
11use std::collections::{HashMap, HashSet};
12
13use async_trait::async_trait;
14use parking_lot::RwLock;
15use thiserror::Error;
16
17#[derive(Debug, Error)]
18#[non_exhaustive]
19pub enum RememberError {
20    #[error("backend error: {0}")]
21    Backend(String),
22}
23
24/// Pluggable persistence store for remembered entities.
25#[async_trait]
26pub trait RememberEntitiesStore: Send + Sync + 'static {
27    /// Load the full entity-id set for `shard_id`.
28    async fn load(&self, shard_id: &str) -> Result<HashSet<String>, RememberError>;
29
30    /// Persist that `entity_id` is now active in `shard_id`.
31    async fn add(&self, shard_id: &str, entity_id: &str) -> Result<(), RememberError>;
32
33    /// Persist that `entity_id` is no longer active in `shard_id`
34    /// (typically after passivation).
35    async fn remove(&self, shard_id: &str, entity_id: &str) -> Result<(), RememberError>;
36}
37
38/// In-process registry of remembered entity ids. Wraps a
39/// [`RememberEntitiesStore`] and serves quick lookups from a local
40/// snapshot.
41pub struct RememberedEntities {
42    store: std::sync::Arc<dyn RememberEntitiesStore>,
43    cache: RwLock<HashMap<String, HashSet<String>>>, // shard_id -> ids
44}
45
46impl RememberedEntities {
47    pub fn new(store: std::sync::Arc<dyn RememberEntitiesStore>) -> Self {
48        Self { store, cache: RwLock::new(HashMap::new()) }
49    }
50
51    /// Refresh cache from the backing store. Idempotent.
52    pub async fn warm(&self, shard_id: &str) -> Result<(), RememberError> {
53        let ids = self.store.load(shard_id).await?;
54        self.cache.write().insert(shard_id.into(), ids);
55        Ok(())
56    }
57
58    /// Mark `entity_id` active. Updates the cache and the store.
59    pub async fn record_active(&self, shard_id: &str, entity_id: &str) -> Result<(), RememberError> {
60        self.store.add(shard_id, entity_id).await?;
61        self.cache.write().entry(shard_id.into()).or_default().insert(entity_id.into());
62        Ok(())
63    }
64
65    /// Mark `entity_id` inactive (passivated/stopped).
66    pub async fn record_inactive(&self, shard_id: &str, entity_id: &str) -> Result<(), RememberError> {
67        self.store.remove(shard_id, entity_id).await?;
68        if let Some(set) = self.cache.write().get_mut(shard_id) {
69            set.remove(entity_id);
70        }
71        Ok(())
72    }
73
74    /// Snapshot of currently-known entity ids for `shard_id`.
75    pub fn entities(&self, shard_id: &str) -> HashSet<String> {
76        self.cache.read().get(shard_id).cloned().unwrap_or_default()
77    }
78
79    pub fn shard_count(&self) -> usize {
80        self.cache.read().len()
81    }
82}
83
84/// In-memory store — for tests and as a reference implementation.
85#[derive(Default)]
86pub struct InMemoryRememberStore {
87    inner: RwLock<HashMap<String, HashSet<String>>>,
88}
89
90impl InMemoryRememberStore {
91    pub fn new() -> Self {
92        Self::default()
93    }
94}
95
96#[async_trait]
97impl RememberEntitiesStore for InMemoryRememberStore {
98    async fn load(&self, shard_id: &str) -> Result<HashSet<String>, RememberError> {
99        Ok(self.inner.read().get(shard_id).cloned().unwrap_or_default())
100    }
101
102    async fn add(&self, shard_id: &str, entity_id: &str) -> Result<(), RememberError> {
103        self.inner.write().entry(shard_id.into()).or_default().insert(entity_id.into());
104        Ok(())
105    }
106
107    async fn remove(&self, shard_id: &str, entity_id: &str) -> Result<(), RememberError> {
108        if let Some(set) = self.inner.write().get_mut(shard_id) {
109            set.remove(entity_id);
110        }
111        Ok(())
112    }
113}
114
115#[cfg(test)]
116mod tests {
117    use super::*;
118    use std::sync::Arc;
119
120    #[tokio::test]
121    async fn record_and_warm_round_trip() {
122        let store: Arc<dyn RememberEntitiesStore> = Arc::new(InMemoryRememberStore::new());
123        let r = RememberedEntities::new(store.clone());
124
125        r.record_active("s1", "e1").await.unwrap();
126        r.record_active("s1", "e2").await.unwrap();
127        r.record_active("s2", "e3").await.unwrap();
128
129        // Fresh registry recovers from the store.
130        let r2 = RememberedEntities::new(store);
131        r2.warm("s1").await.unwrap();
132        let ids = r2.entities("s1");
133        assert_eq!(ids.len(), 2);
134        assert!(ids.contains("e1") && ids.contains("e2"));
135    }
136
137    #[tokio::test]
138    async fn record_inactive_drops_from_set() {
139        let store: Arc<dyn RememberEntitiesStore> = Arc::new(InMemoryRememberStore::new());
140        let r = RememberedEntities::new(store);
141        r.record_active("s1", "e1").await.unwrap();
142        r.record_active("s1", "e2").await.unwrap();
143        r.record_inactive("s1", "e1").await.unwrap();
144        let ids = r.entities("s1");
145        assert_eq!(ids.len(), 1);
146        assert!(ids.contains("e2"));
147    }
148
149    #[tokio::test]
150    async fn shard_count_tracks_distinct_shards() {
151        let store: Arc<dyn RememberEntitiesStore> = Arc::new(InMemoryRememberStore::new());
152        let r = RememberedEntities::new(store);
153        r.record_active("s1", "e1").await.unwrap();
154        r.record_active("s2", "e2").await.unwrap();
155        r.record_active("s3", "e3").await.unwrap();
156        assert_eq!(r.shard_count(), 3);
157    }
158}