1#![doc = include_str!("../README.md")]
2
3use std::{
4 collections::{HashMap, hash_map::RandomState},
5 fmt,
6 future::Future,
7 hash::{BuildHasher, Hash},
8 sync::{
9 Arc, Mutex, Weak,
10 atomic::{AtomicBool, AtomicUsize, Ordering},
11 },
12};
13
14use tokio::sync::broadcast;
15
16type SharedOutcome<T, E> = Arc<Outcome<T, E>>;
17type Calls<K, T, E, S> = HashMap<K, Weak<Call<K, T, E, S>>, S>;
18
19#[derive(Debug)]
21pub enum Outcome<T, E> {
22 Complete { result: Result<T, E>, shared: bool },
24 Canceled,
26}
27
28impl<T, E> Outcome<T, E> {
29 pub fn is_shared(&self) -> bool {
30 matches!(self, Self::Complete { shared: true, .. })
31 }
32
33 pub fn result(&self) -> Option<&Result<T, E>> {
34 match self {
35 Self::Complete { result, .. } => Some(result),
36 Self::Canceled => None,
37 }
38 }
39}
40
41#[derive(Debug, Clone, Copy, PartialEq, Eq)]
43pub enum WaitError {
44 Closed,
46 Lagged(u64),
48}
49
50impl fmt::Display for WaitError {
51 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
52 match self {
53 Self::Closed => f.write_str("singleflight result channel closed"),
54 Self::Lagged(n) => write!(f, "singleflight subscriber lagged by {n} messages"),
55 }
56 }
57}
58
59impl std::error::Error for WaitError {}
60
61pub struct Group<K, T, E, F, S = RandomState> {
66 inner: Arc<Inner<K, T, E, S>>,
67 op: Arc<F>,
68}
69
70impl<K, T, E, F, Fut> Group<K, T, E, F, RandomState>
71where
72 F: Fn(K) -> Fut,
73 Fut: Future<Output = Result<T, E>>,
74{
75 pub fn new(op: F) -> Self {
76 Self::with_hasher(op, RandomState::new())
77 }
78}
79
80impl<K, T, E, F, S> Group<K, T, E, F, S> {
81 pub fn with_hasher(op: F, hasher: S) -> Self {
82 Self {
83 inner: Arc::new(Inner {
84 calls: Mutex::new(HashMap::with_hasher(hasher)),
85 }),
86 op: Arc::new(op),
87 }
88 }
89}
90
91impl<K, T, E, F, S> Clone for Group<K, T, E, F, S> {
92 fn clone(&self) -> Self {
93 Self {
94 inner: Arc::clone(&self.inner),
95 op: Arc::clone(&self.op),
96 }
97 }
98}
99
100impl<K, T, E, F, S> Group<K, T, E, F, S>
101where
102 K: Eq + Hash,
103 S: BuildHasher,
104{
105 pub fn entry(&self, key: K) -> Entry<K, T, E, S> {
107 let mut calls = self
108 .inner
109 .calls
110 .lock()
111 .expect("singleflight mutex poisoned");
112
113 if let Some(call) = calls.get(&key).and_then(Weak::upgrade) {
114 return Entry::Subscriber(call.subscribe());
115 }
116
117 let call = Arc::new(Call::new(Arc::downgrade(&self.inner)));
118 calls.insert(key, Arc::downgrade(&call));
119 Entry::Leader(Leader { call: Some(call) })
120 }
121
122 pub async fn run<Fut>(&self, key: K) -> SharedOutcome<T, E>
124 where
125 K: Clone,
126 F: Fn(K) -> Fut,
127 Fut: Future<Output = Result<T, E>>,
128 {
129 match self.entry(key.clone()) {
130 Entry::Leader(leader) => {
131 let result = (self.op)(key).await;
132 leader.complete(result)
133 }
134 Entry::Subscriber(subscriber) => subscriber
135 .recv()
136 .await
137 .unwrap_or_else(|_| Arc::new(Outcome::Canceled)),
138 }
139 }
140
141 pub fn forget<Q>(&self, key: &Q)
144 where
145 K: std::borrow::Borrow<Q>,
146 Q: Hash + Eq + ?Sized,
147 {
148 self.inner
149 .calls
150 .lock()
151 .expect("singleflight mutex poisoned")
152 .remove(key);
153 }
154
155 pub fn in_flight(&self) -> usize {
156 self.inner
157 .calls
158 .lock()
159 .expect("singleflight mutex poisoned")
160 .len()
161 }
162}
163
164pub enum Entry<K, T, E, S = RandomState> {
166 Leader(Leader<K, T, E, S>),
167 Subscriber(Subscriber<T, E>),
168}
169
170pub struct Leader<K, T, E, S = RandomState> {
175 call: Option<Arc<Call<K, T, E, S>>>,
176}
177
178impl<K, T, E, S> Leader<K, T, E, S>
179where
180 K: Eq + Hash,
181 S: BuildHasher,
182{
183 pub fn complete(mut self, result: Result<T, E>) -> SharedOutcome<T, E> {
184 let call = self.call.take().expect("leader completed twice");
185 call.cleanup();
186 let shared = call.waiters.load(Ordering::SeqCst) > 0;
187 let outcome = Arc::new(Outcome::Complete { result, shared });
188 call.publish(Arc::clone(&outcome));
189 outcome
190 }
191
192 pub fn subscribe(&self) -> Subscriber<T, E> {
193 self.call
194 .as_ref()
195 .expect("leader already completed")
196 .subscribe()
197 }
198
199 pub fn duplicate_count(&self) -> usize {
200 self.call
201 .as_ref()
202 .map(|call| call.waiters.load(Ordering::SeqCst))
203 .unwrap_or(0)
204 }
205}
206
207impl<K, T, E, S> Drop for Leader<K, T, E, S> {
208 fn drop(&mut self) {
209 if let Some(call) = self.call.take() {
210 call.cancel();
211 }
212 }
213}
214
215pub struct Subscriber<T, E> {
217 rx: broadcast::Receiver<SharedOutcome<T, E>>,
218}
219
220impl<T, E> Subscriber<T, E> {
221 pub async fn recv(mut self) -> Result<SharedOutcome<T, E>, WaitError> {
222 match self.rx.recv().await {
223 Ok(outcome) => Ok(outcome),
224 Err(broadcast::error::RecvError::Closed) => Err(WaitError::Closed),
225 Err(broadcast::error::RecvError::Lagged(n)) => Err(WaitError::Lagged(n)),
226 }
227 }
228}
229
230struct Inner<K, T, E, S> {
231 calls: Mutex<Calls<K, T, E, S>>,
232}
233
234struct Call<K, T, E, S> {
235 group: Weak<Inner<K, T, E, S>>,
236 tx: broadcast::Sender<SharedOutcome<T, E>>,
237 waiters: AtomicUsize,
238 finished: AtomicBool,
239}
240
241impl<K, T, E, S> Call<K, T, E, S> {
242 fn new(group: Weak<Inner<K, T, E, S>>) -> Self {
243 let (tx, _) = broadcast::channel(1);
244 Self {
245 group,
246 tx,
247 waiters: AtomicUsize::new(0),
248 finished: AtomicBool::new(false),
249 }
250 }
251
252 fn subscribe(&self) -> Subscriber<T, E> {
253 self.waiters.fetch_add(1, Ordering::SeqCst);
254 Subscriber {
255 rx: self.tx.subscribe(),
256 }
257 }
258
259 fn publish(&self, outcome: SharedOutcome<T, E>) {
260 if !self.finished.swap(true, Ordering::SeqCst) {
261 let _ = self.tx.send(outcome);
262 }
263 }
264
265 fn cancel(&self) {
266 self.cleanup();
267 self.publish(Arc::new(Outcome::Canceled));
268 }
269
270 fn cleanup(&self) {
271 let Some(group) = self.group.upgrade() else {
272 return;
273 };
274
275 let mut calls = group.calls.lock().expect("singleflight mutex poisoned");
276 calls.retain(|_, existing| {
277 existing
278 .upgrade()
279 .is_some_and(|call| !std::ptr::eq(call.as_ref(), self))
280 });
281 }
282}
283
284#[cfg(test)]
285mod tests {
286 use super::*;
287 use std::{
288 future::{Ready, ready},
289 sync::{
290 Arc,
291 atomic::{AtomicUsize, Ordering},
292 },
293 };
294 use tokio::{
295 sync::{Barrier, oneshot},
296 time::{Duration, sleep, timeout},
297 };
298
299 type EntryGroup = Group<&'static str, usize, (), fn(&'static str) -> Ready<Result<usize, ()>>>;
300
301 fn entry_op(_: &'static str) -> Ready<Result<usize, ()>> {
302 ready(Ok(0))
303 }
304
305 fn entry_group() -> EntryGroup {
306 Group::new(entry_op)
307 }
308
309 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
310 async fn suppresses_duplicate_calls() {
311 let calls = Arc::new(AtomicUsize::new(0));
312 let calls_for_op = Arc::clone(&calls);
313 let group = Arc::new(Group::new(move |key: String| {
314 let calls = Arc::clone(&calls_for_op);
315 async move {
316 assert_eq!(key, "key");
317 calls.fetch_add(1, Ordering::SeqCst);
318 sleep(Duration::from_millis(20)).await;
319 Ok::<String, ()>("value".to_owned())
320 }
321 }));
322 let barrier = Arc::new(Barrier::new(12));
323 let mut tasks = Vec::new();
324
325 for _ in 0..12 {
326 let group = Arc::clone(&group);
327 let barrier = Arc::clone(&barrier);
328 tasks.push(tokio::spawn(async move {
329 barrier.wait().await;
330 group.run("key".to_owned()).await
331 }));
332 }
333
334 let mut shared = false;
335 for task in tasks {
336 let outcome = task.await.expect("task panicked");
337 match outcome.as_ref() {
338 Outcome::Complete { result, shared: s } => {
339 assert_eq!(result.as_ref().unwrap(), "value");
340 shared |= *s;
341 }
342 Outcome::Canceled => panic!("leader should complete"),
343 }
344 }
345
346 assert_eq!(calls.load(Ordering::SeqCst), 1);
347 assert!(shared);
348 assert_eq!(group.in_flight(), 0);
349 }
350
351 #[tokio::test]
352 async fn subscribers_receive_cancellation_when_leader_is_dropped() {
353 let group = entry_group();
354 let leader = match group.entry("key") {
355 Entry::Leader(leader) => leader,
356 Entry::Subscriber(_) => panic!("first entry must lead"),
357 };
358 let subscriber = match group.entry("key") {
359 Entry::Subscriber(subscriber) => subscriber,
360 Entry::Leader(_) => panic!("duplicate entry must subscribe"),
361 };
362
363 drop(leader);
364
365 let outcome = timeout(Duration::from_secs(1), subscriber.recv())
366 .await
367 .expect("subscriber hung")
368 .expect("subscriber closed");
369 assert!(matches!(outcome.as_ref(), Outcome::Canceled));
370 assert_eq!(group.in_flight(), 0);
371 }
372
373 #[tokio::test]
374 async fn forget_starts_a_new_leader_without_breaking_old_one() {
375 let group = entry_group();
376 let first = match group.entry("key") {
377 Entry::Leader(leader) => leader,
378 Entry::Subscriber(_) => panic!("first entry must lead"),
379 };
380
381 group.forget("key");
382
383 let second = match group.entry("key") {
384 Entry::Leader(leader) => leader,
385 Entry::Subscriber(_) => panic!("forgotten key should create a new leader"),
386 };
387 let third = match group.entry("key") {
388 Entry::Subscriber(subscriber) => subscriber,
389 Entry::Leader(_) => panic!("third entry should subscribe to second leader"),
390 };
391
392 first.complete(Ok(1));
393 let published = second.complete(Ok(2));
394 assert!(matches!(
395 published.as_ref(),
396 Outcome::Complete {
397 result: Ok(2),
398 shared: true
399 }
400 ));
401
402 let received = third.recv().await.expect("third subscriber closed");
403 assert!(matches!(
404 received.as_ref(),
405 Outcome::Complete {
406 result: Ok(2),
407 shared: true
408 }
409 ));
410 assert_eq!(group.in_flight(), 0);
411 }
412
413 #[tokio::test]
414 async fn custom_entry_api_allows_external_compute_placement() {
415 let group = entry_group();
416 let (release_tx, release_rx) = oneshot::channel();
417
418 let leader = match group.entry("key") {
419 Entry::Leader(leader) => leader,
420 Entry::Subscriber(_) => panic!("first entry must lead"),
421 };
422 let duplicate = match group.entry("key") {
423 Entry::Subscriber(subscriber) => subscriber,
424 Entry::Leader(_) => panic!("duplicate entry must subscribe"),
425 };
426
427 let task = tokio::spawn(async move {
428 release_rx.await.expect("release dropped");
429 leader.complete(Ok(42))
430 });
431
432 release_tx.send(()).expect("leader task dropped");
433 assert!(matches!(
434 duplicate.recv().await.unwrap().as_ref(),
435 Outcome::Complete {
436 result: Ok(42),
437 shared: true
438 }
439 ));
440 assert!(task.await.unwrap().is_shared());
441 }
442}