use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use tokio::sync::{broadcast, Notify};
#[derive(Debug, Clone)]
pub struct ShutdownToken {
shutdown_tx: Arc<broadcast::Sender<()>>,
is_shutdown: Arc<AtomicBool>,
}
impl ShutdownToken {
pub fn is_shutdown(&self) -> bool {
self.is_shutdown.load(Ordering::SeqCst)
}
pub fn subscribe(&self) -> broadcast::Receiver<()> {
self.shutdown_tx.subscribe()
}
}
#[derive(Debug, Clone)]
pub struct ShutdownHandle {
shutdown_tx: Arc<broadcast::Sender<()>>,
notify: Arc<Notify>,
is_shutdown: Arc<AtomicBool>,
}
impl GracefulShutdown for ShutdownHandle {
fn shutdown(&self) {
self.is_shutdown.store(true, Ordering::SeqCst);
let _ = self.shutdown_tx.send(());
self.notify.notify_waiters();
}
}
impl ShutdownHandle {
pub fn new() -> Self {
let (shutdown_tx, _) = broadcast::channel(1);
Self {
shutdown_tx: Arc::new(shutdown_tx),
notify: Arc::new(Notify::new()),
is_shutdown: Arc::new(AtomicBool::new(false)),
}
}
pub fn token(&self) -> ShutdownToken {
ShutdownToken {
shutdown_tx: Arc::clone(&self.shutdown_tx),
is_shutdown: Arc::clone(&self.is_shutdown),
}
}
pub fn notify(&self) -> Arc<Notify> {
Arc::clone(&self.notify)
}
}
impl Default for ShutdownHandle {
fn default() -> Self {
Self::new()
}
}
pub trait GracefulShutdown: Send + Sync {
fn shutdown(&self);
fn is_shutdown(&self) -> bool {
false
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ShutdownState {
Running,
ShuttingDown,
Shutdown,
}
impl Default for ShutdownState {
fn default() -> Self {
Self::Running
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_shutdown_token() {
let handle = ShutdownHandle::new();
let token = handle.token();
assert!(!token.is_shutdown(), "should not be shutdown before calling shutdown()");
handle.shutdown();
let token2 = handle.token();
assert!(token2.is_shutdown(), "should be shutdown after calling shutdown()");
let token3 = handle.token();
assert!(token3.is_shutdown(), "should remain shutdown");
}
#[tokio::test]
async fn test_multiple_tokens() {
let handle = ShutdownHandle::new();
let token1 = handle.token();
let token2 = handle.token();
assert!(!token1.is_shutdown());
assert!(!token2.is_shutdown());
handle.shutdown();
assert!(token1.is_shutdown());
assert!(token2.is_shutdown());
}
}