Skip to main content

atomr_agents_workflow/
state_runner.rs

1//! Stateful runner — DAG over a `StateSchema` + `Checkpointer`.
2//!
3//! Each step is a `StatefulStep` that takes the current `RunState`,
4//! returns a list of writes (channel-key, value), and the runner
5//! applies them via the channel reducers. After every super-step
6//! the snapshot is persisted via the `Checkpointer`. On resume,
7//! `WorkflowRunner::resume_from_checkpoint` skips through completed
8//! super-steps using the journal's existing `StepCompleted` events.
9
10use std::collections::HashSet;
11use std::sync::Arc;
12
13use async_trait::async_trait;
14use atomr_agents_core::{AgentError, Result, RunId, Value, WorkflowId};
15use atomr_agents_state::{CheckpointKey, Checkpointer, RunState, Snapshot, StateSchema};
16
17use crate::dag::{Dag, StepId};
18
19#[async_trait]
20pub trait StatefulStep: Send + Sync + 'static {
21    async fn run(&self, state: &RunState) -> Result<Vec<(String, Value)>>;
22}
23
24pub struct StatefulRunner {
25    pub workflow_id: WorkflowId,
26    pub run_id: RunId,
27    pub dag: Dag<Arc<dyn StatefulStep>>,
28    pub schema: Arc<StateSchema>,
29    pub checkpointer: Arc<dyn Checkpointer>,
30}
31
32impl StatefulRunner {
33    pub async fn run(&self) -> Result<RunState> {
34        // Resume from latest checkpoint if present.
35        let mut state = match self.checkpointer.latest(&self.workflow_id, &self.run_id).await? {
36            Some(snap) => RunState::from_snapshot(self.schema.clone(), snap.values, snap.key.super_step),
37            None => RunState::new(self.schema.clone()),
38        };
39        let order = self.dag.topo_sort()?;
40        // Group steps into super-steps by topological layer (level).
41        let layers = self.layered(&order);
42        let resume_at = state.super_step();
43        let mut completed: HashSet<StepId> = HashSet::new();
44        for (layer_idx, layer) in layers.iter().enumerate() {
45            let super_step = layer_idx as u64 + 1;
46            if super_step <= resume_at {
47                for sid in layer {
48                    completed.insert(sid.clone());
49                }
50                continue;
51            }
52            // Run the layer concurrently; collect all writes.
53            let mut handles = Vec::new();
54            for sid in layer {
55                let step = self
56                    .dag
57                    .steps
58                    .get(sid)
59                    .ok_or_else(|| AgentError::Workflow(format!("missing step {}", sid.as_str())))?;
60                let step = step.clone();
61                let st = state.clone();
62                handles.push(tokio::spawn(async move { step.run(&st).await }));
63            }
64            let mut all_writes: Vec<(String, Value)> = Vec::new();
65            for h in handles {
66                let writes = h.await.map_err(|e| AgentError::Internal(e.to_string()))??;
67                all_writes.extend(writes);
68            }
69            state.merge_writes(all_writes)?;
70            state.advance();
71            for sid in layer {
72                completed.insert(sid.clone());
73            }
74            self.checkpointer
75                .save(Snapshot {
76                    key: CheckpointKey {
77                        workflow_id: self.workflow_id.clone(),
78                        run_id: self.run_id.clone(),
79                        super_step,
80                    },
81                    values: state.snapshot(),
82                    label: format!("layer:{super_step}"),
83                    timestamp_ms: now_ms(),
84                })
85                .await?;
86        }
87        Ok(state)
88    }
89
90    fn layered(&self, order: &[StepId]) -> Vec<Vec<StepId>> {
91        // Compute depth from edges; same depth = same super-step.
92        use std::collections::HashMap;
93        let mut depth: HashMap<StepId, usize> = HashMap::new();
94        for s in order {
95            depth.insert(s.clone(), 0);
96        }
97        for s in order {
98            if let Some(succs) = self.dag.edges.get(s) {
99                let cur = depth[s];
100                for n in succs {
101                    let next = (cur + 1).max(*depth.get(n).unwrap_or(&0));
102                    depth.insert(n.clone(), next);
103                }
104            }
105        }
106        let max_d = depth.values().copied().max().unwrap_or(0);
107        let mut layers: Vec<Vec<StepId>> = vec![Vec::new(); max_d + 1];
108        for s in order {
109            layers[depth[s]].push(s.clone());
110        }
111        layers
112    }
113}
114
115fn now_ms() -> i64 {
116    use std::time::{SystemTime, UNIX_EPOCH};
117    SystemTime::now()
118        .duration_since(UNIX_EPOCH)
119        .map(|d| d.as_millis() as i64)
120        .unwrap_or(0)
121}
122
123// Convenience: build a StatefulStep from a closure.
124pub struct FnStatefulStep<F>(pub F);
125
126#[async_trait]
127impl<F, Fut> StatefulStep for FnStatefulStep<F>
128where
129    F: Fn(&RunState) -> Fut + Send + Sync + 'static,
130    Fut: std::future::Future<Output = Result<Vec<(String, Value)>>> + Send + 'static,
131{
132    async fn run(&self, state: &RunState) -> Result<Vec<(String, Value)>> {
133        (self.0)(state).await
134    }
135}
136
137#[cfg(test)]
138mod tests {
139    use super::*;
140    use crate::dag::Dag;
141    use atomr_agents_state::{AppendMessages, InMemoryCheckpointer, MergeMap, StateSchema};
142    use serde_json::json;
143
144    fn schema() -> Arc<StateSchema> {
145        Arc::new(
146            StateSchema::builder()
147                .add("messages", AppendMessages)
148                .add("config", MergeMap)
149                .build(),
150        )
151    }
152
153    fn step_writing<F>(write: F) -> Arc<dyn StatefulStep>
154    where
155        F: Fn(&RunState) -> Vec<(String, Value)> + Send + Sync + 'static,
156    {
157        Arc::new(FnStatefulStep(move |s: &RunState| {
158            let writes = write(s);
159            async move { Ok(writes) }
160        }))
161    }
162
163    #[tokio::test]
164    async fn linear_dag_writes_per_super_step() {
165        let dag: Dag<Arc<dyn StatefulStep>> = Dag::builder("a")
166            .step(
167                "a",
168                step_writing(|_| vec![("messages".into(), json!([{"id": "m1", "text": "hi"}]))]),
169            )
170            .step(
171                "b",
172                step_writing(|s| {
173                    let n = s.read("messages").as_array().map(|v| v.len()).unwrap_or(0);
174                    vec![("config".into(), json!({"seen": n}))]
175                }),
176            )
177            .edge("a", "b")
178            .build();
179        let runner = StatefulRunner {
180            workflow_id: WorkflowId::from("wf"),
181            run_id: RunId::from("r"),
182            dag,
183            schema: schema(),
184            checkpointer: Arc::new(InMemoryCheckpointer::new()),
185        };
186        let final_state = runner.run().await.unwrap();
187        assert_eq!(final_state.read("messages").as_array().unwrap().len(), 1);
188        assert_eq!(final_state.read("config")["seen"], 1);
189    }
190
191    #[tokio::test]
192    async fn resume_from_checkpoint_skips_completed_layers() {
193        // First run completes layer 1; then we run again with the
194        // same checkpointer and a step that would corrupt state if
195        // re-executed. Resume must skip layer 1.
196        let cpt = Arc::new(InMemoryCheckpointer::new());
197        let bad: Arc<dyn StatefulStep> = Arc::new(FnStatefulStep(|_s: &RunState| async {
198            Err::<Vec<(String, Value)>, _>(AgentError::Workflow("first run dies on b".into()))
199        }));
200        let dag1: Dag<Arc<dyn StatefulStep>> = Dag::builder("a")
201            .step(
202                "a",
203                step_writing(|_| vec![("messages".into(), json!([{"id": "m1"}]))]),
204            )
205            .step("b", bad)
206            .edge("a", "b")
207            .build();
208        let r1 = StatefulRunner {
209            workflow_id: WorkflowId::from("wf"),
210            run_id: RunId::from("r"),
211            dag: dag1,
212            schema: schema(),
213            checkpointer: cpt.clone(),
214        };
215        let _ = r1.run().await; // expected to fail
216                                // Layer 1 should be checkpointed.
217        let metas = cpt
218            .list(&WorkflowId::from("wf"), &RunId::from("r"))
219            .await
220            .unwrap();
221        assert_eq!(metas.len(), 1);
222        assert_eq!(metas[0].super_step, 1);
223
224        // Second run: replace b with a benign step. a must NOT re-run
225        // (it would dedupe via AppendMessages, so we'll detect by
226        // counting how often a runs via a side-channel).
227        use std::sync::atomic::{AtomicU32, Ordering};
228        let a_runs = Arc::new(AtomicU32::new(0));
229        let a_runs2 = a_runs.clone();
230        let counted_a: Arc<dyn StatefulStep> = Arc::new(FnStatefulStep(move |_s: &RunState| {
231            let c = a_runs2.clone();
232            async move {
233                c.fetch_add(1, Ordering::SeqCst);
234                Ok(vec![("messages".into(), json!([{"id": "m1"}]))])
235            }
236        }));
237        let dag2: Dag<Arc<dyn StatefulStep>> = Dag::builder("a")
238            .step("a", counted_a)
239            .step(
240                "b",
241                step_writing(|_| vec![("config".into(), json!({"ok": true}))]),
242            )
243            .edge("a", "b")
244            .build();
245        let r2 = StatefulRunner {
246            workflow_id: WorkflowId::from("wf"),
247            run_id: RunId::from("r"),
248            dag: dag2,
249            schema: schema(),
250            checkpointer: cpt.clone(),
251        };
252        let final_state = r2.run().await.unwrap();
253        assert_eq!(a_runs.load(Ordering::SeqCst), 0);
254        assert_eq!(final_state.read("config")["ok"], true);
255    }
256}