use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};
use std::time::Duration;
use tokio::sync::Notify;
use tokio::task::JoinHandle;
use tokio::time::timeout;
use tokio_util::sync::CancellationToken;
use tracing::{debug, info, warn};
pub const DEFAULT_SHUTDOWN_TIMEOUT: Duration = Duration::from_millis(500);
pub const TASK_ABORT_TIMEOUT: Duration = Duration::from_millis(100);
pub struct ShutdownCoordinator {
close_start: CancellationToken,
close_complete: CancellationToken,
shutdown_initiated: AtomicBool,
active_tasks: Arc<AtomicUsize>,
tasks_complete: Arc<Notify>,
task_handles: Mutex<Vec<JoinHandle<()>>>,
}
impl std::fmt::Debug for ShutdownCoordinator {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ShutdownCoordinator")
.field("shutdown_initiated", &self.shutdown_initiated)
.field("active_tasks", &self.active_tasks)
.finish_non_exhaustive()
}
}
impl ShutdownCoordinator {
pub fn new() -> Arc<Self> {
Arc::new(Self {
close_start: CancellationToken::new(),
close_complete: CancellationToken::new(),
shutdown_initiated: AtomicBool::new(false),
active_tasks: Arc::new(AtomicUsize::new(0)),
tasks_complete: Arc::new(Notify::new()),
task_handles: Mutex::new(Vec::new()),
})
}
pub fn close_start_token(&self) -> CancellationToken {
self.close_start.clone()
}
pub fn close_complete_token(&self) -> CancellationToken {
self.close_complete.clone()
}
pub fn is_shutting_down(&self) -> bool {
self.shutdown_initiated.load(Ordering::SeqCst)
}
pub fn register_task(&self, handle: JoinHandle<()>) {
self.active_tasks.fetch_add(1, Ordering::SeqCst);
if let Ok(mut handles) = self.task_handles.lock() {
handles.push(handle);
}
}
pub fn spawn_tracked<F>(self: &Arc<Self>, future: F) -> JoinHandle<()>
where
F: std::future::Future<Output = ()> + Send + 'static,
{
let tasks_complete = Arc::clone(&self.tasks_complete);
let task_counter = Arc::clone(&self.active_tasks);
self.active_tasks.fetch_add(1, Ordering::SeqCst);
tokio::spawn(async move {
future.await;
if task_counter.fetch_sub(1, Ordering::SeqCst) == 1 {
tasks_complete.notify_waiters();
}
})
}
pub fn active_task_count(&self) -> usize {
self.active_tasks.load(Ordering::SeqCst)
}
pub async fn shutdown(&self) {
if self.shutdown_initiated.swap(true, Ordering::SeqCst) {
debug!("Shutdown already in progress");
return;
}
info!("Starting coordinated shutdown");
debug!("Stage 1: Signaling close start");
self.close_start.cancel();
debug!("Stage 2: Waiting for tasks to complete");
let wait_result = timeout(DEFAULT_SHUTDOWN_TIMEOUT, self.wait_for_tasks()).await;
if wait_result.is_err() {
warn!("Shutdown timeout - aborting remaining tasks");
}
debug!("Stage 3: Aborting remaining tasks");
self.abort_remaining_tasks().await;
debug!("Stage 4: Signaling close complete");
self.close_complete.cancel();
info!("Shutdown complete");
}
async fn wait_for_tasks(&self) {
while self.active_tasks.load(Ordering::SeqCst) > 0 {
self.tasks_complete.notified().await;
}
}
async fn abort_remaining_tasks(&self) {
let handles: Vec<_> = if let Ok(mut guard) = self.task_handles.lock() {
guard.drain(..).collect()
} else {
Vec::new()
};
for handle in handles {
if !handle.is_finished() {
handle.abort();
let _ = timeout(TASK_ABORT_TIMEOUT, async {
let _ = handle.await;
})
.await;
}
}
self.active_tasks.store(0, Ordering::SeqCst);
}
}
impl Default for ShutdownCoordinator {
fn default() -> Self {
Self {
close_start: CancellationToken::new(),
close_complete: CancellationToken::new(),
shutdown_initiated: AtomicBool::new(false),
active_tasks: Arc::new(AtomicUsize::new(0)),
tasks_complete: Arc::new(Notify::new()),
task_handles: Mutex::new(Vec::new()),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Instant;
#[tokio::test]
async fn test_shutdown_completes_within_timeout() {
let coordinator = ShutdownCoordinator::new();
let start = Instant::now();
coordinator.shutdown().await;
assert!(start.elapsed() < DEFAULT_SHUTDOWN_TIMEOUT + Duration::from_millis(100));
}
#[tokio::test]
async fn test_shutdown_is_idempotent() {
let coordinator = ShutdownCoordinator::new();
coordinator.shutdown().await;
coordinator.shutdown().await;
coordinator.shutdown().await;
}
#[tokio::test]
async fn test_is_shutting_down_flag() {
let coordinator = ShutdownCoordinator::new();
assert!(!coordinator.is_shutting_down());
coordinator.shutdown().await;
assert!(coordinator.is_shutting_down());
}
#[tokio::test]
async fn test_close_start_token_cancelled() {
let coordinator = ShutdownCoordinator::new();
let token = coordinator.close_start_token();
assert!(!token.is_cancelled());
coordinator.shutdown().await;
assert!(token.is_cancelled());
}
#[tokio::test]
async fn test_close_complete_token_cancelled() {
let coordinator = ShutdownCoordinator::new();
let token = coordinator.close_complete_token();
assert!(!token.is_cancelled());
coordinator.shutdown().await;
assert!(token.is_cancelled());
}
#[tokio::test]
async fn test_spawn_tracked_increments_count() {
let coordinator = ShutdownCoordinator::new();
assert_eq!(coordinator.active_task_count(), 0);
let _handle = coordinator.spawn_tracked(async {
tokio::time::sleep(Duration::from_secs(10)).await;
});
assert!(coordinator.active_task_count() >= 1);
coordinator.shutdown().await;
}
#[tokio::test]
async fn test_shutdown_with_long_running_tasks() {
let coordinator = ShutdownCoordinator::new();
let token = coordinator.close_start_token();
let _handle = coordinator.spawn_tracked(async move {
token.cancelled().await;
});
let start = Instant::now();
coordinator.shutdown().await;
assert!(start.elapsed() < DEFAULT_SHUTDOWN_TIMEOUT + Duration::from_millis(200));
}
#[tokio::test]
async fn test_task_completes_before_shutdown() {
let coordinator = ShutdownCoordinator::new();
let handle = coordinator.spawn_tracked(async {
tokio::time::sleep(Duration::from_millis(10)).await;
});
let _ = handle.await;
let start = Instant::now();
coordinator.shutdown().await;
assert!(start.elapsed() < Duration::from_millis(100));
}
#[tokio::test]
async fn test_multiple_tracked_tasks() {
let coordinator = ShutdownCoordinator::new();
let token = coordinator.close_start_token();
for _ in 0..5 {
let t = token.clone();
coordinator.spawn_tracked(async move {
t.cancelled().await;
});
}
assert!(coordinator.active_task_count() >= 5);
coordinator.shutdown().await;
}
#[tokio::test]
async fn test_task_decrements_on_completion() {
let coordinator = ShutdownCoordinator::new();
let handle = coordinator.spawn_tracked(async {
});
let _ = handle.await;
tokio::time::sleep(Duration::from_millis(10)).await;
assert_eq!(coordinator.active_task_count(), 0);
}
}