1#![allow(clippy::cast_possible_truncation)]
4
5use std::collections::HashMap;
6
7use super::error::SinkError;
8use super::traits::TransactionId;
9
10#[derive(Debug, Clone, PartialEq, Eq)]
12pub enum SinkOffset {
13 Numeric(u64),
15 String(String),
17 Binary(Vec<u8>),
19}
20
21impl SinkOffset {
22 #[must_use]
24 pub fn to_bytes(&self) -> Vec<u8> {
25 match self {
26 Self::Numeric(n) => {
27 let mut bytes = vec![0u8]; bytes.extend_from_slice(&n.to_le_bytes());
29 bytes
30 }
31 Self::String(s) => {
32 let mut bytes = vec![1u8]; bytes.extend_from_slice(&(s.len() as u32).to_le_bytes());
34 bytes.extend_from_slice(s.as_bytes());
35 bytes
36 }
37 Self::Binary(b) => {
38 let mut bytes = vec![2u8]; bytes.extend_from_slice(&(b.len() as u32).to_le_bytes());
40 bytes.extend_from_slice(b);
41 bytes
42 }
43 }
44 }
45
46 #[must_use]
48 pub fn from_bytes(bytes: &[u8]) -> Option<(Self, usize)> {
49 if bytes.is_empty() {
50 return None;
51 }
52
53 match bytes[0] {
54 0 => {
55 if bytes.len() < 9 {
57 return None;
58 }
59 let n = u64::from_le_bytes(bytes[1..9].try_into().ok()?);
60 Some((Self::Numeric(n), 9))
61 }
62 1 => {
63 if bytes.len() < 5 {
65 return None;
66 }
67 let len = u32::from_le_bytes(bytes[1..5].try_into().ok()?) as usize;
68 if bytes.len() < 5 + len {
69 return None;
70 }
71 let s = String::from_utf8_lossy(&bytes[5..5 + len]).to_string();
72 Some((Self::String(s), 5 + len))
73 }
74 2 => {
75 if bytes.len() < 5 {
77 return None;
78 }
79 let len = u32::from_le_bytes(bytes[1..5].try_into().ok()?) as usize;
80 if bytes.len() < 5 + len {
81 return None;
82 }
83 let b = bytes[5..5 + len].to_vec();
84 Some((Self::Binary(b), 5 + len))
85 }
86 _ => None,
87 }
88 }
89}
90
91#[derive(Debug, Clone)]
99pub struct SinkCheckpoint {
100 sink_id: String,
102
103 offsets: HashMap<String, SinkOffset>,
105
106 pending_transaction: Option<TransactionId>,
108
109 epoch: u64,
111
112 timestamp: u64,
114
115 metadata: HashMap<String, Vec<u8>>,
117}
118
119impl SinkCheckpoint {
120 #[must_use]
122 pub fn new(sink_id: impl Into<String>) -> Self {
123 Self {
124 sink_id: sink_id.into(),
125 offsets: HashMap::new(),
126 pending_transaction: None,
127 epoch: 0,
128 timestamp: std::time::SystemTime::now()
129 .duration_since(std::time::UNIX_EPOCH)
130 .unwrap_or_default()
131 .as_millis() as u64,
132 metadata: HashMap::new(),
133 }
134 }
135
136 #[must_use]
138 pub fn sink_id(&self) -> &str {
139 &self.sink_id
140 }
141
142 #[must_use]
144 pub fn epoch(&self) -> u64 {
145 self.epoch
146 }
147
148 pub fn set_epoch(&mut self, epoch: u64) {
150 self.epoch = epoch;
151 }
152
153 #[must_use]
155 pub fn timestamp(&self) -> u64 {
156 self.timestamp
157 }
158
159 pub fn set_offset(&mut self, partition: impl Into<String>, offset: SinkOffset) {
161 self.offsets.insert(partition.into(), offset);
162 }
163
164 #[must_use]
166 pub fn get_offset(&self, partition: &str) -> Option<&SinkOffset> {
167 self.offsets.get(partition)
168 }
169
170 #[must_use]
172 pub fn offsets(&self) -> &HashMap<String, SinkOffset> {
173 &self.offsets
174 }
175
176 pub fn set_transaction_id(&mut self, tx_id: Option<TransactionId>) {
178 self.pending_transaction = tx_id;
179 }
180
181 #[must_use]
183 pub fn pending_transaction_id(&self) -> Option<&TransactionId> {
184 self.pending_transaction.as_ref()
185 }
186
187 pub fn set_metadata(&mut self, key: impl Into<String>, value: Vec<u8>) {
189 self.metadata.insert(key.into(), value);
190 }
191
192 #[must_use]
194 pub fn get_metadata(&self, key: &str) -> Option<&[u8]> {
195 self.metadata.get(key).map(Vec::as_slice)
196 }
197
198 #[must_use]
200 pub fn to_bytes(&self) -> Vec<u8> {
201 let mut bytes = Vec::new();
208
209 bytes.push(1u8);
211
212 bytes.extend_from_slice(&(self.sink_id.len() as u32).to_le_bytes());
214 bytes.extend_from_slice(self.sink_id.as_bytes());
215
216 bytes.extend_from_slice(&self.epoch.to_le_bytes());
218 bytes.extend_from_slice(&self.timestamp.to_le_bytes());
219
220 if let Some(ref tx) = self.pending_transaction {
222 bytes.push(1u8);
223 let tx_bytes = tx.to_bytes();
224 bytes.extend_from_slice(&(tx_bytes.len() as u32).to_le_bytes());
225 bytes.extend_from_slice(&tx_bytes);
226 } else {
227 bytes.push(0u8);
228 }
229
230 bytes.extend_from_slice(&(self.offsets.len() as u32).to_le_bytes());
232 for (partition, offset) in &self.offsets {
233 bytes.extend_from_slice(&(partition.len() as u32).to_le_bytes());
234 bytes.extend_from_slice(partition.as_bytes());
235 let offset_bytes = offset.to_bytes();
236 bytes.extend_from_slice(&(offset_bytes.len() as u32).to_le_bytes());
237 bytes.extend_from_slice(&offset_bytes);
238 }
239
240 bytes.extend_from_slice(&(self.metadata.len() as u32).to_le_bytes());
242 for (key, value) in &self.metadata {
243 bytes.extend_from_slice(&(key.len() as u32).to_le_bytes());
244 bytes.extend_from_slice(key.as_bytes());
245 bytes.extend_from_slice(&(value.len() as u32).to_le_bytes());
246 bytes.extend_from_slice(value);
247 }
248
249 bytes
250 }
251
252 #[allow(clippy::missing_panics_doc, clippy::too_many_lines)]
262 pub fn from_bytes(bytes: &[u8]) -> Result<Self, SinkError> {
263 if bytes.is_empty() {
264 return Err(SinkError::CheckpointError(
265 "Empty checkpoint data".to_string(),
266 ));
267 }
268
269 let mut pos = 0;
270
271 let version = bytes[pos];
273 pos += 1;
274 if version != 1 {
275 return Err(SinkError::CheckpointError(format!(
276 "Unsupported checkpoint version: {version}"
277 )));
278 }
279
280 let read_u32 = |pos: &mut usize| -> Result<u32, SinkError> {
282 if *pos + 4 > bytes.len() {
283 return Err(SinkError::CheckpointError(
284 "Unexpected end of data".to_string(),
285 ));
286 }
287 let val = u32::from_le_bytes(bytes[*pos..*pos + 4].try_into().unwrap());
288 *pos += 4;
289 Ok(val)
290 };
291
292 let read_u64 = |pos: &mut usize| -> Result<u64, SinkError> {
294 if *pos + 8 > bytes.len() {
295 return Err(SinkError::CheckpointError(
296 "Unexpected end of data".to_string(),
297 ));
298 }
299 let val = u64::from_le_bytes(bytes[*pos..*pos + 8].try_into().unwrap());
300 *pos += 8;
301 Ok(val)
302 };
303
304 let sink_id_len = read_u32(&mut pos)? as usize;
306 if pos + sink_id_len > bytes.len() {
307 return Err(SinkError::CheckpointError(
308 "Invalid sink_id length".to_string(),
309 ));
310 }
311 let sink_id = String::from_utf8_lossy(&bytes[pos..pos + sink_id_len]).to_string();
312 pos += sink_id_len;
313
314 let epoch = read_u64(&mut pos)?;
316 let timestamp = read_u64(&mut pos)?;
317
318 if pos >= bytes.len() {
320 return Err(SinkError::CheckpointError(
321 "Unexpected end of data".to_string(),
322 ));
323 }
324 let has_tx = bytes[pos] == 1;
325 pos += 1;
326
327 let pending_transaction = if has_tx {
328 let tx_len = read_u32(&mut pos)? as usize;
329 if pos + tx_len > bytes.len() {
330 return Err(SinkError::CheckpointError(
331 "Invalid transaction length".to_string(),
332 ));
333 }
334 let tx = TransactionId::from_bytes(&bytes[pos..pos + tx_len]).ok_or_else(|| {
335 SinkError::CheckpointError("Invalid transaction data".to_string())
336 })?;
337 pos += tx_len;
338 Some(tx)
339 } else {
340 None
341 };
342
343 let num_offsets = read_u32(&mut pos)?;
345 let mut offsets = HashMap::new();
346 for _ in 0..num_offsets {
347 let partition_len = read_u32(&mut pos)? as usize;
348 if pos + partition_len > bytes.len() {
349 return Err(SinkError::CheckpointError(
350 "Invalid partition length".to_string(),
351 ));
352 }
353 let partition = String::from_utf8_lossy(&bytes[pos..pos + partition_len]).to_string();
354 pos += partition_len;
355
356 let offset_len = read_u32(&mut pos)? as usize;
357 if pos + offset_len > bytes.len() {
358 return Err(SinkError::CheckpointError(
359 "Invalid offset length".to_string(),
360 ));
361 }
362 let (offset, _) = SinkOffset::from_bytes(&bytes[pos..pos + offset_len])
363 .ok_or_else(|| SinkError::CheckpointError("Invalid offset data".to_string()))?;
364 pos += offset_len;
365
366 offsets.insert(partition, offset);
367 }
368
369 let num_metadata = read_u32(&mut pos)?;
371 let mut metadata = HashMap::new();
372 for _ in 0..num_metadata {
373 let key_len = read_u32(&mut pos)? as usize;
374 if pos + key_len > bytes.len() {
375 return Err(SinkError::CheckpointError(
376 "Invalid metadata key length".to_string(),
377 ));
378 }
379 let key = String::from_utf8_lossy(&bytes[pos..pos + key_len]).to_string();
380 pos += key_len;
381
382 let value_len = read_u32(&mut pos)? as usize;
383 if pos + value_len > bytes.len() {
384 return Err(SinkError::CheckpointError(
385 "Invalid metadata value length".to_string(),
386 ));
387 }
388 let value = bytes[pos..pos + value_len].to_vec();
389 pos += value_len;
390
391 metadata.insert(key, value);
392 }
393
394 Ok(Self {
395 sink_id,
396 offsets,
397 pending_transaction,
398 epoch,
399 timestamp,
400 metadata,
401 })
402 }
403}
404
405pub struct SinkCheckpointManager {
410 checkpoints: HashMap<String, SinkCheckpoint>,
412
413 current_epoch: u64,
415}
416
417impl SinkCheckpointManager {
418 #[must_use]
420 pub fn new() -> Self {
421 Self {
422 checkpoints: HashMap::new(),
423 current_epoch: 0,
424 }
425 }
426
427 pub fn register(&mut self, checkpoint: SinkCheckpoint) {
429 let sink_id = checkpoint.sink_id.clone();
430 self.checkpoints.insert(sink_id, checkpoint);
431 }
432
433 #[must_use]
435 pub fn get(&self, sink_id: &str) -> Option<&SinkCheckpoint> {
436 self.checkpoints.get(sink_id)
437 }
438
439 pub fn get_mut(&mut self, sink_id: &str) -> Option<&mut SinkCheckpoint> {
441 self.checkpoints.get_mut(sink_id)
442 }
443
444 pub fn advance_epoch(&mut self) -> u64 {
446 self.current_epoch += 1;
447 for checkpoint in self.checkpoints.values_mut() {
448 checkpoint.set_epoch(self.current_epoch);
449 }
450 self.current_epoch
451 }
452
453 #[must_use]
455 pub fn current_epoch(&self) -> u64 {
456 self.current_epoch
457 }
458
459 #[must_use]
461 pub fn to_bytes(&self) -> Vec<u8> {
462 let mut bytes = Vec::new();
463
464 bytes.extend_from_slice(&self.current_epoch.to_le_bytes());
466
467 bytes.extend_from_slice(&(self.checkpoints.len() as u32).to_le_bytes());
469
470 for checkpoint in self.checkpoints.values() {
472 let cp_bytes = checkpoint.to_bytes();
473 bytes.extend_from_slice(&(cp_bytes.len() as u32).to_le_bytes());
474 bytes.extend_from_slice(&cp_bytes);
475 }
476
477 bytes
478 }
479
480 #[allow(clippy::missing_panics_doc)]
490 pub fn from_bytes(bytes: &[u8]) -> Result<Self, SinkError> {
491 if bytes.len() < 12 {
492 return Err(SinkError::CheckpointError(
493 "Checkpoint data too short".to_string(),
494 ));
495 }
496
497 let mut pos = 0;
498
499 let current_epoch = u64::from_le_bytes(bytes[pos..pos + 8].try_into().unwrap());
501 pos += 8;
502
503 let num_checkpoints = u32::from_le_bytes(bytes[pos..pos + 4].try_into().unwrap()) as usize;
505 pos += 4;
506
507 let mut checkpoints = HashMap::new();
508
509 for _ in 0..num_checkpoints {
510 if pos + 4 > bytes.len() {
511 return Err(SinkError::CheckpointError(
512 "Unexpected end of data".to_string(),
513 ));
514 }
515 let cp_len = u32::from_le_bytes(bytes[pos..pos + 4].try_into().unwrap()) as usize;
516 pos += 4;
517
518 if pos + cp_len > bytes.len() {
519 return Err(SinkError::CheckpointError(
520 "Invalid checkpoint length".to_string(),
521 ));
522 }
523 let checkpoint = SinkCheckpoint::from_bytes(&bytes[pos..pos + cp_len])?;
524 pos += cp_len;
525
526 checkpoints.insert(checkpoint.sink_id.clone(), checkpoint);
527 }
528
529 Ok(Self {
530 checkpoints,
531 current_epoch,
532 })
533 }
534}
535
536impl Default for SinkCheckpointManager {
537 fn default() -> Self {
538 Self::new()
539 }
540}
541
542#[cfg(test)]
543mod tests {
544 use super::*;
545
546 #[test]
547 fn test_sink_offset_numeric() {
548 let offset = SinkOffset::Numeric(12345);
549 let bytes = offset.to_bytes();
550 let (restored, _) = SinkOffset::from_bytes(&bytes).unwrap();
551 assert_eq!(offset, restored);
552 }
553
554 #[test]
555 fn test_sink_offset_string() {
556 let offset = SinkOffset::String("offset-abc-123".to_string());
557 let bytes = offset.to_bytes();
558 let (restored, _) = SinkOffset::from_bytes(&bytes).unwrap();
559 assert_eq!(offset, restored);
560 }
561
562 #[test]
563 fn test_sink_offset_binary() {
564 let offset = SinkOffset::Binary(vec![1, 2, 3, 4, 5]);
565 let bytes = offset.to_bytes();
566 let (restored, _) = SinkOffset::from_bytes(&bytes).unwrap();
567 assert_eq!(offset, restored);
568 }
569
570 #[test]
571 fn test_sink_checkpoint_new() {
572 let checkpoint = SinkCheckpoint::new("my-sink");
573 assert_eq!(checkpoint.sink_id(), "my-sink");
574 assert_eq!(checkpoint.epoch(), 0);
575 assert!(checkpoint.pending_transaction_id().is_none());
576 }
577
578 #[test]
579 fn test_sink_checkpoint_with_offsets() {
580 let mut checkpoint = SinkCheckpoint::new("kafka-sink");
581 checkpoint.set_offset("topic-0", SinkOffset::Numeric(100));
582 checkpoint.set_offset("topic-1", SinkOffset::Numeric(200));
583
584 assert_eq!(
585 checkpoint.get_offset("topic-0"),
586 Some(&SinkOffset::Numeric(100))
587 );
588 assert_eq!(
589 checkpoint.get_offset("topic-1"),
590 Some(&SinkOffset::Numeric(200))
591 );
592 assert_eq!(checkpoint.get_offset("topic-2"), None);
593 }
594
595 #[test]
596 fn test_sink_checkpoint_serialization() {
597 let mut checkpoint = SinkCheckpoint::new("test-sink");
598 checkpoint.set_epoch(42);
599 checkpoint.set_offset("partition-0", SinkOffset::Numeric(1000));
600 checkpoint.set_offset("partition-1", SinkOffset::String("abc".to_string()));
601 checkpoint.set_transaction_id(Some(TransactionId::new(999)));
602 checkpoint.set_metadata("custom-key", b"custom-value".to_vec());
603
604 let bytes = checkpoint.to_bytes();
605 let restored = SinkCheckpoint::from_bytes(&bytes).unwrap();
606
607 assert_eq!(restored.sink_id(), "test-sink");
608 assert_eq!(restored.epoch(), 42);
609 assert_eq!(
610 restored.get_offset("partition-0"),
611 Some(&SinkOffset::Numeric(1000))
612 );
613 assert_eq!(
614 restored.get_offset("partition-1"),
615 Some(&SinkOffset::String("abc".to_string()))
616 );
617 assert!(restored.pending_transaction_id().is_some());
618 assert_eq!(
619 restored.get_metadata("custom-key"),
620 Some(b"custom-value".as_ref())
621 );
622 }
623
624 #[test]
625 fn test_checkpoint_manager() {
626 let mut manager = SinkCheckpointManager::new();
627
628 let mut cp1 = SinkCheckpoint::new("sink-1");
629 cp1.set_offset("p0", SinkOffset::Numeric(100));
630
631 let mut cp2 = SinkCheckpoint::new("sink-2");
632 cp2.set_offset("p0", SinkOffset::Numeric(200));
633
634 manager.register(cp1);
635 manager.register(cp2);
636
637 assert_eq!(manager.current_epoch(), 0);
638 manager.advance_epoch();
639 assert_eq!(manager.current_epoch(), 1);
640
641 let cp = manager.get("sink-1").unwrap();
642 assert_eq!(cp.epoch(), 1);
643 }
644
645 #[test]
646 fn test_checkpoint_manager_serialization() {
647 let mut manager = SinkCheckpointManager::new();
648
649 let mut cp = SinkCheckpoint::new("sink-1");
650 cp.set_offset("p0", SinkOffset::Numeric(100));
651 manager.register(cp);
652
653 manager.advance_epoch();
654 manager.advance_epoch();
655
656 let bytes = manager.to_bytes();
657 let restored = SinkCheckpointManager::from_bytes(&bytes).unwrap();
658
659 assert_eq!(restored.current_epoch(), 2);
660 assert!(restored.get("sink-1").is_some());
661 }
662}