atomr-agents-workflow 0.13.0

Workflow actor: DAG of Steps over event-sourced state.
Documentation
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use std::time::Duration;

use atomr_agents_core::{
    AgentError, CallCtx, IterationBudget, MoneyBudget, Result, TimeBudget, TokenBudget, Value, WorkflowId,
};

use crate::dag::{Dag, StepId};
use crate::event::{Journal, WorkflowEvent};
use crate::step::{JoinStrategy, Step};

#[derive(Debug, Clone, Default)]
pub struct WorkflowState {
    pub completed: HashSet<StepId>,
    pub outputs: HashMap<StepId, Value>,
    pub branches: HashMap<StepId, StepId>,
    pub terminated: Option<bool>,
}

impl WorkflowState {
    pub fn fold(events: &[WorkflowEvent]) -> Self {
        let mut s = WorkflowState::default();
        for e in events {
            match e {
                WorkflowEvent::StepCompleted { step_id, output } => {
                    s.completed.insert(step_id.clone());
                    s.outputs.insert(step_id.clone(), output.clone());
                }
                WorkflowEvent::BranchTaken { step_id, chosen } => {
                    s.branches.insert(step_id.clone(), chosen.clone());
                }
                WorkflowEvent::Terminated { ok } => {
                    s.terminated = Some(*ok);
                }
                _ => {}
            }
        }
        s
    }
}

pub struct WorkflowRunner {
    pub id: WorkflowId,
    pub dag: Dag<Step>,
    pub journal: Arc<dyn Journal>,
}

impl WorkflowRunner {
    /// Construct a runner from its three pieces. Equivalent to the
    /// struct literal but exposed as a stable constructor for callers
    /// (notably the Python wrapper) that don't have access to crate-
    /// private types.
    pub fn new(id: WorkflowId, dag: Dag<Step>, journal: Arc<dyn Journal>) -> Self {
        Self { id, dag, journal }
    }

    pub async fn run(&self, input: Value) -> Result<Value> {
        // Resume from journal if state exists.
        let history = self.journal.replay(&self.id).await?;
        let mut state = WorkflowState::fold(&history);

        if let Some(true) = state.terminated {
            // Already done: return last output if present.
            return Ok(self.last_output(&state).unwrap_or(Value::Null));
        }

        let order = self.dag.topo_sort()?;
        let mut current_input = input;
        for step_id in order {
            if state.completed.contains(&step_id) {
                continue;
            }
            self.journal
                .append(
                    &self.id,
                    WorkflowEvent::StepStarted {
                        step_id: step_id.clone(),
                        idempotency_key: format!("{}/{}", self.id.as_str(), step_id.as_str()),
                    },
                )
                .await?;
            let step = self
                .dag
                .steps
                .get(&step_id)
                .ok_or_else(|| AgentError::Workflow(format!("unknown step {}", step_id.as_str())))?;
            match self.exec_step(step, &current_input, &mut state).await {
                Ok(out) => {
                    self.journal
                        .append(
                            &self.id,
                            WorkflowEvent::StepCompleted {
                                step_id: step_id.clone(),
                                output: out.clone(),
                            },
                        )
                        .await?;
                    state.completed.insert(step_id.clone());
                    state.outputs.insert(step_id.clone(), out.clone());
                    current_input = out;
                }
                Err(e) => {
                    self.journal
                        .append(
                            &self.id,
                            WorkflowEvent::StepFailed {
                                step_id: step_id.clone(),
                                error: e.to_string(),
                            },
                        )
                        .await?;
                    self.journal
                        .append(&self.id, WorkflowEvent::Terminated { ok: false })
                        .await?;
                    return Err(e);
                }
            }
        }
        self.journal
            .append(&self.id, WorkflowEvent::Terminated { ok: true })
            .await?;
        Ok(self.last_output(&state).unwrap_or(Value::Null))
    }

