#[cfg(all(test, feature = "tokio"))]
#[path = "../../tests/cache/map.rs"]
mod tests;
use std::cell::UnsafeCell;
use std::collections::HashMap;
use std::hash::Hash;
use std::marker::PhantomData;
use std::sync::{Arc, Weak};
use rand::RngCore;
use crate::cache::common::{CacheError, Versioned};
use crate::utils::random::get_rng;
use crate::utils::sync::RwLock;
pub(crate) type SharedState<K, V> = RwLock<HashMap<K, Versioned<V>>>;
struct LocalEntry<V> {
value: V,
source_version: u64,
}
pub(crate) struct SharedMap<K: Clone + Eq + Hash + Send + ToString, V: Clone + Send> {
state: Arc<SharedState<K, V>>,
local: HashMap<K, V>,
_not_sync: PhantomData<UnsafeCell<()>>,
}
impl<K: Clone + Eq + Hash + Send + ToString, V: Clone + Send> SharedMap<K, V> {
pub(crate) fn new() -> Self {
SharedMap {
state: Arc::new(RwLock::new(HashMap::new())),
local: HashMap::new(),
_not_sync: PhantomData,
}
}
pub(crate) async fn insert(&mut self, key: K, value: V) {
self.state.write().await.insert(
key.clone(),
Versioned {
value: value.clone(),
version: get_rng().next_u64(),
},
);
self.local.insert(key, value);
}
pub(crate) async fn remove(&mut self, key: &K) {
self.local.remove(key);
self.state.write().await.remove(key);
}
pub(crate) fn contains_key(&self, key: &K) -> bool {
self.local.contains_key(key)
}
pub(crate) async fn modify<F: FnOnce(&mut V)>(&mut self, key: &K, f: F) {
if let Some(local) = self.local.get_mut(key) {
f(local);
let versioned = Versioned {
value: local.clone(),
version: get_rng().next_u64(),
};
self.state.write().await.insert(key.clone(), versioned);
}
}
pub(crate) fn create_cache(&self) -> CachedMap<K, V> {
CachedMap {
source: Arc::downgrade(&self.state),
local: HashMap::new(),
_not_sync: PhantomData,
}
}
pub(crate) fn create_cache_for(&self, key: K) -> CachedMapEntryTemplate<K, V> {
CachedMapEntryTemplate {
source: Arc::downgrade(&self.state),
key,
}
}
}
impl<K: Clone + Eq + Hash + Send + ToString, V: Clone + Send> Default for SharedMap<K, V> {
fn default() -> Self {
Self::new()
}
}
pub(crate) struct CachedMap<K: Clone + Eq + Hash + Send + ToString, V: Clone + Send> {
source: Weak<SharedState<K, V>>,
local: HashMap<K, LocalEntry<V>>,
_not_sync: PhantomData<UnsafeCell<()>>,
}
impl<K: Clone + Eq + Hash + Send + ToString, V: Clone + Send> CachedMap<K, V> {
async fn fetch(&mut self, key: &K) -> Result<&mut LocalEntry<V>, CacheError> {
let source = self.source.upgrade().ok_or(CacheError::SourceDropped)?;
let guard = source.read().await;
if let Some(entry) = guard.get(key) {
let needs_update = self.local.get(key).is_none_or(|local| local.source_version != entry.version);
if needs_update {
self.local.insert(
key.clone(),
LocalEntry {
value: entry.value.clone(),
source_version: entry.version,
},
);
}
drop(guard);
Ok(self.local.get_mut(key).unwrap())
} else {
self.local.remove(key);
Err(CacheError::KeyNotFound(key.to_string()))
}
}
pub(crate) async fn get_mut(&mut self, key: &K) -> Result<&mut V, CacheError> {
Ok(&mut self.fetch(key).await?.value)
}
}
pub(crate) struct CachedMapEntryTemplate<K: Clone + Eq + Hash + Send + ToString, V: Clone + Send> {
source: Weak<SharedState<K, V>>,
key: K,
}
impl<K: Clone + Eq + Hash + Send + ToString, V: Clone + Send> CachedMapEntryTemplate<K, V> {
pub(crate) fn create_entry(&self) -> CachedMapEntry<K, V> {
CachedMapEntry {
source: self.source.clone(),
key: self.key.clone(),
local: None,
_not_sync: PhantomData,
}
}
}
pub(crate) struct CachedMapEntry<K: Clone + Eq + Hash + Send + ToString, V: Clone + Send> {
source: Weak<SharedState<K, V>>,
key: K,
local: Option<LocalEntry<V>>,
_not_sync: PhantomData<UnsafeCell<()>>,
}
impl<K: Clone + Eq + Hash + Send + ToString, V: Clone + Send> CachedMapEntry<K, V> {
async fn fetch(&mut self) -> Result<&mut LocalEntry<V>, CacheError> {
let source = self.source.upgrade().ok_or(CacheError::SourceDropped)?;
let guard = source.read().await;
if let Some(entry) = guard.get(&self.key) {
let needs_update = self.local.as_ref().is_none_or(|local| local.source_version != entry.version);
if needs_update {
self.local = Some(LocalEntry {
value: entry.value.clone(),
source_version: entry.version,
});
}
drop(guard);
Ok(self.local.as_mut().unwrap())
} else {
self.local = None;
Err(CacheError::KeyNotFound(self.key.to_string()))
}
}
pub(crate) async fn get_mut(&mut self) -> Result<&mut V, CacheError> {
Ok(&mut self.fetch().await?.value)
}
}