hyper_util/server/
graceful.rs

1//! Utility to gracefully shutdown a server.
2//!
3//! This module provides a [`GracefulShutdown`] type,
4//! which can be used to gracefully shutdown a server.
5//!
6//! See <https://github.com/hyperium/hyper-util/blob/master/examples/server_graceful.rs>
7//! for an example of how to use this.
8
9use std::{
10    fmt::{self, Debug},
11    future::Future,
12    pin::Pin,
13    task::{self, Poll},
14};
15
16use pin_project_lite::pin_project;
17use tokio::sync::watch;
18
19/// A graceful shutdown utility
20// Purposefully not `Clone`, see `watcher()` method for why.
21pub struct GracefulShutdown {
22    tx: watch::Sender<()>,
23}
24
25/// A watcher side of the graceful shutdown.
26///
27/// This type can only watch a connection, it cannot trigger a shutdown.
28///
29/// Call [`GracefulShutdown::watcher()`] to construct one of these.
30pub struct Watcher {
31    rx: watch::Receiver<()>,
32}
33
34impl GracefulShutdown {
35    /// Create a new graceful shutdown helper.
36    pub fn new() -> Self {
37        let (tx, _) = watch::channel(());
38        Self { tx }
39    }
40
41    /// Wrap a future for graceful shutdown watching.
42    pub fn watch<C: GracefulConnection>(&self, conn: C) -> impl Future<Output = C::Output> {
43        self.watcher().watch(conn)
44    }
45
46    /// Create an owned type that can watch a connection.
47    ///
48    /// This method allows created an owned type that can be sent onto another
49    /// task before calling [`Watcher::watch()`].
50    // Internal: this function exists because `Clone` allows footguns.
51    // If the `tx` were cloned (or the `rx`), race conditions can happens where
52    // one task starting a shutdown is scheduled and interwined with a task
53    // starting to watch a connection, and the "watch version" is one behind.
54    pub fn watcher(&self) -> Watcher {
55        let rx = self.tx.subscribe();
56        Watcher { rx }
57    }
58
59    /// Signal shutdown for all watched connections.
60    ///
61    /// This returns a `Future` which will complete once all watched
62    /// connections have shutdown.
63    pub async fn shutdown(self) {
64        let Self { tx } = self;
65
66        // signal all the watched futures about the change
67        let _ = tx.send(());
68        // and then wait for all of them to complete
69        tx.closed().await;
70    }
71
72    /// Returns the number of the watching connections.
73    pub fn count(&self) -> usize {
74        self.tx.receiver_count()
75    }
76}
77
78impl Debug for GracefulShutdown {
79    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
80        f.debug_struct("GracefulShutdown").finish()
81    }
82}
83
84impl Default for GracefulShutdown {
85    fn default() -> Self {
86        Self::new()
87    }
88}
89
90impl Watcher {
91    /// Wrap a future for graceful shutdown watching.
92    pub fn watch<C: GracefulConnection>(self, conn: C) -> impl Future<Output = C::Output> {
93        let Watcher { mut rx } = self;
94        GracefulConnectionFuture::new(conn, async move {
95            let _ = rx.changed().await;
96            // hold onto the rx until the watched future is completed
97            rx
98        })
99    }
100}
101
102impl Debug for Watcher {
103    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
104        f.debug_struct("GracefulWatcher").finish()
105    }
106}
107
108pin_project! {
109    struct GracefulConnectionFuture<C, F: Future> {
110        #[pin]
111        conn: C,
112        #[pin]
113        cancel: F,
114        #[pin]
115        // If cancelled, this is held until the inner conn is done.
116        cancelled_guard: Option<F::Output>,
117    }
118}
119
120impl<C, F: Future> GracefulConnectionFuture<C, F> {
121    fn new(conn: C, cancel: F) -> Self {
122        Self {
123            conn,
124            cancel,
125            cancelled_guard: None,
126        }
127    }
128}
129
130impl<C, F: Future> Debug for GracefulConnectionFuture<C, F> {
131    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
132        f.debug_struct("GracefulConnectionFuture").finish()
133    }
134}
135
136impl<C, F> Future for GracefulConnectionFuture<C, F>
137where
138    C: GracefulConnection,
139    F: Future,
140{
141    type Output = C::Output;
142
143    fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
144        let mut this = self.project();
145        if this.cancelled_guard.is_none() {
146            if let Poll::Ready(guard) = this.cancel.poll(cx) {
147                this.cancelled_guard.set(Some(guard));
148                this.conn.as_mut().graceful_shutdown();
149            }
150        }
151        this.conn.poll(cx)
152    }
153}
154
155/// An internal utility trait as an umbrella target for all (hyper) connection
156/// types that the [`GracefulShutdown`] can watch.
157pub trait GracefulConnection: Future<Output = Result<(), Self::Error>> + private::Sealed {
158    /// The error type returned by the connection when used as a future.
159    type Error;
160
161    /// Start a graceful shutdown process for this connection.
162    fn graceful_shutdown(self: Pin<&mut Self>);
163}
164
165#[cfg(feature = "http1")]
166impl<I, B, S> GracefulConnection for hyper::server::conn::http1::Connection<I, S>
167where
168    S: hyper::service::HttpService<hyper::body::Incoming, ResBody = B>,
169    S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
170    I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static,
171    B: hyper::body::Body + 'static,
172    B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
173{
174    type Error = hyper::Error;
175
176    fn graceful_shutdown(self: Pin<&mut Self>) {
177        hyper::server::conn::http1::Connection::graceful_shutdown(self);
178    }
179}
180
181#[cfg(feature = "http2")]
182impl<I, B, S, E> GracefulConnection for hyper::server::conn::http2::Connection<I, S, E>
183where
184    S: hyper::service::HttpService<hyper::body::Incoming, ResBody = B>,
185    S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
186    I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static,
187    B: hyper::body::Body + 'static,
188    B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
189    E: hyper::rt::bounds::Http2ServerConnExec<S::Future, B>,
190{
191    type Error = hyper::Error;
192
193    fn graceful_shutdown(self: Pin<&mut Self>) {
194        hyper::server::conn::http2::Connection::graceful_shutdown(self);
195    }
196}
197
198#[cfg(feature = "server-auto")]
199impl<I, B, S, E> GracefulConnection for crate::server::conn::auto::Connection<'_, I, S, E>
200where
201    S: hyper::service::Service<http::Request<hyper::body::Incoming>, Response = http::Response<B>>,
202    S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
203    S::Future: 'static,
204    I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static,
205    B: hyper::body::Body + 'static,
206    B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
207    E: hyper::rt::bounds::Http2ServerConnExec<S::Future, B>,
208{
209    type Error = Box<dyn std::error::Error + Send + Sync>;
210
211    fn graceful_shutdown(self: Pin<&mut Self>) {
212        crate::server::conn::auto::Connection::graceful_shutdown(self);
213    }
214}
215
216#[cfg(feature = "server-auto")]
217impl<I, B, S, E> GracefulConnection
218    for crate::server::conn::auto::UpgradeableConnection<'_, I, S, E>
219where
220    S: hyper::service::Service<http::Request<hyper::body::Incoming>, Response = http::Response<B>>,
221    S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
222    S::Future: 'static,
223    I: hyper::rt::Read + hyper::rt::Write + Unpin + Send + 'static,
224    B: hyper::body::Body + 'static,
225    B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
226    E: hyper::rt::bounds::Http2ServerConnExec<S::Future, B>,
227{
228    type Error = Box<dyn std::error::Error + Send + Sync>;
229
230    fn graceful_shutdown(self: Pin<&mut Self>) {
231        crate::server::conn::auto::UpgradeableConnection::graceful_shutdown(self);
232    }
233}
234
235mod private {
236    pub trait Sealed {}
237
238    #[cfg(feature = "http1")]
239    impl<I, B, S> Sealed for hyper::server::conn::http1::Connection<I, S>
240    where
241        S: hyper::service::HttpService<hyper::body::Incoming, ResBody = B>,
242        S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
243        I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static,
244        B: hyper::body::Body + 'static,
245        B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
246    {
247    }
248
249    #[cfg(feature = "http1")]
250    impl<I, B, S> Sealed for hyper::server::conn::http1::UpgradeableConnection<I, S>
251    where
252        S: hyper::service::HttpService<hyper::body::Incoming, ResBody = B>,
253        S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
254        I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static,
255        B: hyper::body::Body + 'static,
256        B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
257    {
258    }
259
260    #[cfg(feature = "http2")]
261    impl<I, B, S, E> Sealed for hyper::server::conn::http2::Connection<I, S, E>
262    where
263        S: hyper::service::HttpService<hyper::body::Incoming, ResBody = B>,
264        S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
265        I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static,
266        B: hyper::body::Body + 'static,
267        B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
268        E: hyper::rt::bounds::Http2ServerConnExec<S::Future, B>,
269    {
270    }
271
272    #[cfg(feature = "server-auto")]
273    impl<I, B, S, E> Sealed for crate::server::conn::auto::Connection<'_, I, S, E>
274    where
275        S: hyper::service::Service<
276            http::Request<hyper::body::Incoming>,
277            Response = http::Response<B>,
278        >,
279        S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
280        S::Future: 'static,
281        I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static,
282        B: hyper::body::Body + 'static,
283        B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
284        E: hyper::rt::bounds::Http2ServerConnExec<S::Future, B>,
285    {
286    }
287
288    #[cfg(feature = "server-auto")]
289    impl<I, B, S, E> Sealed for crate::server::conn::auto::UpgradeableConnection<'_, I, S, E>
290    where
291        S: hyper::service::Service<
292            http::Request<hyper::body::Incoming>,
293            Response = http::Response<B>,
294        >,
295        S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
296        S::Future: 'static,
297        I: hyper::rt::Read + hyper::rt::Write + Unpin + Send + 'static,
298        B: hyper::body::Body + 'static,
299        B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
300        E: hyper::rt::bounds::Http2ServerConnExec<S::Future, B>,
301    {
302    }
303}
304
305#[cfg(test)]
306mod test {
307    use super::*;
308    use pin_project_lite::pin_project;
309    use std::sync::atomic::{AtomicUsize, Ordering};
310    use std::sync::Arc;
311
312    pin_project! {
313        #[derive(Debug)]
314        struct DummyConnection<F> {
315            #[pin]
316            future: F,
317            shutdown_counter: Arc<AtomicUsize>,
318        }
319    }
320
321    impl<F> private::Sealed for DummyConnection<F> {}
322
323    impl<F: Future> GracefulConnection for DummyConnection<F> {
324        type Error = ();
325
326        fn graceful_shutdown(self: Pin<&mut Self>) {
327            self.shutdown_counter.fetch_add(1, Ordering::SeqCst);
328        }
329    }
330
331    impl<F: Future> Future for DummyConnection<F> {
332        type Output = Result<(), ()>;
333
334        fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
335            match self.project().future.poll(cx) {
336                Poll::Ready(_) => Poll::Ready(Ok(())),
337                Poll::Pending => Poll::Pending,
338            }
339        }
340    }
341
342    #[cfg(not(miri))]
343    #[tokio::test]
344    async fn test_graceful_shutdown_ok() {
345        let graceful = GracefulShutdown::new();
346        let shutdown_counter = Arc::new(AtomicUsize::new(0));
347        let (dummy_tx, _) = tokio::sync::broadcast::channel(1);
348
349        for i in 1..=3 {
350            let mut dummy_rx = dummy_tx.subscribe();
351            let shutdown_counter = shutdown_counter.clone();
352
353            let future = async move {
354                tokio::time::sleep(std::time::Duration::from_millis(i * 10)).await;
355                let _ = dummy_rx.recv().await;
356            };
357            let dummy_conn = DummyConnection {
358                future,
359                shutdown_counter,
360            };
361            let conn = graceful.watch(dummy_conn);
362            tokio::spawn(async move {
363                conn.await.unwrap();
364            });
365        }
366
367        assert_eq!(shutdown_counter.load(Ordering::SeqCst), 0);
368        let _ = dummy_tx.send(());
369
370        tokio::select! {
371            _ = tokio::time::sleep(std::time::Duration::from_millis(100)) => {
372                panic!("timeout")
373            },
374            _ = graceful.shutdown() => {
375                assert_eq!(shutdown_counter.load(Ordering::SeqCst), 3);
376            }
377        }
378    }
379
380    #[cfg(not(miri))]
381    #[tokio::test]
382    async fn test_graceful_shutdown_delayed_ok() {
383        let graceful = GracefulShutdown::new();
384        let shutdown_counter = Arc::new(AtomicUsize::new(0));
385
386        for i in 1..=3 {
387            let shutdown_counter = shutdown_counter.clone();
388
389            //tokio::time::sleep(std::time::Duration::from_millis(i * 5)).await;
390            let future = async move {
391                tokio::time::sleep(std::time::Duration::from_millis(i * 50)).await;
392            };
393            let dummy_conn = DummyConnection {
394                future,
395                shutdown_counter,
396            };
397            let conn = graceful.watch(dummy_conn);
398            tokio::spawn(async move {
399                conn.await.unwrap();
400            });
401        }
402
403        assert_eq!(shutdown_counter.load(Ordering::SeqCst), 0);
404
405        tokio::select! {
406            _ = tokio::time::sleep(std::time::Duration::from_millis(200)) => {
407                panic!("timeout")
408            },
409            _ = graceful.shutdown() => {
410                assert_eq!(shutdown_counter.load(Ordering::SeqCst), 3);
411            }
412        }
413    }
414
415    #[cfg(not(miri))]
416    #[tokio::test]
417    async fn test_graceful_shutdown_multi_per_watcher_ok() {
418        let graceful = GracefulShutdown::new();
419        let shutdown_counter = Arc::new(AtomicUsize::new(0));
420
421        for i in 1..=3 {
422            let shutdown_counter = shutdown_counter.clone();
423
424            let mut futures = Vec::new();
425            for u in 1..=i {
426                let future = tokio::time::sleep(std::time::Duration::from_millis(u * 50));
427                let dummy_conn = DummyConnection {
428                    future,
429                    shutdown_counter: shutdown_counter.clone(),
430                };
431                let conn = graceful.watch(dummy_conn);
432                futures.push(conn);
433            }
434            tokio::spawn(async move {
435                futures_util::future::join_all(futures).await;
436            });
437        }
438
439        assert_eq!(shutdown_counter.load(Ordering::SeqCst), 0);
440
441        tokio::select! {
442            _ = tokio::time::sleep(std::time::Duration::from_millis(200)) => {
443                panic!("timeout")
444            },
445            _ = graceful.shutdown() => {
446                assert_eq!(shutdown_counter.load(Ordering::SeqCst), 6);
447            }
448        }
449    }
450
451    #[cfg(not(miri))]
452    #[tokio::test]
453    async fn test_graceful_shutdown_timeout() {
454        let graceful = GracefulShutdown::new();
455        let shutdown_counter = Arc::new(AtomicUsize::new(0));
456
457        for i in 1..=3 {
458            let shutdown_counter = shutdown_counter.clone();
459
460            let future = async move {
461                if i == 1 {
462                    std::future::pending::<()>().await
463                } else {
464                    std::future::ready(()).await
465                }
466            };
467            let dummy_conn = DummyConnection {
468                future,
469                shutdown_counter,
470            };
471            let conn = graceful.watch(dummy_conn);
472            tokio::spawn(async move {
473                conn.await.unwrap();
474            });
475        }
476
477        assert_eq!(shutdown_counter.load(Ordering::SeqCst), 0);
478
479        tokio::select! {
480            _ = tokio::time::sleep(std::time::Duration::from_millis(100)) => {
481                assert_eq!(shutdown_counter.load(Ordering::SeqCst), 3);
482            },
483            _ = graceful.shutdown() => {
484                panic!("shutdown should not be completed: as not all our conns finish")
485            }
486        }
487    }
488}