1use crate::{ExecutionCheckpoint, ExecutionId, WorkflowId};
7use chrono::{DateTime, Utc};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10
11#[cfg(feature = "openapi")]
12use utoipa::ToSchema;
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16#[cfg_attr(feature = "openapi", derive(ToSchema))]
17pub struct CheckpointConfig {
18 pub enabled: bool,
20
21 pub frequency: CheckpointFrequency,
23
24 pub max_checkpoints: usize,
26
27 pub auto_checkpoint_threshold_ms: Option<u64>,
29
30 pub compress: bool,
32}
33
34impl Default for CheckpointConfig {
35 fn default() -> Self {
36 Self {
37 enabled: true,
38 frequency: CheckpointFrequency::EveryNNodes(5),
39 max_checkpoints: 10,
40 auto_checkpoint_threshold_ms: Some(60000), compress: false,
42 }
43 }
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
48#[cfg_attr(feature = "openapi", derive(ToSchema))]
49pub enum CheckpointFrequency {
50 EveryNNodes(usize),
52
53 TimeInterval(u64),
55
56 BeforeNodeTypes(Vec<String>),
58
59 Manual,
61
62 Always,
64}
65
66pub trait CheckpointStorage: Send + Sync {
68 fn save_checkpoint(
70 &self,
71 execution_id: ExecutionId,
72 checkpoint: &ExecutionCheckpoint,
73 ) -> Result<CheckpointId, CheckpointError>;
74
75 fn load_latest_checkpoint(
77 &self,
78 execution_id: ExecutionId,
79 ) -> Result<Option<ExecutionCheckpoint>, CheckpointError>;
80
81 fn load_checkpoint(
83 &self,
84 checkpoint_id: CheckpointId,
85 ) -> Result<Option<ExecutionCheckpoint>, CheckpointError>;
86
87 fn list_checkpoints(
89 &self,
90 execution_id: ExecutionId,
91 ) -> Result<Vec<CheckpointMetadata>, CheckpointError>;
92
93 fn prune_checkpoints(
95 &self,
96 execution_id: ExecutionId,
97 keep_count: usize,
98 ) -> Result<usize, CheckpointError>;
99
100 fn delete_checkpoints(&self, execution_id: ExecutionId) -> Result<usize, CheckpointError>;
102}
103
104pub type CheckpointId = uuid::Uuid;
106
107#[derive(Debug, Clone, Serialize, Deserialize)]
109#[cfg_attr(feature = "openapi", derive(ToSchema))]
110pub struct CheckpointMetadata {
111 #[cfg_attr(feature = "openapi", schema(value_type = String, format = "uuid"))]
113 pub id: CheckpointId,
114
115 #[cfg_attr(feature = "openapi", schema(value_type = String, format = "uuid"))]
117 pub execution_id: ExecutionId,
118
119 #[cfg_attr(feature = "openapi", schema(value_type = String, format = "uuid"))]
121 pub workflow_id: WorkflowId,
122
123 pub created_at: DateTime<Utc>,
125
126 pub completed_node_count: usize,
128
129 pub size_bytes: usize,
131
132 pub compressed: bool,
134}
135
136#[derive(Debug, thiserror::Error)]
138pub enum CheckpointError {
139 #[error("Checkpoint not found: {0}")]
140 NotFound(CheckpointId),
141
142 #[error("Storage error: {0}")]
143 StorageError(String),
144
145 #[error("Serialization error: {0}")]
146 SerializationError(String),
147
148 #[error("Decompression error: {0}")]
149 DecompressionError(String),
150
151 #[error("Invalid checkpoint data: {0}")]
152 InvalidData(String),
153}
154
155#[derive(Debug, Default)]
157pub struct InMemoryCheckpointStorage {
158 checkpoints: std::sync::RwLock<HashMap<CheckpointId, (ExecutionId, ExecutionCheckpoint)>>,
159 metadata: std::sync::RwLock<HashMap<CheckpointId, CheckpointMetadata>>,
160}
161
162impl InMemoryCheckpointStorage {
163 pub fn new() -> Self {
164 Self::default()
165 }
166}
167
168impl CheckpointStorage for InMemoryCheckpointStorage {
169 fn save_checkpoint(
170 &self,
171 execution_id: ExecutionId,
172 checkpoint: &ExecutionCheckpoint,
173 ) -> Result<CheckpointId, CheckpointError> {
174 let checkpoint_id = uuid::Uuid::new_v4();
175
176 let data = serde_json::to_vec(checkpoint)
178 .map_err(|e| CheckpointError::SerializationError(e.to_string()))?;
179
180 let metadata = CheckpointMetadata {
181 id: checkpoint_id,
182 execution_id,
183 workflow_id: uuid::Uuid::new_v4(), created_at: checkpoint.timestamp,
185 completed_node_count: checkpoint.completed_nodes.len(),
186 size_bytes: data.len(),
187 compressed: false,
188 };
189
190 self.checkpoints
191 .write()
192 .unwrap()
193 .insert(checkpoint_id, (execution_id, checkpoint.clone()));
194 self.metadata
195 .write()
196 .unwrap()
197 .insert(checkpoint_id, metadata);
198
199 Ok(checkpoint_id)
200 }
201
202 fn load_latest_checkpoint(
203 &self,
204 execution_id: ExecutionId,
205 ) -> Result<Option<ExecutionCheckpoint>, CheckpointError> {
206 let checkpoints = self.checkpoints.read().unwrap();
207
208 let latest = checkpoints
210 .iter()
211 .filter(|(_, (exec_id, _))| *exec_id == execution_id)
212 .map(|(id, (_, checkpoint))| (*id, checkpoint))
213 .max_by_key(|(_, checkpoint)| checkpoint.timestamp);
214
215 Ok(latest.map(|(_, checkpoint)| checkpoint.clone()))
216 }
217
218 fn load_checkpoint(
219 &self,
220 checkpoint_id: CheckpointId,
221 ) -> Result<Option<ExecutionCheckpoint>, CheckpointError> {
222 let checkpoints = self.checkpoints.read().unwrap();
223 Ok(checkpoints
224 .get(&checkpoint_id)
225 .map(|(_, checkpoint)| checkpoint.clone()))
226 }
227
228 fn list_checkpoints(
229 &self,
230 execution_id: ExecutionId,
231 ) -> Result<Vec<CheckpointMetadata>, CheckpointError> {
232 let metadata = self.metadata.read().unwrap();
233 let mut list: Vec<_> = metadata
234 .values()
235 .filter(|m| m.execution_id == execution_id)
236 .cloned()
237 .collect();
238
239 list.sort_by(|a, b| b.created_at.cmp(&a.created_at));
241
242 Ok(list)
243 }
244
245 fn prune_checkpoints(
246 &self,
247 execution_id: ExecutionId,
248 keep_count: usize,
249 ) -> Result<usize, CheckpointError> {
250 let list = self.list_checkpoints(execution_id)?;
251
252 if list.len() <= keep_count {
253 return Ok(0);
254 }
255
256 let to_delete = &list[keep_count..];
257 let mut checkpoints = self.checkpoints.write().unwrap();
258 let mut metadata = self.metadata.write().unwrap();
259
260 for meta in to_delete {
261 checkpoints.remove(&meta.id);
262 metadata.remove(&meta.id);
263 }
264
265 Ok(to_delete.len())
266 }
267
268 fn delete_checkpoints(&self, execution_id: ExecutionId) -> Result<usize, CheckpointError> {
269 let list = self.list_checkpoints(execution_id)?;
270 let mut checkpoints = self.checkpoints.write().unwrap();
271 let mut metadata = self.metadata.write().unwrap();
272
273 for meta in &list {
274 checkpoints.remove(&meta.id);
275 metadata.remove(&meta.id);
276 }
277
278 Ok(list.len())
279 }
280}
281
282#[cfg(test)]
283mod tests {
284 use super::*;
285 use crate::ExecutionState;
286
287 #[test]
288 fn test_checkpoint_config_default() {
289 let config = CheckpointConfig::default();
290 assert!(config.enabled);
291 assert_eq!(config.frequency, CheckpointFrequency::EveryNNodes(5));
292 assert_eq!(config.max_checkpoints, 10);
293 }
294
295 #[test]
296 fn test_checkpoint_frequency_variants() {
297 let freq1 = CheckpointFrequency::EveryNNodes(10);
298 let freq2 = CheckpointFrequency::TimeInterval(60);
299 let freq3 = CheckpointFrequency::Manual;
300 let freq4 = CheckpointFrequency::Always;
301
302 assert_eq!(freq1, CheckpointFrequency::EveryNNodes(10));
303 assert_ne!(freq2, freq3);
304 assert_ne!(freq3, freq4);
305 }
306
307 #[test]
308 fn test_in_memory_storage_save_load() {
309 let storage = InMemoryCheckpointStorage::new();
310 let execution_id = uuid::Uuid::new_v4();
311
312 let checkpoint = ExecutionCheckpoint {
313 timestamp: Utc::now(),
314 completed_nodes: vec![uuid::Uuid::new_v4()],
315 variables: HashMap::new(),
316 state: ExecutionState::Running,
317 };
318
319 let checkpoint_id = storage.save_checkpoint(execution_id, &checkpoint).unwrap();
321
322 let loaded = storage.load_checkpoint(checkpoint_id).unwrap();
324 assert!(loaded.is_some());
325 assert_eq!(loaded.unwrap().completed_nodes, checkpoint.completed_nodes);
326
327 let latest = storage.load_latest_checkpoint(execution_id).unwrap();
329 assert!(latest.is_some());
330 }
331
332 #[test]
333 fn test_list_checkpoints() {
334 let storage = InMemoryCheckpointStorage::new();
335 let execution_id = uuid::Uuid::new_v4();
336
337 for i in 0..3 {
339 let checkpoint = ExecutionCheckpoint {
340 timestamp: Utc::now(),
341 completed_nodes: vec![uuid::Uuid::new_v4(); i + 1],
342 variables: HashMap::new(),
343 state: ExecutionState::Running,
344 };
345 storage.save_checkpoint(execution_id, &checkpoint).unwrap();
346 }
347
348 let list = storage.list_checkpoints(execution_id).unwrap();
349 assert_eq!(list.len(), 3);
350 }
351
352 #[test]
353 fn test_prune_checkpoints() {
354 let storage = InMemoryCheckpointStorage::new();
355 let execution_id = uuid::Uuid::new_v4();
356
357 for _ in 0..5 {
359 let checkpoint = ExecutionCheckpoint {
360 timestamp: Utc::now(),
361 completed_nodes: vec![uuid::Uuid::new_v4()],
362 variables: HashMap::new(),
363 state: ExecutionState::Running,
364 };
365 storage.save_checkpoint(execution_id, &checkpoint).unwrap();
366 }
367
368 let deleted = storage.prune_checkpoints(execution_id, 2).unwrap();
370 assert_eq!(deleted, 3);
371
372 let remaining = storage.list_checkpoints(execution_id).unwrap();
373 assert_eq!(remaining.len(), 2);
374 }
375
376 #[test]
377 fn test_delete_all_checkpoints() {
378 let storage = InMemoryCheckpointStorage::new();
379 let execution_id = uuid::Uuid::new_v4();
380
381 for _ in 0..3 {
383 let checkpoint = ExecutionCheckpoint {
384 timestamp: Utc::now(),
385 completed_nodes: vec![uuid::Uuid::new_v4()],
386 variables: HashMap::new(),
387 state: ExecutionState::Running,
388 };
389 storage.save_checkpoint(execution_id, &checkpoint).unwrap();
390 }
391
392 let deleted = storage.delete_checkpoints(execution_id).unwrap();
394 assert_eq!(deleted, 3);
395
396 let remaining = storage.list_checkpoints(execution_id).unwrap();
397 assert_eq!(remaining.len(), 0);
398 }
399
400 #[test]
401 fn test_multiple_executions() {
402 let storage = InMemoryCheckpointStorage::new();
403 let exec1 = uuid::Uuid::new_v4();
404 let exec2 = uuid::Uuid::new_v4();
405
406 for exec_id in [exec1, exec2] {
408 for _ in 0..2 {
409 let checkpoint = ExecutionCheckpoint {
410 timestamp: Utc::now(),
411 completed_nodes: vec![uuid::Uuid::new_v4()],
412 variables: HashMap::new(),
413 state: ExecutionState::Running,
414 };
415 storage.save_checkpoint(exec_id, &checkpoint).unwrap();
416 }
417 }
418
419 assert_eq!(storage.list_checkpoints(exec1).unwrap().len(), 2);
421 assert_eq!(storage.list_checkpoints(exec2).unwrap().len(), 2);
422
423 storage.delete_checkpoints(exec1).unwrap();
425
426 assert_eq!(storage.list_checkpoints(exec1).unwrap().len(), 0);
428 assert_eq!(storage.list_checkpoints(exec2).unwrap().len(), 2);
429 }
430}