use std::time::{Duration, Instant};
use crate::actor::{ActorId, ActorSupervisor};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ShutdownPhase {
Running,
Draining,
ForceKilling,
Terminated,
}
#[derive(Debug, Clone)]
pub struct ShutdownConfig {
pub drain_timeout: Duration,
pub checkpoint_on_drain: bool,
pub process_in_flight: bool,
pub ordering: ShutdownOrdering,
}
impl Default for ShutdownConfig {
fn default() -> Self {
Self {
drain_timeout: Duration::from_secs(30),
checkpoint_on_drain: true,
process_in_flight: true,
ordering: ShutdownOrdering::LeafFirst,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ShutdownOrdering {
LeafFirst,
ParentFirst,
Parallel,
}
pub struct ShutdownCoordinator {
phase: ShutdownPhase,
config: ShutdownConfig,
started_at: Option<Instant>,
draining_actors: Vec<ActorId>,
drained_actors: Vec<ActorId>,
force_killed: Vec<ActorId>,
}
impl ShutdownCoordinator {
pub fn new(config: ShutdownConfig) -> Self {
Self {
phase: ShutdownPhase::Running,
config,
started_at: None,
draining_actors: Vec::new(),
drained_actors: Vec::new(),
force_killed: Vec::new(),
}
}
pub fn initiate(&mut self, supervisor: &ActorSupervisor) -> Vec<ActorId> {
self.phase = ShutdownPhase::Draining;
self.started_at = Some(Instant::now());
let actors = self.compute_shutdown_order(supervisor);
self.draining_actors = actors.clone();
tracing::info!(
phase = "draining",
actors = actors.len(),
timeout_secs = self.config.drain_timeout.as_secs(),
"Initiating graceful shutdown"
);
actors
}
pub fn mark_drained(&mut self, actor: ActorId) {
self.drained_actors.push(actor);
}
pub fn is_timeout_expired(&self) -> bool {
self.started_at
.map(|start| start.elapsed() >= self.config.drain_timeout)
.unwrap_or(false)
}
pub fn tick(&mut self) -> (ShutdownPhase, Vec<ActorId>) {
match self.phase {
ShutdownPhase::Running => (self.phase, Vec::new()),
ShutdownPhase::Draining => {
let all_drained = self
.draining_actors
.iter()
.all(|a| self.drained_actors.contains(a));
if all_drained {
self.phase = ShutdownPhase::Terminated;
tracing::info!("All actors drained, shutdown complete");
(self.phase, Vec::new())
} else if self.is_timeout_expired() {
self.phase = ShutdownPhase::ForceKilling;
let remaining: Vec<ActorId> = self
.draining_actors
.iter()
.filter(|a| !self.drained_actors.contains(a))
.copied()
.collect();
tracing::warn!(
remaining = remaining.len(),
"Drain timeout expired, force-killing remaining actors"
);
self.force_killed = remaining.clone();
self.phase = ShutdownPhase::Terminated;
(ShutdownPhase::ForceKilling, remaining)
} else {
(self.phase, Vec::new())
}
}
ShutdownPhase::ForceKilling => {
self.phase = ShutdownPhase::Terminated;
(self.phase, Vec::new())
}
ShutdownPhase::Terminated => (self.phase, Vec::new()),
}
}
pub fn phase(&self) -> ShutdownPhase {
self.phase
}
pub fn elapsed(&self) -> Option<Duration> {
self.started_at.map(|s| s.elapsed())
}
pub fn report(&self) -> ShutdownReport {
ShutdownReport {
phase: self.phase,
total_actors: self.draining_actors.len(),
drained: self.drained_actors.len(),
force_killed: self.force_killed.len(),
elapsed: self.elapsed(),
checkpoint_enabled: self.config.checkpoint_on_drain,
}
}
fn compute_shutdown_order(&self, supervisor: &ActorSupervisor) -> Vec<ActorId> {
let mut order: Vec<ActorId> = supervisor
.entries()
.iter()
.filter(|e| e.actor_state().is_alive())
.map(|e| ActorId(e.actor_id))
.collect();
match self.config.ordering {
ShutdownOrdering::LeafFirst => {
order.sort_by(|a, b| {
let da = supervisor.depth(*a);
let db = supervisor.depth(*b);
db.cmp(&da) });
}
ShutdownOrdering::ParentFirst => {
order.sort_by(|a, b| {
let da = supervisor.depth(*a);
let db = supervisor.depth(*b);
da.cmp(&db)
});
}
ShutdownOrdering::Parallel => {
}
}
order
}
}
#[derive(Debug, Clone)]
pub struct ShutdownReport {
pub phase: ShutdownPhase,
pub total_actors: usize,
pub drained: usize,
pub force_killed: usize,
pub elapsed: Option<Duration>,
pub checkpoint_enabled: bool,
}
impl std::fmt::Display for ShutdownReport {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Shutdown: {} actors, {} drained, {} force-killed, {:?} elapsed",
self.total_actors,
self.drained,
self.force_killed,
self.elapsed.unwrap_or_default()
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::actor::ActorConfig;
#[test]
fn test_graceful_shutdown_all_drain() {
let mut supervisor = ActorSupervisor::new(8);
let config = ActorConfig::named("worker");
let a1 = supervisor.create_actor(&config, None).unwrap();
supervisor.activate_actor(a1).unwrap();
let a2 = supervisor.create_actor(&config, None).unwrap();
supervisor.activate_actor(a2).unwrap();
let mut coord = ShutdownCoordinator::new(ShutdownConfig::default());
let actors = coord.initiate(&supervisor);
assert_eq!(actors.len(), 2);
assert_eq!(coord.phase(), ShutdownPhase::Draining);
coord.mark_drained(a1);
coord.mark_drained(a2);
let (phase, force) = coord.tick();
assert_eq!(phase, ShutdownPhase::Terminated);
assert!(force.is_empty());
let report = coord.report();
assert_eq!(report.drained, 2);
assert_eq!(report.force_killed, 0);
}
#[test]
fn test_shutdown_timeout_force_kill() {
let mut supervisor = ActorSupervisor::new(8);
let config = ActorConfig::named("worker");
let a1 = supervisor.create_actor(&config, None).unwrap();
supervisor.activate_actor(a1).unwrap();
let mut coord = ShutdownCoordinator::new(ShutdownConfig {
drain_timeout: Duration::from_millis(1), ..Default::default()
});
coord.initiate(&supervisor);
std::thread::sleep(Duration::from_millis(5));
let (phase, force_killed) = coord.tick();
assert_eq!(phase, ShutdownPhase::ForceKilling);
assert_eq!(force_killed.len(), 1);
assert_eq!(force_killed[0], a1);
}
#[test]
fn test_leaf_first_ordering() {
let mut supervisor = ActorSupervisor::new(8);
let config = ActorConfig::named("node");
let root = supervisor.create_actor(&config, None).unwrap();
supervisor.activate_actor(root).unwrap();
let child = supervisor.create_actor(&config, Some(root)).unwrap();
supervisor.activate_actor(child).unwrap();
let grandchild = supervisor.create_actor(&config, Some(child)).unwrap();
supervisor.activate_actor(grandchild).unwrap();
let coord = ShutdownCoordinator::new(ShutdownConfig {
ordering: ShutdownOrdering::LeafFirst,
..Default::default()
});
let order = coord.compute_shutdown_order(&supervisor);
assert_eq!(order[0], grandchild);
assert_eq!(*order.last().unwrap(), root);
}
#[test]
fn test_shutdown_report_display() {
let report = ShutdownReport {
phase: ShutdownPhase::Terminated,
total_actors: 5,
drained: 4,
force_killed: 1,
elapsed: Some(Duration::from_secs(2)),
checkpoint_enabled: true,
};
let s = format!("{}", report);
assert!(s.contains("5 actors"));
assert!(s.contains("4 drained"));
assert!(s.contains("1 force-killed"));
}
}