Skip to main content

a3s_code_core/orchestration/
checkpoint.rs

1//! Workflow-level checkpoints: journal completed steps so an interrupted
2//! orchestration resumes from the longest completed prefix — on this node or,
3//! because the checkpoint is serializable and the executor is pluggable, on
4//! another one (host-driven migration).
5//!
6//! This is the step-boundary analogue of [`LoopCheckpoint`](crate::loop_checkpoint::LoopCheckpoint),
7//! which checkpoints at tool-round boundaries one level down.
8
9use super::executor::StepOutcome;
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12
13/// Schema version. Bumped on incompatible format changes; loads from a future
14/// version are rejected (see [`WorkflowCheckpoint::ensure_loadable`]).
15pub const WORKFLOW_CHECKPOINT_SCHEMA_VERSION: u32 = 1;
16
17/// One completed step within a workflow.
18#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
19pub struct WorkflowStepRecord {
20    /// Matches the [`AgentStepSpec::task_id`](super::AgentStepSpec) of the
21    /// step that produced this outcome.
22    pub task_id: String,
23    /// The completed step's result.
24    pub outcome: StepOutcome,
25}
26
27/// Snapshot of a workflow's completed steps at a step boundary.
28///
29/// (`StepOutcome` contains a `serde_json::Value`, which is not `Eq`, so this
30/// derives `PartialEq` only.)
31#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
32pub struct WorkflowCheckpoint {
33    /// Schema version — see [`WORKFLOW_CHECKPOINT_SCHEMA_VERSION`].
34    #[serde(default)]
35    pub schema_version: u32,
36    /// Logical workflow identifier the checkpoint is keyed by.
37    pub workflow_id: String,
38    /// The steps completed so far. A resuming run skips these and re-dispatches
39    /// only the rest.
40    pub steps: Vec<WorkflowStepRecord>,
41    /// Wall-clock timestamp when the checkpoint was written (Unix epoch ms).
42    pub checkpoint_ms: u64,
43}
44
45impl WorkflowCheckpoint {
46    /// Build a checkpoint from a map of completed `task_id -> outcome`.
47    pub fn from_completed(
48        workflow_id: impl Into<String>,
49        completed: &HashMap<String, StepOutcome>,
50        checkpoint_ms: u64,
51    ) -> Self {
52        let steps = completed
53            .iter()
54            .map(|(task_id, outcome)| WorkflowStepRecord {
55                task_id: task_id.clone(),
56                outcome: outcome.clone(),
57            })
58            .collect();
59        Self {
60            schema_version: WORKFLOW_CHECKPOINT_SCHEMA_VERSION,
61            workflow_id: workflow_id.into(),
62            steps,
63            checkpoint_ms,
64        }
65    }
66
67    /// The completed steps as a `task_id -> outcome` map.
68    pub fn completed(&self) -> HashMap<String, StepOutcome> {
69        self.steps
70            .iter()
71            .map(|r| (r.task_id.clone(), r.outcome.clone()))
72            .collect()
73    }
74
75    /// Reject a checkpoint written by a *newer*, incompatible schema version
76    /// than this build understands — mirrors
77    /// [`LoopCheckpoint::ensure_loadable`](crate::loop_checkpoint::LoopCheckpoint::ensure_loadable).
78    /// Field additions are absorbed by `#[serde(default)]`, so older (incl.
79    /// pre-v1 `0`) checkpoints always remain loadable.
80    pub fn ensure_loadable(&self) -> anyhow::Result<()> {
81        if self.schema_version > WORKFLOW_CHECKPOINT_SCHEMA_VERSION {
82            anyhow::bail!(
83                "workflow checkpoint {} has schema version {} but this build supports at \
84                 most {}; refusing to resume from an incompatible future checkpoint",
85                self.workflow_id,
86                self.schema_version,
87                WORKFLOW_CHECKPOINT_SCHEMA_VERSION
88            );
89        }
90        Ok(())
91    }
92}
93
94#[cfg(test)]
95mod tests {
96    use super::*;
97
98    fn outcome(id: &str) -> StepOutcome {
99        StepOutcome {
100            task_id: id.to_string(),
101            session_id: format!("task-run-{id}"),
102            agent: "a".to_string(),
103            output: "o".to_string(),
104            success: true,
105            structured: None,
106        }
107    }
108
109    #[test]
110    fn round_trips_and_exposes_completed_map() {
111        let mut completed = HashMap::new();
112        completed.insert("t1".to_string(), outcome("t1"));
113        let cp = WorkflowCheckpoint::from_completed("wf", &completed, 123);
114        let back: WorkflowCheckpoint =
115            serde_json::from_str(&serde_json::to_string(&cp).unwrap()).unwrap();
116        assert_eq!(back, cp);
117        assert_eq!(back.schema_version, WORKFLOW_CHECKPOINT_SCHEMA_VERSION);
118        assert_eq!(back.checkpoint_ms, 123);
119        assert_eq!(back.completed().get("t1").unwrap().task_id, "t1");
120    }
121
122    #[test]
123    fn ensure_loadable_rejects_only_future_versions() {
124        let mut cp = WorkflowCheckpoint::from_completed("wf", &HashMap::new(), 0);
125        cp.ensure_loadable().expect("current version loadable");
126        cp.schema_version = 0;
127        cp.ensure_loadable().expect("pre-v1 loadable");
128        cp.schema_version = WORKFLOW_CHECKPOINT_SCHEMA_VERSION + 1;
129        let err = cp.ensure_loadable().unwrap_err();
130        assert!(err.to_string().contains("schema version"), "got: {err}");
131    }
132
133    #[test]
134    fn pre_v1_payload_without_schema_version_loads() {
135        let json = r#"{"workflow_id":"wf","steps":[],"checkpoint_ms":0}"#;
136        let cp: WorkflowCheckpoint = serde_json::from_str(json).unwrap();
137        assert_eq!(cp.schema_version, 0);
138    }
139}