use std::default::Default;
use std::fmt;
use std::sync::{Arc, Weak};
use tokio::sync::{mpsc, watch};
pub struct ShutdownHandle {
cancellation_rx: ShutdownSignal,
_shutdown_tx: mpsc::Sender<()>,
}
impl Default for ShutdownHandle {
fn default() -> Self {
let (_shutdown_tx, _) = mpsc::channel(1);
ShutdownHandle {
cancellation_rx: ShutdownSignal::default(),
_shutdown_tx,
}
}
}
impl fmt::Debug for ShutdownHandle {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("ShutdownHandle").finish()
}
}
#[derive(Clone)]
pub enum ShutdownSignal {
WaitingForSignal(watch::Receiver<bool>),
Signalled,
}
impl ShutdownSignal {
pub async fn on_shutdown(&mut self) {
match self {
ShutdownSignal::WaitingForSignal(r) => {
let _ = r.changed().await;
*self = ShutdownSignal::Signalled;
}
ShutdownSignal::Signalled => {
futures::future::pending::<()>().await;
}
}
}
}
impl Default for ShutdownSignal {
fn default() -> Self {
ShutdownSignal::Signalled
}
}
impl From<&ShutdownHandle> for ShutdownSignal {
fn from(handle: &ShutdownHandle) -> Self {
handle.cancellation_rx.clone()
}
}
pub struct ShutdownCoordinator {
shutdown_handle: Arc<ShutdownHandle>,
cancellation_tx: watch::Sender<bool>,
shutdown_rx: mpsc::Receiver<()>,
}
impl ShutdownCoordinator {
pub fn new() -> Self {
let (cancellation_tx, cancellation_rx) = watch::channel(false);
let (shutdown_tx, shutdown_rx) = mpsc::channel(1);
let shutdown_handle = Arc::new(ShutdownHandle {
cancellation_rx: ShutdownSignal::WaitingForSignal(cancellation_rx),
_shutdown_tx: shutdown_tx,
});
ShutdownCoordinator {
shutdown_handle,
cancellation_tx,
shutdown_rx,
}
}
pub fn handle(&self) -> Arc<ShutdownHandle> {
Arc::clone(&self.shutdown_handle)
}
pub fn handle_weak(&self) -> Weak<ShutdownHandle> {
Arc::downgrade(&self.shutdown_handle)
}
pub async fn shutdown(mut self) {
let _ = self.cancellation_tx.send(true);
drop(self.shutdown_handle);
let _ = self.shutdown_rx.recv().await;
}
pub async fn shutdown_with_timeout(self, timeout: u64) {
let _ =
tokio::time::timeout(tokio::time::Duration::from_secs(timeout), self.shutdown()).await;
}
}
impl Default for ShutdownCoordinator {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use futures::{pin_mut, FutureExt};
#[tokio::test]
async fn test_shutdown_coordinator() {
assert!(ShutdownCoordinator::new()
.shutdown()
.now_or_never()
.is_some());
let sc = ShutdownCoordinator::new();
let handle = sc.handle();
let shutdown_fut = sc.shutdown();
pin_mut!(shutdown_fut);
assert!(shutdown_fut.as_mut().now_or_never().is_none());
drop(handle);
assert!(shutdown_fut.now_or_never().is_some());
}
#[tokio::test]
async fn test_default_shutdown_handle() {
let handle = ShutdownHandle::default();
let mut signal = ShutdownSignal::from(&handle);
assert!(signal.on_shutdown().now_or_never().is_none());
}
}