nym_task/cancellation/tracker.rs
1// Copyright 2025 - Nym Technologies SA <contact@nymtech.net>
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::cancellation::token::ShutdownToken;
5use crate::spawn::{JoinHandle, spawn_named_future};
6use crate::spawn_future;
7use std::future::Future;
8use thiserror::Error;
9use tokio_util::task::TaskTracker;
10use tracing::{debug, trace};
11
12#[derive(Debug, Error)]
13#[error("task got cancelled")]
14pub struct Cancelled;
15
16/// Extracted [TaskTracker](tokio_util::task::TaskTracker) and [ShutdownToken](ShutdownToken) to more easily allow tracking nested tasks
17/// without having to pass whole [ShutdownManager](ShutdownManager) around.
18#[derive(Clone, Default, Debug)]
19pub struct ShutdownTracker {
20 /// The root [ShutdownToken](ShutdownToken) that will trigger all derived tasks
21 /// to receive cancellation signal.
22 pub(crate) root_cancellation_token: ShutdownToken,
23
24 // Note: the reason we're not using a `JoinSet` is
25 // because it forces us to use futures with the same `::Output` type,
26 // which is not really a desirable property in this instance.
27 /// Tracker used for keeping track of all registered tasks
28 /// so that they could be stopped gracefully before ending the process.
29 pub(crate) tracker: TaskTracker,
30}
31
32#[cfg(not(target_arch = "wasm32"))]
33impl ShutdownTracker {
34 /// Spawn the provided future on the current Tokio runtime, and track it in the underlying [TaskTracker](tokio_util::task::TaskTracker).
35 #[track_caller]
36 pub fn spawn<F>(&self, task: F) -> JoinHandle<F::Output>
37 where
38 F: Future + Send + 'static,
39 F::Output: Send + 'static,
40 {
41 let tracked = self.tracker.track_future(task);
42 spawn_future(tracked)
43 }
44
45 /// Spawn the provided future on the current Tokio runtime,
46 /// and track it in the underlying [TaskTracker](tokio_util::task::TaskTracker).
47 /// Furthermore, attach a name to the spawned task to more easily track it within a [tokio console](https://github.com/tokio-rs/console)
48 ///
49 /// Note that is no different from [spawn](Self::spawn) if the underlying binary
50 /// has not been built with `RUSTFLAGS="--cfg tokio_unstable"` and `--features="tokio-tracing"`
51 #[track_caller]
52 pub fn try_spawn_named<F>(&self, task: F, name: &str) -> JoinHandle<F::Output>
53 where
54 F: Future + Send + 'static,
55 F::Output: Send + 'static,
56 {
57 trace!("attempting to spawn task {name}");
58 let tracked = self.tracker.track_future(task);
59 spawn_named_future(tracked, name)
60 }
61
62 /// Spawn the provided future on the provided Tokio runtime,
63 /// and track it in the underlying [TaskTracker](tokio_util::task::TaskTracker).
64 #[track_caller]
65 pub fn spawn_on<F>(&self, task: F, handle: &tokio::runtime::Handle) -> JoinHandle<F::Output>
66 where
67 F: Future + Send + 'static,
68 F::Output: Send + 'static,
69 {
70 self.tracker.spawn_on(task, handle)
71 }
72
73 /// Spawn the provided future on the current [LocalSet](tokio::task::LocalSet),
74 /// and track it in the underlying [TaskTracker](tokio_util::task::TaskTracker).
75 #[track_caller]
76 pub fn spawn_local<F>(&self, task: F) -> JoinHandle<F::Output>
77 where
78 F: Future + 'static,
79 F::Output: 'static,
80 {
81 self.tracker.spawn_local(task)
82 }
83
84 /// Spawn the provided blocking task on the current Tokio runtime,
85 /// and track it in the underlying [TaskTracker](tokio_util::task::TaskTracker).
86 #[track_caller]
87 pub fn spawn_blocking<F, T>(&self, task: F) -> JoinHandle<T>
88 where
89 F: FnOnce() -> T,
90 F: Send + 'static,
91 T: Send + 'static,
92 {
93 self.tracker.spawn_blocking(task)
94 }
95
96 /// Spawn the provided blocking task on the provided Tokio runtime,
97 /// and track it in the underlying [TaskTracker](tokio_util::task::TaskTracker).
98 #[track_caller]
99 pub fn spawn_blocking_on<F, T>(&self, task: F, handle: &tokio::runtime::Handle) -> JoinHandle<T>
100 where
101 F: FnOnce() -> T,
102 F: Send + 'static,
103 T: Send + 'static,
104 {
105 self.tracker.spawn_blocking_on(task, handle)
106 }
107
108 /// Spawn the provided future on the current Tokio runtime
109 /// that will get cancelled once a global shutdown signal is detected,
110 /// and track it in the underlying [TaskTracker](tokio_util::task::TaskTracker).
111 ///
112 /// Note that to fully use the naming feature, such as tracking within a [tokio console](https://github.com/tokio-rs/console),
113 /// the underlying binary has to be built with `RUSTFLAGS="--cfg tokio_unstable"` and `--features="tokio-tracing"`
114 #[track_caller]
115 pub fn try_spawn_named_with_shutdown<F>(
116 &self,
117 task: F,
118 name: &str,
119 ) -> JoinHandle<Result<F::Output, Cancelled>>
120 where
121 F: Future + Send + 'static,
122 F::Output: Send + 'static,
123 {
124 trace!("attempting to spawn task {name} (with top-level cancellation)");
125
126 let caller = std::panic::Location::caller();
127 let shutdown_token = self.clone_shutdown_token();
128 let name_owned = name.to_string();
129 let tracked = self.tracker.track_future(async move {
130 match shutdown_token.run_until_cancelled_owned(task).await {
131 Some(result) => {
132 debug!("{name_owned} @ {caller}: task has finished execution");
133 Ok(result)
134 }
135 None => {
136 trace!("{name_owned} @ {caller}: shutdown signal received, shutting down");
137 Err(Cancelled)
138 }
139 }
140 });
141 spawn_named_future(tracked, name)
142 }
143
144 /// Spawn the provided future on the current Tokio runtime
145 /// that will get cancelled once a global shutdown signal is detected,
146 /// and track it in the underlying [TaskTracker](tokio_util::task::TaskTracker).
147 #[track_caller]
148 pub fn spawn_with_shutdown<F>(&self, task: F) -> JoinHandle<Result<F::Output, Cancelled>>
149 where
150 F: Future + Send + 'static,
151 F::Output: Send + 'static,
152 {
153 let caller = std::panic::Location::caller();
154 let shutdown_token = self.clone_shutdown_token();
155 self.tracker.spawn(async move {
156 match shutdown_token.run_until_cancelled_owned(task).await {
157 Some(result) => {
158 debug!("{caller}: task has finished execution");
159 Ok(result)
160 }
161 None => {
162 trace!("{caller}: shutdown signal received, shutting down");
163 Err(Cancelled)
164 }
165 }
166 })
167 }
168}
169
170#[cfg(target_arch = "wasm32")]
171impl ShutdownTracker {
172 /// Run the provided future on the current thread, and track it in the underlying [TaskTracker](tokio_util::task::TaskTracker).
173 #[track_caller]
174 pub fn spawn<F>(&self, task: F) -> JoinHandle<F::Output>
175 where
176 F: Future + 'static,
177 {
178 let tracked = self.tracker.track_future(task);
179 spawn_future(tracked)
180 }
181
182 /// Run the provided future on the current thread, and track it in the underlying [TaskTracker](tokio_util::task::TaskTracker).
183 /// It has exactly the same behaviour as [spawn](Self::spawn) and it only exists to provide
184 /// the same interface as non-wasm32 targets.
185 #[track_caller]
186 pub fn try_spawn_named<F>(&self, task: F, name: &str) -> JoinHandle<F::Output>
187 where
188 F: Future + 'static,
189 {
190 let tracked = self.tracker.track_future(task);
191 spawn_named_future(tracked, name)
192 }
193
194 /// Run the provided future on the current thread
195 /// that will get cancelled once a global shutdown signal is detected,
196 /// and track it in the underlying [TaskTracker](tokio_util::task::TaskTracker).
197 /// It has exactly the same behaviour as [spawn_with_shutdown](Self::spawn_with_shutdown) and it only exists to provide
198 /// the same interface as non-wasm32 targets.
199 #[track_caller]
200 pub fn try_spawn_named_with_shutdown<F>(
201 &self,
202 task: F,
203 name: &str,
204 ) -> JoinHandle<Result<F::Output, Cancelled>>
205 where
206 F: Future<Output = ()> + 'static,
207 {
208 let caller = std::panic::Location::caller();
209 let shutdown_token = self.clone_shutdown_token();
210 let tracked = self.tracker.track_future(async move {
211 match shutdown_token.run_until_cancelled_owned(task).await {
212 Some(result) => {
213 debug!("{caller}: task has finished execution");
214 Ok(result)
215 }
216 None => {
217 trace!("{caller}: shutdown signal received, shutting down");
218 Err(Cancelled)
219 }
220 }
221 });
222 spawn_named_future(tracked, name)
223 }
224
225 /// Run the provided future on the current thread
226 /// that will get cancelled once a global shutdown signal is detected,
227 /// and track it in the underlying [TaskTracker](tokio_util::task::TaskTracker).
228 #[track_caller]
229 pub fn spawn_with_shutdown<F>(&self, task: F) -> JoinHandle<Result<F::Output, Cancelled>>
230 where
231 F: Future<Output = ()> + 'static,
232 {
233 let caller = std::panic::Location::caller();
234 let shutdown_token = self.clone_shutdown_token();
235 let tracked = self.tracker.track_future(async move {
236 match shutdown_token.run_until_cancelled_owned(task).await {
237 Some(result) => {
238 debug!("{caller}: task has finished execution");
239 Ok(result)
240 }
241 None => {
242 trace!("{caller}: shutdown signal received, shutting down");
243 Err(Cancelled)
244 }
245 }
246 });
247 spawn_future(tracked)
248 }
249}
250
251impl ShutdownTracker {
252 /// Create new instance of the ShutdownTracker using an external shutdown token.
253 /// This could be useful in situations where shutdown is being managed by an external entity
254 /// that is not [ShutdownManager](ShutdownManager), but interface requires providing a ShutdownTracker,
255 /// such as client-core tasks
256 pub fn new_from_external_shutdown_token(shutdown_token: ShutdownToken) -> Self {
257 ShutdownTracker {
258 root_cancellation_token: shutdown_token,
259 tracker: Default::default(),
260 }
261 }
262
263 /// Waits until the underlying [TaskTracker](tokio_util::task::TaskTracker) is both closed and empty.
264 ///
265 /// If the underlying [TaskTracker](tokio_util::task::TaskTracker) is already closed and empty when this method is called, then it
266 /// returns immediately.
267 pub async fn wait_for_tracker(&self) {
268 self.tracker.wait().await;
269 }
270
271 /// Close the underlying [TaskTracker](tokio_util::task::TaskTracker).
272 ///
273 /// This allows [`wait_for_tracker`] futures to complete. It does not prevent you from spawning new tasks.
274 ///
275 /// Returns `true` if this closed the underlying [TaskTracker](tokio_util::task::TaskTracker), or `false` if it was already closed.
276 ///
277 /// [`wait_for_tracker`]: Self::wait_for_tracker
278 pub fn close_tracker(&self) -> bool {
279 self.tracker.close()
280 }
281
282 /// Reopen the underlying [TaskTracker](tokio_util::task::TaskTracker).
283 ///
284 /// This prevents [`wait_for_tracker`] futures from completing even if the underlying [TaskTracker](tokio_util::task::TaskTracker) is empty.
285 ///
286 /// Returns `true` if this reopened the underlying [TaskTracker](tokio_util::task::TaskTracker), or `false` if it was already open.
287 ///
288 /// [`wait_for_tracker`]: Self::wait_for_tracker
289 pub fn reopen_tracker(&self) -> bool {
290 self.tracker.reopen()
291 }
292
293 /// Returns `true` if the underlying [TaskTracker](tokio_util::task::TaskTracker) is [closed](Self::close_tracker).
294 pub fn is_tracker_closed(&self) -> bool {
295 self.tracker.is_closed()
296 }
297
298 /// Returns the number of tasks tracked by the underlying [TaskTracker](tokio_util::task::TaskTracker).
299 pub fn tracked_tasks(&self) -> usize {
300 self.tracker.len()
301 }
302
303 /// Returns `true` if there are no tasks in the underlying [TaskTracker](tokio_util::task::TaskTracker).
304 pub fn is_tracker_empty(&self) -> bool {
305 self.tracker.is_empty()
306 }
307
308 /// Obtain a [ShutdownToken](crate::cancellation::ShutdownToken) that is a child of the root token
309 pub fn child_shutdown_token(&self) -> ShutdownToken {
310 self.root_cancellation_token.child_token()
311 }
312
313 /// Obtain a [ShutdownToken](crate::cancellation::ShutdownToken) on the same hierarchical structure as the root token
314 pub fn clone_shutdown_token(&self) -> ShutdownToken {
315 self.root_cancellation_token.clone()
316 }
317
318 /// Create a child ShutdownTracker that inherits cancellation from this tracker
319 /// but has its own TaskTracker for managing sub-tasks.
320 ///
321 /// This enables hierarchical task management where:
322 /// - Parent cancellation flows to all children
323 /// - Each level tracks its own tasks independently
324 /// - Components can wait for their specific sub-tasks to complete
325 pub fn child_tracker(&self) -> ShutdownTracker {
326 // Child token inherits cancellation from parent
327 let child_token = self.root_cancellation_token.child_token();
328
329 // New TaskTracker for this level's tasks
330 let child_task_tracker = TaskTracker::new();
331
332 ShutdownTracker {
333 root_cancellation_token: child_token,
334 tracker: child_task_tracker,
335 }
336 }
337
338 /// Convenience method to perform a complete shutdown sequence.
339 /// This method:
340 /// 1. Signals cancellation to all tasks
341 /// 2. Closes the tracker to prevent new tasks
342 /// 3. Waits for all existing tasks to complete
343 pub async fn shutdown(self) {
344 // Signal cancellation to all tasks
345 self.root_cancellation_token.cancel();
346
347 // Close the tracker to prevent new tasks from being spawned
348 self.tracker.close();
349
350 // Wait for all existing tasks to complete
351 self.tracker.wait().await;
352 }
353}