hyper_util/server/
graceful.rs

1//! Utility to gracefully shutdown a server.
2//!
3//! This module provides a [`GracefulShutdown`] type,
4//! which can be used to gracefully shutdown a server.
5//!
6//! See <https://github.com/hyperium/hyper-util/blob/master/examples/server_graceful.rs>
7//! for an example of how to use this.
8
9use std::{
10    fmt::{self, Debug},
11    future::Future,
12    pin::Pin,
13    task::{self, Poll},
14};
15
16use pin_project_lite::pin_project;
17use tokio::sync::watch;
18
19/// A graceful shutdown utility
20// Purposefully not `Clone`, see `watcher()` method for why.
21pub struct GracefulShutdown {
22    tx: watch::Sender<()>,
23}
24
25/// A watcher side of the graceful shutdown.
26///
27/// This type can only watch a connection, it cannot trigger a shutdown.
28///
29/// Call [`GracefulShutdown::watcher()`] to construct one of these.
30pub struct Watcher {
31    rx: watch::Receiver<()>,
32}
33
34impl GracefulShutdown {
35    /// Create a new graceful shutdown helper.
36    pub fn new() -> Self {
37        let (tx, _) = watch::channel(());
38        Self { tx }
39    }
40
41    /// Wrap a future for graceful shutdown watching.
42    pub fn watch<C: GracefulConnection>(&self, conn: C) -> impl Future<Output = C::Output> {
43        self.watcher().watch(conn)
44    }
45
46    /// Create an owned type that can watch a connection.
47    ///
48    /// This method allows created an owned type that can be sent onto another
49    /// task before calling [`Watcher::watch()`].
50    // Internal: this function exists because `Clone` allows footguns.
51    // If the `tx` were cloned (or the `rx`), race conditions can happens where
52    // one task starting a shutdown is scheduled and interwined with a task
53    // starting to watch a connection, and the "watch version" is one behind.
54    pub fn watcher(&self) -> Watcher {
55        let rx = self.tx.subscribe();
56        Watcher { rx }
57    }
58
59    /// Signal shutdown for all watched connections.
60    ///
61    /// This returns a `Future` which will complete once all watched
62    /// connections have shutdown.
63    pub async fn shutdown(self) {
64        let Self { tx } = self;
65
66        // signal all the watched futures about the change
67        let _ = tx.send(());
68        // and then wait for all of them to complete
69        tx.closed().await;
70    }
71}
72
73impl Debug for GracefulShutdown {
74    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
75        f.debug_struct("GracefulShutdown").finish()
76    }
77}
78
79impl Default for GracefulShutdown {
80    fn default() -> Self {
81        Self::new()
82    }
83}
84
85impl Watcher {
86    /// Wrap a future for graceful shutdown watching.
87    pub fn watch<C: GracefulConnection>(self, conn: C) -> impl Future<Output = C::Output> {
88        let Watcher { mut rx } = self;
89        GracefulConnectionFuture::new(conn, async move {
90            let _ = rx.changed().await;
91            // hold onto the rx until the watched future is completed
92            rx
93        })
94    }
95}
96
97impl Debug for Watcher {
98    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
99        f.debug_struct("GracefulWatcher").finish()
100    }
101}
102
103pin_project! {
104    struct GracefulConnectionFuture<C, F: Future> {
105        #[pin]
106        conn: C,
107        #[pin]
108        cancel: F,
109        #[pin]
110        // If cancelled, this is held until the inner conn is done.
111        cancelled_guard: Option<F::Output>,
112    }
113}
114
115impl<C, F: Future> GracefulConnectionFuture<C, F> {
116    fn new(conn: C, cancel: F) -> Self {
117        Self {
118            conn,
119            cancel,
120            cancelled_guard: None,
121        }
122    }
123}
124
125impl<C, F: Future> Debug for GracefulConnectionFuture<C, F> {
126    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127        f.debug_struct("GracefulConnectionFuture").finish()
128    }
129}
130
131impl<C, F> Future for GracefulConnectionFuture<C, F>
132where
133    C: GracefulConnection,
134    F: Future,
135{
136    type Output = C::Output;
137
138    fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
139        let mut this = self.project();
140        if this.cancelled_guard.is_none() {
141            if let Poll::Ready(guard) = this.cancel.poll(cx) {
142                this.cancelled_guard.set(Some(guard));
143                this.conn.as_mut().graceful_shutdown();
144            }
145        }
146        this.conn.poll(cx)
147    }
148}
149
150/// An internal utility trait as an umbrella target for all (hyper) connection
151/// types that the [`GracefulShutdown`] can watch.
152pub trait GracefulConnection: Future<Output = Result<(), Self::Error>> + private::Sealed {
153    /// The error type returned by the connection when used as a future.
154    type Error;
155
156    /// Start a graceful shutdown process for this connection.
157    fn graceful_shutdown(self: Pin<&mut Self>);
158}
159
160#[cfg(feature = "http1")]
161impl<I, B, S> GracefulConnection for hyper::server::conn::http1::Connection<I, S>
162where
163    S: hyper::service::HttpService<hyper::body::Incoming, ResBody = B>,
164    S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
165    I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static,
166    B: hyper::body::Body + 'static,
167    B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
168{
169    type Error = hyper::Error;
170
171    fn graceful_shutdown(self: Pin<&mut Self>) {
172        hyper::server::conn::http1::Connection::graceful_shutdown(self);
173    }
174}
175
176#[cfg(feature = "http2")]
177impl<I, B, S, E> GracefulConnection for hyper::server::conn::http2::Connection<I, S, E>
178where
179    S: hyper::service::HttpService<hyper::body::Incoming, ResBody = B>,
180    S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
181    I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static,
182    B: hyper::body::Body + 'static,
183    B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
184    E: hyper::rt::bounds::Http2ServerConnExec<S::Future, B>,
185{
186    type Error = hyper::Error;
187
188    fn graceful_shutdown(self: Pin<&mut Self>) {
189        hyper::server::conn::http2::Connection::graceful_shutdown(self);
190    }
191}
192
193#[cfg(feature = "server-auto")]
194impl<I, B, S, E> GracefulConnection for crate::server::conn::auto::Connection<'_, I, S, E>
195where
196    S: hyper::service::Service<http::Request<hyper::body::Incoming>, Response = http::Response<B>>,
197    S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
198    S::Future: 'static,
199    I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static,
200    B: hyper::body::Body + 'static,
201    B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
202    E: hyper::rt::bounds::Http2ServerConnExec<S::Future, B>,
203{
204    type Error = Box<dyn std::error::Error + Send + Sync>;
205
206    fn graceful_shutdown(self: Pin<&mut Self>) {
207        crate::server::conn::auto::Connection::graceful_shutdown(self);
208    }
209}
210
211#[cfg(feature = "server-auto")]
212impl<I, B, S, E> GracefulConnection
213    for crate::server::conn::auto::UpgradeableConnection<'_, I, S, E>
214where
215    S: hyper::service::Service<http::Request<hyper::body::Incoming>, Response = http::Response<B>>,
216    S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
217    S::Future: 'static,
218    I: hyper::rt::Read + hyper::rt::Write + Unpin + Send + 'static,
219    B: hyper::body::Body + 'static,
220    B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
221    E: hyper::rt::bounds::Http2ServerConnExec<S::Future, B>,
222{
223    type Error = Box<dyn std::error::Error + Send + Sync>;
224
225    fn graceful_shutdown(self: Pin<&mut Self>) {
226        crate::server::conn::auto::UpgradeableConnection::graceful_shutdown(self);
227    }
228}
229
230mod private {
231    pub trait Sealed {}
232
233    #[cfg(feature = "http1")]
234    impl<I, B, S> Sealed for hyper::server::conn::http1::Connection<I, S>
235    where
236        S: hyper::service::HttpService<hyper::body::Incoming, ResBody = B>,
237        S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
238        I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static,
239        B: hyper::body::Body + 'static,
240        B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
241    {
242    }
243
244    #[cfg(feature = "http1")]
245    impl<I, B, S> Sealed for hyper::server::conn::http1::UpgradeableConnection<I, S>
246    where
247        S: hyper::service::HttpService<hyper::body::Incoming, ResBody = B>,
248        S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
249        I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static,
250        B: hyper::body::Body + 'static,
251        B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
252    {
253    }
254
255    #[cfg(feature = "http2")]
256    impl<I, B, S, E> Sealed for hyper::server::conn::http2::Connection<I, S, E>
257    where
258        S: hyper::service::HttpService<hyper::body::Incoming, ResBody = B>,
259        S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
260        I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static,
261        B: hyper::body::Body + 'static,
262        B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
263        E: hyper::rt::bounds::Http2ServerConnExec<S::Future, B>,
264    {
265    }
266
267    #[cfg(feature = "server-auto")]
268    impl<I, B, S, E> Sealed for crate::server::conn::auto::Connection<'_, I, S, E>
269    where
270        S: hyper::service::Service<
271            http::Request<hyper::body::Incoming>,
272            Response = http::Response<B>,
273        >,
274        S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
275        S::Future: 'static,
276        I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static,
277        B: hyper::body::Body + 'static,
278        B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
279        E: hyper::rt::bounds::Http2ServerConnExec<S::Future, B>,
280    {
281    }
282
283    #[cfg(feature = "server-auto")]
284    impl<I, B, S, E> Sealed for crate::server::conn::auto::UpgradeableConnection<'_, I, S, E>
285    where
286        S: hyper::service::Service<
287            http::Request<hyper::body::Incoming>,
288            Response = http::Response<B>,
289        >,
290        S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
291        S::Future: 'static,
292        I: hyper::rt::Read + hyper::rt::Write + Unpin + Send + 'static,
293        B: hyper::body::Body + 'static,
294        B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
295        E: hyper::rt::bounds::Http2ServerConnExec<S::Future, B>,
296    {
297    }
298}
299
300#[cfg(test)]
301mod test {
302    use super::*;
303    use pin_project_lite::pin_project;
304    use std::sync::atomic::{AtomicUsize, Ordering};
305    use std::sync::Arc;
306
307    pin_project! {
308        #[derive(Debug)]
309        struct DummyConnection<F> {
310            #[pin]
311            future: F,
312            shutdown_counter: Arc<AtomicUsize>,
313        }
314    }
315
316    impl<F> private::Sealed for DummyConnection<F> {}
317
318    impl<F: Future> GracefulConnection for DummyConnection<F> {
319        type Error = ();
320
321        fn graceful_shutdown(self: Pin<&mut Self>) {
322            self.shutdown_counter.fetch_add(1, Ordering::SeqCst);
323        }
324    }
325
326    impl<F: Future> Future for DummyConnection<F> {
327        type Output = Result<(), ()>;
328
329        fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
330            match self.project().future.poll(cx) {
331                Poll::Ready(_) => Poll::Ready(Ok(())),
332                Poll::Pending => Poll::Pending,
333            }
334        }
335    }
336
337    #[cfg(not(miri))]
338    #[tokio::test]
339    async fn test_graceful_shutdown_ok() {
340        let graceful = GracefulShutdown::new();
341        let shutdown_counter = Arc::new(AtomicUsize::new(0));
342        let (dummy_tx, _) = tokio::sync::broadcast::channel(1);
343
344        for i in 1..=3 {
345            let mut dummy_rx = dummy_tx.subscribe();
346            let shutdown_counter = shutdown_counter.clone();
347
348            let future = async move {
349                tokio::time::sleep(std::time::Duration::from_millis(i * 10)).await;
350                let _ = dummy_rx.recv().await;
351            };
352            let dummy_conn = DummyConnection {
353                future,
354                shutdown_counter,
355            };
356            let conn = graceful.watch(dummy_conn);
357            tokio::spawn(async move {
358                conn.await.unwrap();
359            });
360        }
361
362        assert_eq!(shutdown_counter.load(Ordering::SeqCst), 0);
363        let _ = dummy_tx.send(());
364
365        tokio::select! {
366            _ = tokio::time::sleep(std::time::Duration::from_millis(100)) => {
367                panic!("timeout")
368            },
369            _ = graceful.shutdown() => {
370                assert_eq!(shutdown_counter.load(Ordering::SeqCst), 3);
371            }
372        }
373    }
374
375    #[cfg(not(miri))]
376    #[tokio::test]
377    async fn test_graceful_shutdown_delayed_ok() {
378        let graceful = GracefulShutdown::new();
379        let shutdown_counter = Arc::new(AtomicUsize::new(0));
380
381        for i in 1..=3 {
382            let shutdown_counter = shutdown_counter.clone();
383
384            //tokio::time::sleep(std::time::Duration::from_millis(i * 5)).await;
385            let future = async move {
386                tokio::time::sleep(std::time::Duration::from_millis(i * 50)).await;
387            };
388            let dummy_conn = DummyConnection {
389                future,
390                shutdown_counter,
391            };
392            let conn = graceful.watch(dummy_conn);
393            tokio::spawn(async move {
394                conn.await.unwrap();
395            });
396        }
397
398        assert_eq!(shutdown_counter.load(Ordering::SeqCst), 0);
399
400        tokio::select! {
401            _ = tokio::time::sleep(std::time::Duration::from_millis(200)) => {
402                panic!("timeout")
403            },
404            _ = graceful.shutdown() => {
405                assert_eq!(shutdown_counter.load(Ordering::SeqCst), 3);
406            }
407        }
408    }
409
410    #[cfg(not(miri))]
411    #[tokio::test]
412    async fn test_graceful_shutdown_multi_per_watcher_ok() {
413        let graceful = GracefulShutdown::new();
414        let shutdown_counter = Arc::new(AtomicUsize::new(0));
415
416        for i in 1..=3 {
417            let shutdown_counter = shutdown_counter.clone();
418
419            let mut futures = Vec::new();
420            for u in 1..=i {
421                let future = tokio::time::sleep(std::time::Duration::from_millis(u * 50));
422                let dummy_conn = DummyConnection {
423                    future,
424                    shutdown_counter: shutdown_counter.clone(),
425                };
426                let conn = graceful.watch(dummy_conn);
427                futures.push(conn);
428            }
429            tokio::spawn(async move {
430                futures_util::future::join_all(futures).await;
431            });
432        }
433
434        assert_eq!(shutdown_counter.load(Ordering::SeqCst), 0);
435
436        tokio::select! {
437            _ = tokio::time::sleep(std::time::Duration::from_millis(200)) => {
438                panic!("timeout")
439            },
440            _ = graceful.shutdown() => {
441                assert_eq!(shutdown_counter.load(Ordering::SeqCst), 6);
442            }
443        }
444    }
445
446    #[cfg(not(miri))]
447    #[tokio::test]
448    async fn test_graceful_shutdown_timeout() {
449        let graceful = GracefulShutdown::new();
450        let shutdown_counter = Arc::new(AtomicUsize::new(0));
451
452        for i in 1..=3 {
453            let shutdown_counter = shutdown_counter.clone();
454
455            let future = async move {
456                if i == 1 {
457                    std::future::pending::<()>().await
458                } else {
459                    std::future::ready(()).await
460                }
461            };
462            let dummy_conn = DummyConnection {
463                future,
464                shutdown_counter,
465            };
466            let conn = graceful.watch(dummy_conn);
467            tokio::spawn(async move {
468                conn.await.unwrap();
469            });
470        }
471
472        assert_eq!(shutdown_counter.load(Ordering::SeqCst), 0);
473
474        tokio::select! {
475            _ = tokio::time::sleep(std::time::Duration::from_millis(100)) => {
476                assert_eq!(shutdown_counter.load(Ordering::SeqCst), 3);
477            },
478            _ = graceful.shutdown() => {
479                panic!("shutdown should not be completed: as not all our conns finish")
480            }
481        }
482    }
483}