use std::collections::HashMap;
pub use juncture_core::checkpoint::{
Checkpoint, CheckpointFilter, CheckpointMetadata, CheckpointPendingTask, CheckpointSource,
CheckpointTuple, DeltaCounters, DeltaOp, PendingWrite, PregelTaskInfo as PregelTaskInfoExport,
SerializedSend, StateSnapshot,
};
use crate::error::CheckpointError;
pub type PregelTaskInfo = PregelTaskInfoExport;
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct DeltaSnapshot {
pub base_checkpoint_id: String,
pub deltas: Vec<ChannelDelta>,
}
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct ChannelDelta {
pub channel: String,
pub op: DeltaOp,
pub values: Vec<serde_json::Value>,
}
pub fn recover_from_deltas(
checkpoints: &[CheckpointTuple],
target_checkpoint_id: &str,
) -> Result<Option<Checkpoint>, CheckpointError> {
let target_index = checkpoints
.iter()
.position(|t| t.checkpoint.id == target_checkpoint_id);
let Some(target_idx) = target_index else {
return Ok(None);
};
let relevant_checkpoints = &checkpoints[..=target_idx];
let base_snapshot = relevant_checkpoints
.iter()
.rev()
.find(|t| {
!t.checkpoint.channel_values.is_null()
&& t.checkpoint
.channel_values
.as_object()
.is_some_and(|obj| !obj.is_empty())
})
.ok_or_else(|| {
CheckpointError::deserialize_msg(
"No full snapshot found in checkpoint chain".to_string(),
)
})?;
let mut reconstructed = base_snapshot.checkpoint.clone();
let mut all_deltas: Vec<(&String, PendingWrite)> = Vec::new();
for tuple in relevant_checkpoints {
if tuple.checkpoint.id <= base_snapshot.checkpoint.id {
continue;
}
for write in &tuple.pending_writes {
all_deltas.push((&tuple.checkpoint.id, write.clone()));
}
}
all_deltas.sort_by(|a, b| a.0.cmp(b.0));
let channel_values = reconstructed
.channel_values
.as_object_mut()
.ok_or_else(|| {
CheckpointError::deserialize_msg(
"Base checkpoint channel_values is not an object".to_string(),
)
})?;
let mut modified_channels = HashMap::<String, u64>::new();
for (_checkpoint_id, write) in all_deltas {
let channel = &write.channel;
if let serde_json::Value::Array(values) = &write.value {
let entry = channel_values
.entry(channel.clone())
.or_insert(serde_json::Value::Array(vec![]));
if let Some(arr) = entry.as_array_mut() {
arr.extend(values.clone().into_iter());
}
} else {
channel_values.insert(channel.clone(), write.value.clone());
}
*modified_channels.entry(channel.clone()).or_insert(0) += 1;
}
for (channel, delta_count) in &modified_channels {
let current_version = reconstructed
.channel_versions
.get(channel)
.copied()
.unwrap_or(0);
reconstructed
.channel_versions
.insert(channel.clone(), current_version + delta_count);
}
reconstructed.new_versions = modified_channels;
reconstructed.counters_since_delta_snapshot.clear();
Ok(Some(reconstructed))
}
#[derive(Clone, Debug)]
pub struct TtlConfig {
pub default_ttl: Option<std::time::Duration>,
pub sweep_interval: std::time::Duration,
pub max_checkpoints: Option<usize>,
}
impl TtlConfig {
#[must_use]
pub const fn new(
default_ttl: Option<std::time::Duration>,
sweep_interval: std::time::Duration,
max_checkpoints: Option<usize>,
) -> Self {
Self {
default_ttl,
sweep_interval,
max_checkpoints,
}
}
#[must_use]
pub const fn disabled() -> Self {
Self {
default_ttl: None,
sweep_interval: std::time::Duration::from_secs(3600),
max_checkpoints: None,
}
}
#[must_use]
pub fn is_expired(&self, created_at_str: &str) -> bool {
let Some(ttl) = self.default_ttl else {
return false; };
let created_at = match chrono::DateTime::parse_from_rfc3339(created_at_str) {
Ok(dt) => dt.with_timezone(&chrono::Utc),
Err(_) => return false, };
let now = chrono::Utc::now();
let age = now.signed_duration_since(created_at);
age.to_std().unwrap_or(std::time::Duration::MAX) > ttl
}
}
impl Default for TtlConfig {
fn default() -> Self {
Self::disabled()
}
}
#[cfg(test)]
mod tests {
use super::*;
use juncture_core::config::RunnableConfig;
#[test]
fn test_checkpoint_metadata_serialization() {
let metadata = CheckpointMetadata {
source: CheckpointSource::Loop,
step: 5,
writes: std::collections::HashMap::new(),
parents: std::collections::HashMap::new(),
run_id: "run-123".to_string(),
};
let serialized = serde_json::to_value(&metadata).unwrap();
let deserialized: CheckpointMetadata = serde_json::from_value(serialized).unwrap();
assert!(matches!(deserialized.source, CheckpointSource::Loop));
assert_eq!(deserialized.step, 5);
assert_eq!(deserialized.run_id, "run-123");
}
#[test]
fn test_delta_counters_default() {
let counters = DeltaCounters::default();
assert_eq!(counters.updates, 0);
assert_eq!(counters.supersteps, 0);
}
#[test]
fn test_checkpoint_filter_default() {
let filter = CheckpointFilter::default();
assert!(filter.source.is_none());
assert!(filter.step_gte.is_none());
assert!(filter.step_lte.is_none());
assert!(filter.before.is_none());
assert!(filter.after.is_none());
assert!(filter.limit.is_none());
}
#[test]
fn test_ttl_config_default() {
let config = TtlConfig::default();
assert!(config.default_ttl.is_none());
assert!(config.max_checkpoints.is_none());
}
#[test]
fn test_ttl_config_expiration() {
use std::time::Duration;
let config = TtlConfig::new(
Some(Duration::from_secs(60)),
Duration::from_secs(3600),
Some(100),
);
let now = chrono::Utc::now().to_rfc3339();
assert!(!config.is_expired(&now));
let past = (chrono::Utc::now() - chrono::Duration::seconds(120)).to_rfc3339();
assert!(config.is_expired(&past));
}
#[test]
fn test_recover_from_deltas_empty_list() {
let checkpoints = vec![];
let result = recover_from_deltas(&checkpoints, "cp1");
assert!(result.is_ok());
assert!(result.unwrap().is_none());
}
#[test]
fn test_recover_from_deltas_target_not_found() {
let checkpoints = vec![create_test_tuple("cp1", 0)];
let result = recover_from_deltas(&checkpoints, "cp2");
assert!(result.is_ok());
assert!(result.unwrap().is_none());
}
#[test]
fn test_recover_from_deltas_single_full_checkpoint() {
let checkpoints = vec![create_test_tuple("cp1", 0)];
let result = recover_from_deltas(&checkpoints, "cp1");
assert!(result.is_ok());
let recovered = result.unwrap().unwrap();
assert_eq!(recovered.id, "cp1");
assert_eq!(
recovered.channel_values["messages"],
serde_json::json!(["hello"])
);
}
#[test]
fn test_recover_from_deltas_with_pending_writes() {
let base = create_test_tuple("cp1", 0);
let mut delta = create_test_tuple("cp2", 1);
delta.checkpoint.channel_values = serde_json::json!({});
delta.pending_writes = vec![
PendingWrite {
task_id: "task1".to_string(),
channel: "messages".to_string(),
value: serde_json::json!(["world"]),
},
PendingWrite {
task_id: "task2".to_string(),
channel: "messages".to_string(),
value: serde_json::json!(["test"]),
},
];
let checkpoints = vec![base, delta];
let result = recover_from_deltas(&checkpoints, "cp2");
assert!(result.is_ok());
let recovered = result.unwrap().unwrap();
assert_eq!(recovered.id, "cp1");
let messages = recovered.channel_values["messages"].as_array().unwrap();
assert_eq!(messages.len(), 3); assert_eq!(messages[0], "hello");
assert_eq!(messages[1], "world");
assert_eq!(messages[2], "test");
assert_eq!(recovered.channel_versions.get("messages"), Some(&3));
}
#[test]
fn test_recover_from_deltas_no_full_snapshot() {
let mut checkpoint = create_test_tuple("cp1", 0);
checkpoint.checkpoint.channel_values = serde_json::json!({});
let checkpoints = vec![checkpoint];
let result = recover_from_deltas(&checkpoints, "cp1");
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
CheckpointError::Deserialize(_)
));
}
#[test]
fn test_recover_from_deltas_multiple_deltas() {
let base = create_test_tuple("cp1", 0);
let mut delta1 = create_test_tuple("cp2", 1);
delta1.checkpoint.channel_values = serde_json::json!({});
delta1.pending_writes = vec![PendingWrite {
task_id: "task1".to_string(),
channel: "messages".to_string(),
value: serde_json::json!(["delta1"]),
}];
let mut delta2 = create_test_tuple("cp3", 2);
delta2.checkpoint.channel_values = serde_json::json!({});
delta2.pending_writes = vec![
PendingWrite {
task_id: "task2".to_string(),
channel: "messages".to_string(),
value: serde_json::json!(["delta2a"]),
},
PendingWrite {
task_id: "task3".to_string(),
channel: "messages".to_string(),
value: serde_json::json!(["delta2b"]),
},
];
let checkpoints = vec![base, delta1, delta2];
let result = recover_from_deltas(&checkpoints, "cp3");
assert!(result.is_ok());
let recovered = result.unwrap().unwrap();
assert_eq!(recovered.id, "cp1");
let messages = recovered.channel_values["messages"].as_array().unwrap();
assert_eq!(messages.len(), 4); assert_eq!(messages[0], "hello");
assert_eq!(messages[1], "delta1");
assert_eq!(messages[2], "delta2a");
assert_eq!(messages[3], "delta2b");
}
fn create_test_tuple(id: &str, step: i64) -> CheckpointTuple {
CheckpointTuple {
config: RunnableConfig::default(),
checkpoint: Checkpoint {
id: id.to_string(),
channel_values: serde_json::json!({
"messages": ["hello"]
}),
channel_versions: HashMap::from([("messages".to_string(), 1)]),
versions_seen: HashMap::new(),
pending_tasks: vec![],
pending_sends: vec![],
pending_interrupts: vec![],
schema_version: 1,
created_at: chrono::Utc::now().to_rfc3339(),
v: 1,
new_versions: HashMap::new(),
counters_since_delta_snapshot: HashMap::new(),
},
metadata: CheckpointMetadata {
source: CheckpointSource::Loop,
step,
writes: HashMap::new(),
parents: HashMap::new(),
run_id: "test-run".to_string(),
},
pending_writes: vec![],
parent_config: None,
}
}
}