1use std::time::{Duration, Instant};
10
11use crate::actor::{ActorId, ActorSupervisor};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum ShutdownPhase {
16 Running,
18 Draining,
20 ForceKilling,
22 Terminated,
24}
25
26#[derive(Debug, Clone)]
28pub struct ShutdownConfig {
29 pub drain_timeout: Duration,
31 pub checkpoint_on_drain: bool,
33 pub process_in_flight: bool,
35 pub ordering: ShutdownOrdering,
37}
38
39impl Default for ShutdownConfig {
40 fn default() -> Self {
41 Self {
42 drain_timeout: Duration::from_secs(30),
43 checkpoint_on_drain: true,
44 process_in_flight: true,
45 ordering: ShutdownOrdering::LeafFirst,
46 }
47 }
48}
49
50#[derive(Debug, Clone, Copy, PartialEq, Eq)]
52pub enum ShutdownOrdering {
53 LeafFirst,
56 ParentFirst,
58 Parallel,
60}
61
62pub struct ShutdownCoordinator {
64 phase: ShutdownPhase,
66 config: ShutdownConfig,
68 started_at: Option<Instant>,
70 draining_actors: Vec<ActorId>,
72 drained_actors: Vec<ActorId>,
74 force_killed: Vec<ActorId>,
76}
77
78impl ShutdownCoordinator {
79 pub fn new(config: ShutdownConfig) -> Self {
81 Self {
82 phase: ShutdownPhase::Running,
83 config,
84 started_at: None,
85 draining_actors: Vec::new(),
86 drained_actors: Vec::new(),
87 force_killed: Vec::new(),
88 }
89 }
90
91 pub fn initiate(&mut self, supervisor: &ActorSupervisor) -> Vec<ActorId> {
95 self.phase = ShutdownPhase::Draining;
96 self.started_at = Some(Instant::now());
97
98 let actors = self.compute_shutdown_order(supervisor);
100 self.draining_actors = actors.clone();
101
102 tracing::info!(
103 phase = "draining",
104 actors = actors.len(),
105 timeout_secs = self.config.drain_timeout.as_secs(),
106 "Initiating graceful shutdown"
107 );
108
109 actors
110 }
111
112 pub fn mark_drained(&mut self, actor: ActorId) {
114 self.drained_actors.push(actor);
115 }
116
117 pub fn is_timeout_expired(&self) -> bool {
119 self.started_at
120 .map(|start| start.elapsed() >= self.config.drain_timeout)
121 .unwrap_or(false)
122 }
123
124 pub fn tick(&mut self) -> (ShutdownPhase, Vec<ActorId>) {
128 match self.phase {
129 ShutdownPhase::Running => (self.phase, Vec::new()),
130
131 ShutdownPhase::Draining => {
132 let all_drained = self
134 .draining_actors
135 .iter()
136 .all(|a| self.drained_actors.contains(a));
137
138 if all_drained {
139 self.phase = ShutdownPhase::Terminated;
140 tracing::info!("All actors drained, shutdown complete");
141 (self.phase, Vec::new())
142 } else if self.is_timeout_expired() {
143 self.phase = ShutdownPhase::ForceKilling;
145 let remaining: Vec<ActorId> = self
146 .draining_actors
147 .iter()
148 .filter(|a| !self.drained_actors.contains(a))
149 .copied()
150 .collect();
151
152 tracing::warn!(
153 remaining = remaining.len(),
154 "Drain timeout expired, force-killing remaining actors"
155 );
156
157 self.force_killed = remaining.clone();
158 self.phase = ShutdownPhase::Terminated;
159 (ShutdownPhase::ForceKilling, remaining)
160 } else {
161 (self.phase, Vec::new())
162 }
163 }
164
165 ShutdownPhase::ForceKilling => {
166 self.phase = ShutdownPhase::Terminated;
167 (self.phase, Vec::new())
168 }
169
170 ShutdownPhase::Terminated => (self.phase, Vec::new()),
171 }
172 }
173
174 pub fn phase(&self) -> ShutdownPhase {
176 self.phase
177 }
178
179 pub fn elapsed(&self) -> Option<Duration> {
181 self.started_at.map(|s| s.elapsed())
182 }
183
184 pub fn report(&self) -> ShutdownReport {
186 ShutdownReport {
187 phase: self.phase,
188 total_actors: self.draining_actors.len(),
189 drained: self.drained_actors.len(),
190 force_killed: self.force_killed.len(),
191 elapsed: self.elapsed(),
192 checkpoint_enabled: self.config.checkpoint_on_drain,
193 }
194 }
195
196 fn compute_shutdown_order(&self, supervisor: &ActorSupervisor) -> Vec<ActorId> {
198 let mut order: Vec<ActorId> = supervisor
199 .entries()
200 .iter()
201 .filter(|e| e.actor_state().is_alive())
202 .map(|e| ActorId(e.actor_id))
203 .collect();
204
205 match self.config.ordering {
206 ShutdownOrdering::LeafFirst => {
207 order.sort_by(|a, b| {
209 let da = supervisor.depth(*a);
210 let db = supervisor.depth(*b);
211 db.cmp(&da) });
213 }
214 ShutdownOrdering::ParentFirst => {
215 order.sort_by(|a, b| {
217 let da = supervisor.depth(*a);
218 let db = supervisor.depth(*b);
219 da.cmp(&db)
220 });
221 }
222 ShutdownOrdering::Parallel => {
223 }
225 }
226
227 order
228 }
229}
230
231#[derive(Debug, Clone)]
233pub struct ShutdownReport {
234 pub phase: ShutdownPhase,
236 pub total_actors: usize,
238 pub drained: usize,
240 pub force_killed: usize,
242 pub elapsed: Option<Duration>,
244 pub checkpoint_enabled: bool,
246}
247
248impl std::fmt::Display for ShutdownReport {
249 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
250 write!(
251 f,
252 "Shutdown: {} actors, {} drained, {} force-killed, {:?} elapsed",
253 self.total_actors,
254 self.drained,
255 self.force_killed,
256 self.elapsed.unwrap_or_default()
257 )
258 }
259}
260
261#[cfg(test)]
262mod tests {
263 use super::*;
264 use crate::actor::ActorConfig;
265
266 #[test]
267 fn test_graceful_shutdown_all_drain() {
268 let mut supervisor = ActorSupervisor::new(8);
269 let config = ActorConfig::named("worker");
270
271 let a1 = supervisor.create_actor(&config, None).unwrap();
272 supervisor.activate_actor(a1).unwrap();
273 let a2 = supervisor.create_actor(&config, None).unwrap();
274 supervisor.activate_actor(a2).unwrap();
275
276 let mut coord = ShutdownCoordinator::new(ShutdownConfig::default());
277 let actors = coord.initiate(&supervisor);
278 assert_eq!(actors.len(), 2);
279 assert_eq!(coord.phase(), ShutdownPhase::Draining);
280
281 coord.mark_drained(a1);
283 coord.mark_drained(a2);
284
285 let (phase, force) = coord.tick();
286 assert_eq!(phase, ShutdownPhase::Terminated);
287 assert!(force.is_empty());
288
289 let report = coord.report();
290 assert_eq!(report.drained, 2);
291 assert_eq!(report.force_killed, 0);
292 }
293
294 #[test]
295 fn test_shutdown_timeout_force_kill() {
296 let mut supervisor = ActorSupervisor::new(8);
297 let config = ActorConfig::named("worker");
298
299 let a1 = supervisor.create_actor(&config, None).unwrap();
300 supervisor.activate_actor(a1).unwrap();
301
302 let mut coord = ShutdownCoordinator::new(ShutdownConfig {
303 drain_timeout: Duration::from_millis(1), ..Default::default()
305 });
306
307 coord.initiate(&supervisor);
308 std::thread::sleep(Duration::from_millis(5));
311
312 let (phase, force_killed) = coord.tick();
313 assert_eq!(phase, ShutdownPhase::ForceKilling);
314 assert_eq!(force_killed.len(), 1);
315 assert_eq!(force_killed[0], a1);
316 }
317
318 #[test]
319 fn test_leaf_first_ordering() {
320 let mut supervisor = ActorSupervisor::new(8);
321 let config = ActorConfig::named("node");
322
323 let root = supervisor.create_actor(&config, None).unwrap();
324 supervisor.activate_actor(root).unwrap();
325 let child = supervisor.create_actor(&config, Some(root)).unwrap();
326 supervisor.activate_actor(child).unwrap();
327 let grandchild = supervisor.create_actor(&config, Some(child)).unwrap();
328 supervisor.activate_actor(grandchild).unwrap();
329
330 let coord = ShutdownCoordinator::new(ShutdownConfig {
331 ordering: ShutdownOrdering::LeafFirst,
332 ..Default::default()
333 });
334
335 let order = coord.compute_shutdown_order(&supervisor);
336 assert_eq!(order[0], grandchild);
338 assert_eq!(*order.last().unwrap(), root);
340 }
341
342 #[test]
343 fn test_shutdown_report_display() {
344 let report = ShutdownReport {
345 phase: ShutdownPhase::Terminated,
346 total_actors: 5,
347 drained: 4,
348 force_killed: 1,
349 elapsed: Some(Duration::from_secs(2)),
350 checkpoint_enabled: true,
351 };
352 let s = format!("{}", report);
353 assert!(s.contains("5 actors"));
354 assert!(s.contains("4 drained"));
355 assert!(s.contains("1 force-killed"));
356 }
357}