use hyper_util::server::graceful::GracefulConnection;
use pin_project_lite::pin_project;
use std::{
fmt::{self, Debug},
future::Future,
pin::Pin,
task::{self, Poll},
};
use tokio::sync::watch;
#[derive(Clone)] pub struct GracefulShutdown {
tx: watch::Sender<()>,
}
impl GracefulShutdown {
pub fn new() -> Self {
let (tx, _) = watch::channel(());
Self { tx }
}
pub fn watch<C: GracefulConnection>(&self, conn: C) -> impl Future<Output = C::Output> {
let mut rx = self.tx.subscribe();
GracefulConnectionFuture::new(conn, async move {
let _ = rx.changed().await;
rx
})
}
pub fn subscribe(&self) -> watch::Receiver<()> {
self.tx.subscribe()
}
pub async fn shutdown(self) {
let Self { tx } = self;
let _ = tx.send(());
tx.closed().await;
}
}
pin_project! {
struct GracefulConnectionFuture<C, F: Future> {
#[pin]
conn: C,
#[pin]
cancel: F,
#[pin]
cancelled_guard: Option<F::Output>,
}
}
impl<C, F: Future> GracefulConnectionFuture<C, F> {
fn new(conn: C, cancel: F) -> Self {
Self {
conn,
cancel,
cancelled_guard: None,
}
}
}
impl<C, F: Future> Debug for GracefulConnectionFuture<C, F> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("GracefulConnectionFuture").finish()
}
}
impl<C, F> Future for GracefulConnectionFuture<C, F>
where
C: GracefulConnection,
F: Future,
{
type Output = C::Output;
fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
let mut this = self.project();
if this.cancelled_guard.is_none() {
if let Poll::Ready(guard) = this.cancel.poll(cx) {
this.cancelled_guard.set(Some(guard));
this.conn.as_mut().graceful_shutdown();
}
}
this.conn.poll(cx)
}
}