tokio_task_tracker/lib.rs
1//! tokio-task-tracker is a simple graceful shutdown solution for tokio.
2//!
3//! The basic idea is to use a `TaskSpawner` to create `TaskTracker` object, and hold
4//! on to them in spawned tasks. Inside the task, you can check `tracker.cancelled().await`
5//! to wait for the task to be cancelled.
6//!
7//! The `TaskWaiter` can be used to wait for an interrupt and then wait for all
8//! `TaskTracker`s to be dropped.
9//!
10//! # Examples
11//!
12//! ```no_run
13//! # use std::time::Duration;
14//! #
15//! #[tokio::main(flavor = "current_thread")]
16//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
17//! let (spawner, waiter) = tokio_task_tracker::new();
18//!
19//! // Start a task
20//! spawner.spawn(|tracker| async move {
21//! tokio::select! {
22//! _ = tracker.cancelled() => {
23//! // The token was cancelled, task should shut down.
24//! }
25//! _ = tokio::time::sleep(Duration::from_secs(9999)) => {
26//! // Long work has completed
27//! }
28//! }
29//! });
30//!
31//! // Wait for all tasks to complete, or for someone to hit ctrl-c.
32//! // If tasks down't complete within 5 seconds, we'll quit anyways.
33//! waiter.wait_for_shutdown(Duration::from_secs(5)).await?;
34//!
35//! Ok(())
36//! }
37//! ```
38//!
39//! If you do not wish to allow a task to be aborted, you still need to make sure
40//! the task captures the tracker, because TaskWaiter will wait for all trackers to be dropped:
41//!
42//! ```no_run
43//! # use std::time::Duration;
44//! #
45//! # #[tokio::main(flavor = "current_thread")]
46//! # async fn main() {
47//! # let (spawner, waiter) = tokio_task_tracker::new();
48//! #
49//! // Start a task
50//! spawner.spawn(|tracker| async move {
51//! // Move the tracker into the task.
52//! let _tracker = tracker;
53//!
54//! // Do some work that we don't want to abort.
55//! tokio::time::sleep(Duration::from_secs(9999)).await;
56//! });
57//!
58//! # }
59//! ```
60//!
61//! You can also create a tracker via the `task` method:
62//!
63//! ```no_run
64//! # use std::time::Duration;
65//! #
66//! # #[tokio::main(flavor = "current_thread")]
67//! # async fn main() {
68//! # let (spawner, waiter) = tokio_task_tracker::new();
69//! #
70//! // Start a task
71//! let tracker = spawner.task();
72//! tokio::task::spawn(async move {
73//! // Move the tracker into the task.
74//! let _tracker = tracker;
75//!
76//! // ...
77//! });
78//!
79//! # }
80//! ```
81//!
82//! Trackers can be used to spawn subtasks via `tracker.subtask()` or
83//! `tracker.spawn()`.
84
85use std::{
86 future::Future,
87 sync::{Arc, Mutex},
88 time::Duration,
89};
90
91use shutdown::wait_for_shutdown_signal;
92use tokio::{select, sync::mpsc, task::JoinHandle};
93use tokio_util::sync::CancellationToken;
94
95mod shutdown;
96
97/// Builder is used to create a TaskSpawner and TaskWaiter.
98pub struct Builder {
99 token: Option<CancellationToken>,
100}
101
102/// TaskSpawner is used to spawn new task trackers.
103#[derive(Clone)]
104pub struct TaskSpawner {
105 token: CancellationToken,
106 stop_tx: Arc<Mutex<Option<mpsc::Sender<()>>>>,
107}
108
109/// TaskWaiter is used to wait until all task trackers have been dropped.
110pub struct TaskWaiter {
111 token: CancellationToken,
112 /// Shared stop_tx is shared between all TaskSpawners and the TaskWaiter, so that
113 /// when we call TaskWaiter::wait() we can drop the tx from all spawners.
114 stop_tx: Arc<Mutex<Option<mpsc::Sender<()>>>>,
115 stop_rx: mpsc::Receiver<()>,
116}
117
118/// A TaskTracker is used both as a token to keep track of active tasks, and
119/// as a cancellation token to check to see if the current task should quit.
120#[derive(Clone)]
121pub struct TaskTracker {
122 token: CancellationToken,
123 // Hang on to an instance of tx. We do this so we can know when all tasks
124 // have been completed.
125 _stop_tx: Option<mpsc::Sender<()>>,
126}
127
128#[derive(Debug, PartialEq)]
129pub enum Error {
130 /// Returned when we timeout waiting for all tasks to shut down.
131 Timeout,
132 /// Returned when we cannot bind to the interrupt/terminate signals.
133 CouldNotBindInterrupt,
134 /// Returned when we were waiting for graceful shutdown, but received a
135 /// second interrupt signal.
136 ShutdownEarly,
137}
138
139impl std::error::Error for Error {}
140
141impl std::fmt::Display for Error {
142 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
143 match self {
144 Error::Timeout => write!(f, "Not all tasks finished before timeout"),
145 Error::CouldNotBindInterrupt => write!(f, "Could not bind interrupt handler"),
146 Error::ShutdownEarly => write!(f, "Skipping graceful shutdown due to second interrupt"),
147 }
148 }
149}
150
151/// Create a new TaskSpawner and TaskWaiter.
152pub fn new() -> (TaskSpawner, TaskWaiter) {
153 Builder::default().build()
154}
155
156impl Builder {
157 /// Create a new Builder.
158 pub fn new() -> Self {
159 Builder { token: None }
160 }
161
162 /// Use an existing CancellationToken for the returned TaskWaiter and TaskSpawner.
163 /// If the given token is cancelled, all associated TaskTrackers will be cancelled
164 /// as well.
165 pub fn set_cancellation_token(mut self, token: CancellationToken) -> Self {
166 self.token = Some(token);
167 self
168 }
169
170 /// Create a new TaskSpawner and TaskWaiter.
171 pub fn build(self) -> (TaskSpawner, TaskWaiter) {
172 let (stop_tx, stop_rx) = mpsc::channel(1);
173 let stop_tx = Arc::new(Mutex::new(Some(stop_tx)));
174 let token = self.token.unwrap_or(CancellationToken::new());
175
176 (
177 TaskSpawner {
178 token: token.clone(),
179 stop_tx: stop_tx.clone(),
180 },
181 TaskWaiter {
182 token,
183 stop_tx,
184 stop_rx,
185 },
186 )
187 }
188}
189
190impl Default for Builder {
191 fn default() -> Self {
192 Self::new()
193 }
194}
195
196impl TaskSpawner {
197 /// Create a new TaskTracker.
198 pub fn task(&self) -> TaskTracker {
199 TaskTracker {
200 token: self.token.clone(),
201 _stop_tx: self.stop_tx.lock().unwrap().as_ref().cloned(),
202 }
203 }
204
205 /// Spawn a task.
206 ///
207 /// The given closure will be called, passing in a task tracker.
208 pub fn spawn<T, F: FnOnce(TaskTracker) -> T>(&self, f: F) -> JoinHandle<T::Output>
209 where
210 T: Future + Send + 'static,
211 T::Output: Send + 'static,
212 {
213 let tracker = self.task();
214 tokio::task::spawn(f(tracker))
215 }
216
217 /// Notify all tasks created by this TaskSpawner that they should abort.
218 pub fn cancel(&self) {
219 self.token.cancel();
220 }
221}
222
223impl TaskWaiter {
224 /// Notify all tasks this TaskWaiter is waiting on that they should abort.
225 pub fn cancel(&self) {
226 self.token.cancel();
227 }
228
229 /// Wait for the application to be interrupted, and then gracefully shutdown
230 /// allowing a timeout for all tasks to quit. A second interrupt will cause
231 /// an immediate shutdown.
232 ///
233 /// On Unix systems, "interrupt" means a SIGINT or SIGTERM. On all other
234 /// platforms the current implementation uses `tokio::signal::ctrl_c()`
235 /// to wait for an interrupt.
236 pub async fn wait_for_shutdown(self, timeout: Duration) -> Result<(), Error> {
237 // Wait for the ctrl-c.
238 match wait_for_shutdown_signal().await {
239 Ok(()) => {
240 // time to shut down...
241 }
242 Err(_) => return Err(Error::CouldNotBindInterrupt),
243 }
244
245 // Let tasks know they should shut down.
246 self.token.cancel();
247
248 // Wait for everything to finish.
249 select! {
250 res = self.wait_with_timeout(timeout) => res,
251 _ = wait_for_shutdown_signal() => Err(Error::ShutdownEarly),
252 }
253 }
254
255 /// Wait for all tasks to finish. If tasks do not finish before the timeout,
256 /// `Error::Timeout` will be returned.
257 pub async fn wait_with_timeout(self, timeout: Duration) -> Result<(), Error> {
258 // Wait for all tasks to be dropped.
259 tokio::time::timeout(timeout, self.wait())
260 .await
261 .map_err(|_| Error::Timeout {})?;
262
263 Ok(())
264 }
265
266 /// Wait for all tasks to finish.
267 pub async fn wait(mut self) {
268 // Drop the tx half of the channel.
269 drop(self.stop_tx.lock().unwrap().take());
270
271 // Wait for all tasks to be dropped.
272 let _ = self.stop_rx.recv().await;
273 }
274}
275
276impl TaskTracker {
277 /// Create a new subtask from this TaskTracker.
278 pub fn subtask(&self) -> Self {
279 self.clone()
280 }
281
282 /// Spawn a subtask.
283 ///
284 /// The given closure will be called, passing in a task tracker.
285 pub fn spawn<T, F: FnOnce(TaskTracker) -> T>(&self, f: F) -> JoinHandle<T::Output>
286 where
287 T: Future + Send + 'static,
288 T::Output: Send + 'static,
289 {
290 let tracker = self.subtask();
291 tokio::task::spawn(f(tracker))
292 }
293
294 /// Check to see if this task has been cancelled.
295 pub async fn cancelled(&self) {
296 self.token.cancelled().await;
297 }
298
299 /// Returns true if this token has been cancelled.
300 pub fn is_cancelled(&self) -> bool {
301 self.token.is_cancelled()
302 }
303}
304
305#[cfg(test)]
306mod tests {
307 use super::*;
308 use std::{
309 sync::atomic::{AtomicBool, Ordering},
310 time::Duration,
311 };
312
313 #[tokio::test]
314 async fn tracker_should_be_cancelled() {
315 let (spawner, waiter) = super::new();
316
317 let task = spawner.task();
318 waiter.cancel();
319 assert!(task.is_cancelled());
320 }
321
322 #[tokio::test]
323 async fn should_work_with_existing_cancellation_token() {
324 let token = CancellationToken::new();
325 let (spawner, _) = super::Builder::new()
326 .set_cancellation_token(token.clone())
327 .build();
328 let task = spawner.task();
329
330 // Cancelling the token should cancel the task.
331 token.cancel();
332 assert!(task.is_cancelled());
333 }
334
335 #[tokio::test]
336 async fn should_wait_for_tasks_to_complete() -> Result<(), Box<dyn std::error::Error>> {
337 let (spawner, waiter) = super::new();
338
339 let done = Arc::new(AtomicBool::new(false));
340
341 // Start a task
342 {
343 let done = done.clone();
344 spawner.spawn(|tracker| async move {
345 tokio::select! {
346 _ = tracker.cancelled() => {
347 // The token was cancelled, task should shut down.
348 }
349 _ = tokio::time::sleep(Duration::from_millis(100)) => {
350 // Short task has completed.
351 done.store(true, Ordering::SeqCst);
352 }
353 }
354 });
355 }
356
357 // Wait for all tasks to complete.
358 waiter.wait().await;
359
360 // Should have completed.
361 assert!(done.load(Ordering::SeqCst));
362
363 Ok(())
364 }
365
366 #[tokio::test]
367 async fn should_cancel_tasks() -> Result<(), Box<dyn std::error::Error>> {
368 let (spawner, waiter) = super::new();
369
370 let done = Arc::new(AtomicBool::new(false));
371
372 // Start a task
373 {
374 let done = done.clone();
375 spawner.spawn(|tracker| async move {
376 tokio::select! {
377 _ = tracker.cancelled() => {
378 // The token was cancelled, task should shut down.
379 }
380 _ = tokio::time::sleep(Duration::from_secs(9999)) => {
381 // Long work has completed
382 done.store(true, Ordering::SeqCst);
383 }
384 }
385 });
386 }
387
388 // Cancel the task after a short while.
389 tokio::time::sleep(Duration::from_millis(100)).await;
390 waiter.cancel();
391
392 // Wait for all tasks to complete.
393 waiter.wait().await;
394
395 // Should have timed out.
396 assert!(!done.load(Ordering::SeqCst));
397
398 Ok(())
399 }
400
401 #[tokio::test]
402 async fn interrupt_tests() -> Result<(), Box<dyn std::error::Error>> {
403 // Interrupt tests rely on global state in shutdown.rs to simulate
404 // SIGINT. Need to run these serially.
405 should_wait_for_tasks_on_interrupt().await?;
406 should_stop_immediately_on_second_interrupt().await?;
407
408 Ok(())
409 }
410
411 async fn should_wait_for_tasks_on_interrupt() -> Result<(), Box<dyn std::error::Error>> {
412 shutdown::reset_before_test();
413
414 let (spawner, waiter) = super::new();
415
416 let done = Arc::new(AtomicBool::new(false));
417
418 // Start a task
419 {
420 let done = done.clone();
421 spawner.spawn(|tracker| async move {
422 tokio::select! {
423 _ = tracker.cancelled() => {
424 // The token was cancelled, task should shut down.
425 }
426 _ = tokio::time::sleep(Duration::from_secs(9999)) => {
427 // Long running task...
428 done.store(true, Ordering::SeqCst);
429 }
430 }
431 });
432 }
433
434 // Send a fake shutdown signal.
435 tokio::spawn(async {
436 shutdown::send_shutdown().await;
437 });
438
439 // Wait for all tasks to complete.
440 waiter.wait_for_shutdown(Duration::from_secs(10)).await?;
441
442 // Task should have been aborted.
443 assert!(!done.load(Ordering::SeqCst));
444
445 Ok(())
446 }
447
448 async fn should_stop_immediately_on_second_interrupt() -> Result<(), Box<dyn std::error::Error>>
449 {
450 shutdown::reset_before_test();
451
452 let (spawner, waiter) = super::new();
453
454 let done = Arc::new(AtomicBool::new(false));
455
456 // Start a task
457 {
458 let done = done.clone();
459 spawner.spawn(|tracker| async move {
460 let _tracker = tracker;
461
462 // Long running task that can't be cancelled.
463 tokio::time::sleep(Duration::from_secs(99)).await;
464 done.store(true, Ordering::SeqCst);
465 });
466 }
467
468 // Send two shutdown signals. The second should cause us to die immediately.
469 tokio::spawn(async move {
470 shutdown::send_shutdown().await;
471 shutdown::send_shutdown().await;
472 });
473
474 // We shouldn't wait here, because of the second interrupt.
475 let err = waiter
476 .wait_for_shutdown(Duration::from_secs(99))
477 .await
478 .unwrap_err();
479 assert_eq!(err, Error::ShutdownEarly);
480
481 // Task should have been aborted.
482 assert!(!done.load(Ordering::SeqCst));
483
484 Ok(())
485 }
486}