use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tracing::{debug, info, warn};
use super::UpstreamPool;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BackendState {
Active,
Draining,
Drained,
}
impl std::fmt::Display for BackendState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
BackendState::Active => write!(f, "active"),
BackendState::Draining => write!(f, "draining"),
BackendState::Drained => write!(f, "drained"),
}
}
}
pub struct DrainTracker {
max_drain_time: Duration,
poll_interval: Duration,
}
impl DrainTracker {
pub fn new(max_drain_time: Duration, poll_interval: Duration) -> Self {
Self {
max_drain_time,
poll_interval,
}
}
pub async fn track_pools(&self, pools: HashMap<String, Arc<UpstreamPool>>) {
if pools.is_empty() {
return;
}
let pool_count = pools.len();
info!(
pool_count = pool_count,
"Starting drain tracking for removed upstream pools"
);
for (name, pool) in &pools {
let active = pool.active_request_count();
info!(
upstream_id = %name,
active_requests = active,
state = %BackendState::Draining,
"Backend entering drain state"
);
}
let start = Instant::now();
let mut pending: HashMap<String, Arc<UpstreamPool>> = pools;
while !pending.is_empty() && start.elapsed() < self.max_drain_time {
tokio::time::sleep(self.poll_interval).await;
let mut newly_drained = Vec::new();
for (name, pool) in &pending {
let active = pool.active_request_count();
if active == 0 {
let drain_duration = start.elapsed();
info!(
upstream_id = %name,
drain_duration_ms = drain_duration.as_millis(),
drain_duration_secs = drain_duration.as_secs_f64(),
state = %BackendState::Drained,
"Backend fully drained, safe to terminate"
);
newly_drained.push(name.clone());
} else {
debug!(
upstream_id = %name,
active_requests = active,
elapsed_ms = start.elapsed().as_millis(),
state = %BackendState::Draining,
"Backend still draining"
);
}
}
for name in newly_drained {
if let Some(pool) = pending.remove(&name) {
pool.shutdown().await;
}
}
}
for (name, pool) in &pending {
let active = pool.active_request_count();
warn!(
upstream_id = %name,
active_requests = active,
max_drain_time_secs = self.max_drain_time.as_secs(),
state = "drain_timeout",
"Backend drain timeout exceeded, force shutting down"
);
pool.shutdown().await;
}
if pool_count > 0 {
info!(
pool_count = pool_count,
total_duration_ms = start.elapsed().as_millis(),
"Drain tracking complete for all removed pools"
);
}
}
}
impl Default for DrainTracker {
fn default() -> Self {
Self {
max_drain_time: Duration::from_secs(60),
poll_interval: Duration::from_secs(1),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn backend_state_display() {
assert_eq!(BackendState::Active.to_string(), "active");
assert_eq!(BackendState::Draining.to_string(), "draining");
assert_eq!(BackendState::Drained.to_string(), "drained");
}
}