1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3use std::time::Duration;
4
5use atomr_agents_core::{
6 AgentError, CallCtx, IterationBudget, MoneyBudget, Result, TimeBudget, TokenBudget, Value, WorkflowId,
7};
8
9use crate::dag::{Dag, StepId};
10use crate::event::{Journal, WorkflowEvent};
11use crate::step::{JoinStrategy, Step};
12
13#[derive(Debug, Clone, Default)]
14pub struct WorkflowState {
15 pub completed: HashSet<StepId>,
16 pub outputs: HashMap<StepId, Value>,
17 pub branches: HashMap<StepId, StepId>,
18 pub terminated: Option<bool>,
19}
20
21impl WorkflowState {
22 pub fn fold(events: &[WorkflowEvent]) -> Self {
23 let mut s = WorkflowState::default();
24 for e in events {
25 match e {
26 WorkflowEvent::StepCompleted { step_id, output } => {
27 s.completed.insert(step_id.clone());
28 s.outputs.insert(step_id.clone(), output.clone());
29 }
30 WorkflowEvent::BranchTaken { step_id, chosen } => {
31 s.branches.insert(step_id.clone(), chosen.clone());
32 }
33 WorkflowEvent::Terminated { ok } => {
34 s.terminated = Some(*ok);
35 }
36 _ => {}
37 }
38 }
39 s
40 }
41}
42
43pub struct WorkflowRunner {
44 pub id: WorkflowId,
45 pub dag: Dag<Step>,
46 pub journal: Arc<dyn Journal>,
47}
48
49impl WorkflowRunner {
50 pub fn new(id: WorkflowId, dag: Dag<Step>, journal: Arc<dyn Journal>) -> Self {
55 Self { id, dag, journal }
56 }
57
58 pub async fn run(&self, input: Value) -> Result<Value> {
59 let history = self.journal.replay(&self.id).await?;
61 let mut state = WorkflowState::fold(&history);
62
63 if let Some(true) = state.terminated {
64 return Ok(self.last_output(&state).unwrap_or(Value::Null));
66 }
67
68 let order = self.dag.topo_sort()?;
69 let mut current_input = input;
70 for step_id in order {
71 if state.completed.contains(&step_id) {
72 continue;
73 }
74 self.journal
75 .append(
76 &self.id,
77 WorkflowEvent::StepStarted {
78 step_id: step_id.clone(),
79 idempotency_key: format!("{}/{}", self.id.as_str(), step_id.as_str()),
80 },
81 )
82 .await?;
83 let step = self
84 .dag
85 .steps
86 .get(&step_id)
87 .ok_or_else(|| AgentError::Workflow(format!("unknown step {}", step_id.as_str())))?;
88 match self.exec_step(step, ¤t_input, &mut state).await {
89 Ok(out) => {
90 self.journal
91 .append(
92 &self.id,
93 WorkflowEvent::StepCompleted {
94 step_id: step_id.clone(),
95 output: out.clone(),
96 },
97 )
98 .await?;
99 state.completed.insert(step_id.clone());
100 state.outputs.insert(step_id.clone(), out.clone());
101 current_input = out;
102 }
103 Err(e) => {
104 self.journal
105 .append(
106 &self.id,
107 WorkflowEvent::StepFailed {
108 step_id: step_id.clone(),
109 error: e.to_string(),
110 },
111 )
112 .await?;
113 self.journal
114 .append(&self.id, WorkflowEvent::Terminated { ok: false })
115 .await?;
116 return Err(e);
117 }
118 }
119 }
120 self.journal
121 .append(&self.id, WorkflowEvent::Terminated { ok: true })
122 .await?;
123 Ok(self.last_output(&state).unwrap_or(Value::Null))
124 }
125
126 fn last_output(&self, state: &WorkflowState) -> Option<Value> {
127 self.dag.topo_sort().ok().and_then(|order| {
129 order
130 .into_iter()
131 .rev()
132 .find_map(|id| state.outputs.get(&id).cloned())
133 })
134 }
135
136 async fn exec_step(&self, step: &Step, input: &Value, state: &mut WorkflowState) -> Result<Value> {
137 match step {
138 Step::Invoke { callable, mapping: _ } => {
139 let ctx = default_call_ctx();
140 callable.call(input.clone(), ctx).await
141 }
142 Step::Branch {
143 predicate,
144 if_true,
145 if_false,
146 } => {
147 let chosen = if predicate.evaluate(input) {
148 if_true.clone()
149 } else {
150 if_false.clone()
151 };
152 state.branches.insert(StepId::new("__branch__"), chosen.clone());
153 Ok(serde_json::json!({"branch": chosen.as_str()}))
154 }
155 Step::Parallel { steps, join } => {
156 let mut handles = Vec::new();
157 for sid in steps {
158 let s =
159 self.dag.steps.get(sid).ok_or_else(|| {
160 AgentError::Workflow(format!("parallel: unknown {}", sid.as_str()))
161 })?;
162 if let Step::Invoke { callable, .. } = s {
163 let c = callable.clone();
164 let inp = input.clone();
165 handles.push(tokio::spawn(async move { c.call(inp, default_call_ctx()).await }));
166 } else {
167 return Err(AgentError::Workflow(
168 "parallel currently supports only Invoke children".into(),
169 ));
170 }
171 }
172 let mut outs = Vec::new();
173 let mut first_ok = None;
174 for h in handles {
175 match h.await {
176 Ok(Ok(v)) => {
177 if first_ok.is_none() {
178 first_ok = Some(v.clone());
179 }
180 outs.push(v);
181 }
182 Ok(Err(e)) => match join {
183 JoinStrategy::All => return Err(e),
184 JoinStrategy::Any => continue,
185 },
186 Err(e) => return Err(AgentError::Workflow(e.to_string())),
187 }
188 }
189 match join {
190 JoinStrategy::All => Ok(serde_json::json!(outs)),
191 JoinStrategy::Any => Ok(first_ok.unwrap_or(Value::Null)),
192 }
193 }
194 Step::Loop { .. } | Step::Map { .. } | Step::Human { .. } => {
195 Ok(input.clone())
200 }
201 }
202 }
203}
204
205fn default_call_ctx() -> CallCtx {
206 CallCtx {
207 agent_id: None,
208 tokens: TokenBudget::new(8192),
209 time: TimeBudget::new(Duration::from_secs(60)),
210 money: MoneyBudget::from_usd(1.0),
211 iterations: IterationBudget::new(16),
212 trace: vec![],
213 }
214}
215
216#[cfg(test)]
217mod tests {
218 use super::*;
219 use crate::event::InMemoryJournal;
220 use atomr_agents_callable::{Callable, FnCallable};
221 use std::sync::atomic::{AtomicU32, Ordering};
222
223 fn echo_callable() -> Arc<dyn Callable> {
224 Arc::new(FnCallable::labeled("echo", |v: Value, _ctx| async move { Ok(v) }))
225 }
226
227 fn counter_callable(state: Arc<AtomicU32>) -> Arc<dyn Callable> {
228 Arc::new(FnCallable::labeled("counter", move |_v: Value, _ctx| {
229 let s = state.clone();
230 async move { Ok(serde_json::json!(s.fetch_add(1, Ordering::SeqCst))) }
231 }))
232 }
233
234 #[tokio::test]
235 async fn new_constructor_builds_runner() {
236 let dag: Dag<Step> = Dag::builder("a").step("a", Step::invoke(echo_callable())).build();
237 let r = WorkflowRunner::new(WorkflowId::from("wf-new"), dag, Arc::new(InMemoryJournal::new()));
238 let out = r.run(serde_json::json!({"k": 7})).await.unwrap();
239 assert_eq!(out, serde_json::json!({"k": 7}));
240 }
241
242 #[tokio::test]
243 async fn happy_path_runs_topo_order() {
244 let dag: Dag<Step> = Dag::builder("a")
245 .step("a", Step::invoke(echo_callable()))
246 .step("b", Step::invoke(echo_callable()))
247 .edge("a", "b")
248 .build();
249 let r = WorkflowRunner {
250 id: WorkflowId::from("wf-1"),
251 dag,
252 journal: Arc::new(InMemoryJournal::new()),
253 };
254 let out = r.run(serde_json::json!({"x": 1})).await.unwrap();
255 assert_eq!(out, serde_json::json!({"x": 1}));
256 }
257
258 #[tokio::test]
259 async fn parallel_all_collects_outputs() {
260 let dag: Dag<Step> = Dag::builder("p")
261 .step(
262 "p",
263 Step::Parallel {
264 steps: vec![StepId::new("a"), StepId::new("b")],
265 join: JoinStrategy::All,
266 },
267 )
268 .step("a", Step::invoke(echo_callable()))
269 .step("b", Step::invoke(echo_callable()))
270 .build();
271 let r = WorkflowRunner {
272 id: WorkflowId::from("wf-2"),
273 dag,
274 journal: Arc::new(InMemoryJournal::new()),
275 };
276 let out = r.run(serde_json::json!(5)).await.unwrap();
277 assert!(!out.is_null());
281 }
282
283 #[tokio::test]
284 async fn replay_resumes_after_partial_failure() {
285 let journal: Arc<dyn Journal> = Arc::new(InMemoryJournal::new());
287 let counter = Arc::new(AtomicU32::new(0));
288 let id = WorkflowId::from("wf-resume");
289
290 let dag1: Dag<Step> = Dag::builder("a")
291 .step("a", Step::invoke(counter_callable(counter.clone())))
292 .step(
293 "b",
294 Step::invoke(Arc::new(FnCallable::labeled("boom", |_v: Value, _ctx| async {
295 Err(atomr_agents_core::AgentError::Workflow(
296 "first run b fails".into(),
297 ))
298 }))),
299 )
300 .edge("a", "b")
301 .build();
302 let r1 = WorkflowRunner {
303 id: id.clone(),
304 dag: dag1,
305 journal: journal.clone(),
306 };
307 assert!(r1.run(serde_json::json!({})).await.is_err());
308
309 let history = journal.replay(&id).await.unwrap();
316 let clean = InMemoryJournal::new();
317 for e in &history {
318 if !matches!(
319 e,
320 WorkflowEvent::Terminated { ok: false } | WorkflowEvent::StepFailed { .. }
321 ) {
322 clean.append(&id, e.clone()).await.unwrap();
323 }
324 }
325 let dag2: Dag<Step> = Dag::builder("a")
327 .step("a", Step::invoke(counter_callable(counter.clone())))
328 .step("b", Step::invoke(echo_callable()))
329 .edge("a", "b")
330 .build();
331 let r2 = WorkflowRunner {
332 id,
333 dag: dag2,
334 journal: Arc::new(clean),
335 };
336 let out = r2.run(serde_json::json!({"v": 1})).await.unwrap();
337 assert_eq!(counter.load(Ordering::SeqCst), 1);
340 assert_eq!(out, serde_json::json!({"v": 1}));
341 }
342}