use std::{
borrow::Cow,
collections::HashMap,
future::Future,
sync::{Arc, Weak},
};
use parking_lot::Mutex as SyncMutex;
use tokio::sync::Mutex;
type SharedMapping<T> = Arc<SyncMutex<HashMap<Cow<'static, str>, BroadcastOnce<T>>>>;
#[derive(Debug)]
pub struct SingleFlight<T> {
mapping: SharedMapping<T>,
}
impl<T> Default for SingleFlight<T> {
fn default() -> Self {
Self {
mapping: Default::default(),
}
}
}
enum Key<'a> {
Static(Cow<'static, str>),
MaybeBorrowed(Cow<'a, str>),
}
impl<'a> Key<'a> {
#[inline]
fn as_str(&'a self) -> &'a str {
match self {
Key::Static(cow) => cow.as_ref(),
Key::MaybeBorrowed(cow) => cow.as_ref(),
}
}
}
impl<'a> From<Key<'a>> for Cow<'static, str> {
fn from(k: Key<'a>) -> Self {
match k {
Key::Static(cow) => cow,
Key::MaybeBorrowed(cow) => Cow::Owned(cow.into_owned()),
}
}
}
struct Shared<T> {
slot: Mutex<Option<T>>,
}
impl<T> Default for Shared<T> {
fn default() -> Self {
Self {
slot: Mutex::new(None),
}
}
}
#[derive(Clone)]
struct BroadcastOnce<T> {
shared: Weak<Shared<T>>,
}
impl<T> BroadcastOnce<T> {
fn new() -> (Self, Arc<Shared<T>>) {
let shared = Arc::new(Shared::default());
(
Self {
shared: Arc::downgrade(&shared),
},
shared,
)
}
}
struct BroadcastOnceWaiter<T, F> {
func: F,
shared: Arc<Shared<T>>,
key: Cow<'static, str>,
mapping: SharedMapping<T>,
}
impl<T> std::fmt::Debug for BroadcastOnce<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "BroadcastOnce")
}
}
#[allow(clippy::type_complexity)]
impl<T> BroadcastOnce<T> {
fn try_waiter<F>(
&self,
func: F,
key: Cow<'static, str>,
mapping: SharedMapping<T>,
) -> Result<BroadcastOnceWaiter<T, F>, (F, Cow<'static, str>, SharedMapping<T>)> {
let Some(upgraded) = self.shared.upgrade() else {
return Err((func, key, mapping));
};
Ok(BroadcastOnceWaiter {
func,
shared: upgraded,
key,
mapping,
})
}
#[inline]
const fn waiter<F>(
shared: Arc<Shared<T>>,
func: F,
key: Cow<'static, str>,
mapping: SharedMapping<T>,
) -> BroadcastOnceWaiter<T, F> {
BroadcastOnceWaiter {
func,
shared,
key,
mapping,
}
}
}
impl<T, F, Fut> BroadcastOnceWaiter<T, F>
where
F: FnOnce() -> Fut,
Fut: Future<Output = T>,
T: Clone,
{
async fn wait(self) -> T {
let mut slot = self.shared.slot.lock().await;
if let Some(value) = (*slot).as_ref() {
return value.clone();
}
let value = (self.func)().await;
*slot = Some(value.clone());
self.mapping.lock().remove(&self.key);
value
}
}
impl<T> SingleFlight<T> {
#[inline]
pub fn new() -> Self {
Self::default()
}
pub fn work_with_owned_key<F, Fut>(
&self,
key: Cow<'static, str>,
func: F,
) -> impl Future<Output = T>
where
F: FnOnce() -> Fut,
Fut: Future<Output = T>,
T: Clone,
{
self.work_inner(Key::Static(key), func)
}
pub fn work<F, Fut>(&self, key: &str, func: F) -> impl Future<Output = T>
where
F: FnOnce() -> Fut,
Fut: Future<Output = T>,
T: Clone,
{
self.work_inner(Key::MaybeBorrowed(key.into()), func)
}
#[inline]
fn work_inner<'a, 'b: 'a, F, Fut>(&'a self, key: Key<'b>, func: F) -> impl Future<Output = T>
where
F: FnOnce() -> Fut,
Fut: Future<Output = T>,
T: Clone,
{
let owned_mapping = self.mapping.clone();
let mut mapping = self.mapping.lock();
let val = mapping.get_mut(key.as_str());
match val {
Some(call) => {
let key: Cow<'static, str> = key.into();
let (func, key, owned_mapping) = match call.try_waiter(func, key, owned_mapping) {
Ok(waiter) => return waiter.wait(),
Err(fm) => fm,
};
let (new_call, shared) = BroadcastOnce::new();
*call = new_call;
let waiter = BroadcastOnce::waiter(shared, func, key, owned_mapping);
waiter.wait()
}
None => {
let key: Cow<'static, str> = key.into();
let (call, shared) = BroadcastOnce::new();
mapping.insert(key.clone(), call);
let waiter = BroadcastOnce::waiter(shared, func, key, owned_mapping);
waiter.wait()
}
}
}
}
#[cfg(test)]
mod tests {
use std::{
sync::atomic::{
AtomicUsize,
Ordering::{AcqRel, Acquire},
},
time::Duration,
};
use futures_util::{stream::FuturesUnordered, StreamExt};
use super::*;
#[tokio::test]
async fn direct_call() {
let group = SingleFlight::new();
let result = group
.work("key", || async {
tokio::time::sleep(Duration::from_millis(10)).await;
"Result".to_string()
})
.await;
assert_eq!(result, "Result");
}
#[tokio::test]
async fn parallel_call() {
let call_counter = AtomicUsize::default();
let group = SingleFlight::new();
let futures = FuturesUnordered::new();
for _ in 0..10 {
futures.push(group.work("key", || async {
tokio::time::sleep(Duration::from_millis(100)).await;
call_counter.fetch_add(1, AcqRel);
"Result".to_string()
}));
}
assert!(futures.all(|out| async move { out == "Result" }).await);
assert_eq!(
call_counter.load(Acquire),
1,
"future should only be executed once"
);
}
#[tokio::test]
async fn parallel_call_seq_await() {
let call_counter = AtomicUsize::default();
let group = SingleFlight::new();
let mut futures = Vec::new();
for _ in 0..10 {
futures.push(group.work("key", || async {
tokio::time::sleep(Duration::from_millis(100)).await;
call_counter.fetch_add(1, AcqRel);
"Result".to_string()
}));
}
for fut in futures.into_iter() {
assert_eq!(fut.await, "Result");
}
assert_eq!(
call_counter.load(Acquire),
1,
"future should only be executed once"
);
}
#[tokio::test]
async fn call_with_static_str_key() {
let group = SingleFlight::new();
let result = group
.work_with_owned_key("key".into(), || async {
tokio::time::sleep(Duration::from_millis(1)).await;
"Result".to_string()
})
.await;
assert_eq!(result, "Result");
}
#[tokio::test]
async fn call_with_static_string_key() {
let group = SingleFlight::new();
let result = group
.work_with_owned_key("key".to_string().into(), || async {
tokio::time::sleep(Duration::from_millis(1)).await;
"Result".to_string()
})
.await;
assert_eq!(result, "Result");
}
#[tokio::test]
async fn late_wait() {
let group = SingleFlight::new();
let fut_early = group.work_with_owned_key("key".into(), || async {
tokio::time::sleep(Duration::from_millis(20)).await;
"Result".to_string()
});
let fut_late = group.work_with_owned_key("key".into(), || async { panic!("unexpected") });
assert_eq!(fut_early.await, "Result");
tokio::time::sleep(Duration::from_millis(50)).await;
assert_eq!(fut_late.await, "Result");
}
#[tokio::test]
async fn cancel() {
let group = SingleFlight::new();
let fut_cancel = group.work_with_owned_key("key".into(), || async {
tokio::time::sleep(Duration::from_millis(2000)).await;
"Result1".to_string()
});
let _ = tokio::time::timeout(Duration::from_millis(10), fut_cancel).await;
let fut_late = group.work_with_owned_key("key".into(), || async { "Result2".to_string() });
assert_eq!(fut_late.await, "Result2");
let begin = tokio::time::Instant::now();
let fut_1 = group.work_with_owned_key("key".into(), || async {
tokio::time::sleep(Duration::from_millis(2000)).await;
"Result1".to_string()
});
let fut_2 = group.work_with_owned_key("key".into(), || async { panic!() });
let (v1, v2) = tokio::join!(fut_1, fut_2);
assert_eq!(v1, "Result1");
assert_eq!(v2, "Result1");
assert!(begin.elapsed() > Duration::from_millis(1500));
}
}