use std::collections::BTreeMap;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use tokio::sync::watch;
use tokio::task::JoinHandle;
use tracing::{error, info};
use super::phase::ShutdownPhase;
use super::{LoopHandle, LoopRegistry, ShutdownWatch};
use crate::control::metrics::SystemMetrics;
pub const PHASE_BUDGET: Duration = Duration::from_millis(500);
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct TaskId(u64);
struct TaskEntry {
name: &'static str,
phase: ShutdownPhase,
drained: bool,
abort_handle: Option<tokio::task::AbortHandle>,
}
#[derive(Default)]
struct BusState {
tasks: BTreeMap<TaskId, TaskEntry>,
next_id: u64,
initiated: bool,
metrics: Option<Arc<SystemMetrics>>,
}
impl BusState {
fn alloc_id(&mut self) -> TaskId {
let id = TaskId(self.next_id);
self.next_id += 1;
id
}
fn pending_for_phase(&self, phase: ShutdownPhase) -> Vec<(TaskId, &'static str)> {
self.tasks
.iter()
.filter(|(_, e)| e.phase == phase && !e.drained)
.map(|(id, e)| (*id, e.name))
.collect()
}
fn abort_pending_for_phase(&mut self, phase: ShutdownPhase) {
for entry in self.tasks.values_mut() {
if entry.phase == phase && !entry.drained {
if let Some(ref h) = entry.abort_handle {
h.abort();
}
error!(
target: "shutdown",
phase = %phase,
offender = entry.name,
"task exceeded 500ms drain budget — aborting"
);
entry.drained = true; }
}
}
}
#[derive(Clone)]
pub struct ShutdownBus {
state: Arc<Mutex<BusState>>,
phase_tx: Arc<watch::Sender<ShutdownPhase>>,
flat_watch: Arc<ShutdownWatch>,
}
#[derive(Clone)]
pub struct ShutdownHandle {
phase_rx: watch::Receiver<ShutdownPhase>,
flat_watch: Arc<ShutdownWatch>,
}
pub struct DrainGuard {
task_id: TaskId,
phase: ShutdownPhase,
state: Arc<Mutex<BusState>>,
phase_rx: watch::Receiver<ShutdownPhase>,
reported: bool,
name: &'static str,
}
impl DrainGuard {
pub async fn await_signal(&mut self) {
if *self.phase_rx.borrow() >= self.phase {
return;
}
while self.phase_rx.changed().await.is_ok() {
if *self.phase_rx.borrow() >= self.phase {
return;
}
}
}
pub fn report_drained(mut self) {
self.reported = true;
let mut guard = lock_bus(&self.state);
if let Some(entry) = guard.tasks.get_mut(&self.task_id) {
entry.drained = true;
}
}
}
impl Drop for DrainGuard {
fn drop(&mut self) {
if !self.reported {
tracing::warn!(
target: "shutdown",
phase = %self.phase,
offender = self.name,
"DrainGuard dropped without report_drained — task may be a shutdown offender"
);
}
}
}
fn lock_bus(state: &Mutex<BusState>) -> std::sync::MutexGuard<'_, BusState> {
match state.lock() {
Ok(g) => g,
Err(p) => {
error!(target: "shutdown", "ShutdownBus mutex poisoned — recovering");
p.into_inner()
}
}
}
impl ShutdownBus {
pub fn new(flat_watch: Arc<ShutdownWatch>) -> (Self, ShutdownHandle) {
let (phase_tx, phase_rx) = watch::channel(ShutdownPhase::Running);
let phase_tx = Arc::new(phase_tx);
let bus = Self {
state: Arc::new(Mutex::new(BusState::default())),
phase_tx,
flat_watch: Arc::clone(&flat_watch),
};
let handle = ShutdownHandle {
phase_rx,
flat_watch,
};
(bus, handle)
}
pub fn register_task(
&self,
drain_at: ShutdownPhase,
name: &'static str,
abort_handle: Option<tokio::task::AbortHandle>,
) -> DrainGuard {
let mut guard = lock_bus(&self.state);
let id = guard.alloc_id();
guard.tasks.insert(
id,
TaskEntry {
name,
phase: drain_at,
drained: false,
abort_handle,
},
);
let phase_rx = self.phase_tx.subscribe();
DrainGuard {
task_id: id,
phase: drain_at,
state: Arc::clone(&self.state),
phase_rx,
reported: false,
name,
}
}
pub fn initiate(&self) -> JoinHandle<()> {
{
let mut guard = lock_bus(&self.state);
if guard.initiated {
return tokio::spawn(async {});
}
guard.initiated = true;
}
info!(target: "shutdown", "shutdown initiated");
self.flat_watch.signal();
let state = Arc::clone(&self.state);
let phase_tx = Arc::clone(&self.phase_tx);
tokio::spawn(async move {
let mut current = ShutdownPhase::Running;
while let Some(next) = current.next() {
phase_tx.send_replace(current);
let phase_start = std::time::Instant::now();
let deadline = tokio::time::Instant::now() + PHASE_BUDGET;
loop {
let pending = lock_bus(&state).pending_for_phase(current);
if pending.is_empty() {
break;
}
if tokio::time::Instant::now() >= deadline {
lock_bus(&state).abort_pending_for_phase(current);
break;
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
let phase_ms = phase_start.elapsed().as_millis() as u64;
{
let guard = lock_bus(&state);
if let Some(ref m) = guard.metrics {
m.record_shutdown_phase_duration(¤t.to_string(), phase_ms);
}
}
info!(
target: "shutdown",
phase = %current,
next_phase = %next,
duration_ms = phase_ms,
"shutdown phase complete"
);
current = next;
}
phase_tx.send_replace(ShutdownPhase::Closed);
info!(target: "shutdown", "shutdown complete");
})
}
pub fn current_phase(&self) -> ShutdownPhase {
*self.phase_tx.borrow()
}
pub fn set_metrics(&self, metrics: Arc<SystemMetrics>) {
let mut guard = lock_bus(&self.state);
guard.metrics = Some(metrics);
}
pub fn handle(&self) -> ShutdownHandle {
ShutdownHandle {
phase_rx: self.phase_tx.subscribe(),
flat_watch: Arc::clone(&self.flat_watch),
}
}
}
impl ShutdownHandle {
pub async fn await_phase(&mut self, phase: ShutdownPhase) {
if *self.phase_rx.borrow() >= phase {
return;
}
while self.phase_rx.changed().await.is_ok() {
if *self.phase_rx.borrow() >= phase {
return;
}
}
}
pub fn is_shutting_down(&self) -> bool {
*self.phase_rx.borrow() > ShutdownPhase::Running
}
pub fn flat_watch(&self) -> &Arc<ShutdownWatch> {
&self.flat_watch
}
}
pub fn spawn_drainable<F, Fut>(
registry: &LoopRegistry,
bus: &ShutdownBus,
drain_at: ShutdownPhase,
name: &'static str,
body: F,
) where
F: FnOnce(super::ShutdownReceiver, DrainGuard) -> Fut + Send + 'static,
Fut: std::future::Future<Output = ()> + Send + 'static,
{
let rx = bus.flat_watch.subscribe();
let guard = bus.register_task(drain_at, name, None);
let handle = tokio::spawn(async move { body(rx, guard).await });
let abort = handle.abort_handle();
if let Err(e) = registry.register(name, LoopHandle::Async(handle)) {
tracing::warn!(
error = %e,
"spawn_drainable after registry close — task will run to completion \
but shutdown_all will not wait for it"
);
}
drop(abort); }
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicBool, Ordering};
#[tokio::test]
async fn initiate_is_idempotent() {
let watch = Arc::new(ShutdownWatch::new());
let (bus, mut handle) = ShutdownBus::new(Arc::clone(&watch));
bus.initiate();
bus.initiate(); handle.await_phase(ShutdownPhase::Closed).await;
assert_eq!(bus.current_phase(), ShutdownPhase::Closed);
}
#[tokio::test]
async fn flat_watch_signaled_on_initiate() {
let watch = Arc::new(ShutdownWatch::new());
let (bus, _) = ShutdownBus::new(Arc::clone(&watch));
assert!(!watch.is_shutdown());
bus.initiate();
tokio::task::yield_now().await;
assert!(watch.is_shutdown());
}
#[tokio::test]
async fn registered_task_receives_drain_signal() {
let watch = Arc::new(ShutdownWatch::new());
let (bus, mut global_handle) = ShutdownBus::new(Arc::clone(&watch));
let drained = Arc::new(AtomicBool::new(false));
let drained_c = Arc::clone(&drained);
let mut guard = bus.register_task(ShutdownPhase::DrainingListeners, "test_task", None);
tokio::spawn(async move {
guard.await_signal().await;
drained_c.store(true, Ordering::SeqCst);
guard.report_drained();
});
bus.initiate();
global_handle.await_phase(ShutdownPhase::Closed).await;
assert!(drained.load(Ordering::SeqCst), "task did not drain");
}
#[tokio::test]
async fn offender_aborted_after_budget() {
let watch = Arc::new(ShutdownWatch::new());
let (bus, mut handle) = ShutdownBus::new(Arc::clone(&watch));
let _guard = bus.register_task(ShutdownPhase::DrainingListeners, "offender_task", None);
let start = tokio::time::Instant::now();
bus.initiate();
handle.await_phase(ShutdownPhase::Closed).await;
let elapsed = start.elapsed();
assert!(
elapsed < Duration::from_secs(10),
"shutdown did not terminate: {elapsed:?}"
);
}
#[tokio::test]
async fn await_phase_returns_immediately_if_already_past() {
let watch = Arc::new(ShutdownWatch::new());
let (bus, _) = ShutdownBus::new(Arc::clone(&watch));
bus.initiate();
let mut handle = bus.handle();
handle.await_phase(ShutdownPhase::Closed).await;
let mut handle2 = bus.handle();
tokio::time::timeout(
Duration::from_millis(10),
handle2.await_phase(ShutdownPhase::Running),
)
.await
.expect("await_phase(Running) should be immediate when already Closed");
}
}