1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3use std::time::Duration;
4
5use atomr_agents_core::{
6 AgentError, CallCtx, IterationBudget, MoneyBudget, Result, TimeBudget, TokenBudget, Value, WorkflowId,
7};
8
9use crate::dag::{Dag, StepId};
10use crate::event::{Journal, WorkflowEvent};
11use crate::step::{JoinStrategy, Step};
12
13#[derive(Debug, Clone, Default)]
14pub struct WorkflowState {
15 pub completed: HashSet<StepId>,
16 pub outputs: HashMap<StepId, Value>,
17 pub branches: HashMap<StepId, StepId>,
18 pub terminated: Option<bool>,
19}
20
21impl WorkflowState {
22 pub fn fold(events: &[WorkflowEvent]) -> Self {
23 let mut s = WorkflowState::default();
24 for e in events {
25 match e {
26 WorkflowEvent::StepCompleted { step_id, output } => {
27 s.completed.insert(step_id.clone());
28 s.outputs.insert(step_id.clone(), output.clone());
29 }
30 WorkflowEvent::BranchTaken { step_id, chosen } => {
31 s.branches.insert(step_id.clone(), chosen.clone());
32 }
33 WorkflowEvent::Terminated { ok } => {
34 s.terminated = Some(*ok);
35 }
36 _ => {}
37 }
38 }
39 s
40 }
41}
42
43pub struct WorkflowRunner {
44 pub id: WorkflowId,
45 pub dag: Dag<Step>,
46 pub journal: Arc<dyn Journal>,
47}
48
49impl WorkflowRunner {
50 pub async fn run(&self, input: Value) -> Result<Value> {
51 let history = self.journal.replay(&self.id).await?;
53 let mut state = WorkflowState::fold(&history);
54
55 if let Some(true) = state.terminated {
56 return Ok(self.last_output(&state).unwrap_or(Value::Null));
58 }
59
60 let order = self.dag.topo_sort()?;
61 let mut current_input = input;
62 for step_id in order {
63 if state.completed.contains(&step_id) {
64 continue;
65 }
66 self.journal
67 .append(
68 &self.id,
69 WorkflowEvent::StepStarted {
70 step_id: step_id.clone(),
71 idempotency_key: format!("{}/{}", self.id.as_str(), step_id.as_str()),
72 },
73 )
74 .await?;
75 let step = self
76 .dag
77 .steps
78 .get(&step_id)
79 .ok_or_else(|| AgentError::Workflow(format!("unknown step {}", step_id.as_str())))?;
80 match self.exec_step(step, ¤t_input, &mut state).await {
81 Ok(out) => {
82 self.journal
83 .append(
84 &self.id,
85 WorkflowEvent::StepCompleted {
86 step_id: step_id.clone(),
87 output: out.clone(),
88 },
89 )
90 .await?;
91 state.completed.insert(step_id.clone());
92 state.outputs.insert(step_id.clone(), out.clone());
93 current_input = out;
94 }
95 Err(e) => {
96 self.journal
97 .append(
98 &self.id,
99 WorkflowEvent::StepFailed {
100 step_id: step_id.clone(),
101 error: e.to_string(),
102 },
103 )
104 .await?;
105 self.journal
106 .append(&self.id, WorkflowEvent::Terminated { ok: false })
107 .await?;
108 return Err(e);
109 }
110 }
111 }
112 self.journal
113 .append(&self.id, WorkflowEvent::Terminated { ok: true })
114 .await?;
115 Ok(self.last_output(&state).unwrap_or(Value::Null))
116 }
117
118 fn last_output(&self, state: &WorkflowState) -> Option<Value> {
119 self.dag.topo_sort().ok().and_then(|order| {
121 order
122 .into_iter()
123 .rev()
124 .find_map(|id| state.outputs.get(&id).cloned())
125 })
126 }
127
128 async fn exec_step(&self, step: &Step, input: &Value, state: &mut WorkflowState) -> Result<Value> {
129 match step {
130 Step::Invoke { callable, mapping: _ } => {
131 let ctx = default_call_ctx();
132 callable.call(input.clone(), ctx).await
133 }
134 Step::Branch {
135 predicate,
136 if_true,
137 if_false,
138 } => {
139 let chosen = if predicate.evaluate(input) {
140 if_true.clone()
141 } else {
142 if_false.clone()
143 };
144 state.branches.insert(StepId::new("__branch__"), chosen.clone());
145 Ok(serde_json::json!({"branch": chosen.as_str()}))
146 }
147 Step::Parallel { steps, join } => {
148 let mut handles = Vec::new();
149 for sid in steps {
150 let s =
151 self.dag.steps.get(sid).ok_or_else(|| {
152 AgentError::Workflow(format!("parallel: unknown {}", sid.as_str()))
153 })?;
154 if let Step::Invoke { callable, .. } = s {
155 let c = callable.clone();
156 let inp = input.clone();
157 handles.push(tokio::spawn(async move { c.call(inp, default_call_ctx()).await }));
158 } else {
159 return Err(AgentError::Workflow(
160 "parallel currently supports only Invoke children".into(),
161 ));
162 }
163 }
164 let mut outs = Vec::new();
165 let mut first_ok = None;
166 for h in handles {
167 match h.await {
168 Ok(Ok(v)) => {
169 if first_ok.is_none() {
170 first_ok = Some(v.clone());
171 }
172 outs.push(v);
173 }
174 Ok(Err(e)) => match join {
175 JoinStrategy::All => return Err(e),
176 JoinStrategy::Any => continue,
177 },
178 Err(e) => return Err(AgentError::Workflow(e.to_string())),
179 }
180 }
181 match join {
182 JoinStrategy::All => Ok(serde_json::json!(outs)),
183 JoinStrategy::Any => Ok(first_ok.unwrap_or(Value::Null)),
184 }
185 }
186 Step::Loop { .. } | Step::Map { .. } | Step::Human { .. } => {
187 Ok(input.clone())
192 }
193 }
194 }
195}
196
197fn default_call_ctx() -> CallCtx {
198 CallCtx {
199 agent_id: None,
200 tokens: TokenBudget::new(8192),
201 time: TimeBudget::new(Duration::from_secs(60)),
202 money: MoneyBudget::from_usd(1.0),
203 iterations: IterationBudget::new(16),
204 trace: vec![],
205 }
206}
207
208#[cfg(test)]
209mod tests {
210 use super::*;
211 use crate::event::InMemoryJournal;
212 use atomr_agents_callable::{Callable, FnCallable};
213 use std::sync::atomic::{AtomicU32, Ordering};
214
215 fn echo_callable() -> Arc<dyn Callable> {
216 Arc::new(FnCallable::labeled("echo", |v: Value, _ctx| async move { Ok(v) }))
217 }
218
219 fn counter_callable(state: Arc<AtomicU32>) -> Arc<dyn Callable> {
220 Arc::new(FnCallable::labeled("counter", move |_v: Value, _ctx| {
221 let s = state.clone();
222 async move { Ok(serde_json::json!(s.fetch_add(1, Ordering::SeqCst))) }
223 }))
224 }
225
226 #[tokio::test]
227 async fn happy_path_runs_topo_order() {
228 let dag: Dag<Step> = Dag::builder("a")
229 .step("a", Step::invoke(echo_callable()))
230 .step("b", Step::invoke(echo_callable()))
231 .edge("a", "b")
232 .build();
233 let r = WorkflowRunner {
234 id: WorkflowId::from("wf-1"),
235 dag,
236 journal: Arc::new(InMemoryJournal::new()),
237 };
238 let out = r.run(serde_json::json!({"x": 1})).await.unwrap();
239 assert_eq!(out, serde_json::json!({"x": 1}));
240 }
241
242 #[tokio::test]
243 async fn parallel_all_collects_outputs() {
244 let dag: Dag<Step> = Dag::builder("p")
245 .step(
246 "p",
247 Step::Parallel {
248 steps: vec![StepId::new("a"), StepId::new("b")],
249 join: JoinStrategy::All,
250 },
251 )
252 .step("a", Step::invoke(echo_callable()))
253 .step("b", Step::invoke(echo_callable()))
254 .build();
255 let r = WorkflowRunner {
256 id: WorkflowId::from("wf-2"),
257 dag,
258 journal: Arc::new(InMemoryJournal::new()),
259 };
260 let out = r.run(serde_json::json!(5)).await.unwrap();
261 assert!(!out.is_null());
265 }
266
267 #[tokio::test]
268 async fn replay_resumes_after_partial_failure() {
269 let journal: Arc<dyn Journal> = Arc::new(InMemoryJournal::new());
271 let counter = Arc::new(AtomicU32::new(0));
272 let id = WorkflowId::from("wf-resume");
273
274 let dag1: Dag<Step> = Dag::builder("a")
275 .step("a", Step::invoke(counter_callable(counter.clone())))
276 .step(
277 "b",
278 Step::invoke(Arc::new(FnCallable::labeled("boom", |_v: Value, _ctx| async {
279 Err(atomr_agents_core::AgentError::Workflow(
280 "first run b fails".into(),
281 ))
282 }))),
283 )
284 .edge("a", "b")
285 .build();
286 let r1 = WorkflowRunner {
287 id: id.clone(),
288 dag: dag1,
289 journal: journal.clone(),
290 };
291 assert!(r1.run(serde_json::json!({})).await.is_err());
292
293 let history = journal.replay(&id).await.unwrap();
300 let clean = InMemoryJournal::new();
301 for e in &history {
302 if !matches!(
303 e,
304 WorkflowEvent::Terminated { ok: false } | WorkflowEvent::StepFailed { .. }
305 ) {
306 clean.append(&id, e.clone()).await.unwrap();
307 }
308 }
309 let dag2: Dag<Step> = Dag::builder("a")
311 .step("a", Step::invoke(counter_callable(counter.clone())))
312 .step("b", Step::invoke(echo_callable()))
313 .edge("a", "b")
314 .build();
315 let r2 = WorkflowRunner {
316 id,
317 dag: dag2,
318 journal: Arc::new(clean),
319 };
320 let out = r2.run(serde_json::json!({"v": 1})).await.unwrap();
321 assert_eq!(counter.load(Ordering::SeqCst), 1);
324 assert_eq!(out, serde_json::json!({"v": 1}));
325 }
326}