use asupersync::{Budget, CancelReason};
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU8, Ordering};
use std::task::{Context, Poll, Waker};
use std::time::Duration;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
pub enum ShutdownPhase {
Running = 0,
StopAccepting = 1,
ShutdownFlagged = 2,
GracePeriod = 3,
Cancelling = 4,
RunningHooks = 5,
Stopped = 6,
}
impl ShutdownPhase {
#[must_use]
pub fn should_reject_connections(self) -> bool {
self as u8 >= Self::StopAccepting as u8
}
#[must_use]
pub fn should_reject_requests(self) -> bool {
self as u8 >= Self::ShutdownFlagged as u8
}
#[must_use]
pub fn is_shutting_down(self) -> bool {
self as u8 >= Self::StopAccepting as u8
}
#[must_use]
pub fn is_stopped(self) -> bool {
self == Self::Stopped
}
}
impl From<u8> for ShutdownPhase {
fn from(value: u8) -> Self {
match value {
0 => Self::Running,
1 => Self::StopAccepting,
2 => Self::ShutdownFlagged,
3 => Self::GracePeriod,
4 => Self::Cancelling,
5 => Self::RunningHooks,
_ => Self::Stopped,
}
}
}
struct ShutdownState {
phase: AtomicU8,
forced: AtomicBool,
wakers: parking_lot::Mutex<Vec<Waker>>,
hooks: parking_lot::Mutex<Vec<ShutdownHook>>,
in_flight: std::sync::atomic::AtomicUsize,
}
impl ShutdownState {
fn new() -> Self {
Self {
phase: AtomicU8::new(ShutdownPhase::Running as u8),
forced: AtomicBool::new(false),
wakers: parking_lot::Mutex::new(Vec::new()),
hooks: parking_lot::Mutex::new(Vec::new()),
in_flight: std::sync::atomic::AtomicUsize::new(0),
}
}
fn phase(&self) -> ShutdownPhase {
ShutdownPhase::from(self.phase.load(Ordering::Acquire))
}
fn set_phase(&self, phase: ShutdownPhase) {
self.phase.store(phase as u8, Ordering::Release);
self.wake_all();
}
fn try_advance_phase(&self, from: ShutdownPhase, to: ShutdownPhase) -> bool {
self.phase
.compare_exchange(from as u8, to as u8, Ordering::AcqRel, Ordering::Acquire)
.is_ok()
}
fn is_forced(&self) -> bool {
self.forced.load(Ordering::Acquire)
}
fn set_forced(&self) {
self.forced.store(true, Ordering::Release);
self.wake_all();
}
fn wake_all(&self) {
let wakers = std::mem::take(&mut *self.wakers.lock());
for waker in wakers {
waker.wake();
}
}
fn register_waker(&self, waker: &Waker) {
let mut wakers = self.wakers.lock();
if !wakers.iter().any(|w| w.will_wake(waker)) {
wakers.push(waker.clone());
}
}
fn increment_in_flight(&self) -> usize {
self.in_flight.fetch_add(1, Ordering::AcqRel) + 1
}
fn decrement_in_flight(&self) -> usize {
self.in_flight.fetch_sub(1, Ordering::AcqRel) - 1
}
fn in_flight_count(&self) -> usize {
self.in_flight.load(Ordering::Acquire)
}
}
#[derive(Clone)]
pub struct ShutdownController {
state: Arc<ShutdownState>,
}
impl ShutdownController {
#[must_use]
pub fn new() -> Self {
Self {
state: Arc::new(ShutdownState::new()),
}
}
#[must_use]
pub fn subscribe(&self) -> ShutdownReceiver {
ShutdownReceiver {
state: Arc::clone(&self.state),
}
}
#[must_use]
pub fn phase(&self) -> ShutdownPhase {
self.state.phase()
}
#[must_use]
pub fn is_shutting_down(&self) -> bool {
self.state.phase().is_shutting_down()
}
#[must_use]
pub fn is_forced(&self) -> bool {
self.state.is_forced()
}
pub fn shutdown(&self) {
let current = self.state.phase();
if current == ShutdownPhase::Running {
self.state.set_phase(ShutdownPhase::StopAccepting);
} else if !self.state.is_forced() {
self.state.set_forced();
}
}
pub fn force_shutdown(&self) {
self.state.set_forced();
self.state.set_phase(ShutdownPhase::Cancelling);
}
pub fn advance_phase(&self) -> bool {
let current = self.state.phase();
let next = match current {
ShutdownPhase::Running => ShutdownPhase::StopAccepting,
ShutdownPhase::StopAccepting => ShutdownPhase::ShutdownFlagged,
ShutdownPhase::ShutdownFlagged => ShutdownPhase::GracePeriod,
ShutdownPhase::GracePeriod => ShutdownPhase::Cancelling,
ShutdownPhase::Cancelling => ShutdownPhase::RunningHooks,
ShutdownPhase::RunningHooks => ShutdownPhase::Stopped,
ShutdownPhase::Stopped => return false,
};
self.state.try_advance_phase(current, next)
}
pub fn register_hook<F>(&self, hook: F)
where
F: FnOnce() + Send + 'static,
{
let mut hooks = self.state.hooks.lock();
hooks.push(ShutdownHook::Sync(Box::new(hook)));
}
pub fn register_async_hook<F, Fut>(&self, hook: F)
where
F: FnOnce() -> Fut + Send + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
let mut hooks = self.state.hooks.lock();
hooks.push(ShutdownHook::AsyncFactory(Box::new(move || {
Box::pin(hook())
})));
}
pub fn pop_hook(&self) -> Option<ShutdownHook> {
let mut hooks = self.state.hooks.lock();
hooks.pop()
}
#[must_use]
pub fn hook_count(&self) -> usize {
self.state.hooks.lock().len()
}
#[must_use]
pub fn track_request(&self) -> InFlightGuard {
self.state.increment_in_flight();
InFlightGuard {
state: Arc::clone(&self.state),
}
}
#[must_use]
pub fn in_flight_count(&self) -> usize {
self.state.in_flight_count()
}
}
impl Default for ShutdownController {
fn default() -> Self {
Self::new()
}
}
#[derive(Clone)]
pub struct ShutdownReceiver {
state: Arc<ShutdownState>,
}
impl ShutdownReceiver {
pub async fn wait(&self) {
ShutdownWaitFuture { state: &self.state }.await
}
#[must_use]
pub fn phase(&self) -> ShutdownPhase {
self.state.phase()
}
#[must_use]
pub fn is_shutting_down(&self) -> bool {
self.state.phase().is_shutting_down()
}
#[must_use]
pub fn is_forced(&self) -> bool {
self.state.is_forced()
}
}
struct ShutdownWaitFuture<'a> {
state: &'a ShutdownState,
}
impl Future for ShutdownWaitFuture<'_> {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if self.state.phase().is_shutting_down() {
Poll::Ready(())
} else {
self.state.register_waker(cx.waker());
if self.state.phase().is_shutting_down() {
Poll::Ready(())
} else {
Poll::Pending
}
}
}
}
pub struct InFlightGuard {
state: Arc<ShutdownState>,
}
impl Drop for InFlightGuard {
fn drop(&mut self) {
self.state.decrement_in_flight();
}
}
pub enum ShutdownHook {
Sync(Box<dyn FnOnce() + Send>),
AsyncFactory(Box<dyn FnOnce() -> Pin<Box<dyn Future<Output = ()> + Send>> + Send>),
}
impl ShutdownHook {
pub fn run(self) -> Option<Pin<Box<dyn Future<Output = ()> + Send>>> {
match self {
Self::Sync(f) => {
f();
None
}
Self::AsyncFactory(f) => Some(f()),
}
}
}
#[derive(Clone)]
pub struct GracefulConfig {
pub grace_period: Duration,
pub cleanup_budget: Budget,
pub log_events: bool,
}
impl Default for GracefulConfig {
fn default() -> Self {
Self {
grace_period: Duration::from_secs(30),
cleanup_budget: Budget::new()
.with_poll_quota(500)
.with_deadline(asupersync::Time::from_secs(5)),
log_events: true,
}
}
}
pub struct GracefulShutdown {
receiver: ShutdownReceiver,
config: GracefulConfig,
}
impl GracefulShutdown {
#[must_use]
pub fn new(receiver: ShutdownReceiver) -> Self {
Self {
receiver,
config: GracefulConfig::default(),
}
}
#[must_use]
pub fn grace_period(mut self, duration: Duration) -> Self {
self.config.grace_period = duration;
self
}
#[must_use]
pub fn cleanup_budget(mut self, budget: Budget) -> Self {
self.config.cleanup_budget = budget;
self
}
#[must_use]
pub fn log_events(mut self, enabled: bool) -> Self {
self.config.log_events = enabled;
self
}
pub async fn run<F, T>(self, fut: F) -> ShutdownOutcome<T>
where
F: Future<Output = T>,
{
use std::pin::pin;
use std::task::Poll;
let mut fut = pin!(fut);
std::future::poll_fn(|cx| {
if let Poll::Ready(v) = fut.as_mut().poll(cx) {
return Poll::Ready(ShutdownOutcome::Completed(v));
}
if self.receiver.state.is_forced() {
return Poll::Ready(ShutdownOutcome::ForcedShutdown);
}
if self.receiver.state.phase().is_shutting_down() {
return Poll::Ready(ShutdownOutcome::GracefulShutdown);
}
self.receiver.state.register_waker(cx.waker());
if self.receiver.state.is_forced() {
Poll::Ready(ShutdownOutcome::ForcedShutdown)
} else if self.receiver.state.phase().is_shutting_down() {
Poll::Ready(ShutdownOutcome::GracefulShutdown)
} else {
Poll::Pending
}
})
.await
}
#[must_use]
pub fn config(&self) -> &GracefulConfig {
&self.config
}
}
#[derive(Debug)]
pub enum ShutdownOutcome<T> {
Completed(T),
GracefulShutdown,
ForcedShutdown,
}
impl<T> ShutdownOutcome<T> {
#[must_use]
pub fn is_completed(&self) -> bool {
matches!(self, Self::Completed(_))
}
#[must_use]
pub fn is_shutdown(&self) -> bool {
matches!(self, Self::GracefulShutdown | Self::ForcedShutdown)
}
#[must_use]
pub fn is_forced(&self) -> bool {
matches!(self, Self::ForcedShutdown)
}
#[must_use]
pub fn into_completed(self) -> Option<T> {
match self {
Self::Completed(v) => Some(v),
_ => None,
}
}
}
#[must_use]
pub fn subdivide_grace_budget(
grace_remaining: Duration,
in_flight_count: usize,
original_budget: Option<Budget>,
) -> Budget {
use asupersync::Time;
let count = in_flight_count.max(1);
let per_request = grace_remaining / count as u32;
let deadline_nanos = per_request.as_nanos() as u64;
let grace_budget = Budget::new().with_deadline(Time::from_nanos(deadline_nanos));
match original_budget {
Some(original) => original.meet(grace_budget),
None => grace_budget,
}
}
#[must_use]
pub fn shutdown_cancel_reason() -> CancelReason {
CancelReason::shutdown()
}
#[must_use]
pub fn grace_expired_cancel_reason() -> CancelReason {
CancelReason::timeout()
}
pub trait ShutdownAware {
fn is_shutting_down(&self) -> bool;
fn shutdown_phase(&self) -> Option<ShutdownPhase>;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn shutdown_phase_transitions() {
assert!(!ShutdownPhase::Running.should_reject_connections());
assert!(ShutdownPhase::StopAccepting.should_reject_connections());
assert!(ShutdownPhase::ShutdownFlagged.should_reject_requests());
assert!(ShutdownPhase::GracePeriod.is_shutting_down());
assert!(ShutdownPhase::Stopped.is_stopped());
}
#[test]
fn controller_basic() {
let controller = ShutdownController::new();
assert_eq!(controller.phase(), ShutdownPhase::Running);
assert!(!controller.is_shutting_down());
controller.shutdown();
assert_eq!(controller.phase(), ShutdownPhase::StopAccepting);
assert!(controller.is_shutting_down());
}
#[test]
fn controller_double_shutdown_forces() {
let controller = ShutdownController::new();
controller.shutdown();
assert!(!controller.is_forced());
controller.shutdown();
assert!(controller.is_forced());
}
#[test]
fn controller_advance_phase() {
let controller = ShutdownController::new();
assert!(controller.advance_phase());
assert_eq!(controller.phase(), ShutdownPhase::StopAccepting);
assert!(controller.advance_phase());
assert_eq!(controller.phase(), ShutdownPhase::ShutdownFlagged);
assert!(controller.advance_phase());
assert_eq!(controller.phase(), ShutdownPhase::GracePeriod);
assert!(controller.advance_phase());
assert_eq!(controller.phase(), ShutdownPhase::Cancelling);
assert!(controller.advance_phase());
assert_eq!(controller.phase(), ShutdownPhase::RunningHooks);
assert!(controller.advance_phase());
assert_eq!(controller.phase(), ShutdownPhase::Stopped);
assert!(!controller.advance_phase());
}
#[test]
fn in_flight_tracking() {
let controller = ShutdownController::new();
assert_eq!(controller.in_flight_count(), 0);
let guard1 = controller.track_request();
assert_eq!(controller.in_flight_count(), 1);
let guard2 = controller.track_request();
assert_eq!(controller.in_flight_count(), 2);
drop(guard1);
assert_eq!(controller.in_flight_count(), 1);
drop(guard2);
assert_eq!(controller.in_flight_count(), 0);
}
#[test]
fn shutdown_hooks_lifo() {
let controller = ShutdownController::new();
let order: Arc<parking_lot::Mutex<Vec<i32>>> =
Arc::new(parking_lot::Mutex::new(Vec::new()));
let order1 = Arc::clone(&order);
controller.register_hook(move || order1.lock().push(1));
let order2 = Arc::clone(&order);
controller.register_hook(move || order2.lock().push(2));
let order3 = Arc::clone(&order);
controller.register_hook(move || order3.lock().push(3));
assert_eq!(controller.hook_count(), 3);
while let Some(hook) = controller.pop_hook() {
hook.run();
}
assert_eq!(*order.lock(), vec![3, 2, 1]);
}
#[test]
fn subdivide_grace_budget_basic() {
let grace = Duration::from_secs(30);
let budget = subdivide_grace_budget(grace, 3, None);
assert!(budget.deadline.is_some());
}
#[test]
fn subdivide_grace_budget_respects_original() {
use asupersync::Time;
let grace = Duration::from_secs(30);
let original = Budget::new().with_deadline(Time::from_secs(5));
let budget = subdivide_grace_budget(grace, 3, Some(original));
assert!(budget.deadline.is_some());
}
#[test]
fn receiver_is_shutting_down() {
let controller = ShutdownController::new();
let receiver = controller.subscribe();
assert!(!receiver.is_shutting_down());
controller.shutdown();
assert!(receiver.is_shutting_down());
}
#[test]
fn graceful_shutdown_run_completed() {
let controller = ShutdownController::new();
let shutdown = GracefulShutdown::new(controller.subscribe());
let rt = asupersync::runtime::RuntimeBuilder::current_thread()
.build()
.expect("runtime must build");
let out = rt.block_on(async { shutdown.run(async { 42i32 }).await });
assert!(matches!(out, ShutdownOutcome::Completed(42)));
}
#[test]
fn graceful_shutdown_run_graceful_shutdown() {
let controller = ShutdownController::new();
controller.shutdown();
let shutdown = GracefulShutdown::new(controller.subscribe());
let rt = asupersync::runtime::RuntimeBuilder::current_thread()
.build()
.expect("runtime must build");
let out = rt.block_on(async { shutdown.run(std::future::pending::<i32>()).await });
assert!(matches!(out, ShutdownOutcome::GracefulShutdown));
}
#[test]
fn graceful_shutdown_run_forced_shutdown() {
let controller = ShutdownController::new();
controller.force_shutdown();
let shutdown = GracefulShutdown::new(controller.subscribe());
let rt = asupersync::runtime::RuntimeBuilder::current_thread()
.build()
.expect("runtime must build");
let out = rt.block_on(async { shutdown.run(std::future::pending::<i32>()).await });
assert!(matches!(out, ShutdownOutcome::ForcedShutdown));
}
#[test]
fn shutdown_outcome_accessors() {
let completed: ShutdownOutcome<i32> = ShutdownOutcome::Completed(42);
assert!(completed.is_completed());
assert!(!completed.is_shutdown());
assert_eq!(completed.into_completed(), Some(42));
let graceful: ShutdownOutcome<i32> = ShutdownOutcome::GracefulShutdown;
assert!(!graceful.is_completed());
assert!(graceful.is_shutdown());
assert!(!graceful.is_forced());
let forced: ShutdownOutcome<i32> = ShutdownOutcome::ForcedShutdown;
assert!(forced.is_shutdown());
assert!(forced.is_forced());
}
#[test]
fn cancel_reasons() {
let shutdown = shutdown_cancel_reason();
assert!(shutdown.is_shutdown());
let grace = grace_expired_cancel_reason();
assert!(!grace.is_shutdown());
}
}