    fn last_output(&self, state: &WorkflowState) -> Option<Value> {
        // Pick output of the topo-last completed step.
        self.dag.topo_sort().ok().and_then(|order| {
            order
                .into_iter()
                .rev()
                .find_map(|id| state.outputs.get(&id).cloned())
        })
    }

    async fn exec_step(&self, step: &Step, input: &Value, state: &mut WorkflowState) -> Result<Value> {
        match step {
            Step::Invoke { callable, mapping: _ } => {
                let ctx = default_call_ctx();
                callable.call(input.clone(), ctx).await
            }
            Step::Branch {
                predicate,
                if_true,
                if_false,
            } => {
                let chosen = if predicate.evaluate(input) {
                    if_true.clone()
                } else {
                    if_false.clone()
                };
                state.branches.insert(StepId::new("__branch__"), chosen.clone());
                Ok(serde_json::json!({"branch": chosen.as_str()}))
            }
            Step::Parallel { steps, join } => {
                let mut handles = Vec::new();
                for sid in steps {
                    let s =
                        self.dag.steps.get(sid).ok_or_else(|| {
                            AgentError::Workflow(format!("parallel: unknown {}", sid.as_str()))
                        })?;
                    if let Step::Invoke { callable, .. } = s {
                        let c = callable.clone();
                        let inp = input.clone();
                        handles.push(tokio::spawn(async move { c.call(inp, default_call_ctx()).await }));
                    } else {
                        return Err(AgentError::Workflow(
                            "parallel currently supports only Invoke children".into(),
                        ));
                    }
                }
                let mut outs = Vec::new();
                let mut first_ok = None;
                for h in handles {
                    match h.await {
                        Ok(Ok(v)) => {
                            if first_ok.is_none() {
                                first_ok = Some(v.clone());
                            }
                            outs.push(v);
                        }
                        Ok(Err(e)) => match join {
                            JoinStrategy::All => return Err(e),
                            JoinStrategy::Any => continue,
                        },
                        Err(e) => return Err(AgentError::Workflow(e.to_string())),
                    }
                }
                match join {
                    JoinStrategy::All => Ok(serde_json::json!(outs)),
                    JoinStrategy::Any => Ok(first_ok.unwrap_or(Value::Null)),
                }
            }
            Step::Loop { .. } | Step::Map { .. } | Step::Human { .. } => {
                // v0: stub; full support lands in Phase 7 alongside
                // harness integration. Returns the input unchanged
                // so a pinned workflow that doesn't exercise these
                // variants still runs.
                Ok(input.clone())
            }
        }
    }
}

