1use std::collections::HashMap;
8use std::sync::Arc;
9
10use petgraph::Direction;
11use petgraph::algo::is_cyclic_directed;
12use petgraph::graph::{Graph, NodeIndex};
13use tokio::task::JoinSet;
14
15use crate::error::Error;
16use crate::llm::LlmProvider;
17use crate::llm::types::TokenUsage;
18
19use super::{AgentOutput, AgentRunner};
20
21type EdgeCondition = Box<dyn Fn(&str) -> bool + Send + Sync>;
24
25type EdgeTransform = Box<dyn Fn(&str) -> String + Send + Sync>;
27
28struct DagNode<P: LlmProvider> {
30 name: String,
31 agent: Arc<AgentRunner<P>>,
32}
33
34struct DagEdge {
36 condition: Option<EdgeCondition>,
37 transform: Option<EdgeTransform>,
38}
39
40impl std::fmt::Debug for DagEdge {
41 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42 f.debug_struct("DagEdge")
43 .field("has_condition", &self.condition.is_some())
44 .field("has_transform", &self.transform.is_some())
45 .finish()
46 }
47}
48
49pub struct DagAgent<P: LlmProvider + 'static> {
51 graph: Graph<DagNode<P>, DagEdge>,
52}
53
54impl<P: LlmProvider + 'static> std::fmt::Debug for DagAgent<P> {
55 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56 f.debug_struct("DagAgent")
57 .field("node_count", &self.graph.node_count())
58 .field("edge_count", &self.graph.edge_count())
59 .finish()
60 }
61}
62
63pub struct DagAgentBuilder<P: LlmProvider + 'static> {
65 nodes: Vec<(String, AgentRunner<P>)>,
66 edges: Vec<(String, String, DagEdge)>,
67}
68
69impl<P: LlmProvider + 'static> DagAgent<P> {
70 pub fn builder() -> DagAgentBuilder<P> {
113 DagAgentBuilder {
114 nodes: Vec::new(),
115 edges: Vec::new(),
116 }
117 }
118
119 pub async fn execute(&self, task: &str) -> Result<AgentOutput, Error> {
122 let mut completed: HashMap<NodeIndex, String> = HashMap::new();
124 let mut total_usage = TokenUsage::default();
125 let mut total_tool_calls = 0usize;
126 let mut total_cost: Option<f64> = None;
127
128 let roots: Vec<NodeIndex> = self
130 .graph
131 .node_indices()
132 .filter(|&idx| {
133 self.graph
134 .neighbors_directed(idx, Direction::Incoming)
135 .next()
136 .is_none()
137 })
138 .collect();
139
140 let root_results = self.execute_nodes(&roots, task).await;
142 match root_results {
143 Ok(results) => {
144 for (idx, output) in results {
145 output.accumulate_into(
146 &mut total_usage,
147 &mut total_tool_calls,
148 &mut total_cost,
149 );
150 completed.insert(idx, output.result);
151 }
152 }
153 Err(e) => {
154 return Err(e.accumulate_usage(total_usage));
155 }
156 }
157
158 loop {
160 let ready = self.find_ready_nodes(&completed);
161 if ready.is_empty() {
162 break;
163 }
164
165 let mut node_inputs: Vec<(NodeIndex, String)> = Vec::with_capacity(ready.len());
167 for &idx in &ready {
168 let input = self.build_node_input(idx, &completed);
169 node_inputs.push((idx, input));
170 }
171
172 let mut set = JoinSet::new();
174 for (idx, input) in node_inputs {
175 let agent = Arc::clone(&self.graph[idx].agent);
176 set.spawn(async move {
177 let result = agent.execute(&input).await;
178 (idx, result)
179 });
180 }
181
182 while let Some(join_result) = set.join_next().await {
183 let (idx, agent_result) = join_result
184 .map_err(|e| Error::Agent(format!("DAG agent task panicked: {e}")))?;
185 let output = agent_result.map_err(|e| e.accumulate_usage(total_usage))?;
186 output.accumulate_into(&mut total_usage, &mut total_tool_calls, &mut total_cost);
187 completed.insert(idx, output.result);
188 }
189 }
190
191 let terminals: Vec<NodeIndex> = self
194 .graph
195 .node_indices()
196 .filter(|&idx| {
197 if !completed.contains_key(&idx) {
198 return false;
199 }
200 let has_completed_successor = self
202 .graph
203 .neighbors_directed(idx, Direction::Outgoing)
204 .any(|succ| completed.contains_key(&succ));
205 !has_completed_successor
206 })
207 .collect();
208
209 let mut terminal_names: Vec<(String, String)> = terminals
210 .iter()
211 .map(|&idx| {
212 let name = self.graph[idx].name.clone();
213 let text = completed.get(&idx).cloned().unwrap_or_default();
214 (name, text)
215 })
216 .collect();
217 terminal_names.sort_by(|a, b| a.0.cmp(&b.0));
218
219 let merged_text = if terminal_names.len() == 1 {
220 terminal_names
221 .into_iter()
222 .next()
223 .map(|(_, t)| t)
224 .unwrap_or_default()
225 } else {
226 terminal_names
227 .iter()
228 .map(|(name, text)| format!("## {name}\n{text}"))
229 .collect::<Vec<_>>()
230 .join("\n\n")
231 };
232
233 Ok(AgentOutput {
234 result: merged_text,
235 tool_calls_made: total_tool_calls,
236 tokens_used: total_usage,
237 structured: None,
238 estimated_cost_usd: total_cost,
239 model_name: None,
240 })
241 }
242
243 async fn execute_nodes(
248 &self,
249 nodes: &[NodeIndex],
250 input: &str,
251 ) -> Result<Vec<(NodeIndex, AgentOutput)>, Error> {
252 if nodes.len() == 1 {
253 let idx = nodes[0];
254 let output = self.graph[idx].agent.execute(input).await?;
255 return Ok(vec![(idx, output)]);
256 }
257
258 let mut set = JoinSet::new();
259 for &idx in nodes {
260 let agent = Arc::clone(&self.graph[idx].agent);
261 let task = input.to_string();
262 set.spawn(async move {
263 let result = agent.execute(&task).await;
264 (idx, result)
265 });
266 }
267
268 let mut results = Vec::with_capacity(nodes.len());
269 let mut partial_usage = TokenUsage::default();
270 while let Some(join_result) = set.join_next().await {
271 let (idx, agent_result) =
272 join_result.map_err(|e| Error::Agent(format!("DAG agent task panicked: {e}")))?;
273 let output = agent_result.map_err(|e| e.accumulate_usage(partial_usage))?;
274 partial_usage += output.tokens_used;
275 results.push((idx, output));
276 }
277 Ok(results)
278 }
279
280 fn find_ready_nodes(&self, completed: &HashMap<NodeIndex, String>) -> Vec<NodeIndex> {
283 self.graph
284 .node_indices()
285 .filter(|&idx| {
286 if completed.contains_key(&idx) {
287 return false;
288 }
289 let mut has_any_active_incoming = false;
291 for pred in self.graph.neighbors_directed(idx, Direction::Incoming) {
292 if let Some(pred_output) = completed.get(&pred) {
293 let edge_idx = self.graph.find_edge(pred, idx);
295 let active = edge_idx
296 .map(|eidx| &self.graph[eidx])
297 .and_then(|edge| edge.condition.as_ref())
298 .is_none_or(|cond| cond(pred_output));
299 if active {
300 has_any_active_incoming = true;
301 }
302 } else {
303 return false;
307 }
308 }
309 has_any_active_incoming
310 })
311 .collect()
312 }
313
314 fn build_node_input(&self, idx: NodeIndex, completed: &HashMap<NodeIndex, String>) -> String {
316 let mut inputs: Vec<(String, String)> = Vec::new();
317 for pred in self.graph.neighbors_directed(idx, Direction::Incoming) {
318 if let Some(pred_output) = completed.get(&pred) {
319 let edge_idx = self.graph.find_edge(pred, idx);
320 let active = edge_idx
321 .map(|eidx| &self.graph[eidx])
322 .and_then(|edge| edge.condition.as_ref())
323 .is_none_or(|cond| cond(pred_output));
324 if active {
325 let text = edge_idx
326 .and_then(|eidx| {
327 self.graph[eidx].transform.as_ref().map(|t| t(pred_output))
328 })
329 .unwrap_or_else(|| pred_output.clone());
330 let pred_name = self.graph[pred].name.clone();
331 inputs.push((pred_name, text));
332 }
333 }
334 }
335 inputs.sort_by(|a, b| a.0.cmp(&b.0));
337
338 if inputs.len() == 1 {
339 inputs
340 .into_iter()
341 .next()
342 .map(|(_, t)| t)
343 .unwrap_or_default()
344 } else {
345 inputs
346 .into_iter()
347 .map(|(_, text)| text)
348 .collect::<Vec<_>>()
349 .join("\n")
350 }
351 }
352}
353
354impl<P: LlmProvider + 'static> DagAgentBuilder<P> {
355 pub fn node(mut self, name: impl Into<String>, agent: AgentRunner<P>) -> Self {
357 self.nodes.push((name.into(), agent));
358 self
359 }
360
361 pub fn edge(mut self, from: &str, to: &str) -> Self {
363 self.edges.push((
364 from.to_string(),
365 to.to_string(),
366 DagEdge {
367 condition: None,
368 transform: None,
369 },
370 ));
371 self
372 }
373
374 pub fn conditional_edge(
376 mut self,
377 from: &str,
378 to: &str,
379 condition: impl Fn(&str) -> bool + Send + Sync + 'static,
380 ) -> Self {
381 self.edges.push((
382 from.to_string(),
383 to.to_string(),
384 DagEdge {
385 condition: Some(Box::new(condition)),
386 transform: None,
387 },
388 ));
389 self
390 }
391
392 pub fn edge_with_transform(
394 mut self,
395 from: &str,
396 to: &str,
397 transform: impl Fn(&str) -> String + Send + Sync + 'static,
398 ) -> Self {
399 self.edges.push((
400 from.to_string(),
401 to.to_string(),
402 DagEdge {
403 condition: None,
404 transform: Some(Box::new(transform)),
405 },
406 ));
407 self
408 }
409
410 pub fn build(self) -> Result<DagAgent<P>, Error> {
413 if self.nodes.is_empty() {
414 return Err(Error::Config("DagAgent requires at least one node".into()));
415 }
416
417 let mut seen = std::collections::HashSet::new();
419 for (name, _) in &self.nodes {
420 if !seen.insert(name.as_str()) {
421 return Err(Error::Config(format!(
422 "DagAgent has duplicate node name: {name}"
423 )));
424 }
425 }
426
427 let mut graph = Graph::new();
428 let mut node_indices = HashMap::new();
429
430 for (name, agent) in self.nodes {
431 let idx = graph.add_node(DagNode {
432 name: name.clone(),
433 agent: Arc::new(agent),
434 });
435 node_indices.insert(name, idx);
436 }
437
438 for (from, to, edge) in self.edges {
439 let from_idx = node_indices.get(&from).ok_or_else(|| {
440 Error::Config(format!("DagAgent edge references unknown node: {from}"))
441 })?;
442 let to_idx = node_indices.get(&to).ok_or_else(|| {
443 Error::Config(format!("DagAgent edge references unknown node: {to}"))
444 })?;
445 graph.add_edge(*from_idx, *to_idx, edge);
446 }
447
448 if is_cyclic_directed(&graph) {
449 return Err(Error::Config("DagAgent graph contains a cycle".into()));
450 }
451
452 Ok(DagAgent { graph })
453 }
454}
455
456#[cfg(test)]
461mod tests {
462 use super::*;
463 use crate::agent::test_helpers::{MockProvider, make_agent};
464
465 #[test]
470 fn dag_builder_rejects_empty_graph() {
471 let result = DagAgent::<MockProvider>::builder().build();
472 assert!(result.is_err());
473 assert!(
474 result
475 .unwrap_err()
476 .to_string()
477 .contains("at least one node")
478 );
479 }
480
481 #[test]
482 fn dag_builder_rejects_duplicate_names() {
483 let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
484 "a", 1, 1,
485 )]));
486 let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
487 "b", 1, 1,
488 )]));
489 let result = DagAgent::builder()
490 .node("same", make_agent(p1, "same"))
491 .node("same", make_agent(p2, "same"))
492 .build();
493 assert!(result.is_err());
494 assert!(
495 result
496 .unwrap_err()
497 .to_string()
498 .contains("duplicate node name")
499 );
500 }
501
502 #[test]
503 fn dag_builder_rejects_missing_edge_endpoint() {
504 let p = Arc::new(MockProvider::new(vec![MockProvider::text_response(
505 "a", 1, 1,
506 )]));
507 let result = DagAgent::builder()
508 .node("A", make_agent(p, "A"))
509 .edge("A", "B")
510 .build();
511 assert!(result.is_err());
512 assert!(result.unwrap_err().to_string().contains("unknown node"));
513 }
514
515 #[test]
516 fn dag_builder_rejects_cycle() {
517 let pa = Arc::new(MockProvider::new(vec![MockProvider::text_response(
518 "a", 1, 1,
519 )]));
520 let pb = Arc::new(MockProvider::new(vec![MockProvider::text_response(
521 "b", 1, 1,
522 )]));
523 let result = DagAgent::builder()
524 .node("A", make_agent(pa, "A"))
525 .node("B", make_agent(pb, "B"))
526 .edge("A", "B")
527 .edge("B", "A")
528 .build();
529 assert!(result.is_err());
530 assert!(result.unwrap_err().to_string().contains("cycle"));
531 }
532
533 #[test]
534 fn dag_builder_accepts_single_node() {
535 let p = Arc::new(MockProvider::new(vec![MockProvider::text_response(
536 "ok", 1, 1,
537 )]));
538 let result = DagAgent::builder().node("A", make_agent(p, "A")).build();
539 assert!(result.is_ok());
540 }
541
542 #[tokio::test]
547 async fn dag_single_node() {
548 let p = Arc::new(MockProvider::new(vec![MockProvider::text_response(
549 "hello", 10, 5,
550 )]));
551 let dag = DagAgent::builder()
552 .node("A", make_agent(p, "A"))
553 .build()
554 .unwrap();
555
556 let output = dag.execute("task").await.unwrap();
557 assert_eq!(output.result, "hello");
558 assert_eq!(output.tokens_used.input_tokens, 10);
559 assert_eq!(output.tokens_used.output_tokens, 5);
560 }
561
562 #[tokio::test]
563 async fn dag_linear_a_b_c() {
564 let pa = Arc::new(MockProvider::new(vec![MockProvider::text_response(
565 "out-a", 10, 5,
566 )]));
567 let pb = Arc::new(MockProvider::new(vec![MockProvider::text_response(
568 "out-b", 20, 10,
569 )]));
570 let pc = Arc::new(MockProvider::new(vec![MockProvider::text_response(
571 "out-c", 30, 15,
572 )]));
573
574 let dag = DagAgent::builder()
575 .node("A", make_agent(pa, "A"))
576 .node("B", make_agent(pb, "B"))
577 .node("C", make_agent(pc, "C"))
578 .edge("A", "B")
579 .edge("B", "C")
580 .build()
581 .unwrap();
582
583 let output = dag.execute("start").await.unwrap();
584 assert_eq!(output.result, "out-c");
585 assert_eq!(output.tokens_used.input_tokens, 60);
586 assert_eq!(output.tokens_used.output_tokens, 30);
587 }
588
589 #[tokio::test]
590 async fn dag_fan_out() {
591 let pa = Arc::new(MockProvider::new(vec![MockProvider::text_response(
593 "root-out", 10, 5,
594 )]));
595 let pb = Arc::new(MockProvider::new(vec![MockProvider::text_response(
596 "branch-b", 20, 10,
597 )]));
598 let pc = Arc::new(MockProvider::new(vec![MockProvider::text_response(
599 "branch-c", 30, 15,
600 )]));
601
602 let dag = DagAgent::builder()
603 .node("A", make_agent(pa, "A"))
604 .node("B", make_agent(pb, "B"))
605 .node("C", make_agent(pc, "C"))
606 .edge("A", "B")
607 .edge("A", "C")
608 .build()
609 .unwrap();
610
611 let output = dag.execute("task").await.unwrap();
612 assert!(output.result.contains("branch-b"));
614 assert!(output.result.contains("branch-c"));
615 assert_eq!(output.tokens_used.input_tokens, 60);
616 assert_eq!(output.tokens_used.output_tokens, 30);
617 }
618
619 #[tokio::test]
620 async fn dag_fan_in() {
621 let pa = Arc::new(MockProvider::new(vec![MockProvider::text_response(
623 "from-a", 10, 5,
624 )]));
625 let pb = Arc::new(MockProvider::new(vec![MockProvider::text_response(
626 "from-b", 20, 10,
627 )]));
628 let pc = Arc::new(MockProvider::new(vec![MockProvider::text_response(
629 "merged", 30, 15,
630 )]));
631
632 let dag = DagAgent::builder()
633 .node("A", make_agent(pa, "A"))
634 .node("B", make_agent(pb, "B"))
635 .node("C", make_agent(pc, "C"))
636 .edge("A", "C")
637 .edge("B", "C")
638 .build()
639 .unwrap();
640
641 let output = dag.execute("task").await.unwrap();
642 assert_eq!(output.result, "merged");
643 assert_eq!(output.tokens_used.input_tokens, 60);
644 assert_eq!(output.tokens_used.output_tokens, 30);
645 }
646
647 #[tokio::test]
648 async fn dag_diamond() {
649 let pa = Arc::new(MockProvider::new(vec![MockProvider::text_response(
651 "root", 10, 5,
652 )]));
653 let pb = Arc::new(MockProvider::new(vec![MockProvider::text_response(
654 "left", 10, 5,
655 )]));
656 let pc = Arc::new(MockProvider::new(vec![MockProvider::text_response(
657 "right", 10, 5,
658 )]));
659 let pd = Arc::new(MockProvider::new(vec![MockProvider::text_response(
660 "diamond-end",
661 10,
662 5,
663 )]));
664
665 let dag = DagAgent::builder()
666 .node("A", make_agent(pa, "A"))
667 .node("B", make_agent(pb, "B"))
668 .node("C", make_agent(pc, "C"))
669 .node("D", make_agent(pd, "D"))
670 .edge("A", "B")
671 .edge("A", "C")
672 .edge("B", "D")
673 .edge("C", "D")
674 .build()
675 .unwrap();
676
677 let output = dag.execute("task").await.unwrap();
678 assert_eq!(output.result, "diamond-end");
679 assert_eq!(output.tokens_used.input_tokens, 40);
680 assert_eq!(output.tokens_used.output_tokens, 20);
681 }
682
683 #[tokio::test]
684 async fn dag_conditional_edge() {
685 let pa = Arc::new(MockProvider::new(vec![MockProvider::text_response(
688 "no", 10, 5,
689 )]));
690 let pb = Arc::new(MockProvider::new(vec![MockProvider::text_response(
691 "branch-b", 10, 5,
692 )]));
693 let pc = Arc::new(MockProvider::new(vec![MockProvider::text_response(
694 "branch-c", 10, 5,
695 )]));
696
697 let dag = DagAgent::builder()
698 .node("A", make_agent(pa, "A"))
699 .node("B", make_agent(pb, "B"))
700 .node("C", make_agent(pc, "C"))
701 .edge("A", "B")
702 .conditional_edge("A", "C", |output| output.contains("yes"))
703 .build()
704 .unwrap();
705
706 let output = dag.execute("task").await.unwrap();
707 assert!(output.result.contains("branch-b"));
709 assert!(!output.result.contains("branch-c"));
710 assert_eq!(output.tokens_used.input_tokens, 20);
712 assert_eq!(output.tokens_used.output_tokens, 10);
713 }
714
715 #[tokio::test]
716 async fn dag_conditional_edge_passes() {
717 let pa = Arc::new(MockProvider::new(vec![MockProvider::text_response(
720 "yes", 10, 5,
721 )]));
722 let pb = Arc::new(MockProvider::new(vec![MockProvider::text_response(
723 "branch-b", 10, 5,
724 )]));
725 let pc = Arc::new(MockProvider::new(vec![MockProvider::text_response(
726 "branch-c", 10, 5,
727 )]));
728
729 let dag = DagAgent::builder()
730 .node("A", make_agent(pa, "A"))
731 .node("B", make_agent(pb, "B"))
732 .node("C", make_agent(pc, "C"))
733 .edge("A", "B")
734 .conditional_edge("A", "C", |output| output.contains("yes"))
735 .build()
736 .unwrap();
737
738 let output = dag.execute("task").await.unwrap();
739 assert!(output.result.contains("branch-b"));
740 assert!(output.result.contains("branch-c"));
741 assert_eq!(output.tokens_used.input_tokens, 30);
742 }
743
744 #[tokio::test]
745 async fn dag_edge_with_transform() {
746 let pa = Arc::new(MockProvider::new(vec![MockProvider::text_response(
748 "hello", 10, 5,
749 )]));
750 let pb = Arc::new(MockProvider::new(vec![MockProvider::text_response(
753 "got-it", 10, 5,
754 )]));
755
756 let dag = DagAgent::builder()
757 .node("A", make_agent(pa, "A"))
758 .node("B", make_agent(pb, "B"))
759 .edge_with_transform("A", "B", |text| text.to_uppercase())
760 .build()
761 .unwrap();
762
763 let output = dag.execute("task").await.unwrap();
764 assert_eq!(output.result, "got-it");
765 assert_eq!(output.tokens_used.input_tokens, 20);
766 }
767
768 #[tokio::test]
769 async fn dag_token_accumulation() {
770 let pa = Arc::new(MockProvider::new(vec![MockProvider::text_response(
773 "a", 100, 50,
774 )]));
775 let pb = Arc::new(MockProvider::new(vec![MockProvider::text_response(
776 "b", 200, 100,
777 )]));
778 let pc = Arc::new(MockProvider::new(vec![MockProvider::text_response(
779 "c", 300, 150,
780 )]));
781 let pd = Arc::new(MockProvider::new(vec![MockProvider::text_response(
782 "d", 400, 200,
783 )]));
784
785 let dag = DagAgent::builder()
786 .node("A", make_agent(pa, "A"))
787 .node("B", make_agent(pb, "B"))
788 .node("C", make_agent(pc, "C"))
789 .node("D", make_agent(pd, "D"))
790 .edge("A", "B")
791 .edge("A", "C")
792 .edge("B", "D")
793 .edge("C", "D")
794 .build()
795 .unwrap();
796
797 let output = dag.execute("task").await.unwrap();
798 assert_eq!(output.tokens_used.input_tokens, 1000);
799 assert_eq!(output.tokens_used.output_tokens, 500);
800 }
801
802 #[tokio::test]
803 async fn dag_error_carries_partial_usage() {
804 let pa = Arc::new(MockProvider::new(vec![MockProvider::text_response(
806 "ok", 100, 50,
807 )]));
808 let pb = Arc::new(MockProvider::new(vec![])); let dag = DagAgent::builder()
811 .node("A", make_agent(pa, "A"))
812 .node("B", make_agent(pb, "B"))
813 .edge("A", "B")
814 .build()
815 .unwrap();
816
817 let err = dag.execute("task").await.unwrap_err();
818 let partial = err.partial_usage();
819 assert!(partial.input_tokens >= 100);
820 }
821
822 #[tokio::test]
823 async fn dag_parallel_roots_error_carries_sibling_usage() {
824 let pa = Arc::new(MockProvider::new(vec![MockProvider::text_response(
827 "ok", 200, 100,
828 )]));
829 let pb = Arc::new(MockProvider::new(vec![])); let dag = DagAgent::builder()
832 .node("A", make_agent(pa, "A"))
833 .node("B", make_agent(pb, "B"))
834 .build()
835 .unwrap();
836
837 let err = dag.execute("task").await.unwrap_err();
838 let partial = err.partial_usage();
839 let _ = partial;
847 }
848
849 #[test]
850 fn dag_debug_impl() {
851 let p = Arc::new(MockProvider::new(vec![MockProvider::text_response(
852 "a", 1, 1,
853 )]));
854 let dag = DagAgent::builder()
855 .node("A", make_agent(p, "A"))
856 .build()
857 .unwrap();
858 let debug = format!("{dag:?}");
859 assert!(debug.contains("DagAgent"));
860 assert!(debug.contains("node_count"));
861 }
862}