Skip to main content

laminar_core/sink/
checkpoint.rs

1//! Sink checkpoint management for exactly-once recovery
2
3#![allow(clippy::cast_possible_truncation)]
4
5use std::collections::HashMap;
6
7use super::error::SinkError;
8use super::traits::TransactionId;
9
10/// Offset tracking for sink partitions
11#[derive(Debug, Clone, PartialEq, Eq)]
12pub enum SinkOffset {
13    /// Numeric offset (Kafka-style)
14    Numeric(u64),
15    /// String offset (some systems use string identifiers)
16    String(String),
17    /// Binary offset (opaque bytes)
18    Binary(Vec<u8>),
19}
20
21impl SinkOffset {
22    /// Serialize to bytes
23    #[must_use]
24    pub fn to_bytes(&self) -> Vec<u8> {
25        match self {
26            Self::Numeric(n) => {
27                let mut bytes = vec![0u8]; // Type tag
28                bytes.extend_from_slice(&n.to_le_bytes());
29                bytes
30            }
31            Self::String(s) => {
32                let mut bytes = vec![1u8]; // Type tag
33                bytes.extend_from_slice(&(s.len() as u32).to_le_bytes());
34                bytes.extend_from_slice(s.as_bytes());
35                bytes
36            }
37            Self::Binary(b) => {
38                let mut bytes = vec![2u8]; // Type tag
39                bytes.extend_from_slice(&(b.len() as u32).to_le_bytes());
40                bytes.extend_from_slice(b);
41                bytes
42            }
43        }
44    }
45
46    /// Deserialize from bytes, returns `(offset, bytes_consumed)`
47    #[must_use]
48    pub fn from_bytes(bytes: &[u8]) -> Option<(Self, usize)> {
49        if bytes.is_empty() {
50            return None;
51        }
52
53        match bytes[0] {
54            0 => {
55                // Numeric
56                if bytes.len() < 9 {
57                    return None;
58                }
59                let n = u64::from_le_bytes(bytes[1..9].try_into().ok()?);
60                Some((Self::Numeric(n), 9))
61            }
62            1 => {
63                // String
64                if bytes.len() < 5 {
65                    return None;
66                }
67                let len = u32::from_le_bytes(bytes[1..5].try_into().ok()?) as usize;
68                if bytes.len() < 5 + len {
69                    return None;
70                }
71                let s = String::from_utf8_lossy(&bytes[5..5 + len]).to_string();
72                Some((Self::String(s), 5 + len))
73            }
74            2 => {
75                // Binary
76                if bytes.len() < 5 {
77                    return None;
78                }
79                let len = u32::from_le_bytes(bytes[1..5].try_into().ok()?) as usize;
80                if bytes.len() < 5 + len {
81                    return None;
82                }
83                let b = bytes[5..5 + len].to_vec();
84                Some((Self::Binary(b), 5 + len))
85            }
86            _ => None,
87        }
88    }
89}
90
91/// Checkpoint data for exactly-once sinks
92///
93/// This captures the complete state needed to recover a sink after failure:
94/// - Sink identification
95/// - Per-partition offsets (for resumption)
96/// - Pending transaction ID (for rollback on recovery)
97/// - Custom metadata
98#[derive(Debug, Clone)]
99pub struct SinkCheckpoint {
100    /// Unique identifier for the sink
101    sink_id: String,
102
103    /// Offsets per partition/topic
104    offsets: HashMap<String, SinkOffset>,
105
106    /// Pending transaction ID (if any) - needs rollback on recovery
107    pending_transaction: Option<TransactionId>,
108
109    /// Epoch number for this checkpoint
110    epoch: u64,
111
112    /// Timestamp when checkpoint was created
113    timestamp: u64,
114
115    /// Custom metadata for sink-specific state
116    metadata: HashMap<String, Vec<u8>>,
117}
118
119impl SinkCheckpoint {
120    /// Create a new sink checkpoint
121    #[must_use]
122    pub fn new(sink_id: impl Into<String>) -> Self {
123        Self {
124            sink_id: sink_id.into(),
125            offsets: HashMap::new(),
126            pending_transaction: None,
127            epoch: 0,
128            timestamp: std::time::SystemTime::now()
129                .duration_since(std::time::UNIX_EPOCH)
130                .unwrap_or_default()
131                .as_millis() as u64,
132            metadata: HashMap::new(),
133        }
134    }
135
136    /// Get the sink ID
137    #[must_use]
138    pub fn sink_id(&self) -> &str {
139        &self.sink_id
140    }
141
142    /// Get the epoch number
143    #[must_use]
144    pub fn epoch(&self) -> u64 {
145        self.epoch
146    }
147
148    /// Set the epoch number
149    pub fn set_epoch(&mut self, epoch: u64) {
150        self.epoch = epoch;
151    }
152
153    /// Get the timestamp
154    #[must_use]
155    pub fn timestamp(&self) -> u64 {
156        self.timestamp
157    }
158
159    /// Set an offset for a partition
160    pub fn set_offset(&mut self, partition: impl Into<String>, offset: SinkOffset) {
161        self.offsets.insert(partition.into(), offset);
162    }
163
164    /// Get an offset for a partition
165    #[must_use]
166    pub fn get_offset(&self, partition: &str) -> Option<&SinkOffset> {
167        self.offsets.get(partition)
168    }
169
170    /// Get all offsets
171    #[must_use]
172    pub fn offsets(&self) -> &HashMap<String, SinkOffset> {
173        &self.offsets
174    }
175
176    /// Set the pending transaction ID
177    pub fn set_transaction_id(&mut self, tx_id: Option<TransactionId>) {
178        self.pending_transaction = tx_id;
179    }
180
181    /// Get the pending transaction ID
182    #[must_use]
183    pub fn pending_transaction_id(&self) -> Option<&TransactionId> {
184        self.pending_transaction.as_ref()
185    }
186
187    /// Set custom metadata
188    pub fn set_metadata(&mut self, key: impl Into<String>, value: Vec<u8>) {
189        self.metadata.insert(key.into(), value);
190    }
191
192    /// Get custom metadata
193    #[must_use]
194    pub fn get_metadata(&self, key: &str) -> Option<&[u8]> {
195        self.metadata.get(key).map(Vec::as_slice)
196    }
197
198    /// Serialize checkpoint to bytes for persistence
199    #[must_use]
200    pub fn to_bytes(&self) -> Vec<u8> {
201        // Format:
202        // [version: 1][sink_id_len: 4][sink_id][epoch: 8][timestamp: 8]
203        // [has_tx: 1][tx_bytes_len: 4][tx_bytes]?
204        // [num_offsets: 4][offset entries...]
205        // [num_metadata: 4][metadata entries...]
206
207        let mut bytes = Vec::new();
208
209        // Version
210        bytes.push(1u8);
211
212        // Sink ID
213        bytes.extend_from_slice(&(self.sink_id.len() as u32).to_le_bytes());
214        bytes.extend_from_slice(self.sink_id.as_bytes());
215
216        // Epoch and timestamp
217        bytes.extend_from_slice(&self.epoch.to_le_bytes());
218        bytes.extend_from_slice(&self.timestamp.to_le_bytes());
219
220        // Pending transaction
221        if let Some(ref tx) = self.pending_transaction {
222            bytes.push(1u8);
223            let tx_bytes = tx.to_bytes();
224            bytes.extend_from_slice(&(tx_bytes.len() as u32).to_le_bytes());
225            bytes.extend_from_slice(&tx_bytes);
226        } else {
227            bytes.push(0u8);
228        }
229
230        // Offsets
231        bytes.extend_from_slice(&(self.offsets.len() as u32).to_le_bytes());
232        for (partition, offset) in &self.offsets {
233            bytes.extend_from_slice(&(partition.len() as u32).to_le_bytes());
234            bytes.extend_from_slice(partition.as_bytes());
235            let offset_bytes = offset.to_bytes();
236            bytes.extend_from_slice(&(offset_bytes.len() as u32).to_le_bytes());
237            bytes.extend_from_slice(&offset_bytes);
238        }
239
240        // Metadata
241        bytes.extend_from_slice(&(self.metadata.len() as u32).to_le_bytes());
242        for (key, value) in &self.metadata {
243            bytes.extend_from_slice(&(key.len() as u32).to_le_bytes());
244            bytes.extend_from_slice(key.as_bytes());
245            bytes.extend_from_slice(&(value.len() as u32).to_le_bytes());
246            bytes.extend_from_slice(value);
247        }
248
249        bytes
250    }
251
252    /// Deserialize checkpoint from bytes
253    ///
254    /// # Errors
255    ///
256    /// Returns an error if the bytes are malformed or the version is unsupported.
257    ///
258    /// # Panics
259    ///
260    /// Will not panic - all array conversions are bounds-checked before unwrapping.
261    #[allow(clippy::missing_panics_doc, clippy::too_many_lines)]
262    pub fn from_bytes(bytes: &[u8]) -> Result<Self, SinkError> {
263        if bytes.is_empty() {
264            return Err(SinkError::CheckpointError(
265                "Empty checkpoint data".to_string(),
266            ));
267        }
268
269        let mut pos = 0;
270
271        // Version
272        let version = bytes[pos];
273        pos += 1;
274        if version != 1 {
275            return Err(SinkError::CheckpointError(format!(
276                "Unsupported checkpoint version: {version}"
277            )));
278        }
279
280        // Helper to read u32 length
281        let read_u32 = |pos: &mut usize| -> Result<u32, SinkError> {
282            if *pos + 4 > bytes.len() {
283                return Err(SinkError::CheckpointError(
284                    "Unexpected end of data".to_string(),
285                ));
286            }
287            let val = u32::from_le_bytes(bytes[*pos..*pos + 4].try_into().unwrap());
288            *pos += 4;
289            Ok(val)
290        };
291
292        // Helper to read u64
293        let read_u64 = |pos: &mut usize| -> Result<u64, SinkError> {
294            if *pos + 8 > bytes.len() {
295                return Err(SinkError::CheckpointError(
296                    "Unexpected end of data".to_string(),
297                ));
298            }
299            let val = u64::from_le_bytes(bytes[*pos..*pos + 8].try_into().unwrap());
300            *pos += 8;
301            Ok(val)
302        };
303
304        // Sink ID
305        let sink_id_len = read_u32(&mut pos)? as usize;
306        if pos + sink_id_len > bytes.len() {
307            return Err(SinkError::CheckpointError(
308                "Invalid sink_id length".to_string(),
309            ));
310        }
311        let sink_id = String::from_utf8_lossy(&bytes[pos..pos + sink_id_len]).to_string();
312        pos += sink_id_len;
313
314        // Epoch and timestamp
315        let epoch = read_u64(&mut pos)?;
316        let timestamp = read_u64(&mut pos)?;
317
318        // Pending transaction
319        if pos >= bytes.len() {
320            return Err(SinkError::CheckpointError(
321                "Unexpected end of data".to_string(),
322            ));
323        }
324        let has_tx = bytes[pos] == 1;
325        pos += 1;
326
327        let pending_transaction = if has_tx {
328            let tx_len = read_u32(&mut pos)? as usize;
329            if pos + tx_len > bytes.len() {
330                return Err(SinkError::CheckpointError(
331                    "Invalid transaction length".to_string(),
332                ));
333            }
334            let tx = TransactionId::from_bytes(&bytes[pos..pos + tx_len]).ok_or_else(|| {
335                SinkError::CheckpointError("Invalid transaction data".to_string())
336            })?;
337            pos += tx_len;
338            Some(tx)
339        } else {
340            None
341        };
342
343        // Offsets
344        let num_offsets = read_u32(&mut pos)?;
345        let mut offsets = HashMap::new();
346        for _ in 0..num_offsets {
347            let partition_len = read_u32(&mut pos)? as usize;
348            if pos + partition_len > bytes.len() {
349                return Err(SinkError::CheckpointError(
350                    "Invalid partition length".to_string(),
351                ));
352            }
353            let partition = String::from_utf8_lossy(&bytes[pos..pos + partition_len]).to_string();
354            pos += partition_len;
355
356            let offset_len = read_u32(&mut pos)? as usize;
357            if pos + offset_len > bytes.len() {
358                return Err(SinkError::CheckpointError(
359                    "Invalid offset length".to_string(),
360                ));
361            }
362            let (offset, _) = SinkOffset::from_bytes(&bytes[pos..pos + offset_len])
363                .ok_or_else(|| SinkError::CheckpointError("Invalid offset data".to_string()))?;
364            pos += offset_len;
365
366            offsets.insert(partition, offset);
367        }
368
369        // Metadata
370        let num_metadata = read_u32(&mut pos)?;
371        let mut metadata = HashMap::new();
372        for _ in 0..num_metadata {
373            let key_len = read_u32(&mut pos)? as usize;
374            if pos + key_len > bytes.len() {
375                return Err(SinkError::CheckpointError(
376                    "Invalid metadata key length".to_string(),
377                ));
378            }
379            let key = String::from_utf8_lossy(&bytes[pos..pos + key_len]).to_string();
380            pos += key_len;
381
382            let value_len = read_u32(&mut pos)? as usize;
383            if pos + value_len > bytes.len() {
384                return Err(SinkError::CheckpointError(
385                    "Invalid metadata value length".to_string(),
386                ));
387            }
388            let value = bytes[pos..pos + value_len].to_vec();
389            pos += value_len;
390
391            metadata.insert(key, value);
392        }
393
394        Ok(Self {
395            sink_id,
396            offsets,
397            pending_transaction,
398            epoch,
399            timestamp,
400            metadata,
401        })
402    }
403}
404
405/// Manager for sink checkpoints
406///
407/// Coordinates checkpointing across multiple sinks and integrates
408/// with the main checkpoint system.
409pub struct SinkCheckpointManager {
410    /// Sink ID to checkpoint mapping
411    checkpoints: HashMap<String, SinkCheckpoint>,
412
413    /// Current epoch
414    current_epoch: u64,
415}
416
417impl SinkCheckpointManager {
418    /// Create a new checkpoint manager
419    #[must_use]
420    pub fn new() -> Self {
421        Self {
422            checkpoints: HashMap::new(),
423            current_epoch: 0,
424        }
425    }
426
427    /// Register a sink checkpoint
428    pub fn register(&mut self, checkpoint: SinkCheckpoint) {
429        let sink_id = checkpoint.sink_id.clone();
430        self.checkpoints.insert(sink_id, checkpoint);
431    }
432
433    /// Get a sink's checkpoint
434    #[must_use]
435    pub fn get(&self, sink_id: &str) -> Option<&SinkCheckpoint> {
436        self.checkpoints.get(sink_id)
437    }
438
439    /// Get a mutable reference to a sink's checkpoint
440    pub fn get_mut(&mut self, sink_id: &str) -> Option<&mut SinkCheckpoint> {
441        self.checkpoints.get_mut(sink_id)
442    }
443
444    /// Advance the epoch for all sinks
445    pub fn advance_epoch(&mut self) -> u64 {
446        self.current_epoch += 1;
447        for checkpoint in self.checkpoints.values_mut() {
448            checkpoint.set_epoch(self.current_epoch);
449        }
450        self.current_epoch
451    }
452
453    /// Get the current epoch
454    #[must_use]
455    pub fn current_epoch(&self) -> u64 {
456        self.current_epoch
457    }
458
459    /// Serialize all checkpoints to bytes
460    #[must_use]
461    pub fn to_bytes(&self) -> Vec<u8> {
462        let mut bytes = Vec::new();
463
464        // Epoch
465        bytes.extend_from_slice(&self.current_epoch.to_le_bytes());
466
467        // Number of checkpoints
468        bytes.extend_from_slice(&(self.checkpoints.len() as u32).to_le_bytes());
469
470        // Each checkpoint
471        for checkpoint in self.checkpoints.values() {
472            let cp_bytes = checkpoint.to_bytes();
473            bytes.extend_from_slice(&(cp_bytes.len() as u32).to_le_bytes());
474            bytes.extend_from_slice(&cp_bytes);
475        }
476
477        bytes
478    }
479
480    /// Deserialize from bytes
481    ///
482    /// # Errors
483    ///
484    /// Returns an error if the bytes are malformed.
485    ///
486    /// # Panics
487    ///
488    /// Will not panic - all array conversions are bounds-checked before unwrapping.
489    #[allow(clippy::missing_panics_doc)]
490    pub fn from_bytes(bytes: &[u8]) -> Result<Self, SinkError> {
491        if bytes.len() < 12 {
492            return Err(SinkError::CheckpointError(
493                "Checkpoint data too short".to_string(),
494            ));
495        }
496
497        let mut pos = 0;
498
499        // Epoch
500        let current_epoch = u64::from_le_bytes(bytes[pos..pos + 8].try_into().unwrap());
501        pos += 8;
502
503        // Number of checkpoints
504        let num_checkpoints = u32::from_le_bytes(bytes[pos..pos + 4].try_into().unwrap()) as usize;
505        pos += 4;
506
507        let mut checkpoints = HashMap::new();
508
509        for _ in 0..num_checkpoints {
510            if pos + 4 > bytes.len() {
511                return Err(SinkError::CheckpointError(
512                    "Unexpected end of data".to_string(),
513                ));
514            }
515            let cp_len = u32::from_le_bytes(bytes[pos..pos + 4].try_into().unwrap()) as usize;
516            pos += 4;
517
518            if pos + cp_len > bytes.len() {
519                return Err(SinkError::CheckpointError(
520                    "Invalid checkpoint length".to_string(),
521                ));
522            }
523            let checkpoint = SinkCheckpoint::from_bytes(&bytes[pos..pos + cp_len])?;
524            pos += cp_len;
525
526            checkpoints.insert(checkpoint.sink_id.clone(), checkpoint);
527        }
528
529        Ok(Self {
530            checkpoints,
531            current_epoch,
532        })
533    }
534}
535
536impl Default for SinkCheckpointManager {
537    fn default() -> Self {
538        Self::new()
539    }
540}
541
542#[cfg(test)]
543mod tests {
544    use super::*;
545
546    #[test]
547    fn test_sink_offset_numeric() {
548        let offset = SinkOffset::Numeric(12345);
549        let bytes = offset.to_bytes();
550        let (restored, _) = SinkOffset::from_bytes(&bytes).unwrap();
551        assert_eq!(offset, restored);
552    }
553
554    #[test]
555    fn test_sink_offset_string() {
556        let offset = SinkOffset::String("offset-abc-123".to_string());
557        let bytes = offset.to_bytes();
558        let (restored, _) = SinkOffset::from_bytes(&bytes).unwrap();
559        assert_eq!(offset, restored);
560    }
561
562    #[test]
563    fn test_sink_offset_binary() {
564        let offset = SinkOffset::Binary(vec![1, 2, 3, 4, 5]);
565        let bytes = offset.to_bytes();
566        let (restored, _) = SinkOffset::from_bytes(&bytes).unwrap();
567        assert_eq!(offset, restored);
568    }
569
570    #[test]
571    fn test_sink_checkpoint_new() {
572        let checkpoint = SinkCheckpoint::new("my-sink");
573        assert_eq!(checkpoint.sink_id(), "my-sink");
574        assert_eq!(checkpoint.epoch(), 0);
575        assert!(checkpoint.pending_transaction_id().is_none());
576    }
577
578    #[test]
579    fn test_sink_checkpoint_with_offsets() {
580        let mut checkpoint = SinkCheckpoint::new("kafka-sink");
581        checkpoint.set_offset("topic-0", SinkOffset::Numeric(100));
582        checkpoint.set_offset("topic-1", SinkOffset::Numeric(200));
583
584        assert_eq!(
585            checkpoint.get_offset("topic-0"),
586            Some(&SinkOffset::Numeric(100))
587        );
588        assert_eq!(
589            checkpoint.get_offset("topic-1"),
590            Some(&SinkOffset::Numeric(200))
591        );
592        assert_eq!(checkpoint.get_offset("topic-2"), None);
593    }
594
595    #[test]
596    fn test_sink_checkpoint_serialization() {
597        let mut checkpoint = SinkCheckpoint::new("test-sink");
598        checkpoint.set_epoch(42);
599        checkpoint.set_offset("partition-0", SinkOffset::Numeric(1000));
600        checkpoint.set_offset("partition-1", SinkOffset::String("abc".to_string()));
601        checkpoint.set_transaction_id(Some(TransactionId::new(999)));
602        checkpoint.set_metadata("custom-key", b"custom-value".to_vec());
603
604        let bytes = checkpoint.to_bytes();
605        let restored = SinkCheckpoint::from_bytes(&bytes).unwrap();
606
607        assert_eq!(restored.sink_id(), "test-sink");
608        assert_eq!(restored.epoch(), 42);
609        assert_eq!(
610            restored.get_offset("partition-0"),
611            Some(&SinkOffset::Numeric(1000))
612        );
613        assert_eq!(
614            restored.get_offset("partition-1"),
615            Some(&SinkOffset::String("abc".to_string()))
616        );
617        assert!(restored.pending_transaction_id().is_some());
618        assert_eq!(
619            restored.get_metadata("custom-key"),
620            Some(b"custom-value".as_ref())
621        );
622    }
623
624    #[test]
625    fn test_checkpoint_manager() {
626        let mut manager = SinkCheckpointManager::new();
627
628        let mut cp1 = SinkCheckpoint::new("sink-1");
629        cp1.set_offset("p0", SinkOffset::Numeric(100));
630
631        let mut cp2 = SinkCheckpoint::new("sink-2");
632        cp2.set_offset("p0", SinkOffset::Numeric(200));
633
634        manager.register(cp1);
635        manager.register(cp2);
636
637        assert_eq!(manager.current_epoch(), 0);
638        manager.advance_epoch();
639        assert_eq!(manager.current_epoch(), 1);
640
641        let cp = manager.get("sink-1").unwrap();
642        assert_eq!(cp.epoch(), 1);
643    }
644
645    #[test]
646    fn test_checkpoint_manager_serialization() {
647        let mut manager = SinkCheckpointManager::new();
648
649        let mut cp = SinkCheckpoint::new("sink-1");
650        cp.set_offset("p0", SinkOffset::Numeric(100));
651        manager.register(cp);
652
653        manager.advance_epoch();
654        manager.advance_epoch();
655
656        let bytes = manager.to_bytes();
657        let restored = SinkCheckpointManager::from_bytes(&bytes).unwrap();
658
659        assert_eq!(restored.current_epoch(), 2);
660        assert!(restored.get("sink-1").is_some());
661    }
662}