Skip to main content

async_singleflight/
lib.rs

1//! A singleflight implementation for tokio.
2//!
3//! Inspired by [singleflight](https://crates.io/crates/singleflight).
4//!
5//! # Examples
6//!
7//! ```no_run
8//! use futures::future::join_all;
9//! use std::sync::Arc;
10//! use std::time::Duration;
11//!
12//! use async_singleflight::DefaultGroup;
13//!
14//! const RES: usize = 7;
15//!
16//! async fn expensive_fn() -> Result<usize, ()> {
17//!     tokio::time::sleep(Duration::new(1, 500)).await;
18//!     Ok(RES)
19//! }
20//!
21//! #[tokio::main]
22//! async fn main() {
23//!     let g = Arc::new(DefaultGroup::<usize>::new());
24//!     let mut handlers = Vec::new();
25//!     for _ in 0..10 {
26//!         let g = g.clone();
27//!         handlers.push(tokio::spawn(async move {
28//!             let res = g.work("key", expensive_fn()).await;
29//!             let r = res.unwrap();
30//!             println!("{}", r);
31//!         }));
32//!     }
33//!
34//!     join_all(handlers).await;
35//! }
36//! ```
37//!
38
39use std::fmt::{self, Debug};
40use std::future::Future;
41use std::hash::BuildHasher;
42use std::marker::PhantomData;
43use std::pin::Pin;
44use std::task::{Context, Poll};
45
46mod group;
47mod unary;
48
49pub use group::*;
50pub use unary::*;
51
52use pin_project::{pin_project, pinned_drop};
53use std::collections::HashMap;
54use std::hash::Hash;
55use std::hash::RandomState;
56use tokio::sync::{watch, Mutex};
57
58#[derive(Clone)]
59enum State<T> {
60    Starting,
61    LeaderDropped,
62    LeaderFailed,
63    Success(T),
64}
65
66enum ChannelHandler<T> {
67    Sender(watch::Sender<State<T>>),
68    Receiver(watch::Receiver<State<T>>),
69}
70
71#[pin_project(PinnedDrop)]
72struct Leader<T, F, Output>
73where
74    T: Clone,
75    F: Future<Output = Output>,
76{
77    #[pin]
78    fut: F,
79    tx: watch::Sender<State<T>>,
80}
81
82impl<T, F, Output> Leader<T, F, Output>
83where
84    T: Clone,
85    F: Future<Output = Output>,
86{
87    fn new(fut: F, tx: watch::Sender<State<T>>) -> Self {
88        Self { fut, tx }
89    }
90}
91
92#[pinned_drop]
93impl<T, F, Output> PinnedDrop for Leader<T, F, Output>
94where
95    T: Clone,
96    F: Future<Output = Output>,
97{
98    fn drop(self: Pin<&mut Self>) {
99        let this = self.project();
100        let _ = this.tx.send_if_modified(|s| {
101            if matches!(s, State::Starting) {
102                *s = State::LeaderDropped;
103                true
104            } else {
105                false
106            }
107        });
108    }
109}
110
111impl<T, E, F> Future for Leader<T, F, Result<T, E>>
112where
113    T: Clone,
114    F: Future<Output = Result<T, E>>,
115{
116    type Output = Result<T, E>;
117
118    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
119        let this = self.project();
120        let result = this.fut.poll(cx);
121        if let Poll::Ready(val) = &result {
122            let _send = match val {
123                Ok(v) => this.tx.send(State::Success(v.clone())),
124                Err(_) => this.tx.send(State::LeaderFailed),
125            };
126        }
127        result
128    }
129}
130
131impl<T, F> Future for Leader<T, F, T>
132where
133    T: Clone + Send + Sync,
134    F: Future<Output = T>,
135{
136    type Output = T;
137
138    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
139        let this = self.project();
140        let result = this.fut.poll(cx);
141        if let Poll::Ready(val) = &result {
142            let _send = this.tx.send(State::Success(val.clone()));
143        }
144        result
145    }
146}
147
148#[cfg(test)]
149mod tests {
150    use super::*;
151    use std::sync::Arc;
152    use std::time::Duration;
153    use tokio::sync::oneshot;
154
155    async fn return_res() -> Result<usize, ()> {
156        Ok(7)
157    }
158
159    async fn expensive_fn<const RES: usize>(delay: u64) -> Result<usize, ()> {
160        tokio::time::sleep(Duration::from_millis(delay)).await;
161        Ok(RES)
162    }
163
164    async fn expensive_unary_fn<const RES: usize>(delay: u64) -> usize {
165        tokio::time::sleep(Duration::from_millis(delay)).await;
166        RES
167    }
168
169    #[tokio::test]
170    async fn test_simple() {
171        let g = DefaultGroup::new();
172        let res = g.work("key", return_res()).await;
173        let r = res.unwrap();
174        assert_eq!(r, 7);
175    }
176
177    #[tokio::test]
178    async fn test_multiple_threads() {
179        use std::sync::Arc;
180
181        use futures::future::join_all;
182
183        let g = Arc::new(DefaultGroup::new());
184        let mut handlers = Vec::with_capacity(10);
185        for _ in 0..10 {
186            let g = g.clone();
187            handlers.push(tokio::spawn(async move {
188                let res = g.work("key", expensive_fn::<7>(300)).await;
189                let r = res.unwrap();
190                println!("{}", r);
191            }));
192        }
193
194        join_all(handlers).await;
195    }
196
197    #[tokio::test]
198    async fn test_multiple_threads_custom_type() {
199        use std::sync::Arc;
200
201        use futures::future::join_all;
202
203        let g = Arc::new(Group::<u64, usize, ()>::new());
204        let mut handlers = Vec::with_capacity(10);
205        for _ in 0..10 {
206            let g = g.clone();
207            handlers.push(tokio::spawn(async move {
208                let res = g.work(&42, expensive_fn::<8>(300)).await;
209                let r = res.unwrap();
210                println!("{}", r);
211            }));
212        }
213
214        join_all(handlers).await;
215    }
216
217    #[tokio::test]
218    async fn test_multiple_threads_unary() {
219        use std::sync::Arc;
220
221        use futures::future::join_all;
222
223        let g = Arc::new(UnaryGroup::<u64, usize>::new());
224        let mut handlers = Vec::with_capacity(10);
225        for _ in 0..10 {
226            let g = g.clone();
227            handlers.push(tokio::spawn(async move {
228                let res = g.work(&42, expensive_unary_fn::<8>(300)).await;
229                assert_eq!(res, 8);
230            }));
231        }
232
233        join_all(handlers).await;
234    }
235
236    #[tokio::test]
237    async fn test_drop_leader() {
238        let group = Arc::new(DefaultGroup::new());
239
240        // Signal when the leader's inner future gets polled (implies map entry inserted).
241        let (ready_tx, ready_rx) = oneshot::channel::<()>();
242
243        let leader_owned = group.clone();
244        let leader = tokio::spawn(async move {
245            // The inner future signals on first poll, then sleeps long.
246            let fut = async move {
247                let _ = ready_tx.send(());
248                tokio::time::sleep(Duration::from_millis(500)).await;
249                Ok::<usize, ()>(7)
250            };
251            // We expect this task to be aborted before completion.
252            let _ = leader_owned.work("key", fut).await;
253        });
254
255        // Wait until the leader's future has been polled once (map entry is in place).
256        let _ = ready_rx.await;
257
258        // Spawn a follower that will wait on the existing key and should observe LeaderDropped.
259        let follower_owned = group.clone();
260        let follower = tokio::spawn(async move {
261            follower_owned
262                .work("key", async { Ok::<usize, ()>(42) })
263                .await
264        });
265
266        // Give the follower a chance to attach to the receiver.
267        tokio::task::yield_now().await;
268
269        // Abort the leader to trigger LeaderDropped notification to all followers.
270        leader.abort();
271
272        // The follower should return LeaderDropped.
273        let res = tokio::time::timeout(Duration::from_secs(1), follower)
274            .await
275            .expect("follower should finish in time")
276            .expect("follower task should not panic");
277
278        assert_eq!(res, Ok(42));
279    }
280
281    /// Regression test for issue #12: when the leader is dropped, only ONE follower
282    /// should become the new leader. Due to a race condition in the LeaderDropped
283    /// handler (which calls `map.remove(key)` and can delete a newly-inserted entry
284    /// from a follower that already became the new leader), multiple followers can
285    /// each independently become leaders and execute the work function.
286    ///
287    /// The race scenario:
288    /// 1. Leader is dropped; followers A and B both see LeaderDropped
289    /// 2. Follower A acquires the lock and calls remove(key), then loops back
290    ///    into work_inner where it inserts a NEW entry and becomes the new leader
291    /// 3. Follower B acquires the lock for its remove(key) call -- but by now
292    ///    the map entry belongs to the new leader (A). B removes it anyway!
293    /// 4. Follower B loops back into work_inner, sees no entry, inserts its own,
294    ///    and becomes a SECOND independent leader.
295    ///
296    /// We use a multi-threaded runtime and many iterations to trigger this race.
297    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
298    async fn test_leader_drop_single_new_leader() {
299        use std::sync::atomic::{AtomicUsize, Ordering};
300        use tokio::sync::Barrier;
301
302        const NUM_FOLLOWERS: usize = 5;
303
304        // Run the test many times to increase the chance of hitting the race.
305        for iteration in 0..200 {
306            let group = Arc::new(DefaultGroup::new());
307
308            // Counts how many times the actual work function body executes.
309            let execute_count = Arc::new(AtomicUsize::new(0));
310
311            // Signal when the leader's inner future gets polled (map entry inserted).
312            let (leader_ready_tx, leader_ready_rx) = oneshot::channel::<()>();
313
314            // Barrier: all followers + main thread wait here to sync up before
315            // we abort the leader, ensuring followers are subscribed.
316            let barrier = Arc::new(Barrier::new(NUM_FOLLOWERS + 1));
317
318            // Spawn the leader task.
319            let leader_group = group.clone();
320            let leader = tokio::spawn(async move {
321                let fut = async move {
322                    let _ = leader_ready_tx.send(());
323                    tokio::time::sleep(Duration::from_secs(60)).await;
324                    Ok::<usize, ()>(999)
325                };
326                let _ = leader_group.work("key", fut).await;
327            });
328
329            // Wait for the leader's future to be polled (entry in the map).
330            let _ = leader_ready_rx.await;
331
332            let mut follower_handles = Vec::with_capacity(NUM_FOLLOWERS);
333
334            for _ in 0..NUM_FOLLOWERS {
335                let g = group.clone();
336                let cnt = execute_count.clone();
337                let b = barrier.clone();
338                follower_handles.push(tokio::spawn(async move {
339                    // Strategy: each follower signals readiness via the barrier,
340                    // then calls work() which subscribes to the leader's channel.
341                    // When the leader is aborted, followers see LeaderDropped and
342                    // retry via the work() loop. Only one should become the new
343                    // leader; the rest should subscribe to the new leader's channel.
344                    b.wait().await;
345
346                    g.work("key", async move {
347                        cnt.fetch_add(1, Ordering::SeqCst);
348                        // Yield to give other followers a chance to also become
349                        // leaders if the race condition is triggered.
350                        tokio::task::yield_now().await;
351                        Ok::<usize, ()>(42)
352                    })
353                    .await
354                }));
355            }
356
357            // Wait for all followers to be ready to enter work.
358            barrier.wait().await;
359
360            // Give followers time to actually enter work_inner and subscribe
361            // as receivers on the watch channel.
362            tokio::time::sleep(Duration::from_millis(5)).await;
363
364            // Abort the leader. This triggers LeaderDropped to all followers.
365            leader.abort();
366
367            // Wait for all followers to complete.
368            for handle in follower_handles {
369                let res = tokio::time::timeout(Duration::from_secs(5), handle)
370                    .await
371                    .expect("follower should finish in time")
372                    .expect("follower task should not panic");
373                assert_eq!(res, Ok(42), "follower should get the correct result");
374            }
375
376            // The critical assertion: the work function should have executed exactly
377            // once (by the single new leader). If the bug is present, multiple
378            // followers become independent leaders and execute_count will be > 1.
379            let count = execute_count.load(Ordering::SeqCst);
380            assert_eq!(
381                count, 1,
382                "Iteration {}: Expected exactly 1 work execution after leader drop, \
383                 but got {}. This indicates multiple followers became leaders (issue #12).",
384                iteration, count
385            );
386        }
387    }
388
389    #[tokio::test]
390    async fn test_drop_leader_no_retry() {
391        let group = Arc::new(DefaultGroup::<usize>::new());
392
393        // Signal when the leader's inner future gets polled (implies map entry inserted).
394        let (ready_tx, ready_rx) = oneshot::channel::<()>();
395
396        let leader_owned = group.clone();
397        let leader = tokio::spawn(async move {
398            // The inner future signals on first poll, then sleeps long.
399            let fut = async move {
400                let _ = ready_tx.send(());
401                tokio::time::sleep(Duration::from_millis(500)).await;
402                Ok::<usize, ()>(7)
403            };
404            // We expect this task to be aborted before completion.
405            let _ = leader_owned.work("key", fut).await;
406        });
407
408        // Wait until the leader's future has been polled once (map entry is in place).
409        let _ = ready_rx.await;
410
411        // Spawn a follower that will wait on the existing key and should observe LeaderDropped.
412        let follower_owned = group.clone();
413        let follower = tokio::spawn(async move {
414            follower_owned
415                .work_no_retry("key", async { Ok::<usize, ()>(42) })
416                .await
417        });
418
419        // Give the follower a chance to attach to the receiver.
420        tokio::task::yield_now().await;
421
422        // Abort the leader to trigger LeaderDropped notification to all followers.
423        leader.abort();
424
425        // The follower should return LeaderDropped.
426        let res = tokio::time::timeout(Duration::from_secs(1), follower)
427            .await
428            .expect("follower should finish in time")
429            .expect("follower task should not panic");
430
431        assert_eq!(res, Err(GroupWorkError::LeaderDropped));
432    }
433
434    /// Same as test_leader_drop_single_new_leader but for UnaryGroup.
435    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
436    async fn test_leader_drop_single_new_leader_unary() {
437        use std::sync::atomic::{AtomicUsize, Ordering};
438        use tokio::sync::Barrier;
439
440        const NUM_FOLLOWERS: usize = 5;
441
442        for iteration in 0..200 {
443            let group = Arc::new(DefaultUnaryGroup::new());
444            let execute_count = Arc::new(AtomicUsize::new(0));
445            let (leader_ready_tx, leader_ready_rx) = oneshot::channel::<()>();
446            let barrier = Arc::new(Barrier::new(NUM_FOLLOWERS + 1));
447
448            let leader_group = group.clone();
449            let leader = tokio::spawn(async move {
450                let fut = async move {
451                    let _ = leader_ready_tx.send(());
452                    tokio::time::sleep(Duration::from_secs(60)).await;
453                    999_usize
454                };
455                leader_group.work("key", fut).await
456            });
457
458            let _ = leader_ready_rx.await;
459
460            let mut follower_handles = Vec::with_capacity(NUM_FOLLOWERS);
461            for _ in 0..NUM_FOLLOWERS {
462                let g = group.clone();
463                let cnt = execute_count.clone();
464                let b = barrier.clone();
465                follower_handles.push(tokio::spawn(async move {
466                    b.wait().await;
467                    g.work("key", async move {
468                        cnt.fetch_add(1, Ordering::SeqCst);
469                        tokio::task::yield_now().await;
470                        42_usize
471                    })
472                    .await
473                }));
474            }
475
476            barrier.wait().await;
477            tokio::time::sleep(Duration::from_millis(5)).await;
478            leader.abort();
479
480            for handle in follower_handles {
481                let res = tokio::time::timeout(Duration::from_secs(5), handle)
482                    .await
483                    .expect("follower should finish in time")
484                    .expect("follower task should not panic");
485                assert_eq!(res, 42, "follower should get the correct result");
486            }
487
488            let count = execute_count.load(Ordering::SeqCst);
489            assert_eq!(
490                count, 1,
491                "Iteration {}: Expected exactly 1 work execution after leader drop, \
492                 but got {}. This indicates multiple followers became leaders (issue #12).",
493                iteration, count
494            );
495        }
496    }
497
498    /// After a promoted leader completes, its result is cached in the map.
499    /// A fresh caller (not a retrier) should start new work, not return
500    /// the stale cached result.
501    #[tokio::test]
502    async fn test_fresh_caller_replaces_stale_entry() {
503        let group = Arc::new(DefaultGroup::new());
504
505        let (leader_ready_tx, leader_ready_rx) = oneshot::channel::<()>();
506        let leader_group = group.clone();
507        let leader = tokio::spawn(async move {
508            let _ = leader_group
509                .work("key", async move {
510                    let _ = leader_ready_tx.send(());
511                    tokio::time::sleep(Duration::from_secs(60)).await;
512                    Ok::<usize, ()>(999)
513                })
514                .await;
515        });
516        let _ = leader_ready_rx.await;
517
518        // Spawn a follower that will recover after leader drop.
519        let follower_group = group.clone();
520        let follower = tokio::spawn(async move {
521            follower_group
522                .work("key", async { Ok::<usize, ()>(42) })
523                .await
524        });
525        tokio::task::yield_now().await;
526
527        leader.abort();
528        let res = follower.await.unwrap();
529        assert_eq!(res, Ok(42));
530
531        // Now the map has a stale Success(42) entry from the promoted leader.
532        // A fresh caller should start new work and get 99, not the stale 42.
533        let res = group.work("key", async { Ok::<usize, ()>(99) }).await;
534        assert_eq!(res, Ok(99));
535    }
536
537    /// Verify that purge_stale removes completed entries and that
538    /// subsequent calls create fresh leaders.
539    #[tokio::test]
540    async fn test_purge_stale() {
541        let group = Arc::new(DefaultGroup::new());
542
543        let (leader_ready_tx, leader_ready_rx) = oneshot::channel::<()>();
544        let leader_group = group.clone();
545        let leader = tokio::spawn(async move {
546            let _ = leader_group
547                .work("key", async move {
548                    let _ = leader_ready_tx.send(());
549                    tokio::time::sleep(Duration::from_secs(60)).await;
550                    Ok::<usize, ()>(999)
551                })
552                .await;
553        });
554        let _ = leader_ready_rx.await;
555
556        let follower_group = group.clone();
557        let follower = tokio::spawn(async move {
558            follower_group
559                .work("key", async { Ok::<usize, ()>(42) })
560                .await
561        });
562        tokio::task::yield_now().await;
563
564        leader.abort();
565        let res = follower.await.unwrap();
566        assert_eq!(res, Ok(42));
567
568        // Stale entry exists; purge it.
569        group.purge_stale().await;
570
571        // After purge, a new call should work normally.
572        let res = group.work("key", async { Ok::<usize, ()>(77) }).await;
573        assert_eq!(res, Ok(77));
574    }
575
576    /// Verify purge_stale works for UnaryGroup.
577    #[tokio::test]
578    async fn test_purge_stale_unary() {
579        let group = Arc::new(DefaultUnaryGroup::new());
580
581        let (leader_ready_tx, leader_ready_rx) = oneshot::channel::<()>();
582        let leader_group = group.clone();
583        let leader = tokio::spawn(async move {
584            let fut = async move {
585                let _ = leader_ready_tx.send(());
586                tokio::time::sleep(Duration::from_secs(60)).await;
587                999_usize
588            };
589            leader_group.work("key", fut).await
590        });
591        let _ = leader_ready_rx.await;
592
593        let follower_group = group.clone();
594        let follower =
595            tokio::spawn(async move { follower_group.work("key", async { 42_usize }).await });
596        tokio::task::yield_now().await;
597
598        leader.abort();
599        let res = follower.await.unwrap();
600        assert_eq!(res, 42);
601
602        group.purge_stale().await;
603
604        let res = group.work("key", async { 77_usize }).await;
605        assert_eq!(res, 77);
606    }
607}