allframe_core/
shutdown.rs1use std::{
38 future::Future,
39 pin::Pin,
40 sync::Arc,
41 time::Duration,
42};
43
44use tokio::sync::{broadcast, watch};
45
46#[derive(Debug, Clone, Copy, PartialEq, Eq)]
48pub enum ShutdownSignal {
49 Interrupt,
51 Terminate,
53 Manual,
55}
56
57impl std::fmt::Display for ShutdownSignal {
58 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59 match self {
60 ShutdownSignal::Interrupt => write!(f, "SIGINT"),
61 ShutdownSignal::Terminate => write!(f, "SIGTERM"),
62 ShutdownSignal::Manual => write!(f, "Manual"),
63 }
64 }
65}
66
67#[derive(Clone)]
69pub struct ShutdownToken {
70 receiver: watch::Receiver<bool>,
71}
72
73impl ShutdownToken {
74 pub fn is_shutdown(&self) -> bool {
76 *self.receiver.borrow()
77 }
78
79 pub async fn cancelled(&mut self) {
81 let _ = self.receiver.wait_for(|v| *v).await;
83 }
84}
85
86pub struct GracefulShutdownBuilder {
88 timeout: Duration,
89 on_signal: Option<Box<dyn Fn(ShutdownSignal) + Send + Sync>>,
90}
91
92impl Default for GracefulShutdownBuilder {
93 fn default() -> Self {
94 Self {
95 timeout: Duration::from_secs(30),
96 on_signal: None,
97 }
98 }
99}
100
101impl GracefulShutdownBuilder {
102 pub fn new() -> Self {
104 Self::default()
105 }
106
107 pub fn timeout(mut self, timeout: Duration) -> Self {
109 self.timeout = timeout;
110 self
111 }
112
113 pub fn on_signal<F>(mut self, callback: F) -> Self
115 where
116 F: Fn(ShutdownSignal) + Send + Sync + 'static,
117 {
118 self.on_signal = Some(Box::new(callback));
119 self
120 }
121
122 pub fn build(self) -> GracefulShutdown {
124 let on_signal: Option<Arc<dyn Fn(ShutdownSignal) + Send + Sync>> =
125 self.on_signal.map(|f| Arc::from(f) as Arc<dyn Fn(ShutdownSignal) + Send + Sync>);
126 GracefulShutdown {
127 timeout: self.timeout,
128 on_signal,
129 shutdown_tx: watch::channel(false).0,
130 signal_tx: broadcast::channel(1).0,
131 }
132 }
133}
134
135pub struct GracefulShutdown {
139 timeout: Duration,
140 on_signal: Option<Arc<dyn Fn(ShutdownSignal) + Send + Sync>>,
141 shutdown_tx: watch::Sender<bool>,
142 signal_tx: broadcast::Sender<ShutdownSignal>,
143}
144
145impl GracefulShutdown {
146 pub fn new() -> Self {
148 GracefulShutdownBuilder::new().build()
149 }
150
151 pub fn builder() -> GracefulShutdownBuilder {
153 GracefulShutdownBuilder::new()
154 }
155
156 pub fn timeout(&self) -> Duration {
158 self.timeout
159 }
160
161 pub fn token(&self) -> ShutdownToken {
163 ShutdownToken {
164 receiver: self.shutdown_tx.subscribe(),
165 }
166 }
167
168 pub fn subscribe(&self) -> broadcast::Receiver<ShutdownSignal> {
170 self.signal_tx.subscribe()
171 }
172
173 pub fn shutdown(&self) {
175 let _ = self.shutdown_tx.send(true);
176 let _ = self.signal_tx.send(ShutdownSignal::Manual);
177 if let Some(ref callback) = self.on_signal {
178 callback(ShutdownSignal::Manual);
179 }
180 }
181
182 pub async fn wait(&self) -> ShutdownSignal {
186 let signal = wait_for_signal().await;
187
188 let _ = self.shutdown_tx.send(true);
190 let _ = self.signal_tx.send(signal);
191
192 if let Some(ref callback) = self.on_signal {
194 callback(signal);
195 }
196
197 signal
198 }
199
200 pub async fn wait_with_timeout(&self) -> Option<ShutdownSignal> {
204 tokio::select! {
205 signal = self.wait() => Some(signal),
206 _ = tokio::time::sleep(self.timeout) => None,
207 }
208 }
209
210 pub async fn run_until_shutdown<F, T>(&self, future: F) -> Option<T>
215 where
216 F: Future<Output = T>,
217 {
218 let mut token = self.token();
219 tokio::select! {
220 result = future => Some(result),
221 _ = token.cancelled() => None,
222 }
223 }
224
225 pub fn spawn<F>(&self, name: &str, future: F) -> tokio::task::JoinHandle<Option<()>>
227 where
228 F: Future<Output = ()> + Send + 'static,
229 {
230 let mut token = self.token();
231 let name = name.to_string();
232 tokio::spawn(async move {
233 tokio::select! {
234 _ = future => {
235 Some(())
236 }
237 _ = token.cancelled() => {
238 #[cfg(feature = "otel")]
239 tracing::info!(task = %name, "Task cancelled due to shutdown");
240 #[cfg(not(feature = "otel"))]
241 let _ = name;
242 None
243 }
244 }
245 })
246 }
247}
248
249impl Default for GracefulShutdown {
250 fn default() -> Self {
251 Self::new()
252 }
253}
254
255async fn wait_for_signal() -> ShutdownSignal {
257 #[cfg(unix)]
258 {
259 use tokio::signal::unix::{signal, SignalKind};
260
261 let mut sigint = signal(SignalKind::interrupt()).expect("Failed to register SIGINT handler");
262 let mut sigterm = signal(SignalKind::terminate()).expect("Failed to register SIGTERM handler");
263
264 tokio::select! {
265 _ = sigint.recv() => ShutdownSignal::Interrupt,
266 _ = sigterm.recv() => ShutdownSignal::Terminate,
267 }
268 }
269
270 #[cfg(not(unix))]
271 {
272 tokio::signal::ctrl_c()
273 .await
274 .expect("Failed to register Ctrl+C handler");
275 ShutdownSignal::Interrupt
276 }
277}
278
279pub struct ShutdownGuard {
283 shutdown: Arc<GracefulShutdown>,
284}
285
286impl ShutdownGuard {
287 pub fn new(shutdown: Arc<GracefulShutdown>) -> Self {
289 Self { shutdown }
290 }
291}
292
293impl Drop for ShutdownGuard {
294 fn drop(&mut self) {
295 self.shutdown.shutdown();
296 }
297}
298
299pub trait ShutdownExt: Future + Sized {
301 fn with_shutdown(
303 self,
304 shutdown: &GracefulShutdown,
305 ) -> Pin<Box<dyn Future<Output = Option<Self::Output>> + Send + '_>>
306 where
307 Self: Send + 'static,
308 Self::Output: Send,
309 {
310 let future = self;
311 let mut token = shutdown.token();
312 Box::pin(async move {
313 tokio::select! {
314 result = future => Some(result),
315 _ = token.cancelled() => None,
316 }
317 })
318 }
319}
320
321impl<F: Future> ShutdownExt for F {}
322
323#[cfg(test)]
324mod tests {
325 use super::*;
326
327 #[test]
328 fn test_shutdown_signal_display() {
329 assert_eq!(ShutdownSignal::Interrupt.to_string(), "SIGINT");
330 assert_eq!(ShutdownSignal::Terminate.to_string(), "SIGTERM");
331 assert_eq!(ShutdownSignal::Manual.to_string(), "Manual");
332 }
333
334 #[tokio::test]
335 async fn test_shutdown_token() {
336 let shutdown = GracefulShutdown::new();
337 let token = shutdown.token();
338
339 assert!(!token.is_shutdown());
340
341 shutdown.shutdown();
342
343 assert!(token.is_shutdown());
344 }
345
346 #[tokio::test]
347 async fn test_manual_shutdown() {
348 let shutdown = GracefulShutdown::new();
349 let mut rx = shutdown.subscribe();
350
351 shutdown.shutdown();
352
353 let signal = rx.recv().await.unwrap();
354 assert_eq!(signal, ShutdownSignal::Manual);
355 }
356
357 #[tokio::test]
358 async fn test_shutdown_callback() {
359 use std::sync::atomic::{AtomicBool, Ordering};
360
361 let called = Arc::new(AtomicBool::new(false));
362 let called_clone = called.clone();
363
364 let shutdown = GracefulShutdown::builder()
365 .on_signal(move |_| {
366 called_clone.store(true, Ordering::SeqCst);
367 })
368 .build();
369
370 shutdown.shutdown();
371
372 assert!(called.load(Ordering::SeqCst));
373 }
374
375 #[tokio::test]
376 async fn test_run_until_shutdown() {
377 let shutdown = GracefulShutdown::new();
378
379 let result = shutdown
381 .run_until_shutdown(async { 42 })
382 .await;
383 assert_eq!(result, Some(42));
384 }
385
386 #[tokio::test]
387 async fn test_run_until_shutdown_cancelled() {
388 let shutdown = GracefulShutdown::new();
389 let token = shutdown.token();
390
391 shutdown.shutdown();
393
394 assert!(token.is_shutdown());
396 }
397
398 #[tokio::test]
399 async fn test_builder_timeout() {
400 let shutdown = GracefulShutdown::builder()
401 .timeout(Duration::from_secs(60))
402 .build();
403
404 assert_eq!(shutdown.timeout(), Duration::from_secs(60));
405 }
406
407 #[tokio::test]
408 async fn test_spawn_task() {
409 let shutdown = GracefulShutdown::new();
410 let counter = Arc::new(std::sync::atomic::AtomicU32::new(0));
411 let counter_clone = counter.clone();
412
413 let handle = shutdown.spawn("test_task", async move {
414 counter_clone.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
415 });
416
417 let result = handle.await.unwrap();
419 assert_eq!(result, Some(()));
420 assert_eq!(counter.load(std::sync::atomic::Ordering::SeqCst), 1);
421 }
422}