fn default_call_ctx() -> CallCtx {
    CallCtx {
        agent_id: None,
        tokens: TokenBudget::new(8192),
        time: TimeBudget::new(Duration::from_secs(60)),
        money: MoneyBudget::from_usd(1.0),
        iterations: IterationBudget::new(16),
        trace: vec![],
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::event::InMemoryJournal;
    use atomr_agents_callable::{Callable, FnCallable};
    use std::sync::atomic::{AtomicU32, Ordering};

    fn echo_callable() -> Arc<dyn Callable> {
        Arc::new(FnCallable::labeled("echo", |v: Value, _ctx| async move { Ok(v) }))
    }

    fn counter_callable(state: Arc<AtomicU32>) -> Arc<dyn Callable> {
        Arc::new(FnCallable::labeled("counter", move |_v: Value, _ctx| {
            let s = state.clone();
            async move { Ok(serde_json::json!(s.fetch_add(1, Ordering::SeqCst))) }
        }))
    }

    #[tokio::test]
    async fn new_constructor_builds_runner() {
        let dag: Dag<Step> = Dag::builder("a").step("a", Step::invoke(echo_callable())).build();
        let r = WorkflowRunner::new(WorkflowId::from("wf-new"), dag, Arc::new(InMemoryJournal::new()));
        let out = r.run(serde_json::json!({"k": 7})).await.unwrap();
        assert_eq!(out, serde_json::json!({"k": 7}));
    }

    #[tokio::test]
    async fn happy_path_runs_topo_order() {
        let dag: Dag<Step> = Dag::builder("a")
            .step("a", Step::invoke(echo_callable()))
            .step("b", Step::invoke(echo_callable()))
            .edge("a", "b")
            .build();
        let r = WorkflowRunner {
            id: WorkflowId::from("wf-1"),
            dag,
            journal: Arc::new(InMemoryJournal::new()),
        };
        let out = r.run(serde_json::json!({"x": 1})).await.unwrap();
        assert_eq!(out, serde_json::json!({"x": 1}));
    }

    #[tokio::test]
    async fn parallel_all_collects_outputs() {
        let dag: Dag<Step> = Dag::builder("p")
            .step(
                "p",
                Step::Parallel {
                    steps: vec![StepId::new("a"), StepId::new("b")],
                    join: JoinStrategy::All,
                },
            )
            .step("a", Step::invoke(echo_callable()))
            .step("b", Step::invoke(echo_callable()))
            .build();
        let r = WorkflowRunner {
            id: WorkflowId::from("wf-2"),
            dag,
            journal: Arc::new(InMemoryJournal::new()),
        };
        let out = r.run(serde_json::json!(5)).await.unwrap();
        // Parallel happens only when the parent step `p` is the one
        // executed; the other steps still appear in topo order. We
        // accept any non-null output to keep the test scope tight.
        assert!(!out.is_null());
    }

    #[tokio::test]
    async fn replay_resumes_after_partial_failure() {
        // First run: step "a" succeeds, "b" fails.
        let journal: Arc<dyn Journal> = Arc::new(InMemoryJournal::new());
        let counter = Arc::new(AtomicU32::new(0));
        let id = WorkflowId::from("wf-resume");

        let dag1: Dag<Step> = Dag::builder("a")
            .step("a", Step::invoke(counter_callable(counter.clone())))
            .step(
                "b",
                Step::invoke(Arc::new(FnCallable::labeled("boom", |_v: Value, _ctx| async {
                    Err(atomr_agents_core::AgentError::Workflow(
                        "first run b fails".into(),
                    ))
                }))),
            )
            .edge("a", "b")
            .build();
        let r1 = WorkflowRunner {
            id: id.clone(),
            dag: dag1,
            journal: journal.clone(),
        };
        assert!(r1.run(serde_json::json!({})).await.is_err());

        // The first run terminated with ok=false. For replay-resume we
        // only treat the workflow as "done" when terminated=true; since
        // it's false we DO want to retry. Adjust journal: drop the
        // Terminated{false} so resume re-runs from b.
        // (In a real system, retry policy would do this filtering;
        // here we assert by replaying *only* the successful events.)
        let history = journal.replay(&id).await.unwrap();
        let clean = InMemoryJournal::new();
        for e in &history {
            if !matches!(
                e,
                WorkflowEvent::Terminated { ok: false } | WorkflowEvent::StepFailed { .. }
            ) {
                clean.append(&id, e.clone()).await.unwrap();
            }
        }
        // Second run with new dag where b succeeds.
        let dag2: Dag<Step> = Dag::builder("a")
            .step("a", Step::invoke(counter_callable(counter.clone())))
            .step("b", Step::invoke(echo_callable()))
            .edge("a", "b")
            .build();
        let r2 = WorkflowRunner {
            id,
            dag: dag2,
            journal: Arc::new(clean),
        };
        let out = r2.run(serde_json::json!({"v": 1})).await.unwrap();
        // Counter should be at 1, *not* incremented again, because
        // step "a" was replayed-as-completed.
        assert_eq!(counter.load(Ordering::SeqCst), 1);
        assert_eq!(out, serde_json::json!({"v": 1}));
    }
}