shutdown_handler/lib.rs
1//! A graceful shutdown handler that allows all parts of an application to trigger a shutdown.
2//!
3//! # Why?
4//!
5//! An application I was maintaining was in charge of 3 different services.
6//! * A RabbitMQ processing service
7//! * A gRPC Server
8//! * An HTTP metrics server.
9//!
10//! Our RabbitMQ node was restarted, so our connections dropped and our service went into shutdown mode.
11//! However, due to a bug in our application layer, we didn't acknowledge the failure immediately and
12//! continued handling the gRPC and HTTP traffic. Thankfully our alerts triggered that the queue was backing up
13//! and we manually restarted the application without any real impact.
14//!
15//! Understandably, I wanted a way to not have this happen ever again. We fixed the bug in the application, and then
16//! tackled the root cause: **Other services were oblivious that a shutdown happened**.
17//!
18//! Using this library, we've enforced that all service libraries take in a `ShutdownHandler` instance and use it to gracefully
19//! shutdown. If any of them are about to crash, they will immediately raise a shutdown signal. The other services
20//! will then see that signal, finish whatever work they had started, then shutdown.
21//!
22//! # Example
23//!
24//! ```
25//! use std::pin::pin;
26//! use std::sync::Arc;
27//! use shutdown_handler::{ShutdownHandler, SignalOrComplete};
28//!
29//! # #[tokio::main] async fn main() {
30//! // Create the shutdown handler
31//! let shutdown = Arc::new(ShutdownHandler::new());
32//!
33//! // Shutdown on SIGTERM
34//! shutdown.spawn_sigterm_handler().unwrap();
35//!
36//! // Spawn a few service workers
37//! let mut workers = tokio::task::JoinSet::new();
38//! for port in 0..4 {
39//! workers.spawn(service(Arc::clone(&shutdown), port));
40//! }
41//!
42//! // await all workers and collect the errors
43//! let mut errors = vec![];
44//! while let Some(result) = workers.join_next().await {
45//! // unwrap any JoinErrors that happen if the tokio task panicked
46//! let result = result.unwrap();
47//!
48//! // did our service error?
49//! if let Err(e) = result {
50//! errors.push(e);
51//! }
52//! }
53//!
54//! assert_eq!(errors, ["port closed"]);
55//! # }
56//!
57//! // Define our services to loop on work and shutdown gracefully
58//!
59//! async fn service(shutdown: Arc<ShutdownHandler>, port: u16) -> Result<(), &'static str> {
60//! // a work loop that handles events
61//! for request in 0.. {
62//! let handle = pin!(handle_request(port, request));
63//!
64//! match shutdown.wait_for_signal_or_future(handle).await {
65//! // We finished handling the request without any interuptions. Continue
66//! SignalOrComplete::Completed(Ok(_)) => {}
67//!
68//! // There was an error handling the request, let's shutdown
69//! SignalOrComplete::Completed(Err(e)) => {
70//! shutdown.shutdown();
71//! return Err(e);
72//! }
73//!
74//! // There was a shutdown signal raised while handling this request
75//! SignalOrComplete::ShutdownSignal(handle) => {
76//! // We will finish handling the request but then exit
77//! return handle.await;
78//! }
79//! }
80//! }
81//! Ok(())
82//! }
83//!
84//! async fn handle_request(port: u16, request: usize) -> Result<(), &'static str> {
85//! // simulate some work being done
86//! tokio::time::sleep(std::time::Duration::from_millis(10)).await;
87//!
88//! // simulate an error
89//! if port == 3 && request > 12 {
90//! Err("port closed")
91//! } else {
92//! Ok(())
93//! }
94//! }
95//! ```
96
97use pin_project_lite::pin_project;
98use std::{
99 future::Future,
100 pin::{pin, Pin},
101 sync::{atomic::AtomicBool, Arc},
102 task::Poll,
103};
104use tokio::{
105 signal::unix::{signal, SignalKind},
106 sync::{futures::Notified, Notify},
107};
108
109/// A graceful shutdown handler that allows all parts of an application to trigger a shutdown.
110///
111/// # Example
112/// ```
113/// use std::pin::pin;
114/// use std::sync::Arc;
115/// use shutdown_handler::{ShutdownHandler, SignalOrComplete};
116///
117/// # #[tokio::main] async fn main() {
118/// // Create the shutdown handler
119/// let shutdown = Arc::new(ShutdownHandler::new());
120///
121/// // Shutdown on SIGTERM
122/// shutdown.spawn_sigterm_handler().unwrap();
123///
124/// // Spawn a few service workers
125/// let mut workers = tokio::task::JoinSet::new();
126/// for port in 0..4 {
127/// workers.spawn(service(Arc::clone(&shutdown), port));
128/// }
129///
130/// // await all workers and collect the errors
131/// let mut errors = vec![];
132/// while let Some(result) = workers.join_next().await {
133/// // unwrap any JoinErrors that happen if the tokio task panicked
134/// let result = result.unwrap();
135///
136/// // did our service error?
137/// if let Err(e) = result {
138/// errors.push(e);
139/// }
140/// }
141///
142/// assert_eq!(errors, ["port closed"]);
143/// # }
144///
145/// // Define our services to loop on work and shutdown gracefully
146///
147/// async fn service(shutdown: Arc<ShutdownHandler>, port: u16) -> Result<(), &'static str> {
148/// // a work loop that handles events
149/// for request in 0.. {
150/// let handle = pin!(handle_request(port, request));
151///
152/// match shutdown.wait_for_signal_or_future(handle).await {
153/// // We finished handling the request without any interuptions. Continue
154/// SignalOrComplete::Completed(Ok(_)) => {}
155///
156/// // There was an error handling the request, let's shutdown
157/// SignalOrComplete::Completed(Err(e)) => {
158/// shutdown.shutdown();
159/// return Err(e);
160/// }
161///
162/// // There was a shutdown signal raised while handling this request
163/// SignalOrComplete::ShutdownSignal(handle) => {
164/// // We will finish handling the request but then exit
165/// return handle.await;
166/// }
167/// }
168/// }
169/// Ok(())
170/// }
171///
172/// async fn handle_request(port: u16, request: usize) -> Result<(), &'static str> {
173/// // simulate some work being done
174/// tokio::time::sleep(std::time::Duration::from_millis(10)).await;
175///
176/// // simulate an error
177/// if port == 3 && request > 12 {
178/// Err("port closed")
179/// } else {
180/// Ok(())
181/// }
182/// }
183/// ```
184#[derive(Debug, Default)]
185pub struct ShutdownHandler {
186 notifier: Notify,
187 shutdown: AtomicBool,
188}
189
190impl ShutdownHandler {
191 pub fn new() -> Self {
192 Self::default()
193 }
194
195 /// Creates a new `ShutdownHandler` and registers the sigterm handler
196 pub fn sigterm() -> std::io::Result<Arc<Self>> {
197 let this = Arc::new(Self::new());
198 this.spawn_sigterm_handler()?;
199 Ok(this)
200 }
201
202 /// Registers the signal event `SIGTERM` to trigger an application shutdown
203 pub fn spawn_sigterm_handler(self: &Arc<Self>) -> std::io::Result<()> {
204 self.spawn_signal_handler(SignalKind::terminate())
205 }
206
207 /// Registers a signal event to trigger an application shutdown
208 pub fn spawn_signal_handler(self: &Arc<Self>, signal_kind: SignalKind) -> std::io::Result<()> {
209 let mut signal = signal(signal_kind)?;
210
211 let shutdown = self.clone();
212 tokio::spawn(async move {
213 signal.recv().await;
214 shutdown.shutdown();
215 });
216 Ok(())
217 }
218
219 /// Sends the shutdown signal to all the current and future waiters
220 pub fn shutdown(&self) {
221 self.shutdown
222 .store(true, std::sync::atomic::Ordering::Release);
223 self.notifier.notify_waiters();
224 }
225
226 /// Returns a future that waits for the shutdown signal.
227 ///
228 /// You can use this like an async function.
229 pub fn wait_for_signal(&self) -> ShutdownSignal<'_> {
230 ShutdownSignal {
231 shutdown: &self.shutdown,
232 notified: self.notifier.notified(),
233 }
234 }
235
236 /// This method will try to complete the given future, but will give up if the shutdown signal is raised.
237 /// The unfinished future is returned in case it is not cancel safe and you need to complete it
238 ///
239 /// ```
240 /// use std::sync::Arc;
241 /// use std::pin::pin;
242 /// use shutdown_handler::{ShutdownHandler, SignalOrComplete};
243 ///
244 /// # #[tokio::main] async fn main() {
245 /// async fn important_work() -> i32 {
246 /// tokio::time::sleep(std::time::Duration::from_secs(2)).await;
247 /// 42
248 /// }
249 ///
250 /// let shutdown = Arc::new(ShutdownHandler::new());
251 ///
252 /// // another part of the application signals a shutdown
253 /// let shutdown2 = Arc::clone(&shutdown);
254 /// let handle = tokio::spawn(async move {
255 /// tokio::time::sleep(std::time::Duration::from_secs(1)).await;
256 /// shutdown2.shutdown();
257 /// });
258 ///
259 /// let work = pin!(important_work());
260 ///
261 /// match shutdown.wait_for_signal_or_future(work).await {
262 /// SignalOrComplete::Completed(res) => println!("important work completed without interuption: {res}"),
263 /// SignalOrComplete::ShutdownSignal(work) => {
264 /// println!("shutdown signal recieved");
265 /// let res = work.await;
266 /// println!("important work completed: {res}");
267 /// },
268 /// }
269 /// # }
270 /// ```
271 pub async fn wait_for_signal_or_future<F: Future + Unpin>(&self, f: F) -> SignalOrComplete<F> {
272 let mut handle = pin!(self.wait_for_signal());
273 let mut f = Some(f);
274
275 std::future::poll_fn(|cx| {
276 if let Poll::Ready(_signal) = handle.as_mut().poll(cx) {
277 return Poll::Ready(SignalOrComplete::ShutdownSignal(f.take().unwrap()));
278 }
279
280 if let Poll::Ready(res) = Pin::new(f.as_mut().unwrap()).poll(cx) {
281 return Poll::Ready(SignalOrComplete::Completed(res));
282 }
283
284 Poll::Pending
285 })
286 .await
287 }
288}
289
290#[derive(Debug)]
291/// Reports whether a future managed to complete without interuption, or if there was a shutdown signal
292pub enum SignalOrComplete<F: Future> {
293 ShutdownSignal(F),
294 Completed(F::Output),
295}
296
297pin_project!(
298 /// A Future that waits for a shutdown signal. Returned by [`ShutdownHandler::shutdown`]
299 pub struct ShutdownSignal<'a> {
300 shutdown: &'a AtomicBool,
301 #[pin]
302 notified: Notified<'a>,
303 }
304);
305
306impl std::future::Future for ShutdownSignal<'_> {
307 type Output = ();
308
309 fn poll(
310 self: std::pin::Pin<&mut Self>,
311 cx: &mut std::task::Context<'_>,
312 ) -> std::task::Poll<Self::Output> {
313 let this = self.project();
314 if this.shutdown.load(std::sync::atomic::Ordering::Acquire) {
315 std::task::Poll::Ready(())
316 } else {
317 this.notified.poll(cx)
318 }
319 }
320}
321
322#[cfg(test)]
323mod test {
324 use std::{sync::Arc, time::Duration};
325
326 use nix::sys::signal::{raise, Signal};
327 use tokio::{signal::unix::SignalKind, sync::oneshot, time::timeout};
328
329 use crate::ShutdownHandler;
330
331 #[tokio::test]
332 async fn shutdown_sigterm() {
333 let shutdown = Arc::new(ShutdownHandler::new());
334 shutdown.spawn_sigterm_handler().unwrap();
335
336 let (tx, rx) = oneshot::channel();
337 tokio::spawn(async move {
338 shutdown.wait_for_signal().await;
339 tx.send(true).unwrap();
340 });
341
342 raise(Signal::SIGTERM).unwrap();
343
344 assert!(
345 (timeout(Duration::from_secs(1), rx).await).is_ok(),
346 "Shutdown handler took longer than 1 second!"
347 );
348 }
349
350 #[tokio::test]
351 async fn shutdown_custom_signal() {
352 let shutdown = Arc::new(ShutdownHandler::new());
353 shutdown.spawn_signal_handler(SignalKind::hangup()).unwrap();
354
355 let (tx, rx) = oneshot::channel();
356 tokio::spawn(async move {
357 shutdown.wait_for_signal().await;
358 tx.send(true).unwrap();
359 });
360
361 raise(Signal::SIGHUP).unwrap();
362
363 assert!(
364 (timeout(Duration::from_secs(1), rx).await).is_ok(),
365 "Shutdown handler took longer than 1 second!"
366 );
367 }
368
369 #[tokio::test]
370 async fn shutdown() {
371 let shutdown = Arc::new(ShutdownHandler::new());
372
373 let (tx, rx) = oneshot::channel();
374 let channel_shutdown = shutdown.clone();
375 tokio::spawn(async move {
376 channel_shutdown.wait_for_signal().await;
377 tx.send(true).unwrap();
378 });
379
380 tokio::spawn(async move {
381 shutdown.shutdown();
382 });
383
384 assert!(
385 (timeout(Duration::from_secs(1), rx).await).is_ok(),
386 "Shutdown handler took longer than 1 second!"
387 );
388 }
389
390 #[tokio::test]
391 async fn no_notification() {
392 let shutdown = Arc::new(ShutdownHandler::new());
393
394 let (tx, rx) = oneshot::channel();
395 tokio::spawn(async move {
396 shutdown.wait_for_signal().await;
397 tx.send(true).unwrap();
398 });
399
400 assert!(
401 (timeout(Duration::from_secs(1), rx).await).is_err(),
402 "Shutdown handler ran without a signal!"
403 );
404 }
405}