msg_common/task.rs
1use futures::{Future, future::poll_fn};
2use std::{
3 collections::HashSet,
4 task::{Context, Poll, ready},
5};
6use tokio::task::{JoinError, JoinSet};
7
8/// A collection of keyed tasks spawned on a Tokio runtime.
9/// Hacky implementation of a join set that allows for a key to be associated with each task by
10/// having the task return a tuple of (key, value).
11#[derive(Debug, Default)]
12pub struct JoinMap<K, V> {
13 keys: HashSet<K>,
14 joinset: JoinSet<(K, V)>,
15}
16
17impl<K, V> JoinMap<K, V> {
18 /// Create a new `JoinSet`.
19 pub fn new() -> Self {
20 Self { keys: HashSet::new(), joinset: JoinSet::new() }
21 }
22
23 /// Returns the number of tasks currently in the `JoinSet`.
24 pub fn len(&self) -> usize {
25 self.joinset.len()
26 }
27
28 /// Returns whether the `JoinSet` is empty.
29 pub fn is_empty(&self) -> bool {
30 self.joinset.is_empty()
31 }
32}
33
34impl<K, V> JoinMap<K, V>
35where
36 K: Eq + std::hash::Hash + Clone + Send + Sync + 'static,
37 V: 'static,
38{
39 /// Spawns a task onto the Tokio runtime that will execute the given future ONLY IF
40 /// there is not already a task in the set with the same key.
41 pub fn spawn<F>(&mut self, key: K, future: F)
42 where
43 F: Future<Output = (K, V)> + Send + 'static,
44 V: Send,
45 {
46 if self.keys.insert(key) {
47 self.joinset.spawn(future);
48 }
49 }
50
51 /// Returns `true` if the `JoinSet` contains a task for the given key.
52 pub fn contains_key(&self, key: &K) -> bool {
53 self.keys.contains(key)
54 }
55
56 /// Waits until one of the tasks in the set completes and returns its output.
57 ///
58 /// Returns `None` if the set is empty.
59 ///
60 /// # Cancel Safety
61 ///
62 /// This method is cancel safe. If `join_next` is used as the event in a `tokio::select!`
63 /// statement and some other branch completes first, it is guaranteed that no tasks were
64 /// removed from this `JoinSet`.
65 pub async fn join_next(&mut self) -> Option<Result<(K, V), JoinError>> {
66 poll_fn(|cx| self.poll_join_next(cx)).await
67 }
68
69 /// Polls for one of the tasks in the set to complete.
70 ///
71 /// If this returns `Poll::Ready(Some(_))`, then the task that completed is removed from the
72 /// set.
73 ///
74 /// When the method returns `Poll::Pending`, the `Waker` in the provided `Context` is scheduled
75 /// to receive a wakeup when a task in the `JoinSet` completes. Note that on multiple calls to
76 /// `poll_join_next`, only the `Waker` from the `Context` passed to the most recent call is
77 /// scheduled to receive a wakeup.
78 ///
79 /// # Returns
80 ///
81 /// This function returns:
82 ///
83 /// * `Poll::Pending` if the `JoinSet` is not empty but there is no task whose output is
84 /// available right now.
85 /// * `Poll::Ready(Some(Ok(value)))` if one of the tasks in this `JoinSet` has completed. The
86 /// `value` is the return value of one of the tasks that completed.
87 /// * `Poll::Ready(Some(Err(err)))` if one of the tasks in this `JoinSet` has panicked or been
88 /// aborted. The `err` is the `JoinError` from the panicked/aborted task.
89 /// * `Poll::Ready(None)` if the `JoinSet` is empty.
90 ///
91 /// Note that this method may return `Poll::Pending` even if one of the tasks has completed.
92 /// This can happen if the [coop budget] is reached.
93 pub fn poll_join_next(
94 &mut self,
95 cx: &mut Context<'_>,
96 ) -> Poll<Option<Result<(K, V), JoinError>>> {
97 match ready!(self.joinset.poll_join_next(cx)) {
98 Some(Ok((key, value))) => {
99 self.keys.remove(&key);
100 Poll::Ready(Some(Ok((key, value))))
101 }
102 Some(Err(err)) => Poll::Ready(Some(Err(err))),
103 None => Poll::Ready(None),
104 }
105 }
106}