nym_task/cancellation/
manager.rs1use crate::ShutdownToken;
5use crate::cancellation::tracker::{Cancelled, ShutdownTracker};
6use crate::spawn::JoinHandle;
7use futures::StreamExt;
8use futures::stream::FuturesUnordered;
9use log::error;
10use std::future::Future;
11use std::mem;
12use std::pin::Pin;
13use std::time::Duration;
14use tracing::info;
15
16#[cfg(not(target_arch = "wasm32"))]
17use tokio::time::sleep;
18
19#[cfg(target_arch = "wasm32")]
20use wasmtimer::tokio::sleep;
21
22#[cfg(unix)]
23use tokio::signal::unix::{SignalKind, signal};
24use tokio::task::JoinSet;
25
26#[allow(deprecated)]
29pub struct ShutdownManager {
30 pub(crate) legacy_task_manager: Option<crate::TaskManager>,
33
34 pub(crate) shutdown_signals: ShutdownSignals,
36
37 pub(crate) tracker: ShutdownTracker,
40
41 pub(crate) max_shutdown_duration: Duration,
44}
45
46#[derive(Default)]
48pub struct ShutdownSignals(JoinSet<()>);
49
50impl ShutdownSignals {
51 pub async fn wait_for_signal(&mut self) {
53 self.0.join_next().await;
54 }
55}
56
57#[cfg(not(target_arch = "wasm32"))]
60impl Default for ShutdownManager {
61 fn default() -> Self {
62 ShutdownManager::new_without_signals()
63 .with_interrupt_signal()
64 .with_cancel_on_panic()
65 }
66}
67
68#[cfg(not(target_arch = "wasm32"))]
69impl ShutdownManager {
70 pub fn build_new_default() -> std::io::Result<Self> {
74 Ok(ShutdownManager::new_without_signals()
75 .with_default_shutdown_signals()?
76 .with_cancel_on_panic())
77 }
78
79 #[must_use]
81 #[track_caller]
82 pub fn with_shutdown<F>(mut self, shutdown: F) -> Self
83 where
84 F: Future<Output = ()>,
85 F: Send + 'static,
86 {
87 let shutdown_token = self.tracker.clone_shutdown_token();
88 self.shutdown_signals.0.spawn(async move {
89 shutdown.await;
90
91 info!("sending cancellation after receiving shutdown signal");
92 shutdown_token.cancel();
93 });
94 self
95 }
96
97 #[allow(deprecated)]
100 pub fn with_legacy_task_manager(mut self) -> Self {
101 let mut legacy_manager = crate::TaskManager::default().named("legacy-task-manager");
102 let mut legacy_error_rx = legacy_manager.task_return_error_rx();
103 let mut legacy_drop_rx = legacy_manager.task_drop_rx();
104
105 self.legacy_task_manager = Some(legacy_manager);
106
107 self.with_shutdown(async move {
109 tokio::select! {
110 _ = legacy_error_rx.recv() => (),
111 _ = legacy_drop_rx.recv() => (),
112 }
113
114 info!("received legacy shutdown signal");
115 })
116 }
117
118 #[cfg(unix)]
121 #[track_caller]
122 pub fn with_shutdown_signal(self, signal_kind: SignalKind) -> std::io::Result<Self> {
123 let mut sig = signal(signal_kind)?;
124 Ok(self.with_shutdown(async move {
125 sig.recv().await;
126 }))
127 }
128
129 #[cfg(unix)]
132 #[track_caller]
133 pub fn with_terminate_signal(self) -> std::io::Result<Self> {
134 self.with_shutdown_signal(SignalKind::terminate())
135 }
136
137 #[cfg(unix)]
140 #[track_caller]
141 pub fn with_quit_signal(self) -> std::io::Result<Self> {
142 self.with_shutdown_signal(SignalKind::quit())
143 }
144
145 pub fn with_default_shutdown_signals(self) -> std::io::Result<Self> {
149 cfg_if::cfg_if! {
150 if #[cfg(unix)] {
151 self.with_interrupt_signal()
152 .with_terminate_signal()?
153 .with_quit_signal()
154 } else {
155 Ok(self.with_interrupt_signal())
156 }
157 }
158 }
159
160 #[track_caller]
163 pub fn with_interrupt_signal(self) -> Self {
164 self.with_shutdown(async move {
165 let _ = tokio::signal::ctrl_c().await;
166 })
167 }
168
169 #[track_caller]
171 pub fn spawn<F>(&self, task: F) -> JoinHandle<F::Output>
172 where
173 F: Future + Send + 'static,
174 F::Output: Send + 'static,
175 {
176 self.tracker.spawn(task)
177 }
178
179 #[track_caller]
186 pub fn try_spawn_named<F>(&self, task: F, name: &str) -> JoinHandle<F::Output>
187 where
188 F: Future + Send + 'static,
189 F::Output: Send + 'static,
190 {
191 self.tracker.try_spawn_named(task, name)
192 }
193
194 #[track_caller]
197 pub fn spawn_on<F>(&self, task: F, handle: &tokio::runtime::Handle) -> JoinHandle<F::Output>
198 where
199 F: Future + Send + 'static,
200 F::Output: Send + 'static,
201 {
202 self.tracker.spawn_on(task, handle)
203 }
204
205 #[track_caller]
208 pub fn spawn_local<F>(&self, task: F) -> JoinHandle<F::Output>
209 where
210 F: Future + 'static,
211 F::Output: 'static,
212 {
213 self.tracker.spawn_local(task)
214 }
215
216 #[track_caller]
219 pub fn spawn_blocking<F, T>(&self, task: F) -> JoinHandle<T>
220 where
221 F: FnOnce() -> T,
222 F: Send + 'static,
223 T: Send + 'static,
224 {
225 self.tracker.spawn_blocking(task)
226 }
227
228 #[track_caller]
231 pub fn spawn_blocking_on<F, T>(&self, task: F, handle: &tokio::runtime::Handle) -> JoinHandle<T>
232 where
233 F: FnOnce() -> T,
234 F: Send + 'static,
235 T: Send + 'static,
236 {
237 self.tracker.spawn_blocking_on(task, handle)
238 }
239
240 #[track_caller]
247 pub fn try_spawn_named_with_shutdown<F>(
248 &self,
249 task: F,
250 name: &str,
251 ) -> JoinHandle<Result<F::Output, Cancelled>>
252 where
253 F: Future + Send + 'static,
254 F::Output: Send + 'static,
255 {
256 self.tracker.try_spawn_named_with_shutdown(task, name)
257 }
258
259 #[track_caller]
263 pub fn spawn_with_shutdown<F>(&self, task: F) -> JoinHandle<Result<F::Output, Cancelled>>
264 where
265 F: Future + Send + 'static,
266 F::Output: Send + 'static,
267 {
268 self.tracker.spawn_with_shutdown(task)
269 }
270}
271
272#[cfg(target_arch = "wasm32")]
273impl ShutdownManager {
274 #[track_caller]
276 pub fn spawn<F>(&self, task: F) -> JoinHandle<F::Output>
277 where
278 F: Future + 'static,
279 {
280 self.tracker.spawn(task)
281 }
282
283 #[track_caller]
287 pub fn try_spawn_named<F>(&self, task: F, name: &str) -> JoinHandle<F::Output>
288 where
289 F: Future + 'static,
290 {
291 self.tracker.try_spawn_named(task, name)
292 }
293
294 #[track_caller]
300 pub fn try_spawn_named_with_shutdown<F>(
301 &self,
302 task: F,
303 name: &str,
304 ) -> JoinHandle<Result<F::Output, Cancelled>>
305 where
306 F: Future<Output = ()> + Send + 'static,
307 {
308 self.tracker.try_spawn_named_with_shutdown(task, name)
309 }
310
311 #[track_caller]
315 pub fn spawn_with_shutdown<F>(&self, task: F) -> JoinHandle<Result<F::Output, Cancelled>>
316 where
317 F: Future<Output = ()> + Send + 'static,
318 {
319 self.tracker.spawn_with_shutdown(task)
320 }
321}
322
323impl ShutdownManager {
324 pub fn new_without_signals() -> Self {
327 Self::new_from_external_shutdown_token(ShutdownToken::new())
328 }
329
330 pub fn new_from_external_shutdown_token(shutdown_token: ShutdownToken) -> Self {
338 let manager = ShutdownManager {
339 legacy_task_manager: None,
340 shutdown_signals: Default::default(),
341 tracker: ShutdownTracker::new_from_external_shutdown_token(shutdown_token),
342 max_shutdown_duration: Duration::from_secs(10),
343 };
344
345 cfg_if::cfg_if! {if #[cfg(not(target_arch = "wasm32"))] {
348 let cancel_watcher = manager.tracker.clone_shutdown_token();
349 manager.with_shutdown(async move { cancel_watcher.cancelled().await })
350 } else {
351 manager
352 }}
353 }
354
355 pub fn empty_mock() -> Self {
357 ShutdownManager {
358 legacy_task_manager: None,
359 shutdown_signals: Default::default(),
360 tracker: Default::default(),
361 max_shutdown_duration: Default::default(),
362 }
363 }
364
365 #[must_use]
370 pub fn with_cancel_on_panic(self) -> Self {
371 let current_hook = std::panic::take_hook();
372
373 let shutdown_token = self.clone_shutdown_token();
374 std::panic::set_hook(Box::new(move |panic_info| {
375 current_hook(panic_info);
377
378 let location = panic_info
379 .location()
380 .map(|l| l.to_string())
381 .unwrap_or_else(|| "<unknown>".to_string());
382
383 let payload = if let Some(payload) = panic_info.payload().downcast_ref::<&str>() {
384 payload
385 } else {
386 ""
387 };
388
389 error!("panicked at {location}: {payload}. issuing global cancellation");
391 shutdown_token.cancel();
392 }));
393 self
394 }
395
396 #[must_use]
399 pub fn with_shutdown_duration(mut self, duration: Duration) -> Self {
400 self.max_shutdown_duration = duration;
401 self
402 }
403
404 pub fn is_cancelled(&self) -> bool {
406 self.tracker.root_cancellation_token.is_cancelled()
407 }
408
409 pub fn shutdown_tracker(&self) -> &ShutdownTracker {
411 &self.tracker
412 }
413
414 pub fn shutdown_tracker_owned(&self) -> ShutdownTracker {
416 self.tracker.clone()
417 }
418
419 pub async fn wait_for_tracker(&self) {
424 self.tracker.wait_for_tracker().await;
425 }
426
427 pub fn close_tracker(&self) -> bool {
435 self.tracker.close_tracker()
436 }
437
438 pub fn reopen_tracker(&self) -> bool {
446 self.tracker.reopen_tracker()
447 }
448
449 pub fn is_tracker_closed(&self) -> bool {
451 self.tracker.is_tracker_closed()
452 }
453
454 pub fn tracked_tasks(&self) -> usize {
456 self.tracker.tracked_tasks()
457 }
458
459 pub fn is_tracker_empty(&self) -> bool {
461 self.tracker.is_tracker_empty()
462 }
463
464 pub fn child_shutdown_token(&self) -> ShutdownToken {
466 self.tracker.root_cancellation_token.child_token()
467 }
468
469 pub fn clone_shutdown_token(&self) -> ShutdownToken {
471 self.tracker.root_cancellation_token.clone()
472 }
473
474 #[must_use]
478 #[deprecated]
479 #[allow(deprecated)]
480 pub fn subscribe_legacy<S: Into<String>>(&self, child_suffix: S) -> crate::TaskClient {
481 #[allow(clippy::expect_used)]
484 self.legacy_task_manager
485 .as_ref()
486 .expect("did not enable legacy shutdown support")
487 .subscribe_named(child_suffix)
488 }
489
490 async fn finish_shutdown(&mut self) {
495 let mut wait_futures = FuturesUnordered::<Pin<Box<dyn Future<Output = ()> + Send>>>::new();
496
497 wait_futures.push(Box::pin(async move {
499 #[cfg(not(target_arch = "wasm32"))]
500 let interrupt_future = tokio::signal::ctrl_c();
501
502 #[cfg(target_arch = "wasm32")]
503 let interrupt_future = futures::future::pending::<()>();
504
505 let _ = interrupt_future.await;
506 info!("received interrupt - forcing shutdown");
507 }));
508
509 let max_shutdown = self.max_shutdown_duration;
511 wait_futures.push(Box::pin(async move {
512 sleep(max_shutdown).await;
513 info!("timeout reached - forcing shutdown");
514 }));
515
516 let tracker = self.tracker.clone();
518 wait_futures.push(Box::pin(async move {
519 tracker.wait_for_tracker().await;
520 info!("all tracked tasks successfully shutdown");
521 if let Some(legacy) = self.legacy_task_manager.as_mut() {
522 legacy.wait_for_graceful_shutdown().await;
523 info!("all legacy tasks successfully shutdown");
524 }
525
526 info!("all registered tasks successfully shutdown")
527 }));
528
529 wait_futures.next().await;
530 }
531
532 pub fn detach_shutdown_signals(&mut self) -> ShutdownSignals {
538 mem::take(&mut self.shutdown_signals)
539 }
540
541 pub fn replace_shutdown_signals(&mut self, signals: ShutdownSignals) {
544 self.shutdown_signals = signals;
545 }
546
547 pub fn send_cancellation(&self) {
550 if let Some(legacy_manager) = self.legacy_task_manager.as_ref() {
551 info!("attempting to shutdown legacy tasks");
552 let _ = legacy_manager.signal_shutdown();
553 }
554 self.tracker.root_cancellation_token.cancel();
555 }
556
557 pub async fn wait_for_shutdown_signal(&mut self) {
560 #[cfg(not(target_arch = "wasm32"))]
561 self.shutdown_signals.0.join_next().await;
562
563 #[cfg(target_arch = "wasm32")]
564 self.tracker.root_cancellation_token.cancelled().await;
565 }
566
567 pub async fn perform_shutdown(&mut self) {
572 self.send_cancellation();
573
574 info!("waiting for tasks to finish... (press ctrl-c to force)");
575 self.finish_shutdown().await;
576 }
577
578 pub async fn run_until_shutdown(&mut self) {
580 self.close_tracker();
581 self.wait_for_shutdown_signal().await;
582
583 self.perform_shutdown().await;
584 }
585}
586
587#[cfg(test)]
588mod tests {
589 use super::*;
590 use nym_test_utils::traits::{ElapsedExt, Timeboxed};
591 use std::sync::Arc;
592 use std::sync::atomic::AtomicBool;
593
594 #[tokio::test]
595 async fn shutdown_with_no_tracked_tasks_and_signals() -> anyhow::Result<()> {
596 let mut manager = ShutdownManager::new_without_signals();
597 let res = manager.run_until_shutdown().timeboxed().await;
598 assert!(res.has_elapsed());
599
600 let mut manager = ShutdownManager::new_without_signals();
601 let shutdown = manager.clone_shutdown_token();
602 shutdown.cancel();
603 let res = manager.run_until_shutdown().timeboxed().await;
604 assert!(!res.has_elapsed());
605
606 Ok(())
607 }
608
609 #[tokio::test]
610 async fn shutdown_signal() -> anyhow::Result<()> {
611 let timeout_shutdown = sleep(Duration::from_millis(100));
612 let mut manager = ShutdownManager::new_without_signals().with_shutdown(timeout_shutdown);
613
614 let res = manager
616 .run_until_shutdown()
617 .execute_with_deadline(Duration::from_millis(200))
618 .await;
619 assert!(!res.has_elapsed());
620
621 Ok(())
622 }
623
624 #[tokio::test]
625 async fn panic_hook() -> anyhow::Result<()> {
626 let mut manager = ShutdownManager::new_without_signals().with_cancel_on_panic();
627 manager.spawn_with_shutdown(async move {
628 sleep(Duration::from_millis(10000)).await;
629 });
630 manager.spawn_with_shutdown(async move {
631 sleep(Duration::from_millis(10)).await;
632 panic!("panicking");
633 });
634
635 let res = manager
637 .run_until_shutdown()
638 .execute_with_deadline(Duration::from_millis(200))
639 .await;
640 assert!(!res.has_elapsed());
641
642 Ok(())
643 }
644
645 #[tokio::test]
646 async fn task_cancellation() -> anyhow::Result<()> {
647 let timeout_shutdown = sleep(Duration::from_millis(100));
648 let mut manager = ShutdownManager::new_without_signals().with_shutdown(timeout_shutdown);
649
650 let cancelled1 = Arc::new(AtomicBool::new(false));
651 let cancelled1_clone = cancelled1.clone();
652 let cancelled2 = Arc::new(AtomicBool::new(false));
653 let cancelled2_clone = cancelled2.clone();
654
655 let shutdown = manager.clone_shutdown_token();
656 manager.spawn(async move {
657 shutdown.cancelled().await;
658 cancelled1_clone.store(true, std::sync::atomic::Ordering::Relaxed);
659 });
660
661 let shutdown = manager.clone_shutdown_token();
662 manager.spawn(async move {
663 shutdown.cancelled().await;
664 cancelled2_clone.store(true, std::sync::atomic::Ordering::Relaxed);
665 });
666
667 let res = manager
668 .run_until_shutdown()
669 .execute_with_deadline(Duration::from_millis(200))
670 .await;
671
672 assert!(!res.has_elapsed());
673 assert!(cancelled1.load(std::sync::atomic::Ordering::Relaxed));
674 assert!(cancelled2.load(std::sync::atomic::Ordering::Relaxed));
675 Ok(())
676 }
677
678 #[tokio::test]
679 async fn cancellation_within_task() -> anyhow::Result<()> {
680 let mut manager = ShutdownManager::new_without_signals();
681
682 let cancelled1 = Arc::new(AtomicBool::new(false));
683 let cancelled1_clone = cancelled1.clone();
684
685 let shutdown = manager.clone_shutdown_token();
686 manager.spawn(async move {
687 shutdown.cancelled().await;
688 cancelled1_clone.store(true, std::sync::atomic::Ordering::Relaxed);
689 });
690
691 let shutdown = manager.clone_shutdown_token();
692 manager.spawn(async move {
693 sleep(Duration::from_millis(10)).await;
694 shutdown.cancel();
695 });
696
697 let res = manager
698 .run_until_shutdown()
699 .execute_with_deadline(Duration::from_millis(200))
700 .await;
701
702 assert!(!res.has_elapsed());
703 assert!(cancelled1.load(std::sync::atomic::Ordering::Relaxed));
704 Ok(())
705 }
706
707 #[tokio::test]
708 async fn shutdown_timeout() -> anyhow::Result<()> {
709 let timeout_shutdown = sleep(Duration::from_millis(50));
710 let mut manager = ShutdownManager::new_without_signals()
711 .with_shutdown(timeout_shutdown)
712 .with_shutdown_duration(Duration::from_millis(1000));
713
714 manager.spawn(async move {
716 sleep(Duration::from_millis(1000)).await;
717 });
718
719 let res = manager
720 .run_until_shutdown()
721 .execute_with_deadline(Duration::from_millis(200))
722 .await;
723
724 assert!(res.has_elapsed());
725
726 let timeout_shutdown = sleep(Duration::from_millis(50));
727 let mut manager = ShutdownManager::new_without_signals()
728 .with_shutdown(timeout_shutdown)
729 .with_shutdown_duration(Duration::from_millis(100));
730
731 manager.spawn(async move {
733 sleep(Duration::from_millis(1000)).await;
734 });
735
736 let res = manager
737 .run_until_shutdown()
738 .execute_with_deadline(Duration::from_millis(200))
739 .await;
740
741 assert!(!res.has_elapsed());
742 Ok(())
743 }
744}