1use std::{
10 fmt::{self, Debug},
11 future::Future,
12 pin::Pin,
13 task::{self, Poll},
14};
15
16use pin_project_lite::pin_project;
17use tokio::sync::watch;
18
19pub struct GracefulShutdown {
22 tx: watch::Sender<()>,
23}
24
25pub struct Watcher {
31 rx: watch::Receiver<()>,
32}
33
34impl GracefulShutdown {
35 pub fn new() -> Self {
37 let (tx, _) = watch::channel(());
38 Self { tx }
39 }
40
41 pub fn watch<C: GracefulConnection>(&self, conn: C) -> impl Future<Output = C::Output> {
43 self.watcher().watch(conn)
44 }
45
46 pub fn watcher(&self) -> Watcher {
55 let rx = self.tx.subscribe();
56 Watcher { rx }
57 }
58
59 pub async fn shutdown(self) {
64 let Self { tx } = self;
65
66 let _ = tx.send(());
68 tx.closed().await;
70 }
71
72 pub fn count(&self) -> usize {
74 self.tx.receiver_count()
75 }
76}
77
78impl Debug for GracefulShutdown {
79 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
80 f.debug_struct("GracefulShutdown").finish()
81 }
82}
83
84impl Default for GracefulShutdown {
85 fn default() -> Self {
86 Self::new()
87 }
88}
89
90impl Watcher {
91 pub fn watch<C: GracefulConnection>(self, conn: C) -> impl Future<Output = C::Output> {
93 let Watcher { mut rx } = self;
94 GracefulConnectionFuture::new(conn, async move {
95 let _ = rx.changed().await;
96 rx
98 })
99 }
100}
101
102impl Debug for Watcher {
103 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
104 f.debug_struct("GracefulWatcher").finish()
105 }
106}
107
108pin_project! {
109 struct GracefulConnectionFuture<C, F: Future> {
110 #[pin]
111 conn: C,
112 #[pin]
113 cancel: F,
114 #[pin]
115 cancelled_guard: Option<F::Output>,
117 }
118}
119
120impl<C, F: Future> GracefulConnectionFuture<C, F> {
121 fn new(conn: C, cancel: F) -> Self {
122 Self {
123 conn,
124 cancel,
125 cancelled_guard: None,
126 }
127 }
128}
129
130impl<C, F: Future> Debug for GracefulConnectionFuture<C, F> {
131 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
132 f.debug_struct("GracefulConnectionFuture").finish()
133 }
134}
135
136impl<C, F> Future for GracefulConnectionFuture<C, F>
137where
138 C: GracefulConnection,
139 F: Future,
140{
141 type Output = C::Output;
142
143 fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
144 let mut this = self.project();
145 if this.cancelled_guard.is_none() {
146 if let Poll::Ready(guard) = this.cancel.poll(cx) {
147 this.cancelled_guard.set(Some(guard));
148 this.conn.as_mut().graceful_shutdown();
149 }
150 }
151 this.conn.poll(cx)
152 }
153}
154
155pub trait GracefulConnection: Future<Output = Result<(), Self::Error>> + private::Sealed {
158 type Error;
160
161 fn graceful_shutdown(self: Pin<&mut Self>);
163}
164
165#[cfg(feature = "http1")]
166impl<I, B, S> GracefulConnection for hyper::server::conn::http1::Connection<I, S>
167where
168 S: hyper::service::HttpService<hyper::body::Incoming, ResBody = B>,
169 S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
170 I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static,
171 B: hyper::body::Body + 'static,
172 B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
173{
174 type Error = hyper::Error;
175
176 fn graceful_shutdown(self: Pin<&mut Self>) {
177 hyper::server::conn::http1::Connection::graceful_shutdown(self);
178 }
179}
180
181#[cfg(feature = "http2")]
182impl<I, B, S, E> GracefulConnection for hyper::server::conn::http2::Connection<I, S, E>
183where
184 S: hyper::service::HttpService<hyper::body::Incoming, ResBody = B>,
185 S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
186 I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static,
187 B: hyper::body::Body + 'static,
188 B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
189 E: hyper::rt::bounds::Http2ServerConnExec<S::Future, B>,
190{
191 type Error = hyper::Error;
192
193 fn graceful_shutdown(self: Pin<&mut Self>) {
194 hyper::server::conn::http2::Connection::graceful_shutdown(self);
195 }
196}
197
198#[cfg(feature = "server-auto")]
199impl<I, B, S, E> GracefulConnection for crate::server::conn::auto::Connection<'_, I, S, E>
200where
201 S: hyper::service::Service<http::Request<hyper::body::Incoming>, Response = http::Response<B>>,
202 S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
203 S::Future: 'static,
204 I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static,
205 B: hyper::body::Body + 'static,
206 B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
207 E: hyper::rt::bounds::Http2ServerConnExec<S::Future, B>,
208{
209 type Error = Box<dyn std::error::Error + Send + Sync>;
210
211 fn graceful_shutdown(self: Pin<&mut Self>) {
212 crate::server::conn::auto::Connection::graceful_shutdown(self);
213 }
214}
215
216#[cfg(feature = "server-auto")]
217impl<I, B, S, E> GracefulConnection
218 for crate::server::conn::auto::UpgradeableConnection<'_, I, S, E>
219where
220 S: hyper::service::Service<http::Request<hyper::body::Incoming>, Response = http::Response<B>>,
221 S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
222 S::Future: 'static,
223 I: hyper::rt::Read + hyper::rt::Write + Unpin + Send + 'static,
224 B: hyper::body::Body + 'static,
225 B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
226 E: hyper::rt::bounds::Http2ServerConnExec<S::Future, B>,
227{
228 type Error = Box<dyn std::error::Error + Send + Sync>;
229
230 fn graceful_shutdown(self: Pin<&mut Self>) {
231 crate::server::conn::auto::UpgradeableConnection::graceful_shutdown(self);
232 }
233}
234
235mod private {
236 pub trait Sealed {}
237
238 #[cfg(feature = "http1")]
239 impl<I, B, S> Sealed for hyper::server::conn::http1::Connection<I, S>
240 where
241 S: hyper::service::HttpService<hyper::body::Incoming, ResBody = B>,
242 S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
243 I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static,
244 B: hyper::body::Body + 'static,
245 B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
246 {
247 }
248
249 #[cfg(feature = "http1")]
250 impl<I, B, S> Sealed for hyper::server::conn::http1::UpgradeableConnection<I, S>
251 where
252 S: hyper::service::HttpService<hyper::body::Incoming, ResBody = B>,
253 S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
254 I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static,
255 B: hyper::body::Body + 'static,
256 B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
257 {
258 }
259
260 #[cfg(feature = "http2")]
261 impl<I, B, S, E> Sealed for hyper::server::conn::http2::Connection<I, S, E>
262 where
263 S: hyper::service::HttpService<hyper::body::Incoming, ResBody = B>,
264 S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
265 I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static,
266 B: hyper::body::Body + 'static,
267 B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
268 E: hyper::rt::bounds::Http2ServerConnExec<S::Future, B>,
269 {
270 }
271
272 #[cfg(feature = "server-auto")]
273 impl<I, B, S, E> Sealed for crate::server::conn::auto::Connection<'_, I, S, E>
274 where
275 S: hyper::service::Service<
276 http::Request<hyper::body::Incoming>,
277 Response = http::Response<B>,
278 >,
279 S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
280 S::Future: 'static,
281 I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static,
282 B: hyper::body::Body + 'static,
283 B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
284 E: hyper::rt::bounds::Http2ServerConnExec<S::Future, B>,
285 {
286 }
287
288 #[cfg(feature = "server-auto")]
289 impl<I, B, S, E> Sealed for crate::server::conn::auto::UpgradeableConnection<'_, I, S, E>
290 where
291 S: hyper::service::Service<
292 http::Request<hyper::body::Incoming>,
293 Response = http::Response<B>,
294 >,
295 S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
296 S::Future: 'static,
297 I: hyper::rt::Read + hyper::rt::Write + Unpin + Send + 'static,
298 B: hyper::body::Body + 'static,
299 B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
300 E: hyper::rt::bounds::Http2ServerConnExec<S::Future, B>,
301 {
302 }
303}
304
305#[cfg(test)]
306mod test {
307 use super::*;
308 use pin_project_lite::pin_project;
309 use std::sync::atomic::{AtomicUsize, Ordering};
310 use std::sync::Arc;
311
312 pin_project! {
313 #[derive(Debug)]
314 struct DummyConnection<F> {
315 #[pin]
316 future: F,
317 shutdown_counter: Arc<AtomicUsize>,
318 }
319 }
320
321 impl<F> private::Sealed for DummyConnection<F> {}
322
323 impl<F: Future> GracefulConnection for DummyConnection<F> {
324 type Error = ();
325
326 fn graceful_shutdown(self: Pin<&mut Self>) {
327 self.shutdown_counter.fetch_add(1, Ordering::SeqCst);
328 }
329 }
330
331 impl<F: Future> Future for DummyConnection<F> {
332 type Output = Result<(), ()>;
333
334 fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
335 match self.project().future.poll(cx) {
336 Poll::Ready(_) => Poll::Ready(Ok(())),
337 Poll::Pending => Poll::Pending,
338 }
339 }
340 }
341
342 #[cfg(not(miri))]
343 #[tokio::test]
344 async fn test_graceful_shutdown_ok() {
345 let graceful = GracefulShutdown::new();
346 let shutdown_counter = Arc::new(AtomicUsize::new(0));
347 let (dummy_tx, _) = tokio::sync::broadcast::channel(1);
348
349 for i in 1..=3 {
350 let mut dummy_rx = dummy_tx.subscribe();
351 let shutdown_counter = shutdown_counter.clone();
352
353 let future = async move {
354 tokio::time::sleep(std::time::Duration::from_millis(i * 10)).await;
355 let _ = dummy_rx.recv().await;
356 };
357 let dummy_conn = DummyConnection {
358 future,
359 shutdown_counter,
360 };
361 let conn = graceful.watch(dummy_conn);
362 tokio::spawn(async move {
363 conn.await.unwrap();
364 });
365 }
366
367 assert_eq!(shutdown_counter.load(Ordering::SeqCst), 0);
368 let _ = dummy_tx.send(());
369
370 tokio::select! {
371 _ = tokio::time::sleep(std::time::Duration::from_millis(100)) => {
372 panic!("timeout")
373 },
374 _ = graceful.shutdown() => {
375 assert_eq!(shutdown_counter.load(Ordering::SeqCst), 3);
376 }
377 }
378 }
379
380 #[cfg(not(miri))]
381 #[tokio::test]
382 async fn test_graceful_shutdown_delayed_ok() {
383 let graceful = GracefulShutdown::new();
384 let shutdown_counter = Arc::new(AtomicUsize::new(0));
385
386 for i in 1..=3 {
387 let shutdown_counter = shutdown_counter.clone();
388
389 let future = async move {
391 tokio::time::sleep(std::time::Duration::from_millis(i * 50)).await;
392 };
393 let dummy_conn = DummyConnection {
394 future,
395 shutdown_counter,
396 };
397 let conn = graceful.watch(dummy_conn);
398 tokio::spawn(async move {
399 conn.await.unwrap();
400 });
401 }
402
403 assert_eq!(shutdown_counter.load(Ordering::SeqCst), 0);
404
405 tokio::select! {
406 _ = tokio::time::sleep(std::time::Duration::from_millis(200)) => {
407 panic!("timeout")
408 },
409 _ = graceful.shutdown() => {
410 assert_eq!(shutdown_counter.load(Ordering::SeqCst), 3);
411 }
412 }
413 }
414
415 #[cfg(not(miri))]
416 #[tokio::test]
417 async fn test_graceful_shutdown_multi_per_watcher_ok() {
418 let graceful = GracefulShutdown::new();
419 let shutdown_counter = Arc::new(AtomicUsize::new(0));
420
421 for i in 1..=3 {
422 let shutdown_counter = shutdown_counter.clone();
423
424 let mut futures = Vec::new();
425 for u in 1..=i {
426 let future = tokio::time::sleep(std::time::Duration::from_millis(u * 50));
427 let dummy_conn = DummyConnection {
428 future,
429 shutdown_counter: shutdown_counter.clone(),
430 };
431 let conn = graceful.watch(dummy_conn);
432 futures.push(conn);
433 }
434 tokio::spawn(async move {
435 futures_util::future::join_all(futures).await;
436 });
437 }
438
439 assert_eq!(shutdown_counter.load(Ordering::SeqCst), 0);
440
441 tokio::select! {
442 _ = tokio::time::sleep(std::time::Duration::from_millis(200)) => {
443 panic!("timeout")
444 },
445 _ = graceful.shutdown() => {
446 assert_eq!(shutdown_counter.load(Ordering::SeqCst), 6);
447 }
448 }
449 }
450
451 #[cfg(not(miri))]
452 #[tokio::test]
453 async fn test_graceful_shutdown_timeout() {
454 let graceful = GracefulShutdown::new();
455 let shutdown_counter = Arc::new(AtomicUsize::new(0));
456
457 for i in 1..=3 {
458 let shutdown_counter = shutdown_counter.clone();
459
460 let future = async move {
461 if i == 1 {
462 std::future::pending::<()>().await
463 } else {
464 std::future::ready(()).await
465 }
466 };
467 let dummy_conn = DummyConnection {
468 future,
469 shutdown_counter,
470 };
471 let conn = graceful.watch(dummy_conn);
472 tokio::spawn(async move {
473 conn.await.unwrap();
474 });
475 }
476
477 assert_eq!(shutdown_counter.load(Ordering::SeqCst), 0);
478
479 tokio::select! {
480 _ = tokio::time::sleep(std::time::Duration::from_millis(100)) => {
481 assert_eq!(shutdown_counter.load(Ordering::SeqCst), 3);
482 },
483 _ = graceful.shutdown() => {
484 panic!("shutdown should not be completed: as not all our conns finish")
485 }
486 }
487 }
488}