use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tracing::{debug, info, trace, warn};
pub struct GracefulReloadCoordinator {
active_requests: Arc<AtomicUsize>,
max_drain_time: Duration,
shutdown_requested: Arc<AtomicBool>,
}
impl GracefulReloadCoordinator {
pub fn new(max_drain_time: Duration) -> Self {
debug!(
max_drain_time_secs = max_drain_time.as_secs(),
"Creating graceful reload coordinator"
);
Self {
active_requests: Arc::new(AtomicUsize::new(0)),
max_drain_time,
shutdown_requested: Arc::new(AtomicBool::new(false)),
}
}
pub fn inc_requests(&self) {
let count = self.active_requests.fetch_add(1, Ordering::Relaxed) + 1;
trace!(active_requests = count, "Request started");
}
pub fn dec_requests(&self) {
let count = self.active_requests.fetch_sub(1, Ordering::Relaxed) - 1;
trace!(active_requests = count, "Request completed");
}
pub async fn wait_for_drain(&self) -> bool {
let start = Instant::now();
let initial_count = self.active_requests.load(Ordering::Relaxed);
info!(
active_requests = initial_count,
max_drain_time_secs = self.max_drain_time.as_secs(),
"Starting request drain"
);
let mut last_logged_count = initial_count;
while self.active_requests.load(Ordering::Relaxed) > 0 {
if start.elapsed() > self.max_drain_time {
let remaining = self.active_requests.load(Ordering::Relaxed);
warn!(
remaining_requests = remaining,
elapsed_secs = start.elapsed().as_secs(),
"Drain timeout reached, requests still active"
);
return false;
}
let current_count = self.active_requests.load(Ordering::Relaxed);
if current_count != last_logged_count {
debug!(
remaining_requests = current_count,
elapsed_ms = start.elapsed().as_millis(),
"Draining requests"
);
last_logged_count = current_count;
}
tokio::time::sleep(Duration::from_millis(100)).await;
}
info!(
elapsed_ms = start.elapsed().as_millis(),
initial_requests = initial_count,
"All requests drained successfully"
);
true
}
pub fn active_count(&self) -> usize {
self.active_requests.load(Ordering::Relaxed)
}
pub fn request_shutdown(&self) {
info!(
active_requests = self.active_requests.load(Ordering::Relaxed),
"Shutdown requested"
);
self.shutdown_requested.store(true, Ordering::SeqCst);
}
pub fn is_shutdown_requested(&self) -> bool {
self.shutdown_requested.load(Ordering::SeqCst)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_graceful_coordinator() {
let coordinator = GracefulReloadCoordinator::new(Duration::from_secs(1));
coordinator.inc_requests();
coordinator.inc_requests();
assert_eq!(coordinator.active_count(), 2);
coordinator.dec_requests();
assert_eq!(coordinator.active_count(), 1);
coordinator.dec_requests();
assert_eq!(coordinator.active_count(), 0);
let drained = coordinator.wait_for_drain().await;
assert!(drained);
}
#[tokio::test]
async fn test_graceful_coordinator_shutdown_flag() {
let coordinator = GracefulReloadCoordinator::new(Duration::from_secs(1));
assert!(!coordinator.is_shutdown_requested());
coordinator.request_shutdown();
assert!(coordinator.is_shutdown_requested());
}
}