use dashmap::DashMap;
use std::future::Future;
use std::hash::Hash;
use std::pin::Pin;
use std::sync::Arc;
use tokio::sync::OnceCell;
use tokio_util::sync::CancellationToken;
use crate::error::ClusterError;
type ResourceFactory<K, V> =
Box<dyn Fn(K) -> Pin<Box<dyn Future<Output = Result<V, ClusterError>> + Send>> + Send + Sync>;
pub struct ResourceMap<K, V> {
map: DashMap<K, Arc<ResourceEntry<V>>>,
factory: ResourceFactory<K, V>,
}
pub struct ResourceEntry<V> {
value: OnceCell<V>,
cancel: CancellationToken,
}
impl<V> ResourceEntry<V> {
pub fn value(&self) -> Option<&V> {
self.value.get()
}
pub fn cancel(&self) -> &CancellationToken {
&self.cancel
}
}
impl<K, V> ResourceMap<K, V>
where
K: Hash + Eq + Clone + Send + Sync + 'static,
V: Send + Sync + 'static,
{
pub fn new<F, Fut>(factory: F) -> Self
where
F: Fn(K) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<V, ClusterError>> + Send + 'static,
{
Self {
map: DashMap::new(),
factory: Box::new(move |key| Box::pin(factory(key))),
}
}
pub async fn get(&self, key: &K) -> Result<Arc<ResourceEntry<V>>, ClusterError> {
let entry = self
.map
.entry(key.clone())
.or_insert_with(|| {
Arc::new(ResourceEntry {
value: OnceCell::new(),
cancel: CancellationToken::new(),
})
})
.clone();
entry
.value
.get_or_try_init(|| (self.factory)(key.clone()))
.await?;
Ok(entry)
}
pub fn remove(&self, key: &K) {
if let Some((_, entry)) = self.map.remove(key) {
entry.cancel.cancel();
}
}
pub fn clear(&self) {
let keys: Vec<K> = self.map.iter().map(|e| e.key().clone()).collect();
for key in keys {
if let Some((_, entry)) = self.map.remove(&key) {
entry.cancel.cancel();
}
}
}
pub fn len(&self) -> usize {
self.map.len()
}
pub fn is_empty(&self) -> bool {
self.map.is_empty()
}
pub fn contains_key(&self, key: &K) -> bool {
self.map.contains_key(key)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicI32, Ordering};
#[tokio::test]
async fn lazily_creates_resource_on_first_access() {
let call_count = Arc::new(AtomicI32::new(0));
let cc = Arc::clone(&call_count);
let map = ResourceMap::new(move |key: String| {
let cc = Arc::clone(&cc);
async move {
cc.fetch_add(1, Ordering::Relaxed);
Ok(format!("value-{key}"))
}
});
assert_eq!(map.len(), 0);
let entry = map.get(&"k1".to_string()).await.unwrap();
assert_eq!(entry.value().unwrap(), "value-k1");
assert_eq!(call_count.load(Ordering::Relaxed), 1);
assert_eq!(map.len(), 1);
}
#[tokio::test]
async fn concurrent_get_same_key_calls_factory_once() {
let call_count = Arc::new(AtomicI32::new(0));
let cc = Arc::clone(&call_count);
let map = Arc::new(ResourceMap::new(move |key: String| {
let cc = Arc::clone(&cc);
async move {
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
cc.fetch_add(1, Ordering::Relaxed);
Ok(format!("value-{key}"))
}
}));
let key = "k1".to_string();
let mut handles = vec![];
for _ in 0..5 {
let m = Arc::clone(&map);
let k = key.clone();
handles.push(tokio::spawn(async move { m.get(&k).await }));
}
for h in handles {
let entry = h.await.unwrap().unwrap();
assert_eq!(entry.value().unwrap(), "value-k1");
}
assert_eq!(call_count.load(Ordering::Relaxed), 1);
}
#[tokio::test]
async fn remove_allows_recreation() {
let call_count = Arc::new(AtomicI32::new(0));
let cc = Arc::clone(&call_count);
let map = ResourceMap::new(move |key: String| {
let cc = Arc::clone(&cc);
async move {
cc.fetch_add(1, Ordering::Relaxed);
Ok(format!("value-{key}"))
}
});
let key = "k1".to_string();
map.get(&key).await.unwrap();
assert_eq!(call_count.load(Ordering::Relaxed), 1);
map.remove(&key);
assert_eq!(map.len(), 0);
map.get(&key).await.unwrap();
assert_eq!(call_count.load(Ordering::Relaxed), 2);
}
#[tokio::test]
async fn remove_cancels_token() {
let map = ResourceMap::new(|_key: String| async { Ok("value") });
let key = "k1".to_string();
let entry = map.get(&key).await.unwrap();
let cancel = entry.cancel().clone();
assert!(!cancel.is_cancelled());
map.remove(&key);
assert!(cancel.is_cancelled());
}
#[tokio::test]
async fn factory_error_allows_retry() {
let call_count = Arc::new(AtomicI32::new(0));
let cc = Arc::clone(&call_count);
let map = ResourceMap::new(move |_key: String| {
let cc = Arc::clone(&cc);
async move {
let n = cc.fetch_add(1, Ordering::Relaxed);
if n == 0 {
Err(ClusterError::PersistenceError {
reason: "transient".into(),
source: None,
})
} else {
Ok("recovered".to_string())
}
}
});
let key = "k1".to_string();
assert!(map.get(&key).await.is_err());
let entry = map.get(&key).await.unwrap();
assert_eq!(entry.value().unwrap(), "recovered");
assert_eq!(call_count.load(Ordering::Relaxed), 2);
}
#[tokio::test]
async fn clear_removes_all_and_cancels() {
let map = ResourceMap::new(|key: String| async move { Ok(key) });
let entry_a = map.get(&"a".to_string()).await.unwrap();
let entry_b = map.get(&"b".to_string()).await.unwrap();
let cancel_a = entry_a.cancel().clone();
let cancel_b = entry_b.cancel().clone();
assert_eq!(map.len(), 2);
map.clear();
assert_eq!(map.len(), 0);
assert!(cancel_a.is_cancelled());
assert!(cancel_b.is_cancelled());
}
#[tokio::test]
async fn contains_key_and_is_empty() {
let map = ResourceMap::new(|key: String| async move { Ok(key) });
assert!(map.is_empty());
assert!(!map.contains_key(&"k".to_string()));
map.get(&"k".to_string()).await.unwrap();
assert!(!map.is_empty());
assert!(map.contains_key(&"k".to_string()));
}
}