Skip to main content

oxigdal_streaming/state/
checkpoint.rs

1//! Checkpointing for fault tolerance.
2
3use crate::error::{Result, StreamingError};
4use chrono::{DateTime, Utc};
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::path::PathBuf;
8use std::sync::Arc;
9use std::time::Duration;
10use tokio::sync::RwLock;
11use tokio::time::sleep;
12
13/// Checkpoint metadata.
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct CheckpointMetadata {
16    /// Checkpoint ID
17    pub id: u64,
18
19    /// Checkpoint timestamp
20    pub timestamp: DateTime<Utc>,
21
22    /// Checkpoint size in bytes
23    pub size_bytes: usize,
24
25    /// State of operators
26    pub operator_states: HashMap<String, Vec<u8>>,
27
28    /// Success status
29    pub success: bool,
30
31    /// Duration to complete
32    pub duration: Duration,
33}
34
35/// Checkpoint barrier.
36#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
37pub struct CheckpointBarrier {
38    /// Checkpoint ID
39    pub id: u64,
40
41    /// Timestamp
42    pub timestamp: DateTime<Utc>,
43}
44
45impl CheckpointBarrier {
46    /// Create a new checkpoint barrier.
47    pub fn new(id: u64) -> Self {
48        Self {
49            id,
50            timestamp: Utc::now(),
51        }
52    }
53}
54
55/// Checkpoint configuration.
56#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct CheckpointConfig {
58    /// Checkpoint interval
59    pub interval: Duration,
60
61    /// Minimum pause between checkpoints
62    pub min_pause: Duration,
63
64    /// Maximum concurrent checkpoints
65    pub max_concurrent: usize,
66
67    /// Enable unaligned checkpoints
68    pub unaligned: bool,
69
70    /// Checkpoint timeout
71    pub timeout: Duration,
72
73    /// Storage path
74    pub storage_path: Option<PathBuf>,
75}
76
77impl Default for CheckpointConfig {
78    fn default() -> Self {
79        Self {
80            interval: Duration::from_secs(60),
81            min_pause: Duration::from_secs(10),
82            max_concurrent: 1,
83            unaligned: false,
84            timeout: Duration::from_secs(300),
85            storage_path: None,
86        }
87    }
88}
89
90/// Checkpoint storage.
91pub trait CheckpointStorage: Send + Sync {
92    /// Store a checkpoint.
93    fn store(&self, checkpoint: &Checkpoint) -> Result<()>;
94
95    /// Load a checkpoint.
96    fn load(&self, checkpoint_id: u64) -> Result<Option<Checkpoint>>;
97
98    /// Delete a checkpoint.
99    fn delete(&self, checkpoint_id: u64) -> Result<()>;
100
101    /// List all checkpoints.
102    fn list(&self) -> Result<Vec<u64>>;
103
104    /// Get the latest checkpoint ID.
105    fn latest(&self) -> Result<Option<u64>>;
106}
107
108/// In-memory checkpoint implementation.
109#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct Checkpoint {
111    /// Metadata
112    pub metadata: CheckpointMetadata,
113
114    /// Actual checkpoint data
115    pub data: Vec<u8>,
116}
117
118impl Checkpoint {
119    /// Create a new checkpoint.
120    pub fn new(id: u64, data: Vec<u8>) -> Self {
121        let size_bytes = data.len();
122        Self {
123            metadata: CheckpointMetadata {
124                id,
125                timestamp: Utc::now(),
126                size_bytes,
127                operator_states: HashMap::new(),
128                success: true,
129                duration: Duration::ZERO,
130            },
131            data,
132        }
133    }
134
135    /// Get the checkpoint ID.
136    pub fn id(&self) -> u64 {
137        self.metadata.id
138    }
139
140    /// Get the checkpoint size.
141    pub fn size(&self) -> usize {
142        self.metadata.size_bytes
143    }
144}
145
146/// Checkpoint coordinator.
147pub struct CheckpointCoordinator {
148    config: CheckpointConfig,
149    next_checkpoint_id: Arc<RwLock<u64>>,
150    active_checkpoints: Arc<RwLock<HashMap<u64, CheckpointMetadata>>>,
151    completed_checkpoints: Arc<RwLock<Vec<u64>>>,
152    last_checkpoint_time: Arc<RwLock<Option<DateTime<Utc>>>>,
153}
154
155impl CheckpointCoordinator {
156    /// Create a new checkpoint coordinator.
157    pub fn new(config: CheckpointConfig) -> Self {
158        Self {
159            config,
160            next_checkpoint_id: Arc::new(RwLock::new(0)),
161            active_checkpoints: Arc::new(RwLock::new(HashMap::new())),
162            completed_checkpoints: Arc::new(RwLock::new(Vec::new())),
163            last_checkpoint_time: Arc::new(RwLock::new(None)),
164        }
165    }
166
167    /// Trigger a new checkpoint.
168    pub async fn trigger_checkpoint(&self) -> Result<u64> {
169        let now = Utc::now();
170        let last_time = *self.last_checkpoint_time.read().await;
171
172        if let Some(last) = last_time {
173            let min_pause_chrono = match chrono::Duration::from_std(self.config.min_pause) {
174                Ok(duration) => duration,
175                Err(_) => chrono::Duration::zero(),
176            };
177
178            if now - last < min_pause_chrono {
179                return Err(StreamingError::CheckpointError(
180                    "Minimum pause not elapsed".to_string(),
181                ));
182            }
183        }
184
185        let active_count = self.active_checkpoints.read().await.len();
186        if active_count >= self.config.max_concurrent {
187            return Err(StreamingError::CheckpointError(
188                "Too many concurrent checkpoints".to_string(),
189            ));
190        }
191
192        let mut next_id = self.next_checkpoint_id.write().await;
193        let checkpoint_id = *next_id;
194        *next_id += 1;
195
196        let metadata = CheckpointMetadata {
197            id: checkpoint_id,
198            timestamp: now,
199            size_bytes: 0,
200            operator_states: HashMap::new(),
201            success: false,
202            duration: Duration::ZERO,
203        };
204
205        self.active_checkpoints
206            .write()
207            .await
208            .insert(checkpoint_id, metadata);
209
210        *self.last_checkpoint_time.write().await = Some(now);
211
212        Ok(checkpoint_id)
213    }
214
215    /// Complete a checkpoint.
216    pub async fn complete_checkpoint(&self, checkpoint_id: u64, success: bool) -> Result<()> {
217        let mut active = self.active_checkpoints.write().await;
218
219        if let Some(mut metadata) = active.remove(&checkpoint_id) {
220            metadata.success = success;
221            metadata.duration = match (Utc::now() - metadata.timestamp).to_std() {
222                Ok(duration) => duration,
223                Err(_) => Duration::ZERO,
224            };
225
226            if success {
227                self.completed_checkpoints.write().await.push(checkpoint_id);
228            }
229
230            Ok(())
231        } else {
232            Err(StreamingError::CheckpointError(format!(
233                "Checkpoint {} not found",
234                checkpoint_id
235            )))
236        }
237    }
238
239    /// Get active checkpoint count.
240    pub async fn active_count(&self) -> usize {
241        self.active_checkpoints.read().await.len()
242    }
243
244    /// Get completed checkpoint count.
245    pub async fn completed_count(&self) -> usize {
246        self.completed_checkpoints.read().await.len()
247    }
248
249    /// Get the latest completed checkpoint ID.
250    pub async fn latest_checkpoint(&self) -> Option<u64> {
251        self.completed_checkpoints.read().await.last().copied()
252    }
253
254    /// Clear old checkpoints.
255    pub async fn clear_old_checkpoints(&self, keep_count: usize) {
256        let mut completed = self.completed_checkpoints.write().await;
257
258        if completed.len() > keep_count {
259            let to_remove = completed.len() - keep_count;
260            completed.drain(0..to_remove);
261        }
262    }
263
264    /// Start periodic checkpointing.
265    pub async fn start_periodic_checkpointing(self: Arc<Self>) {
266        let interval = self.config.interval;
267
268        tokio::spawn(async move {
269            loop {
270                sleep(interval).await;
271
272                match self.trigger_checkpoint().await {
273                    Ok(id) => {
274                        tracing::info!("Triggered checkpoint {}", id);
275
276                        tokio::spawn({
277                            let coordinator = self.clone();
278                            async move {
279                                sleep(Duration::from_secs(1)).await;
280                                if let Err(e) = coordinator.complete_checkpoint(id, true).await {
281                                    tracing::error!("Failed to complete checkpoint {}: {}", id, e);
282                                }
283                            }
284                        });
285                    }
286                    Err(e) => {
287                        tracing::warn!("Failed to trigger checkpoint: {}", e);
288                    }
289                }
290            }
291        });
292    }
293}
294
295#[cfg(test)]
296mod tests {
297    use super::*;
298
299    #[tokio::test]
300    async fn test_checkpoint_creation() {
301        let data = vec![1, 2, 3, 4];
302        let checkpoint = Checkpoint::new(1, data.clone());
303
304        assert_eq!(checkpoint.id(), 1);
305        assert_eq!(checkpoint.size(), 4);
306        assert_eq!(checkpoint.data, data);
307    }
308
309    #[tokio::test]
310    async fn test_checkpoint_barrier() {
311        let barrier = CheckpointBarrier::new(1);
312        assert_eq!(barrier.id, 1);
313    }
314
315    #[tokio::test]
316    async fn test_checkpoint_coordinator() {
317        let config = CheckpointConfig {
318            min_pause: Duration::ZERO, // Allow immediate consecutive checkpoints
319            max_concurrent: 2,         // Allow 2 concurrent checkpoints
320            ..Default::default()
321        };
322        let coordinator = CheckpointCoordinator::new(config);
323
324        let id1 = coordinator
325            .trigger_checkpoint()
326            .await
327            .expect("First checkpoint trigger should succeed");
328        assert_eq!(id1, 0);
329
330        let id2 = coordinator
331            .trigger_checkpoint()
332            .await
333            .expect("Second checkpoint trigger should succeed");
334        assert_eq!(id2, 1);
335
336        assert_eq!(coordinator.active_count().await, 2);
337
338        coordinator
339            .complete_checkpoint(id1, true)
340            .await
341            .expect("Checkpoint completion should succeed");
342        assert_eq!(coordinator.active_count().await, 1);
343        assert_eq!(coordinator.completed_count().await, 1);
344    }
345
346    #[tokio::test]
347    async fn test_checkpoint_min_pause() {
348        let config = CheckpointConfig {
349            min_pause: Duration::from_secs(60),
350            ..Default::default()
351        };
352
353        let coordinator = CheckpointCoordinator::new(config);
354
355        coordinator
356            .trigger_checkpoint()
357            .await
358            .expect("First checkpoint should trigger successfully");
359        let result = coordinator.trigger_checkpoint().await;
360
361        assert!(result.is_err());
362    }
363
364    #[tokio::test]
365    async fn test_clear_old_checkpoints() {
366        let config = CheckpointConfig {
367            min_pause: Duration::ZERO, // Allow rapid consecutive checkpoints
368            ..Default::default()
369        };
370        let coordinator = CheckpointCoordinator::new(config);
371
372        for _ in 0..5 {
373            let id = coordinator
374                .trigger_checkpoint()
375                .await
376                .expect("Checkpoint trigger should succeed in loop");
377            coordinator
378                .complete_checkpoint(id, true)
379                .await
380                .expect("Checkpoint completion should succeed in loop");
381        }
382
383        assert_eq!(coordinator.completed_count().await, 5);
384
385        coordinator.clear_old_checkpoints(2).await;
386        assert_eq!(coordinator.completed_count().await, 2);
387    }
388}