use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use std::time::Duration;
use atomr_agents_core::{
AgentError, CallCtx, IterationBudget, MoneyBudget, Result, TimeBudget, TokenBudget, Value, WorkflowId,
};
use crate::dag::{Dag, StepId};
use crate::event::{Journal, WorkflowEvent};
use crate::step::{JoinStrategy, Step};
#[derive(Debug, Clone, Default)]
pub struct WorkflowState {
pub completed: HashSet<StepId>,
pub outputs: HashMap<StepId, Value>,
pub branches: HashMap<StepId, StepId>,
pub terminated: Option<bool>,
}
impl WorkflowState {
pub fn fold(events: &[WorkflowEvent]) -> Self {
let mut s = WorkflowState::default();
for e in events {
match e {
WorkflowEvent::StepCompleted { step_id, output } => {
s.completed.insert(step_id.clone());
s.outputs.insert(step_id.clone(), output.clone());
}
WorkflowEvent::BranchTaken { step_id, chosen } => {
s.branches.insert(step_id.clone(), chosen.clone());
}
WorkflowEvent::Terminated { ok } => {
s.terminated = Some(*ok);
}
_ => {}
}
}
s
}
}
pub struct WorkflowRunner {
pub id: WorkflowId,
pub dag: Dag<Step>,
pub journal: Arc<dyn Journal>,
}
impl WorkflowRunner {
pub fn new(id: WorkflowId, dag: Dag<Step>, journal: Arc<dyn Journal>) -> Self {
Self { id, dag, journal }
}
pub async fn run(&self, input: Value) -> Result<Value> {
let history = self.journal.replay(&self.id).await?;
let mut state = WorkflowState::fold(&history);
if let Some(true) = state.terminated {
return Ok(self.last_output(&state).unwrap_or(Value::Null));
}
let order = self.dag.topo_sort()?;
let mut current_input = input;
for step_id in order {
if state.completed.contains(&step_id) {
continue;
}
self.journal
.append(
&self.id,
WorkflowEvent::StepStarted {
step_id: step_id.clone(),
idempotency_key: format!("{}/{}", self.id.as_str(), step_id.as_str()),
},
)
.await?;
let step = self
.dag
.steps
.get(&step_id)
.ok_or_else(|| AgentError::Workflow(format!("unknown step {}", step_id.as_str())))?;
match self.exec_step(step, ¤t_input, &mut state).await {
Ok(out) => {
self.journal
.append(
&self.id,
WorkflowEvent::StepCompleted {
step_id: step_id.clone(),
output: out.clone(),
},
)
.await?;
state.completed.insert(step_id.clone());
state.outputs.insert(step_id.clone(), out.clone());
current_input = out;
}
Err(e) => {
self.journal
.append(
&self.id,
WorkflowEvent::StepFailed {
step_id: step_id.clone(),
error: e.to_string(),
},
)
.await?;
self.journal
.append(&self.id, WorkflowEvent::Terminated { ok: false })
.await?;
return Err(e);
}
}
}
self.journal
.append(&self.id, WorkflowEvent::Terminated { ok: true })
.await?;
Ok(self.last_output(&state).unwrap_or(Value::Null))
}
fn last_output(&self, state: &WorkflowState) -> Option<Value> {
self.dag.topo_sort().ok().and_then(|order| {
order
.into_iter()
.rev()
.find_map(|id| state.outputs.get(&id).cloned())
})
}
async fn exec_step(&self, step: &Step, input: &Value, state: &mut WorkflowState) -> Result<Value> {
match step {
Step::Invoke { callable, mapping: _ } => {
let ctx = default_call_ctx();
callable.call(input.clone(), ctx).await
}
Step::Branch {
predicate,
if_true,
if_false,
} => {
let chosen = if predicate.evaluate(input) {
if_true.clone()
} else {
if_false.clone()
};
state.branches.insert(StepId::new("__branch__"), chosen.clone());
Ok(serde_json::json!({"branch": chosen.as_str()}))
}
Step::Parallel { steps, join } => {
let mut handles = Vec::new();
for sid in steps {
let s =
self.dag.steps.get(sid).ok_or_else(|| {
AgentError::Workflow(format!("parallel: unknown {}", sid.as_str()))
})?;
if let Step::Invoke { callable, .. } = s {
let c = callable.clone();
let inp = input.clone();
handles.push(tokio::spawn(async move { c.call(inp, default_call_ctx()).await }));
} else {
return Err(AgentError::Workflow(
"parallel currently supports only Invoke children".into(),
));
}
}
let mut outs = Vec::new();
let mut first_ok = None;
for h in handles {
match h.await {
Ok(Ok(v)) => {
if first_ok.is_none() {
first_ok = Some(v.clone());
}
outs.push(v);
}
Ok(Err(e)) => match join {
JoinStrategy::All => return Err(e),
JoinStrategy::Any => continue,
},
Err(e) => return Err(AgentError::Workflow(e.to_string())),
}
}
match join {
JoinStrategy::All => Ok(serde_json::json!(outs)),
JoinStrategy::Any => Ok(first_ok.unwrap_or(Value::Null)),
}
}
Step::Loop { .. } | Step::Map { .. } | Step::Human { .. } => {
Ok(input.clone())
}
}
}
}
fn default_call_ctx() -> CallCtx {
CallCtx {
agent_id: None,
tokens: TokenBudget::new(8192),
time: TimeBudget::new(Duration::from_secs(60)),
money: MoneyBudget::from_usd(1.0),
iterations: IterationBudget::new(16),
trace: vec![],
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::event::InMemoryJournal;
use atomr_agents_callable::{Callable, FnCallable};
use std::sync::atomic::{AtomicU32, Ordering};
fn echo_callable() -> Arc<dyn Callable> {
Arc::new(FnCallable::labeled("echo", |v: Value, _ctx| async move { Ok(v) }))
}
fn counter_callable(state: Arc<AtomicU32>) -> Arc<dyn Callable> {
Arc::new(FnCallable::labeled("counter", move |_v: Value, _ctx| {
let s = state.clone();
async move { Ok(serde_json::json!(s.fetch_add(1, Ordering::SeqCst))) }
}))
}
#[tokio::test]
async fn new_constructor_builds_runner() {
let dag: Dag<Step> = Dag::builder("a").step("a", Step::invoke(echo_callable())).build();
let r = WorkflowRunner::new(WorkflowId::from("wf-new"), dag, Arc::new(InMemoryJournal::new()));
let out = r.run(serde_json::json!({"k": 7})).await.unwrap();
assert_eq!(out, serde_json::json!({"k": 7}));
}
#[tokio::test]
async fn happy_path_runs_topo_order() {
let dag: Dag<Step> = Dag::builder("a")
.step("a", Step::invoke(echo_callable()))
.step("b", Step::invoke(echo_callable()))
.edge("a", "b")
.build();
let r = WorkflowRunner {
id: WorkflowId::from("wf-1"),
dag,
journal: Arc::new(InMemoryJournal::new()),
};
let out = r.run(serde_json::json!({"x": 1})).await.unwrap();
assert_eq!(out, serde_json::json!({"x": 1}));
}
#[tokio::test]
async fn parallel_all_collects_outputs() {
let dag: Dag<Step> = Dag::builder("p")
.step(
"p",
Step::Parallel {
steps: vec![StepId::new("a"), StepId::new("b")],
join: JoinStrategy::All,
},
)
.step("a", Step::invoke(echo_callable()))
.step("b", Step::invoke(echo_callable()))
.build();
let r = WorkflowRunner {
id: WorkflowId::from("wf-2"),
dag,
journal: Arc::new(InMemoryJournal::new()),
};
let out = r.run(serde_json::json!(5)).await.unwrap();
assert!(!out.is_null());
}
#[tokio::test]
async fn replay_resumes_after_partial_failure() {
let journal: Arc<dyn Journal> = Arc::new(InMemoryJournal::new());
let counter = Arc::new(AtomicU32::new(0));
let id = WorkflowId::from("wf-resume");
let dag1: Dag<Step> = Dag::builder("a")
.step("a", Step::invoke(counter_callable(counter.clone())))
.step(
"b",
Step::invoke(Arc::new(FnCallable::labeled("boom", |_v: Value, _ctx| async {
Err(atomr_agents_core::AgentError::Workflow(
"first run b fails".into(),
))
}))),
)
.edge("a", "b")
.build();
let r1 = WorkflowRunner {
id: id.clone(),
dag: dag1,
journal: journal.clone(),
};
assert!(r1.run(serde_json::json!({})).await.is_err());
let history = journal.replay(&id).await.unwrap();
let clean = InMemoryJournal::new();
for e in &history {
if !matches!(
e,
WorkflowEvent::Terminated { ok: false } | WorkflowEvent::StepFailed { .. }
) {
clean.append(&id, e.clone()).await.unwrap();
}
}
let dag2: Dag<Step> = Dag::builder("a")
.step("a", Step::invoke(counter_callable(counter.clone())))
.step("b", Step::invoke(echo_callable()))
.edge("a", "b")
.build();
let r2 = WorkflowRunner {
id,
dag: dag2,
journal: Arc::new(clean),
};
let out = r2.run(serde_json::json!({"v": 1})).await.unwrap();
assert_eq!(counter.load(Ordering::SeqCst), 1);
assert_eq!(out, serde_json::json!({"v": 1}));
}
}