1use std::{future::Future, pin::Pin, sync::Arc, time::Duration};
38
39use tokio::sync::{broadcast, watch};
40
41#[derive(Debug, Clone, Copy, PartialEq, Eq)]
43pub enum ShutdownSignal {
44 Interrupt,
46 Terminate,
48 Manual,
50}
51
52impl std::fmt::Display for ShutdownSignal {
53 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54 match self {
55 ShutdownSignal::Interrupt => write!(f, "SIGINT"),
56 ShutdownSignal::Terminate => write!(f, "SIGTERM"),
57 ShutdownSignal::Manual => write!(f, "Manual"),
58 }
59 }
60}
61
62#[derive(Clone)]
64pub struct ShutdownToken {
65 receiver: watch::Receiver<bool>,
66}
67
68impl ShutdownToken {
69 pub fn is_shutdown(&self) -> bool {
71 *self.receiver.borrow()
72 }
73
74 pub async fn cancelled(&mut self) {
76 let _ = self.receiver.wait_for(|v| *v).await;
78 }
79}
80
81pub struct GracefulShutdownBuilder {
83 timeout: Duration,
84 on_signal: Option<Box<dyn Fn(ShutdownSignal) + Send + Sync>>,
85}
86
87impl Default for GracefulShutdownBuilder {
88 fn default() -> Self {
89 Self {
90 timeout: Duration::from_secs(30),
91 on_signal: None,
92 }
93 }
94}
95
96impl GracefulShutdownBuilder {
97 pub fn new() -> Self {
99 Self::default()
100 }
101
102 pub fn timeout(mut self, timeout: Duration) -> Self {
104 self.timeout = timeout;
105 self
106 }
107
108 pub fn on_signal<F>(mut self, callback: F) -> Self
110 where
111 F: Fn(ShutdownSignal) + Send + Sync + 'static,
112 {
113 self.on_signal = Some(Box::new(callback));
114 self
115 }
116
117 pub fn build(self) -> GracefulShutdown {
119 let on_signal: Option<Arc<dyn Fn(ShutdownSignal) + Send + Sync>> = self
120 .on_signal
121 .map(|f| Arc::from(f) as Arc<dyn Fn(ShutdownSignal) + Send + Sync>);
122 GracefulShutdown {
123 timeout: self.timeout,
124 on_signal,
125 shutdown_tx: watch::channel(false).0,
126 signal_tx: broadcast::channel(1).0,
127 }
128 }
129}
130
131pub struct GracefulShutdown {
135 timeout: Duration,
136 on_signal: Option<Arc<dyn Fn(ShutdownSignal) + Send + Sync>>,
137 shutdown_tx: watch::Sender<bool>,
138 signal_tx: broadcast::Sender<ShutdownSignal>,
139}
140
141impl GracefulShutdown {
142 pub fn new() -> Self {
144 GracefulShutdownBuilder::new().build()
145 }
146
147 pub fn builder() -> GracefulShutdownBuilder {
149 GracefulShutdownBuilder::new()
150 }
151
152 pub fn timeout(&self) -> Duration {
154 self.timeout
155 }
156
157 pub fn token(&self) -> ShutdownToken {
159 ShutdownToken {
160 receiver: self.shutdown_tx.subscribe(),
161 }
162 }
163
164 pub fn subscribe(&self) -> broadcast::Receiver<ShutdownSignal> {
166 self.signal_tx.subscribe()
167 }
168
169 pub fn shutdown(&self) {
171 let _ = self.shutdown_tx.send(true);
172 let _ = self.signal_tx.send(ShutdownSignal::Manual);
173 if let Some(ref callback) = self.on_signal {
174 callback(ShutdownSignal::Manual);
175 }
176 }
177
178 pub async fn wait(&self) -> ShutdownSignal {
182 let signal = wait_for_signal().await;
183
184 let _ = self.shutdown_tx.send(true);
186 let _ = self.signal_tx.send(signal);
187
188 if let Some(ref callback) = self.on_signal {
190 callback(signal);
191 }
192
193 signal
194 }
195
196 pub async fn wait_with_timeout(&self) -> Option<ShutdownSignal> {
200 tokio::select! {
201 signal = self.wait() => Some(signal),
202 _ = tokio::time::sleep(self.timeout) => None,
203 }
204 }
205
206 pub async fn run_until_shutdown<F, T>(&self, future: F) -> Option<T>
211 where
212 F: Future<Output = T>,
213 {
214 let mut token = self.token();
215 tokio::select! {
216 result = future => Some(result),
217 _ = token.cancelled() => None,
218 }
219 }
220
221 pub fn spawn<F>(&self, name: &str, future: F) -> tokio::task::JoinHandle<Option<()>>
223 where
224 F: Future<Output = ()> + Send + 'static,
225 {
226 let mut token = self.token();
227 let name = name.to_string();
228 tokio::spawn(async move {
229 tokio::select! {
230 _ = future => {
231 Some(())
232 }
233 _ = token.cancelled() => {
234 #[cfg(feature = "otel")]
235 tracing::info!(task = %name, "Task cancelled due to shutdown");
236 #[cfg(not(feature = "otel"))]
237 let _ = name;
238 None
239 }
240 }
241 })
242 }
243}
244
245impl Default for GracefulShutdown {
246 fn default() -> Self {
247 Self::new()
248 }
249}
250
251async fn wait_for_signal() -> ShutdownSignal {
253 #[cfg(unix)]
254 {
255 use tokio::signal::unix::{signal, SignalKind};
256
257 let mut sigint =
258 signal(SignalKind::interrupt()).expect("Failed to register SIGINT handler");
259 let mut sigterm =
260 signal(SignalKind::terminate()).expect("Failed to register SIGTERM handler");
261
262 tokio::select! {
263 _ = sigint.recv() => ShutdownSignal::Interrupt,
264 _ = sigterm.recv() => ShutdownSignal::Terminate,
265 }
266 }
267
268 #[cfg(not(unix))]
269 {
270 tokio::signal::ctrl_c()
271 .await
272 .expect("Failed to register Ctrl+C handler");
273 ShutdownSignal::Interrupt
274 }
275}
276
277pub struct ShutdownGuard {
281 shutdown: Arc<GracefulShutdown>,
282}
283
284impl ShutdownGuard {
285 pub fn new(shutdown: Arc<GracefulShutdown>) -> Self {
287 Self { shutdown }
288 }
289}
290
291impl Drop for ShutdownGuard {
292 fn drop(&mut self) {
293 self.shutdown.shutdown();
294 }
295}
296
297pub trait ShutdownExt: Future + Sized {
299 fn with_shutdown(
301 self,
302 shutdown: &GracefulShutdown,
303 ) -> Pin<Box<dyn Future<Output = Option<Self::Output>> + Send + '_>>
304 where
305 Self: Send + 'static,
306 Self::Output: Send,
307 {
308 let future = self;
309 let mut token = shutdown.token();
310 Box::pin(async move {
311 tokio::select! {
312 result = future => Some(result),
313 _ = token.cancelled() => None,
314 }
315 })
316 }
317}
318
319impl<F: Future> ShutdownExt for F {}
320
321pub struct ShutdownAwareTaskSpawner {
350 shutdown: Arc<GracefulShutdown>,
351}
352
353impl ShutdownAwareTaskSpawner {
354 pub fn new(shutdown: Arc<GracefulShutdown>) -> Self {
356 Self { shutdown }
357 }
358
359 pub fn shutdown(&self) -> &Arc<GracefulShutdown> {
361 &self.shutdown
362 }
363
364 pub fn spawn<F, Fut>(&self, task_name: &str, future: F) -> tokio::task::JoinHandle<()>
368 where
369 F: FnOnce() -> Fut + Send + 'static,
370 Fut: Future<Output = ()> + Send,
371 {
372 let mut token = self.shutdown.token();
373 let task_name = task_name.to_string();
374
375 tokio::spawn(async move {
376 #[cfg(feature = "otel")]
377 tracing::info!(task = %task_name, "Starting task");
378
379 let task_future = future();
380
381 tokio::select! {
382 _ = task_future => {
383 #[cfg(feature = "otel")]
384 tracing::info!(task = %task_name, "Task completed normally");
385 }
386 _ = token.cancelled() => {
387 #[cfg(feature = "otel")]
388 tracing::info!(task = %task_name, "Task cancelled due to shutdown");
389 }
390 }
391
392 #[cfg(feature = "otel")]
393 tracing::info!(task = %task_name, "Task finished");
394
395 #[cfg(not(feature = "otel"))]
397 let _ = task_name;
398 })
399 }
400
401 pub fn spawn_background<F, Fut>(
407 &self,
408 task_name: &str,
409 future: F,
410 ) -> tokio::task::JoinHandle<()>
411 where
412 F: FnOnce() -> Fut + Send + 'static,
413 Fut: Future<Output = ()> + Send,
414 {
415 self.spawn(task_name, future)
416 }
417
418 pub fn spawn_with_result<F, Fut, T>(
420 &self,
421 task_name: &str,
422 future: F,
423 ) -> tokio::task::JoinHandle<Option<T>>
424 where
425 F: FnOnce() -> Fut + Send + 'static,
426 Fut: Future<Output = T> + Send,
427 T: Send + 'static,
428 {
429 let mut token = self.shutdown.token();
430 let task_name = task_name.to_string();
431
432 tokio::spawn(async move {
433 #[cfg(feature = "otel")]
434 tracing::info!(task = %task_name, "Starting task");
435
436 let task_future = future();
437
438 let result = tokio::select! {
439 result = task_future => {
440 #[cfg(feature = "otel")]
441 tracing::info!(task = %task_name, "Task completed normally");
442 Some(result)
443 }
444 _ = token.cancelled() => {
445 #[cfg(feature = "otel")]
446 tracing::info!(task = %task_name, "Task cancelled due to shutdown");
447 None
448 }
449 };
450
451 #[cfg(feature = "otel")]
452 tracing::info!(task = %task_name, "Task finished");
453
454 #[cfg(not(feature = "otel"))]
456 let _ = task_name;
457
458 result
459 })
460 }
461}
462
463impl Clone for ShutdownAwareTaskSpawner {
464 fn clone(&self) -> Self {
465 Self {
466 shutdown: self.shutdown.clone(),
467 }
468 }
469}
470
471pub trait GracefulShutdownExt {
496 fn perform_shutdown<F, Fut, E>(
501 &self,
502 cleanup_fn: F,
503 ) -> impl Future<Output = Result<(), E>> + Send
504 where
505 F: FnOnce() -> Fut + Send,
506 Fut: Future<Output = Result<(), E>> + Send,
507 E: std::fmt::Display + Send;
508}
509
510impl GracefulShutdownExt for GracefulShutdown {
511 async fn perform_shutdown<F, Fut, E>(&self, cleanup_fn: F) -> Result<(), E>
512 where
513 F: FnOnce() -> Fut + Send,
514 Fut: Future<Output = Result<(), E>> + Send,
515 E: std::fmt::Display + Send,
516 {
517 #[cfg(feature = "otel")]
518 tracing::info!("Starting graceful shutdown sequence");
519
520 #[cfg(feature = "otel")]
522 tracing::info!("Running cleanup functions");
523
524 let result = cleanup_fn().await;
525
526 if let Err(ref e) = result {
527 #[cfg(feature = "otel")]
528 tracing::error!(error = %e, "Cleanup function failed");
529
530 #[cfg(not(feature = "otel"))]
532 let _ = e;
533 }
534
535 #[cfg(feature = "otel")]
536 tracing::info!("Graceful shutdown completed");
537
538 result
539 }
540}
541
542impl GracefulShutdownExt for Arc<GracefulShutdown> {
544 async fn perform_shutdown<F, Fut, E>(&self, cleanup_fn: F) -> Result<(), E>
545 where
546 F: FnOnce() -> Fut + Send,
547 Fut: Future<Output = Result<(), E>> + Send,
548 E: std::fmt::Display + Send,
549 {
550 self.as_ref().perform_shutdown(cleanup_fn).await
551 }
552}
553
554#[cfg(test)]
555mod tests {
556 use super::*;
557
558 #[test]
559 fn test_shutdown_signal_display() {
560 assert_eq!(ShutdownSignal::Interrupt.to_string(), "SIGINT");
561 assert_eq!(ShutdownSignal::Terminate.to_string(), "SIGTERM");
562 assert_eq!(ShutdownSignal::Manual.to_string(), "Manual");
563 }
564
565 #[tokio::test]
566 async fn test_shutdown_token() {
567 let shutdown = GracefulShutdown::new();
568 let token = shutdown.token();
569
570 assert!(!token.is_shutdown());
571
572 shutdown.shutdown();
573
574 assert!(token.is_shutdown());
575 }
576
577 #[tokio::test]
578 async fn test_manual_shutdown() {
579 let shutdown = GracefulShutdown::new();
580 let mut rx = shutdown.subscribe();
581
582 shutdown.shutdown();
583
584 let signal = rx.recv().await.unwrap();
585 assert_eq!(signal, ShutdownSignal::Manual);
586 }
587
588 #[tokio::test]
589 async fn test_shutdown_callback() {
590 use std::sync::atomic::{AtomicBool, Ordering};
591
592 let called = Arc::new(AtomicBool::new(false));
593 let called_clone = called.clone();
594
595 let shutdown = GracefulShutdown::builder()
596 .on_signal(move |_| {
597 called_clone.store(true, Ordering::SeqCst);
598 })
599 .build();
600
601 shutdown.shutdown();
602
603 assert!(called.load(Ordering::SeqCst));
604 }
605
606 #[tokio::test]
607 async fn test_run_until_shutdown() {
608 let shutdown = GracefulShutdown::new();
609
610 let result = shutdown.run_until_shutdown(async { 42 }).await;
612 assert_eq!(result, Some(42));
613 }
614
615 #[tokio::test]
616 async fn test_run_until_shutdown_cancelled() {
617 let shutdown = GracefulShutdown::new();
618 let token = shutdown.token();
619
620 shutdown.shutdown();
622
623 assert!(token.is_shutdown());
625 }
626
627 #[tokio::test]
628 async fn test_builder_timeout() {
629 let shutdown = GracefulShutdown::builder()
630 .timeout(Duration::from_secs(60))
631 .build();
632
633 assert_eq!(shutdown.timeout(), Duration::from_secs(60));
634 }
635
636 #[tokio::test]
637 async fn test_spawn_task() {
638 let shutdown = GracefulShutdown::new();
639 let counter = Arc::new(std::sync::atomic::AtomicU32::new(0));
640 let counter_clone = counter.clone();
641
642 let handle = shutdown.spawn("test_task", async move {
643 counter_clone.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
644 });
645
646 let result = handle.await.unwrap();
648 assert_eq!(result, Some(()));
649 assert_eq!(counter.load(std::sync::atomic::Ordering::SeqCst), 1);
650 }
651
652 #[tokio::test]
653 async fn test_shutdown_aware_spawner_task_completes() {
654 let shutdown = Arc::new(GracefulShutdown::new());
655 let spawner = ShutdownAwareTaskSpawner::new(shutdown.clone());
656
657 let counter = Arc::new(std::sync::atomic::AtomicU32::new(0));
658 let counter_clone = counter.clone();
659
660 let handle = spawner.spawn("test_task", move || async move {
661 counter_clone.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
662 });
663
664 handle.await.unwrap();
665 assert_eq!(counter.load(std::sync::atomic::Ordering::SeqCst), 1);
666 }
667
668 #[tokio::test]
669 async fn test_shutdown_aware_spawner_task_cancelled() {
670 let shutdown = Arc::new(GracefulShutdown::new());
671 let spawner = ShutdownAwareTaskSpawner::new(shutdown.clone());
672
673 let counter = Arc::new(std::sync::atomic::AtomicU32::new(0));
674 let counter_clone = counter.clone();
675
676 let handle = spawner.spawn("long_task", move || async move {
678 tokio::time::sleep(Duration::from_secs(60)).await;
679 counter_clone.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
680 });
681
682 shutdown.shutdown();
684
685 handle.await.unwrap();
687
688 assert_eq!(counter.load(std::sync::atomic::Ordering::SeqCst), 0);
690 }
691
692 #[tokio::test]
693 async fn test_shutdown_aware_spawner_with_result() {
694 let shutdown = Arc::new(GracefulShutdown::new());
695 let spawner = ShutdownAwareTaskSpawner::new(shutdown.clone());
696
697 let handle = spawner.spawn_with_result("compute_task", || async { 42 });
698
699 let result = handle.await.unwrap();
700 assert_eq!(result, Some(42));
701 }
702
703 #[tokio::test]
704 async fn test_shutdown_aware_spawner_with_result_cancelled() {
705 let shutdown = Arc::new(GracefulShutdown::new());
706 let spawner = ShutdownAwareTaskSpawner::new(shutdown.clone());
707
708 let handle = spawner.spawn_with_result("long_compute", || async {
709 tokio::time::sleep(Duration::from_secs(60)).await;
710 42
711 });
712
713 shutdown.shutdown();
715
716 let result = handle.await.unwrap();
717 assert_eq!(result, None); }
719
720 #[tokio::test]
721 async fn test_shutdown_aware_spawner_clone() {
722 let shutdown = Arc::new(GracefulShutdown::new());
723 let spawner = ShutdownAwareTaskSpawner::new(shutdown.clone());
724 let spawner2 = spawner.clone();
725
726 assert!(Arc::ptr_eq(spawner.shutdown(), spawner2.shutdown()));
728 }
729
730 #[tokio::test]
731 async fn test_graceful_shutdown_ext_success() {
732 let shutdown = GracefulShutdown::new();
733
734 let result: Result<(), &str> = shutdown.perform_shutdown(|| async { Ok(()) }).await;
735
736 assert!(result.is_ok());
737 }
738
739 #[tokio::test]
740 async fn test_graceful_shutdown_ext_error() {
741 let shutdown = GracefulShutdown::new();
742
743 let result: Result<(), &str> = shutdown
744 .perform_shutdown(|| async { Err("cleanup failed") })
745 .await;
746
747 assert!(result.is_err());
748 assert_eq!(result.unwrap_err(), "cleanup failed");
749 }
750
751 #[tokio::test]
752 async fn test_graceful_shutdown_ext_with_arc() {
753 let shutdown = Arc::new(GracefulShutdown::new());
754
755 let result: Result<(), &str> = shutdown.perform_shutdown(|| async { Ok(()) }).await;
756
757 assert!(result.is_ok());
758 }
759
760 #[tokio::test]
761 async fn test_shutdown_aware_spawner_background() {
762 let shutdown = Arc::new(GracefulShutdown::new());
763 let spawner = ShutdownAwareTaskSpawner::new(shutdown.clone());
764
765 let counter = Arc::new(std::sync::atomic::AtomicU32::new(0));
766 let counter_clone = counter.clone();
767
768 let handle = spawner.spawn_background("bg_task", move || async move {
770 counter_clone.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
771 });
772
773 handle.await.unwrap();
774 assert_eq!(counter.load(std::sync::atomic::Ordering::SeqCst), 1);
775 }
776}