use std::collections::HashMap;
use std::time::SystemTime;
use crate::error::StreamingError;
#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct CheckpointId {
pub stream_id: String,
pub sequence: u64,
pub created_at: SystemTime,
}
impl CheckpointId {
pub fn new(stream_id: impl Into<String>, sequence: u64) -> Self {
Self {
stream_id: stream_id.into(),
sequence,
created_at: SystemTime::now(),
}
}
}
#[derive(Debug, Clone)]
pub struct CheckpointState {
pub id: CheckpointId,
pub operator_states: HashMap<String, Vec<u8>>,
pub source_offsets: HashMap<String, u64>,
pub watermark_ns: u64,
pub event_count: u64,
pub metadata: HashMap<String, String>,
}
impl CheckpointState {
pub fn new(id: CheckpointId) -> Self {
Self {
id,
operator_states: HashMap::new(),
source_offsets: HashMap::new(),
watermark_ns: 0,
event_count: 0,
metadata: HashMap::new(),
}
}
pub fn set_operator_state(&mut self, operator: impl Into<String>, state: Vec<u8>) {
self.operator_states.insert(operator.into(), state);
}
pub fn set_source_offset(&mut self, source: impl Into<String>, offset: u64) {
self.source_offsets.insert(source.into(), offset);
}
pub fn serialize(&self) -> Vec<u8> {
let mut buf = Vec::new();
buf.extend_from_slice(&self.id.sequence.to_le_bytes());
buf.extend_from_slice(&self.watermark_ns.to_le_bytes());
buf.extend_from_slice(&self.event_count.to_le_bytes());
buf.extend_from_slice(&(self.operator_states.len() as u32).to_le_bytes());
for (name, state) in &self.operator_states {
let name_bytes = name.as_bytes();
buf.extend_from_slice(&(name_bytes.len() as u32).to_le_bytes());
buf.extend_from_slice(name_bytes);
buf.extend_from_slice(&(state.len() as u32).to_le_bytes());
buf.extend_from_slice(state);
}
buf.extend_from_slice(&(self.source_offsets.len() as u32).to_le_bytes());
for (name, offset) in &self.source_offsets {
let name_bytes = name.as_bytes();
buf.extend_from_slice(&(name_bytes.len() as u32).to_le_bytes());
buf.extend_from_slice(name_bytes);
buf.extend_from_slice(&offset.to_le_bytes());
}
buf
}
pub fn deserialize(stream_id: &str, data: &[u8]) -> Result<Self, StreamingError> {
const HEADER: usize = 24; if data.len() < HEADER {
return Err(StreamingError::DeserializationError(
"checkpoint data too short for header".into(),
));
}
let sequence = Self::read_u64(data, 0)?;
let watermark_ns = Self::read_u64(data, 8)?;
let event_count = Self::read_u64(data, 16)?;
let id = CheckpointId::new(stream_id, sequence);
let mut state = Self::new(id);
state.watermark_ns = watermark_ns;
state.event_count = event_count;
let mut cursor = HEADER;
let n_ops = Self::read_u32(data, cursor)? as usize;
cursor += 4;
for _ in 0..n_ops {
let (name, advance) = Self::read_string(data, cursor)?;
cursor += advance;
let state_len = Self::read_u32(data, cursor)? as usize;
cursor += 4;
if cursor + state_len > data.len() {
return Err(StreamingError::DeserializationError(
"truncated operator state bytes".into(),
));
}
let op_state = data[cursor..cursor + state_len].to_vec();
cursor += state_len;
state.operator_states.insert(name, op_state);
}
if cursor + 4 > data.len() {
return Ok(state);
}
let n_src = Self::read_u32(data, cursor)? as usize;
cursor += 4;
for _ in 0..n_src {
let (name, advance) = Self::read_string(data, cursor)?;
cursor += advance;
let offset = Self::read_u64(data, cursor)?;
cursor += 8;
state.source_offsets.insert(name, offset);
}
Ok(state)
}
fn read_u64(data: &[u8], offset: usize) -> Result<u64, StreamingError> {
data.get(offset..offset + 8)
.and_then(|b| b.try_into().ok())
.map(u64::from_le_bytes)
.ok_or_else(|| {
StreamingError::DeserializationError(format!("cannot read u64 at offset {offset}"))
})
}
fn read_u32(data: &[u8], offset: usize) -> Result<u32, StreamingError> {
data.get(offset..offset + 4)
.and_then(|b| b.try_into().ok())
.map(u32::from_le_bytes)
.ok_or_else(|| {
StreamingError::DeserializationError(format!("cannot read u32 at offset {offset}"))
})
}
fn read_string(data: &[u8], cursor: usize) -> Result<(String, usize), StreamingError> {
let name_len = Self::read_u32(data, cursor)? as usize;
let name_start = cursor + 4;
let name_end = name_start + name_len;
if name_end > data.len() {
return Err(StreamingError::DeserializationError(
"truncated string bytes".into(),
));
}
let name = String::from_utf8(data[name_start..name_end].to_vec()).map_err(|e| {
StreamingError::DeserializationError(format!("invalid UTF-8 in field name: {e}"))
})?;
Ok((name, 4 + name_len))
}
}
pub struct InMemoryCheckpointStore {
checkpoints: HashMap<String, Vec<CheckpointState>>,
max_per_stream: usize,
}
impl InMemoryCheckpointStore {
pub fn new(max_per_stream: usize) -> Self {
assert!(max_per_stream > 0, "max_per_stream must be at least 1");
Self {
checkpoints: HashMap::new(),
max_per_stream,
}
}
pub fn save(&mut self, state: CheckpointState) -> Result<(), StreamingError> {
let stream_id = state.id.stream_id.clone();
let entry = self.checkpoints.entry(stream_id).or_default();
entry.push(state);
entry.sort_by_key(|s| s.id.sequence);
if entry.len() > self.max_per_stream {
let excess = entry.len() - self.max_per_stream;
entry.drain(0..excess);
}
Ok(())
}
pub fn latest(&self, stream_id: &str) -> Option<&CheckpointState> {
self.checkpoints.get(stream_id)?.last()
}
pub fn list(&self, stream_id: &str) -> Vec<&CheckpointState> {
self.checkpoints
.get(stream_id)
.map(|v| v.iter().collect())
.unwrap_or_default()
}
pub fn delete_before(&mut self, stream_id: &str, sequence: u64) {
if let Some(entry) = self.checkpoints.get_mut(stream_id) {
entry.retain(|s| s.id.sequence >= sequence);
}
}
pub fn checkpoint_count(&self, stream_id: &str) -> usize {
self.checkpoints
.get(stream_id)
.map(|v| v.len())
.unwrap_or(0)
}
}
pub struct CheckpointManager {
store: InMemoryCheckpointStore,
checkpoint_interval: u64,
next_checkpoint_at: u64,
total_checkpoints: u64,
}
impl CheckpointManager {
pub fn new(store: InMemoryCheckpointStore, checkpoint_interval: u64) -> Self {
assert!(
checkpoint_interval > 0,
"checkpoint_interval must be positive"
);
Self {
store,
checkpoint_interval,
next_checkpoint_at: checkpoint_interval,
total_checkpoints: 0,
}
}
pub fn on_event(
&mut self,
stream_id: &str,
sequence: u64,
watermark_ns: u64,
) -> Result<bool, StreamingError> {
if sequence >= self.next_checkpoint_at {
let id = CheckpointId::new(stream_id, sequence);
let mut state = CheckpointState::new(id);
state.watermark_ns = watermark_ns;
state.event_count = sequence;
self.store.save(state)?;
self.next_checkpoint_at = sequence + self.checkpoint_interval;
self.total_checkpoints += 1;
return Ok(true);
}
Ok(false)
}
pub fn recover(&self, stream_id: &str) -> Option<u64> {
self.store.latest(stream_id).map(|s| s.id.sequence)
}
pub fn total_checkpoints(&self) -> u64 {
self.total_checkpoints
}
pub fn store(&self) -> &InMemoryCheckpointStore {
&self.store
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serialize_deserialize_round_trip_empty() {
let id = CheckpointId::new("stream-a", 42);
let mut state = CheckpointState::new(id);
state.watermark_ns = 999_000_000;
state.event_count = 42;
let bytes = state.serialize();
let decoded = CheckpointState::deserialize("stream-a", &bytes)
.expect("deserialization should succeed");
assert_eq!(decoded.id.sequence, 42);
assert_eq!(decoded.watermark_ns, 999_000_000);
assert_eq!(decoded.event_count, 42);
assert!(decoded.operator_states.is_empty());
assert!(decoded.source_offsets.is_empty());
}
#[test]
fn test_serialize_deserialize_with_operator_states() {
let id = CheckpointId::new("s", 1);
let mut state = CheckpointState::new(id);
state.set_operator_state("agg_op", vec![1, 2, 3, 4]);
state.set_operator_state("filter_op", vec![9, 8]);
let bytes = state.serialize();
let decoded = CheckpointState::deserialize("s", &bytes).expect("should succeed");
assert_eq!(
decoded.operator_states.get("agg_op"),
Some(&vec![1, 2, 3, 4])
);
assert_eq!(decoded.operator_states.get("filter_op"), Some(&vec![9, 8]));
}
#[test]
fn test_serialize_deserialize_with_source_offsets() {
let id = CheckpointId::new("s", 7);
let mut state = CheckpointState::new(id);
state.set_source_offset("kafka-topic-0", 1_234_567);
state.set_source_offset("file-source", 4_096);
let bytes = state.serialize();
let decoded = CheckpointState::deserialize("s", &bytes).expect("should succeed");
assert_eq!(
decoded.source_offsets.get("kafka-topic-0"),
Some(&1_234_567)
);
assert_eq!(decoded.source_offsets.get("file-source"), Some(&4_096));
}
#[test]
fn test_deserialize_truncated_data_returns_error() {
let result = CheckpointState::deserialize("s", &[0u8; 10]);
assert!(result.is_err());
}
#[test]
fn test_deserialize_empty_slice_returns_error() {
let result = CheckpointState::deserialize("s", &[]);
assert!(result.is_err());
}
#[test]
fn test_store_save_and_latest() {
let mut store = InMemoryCheckpointStore::new(5);
let id = CheckpointId::new("stream-x", 10);
let state = CheckpointState::new(id);
store.save(state).expect("save should succeed");
let latest = store.latest("stream-x").expect("should be present");
assert_eq!(latest.id.sequence, 10);
}
#[test]
fn test_store_latest_none_when_empty() {
let store = InMemoryCheckpointStore::new(5);
assert!(store.latest("unknown").is_none());
}
#[test]
fn test_store_trims_to_max_per_stream() {
let mut store = InMemoryCheckpointStore::new(3);
for i in 0u64..6 {
let id = CheckpointId::new("s", i);
store.save(CheckpointState::new(id)).expect("save ok");
}
assert_eq!(store.checkpoint_count("s"), 3);
assert_eq!(
store
.latest("s")
.expect("latest checkpoint for stream 's'")
.id
.sequence,
5
);
}
#[test]
fn test_store_delete_before() {
let mut store = InMemoryCheckpointStore::new(10);
for i in 0u64..5 {
let id = CheckpointId::new("s", i * 10);
store.save(CheckpointState::new(id)).expect("save ok");
}
store.delete_before("s", 20);
let remaining = store.list("s");
assert!(remaining.iter().all(|c| c.id.sequence >= 20));
}
#[test]
fn test_store_multiple_streams_independent() {
let mut store = InMemoryCheckpointStore::new(5);
for seq in [1u64, 2, 3] {
store
.save(CheckpointState::new(CheckpointId::new("stream-a", seq)))
.expect("ok");
store
.save(CheckpointState::new(CheckpointId::new(
"stream-b",
seq * 10,
)))
.expect("ok");
}
assert_eq!(store.checkpoint_count("stream-a"), 3);
assert_eq!(store.checkpoint_count("stream-b"), 3);
assert_eq!(
store
.latest("stream-a")
.expect("latest checkpoint for stream-a")
.id
.sequence,
3
);
assert_eq!(
store
.latest("stream-b")
.expect("latest checkpoint for stream-b")
.id
.sequence,
30
);
}
#[test]
fn test_manager_triggers_checkpoint_at_interval() {
let store = InMemoryCheckpointStore::new(10);
let mut mgr = CheckpointManager::new(store, 100);
for seq in 0u64..99 {
let triggered = mgr.on_event("s", seq, 0).expect("on_event ok");
assert!(!triggered);
}
let triggered = mgr.on_event("s", 100, 0).expect("on_event ok");
assert!(triggered);
assert_eq!(mgr.total_checkpoints(), 1);
}
#[test]
fn test_manager_recover_returns_last_sequence() {
let store = InMemoryCheckpointStore::new(10);
let mut mgr = CheckpointManager::new(store, 50);
mgr.on_event("s", 50, 0).expect("ok");
mgr.on_event("s", 100, 0).expect("ok");
let seq = mgr.recover("s").expect("should recover");
assert_eq!(seq, 100);
}
#[test]
fn test_manager_recover_none_before_first_checkpoint() {
let store = InMemoryCheckpointStore::new(5);
let mgr = CheckpointManager::new(store, 100);
assert!(mgr.recover("s").is_none());
}
#[test]
fn test_manager_total_checkpoints_counter() {
let store = InMemoryCheckpointStore::new(10);
let mut mgr = CheckpointManager::new(store, 10);
for seq in (0u64..=50).step_by(1) {
mgr.on_event("s", seq, 0).expect("ok");
}
assert_eq!(mgr.total_checkpoints(), 5);
}
#[test]
fn test_checkpoint_state_full_round_trip() {
let id = CheckpointId::new("full-test", 77);
let mut state = CheckpointState::new(id);
state.watermark_ns = 1_700_000_000_000_000_000;
state.event_count = 77;
state.set_operator_state("window_op", b"window_state_data".to_vec());
state.set_source_offset("source-0", 8192);
state.metadata.insert("app_version".into(), "1.2.3".into());
let bytes = state.serialize();
let decoded =
CheckpointState::deserialize("full-test", &bytes).expect("round-trip should succeed");
assert_eq!(decoded.id.sequence, 77);
assert_eq!(decoded.watermark_ns, 1_700_000_000_000_000_000);
assert_eq!(decoded.event_count, 77);
assert_eq!(
decoded.operator_states.get("window_op"),
Some(&b"window_state_data".to_vec())
);
assert_eq!(decoded.source_offsets.get("source-0"), Some(&8192u64));
}
}