Skip to main content

rs_singleflight/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use std::{
4    collections::{HashMap, hash_map::RandomState},
5    fmt,
6    future::Future,
7    hash::{BuildHasher, Hash},
8    sync::{
9        Arc, Mutex, Weak,
10        atomic::{AtomicBool, AtomicUsize, Ordering},
11    },
12};
13
14use tokio::sync::broadcast;
15
16type SharedOutcome<T, E> = Arc<Outcome<T, E>>;
17type Calls<K, T, E, S> = HashMap<K, Weak<Call<K, T, E, S>>, S>;
18
19/// Result published by the single in-flight computation.
20#[derive(Debug)]
21pub enum Outcome<T, E> {
22    /// The leader completed the computation.
23    Complete { result: Result<T, E>, shared: bool },
24    /// The leader future was dropped before it completed.
25    Canceled,
26}
27
28impl<T, E> Outcome<T, E> {
29    pub fn is_shared(&self) -> bool {
30        matches!(self, Self::Complete { shared: true, .. })
31    }
32
33    pub fn result(&self) -> Option<&Result<T, E>> {
34        match self {
35            Self::Complete { result, .. } => Some(result),
36            Self::Canceled => None,
37        }
38    }
39}
40
41/// Error returned when a subscriber cannot receive a leader result.
42#[derive(Debug, Clone, Copy, PartialEq, Eq)]
43pub enum WaitError {
44    /// The broadcast channel closed before an outcome was available.
45    Closed,
46    /// The subscriber lagged behind the broadcast channel.
47    Lagged(u64),
48}
49
50impl fmt::Display for WaitError {
51    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
52        match self {
53            Self::Closed => f.write_str("singleflight result channel closed"),
54            Self::Lagged(n) => write!(f, "singleflight subscriber lagged by {n} messages"),
55        }
56    }
57}
58
59impl std::error::Error for WaitError {}
60
61/// Namespace for duplicate suppression.
62///
63/// For a given key, only the leader computes. Duplicate callers subscribe to
64/// the leader's broadcast and receive the same [`Outcome`].
65pub struct Group<K, T, E, F, S = RandomState> {
66    inner: Arc<Inner<K, T, E, S>>,
67    op: Arc<F>,
68}
69
70impl<K, T, E, F, Fut> Group<K, T, E, F, RandomState>
71where
72    F: Fn(K) -> Fut,
73    Fut: Future<Output = Result<T, E>>,
74{
75    pub fn new(op: F) -> Self {
76        Self::with_hasher(op, RandomState::new())
77    }
78}
79
80impl<K, T, E, F, S> Group<K, T, E, F, S> {
81    pub fn with_hasher(op: F, hasher: S) -> Self {
82        Self {
83            inner: Arc::new(Inner {
84                calls: Mutex::new(HashMap::with_hasher(hasher)),
85            }),
86            op: Arc::new(op),
87        }
88    }
89}
90
91impl<K, T, E, F, S> Clone for Group<K, T, E, F, S> {
92    fn clone(&self) -> Self {
93        Self {
94            inner: Arc::clone(&self.inner),
95            op: Arc::clone(&self.op),
96        }
97    }
98}
99
100impl<K, T, E, F, S> Group<K, T, E, F, S>
101where
102    K: Eq + Hash,
103    S: BuildHasher,
104{
105    /// Returns a leader for a new key, or a subscriber for an in-flight key.
106    pub fn entry(&self, key: K) -> Entry<K, T, E, S> {
107        let mut calls = self
108            .inner
109            .calls
110            .lock()
111            .expect("singleflight mutex poisoned");
112
113        if let Some(call) = calls.get(&key).and_then(Weak::upgrade) {
114            return Entry::Subscriber(call.subscribe());
115        }
116
117        let call = Arc::new(Call::new(Arc::downgrade(&self.inner)));
118        calls.insert(key, Arc::downgrade(&call));
119        Entry::Leader(Leader { call: Some(call) })
120    }
121
122    /// Executes this group's operation once per key while an earlier call is in flight.
123    pub async fn run<Fut>(&self, key: K) -> SharedOutcome<T, E>
124    where
125        K: Clone,
126        F: Fn(K) -> Fut,
127        Fut: Future<Output = Result<T, E>>,
128    {
129        match self.entry(key.clone()) {
130            Entry::Leader(leader) => {
131                let result = (self.op)(key).await;
132                leader.complete(result)
133            }
134            Entry::Subscriber(subscriber) => subscriber
135                .recv()
136                .await
137                .unwrap_or_else(|_| Arc::new(Outcome::Canceled)),
138        }
139    }
140
141    /// Forgets a key so the next [`entry`](Self::entry) or [`run`](Self::run)
142    /// starts a fresh leader instead of joining the current call.
143    pub fn forget<Q>(&self, key: &Q)
144    where
145        K: std::borrow::Borrow<Q>,
146        Q: Hash + Eq + ?Sized,
147    {
148        self.inner
149            .calls
150            .lock()
151            .expect("singleflight mutex poisoned")
152            .remove(key);
153    }
154
155    pub fn in_flight(&self) -> usize {
156        self.inner
157            .calls
158            .lock()
159            .expect("singleflight mutex poisoned")
160            .len()
161    }
162}
163
164/// Returned by [`Group::entry`].
165pub enum Entry<K, T, E, S = RandomState> {
166    Leader(Leader<K, T, E, S>),
167    Subscriber(Subscriber<T, E>),
168}
169
170/// Owner of the single computation for a key.
171///
172/// Dropping a leader before calling [`complete`](Self::complete) publishes
173/// [`Outcome::Canceled`] to subscribers and removes the key from the group.
174pub struct Leader<K, T, E, S = RandomState> {
175    call: Option<Arc<Call<K, T, E, S>>>,
176}
177
178impl<K, T, E, S> Leader<K, T, E, S>
179where
180    K: Eq + Hash,
181    S: BuildHasher,
182{
183    pub fn complete(mut self, result: Result<T, E>) -> SharedOutcome<T, E> {
184        let call = self.call.take().expect("leader completed twice");
185        call.cleanup();
186        let shared = call.waiters.load(Ordering::SeqCst) > 0;
187        let outcome = Arc::new(Outcome::Complete { result, shared });
188        call.publish(Arc::clone(&outcome));
189        outcome
190    }
191
192    pub fn subscribe(&self) -> Subscriber<T, E> {
193        self.call
194            .as_ref()
195            .expect("leader already completed")
196            .subscribe()
197    }
198
199    pub fn duplicate_count(&self) -> usize {
200        self.call
201            .as_ref()
202            .map(|call| call.waiters.load(Ordering::SeqCst))
203            .unwrap_or(0)
204    }
205}
206
207impl<K, T, E, S> Drop for Leader<K, T, E, S> {
208    fn drop(&mut self) {
209        if let Some(call) = self.call.take() {
210            call.cancel();
211        }
212    }
213}
214
215/// Receiver for a duplicate caller.
216pub struct Subscriber<T, E> {
217    rx: broadcast::Receiver<SharedOutcome<T, E>>,
218}
219
220impl<T, E> Subscriber<T, E> {
221    pub async fn recv(mut self) -> Result<SharedOutcome<T, E>, WaitError> {
222        match self.rx.recv().await {
223            Ok(outcome) => Ok(outcome),
224            Err(broadcast::error::RecvError::Closed) => Err(WaitError::Closed),
225            Err(broadcast::error::RecvError::Lagged(n)) => Err(WaitError::Lagged(n)),
226        }
227    }
228}
229
230struct Inner<K, T, E, S> {
231    calls: Mutex<Calls<K, T, E, S>>,
232}
233
234struct Call<K, T, E, S> {
235    group: Weak<Inner<K, T, E, S>>,
236    tx: broadcast::Sender<SharedOutcome<T, E>>,
237    waiters: AtomicUsize,
238    finished: AtomicBool,
239}
240
241impl<K, T, E, S> Call<K, T, E, S> {
242    fn new(group: Weak<Inner<K, T, E, S>>) -> Self {
243        let (tx, _) = broadcast::channel(1);
244        Self {
245            group,
246            tx,
247            waiters: AtomicUsize::new(0),
248            finished: AtomicBool::new(false),
249        }
250    }
251
252    fn subscribe(&self) -> Subscriber<T, E> {
253        self.waiters.fetch_add(1, Ordering::SeqCst);
254        Subscriber {
255            rx: self.tx.subscribe(),
256        }
257    }
258
259    fn publish(&self, outcome: SharedOutcome<T, E>) {
260        if !self.finished.swap(true, Ordering::SeqCst) {
261            let _ = self.tx.send(outcome);
262        }
263    }
264
265    fn cancel(&self) {
266        self.cleanup();
267        self.publish(Arc::new(Outcome::Canceled));
268    }
269
270    fn cleanup(&self) {
271        let Some(group) = self.group.upgrade() else {
272            return;
273        };
274
275        let mut calls = group.calls.lock().expect("singleflight mutex poisoned");
276        calls.retain(|_, existing| {
277            existing
278                .upgrade()
279                .is_some_and(|call| !std::ptr::eq(call.as_ref(), self))
280        });
281    }
282}
283
284#[cfg(test)]
285mod tests {
286    use super::*;
287    use std::{
288        future::{Ready, ready},
289        sync::{
290            Arc,
291            atomic::{AtomicUsize, Ordering},
292        },
293    };
294    use tokio::{
295        sync::{Barrier, oneshot},
296        time::{Duration, sleep, timeout},
297    };
298
299    type EntryGroup = Group<&'static str, usize, (), fn(&'static str) -> Ready<Result<usize, ()>>>;
300
301    fn entry_op(_: &'static str) -> Ready<Result<usize, ()>> {
302        ready(Ok(0))
303    }
304
305    fn entry_group() -> EntryGroup {
306        Group::new(entry_op)
307    }
308
309    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
310    async fn suppresses_duplicate_calls() {
311        let calls = Arc::new(AtomicUsize::new(0));
312        let calls_for_op = Arc::clone(&calls);
313        let group = Arc::new(Group::new(move |key: String| {
314            let calls = Arc::clone(&calls_for_op);
315            async move {
316                assert_eq!(key, "key");
317                calls.fetch_add(1, Ordering::SeqCst);
318                sleep(Duration::from_millis(20)).await;
319                Ok::<String, ()>("value".to_owned())
320            }
321        }));
322        let barrier = Arc::new(Barrier::new(12));
323        let mut tasks = Vec::new();
324
325        for _ in 0..12 {
326            let group = Arc::clone(&group);
327            let barrier = Arc::clone(&barrier);
328            tasks.push(tokio::spawn(async move {
329                barrier.wait().await;
330                group.run("key".to_owned()).await
331            }));
332        }
333
334        let mut shared = false;
335        for task in tasks {
336            let outcome = task.await.expect("task panicked");
337            match outcome.as_ref() {
338                Outcome::Complete { result, shared: s } => {
339                    assert_eq!(result.as_ref().unwrap(), "value");
340                    shared |= *s;
341                }
342                Outcome::Canceled => panic!("leader should complete"),
343            }
344        }
345
346        assert_eq!(calls.load(Ordering::SeqCst), 1);
347        assert!(shared);
348        assert_eq!(group.in_flight(), 0);
349    }
350
351    #[tokio::test]
352    async fn subscribers_receive_cancellation_when_leader_is_dropped() {
353        let group = entry_group();
354        let leader = match group.entry("key") {
355            Entry::Leader(leader) => leader,
356            Entry::Subscriber(_) => panic!("first entry must lead"),
357        };
358        let subscriber = match group.entry("key") {
359            Entry::Subscriber(subscriber) => subscriber,
360            Entry::Leader(_) => panic!("duplicate entry must subscribe"),
361        };
362
363        drop(leader);
364
365        let outcome = timeout(Duration::from_secs(1), subscriber.recv())
366            .await
367            .expect("subscriber hung")
368            .expect("subscriber closed");
369        assert!(matches!(outcome.as_ref(), Outcome::Canceled));
370        assert_eq!(group.in_flight(), 0);
371    }
372
373    #[tokio::test]
374    async fn forget_starts_a_new_leader_without_breaking_old_one() {
375        let group = entry_group();
376        let first = match group.entry("key") {
377            Entry::Leader(leader) => leader,
378            Entry::Subscriber(_) => panic!("first entry must lead"),
379        };
380
381        group.forget("key");
382
383        let second = match group.entry("key") {
384            Entry::Leader(leader) => leader,
385            Entry::Subscriber(_) => panic!("forgotten key should create a new leader"),
386        };
387        let third = match group.entry("key") {
388            Entry::Subscriber(subscriber) => subscriber,
389            Entry::Leader(_) => panic!("third entry should subscribe to second leader"),
390        };
391
392        first.complete(Ok(1));
393        let published = second.complete(Ok(2));
394        assert!(matches!(
395            published.as_ref(),
396            Outcome::Complete {
397                result: Ok(2),
398                shared: true
399            }
400        ));
401
402        let received = third.recv().await.expect("third subscriber closed");
403        assert!(matches!(
404            received.as_ref(),
405            Outcome::Complete {
406                result: Ok(2),
407                shared: true
408            }
409        ));
410        assert_eq!(group.in_flight(), 0);
411    }
412
413    #[tokio::test]
414    async fn custom_entry_api_allows_external_compute_placement() {
415        let group = entry_group();
416        let (release_tx, release_rx) = oneshot::channel();
417
418        let leader = match group.entry("key") {
419            Entry::Leader(leader) => leader,
420            Entry::Subscriber(_) => panic!("first entry must lead"),
421        };
422        let duplicate = match group.entry("key") {
423            Entry::Subscriber(subscriber) => subscriber,
424            Entry::Leader(_) => panic!("duplicate entry must subscribe"),
425        };
426
427        let task = tokio::spawn(async move {
428            release_rx.await.expect("release dropped");
429            leader.complete(Ok(42))
430        });
431
432        release_tx.send(()).expect("leader task dropped");
433        assert!(matches!(
434            duplicate.recv().await.unwrap().as_ref(),
435            Outcome::Complete {
436                result: Ok(42),
437                shared: true
438            }
439        ));
440        assert!(task.await.unwrap().is_shared());
441    }
442}