use std::{
collections::HashMap,
future::Future,
hash::Hash,
sync::{Arc, Weak},
};
use parking_lot::Mutex as SyncMutex;
use tokio::sync::Mutex;
type SharedMapping<K, T> = Arc<SyncMutex<HashMap<K, BroadcastOnce<T>>>>;
#[derive(Debug)]
pub struct SingleFlight<K, T> {
mapping: SharedMapping<K, T>,
}
impl<K, T> Default for SingleFlight<K, T> {
fn default() -> Self {
Self {
mapping: Default::default(),
}
}
}
struct Shared<T> {
slot: Mutex<Option<T>>,
}
impl<T> Default for Shared<T> {
fn default() -> Self {
Self {
slot: Mutex::new(None),
}
}
}
#[derive(Clone)]
struct BroadcastOnce<T> {
shared: Weak<Shared<T>>,
}
impl<T> BroadcastOnce<T> {
fn new() -> (Self, Arc<Shared<T>>) {
let shared = Arc::new(Shared::default());
(
Self {
shared: Arc::downgrade(&shared),
},
shared,
)
}
}
struct BroadcastOnceWaiter<K, T, F> {
func: F,
shared: Arc<Shared<T>>,
key: K,
mapping: SharedMapping<K, T>,
}
impl<T> std::fmt::Debug for BroadcastOnce<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "BroadcastOnce")
}
}
#[allow(clippy::type_complexity)]
impl<T> BroadcastOnce<T> {
fn try_waiter<K, F>(
&self,
func: F,
key: K,
mapping: SharedMapping<K, T>,
) -> Result<BroadcastOnceWaiter<K, T, F>, (F, K, SharedMapping<K, T>)> {
let Some(upgraded) = self.shared.upgrade() else {
return Err((func, key, mapping));
};
Ok(BroadcastOnceWaiter {
func,
shared: upgraded,
key,
mapping,
})
}
#[inline]
const fn waiter<K, F>(
shared: Arc<Shared<T>>,
func: F,
key: K,
mapping: SharedMapping<K, T>,
) -> BroadcastOnceWaiter<K, T, F> {
BroadcastOnceWaiter {
func,
shared,
key,
mapping,
}
}
}
impl<K, T, F, Fut> BroadcastOnceWaiter<K, T, F>
where
K: Hash + Eq,
F: FnOnce() -> Fut,
Fut: Future<Output = T>,
T: Clone,
{
async fn wait(self) -> T {
let mut slot = self.shared.slot.lock().await;
if let Some(value) = (*slot).as_ref() {
return value.clone();
}
let value = (self.func)().await;
*slot = Some(value.clone());
self.mapping.lock().remove(&self.key);
value
}
}
impl<K, T> SingleFlight<K, T>
where
K: Hash + Eq + Clone,
{
#[inline]
pub fn new() -> Self {
Self::default()
}
pub fn work<F, Fut>(&self, key: K, func: F) -> impl Future<Output = T>
where
F: FnOnce() -> Fut,
Fut: Future<Output = T>,
T: Clone,
{
let owned_mapping = self.mapping.clone();
let mut mapping = self.mapping.lock();
let val = mapping.get_mut(&key);
match val {
Some(call) => {
let (func, key, owned_mapping) = match call.try_waiter(func, key, owned_mapping) {
Ok(waiter) => return waiter.wait(),
Err(fm) => fm,
};
let (new_call, shared) = BroadcastOnce::new();
*call = new_call;
let waiter = BroadcastOnce::waiter(shared, func, key, owned_mapping);
waiter.wait()
}
None => {
let (call, shared) = BroadcastOnce::new();
mapping.insert(key.clone(), call);
let waiter = BroadcastOnce::waiter(shared, func, key, owned_mapping);
waiter.wait()
}
}
}
}
#[cfg(test)]
mod tests {
use std::{
sync::atomic::{
AtomicUsize,
Ordering::{AcqRel, Acquire},
},
time::Duration,
};
use futures_util::{stream::FuturesUnordered, StreamExt};
use super::*;
#[tokio::test]
async fn direct_call() {
let group = SingleFlight::new();
let result = group
.work("key", || async {
tokio::time::sleep(Duration::from_millis(10)).await;
"Result".to_string()
})
.await;
assert_eq!(result, "Result");
}
#[tokio::test]
async fn parallel_call() {
let call_counter = AtomicUsize::default();
let group = SingleFlight::new();
let futures = FuturesUnordered::new();
for _ in 0..10 {
futures.push(group.work("key", || async {
tokio::time::sleep(Duration::from_millis(100)).await;
call_counter.fetch_add(1, AcqRel);
"Result".to_string()
}));
}
assert!(futures.all(|out| async move { out == "Result" }).await);
assert_eq!(
call_counter.load(Acquire),
1,
"future should only be executed once"
);
}
#[tokio::test]
async fn parallel_call_seq_await() {
let call_counter = AtomicUsize::default();
let group = SingleFlight::new();
let mut futures = Vec::new();
for _ in 0..10 {
futures.push(group.work("key", || async {
tokio::time::sleep(Duration::from_millis(100)).await;
call_counter.fetch_add(1, AcqRel);
"Result".to_string()
}));
}
for fut in futures.into_iter() {
assert_eq!(fut.await, "Result");
}
assert_eq!(
call_counter.load(Acquire),
1,
"future should only be executed once"
);
}
#[tokio::test]
async fn call_with_static_str_key() {
let group = SingleFlight::new();
let result = group
.work("key".to_string(), || async {
tokio::time::sleep(Duration::from_millis(1)).await;
"Result".to_string()
})
.await;
assert_eq!(result, "Result");
}
#[tokio::test]
async fn call_with_static_string_key() {
let group = SingleFlight::new();
let result = group
.work("key".to_string(), || async {
tokio::time::sleep(Duration::from_millis(1)).await;
"Result".to_string()
})
.await;
assert_eq!(result, "Result");
}
#[tokio::test]
async fn call_with_custom_key() {
#[derive(Clone, PartialEq, Eq, Hash)]
struct K(i32);
let group = SingleFlight::new();
let result = group
.work(K(1), || async {
tokio::time::sleep(Duration::from_millis(1)).await;
"Result".to_string()
})
.await;
assert_eq!(result, "Result");
}
#[tokio::test]
async fn late_wait() {
let group = SingleFlight::new();
let fut_early = group.work("key".to_string(), || async {
tokio::time::sleep(Duration::from_millis(20)).await;
"Result".to_string()
});
let fut_late = group.work("key".into(), || async { panic!("unexpected") });
assert_eq!(fut_early.await, "Result");
tokio::time::sleep(Duration::from_millis(50)).await;
assert_eq!(fut_late.await, "Result");
}
#[tokio::test]
async fn cancel() {
let group = SingleFlight::new();
let fut_cancel = group.work("key".to_string(), || async {
tokio::time::sleep(Duration::from_millis(2000)).await;
"Result1".to_string()
});
let _ = tokio::time::timeout(Duration::from_millis(10), fut_cancel).await;
let fut_late = group.work("key".to_string(), || async { "Result2".to_string() });
assert_eq!(fut_late.await, "Result2");
let begin = tokio::time::Instant::now();
let fut_1 = group.work("key".to_string(), || async {
tokio::time::sleep(Duration::from_millis(2000)).await;
"Result1".to_string()
});
let fut_2 = group.work("key".to_string(), || async { panic!() });
let (v1, v2) = tokio::join!(fut_1, fut_2);
assert_eq!(v1, "Result1");
assert_eq!(v2, "Result1");
assert!(begin.elapsed() > Duration::from_millis(1500));
}
}