allframe_core/
shutdown.rs1use std::{future::Future, pin::Pin, sync::Arc, time::Duration};
38
39use tokio::sync::{broadcast, watch};
40
41#[derive(Debug, Clone, Copy, PartialEq, Eq)]
43pub enum ShutdownSignal {
44 Interrupt,
46 Terminate,
48 Manual,
50}
51
52impl std::fmt::Display for ShutdownSignal {
53 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54 match self {
55 ShutdownSignal::Interrupt => write!(f, "SIGINT"),
56 ShutdownSignal::Terminate => write!(f, "SIGTERM"),
57 ShutdownSignal::Manual => write!(f, "Manual"),
58 }
59 }
60}
61
62#[derive(Clone)]
64pub struct ShutdownToken {
65 receiver: watch::Receiver<bool>,
66}
67
68impl ShutdownToken {
69 pub fn is_shutdown(&self) -> bool {
71 *self.receiver.borrow()
72 }
73
74 pub async fn cancelled(&mut self) {
76 let _ = self.receiver.wait_for(|v| *v).await;
78 }
79}
80
81pub struct GracefulShutdownBuilder {
83 timeout: Duration,
84 on_signal: Option<Box<dyn Fn(ShutdownSignal) + Send + Sync>>,
85}
86
87impl Default for GracefulShutdownBuilder {
88 fn default() -> Self {
89 Self {
90 timeout: Duration::from_secs(30),
91 on_signal: None,
92 }
93 }
94}
95
96impl GracefulShutdownBuilder {
97 pub fn new() -> Self {
99 Self::default()
100 }
101
102 pub fn timeout(mut self, timeout: Duration) -> Self {
104 self.timeout = timeout;
105 self
106 }
107
108 pub fn on_signal<F>(mut self, callback: F) -> Self
110 where
111 F: Fn(ShutdownSignal) + Send + Sync + 'static,
112 {
113 self.on_signal = Some(Box::new(callback));
114 self
115 }
116
117 pub fn build(self) -> GracefulShutdown {
119 let on_signal: Option<Arc<dyn Fn(ShutdownSignal) + Send + Sync>> = self
120 .on_signal
121 .map(|f| Arc::from(f) as Arc<dyn Fn(ShutdownSignal) + Send + Sync>);
122 GracefulShutdown {
123 timeout: self.timeout,
124 on_signal,
125 shutdown_tx: watch::channel(false).0,
126 signal_tx: broadcast::channel(1).0,
127 }
128 }
129}
130
131pub struct GracefulShutdown {
135 timeout: Duration,
136 on_signal: Option<Arc<dyn Fn(ShutdownSignal) + Send + Sync>>,
137 shutdown_tx: watch::Sender<bool>,
138 signal_tx: broadcast::Sender<ShutdownSignal>,
139}
140
141impl GracefulShutdown {
142 pub fn new() -> Self {
144 GracefulShutdownBuilder::new().build()
145 }
146
147 pub fn builder() -> GracefulShutdownBuilder {
149 GracefulShutdownBuilder::new()
150 }
151
152 pub fn timeout(&self) -> Duration {
154 self.timeout
155 }
156
157 pub fn token(&self) -> ShutdownToken {
159 ShutdownToken {
160 receiver: self.shutdown_tx.subscribe(),
161 }
162 }
163
164 pub fn subscribe(&self) -> broadcast::Receiver<ShutdownSignal> {
166 self.signal_tx.subscribe()
167 }
168
169 pub fn shutdown(&self) {
171 let _ = self.shutdown_tx.send(true);
172 let _ = self.signal_tx.send(ShutdownSignal::Manual);
173 if let Some(ref callback) = self.on_signal {
174 callback(ShutdownSignal::Manual);
175 }
176 }
177
178 pub async fn wait(&self) -> ShutdownSignal {
182 let signal = wait_for_signal().await;
183
184 let _ = self.shutdown_tx.send(true);
186 let _ = self.signal_tx.send(signal);
187
188 if let Some(ref callback) = self.on_signal {
190 callback(signal);
191 }
192
193 signal
194 }
195
196 pub async fn wait_with_timeout(&self) -> Option<ShutdownSignal> {
200 tokio::select! {
201 signal = self.wait() => Some(signal),
202 _ = tokio::time::sleep(self.timeout) => None,
203 }
204 }
205
206 pub async fn run_until_shutdown<F, T>(&self, future: F) -> Option<T>
211 where
212 F: Future<Output = T>,
213 {
214 let mut token = self.token();
215 tokio::select! {
216 result = future => Some(result),
217 _ = token.cancelled() => None,
218 }
219 }
220
221 pub fn spawn<F>(&self, name: &str, future: F) -> tokio::task::JoinHandle<Option<()>>
223 where
224 F: Future<Output = ()> + Send + 'static,
225 {
226 let mut token = self.token();
227 let name = name.to_string();
228 tokio::spawn(async move {
229 tokio::select! {
230 _ = future => {
231 Some(())
232 }
233 _ = token.cancelled() => {
234 #[cfg(feature = "otel")]
235 tracing::info!(task = %name, "Task cancelled due to shutdown");
236 #[cfg(not(feature = "otel"))]
237 let _ = name;
238 None
239 }
240 }
241 })
242 }
243}
244
245impl Default for GracefulShutdown {
246 fn default() -> Self {
247 Self::new()
248 }
249}
250
251async fn wait_for_signal() -> ShutdownSignal {
253 #[cfg(unix)]
254 {
255 use tokio::signal::unix::{signal, SignalKind};
256
257 let mut sigint =
258 signal(SignalKind::interrupt()).expect("Failed to register SIGINT handler");
259 let mut sigterm =
260 signal(SignalKind::terminate()).expect("Failed to register SIGTERM handler");
261
262 tokio::select! {
263 _ = sigint.recv() => ShutdownSignal::Interrupt,
264 _ = sigterm.recv() => ShutdownSignal::Terminate,
265 }
266 }
267
268 #[cfg(not(unix))]
269 {
270 tokio::signal::ctrl_c()
271 .await
272 .expect("Failed to register Ctrl+C handler");
273 ShutdownSignal::Interrupt
274 }
275}
276
277pub struct ShutdownGuard {
281 shutdown: Arc<GracefulShutdown>,
282}
283
284impl ShutdownGuard {
285 pub fn new(shutdown: Arc<GracefulShutdown>) -> Self {
287 Self { shutdown }
288 }
289}
290
291impl Drop for ShutdownGuard {
292 fn drop(&mut self) {
293 self.shutdown.shutdown();
294 }
295}
296
297pub trait ShutdownExt: Future + Sized {
299 fn with_shutdown(
301 self,
302 shutdown: &GracefulShutdown,
303 ) -> Pin<Box<dyn Future<Output = Option<Self::Output>> + Send + '_>>
304 where
305 Self: Send + 'static,
306 Self::Output: Send,
307 {
308 let future = self;
309 let mut token = shutdown.token();
310 Box::pin(async move {
311 tokio::select! {
312 result = future => Some(result),
313 _ = token.cancelled() => None,
314 }
315 })
316 }
317}
318
319impl<F: Future> ShutdownExt for F {}
320
321#[cfg(test)]
322mod tests {
323 use super::*;
324
325 #[test]
326 fn test_shutdown_signal_display() {
327 assert_eq!(ShutdownSignal::Interrupt.to_string(), "SIGINT");
328 assert_eq!(ShutdownSignal::Terminate.to_string(), "SIGTERM");
329 assert_eq!(ShutdownSignal::Manual.to_string(), "Manual");
330 }
331
332 #[tokio::test]
333 async fn test_shutdown_token() {
334 let shutdown = GracefulShutdown::new();
335 let token = shutdown.token();
336
337 assert!(!token.is_shutdown());
338
339 shutdown.shutdown();
340
341 assert!(token.is_shutdown());
342 }
343
344 #[tokio::test]
345 async fn test_manual_shutdown() {
346 let shutdown = GracefulShutdown::new();
347 let mut rx = shutdown.subscribe();
348
349 shutdown.shutdown();
350
351 let signal = rx.recv().await.unwrap();
352 assert_eq!(signal, ShutdownSignal::Manual);
353 }
354
355 #[tokio::test]
356 async fn test_shutdown_callback() {
357 use std::sync::atomic::{AtomicBool, Ordering};
358
359 let called = Arc::new(AtomicBool::new(false));
360 let called_clone = called.clone();
361
362 let shutdown = GracefulShutdown::builder()
363 .on_signal(move |_| {
364 called_clone.store(true, Ordering::SeqCst);
365 })
366 .build();
367
368 shutdown.shutdown();
369
370 assert!(called.load(Ordering::SeqCst));
371 }
372
373 #[tokio::test]
374 async fn test_run_until_shutdown() {
375 let shutdown = GracefulShutdown::new();
376
377 let result = shutdown.run_until_shutdown(async { 42 }).await;
379 assert_eq!(result, Some(42));
380 }
381
382 #[tokio::test]
383 async fn test_run_until_shutdown_cancelled() {
384 let shutdown = GracefulShutdown::new();
385 let token = shutdown.token();
386
387 shutdown.shutdown();
389
390 assert!(token.is_shutdown());
392 }
393
394 #[tokio::test]
395 async fn test_builder_timeout() {
396 let shutdown = GracefulShutdown::builder()
397 .timeout(Duration::from_secs(60))
398 .build();
399
400 assert_eq!(shutdown.timeout(), Duration::from_secs(60));
401 }
402
403 #[tokio::test]
404 async fn test_spawn_task() {
405 let shutdown = GracefulShutdown::new();
406 let counter = Arc::new(std::sync::atomic::AtomicU32::new(0));
407 let counter_clone = counter.clone();
408
409 let handle = shutdown.spawn("test_task", async move {
410 counter_clone.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
411 });
412
413 let result = handle.await.unwrap();
415 assert_eq!(result, Some(()));
416 assert_eq!(counter.load(std::sync::atomic::Ordering::SeqCst), 1);
417 }
418}