use std::collections::{HashMap, HashSet};
use async_trait::async_trait;
use parking_lot::RwLock;
use thiserror::Error;
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum RememberError {
#[error("backend error: {0}")]
Backend(String),
}
#[async_trait]
pub trait RememberEntitiesStore: Send + Sync + 'static {
async fn load(&self, shard_id: &str) -> Result<HashSet<String>, RememberError>;
async fn add(&self, shard_id: &str, entity_id: &str) -> Result<(), RememberError>;
async fn remove(&self, shard_id: &str, entity_id: &str) -> Result<(), RememberError>;
}
pub struct RememberedEntities {
store: std::sync::Arc<dyn RememberEntitiesStore>,
cache: RwLock<HashMap<String, HashSet<String>>>, }
impl RememberedEntities {
pub fn new(store: std::sync::Arc<dyn RememberEntitiesStore>) -> Self {
Self { store, cache: RwLock::new(HashMap::new()) }
}
pub async fn warm(&self, shard_id: &str) -> Result<(), RememberError> {
let ids = self.store.load(shard_id).await?;
self.cache.write().insert(shard_id.into(), ids);
Ok(())
}
pub async fn record_active(&self, shard_id: &str, entity_id: &str) -> Result<(), RememberError> {
self.store.add(shard_id, entity_id).await?;
self.cache.write().entry(shard_id.into()).or_default().insert(entity_id.into());
Ok(())
}
pub async fn record_inactive(&self, shard_id: &str, entity_id: &str) -> Result<(), RememberError> {
self.store.remove(shard_id, entity_id).await?;
if let Some(set) = self.cache.write().get_mut(shard_id) {
set.remove(entity_id);
}
Ok(())
}
pub fn entities(&self, shard_id: &str) -> HashSet<String> {
self.cache.read().get(shard_id).cloned().unwrap_or_default()
}
pub fn shard_count(&self) -> usize {
self.cache.read().len()
}
}
#[derive(Default)]
pub struct InMemoryRememberStore {
inner: RwLock<HashMap<String, HashSet<String>>>,
}
impl InMemoryRememberStore {
pub fn new() -> Self {
Self::default()
}
}
#[async_trait]
impl RememberEntitiesStore for InMemoryRememberStore {
async fn load(&self, shard_id: &str) -> Result<HashSet<String>, RememberError> {
Ok(self.inner.read().get(shard_id).cloned().unwrap_or_default())
}
async fn add(&self, shard_id: &str, entity_id: &str) -> Result<(), RememberError> {
self.inner.write().entry(shard_id.into()).or_default().insert(entity_id.into());
Ok(())
}
async fn remove(&self, shard_id: &str, entity_id: &str) -> Result<(), RememberError> {
if let Some(set) = self.inner.write().get_mut(shard_id) {
set.remove(entity_id);
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
#[tokio::test]
async fn record_and_warm_round_trip() {
let store: Arc<dyn RememberEntitiesStore> = Arc::new(InMemoryRememberStore::new());
let r = RememberedEntities::new(store.clone());
r.record_active("s1", "e1").await.unwrap();
r.record_active("s1", "e2").await.unwrap();
r.record_active("s2", "e3").await.unwrap();
let r2 = RememberedEntities::new(store);
r2.warm("s1").await.unwrap();
let ids = r2.entities("s1");
assert_eq!(ids.len(), 2);
assert!(ids.contains("e1") && ids.contains("e2"));
}
#[tokio::test]
async fn record_inactive_drops_from_set() {
let store: Arc<dyn RememberEntitiesStore> = Arc::new(InMemoryRememberStore::new());
let r = RememberedEntities::new(store);
r.record_active("s1", "e1").await.unwrap();
r.record_active("s1", "e2").await.unwrap();
r.record_inactive("s1", "e1").await.unwrap();
let ids = r.entities("s1");
assert_eq!(ids.len(), 1);
assert!(ids.contains("e2"));
}
#[tokio::test]
async fn shard_count_tracks_distinct_shards() {
let store: Arc<dyn RememberEntitiesStore> = Arc::new(InMemoryRememberStore::new());
let r = RememberedEntities::new(store);
r.record_active("s1", "e1").await.unwrap();
r.record_active("s2", "e2").await.unwrap();
r.record_active("s3", "e3").await.unwrap();
assert_eq!(r.shard_count(), 3);
}
}