async_task_tracker/
lib.rs

1#![allow(unknown_lints, unexpected_cfgs)]
2#![allow(clippy::needless_doctest_main)]
3#![warn(
4    missing_debug_implementations,
5    missing_docs,
6    rust_2018_idioms,
7    unreachable_pub
8)]
9#![doc(test(
10    no_crate_inject,
11    attr(deny(warnings, rust_2018_idioms), allow(dead_code, unused_variables))
12))]
13#![cfg_attr(docsrs, feature(doc_cfg))]
14#![cfg_attr(not(feature = "std"), no_std)]
15
16//! Modified version of [`TaskTracker`] from the `tokio-util` crate without tokio-specific features.
17
18extern crate alloc;
19
20use core::{
21    fmt,
22    future::Future,
23    pin::Pin,
24    sync::atomic::{AtomicUsize, Ordering},
25    task::{Context, Poll},
26};
27
28use alloc::sync::Arc;
29
30use event_listener::{Event, EventListener};
31use pin_project_lite::pin_project;
32
33/// A task tracker used for waiting until tasks exit.
34///
35/// This is usually used together with a cancellation token to implement [graceful shutdown]. The
36/// cancellation token is used to signal to tasks that they should shut down, and the
37/// `TaskTracker` is used to wait for them to finish shutting down. For tokio runtime, there is a
38/// [`CancellationToken`] in the `tokio-util` crate that can be used for this purpose. Otherwise,
39/// consider using an mpsc channel as a cancellation token.
40///
41/// The `TaskTracker` will also keep track of a `closed` boolean. This is used to handle the case
42/// where the `TaskTracker` is empty, but we don't want to shut down yet. This means that the
43/// [`wait`] method will wait until *both* of the following happen at the same time:
44///
45///  * The `TaskTracker` must be closed using the [`close`] method.
46///  * The `TaskTracker` must be empty, that is, all tasks that it is tracking must have exited.
47///
48/// When a call to [`wait`] returns, it is guaranteed that all tracked tasks have exited and that
49/// the destructor of the future has finished running.
50///
51/// # Examples
52///
53/// ## Spawn tasks and wait for them to exit
54///
55/// ```
56/// use async_task_tracker::TaskTracker;
57///
58/// async fn run() {
59///     let tracker = TaskTracker::new();
60///
61///     for i in 0..10 {
62///         my_runtime_spawn(tracker.track_future(async move {
63///             println!("Task {} is running!", i);
64///         }));
65///     }
66///     // Once we spawned everything, we close the tracker.
67///     tracker.close();
68///     // Wait for everything to finish.
69///     tracker.wait().await;
70///     println!("This is printed after all of the tasks.");
71/// }
72/// fn my_runtime_spawn(_fut: impl std::future::Future<Output = ()> + 'static) {}
73/// ```
74///
75/// [`CancellationToken`]: https://docs.rs/tokio-util/latest/tokio_util/sync/struct.CancellationToken.html
76/// [`close`]: Self::close
77/// [`wait`]: Self::wait
78/// [graceful shutdown]: https://tokio.rs/tokio/topics/shutdown
79pub struct TaskTracker {
80    inner: Arc<TaskTrackerInner>,
81}
82
83/// Represents a task tracked by a [`TaskTracker`].
84#[must_use]
85#[derive(Debug)]
86pub struct TaskTrackerToken {
87    task_tracker: TaskTracker,
88}
89
90struct TaskTrackerInner {
91    /// Keeps track of the state.
92    ///
93    /// The lowest bit is whether the task tracker is closed.
94    ///
95    /// The rest of the bits count the number of tracked tasks.
96    state: AtomicUsize,
97    /// Used to notify when the last task exits.
98    on_last_exit: Event,
99}
100
101pin_project! {
102    /// A future that is tracked as a task by a [`TaskTracker`].
103    ///
104    /// The associated [`TaskTracker`] cannot complete until this future is dropped.
105    ///
106    /// This future is returned by [`TaskTracker::track_future`].
107    #[must_use = "futures do nothing unless polled"]
108    pub struct TrackedFuture<F> {
109        #[pin]
110        future: F,
111        token: TaskTrackerToken,
112    }
113}
114
115pin_project! {
116    /// A future that completes when the [`TaskTracker`] is empty and closed.
117    ///
118    /// This future is returned by [`TaskTracker::wait`].
119    #[must_use = "futures do nothing unless polled"]
120    pub struct TaskTrackerWaitFuture<'a> {
121        #[pin]
122        future: EventListener,
123        inner: Option<&'a TaskTrackerInner>,
124    }
125}
126
127impl TaskTrackerInner {
128    fn new() -> Self {
129        Self {
130            state: AtomicUsize::new(0),
131            on_last_exit: Event::new(),
132        }
133    }
134
135    fn is_closed_and_empty(&self) -> bool {
136        // If empty and closed bit set, then we are done.
137        //
138        // The acquire load will synchronize with the release store of any previous call to
139        // `set_closed` and `drop_task`.
140        self.state.load(Ordering::Acquire) == 1
141    }
142
143    fn set_closed(&self) -> bool {
144        // The AcqRel ordering makes the closed bit behave like a `Mutex<bool>` for synchronization
145        // purposes. We do this because it makes the return value of `TaskTracker::{close,reopen}`
146        // more meaningful for the user. Without these orderings, this assert could fail:
147        // ```
148        // // thread 1
149        // some_other_atomic.store(true, Relaxed);
150        // tracker.close();
151        //
152        // // thread 2
153        // if tracker.reopen() {
154        //     assert!(some_other_atomic.load(Relaxed));
155        // }
156        // ```
157        // However, with the AcqRel ordering, we establish a happens-before relationship from the
158        // call to `close` and the later call to `reopen` that returned true.
159        let state = self.state.fetch_or(1, Ordering::AcqRel);
160
161        // If there are no tasks, and if it was not already closed:
162        if state == 0 {
163            self.notify_now();
164        }
165
166        (state & 1) == 0
167    }
168
169    fn set_open(&self) -> bool {
170        // See `set_closed` regarding the AcqRel ordering.
171        let state = self.state.fetch_and(!1, Ordering::AcqRel);
172        (state & 1) == 1
173    }
174
175    fn add_task(&self) {
176        self.state.fetch_add(2, Ordering::Relaxed);
177    }
178
179    fn drop_task(&self) {
180        let state = self.state.fetch_sub(2, Ordering::Release);
181
182        // If this was the last task and we are closed:
183        if state == 3 {
184            self.notify_now();
185        }
186    }
187
188    #[cold]
189    fn notify_now(&self) {
190        // Insert an acquire fence. This matters for `drop_task` but doesn't matter for
191        // `set_closed` since it already uses AcqRel.
192        //
193        // This synchronizes with the release store of any other call to `drop_task`, and with the
194        // release store in the call to `set_closed`. That ensures that everything that happened
195        // before those other calls to `drop_task` or `set_closed` will be visible after this load,
196        // and those things will also be visible to anything woken by the call to `notify_waiters`.
197        self.state.load(Ordering::Acquire);
198
199        self.on_last_exit.notify(usize::MAX);
200    }
201}
202
203impl TaskTracker {
204    /// Creates a new `TaskTracker`.
205    ///
206    /// The `TaskTracker` will start out as open.
207    #[must_use]
208    pub fn new() -> Self {
209        Self {
210            inner: Arc::new(TaskTrackerInner::new()),
211        }
212    }
213
214    /// Waits until this `TaskTracker` is both closed and empty.
215    ///
216    /// If the `TaskTracker` is already closed and empty when this method is called, then it
217    /// returns immediately.
218    ///
219    /// The `wait` future is resistant against [ABA problems][aba]. That is, if the `TaskTracker`
220    /// becomes both closed and empty for a short amount of time, then it is guarantee that all
221    /// `wait` futures that were created before the short time interval will trigger, even if they
222    /// are not polled during that short time interval.
223    ///
224    /// # Cancel safety
225    ///
226    /// This method is cancel safe.
227    ///
228    /// However, the resistance against [ABA problems][aba] is lost when using `wait` as the
229    /// condition in a `tokio::select!` loop.
230    ///
231    /// [aba]: https://en.wikipedia.org/wiki/ABA_problem
232    #[inline]
233    pub fn wait(&self) -> TaskTrackerWaitFuture<'_> {
234        TaskTrackerWaitFuture {
235            future: self.inner.on_last_exit.listen(),
236            inner: if self.inner.is_closed_and_empty() {
237                None
238            } else {
239                Some(&self.inner)
240            },
241        }
242    }
243
244    /// Close this `TaskTracker`.
245    ///
246    /// This allows [`wait`] futures to complete. It does not prevent you from spawning new tasks.
247    ///
248    /// Returns `true` if this closed the `TaskTracker`, or `false` if it was already closed.
249    ///
250    /// [`wait`]: Self::wait
251    #[inline]
252    pub fn close(&self) -> bool {
253        self.inner.set_closed()
254    }
255
256    /// Reopen this `TaskTracker`.
257    ///
258    /// This prevents [`wait`] futures from completing even if the `TaskTracker` is empty.
259    ///
260    /// Returns `true` if this reopened the `TaskTracker`, or `false` if it was already open.
261    ///
262    /// [`wait`]: Self::wait
263    #[inline]
264    pub fn reopen(&self) -> bool {
265        self.inner.set_open()
266    }
267
268    /// Returns `true` if this `TaskTracker` is [closed](Self::close).
269    #[inline]
270    #[must_use]
271    pub fn is_closed(&self) -> bool {
272        (self.inner.state.load(Ordering::Acquire) & 1) != 0
273    }
274
275    /// Returns the number of tasks tracked by this `TaskTracker`.
276    #[inline]
277    #[must_use]
278    pub fn len(&self) -> usize {
279        self.inner.state.load(Ordering::Acquire) >> 1
280    }
281
282    /// Returns `true` if there are no tasks in this `TaskTracker`.
283    #[inline]
284    #[must_use]
285    pub fn is_empty(&self) -> bool {
286        self.inner.state.load(Ordering::Acquire) <= 1
287    }
288
289    /// Track the provided future.
290    ///
291    /// The returned [`TrackedFuture`] will count as a task tracked by this collection, and will
292    /// prevent calls to [`wait`] from returning until the task is dropped.
293    ///
294    /// The task is removed from the collection when it is dropped, not when [`poll`] returns
295    /// [`Poll::Ready`].
296    ///
297    /// # Examples
298    ///
299    /// Track a spawned future.
300    ///
301    /// ```
302    /// # async fn my_async_fn() {}
303    /// use async_task_tracker::TaskTracker;
304    ///
305    /// # async fn run() {
306    /// let tracker = TaskTracker::new();
307    ///
308    /// my_runtime_spawn(tracker.track_future(my_async_fn()));
309    /// # fn my_runtime_spawn(_fut: impl std::future::Future<Output = ()> + 'static) {}
310    /// # }
311    /// ```
312    ///
313    /// [`Poll::Pending`]: std::task::Poll::Pending
314    /// [`poll`]: core::future::Future::poll
315    /// [`wait`]: Self::wait
316    #[inline]
317    pub fn track_future<F: Future>(&self, future: F) -> TrackedFuture<F> {
318        TrackedFuture {
319            future,
320            token: self.token(),
321        }
322    }
323
324    /// Creates a [`TaskTrackerToken`] representing a task tracked by this `TaskTracker`.
325    ///
326    /// This token is a lower-level utility than the spawn methods. Each token is considered to
327    /// correspond to a task. As long as the token exists, the `TaskTracker` cannot complete.
328    /// Furthermore, the count returned by the [`len`] method will include the tokens in the count.
329    ///
330    /// Dropping the token indicates to the `TaskTracker` that the task has exited.
331    ///
332    /// [`len`]: TaskTracker::len
333    #[inline]
334    pub fn token(&self) -> TaskTrackerToken {
335        self.inner.add_task();
336        TaskTrackerToken {
337            task_tracker: self.clone(),
338        }
339    }
340
341    /// Returns `true` if both task trackers correspond to the same set of tasks.
342    ///
343    /// # Examples
344    ///
345    /// ```
346    /// use async_task_tracker::TaskTracker;
347    ///
348    /// let tracker_1 = TaskTracker::new();
349    /// let tracker_2 = TaskTracker::new();
350    /// let tracker_1_clone = tracker_1.clone();
351    ///
352    /// assert!(TaskTracker::ptr_eq(&tracker_1, &tracker_1_clone));
353    /// assert!(!TaskTracker::ptr_eq(&tracker_1, &tracker_2));
354    /// ```
355    #[inline]
356    #[must_use]
357    pub fn ptr_eq(left: &TaskTracker, right: &TaskTracker) -> bool {
358        Arc::ptr_eq(&left.inner, &right.inner)
359    }
360}
361
362impl Default for TaskTracker {
363    /// Creates a new `TaskTracker`.
364    ///
365    /// The `TaskTracker` will start out as open.
366    #[inline]
367    fn default() -> TaskTracker {
368        TaskTracker::new()
369    }
370}
371
372impl Clone for TaskTracker {
373    /// Returns a new `TaskTracker` that tracks the same set of tasks.
374    ///
375    /// Since the new `TaskTracker` shares the same set of tasks, changes to one set are visible in
376    /// all other clones.
377    ///
378    /// # Examples
379    ///
380    /// ```
381    /// use async_task_tracker::TaskTracker;
382    ///
383    /// async fn run() {
384    ///     let tracker = TaskTracker::new();
385    ///     let cloned = tracker.clone();
386    ///
387    ///     // Spawns on `tracker` are visible in `cloned`.
388    ///     my_runtime_spawn(tracker.track_future(std::future::pending::<()>()));
389    ///     assert_eq!(cloned.len(), 1);
390    ///
391    ///     // Spawns on `cloned` are visible in `tracker`.
392    ///     my_runtime_spawn(tracker.track_future(std::future::pending::<()>()));
393    ///     assert_eq!(tracker.len(), 2);
394    ///
395    ///     // Calling `close` is visible to `cloned`.
396    ///     tracker.close();
397    ///     assert!(cloned.is_closed());
398    ///
399    ///     // Calling `reopen` is visible to `tracker`.
400    ///     cloned.reopen();
401    ///     assert!(!tracker.is_closed());
402    /// }
403    /// fn my_runtime_spawn(_fut: impl std::future::Future<Output = ()> + 'static) {}
404    /// ```
405    #[inline]
406    fn clone(&self) -> TaskTracker {
407        Self {
408            inner: self.inner.clone(),
409        }
410    }
411}
412
413fn debug_inner(inner: &TaskTrackerInner, f: &mut fmt::Formatter<'_>) -> fmt::Result {
414    let state = inner.state.load(Ordering::Acquire);
415    let is_closed = (state & 1) != 0;
416    let len = state >> 1;
417
418    f.debug_struct("TaskTracker")
419        .field("len", &len)
420        .field("is_closed", &is_closed)
421        .field("inner", &(inner as *const TaskTrackerInner))
422        .finish()
423}
424
425impl fmt::Debug for TaskTracker {
426    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
427        debug_inner(&self.inner, f)
428    }
429}
430
431impl TaskTrackerToken {
432    /// Returns the [`TaskTracker`] that this token is associated with.
433    #[inline]
434    #[must_use]
435    pub fn task_tracker(&self) -> &TaskTracker {
436        &self.task_tracker
437    }
438}
439
440impl Clone for TaskTrackerToken {
441    /// Returns a new `TaskTrackerToken` associated with the same [`TaskTracker`].
442    ///
443    /// This is equivalent to `token.task_tracker().token()`.
444    #[inline]
445    fn clone(&self) -> TaskTrackerToken {
446        self.task_tracker.token()
447    }
448}
449
450impl Drop for TaskTrackerToken {
451    /// Dropping the token indicates to the [`TaskTracker`] that the task has exited.
452    #[inline]
453    fn drop(&mut self) {
454        self.task_tracker.inner.drop_task();
455    }
456}
457
458impl<F: Future> Future for TrackedFuture<F> {
459    type Output = F::Output;
460
461    #[inline]
462    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<F::Output> {
463        self.project().future.poll(cx)
464    }
465}
466
467impl<F: fmt::Debug> fmt::Debug for TrackedFuture<F> {
468    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
469        f.debug_struct("TrackedFuture")
470            .field("future", &self.future)
471            .field("task_tracker", self.token.task_tracker())
472            .finish()
473    }
474}
475
476impl Future for TaskTrackerWaitFuture<'_> {
477    type Output = ();
478
479    #[inline]
480    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
481        let me = self.project();
482
483        let inner = match me.inner.as_ref() {
484            None => return Poll::Ready(()),
485            Some(inner) => inner,
486        };
487
488        let ready = inner.is_closed_and_empty() || me.future.poll(cx).is_ready();
489        if ready {
490            *me.inner = None;
491            Poll::Ready(())
492        } else {
493            Poll::Pending
494        }
495    }
496}
497
498impl fmt::Debug for TaskTrackerWaitFuture<'_> {
499    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
500        struct Helper<'a>(&'a TaskTrackerInner);
501
502        impl fmt::Debug for Helper<'_> {
503            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
504                debug_inner(self.0, f)
505            }
506        }
507
508        f.debug_struct("TaskTrackerWaitFuture")
509            .field("future", &self.future)
510            .field("task_tracker", &self.inner.map(Helper))
511            .finish()
512    }
513}