1use std::collections::HashMap;
6
7pub use juncture_core::checkpoint::{
8 Checkpoint, CheckpointFilter, CheckpointMetadata, CheckpointPendingTask, CheckpointSource,
9 CheckpointTuple, DeltaCounters, DeltaOp, PendingWrite, PregelTaskInfo as PregelTaskInfoExport,
10 SerializedSend, StateSnapshot,
11};
12
13use crate::error::CheckpointError;
15
16pub type PregelTaskInfo = PregelTaskInfoExport;
18
19#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
23pub struct DeltaSnapshot {
24 pub base_checkpoint_id: String,
26
27 pub deltas: Vec<ChannelDelta>,
29}
30
31#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
35pub struct ChannelDelta {
36 pub channel: String,
38
39 pub op: DeltaOp,
41
42 pub values: Vec<serde_json::Value>,
44}
45
46pub fn recover_from_deltas(
96 checkpoints: &[CheckpointTuple],
97 target_checkpoint_id: &str,
98) -> Result<Option<Checkpoint>, CheckpointError> {
99 let target_index = checkpoints
101 .iter()
102 .position(|t| t.checkpoint.id == target_checkpoint_id);
103
104 let Some(target_idx) = target_index else {
105 return Ok(None);
106 };
107
108 let relevant_checkpoints = &checkpoints[..=target_idx];
110
111 let base_snapshot = relevant_checkpoints
115 .iter()
116 .rev()
117 .find(|t| {
118 !t.checkpoint.channel_values.is_null()
119 && t.checkpoint
120 .channel_values
121 .as_object()
122 .is_some_and(|obj| !obj.is_empty())
123 })
124 .ok_or_else(|| {
125 CheckpointError::deserialize_msg(
126 "No full snapshot found in checkpoint chain".to_string(),
127 )
128 })?;
129
130 let mut reconstructed = base_snapshot.checkpoint.clone();
132
133 let mut all_deltas: Vec<(&String, PendingWrite)> = Vec::new();
135
136 for tuple in relevant_checkpoints {
138 if tuple.checkpoint.id <= base_snapshot.checkpoint.id {
140 continue;
141 }
142
143 for write in &tuple.pending_writes {
145 all_deltas.push((&tuple.checkpoint.id, write.clone()));
146 }
147 }
148
149 all_deltas.sort_by(|a, b| a.0.cmp(b.0));
151
152 let channel_values = reconstructed
154 .channel_values
155 .as_object_mut()
156 .ok_or_else(|| {
157 CheckpointError::deserialize_msg(
158 "Base checkpoint channel_values is not an object".to_string(),
159 )
160 })?;
161
162 let mut modified_channels = HashMap::<String, u64>::new();
164
165 for (_checkpoint_id, write) in all_deltas {
166 let channel = &write.channel;
167
168 if let serde_json::Value::Array(values) = &write.value {
172 let entry = channel_values
174 .entry(channel.clone())
175 .or_insert(serde_json::Value::Array(vec![]));
176
177 if let Some(arr) = entry.as_array_mut() {
178 arr.extend(values.clone().into_iter());
179 }
180 } else {
181 channel_values.insert(channel.clone(), write.value.clone());
183 }
184
185 *modified_channels.entry(channel.clone()).or_insert(0) += 1;
187 }
188
189 for (channel, delta_count) in &modified_channels {
192 let current_version = reconstructed
193 .channel_versions
194 .get(channel)
195 .copied()
196 .unwrap_or(0);
197 reconstructed
198 .channel_versions
199 .insert(channel.clone(), current_version + delta_count);
200 }
201
202 reconstructed.new_versions = modified_channels;
204
205 reconstructed.counters_since_delta_snapshot.clear();
207
208 Ok(Some(reconstructed))
209}
210
211#[derive(Clone, Debug)]
215pub struct TtlConfig {
216 pub default_ttl: Option<std::time::Duration>,
218
219 pub sweep_interval: std::time::Duration,
221
222 pub max_checkpoints: Option<usize>,
224}
225
226impl TtlConfig {
227 #[must_use]
235 pub const fn new(
236 default_ttl: Option<std::time::Duration>,
237 sweep_interval: std::time::Duration,
238 max_checkpoints: Option<usize>,
239 ) -> Self {
240 Self {
241 default_ttl,
242 sweep_interval,
243 max_checkpoints,
244 }
245 }
246
247 #[must_use]
249 pub const fn disabled() -> Self {
250 Self {
251 default_ttl: None,
252 sweep_interval: std::time::Duration::from_secs(3600),
253 max_checkpoints: None,
254 }
255 }
256
257 #[must_use]
268 pub fn is_expired(&self, created_at_str: &str) -> bool {
269 let Some(ttl) = self.default_ttl else {
270 return false; };
272
273 let created_at = match chrono::DateTime::parse_from_rfc3339(created_at_str) {
275 Ok(dt) => dt.with_timezone(&chrono::Utc),
276 Err(_) => return false, };
278
279 let now = chrono::Utc::now();
280 let age = now.signed_duration_since(created_at);
281
282 age.to_std().unwrap_or(std::time::Duration::MAX) > ttl
283 }
284}
285
286impl Default for TtlConfig {
287 fn default() -> Self {
288 Self::disabled()
289 }
290}
291
292#[cfg(test)]
293mod tests {
294 use super::*;
295 use juncture_core::config::RunnableConfig;
296
297 #[test]
298 fn test_checkpoint_metadata_serialization() {
299 let metadata = CheckpointMetadata {
300 source: CheckpointSource::Loop,
301 step: 5,
302 writes: std::collections::HashMap::new(),
303 parents: std::collections::HashMap::new(),
304 run_id: "run-123".to_string(),
305 };
306
307 let serialized = serde_json::to_value(&metadata).unwrap();
308 let deserialized: CheckpointMetadata = serde_json::from_value(serialized).unwrap();
309
310 assert!(matches!(deserialized.source, CheckpointSource::Loop));
311 assert_eq!(deserialized.step, 5);
312 assert_eq!(deserialized.run_id, "run-123");
313 }
314
315 #[test]
316 fn test_delta_counters_default() {
317 let counters = DeltaCounters::default();
318 assert_eq!(counters.updates, 0);
319 assert_eq!(counters.supersteps, 0);
320 }
321
322 #[test]
323 fn test_checkpoint_filter_default() {
324 let filter = CheckpointFilter::default();
325 assert!(filter.source.is_none());
326 assert!(filter.step_gte.is_none());
327 assert!(filter.step_lte.is_none());
328 assert!(filter.before.is_none());
329 assert!(filter.after.is_none());
330 assert!(filter.limit.is_none());
331 }
332
333 #[test]
334 fn test_ttl_config_default() {
335 let config = TtlConfig::default();
336 assert!(config.default_ttl.is_none());
337 assert!(config.max_checkpoints.is_none());
338 }
339
340 #[test]
341 fn test_ttl_config_expiration() {
342 use std::time::Duration;
343
344 let config = TtlConfig::new(
345 Some(Duration::from_secs(60)),
346 Duration::from_secs(3600),
347 Some(100),
348 );
349
350 let now = chrono::Utc::now().to_rfc3339();
352 assert!(!config.is_expired(&now));
353
354 let past = (chrono::Utc::now() - chrono::Duration::seconds(120)).to_rfc3339();
356 assert!(config.is_expired(&past));
357 }
358
359 #[test]
360 fn test_recover_from_deltas_empty_list() {
361 let checkpoints = vec![];
362 let result = recover_from_deltas(&checkpoints, "cp1");
363 assert!(result.is_ok());
364 assert!(result.unwrap().is_none());
365 }
366
367 #[test]
368 fn test_recover_from_deltas_target_not_found() {
369 let checkpoints = vec![create_test_tuple("cp1", 0)];
370 let result = recover_from_deltas(&checkpoints, "cp2");
371 assert!(result.is_ok());
372 assert!(result.unwrap().is_none());
373 }
374
375 #[test]
376 fn test_recover_from_deltas_single_full_checkpoint() {
377 let checkpoints = vec![create_test_tuple("cp1", 0)];
378 let result = recover_from_deltas(&checkpoints, "cp1");
379 assert!(result.is_ok());
380
381 let recovered = result.unwrap().unwrap();
382 assert_eq!(recovered.id, "cp1");
383 assert_eq!(
384 recovered.channel_values["messages"],
385 serde_json::json!(["hello"])
386 );
387 }
388
389 #[test]
390 fn test_recover_from_deltas_with_pending_writes() {
391 let base = create_test_tuple("cp1", 0);
392 let mut delta = create_test_tuple("cp2", 1);
393
394 delta.checkpoint.channel_values = serde_json::json!({});
396
397 delta.pending_writes = vec![
399 PendingWrite {
400 task_id: "task1".to_string(),
401 channel: "messages".to_string(),
402 value: serde_json::json!(["world"]),
403 },
404 PendingWrite {
405 task_id: "task2".to_string(),
406 channel: "messages".to_string(),
407 value: serde_json::json!(["test"]),
408 },
409 ];
410
411 let checkpoints = vec![base, delta];
412 let result = recover_from_deltas(&checkpoints, "cp2");
413 assert!(result.is_ok());
414
415 let recovered = result.unwrap().unwrap();
416 assert_eq!(recovered.id, "cp1");
418
419 let messages = recovered.channel_values["messages"].as_array().unwrap();
421 assert_eq!(messages.len(), 3); assert_eq!(messages[0], "hello");
423 assert_eq!(messages[1], "world");
424 assert_eq!(messages[2], "test");
425
426 assert_eq!(recovered.channel_versions.get("messages"), Some(&3));
428 }
429
430 #[test]
431 fn test_recover_from_deltas_no_full_snapshot() {
432 let mut checkpoint = create_test_tuple("cp1", 0);
433 checkpoint.checkpoint.channel_values = serde_json::json!({});
435
436 let checkpoints = vec![checkpoint];
437 let result = recover_from_deltas(&checkpoints, "cp1");
438 assert!(result.is_err());
439 assert!(matches!(
440 result.unwrap_err(),
441 CheckpointError::Deserialize(_)
442 ));
443 }
444
445 #[test]
446 fn test_recover_from_deltas_multiple_deltas() {
447 let base = create_test_tuple("cp1", 0);
448
449 let mut delta1 = create_test_tuple("cp2", 1);
450 delta1.checkpoint.channel_values = serde_json::json!({});
452 delta1.pending_writes = vec![PendingWrite {
453 task_id: "task1".to_string(),
454 channel: "messages".to_string(),
455 value: serde_json::json!(["delta1"]),
456 }];
457
458 let mut delta2 = create_test_tuple("cp3", 2);
459 delta2.checkpoint.channel_values = serde_json::json!({});
461 delta2.pending_writes = vec![
462 PendingWrite {
463 task_id: "task2".to_string(),
464 channel: "messages".to_string(),
465 value: serde_json::json!(["delta2a"]),
466 },
467 PendingWrite {
468 task_id: "task3".to_string(),
469 channel: "messages".to_string(),
470 value: serde_json::json!(["delta2b"]),
471 },
472 ];
473
474 let checkpoints = vec![base, delta1, delta2];
475 let result = recover_from_deltas(&checkpoints, "cp3");
476 assert!(result.is_ok());
477
478 let recovered = result.unwrap().unwrap();
479 assert_eq!(recovered.id, "cp1");
481
482 let messages = recovered.channel_values["messages"].as_array().unwrap();
484 assert_eq!(messages.len(), 4); assert_eq!(messages[0], "hello");
486 assert_eq!(messages[1], "delta1");
487 assert_eq!(messages[2], "delta2a");
488 assert_eq!(messages[3], "delta2b");
489 }
490
491 fn create_test_tuple(id: &str, step: i64) -> CheckpointTuple {
493 CheckpointTuple {
494 config: RunnableConfig::default(),
495 checkpoint: Checkpoint {
496 id: id.to_string(),
497 channel_values: serde_json::json!({
498 "messages": ["hello"]
499 }),
500 channel_versions: HashMap::from([("messages".to_string(), 1)]),
501 versions_seen: HashMap::new(),
502 pending_tasks: vec![],
503 pending_sends: vec![],
504 pending_interrupts: vec![],
505 schema_version: 1,
506 created_at: chrono::Utc::now().to_rfc3339(),
507 v: 1,
508 new_versions: HashMap::new(),
509 counters_since_delta_snapshot: HashMap::new(),
510 },
511 metadata: CheckpointMetadata {
512 source: CheckpointSource::Loop,
513 step,
514 writes: HashMap::new(),
515 parents: HashMap::new(),
516 run_id: "test-run".to_string(),
517 },
518 pending_writes: vec![],
519 parent_config: None,
520 }
521 }
522}
523
524