Skip to main content

oxigdal_streaming/v2/
checkpoint.rs

1//! Checkpoint-based stream recovery for fault tolerance.
2//!
3//! A checkpoint captures the processing state at a given sequence number.
4//! On restart, processing resumes from the last checkpoint, replaying only
5//! the events since that point (minimal-replay guarantee).
6//!
7//! # Components
8//!
9//! - [`CheckpointId`]: unique identity of a checkpoint (stream + sequence number).
10//! - [`CheckpointState`]: serialisable snapshot of all operator states, source
11//!   offsets, watermark, and event count.
12//! - [`InMemoryCheckpointStore`]: bounded in-memory store (useful for testing
13//!   and for single-process use; production typically uses disk or object storage).
14//! - [`CheckpointManager`]: drives periodic checkpointing and recovery.
15
16use std::collections::HashMap;
17use std::time::SystemTime;
18
19use crate::error::StreamingError;
20
21// ─── CheckpointId ─────────────────────────────────────────────────────────────
22
23/// Unique identifier for a checkpoint within a named stream.
24#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
25pub struct CheckpointId {
26    /// Logical stream name.
27    pub stream_id: String,
28    /// Sequence number of the last event included in this checkpoint.
29    pub sequence: u64,
30    /// Wall-clock time at which the checkpoint was created.
31    pub created_at: SystemTime,
32}
33
34impl CheckpointId {
35    /// Construct a new `CheckpointId` timestamped *now*.
36    pub fn new(stream_id: impl Into<String>, sequence: u64) -> Self {
37        Self {
38            stream_id: stream_id.into(),
39            sequence,
40            created_at: SystemTime::now(),
41        }
42    }
43}
44
45// ─── CheckpointState ─────────────────────────────────────────────────────────
46
47/// Serialisable snapshot captured at a checkpoint.
48///
49/// # Binary format
50///
51/// ```text
52/// [8 bytes]  sequence          (little-endian u64)
53/// [8 bytes]  watermark_ns      (little-endian u64)
54/// [8 bytes]  event_count       (little-endian u64)
55/// [4 bytes]  n_operators       (little-endian u32)
56/// for each operator:
57///   [4 bytes] name_len         (little-endian u32)
58///   [name_len bytes] name      (UTF-8)
59///   [4 bytes] state_len        (little-endian u32)
60///   [state_len bytes] state    (opaque bytes)
61/// [4 bytes]  n_sources         (little-endian u32)
62/// for each source:
63///   [4 bytes] name_len         (little-endian u32)
64///   [name_len bytes] name      (UTF-8)
65///   [8 bytes] offset           (little-endian u64)
66/// ```
67#[derive(Debug, Clone)]
68pub struct CheckpointState {
69    /// Identity of this checkpoint.
70    pub id: CheckpointId,
71    /// Per-operator opaque state blobs.  Key: operator name.
72    pub operator_states: HashMap<String, Vec<u8>>,
73    /// Per-source byte/record offsets.  Key: source identifier.
74    pub source_offsets: HashMap<String, u64>,
75    /// Maximum processed event time expressed as nanoseconds since the Unix epoch.
76    pub watermark_ns: u64,
77    /// Number of events processed up to and including this checkpoint.
78    pub event_count: u64,
79    /// Arbitrary key→value metadata.
80    pub metadata: HashMap<String, String>,
81}
82
83impl CheckpointState {
84    /// Create an empty state for the given checkpoint ID.
85    pub fn new(id: CheckpointId) -> Self {
86        Self {
87            id,
88            operator_states: HashMap::new(),
89            source_offsets: HashMap::new(),
90            watermark_ns: 0,
91            event_count: 0,
92            metadata: HashMap::new(),
93        }
94    }
95
96    /// Store an operator's serialised state.
97    pub fn set_operator_state(&mut self, operator: impl Into<String>, state: Vec<u8>) {
98        self.operator_states.insert(operator.into(), state);
99    }
100
101    /// Record the byte/record offset for a source.
102    pub fn set_source_offset(&mut self, source: impl Into<String>, offset: u64) {
103        self.source_offsets.insert(source.into(), offset);
104    }
105
106    /// Serialise this state to a compact binary representation.
107    pub fn serialize(&self) -> Vec<u8> {
108        let mut buf = Vec::new();
109
110        buf.extend_from_slice(&self.id.sequence.to_le_bytes());
111        buf.extend_from_slice(&self.watermark_ns.to_le_bytes());
112        buf.extend_from_slice(&self.event_count.to_le_bytes());
113
114        // Operator states
115        buf.extend_from_slice(&(self.operator_states.len() as u32).to_le_bytes());
116        for (name, state) in &self.operator_states {
117            let name_bytes = name.as_bytes();
118            buf.extend_from_slice(&(name_bytes.len() as u32).to_le_bytes());
119            buf.extend_from_slice(name_bytes);
120            buf.extend_from_slice(&(state.len() as u32).to_le_bytes());
121            buf.extend_from_slice(state);
122        }
123
124        // Source offsets
125        buf.extend_from_slice(&(self.source_offsets.len() as u32).to_le_bytes());
126        for (name, offset) in &self.source_offsets {
127            let name_bytes = name.as_bytes();
128            buf.extend_from_slice(&(name_bytes.len() as u32).to_le_bytes());
129            buf.extend_from_slice(name_bytes);
130            buf.extend_from_slice(&offset.to_le_bytes());
131        }
132
133        buf
134    }
135
136    /// Deserialise a `CheckpointState` from bytes previously produced by [`Self::serialize`].
137    ///
138    /// Returns [`StreamingError::DeserializationError`] if the data is truncated
139    /// or otherwise malformed.
140    pub fn deserialize(stream_id: &str, data: &[u8]) -> Result<Self, StreamingError> {
141        const HEADER: usize = 24; // sequence(8) + watermark_ns(8) + event_count(8)
142        if data.len() < HEADER {
143            return Err(StreamingError::DeserializationError(
144                "checkpoint data too short for header".into(),
145            ));
146        }
147
148        let sequence = Self::read_u64(data, 0)?;
149        let watermark_ns = Self::read_u64(data, 8)?;
150        let event_count = Self::read_u64(data, 16)?;
151
152        let id = CheckpointId::new(stream_id, sequence);
153        let mut state = Self::new(id);
154        state.watermark_ns = watermark_ns;
155        state.event_count = event_count;
156
157        let mut cursor = HEADER;
158
159        // ── operator states ──
160        let n_ops = Self::read_u32(data, cursor)? as usize;
161        cursor += 4;
162        for _ in 0..n_ops {
163            let (name, advance) = Self::read_string(data, cursor)?;
164            cursor += advance;
165            let state_len = Self::read_u32(data, cursor)? as usize;
166            cursor += 4;
167            if cursor + state_len > data.len() {
168                return Err(StreamingError::DeserializationError(
169                    "truncated operator state bytes".into(),
170                ));
171            }
172            let op_state = data[cursor..cursor + state_len].to_vec();
173            cursor += state_len;
174            state.operator_states.insert(name, op_state);
175        }
176
177        // ── source offsets ──
178        if cursor + 4 > data.len() {
179            // No source-offsets section present — treat as empty.
180            return Ok(state);
181        }
182        let n_src = Self::read_u32(data, cursor)? as usize;
183        cursor += 4;
184        for _ in 0..n_src {
185            let (name, advance) = Self::read_string(data, cursor)?;
186            cursor += advance;
187            let offset = Self::read_u64(data, cursor)?;
188            cursor += 8;
189            state.source_offsets.insert(name, offset);
190        }
191
192        Ok(state)
193    }
194
195    // ── byte-reading helpers ─────────────────────────────────────────────────
196
197    fn read_u64(data: &[u8], offset: usize) -> Result<u64, StreamingError> {
198        data.get(offset..offset + 8)
199            .and_then(|b| b.try_into().ok())
200            .map(u64::from_le_bytes)
201            .ok_or_else(|| {
202                StreamingError::DeserializationError(format!("cannot read u64 at offset {offset}"))
203            })
204    }
205
206    fn read_u32(data: &[u8], offset: usize) -> Result<u32, StreamingError> {
207        data.get(offset..offset + 4)
208            .and_then(|b| b.try_into().ok())
209            .map(u32::from_le_bytes)
210            .ok_or_else(|| {
211                StreamingError::DeserializationError(format!("cannot read u32 at offset {offset}"))
212            })
213    }
214
215    /// Read a length-prefixed UTF-8 string from `data[cursor..]`.
216    ///
217    /// Returns `(string, bytes_consumed)` where `bytes_consumed` includes the
218    /// 4-byte length prefix.
219    fn read_string(data: &[u8], cursor: usize) -> Result<(String, usize), StreamingError> {
220        let name_len = Self::read_u32(data, cursor)? as usize;
221        let name_start = cursor + 4;
222        let name_end = name_start + name_len;
223        if name_end > data.len() {
224            return Err(StreamingError::DeserializationError(
225                "truncated string bytes".into(),
226            ));
227        }
228        let name = String::from_utf8(data[name_start..name_end].to_vec()).map_err(|e| {
229            StreamingError::DeserializationError(format!("invalid UTF-8 in field name: {e}"))
230        })?;
231        Ok((name, 4 + name_len))
232    }
233}
234
235// ─── InMemoryCheckpointStore ──────────────────────────────────────────────────
236
237/// A bounded in-memory checkpoint store.
238///
239/// Each stream maintains its own sorted list of checkpoints.  When the number
240/// of checkpoints for a stream exceeds `max_per_stream`, the **oldest** ones are
241/// evicted automatically.
242pub struct InMemoryCheckpointStore {
243    /// stream_id → list of checkpoints, sorted ascending by sequence number.
244    checkpoints: HashMap<String, Vec<CheckpointState>>,
245    /// Maximum checkpoints retained per stream.
246    max_per_stream: usize,
247}
248
249impl InMemoryCheckpointStore {
250    /// Create a store that retains at most `max_per_stream` checkpoints per stream.
251    pub fn new(max_per_stream: usize) -> Self {
252        assert!(max_per_stream > 0, "max_per_stream must be at least 1");
253        Self {
254            checkpoints: HashMap::new(),
255            max_per_stream,
256        }
257    }
258
259    /// Save a checkpoint state.  The list is kept sorted by sequence number.
260    pub fn save(&mut self, state: CheckpointState) -> Result<(), StreamingError> {
261        let stream_id = state.id.stream_id.clone();
262        let entry = self.checkpoints.entry(stream_id).or_default();
263        entry.push(state);
264        entry.sort_by_key(|s| s.id.sequence);
265        // Trim to max_per_stream, evicting oldest (lowest sequence)
266        if entry.len() > self.max_per_stream {
267            let excess = entry.len() - self.max_per_stream;
268            entry.drain(0..excess);
269        }
270        Ok(())
271    }
272
273    /// Return the most recent checkpoint for the given stream, or `None`.
274    pub fn latest(&self, stream_id: &str) -> Option<&CheckpointState> {
275        self.checkpoints.get(stream_id)?.last()
276    }
277
278    /// Return all checkpoints for the given stream, sorted ascending by sequence.
279    pub fn list(&self, stream_id: &str) -> Vec<&CheckpointState> {
280        self.checkpoints
281            .get(stream_id)
282            .map(|v| v.iter().collect())
283            .unwrap_or_default()
284    }
285
286    /// Remove all checkpoints for `stream_id` with sequence number **less than** `sequence`.
287    pub fn delete_before(&mut self, stream_id: &str, sequence: u64) {
288        if let Some(entry) = self.checkpoints.get_mut(stream_id) {
289            entry.retain(|s| s.id.sequence >= sequence);
290        }
291    }
292
293    /// Number of checkpoints currently stored for the given stream.
294    pub fn checkpoint_count(&self, stream_id: &str) -> usize {
295        self.checkpoints
296            .get(stream_id)
297            .map(|v| v.len())
298            .unwrap_or(0)
299    }
300}
301
302// ─── CheckpointManager ────────────────────────────────────────────────────────
303
304/// Drives periodic checkpointing and provides recovery support.
305///
306/// Call [`Self::on_event`] after processing each event.  When the cumulative sequence
307/// number reaches the next scheduled checkpoint, a new [`CheckpointState`] is
308/// automatically saved to the underlying store.
309pub struct CheckpointManager {
310    store: InMemoryCheckpointStore,
311    /// Checkpoint every `checkpoint_interval` events.
312    checkpoint_interval: u64,
313    next_checkpoint_at: u64,
314    total_checkpoints: u64,
315}
316
317impl CheckpointManager {
318    /// Create a manager with the given store and interval.
319    pub fn new(store: InMemoryCheckpointStore, checkpoint_interval: u64) -> Self {
320        assert!(
321            checkpoint_interval > 0,
322            "checkpoint_interval must be positive"
323        );
324        Self {
325            store,
326            checkpoint_interval,
327            next_checkpoint_at: checkpoint_interval,
328            total_checkpoints: 0,
329        }
330    }
331
332    /// Called after each processed event.
333    ///
334    /// Returns `Ok(true)` if a checkpoint was taken, `Ok(false)` otherwise.
335    pub fn on_event(
336        &mut self,
337        stream_id: &str,
338        sequence: u64,
339        watermark_ns: u64,
340    ) -> Result<bool, StreamingError> {
341        if sequence >= self.next_checkpoint_at {
342            let id = CheckpointId::new(stream_id, sequence);
343            let mut state = CheckpointState::new(id);
344            state.watermark_ns = watermark_ns;
345            state.event_count = sequence;
346            self.store.save(state)?;
347            self.next_checkpoint_at = sequence + self.checkpoint_interval;
348            self.total_checkpoints += 1;
349            return Ok(true);
350        }
351        Ok(false)
352    }
353
354    /// Return the sequence number from which to resume, or `None` if no
355    /// checkpoint exists for the stream.
356    pub fn recover(&self, stream_id: &str) -> Option<u64> {
357        self.store.latest(stream_id).map(|s| s.id.sequence)
358    }
359
360    /// Total checkpoints taken since this manager was created.
361    pub fn total_checkpoints(&self) -> u64 {
362        self.total_checkpoints
363    }
364
365    /// Read-only access to the underlying store (for inspection / testing).
366    pub fn store(&self) -> &InMemoryCheckpointStore {
367        &self.store
368    }
369}
370
371#[cfg(test)]
372mod tests {
373    use super::*;
374
375    // ── CheckpointState serialisation ────────────────────────────────────────
376
377    #[test]
378    fn test_serialize_deserialize_round_trip_empty() {
379        let id = CheckpointId::new("stream-a", 42);
380        let mut state = CheckpointState::new(id);
381        state.watermark_ns = 999_000_000;
382        state.event_count = 42;
383
384        let bytes = state.serialize();
385        let decoded = CheckpointState::deserialize("stream-a", &bytes)
386            .expect("deserialization should succeed");
387
388        assert_eq!(decoded.id.sequence, 42);
389        assert_eq!(decoded.watermark_ns, 999_000_000);
390        assert_eq!(decoded.event_count, 42);
391        assert!(decoded.operator_states.is_empty());
392        assert!(decoded.source_offsets.is_empty());
393    }
394
395    #[test]
396    fn test_serialize_deserialize_with_operator_states() {
397        let id = CheckpointId::new("s", 1);
398        let mut state = CheckpointState::new(id);
399        state.set_operator_state("agg_op", vec![1, 2, 3, 4]);
400        state.set_operator_state("filter_op", vec![9, 8]);
401
402        let bytes = state.serialize();
403        let decoded = CheckpointState::deserialize("s", &bytes).expect("should succeed");
404        assert_eq!(
405            decoded.operator_states.get("agg_op"),
406            Some(&vec![1, 2, 3, 4])
407        );
408        assert_eq!(decoded.operator_states.get("filter_op"), Some(&vec![9, 8]));
409    }
410
411    #[test]
412    fn test_serialize_deserialize_with_source_offsets() {
413        let id = CheckpointId::new("s", 7);
414        let mut state = CheckpointState::new(id);
415        state.set_source_offset("kafka-topic-0", 1_234_567);
416        state.set_source_offset("file-source", 4_096);
417
418        let bytes = state.serialize();
419        let decoded = CheckpointState::deserialize("s", &bytes).expect("should succeed");
420        assert_eq!(
421            decoded.source_offsets.get("kafka-topic-0"),
422            Some(&1_234_567)
423        );
424        assert_eq!(decoded.source_offsets.get("file-source"), Some(&4_096));
425    }
426
427    #[test]
428    fn test_deserialize_truncated_data_returns_error() {
429        let result = CheckpointState::deserialize("s", &[0u8; 10]);
430        assert!(result.is_err());
431    }
432
433    #[test]
434    fn test_deserialize_empty_slice_returns_error() {
435        let result = CheckpointState::deserialize("s", &[]);
436        assert!(result.is_err());
437    }
438
439    // ── InMemoryCheckpointStore ───────────────────────────────────────────────
440
441    #[test]
442    fn test_store_save_and_latest() {
443        let mut store = InMemoryCheckpointStore::new(5);
444        let id = CheckpointId::new("stream-x", 10);
445        let state = CheckpointState::new(id);
446        store.save(state).expect("save should succeed");
447        let latest = store.latest("stream-x").expect("should be present");
448        assert_eq!(latest.id.sequence, 10);
449    }
450
451    #[test]
452    fn test_store_latest_none_when_empty() {
453        let store = InMemoryCheckpointStore::new(5);
454        assert!(store.latest("unknown").is_none());
455    }
456
457    #[test]
458    fn test_store_trims_to_max_per_stream() {
459        let mut store = InMemoryCheckpointStore::new(3);
460        for i in 0u64..6 {
461            let id = CheckpointId::new("s", i);
462            store.save(CheckpointState::new(id)).expect("save ok");
463        }
464        assert_eq!(store.checkpoint_count("s"), 3);
465        // The oldest should have been evicted; latest should be seq=5
466        assert_eq!(
467            store
468                .latest("s")
469                .expect("latest checkpoint for stream 's'")
470                .id
471                .sequence,
472            5
473        );
474    }
475
476    #[test]
477    fn test_store_delete_before() {
478        let mut store = InMemoryCheckpointStore::new(10);
479        for i in 0u64..5 {
480            let id = CheckpointId::new("s", i * 10);
481            store.save(CheckpointState::new(id)).expect("save ok");
482        }
483        // Delete all checkpoints with sequence < 20
484        store.delete_before("s", 20);
485        let remaining = store.list("s");
486        assert!(remaining.iter().all(|c| c.id.sequence >= 20));
487    }
488
489    #[test]
490    fn test_store_multiple_streams_independent() {
491        let mut store = InMemoryCheckpointStore::new(5);
492        for seq in [1u64, 2, 3] {
493            store
494                .save(CheckpointState::new(CheckpointId::new("stream-a", seq)))
495                .expect("ok");
496            store
497                .save(CheckpointState::new(CheckpointId::new(
498                    "stream-b",
499                    seq * 10,
500                )))
501                .expect("ok");
502        }
503        assert_eq!(store.checkpoint_count("stream-a"), 3);
504        assert_eq!(store.checkpoint_count("stream-b"), 3);
505        assert_eq!(
506            store
507                .latest("stream-a")
508                .expect("latest checkpoint for stream-a")
509                .id
510                .sequence,
511            3
512        );
513        assert_eq!(
514            store
515                .latest("stream-b")
516                .expect("latest checkpoint for stream-b")
517                .id
518                .sequence,
519            30
520        );
521    }
522
523    // ── CheckpointManager ────────────────────────────────────────────────────
524
525    #[test]
526    fn test_manager_triggers_checkpoint_at_interval() {
527        let store = InMemoryCheckpointStore::new(10);
528        let mut mgr = CheckpointManager::new(store, 100);
529        // Events 0-98: no checkpoint yet
530        for seq in 0u64..99 {
531            let triggered = mgr.on_event("s", seq, 0).expect("on_event ok");
532            assert!(!triggered);
533        }
534        // Event 100: checkpoint fires
535        let triggered = mgr.on_event("s", 100, 0).expect("on_event ok");
536        assert!(triggered);
537        assert_eq!(mgr.total_checkpoints(), 1);
538    }
539
540    #[test]
541    fn test_manager_recover_returns_last_sequence() {
542        let store = InMemoryCheckpointStore::new(10);
543        let mut mgr = CheckpointManager::new(store, 50);
544        mgr.on_event("s", 50, 0).expect("ok");
545        mgr.on_event("s", 100, 0).expect("ok");
546        let seq = mgr.recover("s").expect("should recover");
547        assert_eq!(seq, 100);
548    }
549
550    #[test]
551    fn test_manager_recover_none_before_first_checkpoint() {
552        let store = InMemoryCheckpointStore::new(5);
553        let mgr = CheckpointManager::new(store, 100);
554        assert!(mgr.recover("s").is_none());
555    }
556
557    #[test]
558    fn test_manager_total_checkpoints_counter() {
559        let store = InMemoryCheckpointStore::new(10);
560        let mut mgr = CheckpointManager::new(store, 10);
561        for seq in (0u64..=50).step_by(1) {
562            mgr.on_event("s", seq, 0).expect("ok");
563        }
564        // Checkpoints at seq 10, 20, 30, 40, 50 = 5
565        assert_eq!(mgr.total_checkpoints(), 5);
566    }
567
568    #[test]
569    fn test_checkpoint_state_full_round_trip() {
570        let id = CheckpointId::new("full-test", 77);
571        let mut state = CheckpointState::new(id);
572        state.watermark_ns = 1_700_000_000_000_000_000;
573        state.event_count = 77;
574        state.set_operator_state("window_op", b"window_state_data".to_vec());
575        state.set_source_offset("source-0", 8192);
576        state.metadata.insert("app_version".into(), "1.2.3".into());
577
578        let bytes = state.serialize();
579        let decoded =
580            CheckpointState::deserialize("full-test", &bytes).expect("round-trip should succeed");
581
582        assert_eq!(decoded.id.sequence, 77);
583        assert_eq!(decoded.watermark_ns, 1_700_000_000_000_000_000);
584        assert_eq!(decoded.event_count, 77);
585        assert_eq!(
586            decoded.operator_states.get("window_op"),
587            Some(&b"window_state_data".to_vec())
588        );
589        assert_eq!(decoded.source_offsets.get("source-0"), Some(&8192u64));
590    }
591}