Skip to main content

oxirs_stream/checkpoint/
coordinator.rs

1//! # Chandy-Lamport Checkpoint Coordinator
2//!
3//! Distributed checkpoint coordinator inspired by the Chandy-Lamport algorithm
4//! for consistent global snapshots.  Enables failure recovery by periodically
5//! snapshotting all operator states.
6//!
7//! ## How it works
8//!
9//! 1. The coordinator decides it is time for a new checkpoint.
10//! 2. It broadcasts a *barrier marker* to every registered operator (modelled
11//!    here as the `initiate()` call returning a `checkpoint_id` that is
12//!    forwarded to operators out-of-band).
13//! 3. Each operator:
14//!    a. Drains its in-flight channel messages.
15//!    b. Serialises its state via its `StateBackend`.
16//!    c. Calls `coordinator.operator_reported(snapshot)`.
17//! 4. Once all operators have acknowledged the coordinator declares the
18//!    checkpoint *complete* and stores the `GlobalCheckpoint`.
19//! 5. On recovery the coordinator replays the latest `GlobalCheckpoint`.
20
21use crate::error::StreamError;
22use std::collections::{HashMap, HashSet};
23use std::time::{Duration, Instant};
24
25// ─── Types ────────────────────────────────────────────────────────────────────
26
27/// Lifecycle phase of the checkpoint coordinator.
28#[derive(Debug, Clone, PartialEq)]
29pub enum CheckpointPhase {
30    /// No checkpoint in progress.
31    Idle,
32    /// Coordinator sent barrier markers; waiting for operators to respond.
33    InProgress {
34        checkpoint_id: u64,
35        started_at: Instant,
36        acked_operators: HashSet<String>,
37    },
38    /// All operators have reported; checkpoint is permanently stored.
39    Completed {
40        checkpoint_id: u64,
41        completed_at: Instant,
42        size_bytes: usize,
43    },
44    /// Checkpoint was abandoned.
45    Failed { checkpoint_id: u64, reason: String },
46}
47
48impl CheckpointPhase {
49    /// Return the checkpoint ID for phases that have one.
50    pub fn checkpoint_id(&self) -> Option<u64> {
51        match self {
52            Self::Idle => None,
53            Self::InProgress { checkpoint_id, .. } => Some(*checkpoint_id),
54            Self::Completed { checkpoint_id, .. } => Some(*checkpoint_id),
55            Self::Failed { checkpoint_id, .. } => Some(*checkpoint_id),
56        }
57    }
58}
59
60/// Snapshot of a single operator's state at checkpoint time.
61#[derive(Debug, Clone)]
62pub struct OperatorSnapshot {
63    /// ID of the operator that produced this snapshot.
64    pub operator_id: String,
65    /// The checkpoint this snapshot belongs to.
66    pub checkpoint_id: u64,
67    /// Serialised state bytes (format is operator-specific).
68    pub state_bytes: Vec<u8>,
69    /// In-flight messages that were in the channel when the barrier arrived.
70    pub in_flight_messages: Vec<Vec<u8>>,
71    /// Wall-clock time at snapshot creation.
72    pub created_at: Instant,
73    /// Total byte size for budget tracking.
74    pub size_bytes: usize,
75}
76
77impl OperatorSnapshot {
78    /// Create a new operator snapshot, auto-computing `size_bytes`.
79    pub fn new(
80        operator_id: impl Into<String>,
81        checkpoint_id: u64,
82        state_bytes: Vec<u8>,
83        in_flight_messages: Vec<Vec<u8>>,
84    ) -> Self {
85        let in_flight_size: usize = in_flight_messages.iter().map(|m| m.len()).sum();
86        let size_bytes = state_bytes.len() + in_flight_size;
87        Self {
88            operator_id: operator_id.into(),
89            checkpoint_id,
90            state_bytes,
91            in_flight_messages,
92            created_at: Instant::now(),
93            size_bytes,
94        }
95    }
96}
97
98/// Consistent global snapshot across all operators in the streaming job.
99#[derive(Debug, Clone)]
100pub struct GlobalCheckpoint {
101    pub checkpoint_id: u64,
102    pub operator_snapshots: HashMap<String, OperatorSnapshot>,
103    pub created_at: Instant,
104    pub total_size_bytes: usize,
105    /// Stream read offsets at the time of the checkpoint.
106    /// Maps `stream_id → offset`, allowing the job to replay exactly from here.
107    pub stream_positions: HashMap<String, u64>,
108}
109
110impl GlobalCheckpoint {
111    /// Create an empty global checkpoint container.
112    pub fn new(checkpoint_id: u64) -> Self {
113        Self {
114            checkpoint_id,
115            operator_snapshots: HashMap::new(),
116            created_at: Instant::now(),
117            total_size_bytes: 0,
118            stream_positions: HashMap::new(),
119        }
120    }
121
122    /// Add an operator snapshot to this global checkpoint.
123    pub fn add_operator_snapshot(&mut self, snapshot: OperatorSnapshot) {
124        self.total_size_bytes += snapshot.size_bytes;
125        self.operator_snapshots
126            .insert(snapshot.operator_id.clone(), snapshot);
127    }
128
129    /// Set the committed read offset for a stream.
130    pub fn set_stream_position(&mut self, stream_id: impl Into<String>, offset: u64) {
131        self.stream_positions.insert(stream_id.into(), offset);
132    }
133
134    /// Returns `true` when every expected operator has contributed a snapshot.
135    pub fn is_complete(&self, expected_operators: &[String]) -> bool {
136        expected_operators
137            .iter()
138            .all(|op_id| self.operator_snapshots.contains_key(op_id))
139    }
140
141    /// Total byte size of all operator state in this checkpoint.
142    pub fn total_bytes(&self) -> usize {
143        self.total_size_bytes
144    }
145}
146
147// ─── Coordinator ─────────────────────────────────────────────────────────────
148
149/// Central checkpoint coordinator.
150///
151/// Manages the lifecycle of periodic checkpoints: scheduling, barrier
152/// initiation, acknowledgement collection, and storage of completed snapshots.
153pub struct CheckpointCoordinator {
154    current_phase: CheckpointPhase,
155    /// How often to trigger a checkpoint.
156    checkpoint_interval: Duration,
157    /// Wall-clock time when the last checkpoint completed (or was initiated).
158    last_checkpoint: Option<Instant>,
159    /// Completed checkpoints, newest last.
160    completed_checkpoints: Vec<GlobalCheckpoint>,
161    /// Maximum number of completed checkpoints to retain in memory.
162    max_retained_checkpoints: usize,
163    /// Operators that must acknowledge before a checkpoint is complete.
164    registered_operators: Vec<String>,
165    /// Monotonically increasing checkpoint counter.
166    next_checkpoint_id: u64,
167    /// In-progress global checkpoint being assembled.
168    in_progress_checkpoint: Option<GlobalCheckpoint>,
169    /// Timeout for operators to acknowledge before a checkpoint is aborted.
170    operator_timeout: Duration,
171}
172
173impl CheckpointCoordinator {
174    /// Create a coordinator with the given checkpoint interval.
175    pub fn new(interval: Duration) -> Self {
176        Self {
177            current_phase: CheckpointPhase::Idle,
178            checkpoint_interval: interval,
179            last_checkpoint: None,
180            completed_checkpoints: Vec::new(),
181            max_retained_checkpoints: 10,
182            registered_operators: Vec::new(),
183            next_checkpoint_id: 1,
184            in_progress_checkpoint: None,
185            operator_timeout: Duration::from_secs(60),
186        }
187    }
188
189    /// Override the maximum number of retained completed checkpoints.
190    pub fn with_max_retained(mut self, n: usize) -> Self {
191        self.max_retained_checkpoints = n;
192        self
193    }
194
195    /// Override the per-operator acknowledgement timeout.
196    pub fn with_operator_timeout(mut self, timeout: Duration) -> Self {
197        self.operator_timeout = timeout;
198        self
199    }
200
201    /// Register an operator that must participate in every checkpoint.
202    pub fn register_operator(&mut self, operator_id: String) {
203        if !self.registered_operators.contains(&operator_id) {
204            self.registered_operators.push(operator_id);
205        }
206    }
207
208    /// Register multiple operators at once.
209    pub fn register_operators(&mut self, operator_ids: impl IntoIterator<Item = String>) {
210        for id in operator_ids {
211            self.register_operator(id);
212        }
213    }
214
215    /// Returns `true` if enough time has elapsed and no checkpoint is in
216    /// progress.
217    pub fn should_checkpoint(&self) -> bool {
218        if !matches!(self.current_phase, CheckpointPhase::Idle) {
219            return false;
220        }
221
222        match self.last_checkpoint {
223            None => true,
224            Some(last) => last.elapsed() >= self.checkpoint_interval,
225        }
226    }
227
228    /// Initiate a new checkpoint.
229    ///
230    /// The returned `checkpoint_id` should be forwarded to all registered
231    /// operators as a barrier token.
232    ///
233    /// Returns an error if a checkpoint is already in progress.
234    pub fn initiate(&mut self) -> Result<u64, StreamError> {
235        if !matches!(self.current_phase, CheckpointPhase::Idle) {
236            return Err(StreamError::InvalidOperation(format!(
237                "cannot initiate checkpoint while in phase {:?}",
238                self.current_phase.checkpoint_id()
239            )));
240        }
241
242        let checkpoint_id = self.next_checkpoint_id;
243        self.next_checkpoint_id += 1;
244
245        self.current_phase = CheckpointPhase::InProgress {
246            checkpoint_id,
247            started_at: Instant::now(),
248            acked_operators: HashSet::new(),
249        };
250
251        self.in_progress_checkpoint = Some(GlobalCheckpoint::new(checkpoint_id));
252        self.last_checkpoint = Some(Instant::now());
253
254        Ok(checkpoint_id)
255    }
256
257    /// Called by an operator after it has snapshotted its state.
258    ///
259    /// Returns `true` when all operators have acknowledged (checkpoint
260    /// complete).
261    pub fn operator_reported(&mut self, snapshot: OperatorSnapshot) -> Result<bool, StreamError> {
262        let (checkpoint_id, started_at) = match &self.current_phase {
263            CheckpointPhase::InProgress {
264                checkpoint_id,
265                started_at,
266                ..
267            } => (*checkpoint_id, *started_at),
268            other => {
269                return Err(StreamError::InvalidOperation(format!(
270                    "operator_reported called but coordinator is in {:?} phase",
271                    other.checkpoint_id()
272                )));
273            }
274        };
275
276        if snapshot.checkpoint_id != checkpoint_id {
277            return Err(StreamError::InvalidInput(format!(
278                "snapshot checkpoint_id {} does not match in-progress {}",
279                snapshot.checkpoint_id, checkpoint_id
280            )));
281        }
282
283        // Check operator timeout
284        if started_at.elapsed() > self.operator_timeout {
285            let reason = format!(
286                "operator {} timed out after {:?}",
287                snapshot.operator_id,
288                started_at.elapsed()
289            );
290            self.abort(&reason);
291            return Err(StreamError::Timeout(reason));
292        }
293
294        // Record acknowledgement
295        let operator_id = snapshot.operator_id.clone();
296        if let CheckpointPhase::InProgress {
297            ref mut acked_operators,
298            ..
299        } = self.current_phase
300        {
301            acked_operators.insert(operator_id.clone());
302        }
303
304        // Accumulate into the global snapshot
305        if let Some(ref mut global) = self.in_progress_checkpoint {
306            global.add_operator_snapshot(snapshot);
307        }
308
309        // Check if all operators have reported
310        let all_done = if let CheckpointPhase::InProgress {
311            ref acked_operators,
312            ..
313        } = self.current_phase
314        {
315            self.registered_operators
316                .iter()
317                .all(|op| acked_operators.contains(op))
318        } else {
319            false
320        };
321
322        if all_done {
323            self.finalize_checkpoint(checkpoint_id)?;
324            return Ok(true);
325        }
326
327        Ok(false)
328    }
329
330    fn finalize_checkpoint(&mut self, checkpoint_id: u64) -> Result<(), StreamError> {
331        let global = self.in_progress_checkpoint.take().ok_or_else(|| {
332            StreamError::Other("in_progress_checkpoint missing at finalize".into())
333        })?;
334
335        let size_bytes = global.total_size_bytes;
336
337        self.completed_checkpoints.push(global);
338
339        // Trim retained checkpoints
340        while self.completed_checkpoints.len() > self.max_retained_checkpoints {
341            self.completed_checkpoints.remove(0);
342        }
343
344        self.current_phase = CheckpointPhase::Completed {
345            checkpoint_id,
346            completed_at: Instant::now(),
347            size_bytes,
348        };
349
350        Ok(())
351    }
352
353    /// Abort the in-progress checkpoint, transitioning back to `Idle`.
354    pub fn abort(&mut self, reason: &str) {
355        let checkpoint_id = self.current_phase.checkpoint_id().unwrap_or(0);
356        self.in_progress_checkpoint = None;
357        self.current_phase = CheckpointPhase::Failed {
358            checkpoint_id,
359            reason: reason.to_string(),
360        };
361    }
362
363    /// Reset from a Failed or Completed phase back to Idle.
364    pub fn reset_to_idle(&mut self) {
365        match self.current_phase {
366            CheckpointPhase::Completed { .. } | CheckpointPhase::Failed { .. } => {
367                self.current_phase = CheckpointPhase::Idle;
368            }
369            _ => {}
370        }
371    }
372
373    /// Return a reference to the most recent completed checkpoint.
374    pub fn latest_checkpoint(&self) -> Option<&GlobalCheckpoint> {
375        self.completed_checkpoints.last()
376    }
377
378    /// Return a reference to a completed checkpoint by ID.
379    pub fn get_checkpoint(&self, id: u64) -> Option<&GlobalCheckpoint> {
380        self.completed_checkpoints
381            .iter()
382            .find(|cp| cp.checkpoint_id == id)
383    }
384
385    /// Number of retained completed checkpoints.
386    pub fn completed_count(&self) -> usize {
387        self.completed_checkpoints.len()
388    }
389
390    /// The checkpoint ID currently in progress, if any.
391    pub fn current_checkpoint_id(&self) -> Option<u64> {
392        match &self.current_phase {
393            CheckpointPhase::InProgress { checkpoint_id, .. } => Some(*checkpoint_id),
394            _ => None,
395        }
396    }
397
398    /// Current phase (for diagnostics).
399    pub fn phase(&self) -> &CheckpointPhase {
400        &self.current_phase
401    }
402
403    /// Pending operator acknowledgements for the in-progress checkpoint.
404    ///
405    /// Returns `None` if no checkpoint is in progress.
406    pub fn pending_operators(&self) -> Option<Vec<String>> {
407        if let CheckpointPhase::InProgress {
408            ref acked_operators,
409            ..
410        } = self.current_phase
411        {
412            let pending: Vec<String> = self
413                .registered_operators
414                .iter()
415                .filter(|op| !acked_operators.contains(*op))
416                .cloned()
417                .collect();
418            Some(pending)
419        } else {
420            None
421        }
422    }
423}
424
425#[cfg(test)]
426mod tests {
427    use super::*;
428
429    fn make_snapshot(operator_id: &str, checkpoint_id: u64, state: &[u8]) -> OperatorSnapshot {
430        OperatorSnapshot::new(operator_id, checkpoint_id, state.to_vec(), vec![])
431    }
432
433    // ── CheckpointPhase helpers ───────────────────────────────────────────────
434
435    #[test]
436    fn test_phase_checkpoint_id() {
437        let idle = CheckpointPhase::Idle;
438        assert_eq!(idle.checkpoint_id(), None);
439
440        let in_progress = CheckpointPhase::InProgress {
441            checkpoint_id: 7,
442            started_at: Instant::now(),
443            acked_operators: HashSet::new(),
444        };
445        assert_eq!(in_progress.checkpoint_id(), Some(7));
446    }
447
448    // ── GlobalCheckpoint ─────────────────────────────────────────────────────
449
450    #[test]
451    fn test_global_checkpoint_completeness() {
452        let mut cp = GlobalCheckpoint::new(1);
453        let ops = vec!["op_a".to_string(), "op_b".to_string()];
454
455        assert!(!cp.is_complete(&ops));
456
457        cp.add_operator_snapshot(make_snapshot("op_a", 1, b"state_a"));
458        assert!(!cp.is_complete(&ops));
459
460        cp.add_operator_snapshot(make_snapshot("op_b", 1, b"state_b"));
461        assert!(cp.is_complete(&ops));
462    }
463
464    #[test]
465    fn test_global_checkpoint_bytes() {
466        let mut cp = GlobalCheckpoint::new(1);
467        cp.add_operator_snapshot(make_snapshot("op_a", 1, &[0u8; 100]));
468        cp.add_operator_snapshot(make_snapshot("op_b", 1, &[0u8; 200]));
469        assert_eq!(cp.total_bytes(), 300);
470    }
471
472    // ── CheckpointCoordinator ─────────────────────────────────────────────────
473
474    #[test]
475    fn test_should_checkpoint_when_no_last() {
476        let coord = CheckpointCoordinator::new(Duration::from_secs(60));
477        assert!(coord.should_checkpoint()); // No last checkpoint
478    }
479
480    #[test]
481    fn test_should_not_checkpoint_when_in_progress() {
482        let mut coord = CheckpointCoordinator::new(Duration::from_secs(60));
483        coord.register_operator("op1".to_string());
484        coord.initiate().unwrap();
485        assert!(!coord.should_checkpoint());
486    }
487
488    #[test]
489    fn test_initiate_returns_incrementing_ids() {
490        let mut coord = CheckpointCoordinator::new(Duration::from_millis(0));
491        coord.register_operator("op1".to_string());
492
493        let id1 = coord.initiate().unwrap();
494        assert_eq!(id1, 1);
495        assert_eq!(coord.current_checkpoint_id(), Some(1));
496
497        // Report so we can initiate again
498        let snap = make_snapshot("op1", 1, b"state");
499        coord.operator_reported(snap).unwrap();
500        coord.reset_to_idle();
501
502        let id2 = coord.initiate().unwrap();
503        assert_eq!(id2, 2);
504    }
505
506    #[test]
507    fn test_single_operator_full_lifecycle() {
508        let mut coord = CheckpointCoordinator::new(Duration::from_secs(300));
509        coord.register_operator("worker".to_string());
510
511        assert!(coord.should_checkpoint());
512
513        let cp_id = coord.initiate().unwrap();
514        assert_eq!(cp_id, 1);
515
516        let snap = make_snapshot("worker", cp_id, b"my_state_data");
517        let complete = coord.operator_reported(snap).unwrap();
518        assert!(complete);
519
520        assert_eq!(coord.completed_count(), 1);
521        let latest = coord.latest_checkpoint().unwrap();
522        assert_eq!(latest.checkpoint_id, 1);
523        assert!(latest.operator_snapshots.contains_key("worker"));
524        assert_eq!(latest.total_bytes(), 13); // "my_state_data"
525
526        coord.reset_to_idle();
527        assert!(!coord.should_checkpoint()); // interval hasn't elapsed
528    }
529
530    #[test]
531    fn test_multi_operator_checkpoint() {
532        let mut coord = CheckpointCoordinator::new(Duration::from_secs(300));
533        coord.register_operators(["op_a".to_string(), "op_b".to_string(), "op_c".to_string()]);
534
535        let cp_id = coord.initiate().unwrap();
536
537        // First two operators report → not complete yet
538        let not_done = coord
539            .operator_reported(make_snapshot("op_a", cp_id, b"state_a"))
540            .unwrap();
541        assert!(!not_done);
542
543        let not_done2 = coord
544            .operator_reported(make_snapshot("op_b", cp_id, b"state_b"))
545            .unwrap();
546        assert!(!not_done2);
547
548        // Pending operators should be just "op_c"
549        let pending = coord.pending_operators().unwrap();
550        assert_eq!(pending, vec!["op_c".to_string()]);
551
552        // Last operator reports → complete
553        let done = coord
554            .operator_reported(make_snapshot("op_c", cp_id, b"state_c"))
555            .unwrap();
556        assert!(done);
557
558        let cp = coord.get_checkpoint(cp_id).unwrap();
559        assert_eq!(cp.operator_snapshots.len(), 3);
560    }
561
562    #[test]
563    fn test_abort_checkpoint() {
564        let mut coord = CheckpointCoordinator::new(Duration::from_secs(300));
565        coord.register_operator("op".to_string());
566
567        coord.initiate().unwrap();
568        coord.abort("operator crashed");
569
570        assert!(matches!(coord.phase(), CheckpointPhase::Failed { .. }));
571        assert_eq!(coord.completed_count(), 0);
572
573        coord.reset_to_idle();
574        assert!(matches!(coord.phase(), CheckpointPhase::Idle));
575    }
576
577    #[test]
578    fn test_max_retained_checkpoints() {
579        let mut coord = CheckpointCoordinator::new(Duration::from_millis(0)).with_max_retained(3);
580        coord.register_operator("op".to_string());
581
582        for _ in 0..5 {
583            coord.initiate().unwrap();
584            let cp_id = coord.current_checkpoint_id().unwrap();
585            coord
586                .operator_reported(make_snapshot("op", cp_id, b"s"))
587                .unwrap();
588            coord.reset_to_idle();
589        }
590
591        assert_eq!(coord.completed_count(), 3);
592        // The latest should be checkpoint 5
593        assert_eq!(coord.latest_checkpoint().unwrap().checkpoint_id, 5);
594    }
595
596    #[test]
597    fn test_wrong_checkpoint_id_rejected() {
598        let mut coord = CheckpointCoordinator::new(Duration::from_secs(300));
599        coord.register_operator("op".to_string());
600
601        coord.initiate().unwrap(); // checkpoint_id = 1
602
603        // Report with wrong ID
604        let snap = make_snapshot("op", 999, b"state");
605        let result = coord.operator_reported(snap);
606        assert!(result.is_err());
607    }
608
609    #[test]
610    fn test_duplicate_initiate_fails() {
611        let mut coord = CheckpointCoordinator::new(Duration::from_secs(300));
612        coord.register_operator("op".to_string());
613
614        coord.initiate().unwrap();
615        let result = coord.initiate();
616        assert!(result.is_err());
617    }
618
619    #[test]
620    fn test_get_checkpoint_by_id() {
621        let mut coord = CheckpointCoordinator::new(Duration::from_millis(0));
622        coord.register_operator("op".to_string());
623
624        for _ in 0..3 {
625            coord.initiate().unwrap();
626            let cp_id = coord.current_checkpoint_id().unwrap();
627            coord
628                .operator_reported(make_snapshot("op", cp_id, b"s"))
629                .unwrap();
630            coord.reset_to_idle();
631        }
632
633        assert!(coord.get_checkpoint(1).is_some());
634        assert!(coord.get_checkpoint(2).is_some());
635        assert!(coord.get_checkpoint(3).is_some());
636        assert!(coord.get_checkpoint(99).is_none());
637    }
638
639    #[test]
640    fn test_operator_snapshot_size() {
641        let state = vec![0u8; 500];
642        let in_flight = vec![vec![0u8; 100], vec![0u8; 50]];
643        let snap = OperatorSnapshot::new("op", 1, state, in_flight);
644        assert_eq!(snap.size_bytes, 650);
645    }
646
647    #[test]
648    fn test_stream_positions() {
649        let mut cp = GlobalCheckpoint::new(1);
650        cp.set_stream_position("topic-A", 1024);
651        cp.set_stream_position("topic-B", 2048);
652
653        assert_eq!(cp.stream_positions.get("topic-A"), Some(&1024));
654        assert_eq!(cp.stream_positions.get("topic-B"), Some(&2048));
655    }
656}