Skip to main content

rs_singleflight/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use std::{
4    collections::{HashMap, hash_map::RandomState},
5    fmt,
6    future::Future,
7    hash::{BuildHasher, Hash},
8    sync::{
9        Arc, Mutex, Weak,
10        atomic::{AtomicBool, AtomicUsize, Ordering},
11    },
12};
13
14use tokio::sync::broadcast;
15
16type SharedOutcome<T, E> = Arc<Outcome<T, E>>;
17type Calls<K, T, E, S> = HashMap<K, Weak<Call<K, T, E, S>>, S>;
18
19/// Result published by the single in-flight computation.
20#[derive(Debug)]
21pub enum Outcome<T, E> {
22    /// The leader completed the computation.
23    Complete { result: Result<T, E>, shared: bool },
24    /// The leader future was dropped before it completed.
25    Canceled,
26}
27
28impl<T, E> Outcome<T, E> {
29    pub fn is_shared(&self) -> bool {
30        matches!(self, Self::Complete { shared: true, .. })
31    }
32
33    pub fn result(&self) -> Option<&Result<T, E>> {
34        match self {
35            Self::Complete { result, .. } => Some(result),
36            Self::Canceled => None,
37        }
38    }
39}
40
41/// Error returned when a subscriber cannot receive a leader result.
42#[derive(Debug, Clone, Copy, PartialEq, Eq)]
43pub enum WaitError {
44    /// The broadcast channel closed before an outcome was available.
45    Closed,
46    /// The subscriber lagged behind the broadcast channel.
47    Lagged(u64),
48}
49
50impl fmt::Display for WaitError {
51    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
52        match self {
53            Self::Closed => f.write_str("singleflight result channel closed"),
54            Self::Lagged(n) => write!(f, "singleflight subscriber lagged by {n} messages"),
55        }
56    }
57}
58
59impl std::error::Error for WaitError {}
60
61/// Namespace for duplicate suppression.
62///
63/// For a given key, only the leader computes. Duplicate callers subscribe to
64/// the leader's broadcast and receive the same [`Outcome`].
65pub struct Group<K, T, E, S = RandomState> {
66    inner: Arc<Inner<K, T, E, S>>,
67}
68
69impl<K, T, E> Group<K, T, E, RandomState> {
70    pub fn new() -> Self {
71        Self::with_hasher(RandomState::new())
72    }
73}
74
75impl<K, T, E, S> Group<K, T, E, S> {
76    pub fn with_hasher(hasher: S) -> Self {
77        Self {
78            inner: Arc::new(Inner {
79                calls: Mutex::new(HashMap::with_hasher(hasher)),
80            }),
81        }
82    }
83}
84
85impl<K, T, E, S> Clone for Group<K, T, E, S> {
86    fn clone(&self) -> Self {
87        Self {
88            inner: Arc::clone(&self.inner),
89        }
90    }
91}
92
93impl<K, T, E> Default for Group<K, T, E, RandomState> {
94    fn default() -> Self {
95        Self::new()
96    }
97}
98
99impl<K, T, E, S> Group<K, T, E, S>
100where
101    K: Eq + Hash,
102    S: BuildHasher,
103{
104    /// Returns a leader for a new key, or a subscriber for an in-flight key.
105    pub fn entry(&self, key: K) -> Entry<K, T, E, S> {
106        let mut calls = self
107            .inner
108            .calls
109            .lock()
110            .expect("singleflight mutex poisoned");
111
112        if let Some(call) = calls.get(&key).and_then(Weak::upgrade) {
113            return Entry::Subscriber(call.subscribe());
114        }
115
116        let call = Arc::new(Call::new(Arc::downgrade(&self.inner)));
117        calls.insert(key, Arc::downgrade(&call));
118        Entry::Leader(Leader { call: Some(call) })
119    }
120
121    /// Executes `f` once per key while an earlier call is in flight.
122    pub async fn run<F, Fut>(&self, key: K, f: F) -> SharedOutcome<T, E>
123    where
124        F: FnOnce() -> Fut,
125        Fut: Future<Output = Result<T, E>>,
126    {
127        match self.entry(key) {
128            Entry::Leader(leader) => {
129                let result = f().await;
130                leader.complete(result)
131            }
132            Entry::Subscriber(subscriber) => subscriber
133                .recv()
134                .await
135                .unwrap_or_else(|_| Arc::new(Outcome::Canceled)),
136        }
137    }
138
139    /// Forgets a key so the next [`entry`](Self::entry) or [`run`](Self::run)
140    /// starts a fresh leader instead of joining the current call.
141    pub fn forget<Q>(&self, key: &Q)
142    where
143        K: std::borrow::Borrow<Q>,
144        Q: Hash + Eq + ?Sized,
145    {
146        self.inner
147            .calls
148            .lock()
149            .expect("singleflight mutex poisoned")
150            .remove(key);
151    }
152
153    pub fn in_flight(&self) -> usize {
154        self.inner
155            .calls
156            .lock()
157            .expect("singleflight mutex poisoned")
158            .len()
159    }
160}
161
162/// Returned by [`Group::entry`].
163pub enum Entry<K, T, E, S = RandomState> {
164    Leader(Leader<K, T, E, S>),
165    Subscriber(Subscriber<T, E>),
166}
167
168/// Owner of the single computation for a key.
169///
170/// Dropping a leader before calling [`complete`](Self::complete) publishes
171/// [`Outcome::Canceled`] to subscribers and removes the key from the group.
172pub struct Leader<K, T, E, S = RandomState> {
173    call: Option<Arc<Call<K, T, E, S>>>,
174}
175
176impl<K, T, E, S> Leader<K, T, E, S>
177where
178    K: Eq + Hash,
179    S: BuildHasher,
180{
181    pub fn complete(mut self, result: Result<T, E>) -> SharedOutcome<T, E> {
182        let call = self.call.take().expect("leader completed twice");
183        call.cleanup();
184        let shared = call.waiters.load(Ordering::SeqCst) > 0;
185        let outcome = Arc::new(Outcome::Complete { result, shared });
186        call.publish(Arc::clone(&outcome));
187        outcome
188    }
189
190    pub fn subscribe(&self) -> Subscriber<T, E> {
191        self.call
192            .as_ref()
193            .expect("leader already completed")
194            .subscribe()
195    }
196
197    pub fn duplicate_count(&self) -> usize {
198        self.call
199            .as_ref()
200            .map(|call| call.waiters.load(Ordering::SeqCst))
201            .unwrap_or(0)
202    }
203}
204
205impl<K, T, E, S> Drop for Leader<K, T, E, S> {
206    fn drop(&mut self) {
207        if let Some(call) = self.call.take() {
208            call.cancel();
209        }
210    }
211}
212
213/// Receiver for a duplicate caller.
214pub struct Subscriber<T, E> {
215    rx: broadcast::Receiver<SharedOutcome<T, E>>,
216}
217
218impl<T, E> Subscriber<T, E> {
219    pub async fn recv(mut self) -> Result<SharedOutcome<T, E>, WaitError> {
220        match self.rx.recv().await {
221            Ok(outcome) => Ok(outcome),
222            Err(broadcast::error::RecvError::Closed) => Err(WaitError::Closed),
223            Err(broadcast::error::RecvError::Lagged(n)) => Err(WaitError::Lagged(n)),
224        }
225    }
226}
227
228struct Inner<K, T, E, S> {
229    calls: Mutex<Calls<K, T, E, S>>,
230}
231
232struct Call<K, T, E, S> {
233    group: Weak<Inner<K, T, E, S>>,
234    tx: broadcast::Sender<SharedOutcome<T, E>>,
235    waiters: AtomicUsize,
236    finished: AtomicBool,
237}
238
239impl<K, T, E, S> Call<K, T, E, S> {
240    fn new(group: Weak<Inner<K, T, E, S>>) -> Self {
241        let (tx, _) = broadcast::channel(1);
242        Self {
243            group,
244            tx,
245            waiters: AtomicUsize::new(0),
246            finished: AtomicBool::new(false),
247        }
248    }
249
250    fn subscribe(&self) -> Subscriber<T, E> {
251        self.waiters.fetch_add(1, Ordering::SeqCst);
252        Subscriber {
253            rx: self.tx.subscribe(),
254        }
255    }
256
257    fn publish(&self, outcome: SharedOutcome<T, E>) {
258        if !self.finished.swap(true, Ordering::SeqCst) {
259            let _ = self.tx.send(outcome);
260        }
261    }
262
263    fn cancel(&self) {
264        self.cleanup();
265        self.publish(Arc::new(Outcome::Canceled));
266    }
267
268    fn cleanup(&self) {
269        let Some(group) = self.group.upgrade() else {
270            return;
271        };
272
273        let mut calls = group.calls.lock().expect("singleflight mutex poisoned");
274        calls.retain(|_, existing| {
275            existing
276                .upgrade()
277                .is_some_and(|call| !std::ptr::eq(call.as_ref(), self))
278        });
279    }
280}
281
282#[cfg(test)]
283mod tests {
284    use super::*;
285    use std::sync::{
286        Arc,
287        atomic::{AtomicUsize, Ordering},
288    };
289    use tokio::{
290        sync::{Barrier, oneshot},
291        time::{Duration, sleep, timeout},
292    };
293
294    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
295    async fn suppresses_duplicate_calls() {
296        let group = Arc::new(Group::<String, String, ()>::new());
297        let calls = Arc::new(AtomicUsize::new(0));
298        let barrier = Arc::new(Barrier::new(12));
299        let mut tasks = Vec::new();
300
301        for _ in 0..12 {
302            let group = Arc::clone(&group);
303            let calls = Arc::clone(&calls);
304            let barrier = Arc::clone(&barrier);
305            tasks.push(tokio::spawn(async move {
306                barrier.wait().await;
307                group
308                    .run("key".to_owned(), || async {
309                        calls.fetch_add(1, Ordering::SeqCst);
310                        sleep(Duration::from_millis(20)).await;
311                        Ok("value".to_owned())
312                    })
313                    .await
314            }));
315        }
316
317        let mut shared = false;
318        for task in tasks {
319            let outcome = task.await.expect("task panicked");
320            match outcome.as_ref() {
321                Outcome::Complete { result, shared: s } => {
322                    assert_eq!(result.as_ref().unwrap(), "value");
323                    shared |= *s;
324                }
325                Outcome::Canceled => panic!("leader should complete"),
326            }
327        }
328
329        assert_eq!(calls.load(Ordering::SeqCst), 1);
330        assert!(shared);
331        assert_eq!(group.in_flight(), 0);
332    }
333
334    #[tokio::test]
335    async fn subscribers_receive_cancellation_when_leader_is_dropped() {
336        let group = Group::<&'static str, usize, ()>::new();
337        let leader = match group.entry("key") {
338            Entry::Leader(leader) => leader,
339            Entry::Subscriber(_) => panic!("first entry must lead"),
340        };
341        let subscriber = match group.entry("key") {
342            Entry::Subscriber(subscriber) => subscriber,
343            Entry::Leader(_) => panic!("duplicate entry must subscribe"),
344        };
345
346        drop(leader);
347
348        let outcome = timeout(Duration::from_secs(1), subscriber.recv())
349            .await
350            .expect("subscriber hung")
351            .expect("subscriber closed");
352        assert!(matches!(outcome.as_ref(), Outcome::Canceled));
353        assert_eq!(group.in_flight(), 0);
354    }
355
356    #[tokio::test]
357    async fn forget_starts_a_new_leader_without_breaking_old_one() {
358        let group = Group::<&'static str, usize, ()>::new();
359        let first = match group.entry("key") {
360            Entry::Leader(leader) => leader,
361            Entry::Subscriber(_) => panic!("first entry must lead"),
362        };
363
364        group.forget("key");
365
366        let second = match group.entry("key") {
367            Entry::Leader(leader) => leader,
368            Entry::Subscriber(_) => panic!("forgotten key should create a new leader"),
369        };
370        let third = match group.entry("key") {
371            Entry::Subscriber(subscriber) => subscriber,
372            Entry::Leader(_) => panic!("third entry should subscribe to second leader"),
373        };
374
375        first.complete(Ok(1));
376        let published = second.complete(Ok(2));
377        assert!(matches!(
378            published.as_ref(),
379            Outcome::Complete {
380                result: Ok(2),
381                shared: true
382            }
383        ));
384
385        let received = third.recv().await.expect("third subscriber closed");
386        assert!(matches!(
387            received.as_ref(),
388            Outcome::Complete {
389                result: Ok(2),
390                shared: true
391            }
392        ));
393        assert_eq!(group.in_flight(), 0);
394    }
395
396    #[tokio::test]
397    async fn custom_entry_api_allows_external_compute_placement() {
398        let group = Group::<&'static str, usize, ()>::new();
399        let (release_tx, release_rx) = oneshot::channel();
400
401        let leader = match group.entry("key") {
402            Entry::Leader(leader) => leader,
403            Entry::Subscriber(_) => panic!("first entry must lead"),
404        };
405        let duplicate = match group.entry("key") {
406            Entry::Subscriber(subscriber) => subscriber,
407            Entry::Leader(_) => panic!("duplicate entry must subscribe"),
408        };
409
410        let task = tokio::spawn(async move {
411            release_rx.await.expect("release dropped");
412            leader.complete(Ok(42))
413        });
414
415        release_tx.send(()).expect("leader task dropped");
416        assert!(matches!(
417            duplicate.recv().await.unwrap().as_ref(),
418            Outcome::Complete {
419                result: Ok(42),
420                shared: true
421            }
422        ));
423        assert!(task.await.unwrap().is_shared());
424    }
425}