nym_task/cancellation/
tracker.rs

1// Copyright 2025 - Nym Technologies SA <contact@nymtech.net>
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::cancellation::token::ShutdownToken;
5use crate::spawn::{JoinHandle, spawn_named_future};
6use crate::spawn_future;
7use std::future::Future;
8use thiserror::Error;
9use tokio_util::task::TaskTracker;
10use tracing::{debug, trace};
11
12#[derive(Debug, Error)]
13#[error("task got cancelled")]
14pub struct Cancelled;
15
16/// Extracted [TaskTracker](tokio_util::task::TaskTracker) and [ShutdownToken](ShutdownToken) to more easily allow tracking nested tasks
17/// without having to pass whole [ShutdownManager](ShutdownManager) around.
18#[derive(Clone, Default, Debug)]
19pub struct ShutdownTracker {
20    /// The root [ShutdownToken](ShutdownToken) that will trigger all derived tasks
21    /// to receive cancellation signal.
22    pub(crate) root_cancellation_token: ShutdownToken,
23
24    // Note: the reason we're not using a `JoinSet` is
25    // because it forces us to use futures with the same `::Output` type,
26    // which is not really a desirable property in this instance.
27    /// Tracker used for keeping track of all registered tasks
28    /// so that they could be stopped gracefully before ending the process.
29    pub(crate) tracker: TaskTracker,
30}
31
32#[cfg(not(target_arch = "wasm32"))]
33impl ShutdownTracker {
34    /// Spawn the provided future on the current Tokio runtime, and track it in the underlying [TaskTracker](tokio_util::task::TaskTracker).
35    #[track_caller]
36    pub fn spawn<F>(&self, task: F) -> JoinHandle<F::Output>
37    where
38        F: Future + Send + 'static,
39        F::Output: Send + 'static,
40    {
41        let tracked = self.tracker.track_future(task);
42        spawn_future(tracked)
43    }
44
45    /// Spawn the provided future on the current Tokio runtime,
46    /// and track it in the underlying [TaskTracker](tokio_util::task::TaskTracker).
47    /// Furthermore, attach a name to the spawned task to more easily track it within a [tokio console](https://github.com/tokio-rs/console)
48    ///
49    /// Note that is no different from [spawn](Self::spawn) if the underlying binary
50    /// has not been built with `RUSTFLAGS="--cfg tokio_unstable"` and `--features="tokio-tracing"`
51    #[track_caller]
52    pub fn try_spawn_named<F>(&self, task: F, name: &str) -> JoinHandle<F::Output>
53    where
54        F: Future + Send + 'static,
55        F::Output: Send + 'static,
56    {
57        trace!("attempting to spawn task {name}");
58        let tracked = self.tracker.track_future(task);
59        spawn_named_future(tracked, name)
60    }
61
62    /// Spawn the provided future on the provided Tokio runtime,
63    /// and track it in the underlying [TaskTracker](tokio_util::task::TaskTracker).
64    #[track_caller]
65    pub fn spawn_on<F>(&self, task: F, handle: &tokio::runtime::Handle) -> JoinHandle<F::Output>
66    where
67        F: Future + Send + 'static,
68        F::Output: Send + 'static,
69    {
70        self.tracker.spawn_on(task, handle)
71    }
72
73    /// Spawn the provided future on the current [LocalSet](tokio::task::LocalSet),
74    /// and track it in the underlying [TaskTracker](tokio_util::task::TaskTracker).
75    #[track_caller]
76    pub fn spawn_local<F>(&self, task: F) -> JoinHandle<F::Output>
77    where
78        F: Future + 'static,
79        F::Output: 'static,
80    {
81        self.tracker.spawn_local(task)
82    }
83
84    /// Spawn the provided blocking task on the current Tokio runtime,
85    /// and track it in the underlying [TaskTracker](tokio_util::task::TaskTracker).
86    #[track_caller]
87    pub fn spawn_blocking<F, T>(&self, task: F) -> JoinHandle<T>
88    where
89        F: FnOnce() -> T,
90        F: Send + 'static,
91        T: Send + 'static,
92    {
93        self.tracker.spawn_blocking(task)
94    }
95
96    /// Spawn the provided blocking task on the provided Tokio runtime,
97    /// and track it in the underlying [TaskTracker](tokio_util::task::TaskTracker).
98    #[track_caller]
99    pub fn spawn_blocking_on<F, T>(&self, task: F, handle: &tokio::runtime::Handle) -> JoinHandle<T>
100    where
101        F: FnOnce() -> T,
102        F: Send + 'static,
103        T: Send + 'static,
104    {
105        self.tracker.spawn_blocking_on(task, handle)
106    }
107
108    /// Spawn the provided future on the current Tokio runtime
109    /// that will get cancelled once a global shutdown signal is detected,
110    /// and track it in the underlying [TaskTracker](tokio_util::task::TaskTracker).
111    ///
112    /// Note that to fully use the naming feature, such as tracking within a [tokio console](https://github.com/tokio-rs/console),
113    /// the underlying binary has to be built with `RUSTFLAGS="--cfg tokio_unstable"` and `--features="tokio-tracing"`
114    #[track_caller]
115    pub fn try_spawn_named_with_shutdown<F>(
116        &self,
117        task: F,
118        name: &str,
119    ) -> JoinHandle<Result<F::Output, Cancelled>>
120    where
121        F: Future + Send + 'static,
122        F::Output: Send + 'static,
123    {
124        trace!("attempting to spawn task {name} (with top-level cancellation)");
125
126        let caller = std::panic::Location::caller();
127        let shutdown_token = self.clone_shutdown_token();
128        let name_owned = name.to_string();
129        let tracked = self.tracker.track_future(async move {
130            match shutdown_token.run_until_cancelled_owned(task).await {
131                Some(result) => {
132                    debug!("{name_owned} @ {caller}: task has finished execution");
133                    Ok(result)
134                }
135                None => {
136                    trace!("{name_owned} @ {caller}: shutdown signal received, shutting down");
137                    Err(Cancelled)
138                }
139            }
140        });
141        spawn_named_future(tracked, name)
142    }
143
144    /// Spawn the provided future on the current Tokio runtime
145    /// that will get cancelled once a global shutdown signal is detected,
146    /// and track it in the underlying [TaskTracker](tokio_util::task::TaskTracker).
147    #[track_caller]
148    pub fn spawn_with_shutdown<F>(&self, task: F) -> JoinHandle<Result<F::Output, Cancelled>>
149    where
150        F: Future + Send + 'static,
151        F::Output: Send + 'static,
152    {
153        let caller = std::panic::Location::caller();
154        let shutdown_token = self.clone_shutdown_token();
155        self.tracker.spawn(async move {
156            match shutdown_token.run_until_cancelled_owned(task).await {
157                Some(result) => {
158                    debug!("{caller}: task has finished execution");
159                    Ok(result)
160                }
161                None => {
162                    trace!("{caller}: shutdown signal received, shutting down");
163                    Err(Cancelled)
164                }
165            }
166        })
167    }
168}
169
170#[cfg(target_arch = "wasm32")]
171impl ShutdownTracker {
172    /// Run the provided future on the current thread, and track it in the underlying [TaskTracker](tokio_util::task::TaskTracker).
173    #[track_caller]
174    pub fn spawn<F>(&self, task: F) -> JoinHandle<F::Output>
175    where
176        F: Future + 'static,
177    {
178        let tracked = self.tracker.track_future(task);
179        spawn_future(tracked)
180    }
181
182    /// Run the provided future on the current thread, and track it in the underlying [TaskTracker](tokio_util::task::TaskTracker).
183    /// It has exactly the same behaviour as [spawn](Self::spawn) and it only exists to provide
184    /// the same interface as non-wasm32 targets.
185    #[track_caller]
186    pub fn try_spawn_named<F>(&self, task: F, name: &str) -> JoinHandle<F::Output>
187    where
188        F: Future + 'static,
189    {
190        let tracked = self.tracker.track_future(task);
191        spawn_named_future(tracked, name)
192    }
193
194    /// Run the provided future on the current thread
195    /// that will get cancelled once a global shutdown signal is detected,
196    /// and track it in the underlying [TaskTracker](tokio_util::task::TaskTracker).
197    /// It has exactly the same behaviour as [spawn_with_shutdown](Self::spawn_with_shutdown) and it only exists to provide
198    /// the same interface as non-wasm32 targets.
199    #[track_caller]
200    pub fn try_spawn_named_with_shutdown<F>(
201        &self,
202        task: F,
203        name: &str,
204    ) -> JoinHandle<Result<F::Output, Cancelled>>
205    where
206        F: Future<Output = ()> + 'static,
207    {
208        let caller = std::panic::Location::caller();
209        let shutdown_token = self.clone_shutdown_token();
210        let tracked = self.tracker.track_future(async move {
211            match shutdown_token.run_until_cancelled_owned(task).await {
212                Some(result) => {
213                    debug!("{caller}: task has finished execution");
214                    Ok(result)
215                }
216                None => {
217                    trace!("{caller}: shutdown signal received, shutting down");
218                    Err(Cancelled)
219                }
220            }
221        });
222        spawn_named_future(tracked, name)
223    }
224
225    /// Run the provided future on the current thread
226    /// that will get cancelled once a global shutdown signal is detected,
227    /// and track it in the underlying [TaskTracker](tokio_util::task::TaskTracker).
228    #[track_caller]
229    pub fn spawn_with_shutdown<F>(&self, task: F) -> JoinHandle<Result<F::Output, Cancelled>>
230    where
231        F: Future<Output = ()> + 'static,
232    {
233        let caller = std::panic::Location::caller();
234        let shutdown_token = self.clone_shutdown_token();
235        let tracked = self.tracker.track_future(async move {
236            match shutdown_token.run_until_cancelled_owned(task).await {
237                Some(result) => {
238                    debug!("{caller}: task has finished execution");
239                    Ok(result)
240                }
241                None => {
242                    trace!("{caller}: shutdown signal received, shutting down");
243                    Err(Cancelled)
244                }
245            }
246        });
247        spawn_future(tracked)
248    }
249}
250
251impl ShutdownTracker {
252    /// Create new instance of the ShutdownTracker using an external shutdown token.
253    /// This could be useful in situations where shutdown is being managed by an external entity
254    /// that is not [ShutdownManager](ShutdownManager), but interface requires providing a ShutdownTracker,
255    /// such as client-core tasks
256    pub fn new_from_external_shutdown_token(shutdown_token: ShutdownToken) -> Self {
257        ShutdownTracker {
258            root_cancellation_token: shutdown_token,
259            tracker: Default::default(),
260        }
261    }
262
263    /// Waits until the underlying [TaskTracker](tokio_util::task::TaskTracker) is both closed and empty.
264    ///
265    /// If the underlying [TaskTracker](tokio_util::task::TaskTracker) is already closed and empty when this method is called, then it
266    /// returns immediately.
267    pub async fn wait_for_tracker(&self) {
268        self.tracker.wait().await;
269    }
270
271    /// Close the underlying [TaskTracker](tokio_util::task::TaskTracker).
272    ///
273    /// This allows [`wait_for_tracker`] futures to complete. It does not prevent you from spawning new tasks.
274    ///
275    /// Returns `true` if this closed the underlying [TaskTracker](tokio_util::task::TaskTracker), or `false` if it was already closed.
276    ///
277    /// [`wait_for_tracker`]: Self::wait_for_tracker
278    pub fn close_tracker(&self) -> bool {
279        self.tracker.close()
280    }
281
282    /// Reopen the underlying [TaskTracker](tokio_util::task::TaskTracker).
283    ///
284    /// This prevents [`wait_for_tracker`] futures from completing even if the underlying [TaskTracker](tokio_util::task::TaskTracker) is empty.
285    ///
286    /// Returns `true` if this reopened the underlying [TaskTracker](tokio_util::task::TaskTracker), or `false` if it was already open.
287    ///
288    /// [`wait_for_tracker`]: Self::wait_for_tracker
289    pub fn reopen_tracker(&self) -> bool {
290        self.tracker.reopen()
291    }
292
293    /// Returns `true` if the underlying [TaskTracker](tokio_util::task::TaskTracker) is [closed](Self::close_tracker).
294    pub fn is_tracker_closed(&self) -> bool {
295        self.tracker.is_closed()
296    }
297
298    /// Returns the number of tasks tracked by the underlying [TaskTracker](tokio_util::task::TaskTracker).
299    pub fn tracked_tasks(&self) -> usize {
300        self.tracker.len()
301    }
302
303    /// Returns `true` if there are no tasks in the underlying [TaskTracker](tokio_util::task::TaskTracker).
304    pub fn is_tracker_empty(&self) -> bool {
305        self.tracker.is_empty()
306    }
307
308    /// Obtain a [ShutdownToken](crate::cancellation::ShutdownToken) that is a child of the root token
309    pub fn child_shutdown_token(&self) -> ShutdownToken {
310        self.root_cancellation_token.child_token()
311    }
312
313    /// Obtain a [ShutdownToken](crate::cancellation::ShutdownToken) on the same hierarchical structure as the root token
314    pub fn clone_shutdown_token(&self) -> ShutdownToken {
315        self.root_cancellation_token.clone()
316    }
317
318    /// Create a child ShutdownTracker that inherits cancellation from this tracker
319    /// but has its own TaskTracker for managing sub-tasks.
320    ///
321    /// This enables hierarchical task management where:
322    /// - Parent cancellation flows to all children
323    /// - Each level tracks its own tasks independently
324    /// - Components can wait for their specific sub-tasks to complete
325    pub fn child_tracker(&self) -> ShutdownTracker {
326        // Child token inherits cancellation from parent
327        let child_token = self.root_cancellation_token.child_token();
328
329        // New TaskTracker for this level's tasks
330        let child_task_tracker = TaskTracker::new();
331
332        ShutdownTracker {
333            root_cancellation_token: child_token,
334            tracker: child_task_tracker,
335        }
336    }
337
338    /// Convenience method to perform a complete shutdown sequence.
339    /// This method:
340    /// 1. Signals cancellation to all tasks
341    /// 2. Closes the tracker to prevent new tasks
342    /// 3. Waits for all existing tasks to complete
343    pub async fn shutdown(self) {
344        // Signal cancellation to all tasks
345        self.root_cancellation_token.cancel();
346
347        // Close the tracker to prevent new tasks from being spawned
348        self.tracker.close();
349
350        // Wait for all existing tasks to complete
351        self.tracker.wait().await;
352    }
353}