Skip to main content

brainwires_agents/
workflow.rs

1//! Workflow Graph Builder — Declarative DAG-based workflow pipelines
2//!
3//! Provides a [`WorkflowBuilder`] API for defining multi-step workflows as
4//! directed acyclic graphs (DAGs).  Workflows compile down to `TaskSpec`
5//! vectors and execute via the existing `TaskOrchestrator`.
6//!
7//! # Example
8//!
9//! ```rust,ignore
10//! use brainwires_agents::workflow::{WorkflowBuilder, WorkflowContext};
11//!
12//! let workflow = WorkflowBuilder::new("review-pipeline")
13//!     .node("fetch", |ctx| Box::pin(async move {
14//!         ctx.set("code", serde_json::json!("fn main() {}")).await;
15//!         Ok(serde_json::json!({"status": "fetched"}))
16//!     }))
17//!     .node("lint", |ctx| Box::pin(async move {
18//!         let code = ctx.get("code").await;
19//!         Ok(serde_json::json!({"lint": "passed"}))
20//!     }))
21//!     .node("review", |ctx| Box::pin(async move {
22//!         Ok(serde_json::json!({"review": "approved"}))
23//!     }))
24//!     .edge("fetch", "lint")
25//!     .edge("fetch", "review")   // lint + review run in parallel
26//!     .edge("lint", "summarize")
27//!     .edge("review", "summarize")
28//!     .node("summarize", |ctx| Box::pin(async move {
29//!         Ok(serde_json::json!({"summary": "all good"}))
30//!     }))
31//!     .build()
32//!     .unwrap();
33//!
34//! let results = workflow.run().await.unwrap();
35//! ```
36
37use std::collections::{HashMap, HashSet};
38use std::future::Future;
39use std::pin::Pin;
40use std::sync::Arc;
41
42use anyhow::{Result, anyhow};
43use petgraph::algo::is_cyclic_directed;
44use petgraph::graph::{DiGraph, NodeIndex};
45use serde_json::Value;
46use tokio::sync::RwLock;
47
48// ── Shared workflow state ────────────────────────────────────────────────────
49
50/// Shared state accessible to all workflow nodes during execution.
51///
52/// Nodes read and write values to this shared map to pass data between
53/// pipeline stages.
54#[derive(Clone)]
55pub struct WorkflowContext {
56    state: Arc<RwLock<HashMap<String, Value>>>,
57    /// Per-node results, keyed by node name.
58    results: Arc<RwLock<HashMap<String, Value>>>,
59}
60
61impl WorkflowContext {
62    /// Create a new empty context.
63    pub fn new() -> Self {
64        Self {
65            state: Arc::new(RwLock::new(HashMap::new())),
66            results: Arc::new(RwLock::new(HashMap::new())),
67        }
68    }
69
70    /// Set a value in the shared state.
71    pub async fn set(&self, key: impl Into<String>, value: Value) {
72        self.state.write().await.insert(key.into(), value);
73    }
74
75    /// Get a value from the shared state.
76    pub async fn get(&self, key: &str) -> Option<Value> {
77        self.state.read().await.get(key).cloned()
78    }
79
80    /// Remove a value from the shared state.
81    pub async fn remove(&self, key: &str) -> Option<Value> {
82        self.state.write().await.remove(key)
83    }
84
85    /// Get the result of a previously completed node.
86    pub async fn node_result(&self, node_name: &str) -> Option<Value> {
87        self.results.read().await.get(node_name).cloned()
88    }
89
90    /// Store a node's result (called internally by the executor).
91    async fn store_result(&self, node_name: impl Into<String>, value: Value) {
92        self.results.write().await.insert(node_name.into(), value);
93    }
94
95    /// Get all results as a map.
96    pub async fn all_results(&self) -> HashMap<String, Value> {
97        self.results.read().await.clone()
98    }
99}
100
101impl Default for WorkflowContext {
102    fn default() -> Self {
103        Self::new()
104    }
105}
106
107// ── Node function type ───────────────────────────────────────────────────────
108
109/// A boxed async function that a workflow node executes.
110pub type NodeFn = Box<
111    dyn Fn(WorkflowContext) -> Pin<Box<dyn Future<Output = Result<Value>> + Send>> + Send + Sync,
112>;
113
114/// A conditional edge function that returns the name of the next node(s)
115/// to activate based on the current node's result.
116pub type ConditionalFn = Box<dyn Fn(&Value) -> Vec<String> + Send + Sync>;
117
118// ── Internal node representation ─────────────────────────────────────────────
119
120struct WorkflowNode {
121    name: String,
122    handler: NodeFn,
123}
124
125enum EdgeType {
126    /// Always-active edge from `from` to `to`.
127    Direct { from: String, to: String },
128    /// Conditional edge: `evaluator` returns which downstream nodes to activate.
129    Conditional {
130        from: String,
131        evaluator: ConditionalFn,
132    },
133}
134
135// ── WorkflowBuilder ──────────────────────────────────────────────────────────
136
137/// Builder for constructing workflow DAGs.
138///
139/// Add nodes with [`node`][Self::node], wire them with [`edge`][Self::edge]
140/// or [`conditional`][Self::conditional], then call [`build`][Self::build].
141pub struct WorkflowBuilder {
142    name: String,
143    nodes: Vec<WorkflowNode>,
144    node_names: HashSet<String>,
145    edges: Vec<EdgeType>,
146}
147
148impl WorkflowBuilder {
149    /// Create a new workflow builder with the given name.
150    pub fn new(name: impl Into<String>) -> Self {
151        Self {
152            name: name.into(),
153            nodes: Vec::new(),
154            node_names: HashSet::new(),
155            edges: Vec::new(),
156        }
157    }
158
159    /// Add a node to the workflow.
160    ///
161    /// The handler receives a [`WorkflowContext`] and returns `Result<Value>`.
162    /// Nodes with no incoming edges are considered entry points and run first.
163    pub fn node<F, Fut>(mut self, name: impl Into<String>, handler: F) -> Self
164    where
165        F: Fn(WorkflowContext) -> Fut + Send + Sync + 'static,
166        Fut: Future<Output = Result<Value>> + Send + 'static,
167    {
168        let name = name.into();
169        self.node_names.insert(name.clone());
170        self.nodes.push(WorkflowNode {
171            name,
172            handler: Box::new(move |ctx| Box::pin(handler(ctx))),
173        });
174        self
175    }
176
177    /// Add a direct edge from one node to another.
178    ///
179    /// The `to` node will only execute after `from` completes successfully.
180    /// Multiple edges into the same node create a join (all predecessors must
181    /// complete).
182    pub fn edge(mut self, from: impl Into<String>, to: impl Into<String>) -> Self {
183        self.edges.push(EdgeType::Direct {
184            from: from.into(),
185            to: to.into(),
186        });
187        self
188    }
189
190    /// Add a conditional edge from a node.
191    ///
192    /// After `from` completes, `evaluator` is called with the node's result
193    /// value. It returns a list of downstream node names to activate. Nodes
194    /// not in the returned list are skipped (treated as completed with a
195    /// null result).
196    pub fn conditional<F>(mut self, from: impl Into<String>, evaluator: F) -> Self
197    where
198        F: Fn(&Value) -> Vec<String> + Send + Sync + 'static,
199    {
200        self.edges.push(EdgeType::Conditional {
201            from: from.into(),
202            evaluator: Box::new(evaluator),
203        });
204        self
205    }
206
207    /// Validate and build the workflow.
208    ///
209    /// Returns an error if:
210    /// - An edge references a node that does not exist
211    /// - The graph contains a cycle
212    /// - There are no nodes
213    pub fn build(self) -> Result<Workflow> {
214        if self.nodes.is_empty() {
215            return Err(anyhow!("Workflow '{}' has no nodes", self.name));
216        }
217
218        // Build petgraph for validation
219        let mut graph = DiGraph::<String, ()>::new();
220        let mut name_to_idx: HashMap<String, NodeIndex> = HashMap::new();
221
222        for node in &self.nodes {
223            let idx = graph.add_node(node.name.clone());
224            name_to_idx.insert(node.name.clone(), idx);
225        }
226
227        // Validate and collect edges
228        let mut direct_edges: Vec<(String, String)> = Vec::new();
229        let mut conditional_edges: Vec<(String, ConditionalFn)> = Vec::new();
230
231        for edge in self.edges {
232            match edge {
233                EdgeType::Direct { from, to } => {
234                    if !name_to_idx.contains_key(&from) {
235                        return Err(anyhow!("Edge references unknown source node '{}'", from));
236                    }
237                    if !name_to_idx.contains_key(&to) {
238                        return Err(anyhow!("Edge references unknown target node '{}'", to));
239                    }
240                    graph.add_edge(name_to_idx[&from], name_to_idx[&to], ());
241                    direct_edges.push((from, to));
242                }
243                EdgeType::Conditional { from, evaluator } => {
244                    if !name_to_idx.contains_key(&from) {
245                        return Err(anyhow!(
246                            "Conditional edge references unknown source node '{}'",
247                            from
248                        ));
249                    }
250                    conditional_edges.push((from, evaluator));
251                }
252            }
253        }
254
255        if is_cyclic_directed(&graph) {
256            return Err(anyhow!("Workflow '{}' contains a cycle", self.name));
257        }
258
259        // Identify entry nodes (no incoming direct edges)
260        let targets: HashSet<&str> = direct_edges.iter().map(|(_, t)| t.as_str()).collect();
261        let entry_nodes: Vec<String> = self
262            .nodes
263            .iter()
264            .map(|n| &n.name)
265            .filter(|n| !targets.contains(n.as_str()))
266            .cloned()
267            .collect();
268
269        if entry_nodes.is_empty() {
270            return Err(anyhow!(
271                "Workflow '{}' has no entry nodes (every node has an incoming edge)",
272                self.name
273            ));
274        }
275
276        // Build handler map
277        let mut handlers: HashMap<String, NodeFn> = HashMap::new();
278        for node in self.nodes {
279            handlers.insert(node.name, node.handler);
280        }
281
282        Ok(Workflow {
283            name: self.name,
284            handlers: Arc::new(handlers),
285            direct_edges,
286            conditional_edges: Arc::new(conditional_edges),
287            entry_nodes,
288            all_nodes: self.node_names,
289        })
290    }
291}
292
293// ── Compiled Workflow ────────────────────────────────────────────────────────
294
295/// A compiled workflow ready for execution.
296///
297/// Created by [`WorkflowBuilder::build`]. Execute with [`run`][Self::run].
298///
299/// Note: `Debug` is implemented manually because the handler functions
300/// are not `Debug`.
301pub struct Workflow {
302    name: String,
303    handlers: Arc<HashMap<String, NodeFn>>,
304    direct_edges: Vec<(String, String)>,
305    conditional_edges: Arc<Vec<(String, ConditionalFn)>>,
306    entry_nodes: Vec<String>,
307    all_nodes: HashSet<String>,
308}
309
310impl std::fmt::Debug for Workflow {
311    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
312        f.debug_struct("Workflow")
313            .field("name", &self.name)
314            .field("entry_nodes", &self.entry_nodes)
315            .field("all_nodes", &self.all_nodes)
316            .field("direct_edges", &self.direct_edges)
317            .field("handlers", &format!("<{} handlers>", self.handlers.len()))
318            .finish()
319    }
320}
321
322/// Result of a completed workflow execution.
323#[derive(Debug, Clone)]
324pub struct WorkflowResult {
325    /// Workflow name.
326    pub name: String,
327    /// Whether all executed nodes succeeded.
328    pub success: bool,
329    /// Per-node results (only includes nodes that actually ran).
330    pub node_results: HashMap<String, Value>,
331    /// Nodes that were skipped (via conditional edges).
332    pub skipped_nodes: Vec<String>,
333    /// Nodes that failed, with their error messages.
334    pub failed_nodes: HashMap<String, String>,
335}
336
337impl Workflow {
338    /// Execute the workflow.
339    ///
340    /// Entry nodes (no incoming edges) run first. Downstream nodes run as
341    /// their dependencies complete. Nodes sharing the same set of completed
342    /// dependencies run concurrently via [`tokio::spawn`].
343    pub async fn run(&self) -> Result<WorkflowResult> {
344        self.run_with_context(WorkflowContext::new()).await
345    }
346
347    /// Execute the workflow with a pre-populated context.
348    pub async fn run_with_context(&self, ctx: WorkflowContext) -> Result<WorkflowResult> {
349        let completed: Arc<RwLock<HashSet<String>>> = Arc::new(RwLock::new(HashSet::new()));
350        let failed: Arc<RwLock<HashMap<String, String>>> = Arc::new(RwLock::new(HashMap::new()));
351        let skipped: Arc<RwLock<HashSet<String>>> = Arc::new(RwLock::new(HashSet::new()));
352
353        // Build dependency map: node -> set of predecessors
354        let mut deps: HashMap<String, HashSet<String>> = HashMap::new();
355        for node in &self.all_nodes {
356            deps.insert(node.clone(), HashSet::new());
357        }
358        for (from, to) in &self.direct_edges {
359            deps.entry(to.clone()).or_default().insert(from.clone());
360        }
361
362        loop {
363            // First, propagate failures: any pending node whose predecessor
364            // failed gets skipped.
365            {
366                let done = completed.read().await;
367                let fail = failed.read().await;
368                let skip = skipped.read().await;
369                let mut to_skip = Vec::new();
370                for (name, predecessors) in &deps {
371                    if done.contains(name) || fail.contains_key(name) || skip.contains(name) {
372                        continue;
373                    }
374                    if predecessors.iter().any(|p| fail.contains_key(p)) {
375                        to_skip.push(name.clone());
376                    }
377                }
378                drop(done);
379                drop(fail);
380                drop(skip);
381                if !to_skip.is_empty() {
382                    let mut skip_guard = skipped.write().await;
383                    for name in to_skip {
384                        skip_guard.insert(name);
385                    }
386                }
387            }
388
389            let ready: Vec<String> = {
390                let done = completed.read().await;
391                let fail = failed.read().await;
392                let skip = skipped.read().await;
393                deps.iter()
394                    .filter(|(name, predecessors)| {
395                        !done.contains(*name)
396                            && !fail.contains_key(*name)
397                            && !skip.contains(*name)
398                            && predecessors
399                                .iter()
400                                .all(|p| done.contains(p) || skip.contains(p))
401                    })
402                    .map(|(name, _)| name.clone())
403                    .collect()
404            };
405
406            if ready.is_empty() {
407                break;
408            }
409
410            // Spawn all ready nodes concurrently
411            let mut handles = Vec::new();
412            for name in ready {
413                let ctx = ctx.clone();
414                let handlers = Arc::clone(&self.handlers);
415                let completed = Arc::clone(&completed);
416                let failed = Arc::clone(&failed);
417                let conditional_edges = Arc::clone(&self.conditional_edges);
418                let node_name = name.clone();
419
420                let handle = tokio::spawn(async move {
421                    if let Some(handler) = handlers.get(&node_name) {
422                        match handler(ctx.clone()).await {
423                            Ok(result) => {
424                                ctx.store_result(&node_name, result.clone()).await;
425
426                                // Evaluate conditional edges from this node and store
427                                // the activated set for the main loop to process.
428                                for (from, evaluator) in conditional_edges.iter() {
429                                    if from == &node_name {
430                                        let activated = evaluator(&result);
431                                        ctx.set(
432                                            format!("__conditional_activated_{}", node_name),
433                                            serde_json::json!(activated),
434                                        )
435                                        .await;
436                                    }
437                                }
438
439                                completed.write().await.insert(node_name);
440                            }
441                            Err(e) => {
442                                failed.write().await.insert(node_name, e.to_string());
443                            }
444                        }
445                    } else {
446                        failed
447                            .write()
448                            .await
449                            .insert(node_name, "Handler not found".to_string());
450                    }
451                });
452                handles.push(handle);
453            }
454
455            // Wait for all concurrent nodes to complete
456            for handle in handles {
457                let _ = handle.await;
458            }
459
460            // Process conditional edge results: skip non-activated downstream nodes
461            {
462                let ctx_state = ctx.state.read().await;
463                let mut skip_guard = skipped.write().await;
464                for (from, _) in self.conditional_edges.iter() {
465                    let key = format!("__conditional_activated_{}", from);
466                    if let Some(activated_val) = ctx_state.get(&key)
467                        && let Some(activated) = activated_val.as_array()
468                    {
469                        let activated_set: HashSet<String> = activated
470                            .iter()
471                            .filter_map(|v| v.as_str().map(|s| s.to_string()))
472                            .collect();
473                        // Find all direct-edge targets from this node
474                        for (edge_from, edge_to) in &self.direct_edges {
475                            if edge_from == from && !activated_set.contains(edge_to) {
476                                skip_guard.insert(edge_to.clone());
477                            }
478                        }
479                    }
480                }
481            }
482        }
483
484        let node_results = ctx.all_results().await;
485        let failed_map = failed.read().await.clone();
486        let skipped_vec: Vec<String> = skipped.read().await.iter().cloned().collect();
487        let success = failed_map.is_empty();
488
489        Ok(WorkflowResult {
490            name: self.name.clone(),
491            success,
492            node_results,
493            skipped_nodes: skipped_vec,
494            failed_nodes: failed_map,
495        })
496    }
497
498    /// Get the workflow name.
499    pub fn name(&self) -> &str {
500        &self.name
501    }
502
503    /// Get the entry node names.
504    pub fn entry_nodes(&self) -> &[String] {
505        &self.entry_nodes
506    }
507
508    /// Get all node names.
509    pub fn node_names(&self) -> &HashSet<String> {
510        &self.all_nodes
511    }
512}
513
514// ── Tests ────────────────────────────────────────────────────────────────────
515
516#[cfg(test)]
517mod tests {
518    use super::*;
519
520    #[tokio::test]
521    async fn test_simple_linear_workflow() {
522        let workflow = WorkflowBuilder::new("linear")
523            .node("a", |ctx| {
524                Box::pin(async move {
525                    ctx.set("counter", serde_json::json!(1)).await;
526                    Ok(serde_json::json!({"step": "a"}))
527                })
528            })
529            .node("b", |ctx| {
530                Box::pin(async move {
531                    let val = ctx.get("counter").await.unwrap();
532                    let n = val.as_i64().unwrap();
533                    ctx.set("counter", serde_json::json!(n + 1)).await;
534                    Ok(serde_json::json!({"step": "b"}))
535                })
536            })
537            .edge("a", "b")
538            .build()
539            .unwrap();
540
541        let result = workflow.run().await.unwrap();
542        assert!(result.success);
543        assert_eq!(result.node_results.len(), 2);
544        assert!(result.failed_nodes.is_empty());
545    }
546
547    #[tokio::test]
548    async fn test_parallel_workflow() {
549        let workflow = WorkflowBuilder::new("parallel")
550            .node("start", |_ctx| {
551                Box::pin(async move { Ok(serde_json::json!("started")) })
552            })
553            .node("branch_a", |_ctx| {
554                Box::pin(async move { Ok(serde_json::json!("a_done")) })
555            })
556            .node("branch_b", |_ctx| {
557                Box::pin(async move { Ok(serde_json::json!("b_done")) })
558            })
559            .node("join", |ctx| {
560                Box::pin(async move {
561                    let a = ctx.node_result("branch_a").await;
562                    let b = ctx.node_result("branch_b").await;
563                    Ok(serde_json::json!({"a": a, "b": b}))
564                })
565            })
566            .edge("start", "branch_a")
567            .edge("start", "branch_b")
568            .edge("branch_a", "join")
569            .edge("branch_b", "join")
570            .build()
571            .unwrap();
572
573        let result = workflow.run().await.unwrap();
574        assert!(result.success);
575        assert_eq!(result.node_results.len(), 4);
576    }
577
578    #[tokio::test]
579    async fn test_diamond_workflow() {
580        let workflow = WorkflowBuilder::new("diamond")
581            .node("a", |_| Box::pin(async { Ok(serde_json::json!(1)) }))
582            .node("b", |_| Box::pin(async { Ok(serde_json::json!(2)) }))
583            .node("c", |_| Box::pin(async { Ok(serde_json::json!(3)) }))
584            .node("d", |ctx| {
585                Box::pin(async move {
586                    let b = ctx.node_result("b").await.unwrap();
587                    let c = ctx.node_result("c").await.unwrap();
588                    Ok(serde_json::json!(b.as_i64().unwrap() + c.as_i64().unwrap()))
589                })
590            })
591            .edge("a", "b")
592            .edge("a", "c")
593            .edge("b", "d")
594            .edge("c", "d")
595            .build()
596            .unwrap();
597
598        let result = workflow.run().await.unwrap();
599        assert!(result.success);
600        assert_eq!(result.node_results["d"], serde_json::json!(5));
601    }
602
603    #[tokio::test]
604    async fn test_conditional_workflow() {
605        let workflow = WorkflowBuilder::new("conditional")
606            .node("check", |_| {
607                Box::pin(async { Ok(serde_json::json!({"route": "fast"})) })
608            })
609            .node("fast_path", |_| {
610                Box::pin(async { Ok(serde_json::json!("fast_done")) })
611            })
612            .node("slow_path", |_| {
613                Box::pin(async { Ok(serde_json::json!("slow_done")) })
614            })
615            .edge("check", "fast_path")
616            .edge("check", "slow_path")
617            .conditional("check", |result| {
618                let route = result
619                    .get("route")
620                    .and_then(|v| v.as_str())
621                    .unwrap_or("fast");
622                if route == "fast" {
623                    vec!["fast_path".to_string()]
624                } else {
625                    vec!["slow_path".to_string()]
626                }
627            })
628            .build()
629            .unwrap();
630
631        let result = workflow.run().await.unwrap();
632        assert!(result.success);
633        assert!(result.node_results.contains_key("fast_path"));
634        assert!(result.skipped_nodes.contains(&"slow_path".to_string()));
635    }
636
637    #[tokio::test]
638    async fn test_cycle_detection() {
639        let result = WorkflowBuilder::new("cyclic")
640            .node("a", |_| Box::pin(async { Ok(serde_json::json!(1)) }))
641            .node("b", |_| Box::pin(async { Ok(serde_json::json!(2)) }))
642            .edge("a", "b")
643            .edge("b", "a")
644            .build();
645
646        assert!(result.is_err());
647        assert!(result.unwrap_err().to_string().contains("cycle"));
648    }
649
650    #[tokio::test]
651    async fn test_unknown_node_in_edge() {
652        let result = WorkflowBuilder::new("bad")
653            .node("a", |_| Box::pin(async { Ok(serde_json::json!(1)) }))
654            .edge("a", "nonexistent")
655            .build();
656
657        assert!(result.is_err());
658        assert!(result.unwrap_err().to_string().contains("unknown target"));
659    }
660
661    #[tokio::test]
662    async fn test_empty_workflow() {
663        let result = WorkflowBuilder::new("empty").build();
664        assert!(result.is_err());
665        assert!(result.unwrap_err().to_string().contains("no nodes"));
666    }
667
668    #[tokio::test]
669    async fn test_single_node_workflow() {
670        let workflow = WorkflowBuilder::new("single")
671            .node("only", |_| {
672                Box::pin(async { Ok(serde_json::json!("done")) })
673            })
674            .build()
675            .unwrap();
676
677        let result = workflow.run().await.unwrap();
678        assert!(result.success);
679        assert_eq!(result.node_results.len(), 1);
680    }
681
682    #[tokio::test]
683    async fn test_node_failure_skips_dependents() {
684        let workflow = WorkflowBuilder::new("fail")
685            .node("a", |_| Box::pin(async { Err(anyhow::anyhow!("boom")) }))
686            .node("b", |_| {
687                Box::pin(async { Ok(serde_json::json!("should not run")) })
688            })
689            .edge("a", "b")
690            .build()
691            .unwrap();
692
693        let result = workflow.run().await.unwrap();
694        assert!(!result.success);
695        assert!(result.failed_nodes.contains_key("a"));
696        assert!(result.skipped_nodes.contains(&"b".to_string()));
697    }
698
699    #[tokio::test]
700    async fn test_pre_populated_context() {
701        let ctx = WorkflowContext::new();
702        ctx.set("input", serde_json::json!("hello")).await;
703
704        let workflow = WorkflowBuilder::new("with-ctx")
705            .node("use_input", |ctx| {
706                Box::pin(async move {
707                    let input = ctx.get("input").await.unwrap();
708                    Ok(serde_json::json!({"received": input}))
709                })
710            })
711            .build()
712            .unwrap();
713
714        let result = workflow.run_with_context(ctx).await.unwrap();
715        assert!(result.success);
716        assert_eq!(
717            result.node_results["use_input"],
718            serde_json::json!({"received": "hello"})
719        );
720    }
721}