use std::{
collections::HashMap,
error::Error,
future::Future,
pin::Pin,
sync::{Arc, Mutex, Weak},
task::{Context, Poll},
};
use futures::{
channel::oneshot::{self, Receiver, Sender},
ready, FutureExt,
};
use hyper::{
body::{Body, Incoming},
rt::{Read, Write},
service::{HttpService, Service},
Request, Response,
};
use hyper_util::server::conn::auto::{HttpServerConnExec, UpgradeableConnection};
pub struct GracefulShutdown {
context: Arc<Mutex<GracefulShutdownContext>>,
}
impl GracefulShutdown {
pub fn new() -> Self {
Self {
context: Arc::new(Mutex::new(GracefulShutdownContext::new())),
}
}
pub fn register_task(&self) -> ShutdownRegistration {
let (start_shutdown_tx, start_shutdown_rx) = oneshot::channel();
let (shutdown_finished_tx, shutdown_finished_rx) = oneshot::channel();
let handle = TaskHandle {
start_shutdown_tx,
shutdown_finished_rx,
};
let id = self.context.lock().unwrap().register_task(handle);
let registration = TaskRegistration {
context: Arc::downgrade(&self.context),
id,
};
let finish_shutdown = FinishShutdown {
_registration: registration,
shutdown_finished_tx,
};
ShutdownRegistration {
start_shutdown_rx,
finish_shutdown: Some(finish_shutdown),
}
}
pub async fn shutdown(self) {
let f = self.context.lock().unwrap().graceful_shutdown();
std::mem::drop(self.context);
f.await
}
}
struct GracefulShutdownContext {
tasks: HashMap<u64, TaskHandle>,
next_id: u64,
}
impl GracefulShutdownContext {
fn new() -> Self {
Self {
tasks: HashMap::new(),
next_id: 0,
}
}
fn register_task(&mut self, handle: TaskHandle) -> u64 {
let id = self.next_id;
self.next_id = self.next_id.wrapping_add(1);
self.tasks.insert(id, handle);
id
}
fn remove_task(&mut self, id: u64) {
self.tasks.remove(&id);
}
fn graceful_shutdown(&mut self) -> impl Future<Output = ()> {
let finished = self
.tasks
.drain()
.map(|(_, c)| c.shutdown())
.collect::<Vec<_>>();
async move {
for f in finished {
f.await;
}
}
}
}
struct TaskHandle {
start_shutdown_tx: Sender<()>,
shutdown_finished_rx: Receiver<()>,
}
impl TaskHandle {
async fn shutdown(self) {
let _ = self.start_shutdown_tx.send(());
let _ = self.shutdown_finished_rx.await;
}
}
struct TaskRegistration {
context: Weak<Mutex<GracefulShutdownContext>>,
id: u64,
}
impl Drop for TaskRegistration {
fn drop(&mut self) {
if let Some(context) = self.context.upgrade() {
context.lock().unwrap().remove_task(self.id);
}
}
}
pub struct FinishShutdown {
_registration: TaskRegistration,
shutdown_finished_tx: Sender<()>,
}
impl FinishShutdown {
pub fn finish(self) {
let _ = self.shutdown_finished_tx.send(());
}
}
pub struct ShutdownRegistration {
start_shutdown_rx: Receiver<()>,
finish_shutdown: Option<FinishShutdown>,
}
impl Future for ShutdownRegistration {
type Output = FinishShutdown;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let _ = ready!(self.start_shutdown_rx.poll_unpin(cx));
let res = self.finish_shutdown.take();
Poll::Ready(res.unwrap())
}
}
pub trait UpgradeableConnectionExt<'a, I, S, E>
where
S: HttpService<Incoming>,
{
fn with_graceful_shutdown(
self,
shutdown: ShutdownRegistration,
) -> UpgradeableConnectionWithShutdown<'a, I, S, E>;
}
impl<'a, I, S, E> UpgradeableConnectionExt<'a, I, S, E> for UpgradeableConnection<'a, I, S, E>
where
S: HttpService<Incoming>,
{
fn with_graceful_shutdown(
self,
shutdown: ShutdownRegistration,
) -> UpgradeableConnectionWithShutdown<'a, I, S, E> {
UpgradeableConnectionWithShutdown {
connection: self,
start_shutdown: Some(shutdown),
finish_shutdown: None,
}
}
}
pin_project_lite::pin_project! {
pub struct UpgradeableConnectionWithShutdown<'a, I, S, E>
where
S: HttpService<Incoming>,
{
#[pin]
connection: UpgradeableConnection<'a, I, S, E>,
start_shutdown: Option<ShutdownRegistration>,
finish_shutdown: Option<FinishShutdown>,
}
}
impl<'a, I, S, E, B> Future for UpgradeableConnectionWithShutdown<'a, I, S, E>
where
S: Service<Request<Incoming>, Response = Response<B>>,
S::Future: 'static,
S::Error: Into<Box<dyn Error + Send + Sync>>,
B: Body + 'static,
B::Error: Into<Box<dyn Error + Send + Sync>>,
I: Read + Write + Unpin + Send + 'static,
E: HttpServerConnExec<S::Future, B>,
{
type Output = Result<(), Box<dyn Error + Send + Sync>>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut this = self.project();
if let Some(s) = this.start_shutdown.as_mut() {
if let Poll::Ready(f) = s.poll_unpin(cx) {
let c = this.connection.as_mut();
c.graceful_shutdown();
*this.start_shutdown = None;
*this.finish_shutdown = Some(f);
}
}
let res = ready!(this.connection.poll(cx));
if let Some(f) = this.finish_shutdown.take() {
f.finish();
}
*this.start_shutdown = None;
Poll::Ready(res)
}
}