1#![doc = include_str!("../README.md")]
2
3use std::{
4 collections::{HashMap, hash_map::RandomState},
5 fmt,
6 future::Future,
7 hash::{BuildHasher, Hash},
8 sync::{
9 Arc, Mutex, Weak,
10 atomic::{AtomicBool, AtomicUsize, Ordering},
11 },
12};
13
14use tokio::sync::broadcast;
15
16type SharedOutcome<T, E> = Arc<Outcome<T, E>>;
17type Calls<K, T, E, S> = HashMap<K, Weak<Call<K, T, E, S>>, S>;
18
19#[derive(Debug)]
21pub enum Outcome<T, E> {
22 Complete { result: Result<T, E>, shared: bool },
24 Canceled,
26}
27
28impl<T, E> Outcome<T, E> {
29 pub fn is_shared(&self) -> bool {
30 matches!(self, Self::Complete { shared: true, .. })
31 }
32
33 pub fn result(&self) -> Option<&Result<T, E>> {
34 match self {
35 Self::Complete { result, .. } => Some(result),
36 Self::Canceled => None,
37 }
38 }
39}
40
41#[derive(Debug, Clone, Copy, PartialEq, Eq)]
43pub enum WaitError {
44 Closed,
46 Lagged(u64),
48}
49
50impl fmt::Display for WaitError {
51 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
52 match self {
53 Self::Closed => f.write_str("singleflight result channel closed"),
54 Self::Lagged(n) => write!(f, "singleflight subscriber lagged by {n} messages"),
55 }
56 }
57}
58
59impl std::error::Error for WaitError {}
60
61pub struct Group<K, T, E, S = RandomState> {
66 inner: Arc<Inner<K, T, E, S>>,
67}
68
69impl<K, T, E> Group<K, T, E, RandomState> {
70 pub fn new() -> Self {
71 Self::with_hasher(RandomState::new())
72 }
73}
74
75impl<K, T, E, S> Group<K, T, E, S> {
76 pub fn with_hasher(hasher: S) -> Self {
77 Self {
78 inner: Arc::new(Inner {
79 calls: Mutex::new(HashMap::with_hasher(hasher)),
80 }),
81 }
82 }
83}
84
85impl<K, T, E, S> Clone for Group<K, T, E, S> {
86 fn clone(&self) -> Self {
87 Self {
88 inner: Arc::clone(&self.inner),
89 }
90 }
91}
92
93impl<K, T, E> Default for Group<K, T, E, RandomState> {
94 fn default() -> Self {
95 Self::new()
96 }
97}
98
99impl<K, T, E, S> Group<K, T, E, S>
100where
101 K: Eq + Hash,
102 S: BuildHasher,
103{
104 pub fn entry(&self, key: K) -> Entry<K, T, E, S> {
106 let mut calls = self
107 .inner
108 .calls
109 .lock()
110 .expect("singleflight mutex poisoned");
111
112 if let Some(call) = calls.get(&key).and_then(Weak::upgrade) {
113 return Entry::Subscriber(call.subscribe());
114 }
115
116 let call = Arc::new(Call::new(Arc::downgrade(&self.inner)));
117 calls.insert(key, Arc::downgrade(&call));
118 Entry::Leader(Leader { call: Some(call) })
119 }
120
121 pub async fn run<F, Fut>(&self, key: K, f: F) -> SharedOutcome<T, E>
123 where
124 F: FnOnce() -> Fut,
125 Fut: Future<Output = Result<T, E>>,
126 {
127 match self.entry(key) {
128 Entry::Leader(leader) => {
129 let result = f().await;
130 leader.complete(result)
131 }
132 Entry::Subscriber(subscriber) => subscriber
133 .recv()
134 .await
135 .unwrap_or_else(|_| Arc::new(Outcome::Canceled)),
136 }
137 }
138
139 pub fn forget<Q>(&self, key: &Q)
142 where
143 K: std::borrow::Borrow<Q>,
144 Q: Hash + Eq + ?Sized,
145 {
146 self.inner
147 .calls
148 .lock()
149 .expect("singleflight mutex poisoned")
150 .remove(key);
151 }
152
153 pub fn in_flight(&self) -> usize {
154 self.inner
155 .calls
156 .lock()
157 .expect("singleflight mutex poisoned")
158 .len()
159 }
160}
161
162pub enum Entry<K, T, E, S = RandomState> {
164 Leader(Leader<K, T, E, S>),
165 Subscriber(Subscriber<T, E>),
166}
167
168pub struct Leader<K, T, E, S = RandomState> {
173 call: Option<Arc<Call<K, T, E, S>>>,
174}
175
176impl<K, T, E, S> Leader<K, T, E, S>
177where
178 K: Eq + Hash,
179 S: BuildHasher,
180{
181 pub fn complete(mut self, result: Result<T, E>) -> SharedOutcome<T, E> {
182 let call = self.call.take().expect("leader completed twice");
183 call.cleanup();
184 let shared = call.waiters.load(Ordering::SeqCst) > 0;
185 let outcome = Arc::new(Outcome::Complete { result, shared });
186 call.publish(Arc::clone(&outcome));
187 outcome
188 }
189
190 pub fn subscribe(&self) -> Subscriber<T, E> {
191 self.call
192 .as_ref()
193 .expect("leader already completed")
194 .subscribe()
195 }
196
197 pub fn duplicate_count(&self) -> usize {
198 self.call
199 .as_ref()
200 .map(|call| call.waiters.load(Ordering::SeqCst))
201 .unwrap_or(0)
202 }
203}
204
205impl<K, T, E, S> Drop for Leader<K, T, E, S> {
206 fn drop(&mut self) {
207 if let Some(call) = self.call.take() {
208 call.cancel();
209 }
210 }
211}
212
213pub struct Subscriber<T, E> {
215 rx: broadcast::Receiver<SharedOutcome<T, E>>,
216}
217
218impl<T, E> Subscriber<T, E> {
219 pub async fn recv(mut self) -> Result<SharedOutcome<T, E>, WaitError> {
220 match self.rx.recv().await {
221 Ok(outcome) => Ok(outcome),
222 Err(broadcast::error::RecvError::Closed) => Err(WaitError::Closed),
223 Err(broadcast::error::RecvError::Lagged(n)) => Err(WaitError::Lagged(n)),
224 }
225 }
226}
227
228struct Inner<K, T, E, S> {
229 calls: Mutex<Calls<K, T, E, S>>,
230}
231
232struct Call<K, T, E, S> {
233 group: Weak<Inner<K, T, E, S>>,
234 tx: broadcast::Sender<SharedOutcome<T, E>>,
235 waiters: AtomicUsize,
236 finished: AtomicBool,
237}
238
239impl<K, T, E, S> Call<K, T, E, S> {
240 fn new(group: Weak<Inner<K, T, E, S>>) -> Self {
241 let (tx, _) = broadcast::channel(1);
242 Self {
243 group,
244 tx,
245 waiters: AtomicUsize::new(0),
246 finished: AtomicBool::new(false),
247 }
248 }
249
250 fn subscribe(&self) -> Subscriber<T, E> {
251 self.waiters.fetch_add(1, Ordering::SeqCst);
252 Subscriber {
253 rx: self.tx.subscribe(),
254 }
255 }
256
257 fn publish(&self, outcome: SharedOutcome<T, E>) {
258 if !self.finished.swap(true, Ordering::SeqCst) {
259 let _ = self.tx.send(outcome);
260 }
261 }
262
263 fn cancel(&self) {
264 self.cleanup();
265 self.publish(Arc::new(Outcome::Canceled));
266 }
267
268 fn cleanup(&self) {
269 let Some(group) = self.group.upgrade() else {
270 return;
271 };
272
273 let mut calls = group.calls.lock().expect("singleflight mutex poisoned");
274 calls.retain(|_, existing| {
275 existing
276 .upgrade()
277 .is_some_and(|call| !std::ptr::eq(call.as_ref(), self))
278 });
279 }
280}
281
282#[cfg(test)]
283mod tests {
284 use super::*;
285 use std::sync::{
286 Arc,
287 atomic::{AtomicUsize, Ordering},
288 };
289 use tokio::{
290 sync::{Barrier, oneshot},
291 time::{Duration, sleep, timeout},
292 };
293
294 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
295 async fn suppresses_duplicate_calls() {
296 let group = Arc::new(Group::<String, String, ()>::new());
297 let calls = Arc::new(AtomicUsize::new(0));
298 let barrier = Arc::new(Barrier::new(12));
299 let mut tasks = Vec::new();
300
301 for _ in 0..12 {
302 let group = Arc::clone(&group);
303 let calls = Arc::clone(&calls);
304 let barrier = Arc::clone(&barrier);
305 tasks.push(tokio::spawn(async move {
306 barrier.wait().await;
307 group
308 .run("key".to_owned(), || async {
309 calls.fetch_add(1, Ordering::SeqCst);
310 sleep(Duration::from_millis(20)).await;
311 Ok("value".to_owned())
312 })
313 .await
314 }));
315 }
316
317 let mut shared = false;
318 for task in tasks {
319 let outcome = task.await.expect("task panicked");
320 match outcome.as_ref() {
321 Outcome::Complete { result, shared: s } => {
322 assert_eq!(result.as_ref().unwrap(), "value");
323 shared |= *s;
324 }
325 Outcome::Canceled => panic!("leader should complete"),
326 }
327 }
328
329 assert_eq!(calls.load(Ordering::SeqCst), 1);
330 assert!(shared);
331 assert_eq!(group.in_flight(), 0);
332 }
333
334 #[tokio::test]
335 async fn subscribers_receive_cancellation_when_leader_is_dropped() {
336 let group = Group::<&'static str, usize, ()>::new();
337 let leader = match group.entry("key") {
338 Entry::Leader(leader) => leader,
339 Entry::Subscriber(_) => panic!("first entry must lead"),
340 };
341 let subscriber = match group.entry("key") {
342 Entry::Subscriber(subscriber) => subscriber,
343 Entry::Leader(_) => panic!("duplicate entry must subscribe"),
344 };
345
346 drop(leader);
347
348 let outcome = timeout(Duration::from_secs(1), subscriber.recv())
349 .await
350 .expect("subscriber hung")
351 .expect("subscriber closed");
352 assert!(matches!(outcome.as_ref(), Outcome::Canceled));
353 assert_eq!(group.in_flight(), 0);
354 }
355
356 #[tokio::test]
357 async fn forget_starts_a_new_leader_without_breaking_old_one() {
358 let group = Group::<&'static str, usize, ()>::new();
359 let first = match group.entry("key") {
360 Entry::Leader(leader) => leader,
361 Entry::Subscriber(_) => panic!("first entry must lead"),
362 };
363
364 group.forget("key");
365
366 let second = match group.entry("key") {
367 Entry::Leader(leader) => leader,
368 Entry::Subscriber(_) => panic!("forgotten key should create a new leader"),
369 };
370 let third = match group.entry("key") {
371 Entry::Subscriber(subscriber) => subscriber,
372 Entry::Leader(_) => panic!("third entry should subscribe to second leader"),
373 };
374
375 first.complete(Ok(1));
376 let published = second.complete(Ok(2));
377 assert!(matches!(
378 published.as_ref(),
379 Outcome::Complete {
380 result: Ok(2),
381 shared: true
382 }
383 ));
384
385 let received = third.recv().await.expect("third subscriber closed");
386 assert!(matches!(
387 received.as_ref(),
388 Outcome::Complete {
389 result: Ok(2),
390 shared: true
391 }
392 ));
393 assert_eq!(group.in_flight(), 0);
394 }
395
396 #[tokio::test]
397 async fn custom_entry_api_allows_external_compute_placement() {
398 let group = Group::<&'static str, usize, ()>::new();
399 let (release_tx, release_rx) = oneshot::channel();
400
401 let leader = match group.entry("key") {
402 Entry::Leader(leader) => leader,
403 Entry::Subscriber(_) => panic!("first entry must lead"),
404 };
405 let duplicate = match group.entry("key") {
406 Entry::Subscriber(subscriber) => subscriber,
407 Entry::Leader(_) => panic!("duplicate entry must subscribe"),
408 };
409
410 let task = tokio::spawn(async move {
411 release_rx.await.expect("release dropped");
412 leader.complete(Ok(42))
413 });
414
415 release_tx.send(()).expect("leader task dropped");
416 assert!(matches!(
417 duplicate.recv().await.unwrap().as_ref(),
418 Outcome::Complete {
419 result: Ok(42),
420 shared: true
421 }
422 ));
423 assert!(task.await.unwrap().is_shared());
424 }
425}