echo_orchestration/workflow/
dag.rs1use super::{SharedAgent, StepOutput, Workflow, WorkflowOutput, shared_agent};
4use echo_core::agent::Agent;
5use echo_core::error::{AgentError, ReactError, Result};
6use futures::future::BoxFuture;
7use std::collections::{HashMap, HashSet, VecDeque};
8use std::time::Instant;
9use tracing::{debug, info};
10
11pub struct DagNode {
13 pub id: String,
14 pub agent: SharedAgent,
15}
16
17#[derive(Debug, Clone)]
19pub struct DagEdge {
20 pub from: String,
21 pub to: String,
22}
23
24pub struct DagWorkflow {
82 nodes: HashMap<String, SharedAgent>,
83 edges: Vec<DagEdge>,
84 node_order: Vec<String>,
85}
86
87impl DagWorkflow {
88 pub fn builder() -> DagWorkflowBuilder {
89 DagWorkflowBuilder {
90 nodes: Vec::new(),
91 edges: Vec::new(),
92 }
93 }
94}
95
96impl Workflow for DagWorkflow {
97 fn run<'a>(&'a mut self, input: &'a str) -> BoxFuture<'a, Result<WorkflowOutput>> {
98 Box::pin(async move {
99 let total_start = Instant::now();
100 let mut step_outputs: Vec<StepOutput> = Vec::new();
101 let mut node_results: HashMap<String, String> = HashMap::new();
102
103 let predecessors = build_predecessors(&self.edges);
104 let successors = build_successors(&self.edges);
105 let in_degree = compute_in_degree(&self.node_order, &self.edges);
106
107 let mut remaining_in_degree = in_degree.clone();
108 let mut ready: VecDeque<String> = VecDeque::new();
109
110 for node_id in &self.node_order {
111 if remaining_in_degree[node_id.as_str()] == 0 {
112 ready.push_back(node_id.clone());
113 }
114 }
115
116 info!(
117 workflow = "dag",
118 nodes = self.node_order.len(),
119 edges = self.edges.len(),
120 roots = ready.len(),
121 "🔀 DAG workflow started"
122 );
123
124 while !ready.is_empty() {
125 let batch: Vec<String> = ready.drain(..).collect();
126
127 debug!(
128 workflow = "dag",
129 batch = ?batch,
130 "âš¡ Executing {} nodes concurrently",
131 batch.len()
132 );
133
134 let mut handles = Vec::with_capacity(batch.len());
135
136 for node_id in &batch {
137 let agent_handle = self.nodes[node_id].clone();
138 let preds = predecessors
139 .get(node_id.as_str())
140 .cloned()
141 .unwrap_or_default();
142
143 let node_input = if preds.is_empty() {
144 input.to_string()
145 } else {
146 preds
147 .iter()
148 .filter_map(|p| node_results.get(p.as_str()))
149 .cloned()
150 .collect::<Vec<_>>()
151 .join("\n\n")
152 };
153
154 let nid = node_id.clone();
155 handles.push(tokio::spawn(async move {
156 let step_start = Instant::now();
157 let agent = agent_handle.lock().await;
158 let agent_name = agent.name().to_string();
159 let result = agent.execute(&node_input).await;
160 let elapsed = step_start.elapsed();
161 (nid, agent_name, node_input, result, elapsed)
162 }));
163 }
164
165 for handle in handles {
166 let (node_id, agent_name, node_input, result, elapsed) = handle
167 .await
168 .map_err(|e| ReactError::Other(format!("task join error: {e}")))?;
169
170 let output = result?;
171
172 info!(
173 workflow = "dag",
174 node = %node_id,
175 agent = %agent_name,
176 elapsed_ms = elapsed.as_millis(),
177 "✓ Node completed"
178 );
179
180 step_outputs.push(StepOutput {
181 agent_name,
182 input: node_input,
183 output: output.clone(),
184 elapsed,
185 });
186
187 node_results.insert(node_id.clone(), output);
188
189 if let Some(succs) = successors.get(node_id.as_str()) {
190 for succ in succs {
191 if let Some(deg) = remaining_in_degree.get_mut(succ.as_str()) {
192 *deg -= 1;
193 if *deg == 0 {
194 ready.push_back(succ.clone());
195 }
196 }
197 }
198 }
199 }
200 }
201
202 let leaf_nodes: Vec<&str> = self
204 .node_order
205 .iter()
206 .filter(|id| successors.get(id.as_str()).is_none_or(|s| s.is_empty()))
207 .map(|s| s.as_str())
208 .collect();
209
210 let final_result = leaf_nodes
211 .iter()
212 .filter_map(|id| node_results.get(*id))
213 .cloned()
214 .collect::<Vec<_>>()
215 .join("\n\n");
216
217 Ok(WorkflowOutput {
218 result: final_result,
219 steps: step_outputs,
220 elapsed: total_start.elapsed(),
221 })
222 })
223 }
224}
225
226pub struct DagWorkflowBuilder {
228 nodes: Vec<(String, SharedAgent)>,
229 edges: Vec<DagEdge>,
230}
231
232impl DagWorkflowBuilder {
233 pub fn node(mut self, id: impl Into<String>, agent: impl Agent + 'static) -> Self {
235 self.nodes.push((id.into(), shared_agent(agent)));
236 self
237 }
238
239 pub fn node_shared(mut self, id: impl Into<String>, agent: SharedAgent) -> Self {
241 self.nodes.push((id.into(), agent));
242 self
243 }
244
245 pub fn edge(mut self, from: impl Into<String>, to: impl Into<String>) -> Self {
247 self.edges.push(DagEdge {
248 from: from.into(),
249 to: to.into(),
250 });
251 self
252 }
253
254 pub fn build(self) -> Result<DagWorkflow> {
256 let node_ids: HashSet<&str> = self.nodes.iter().map(|(id, _)| id.as_str()).collect();
257
258 for edge in &self.edges {
259 if !node_ids.contains(edge.from.as_str()) {
260 return Err(ReactError::Agent(AgentError::InitializationFailed(
261 format!("DAG edge references unknown node: '{}'", edge.from),
262 )));
263 }
264 if !node_ids.contains(edge.to.as_str()) {
265 return Err(ReactError::Agent(AgentError::InitializationFailed(
266 format!("DAG edge references unknown node: '{}'", edge.to),
267 )));
268 }
269 }
270
271 let node_list: Vec<String> = self.nodes.iter().map(|(id, _)| id.clone()).collect();
272 if let Some(cycle) = detect_cycle(&node_list, &self.edges) {
273 return Err(ReactError::Agent(AgentError::InitializationFailed(
274 format!("DAG contains cycle: {}", cycle.join(" → ")),
275 )));
276 }
277
278 let topo_order = topological_sort(&node_list, &self.edges)?;
279
280 let nodes: HashMap<String, SharedAgent> = self.nodes.into_iter().collect();
281
282 Ok(DagWorkflow {
283 nodes,
284 edges: self.edges,
285 node_order: topo_order,
286 })
287 }
288}
289
290fn build_predecessors(edges: &[DagEdge]) -> HashMap<&str, Vec<String>> {
293 let mut preds: HashMap<&str, Vec<String>> = HashMap::new();
294 for edge in edges {
295 preds
296 .entry(edge.to.as_str())
297 .or_default()
298 .push(edge.from.clone());
299 }
300 preds
301}
302
303fn build_successors(edges: &[DagEdge]) -> HashMap<&str, Vec<String>> {
304 let mut succs: HashMap<&str, Vec<String>> = HashMap::new();
305 for edge in edges {
306 succs
307 .entry(edge.from.as_str())
308 .or_default()
309 .push(edge.to.clone());
310 }
311 succs
312}
313
314fn compute_in_degree<'a>(nodes: &'a [String], edges: &[DagEdge]) -> HashMap<&'a str, usize> {
315 let mut deg: HashMap<&str, usize> = nodes.iter().map(|id| (id.as_str(), 0)).collect();
316 for edge in edges {
317 if let Some(d) = deg.get_mut(edge.to.as_str()) {
318 *d += 1;
319 }
320 }
321 deg
322}
323
324fn topological_sort(nodes: &[String], edges: &[DagEdge]) -> Result<Vec<String>> {
326 let mut in_deg = compute_in_degree(nodes, edges);
327 let succs = build_successors(edges);
328
329 let mut queue: VecDeque<String> = nodes
330 .iter()
331 .filter(|id| in_deg[id.as_str()] == 0)
332 .cloned()
333 .collect();
334
335 let mut order = Vec::with_capacity(nodes.len());
336
337 while let Some(node) = queue.pop_front() {
338 order.push(node.clone());
339 if let Some(neighbors) = succs.get(node.as_str()) {
340 for neighbor in neighbors {
341 if let Some(d) = in_deg.get_mut(neighbor.as_str()) {
342 *d -= 1;
343 if *d == 0 {
344 queue.push_back(neighbor.clone());
345 }
346 }
347 }
348 }
349 }
350
351 if order.len() != nodes.len() {
352 return Err(ReactError::Agent(AgentError::InitializationFailed(
353 "DAG contains a cycle (topological sort incomplete)".to_string(),
354 )));
355 }
356
357 Ok(order)
358}
359
360fn detect_cycle(nodes: &[String], edges: &[DagEdge]) -> Option<Vec<String>> {
362 let succs: HashMap<String, Vec<String>> = {
363 let mut map: HashMap<String, Vec<String>> = HashMap::new();
364 for edge in edges {
365 map.entry(edge.from.clone())
366 .or_default()
367 .push(edge.to.clone());
368 }
369 map
370 };
371
372 #[derive(Clone, Copy, PartialEq)]
373 enum Color {
374 White,
375 Gray,
376 Black,
377 }
378
379 let mut color: HashMap<String, Color> =
380 nodes.iter().map(|id| (id.clone(), Color::White)).collect();
381 let mut path: Vec<String> = Vec::new();
382
383 fn dfs(
384 node: &str,
385 succs: &HashMap<String, Vec<String>>,
386 color: &mut HashMap<String, Color>,
387 path: &mut Vec<String>,
388 ) -> bool {
389 color.insert(node.to_string(), Color::Gray);
390 path.push(node.to_string());
391
392 if let Some(neighbors) = succs.get(node) {
393 for neighbor in neighbors {
394 match color.get(neighbor.as_str()).copied() {
395 Some(Color::Gray) => {
396 path.push(neighbor.clone());
397 return true;
398 }
399 Some(Color::White) | None if dfs(neighbor, succs, color, path) => {
400 return true;
401 }
402 Some(Color::White) | None => {}
403 _ => {}
404 }
405 }
406 }
407
408 path.pop();
409 color.insert(node.to_string(), Color::Black);
410 false
411 }
412
413 for node in nodes {
414 if color[node.as_str()] == Color::White && dfs(node, &succs, &mut color, &mut path) {
415 return Some(path);
416 }
417 }
418
419 None
420}
421
422#[cfg(test)]
425mod tests {
426 use super::*;
427
428 #[test]
429 fn test_topological_sort_simple() {
430 let nodes = vec!["a".to_string(), "b".to_string(), "c".to_string()];
431 let edges = vec![
432 DagEdge {
433 from: "a".into(),
434 to: "b".into(),
435 },
436 DagEdge {
437 from: "b".into(),
438 to: "c".into(),
439 },
440 ];
441 let order = topological_sort(&nodes, &edges).unwrap();
442 assert_eq!(order, vec!["a", "b", "c"]);
443 }
444
445 #[test]
446 fn test_topological_sort_diamond() {
447 let nodes = vec![
448 "a".to_string(),
449 "b".to_string(),
450 "c".to_string(),
451 "d".to_string(),
452 ];
453 let edges = vec![
454 DagEdge {
455 from: "a".into(),
456 to: "b".into(),
457 },
458 DagEdge {
459 from: "a".into(),
460 to: "c".into(),
461 },
462 DagEdge {
463 from: "b".into(),
464 to: "d".into(),
465 },
466 DagEdge {
467 from: "c".into(),
468 to: "d".into(),
469 },
470 ];
471 let order = topological_sort(&nodes, &edges).unwrap();
472 assert_eq!(order[0], "a");
473 assert_eq!(order[3], "d");
474 assert!(order.contains(&"b".to_string()));
475 assert!(order.contains(&"c".to_string()));
476 }
477
478 #[test]
479 fn test_cycle_detection() {
480 let nodes = vec!["a".to_string(), "b".to_string(), "c".to_string()];
481 let edges = vec![
482 DagEdge {
483 from: "a".into(),
484 to: "b".into(),
485 },
486 DagEdge {
487 from: "b".into(),
488 to: "c".into(),
489 },
490 DagEdge {
491 from: "c".into(),
492 to: "a".into(),
493 },
494 ];
495 assert!(detect_cycle(&nodes, &edges).is_some());
496 }
497
498 #[test]
499 fn test_no_cycle() {
500 let nodes = vec!["a".to_string(), "b".to_string(), "c".to_string()];
501 let edges = vec![
502 DagEdge {
503 from: "a".into(),
504 to: "b".into(),
505 },
506 DagEdge {
507 from: "a".into(),
508 to: "c".into(),
509 },
510 ];
511 assert!(detect_cycle(&nodes, &edges).is_none());
512 }
513}