a3s_code_core/orchestration/
checkpoint.rs1use super::executor::StepOutcome;
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12
13pub const WORKFLOW_CHECKPOINT_SCHEMA_VERSION: u32 = 1;
16
17#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
19pub struct WorkflowStepRecord {
20 pub task_id: String,
23 pub outcome: StepOutcome,
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
32pub struct WorkflowCheckpoint {
33 #[serde(default)]
35 pub schema_version: u32,
36 pub workflow_id: String,
38 pub steps: Vec<WorkflowStepRecord>,
41 pub checkpoint_ms: u64,
43}
44
45impl WorkflowCheckpoint {
46 pub fn from_completed(
48 workflow_id: impl Into<String>,
49 completed: &HashMap<String, StepOutcome>,
50 checkpoint_ms: u64,
51 ) -> Self {
52 let steps = completed
53 .iter()
54 .map(|(task_id, outcome)| WorkflowStepRecord {
55 task_id: task_id.clone(),
56 outcome: outcome.clone(),
57 })
58 .collect();
59 Self {
60 schema_version: WORKFLOW_CHECKPOINT_SCHEMA_VERSION,
61 workflow_id: workflow_id.into(),
62 steps,
63 checkpoint_ms,
64 }
65 }
66
67 pub fn completed(&self) -> HashMap<String, StepOutcome> {
69 self.steps
70 .iter()
71 .map(|r| (r.task_id.clone(), r.outcome.clone()))
72 .collect()
73 }
74
75 pub fn ensure_loadable(&self) -> anyhow::Result<()> {
81 if self.schema_version > WORKFLOW_CHECKPOINT_SCHEMA_VERSION {
82 anyhow::bail!(
83 "workflow checkpoint {} has schema version {} but this build supports at \
84 most {}; refusing to resume from an incompatible future checkpoint",
85 self.workflow_id,
86 self.schema_version,
87 WORKFLOW_CHECKPOINT_SCHEMA_VERSION
88 );
89 }
90 Ok(())
91 }
92}
93
94#[cfg(test)]
95mod tests {
96 use super::*;
97
98 fn outcome(id: &str) -> StepOutcome {
99 StepOutcome {
100 task_id: id.to_string(),
101 session_id: format!("task-run-{id}"),
102 agent: "a".to_string(),
103 output: "o".to_string(),
104 success: true,
105 structured: None,
106 }
107 }
108
109 #[test]
110 fn round_trips_and_exposes_completed_map() {
111 let mut completed = HashMap::new();
112 completed.insert("t1".to_string(), outcome("t1"));
113 let cp = WorkflowCheckpoint::from_completed("wf", &completed, 123);
114 let back: WorkflowCheckpoint =
115 serde_json::from_str(&serde_json::to_string(&cp).unwrap()).unwrap();
116 assert_eq!(back, cp);
117 assert_eq!(back.schema_version, WORKFLOW_CHECKPOINT_SCHEMA_VERSION);
118 assert_eq!(back.checkpoint_ms, 123);
119 assert_eq!(back.completed().get("t1").unwrap().task_id, "t1");
120 }
121
122 #[test]
123 fn ensure_loadable_rejects_only_future_versions() {
124 let mut cp = WorkflowCheckpoint::from_completed("wf", &HashMap::new(), 0);
125 cp.ensure_loadable().expect("current version loadable");
126 cp.schema_version = 0;
127 cp.ensure_loadable().expect("pre-v1 loadable");
128 cp.schema_version = WORKFLOW_CHECKPOINT_SCHEMA_VERSION + 1;
129 let err = cp.ensure_loadable().unwrap_err();
130 assert!(err.to_string().contains("schema version"), "got: {err}");
131 }
132
133 #[test]
134 fn pre_v1_payload_without_schema_version_loads() {
135 let json = r#"{"workflow_id":"wf","steps":[],"checkpoint_ms":0}"#;
136 let cp: WorkflowCheckpoint = serde_json::from_str(json).unwrap();
137 assert_eq!(cp.schema_version, 0);
138 }
139}