use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Notify;
use tokio::time::timeout;
pub struct GracefulShutdown {
shutting_down: AtomicBool,
in_flight: AtomicUsize,
all_complete: Arc<Notify>,
shutdown_timeout: Duration,
completed: AtomicBool,
}
impl GracefulShutdown {
pub fn new(shutdown_timeout: Duration) -> Self {
Self {
shutting_down: AtomicBool::new(false),
in_flight: AtomicUsize::new(0),
all_complete: Arc::new(Notify::new()),
shutdown_timeout,
completed: AtomicBool::new(false),
}
}
pub fn default_timeout() -> Self {
Self::new(Duration::from_secs(30))
}
pub fn is_shutting_down(&self) -> bool {
self.shutting_down.load(Ordering::SeqCst)
}
pub fn in_flight_count(&self) -> usize {
self.in_flight.load(Ordering::SeqCst)
}
pub fn start_request(&self) -> Option<RequestGuard<'_>> {
if self.is_shutting_down() {
return None;
}
self.in_flight.fetch_add(1, Ordering::SeqCst);
if self.is_shutting_down() {
self.finish_request();
return None;
}
Some(RequestGuard {
shutdown: self,
})
}
fn finish_request(&self) {
let prev = self.in_flight.fetch_sub(1, Ordering::SeqCst);
if prev == 1 && self.is_shutting_down() {
self.all_complete.notify_waiters();
}
}
pub fn initiate(&self) {
self.shutting_down.store(true, Ordering::SeqCst);
if self.in_flight.load(Ordering::SeqCst) == 0 {
self.all_complete.notify_waiters();
}
}
pub async fn wait_for_completion(&self) -> ShutdownResult {
if !self.is_shutting_down() {
self.initiate();
}
if self.in_flight.load(Ordering::SeqCst) == 0 {
self.completed.store(true, Ordering::SeqCst);
return ShutdownResult::Clean;
}
match timeout(self.shutdown_timeout, self.all_complete.notified()).await {
Ok(_) => {
self.completed.store(true, Ordering::SeqCst);
ShutdownResult::Clean
}
Err(_) => {
let remaining = self.in_flight.load(Ordering::SeqCst);
self.completed.store(true, Ordering::SeqCst);
ShutdownResult::Timeout { remaining_requests: remaining }
}
}
}
pub fn is_completed(&self) -> bool {
self.completed.load(Ordering::SeqCst)
}
pub fn status(&self) -> ShutdownStatus {
ShutdownStatus {
shutting_down: self.is_shutting_down(),
in_flight: self.in_flight_count(),
completed: self.is_completed(),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ShutdownResult {
Clean,
Timeout {
remaining_requests: usize,
},
}
impl ShutdownResult {
pub fn is_clean(&self) -> bool {
matches!(self, ShutdownResult::Clean)
}
}
pub struct RequestGuard<'a> {
shutdown: &'a GracefulShutdown,
}
impl Drop for RequestGuard<'_> {
fn drop(&mut self) {
self.shutdown.finish_request();
}
}
#[derive(Debug, Clone)]
pub struct ShutdownStatus {
pub shutting_down: bool,
pub in_flight: usize,
pub completed: bool,
}
#[cfg(unix)]
pub async fn wait_for_shutdown_signal() {
use tokio::signal::unix::{signal, SignalKind};
let mut sigterm = signal(SignalKind::terminate()).expect("Failed to install SIGTERM handler");
let mut sigint = signal(SignalKind::interrupt()).expect("Failed to install SIGINT handler");
tokio::select! {
_ = sigterm.recv() => {
tracing::info!("Received SIGTERM, initiating graceful shutdown");
}
_ = sigint.recv() => {
tracing::info!("Received SIGINT, initiating graceful shutdown");
}
}
}
#[cfg(windows)]
pub async fn wait_for_shutdown_signal() {
use tokio::signal::ctrl_c;
ctrl_c().await.expect("Failed to listen for Ctrl+C");
tracing::info!("Received Ctrl+C, initiating graceful shutdown");
}
pub async fn run_with_graceful_shutdown<F, Fut>(
shutdown: Arc<GracefulShutdown>,
main_task: F,
) -> ShutdownResult
where
F: FnOnce() -> Fut,
Fut: std::future::Future<Output = ()>,
{
let shutdown_clone = shutdown.clone();
tokio::spawn(async move {
wait_for_shutdown_signal().await;
shutdown_clone.initiate();
});
main_task().await;
shutdown.wait_for_completion().await
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_clean_shutdown() {
let shutdown = GracefulShutdown::new(Duration::from_secs(5));
let guard1 = shutdown.start_request().unwrap();
let guard2 = shutdown.start_request().unwrap();
assert_eq!(shutdown.in_flight_count(), 2);
shutdown.initiate();
assert!(shutdown.start_request().is_none());
drop(guard1);
assert_eq!(shutdown.in_flight_count(), 1);
drop(guard2);
assert_eq!(shutdown.in_flight_count(), 0);
let result = shutdown.wait_for_completion().await;
assert_eq!(result, ShutdownResult::Clean);
}
#[tokio::test]
async fn test_shutdown_timeout() {
let shutdown = GracefulShutdown::new(Duration::from_millis(100));
let _guard = shutdown.start_request().unwrap();
shutdown.initiate();
let result = shutdown.wait_for_completion().await;
assert!(matches!(result, ShutdownResult::Timeout { remaining_requests: 1 }));
}
#[tokio::test]
async fn test_empty_shutdown() {
let shutdown = GracefulShutdown::new(Duration::from_secs(5));
let result = shutdown.wait_for_completion().await;
assert_eq!(result, ShutdownResult::Clean);
}
#[test]
fn test_status() {
let shutdown = GracefulShutdown::default_timeout();
let status = shutdown.status();
assert!(!status.shutting_down);
assert_eq!(status.in_flight, 0);
assert!(!status.completed);
}
}