1use std::collections::HashMap;
17use std::time::SystemTime;
18
19use crate::error::StreamingError;
20
21#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
25pub struct CheckpointId {
26 pub stream_id: String,
28 pub sequence: u64,
30 pub created_at: SystemTime,
32}
33
34impl CheckpointId {
35 pub fn new(stream_id: impl Into<String>, sequence: u64) -> Self {
37 Self {
38 stream_id: stream_id.into(),
39 sequence,
40 created_at: SystemTime::now(),
41 }
42 }
43}
44
45#[derive(Debug, Clone)]
68pub struct CheckpointState {
69 pub id: CheckpointId,
71 pub operator_states: HashMap<String, Vec<u8>>,
73 pub source_offsets: HashMap<String, u64>,
75 pub watermark_ns: u64,
77 pub event_count: u64,
79 pub metadata: HashMap<String, String>,
81}
82
83impl CheckpointState {
84 pub fn new(id: CheckpointId) -> Self {
86 Self {
87 id,
88 operator_states: HashMap::new(),
89 source_offsets: HashMap::new(),
90 watermark_ns: 0,
91 event_count: 0,
92 metadata: HashMap::new(),
93 }
94 }
95
96 pub fn set_operator_state(&mut self, operator: impl Into<String>, state: Vec<u8>) {
98 self.operator_states.insert(operator.into(), state);
99 }
100
101 pub fn set_source_offset(&mut self, source: impl Into<String>, offset: u64) {
103 self.source_offsets.insert(source.into(), offset);
104 }
105
106 pub fn serialize(&self) -> Vec<u8> {
108 let mut buf = Vec::new();
109
110 buf.extend_from_slice(&self.id.sequence.to_le_bytes());
111 buf.extend_from_slice(&self.watermark_ns.to_le_bytes());
112 buf.extend_from_slice(&self.event_count.to_le_bytes());
113
114 buf.extend_from_slice(&(self.operator_states.len() as u32).to_le_bytes());
116 for (name, state) in &self.operator_states {
117 let name_bytes = name.as_bytes();
118 buf.extend_from_slice(&(name_bytes.len() as u32).to_le_bytes());
119 buf.extend_from_slice(name_bytes);
120 buf.extend_from_slice(&(state.len() as u32).to_le_bytes());
121 buf.extend_from_slice(state);
122 }
123
124 buf.extend_from_slice(&(self.source_offsets.len() as u32).to_le_bytes());
126 for (name, offset) in &self.source_offsets {
127 let name_bytes = name.as_bytes();
128 buf.extend_from_slice(&(name_bytes.len() as u32).to_le_bytes());
129 buf.extend_from_slice(name_bytes);
130 buf.extend_from_slice(&offset.to_le_bytes());
131 }
132
133 buf
134 }
135
136 pub fn deserialize(stream_id: &str, data: &[u8]) -> Result<Self, StreamingError> {
141 const HEADER: usize = 24; if data.len() < HEADER {
143 return Err(StreamingError::DeserializationError(
144 "checkpoint data too short for header".into(),
145 ));
146 }
147
148 let sequence = Self::read_u64(data, 0)?;
149 let watermark_ns = Self::read_u64(data, 8)?;
150 let event_count = Self::read_u64(data, 16)?;
151
152 let id = CheckpointId::new(stream_id, sequence);
153 let mut state = Self::new(id);
154 state.watermark_ns = watermark_ns;
155 state.event_count = event_count;
156
157 let mut cursor = HEADER;
158
159 let n_ops = Self::read_u32(data, cursor)? as usize;
161 cursor += 4;
162 for _ in 0..n_ops {
163 let (name, advance) = Self::read_string(data, cursor)?;
164 cursor += advance;
165 let state_len = Self::read_u32(data, cursor)? as usize;
166 cursor += 4;
167 if cursor + state_len > data.len() {
168 return Err(StreamingError::DeserializationError(
169 "truncated operator state bytes".into(),
170 ));
171 }
172 let op_state = data[cursor..cursor + state_len].to_vec();
173 cursor += state_len;
174 state.operator_states.insert(name, op_state);
175 }
176
177 if cursor + 4 > data.len() {
179 return Ok(state);
181 }
182 let n_src = Self::read_u32(data, cursor)? as usize;
183 cursor += 4;
184 for _ in 0..n_src {
185 let (name, advance) = Self::read_string(data, cursor)?;
186 cursor += advance;
187 let offset = Self::read_u64(data, cursor)?;
188 cursor += 8;
189 state.source_offsets.insert(name, offset);
190 }
191
192 Ok(state)
193 }
194
195 fn read_u64(data: &[u8], offset: usize) -> Result<u64, StreamingError> {
198 data.get(offset..offset + 8)
199 .and_then(|b| b.try_into().ok())
200 .map(u64::from_le_bytes)
201 .ok_or_else(|| {
202 StreamingError::DeserializationError(format!("cannot read u64 at offset {offset}"))
203 })
204 }
205
206 fn read_u32(data: &[u8], offset: usize) -> Result<u32, StreamingError> {
207 data.get(offset..offset + 4)
208 .and_then(|b| b.try_into().ok())
209 .map(u32::from_le_bytes)
210 .ok_or_else(|| {
211 StreamingError::DeserializationError(format!("cannot read u32 at offset {offset}"))
212 })
213 }
214
215 fn read_string(data: &[u8], cursor: usize) -> Result<(String, usize), StreamingError> {
220 let name_len = Self::read_u32(data, cursor)? as usize;
221 let name_start = cursor + 4;
222 let name_end = name_start + name_len;
223 if name_end > data.len() {
224 return Err(StreamingError::DeserializationError(
225 "truncated string bytes".into(),
226 ));
227 }
228 let name = String::from_utf8(data[name_start..name_end].to_vec()).map_err(|e| {
229 StreamingError::DeserializationError(format!("invalid UTF-8 in field name: {e}"))
230 })?;
231 Ok((name, 4 + name_len))
232 }
233}
234
235pub struct InMemoryCheckpointStore {
243 checkpoints: HashMap<String, Vec<CheckpointState>>,
245 max_per_stream: usize,
247}
248
249impl InMemoryCheckpointStore {
250 pub fn new(max_per_stream: usize) -> Self {
252 assert!(max_per_stream > 0, "max_per_stream must be at least 1");
253 Self {
254 checkpoints: HashMap::new(),
255 max_per_stream,
256 }
257 }
258
259 pub fn save(&mut self, state: CheckpointState) -> Result<(), StreamingError> {
261 let stream_id = state.id.stream_id.clone();
262 let entry = self.checkpoints.entry(stream_id).or_default();
263 entry.push(state);
264 entry.sort_by_key(|s| s.id.sequence);
265 if entry.len() > self.max_per_stream {
267 let excess = entry.len() - self.max_per_stream;
268 entry.drain(0..excess);
269 }
270 Ok(())
271 }
272
273 pub fn latest(&self, stream_id: &str) -> Option<&CheckpointState> {
275 self.checkpoints.get(stream_id)?.last()
276 }
277
278 pub fn list(&self, stream_id: &str) -> Vec<&CheckpointState> {
280 self.checkpoints
281 .get(stream_id)
282 .map(|v| v.iter().collect())
283 .unwrap_or_default()
284 }
285
286 pub fn delete_before(&mut self, stream_id: &str, sequence: u64) {
288 if let Some(entry) = self.checkpoints.get_mut(stream_id) {
289 entry.retain(|s| s.id.sequence >= sequence);
290 }
291 }
292
293 pub fn checkpoint_count(&self, stream_id: &str) -> usize {
295 self.checkpoints
296 .get(stream_id)
297 .map(|v| v.len())
298 .unwrap_or(0)
299 }
300}
301
302pub struct CheckpointManager {
310 store: InMemoryCheckpointStore,
311 checkpoint_interval: u64,
313 next_checkpoint_at: u64,
314 total_checkpoints: u64,
315}
316
317impl CheckpointManager {
318 pub fn new(store: InMemoryCheckpointStore, checkpoint_interval: u64) -> Self {
320 assert!(
321 checkpoint_interval > 0,
322 "checkpoint_interval must be positive"
323 );
324 Self {
325 store,
326 checkpoint_interval,
327 next_checkpoint_at: checkpoint_interval,
328 total_checkpoints: 0,
329 }
330 }
331
332 pub fn on_event(
336 &mut self,
337 stream_id: &str,
338 sequence: u64,
339 watermark_ns: u64,
340 ) -> Result<bool, StreamingError> {
341 if sequence >= self.next_checkpoint_at {
342 let id = CheckpointId::new(stream_id, sequence);
343 let mut state = CheckpointState::new(id);
344 state.watermark_ns = watermark_ns;
345 state.event_count = sequence;
346 self.store.save(state)?;
347 self.next_checkpoint_at = sequence + self.checkpoint_interval;
348 self.total_checkpoints += 1;
349 return Ok(true);
350 }
351 Ok(false)
352 }
353
354 pub fn recover(&self, stream_id: &str) -> Option<u64> {
357 self.store.latest(stream_id).map(|s| s.id.sequence)
358 }
359
360 pub fn total_checkpoints(&self) -> u64 {
362 self.total_checkpoints
363 }
364
365 pub fn store(&self) -> &InMemoryCheckpointStore {
367 &self.store
368 }
369}
370
371#[cfg(test)]
372mod tests {
373 use super::*;
374
375 #[test]
378 fn test_serialize_deserialize_round_trip_empty() {
379 let id = CheckpointId::new("stream-a", 42);
380 let mut state = CheckpointState::new(id);
381 state.watermark_ns = 999_000_000;
382 state.event_count = 42;
383
384 let bytes = state.serialize();
385 let decoded = CheckpointState::deserialize("stream-a", &bytes)
386 .expect("deserialization should succeed");
387
388 assert_eq!(decoded.id.sequence, 42);
389 assert_eq!(decoded.watermark_ns, 999_000_000);
390 assert_eq!(decoded.event_count, 42);
391 assert!(decoded.operator_states.is_empty());
392 assert!(decoded.source_offsets.is_empty());
393 }
394
395 #[test]
396 fn test_serialize_deserialize_with_operator_states() {
397 let id = CheckpointId::new("s", 1);
398 let mut state = CheckpointState::new(id);
399 state.set_operator_state("agg_op", vec![1, 2, 3, 4]);
400 state.set_operator_state("filter_op", vec![9, 8]);
401
402 let bytes = state.serialize();
403 let decoded = CheckpointState::deserialize("s", &bytes).expect("should succeed");
404 assert_eq!(
405 decoded.operator_states.get("agg_op"),
406 Some(&vec![1, 2, 3, 4])
407 );
408 assert_eq!(decoded.operator_states.get("filter_op"), Some(&vec![9, 8]));
409 }
410
411 #[test]
412 fn test_serialize_deserialize_with_source_offsets() {
413 let id = CheckpointId::new("s", 7);
414 let mut state = CheckpointState::new(id);
415 state.set_source_offset("kafka-topic-0", 1_234_567);
416 state.set_source_offset("file-source", 4_096);
417
418 let bytes = state.serialize();
419 let decoded = CheckpointState::deserialize("s", &bytes).expect("should succeed");
420 assert_eq!(
421 decoded.source_offsets.get("kafka-topic-0"),
422 Some(&1_234_567)
423 );
424 assert_eq!(decoded.source_offsets.get("file-source"), Some(&4_096));
425 }
426
427 #[test]
428 fn test_deserialize_truncated_data_returns_error() {
429 let result = CheckpointState::deserialize("s", &[0u8; 10]);
430 assert!(result.is_err());
431 }
432
433 #[test]
434 fn test_deserialize_empty_slice_returns_error() {
435 let result = CheckpointState::deserialize("s", &[]);
436 assert!(result.is_err());
437 }
438
439 #[test]
442 fn test_store_save_and_latest() {
443 let mut store = InMemoryCheckpointStore::new(5);
444 let id = CheckpointId::new("stream-x", 10);
445 let state = CheckpointState::new(id);
446 store.save(state).expect("save should succeed");
447 let latest = store.latest("stream-x").expect("should be present");
448 assert_eq!(latest.id.sequence, 10);
449 }
450
451 #[test]
452 fn test_store_latest_none_when_empty() {
453 let store = InMemoryCheckpointStore::new(5);
454 assert!(store.latest("unknown").is_none());
455 }
456
457 #[test]
458 fn test_store_trims_to_max_per_stream() {
459 let mut store = InMemoryCheckpointStore::new(3);
460 for i in 0u64..6 {
461 let id = CheckpointId::new("s", i);
462 store.save(CheckpointState::new(id)).expect("save ok");
463 }
464 assert_eq!(store.checkpoint_count("s"), 3);
465 assert_eq!(
467 store
468 .latest("s")
469 .expect("latest checkpoint for stream 's'")
470 .id
471 .sequence,
472 5
473 );
474 }
475
476 #[test]
477 fn test_store_delete_before() {
478 let mut store = InMemoryCheckpointStore::new(10);
479 for i in 0u64..5 {
480 let id = CheckpointId::new("s", i * 10);
481 store.save(CheckpointState::new(id)).expect("save ok");
482 }
483 store.delete_before("s", 20);
485 let remaining = store.list("s");
486 assert!(remaining.iter().all(|c| c.id.sequence >= 20));
487 }
488
489 #[test]
490 fn test_store_multiple_streams_independent() {
491 let mut store = InMemoryCheckpointStore::new(5);
492 for seq in [1u64, 2, 3] {
493 store
494 .save(CheckpointState::new(CheckpointId::new("stream-a", seq)))
495 .expect("ok");
496 store
497 .save(CheckpointState::new(CheckpointId::new(
498 "stream-b",
499 seq * 10,
500 )))
501 .expect("ok");
502 }
503 assert_eq!(store.checkpoint_count("stream-a"), 3);
504 assert_eq!(store.checkpoint_count("stream-b"), 3);
505 assert_eq!(
506 store
507 .latest("stream-a")
508 .expect("latest checkpoint for stream-a")
509 .id
510 .sequence,
511 3
512 );
513 assert_eq!(
514 store
515 .latest("stream-b")
516 .expect("latest checkpoint for stream-b")
517 .id
518 .sequence,
519 30
520 );
521 }
522
523 #[test]
526 fn test_manager_triggers_checkpoint_at_interval() {
527 let store = InMemoryCheckpointStore::new(10);
528 let mut mgr = CheckpointManager::new(store, 100);
529 for seq in 0u64..99 {
531 let triggered = mgr.on_event("s", seq, 0).expect("on_event ok");
532 assert!(!triggered);
533 }
534 let triggered = mgr.on_event("s", 100, 0).expect("on_event ok");
536 assert!(triggered);
537 assert_eq!(mgr.total_checkpoints(), 1);
538 }
539
540 #[test]
541 fn test_manager_recover_returns_last_sequence() {
542 let store = InMemoryCheckpointStore::new(10);
543 let mut mgr = CheckpointManager::new(store, 50);
544 mgr.on_event("s", 50, 0).expect("ok");
545 mgr.on_event("s", 100, 0).expect("ok");
546 let seq = mgr.recover("s").expect("should recover");
547 assert_eq!(seq, 100);
548 }
549
550 #[test]
551 fn test_manager_recover_none_before_first_checkpoint() {
552 let store = InMemoryCheckpointStore::new(5);
553 let mgr = CheckpointManager::new(store, 100);
554 assert!(mgr.recover("s").is_none());
555 }
556
557 #[test]
558 fn test_manager_total_checkpoints_counter() {
559 let store = InMemoryCheckpointStore::new(10);
560 let mut mgr = CheckpointManager::new(store, 10);
561 for seq in (0u64..=50).step_by(1) {
562 mgr.on_event("s", seq, 0).expect("ok");
563 }
564 assert_eq!(mgr.total_checkpoints(), 5);
566 }
567
568 #[test]
569 fn test_checkpoint_state_full_round_trip() {
570 let id = CheckpointId::new("full-test", 77);
571 let mut state = CheckpointState::new(id);
572 state.watermark_ns = 1_700_000_000_000_000_000;
573 state.event_count = 77;
574 state.set_operator_state("window_op", b"window_state_data".to_vec());
575 state.set_source_offset("source-0", 8192);
576 state.metadata.insert("app_version".into(), "1.2.3".into());
577
578 let bytes = state.serialize();
579 let decoded =
580 CheckpointState::deserialize("full-test", &bytes).expect("round-trip should succeed");
581
582 assert_eq!(decoded.id.sequence, 77);
583 assert_eq!(decoded.watermark_ns, 1_700_000_000_000_000_000);
584 assert_eq!(decoded.event_count, 77);
585 assert_eq!(
586 decoded.operator_states.get("window_op"),
587 Some(&b"window_state_data".to_vec())
588 );
589 assert_eq!(decoded.source_offsets.get("source-0"), Some(&8192u64));
590 }
591}