1use std::collections::HashMap;
7use std::sync::Arc;
8
9use async_trait::async_trait;
10use atomr_agents_core::{Result, RunId, Value, WorkflowId};
11use parking_lot::RwLock;
12use serde::{Deserialize, Serialize};
13
14#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
15pub struct CheckpointKey {
16 pub workflow_id: WorkflowId,
17 pub run_id: RunId,
18 pub super_step: u64,
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct CheckpointMeta {
23 pub workflow_id: WorkflowId,
24 pub run_id: RunId,
25 pub super_step: u64,
26 pub timestamp_ms: i64,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct Snapshot {
31 pub key: CheckpointKey,
32 pub values: HashMap<String, Value>,
33 #[serde(default)]
36 pub label: String,
37 pub timestamp_ms: i64,
38}
39
40#[async_trait]
41pub trait Checkpointer: Send + Sync + 'static {
42 async fn save(&self, snapshot: Snapshot) -> Result<()>;
43 async fn load(&self, key: &CheckpointKey) -> Result<Option<Snapshot>>;
44 async fn latest(&self, workflow_id: &WorkflowId, run_id: &RunId) -> Result<Option<Snapshot>>;
46 async fn list(&self, workflow_id: &WorkflowId, run_id: &RunId) -> Result<Vec<CheckpointMeta>>;
47 async fn fork(&self, from: &CheckpointKey, edits: Vec<(String, Value)>) -> Result<RunId>;
50}
51
52#[derive(Default, Clone)]
53pub struct InMemoryCheckpointer {
54 inner: Arc<RwLock<Vec<Snapshot>>>,
55}
56
57impl InMemoryCheckpointer {
58 pub fn new() -> Self {
59 Self::default()
60 }
61
62 pub fn len(&self) -> usize {
63 self.inner.read().len()
64 }
65
66 pub fn is_empty(&self) -> bool {
67 self.inner.read().is_empty()
68 }
69}
70
71#[async_trait]
72impl Checkpointer for InMemoryCheckpointer {
73 async fn save(&self, snapshot: Snapshot) -> Result<()> {
74 self.inner.write().push(snapshot);
75 Ok(())
76 }
77
78 async fn load(&self, key: &CheckpointKey) -> Result<Option<Snapshot>> {
79 Ok(self
80 .inner
81 .read()
82 .iter()
83 .find(|s| {
84 s.key.workflow_id.as_str() == key.workflow_id.as_str()
85 && s.key.run_id.as_str() == key.run_id.as_str()
86 && s.key.super_step == key.super_step
87 })
88 .cloned())
89 }
90
91 async fn latest(&self, workflow_id: &WorkflowId, run_id: &RunId) -> Result<Option<Snapshot>> {
92 let g = self.inner.read();
93 Ok(g.iter()
94 .filter(|s| {
95 s.key.workflow_id.as_str() == workflow_id.as_str() && s.key.run_id.as_str() == run_id.as_str()
96 })
97 .max_by_key(|s| s.key.super_step)
98 .cloned())
99 }
100
101 async fn list(&self, workflow_id: &WorkflowId, run_id: &RunId) -> Result<Vec<CheckpointMeta>> {
102 Ok(self
103 .inner
104 .read()
105 .iter()
106 .filter(|s| {
107 s.key.workflow_id.as_str() == workflow_id.as_str() && s.key.run_id.as_str() == run_id.as_str()
108 })
109 .map(|s| CheckpointMeta {
110 workflow_id: s.key.workflow_id.clone(),
111 run_id: s.key.run_id.clone(),
112 super_step: s.key.super_step,
113 timestamp_ms: s.timestamp_ms,
114 })
115 .collect())
116 }
117
118 async fn fork(&self, from: &CheckpointKey, edits: Vec<(String, Value)>) -> Result<RunId> {
119 let snap = self.load(from).await?.ok_or_else(|| {
120 atomr_agents_core::AgentError::Internal(format!(
121 "fork: source checkpoint {}#{} not found",
122 from.run_id.as_str(),
123 from.super_step
124 ))
125 })?;
126 let new_run = RunId::new();
127 let mut values = snap.values.clone();
128 for (k, v) in edits {
129 values.insert(k, v);
130 }
131 self.save(Snapshot {
132 key: CheckpointKey {
133 workflow_id: snap.key.workflow_id.clone(),
134 run_id: new_run.clone(),
135 super_step: snap.key.super_step,
136 },
137 values,
138 label: format!("fork-of:{}", from.run_id.as_str()),
139 timestamp_ms: chrono_now_ms(),
140 })
141 .await?;
142 Ok(new_run)
143 }
144}
145
146fn chrono_now_ms() -> i64 {
147 use std::time::{SystemTime, UNIX_EPOCH};
148 SystemTime::now()
149 .duration_since(UNIX_EPOCH)
150 .map(|d| d.as_millis() as i64)
151 .unwrap_or(0)
152}
153
154#[cfg(test)]
155mod tests {
156 use super::*;
157 use serde_json::json;
158
159 fn snap(wf: &str, run: &str, step: u64, label: &str, kvs: Vec<(&str, Value)>) -> Snapshot {
160 let mut values = HashMap::new();
161 for (k, v) in kvs {
162 values.insert(k.into(), v);
163 }
164 Snapshot {
165 key: CheckpointKey {
166 workflow_id: WorkflowId::from(wf),
167 run_id: RunId::from(run),
168 super_step: step,
169 },
170 values,
171 label: label.into(),
172 timestamp_ms: chrono_now_ms(),
173 }
174 }
175
176 #[tokio::test]
177 async fn save_and_latest() {
178 let c = InMemoryCheckpointer::new();
179 c.save(snap("wf", "r", 0, "init", vec![("messages", json!([]))]))
180 .await
181 .unwrap();
182 c.save(snap(
183 "wf",
184 "r",
185 2,
186 "after",
187 vec![("messages", json!([{"id": "m1"}]))],
188 ))
189 .await
190 .unwrap();
191 let latest = c
192 .latest(&WorkflowId::from("wf"), &RunId::from("r"))
193 .await
194 .unwrap()
195 .unwrap();
196 assert_eq!(latest.key.super_step, 2);
197 assert_eq!(latest.values["messages"][0]["id"], "m1");
198 }
199
200 #[tokio::test]
201 async fn fork_creates_new_run_with_edits() {
202 let c = InMemoryCheckpointer::new();
203 c.save(snap("wf", "main", 1, "before-fork", vec![("a", json!(1))]))
204 .await
205 .unwrap();
206 let new_run = c
207 .fork(
208 &CheckpointKey {
209 workflow_id: WorkflowId::from("wf"),
210 run_id: RunId::from("main"),
211 super_step: 1,
212 },
213 vec![("a".into(), json!(99))],
214 )
215 .await
216 .unwrap();
217 let forked = c
218 .latest(&WorkflowId::from("wf"), &new_run)
219 .await
220 .unwrap()
221 .unwrap();
222 assert_eq!(forked.values["a"], json!(99));
223 assert!(forked.label.starts_with("fork-of:main"));
224 }
225
226 #[tokio::test]
227 async fn list_returns_meta_in_order() {
228 let c = InMemoryCheckpointer::new();
229 for step in [0u64, 1, 2, 3] {
230 c.save(snap("wf", "r", step, "step", vec![])).await.unwrap();
231 }
232 let metas = c.list(&WorkflowId::from("wf"), &RunId::from("r")).await.unwrap();
233 assert_eq!(metas.len(), 4);
234 }
235}