1use crate::error::{Result, StreamingError};
4use chrono::{DateTime, Utc};
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::path::PathBuf;
8use std::sync::Arc;
9use std::time::Duration;
10use tokio::sync::RwLock;
11use tokio::time::sleep;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct CheckpointMetadata {
16 pub id: u64,
18
19 pub timestamp: DateTime<Utc>,
21
22 pub size_bytes: usize,
24
25 pub operator_states: HashMap<String, Vec<u8>>,
27
28 pub success: bool,
30
31 pub duration: Duration,
33}
34
35#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
37pub struct CheckpointBarrier {
38 pub id: u64,
40
41 pub timestamp: DateTime<Utc>,
43}
44
45impl CheckpointBarrier {
46 pub fn new(id: u64) -> Self {
48 Self {
49 id,
50 timestamp: Utc::now(),
51 }
52 }
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct CheckpointConfig {
58 pub interval: Duration,
60
61 pub min_pause: Duration,
63
64 pub max_concurrent: usize,
66
67 pub unaligned: bool,
69
70 pub timeout: Duration,
72
73 pub storage_path: Option<PathBuf>,
75}
76
77impl Default for CheckpointConfig {
78 fn default() -> Self {
79 Self {
80 interval: Duration::from_secs(60),
81 min_pause: Duration::from_secs(10),
82 max_concurrent: 1,
83 unaligned: false,
84 timeout: Duration::from_secs(300),
85 storage_path: None,
86 }
87 }
88}
89
90pub trait CheckpointStorage: Send + Sync {
92 fn store(&self, checkpoint: &Checkpoint) -> Result<()>;
94
95 fn load(&self, checkpoint_id: u64) -> Result<Option<Checkpoint>>;
97
98 fn delete(&self, checkpoint_id: u64) -> Result<()>;
100
101 fn list(&self) -> Result<Vec<u64>>;
103
104 fn latest(&self) -> Result<Option<u64>>;
106}
107
108#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct Checkpoint {
111 pub metadata: CheckpointMetadata,
113
114 pub data: Vec<u8>,
116}
117
118impl Checkpoint {
119 pub fn new(id: u64, data: Vec<u8>) -> Self {
121 let size_bytes = data.len();
122 Self {
123 metadata: CheckpointMetadata {
124 id,
125 timestamp: Utc::now(),
126 size_bytes,
127 operator_states: HashMap::new(),
128 success: true,
129 duration: Duration::ZERO,
130 },
131 data,
132 }
133 }
134
135 pub fn id(&self) -> u64 {
137 self.metadata.id
138 }
139
140 pub fn size(&self) -> usize {
142 self.metadata.size_bytes
143 }
144}
145
146pub struct CheckpointCoordinator {
148 config: CheckpointConfig,
149 next_checkpoint_id: Arc<RwLock<u64>>,
150 active_checkpoints: Arc<RwLock<HashMap<u64, CheckpointMetadata>>>,
151 completed_checkpoints: Arc<RwLock<Vec<u64>>>,
152 last_checkpoint_time: Arc<RwLock<Option<DateTime<Utc>>>>,
153}
154
155impl CheckpointCoordinator {
156 pub fn new(config: CheckpointConfig) -> Self {
158 Self {
159 config,
160 next_checkpoint_id: Arc::new(RwLock::new(0)),
161 active_checkpoints: Arc::new(RwLock::new(HashMap::new())),
162 completed_checkpoints: Arc::new(RwLock::new(Vec::new())),
163 last_checkpoint_time: Arc::new(RwLock::new(None)),
164 }
165 }
166
167 pub async fn trigger_checkpoint(&self) -> Result<u64> {
169 let now = Utc::now();
170 let last_time = *self.last_checkpoint_time.read().await;
171
172 if let Some(last) = last_time {
173 let min_pause_chrono = match chrono::Duration::from_std(self.config.min_pause) {
174 Ok(duration) => duration,
175 Err(_) => chrono::Duration::zero(),
176 };
177
178 if now - last < min_pause_chrono {
179 return Err(StreamingError::CheckpointError(
180 "Minimum pause not elapsed".to_string(),
181 ));
182 }
183 }
184
185 let active_count = self.active_checkpoints.read().await.len();
186 if active_count >= self.config.max_concurrent {
187 return Err(StreamingError::CheckpointError(
188 "Too many concurrent checkpoints".to_string(),
189 ));
190 }
191
192 let mut next_id = self.next_checkpoint_id.write().await;
193 let checkpoint_id = *next_id;
194 *next_id += 1;
195
196 let metadata = CheckpointMetadata {
197 id: checkpoint_id,
198 timestamp: now,
199 size_bytes: 0,
200 operator_states: HashMap::new(),
201 success: false,
202 duration: Duration::ZERO,
203 };
204
205 self.active_checkpoints
206 .write()
207 .await
208 .insert(checkpoint_id, metadata);
209
210 *self.last_checkpoint_time.write().await = Some(now);
211
212 Ok(checkpoint_id)
213 }
214
215 pub async fn complete_checkpoint(&self, checkpoint_id: u64, success: bool) -> Result<()> {
217 let mut active = self.active_checkpoints.write().await;
218
219 if let Some(mut metadata) = active.remove(&checkpoint_id) {
220 metadata.success = success;
221 metadata.duration = match (Utc::now() - metadata.timestamp).to_std() {
222 Ok(duration) => duration,
223 Err(_) => Duration::ZERO,
224 };
225
226 if success {
227 self.completed_checkpoints.write().await.push(checkpoint_id);
228 }
229
230 Ok(())
231 } else {
232 Err(StreamingError::CheckpointError(format!(
233 "Checkpoint {} not found",
234 checkpoint_id
235 )))
236 }
237 }
238
239 pub async fn active_count(&self) -> usize {
241 self.active_checkpoints.read().await.len()
242 }
243
244 pub async fn completed_count(&self) -> usize {
246 self.completed_checkpoints.read().await.len()
247 }
248
249 pub async fn latest_checkpoint(&self) -> Option<u64> {
251 self.completed_checkpoints.read().await.last().copied()
252 }
253
254 pub async fn clear_old_checkpoints(&self, keep_count: usize) {
256 let mut completed = self.completed_checkpoints.write().await;
257
258 if completed.len() > keep_count {
259 let to_remove = completed.len() - keep_count;
260 completed.drain(0..to_remove);
261 }
262 }
263
264 pub async fn start_periodic_checkpointing(self: Arc<Self>) {
266 let interval = self.config.interval;
267
268 tokio::spawn(async move {
269 loop {
270 sleep(interval).await;
271
272 match self.trigger_checkpoint().await {
273 Ok(id) => {
274 tracing::info!("Triggered checkpoint {}", id);
275
276 tokio::spawn({
277 let coordinator = self.clone();
278 async move {
279 sleep(Duration::from_secs(1)).await;
280 if let Err(e) = coordinator.complete_checkpoint(id, true).await {
281 tracing::error!("Failed to complete checkpoint {}: {}", id, e);
282 }
283 }
284 });
285 }
286 Err(e) => {
287 tracing::warn!("Failed to trigger checkpoint: {}", e);
288 }
289 }
290 }
291 });
292 }
293}
294
295#[cfg(test)]
296mod tests {
297 use super::*;
298
299 #[tokio::test]
300 async fn test_checkpoint_creation() {
301 let data = vec![1, 2, 3, 4];
302 let checkpoint = Checkpoint::new(1, data.clone());
303
304 assert_eq!(checkpoint.id(), 1);
305 assert_eq!(checkpoint.size(), 4);
306 assert_eq!(checkpoint.data, data);
307 }
308
309 #[tokio::test]
310 async fn test_checkpoint_barrier() {
311 let barrier = CheckpointBarrier::new(1);
312 assert_eq!(barrier.id, 1);
313 }
314
315 #[tokio::test]
316 async fn test_checkpoint_coordinator() {
317 let config = CheckpointConfig {
318 min_pause: Duration::ZERO, max_concurrent: 2, ..Default::default()
321 };
322 let coordinator = CheckpointCoordinator::new(config);
323
324 let id1 = coordinator
325 .trigger_checkpoint()
326 .await
327 .expect("First checkpoint trigger should succeed");
328 assert_eq!(id1, 0);
329
330 let id2 = coordinator
331 .trigger_checkpoint()
332 .await
333 .expect("Second checkpoint trigger should succeed");
334 assert_eq!(id2, 1);
335
336 assert_eq!(coordinator.active_count().await, 2);
337
338 coordinator
339 .complete_checkpoint(id1, true)
340 .await
341 .expect("Checkpoint completion should succeed");
342 assert_eq!(coordinator.active_count().await, 1);
343 assert_eq!(coordinator.completed_count().await, 1);
344 }
345
346 #[tokio::test]
347 async fn test_checkpoint_min_pause() {
348 let config = CheckpointConfig {
349 min_pause: Duration::from_secs(60),
350 ..Default::default()
351 };
352
353 let coordinator = CheckpointCoordinator::new(config);
354
355 coordinator
356 .trigger_checkpoint()
357 .await
358 .expect("First checkpoint should trigger successfully");
359 let result = coordinator.trigger_checkpoint().await;
360
361 assert!(result.is_err());
362 }
363
364 #[tokio::test]
365 async fn test_clear_old_checkpoints() {
366 let config = CheckpointConfig {
367 min_pause: Duration::ZERO, ..Default::default()
369 };
370 let coordinator = CheckpointCoordinator::new(config);
371
372 for _ in 0..5 {
373 let id = coordinator
374 .trigger_checkpoint()
375 .await
376 .expect("Checkpoint trigger should succeed in loop");
377 coordinator
378 .complete_checkpoint(id, true)
379 .await
380 .expect("Checkpoint completion should succeed in loop");
381 }
382
383 assert_eq!(coordinator.completed_count().await, 5);
384
385 coordinator.clear_old_checkpoints(2).await;
386 assert_eq!(coordinator.completed_count().await, 2);
387 }
388}