#![doc = include_str!("../README.md")]
use std::{
collections::{HashMap, hash_map::RandomState},
fmt,
future::Future,
hash::{BuildHasher, Hash},
sync::{
Arc, Mutex, Weak,
atomic::{AtomicBool, AtomicUsize, Ordering},
},
};
use tokio::sync::broadcast;
type SharedOutcome<T, E> = Arc<Outcome<T, E>>;
type Calls<K, T, E, S> = HashMap<K, Weak<Call<K, T, E, S>>, S>;
#[derive(Debug)]
pub enum Outcome<T, E> {
Complete { result: Result<T, E>, shared: bool },
Canceled,
}
impl<T, E> Outcome<T, E> {
pub fn is_shared(&self) -> bool {
matches!(self, Self::Complete { shared: true, .. })
}
pub fn result(&self) -> Option<&Result<T, E>> {
match self {
Self::Complete { result, .. } => Some(result),
Self::Canceled => None,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum WaitError {
Closed,
Lagged(u64),
}
impl fmt::Display for WaitError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Closed => f.write_str("singleflight result channel closed"),
Self::Lagged(n) => write!(f, "singleflight subscriber lagged by {n} messages"),
}
}
}
impl std::error::Error for WaitError {}
pub struct Group<K, T, E, F, S = RandomState> {
inner: Arc<Inner<K, T, E, S>>,
op: Arc<F>,
}
impl<K, T, E, F, Fut> Group<K, T, E, F, RandomState>
where
F: Fn(K) -> Fut,
Fut: Future<Output = Result<T, E>>,
{
pub fn new(op: F) -> Self {
Self::with_hasher(op, RandomState::new())
}
}
impl<K, T, E, F, S> Group<K, T, E, F, S> {
pub fn with_hasher(op: F, hasher: S) -> Self {
Self {
inner: Arc::new(Inner {
calls: Mutex::new(HashMap::with_hasher(hasher)),
}),
op: Arc::new(op),
}
}
}
impl<K, T, E, F, S> Clone for Group<K, T, E, F, S> {
fn clone(&self) -> Self {
Self {
inner: Arc::clone(&self.inner),
op: Arc::clone(&self.op),
}
}
}
impl<K, T, E, F, S> Group<K, T, E, F, S>
where
K: Eq + Hash,
S: BuildHasher,
{
pub fn entry(&self, key: K) -> Entry<K, T, E, S> {
let mut calls = self
.inner
.calls
.lock()
.expect("singleflight mutex poisoned");
if let Some(call) = calls.get(&key).and_then(Weak::upgrade) {
return Entry::Subscriber(call.subscribe());
}
let call = Arc::new(Call::new(Arc::downgrade(&self.inner)));
calls.insert(key, Arc::downgrade(&call));
Entry::Leader(Leader { call: Some(call) })
}
pub async fn run<Fut>(&self, key: K) -> SharedOutcome<T, E>
where
K: Clone,
F: Fn(K) -> Fut,
Fut: Future<Output = Result<T, E>>,
{
match self.entry(key.clone()) {
Entry::Leader(leader) => {
let result = (self.op)(key).await;
leader.complete(result)
}
Entry::Subscriber(subscriber) => subscriber
.recv()
.await
.unwrap_or_else(|_| Arc::new(Outcome::Canceled)),
}
}
pub fn forget<Q>(&self, key: &Q)
where
K: std::borrow::Borrow<Q>,
Q: Hash + Eq + ?Sized,
{
self.inner
.calls
.lock()
.expect("singleflight mutex poisoned")
.remove(key);
}
pub fn in_flight(&self) -> usize {
self.inner
.calls
.lock()
.expect("singleflight mutex poisoned")
.len()
}
}
pub enum Entry<K, T, E, S = RandomState> {
Leader(Leader<K, T, E, S>),
Subscriber(Subscriber<T, E>),
}
pub struct Leader<K, T, E, S = RandomState> {
call: Option<Arc<Call<K, T, E, S>>>,
}
impl<K, T, E, S> Leader<K, T, E, S>
where
K: Eq + Hash,
S: BuildHasher,
{
pub fn complete(mut self, result: Result<T, E>) -> SharedOutcome<T, E> {
let call = self.call.take().expect("leader completed twice");
call.cleanup();
let shared = call.waiters.load(Ordering::SeqCst) > 0;
let outcome = Arc::new(Outcome::Complete { result, shared });
call.publish(Arc::clone(&outcome));
outcome
}
pub fn subscribe(&self) -> Subscriber<T, E> {
self.call
.as_ref()
.expect("leader already completed")
.subscribe()
}
pub fn duplicate_count(&self) -> usize {
self.call
.as_ref()
.map(|call| call.waiters.load(Ordering::SeqCst))
.unwrap_or(0)
}
}
impl<K, T, E, S> Drop for Leader<K, T, E, S> {
fn drop(&mut self) {
if let Some(call) = self.call.take() {
call.cancel();
}
}
}
pub struct Subscriber<T, E> {
rx: broadcast::Receiver<SharedOutcome<T, E>>,
}
impl<T, E> Subscriber<T, E> {
pub async fn recv(mut self) -> Result<SharedOutcome<T, E>, WaitError> {
match self.rx.recv().await {
Ok(outcome) => Ok(outcome),
Err(broadcast::error::RecvError::Closed) => Err(WaitError::Closed),
Err(broadcast::error::RecvError::Lagged(n)) => Err(WaitError::Lagged(n)),
}
}
}
struct Inner<K, T, E, S> {
calls: Mutex<Calls<K, T, E, S>>,
}
struct Call<K, T, E, S> {
group: Weak<Inner<K, T, E, S>>,
tx: broadcast::Sender<SharedOutcome<T, E>>,
waiters: AtomicUsize,
finished: AtomicBool,
}
impl<K, T, E, S> Call<K, T, E, S> {
fn new(group: Weak<Inner<K, T, E, S>>) -> Self {
let (tx, _) = broadcast::channel(1);
Self {
group,
tx,
waiters: AtomicUsize::new(0),
finished: AtomicBool::new(false),
}
}
fn subscribe(&self) -> Subscriber<T, E> {
self.waiters.fetch_add(1, Ordering::SeqCst);
Subscriber {
rx: self.tx.subscribe(),
}
}
fn publish(&self, outcome: SharedOutcome<T, E>) {
if !self.finished.swap(true, Ordering::SeqCst) {
let _ = self.tx.send(outcome);
}
}
fn cancel(&self) {
self.cleanup();
self.publish(Arc::new(Outcome::Canceled));
}
fn cleanup(&self) {
let Some(group) = self.group.upgrade() else {
return;
};
let mut calls = group.calls.lock().expect("singleflight mutex poisoned");
calls.retain(|_, existing| {
existing
.upgrade()
.is_some_and(|call| !std::ptr::eq(call.as_ref(), self))
});
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::{
future::{Ready, ready},
sync::{
Arc,
atomic::{AtomicUsize, Ordering},
},
};
use tokio::{
sync::{Barrier, oneshot},
time::{Duration, sleep, timeout},
};
type EntryGroup = Group<&'static str, usize, (), fn(&'static str) -> Ready<Result<usize, ()>>>;
fn entry_op(_: &'static str) -> Ready<Result<usize, ()>> {
ready(Ok(0))
}
fn entry_group() -> EntryGroup {
Group::new(entry_op)
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn suppresses_duplicate_calls() {
let calls = Arc::new(AtomicUsize::new(0));
let calls_for_op = Arc::clone(&calls);
let group = Arc::new(Group::new(move |key: String| {
let calls = Arc::clone(&calls_for_op);
async move {
assert_eq!(key, "key");
calls.fetch_add(1, Ordering::SeqCst);
sleep(Duration::from_millis(20)).await;
Ok::<String, ()>("value".to_owned())
}
}));
let barrier = Arc::new(Barrier::new(12));
let mut tasks = Vec::new();
for _ in 0..12 {
let group = Arc::clone(&group);
let barrier = Arc::clone(&barrier);
tasks.push(tokio::spawn(async move {
barrier.wait().await;
group.run("key".to_owned()).await
}));
}
let mut shared = false;
for task in tasks {
let outcome = task.await.expect("task panicked");
match outcome.as_ref() {
Outcome::Complete { result, shared: s } => {
assert_eq!(result.as_ref().unwrap(), "value");
shared |= *s;
}
Outcome::Canceled => panic!("leader should complete"),
}
}
assert_eq!(calls.load(Ordering::SeqCst), 1);
assert!(shared);
assert_eq!(group.in_flight(), 0);
}
#[tokio::test]
async fn subscribers_receive_cancellation_when_leader_is_dropped() {
let group = entry_group();
let leader = match group.entry("key") {
Entry::Leader(leader) => leader,
Entry::Subscriber(_) => panic!("first entry must lead"),
};
let subscriber = match group.entry("key") {
Entry::Subscriber(subscriber) => subscriber,
Entry::Leader(_) => panic!("duplicate entry must subscribe"),
};
drop(leader);
let outcome = timeout(Duration::from_secs(1), subscriber.recv())
.await
.expect("subscriber hung")
.expect("subscriber closed");
assert!(matches!(outcome.as_ref(), Outcome::Canceled));
assert_eq!(group.in_flight(), 0);
}
#[tokio::test]
async fn forget_starts_a_new_leader_without_breaking_old_one() {
let group = entry_group();
let first = match group.entry("key") {
Entry::Leader(leader) => leader,
Entry::Subscriber(_) => panic!("first entry must lead"),
};
group.forget("key");
let second = match group.entry("key") {
Entry::Leader(leader) => leader,
Entry::Subscriber(_) => panic!("forgotten key should create a new leader"),
};
let third = match group.entry("key") {
Entry::Subscriber(subscriber) => subscriber,
Entry::Leader(_) => panic!("third entry should subscribe to second leader"),
};
first.complete(Ok(1));
let published = second.complete(Ok(2));
assert!(matches!(
published.as_ref(),
Outcome::Complete {
result: Ok(2),
shared: true
}
));
let received = third.recv().await.expect("third subscriber closed");
assert!(matches!(
received.as_ref(),
Outcome::Complete {
result: Ok(2),
shared: true
}
));
assert_eq!(group.in_flight(), 0);
}
#[tokio::test]
async fn custom_entry_api_allows_external_compute_placement() {
let group = entry_group();
let (release_tx, release_rx) = oneshot::channel();
let leader = match group.entry("key") {
Entry::Leader(leader) => leader,
Entry::Subscriber(_) => panic!("first entry must lead"),
};
let duplicate = match group.entry("key") {
Entry::Subscriber(subscriber) => subscriber,
Entry::Leader(_) => panic!("duplicate entry must subscribe"),
};
let task = tokio::spawn(async move {
release_rx.await.expect("release dropped");
leader.complete(Ok(42))
});
release_tx.send(()).expect("leader task dropped");
assert!(matches!(
duplicate.recv().await.unwrap().as_ref(),
Outcome::Complete {
result: Ok(42),
shared: true
}
));
assert!(task.await.unwrap().is_shared());
}
}