Skip to main content

juncture_checkpoint/
types.rs

1//! Checkpoint data structures
2//!
3//! This module re-exports checkpoint types from juncture-core for convenience.
4
5use std::collections::HashMap;
6
7pub use juncture_core::checkpoint::{
8    Checkpoint, CheckpointFilter, CheckpointMetadata, CheckpointPendingTask, CheckpointSource,
9    CheckpointTuple, DeltaCounters, DeltaOp, PendingWrite, PregelTaskInfo as PregelTaskInfoExport,
10    SerializedSend, StateSnapshot,
11};
12
13// Import CheckpointError from this crate's error module
14use crate::error::CheckpointError;
15
16/// Pregel task information (re-exported from juncture-core)
17pub type PregelTaskInfo = PregelTaskInfoExport;
18
19/// Delta snapshot for incremental checkpointing
20///
21/// Stores only the changes from a base checkpoint, enabling efficient storage.
22#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
23pub struct DeltaSnapshot {
24    /// Base checkpoint ID (full snapshot)
25    pub base_checkpoint_id: String,
26
27    /// Ordered list of channel deltas
28    pub deltas: Vec<ChannelDelta>,
29}
30
31/// Delta for a single channel
32///
33/// Represents incremental changes to a specific channel.
34#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
35pub struct ChannelDelta {
36    /// Channel name
37    pub channel: String,
38
39    /// Operation type
40    pub op: DeltaOp,
41
42    /// Values to apply
43    pub values: Vec<serde_json::Value>,
44}
45
46/// Recover a full checkpoint from delta snapshots using ancestor walk algorithm
47///
48/// This function implements the delta recovery algorithm specified in design doc §1.4 and §8.3.
49/// Given a list of checkpoint tuples sorted by step (ascending), it finds the latest full snapshot
50/// and replays all subsequent delta writes to reconstruct the complete checkpoint state.
51///
52/// # Algorithm
53///
54/// 1. Find the nearest full snapshot (latest checkpoint with `pending_writes` empty or minimal)
55/// 2. Walk forward through all delta writes from checkpoints after the snapshot
56/// 3. Replay delta writes to the snapshot:
57///    - Append channels: `snapshot[channel].extend(delta.values)`
58///    - Replace channels: `snapshot[channel] = delta.values`
59/// 4. Generate complete Checkpoint object with updated versions
60///
61/// # Arguments
62///
63/// * `checkpoints` - Slice of checkpoint tuples sorted by step (ascending order)
64/// * `target_checkpoint_id` - ID of the target checkpoint to recover (must be in the list)
65///
66/// # Returns
67///
68/// * `Ok(Some(Checkpoint))` - The reconstructed full checkpoint
69/// * `Ok(None)` - Target checkpoint not found in the list
70/// * `Err(CheckpointError)` - Recovery failure (invalid data, missing base, etc.)
71///
72/// # Errors
73///
74/// Returns [`CheckpointError`] if:
75/// - Target checkpoint ID is not found in the list
76/// - Base checkpoint for delta snapshot is missing
77/// - Channel data cannot be merged (type mismatch)
78/// - Checkpoint list is not properly sorted
79///
80/// # Examples
81///
82/// ```ignore
83/// use juncture_checkpoint::types::recover_from_deltas;
84/// use juncture_core::checkpoint::CheckpointTuple;
85///
86/// let checkpoints = vec![
87///     full_snapshot_checkpoint,   // Step 0 - full snapshot
88///     delta_checkpoint_1,          // Step 1 - deltas only
89///     delta_checkpoint_2,          // Step 2 - deltas only
90/// ];
91///
92/// let recovered = recover_from_deltas(&checkpoints, "cp2").await?;
93/// assert!(recovered.is_some());
94/// ```
95pub fn recover_from_deltas(
96    checkpoints: &[CheckpointTuple],
97    target_checkpoint_id: &str,
98) -> Result<Option<Checkpoint>, CheckpointError> {
99    // Validate input: find target checkpoint
100    let target_index = checkpoints
101        .iter()
102        .position(|t| t.checkpoint.id == target_checkpoint_id);
103
104    let Some(target_idx) = target_index else {
105        return Ok(None);
106    };
107
108    // Consider checkpoints up to and including the target
109    let relevant_checkpoints = &checkpoints[..=target_idx];
110
111    // Step 1: Find the nearest full snapshot
112    // A full snapshot is one that contains complete channel_values
113    // We iterate backwards from target to find the most recent full checkpoint
114    let base_snapshot = relevant_checkpoints
115        .iter()
116        .rev()
117        .find(|t| {
118            !t.checkpoint.channel_values.is_null()
119                && t.checkpoint
120                    .channel_values
121                    .as_object()
122                    .is_some_and(|obj| !obj.is_empty())
123        })
124        .ok_or_else(|| {
125            CheckpointError::deserialize_msg(
126                "No full snapshot found in checkpoint chain".to_string(),
127            )
128        })?;
129
130    // Clone the base checkpoint as our starting point
131    let mut reconstructed = base_snapshot.checkpoint.clone();
132
133    // Collect all pending writes from checkpoints after the base snapshot
134    let mut all_deltas: Vec<(&String, PendingWrite)> = Vec::new();
135
136    // Step 2: Walk forward collecting all delta writes
137    for tuple in relevant_checkpoints {
138        // Skip checkpoints that are before or at the base snapshot
139        if tuple.checkpoint.id <= base_snapshot.checkpoint.id {
140            continue;
141        }
142
143        // Collect pending writes from this checkpoint
144        for write in &tuple.pending_writes {
145            all_deltas.push((&tuple.checkpoint.id, write.clone()));
146        }
147    }
148
149    // Sort deltas by checkpoint ID to ensure correct order
150    all_deltas.sort_by(|a, b| a.0.cmp(b.0));
151
152    // Step 3: Replay delta writes to the snapshot
153    let channel_values = reconstructed
154        .channel_values
155        .as_object_mut()
156        .ok_or_else(|| {
157            CheckpointError::deserialize_msg(
158                "Base checkpoint channel_values is not an object".to_string(),
159            )
160        })?;
161
162    // Track which channels were modified
163    let mut modified_channels = HashMap::<String, u64>::new();
164
165    for (_checkpoint_id, write) in all_deltas {
166        let channel = &write.channel;
167
168        // Delta channels use Append semantics
169        // In a full implementation, the operation type would be determined
170        // by the channel's reducer type configuration
171        if let serde_json::Value::Array(values) = &write.value {
172            // Append array values to existing channel data
173            let entry = channel_values
174                .entry(channel.clone())
175                .or_insert(serde_json::Value::Array(vec![]));
176
177            if let Some(arr) = entry.as_array_mut() {
178                arr.extend(values.clone().into_iter());
179            }
180        } else {
181            // Non-array values use Replace semantics
182            channel_values.insert(channel.clone(), write.value.clone());
183        }
184
185        // Update version counter (common to both branches)
186        *modified_channels.entry(channel.clone()).or_insert(0) += 1;
187    }
188
189    // Step 4: Update checkpoint metadata
190    // Update channel_versions for modified channels
191    for (channel, delta_count) in &modified_channels {
192        let current_version = reconstructed
193            .channel_versions
194            .get(channel)
195            .copied()
196            .unwrap_or(0);
197        reconstructed
198            .channel_versions
199            .insert(channel.clone(), current_version + delta_count);
200    }
201
202    // Update new_versions to reflect the channels modified during recovery
203    reconstructed.new_versions = modified_channels;
204
205    // Clear delta counters since we now have a full snapshot
206    reconstructed.counters_since_delta_snapshot.clear();
207
208    Ok(Some(reconstructed))
209}
210
211/// Time-to-live configuration for checkpoint expiration
212///
213/// Configures automatic cleanup of old checkpoints per design spec §5.7.
214#[derive(Clone, Debug)]
215pub struct TtlConfig {
216    /// Default TTL for checkpoints (None = no expiration)
217    pub default_ttl: Option<std::time::Duration>,
218
219    /// Interval between cleanup sweeps for active background cleanup
220    pub sweep_interval: std::time::Duration,
221
222    /// Maximum number of checkpoints to retain per thread/namespace (None = unlimited)
223    pub max_checkpoints: Option<usize>,
224}
225
226impl TtlConfig {
227    /// Create a new TTL configuration
228    ///
229    /// # Arguments
230    ///
231    /// * `default_ttl` - Default time-to-live for checkpoints (None = no expiration)
232    /// * `sweep_interval` - Interval between active cleanup sweeps
233    /// * `max_checkpoints` - Maximum checkpoints to retain (None = unlimited)
234    #[must_use]
235    pub const fn new(
236        default_ttl: Option<std::time::Duration>,
237        sweep_interval: std::time::Duration,
238        max_checkpoints: Option<usize>,
239    ) -> Self {
240        Self {
241            default_ttl,
242            sweep_interval,
243            max_checkpoints,
244        }
245    }
246
247    /// Create a TTL configuration with no expiration (default)
248    #[must_use]
249    pub const fn disabled() -> Self {
250        Self {
251            default_ttl: None,
252            sweep_interval: std::time::Duration::from_secs(3600),
253            max_checkpoints: None,
254        }
255    }
256
257    /// Check if a checkpoint has expired based on its creation time
258    ///
259    /// # Arguments
260    ///
261    /// * `created_at_str` - ISO 8601 timestamp string from checkpoint
262    ///
263    /// # Returns
264    ///
265    /// * `true` if checkpoint is expired and should be cleaned up
266    /// * `false` if checkpoint is still valid
267    #[must_use]
268    pub fn is_expired(&self, created_at_str: &str) -> bool {
269        let Some(ttl) = self.default_ttl else {
270            return false; // No expiration configured
271        };
272
273        // Parse ISO 8601 timestamp
274        let created_at = match chrono::DateTime::parse_from_rfc3339(created_at_str) {
275            Ok(dt) => dt.with_timezone(&chrono::Utc),
276            Err(_) => return false, // Invalid timestamp, don't expire
277        };
278
279        let now = chrono::Utc::now();
280        let age = now.signed_duration_since(created_at);
281
282        age.to_std().unwrap_or(std::time::Duration::MAX) > ttl
283    }
284}
285
286impl Default for TtlConfig {
287    fn default() -> Self {
288        Self::disabled()
289    }
290}
291
292#[cfg(test)]
293mod tests {
294    use super::*;
295    use juncture_core::config::RunnableConfig;
296
297    #[test]
298    fn test_checkpoint_metadata_serialization() {
299        let metadata = CheckpointMetadata {
300            source: CheckpointSource::Loop,
301            step: 5,
302            writes: std::collections::HashMap::new(),
303            parents: std::collections::HashMap::new(),
304            run_id: "run-123".to_string(),
305        };
306
307        let serialized = serde_json::to_value(&metadata).unwrap();
308        let deserialized: CheckpointMetadata = serde_json::from_value(serialized).unwrap();
309
310        assert!(matches!(deserialized.source, CheckpointSource::Loop));
311        assert_eq!(deserialized.step, 5);
312        assert_eq!(deserialized.run_id, "run-123");
313    }
314
315    #[test]
316    fn test_delta_counters_default() {
317        let counters = DeltaCounters::default();
318        assert_eq!(counters.updates, 0);
319        assert_eq!(counters.supersteps, 0);
320    }
321
322    #[test]
323    fn test_checkpoint_filter_default() {
324        let filter = CheckpointFilter::default();
325        assert!(filter.source.is_none());
326        assert!(filter.step_gte.is_none());
327        assert!(filter.step_lte.is_none());
328        assert!(filter.before.is_none());
329        assert!(filter.after.is_none());
330        assert!(filter.limit.is_none());
331    }
332
333    #[test]
334    fn test_ttl_config_default() {
335        let config = TtlConfig::default();
336        assert!(config.default_ttl.is_none());
337        assert!(config.max_checkpoints.is_none());
338    }
339
340    #[test]
341    fn test_ttl_config_expiration() {
342        use std::time::Duration;
343
344        let config = TtlConfig::new(
345            Some(Duration::from_secs(60)),
346            Duration::from_secs(3600),
347            Some(100),
348        );
349
350        // Current timestamp should not be expired
351        let now = chrono::Utc::now().to_rfc3339();
352        assert!(!config.is_expired(&now));
353
354        // 2 minutes ago should be expired
355        let past = (chrono::Utc::now() - chrono::Duration::seconds(120)).to_rfc3339();
356        assert!(config.is_expired(&past));
357    }
358
359    #[test]
360    fn test_recover_from_deltas_empty_list() {
361        let checkpoints = vec![];
362        let result = recover_from_deltas(&checkpoints, "cp1");
363        assert!(result.is_ok());
364        assert!(result.unwrap().is_none());
365    }
366
367    #[test]
368    fn test_recover_from_deltas_target_not_found() {
369        let checkpoints = vec![create_test_tuple("cp1", 0)];
370        let result = recover_from_deltas(&checkpoints, "cp2");
371        assert!(result.is_ok());
372        assert!(result.unwrap().is_none());
373    }
374
375    #[test]
376    fn test_recover_from_deltas_single_full_checkpoint() {
377        let checkpoints = vec![create_test_tuple("cp1", 0)];
378        let result = recover_from_deltas(&checkpoints, "cp1");
379        assert!(result.is_ok());
380
381        let recovered = result.unwrap().unwrap();
382        assert_eq!(recovered.id, "cp1");
383        assert_eq!(
384            recovered.channel_values["messages"],
385            serde_json::json!(["hello"])
386        );
387    }
388
389    #[test]
390    fn test_recover_from_deltas_with_pending_writes() {
391        let base = create_test_tuple("cp1", 0);
392        let mut delta = create_test_tuple("cp2", 1);
393
394        // Clear channel_values for delta checkpoint to simulate delta-only checkpoint
395        delta.checkpoint.channel_values = serde_json::json!({});
396
397        // Add pending writes to delta checkpoint - use arrays for append semantics
398        delta.pending_writes = vec![
399            PendingWrite {
400                task_id: "task1".to_string(),
401                channel: "messages".to_string(),
402                value: serde_json::json!(["world"]),
403            },
404            PendingWrite {
405                task_id: "task2".to_string(),
406                channel: "messages".to_string(),
407                value: serde_json::json!(["test"]),
408            },
409        ];
410
411        let checkpoints = vec![base, delta];
412        let result = recover_from_deltas(&checkpoints, "cp2");
413        assert!(result.is_ok());
414
415        let recovered = result.unwrap().unwrap();
416        // The recovered checkpoint has the base snapshot's ID since we clone it
417        assert_eq!(recovered.id, "cp1");
418
419        // Check that messages were appended
420        let messages = recovered.channel_values["messages"].as_array().unwrap();
421        assert_eq!(messages.len(), 3); // ["hello", "world", "test"]
422        assert_eq!(messages[0], "hello");
423        assert_eq!(messages[1], "world");
424        assert_eq!(messages[2], "test");
425
426        // Check that channel_versions was updated
427        assert_eq!(recovered.channel_versions.get("messages"), Some(&3));
428    }
429
430    #[test]
431    fn test_recover_from_deltas_no_full_snapshot() {
432        let mut checkpoint = create_test_tuple("cp1", 0);
433        // Clear channel_values to simulate non-full snapshot
434        checkpoint.checkpoint.channel_values = serde_json::json!({});
435
436        let checkpoints = vec![checkpoint];
437        let result = recover_from_deltas(&checkpoints, "cp1");
438        assert!(result.is_err());
439        assert!(matches!(
440            result.unwrap_err(),
441            CheckpointError::Deserialize(_)
442        ));
443    }
444
445    #[test]
446    fn test_recover_from_deltas_multiple_deltas() {
447        let base = create_test_tuple("cp1", 0);
448
449        let mut delta1 = create_test_tuple("cp2", 1);
450        // Clear channel_values for delta checkpoint
451        delta1.checkpoint.channel_values = serde_json::json!({});
452        delta1.pending_writes = vec![PendingWrite {
453            task_id: "task1".to_string(),
454            channel: "messages".to_string(),
455            value: serde_json::json!(["delta1"]),
456        }];
457
458        let mut delta2 = create_test_tuple("cp3", 2);
459        // Clear channel_values for delta checkpoint
460        delta2.checkpoint.channel_values = serde_json::json!({});
461        delta2.pending_writes = vec![
462            PendingWrite {
463                task_id: "task2".to_string(),
464                channel: "messages".to_string(),
465                value: serde_json::json!(["delta2a"]),
466            },
467            PendingWrite {
468                task_id: "task3".to_string(),
469                channel: "messages".to_string(),
470                value: serde_json::json!(["delta2b"]),
471            },
472        ];
473
474        let checkpoints = vec![base, delta1, delta2];
475        let result = recover_from_deltas(&checkpoints, "cp3");
476        assert!(result.is_ok());
477
478        let recovered = result.unwrap().unwrap();
479        // The recovered checkpoint has the base snapshot's ID since we clone it
480        assert_eq!(recovered.id, "cp1");
481
482        // Check that all messages were appended in order
483        let messages = recovered.channel_values["messages"].as_array().unwrap();
484        assert_eq!(messages.len(), 4); // ["hello", "delta1", "delta2a", "delta2b"]
485        assert_eq!(messages[0], "hello");
486        assert_eq!(messages[1], "delta1");
487        assert_eq!(messages[2], "delta2a");
488        assert_eq!(messages[3], "delta2b");
489    }
490
491    // Helper function to create test checkpoint tuples
492    fn create_test_tuple(id: &str, step: i64) -> CheckpointTuple {
493        CheckpointTuple {
494            config: RunnableConfig::default(),
495            checkpoint: Checkpoint {
496                id: id.to_string(),
497                channel_values: serde_json::json!({
498                    "messages": ["hello"]
499                }),
500                channel_versions: HashMap::from([("messages".to_string(), 1)]),
501                versions_seen: HashMap::new(),
502                pending_tasks: vec![],
503                pending_sends: vec![],
504                pending_interrupts: vec![],
505                schema_version: 1,
506                created_at: chrono::Utc::now().to_rfc3339(),
507                v: 1,
508                new_versions: HashMap::new(),
509                counters_since_delta_snapshot: HashMap::new(),
510            },
511            metadata: CheckpointMetadata {
512                source: CheckpointSource::Loop,
513                step,
514                writes: HashMap::new(),
515                parents: HashMap::new(),
516                run_id: "test-run".to_string(),
517            },
518            pending_writes: vec![],
519            parent_config: None,
520        }
521    }
522}
523
524// Rust guideline compliant 2026-05-23