#![deny(missing_docs)]
use std::{
fmt,
hash::{BuildHasher, Hash, RandomState},
sync::{Arc, Weak},
};
use dashmap::{DashMap, mapref::entry::Entry};
use tokio::sync::broadcast;
#[derive(Debug)]
pub enum CoalescedGetError<E> {
Init(E),
CoalescedRequestFailed,
}
impl<E: fmt::Display> fmt::Display for CoalescedGetError<E> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
CoalescedGetError::Init(e) => write!(f, "initializer failed: {e}"),
CoalescedGetError::CoalescedRequestFailed => {
write!(f, "a coalesced request failed")
}
}
}
}
impl<E: fmt::Debug + fmt::Display> std::error::Error for CoalescedGetError<E> {}
#[derive(Clone)]
pub struct CoalescedMap<K, V, S = RandomState>
where
K: Eq + Hash,
V: Clone,
S: BuildHasher + Clone,
{
map: DashMap<K, PendingOrFetched<V>, S>,
}
impl<K, V> Default for CoalescedMap<K, V, RandomState>
where
K: Eq + Hash,
V: Clone,
{
fn default() -> Self {
Self::new()
}
}
impl<K, V> CoalescedMap<K, V, RandomState>
where
K: Eq + Hash,
V: Clone,
{
pub fn new() -> Self {
Self {
map: DashMap::new(),
}
}
pub fn with_capacity(capacity: usize) -> Self {
Self {
map: DashMap::with_capacity(capacity),
}
}
}
impl<K, V, S> CoalescedMap<K, V, S>
where
K: Eq + Hash,
V: Clone,
S: BuildHasher + Clone,
{
pub fn with_hasher(hasher: S) -> Self {
Self {
map: DashMap::with_hasher(hasher),
}
}
pub fn with_capacity_and_hasher(capacity: usize, hasher: S) -> Self {
Self {
map: DashMap::with_capacity_and_hasher(capacity, hasher),
}
}
pub fn len(&self) -> usize {
self.map.len()
}
pub fn is_empty(&self) -> bool {
self.map.is_empty()
}
}
impl<K, V, S> CoalescedMap<K, V, S>
where
K: Eq + Hash + Clone,
V: Clone + Send + Sync + 'static,
S: BuildHasher + Clone,
{
pub async fn get_or_try_init<E, Fut, F>(
&self,
key: K,
init: F,
) -> Result<V, CoalescedGetError<E>>
where
F: FnOnce() -> Fut,
Fut: std::future::Future<Output = Result<V, E>>,
{
let sender = match self.map.entry(key.clone()) {
Entry::Vacant(entry) => {
let (tx, _) = broadcast::channel(1);
let tx = Arc::new(tx);
entry.insert(PendingOrFetched::Pending(Arc::downgrade(&tx)));
tx
}
Entry::Occupied(mut entry) => match entry.get() {
PendingOrFetched::Fetched(v) => return Ok(v.clone()),
PendingOrFetched::Pending(weak_tx) => {
if let Some(tx) = weak_tx.upgrade() {
let mut rx = tx.subscribe();
drop(tx);
drop(entry);
return rx
.recv()
.await
.map_err(|_err| CoalescedGetError::CoalescedRequestFailed);
}
let (tx, _) = broadcast::channel(1);
let tx = Arc::new(tx);
entry.insert(PendingOrFetched::Pending(Arc::downgrade(&tx)));
tx
}
},
};
match init().await {
Ok(value) => {
self.map
.insert(key, PendingOrFetched::Fetched(value.clone()));
let _ = sender.send(value.clone());
Ok(value)
}
Err(err) => Err(CoalescedGetError::Init(err)),
}
}
pub fn get(&self, key: &K) -> Option<V> {
self.map.get(key).and_then(|g| match g.value() {
PendingOrFetched::Fetched(v) => Some(v.clone()),
PendingOrFetched::Pending(_) => None,
})
}
pub fn retain<F>(&self, mut f: F)
where
F: FnMut(&K, &PendingOrFetched<V>) -> bool,
{
self.map.retain(|k, v| f(k, v));
}
pub fn clear(&self) {
self.map
.retain(|_, v| matches!(v, PendingOrFetched::Pending(_)));
}
}
#[derive(Clone)]
pub enum PendingOrFetched<T> {
Pending(Weak<broadcast::Sender<T>>),
Fetched(T),
}
#[cfg(test)]
mod tests {
use super::*;
use std::{
future::pending,
sync::{
Arc,
atomic::{AtomicUsize, Ordering},
},
};
use tokio::task::JoinHandle;
#[tokio::test]
async fn test_basic_get_or_try_init() {
let map: CoalescedMap<String, String> = CoalescedMap::new();
let result = map
.get_or_try_init("key1".to_string(), || async {
Ok::<_, &str>("value1".to_string())
})
.await
.unwrap();
assert_eq!(result, "value1");
let result2 = map
.get_or_try_init("key1".to_string(), || async {
Ok::<_, &str>("should_not_be_called".to_string())
})
.await
.unwrap();
assert_eq!(result2, "value1");
}
#[tokio::test]
async fn test_get_if_fetched() {
let map: CoalescedMap<String, String> = CoalescedMap::new();
assert_eq!(map.get(&"key1".to_string()), None);
map.get_or_try_init("key1".to_string(), || async {
Ok::<_, &str>("value1".to_string())
})
.await
.unwrap();
assert_eq!(map.get(&"key1".to_string()), Some("value1".to_string()));
}
#[tokio::test]
async fn test_concurrent_initialization() {
let map: Arc<CoalescedMap<String, Arc<String>>> = Arc::new(CoalescedMap::new());
let call_count = Arc::new(AtomicUsize::new(0));
let barrier = Arc::new(tokio::sync::Barrier::new(10));
let handles: Vec<_> = (0..10)
.map(|i| {
let map = map.clone();
let call_count = call_count.clone();
let barrier = barrier.clone();
tokio::spawn(async move {
barrier.wait().await;
map.get_or_try_init("shared_key".to_string(), || {
let call_count = call_count.clone();
async move {
call_count.fetch_add(1, Ordering::SeqCst);
Ok::<_, &str>(Arc::new(format!("value_from_task_{i}")))
}
})
.await
})
})
.collect();
let results: Vec<_> = futures::future::try_join_all(handles)
.await
.unwrap()
.into_iter()
.map(|r| r.unwrap())
.collect();
assert_eq!(call_count.load(Ordering::SeqCst), 1);
let first_result = &results[0];
for result in &results {
assert!(Arc::ptr_eq(first_result, result));
}
}
#[tokio::test]
async fn test_error_handling() {
let map: CoalescedMap<String, String> = CoalescedMap::new();
let result = map
.get_or_try_init("error_key".to_string(), || async {
Err("initialization failed")
})
.await;
match result {
Err(CoalescedGetError::Init(err)) => assert_eq!(err, "initialization failed"),
_ => panic!("Expected Init error"),
}
assert_eq!(map.get(&"error_key".to_string()), None);
let success_result = map
.get_or_try_init("error_key".to_string(), || async {
Ok::<_, &str>("success_value".to_string())
})
.await
.unwrap();
assert_eq!(success_result, "success_value");
}
#[tokio::test]
async fn test_concurrent_error_handling() {
let map = Arc::new(CoalescedMap::new());
let init_calls = Arc::new(AtomicUsize::new(0));
let barrier = Arc::new(tokio::sync::Barrier::new(2));
let handles: Vec<_> = (0..5)
.map(|i| {
let map = map.clone();
let init_calls = init_calls.clone();
let barrier = barrier.clone();
tokio::spawn(async move {
map.get_or_try_init("fail_key".to_string(), || {
let init_calls = init_calls.clone();
async move {
init_calls.fetch_add(1, Ordering::SeqCst);
if i == 0 {
barrier.wait().await;
Ok(format!("success_{i}"))
} else {
Err(format!("error_{i}"))
}
}
})
.await
})
})
.collect();
barrier.wait().await;
let results: Vec<_> = futures::future::join_all(handles).await;
assert_eq!(init_calls.load(Ordering::SeqCst), 1);
for result in results {
let value = result.unwrap().unwrap();
assert_eq!(value, "success_0");
}
}
#[tokio::test]
async fn test_different_keys() {
let map = Arc::new(CoalescedMap::new());
let handles: Vec<_> = (0..5)
.map(|i| {
let map = map.clone();
tokio::spawn(async move {
let key = format!("key_{i}");
let value = format!("value_{i}");
map.get_or_try_init(key.clone(), || async move { Ok::<_, &str>(value) })
.await
.map(|v| (key, v))
})
})
.collect();
let results: Vec<_> = futures::future::try_join_all(handles)
.await
.unwrap()
.into_iter()
.map(|r| r.unwrap())
.collect();
for (i, (key, value)) in results.into_iter().enumerate() {
assert_eq!(key, format!("key_{i}"));
assert_eq!(value, format!("value_{i}"));
}
for i in 0..5 {
let key = format!("key_{i}");
let expected_value = format!("value_{i}");
assert_eq!(map.get(&key), Some(expected_value));
}
}
#[tokio::test]
async fn test_retain_functionality() {
let map: CoalescedMap<String, String> = CoalescedMap::new();
for i in 0..5 {
let key = format!("key_{i}");
let value = format!("value_{i}");
map.get_or_try_init(key, || async move { Ok::<_, &str>(value) })
.await
.unwrap();
}
map.retain(|key, _| {
if let Some(num_str) = key.strip_prefix("key_")
&& let Ok(num) = num_str.parse::<i32>()
{
return num % 2 == 0;
}
false
});
assert_eq!(map.get(&"key_0".to_string()), Some("value_0".to_string()));
assert_eq!(map.get(&"key_1".to_string()), None);
assert_eq!(map.get(&"key_2".to_string()), Some("value_2".to_string()));
assert_eq!(map.get(&"key_3".to_string()), None);
assert_eq!(map.get(&"key_4".to_string()), Some("value_4".to_string()));
}
#[tokio::test]
async fn test_coalesced_request_failed_error() {
let map = Arc::new(CoalescedMap::new());
let barrier = Arc::new(tokio::sync::Barrier::new(3));
let map1 = map.clone();
let barrier1 = barrier.clone();
let handle1 = tokio::spawn(async move {
map1.get_or_try_init("test_key".to_string(), || async move {
barrier1.wait().await;
let () = pending().await;
Ok::<_, &str>("value".to_string())
})
.await
});
let map2 = map.clone();
let barrier2 = barrier.clone();
let handle2 = tokio::spawn(async move {
barrier2.wait().await;
map2.get_or_try_init("test_key".to_string(), || async move {
Ok::<_, &str>("should_not_be_called".to_string())
})
.await
});
barrier.wait().await;
handle1.abort();
let result = handle2.await.unwrap();
match result {
Err(CoalescedGetError::CoalescedRequestFailed) => {
}
_ => panic!("Expected CoalescedRequestFailed error, got {result:?}"),
}
}
#[tokio::test]
async fn test_coalesced_request_failed_panic() {
let map = Arc::new(CoalescedMap::new());
let barrier = Arc::new(tokio::sync::Barrier::new(3));
let map1 = map.clone();
let barrier1 = barrier.clone();
let handle1: JoinHandle<Result<String, CoalescedGetError<&'static str>>> =
tokio::spawn(async move {
map1.get_or_try_init("test_key".to_string(), || async move {
barrier1.wait().await;
panic!();
})
.await
});
let map2 = map.clone();
let barrier2 = barrier.clone();
let handle2 = tokio::spawn(async move {
barrier2.wait().await;
map2.get_or_try_init("test_key".to_string(), || async move {
Ok::<_, &str>("should_not_be_called".to_string())
})
.await
});
barrier.wait().await;
handle1.abort();
let result = handle2.await.unwrap();
match result {
Err(CoalescedGetError::CoalescedRequestFailed) => {
}
_ => panic!("Expected CoalescedRequestFailed error, got {result:?}"),
}
}
#[tokio::test]
async fn test_empty_map() {
let map: CoalescedMap<String, String> = CoalescedMap::new();
assert_eq!(map.len(), 0);
assert!(map.is_empty());
assert_eq!(map.get(&"nonexistent".to_string()), None);
map.get_or_try_init("key".to_string(), || async {
Ok::<_, &str>("value".to_string())
})
.await
.unwrap();
assert_eq!(map.len(), 1);
assert!(!map.is_empty());
}
#[tokio::test]
async fn test_custom_hasher() {
use std::collections::hash_map::RandomState;
let hasher = RandomState::new();
let map: CoalescedMap<String, String, RandomState> = CoalescedMap::with_hasher(hasher);
let result = map
.get_or_try_init("key1".to_string(), || async {
Ok::<_, &str>("value1".to_string())
})
.await
.unwrap();
assert_eq!(result, "value1");
let hasher2 = RandomState::new();
let map2: CoalescedMap<String, String, RandomState> =
CoalescedMap::with_capacity_and_hasher(10, hasher2);
let result2 = map2
.get_or_try_init("key2".to_string(), || async {
Ok::<_, &str>("value2".to_string())
})
.await
.unwrap();
assert_eq!(result2, "value2");
}
#[tokio::test]
async fn test_clear_removes_fetched_keeps_pending() {
let map = Arc::new(CoalescedMap::new());
let barrier = Arc::new(tokio::sync::Barrier::new(2));
let map1 = map.clone();
let barrier1 = barrier.clone();
let pending_handle = tokio::spawn(async move {
map1.get_or_try_init("pending".to_string(), || {
let barrier = barrier1.clone();
async move {
barrier.wait().await;
let () = pending().await;
#[allow(unreachable_code)]
Ok::<_, &str>("never".to_string())
}
})
.await
});
barrier.wait().await;
map.get_or_try_init("fetched".to_string(), || async {
Ok::<_, &str>("value".to_string())
})
.await
.unwrap();
assert_eq!(map.len(), 2);
assert_eq!(map.get(&"fetched".to_string()), Some("value".to_string()));
assert_eq!(map.get(&"pending".to_string()), None);
map.clear();
assert_eq!(map.len(), 1, "only the pending entry should remain");
assert_eq!(map.get(&"fetched".to_string()), None);
assert_eq!(map.get(&"pending".to_string()), None);
pending_handle.abort();
let _ = pending_handle.await;
}
}