1use std::collections::HashSet;
11use std::sync::Arc;
12
13use async_trait::async_trait;
14use atomr_agents_core::{AgentError, Result, RunId, Value, WorkflowId};
15use atomr_agents_state::{CheckpointKey, Checkpointer, RunState, Snapshot, StateSchema};
16
17use crate::dag::{Dag, StepId};
18
19#[async_trait]
20pub trait StatefulStep: Send + Sync + 'static {
21 async fn run(&self, state: &RunState) -> Result<Vec<(String, Value)>>;
22}
23
24pub struct StatefulRunner {
25 pub workflow_id: WorkflowId,
26 pub run_id: RunId,
27 pub dag: Dag<Arc<dyn StatefulStep>>,
28 pub schema: Arc<StateSchema>,
29 pub checkpointer: Arc<dyn Checkpointer>,
30}
31
32impl StatefulRunner {
33 pub async fn run(&self) -> Result<RunState> {
34 let mut state = match self.checkpointer.latest(&self.workflow_id, &self.run_id).await? {
36 Some(snap) => RunState::from_snapshot(self.schema.clone(), snap.values, snap.key.super_step),
37 None => RunState::new(self.schema.clone()),
38 };
39 let order = self.dag.topo_sort()?;
40 let layers = self.layered(&order);
42 let resume_at = state.super_step();
43 let mut completed: HashSet<StepId> = HashSet::new();
44 for (layer_idx, layer) in layers.iter().enumerate() {
45 let super_step = layer_idx as u64 + 1;
46 if super_step <= resume_at {
47 for sid in layer {
48 completed.insert(sid.clone());
49 }
50 continue;
51 }
52 let mut handles = Vec::new();
54 for sid in layer {
55 let step = self
56 .dag
57 .steps
58 .get(sid)
59 .ok_or_else(|| AgentError::Workflow(format!("missing step {}", sid.as_str())))?;
60 let step = step.clone();
61 let st = state.clone();
62 handles.push(tokio::spawn(async move { step.run(&st).await }));
63 }
64 let mut all_writes: Vec<(String, Value)> = Vec::new();
65 for h in handles {
66 let writes = h.await.map_err(|e| AgentError::Internal(e.to_string()))??;
67 all_writes.extend(writes);
68 }
69 state.merge_writes(all_writes)?;
70 state.advance();
71 for sid in layer {
72 completed.insert(sid.clone());
73 }
74 self.checkpointer
75 .save(Snapshot {
76 key: CheckpointKey {
77 workflow_id: self.workflow_id.clone(),
78 run_id: self.run_id.clone(),
79 super_step,
80 },
81 values: state.snapshot(),
82 label: format!("layer:{super_step}"),
83 timestamp_ms: now_ms(),
84 })
85 .await?;
86 }
87 Ok(state)
88 }
89
90 fn layered(&self, order: &[StepId]) -> Vec<Vec<StepId>> {
91 use std::collections::HashMap;
93 let mut depth: HashMap<StepId, usize> = HashMap::new();
94 for s in order {
95 depth.insert(s.clone(), 0);
96 }
97 for s in order {
98 if let Some(succs) = self.dag.edges.get(s) {
99 let cur = depth[s];
100 for n in succs {
101 let next = (cur + 1).max(*depth.get(n).unwrap_or(&0));
102 depth.insert(n.clone(), next);
103 }
104 }
105 }
106 let max_d = depth.values().copied().max().unwrap_or(0);
107 let mut layers: Vec<Vec<StepId>> = vec![Vec::new(); max_d + 1];
108 for s in order {
109 layers[depth[s]].push(s.clone());
110 }
111 layers
112 }
113}
114
115fn now_ms() -> i64 {
116 use std::time::{SystemTime, UNIX_EPOCH};
117 SystemTime::now()
118 .duration_since(UNIX_EPOCH)
119 .map(|d| d.as_millis() as i64)
120 .unwrap_or(0)
121}
122
123pub struct FnStatefulStep<F>(pub F);
125
126#[async_trait]
127impl<F, Fut> StatefulStep for FnStatefulStep<F>
128where
129 F: Fn(&RunState) -> Fut + Send + Sync + 'static,
130 Fut: std::future::Future<Output = Result<Vec<(String, Value)>>> + Send + 'static,
131{
132 async fn run(&self, state: &RunState) -> Result<Vec<(String, Value)>> {
133 (self.0)(state).await
134 }
135}
136
137#[cfg(test)]
138mod tests {
139 use super::*;
140 use crate::dag::Dag;
141 use atomr_agents_state::{AppendMessages, InMemoryCheckpointer, MergeMap, StateSchema};
142 use serde_json::json;
143
144 fn schema() -> Arc<StateSchema> {
145 Arc::new(
146 StateSchema::builder()
147 .add("messages", AppendMessages)
148 .add("config", MergeMap)
149 .build(),
150 )
151 }
152
153 fn step_writing<F>(write: F) -> Arc<dyn StatefulStep>
154 where
155 F: Fn(&RunState) -> Vec<(String, Value)> + Send + Sync + 'static,
156 {
157 Arc::new(FnStatefulStep(move |s: &RunState| {
158 let writes = write(s);
159 async move { Ok(writes) }
160 }))
161 }
162
163 #[tokio::test]
164 async fn linear_dag_writes_per_super_step() {
165 let dag: Dag<Arc<dyn StatefulStep>> = Dag::builder("a")
166 .step(
167 "a",
168 step_writing(|_| vec![("messages".into(), json!([{"id": "m1", "text": "hi"}]))]),
169 )
170 .step(
171 "b",
172 step_writing(|s| {
173 let n = s.read("messages").as_array().map(|v| v.len()).unwrap_or(0);
174 vec![("config".into(), json!({"seen": n}))]
175 }),
176 )
177 .edge("a", "b")
178 .build();
179 let runner = StatefulRunner {
180 workflow_id: WorkflowId::from("wf"),
181 run_id: RunId::from("r"),
182 dag,
183 schema: schema(),
184 checkpointer: Arc::new(InMemoryCheckpointer::new()),
185 };
186 let final_state = runner.run().await.unwrap();
187 assert_eq!(final_state.read("messages").as_array().unwrap().len(), 1);
188 assert_eq!(final_state.read("config")["seen"], 1);
189 }
190
191 #[tokio::test]
192 async fn resume_from_checkpoint_skips_completed_layers() {
193 let cpt = Arc::new(InMemoryCheckpointer::new());
197 let bad: Arc<dyn StatefulStep> = Arc::new(FnStatefulStep(|_s: &RunState| async {
198 Err::<Vec<(String, Value)>, _>(AgentError::Workflow("first run dies on b".into()))
199 }));
200 let dag1: Dag<Arc<dyn StatefulStep>> = Dag::builder("a")
201 .step(
202 "a",
203 step_writing(|_| vec![("messages".into(), json!([{"id": "m1"}]))]),
204 )
205 .step("b", bad)
206 .edge("a", "b")
207 .build();
208 let r1 = StatefulRunner {
209 workflow_id: WorkflowId::from("wf"),
210 run_id: RunId::from("r"),
211 dag: dag1,
212 schema: schema(),
213 checkpointer: cpt.clone(),
214 };
215 let _ = r1.run().await; let metas = cpt
218 .list(&WorkflowId::from("wf"), &RunId::from("r"))
219 .await
220 .unwrap();
221 assert_eq!(metas.len(), 1);
222 assert_eq!(metas[0].super_step, 1);
223
224 use std::sync::atomic::{AtomicU32, Ordering};
228 let a_runs = Arc::new(AtomicU32::new(0));
229 let a_runs2 = a_runs.clone();
230 let counted_a: Arc<dyn StatefulStep> = Arc::new(FnStatefulStep(move |_s: &RunState| {
231 let c = a_runs2.clone();
232 async move {
233 c.fetch_add(1, Ordering::SeqCst);
234 Ok(vec![("messages".into(), json!([{"id": "m1"}]))])
235 }
236 }));
237 let dag2: Dag<Arc<dyn StatefulStep>> = Dag::builder("a")
238 .step("a", counted_a)
239 .step(
240 "b",
241 step_writing(|_| vec![("config".into(), json!({"ok": true}))]),
242 )
243 .edge("a", "b")
244 .build();
245 let r2 = StatefulRunner {
246 workflow_id: WorkflowId::from("wf"),
247 run_id: RunId::from("r"),
248 dag: dag2,
249 schema: schema(),
250 checkpointer: cpt.clone(),
251 };
252 let final_state = r2.run().await.unwrap();
253 assert_eq!(a_runs.load(Ordering::SeqCst), 0);
254 assert_eq!(final_state.read("config")["ok"], true);
255 }
256}