async_singleflight/
lib.rs

1//! A singleflight implementation for tokio.
2//!
3//! Inspired by [singleflight](https://crates.io/crates/singleflight).
4//!
5//! # Examples
6//!
7//! ```no_run
8//! use futures::future::join_all;
9//! use std::sync::Arc;
10//! use std::time::Duration;
11//!
12//! use async_singleflight::Group;
13//!
14//! const RES: usize = 7;
15//!
16//! async fn expensive_fn() -> Result<usize, ()> {
17//!     tokio::time::sleep(Duration::new(1, 500)).await;
18//!     Ok(RES)
19//! }
20//!
21//! #[tokio::main]
22//! async fn main() {
23//!     let g = Arc::new(Group::<_, ()>::new());
24//!     let mut handlers = Vec::new();
25//!     for _ in 0..10 {
26//!         let g = g.clone();
27//!         handlers.push(tokio::spawn(async move {
28//!             let res = g.work("key", expensive_fn()).await.0;
29//!             let r = res.unwrap();
30//!             println!("{}", r);
31//!         }));
32//!     }
33//!
34//!     join_all(handlers).await;
35//! }
36//! ```
37//!
38
39use std::fmt::{self, Debug};
40use std::future::Future;
41use std::marker::PhantomData;
42use std::pin::Pin;
43use std::task::{Context, Poll};
44
45use futures::future::BoxFuture;
46use hashbrown::HashMap;
47use parking_lot::Mutex;
48use pin_project::{pin_project, pinned_drop};
49use tokio::sync::watch;
50
51/// Group represents a class of work and creates a space in which units of work
52/// can be executed with duplicate suppression.
53pub struct Group<T, E>
54where
55    T: Clone,
56{
57    m: Mutex<HashMap<String, watch::Receiver<State<T>>>>,
58    _marker: PhantomData<fn(E)>,
59}
60
61impl<T, E> Debug for Group<T, E>
62where
63    T: Clone,
64{
65    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66        f.debug_struct("Group").finish()
67    }
68}
69
70impl<T, E> Default for Group<T, E>
71where
72    T: Clone,
73{
74    fn default() -> Self {
75        Self::new()
76    }
77}
78
79#[derive(Clone)]
80enum State<T: Clone> {
81    Starting,
82    LeaderDropped,
83    Done(Option<T>),
84}
85
86impl<T, E> Group<T, E>
87where
88    T: Clone,
89{
90    /// Create a new Group to do work with.
91    #[must_use]
92    pub fn new() -> Group<T, E> {
93        Self {
94            m: Mutex::new(HashMap::new()),
95            _marker: PhantomData,
96        }
97    }
98
99    /// Execute and return the value for a given function, making sure that only one
100    /// operation is in-flight at a given moment. If a duplicate call comes in, that caller will
101    /// wait until the original call completes and return the same value.
102    /// Only owner call returns error if exists.
103    /// The third return value indicates whether the call is the owner.
104    pub async fn work(
105        &self,
106        key: &str,
107        fut: impl Future<Output = Result<T, E>>,
108    ) -> (Option<T>, Option<E>, bool) {
109        use hashbrown::hash_map::EntryRef;
110
111        let tx_or_rx = match self.m.lock().entry_ref(key) {
112            EntryRef::Occupied(mut entry) => {
113                let state = entry.get().borrow().clone();
114                match state {
115                    State::Starting => Err(entry.get().clone()),
116                    State::LeaderDropped => {
117                        // switch into leader if leader dropped
118                        let (tx, rx) = watch::channel(State::Starting);
119                        entry.insert(rx);
120                        Ok(tx)
121                    }
122                    State::Done(val) => return (val, None, false),
123                }
124            }
125            EntryRef::Vacant(entry) => {
126                let (tx, rx) = watch::channel(State::Starting);
127                entry.insert(rx);
128                Ok(tx)
129            }
130        };
131
132        match tx_or_rx {
133            Ok(tx) => {
134                let fut = Leader { fut, tx };
135                let result = fut.await;
136                self.m.lock().remove(key);
137                match result {
138                    Ok(val) => (Some(val), None, true),
139                    Err(err) => (None, Some(err), true),
140                }
141            }
142            Err(mut rx) => {
143                let mut state = rx.borrow_and_update().clone();
144                if matches!(state, State::Starting) {
145                    let _changed = rx.changed().await;
146                    state = rx.borrow().clone();
147                }
148                match state {
149                    State::Starting => (None, None, false), // unreachable
150                    State::LeaderDropped => {
151                        self.m.lock().remove(key);
152                        (None, None, false)
153                    }
154                    State::Done(val) => (val, None, false),
155                }
156            }
157        }
158    }
159}
160
161#[pin_project(PinnedDrop)]
162struct Leader<T: Clone, F> {
163    #[pin]
164    fut: F,
165    tx: watch::Sender<State<T>>,
166}
167
168impl<T, E, F> Future for Leader<T, F>
169where
170    T: Clone,
171    F: Future<Output = Result<T, E>>,
172{
173    type Output = Result<T, E>;
174
175    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
176        let this = self.project();
177        let result = this.fut.poll(cx);
178        if let Poll::Ready(val) = &result {
179            let _send = this.tx.send(State::Done(val.as_ref().ok().cloned()));
180        }
181        result
182    }
183}
184
185#[pinned_drop]
186impl<T, F> PinnedDrop for Leader<T, F>
187where
188    T: Clone,
189{
190    fn drop(self: Pin<&mut Self>) {
191        let this = self.project();
192        let _ = this.tx.send_if_modified(|s| {
193            if matches!(s, State::Starting) {
194                *s = State::LeaderDropped;
195                true
196            } else {
197                false
198            }
199        });
200    }
201}
202
203/// UnaryGroup represents a class of work and creates a space in which units of work
204/// can be executed with duplicate suppression.
205pub struct UnaryGroup<T>
206where
207    T: Clone,
208{
209    m: Mutex<HashMap<String, watch::Receiver<UnaryState<T>>>>,
210}
211
212impl<T> Debug for UnaryGroup<T>
213where
214    T: Clone,
215{
216    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
217        f.debug_struct("UnaryGroup").finish()
218    }
219}
220
221impl<T> Default for UnaryGroup<T>
222where
223    T: Clone + Send + Sync,
224{
225    fn default() -> Self {
226        Self::new()
227    }
228}
229
230#[derive(Clone)]
231enum UnaryState<T: Clone> {
232    Starting,
233    LeaderDropped,
234    Done(T),
235}
236
237impl<T> UnaryGroup<T>
238where
239    T: Clone + Send + Sync,
240{
241    /// Create a new Group to do work with.
242    #[must_use]
243    pub fn new() -> UnaryGroup<T> {
244        Self {
245            m: Mutex::new(HashMap::new()),
246        }
247    }
248
249    /// Execute and return the value for a given function, making sure that only one
250    /// operation is in-flight at a given moment. If a duplicate call comes in, that caller will
251    /// wait until the original call completes and return the same value.
252    ///
253    /// The third return value indicates whether the call is the owner.
254    pub fn work<'s>(
255        &'s self,
256        key: &'s str,
257        fut: impl Future<Output = T> + Send + 's,
258    ) -> BoxFuture<'s, (T, bool)> {
259        use hashbrown::hash_map::EntryRef;
260        Box::pin(async move {
261            let tx_or_rx = match self.m.lock().entry_ref(key) {
262                EntryRef::Occupied(mut entry) => {
263                    let state = entry.get().borrow().clone();
264                    match state {
265                        UnaryState::Starting => Err(entry.get().clone()),
266                        UnaryState::LeaderDropped => {
267                            // switch into leader if leader dropped
268                            let (tx, rx) = watch::channel(UnaryState::Starting);
269                            entry.insert(rx);
270                            Ok(tx)
271                        }
272                        UnaryState::Done(val) => return (val, false),
273                    }
274                }
275                EntryRef::Vacant(entry) => {
276                    let (tx, rx) = watch::channel(UnaryState::Starting);
277                    entry.insert(rx);
278                    Ok(tx)
279                }
280            };
281
282            match tx_or_rx {
283                Ok(tx) => {
284                    let fut = UnaryLeader { fut, tx };
285                    let result = fut.await;
286                    self.m.lock().remove(key);
287                    (result, true)
288                }
289                Err(mut rx) => {
290                    let mut state = rx.borrow_and_update().clone();
291                    if matches!(state, UnaryState::Starting) {
292                        let _changed = rx.changed().await;
293                        state = rx.borrow().clone();
294                    }
295                    match state {
296                        UnaryState::Starting => unreachable!(), // unreachable
297                        UnaryState::LeaderDropped => {
298                            self.m.lock().remove(key);
299                            // the leader dropped, so we need to retry
300                            self.work(key, fut).await
301                        }
302                        UnaryState::Done(val) => (val, false),
303                    }
304                }
305            }
306        })
307    }
308}
309
310#[pin_project(PinnedDrop)]
311struct UnaryLeader<T: Clone, F> {
312    #[pin]
313    fut: F,
314    tx: watch::Sender<UnaryState<T>>,
315}
316
317impl<T, F> Future for UnaryLeader<T, F>
318where
319    T: Clone + Send + Sync,
320    F: Future<Output = T>,
321{
322    type Output = T;
323
324    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
325        let this = self.project();
326        let result = this.fut.poll(cx);
327        if let Poll::Ready(val) = &result {
328            let _send = this.tx.send(UnaryState::Done(val.clone()));
329        }
330        result
331    }
332}
333
334#[pinned_drop]
335impl<T, F> PinnedDrop for UnaryLeader<T, F>
336where
337    T: Clone,
338{
339    fn drop(self: Pin<&mut Self>) {
340        let this = self.project();
341        let _ = this.tx.send_if_modified(|s| {
342            if matches!(s, UnaryState::Starting) {
343                *s = UnaryState::LeaderDropped;
344                true
345            } else {
346                false
347            }
348        });
349    }
350}
351
352#[cfg(test)]
353mod tests {
354    use std::time::Duration;
355
356    use super::Group;
357
358    const RES: usize = 7;
359
360    async fn return_res() -> Result<usize, ()> {
361        Ok(7)
362    }
363
364    async fn expensive_fn() -> Result<usize, ()> {
365        tokio::time::sleep(Duration::from_millis(500)).await;
366        Ok(RES)
367    }
368
369    #[tokio::test]
370    async fn test_simple() {
371        let g = Group::new();
372        let res = g.work("key", return_res()).await.0;
373        let r = res.unwrap();
374        assert_eq!(r, RES);
375    }
376
377    #[tokio::test]
378    async fn test_multiple_threads() {
379        use std::sync::Arc;
380
381        use futures::future::join_all;
382
383        let g = Arc::new(Group::new());
384        let mut handlers = Vec::new();
385        for _ in 0..10 {
386            let g = g.clone();
387            handlers.push(tokio::spawn(async move {
388                let res = g.work("key", expensive_fn()).await.0;
389                let r = res.unwrap();
390                println!("{}", r);
391            }));
392        }
393
394        join_all(handlers).await;
395    }
396
397    #[tokio::test]
398    async fn test_drop_leader() {
399        use std::time::Duration;
400
401        let g = Group::new();
402        {
403            tokio::time::timeout(Duration::from_millis(50), g.work("key", expensive_fn()))
404                .await
405                .expect_err("owner should be running and cancelled");
406        }
407        assert_eq!(
408            tokio::time::timeout(Duration::from_secs(1), g.work("key", expensive_fn())).await,
409            Ok((Some(RES), None, true)),
410        );
411    }
412}