Skip to main content

msg_common/
task.rs

1use futures::{Future, future::poll_fn};
2use std::{
3    collections::HashSet,
4    task::{Context, Poll, ready},
5};
6use tokio::task::{JoinError, JoinSet};
7
8/// A collection of keyed tasks spawned on a Tokio runtime.
9/// Hacky implementation of a join set that allows for a key to be associated with each task by
10/// having the task return a tuple of (key, value).
11#[derive(Debug, Default)]
12pub struct JoinMap<K, V> {
13    keys: HashSet<K>,
14    joinset: JoinSet<(K, V)>,
15}
16
17impl<K, V> JoinMap<K, V> {
18    /// Create a new `JoinSet`.
19    pub fn new() -> Self {
20        Self { keys: HashSet::new(), joinset: JoinSet::new() }
21    }
22
23    /// Returns the number of tasks currently in the `JoinSet`.
24    pub fn len(&self) -> usize {
25        self.joinset.len()
26    }
27
28    /// Returns whether the `JoinSet` is empty.
29    pub fn is_empty(&self) -> bool {
30        self.joinset.is_empty()
31    }
32}
33
34impl<K, V> JoinMap<K, V>
35where
36    K: Eq + std::hash::Hash + Clone + Send + Sync + 'static,
37    V: 'static,
38{
39    /// Spawns a task onto the Tokio runtime that will execute the given future ONLY IF
40    /// there is not already a task in the set with the same key.
41    pub fn spawn<F>(&mut self, key: K, future: F)
42    where
43        F: Future<Output = (K, V)> + Send + 'static,
44        V: Send,
45    {
46        if self.keys.insert(key) {
47            self.joinset.spawn(future);
48        }
49    }
50
51    /// Returns `true` if the `JoinSet` contains a task for the given key.
52    pub fn contains_key(&self, key: &K) -> bool {
53        self.keys.contains(key)
54    }
55
56    /// Waits until one of the tasks in the set completes and returns its output.
57    ///
58    /// Returns `None` if the set is empty.
59    ///
60    /// # Cancel Safety
61    ///
62    /// This method is cancel safe. If `join_next` is used as the event in a `tokio::select!`
63    /// statement and some other branch completes first, it is guaranteed that no tasks were
64    /// removed from this `JoinSet`.
65    pub async fn join_next(&mut self) -> Option<Result<(K, V), JoinError>> {
66        poll_fn(|cx| self.poll_join_next(cx)).await
67    }
68
69    /// Polls for one of the tasks in the set to complete.
70    ///
71    /// If this returns `Poll::Ready(Some(_))`, then the task that completed is removed from the
72    /// set.
73    ///
74    /// When the method returns `Poll::Pending`, the `Waker` in the provided `Context` is scheduled
75    /// to receive a wakeup when a task in the `JoinSet` completes. Note that on multiple calls to
76    /// `poll_join_next`, only the `Waker` from the `Context` passed to the most recent call is
77    /// scheduled to receive a wakeup.
78    ///
79    /// # Returns
80    ///
81    /// This function returns:
82    ///
83    ///  * `Poll::Pending` if the `JoinSet` is not empty but there is no task whose output is
84    ///    available right now.
85    ///  * `Poll::Ready(Some(Ok(value)))` if one of the tasks in this `JoinSet` has completed. The
86    ///    `value` is the return value of one of the tasks that completed.
87    ///  * `Poll::Ready(Some(Err(err)))` if one of the tasks in this `JoinSet` has panicked or been
88    ///    aborted. The `err` is the `JoinError` from the panicked/aborted task.
89    ///  * `Poll::Ready(None)` if the `JoinSet` is empty.
90    ///
91    /// Note that this method may return `Poll::Pending` even if one of the tasks has completed.
92    /// This can happen if the [coop budget] is reached.
93    pub fn poll_join_next(
94        &mut self,
95        cx: &mut Context<'_>,
96    ) -> Poll<Option<Result<(K, V), JoinError>>> {
97        match ready!(self.joinset.poll_join_next(cx)) {
98            Some(Ok((key, value))) => {
99                self.keys.remove(&key);
100                Poll::Ready(Some(Ok((key, value))))
101            }
102            Some(Err(err)) => Poll::Ready(Some(Err(err))),
103            None => Poll::Ready(None),
104        }
105    }
106}