use super::AsyncLazy;
use async_lock::RwLock;
use std::{collections::HashMap, future::Future, hash::Hash, sync::Arc};
#[derive(Debug)]
pub(crate) struct AsyncCache<K, V> {
map: RwLock<HashMap<K, Arc<AsyncLazy<V>>>>,
}
impl<K, V> AsyncCache<K, V>
where
K: Eq + Hash + Clone,
{
pub(crate) fn new() -> Self {
Self {
map: RwLock::new(HashMap::new()),
}
}
pub(crate) async fn get_or_insert_with<F, Fut>(&self, key: K, factory: F) -> Arc<V>
where
F: FnOnce() -> Fut,
Fut: Future<Output = V>,
{
{
let read_guard = self.map.read().await;
if let Some(lazy) = read_guard.get(&key) {
let lazy_clone = lazy.clone();
drop(read_guard); return lazy_clone.get_or_init(factory).await;
}
}
let lazy = {
let mut write_guard = self.map.write().await;
if let Some(existing) = write_guard.get(&key) {
existing.clone()
} else {
let new_lazy = Arc::new(AsyncLazy::new());
write_guard.insert(key, new_lazy.clone());
new_lazy
}
};
lazy.get_or_init(factory).await
}
pub(crate) async fn get(&self, key: &K) -> Option<Arc<V>> {
let read_guard = self.map.read().await;
read_guard.get(key).and_then(|lazy| lazy.try_get())
}
pub(crate) async fn invalidate(&self, key: &K) -> Option<Arc<V>> {
let mut write_guard = self.map.write().await;
write_guard.remove(key).and_then(|lazy| lazy.try_get())
}
#[cfg(test)]
pub(crate) async fn clear(&self) {
let mut write_guard = self.map.write().await;
write_guard.clear();
}
#[allow(dead_code)] pub(crate) async fn get_or_refresh_with<F, Fut, P>(
&self,
key: K,
should_force_refresh: P,
factory: F,
) -> Option<Arc<V>>
where
F: FnOnce() -> Fut,
Fut: Future<Output = V>,
P: FnOnce(Option<&V>) -> bool,
{
let (initial_lazy, existing_value, is_initializing) = {
let read_guard = self.map.read().await;
match read_guard.get(&key) {
Some(lazy) => {
let lazy_clone = lazy.clone();
let current = lazy.try_get();
let initializing = current.is_none();
(Some(lazy_clone), current, initializing)
}
None => (None, None, false),
}
};
if is_initializing {
if let Some(lazy) = initial_lazy {
return Some(lazy.get().await);
}
}
let needs_refresh = should_force_refresh(existing_value.as_ref().map(|v| v.as_ref()));
if !needs_refresh {
return existing_value;
}
let new_lazy = {
let mut write_guard = self.map.write().await;
match (initial_lazy.as_ref(), write_guard.get(&key)) {
(None, None) => {
let lazy = Arc::new(AsyncLazy::new());
write_guard.insert(key, lazy.clone());
lazy
}
(None, Some(current)) => {
let current_clone = current.clone();
drop(write_guard);
return Some(current_clone.get().await);
}
(Some(initial), Some(current)) => {
if Arc::ptr_eq(initial, current) {
let lazy = Arc::new(AsyncLazy::new());
write_guard.insert(key, lazy.clone());
lazy
} else {
let current_clone = current.clone();
drop(write_guard);
return Some(current_clone.get().await);
}
}
(Some(_), None) => {
let lazy = Arc::new(AsyncLazy::new());
write_guard.insert(key, lazy.clone());
lazy
}
}
};
Some(new_lazy.get_or_init(factory).await)
}
}
impl<K, V> Default for AsyncCache<K, V>
where
K: Eq + Hash + Clone,
{
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use tokio::time::{sleep, Duration};
async fn cache_len<K, V>(cache: &AsyncCache<K, V>) -> usize
where
K: Eq + Hash + Clone,
{
cache.map.read().await.len()
}
async fn cache_is_empty<K, V>(cache: &AsyncCache<K, V>) -> bool
where
K: Eq + Hash + Clone,
{
cache.map.read().await.is_empty()
}
async fn cache_clear<K, V>(cache: &AsyncCache<K, V>)
where
K: Eq + Hash + Clone,
{
cache.map.write().await.clear();
}
#[tokio::test]
async fn get_or_insert_caches_value() {
let cache: AsyncCache<String, i32> = AsyncCache::new();
let counter = Arc::new(AtomicUsize::new(0));
let counter_clone = counter.clone();
let value = cache
.get_or_insert_with("key1".to_string(), || async move {
counter_clone.fetch_add(1, Ordering::SeqCst);
42
})
.await;
assert_eq!(*value, 42);
assert_eq!(counter.load(Ordering::SeqCst), 1);
let counter_clone = counter.clone();
let value2 = cache
.get_or_insert_with("key1".to_string(), || async move {
counter_clone.fetch_add(1, Ordering::SeqCst);
100
})
.await;
assert_eq!(*value2, 42);
assert_eq!(counter.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn different_keys_different_values() {
let cache: AsyncCache<String, i32> = AsyncCache::new();
let v1 = cache
.get_or_insert_with("key1".to_string(), || async { 1 })
.await;
let v2 = cache
.get_or_insert_with("key2".to_string(), || async { 2 })
.await;
assert_eq!(*v1, 1);
assert_eq!(*v2, 2);
assert_eq!(cache_len(&cache).await, 2);
}
#[tokio::test]
async fn concurrent_same_key_single_init() {
let cache = Arc::new(AsyncCache::<String, String>::new());
let counter = Arc::new(AtomicUsize::new(0));
let mut handles = vec![];
for _ in 0..10 {
let cache_clone = cache.clone();
let counter_clone = counter.clone();
handles.push(tokio::spawn(async move {
cache_clone
.get_or_insert_with("shared_key".to_string(), || async move {
counter_clone.fetch_add(1, Ordering::SeqCst);
sleep(Duration::from_millis(50)).await;
"result".to_string()
})
.await
}));
}
for handle in handles {
let result = handle.await.unwrap();
assert_eq!(*result, "result");
}
assert_eq!(counter.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn get_returns_none_before_insert() {
let cache: AsyncCache<String, i32> = AsyncCache::new();
assert!(cache.get(&"key".to_string()).await.is_none());
}
#[tokio::test]
async fn get_returns_value_after_insert() {
let cache: AsyncCache<String, i32> = AsyncCache::new();
cache
.get_or_insert_with("key".to_string(), || async { 42 })
.await;
assert_eq!(*cache.get(&"key".to_string()).await.unwrap(), 42);
}
#[tokio::test]
async fn invalidate_removes_entry() {
let cache: AsyncCache<String, i32> = AsyncCache::new();
cache
.get_or_insert_with("key".to_string(), || async { 42 })
.await;
let removed = cache.invalidate(&"key".to_string()).await;
assert_eq!(*removed.unwrap(), 42);
assert!(cache.get(&"key".to_string()).await.is_none());
assert!(cache_is_empty(&cache).await);
}
#[tokio::test]
async fn clear_removes_all_entries() {
let cache: AsyncCache<String, i32> = AsyncCache::new();
cache
.get_or_insert_with("key1".to_string(), || async { 1 })
.await;
cache
.get_or_insert_with("key2".to_string(), || async { 2 })
.await;
assert_eq!(cache_len(&cache).await, 2);
cache_clear(&cache).await;
assert!(cache_is_empty(&cache).await);
}
#[tokio::test]
async fn refresh_when_key_missing_and_predicate_true() {
let cache: AsyncCache<String, i32> = AsyncCache::new();
let counter = Arc::new(AtomicUsize::new(0));
let counter_clone = counter.clone();
let result = cache
.get_or_refresh_with(
"key".to_string(),
|existing| {
assert!(existing.is_none());
true
},
|| async move {
counter_clone.fetch_add(1, Ordering::SeqCst);
42
},
)
.await;
assert_eq!(*result.unwrap(), 42);
assert_eq!(counter.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn no_refresh_when_key_missing_and_predicate_false() {
let cache: AsyncCache<String, i32> = AsyncCache::new();
let counter = Arc::new(AtomicUsize::new(0));
let counter_clone = counter.clone();
let result = cache
.get_or_refresh_with(
"key".to_string(),
|existing| {
assert!(existing.is_none());
false },
|| async move {
counter_clone.fetch_add(1, Ordering::SeqCst);
42
},
)
.await;
assert!(result.is_none());
assert_eq!(counter.load(Ordering::SeqCst), 0);
assert!(cache_is_empty(&cache).await);
}
#[tokio::test]
async fn no_refresh_when_value_not_stale() {
let cache: AsyncCache<String, i32> = AsyncCache::new();
cache
.get_or_insert_with("key".to_string(), || async { 42 })
.await;
let counter = Arc::new(AtomicUsize::new(0));
let counter_clone = counter.clone();
let result = cache
.get_or_refresh_with(
"key".to_string(),
|existing| {
assert_eq!(*existing.unwrap(), 42);
false },
|| async move {
counter_clone.fetch_add(1, Ordering::SeqCst);
100
},
)
.await;
assert_eq!(*result.unwrap(), 42); assert_eq!(counter.load(Ordering::SeqCst), 0); }
#[tokio::test]
async fn refresh_when_value_is_stale() {
let cache: AsyncCache<String, i32> = AsyncCache::new();
cache
.get_or_insert_with("key".to_string(), || async { 42 })
.await;
let counter = Arc::new(AtomicUsize::new(0));
let counter_clone = counter.clone();
let result = cache
.get_or_refresh_with(
"key".to_string(),
|existing| {
assert_eq!(*existing.unwrap(), 42);
true },
|| async move {
counter_clone.fetch_add(1, Ordering::SeqCst);
100
},
)
.await;
assert_eq!(*result.unwrap(), 100); assert_eq!(counter.load(Ordering::SeqCst), 1); }
#[tokio::test]
async fn concurrent_refresh_single_factory_call() {
let cache = Arc::new(AsyncCache::<String, String>::new());
cache
.get_or_insert_with("key".to_string(), || async { "stale".to_string() })
.await;
let counter = Arc::new(AtomicUsize::new(0));
let mut handles = vec![];
for _ in 0..10 {
let cache_clone = cache.clone();
let counter_clone = counter.clone();
handles.push(tokio::spawn(async move {
cache_clone
.get_or_refresh_with(
"key".to_string(),
|existing| {
existing.map(|v| v.as_str() == "stale").unwrap_or(true)
},
|| async move {
counter_clone.fetch_add(1, Ordering::SeqCst);
sleep(Duration::from_millis(50)).await;
"fresh".to_string()
},
)
.await
}));
}
for handle in handles {
let result = handle.await.unwrap();
assert_eq!(*result.unwrap(), "fresh");
}
assert_eq!(counter.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn refresh_replaces_old_lazy_atomically() {
let cache: AsyncCache<String, i32> = AsyncCache::new();
cache
.get_or_insert_with("key".to_string(), || async { 1 })
.await;
cache
.get_or_refresh_with(
"key".to_string(),
|_| true, || async { 2 },
)
.await;
let value = cache.get(&"key".to_string()).await.unwrap();
assert_eq!(*value, 2);
cache
.get_or_refresh_with(
"key".to_string(),
|_| true, || async { 3 },
)
.await;
let value = cache.get(&"key".to_string()).await.unwrap();
assert_eq!(*value, 3);
}
#[tokio::test]
async fn fast_path_does_not_panic_on_uninitialized_lazy() {
let cache = Arc::new(AsyncCache::<String, String>::new());
let mut handles = vec![];
for i in 0..20 {
let cache_clone = cache.clone();
handles.push(tokio::spawn(async move {
cache_clone
.get_or_insert_with("race_key".to_string(), || async move {
sleep(Duration::from_millis(10)).await;
format!("value-{}", i)
})
.await
}));
}
let mut results = Vec::new();
for handle in handles {
results.push(handle.await.unwrap());
}
let first = &*results[0];
for result in &results {
assert_eq!(&**result, first);
}
}
}