1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3use std::time::Duration;
4
5use atomr_agents_core::{
6 AgentError, CallCtx, IterationBudget, MoneyBudget, Result, TimeBudget, TokenBudget, Value, WorkflowId,
7};
8
9use crate::dag::{Dag, StepId};
10use crate::event::{Journal, WorkflowEvent};
11use crate::step::{JoinStrategy, Step};
12
13#[derive(Debug, Clone, Default)]
14pub struct WorkflowState {
15 pub completed: HashSet<StepId>,
16 pub outputs: HashMap<StepId, Value>,
17 pub branches: HashMap<StepId, StepId>,
18 pub terminated: Option<bool>,
19}
20
21impl WorkflowState {
22 pub fn fold(events: &[WorkflowEvent]) -> Self {
23 let mut s = WorkflowState::default();
24 for e in events {
25 match e {
26 WorkflowEvent::StepCompleted { step_id, output } => {
27 s.completed.insert(step_id.clone());
28 s.outputs.insert(step_id.clone(), output.clone());
29 }
30 WorkflowEvent::BranchTaken { step_id, chosen } => {
31 s.branches.insert(step_id.clone(), chosen.clone());
32 }
33 WorkflowEvent::Terminated { ok } => {
34 s.terminated = Some(*ok);
35 }
36 _ => {}
37 }
38 }
39 s
40 }
41}
42
43pub struct WorkflowRunner {
44 pub id: WorkflowId,
45 pub dag: Dag<Step>,
46 pub journal: Arc<dyn Journal>,
47}
48
49impl WorkflowRunner {
50 pub fn new(id: WorkflowId, dag: Dag<Step>, journal: Arc<dyn Journal>) -> Self {
55 Self { id, dag, journal }
56 }
57
58 pub async fn run(&self, input: Value) -> Result<Value> {
59 let history = self.journal.replay(&self.id).await?;
61 let mut state = WorkflowState::fold(&history);
62
63 if let Some(true) = state.terminated {
64 return Ok(self.last_output(&state).unwrap_or(Value::Null));
66 }
67
68 let order = self.dag.topo_sort()?;
69 let mut current_input = input;
70 for step_id in order {
71 if state.completed.contains(&step_id) {
72 continue;
73 }
74 self.journal
75 .append(
76 &self.id,
77 WorkflowEvent::StepStarted {
78 step_id: step_id.clone(),
79 idempotency_key: format!("{}/{}", self.id.as_str(), step_id.as_str()),
80 },
81 )
82 .await?;
83 let step = self
84 .dag
85 .steps
86 .get(&step_id)
87 .ok_or_else(|| AgentError::Workflow(format!("unknown step {}", step_id.as_str())))?;
88 match self.exec_step(step, ¤t_input, &mut state).await {
89 Ok(out) => {
90 self.journal
91 .append(
92 &self.id,
93 WorkflowEvent::StepCompleted {
94 step_id: step_id.clone(),
95 output: out.clone(),
96 },
97 )
98 .await?;
99 state.completed.insert(step_id.clone());
100 state.outputs.insert(step_id.clone(), out.clone());
101 current_input = out;
102 }
103 Err(e) => {
104 self.journal
105 .append(
106 &self.id,
107 WorkflowEvent::StepFailed {
108 step_id: step_id.clone(),
109 error: e.to_string(),
110 },
111 )
112 .await?;
113 self.journal
114 .append(&self.id, WorkflowEvent::Terminated { ok: false })
115 .await?;
116 return Err(e);
117 }
118 }
119 }
120 self.journal
121 .append(&self.id, WorkflowEvent::Terminated { ok: true })
122 .await?;
123 Ok(self.last_output(&state).unwrap_or(Value::Null))
124 }
125
126 fn last_output(&self, state: &WorkflowState) -> Option<Value> {
127 self.dag.topo_sort().ok().and_then(|order| {
129 order
130 .into_iter()
131 .rev()
132 .find_map(|id| state.outputs.get(&id).cloned())
133 })
134 }
135
136 async fn exec_step(&self, step: &Step, input: &Value, state: &mut WorkflowState) -> Result<Value> {
137 match step {
138 Step::Invoke { callable, mapping: _ } => {
139 let ctx = default_call_ctx();
140 callable.call(input.clone(), ctx).await
141 }
142 Step::Branch {
143 predicate,
144 if_true,
145 if_false,
146 } => {
147 let chosen = if predicate.evaluate(input) {
148 if_true.clone()
149 } else {
150 if_false.clone()
151 };
152 state.branches.insert(StepId::new("__branch__"), chosen.clone());
153 Ok(serde_json::json!({"branch": chosen.as_str()}))
154 }
155 Step::Parallel { steps, join } => {
156 let mut handles = Vec::new();
157 for sid in steps {
158 let s =
159 self.dag.steps.get(sid).ok_or_else(|| {
160 AgentError::Workflow(format!("parallel: unknown {}", sid.as_str()))
161 })?;
162 if let Step::Invoke { callable, .. } = s {
163 let c = callable.clone();
164 let inp = input.clone();
165 handles.push(tokio::spawn(async move { c.call(inp, default_call_ctx()).await }));
166 } else {
167 return Err(AgentError::Workflow(
168 "parallel currently supports only Invoke children".into(),
169 ));
170 }
171 }
172 let mut outs = Vec::new();
173 let mut first_ok = None;
174 for h in handles {
175 match h.await {
176 Ok(Ok(v)) => {
177 if first_ok.is_none() {
178 first_ok = Some(v.clone());
179 }
180 outs.push(v);
181 }
182 Ok(Err(e)) => match join {
183 JoinStrategy::All => return Err(e),
184 JoinStrategy::Any => continue,
185 },
186 Err(e) => return Err(AgentError::Workflow(e.to_string())),
187 }
188 }
189 match join {
190 JoinStrategy::All => Ok(serde_json::json!(outs)),
191 JoinStrategy::Any => Ok(first_ok.unwrap_or(Value::Null)),
192 }
193 }
194 Step::Loop { .. } | Step::Map { .. } | Step::Human { .. } => {
195 Ok(input.clone())
200 }
201 }
202 }
203}
204
205fn default_call_ctx() -> CallCtx {
206 CallCtx {
207 agent_id: None,
208 tokens: TokenBudget::new(8192),
209 time: TimeBudget::new(Duration::from_secs(60)),
210 money: MoneyBudget::from_usd(1.0),
211 iterations: IterationBudget::new(16),
212 trace: vec![],
213 }
214}
215
216#[cfg(test)]
217mod tests {
218 use super::*;
219 use crate::event::InMemoryJournal;
220 use atomr_agents_callable::{Callable, FnCallable};
221 use std::sync::atomic::{AtomicU32, Ordering};
222
223 fn echo_callable() -> Arc<dyn Callable> {
224 Arc::new(FnCallable::labeled("echo", |v: Value, _ctx| async move { Ok(v) }))
225 }
226
227 fn counter_callable(state: Arc<AtomicU32>) -> Arc<dyn Callable> {
228 Arc::new(FnCallable::labeled("counter", move |_v: Value, _ctx| {
229 let s = state.clone();
230 async move { Ok(serde_json::json!(s.fetch_add(1, Ordering::SeqCst))) }
231 }))
232 }
233
234 #[tokio::test]
235 async fn new_constructor_builds_runner() {
236 let dag: Dag<Step> = Dag::builder("a")
237 .step("a", Step::invoke(echo_callable()))
238 .build();
239 let r = WorkflowRunner::new(
240 WorkflowId::from("wf-new"),
241 dag,
242 Arc::new(InMemoryJournal::new()),
243 );
244 let out = r.run(serde_json::json!({"k": 7})).await.unwrap();
245 assert_eq!(out, serde_json::json!({"k": 7}));
246 }
247
248 #[tokio::test]
249 async fn happy_path_runs_topo_order() {
250 let dag: Dag<Step> = Dag::builder("a")
251 .step("a", Step::invoke(echo_callable()))
252 .step("b", Step::invoke(echo_callable()))
253 .edge("a", "b")
254 .build();
255 let r = WorkflowRunner {
256 id: WorkflowId::from("wf-1"),
257 dag,
258 journal: Arc::new(InMemoryJournal::new()),
259 };
260 let out = r.run(serde_json::json!({"x": 1})).await.unwrap();
261 assert_eq!(out, serde_json::json!({"x": 1}));
262 }
263
264 #[tokio::test]
265 async fn parallel_all_collects_outputs() {
266 let dag: Dag<Step> = Dag::builder("p")
267 .step(
268 "p",
269 Step::Parallel {
270 steps: vec![StepId::new("a"), StepId::new("b")],
271 join: JoinStrategy::All,
272 },
273 )
274 .step("a", Step::invoke(echo_callable()))
275 .step("b", Step::invoke(echo_callable()))
276 .build();
277 let r = WorkflowRunner {
278 id: WorkflowId::from("wf-2"),
279 dag,
280 journal: Arc::new(InMemoryJournal::new()),
281 };
282 let out = r.run(serde_json::json!(5)).await.unwrap();
283 assert!(!out.is_null());
287 }
288
289 #[tokio::test]
290 async fn replay_resumes_after_partial_failure() {
291 let journal: Arc<dyn Journal> = Arc::new(InMemoryJournal::new());
293 let counter = Arc::new(AtomicU32::new(0));
294 let id = WorkflowId::from("wf-resume");
295
296 let dag1: Dag<Step> = Dag::builder("a")
297 .step("a", Step::invoke(counter_callable(counter.clone())))
298 .step(
299 "b",
300 Step::invoke(Arc::new(FnCallable::labeled("boom", |_v: Value, _ctx| async {
301 Err(atomr_agents_core::AgentError::Workflow(
302 "first run b fails".into(),
303 ))
304 }))),
305 )
306 .edge("a", "b")
307 .build();
308 let r1 = WorkflowRunner {
309 id: id.clone(),
310 dag: dag1,
311 journal: journal.clone(),
312 };
313 assert!(r1.run(serde_json::json!({})).await.is_err());
314
315 let history = journal.replay(&id).await.unwrap();
322 let clean = InMemoryJournal::new();
323 for e in &history {
324 if !matches!(
325 e,
326 WorkflowEvent::Terminated { ok: false } | WorkflowEvent::StepFailed { .. }
327 ) {
328 clean.append(&id, e.clone()).await.unwrap();
329 }
330 }
331 let dag2: Dag<Step> = Dag::builder("a")
333 .step("a", Step::invoke(counter_callable(counter.clone())))
334 .step("b", Step::invoke(echo_callable()))
335 .edge("a", "b")
336 .build();
337 let r2 = WorkflowRunner {
338 id,
339 dag: dag2,
340 journal: Arc::new(clean),
341 };
342 let out = r2.run(serde_json::json!({"v": 1})).await.unwrap();
343 assert_eq!(counter.load(Ordering::SeqCst), 1);
346 assert_eq!(out, serde_json::json!({"v": 1}));
347 }
348}