oxify_model/
checkpoint.rs

1//! Checkpoint configuration and storage abstraction
2//!
3//! This module provides checkpoint frequency configuration and storage traits
4//! for saving and restoring workflow execution state.
5
6use crate::{ExecutionCheckpoint, ExecutionId, WorkflowId};
7use chrono::{DateTime, Utc};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10
11#[cfg(feature = "openapi")]
12use utoipa::ToSchema;
13
14/// Configuration for checkpoint frequency and behavior
15#[derive(Debug, Clone, Serialize, Deserialize)]
16#[cfg_attr(feature = "openapi", derive(ToSchema))]
17pub struct CheckpointConfig {
18    /// Enable automatic checkpointing
19    pub enabled: bool,
20
21    /// Checkpoint frequency strategy
22    pub frequency: CheckpointFrequency,
23
24    /// Maximum number of checkpoints to retain per execution
25    pub max_checkpoints: usize,
26
27    /// Automatically checkpoint after long-running nodes (threshold in ms)
28    pub auto_checkpoint_threshold_ms: Option<u64>,
29
30    /// Compress checkpoint data
31    pub compress: bool,
32}
33
34impl Default for CheckpointConfig {
35    fn default() -> Self {
36        Self {
37            enabled: true,
38            frequency: CheckpointFrequency::EveryNNodes(5),
39            max_checkpoints: 10,
40            auto_checkpoint_threshold_ms: Some(60000), // 1 minute
41            compress: false,
42        }
43    }
44}
45
46/// Checkpoint frequency strategies
47#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
48#[cfg_attr(feature = "openapi", derive(ToSchema))]
49pub enum CheckpointFrequency {
50    /// Checkpoint after every N nodes complete
51    EveryNNodes(usize),
52
53    /// Checkpoint at specific time intervals (in seconds)
54    TimeInterval(u64),
55
56    /// Checkpoint before specific node types (e.g., before expensive LLM calls)
57    BeforeNodeTypes(Vec<String>),
58
59    /// Manual checkpointing only (no automatic checkpoints)
60    Manual,
61
62    /// Checkpoint at every node completion (maximum safety, higher overhead)
63    Always,
64}
65
66/// Storage abstraction for checkpoints
67pub trait CheckpointStorage: Send + Sync {
68    /// Save a checkpoint
69    fn save_checkpoint(
70        &self,
71        execution_id: ExecutionId,
72        checkpoint: &ExecutionCheckpoint,
73    ) -> Result<CheckpointId, CheckpointError>;
74
75    /// Load the latest checkpoint for an execution
76    fn load_latest_checkpoint(
77        &self,
78        execution_id: ExecutionId,
79    ) -> Result<Option<ExecutionCheckpoint>, CheckpointError>;
80
81    /// Load a specific checkpoint by ID
82    fn load_checkpoint(
83        &self,
84        checkpoint_id: CheckpointId,
85    ) -> Result<Option<ExecutionCheckpoint>, CheckpointError>;
86
87    /// List all checkpoints for an execution
88    fn list_checkpoints(
89        &self,
90        execution_id: ExecutionId,
91    ) -> Result<Vec<CheckpointMetadata>, CheckpointError>;
92
93    /// Delete old checkpoints beyond retention limit
94    fn prune_checkpoints(
95        &self,
96        execution_id: ExecutionId,
97        keep_count: usize,
98    ) -> Result<usize, CheckpointError>;
99
100    /// Delete all checkpoints for an execution
101    fn delete_checkpoints(&self, execution_id: ExecutionId) -> Result<usize, CheckpointError>;
102}
103
104/// Unique identifier for a checkpoint
105pub type CheckpointId = uuid::Uuid;
106
107/// Metadata about a stored checkpoint
108#[derive(Debug, Clone, Serialize, Deserialize)]
109#[cfg_attr(feature = "openapi", derive(ToSchema))]
110pub struct CheckpointMetadata {
111    /// Unique checkpoint identifier
112    #[cfg_attr(feature = "openapi", schema(value_type = String, format = "uuid"))]
113    pub id: CheckpointId,
114
115    /// Execution this checkpoint belongs to
116    #[cfg_attr(feature = "openapi", schema(value_type = String, format = "uuid"))]
117    pub execution_id: ExecutionId,
118
119    /// Workflow being executed
120    #[cfg_attr(feature = "openapi", schema(value_type = String, format = "uuid"))]
121    pub workflow_id: WorkflowId,
122
123    /// When the checkpoint was created
124    pub created_at: DateTime<Utc>,
125
126    /// Number of nodes completed at checkpoint time
127    pub completed_node_count: usize,
128
129    /// Size of checkpoint data in bytes
130    pub size_bytes: usize,
131
132    /// Whether the checkpoint is compressed
133    pub compressed: bool,
134}
135
136/// Errors that can occur during checkpoint operations
137#[derive(Debug, thiserror::Error)]
138pub enum CheckpointError {
139    #[error("Checkpoint not found: {0}")]
140    NotFound(CheckpointId),
141
142    #[error("Storage error: {0}")]
143    StorageError(String),
144
145    #[error("Serialization error: {0}")]
146    SerializationError(String),
147
148    #[error("Decompression error: {0}")]
149    DecompressionError(String),
150
151    #[error("Invalid checkpoint data: {0}")]
152    InvalidData(String),
153}
154
155/// In-memory checkpoint storage implementation (for testing/development)
156#[derive(Debug, Default)]
157pub struct InMemoryCheckpointStorage {
158    checkpoints: std::sync::RwLock<HashMap<CheckpointId, (ExecutionId, ExecutionCheckpoint)>>,
159    metadata: std::sync::RwLock<HashMap<CheckpointId, CheckpointMetadata>>,
160}
161
162impl InMemoryCheckpointStorage {
163    pub fn new() -> Self {
164        Self::default()
165    }
166}
167
168impl CheckpointStorage for InMemoryCheckpointStorage {
169    fn save_checkpoint(
170        &self,
171        execution_id: ExecutionId,
172        checkpoint: &ExecutionCheckpoint,
173    ) -> Result<CheckpointId, CheckpointError> {
174        let checkpoint_id = uuid::Uuid::new_v4();
175
176        // Serialize to estimate size
177        let data = serde_json::to_vec(checkpoint)
178            .map_err(|e| CheckpointError::SerializationError(e.to_string()))?;
179
180        let metadata = CheckpointMetadata {
181            id: checkpoint_id,
182            execution_id,
183            workflow_id: uuid::Uuid::new_v4(), // Would be passed from context
184            created_at: checkpoint.timestamp,
185            completed_node_count: checkpoint.completed_nodes.len(),
186            size_bytes: data.len(),
187            compressed: false,
188        };
189
190        self.checkpoints
191            .write()
192            .unwrap()
193            .insert(checkpoint_id, (execution_id, checkpoint.clone()));
194        self.metadata
195            .write()
196            .unwrap()
197            .insert(checkpoint_id, metadata);
198
199        Ok(checkpoint_id)
200    }
201
202    fn load_latest_checkpoint(
203        &self,
204        execution_id: ExecutionId,
205    ) -> Result<Option<ExecutionCheckpoint>, CheckpointError> {
206        let checkpoints = self.checkpoints.read().unwrap();
207
208        // Find latest checkpoint for this execution
209        let latest = checkpoints
210            .iter()
211            .filter(|(_, (exec_id, _))| *exec_id == execution_id)
212            .map(|(id, (_, checkpoint))| (*id, checkpoint))
213            .max_by_key(|(_, checkpoint)| checkpoint.timestamp);
214
215        Ok(latest.map(|(_, checkpoint)| checkpoint.clone()))
216    }
217
218    fn load_checkpoint(
219        &self,
220        checkpoint_id: CheckpointId,
221    ) -> Result<Option<ExecutionCheckpoint>, CheckpointError> {
222        let checkpoints = self.checkpoints.read().unwrap();
223        Ok(checkpoints
224            .get(&checkpoint_id)
225            .map(|(_, checkpoint)| checkpoint.clone()))
226    }
227
228    fn list_checkpoints(
229        &self,
230        execution_id: ExecutionId,
231    ) -> Result<Vec<CheckpointMetadata>, CheckpointError> {
232        let metadata = self.metadata.read().unwrap();
233        let mut list: Vec<_> = metadata
234            .values()
235            .filter(|m| m.execution_id == execution_id)
236            .cloned()
237            .collect();
238
239        // Sort by creation time (newest first)
240        list.sort_by(|a, b| b.created_at.cmp(&a.created_at));
241
242        Ok(list)
243    }
244
245    fn prune_checkpoints(
246        &self,
247        execution_id: ExecutionId,
248        keep_count: usize,
249    ) -> Result<usize, CheckpointError> {
250        let list = self.list_checkpoints(execution_id)?;
251
252        if list.len() <= keep_count {
253            return Ok(0);
254        }
255
256        let to_delete = &list[keep_count..];
257        let mut checkpoints = self.checkpoints.write().unwrap();
258        let mut metadata = self.metadata.write().unwrap();
259
260        for meta in to_delete {
261            checkpoints.remove(&meta.id);
262            metadata.remove(&meta.id);
263        }
264
265        Ok(to_delete.len())
266    }
267
268    fn delete_checkpoints(&self, execution_id: ExecutionId) -> Result<usize, CheckpointError> {
269        let list = self.list_checkpoints(execution_id)?;
270        let mut checkpoints = self.checkpoints.write().unwrap();
271        let mut metadata = self.metadata.write().unwrap();
272
273        for meta in &list {
274            checkpoints.remove(&meta.id);
275            metadata.remove(&meta.id);
276        }
277
278        Ok(list.len())
279    }
280}
281
282#[cfg(test)]
283mod tests {
284    use super::*;
285    use crate::ExecutionState;
286
287    #[test]
288    fn test_checkpoint_config_default() {
289        let config = CheckpointConfig::default();
290        assert!(config.enabled);
291        assert_eq!(config.frequency, CheckpointFrequency::EveryNNodes(5));
292        assert_eq!(config.max_checkpoints, 10);
293    }
294
295    #[test]
296    fn test_checkpoint_frequency_variants() {
297        let freq1 = CheckpointFrequency::EveryNNodes(10);
298        let freq2 = CheckpointFrequency::TimeInterval(60);
299        let freq3 = CheckpointFrequency::Manual;
300        let freq4 = CheckpointFrequency::Always;
301
302        assert_eq!(freq1, CheckpointFrequency::EveryNNodes(10));
303        assert_ne!(freq2, freq3);
304        assert_ne!(freq3, freq4);
305    }
306
307    #[test]
308    fn test_in_memory_storage_save_load() {
309        let storage = InMemoryCheckpointStorage::new();
310        let execution_id = uuid::Uuid::new_v4();
311
312        let checkpoint = ExecutionCheckpoint {
313            timestamp: Utc::now(),
314            completed_nodes: vec![uuid::Uuid::new_v4()],
315            variables: HashMap::new(),
316            state: ExecutionState::Running,
317        };
318
319        // Save checkpoint
320        let checkpoint_id = storage.save_checkpoint(execution_id, &checkpoint).unwrap();
321
322        // Load checkpoint by ID
323        let loaded = storage.load_checkpoint(checkpoint_id).unwrap();
324        assert!(loaded.is_some());
325        assert_eq!(loaded.unwrap().completed_nodes, checkpoint.completed_nodes);
326
327        // Load latest checkpoint
328        let latest = storage.load_latest_checkpoint(execution_id).unwrap();
329        assert!(latest.is_some());
330    }
331
332    #[test]
333    fn test_list_checkpoints() {
334        let storage = InMemoryCheckpointStorage::new();
335        let execution_id = uuid::Uuid::new_v4();
336
337        // Create multiple checkpoints
338        for i in 0..3 {
339            let checkpoint = ExecutionCheckpoint {
340                timestamp: Utc::now(),
341                completed_nodes: vec![uuid::Uuid::new_v4(); i + 1],
342                variables: HashMap::new(),
343                state: ExecutionState::Running,
344            };
345            storage.save_checkpoint(execution_id, &checkpoint).unwrap();
346        }
347
348        let list = storage.list_checkpoints(execution_id).unwrap();
349        assert_eq!(list.len(), 3);
350    }
351
352    #[test]
353    fn test_prune_checkpoints() {
354        let storage = InMemoryCheckpointStorage::new();
355        let execution_id = uuid::Uuid::new_v4();
356
357        // Create 5 checkpoints
358        for _ in 0..5 {
359            let checkpoint = ExecutionCheckpoint {
360                timestamp: Utc::now(),
361                completed_nodes: vec![uuid::Uuid::new_v4()],
362                variables: HashMap::new(),
363                state: ExecutionState::Running,
364            };
365            storage.save_checkpoint(execution_id, &checkpoint).unwrap();
366        }
367
368        // Prune to keep only 2
369        let deleted = storage.prune_checkpoints(execution_id, 2).unwrap();
370        assert_eq!(deleted, 3);
371
372        let remaining = storage.list_checkpoints(execution_id).unwrap();
373        assert_eq!(remaining.len(), 2);
374    }
375
376    #[test]
377    fn test_delete_all_checkpoints() {
378        let storage = InMemoryCheckpointStorage::new();
379        let execution_id = uuid::Uuid::new_v4();
380
381        // Create checkpoints
382        for _ in 0..3 {
383            let checkpoint = ExecutionCheckpoint {
384                timestamp: Utc::now(),
385                completed_nodes: vec![uuid::Uuid::new_v4()],
386                variables: HashMap::new(),
387                state: ExecutionState::Running,
388            };
389            storage.save_checkpoint(execution_id, &checkpoint).unwrap();
390        }
391
392        // Delete all
393        let deleted = storage.delete_checkpoints(execution_id).unwrap();
394        assert_eq!(deleted, 3);
395
396        let remaining = storage.list_checkpoints(execution_id).unwrap();
397        assert_eq!(remaining.len(), 0);
398    }
399
400    #[test]
401    fn test_multiple_executions() {
402        let storage = InMemoryCheckpointStorage::new();
403        let exec1 = uuid::Uuid::new_v4();
404        let exec2 = uuid::Uuid::new_v4();
405
406        // Create checkpoints for two different executions
407        for exec_id in [exec1, exec2] {
408            for _ in 0..2 {
409                let checkpoint = ExecutionCheckpoint {
410                    timestamp: Utc::now(),
411                    completed_nodes: vec![uuid::Uuid::new_v4()],
412                    variables: HashMap::new(),
413                    state: ExecutionState::Running,
414                };
415                storage.save_checkpoint(exec_id, &checkpoint).unwrap();
416            }
417        }
418
419        // Each execution should have 2 checkpoints
420        assert_eq!(storage.list_checkpoints(exec1).unwrap().len(), 2);
421        assert_eq!(storage.list_checkpoints(exec2).unwrap().len(), 2);
422
423        // Delete one execution's checkpoints
424        storage.delete_checkpoints(exec1).unwrap();
425
426        // exec1 should have no checkpoints, exec2 should still have 2
427        assert_eq!(storage.list_checkpoints(exec1).unwrap().len(), 0);
428        assert_eq!(storage.list_checkpoints(exec2).unwrap().len(), 2);
429    }
430}