Skip to main content

echo_orchestration/workflow/
dag.rs

1//! DAG workflow: agents organized as a directed acyclic graph, executed in topological order with independent nodes running concurrently.
2
3use super::{SharedAgent, StepOutput, Workflow, WorkflowOutput, shared_agent};
4use echo_core::agent::Agent;
5use echo_core::error::{AgentError, ReactError, Result};
6use futures::future::BoxFuture;
7use std::collections::{HashMap, HashSet, VecDeque};
8use std::time::Instant;
9use tracing::{debug, info};
10
11/// A node in the DAG
12pub struct DagNode {
13    pub id: String,
14    pub agent: SharedAgent,
15}
16
17/// Directed edge in the DAG (from -> to means to depends on from's output)
18#[derive(Debug, Clone)]
19pub struct DagEdge {
20    pub from: String,
21    pub to: String,
22}
23
24/// DAG workflow: nodes execute in topological order; independent nodes run concurrently.
25///
26/// Each node receives the concatenated outputs of all its predecessors (joined by newlines) as input;
27/// root nodes with zero in-degree receive the workflow's original input.
28///
29/// # Example
30///
31/// ```rust,no_run
32/// use echo_core::agent::{Agent, AgentEvent};
33/// use echo_core::error::Result;
34/// use echo_orchestration::workflow::{DagWorkflow, Workflow};
35/// use futures::future::BoxFuture;
36/// use futures::stream::{self, BoxStream};
37///
38/// # struct DummyAgent {
39/// #     name: String,
40/// # }
41/// #
42/// # impl DummyAgent {
43/// #     fn new(name: impl Into<String>) -> Self {
44/// #         Self { name: name.into() }
45/// #     }
46/// # }
47/// #
48/// # impl Agent for DummyAgent {
49/// #     fn name(&self) -> &str { &self.name }
50/// #     fn model_name(&self) -> &str { "mock-model" }
51/// #     fn system_prompt(&self) -> &str { "You are a mock agent" }
52/// #     fn execute<'a>(&'a self, task: &'a str) -> BoxFuture<'a, Result<String>> {
53/// #         Box::pin(async move { Ok(format!("{}: {task}", self.name)) })
54/// #     }
55/// #     fn execute_stream<'a>(&'a self, _task: &'a str) -> BoxFuture<'a, Result<BoxStream<'a, Result<AgentEvent>>>> {
56/// #         Box::pin(async move {
57/// #             let s: BoxStream<'a, Result<AgentEvent>> = Box::pin(stream::empty());
58/// #             Ok(s)
59/// #         })
60/// #     }
61/// # }
62///
63/// # async fn example() -> Result<()> {
64/// let researcher = DummyAgent::new("researcher");
65/// let analyst = DummyAgent::new("analyst");
66/// let writer = DummyAgent::new("writer");
67///
68/// let mut wf = DagWorkflow::builder()
69///     .node("research", researcher)
70///     .node("analyze", analyst)
71///     .node("write", writer)
72///     .edge("research", "write")
73///     .edge("analyze", "write")
74///     .build()?;
75///
76/// let output = wf.run("Analyze the 2025 AI Agent ecosystem").await?;
77/// println!("{}", output.result);
78/// # Ok(())
79/// # }
80/// ```
81pub struct DagWorkflow {
82    nodes: HashMap<String, SharedAgent>,
83    edges: Vec<DagEdge>,
84    node_order: Vec<String>,
85}
86
87impl DagWorkflow {
88    pub fn builder() -> DagWorkflowBuilder {
89        DagWorkflowBuilder {
90            nodes: Vec::new(),
91            edges: Vec::new(),
92        }
93    }
94}
95
96impl Workflow for DagWorkflow {
97    fn run<'a>(&'a mut self, input: &'a str) -> BoxFuture<'a, Result<WorkflowOutput>> {
98        Box::pin(async move {
99            let total_start = Instant::now();
100            let mut step_outputs: Vec<StepOutput> = Vec::new();
101            let mut node_results: HashMap<String, String> = HashMap::new();
102
103            let predecessors = build_predecessors(&self.edges);
104            let successors = build_successors(&self.edges);
105            let in_degree = compute_in_degree(&self.node_order, &self.edges);
106
107            let mut remaining_in_degree = in_degree.clone();
108            let mut ready: VecDeque<String> = VecDeque::new();
109
110            for node_id in &self.node_order {
111                if remaining_in_degree[node_id.as_str()] == 0 {
112                    ready.push_back(node_id.clone());
113                }
114            }
115
116            info!(
117                workflow = "dag",
118                nodes = self.node_order.len(),
119                edges = self.edges.len(),
120                roots = ready.len(),
121                "🔀 DAG workflow started"
122            );
123
124            while !ready.is_empty() {
125                let batch: Vec<String> = ready.drain(..).collect();
126
127                debug!(
128                    workflow = "dag",
129                    batch = ?batch,
130                    "âš¡ Executing {} nodes concurrently",
131                    batch.len()
132                );
133
134                let mut handles = Vec::with_capacity(batch.len());
135
136                for node_id in &batch {
137                    let agent_handle = self.nodes[node_id].clone();
138                    let preds = predecessors
139                        .get(node_id.as_str())
140                        .cloned()
141                        .unwrap_or_default();
142
143                    let node_input = if preds.is_empty() {
144                        input.to_string()
145                    } else {
146                        preds
147                            .iter()
148                            .filter_map(|p| node_results.get(p.as_str()))
149                            .cloned()
150                            .collect::<Vec<_>>()
151                            .join("\n\n")
152                    };
153
154                    let nid = node_id.clone();
155                    handles.push(tokio::spawn(async move {
156                        let step_start = Instant::now();
157                        let agent = agent_handle.lock().await;
158                        let agent_name = agent.name().to_string();
159                        let result = agent.execute(&node_input).await;
160                        let elapsed = step_start.elapsed();
161                        (nid, agent_name, node_input, result, elapsed)
162                    }));
163                }
164
165                for handle in handles {
166                    let (node_id, agent_name, node_input, result, elapsed) = handle
167                        .await
168                        .map_err(|e| ReactError::Other(format!("task join error: {e}")))?;
169
170                    let output = result?;
171
172                    info!(
173                        workflow = "dag",
174                        node = %node_id,
175                        agent = %agent_name,
176                        elapsed_ms = elapsed.as_millis(),
177                        "✓ Node completed"
178                    );
179
180                    step_outputs.push(StepOutput {
181                        agent_name,
182                        input: node_input,
183                        output: output.clone(),
184                        elapsed,
185                    });
186
187                    node_results.insert(node_id.clone(), output);
188
189                    if let Some(succs) = successors.get(node_id.as_str()) {
190                        for succ in succs {
191                            if let Some(deg) = remaining_in_degree.get_mut(succ.as_str()) {
192                                *deg -= 1;
193                                if *deg == 0 {
194                                    ready.push_back(succ.clone());
195                                }
196                            }
197                        }
198                    }
199                }
200            }
201
202            // Final result: collect output from all leaf nodes (out-degree = 0)
203            let leaf_nodes: Vec<&str> = self
204                .node_order
205                .iter()
206                .filter(|id| successors.get(id.as_str()).is_none_or(|s| s.is_empty()))
207                .map(|s| s.as_str())
208                .collect();
209
210            let final_result = leaf_nodes
211                .iter()
212                .filter_map(|id| node_results.get(*id))
213                .cloned()
214                .collect::<Vec<_>>()
215                .join("\n\n");
216
217            Ok(WorkflowOutput {
218                result: final_result,
219                steps: step_outputs,
220                elapsed: total_start.elapsed(),
221            })
222        })
223    }
224}
225
226/// [`DagWorkflow`] builder
227pub struct DagWorkflowBuilder {
228    nodes: Vec<(String, SharedAgent)>,
229    edges: Vec<DagEdge>,
230}
231
232impl DagWorkflowBuilder {
233    /// Register a named node
234    pub fn node(mut self, id: impl Into<String>, agent: impl Agent + 'static) -> Self {
235        self.nodes.push((id.into(), shared_agent(agent)));
236        self
237    }
238
239    /// Register a named node (using an already-wrapped SharedAgent)
240    pub fn node_shared(mut self, id: impl Into<String>, agent: SharedAgent) -> Self {
241        self.nodes.push((id.into(), agent));
242        self
243    }
244
245    /// Add a directed edge: `from`'s output will flow into `to`'s input
246    pub fn edge(mut self, from: impl Into<String>, to: impl Into<String>) -> Self {
247        self.edges.push(DagEdge {
248            from: from.into(),
249            to: to.into(),
250        });
251        self
252    }
253
254    /// Build the DAG workflow, validate acyclicity, and compute topological order
255    pub fn build(self) -> Result<DagWorkflow> {
256        let node_ids: HashSet<&str> = self.nodes.iter().map(|(id, _)| id.as_str()).collect();
257
258        for edge in &self.edges {
259            if !node_ids.contains(edge.from.as_str()) {
260                return Err(ReactError::Agent(AgentError::InitializationFailed(
261                    format!("DAG edge references unknown node: '{}'", edge.from),
262                )));
263            }
264            if !node_ids.contains(edge.to.as_str()) {
265                return Err(ReactError::Agent(AgentError::InitializationFailed(
266                    format!("DAG edge references unknown node: '{}'", edge.to),
267                )));
268            }
269        }
270
271        let node_list: Vec<String> = self.nodes.iter().map(|(id, _)| id.clone()).collect();
272        if let Some(cycle) = detect_cycle(&node_list, &self.edges) {
273            return Err(ReactError::Agent(AgentError::InitializationFailed(
274                format!("DAG contains cycle: {}", cycle.join(" → ")),
275            )));
276        }
277
278        let topo_order = topological_sort(&node_list, &self.edges)?;
279
280        let nodes: HashMap<String, SharedAgent> = self.nodes.into_iter().collect();
281
282        Ok(DagWorkflow {
283            nodes,
284            edges: self.edges,
285            node_order: topo_order,
286        })
287    }
288}
289
290// ── DAG Algorithms ──────────────────────────────────────────────────────────────────
291
292fn build_predecessors(edges: &[DagEdge]) -> HashMap<&str, Vec<String>> {
293    let mut preds: HashMap<&str, Vec<String>> = HashMap::new();
294    for edge in edges {
295        preds
296            .entry(edge.to.as_str())
297            .or_default()
298            .push(edge.from.clone());
299    }
300    preds
301}
302
303fn build_successors(edges: &[DagEdge]) -> HashMap<&str, Vec<String>> {
304    let mut succs: HashMap<&str, Vec<String>> = HashMap::new();
305    for edge in edges {
306        succs
307            .entry(edge.from.as_str())
308            .or_default()
309            .push(edge.to.clone());
310    }
311    succs
312}
313
314fn compute_in_degree<'a>(nodes: &'a [String], edges: &[DagEdge]) -> HashMap<&'a str, usize> {
315    let mut deg: HashMap<&str, usize> = nodes.iter().map(|id| (id.as_str(), 0)).collect();
316    for edge in edges {
317        if let Some(d) = deg.get_mut(edge.to.as_str()) {
318            *d += 1;
319        }
320    }
321    deg
322}
323
324/// Kahn's algorithm for topological sort
325fn topological_sort(nodes: &[String], edges: &[DagEdge]) -> Result<Vec<String>> {
326    let mut in_deg = compute_in_degree(nodes, edges);
327    let succs = build_successors(edges);
328
329    let mut queue: VecDeque<String> = nodes
330        .iter()
331        .filter(|id| in_deg[id.as_str()] == 0)
332        .cloned()
333        .collect();
334
335    let mut order = Vec::with_capacity(nodes.len());
336
337    while let Some(node) = queue.pop_front() {
338        order.push(node.clone());
339        if let Some(neighbors) = succs.get(node.as_str()) {
340            for neighbor in neighbors {
341                if let Some(d) = in_deg.get_mut(neighbor.as_str()) {
342                    *d -= 1;
343                    if *d == 0 {
344                        queue.push_back(neighbor.clone());
345                    }
346                }
347            }
348        }
349    }
350
351    if order.len() != nodes.len() {
352        return Err(ReactError::Agent(AgentError::InitializationFailed(
353            "DAG contains a cycle (topological sort incomplete)".to_string(),
354        )));
355    }
356
357    Ok(order)
358}
359
360/// DFS-based cycle detection; returns the cycle path if found
361fn detect_cycle(nodes: &[String], edges: &[DagEdge]) -> Option<Vec<String>> {
362    let succs: HashMap<String, Vec<String>> = {
363        let mut map: HashMap<String, Vec<String>> = HashMap::new();
364        for edge in edges {
365            map.entry(edge.from.clone())
366                .or_default()
367                .push(edge.to.clone());
368        }
369        map
370    };
371
372    #[derive(Clone, Copy, PartialEq)]
373    enum Color {
374        White,
375        Gray,
376        Black,
377    }
378
379    let mut color: HashMap<String, Color> =
380        nodes.iter().map(|id| (id.clone(), Color::White)).collect();
381    let mut path: Vec<String> = Vec::new();
382
383    fn dfs(
384        node: &str,
385        succs: &HashMap<String, Vec<String>>,
386        color: &mut HashMap<String, Color>,
387        path: &mut Vec<String>,
388    ) -> bool {
389        color.insert(node.to_string(), Color::Gray);
390        path.push(node.to_string());
391
392        if let Some(neighbors) = succs.get(node) {
393            for neighbor in neighbors {
394                match color.get(neighbor.as_str()).copied() {
395                    Some(Color::Gray) => {
396                        path.push(neighbor.clone());
397                        return true;
398                    }
399                    Some(Color::White) | None if dfs(neighbor, succs, color, path) => {
400                        return true;
401                    }
402                    Some(Color::White) | None => {}
403                    _ => {}
404                }
405            }
406        }
407
408        path.pop();
409        color.insert(node.to_string(), Color::Black);
410        false
411    }
412
413    for node in nodes {
414        if color[node.as_str()] == Color::White && dfs(node, &succs, &mut color, &mut path) {
415            return Some(path);
416        }
417    }
418
419    None
420}
421
422// ── Unit Tests ──────────────────────────────────────────────────────────────────
423
424#[cfg(test)]
425mod tests {
426    use super::*;
427
428    #[test]
429    fn test_topological_sort_simple() {
430        let nodes = vec!["a".to_string(), "b".to_string(), "c".to_string()];
431        let edges = vec![
432            DagEdge {
433                from: "a".into(),
434                to: "b".into(),
435            },
436            DagEdge {
437                from: "b".into(),
438                to: "c".into(),
439            },
440        ];
441        let order = topological_sort(&nodes, &edges).unwrap();
442        assert_eq!(order, vec!["a", "b", "c"]);
443    }
444
445    #[test]
446    fn test_topological_sort_diamond() {
447        let nodes = vec![
448            "a".to_string(),
449            "b".to_string(),
450            "c".to_string(),
451            "d".to_string(),
452        ];
453        let edges = vec![
454            DagEdge {
455                from: "a".into(),
456                to: "b".into(),
457            },
458            DagEdge {
459                from: "a".into(),
460                to: "c".into(),
461            },
462            DagEdge {
463                from: "b".into(),
464                to: "d".into(),
465            },
466            DagEdge {
467                from: "c".into(),
468                to: "d".into(),
469            },
470        ];
471        let order = topological_sort(&nodes, &edges).unwrap();
472        assert_eq!(order[0], "a");
473        assert_eq!(order[3], "d");
474        assert!(order.contains(&"b".to_string()));
475        assert!(order.contains(&"c".to_string()));
476    }
477
478    #[test]
479    fn test_cycle_detection() {
480        let nodes = vec!["a".to_string(), "b".to_string(), "c".to_string()];
481        let edges = vec![
482            DagEdge {
483                from: "a".into(),
484                to: "b".into(),
485            },
486            DagEdge {
487                from: "b".into(),
488                to: "c".into(),
489            },
490            DagEdge {
491                from: "c".into(),
492                to: "a".into(),
493            },
494        ];
495        assert!(detect_cycle(&nodes, &edges).is_some());
496    }
497
498    #[test]
499    fn test_no_cycle() {
500        let nodes = vec!["a".to_string(), "b".to_string(), "c".to_string()];
501        let edges = vec![
502            DagEdge {
503                from: "a".into(),
504                to: "b".into(),
505            },
506            DagEdge {
507                from: "a".into(),
508                to: "c".into(),
509            },
510        ];
511        assert!(detect_cycle(&nodes, &edges).is_none());
512    }
513}