use std::process::ExitCode;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Duration;
use tokio::sync::Notify;
use tracing::{error, info, warn};
use crate::ServerState;
use crate::error::ServerError;
use crate::worker::LostWorkerReport;
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum ShutdownOutcome {
Clean,
TimedOut,
Forced,
}
impl ShutdownOutcome {
#[must_use]
pub fn exit_code(self) -> ExitCode {
match self {
Self::Clean => ExitCode::SUCCESS,
Self::TimedOut => ExitCode::FAILURE,
Self::Forced => ExitCode::from(130),
}
}
}
#[derive(Clone, Debug, Default)]
pub struct DrainState {
inner: Arc<DrainStateInner>,
}
#[derive(Debug, Default)]
struct DrainStateInner {
draining: AtomicBool,
empty: Notify,
}
impl DrainState {
#[must_use]
pub fn is_draining(&self) -> bool {
self.inner.draining.load(Ordering::Acquire)
}
#[must_use]
pub fn begin(&self) -> bool {
!self.inner.draining.swap(true, Ordering::AcqRel)
}
pub fn ensure_accepting(
&self,
namespace: &str,
activity_type: &str,
) -> Result<(), ServerError> {
if self.is_draining() {
Err(ServerError::worker_dispatch(
namespace.to_owned(),
activity_type.to_owned(),
"server is draining and not accepting new activity tasks",
))
} else {
Ok(())
}
}
pub fn notify_activity_drained(&self) {
self.inner.empty.notify_waiters();
}
async fn wait_for_empty(&self, state: &ServerState) -> Result<(), ServerError> {
loop {
let in_flight = state.heartbeat_tracker().in_flight_count()?;
if in_flight == 0 {
return Ok(());
}
let notified = self.inner.empty.notified();
if state.heartbeat_tracker().in_flight_count()? == 0 {
return Ok(());
}
notified.await;
}
}
}
pub async fn drain_after_first_signal(
state: ServerState,
second_signal: impl std::future::Future<Output = ()>,
) -> Result<ShutdownOutcome, ServerError> {
let drain = state.drain_state().clone();
let first = drain.begin();
if first {
info!("shutdown signal received; beginning graceful drain");
}
let delivered_workers = state.worker_registry().broadcast_drain()?;
info!(delivered_workers, "sent drain request to connected workers");
let timeout = state.runtime_config().drain_timeout;
tokio::pin!(second_signal);
let outcome = tokio::select! {
() = &mut second_signal => {
warn!("second shutdown signal received; forcing immediate exit");
ShutdownOutcome::Forced
}
result = wait_for_drain_or_timeout(&state, &drain, timeout) => result?,
};
if matches!(outcome, ShutdownOutcome::Forced) {
return Ok(outcome);
}
state.shutdown()?;
Ok(outcome)
}
async fn wait_for_drain_or_timeout(
state: &ServerState,
drain: &DrainState,
timeout: Duration,
) -> Result<ShutdownOutcome, ServerError> {
match tokio::time::timeout(timeout, drain.wait_for_empty(state)).await {
Ok(result) => {
result?;
info!("activity drain completed cleanly");
Ok(ShutdownOutcome::Clean)
}
Err(_elapsed) => {
let reports = state
.heartbeat_tracker()
.fail_all_in_flight_workers(state.worker_registry(), state.pending_activities())?;
log_lost_workers(&reports);
Ok(ShutdownOutcome::TimedOut)
}
}
}
fn log_lost_workers(reports: &[LostWorkerReport]) {
let failed_tasks: usize = reports.iter().map(|report| report.tasks.len()).sum();
if failed_tasks == 0 {
info!("activity drain timed out with no tracked in-flight activity failures");
} else {
error!(
failed_workers = reports.len(),
failed_tasks,
"activity drain timed out; remaining activities surfaced as retryable lost-worker failures"
);
}
}
#[cfg(test)]
mod tests {
use super::DrainState;
#[test]
fn begin_is_idempotent_and_sets_draining() {
let drain = DrainState::default();
assert!(!drain.is_draining());
assert!(drain.begin());
assert!(drain.is_draining());
assert!(!drain.begin());
}
}