shutdown_handler/
lib.rs

1//! A graceful shutdown handler that allows all parts of an application to trigger a shutdown.
2//!
3//! # Why?
4//!
5//! An application I was maintaining was in charge of 3 different services.
6//! * A RabbitMQ processing service
7//! * A gRPC Server
8//! * An HTTP metrics server.
9//!
10//! Our RabbitMQ node was restarted, so our connections dropped and our service went into shutdown mode.
11//! However, due to a bug in our application layer, we didn't acknowledge the failure immediately and
12//! continued handling the gRPC and HTTP traffic. Thankfully our alerts triggered that the queue was backing up
13//! and we manually restarted the application without any real impact.
14//!
15//! Understandably, I wanted a way to not have this happen ever again. We fixed the bug in the application, and then
16//! tackled the root cause: **Other services were oblivious that a shutdown happened**.
17//!
18//! Using this library, we've enforced that all service libraries take in a `ShutdownHandler` instance and use it to gracefully
19//! shutdown. If any of them are about to crash, they will immediately raise a shutdown signal. The other services
20//! will then see that signal, finish whatever work they had started, then shutdown.
21//!
22//! # Example
23//!
24//! ```
25//! use std::pin::pin;
26//! use std::sync::Arc;
27//! use shutdown_handler::{ShutdownHandler, SignalOrComplete};
28//!
29//! # #[tokio::main] async fn main() {
30//! // Create the shutdown handler
31//! let shutdown = Arc::new(ShutdownHandler::new());
32//!
33//! // Shutdown on SIGTERM
34//! shutdown.spawn_sigterm_handler().unwrap();
35//!
36//! // Spawn a few service workers
37//! let mut workers = tokio::task::JoinSet::new();
38//! for port in 0..4 {
39//!     workers.spawn(service(Arc::clone(&shutdown), port));
40//! }
41//!
42//! // await all workers and collect the errors
43//! let mut errors = vec![];
44//! while let Some(result) = workers.join_next().await {
45//!     // unwrap any JoinErrors that happen if the tokio task panicked
46//!     let result = result.unwrap();
47//!
48//!     // did our service error?
49//!     if let Err(e) = result {
50//!         errors.push(e);
51//!     }
52//! }
53//!
54//! assert_eq!(errors, ["port closed"]);
55//! # }
56//!
57//! // Define our services to loop on work and shutdown gracefully
58//!
59//! async fn service(shutdown: Arc<ShutdownHandler>, port: u16) -> Result<(), &'static str> {
60//!     // a work loop that handles events
61//!     for request in 0.. {
62//!         let handle = pin!(handle_request(port, request));
63//!
64//!         match shutdown.wait_for_signal_or_future(handle).await {
65//!             // We finished handling the request without any interuptions. Continue
66//!             SignalOrComplete::Completed(Ok(_)) => {}
67//!
68//!             // There was an error handling the request, let's shutdown
69//!             SignalOrComplete::Completed(Err(e)) => {
70//!                 shutdown.shutdown();
71//!                 return Err(e);
72//!             }
73//!
74//!             // There was a shutdown signal raised while handling this request
75//!             SignalOrComplete::ShutdownSignal(handle) => {
76//!                 // We will finish handling the request but then exit
77//!                 return handle.await;
78//!             }
79//!         }
80//!     }
81//!     Ok(())
82//! }
83//!
84//! async fn handle_request(port: u16, request: usize) -> Result<(), &'static str> {
85//!     // simulate some work being done
86//!     tokio::time::sleep(std::time::Duration::from_millis(10)).await;
87//!     
88//!     // simulate an error
89//!     if port == 3 && request > 12 {
90//!         Err("port closed")
91//!     } else {
92//!         Ok(())
93//!     }
94//! }
95//! ```
96
97use pin_project_lite::pin_project;
98use std::{
99    future::Future,
100    pin::{pin, Pin},
101    sync::{atomic::AtomicBool, Arc},
102    task::Poll,
103};
104use tokio::{
105    signal::unix::{signal, SignalKind},
106    sync::{futures::Notified, Notify},
107};
108
109/// A graceful shutdown handler that allows all parts of an application to trigger a shutdown.
110///
111/// # Example
112/// ```
113/// use std::pin::pin;
114/// use std::sync::Arc;
115/// use shutdown_handler::{ShutdownHandler, SignalOrComplete};
116///
117/// # #[tokio::main] async fn main() {
118/// // Create the shutdown handler
119/// let shutdown = Arc::new(ShutdownHandler::new());
120///
121/// // Shutdown on SIGTERM
122/// shutdown.spawn_sigterm_handler().unwrap();
123///
124/// // Spawn a few service workers
125/// let mut workers = tokio::task::JoinSet::new();
126/// for port in 0..4 {
127///     workers.spawn(service(Arc::clone(&shutdown), port));
128/// }
129///
130/// // await all workers and collect the errors
131/// let mut errors = vec![];
132/// while let Some(result) = workers.join_next().await {
133///     // unwrap any JoinErrors that happen if the tokio task panicked
134///     let result = result.unwrap();
135///
136///     // did our service error?
137///     if let Err(e) = result {
138///         errors.push(e);
139///     }
140/// }
141///
142/// assert_eq!(errors, ["port closed"]);
143/// # }
144///
145/// // Define our services to loop on work and shutdown gracefully
146///
147/// async fn service(shutdown: Arc<ShutdownHandler>, port: u16) -> Result<(), &'static str> {
148///     // a work loop that handles events
149///     for request in 0.. {
150///         let handle = pin!(handle_request(port, request));
151///
152///         match shutdown.wait_for_signal_or_future(handle).await {
153///             // We finished handling the request without any interuptions. Continue
154///             SignalOrComplete::Completed(Ok(_)) => {}
155///
156///             // There was an error handling the request, let's shutdown
157///             SignalOrComplete::Completed(Err(e)) => {
158///                 shutdown.shutdown();
159///                 return Err(e);
160///             }
161///
162///             // There was a shutdown signal raised while handling this request
163///             SignalOrComplete::ShutdownSignal(handle) => {
164///                 // We will finish handling the request but then exit
165///                 return handle.await;
166///             }
167///         }
168///     }
169///     Ok(())
170/// }
171///
172/// async fn handle_request(port: u16, request: usize) -> Result<(), &'static str> {
173///     // simulate some work being done
174///     tokio::time::sleep(std::time::Duration::from_millis(10)).await;
175///     
176///     // simulate an error
177///     if port == 3 && request > 12 {
178///         Err("port closed")
179///     } else {
180///         Ok(())
181///     }
182/// }
183/// ```
184#[derive(Debug, Default)]
185pub struct ShutdownHandler {
186    notifier: Notify,
187    shutdown: AtomicBool,
188}
189
190impl ShutdownHandler {
191    pub fn new() -> Self {
192        Self::default()
193    }
194
195    /// Creates a new `ShutdownHandler` and registers the sigterm handler
196    pub fn sigterm() -> std::io::Result<Arc<Self>> {
197        let this = Arc::new(Self::new());
198        this.spawn_sigterm_handler()?;
199        Ok(this)
200    }
201
202    /// Registers the signal event `SIGTERM` to trigger an application shutdown
203    pub fn spawn_sigterm_handler(self: &Arc<Self>) -> std::io::Result<()> {
204        self.spawn_signal_handler(SignalKind::terminate())
205    }
206
207    /// Registers a signal event to trigger an application shutdown
208    pub fn spawn_signal_handler(self: &Arc<Self>, signal_kind: SignalKind) -> std::io::Result<()> {
209        let mut signal = signal(signal_kind)?;
210
211        let shutdown = self.clone();
212        tokio::spawn(async move {
213            signal.recv().await;
214            shutdown.shutdown();
215        });
216        Ok(())
217    }
218
219    /// Sends the shutdown signal to all the current and future waiters
220    pub fn shutdown(&self) {
221        self.shutdown
222            .store(true, std::sync::atomic::Ordering::Release);
223        self.notifier.notify_waiters();
224    }
225
226    /// Returns a future that waits for the shutdown signal.
227    ///
228    /// You can use this like an async function.
229    pub fn wait_for_signal(&self) -> ShutdownSignal<'_> {
230        ShutdownSignal {
231            shutdown: &self.shutdown,
232            notified: self.notifier.notified(),
233        }
234    }
235
236    /// This method will try to complete the given future, but will give up if the shutdown signal is raised.
237    /// The unfinished future is returned in case it is not cancel safe and you need to complete it
238    ///
239    /// ```
240    /// use std::sync::Arc;
241    /// use std::pin::pin;
242    /// use shutdown_handler::{ShutdownHandler, SignalOrComplete};
243    ///
244    /// # #[tokio::main] async fn main() {
245    /// async fn important_work() -> i32 {
246    ///     tokio::time::sleep(std::time::Duration::from_secs(2)).await;
247    ///     42
248    /// }
249    ///
250    /// let shutdown = Arc::new(ShutdownHandler::new());
251    ///
252    /// // another part of the application signals a shutdown
253    /// let shutdown2 = Arc::clone(&shutdown);
254    /// let handle = tokio::spawn(async move {
255    ///     tokio::time::sleep(std::time::Duration::from_secs(1)).await;
256    ///     shutdown2.shutdown();
257    /// });
258    ///
259    /// let work = pin!(important_work());
260    ///
261    /// match shutdown.wait_for_signal_or_future(work).await {
262    ///     SignalOrComplete::Completed(res) => println!("important work completed without interuption: {res}"),
263    ///     SignalOrComplete::ShutdownSignal(work) => {
264    ///         println!("shutdown signal recieved");
265    ///         let res = work.await;
266    ///         println!("important work completed: {res}");
267    ///     },
268    /// }
269    /// # }
270    /// ```
271    pub async fn wait_for_signal_or_future<F: Future + Unpin>(&self, f: F) -> SignalOrComplete<F> {
272        let mut handle = pin!(self.wait_for_signal());
273        let mut f = Some(f);
274
275        std::future::poll_fn(|cx| {
276            if let Poll::Ready(_signal) = handle.as_mut().poll(cx) {
277                return Poll::Ready(SignalOrComplete::ShutdownSignal(f.take().unwrap()));
278            }
279
280            if let Poll::Ready(res) = Pin::new(f.as_mut().unwrap()).poll(cx) {
281                return Poll::Ready(SignalOrComplete::Completed(res));
282            }
283
284            Poll::Pending
285        })
286        .await
287    }
288}
289
290#[derive(Debug)]
291/// Reports whether a future managed to complete without interuption, or if there was a shutdown signal
292pub enum SignalOrComplete<F: Future> {
293    ShutdownSignal(F),
294    Completed(F::Output),
295}
296
297pin_project!(
298    /// A Future that waits for a shutdown signal. Returned by [`ShutdownHandler::shutdown`]
299    pub struct ShutdownSignal<'a> {
300        shutdown: &'a AtomicBool,
301        #[pin]
302        notified: Notified<'a>,
303    }
304);
305
306impl std::future::Future for ShutdownSignal<'_> {
307    type Output = ();
308
309    fn poll(
310        self: std::pin::Pin<&mut Self>,
311        cx: &mut std::task::Context<'_>,
312    ) -> std::task::Poll<Self::Output> {
313        let this = self.project();
314        if this.shutdown.load(std::sync::atomic::Ordering::Acquire) {
315            std::task::Poll::Ready(())
316        } else {
317            this.notified.poll(cx)
318        }
319    }
320}
321
322#[cfg(test)]
323mod test {
324    use std::{sync::Arc, time::Duration};
325
326    use nix::sys::signal::{raise, Signal};
327    use tokio::{signal::unix::SignalKind, sync::oneshot, time::timeout};
328
329    use crate::ShutdownHandler;
330
331    #[tokio::test]
332    async fn shutdown_sigterm() {
333        let shutdown = Arc::new(ShutdownHandler::new());
334        shutdown.spawn_sigterm_handler().unwrap();
335
336        let (tx, rx) = oneshot::channel();
337        tokio::spawn(async move {
338            shutdown.wait_for_signal().await;
339            tx.send(true).unwrap();
340        });
341
342        raise(Signal::SIGTERM).unwrap();
343
344        assert!(
345            (timeout(Duration::from_secs(1), rx).await).is_ok(),
346            "Shutdown handler took longer than 1 second!"
347        );
348    }
349
350    #[tokio::test]
351    async fn shutdown_custom_signal() {
352        let shutdown = Arc::new(ShutdownHandler::new());
353        shutdown.spawn_signal_handler(SignalKind::hangup()).unwrap();
354
355        let (tx, rx) = oneshot::channel();
356        tokio::spawn(async move {
357            shutdown.wait_for_signal().await;
358            tx.send(true).unwrap();
359        });
360
361        raise(Signal::SIGHUP).unwrap();
362
363        assert!(
364            (timeout(Duration::from_secs(1), rx).await).is_ok(),
365            "Shutdown handler took longer than 1 second!"
366        );
367    }
368
369    #[tokio::test]
370    async fn shutdown() {
371        let shutdown = Arc::new(ShutdownHandler::new());
372
373        let (tx, rx) = oneshot::channel();
374        let channel_shutdown = shutdown.clone();
375        tokio::spawn(async move {
376            channel_shutdown.wait_for_signal().await;
377            tx.send(true).unwrap();
378        });
379
380        tokio::spawn(async move {
381            shutdown.shutdown();
382        });
383
384        assert!(
385            (timeout(Duration::from_secs(1), rx).await).is_ok(),
386            "Shutdown handler took longer than 1 second!"
387        );
388    }
389
390    #[tokio::test]
391    async fn no_notification() {
392        let shutdown = Arc::new(ShutdownHandler::new());
393
394        let (tx, rx) = oneshot::channel();
395        tokio::spawn(async move {
396            shutdown.wait_for_signal().await;
397            tx.send(true).unwrap();
398        });
399
400        assert!(
401            (timeout(Duration::from_secs(1), rx).await).is_err(),
402            "Shutdown handler ran without a signal!"
403        );
404    }
405}