use std::sync::{
Arc,
atomic::{AtomicBool, AtomicUsize, Ordering},
};
use tokio::sync::Notify;
#[derive(Clone, Default)]
pub struct ShutdownCoordinator {
inner: Arc<ShutdownInner>,
}
#[derive(Default)]
struct ShutdownInner {
draining: AtomicBool,
notify: Notify,
inflight: AtomicUsize,
restart_requested: AtomicBool,
}
impl ShutdownCoordinator {
pub fn new() -> Self {
Self::default()
}
pub fn is_draining(&self) -> bool {
self.inner.draining.load(Ordering::Acquire)
}
pub fn begin_drain(&self) {
self.inner.draining.store(true, Ordering::Release);
self.inner.notify.notify_waiters();
}
pub async fn notified(&self) {
if self.is_draining() {
return;
}
let waiter = self.inner.notify.notified();
if self.is_draining() {
return;
}
waiter.await;
}
pub fn begin_work(&self) -> InflightGuard {
self.inner.inflight.fetch_add(1, Ordering::AcqRel);
InflightGuard {
inner: Arc::clone(&self.inner),
}
}
pub fn inflight(&self) -> usize {
self.inner.inflight.load(Ordering::Acquire)
}
pub fn request_restart(&self) {
self.inner.restart_requested.store(true, Ordering::Release);
self.begin_drain();
}
pub fn is_restart_requested(&self) -> bool {
self.inner.restart_requested.load(Ordering::Acquire)
}
}
pub struct InflightGuard {
inner: Arc<ShutdownInner>,
}
impl Drop for InflightGuard {
fn drop(&mut self) {
self.inner.inflight.fetch_sub(1, Ordering::AcqRel);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn begin_drain_wakes_notified() {
let coord = ShutdownCoordinator::new();
let coord_clone = coord.clone();
let waiter = tokio::spawn(async move { coord_clone.notified().await });
tokio::task::yield_now().await;
assert!(!coord.is_draining());
coord.begin_drain();
waiter.await.expect("waiter ok");
assert!(coord.is_draining());
}
#[tokio::test]
async fn notified_returns_immediately_if_already_draining() {
let coord = ShutdownCoordinator::new();
coord.begin_drain();
tokio::time::timeout(std::time::Duration::from_millis(100), coord.notified())
.await
.expect("notified returned");
}
#[tokio::test]
async fn request_restart_sets_flag_and_begins_drain() {
let coord = ShutdownCoordinator::new();
assert!(!coord.is_draining());
assert!(!coord.is_restart_requested());
coord.request_restart();
assert!(coord.is_draining(), "request_restart should also drain");
assert!(coord.is_restart_requested());
coord.request_restart();
assert!(coord.is_restart_requested());
tokio::time::timeout(std::time::Duration::from_millis(100), coord.notified())
.await
.expect("notified after request_restart");
}
#[test]
fn begin_drain_alone_does_not_set_restart_flag() {
let coord = ShutdownCoordinator::new();
coord.begin_drain();
assert!(coord.is_draining());
assert!(
!coord.is_restart_requested(),
"drain without restart must not set the restart flag"
);
}
#[test]
fn inflight_guard_decrements_on_drop() {
let coord = ShutdownCoordinator::new();
assert_eq!(coord.inflight(), 0);
let g1 = coord.begin_work();
let g2 = coord.begin_work();
assert_eq!(coord.inflight(), 2);
drop(g1);
assert_eq!(coord.inflight(), 1);
drop(g2);
assert_eq!(coord.inflight(), 0);
}
}