Skip to main content

atomr_agents_workflow/
runner.rs

1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3use std::time::Duration;
4
5use atomr_agents_core::{
6    AgentError, CallCtx, IterationBudget, MoneyBudget, Result, TimeBudget, TokenBudget, Value, WorkflowId,
7};
8
9use crate::dag::{Dag, StepId};
10use crate::event::{Journal, WorkflowEvent};
11use crate::step::{JoinStrategy, Step};
12
13#[derive(Debug, Clone, Default)]
14pub struct WorkflowState {
15    pub completed: HashSet<StepId>,
16    pub outputs: HashMap<StepId, Value>,
17    pub branches: HashMap<StepId, StepId>,
18    pub terminated: Option<bool>,
19}
20
21impl WorkflowState {
22    pub fn fold(events: &[WorkflowEvent]) -> Self {
23        let mut s = WorkflowState::default();
24        for e in events {
25            match e {
26                WorkflowEvent::StepCompleted { step_id, output } => {
27                    s.completed.insert(step_id.clone());
28                    s.outputs.insert(step_id.clone(), output.clone());
29                }
30                WorkflowEvent::BranchTaken { step_id, chosen } => {
31                    s.branches.insert(step_id.clone(), chosen.clone());
32                }
33                WorkflowEvent::Terminated { ok } => {
34                    s.terminated = Some(*ok);
35                }
36                _ => {}
37            }
38        }
39        s
40    }
41}
42
43pub struct WorkflowRunner {
44    pub id: WorkflowId,
45    pub dag: Dag<Step>,
46    pub journal: Arc<dyn Journal>,
47}
48
49impl WorkflowRunner {
50    /// Construct a runner from its three pieces. Equivalent to the
51    /// struct literal but exposed as a stable constructor for callers
52    /// (notably the Python wrapper) that don't have access to crate-
53    /// private types.
54    pub fn new(id: WorkflowId, dag: Dag<Step>, journal: Arc<dyn Journal>) -> Self {
55        Self { id, dag, journal }
56    }
57
58    pub async fn run(&self, input: Value) -> Result<Value> {
59        // Resume from journal if state exists.
60        let history = self.journal.replay(&self.id).await?;
61        let mut state = WorkflowState::fold(&history);
62
63        if let Some(true) = state.terminated {
64            // Already done: return last output if present.
65            return Ok(self.last_output(&state).unwrap_or(Value::Null));
66        }
67
68        let order = self.dag.topo_sort()?;
69        let mut current_input = input;
70        for step_id in order {
71            if state.completed.contains(&step_id) {
72                continue;
73            }
74            self.journal
75                .append(
76                    &self.id,
77                    WorkflowEvent::StepStarted {
78                        step_id: step_id.clone(),
79                        idempotency_key: format!("{}/{}", self.id.as_str(), step_id.as_str()),
80                    },
81                )
82                .await?;
83            let step = self
84                .dag
85                .steps
86                .get(&step_id)
87                .ok_or_else(|| AgentError::Workflow(format!("unknown step {}", step_id.as_str())))?;
88            match self.exec_step(step, &current_input, &mut state).await {
89                Ok(out) => {
90                    self.journal
91                        .append(
92                            &self.id,
93                            WorkflowEvent::StepCompleted {
94                                step_id: step_id.clone(),
95                                output: out.clone(),
96                            },
97                        )
98                        .await?;
99                    state.completed.insert(step_id.clone());
100                    state.outputs.insert(step_id.clone(), out.clone());
101                    current_input = out;
102                }
103                Err(e) => {
104                    self.journal
105                        .append(
106                            &self.id,
107                            WorkflowEvent::StepFailed {
108                                step_id: step_id.clone(),
109                                error: e.to_string(),
110                            },
111                        )
112                        .await?;
113                    self.journal
114                        .append(&self.id, WorkflowEvent::Terminated { ok: false })
115                        .await?;
116                    return Err(e);
117                }
118            }
119        }
120        self.journal
121            .append(&self.id, WorkflowEvent::Terminated { ok: true })
122            .await?;
123        Ok(self.last_output(&state).unwrap_or(Value::Null))
124    }
125
126    fn last_output(&self, state: &WorkflowState) -> Option<Value> {
127        // Pick output of the topo-last completed step.
128        self.dag.topo_sort().ok().and_then(|order| {
129            order
130                .into_iter()
131                .rev()
132                .find_map(|id| state.outputs.get(&id).cloned())
133        })
134    }
135
136    async fn exec_step(&self, step: &Step, input: &Value, state: &mut WorkflowState) -> Result<Value> {
137        match step {
138            Step::Invoke { callable, mapping: _ } => {
139                let ctx = default_call_ctx();
140                callable.call(input.clone(), ctx).await
141            }
142            Step::Branch {
143                predicate,
144                if_true,
145                if_false,
146            } => {
147                let chosen = if predicate.evaluate(input) {
148                    if_true.clone()
149                } else {
150                    if_false.clone()
151                };
152                state.branches.insert(StepId::new("__branch__"), chosen.clone());
153                Ok(serde_json::json!({"branch": chosen.as_str()}))
154            }
155            Step::Parallel { steps, join } => {
156                let mut handles = Vec::new();
157                for sid in steps {
158                    let s =
159                        self.dag.steps.get(sid).ok_or_else(|| {
160                            AgentError::Workflow(format!("parallel: unknown {}", sid.as_str()))
161                        })?;
162                    if let Step::Invoke { callable, .. } = s {
163                        let c = callable.clone();
164                        let inp = input.clone();
165                        handles.push(tokio::spawn(async move { c.call(inp, default_call_ctx()).await }));
166                    } else {
167                        return Err(AgentError::Workflow(
168                            "parallel currently supports only Invoke children".into(),
169                        ));
170                    }
171                }
172                let mut outs = Vec::new();
173                let mut first_ok = None;
174                for h in handles {
175                    match h.await {
176                        Ok(Ok(v)) => {
177                            if first_ok.is_none() {
178                                first_ok = Some(v.clone());
179                            }
180                            outs.push(v);
181                        }
182                        Ok(Err(e)) => match join {
183                            JoinStrategy::All => return Err(e),
184                            JoinStrategy::Any => continue,
185                        },
186                        Err(e) => return Err(AgentError::Workflow(e.to_string())),
187                    }
188                }
189                match join {
190                    JoinStrategy::All => Ok(serde_json::json!(outs)),
191                    JoinStrategy::Any => Ok(first_ok.unwrap_or(Value::Null)),
192                }
193            }
194            Step::Loop { .. } | Step::Map { .. } | Step::Human { .. } => {
195                // v0: stub; full support lands in Phase 7 alongside
196                // harness integration. Returns the input unchanged
197                // so a pinned workflow that doesn't exercise these
198                // variants still runs.
199                Ok(input.clone())
200            }
201        }
202    }
203}
204
205fn default_call_ctx() -> CallCtx {
206    CallCtx {
207        agent_id: None,
208        tokens: TokenBudget::new(8192),
209        time: TimeBudget::new(Duration::from_secs(60)),
210        money: MoneyBudget::from_usd(1.0),
211        iterations: IterationBudget::new(16),
212        trace: vec![],
213    }
214}
215
216#[cfg(test)]
217mod tests {
218    use super::*;
219    use crate::event::InMemoryJournal;
220    use atomr_agents_callable::{Callable, FnCallable};
221    use std::sync::atomic::{AtomicU32, Ordering};
222
223    fn echo_callable() -> Arc<dyn Callable> {
224        Arc::new(FnCallable::labeled("echo", |v: Value, _ctx| async move { Ok(v) }))
225    }
226
227    fn counter_callable(state: Arc<AtomicU32>) -> Arc<dyn Callable> {
228        Arc::new(FnCallable::labeled("counter", move |_v: Value, _ctx| {
229            let s = state.clone();
230            async move { Ok(serde_json::json!(s.fetch_add(1, Ordering::SeqCst))) }
231        }))
232    }
233
234    #[tokio::test]
235    async fn new_constructor_builds_runner() {
236        let dag: Dag<Step> = Dag::builder("a")
237            .step("a", Step::invoke(echo_callable()))
238            .build();
239        let r = WorkflowRunner::new(
240            WorkflowId::from("wf-new"),
241            dag,
242            Arc::new(InMemoryJournal::new()),
243        );
244        let out = r.run(serde_json::json!({"k": 7})).await.unwrap();
245        assert_eq!(out, serde_json::json!({"k": 7}));
246    }
247
248    #[tokio::test]
249    async fn happy_path_runs_topo_order() {
250        let dag: Dag<Step> = Dag::builder("a")
251            .step("a", Step::invoke(echo_callable()))
252            .step("b", Step::invoke(echo_callable()))
253            .edge("a", "b")
254            .build();
255        let r = WorkflowRunner {
256            id: WorkflowId::from("wf-1"),
257            dag,
258            journal: Arc::new(InMemoryJournal::new()),
259        };
260        let out = r.run(serde_json::json!({"x": 1})).await.unwrap();
261        assert_eq!(out, serde_json::json!({"x": 1}));
262    }
263
264    #[tokio::test]
265    async fn parallel_all_collects_outputs() {
266        let dag: Dag<Step> = Dag::builder("p")
267            .step(
268                "p",
269                Step::Parallel {
270                    steps: vec![StepId::new("a"), StepId::new("b")],
271                    join: JoinStrategy::All,
272                },
273            )
274            .step("a", Step::invoke(echo_callable()))
275            .step("b", Step::invoke(echo_callable()))
276            .build();
277        let r = WorkflowRunner {
278            id: WorkflowId::from("wf-2"),
279            dag,
280            journal: Arc::new(InMemoryJournal::new()),
281        };
282        let out = r.run(serde_json::json!(5)).await.unwrap();
283        // Parallel happens only when the parent step `p` is the one
284        // executed; the other steps still appear in topo order. We
285        // accept any non-null output to keep the test scope tight.
286        assert!(!out.is_null());
287    }
288
289    #[tokio::test]
290    async fn replay_resumes_after_partial_failure() {
291        // First run: step "a" succeeds, "b" fails.
292        let journal: Arc<dyn Journal> = Arc::new(InMemoryJournal::new());
293        let counter = Arc::new(AtomicU32::new(0));
294        let id = WorkflowId::from("wf-resume");
295
296        let dag1: Dag<Step> = Dag::builder("a")
297            .step("a", Step::invoke(counter_callable(counter.clone())))
298            .step(
299                "b",
300                Step::invoke(Arc::new(FnCallable::labeled("boom", |_v: Value, _ctx| async {
301                    Err(atomr_agents_core::AgentError::Workflow(
302                        "first run b fails".into(),
303                    ))
304                }))),
305            )
306            .edge("a", "b")
307            .build();
308        let r1 = WorkflowRunner {
309            id: id.clone(),
310            dag: dag1,
311            journal: journal.clone(),
312        };
313        assert!(r1.run(serde_json::json!({})).await.is_err());
314
315        // The first run terminated with ok=false. For replay-resume we
316        // only treat the workflow as "done" when terminated=true; since
317        // it's false we DO want to retry. Adjust journal: drop the
318        // Terminated{false} so resume re-runs from b.
319        // (In a real system, retry policy would do this filtering;
320        // here we assert by replaying *only* the successful events.)
321        let history = journal.replay(&id).await.unwrap();
322        let clean = InMemoryJournal::new();
323        for e in &history {
324            if !matches!(
325                e,
326                WorkflowEvent::Terminated { ok: false } | WorkflowEvent::StepFailed { .. }
327            ) {
328                clean.append(&id, e.clone()).await.unwrap();
329            }
330        }
331        // Second run with new dag where b succeeds.
332        let dag2: Dag<Step> = Dag::builder("a")
333            .step("a", Step::invoke(counter_callable(counter.clone())))
334            .step("b", Step::invoke(echo_callable()))
335            .edge("a", "b")
336            .build();
337        let r2 = WorkflowRunner {
338            id,
339            dag: dag2,
340            journal: Arc::new(clean),
341        };
342        let out = r2.run(serde_json::json!({"v": 1})).await.unwrap();
343        // Counter should be at 1, *not* incremented again, because
344        // step "a" was replayed-as-completed.
345        assert_eq!(counter.load(Ordering::SeqCst), 1);
346        assert_eq!(out, serde_json::json!({"v": 1}));
347    }
348}