atomr_cluster_sharding/
remember_entities.rs1use std::collections::{HashMap, HashSet};
12
13use async_trait::async_trait;
14use parking_lot::RwLock;
15use thiserror::Error;
16
17#[derive(Debug, Error)]
18#[non_exhaustive]
19pub enum RememberError {
20 #[error("backend error: {0}")]
21 Backend(String),
22}
23
24#[async_trait]
26pub trait RememberEntitiesStore: Send + Sync + 'static {
27 async fn load(&self, shard_id: &str) -> Result<HashSet<String>, RememberError>;
29
30 async fn add(&self, shard_id: &str, entity_id: &str) -> Result<(), RememberError>;
32
33 async fn remove(&self, shard_id: &str, entity_id: &str) -> Result<(), RememberError>;
36}
37
38pub struct RememberedEntities {
42 store: std::sync::Arc<dyn RememberEntitiesStore>,
43 cache: RwLock<HashMap<String, HashSet<String>>>, }
45
46impl RememberedEntities {
47 pub fn new(store: std::sync::Arc<dyn RememberEntitiesStore>) -> Self {
48 Self { store, cache: RwLock::new(HashMap::new()) }
49 }
50
51 pub async fn warm(&self, shard_id: &str) -> Result<(), RememberError> {
53 let ids = self.store.load(shard_id).await?;
54 self.cache.write().insert(shard_id.into(), ids);
55 Ok(())
56 }
57
58 pub async fn record_active(&self, shard_id: &str, entity_id: &str) -> Result<(), RememberError> {
60 self.store.add(shard_id, entity_id).await?;
61 self.cache.write().entry(shard_id.into()).or_default().insert(entity_id.into());
62 Ok(())
63 }
64
65 pub async fn record_inactive(&self, shard_id: &str, entity_id: &str) -> Result<(), RememberError> {
67 self.store.remove(shard_id, entity_id).await?;
68 if let Some(set) = self.cache.write().get_mut(shard_id) {
69 set.remove(entity_id);
70 }
71 Ok(())
72 }
73
74 pub fn entities(&self, shard_id: &str) -> HashSet<String> {
76 self.cache.read().get(shard_id).cloned().unwrap_or_default()
77 }
78
79 pub fn shard_count(&self) -> usize {
80 self.cache.read().len()
81 }
82}
83
84#[derive(Default)]
86pub struct InMemoryRememberStore {
87 inner: RwLock<HashMap<String, HashSet<String>>>,
88}
89
90impl InMemoryRememberStore {
91 pub fn new() -> Self {
92 Self::default()
93 }
94}
95
96#[async_trait]
97impl RememberEntitiesStore for InMemoryRememberStore {
98 async fn load(&self, shard_id: &str) -> Result<HashSet<String>, RememberError> {
99 Ok(self.inner.read().get(shard_id).cloned().unwrap_or_default())
100 }
101
102 async fn add(&self, shard_id: &str, entity_id: &str) -> Result<(), RememberError> {
103 self.inner.write().entry(shard_id.into()).or_default().insert(entity_id.into());
104 Ok(())
105 }
106
107 async fn remove(&self, shard_id: &str, entity_id: &str) -> Result<(), RememberError> {
108 if let Some(set) = self.inner.write().get_mut(shard_id) {
109 set.remove(entity_id);
110 }
111 Ok(())
112 }
113}
114
115#[cfg(test)]
116mod tests {
117 use super::*;
118 use std::sync::Arc;
119
120 #[tokio::test]
121 async fn record_and_warm_round_trip() {
122 let store: Arc<dyn RememberEntitiesStore> = Arc::new(InMemoryRememberStore::new());
123 let r = RememberedEntities::new(store.clone());
124
125 r.record_active("s1", "e1").await.unwrap();
126 r.record_active("s1", "e2").await.unwrap();
127 r.record_active("s2", "e3").await.unwrap();
128
129 let r2 = RememberedEntities::new(store);
131 r2.warm("s1").await.unwrap();
132 let ids = r2.entities("s1");
133 assert_eq!(ids.len(), 2);
134 assert!(ids.contains("e1") && ids.contains("e2"));
135 }
136
137 #[tokio::test]
138 async fn record_inactive_drops_from_set() {
139 let store: Arc<dyn RememberEntitiesStore> = Arc::new(InMemoryRememberStore::new());
140 let r = RememberedEntities::new(store);
141 r.record_active("s1", "e1").await.unwrap();
142 r.record_active("s1", "e2").await.unwrap();
143 r.record_inactive("s1", "e1").await.unwrap();
144 let ids = r.entities("s1");
145 assert_eq!(ids.len(), 1);
146 assert!(ids.contains("e2"));
147 }
148
149 #[tokio::test]
150 async fn shard_count_tracks_distinct_shards() {
151 let store: Arc<dyn RememberEntitiesStore> = Arc::new(InMemoryRememberStore::new());
152 let r = RememberedEntities::new(store);
153 r.record_active("s1", "e1").await.unwrap();
154 r.record_active("s2", "e2").await.unwrap();
155 r.record_active("s3", "e3").await.unwrap();
156 assert_eq!(r.shard_count(), 3);
157 }
158}