Skip to main content

laminar_core/streaming/
checkpoint.rs

1//! Streaming checkpoint support.
2//!
3//! Provides optional, zero-overhead checkpointing for the streaming API.
4//! When disabled (the default), no runtime cost is incurred. When enabled,
5//! captures source sequences, watermarks, and persists checkpoint snapshots.
6//!
7//! ## Architecture
8//!
9//! ```text
10//! Ring 0 (Hot Path): Source.push() -> increment sequence (AtomicU64 Relaxed ~1ns)
11//! Ring 1 (Background): StreamCheckpointManager.trigger() -> capture atomics -> store
12//! Ring 2 (Control):    LaminarDB.checkpoint() -> manual trigger
13//! ```
14
15use std::collections::HashMap;
16use std::fmt;
17use std::sync::atomic::{AtomicI64, AtomicU64, Ordering};
18use std::sync::Arc;
19
20// Configuration
21
22/// WAL mode for checkpoint durability.
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub enum WalMode {
25    /// Asynchronous WAL writes (faster, may lose last few entries on crash).
26    Async,
27    /// Synchronous WAL writes (slower, durable).
28    Sync,
29}
30
31/// Configuration for streaming checkpoints.
32///
33/// All fields default to `None`/disabled. Checkpointing is opt-in.
34#[derive(Debug, Clone)]
35pub struct StreamCheckpointConfig {
36    /// Checkpoint interval in milliseconds. `None` = manual only.
37    pub interval_ms: Option<u64>,
38    /// WAL mode. Requires `data_dir` to be set.
39    pub wal_mode: Option<WalMode>,
40    /// Directory for persisting checkpoints/WAL. `None` = in-memory only.
41    pub data_dir: Option<std::path::PathBuf>,
42    /// Changelog buffer capacity. `None` = no changelog buffer.
43    pub changelog_capacity: Option<usize>,
44    /// Maximum number of retained checkpoints. `None` = unlimited.
45    pub max_retained: Option<usize>,
46    /// Overflow policy for the changelog buffer.
47    pub overflow_policy: OverflowPolicy,
48}
49
50impl Default for StreamCheckpointConfig {
51    fn default() -> Self {
52        Self {
53            interval_ms: None,
54            wal_mode: None,
55            data_dir: None,
56            changelog_capacity: None,
57            max_retained: None,
58            overflow_policy: OverflowPolicy::DropNew,
59        }
60    }
61}
62
63impl StreamCheckpointConfig {
64    /// Validates the configuration, returning an error if invalid.
65    ///
66    /// # Errors
67    ///
68    /// Returns `CheckpointError::InvalidConfig` if WAL mode is set without
69    /// `data_dir`, or if `changelog_capacity` is zero.
70    pub fn validate(&self) -> Result<(), CheckpointError> {
71        if self.wal_mode.is_some() && self.data_dir.is_none() {
72            return Err(CheckpointError::InvalidConfig(
73                "WAL mode requires data_dir to be set".into(),
74            ));
75        }
76        if let Some(cap) = self.changelog_capacity {
77            if cap == 0 {
78                return Err(CheckpointError::InvalidConfig(
79                    "changelog_capacity must be > 0".into(),
80                ));
81            }
82        }
83        Ok(())
84    }
85}
86
87// Errors
88
89/// Errors from checkpoint operations.
90#[derive(Debug, Clone, PartialEq, Eq)]
91pub enum CheckpointError {
92    /// Checkpointing is disabled.
93    Disabled,
94    /// A data directory is required for this operation.
95    DataDirRequired,
96    /// WAL mode requires checkpointing to be enabled.
97    WalRequiresCheckpoint,
98    /// No checkpoint available for restore.
99    NoCheckpoint,
100    /// Operation timed out.
101    Timeout,
102    /// Invalid configuration.
103    InvalidConfig(String),
104    /// I/O error (stored as string for Clone/PartialEq).
105    IoError(String),
106}
107
108impl fmt::Display for CheckpointError {
109    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
110        match self {
111            Self::Disabled => write!(f, "checkpointing is disabled"),
112            Self::DataDirRequired => write!(f, "data directory is required"),
113            Self::WalRequiresCheckpoint => {
114                write!(f, "WAL mode requires checkpointing")
115            }
116            Self::NoCheckpoint => write!(f, "no checkpoint available"),
117            Self::Timeout => write!(f, "checkpoint operation timed out"),
118            Self::InvalidConfig(msg) => {
119                write!(f, "invalid checkpoint config: {msg}")
120            }
121            Self::IoError(msg) => write!(f, "checkpoint I/O error: {msg}"),
122        }
123    }
124}
125
126impl std::error::Error for CheckpointError {}
127
128// Overflow policy
129
130/// Policy when the changelog buffer is full.
131#[derive(Debug, Clone, Copy, PartialEq, Eq)]
132pub enum OverflowPolicy {
133    /// Drop new entries when buffer is full.
134    DropNew,
135    /// Overwrite the oldest entry.
136    OverwriteOldest,
137}
138
139// Changelog entry (24 bytes, repr(C))
140
141/// Type of changelog operation.
142#[derive(Debug, Clone, Copy, PartialEq, Eq)]
143#[repr(u8)]
144pub enum StreamChangeOp {
145    /// A record was pushed.
146    Push = 0,
147    /// A watermark was emitted.
148    Watermark = 1,
149    /// A checkpoint barrier.
150    Barrier = 2,
151}
152
153impl StreamChangeOp {
154    fn from_u8(v: u8) -> Option<Self> {
155        match v {
156            0 => Some(Self::Push),
157            1 => Some(Self::Watermark),
158            2 => Some(Self::Barrier),
159            _ => None,
160        }
161    }
162}
163
164/// A single changelog entry — fixed 24 bytes, no heap allocation.
165///
166/// Layout (repr(C)):
167/// ```text
168/// [source_id: u16][op: u8][padding: u8][reserved: u32][sequence: u64][watermark: i64]
169/// ```
170#[derive(Debug, Clone, Copy, PartialEq, Eq)]
171#[repr(C)]
172pub struct StreamChangelogEntry {
173    /// Source identifier (compact).
174    pub source_id: u16,
175    /// Operation type.
176    pub op: u8,
177    /// Padding for alignment.
178    _padding: u8,
179    /// Reserved for future use.
180    _reserved: u32,
181    /// Sequence number at time of operation.
182    pub sequence: u64,
183    /// Watermark value at time of operation.
184    pub watermark: i64,
185}
186
187impl StreamChangelogEntry {
188    /// Creates a new changelog entry.
189    #[must_use]
190    pub fn new(source_id: u16, op: StreamChangeOp, sequence: u64, watermark: i64) -> Self {
191        Self {
192            source_id,
193            op: op as u8,
194            _padding: 0,
195            _reserved: 0,
196            sequence,
197            watermark,
198        }
199    }
200
201    /// Returns the operation type.
202    #[must_use]
203    pub fn op_type(&self) -> Option<StreamChangeOp> {
204        StreamChangeOp::from_u8(self.op)
205    }
206}
207
208// Changelog buffer (pre-allocated ring buffer, zero-alloc after init)
209
210/// A pre-allocated ring buffer for changelog entries.
211///
212/// Uses a simple write/read index scheme. Not thread-safe on its own —
213/// intended to be used behind the `StreamCheckpointManager` mutex.
214pub struct StreamChangelogBuffer {
215    entries: Vec<StreamChangelogEntry>,
216    capacity: usize,
217    write_idx: usize,
218    read_idx: usize,
219    count: usize,
220    overflow_count: u64,
221    policy: OverflowPolicy,
222}
223
224impl StreamChangelogBuffer {
225    /// Creates a new changelog buffer with the given capacity.
226    #[must_use]
227    pub fn new(capacity: usize, policy: OverflowPolicy) -> Self {
228        let zeroed = StreamChangelogEntry {
229            source_id: 0,
230            op: 0,
231            _padding: 0,
232            _reserved: 0,
233            sequence: 0,
234            watermark: 0,
235        };
236        Self {
237            entries: vec![zeroed; capacity],
238            capacity,
239            write_idx: 0,
240            read_idx: 0,
241            count: 0,
242            overflow_count: 0,
243            policy,
244        }
245    }
246
247    /// Pushes an entry into the buffer.
248    ///
249    /// Returns `true` if the entry was stored, `false` if dropped due to
250    /// overflow policy.
251    pub fn push(&mut self, entry: StreamChangelogEntry) -> bool {
252        if self.count == self.capacity {
253            self.overflow_count += 1;
254            match self.policy {
255                OverflowPolicy::DropNew => return false,
256                OverflowPolicy::OverwriteOldest => {
257                    // Advance read pointer, discarding oldest
258                    self.read_idx = (self.read_idx + 1) % self.capacity;
259                    self.count -= 1;
260                }
261            }
262        }
263        self.entries[self.write_idx] = entry;
264        self.write_idx = (self.write_idx + 1) % self.capacity;
265        self.count += 1;
266        true
267    }
268
269    /// Pops the oldest entry from the buffer.
270    pub fn pop(&mut self) -> Option<StreamChangelogEntry> {
271        if self.count == 0 {
272            return None;
273        }
274        let entry = self.entries[self.read_idx];
275        self.read_idx = (self.read_idx + 1) % self.capacity;
276        self.count -= 1;
277        Some(entry)
278    }
279
280    /// Drains up to `max` entries into the provided vector.
281    pub fn drain(&mut self, max: usize, out: &mut Vec<StreamChangelogEntry>) {
282        let n = max.min(self.count);
283        for _ in 0..n {
284            if let Some(entry) = self.pop() {
285                out.push(entry);
286            }
287        }
288    }
289
290    /// Drains all entries into the provided vector.
291    pub fn drain_all(&mut self, out: &mut Vec<StreamChangelogEntry>) {
292        let n = self.count;
293        self.drain(n, out);
294    }
295
296    /// Returns the number of entries in the buffer.
297    #[must_use]
298    pub fn len(&self) -> usize {
299        self.count
300    }
301
302    /// Returns `true` if the buffer is empty.
303    #[must_use]
304    pub fn is_empty(&self) -> bool {
305        self.count == 0
306    }
307
308    /// Returns `true` if the buffer is full.
309    #[must_use]
310    pub fn is_full(&self) -> bool {
311        self.count == self.capacity
312    }
313
314    /// Returns the buffer capacity.
315    #[must_use]
316    pub fn capacity(&self) -> usize {
317        self.capacity
318    }
319
320    /// Returns the total number of overflows since creation.
321    #[must_use]
322    pub fn overflow_count(&self) -> u64 {
323        self.overflow_count
324    }
325}
326
327impl fmt::Debug for StreamChangelogBuffer {
328    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
329        f.debug_struct("StreamChangelogBuffer")
330            .field("capacity", &self.capacity)
331            .field("len", &self.count)
332            .field("overflow_count", &self.overflow_count)
333            .finish_non_exhaustive()
334    }
335}
336
337// Checkpoint snapshot
338
339/// A point-in-time snapshot of streaming pipeline state.
340#[derive(Debug, Clone)]
341pub struct StreamCheckpoint {
342    /// Unique checkpoint identifier (monotonically increasing).
343    pub id: u64,
344    /// Epoch number.
345    pub epoch: u64,
346    /// Source name -> sequence number at checkpoint time.
347    pub source_sequences: HashMap<String, u64>,
348    /// Sink name -> position at checkpoint time.
349    pub sink_positions: HashMap<String, u64>,
350    /// Source name -> watermark at checkpoint time.
351    pub watermarks: HashMap<String, i64>,
352    /// Operator name -> opaque state bytes.
353    pub operator_states: HashMap<String, Vec<u8>>,
354    /// Timestamp when this checkpoint was created (millis since epoch).
355    pub created_at: u64,
356}
357
358impl StreamCheckpoint {
359    /// Serializes the checkpoint to bytes.
360    ///
361    /// Format:
362    /// ```text
363    /// [version: 1][id: 8][epoch: 8][created_at: 8]
364    /// [num_sources: 4][ [name_len:4][name][seq:8] ... ]
365    /// [num_sinks: 4][ [name_len:4][name][pos:8] ... ]
366    /// [num_watermarks: 4][ [name_len:4][name][wm:8] ... ]
367    /// [num_ops: 4][ [name_len:4][name][data_len:4][data] ... ]
368    /// ```
369    #[must_use]
370    #[allow(clippy::cast_possible_truncation)] // Wire format uses u32 for collection lengths
371    pub fn to_bytes(&self) -> Vec<u8> {
372        let mut buf = Vec::with_capacity(256);
373
374        // Version
375        buf.push(1u8);
376
377        // Header
378        buf.extend_from_slice(&self.id.to_le_bytes());
379        buf.extend_from_slice(&self.epoch.to_le_bytes());
380        buf.extend_from_slice(&self.created_at.to_le_bytes());
381
382        // Source sequences
383        buf.extend_from_slice(&(self.source_sequences.len() as u32).to_le_bytes());
384        for (name, seq) in &self.source_sequences {
385            buf.extend_from_slice(&(name.len() as u32).to_le_bytes());
386            buf.extend_from_slice(name.as_bytes());
387            buf.extend_from_slice(&seq.to_le_bytes());
388        }
389
390        // Sink positions
391        buf.extend_from_slice(&(self.sink_positions.len() as u32).to_le_bytes());
392        for (name, pos) in &self.sink_positions {
393            buf.extend_from_slice(&(name.len() as u32).to_le_bytes());
394            buf.extend_from_slice(name.as_bytes());
395            buf.extend_from_slice(&pos.to_le_bytes());
396        }
397
398        // Watermarks
399        buf.extend_from_slice(&(self.watermarks.len() as u32).to_le_bytes());
400        for (name, wm) in &self.watermarks {
401            buf.extend_from_slice(&(name.len() as u32).to_le_bytes());
402            buf.extend_from_slice(name.as_bytes());
403            buf.extend_from_slice(&wm.to_le_bytes());
404        }
405
406        // Operator states
407        buf.extend_from_slice(&(self.operator_states.len() as u32).to_le_bytes());
408        for (name, data) in &self.operator_states {
409            buf.extend_from_slice(&(name.len() as u32).to_le_bytes());
410            buf.extend_from_slice(name.as_bytes());
411            buf.extend_from_slice(&(data.len() as u32).to_le_bytes());
412            buf.extend_from_slice(data);
413        }
414
415        buf
416    }
417
418    /// Deserializes a checkpoint from bytes.
419    ///
420    /// # Errors
421    ///
422    /// Returns `CheckpointError::IoError` if the data is truncated, corrupted,
423    /// or uses an unsupported version.
424    #[allow(clippy::similar_names, clippy::too_many_lines)]
425    pub fn from_bytes(data: &[u8]) -> Result<Self, CheckpointError> {
426        let mut pos = 0;
427
428        let read_u32 = |p: &mut usize| -> Result<u32, CheckpointError> {
429            if *p + 4 > data.len() {
430                return Err(CheckpointError::IoError("truncated u32".into()));
431            }
432            let val = u32::from_le_bytes(
433                data[*p..*p + 4]
434                    .try_into()
435                    .map_err(|_| CheckpointError::IoError("bad u32".into()))?,
436            );
437            *p += 4;
438            Ok(val)
439        };
440
441        let read_u64_val = |p: &mut usize| -> Result<u64, CheckpointError> {
442            if *p + 8 > data.len() {
443                return Err(CheckpointError::IoError("truncated u64".into()));
444            }
445            let val = u64::from_le_bytes(
446                data[*p..*p + 8]
447                    .try_into()
448                    .map_err(|_| CheckpointError::IoError("bad u64".into()))?,
449            );
450            *p += 8;
451            Ok(val)
452        };
453
454        let read_i64_val = |p: &mut usize| -> Result<i64, CheckpointError> {
455            if *p + 8 > data.len() {
456                return Err(CheckpointError::IoError("truncated i64".into()));
457            }
458            let val = i64::from_le_bytes(
459                data[*p..*p + 8]
460                    .try_into()
461                    .map_err(|_| CheckpointError::IoError("bad i64".into()))?,
462            );
463            *p += 8;
464            Ok(val)
465        };
466
467        let read_string = |p: &mut usize| -> Result<String, CheckpointError> {
468            let slen = read_u32(p)? as usize;
469            if *p + slen > data.len() {
470                return Err(CheckpointError::IoError("truncated string".into()));
471            }
472            let s = std::str::from_utf8(&data[*p..*p + slen])
473                .map_err(|_| CheckpointError::IoError("invalid utf8".into()))?
474                .to_string();
475            *p += slen;
476            Ok(s)
477        };
478
479        // Version
480        if pos >= data.len() {
481            return Err(CheckpointError::IoError("empty checkpoint data".into()));
482        }
483        let version = data[pos];
484        pos += 1;
485        if version != 1 {
486            return Err(CheckpointError::IoError(format!(
487                "unsupported checkpoint version: {version}"
488            )));
489        }
490
491        // Header
492        let id = read_u64_val(&mut pos)?;
493        let epoch = read_u64_val(&mut pos)?;
494        let created_at = read_u64_val(&mut pos)?;
495
496        // Source sequences
497        let num_sources = read_u32(&mut pos)? as usize;
498        let mut source_sequences = HashMap::with_capacity(num_sources);
499        for _ in 0..num_sources {
500            let name = read_string(&mut pos)?;
501            let seq = read_u64_val(&mut pos)?;
502            source_sequences.insert(name, seq);
503        }
504
505        // Sink positions
506        let num_sinks = read_u32(&mut pos)? as usize;
507        let mut sink_positions = HashMap::with_capacity(num_sinks);
508        for _ in 0..num_sinks {
509            let name = read_string(&mut pos)?;
510            let sink_pos = read_u64_val(&mut pos)?;
511            sink_positions.insert(name, sink_pos);
512        }
513
514        // Watermarks
515        let num_watermarks = read_u32(&mut pos)? as usize;
516        let mut watermarks = HashMap::with_capacity(num_watermarks);
517        for _ in 0..num_watermarks {
518            let name = read_string(&mut pos)?;
519            let wm = read_i64_val(&mut pos)?;
520            watermarks.insert(name, wm);
521        }
522
523        // Operator states
524        let num_ops = read_u32(&mut pos)? as usize;
525        let mut operator_states = HashMap::with_capacity(num_ops);
526        for _ in 0..num_ops {
527            let name = read_string(&mut pos)?;
528            let data_len = read_u32(&mut pos)? as usize;
529            if pos + data_len > data.len() {
530                return Err(CheckpointError::IoError("truncated operator state".into()));
531            }
532            let state_data = data[pos..pos + data_len].to_vec();
533            pos += data_len;
534            operator_states.insert(name, state_data);
535        }
536
537        Ok(Self {
538            id,
539            epoch,
540            source_sequences,
541            sink_positions,
542            watermarks,
543            operator_states,
544            created_at,
545        })
546    }
547}
548
549// Registered source info (held by the manager)
550
551/// Registered source state visible to the checkpoint manager.
552struct RegisteredSource {
553    /// Shared sequence counter (atomically incremented by Source on push).
554    sequence: Arc<AtomicU64>,
555    /// Shared watermark (atomically updated by Source).
556    watermark: Arc<AtomicI64>,
557}
558
559// Checkpoint manager
560
561/// Coordinates checkpoint lifecycle for streaming sources and sinks.
562///
563/// Disabled by default. When enabled via [`StreamCheckpointConfig`], the
564/// manager captures atomic counters from registered sources to produce
565/// consistent [`StreamCheckpoint`] snapshots.
566pub struct StreamCheckpointManager {
567    config: StreamCheckpointConfig,
568    enabled: bool,
569    sources: HashMap<String, RegisteredSource>,
570    sinks: HashMap<String, u64>,
571    checkpoints: Vec<StreamCheckpoint>,
572    next_id: u64,
573    epoch: u64,
574    changelog: Option<StreamChangelogBuffer>,
575}
576
577impl StreamCheckpointManager {
578    /// Creates a new checkpoint manager.
579    ///
580    /// If `config` validation fails, the manager is created in disabled state.
581    #[must_use]
582    pub fn new(config: StreamCheckpointConfig) -> Self {
583        let enabled = config.validate().is_ok();
584        let changelog = config
585            .changelog_capacity
586            .filter(|_| enabled)
587            .map(|cap| StreamChangelogBuffer::new(cap, config.overflow_policy));
588        Self {
589            config,
590            enabled,
591            sources: HashMap::new(),
592            sinks: HashMap::new(),
593            checkpoints: Vec::new(),
594            next_id: 1,
595            epoch: 0,
596            changelog,
597        }
598    }
599
600    /// Creates a disabled (no-op) manager.
601    #[must_use]
602    pub fn disabled() -> Self {
603        Self {
604            config: StreamCheckpointConfig::default(),
605            enabled: false,
606            sources: HashMap::new(),
607            sinks: HashMap::new(),
608            checkpoints: Vec::new(),
609            next_id: 1,
610            epoch: 0,
611            changelog: None,
612        }
613    }
614
615    /// Returns whether checkpointing is enabled.
616    #[must_use]
617    pub fn is_enabled(&self) -> bool {
618        self.enabled
619    }
620
621    /// Registers a source for checkpoint tracking.
622    ///
623    /// The `sequence` and `watermark` atomics are shared with the live
624    /// [`Source`](super::Source) — reading them is lock-free.
625    pub fn register_source(
626        &mut self,
627        name: &str,
628        sequence: Arc<AtomicU64>,
629        watermark: Arc<AtomicI64>,
630    ) {
631        self.sources.insert(
632            name.to_string(),
633            RegisteredSource {
634                sequence,
635                watermark,
636            },
637        );
638    }
639
640    /// Registers a sink for checkpoint tracking.
641    pub fn register_sink(&mut self, name: &str, position: u64) {
642        self.sinks.insert(name.to_string(), position);
643    }
644
645    /// Triggers a checkpoint, capturing current source/sink state.
646    ///
647    /// Returns the checkpoint ID, or `None` if checkpointing is disabled.
648    #[allow(clippy::cast_possible_truncation)] // Timestamp ms fits i64 for ~292 years from epoch
649    pub fn trigger(&mut self) -> Option<u64> {
650        if !self.enabled {
651            return None;
652        }
653
654        self.epoch += 1;
655        let id = self.next_id;
656        self.next_id += 1;
657
658        // Capture source sequences and watermarks atomically
659        let mut source_sequences = HashMap::with_capacity(self.sources.len());
660        let mut watermarks = HashMap::with_capacity(self.sources.len());
661        for (name, src) in &self.sources {
662            source_sequences.insert(name.clone(), src.sequence.load(Ordering::Acquire));
663            watermarks.insert(name.clone(), src.watermark.load(Ordering::Acquire));
664        }
665
666        // Capture sink positions
667        let sink_positions = self.sinks.clone();
668
669        let now = std::time::SystemTime::now()
670            .duration_since(std::time::UNIX_EPOCH)
671            .map(|d| d.as_millis() as u64)
672            .unwrap_or(0);
673
674        let checkpoint = StreamCheckpoint {
675            id,
676            epoch: self.epoch,
677            source_sequences,
678            sink_positions,
679            watermarks,
680            operator_states: HashMap::new(),
681            created_at: now,
682        };
683
684        self.checkpoints.push(checkpoint);
685
686        // Prune old checkpoints if max_retained is set
687        if let Some(max) = self.config.max_retained {
688            while self.checkpoints.len() > max {
689                self.checkpoints.remove(0);
690            }
691        }
692
693        Some(id)
694    }
695
696    /// Creates a checkpoint and returns the checkpoint ID.
697    ///
698    /// # Errors
699    ///
700    /// Returns `CheckpointError::Disabled` if checkpointing is not enabled.
701    pub fn checkpoint(&mut self) -> Result<Option<u64>, CheckpointError> {
702        if !self.enabled {
703            return Err(CheckpointError::Disabled);
704        }
705        Ok(self.trigger())
706    }
707
708    /// Returns the most recent checkpoint for restore.
709    ///
710    /// # Errors
711    ///
712    /// Returns `CheckpointError::Disabled` if checkpointing is not enabled,
713    /// or `CheckpointError::NoCheckpoint` if no checkpoint exists.
714    pub fn restore(&self) -> Result<&StreamCheckpoint, CheckpointError> {
715        if !self.enabled {
716            return Err(CheckpointError::Disabled);
717        }
718        self.checkpoints.last().ok_or(CheckpointError::NoCheckpoint)
719    }
720
721    /// Returns a checkpoint by ID.
722    #[must_use]
723    pub fn get_checkpoint(&self, id: u64) -> Option<&StreamCheckpoint> {
724        self.checkpoints.iter().find(|cp| cp.id == id)
725    }
726
727    /// Returns the ID of the most recent checkpoint.
728    #[must_use]
729    pub fn last_checkpoint_id(&self) -> Option<u64> {
730        self.checkpoints.last().map(|cp| cp.id)
731    }
732
733    /// Returns a reference to the changelog buffer, if configured.
734    #[must_use]
735    pub fn changelog(&self) -> Option<&StreamChangelogBuffer> {
736        self.changelog.as_ref()
737    }
738
739    /// Returns a mutable reference to the changelog buffer.
740    pub fn changelog_mut(&mut self) -> Option<&mut StreamChangelogBuffer> {
741        self.changelog.as_mut()
742    }
743
744    /// Returns the current epoch.
745    #[must_use]
746    pub fn epoch(&self) -> u64 {
747        self.epoch
748    }
749}
750
751impl fmt::Debug for StreamCheckpointManager {
752    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
753        f.debug_struct("StreamCheckpointManager")
754            .field("enabled", &self.enabled)
755            .field("sources", &self.sources.len())
756            .field("sinks", &self.sinks.len())
757            .field("checkpoints", &self.checkpoints.len())
758            .field("epoch", &self.epoch)
759            .finish_non_exhaustive()
760    }
761}
762
763// Tests
764
765#[cfg(test)]
766mod tests {
767    use super::*;
768
769    fn enabled_config() -> StreamCheckpointConfig {
770        StreamCheckpointConfig {
771            interval_ms: Some(1000),
772            ..Default::default()
773        }
774    }
775
776    // -- Config / disabled tests --
777
778    #[test]
779    fn test_checkpoint_disabled_by_default() {
780        let config = StreamCheckpointConfig::default();
781        let mgr = StreamCheckpointManager::new(config);
782        // Default config is valid but has no interval — still "enabled"
783        // because validate() passes. Disabled means validate() fails.
784        assert!(mgr.is_enabled());
785
786        // A truly disabled manager:
787        let mgr2 = StreamCheckpointManager::disabled();
788        assert!(!mgr2.is_enabled());
789    }
790
791    #[test]
792    fn test_checkpoint_no_op_when_disabled() {
793        let mgr = StreamCheckpointManager::disabled();
794        assert!(!mgr.is_enabled());
795        assert_eq!(mgr.last_checkpoint_id(), None);
796    }
797
798    #[test]
799    fn test_checkpoint_config_requires_data_dir() {
800        let config = StreamCheckpointConfig {
801            wal_mode: Some(WalMode::Sync),
802            data_dir: None,
803            ..Default::default()
804        };
805        assert!(config.validate().is_err());
806
807        // With data_dir set, validation passes
808        let config2 = StreamCheckpointConfig {
809            wal_mode: Some(WalMode::Sync),
810            data_dir: Some(std::path::PathBuf::from("/tmp/test")),
811            ..Default::default()
812        };
813        assert!(config2.validate().is_ok());
814    }
815
816    #[test]
817    fn test_wal_requires_checkpoint() {
818        let config = StreamCheckpointConfig {
819            wal_mode: Some(WalMode::Async),
820            data_dir: None, // missing
821            ..Default::default()
822        };
823        let result = config.validate();
824        assert!(matches!(result, Err(CheckpointError::InvalidConfig(_))));
825    }
826
827    // -- Source registration --
828
829    #[test]
830    fn test_register_source() {
831        let mut mgr = StreamCheckpointManager::new(enabled_config());
832
833        let seq = Arc::new(AtomicU64::new(0));
834        let wm = Arc::new(AtomicI64::new(i64::MIN));
835
836        mgr.register_source("trades", Arc::clone(&seq), Arc::clone(&wm));
837        assert!(mgr.is_enabled());
838    }
839
840    // -- Trigger / capture --
841
842    #[test]
843    fn test_trigger_checkpoint() {
844        let mut mgr = StreamCheckpointManager::new(enabled_config());
845        let id = mgr.trigger();
846        assert_eq!(id, Some(1));
847
848        let id2 = mgr.trigger();
849        assert_eq!(id2, Some(2));
850    }
851
852    #[test]
853    fn test_checkpoint_captures_sequences() {
854        let mut mgr = StreamCheckpointManager::new(enabled_config());
855
856        let seq = Arc::new(AtomicU64::new(0));
857        let wm = Arc::new(AtomicI64::new(i64::MIN));
858        mgr.register_source("src1", Arc::clone(&seq), Arc::clone(&wm));
859
860        // Simulate pushes
861        seq.store(42, Ordering::Release);
862
863        let id = mgr.trigger().unwrap();
864        let cp = mgr.get_checkpoint(id).unwrap();
865        assert_eq!(cp.source_sequences.get("src1"), Some(&42));
866    }
867
868    #[test]
869    fn test_checkpoint_captures_watermarks() {
870        let mut mgr = StreamCheckpointManager::new(enabled_config());
871
872        let seq = Arc::new(AtomicU64::new(0));
873        let wm = Arc::new(AtomicI64::new(5000));
874        mgr.register_source("src1", Arc::clone(&seq), Arc::clone(&wm));
875
876        let id = mgr.trigger().unwrap();
877        let cp = mgr.get_checkpoint(id).unwrap();
878        assert_eq!(cp.watermarks.get("src1"), Some(&5000));
879    }
880
881    #[test]
882    fn test_restore_from_checkpoint() {
883        let mut mgr = StreamCheckpointManager::new(enabled_config());
884
885        let seq = Arc::new(AtomicU64::new(10));
886        let wm = Arc::new(AtomicI64::new(1000));
887        mgr.register_source("src1", Arc::clone(&seq), Arc::clone(&wm));
888
889        mgr.trigger();
890        let restored = mgr.restore().unwrap();
891        assert_eq!(restored.source_sequences.get("src1"), Some(&10));
892        assert_eq!(restored.watermarks.get("src1"), Some(&1000));
893    }
894
895    // -- Changelog buffer --
896
897    #[test]
898    fn test_changelog_buffer_push_pop() {
899        let mut buf = StreamChangelogBuffer::new(4, OverflowPolicy::DropNew);
900        assert!(buf.is_empty());
901
902        let entry = StreamChangelogEntry::new(0, StreamChangeOp::Push, 1, i64::MIN);
903        assert!(buf.push(entry));
904        assert_eq!(buf.len(), 1);
905
906        let popped = buf.pop().unwrap();
907        assert_eq!(popped.sequence, 1);
908        assert!(buf.is_empty());
909    }
910
911    #[test]
912    fn test_changelog_buffer_overflow() {
913        // DropNew policy
914        let mut buf = StreamChangelogBuffer::new(2, OverflowPolicy::DropNew);
915        let e1 = StreamChangelogEntry::new(0, StreamChangeOp::Push, 1, i64::MIN);
916        let e2 = StreamChangelogEntry::new(0, StreamChangeOp::Push, 2, i64::MIN);
917        let e3 = StreamChangelogEntry::new(0, StreamChangeOp::Push, 3, i64::MIN);
918
919        assert!(buf.push(e1));
920        assert!(buf.push(e2));
921        assert!(buf.is_full());
922        assert!(!buf.push(e3)); // dropped
923        assert_eq!(buf.overflow_count(), 1);
924
925        // Verify oldest is still there
926        assert_eq!(buf.pop().unwrap().sequence, 1);
927
928        // OverwriteOldest policy
929        let mut buf2 = StreamChangelogBuffer::new(2, OverflowPolicy::OverwriteOldest);
930        assert!(buf2.push(e1));
931        assert!(buf2.push(e2));
932        assert!(buf2.push(e3)); // overwrites e1
933        assert_eq!(buf2.overflow_count(), 1);
934        assert_eq!(buf2.pop().unwrap().sequence, 2); // e1 was overwritten
935    }
936
937    // -- Prune --
938
939    #[test]
940    fn test_checkpoint_prune_old() {
941        let config = StreamCheckpointConfig {
942            interval_ms: Some(1000),
943            max_retained: Some(2),
944            ..Default::default()
945        };
946        let mut mgr = StreamCheckpointManager::new(config);
947
948        mgr.trigger(); // id=1
949        mgr.trigger(); // id=2
950        mgr.trigger(); // id=3 — should prune id=1
951
952        assert_eq!(mgr.checkpoints.len(), 2);
953        assert!(mgr.get_checkpoint(1).is_none());
954        assert!(mgr.get_checkpoint(2).is_some());
955        assert!(mgr.get_checkpoint(3).is_some());
956    }
957
958    // -- Serialization --
959
960    #[test]
961    fn test_checkpoint_serialization() {
962        let mut source_sequences = HashMap::new();
963        source_sequences.insert("src1".to_string(), 100u64);
964        source_sequences.insert("src2".to_string(), 200u64);
965
966        let mut sink_positions = HashMap::new();
967        sink_positions.insert("sink1".to_string(), 50u64);
968
969        let mut watermarks = HashMap::new();
970        watermarks.insert("src1".to_string(), 5000i64);
971        watermarks.insert("src2".to_string(), 6000i64);
972
973        let mut operator_states = HashMap::new();
974        operator_states.insert("op1".to_string(), vec![1, 2, 3, 4]);
975
976        let cp = StreamCheckpoint {
977            id: 42,
978            epoch: 7,
979            source_sequences,
980            sink_positions,
981            watermarks,
982            operator_states,
983            created_at: 1_706_400_000_000,
984        };
985
986        let bytes = cp.to_bytes();
987        let restored = StreamCheckpoint::from_bytes(&bytes).unwrap();
988
989        assert_eq!(restored.id, 42);
990        assert_eq!(restored.epoch, 7);
991        assert_eq!(restored.created_at, 1_706_400_000_000);
992        assert_eq!(restored.source_sequences.get("src1"), Some(&100));
993        assert_eq!(restored.source_sequences.get("src2"), Some(&200));
994        assert_eq!(restored.sink_positions.get("sink1"), Some(&50));
995        assert_eq!(restored.watermarks.get("src1"), Some(&5000));
996        assert_eq!(restored.watermarks.get("src2"), Some(&6000));
997        assert_eq!(restored.operator_states.get("op1"), Some(&vec![1, 2, 3, 4]));
998    }
999
1000    #[test]
1001    fn test_changelog_entry_size() {
1002        assert_eq!(
1003            std::mem::size_of::<StreamChangelogEntry>(),
1004            24,
1005            "StreamChangelogEntry must be exactly 24 bytes"
1006        );
1007    }
1008
1009    // -- Source sequence counter tests --
1010
1011    #[test]
1012    fn test_source_sequence_counter() {
1013        let seq = Arc::new(AtomicU64::new(0));
1014        assert_eq!(seq.load(Ordering::Acquire), 0);
1015
1016        seq.fetch_add(1, Ordering::Relaxed);
1017        seq.fetch_add(1, Ordering::Relaxed);
1018        seq.fetch_add(1, Ordering::Relaxed);
1019        assert_eq!(seq.load(Ordering::Acquire), 3);
1020    }
1021
1022    #[test]
1023    fn test_source_clone_shares_sequence() {
1024        let seq = Arc::new(AtomicU64::new(0));
1025        let seq2 = Arc::clone(&seq);
1026
1027        seq.fetch_add(1, Ordering::Relaxed);
1028        assert_eq!(seq2.load(Ordering::Acquire), 1);
1029
1030        seq2.fetch_add(5, Ordering::Relaxed);
1031        assert_eq!(seq.load(Ordering::Acquire), 6);
1032    }
1033}