async_task_tracker/lib.rs
1#![allow(unknown_lints, unexpected_cfgs)]
2#![allow(clippy::needless_doctest_main)]
3#![warn(
4 missing_debug_implementations,
5 missing_docs,
6 rust_2018_idioms,
7 unreachable_pub
8)]
9#![doc(test(
10 no_crate_inject,
11 attr(deny(warnings, rust_2018_idioms), allow(dead_code, unused_variables))
12))]
13#![cfg_attr(docsrs, feature(doc_cfg))]
14#![cfg_attr(not(feature = "std"), no_std)]
15
16//! Modified version of [`TaskTracker`] from the `tokio-util` crate without tokio-specific features.
17
18extern crate alloc;
19
20use core::{
21 fmt,
22 future::Future,
23 pin::Pin,
24 sync::atomic::{AtomicUsize, Ordering},
25 task::{Context, Poll},
26};
27
28use alloc::sync::Arc;
29
30use event_listener::{Event, EventListener};
31use pin_project_lite::pin_project;
32
33/// A task tracker used for waiting until tasks exit.
34///
35/// This is usually used together with a cancellation token to implement [graceful shutdown]. The
36/// cancellation token is used to signal to tasks that they should shut down, and the
37/// `TaskTracker` is used to wait for them to finish shutting down. For tokio runtime, there is a
38/// [`CancellationToken`] in the `tokio-util` crate that can be used for this purpose. Otherwise,
39/// consider using an mpsc channel as a cancellation token.
40///
41/// The `TaskTracker` will also keep track of a `closed` boolean. This is used to handle the case
42/// where the `TaskTracker` is empty, but we don't want to shut down yet. This means that the
43/// [`wait`] method will wait until *both* of the following happen at the same time:
44///
45/// * The `TaskTracker` must be closed using the [`close`] method.
46/// * The `TaskTracker` must be empty, that is, all tasks that it is tracking must have exited.
47///
48/// When a call to [`wait`] returns, it is guaranteed that all tracked tasks have exited and that
49/// the destructor of the future has finished running.
50///
51/// # Examples
52///
53/// ## Spawn tasks and wait for them to exit
54///
55/// ```
56/// use async_task_tracker::TaskTracker;
57///
58/// async fn run() {
59/// let tracker = TaskTracker::new();
60///
61/// for i in 0..10 {
62/// my_runtime_spawn(tracker.track_future(async move {
63/// println!("Task {} is running!", i);
64/// }));
65/// }
66/// // Once we spawned everything, we close the tracker.
67/// tracker.close();
68/// // Wait for everything to finish.
69/// tracker.wait().await;
70/// println!("This is printed after all of the tasks.");
71/// }
72/// fn my_runtime_spawn(_fut: impl std::future::Future<Output = ()> + 'static) {}
73/// ```
74///
75/// [`CancellationToken`]: https://docs.rs/tokio-util/latest/tokio_util/sync/struct.CancellationToken.html
76/// [`close`]: Self::close
77/// [`wait`]: Self::wait
78/// [graceful shutdown]: https://tokio.rs/tokio/topics/shutdown
79pub struct TaskTracker {
80 inner: Arc<TaskTrackerInner>,
81}
82
83/// Represents a task tracked by a [`TaskTracker`].
84#[must_use]
85#[derive(Debug)]
86pub struct TaskTrackerToken {
87 task_tracker: TaskTracker,
88}
89
90struct TaskTrackerInner {
91 /// Keeps track of the state.
92 ///
93 /// The lowest bit is whether the task tracker is closed.
94 ///
95 /// The rest of the bits count the number of tracked tasks.
96 state: AtomicUsize,
97 /// Used to notify when the last task exits.
98 on_last_exit: Event,
99}
100
101pin_project! {
102 /// A future that is tracked as a task by a [`TaskTracker`].
103 ///
104 /// The associated [`TaskTracker`] cannot complete until this future is dropped.
105 ///
106 /// This future is returned by [`TaskTracker::track_future`].
107 #[must_use = "futures do nothing unless polled"]
108 pub struct TrackedFuture<F> {
109 #[pin]
110 future: F,
111 token: TaskTrackerToken,
112 }
113}
114
115pin_project! {
116 /// A future that completes when the [`TaskTracker`] is empty and closed.
117 ///
118 /// This future is returned by [`TaskTracker::wait`].
119 #[must_use = "futures do nothing unless polled"]
120 pub struct TaskTrackerWaitFuture<'a> {
121 #[pin]
122 future: EventListener,
123 inner: Option<&'a TaskTrackerInner>,
124 }
125}
126
127impl TaskTrackerInner {
128 fn new() -> Self {
129 Self {
130 state: AtomicUsize::new(0),
131 on_last_exit: Event::new(),
132 }
133 }
134
135 fn is_closed_and_empty(&self) -> bool {
136 // If empty and closed bit set, then we are done.
137 //
138 // The acquire load will synchronize with the release store of any previous call to
139 // `set_closed` and `drop_task`.
140 self.state.load(Ordering::Acquire) == 1
141 }
142
143 fn set_closed(&self) -> bool {
144 // The AcqRel ordering makes the closed bit behave like a `Mutex<bool>` for synchronization
145 // purposes. We do this because it makes the return value of `TaskTracker::{close,reopen}`
146 // more meaningful for the user. Without these orderings, this assert could fail:
147 // ```
148 // // thread 1
149 // some_other_atomic.store(true, Relaxed);
150 // tracker.close();
151 //
152 // // thread 2
153 // if tracker.reopen() {
154 // assert!(some_other_atomic.load(Relaxed));
155 // }
156 // ```
157 // However, with the AcqRel ordering, we establish a happens-before relationship from the
158 // call to `close` and the later call to `reopen` that returned true.
159 let state = self.state.fetch_or(1, Ordering::AcqRel);
160
161 // If there are no tasks, and if it was not already closed:
162 if state == 0 {
163 self.notify_now();
164 }
165
166 (state & 1) == 0
167 }
168
169 fn set_open(&self) -> bool {
170 // See `set_closed` regarding the AcqRel ordering.
171 let state = self.state.fetch_and(!1, Ordering::AcqRel);
172 (state & 1) == 1
173 }
174
175 fn add_task(&self) {
176 self.state.fetch_add(2, Ordering::Relaxed);
177 }
178
179 fn drop_task(&self) {
180 let state = self.state.fetch_sub(2, Ordering::Release);
181
182 // If this was the last task and we are closed:
183 if state == 3 {
184 self.notify_now();
185 }
186 }
187
188 #[cold]
189 fn notify_now(&self) {
190 // Insert an acquire fence. This matters for `drop_task` but doesn't matter for
191 // `set_closed` since it already uses AcqRel.
192 //
193 // This synchronizes with the release store of any other call to `drop_task`, and with the
194 // release store in the call to `set_closed`. That ensures that everything that happened
195 // before those other calls to `drop_task` or `set_closed` will be visible after this load,
196 // and those things will also be visible to anything woken by the call to `notify_waiters`.
197 self.state.load(Ordering::Acquire);
198
199 self.on_last_exit.notify(usize::MAX);
200 }
201}
202
203impl TaskTracker {
204 /// Creates a new `TaskTracker`.
205 ///
206 /// The `TaskTracker` will start out as open.
207 #[must_use]
208 pub fn new() -> Self {
209 Self {
210 inner: Arc::new(TaskTrackerInner::new()),
211 }
212 }
213
214 /// Waits until this `TaskTracker` is both closed and empty.
215 ///
216 /// If the `TaskTracker` is already closed and empty when this method is called, then it
217 /// returns immediately.
218 ///
219 /// The `wait` future is resistant against [ABA problems][aba]. That is, if the `TaskTracker`
220 /// becomes both closed and empty for a short amount of time, then it is guarantee that all
221 /// `wait` futures that were created before the short time interval will trigger, even if they
222 /// are not polled during that short time interval.
223 ///
224 /// # Cancel safety
225 ///
226 /// This method is cancel safe.
227 ///
228 /// However, the resistance against [ABA problems][aba] is lost when using `wait` as the
229 /// condition in a `tokio::select!` loop.
230 ///
231 /// [aba]: https://en.wikipedia.org/wiki/ABA_problem
232 #[inline]
233 pub fn wait(&self) -> TaskTrackerWaitFuture<'_> {
234 TaskTrackerWaitFuture {
235 future: self.inner.on_last_exit.listen(),
236 inner: if self.inner.is_closed_and_empty() {
237 None
238 } else {
239 Some(&self.inner)
240 },
241 }
242 }
243
244 /// Close this `TaskTracker`.
245 ///
246 /// This allows [`wait`] futures to complete. It does not prevent you from spawning new tasks.
247 ///
248 /// Returns `true` if this closed the `TaskTracker`, or `false` if it was already closed.
249 ///
250 /// [`wait`]: Self::wait
251 #[inline]
252 pub fn close(&self) -> bool {
253 self.inner.set_closed()
254 }
255
256 /// Reopen this `TaskTracker`.
257 ///
258 /// This prevents [`wait`] futures from completing even if the `TaskTracker` is empty.
259 ///
260 /// Returns `true` if this reopened the `TaskTracker`, or `false` if it was already open.
261 ///
262 /// [`wait`]: Self::wait
263 #[inline]
264 pub fn reopen(&self) -> bool {
265 self.inner.set_open()
266 }
267
268 /// Returns `true` if this `TaskTracker` is [closed](Self::close).
269 #[inline]
270 #[must_use]
271 pub fn is_closed(&self) -> bool {
272 (self.inner.state.load(Ordering::Acquire) & 1) != 0
273 }
274
275 /// Returns the number of tasks tracked by this `TaskTracker`.
276 #[inline]
277 #[must_use]
278 pub fn len(&self) -> usize {
279 self.inner.state.load(Ordering::Acquire) >> 1
280 }
281
282 /// Returns `true` if there are no tasks in this `TaskTracker`.
283 #[inline]
284 #[must_use]
285 pub fn is_empty(&self) -> bool {
286 self.inner.state.load(Ordering::Acquire) <= 1
287 }
288
289 /// Track the provided future.
290 ///
291 /// The returned [`TrackedFuture`] will count as a task tracked by this collection, and will
292 /// prevent calls to [`wait`] from returning until the task is dropped.
293 ///
294 /// The task is removed from the collection when it is dropped, not when [`poll`] returns
295 /// [`Poll::Ready`].
296 ///
297 /// # Examples
298 ///
299 /// Track a spawned future.
300 ///
301 /// ```
302 /// # async fn my_async_fn() {}
303 /// use async_task_tracker::TaskTracker;
304 ///
305 /// # async fn run() {
306 /// let tracker = TaskTracker::new();
307 ///
308 /// my_runtime_spawn(tracker.track_future(my_async_fn()));
309 /// # fn my_runtime_spawn(_fut: impl std::future::Future<Output = ()> + 'static) {}
310 /// # }
311 /// ```
312 ///
313 /// [`Poll::Pending`]: std::task::Poll::Pending
314 /// [`poll`]: core::future::Future::poll
315 /// [`wait`]: Self::wait
316 #[inline]
317 pub fn track_future<F: Future>(&self, future: F) -> TrackedFuture<F> {
318 TrackedFuture {
319 future,
320 token: self.token(),
321 }
322 }
323
324 /// Creates a [`TaskTrackerToken`] representing a task tracked by this `TaskTracker`.
325 ///
326 /// This token is a lower-level utility than the spawn methods. Each token is considered to
327 /// correspond to a task. As long as the token exists, the `TaskTracker` cannot complete.
328 /// Furthermore, the count returned by the [`len`] method will include the tokens in the count.
329 ///
330 /// Dropping the token indicates to the `TaskTracker` that the task has exited.
331 ///
332 /// [`len`]: TaskTracker::len
333 #[inline]
334 pub fn token(&self) -> TaskTrackerToken {
335 self.inner.add_task();
336 TaskTrackerToken {
337 task_tracker: self.clone(),
338 }
339 }
340
341 /// Returns `true` if both task trackers correspond to the same set of tasks.
342 ///
343 /// # Examples
344 ///
345 /// ```
346 /// use async_task_tracker::TaskTracker;
347 ///
348 /// let tracker_1 = TaskTracker::new();
349 /// let tracker_2 = TaskTracker::new();
350 /// let tracker_1_clone = tracker_1.clone();
351 ///
352 /// assert!(TaskTracker::ptr_eq(&tracker_1, &tracker_1_clone));
353 /// assert!(!TaskTracker::ptr_eq(&tracker_1, &tracker_2));
354 /// ```
355 #[inline]
356 #[must_use]
357 pub fn ptr_eq(left: &TaskTracker, right: &TaskTracker) -> bool {
358 Arc::ptr_eq(&left.inner, &right.inner)
359 }
360}
361
362impl Default for TaskTracker {
363 /// Creates a new `TaskTracker`.
364 ///
365 /// The `TaskTracker` will start out as open.
366 #[inline]
367 fn default() -> TaskTracker {
368 TaskTracker::new()
369 }
370}
371
372impl Clone for TaskTracker {
373 /// Returns a new `TaskTracker` that tracks the same set of tasks.
374 ///
375 /// Since the new `TaskTracker` shares the same set of tasks, changes to one set are visible in
376 /// all other clones.
377 ///
378 /// # Examples
379 ///
380 /// ```
381 /// use async_task_tracker::TaskTracker;
382 ///
383 /// async fn run() {
384 /// let tracker = TaskTracker::new();
385 /// let cloned = tracker.clone();
386 ///
387 /// // Spawns on `tracker` are visible in `cloned`.
388 /// my_runtime_spawn(tracker.track_future(std::future::pending::<()>()));
389 /// assert_eq!(cloned.len(), 1);
390 ///
391 /// // Spawns on `cloned` are visible in `tracker`.
392 /// my_runtime_spawn(tracker.track_future(std::future::pending::<()>()));
393 /// assert_eq!(tracker.len(), 2);
394 ///
395 /// // Calling `close` is visible to `cloned`.
396 /// tracker.close();
397 /// assert!(cloned.is_closed());
398 ///
399 /// // Calling `reopen` is visible to `tracker`.
400 /// cloned.reopen();
401 /// assert!(!tracker.is_closed());
402 /// }
403 /// fn my_runtime_spawn(_fut: impl std::future::Future<Output = ()> + 'static) {}
404 /// ```
405 #[inline]
406 fn clone(&self) -> TaskTracker {
407 Self {
408 inner: self.inner.clone(),
409 }
410 }
411}
412
413fn debug_inner(inner: &TaskTrackerInner, f: &mut fmt::Formatter<'_>) -> fmt::Result {
414 let state = inner.state.load(Ordering::Acquire);
415 let is_closed = (state & 1) != 0;
416 let len = state >> 1;
417
418 f.debug_struct("TaskTracker")
419 .field("len", &len)
420 .field("is_closed", &is_closed)
421 .field("inner", &(inner as *const TaskTrackerInner))
422 .finish()
423}
424
425impl fmt::Debug for TaskTracker {
426 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
427 debug_inner(&self.inner, f)
428 }
429}
430
431impl TaskTrackerToken {
432 /// Returns the [`TaskTracker`] that this token is associated with.
433 #[inline]
434 #[must_use]
435 pub fn task_tracker(&self) -> &TaskTracker {
436 &self.task_tracker
437 }
438}
439
440impl Clone for TaskTrackerToken {
441 /// Returns a new `TaskTrackerToken` associated with the same [`TaskTracker`].
442 ///
443 /// This is equivalent to `token.task_tracker().token()`.
444 #[inline]
445 fn clone(&self) -> TaskTrackerToken {
446 self.task_tracker.token()
447 }
448}
449
450impl Drop for TaskTrackerToken {
451 /// Dropping the token indicates to the [`TaskTracker`] that the task has exited.
452 #[inline]
453 fn drop(&mut self) {
454 self.task_tracker.inner.drop_task();
455 }
456}
457
458impl<F: Future> Future for TrackedFuture<F> {
459 type Output = F::Output;
460
461 #[inline]
462 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<F::Output> {
463 self.project().future.poll(cx)
464 }
465}
466
467impl<F: fmt::Debug> fmt::Debug for TrackedFuture<F> {
468 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
469 f.debug_struct("TrackedFuture")
470 .field("future", &self.future)
471 .field("task_tracker", self.token.task_tracker())
472 .finish()
473 }
474}
475
476impl Future for TaskTrackerWaitFuture<'_> {
477 type Output = ();
478
479 #[inline]
480 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
481 let me = self.project();
482
483 let inner = match me.inner.as_ref() {
484 None => return Poll::Ready(()),
485 Some(inner) => inner,
486 };
487
488 let ready = inner.is_closed_and_empty() || me.future.poll(cx).is_ready();
489 if ready {
490 *me.inner = None;
491 Poll::Ready(())
492 } else {
493 Poll::Pending
494 }
495 }
496}
497
498impl fmt::Debug for TaskTrackerWaitFuture<'_> {
499 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
500 struct Helper<'a>(&'a TaskTrackerInner);
501
502 impl fmt::Debug for Helper<'_> {
503 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
504 debug_inner(self.0, f)
505 }
506 }
507
508 f.debug_struct("TaskTrackerWaitFuture")
509 .field("future", &self.future)
510 .field("task_tracker", &self.inner.map(Helper))
511 .finish()
512 }
513}