1use std::{
10 fmt::{self, Debug},
11 future::Future,
12 pin::Pin,
13 task::{self, Poll},
14};
15
16use pin_project_lite::pin_project;
17use tokio::sync::watch;
18
19pub struct GracefulShutdown {
22 tx: watch::Sender<()>,
23}
24
25pub struct Watcher {
31 rx: watch::Receiver<()>,
32}
33
34impl GracefulShutdown {
35 pub fn new() -> Self {
37 let (tx, _) = watch::channel(());
38 Self { tx }
39 }
40
41 pub fn watch<C: GracefulConnection>(&self, conn: C) -> impl Future<Output = C::Output> {
43 self.watcher().watch(conn)
44 }
45
46 pub fn watcher(&self) -> Watcher {
55 let rx = self.tx.subscribe();
56 Watcher { rx }
57 }
58
59 pub async fn shutdown(self) {
64 let Self { tx } = self;
65
66 let _ = tx.send(());
68 tx.closed().await;
70 }
71}
72
73impl Debug for GracefulShutdown {
74 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
75 f.debug_struct("GracefulShutdown").finish()
76 }
77}
78
79impl Default for GracefulShutdown {
80 fn default() -> Self {
81 Self::new()
82 }
83}
84
85impl Watcher {
86 pub fn watch<C: GracefulConnection>(self, conn: C) -> impl Future<Output = C::Output> {
88 let Watcher { mut rx } = self;
89 GracefulConnectionFuture::new(conn, async move {
90 let _ = rx.changed().await;
91 rx
93 })
94 }
95}
96
97impl Debug for Watcher {
98 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
99 f.debug_struct("GracefulWatcher").finish()
100 }
101}
102
103pin_project! {
104 struct GracefulConnectionFuture<C, F: Future> {
105 #[pin]
106 conn: C,
107 #[pin]
108 cancel: F,
109 #[pin]
110 cancelled_guard: Option<F::Output>,
112 }
113}
114
115impl<C, F: Future> GracefulConnectionFuture<C, F> {
116 fn new(conn: C, cancel: F) -> Self {
117 Self {
118 conn,
119 cancel,
120 cancelled_guard: None,
121 }
122 }
123}
124
125impl<C, F: Future> Debug for GracefulConnectionFuture<C, F> {
126 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127 f.debug_struct("GracefulConnectionFuture").finish()
128 }
129}
130
131impl<C, F> Future for GracefulConnectionFuture<C, F>
132where
133 C: GracefulConnection,
134 F: Future,
135{
136 type Output = C::Output;
137
138 fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
139 let mut this = self.project();
140 if this.cancelled_guard.is_none() {
141 if let Poll::Ready(guard) = this.cancel.poll(cx) {
142 this.cancelled_guard.set(Some(guard));
143 this.conn.as_mut().graceful_shutdown();
144 }
145 }
146 this.conn.poll(cx)
147 }
148}
149
150pub trait GracefulConnection: Future<Output = Result<(), Self::Error>> + private::Sealed {
153 type Error;
155
156 fn graceful_shutdown(self: Pin<&mut Self>);
158}
159
160#[cfg(feature = "http1")]
161impl<I, B, S> GracefulConnection for hyper::server::conn::http1::Connection<I, S>
162where
163 S: hyper::service::HttpService<hyper::body::Incoming, ResBody = B>,
164 S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
165 I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static,
166 B: hyper::body::Body + 'static,
167 B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
168{
169 type Error = hyper::Error;
170
171 fn graceful_shutdown(self: Pin<&mut Self>) {
172 hyper::server::conn::http1::Connection::graceful_shutdown(self);
173 }
174}
175
176#[cfg(feature = "http2")]
177impl<I, B, S, E> GracefulConnection for hyper::server::conn::http2::Connection<I, S, E>
178where
179 S: hyper::service::HttpService<hyper::body::Incoming, ResBody = B>,
180 S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
181 I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static,
182 B: hyper::body::Body + 'static,
183 B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
184 E: hyper::rt::bounds::Http2ServerConnExec<S::Future, B>,
185{
186 type Error = hyper::Error;
187
188 fn graceful_shutdown(self: Pin<&mut Self>) {
189 hyper::server::conn::http2::Connection::graceful_shutdown(self);
190 }
191}
192
193#[cfg(feature = "server-auto")]
194impl<I, B, S, E> GracefulConnection for crate::server::conn::auto::Connection<'_, I, S, E>
195where
196 S: hyper::service::Service<http::Request<hyper::body::Incoming>, Response = http::Response<B>>,
197 S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
198 S::Future: 'static,
199 I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static,
200 B: hyper::body::Body + 'static,
201 B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
202 E: hyper::rt::bounds::Http2ServerConnExec<S::Future, B>,
203{
204 type Error = Box<dyn std::error::Error + Send + Sync>;
205
206 fn graceful_shutdown(self: Pin<&mut Self>) {
207 crate::server::conn::auto::Connection::graceful_shutdown(self);
208 }
209}
210
211#[cfg(feature = "server-auto")]
212impl<I, B, S, E> GracefulConnection
213 for crate::server::conn::auto::UpgradeableConnection<'_, I, S, E>
214where
215 S: hyper::service::Service<http::Request<hyper::body::Incoming>, Response = http::Response<B>>,
216 S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
217 S::Future: 'static,
218 I: hyper::rt::Read + hyper::rt::Write + Unpin + Send + 'static,
219 B: hyper::body::Body + 'static,
220 B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
221 E: hyper::rt::bounds::Http2ServerConnExec<S::Future, B>,
222{
223 type Error = Box<dyn std::error::Error + Send + Sync>;
224
225 fn graceful_shutdown(self: Pin<&mut Self>) {
226 crate::server::conn::auto::UpgradeableConnection::graceful_shutdown(self);
227 }
228}
229
230mod private {
231 pub trait Sealed {}
232
233 #[cfg(feature = "http1")]
234 impl<I, B, S> Sealed for hyper::server::conn::http1::Connection<I, S>
235 where
236 S: hyper::service::HttpService<hyper::body::Incoming, ResBody = B>,
237 S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
238 I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static,
239 B: hyper::body::Body + 'static,
240 B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
241 {
242 }
243
244 #[cfg(feature = "http1")]
245 impl<I, B, S> Sealed for hyper::server::conn::http1::UpgradeableConnection<I, S>
246 where
247 S: hyper::service::HttpService<hyper::body::Incoming, ResBody = B>,
248 S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
249 I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static,
250 B: hyper::body::Body + 'static,
251 B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
252 {
253 }
254
255 #[cfg(feature = "http2")]
256 impl<I, B, S, E> Sealed for hyper::server::conn::http2::Connection<I, S, E>
257 where
258 S: hyper::service::HttpService<hyper::body::Incoming, ResBody = B>,
259 S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
260 I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static,
261 B: hyper::body::Body + 'static,
262 B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
263 E: hyper::rt::bounds::Http2ServerConnExec<S::Future, B>,
264 {
265 }
266
267 #[cfg(feature = "server-auto")]
268 impl<I, B, S, E> Sealed for crate::server::conn::auto::Connection<'_, I, S, E>
269 where
270 S: hyper::service::Service<
271 http::Request<hyper::body::Incoming>,
272 Response = http::Response<B>,
273 >,
274 S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
275 S::Future: 'static,
276 I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static,
277 B: hyper::body::Body + 'static,
278 B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
279 E: hyper::rt::bounds::Http2ServerConnExec<S::Future, B>,
280 {
281 }
282
283 #[cfg(feature = "server-auto")]
284 impl<I, B, S, E> Sealed for crate::server::conn::auto::UpgradeableConnection<'_, I, S, E>
285 where
286 S: hyper::service::Service<
287 http::Request<hyper::body::Incoming>,
288 Response = http::Response<B>,
289 >,
290 S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
291 S::Future: 'static,
292 I: hyper::rt::Read + hyper::rt::Write + Unpin + Send + 'static,
293 B: hyper::body::Body + 'static,
294 B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
295 E: hyper::rt::bounds::Http2ServerConnExec<S::Future, B>,
296 {
297 }
298}
299
300#[cfg(test)]
301mod test {
302 use super::*;
303 use pin_project_lite::pin_project;
304 use std::sync::atomic::{AtomicUsize, Ordering};
305 use std::sync::Arc;
306
307 pin_project! {
308 #[derive(Debug)]
309 struct DummyConnection<F> {
310 #[pin]
311 future: F,
312 shutdown_counter: Arc<AtomicUsize>,
313 }
314 }
315
316 impl<F> private::Sealed for DummyConnection<F> {}
317
318 impl<F: Future> GracefulConnection for DummyConnection<F> {
319 type Error = ();
320
321 fn graceful_shutdown(self: Pin<&mut Self>) {
322 self.shutdown_counter.fetch_add(1, Ordering::SeqCst);
323 }
324 }
325
326 impl<F: Future> Future for DummyConnection<F> {
327 type Output = Result<(), ()>;
328
329 fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
330 match self.project().future.poll(cx) {
331 Poll::Ready(_) => Poll::Ready(Ok(())),
332 Poll::Pending => Poll::Pending,
333 }
334 }
335 }
336
337 #[cfg(not(miri))]
338 #[tokio::test]
339 async fn test_graceful_shutdown_ok() {
340 let graceful = GracefulShutdown::new();
341 let shutdown_counter = Arc::new(AtomicUsize::new(0));
342 let (dummy_tx, _) = tokio::sync::broadcast::channel(1);
343
344 for i in 1..=3 {
345 let mut dummy_rx = dummy_tx.subscribe();
346 let shutdown_counter = shutdown_counter.clone();
347
348 let future = async move {
349 tokio::time::sleep(std::time::Duration::from_millis(i * 10)).await;
350 let _ = dummy_rx.recv().await;
351 };
352 let dummy_conn = DummyConnection {
353 future,
354 shutdown_counter,
355 };
356 let conn = graceful.watch(dummy_conn);
357 tokio::spawn(async move {
358 conn.await.unwrap();
359 });
360 }
361
362 assert_eq!(shutdown_counter.load(Ordering::SeqCst), 0);
363 let _ = dummy_tx.send(());
364
365 tokio::select! {
366 _ = tokio::time::sleep(std::time::Duration::from_millis(100)) => {
367 panic!("timeout")
368 },
369 _ = graceful.shutdown() => {
370 assert_eq!(shutdown_counter.load(Ordering::SeqCst), 3);
371 }
372 }
373 }
374
375 #[cfg(not(miri))]
376 #[tokio::test]
377 async fn test_graceful_shutdown_delayed_ok() {
378 let graceful = GracefulShutdown::new();
379 let shutdown_counter = Arc::new(AtomicUsize::new(0));
380
381 for i in 1..=3 {
382 let shutdown_counter = shutdown_counter.clone();
383
384 let future = async move {
386 tokio::time::sleep(std::time::Duration::from_millis(i * 50)).await;
387 };
388 let dummy_conn = DummyConnection {
389 future,
390 shutdown_counter,
391 };
392 let conn = graceful.watch(dummy_conn);
393 tokio::spawn(async move {
394 conn.await.unwrap();
395 });
396 }
397
398 assert_eq!(shutdown_counter.load(Ordering::SeqCst), 0);
399
400 tokio::select! {
401 _ = tokio::time::sleep(std::time::Duration::from_millis(200)) => {
402 panic!("timeout")
403 },
404 _ = graceful.shutdown() => {
405 assert_eq!(shutdown_counter.load(Ordering::SeqCst), 3);
406 }
407 }
408 }
409
410 #[cfg(not(miri))]
411 #[tokio::test]
412 async fn test_graceful_shutdown_multi_per_watcher_ok() {
413 let graceful = GracefulShutdown::new();
414 let shutdown_counter = Arc::new(AtomicUsize::new(0));
415
416 for i in 1..=3 {
417 let shutdown_counter = shutdown_counter.clone();
418
419 let mut futures = Vec::new();
420 for u in 1..=i {
421 let future = tokio::time::sleep(std::time::Duration::from_millis(u * 50));
422 let dummy_conn = DummyConnection {
423 future,
424 shutdown_counter: shutdown_counter.clone(),
425 };
426 let conn = graceful.watch(dummy_conn);
427 futures.push(conn);
428 }
429 tokio::spawn(async move {
430 futures_util::future::join_all(futures).await;
431 });
432 }
433
434 assert_eq!(shutdown_counter.load(Ordering::SeqCst), 0);
435
436 tokio::select! {
437 _ = tokio::time::sleep(std::time::Duration::from_millis(200)) => {
438 panic!("timeout")
439 },
440 _ = graceful.shutdown() => {
441 assert_eq!(shutdown_counter.load(Ordering::SeqCst), 6);
442 }
443 }
444 }
445
446 #[cfg(not(miri))]
447 #[tokio::test]
448 async fn test_graceful_shutdown_timeout() {
449 let graceful = GracefulShutdown::new();
450 let shutdown_counter = Arc::new(AtomicUsize::new(0));
451
452 for i in 1..=3 {
453 let shutdown_counter = shutdown_counter.clone();
454
455 let future = async move {
456 if i == 1 {
457 std::future::pending::<()>().await
458 } else {
459 std::future::ready(()).await
460 }
461 };
462 let dummy_conn = DummyConnection {
463 future,
464 shutdown_counter,
465 };
466 let conn = graceful.watch(dummy_conn);
467 tokio::spawn(async move {
468 conn.await.unwrap();
469 });
470 }
471
472 assert_eq!(shutdown_counter.load(Ordering::SeqCst), 0);
473
474 tokio::select! {
475 _ = tokio::time::sleep(std::time::Duration::from_millis(100)) => {
476 assert_eq!(shutdown_counter.load(Ordering::SeqCst), 3);
477 },
478 _ = graceful.shutdown() => {
479 panic!("shutdown should not be completed: as not all our conns finish")
480 }
481 }
482 }
483}