Skip to main content

atomr_agents_state/
checkpointer.rs

1//! Pluggable checkpointer.
2//!
3//! `InMemoryCheckpointer` is the default. SQLite + Postgres backends
4//! land in Phase R17 behind feature flags.
5
6use std::collections::HashMap;
7use std::sync::Arc;
8
9use async_trait::async_trait;
10use atomr_agents_core::{Result, RunId, Value, WorkflowId};
11use parking_lot::RwLock;
12use serde::{Deserialize, Serialize};
13
14#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
15pub struct CheckpointKey {
16    pub workflow_id: WorkflowId,
17    pub run_id: RunId,
18    pub super_step: u64,
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct CheckpointMeta {
23    pub workflow_id: WorkflowId,
24    pub run_id: RunId,
25    pub super_step: u64,
26    pub timestamp_ms: i64,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct Snapshot {
31    pub key: CheckpointKey,
32    pub values: HashMap<String, Value>,
33    /// Optional human label describing what produced this snapshot
34    /// (step name, "interrupt", etc.).
35    #[serde(default)]
36    pub label: String,
37    pub timestamp_ms: i64,
38}
39
40#[async_trait]
41pub trait Checkpointer: Send + Sync + 'static {
42    async fn save(&self, snapshot: Snapshot) -> Result<()>;
43    async fn load(&self, key: &CheckpointKey) -> Result<Option<Snapshot>>;
44    /// Returns the latest snapshot for a `(workflow, run)` pair.
45    async fn latest(&self, workflow_id: &WorkflowId, run_id: &RunId) -> Result<Option<Snapshot>>;
46    async fn list(&self, workflow_id: &WorkflowId, run_id: &RunId) -> Result<Vec<CheckpointMeta>>;
47    /// Create a new run that diverges from an existing checkpoint
48    /// with an optional set of state edits applied at the fork point.
49    async fn fork(&self, from: &CheckpointKey, edits: Vec<(String, Value)>) -> Result<RunId>;
50}
51
52#[derive(Default, Clone)]
53pub struct InMemoryCheckpointer {
54    inner: Arc<RwLock<Vec<Snapshot>>>,
55}
56
57impl InMemoryCheckpointer {
58    pub fn new() -> Self {
59        Self::default()
60    }
61
62    pub fn len(&self) -> usize {
63        self.inner.read().len()
64    }
65
66    pub fn is_empty(&self) -> bool {
67        self.inner.read().is_empty()
68    }
69}
70
71#[async_trait]
72impl Checkpointer for InMemoryCheckpointer {
73    async fn save(&self, snapshot: Snapshot) -> Result<()> {
74        self.inner.write().push(snapshot);
75        Ok(())
76    }
77
78    async fn load(&self, key: &CheckpointKey) -> Result<Option<Snapshot>> {
79        Ok(self
80            .inner
81            .read()
82            .iter()
83            .find(|s| {
84                s.key.workflow_id.as_str() == key.workflow_id.as_str()
85                    && s.key.run_id.as_str() == key.run_id.as_str()
86                    && s.key.super_step == key.super_step
87            })
88            .cloned())
89    }
90
91    async fn latest(&self, workflow_id: &WorkflowId, run_id: &RunId) -> Result<Option<Snapshot>> {
92        let g = self.inner.read();
93        Ok(g.iter()
94            .filter(|s| {
95                s.key.workflow_id.as_str() == workflow_id.as_str() && s.key.run_id.as_str() == run_id.as_str()
96            })
97            .max_by_key(|s| s.key.super_step)
98            .cloned())
99    }
100
101    async fn list(&self, workflow_id: &WorkflowId, run_id: &RunId) -> Result<Vec<CheckpointMeta>> {
102        Ok(self
103            .inner
104            .read()
105            .iter()
106            .filter(|s| {
107                s.key.workflow_id.as_str() == workflow_id.as_str() && s.key.run_id.as_str() == run_id.as_str()
108            })
109            .map(|s| CheckpointMeta {
110                workflow_id: s.key.workflow_id.clone(),
111                run_id: s.key.run_id.clone(),
112                super_step: s.key.super_step,
113                timestamp_ms: s.timestamp_ms,
114            })
115            .collect())
116    }
117
118    async fn fork(&self, from: &CheckpointKey, edits: Vec<(String, Value)>) -> Result<RunId> {
119        let snap = self.load(from).await?.ok_or_else(|| {
120            atomr_agents_core::AgentError::Internal(format!(
121                "fork: source checkpoint {}#{} not found",
122                from.run_id.as_str(),
123                from.super_step
124            ))
125        })?;
126        let new_run = RunId::new();
127        let mut values = snap.values.clone();
128        for (k, v) in edits {
129            values.insert(k, v);
130        }
131        self.save(Snapshot {
132            key: CheckpointKey {
133                workflow_id: snap.key.workflow_id.clone(),
134                run_id: new_run.clone(),
135                super_step: snap.key.super_step,
136            },
137            values,
138            label: format!("fork-of:{}", from.run_id.as_str()),
139            timestamp_ms: chrono_now_ms(),
140        })
141        .await?;
142        Ok(new_run)
143    }
144}
145
146fn chrono_now_ms() -> i64 {
147    use std::time::{SystemTime, UNIX_EPOCH};
148    SystemTime::now()
149        .duration_since(UNIX_EPOCH)
150        .map(|d| d.as_millis() as i64)
151        .unwrap_or(0)
152}
153
154#[cfg(test)]
155mod tests {
156    use super::*;
157    use serde_json::json;
158
159    fn snap(wf: &str, run: &str, step: u64, label: &str, kvs: Vec<(&str, Value)>) -> Snapshot {
160        let mut values = HashMap::new();
161        for (k, v) in kvs {
162            values.insert(k.into(), v);
163        }
164        Snapshot {
165            key: CheckpointKey {
166                workflow_id: WorkflowId::from(wf),
167                run_id: RunId::from(run),
168                super_step: step,
169            },
170            values,
171            label: label.into(),
172            timestamp_ms: chrono_now_ms(),
173        }
174    }
175
176    #[tokio::test]
177    async fn save_and_latest() {
178        let c = InMemoryCheckpointer::new();
179        c.save(snap("wf", "r", 0, "init", vec![("messages", json!([]))]))
180            .await
181            .unwrap();
182        c.save(snap(
183            "wf",
184            "r",
185            2,
186            "after",
187            vec![("messages", json!([{"id": "m1"}]))],
188        ))
189        .await
190        .unwrap();
191        let latest = c
192            .latest(&WorkflowId::from("wf"), &RunId::from("r"))
193            .await
194            .unwrap()
195            .unwrap();
196        assert_eq!(latest.key.super_step, 2);
197        assert_eq!(latest.values["messages"][0]["id"], "m1");
198    }
199
200    #[tokio::test]
201    async fn fork_creates_new_run_with_edits() {
202        let c = InMemoryCheckpointer::new();
203        c.save(snap("wf", "main", 1, "before-fork", vec![("a", json!(1))]))
204            .await
205            .unwrap();
206        let new_run = c
207            .fork(
208                &CheckpointKey {
209                    workflow_id: WorkflowId::from("wf"),
210                    run_id: RunId::from("main"),
211                    super_step: 1,
212                },
213                vec![("a".into(), json!(99))],
214            )
215            .await
216            .unwrap();
217        let forked = c
218            .latest(&WorkflowId::from("wf"), &new_run)
219            .await
220            .unwrap()
221            .unwrap();
222        assert_eq!(forked.values["a"], json!(99));
223        assert!(forked.label.starts_with("fork-of:main"));
224    }
225
226    #[tokio::test]
227    async fn list_returns_meta_in_order() {
228        let c = InMemoryCheckpointer::new();
229        for step in [0u64, 1, 2, 3] {
230            c.save(snap("wf", "r", step, "step", vec![])).await.unwrap();
231        }
232        let metas = c.list(&WorkflowId::from("wf"), &RunId::from("r")).await.unwrap();
233        assert_eq!(metas.len(), 4);
234    }
235}