1use crate::error::StreamError;
22use std::collections::{HashMap, HashSet};
23use std::time::{Duration, Instant};
24
25#[derive(Debug, Clone, PartialEq)]
29pub enum CheckpointPhase {
30 Idle,
32 InProgress {
34 checkpoint_id: u64,
35 started_at: Instant,
36 acked_operators: HashSet<String>,
37 },
38 Completed {
40 checkpoint_id: u64,
41 completed_at: Instant,
42 size_bytes: usize,
43 },
44 Failed { checkpoint_id: u64, reason: String },
46}
47
48impl CheckpointPhase {
49 pub fn checkpoint_id(&self) -> Option<u64> {
51 match self {
52 Self::Idle => None,
53 Self::InProgress { checkpoint_id, .. } => Some(*checkpoint_id),
54 Self::Completed { checkpoint_id, .. } => Some(*checkpoint_id),
55 Self::Failed { checkpoint_id, .. } => Some(*checkpoint_id),
56 }
57 }
58}
59
60#[derive(Debug, Clone)]
62pub struct OperatorSnapshot {
63 pub operator_id: String,
65 pub checkpoint_id: u64,
67 pub state_bytes: Vec<u8>,
69 pub in_flight_messages: Vec<Vec<u8>>,
71 pub created_at: Instant,
73 pub size_bytes: usize,
75}
76
77impl OperatorSnapshot {
78 pub fn new(
80 operator_id: impl Into<String>,
81 checkpoint_id: u64,
82 state_bytes: Vec<u8>,
83 in_flight_messages: Vec<Vec<u8>>,
84 ) -> Self {
85 let in_flight_size: usize = in_flight_messages.iter().map(|m| m.len()).sum();
86 let size_bytes = state_bytes.len() + in_flight_size;
87 Self {
88 operator_id: operator_id.into(),
89 checkpoint_id,
90 state_bytes,
91 in_flight_messages,
92 created_at: Instant::now(),
93 size_bytes,
94 }
95 }
96}
97
98#[derive(Debug, Clone)]
100pub struct GlobalCheckpoint {
101 pub checkpoint_id: u64,
102 pub operator_snapshots: HashMap<String, OperatorSnapshot>,
103 pub created_at: Instant,
104 pub total_size_bytes: usize,
105 pub stream_positions: HashMap<String, u64>,
108}
109
110impl GlobalCheckpoint {
111 pub fn new(checkpoint_id: u64) -> Self {
113 Self {
114 checkpoint_id,
115 operator_snapshots: HashMap::new(),
116 created_at: Instant::now(),
117 total_size_bytes: 0,
118 stream_positions: HashMap::new(),
119 }
120 }
121
122 pub fn add_operator_snapshot(&mut self, snapshot: OperatorSnapshot) {
124 self.total_size_bytes += snapshot.size_bytes;
125 self.operator_snapshots
126 .insert(snapshot.operator_id.clone(), snapshot);
127 }
128
129 pub fn set_stream_position(&mut self, stream_id: impl Into<String>, offset: u64) {
131 self.stream_positions.insert(stream_id.into(), offset);
132 }
133
134 pub fn is_complete(&self, expected_operators: &[String]) -> bool {
136 expected_operators
137 .iter()
138 .all(|op_id| self.operator_snapshots.contains_key(op_id))
139 }
140
141 pub fn total_bytes(&self) -> usize {
143 self.total_size_bytes
144 }
145}
146
147pub struct CheckpointCoordinator {
154 current_phase: CheckpointPhase,
155 checkpoint_interval: Duration,
157 last_checkpoint: Option<Instant>,
159 completed_checkpoints: Vec<GlobalCheckpoint>,
161 max_retained_checkpoints: usize,
163 registered_operators: Vec<String>,
165 next_checkpoint_id: u64,
167 in_progress_checkpoint: Option<GlobalCheckpoint>,
169 operator_timeout: Duration,
171}
172
173impl CheckpointCoordinator {
174 pub fn new(interval: Duration) -> Self {
176 Self {
177 current_phase: CheckpointPhase::Idle,
178 checkpoint_interval: interval,
179 last_checkpoint: None,
180 completed_checkpoints: Vec::new(),
181 max_retained_checkpoints: 10,
182 registered_operators: Vec::new(),
183 next_checkpoint_id: 1,
184 in_progress_checkpoint: None,
185 operator_timeout: Duration::from_secs(60),
186 }
187 }
188
189 pub fn with_max_retained(mut self, n: usize) -> Self {
191 self.max_retained_checkpoints = n;
192 self
193 }
194
195 pub fn with_operator_timeout(mut self, timeout: Duration) -> Self {
197 self.operator_timeout = timeout;
198 self
199 }
200
201 pub fn register_operator(&mut self, operator_id: String) {
203 if !self.registered_operators.contains(&operator_id) {
204 self.registered_operators.push(operator_id);
205 }
206 }
207
208 pub fn register_operators(&mut self, operator_ids: impl IntoIterator<Item = String>) {
210 for id in operator_ids {
211 self.register_operator(id);
212 }
213 }
214
215 pub fn should_checkpoint(&self) -> bool {
218 if !matches!(self.current_phase, CheckpointPhase::Idle) {
219 return false;
220 }
221
222 match self.last_checkpoint {
223 None => true,
224 Some(last) => last.elapsed() >= self.checkpoint_interval,
225 }
226 }
227
228 pub fn initiate(&mut self) -> Result<u64, StreamError> {
235 if !matches!(self.current_phase, CheckpointPhase::Idle) {
236 return Err(StreamError::InvalidOperation(format!(
237 "cannot initiate checkpoint while in phase {:?}",
238 self.current_phase.checkpoint_id()
239 )));
240 }
241
242 let checkpoint_id = self.next_checkpoint_id;
243 self.next_checkpoint_id += 1;
244
245 self.current_phase = CheckpointPhase::InProgress {
246 checkpoint_id,
247 started_at: Instant::now(),
248 acked_operators: HashSet::new(),
249 };
250
251 self.in_progress_checkpoint = Some(GlobalCheckpoint::new(checkpoint_id));
252 self.last_checkpoint = Some(Instant::now());
253
254 Ok(checkpoint_id)
255 }
256
257 pub fn operator_reported(&mut self, snapshot: OperatorSnapshot) -> Result<bool, StreamError> {
262 let (checkpoint_id, started_at) = match &self.current_phase {
263 CheckpointPhase::InProgress {
264 checkpoint_id,
265 started_at,
266 ..
267 } => (*checkpoint_id, *started_at),
268 other => {
269 return Err(StreamError::InvalidOperation(format!(
270 "operator_reported called but coordinator is in {:?} phase",
271 other.checkpoint_id()
272 )));
273 }
274 };
275
276 if snapshot.checkpoint_id != checkpoint_id {
277 return Err(StreamError::InvalidInput(format!(
278 "snapshot checkpoint_id {} does not match in-progress {}",
279 snapshot.checkpoint_id, checkpoint_id
280 )));
281 }
282
283 if started_at.elapsed() > self.operator_timeout {
285 let reason = format!(
286 "operator {} timed out after {:?}",
287 snapshot.operator_id,
288 started_at.elapsed()
289 );
290 self.abort(&reason);
291 return Err(StreamError::Timeout(reason));
292 }
293
294 let operator_id = snapshot.operator_id.clone();
296 if let CheckpointPhase::InProgress {
297 ref mut acked_operators,
298 ..
299 } = self.current_phase
300 {
301 acked_operators.insert(operator_id.clone());
302 }
303
304 if let Some(ref mut global) = self.in_progress_checkpoint {
306 global.add_operator_snapshot(snapshot);
307 }
308
309 let all_done = if let CheckpointPhase::InProgress {
311 ref acked_operators,
312 ..
313 } = self.current_phase
314 {
315 self.registered_operators
316 .iter()
317 .all(|op| acked_operators.contains(op))
318 } else {
319 false
320 };
321
322 if all_done {
323 self.finalize_checkpoint(checkpoint_id)?;
324 return Ok(true);
325 }
326
327 Ok(false)
328 }
329
330 fn finalize_checkpoint(&mut self, checkpoint_id: u64) -> Result<(), StreamError> {
331 let global = self.in_progress_checkpoint.take().ok_or_else(|| {
332 StreamError::Other("in_progress_checkpoint missing at finalize".into())
333 })?;
334
335 let size_bytes = global.total_size_bytes;
336
337 self.completed_checkpoints.push(global);
338
339 while self.completed_checkpoints.len() > self.max_retained_checkpoints {
341 self.completed_checkpoints.remove(0);
342 }
343
344 self.current_phase = CheckpointPhase::Completed {
345 checkpoint_id,
346 completed_at: Instant::now(),
347 size_bytes,
348 };
349
350 Ok(())
351 }
352
353 pub fn abort(&mut self, reason: &str) {
355 let checkpoint_id = self.current_phase.checkpoint_id().unwrap_or(0);
356 self.in_progress_checkpoint = None;
357 self.current_phase = CheckpointPhase::Failed {
358 checkpoint_id,
359 reason: reason.to_string(),
360 };
361 }
362
363 pub fn reset_to_idle(&mut self) {
365 match self.current_phase {
366 CheckpointPhase::Completed { .. } | CheckpointPhase::Failed { .. } => {
367 self.current_phase = CheckpointPhase::Idle;
368 }
369 _ => {}
370 }
371 }
372
373 pub fn latest_checkpoint(&self) -> Option<&GlobalCheckpoint> {
375 self.completed_checkpoints.last()
376 }
377
378 pub fn get_checkpoint(&self, id: u64) -> Option<&GlobalCheckpoint> {
380 self.completed_checkpoints
381 .iter()
382 .find(|cp| cp.checkpoint_id == id)
383 }
384
385 pub fn completed_count(&self) -> usize {
387 self.completed_checkpoints.len()
388 }
389
390 pub fn current_checkpoint_id(&self) -> Option<u64> {
392 match &self.current_phase {
393 CheckpointPhase::InProgress { checkpoint_id, .. } => Some(*checkpoint_id),
394 _ => None,
395 }
396 }
397
398 pub fn phase(&self) -> &CheckpointPhase {
400 &self.current_phase
401 }
402
403 pub fn pending_operators(&self) -> Option<Vec<String>> {
407 if let CheckpointPhase::InProgress {
408 ref acked_operators,
409 ..
410 } = self.current_phase
411 {
412 let pending: Vec<String> = self
413 .registered_operators
414 .iter()
415 .filter(|op| !acked_operators.contains(*op))
416 .cloned()
417 .collect();
418 Some(pending)
419 } else {
420 None
421 }
422 }
423}
424
425#[cfg(test)]
426mod tests {
427 use super::*;
428
429 fn make_snapshot(operator_id: &str, checkpoint_id: u64, state: &[u8]) -> OperatorSnapshot {
430 OperatorSnapshot::new(operator_id, checkpoint_id, state.to_vec(), vec![])
431 }
432
433 #[test]
436 fn test_phase_checkpoint_id() {
437 let idle = CheckpointPhase::Idle;
438 assert_eq!(idle.checkpoint_id(), None);
439
440 let in_progress = CheckpointPhase::InProgress {
441 checkpoint_id: 7,
442 started_at: Instant::now(),
443 acked_operators: HashSet::new(),
444 };
445 assert_eq!(in_progress.checkpoint_id(), Some(7));
446 }
447
448 #[test]
451 fn test_global_checkpoint_completeness() {
452 let mut cp = GlobalCheckpoint::new(1);
453 let ops = vec!["op_a".to_string(), "op_b".to_string()];
454
455 assert!(!cp.is_complete(&ops));
456
457 cp.add_operator_snapshot(make_snapshot("op_a", 1, b"state_a"));
458 assert!(!cp.is_complete(&ops));
459
460 cp.add_operator_snapshot(make_snapshot("op_b", 1, b"state_b"));
461 assert!(cp.is_complete(&ops));
462 }
463
464 #[test]
465 fn test_global_checkpoint_bytes() {
466 let mut cp = GlobalCheckpoint::new(1);
467 cp.add_operator_snapshot(make_snapshot("op_a", 1, &[0u8; 100]));
468 cp.add_operator_snapshot(make_snapshot("op_b", 1, &[0u8; 200]));
469 assert_eq!(cp.total_bytes(), 300);
470 }
471
472 #[test]
475 fn test_should_checkpoint_when_no_last() {
476 let coord = CheckpointCoordinator::new(Duration::from_secs(60));
477 assert!(coord.should_checkpoint()); }
479
480 #[test]
481 fn test_should_not_checkpoint_when_in_progress() {
482 let mut coord = CheckpointCoordinator::new(Duration::from_secs(60));
483 coord.register_operator("op1".to_string());
484 coord.initiate().unwrap();
485 assert!(!coord.should_checkpoint());
486 }
487
488 #[test]
489 fn test_initiate_returns_incrementing_ids() {
490 let mut coord = CheckpointCoordinator::new(Duration::from_millis(0));
491 coord.register_operator("op1".to_string());
492
493 let id1 = coord.initiate().unwrap();
494 assert_eq!(id1, 1);
495 assert_eq!(coord.current_checkpoint_id(), Some(1));
496
497 let snap = make_snapshot("op1", 1, b"state");
499 coord.operator_reported(snap).unwrap();
500 coord.reset_to_idle();
501
502 let id2 = coord.initiate().unwrap();
503 assert_eq!(id2, 2);
504 }
505
506 #[test]
507 fn test_single_operator_full_lifecycle() {
508 let mut coord = CheckpointCoordinator::new(Duration::from_secs(300));
509 coord.register_operator("worker".to_string());
510
511 assert!(coord.should_checkpoint());
512
513 let cp_id = coord.initiate().unwrap();
514 assert_eq!(cp_id, 1);
515
516 let snap = make_snapshot("worker", cp_id, b"my_state_data");
517 let complete = coord.operator_reported(snap).unwrap();
518 assert!(complete);
519
520 assert_eq!(coord.completed_count(), 1);
521 let latest = coord.latest_checkpoint().unwrap();
522 assert_eq!(latest.checkpoint_id, 1);
523 assert!(latest.operator_snapshots.contains_key("worker"));
524 assert_eq!(latest.total_bytes(), 13); coord.reset_to_idle();
527 assert!(!coord.should_checkpoint()); }
529
530 #[test]
531 fn test_multi_operator_checkpoint() {
532 let mut coord = CheckpointCoordinator::new(Duration::from_secs(300));
533 coord.register_operators(["op_a".to_string(), "op_b".to_string(), "op_c".to_string()]);
534
535 let cp_id = coord.initiate().unwrap();
536
537 let not_done = coord
539 .operator_reported(make_snapshot("op_a", cp_id, b"state_a"))
540 .unwrap();
541 assert!(!not_done);
542
543 let not_done2 = coord
544 .operator_reported(make_snapshot("op_b", cp_id, b"state_b"))
545 .unwrap();
546 assert!(!not_done2);
547
548 let pending = coord.pending_operators().unwrap();
550 assert_eq!(pending, vec!["op_c".to_string()]);
551
552 let done = coord
554 .operator_reported(make_snapshot("op_c", cp_id, b"state_c"))
555 .unwrap();
556 assert!(done);
557
558 let cp = coord.get_checkpoint(cp_id).unwrap();
559 assert_eq!(cp.operator_snapshots.len(), 3);
560 }
561
562 #[test]
563 fn test_abort_checkpoint() {
564 let mut coord = CheckpointCoordinator::new(Duration::from_secs(300));
565 coord.register_operator("op".to_string());
566
567 coord.initiate().unwrap();
568 coord.abort("operator crashed");
569
570 assert!(matches!(coord.phase(), CheckpointPhase::Failed { .. }));
571 assert_eq!(coord.completed_count(), 0);
572
573 coord.reset_to_idle();
574 assert!(matches!(coord.phase(), CheckpointPhase::Idle));
575 }
576
577 #[test]
578 fn test_max_retained_checkpoints() {
579 let mut coord = CheckpointCoordinator::new(Duration::from_millis(0)).with_max_retained(3);
580 coord.register_operator("op".to_string());
581
582 for _ in 0..5 {
583 coord.initiate().unwrap();
584 let cp_id = coord.current_checkpoint_id().unwrap();
585 coord
586 .operator_reported(make_snapshot("op", cp_id, b"s"))
587 .unwrap();
588 coord.reset_to_idle();
589 }
590
591 assert_eq!(coord.completed_count(), 3);
592 assert_eq!(coord.latest_checkpoint().unwrap().checkpoint_id, 5);
594 }
595
596 #[test]
597 fn test_wrong_checkpoint_id_rejected() {
598 let mut coord = CheckpointCoordinator::new(Duration::from_secs(300));
599 coord.register_operator("op".to_string());
600
601 coord.initiate().unwrap(); let snap = make_snapshot("op", 999, b"state");
605 let result = coord.operator_reported(snap);
606 assert!(result.is_err());
607 }
608
609 #[test]
610 fn test_duplicate_initiate_fails() {
611 let mut coord = CheckpointCoordinator::new(Duration::from_secs(300));
612 coord.register_operator("op".to_string());
613
614 coord.initiate().unwrap();
615 let result = coord.initiate();
616 assert!(result.is_err());
617 }
618
619 #[test]
620 fn test_get_checkpoint_by_id() {
621 let mut coord = CheckpointCoordinator::new(Duration::from_millis(0));
622 coord.register_operator("op".to_string());
623
624 for _ in 0..3 {
625 coord.initiate().unwrap();
626 let cp_id = coord.current_checkpoint_id().unwrap();
627 coord
628 .operator_reported(make_snapshot("op", cp_id, b"s"))
629 .unwrap();
630 coord.reset_to_idle();
631 }
632
633 assert!(coord.get_checkpoint(1).is_some());
634 assert!(coord.get_checkpoint(2).is_some());
635 assert!(coord.get_checkpoint(3).is_some());
636 assert!(coord.get_checkpoint(99).is_none());
637 }
638
639 #[test]
640 fn test_operator_snapshot_size() {
641 let state = vec![0u8; 500];
642 let in_flight = vec![vec![0u8; 100], vec![0u8; 50]];
643 let snap = OperatorSnapshot::new("op", 1, state, in_flight);
644 assert_eq!(snap.size_bytes, 650);
645 }
646
647 #[test]
648 fn test_stream_positions() {
649 let mut cp = GlobalCheckpoint::new(1);
650 cp.set_stream_position("topic-A", 1024);
651 cp.set_stream_position("topic-B", 2048);
652
653 assert_eq!(cp.stream_positions.get("topic-A"), Some(&1024));
654 assert_eq!(cp.stream_positions.get("topic-B"), Some(&2048));
655 }
656}