use std::collections::HashMap;
use std::sync::Arc;
use petgraph::Direction;
use petgraph::algo::is_cyclic_directed;
use petgraph::graph::{Graph, NodeIndex};
use tokio::task::JoinSet;
use crate::error::Error;
use crate::llm::LlmProvider;
use crate::llm::types::TokenUsage;
use super::{AgentOutput, AgentRunner};
type EdgeCondition = Box<dyn Fn(&str) -> bool + Send + Sync>;
type EdgeTransform = Box<dyn Fn(&str) -> String + Send + Sync>;
struct DagNode<P: LlmProvider> {
name: String,
agent: Arc<AgentRunner<P>>,
}
struct DagEdge {
condition: Option<EdgeCondition>,
transform: Option<EdgeTransform>,
}
impl std::fmt::Debug for DagEdge {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DagEdge")
.field("has_condition", &self.condition.is_some())
.field("has_transform", &self.transform.is_some())
.finish()
}
}
pub struct DagAgent<P: LlmProvider + 'static> {
graph: Graph<DagNode<P>, DagEdge>,
}
impl<P: LlmProvider + 'static> std::fmt::Debug for DagAgent<P> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DagAgent")
.field("node_count", &self.graph.node_count())
.field("edge_count", &self.graph.edge_count())
.finish()
}
}
pub struct DagAgentBuilder<P: LlmProvider + 'static> {
nodes: Vec<(String, AgentRunner<P>)>,
edges: Vec<(String, String, DagEdge)>,
}
impl<P: LlmProvider + 'static> DagAgent<P> {
pub fn builder() -> DagAgentBuilder<P> {
DagAgentBuilder {
nodes: Vec::new(),
edges: Vec::new(),
}
}
pub async fn execute(&self, task: &str) -> Result<AgentOutput, Error> {
let mut completed: HashMap<NodeIndex, String> = HashMap::new();
let mut total_usage = TokenUsage::default();
let mut total_tool_calls = 0usize;
let mut total_cost: Option<f64> = None;
let roots: Vec<NodeIndex> = self
.graph
.node_indices()
.filter(|&idx| {
self.graph
.neighbors_directed(idx, Direction::Incoming)
.next()
.is_none()
})
.collect();
let root_results = self.execute_nodes(&roots, task).await;
match root_results {
Ok(results) => {
for (idx, output) in results {
output.accumulate_into(
&mut total_usage,
&mut total_tool_calls,
&mut total_cost,
);
completed.insert(idx, output.result);
}
}
Err(e) => {
return Err(e.accumulate_usage(total_usage));
}
}
loop {
let ready = self.find_ready_nodes(&completed);
if ready.is_empty() {
break;
}
let mut node_inputs: Vec<(NodeIndex, String)> = Vec::with_capacity(ready.len());
for &idx in &ready {
let input = self.build_node_input(idx, &completed);
node_inputs.push((idx, input));
}
let mut set = JoinSet::new();
for (idx, input) in node_inputs {
let agent = Arc::clone(&self.graph[idx].agent);
set.spawn(async move {
let result = agent.execute(&input).await;
(idx, result)
});
}
while let Some(join_result) = set.join_next().await {
let (idx, agent_result) = join_result
.map_err(|e| Error::Agent(format!("DAG agent task panicked: {e}")))?;
let output = agent_result.map_err(|e| e.accumulate_usage(total_usage))?;
output.accumulate_into(&mut total_usage, &mut total_tool_calls, &mut total_cost);
completed.insert(idx, output.result);
}
}
let terminals: Vec<NodeIndex> = self
.graph
.node_indices()
.filter(|&idx| {
if !completed.contains_key(&idx) {
return false;
}
let has_completed_successor = self
.graph
.neighbors_directed(idx, Direction::Outgoing)
.any(|succ| completed.contains_key(&succ));
!has_completed_successor
})
.collect();
let mut terminal_names: Vec<(String, String)> = terminals
.iter()
.map(|&idx| {
let name = self.graph[idx].name.clone();
let text = completed.get(&idx).cloned().unwrap_or_default();
(name, text)
})
.collect();
terminal_names.sort_by(|a, b| a.0.cmp(&b.0));
let merged_text = if terminal_names.len() == 1 {
terminal_names
.into_iter()
.next()
.map(|(_, t)| t)
.unwrap_or_default()
} else {
terminal_names
.iter()
.map(|(name, text)| format!("## {name}\n{text}"))
.collect::<Vec<_>>()
.join("\n\n")
};
Ok(AgentOutput {
result: merged_text,
tool_calls_made: total_tool_calls,
tokens_used: total_usage,
structured: None,
estimated_cost_usd: total_cost,
model_name: None,
})
}
async fn execute_nodes(
&self,
nodes: &[NodeIndex],
input: &str,
) -> Result<Vec<(NodeIndex, AgentOutput)>, Error> {
if nodes.len() == 1 {
let idx = nodes[0];
let output = self.graph[idx].agent.execute(input).await?;
return Ok(vec![(idx, output)]);
}
let mut set = JoinSet::new();
for &idx in nodes {
let agent = Arc::clone(&self.graph[idx].agent);
let task = input.to_string();
set.spawn(async move {
let result = agent.execute(&task).await;
(idx, result)
});
}
let mut results = Vec::with_capacity(nodes.len());
let mut partial_usage = TokenUsage::default();
while let Some(join_result) = set.join_next().await {
let (idx, agent_result) =
join_result.map_err(|e| Error::Agent(format!("DAG agent task panicked: {e}")))?;
let output = agent_result.map_err(|e| e.accumulate_usage(partial_usage))?;
partial_usage += output.tokens_used;
results.push((idx, output));
}
Ok(results)
}
fn find_ready_nodes(&self, completed: &HashMap<NodeIndex, String>) -> Vec<NodeIndex> {
self.graph
.node_indices()
.filter(|&idx| {
if completed.contains_key(&idx) {
return false;
}
let mut has_any_active_incoming = false;
for pred in self.graph.neighbors_directed(idx, Direction::Incoming) {
if let Some(pred_output) = completed.get(&pred) {
let edge_idx = self.graph.find_edge(pred, idx);
let active = edge_idx
.map(|eidx| &self.graph[eidx])
.and_then(|edge| edge.condition.as_ref())
.is_none_or(|cond| cond(pred_output));
if active {
has_any_active_incoming = true;
}
} else {
return false;
}
}
has_any_active_incoming
})
.collect()
}
fn build_node_input(&self, idx: NodeIndex, completed: &HashMap<NodeIndex, String>) -> String {
let mut inputs: Vec<(String, String)> = Vec::new();
for pred in self.graph.neighbors_directed(idx, Direction::Incoming) {
if let Some(pred_output) = completed.get(&pred) {
let edge_idx = self.graph.find_edge(pred, idx);
let active = edge_idx
.map(|eidx| &self.graph[eidx])
.and_then(|edge| edge.condition.as_ref())
.is_none_or(|cond| cond(pred_output));
if active {
let text = edge_idx
.and_then(|eidx| {
self.graph[eidx].transform.as_ref().map(|t| t(pred_output))
})
.unwrap_or_else(|| pred_output.clone());
let pred_name = self.graph[pred].name.clone();
inputs.push((pred_name, text));
}
}
}
inputs.sort_by(|a, b| a.0.cmp(&b.0));
if inputs.len() == 1 {
inputs
.into_iter()
.next()
.map(|(_, t)| t)
.unwrap_or_default()
} else {
inputs
.into_iter()
.map(|(_, text)| text)
.collect::<Vec<_>>()
.join("\n")
}
}
}
impl<P: LlmProvider + 'static> DagAgentBuilder<P> {
pub fn node(mut self, name: impl Into<String>, agent: AgentRunner<P>) -> Self {
self.nodes.push((name.into(), agent));
self
}
pub fn edge(mut self, from: &str, to: &str) -> Self {
self.edges.push((
from.to_string(),
to.to_string(),
DagEdge {
condition: None,
transform: None,
},
));
self
}
pub fn conditional_edge(
mut self,
from: &str,
to: &str,
condition: impl Fn(&str) -> bool + Send + Sync + 'static,
) -> Self {
self.edges.push((
from.to_string(),
to.to_string(),
DagEdge {
condition: Some(Box::new(condition)),
transform: None,
},
));
self
}
pub fn edge_with_transform(
mut self,
from: &str,
to: &str,
transform: impl Fn(&str) -> String + Send + Sync + 'static,
) -> Self {
self.edges.push((
from.to_string(),
to.to_string(),
DagEdge {
condition: None,
transform: Some(Box::new(transform)),
},
));
self
}
pub fn build(self) -> Result<DagAgent<P>, Error> {
if self.nodes.is_empty() {
return Err(Error::Config("DagAgent requires at least one node".into()));
}
let mut seen = std::collections::HashSet::new();
for (name, _) in &self.nodes {
if !seen.insert(name.as_str()) {
return Err(Error::Config(format!(
"DagAgent has duplicate node name: {name}"
)));
}
}
let mut graph = Graph::new();
let mut node_indices = HashMap::new();
for (name, agent) in self.nodes {
let idx = graph.add_node(DagNode {
name: name.clone(),
agent: Arc::new(agent),
});
node_indices.insert(name, idx);
}
for (from, to, edge) in self.edges {
let from_idx = node_indices.get(&from).ok_or_else(|| {
Error::Config(format!("DagAgent edge references unknown node: {from}"))
})?;
let to_idx = node_indices.get(&to).ok_or_else(|| {
Error::Config(format!("DagAgent edge references unknown node: {to}"))
})?;
graph.add_edge(*from_idx, *to_idx, edge);
}
if is_cyclic_directed(&graph) {
return Err(Error::Config("DagAgent graph contains a cycle".into()));
}
Ok(DagAgent { graph })
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::agent::test_helpers::{MockProvider, make_agent};
#[test]
fn dag_builder_rejects_empty_graph() {
let result = DagAgent::<MockProvider>::builder().build();
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("at least one node")
);
}
#[test]
fn dag_builder_rejects_duplicate_names() {
let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
"a", 1, 1,
)]));
let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
"b", 1, 1,
)]));
let result = DagAgent::builder()
.node("same", make_agent(p1, "same"))
.node("same", make_agent(p2, "same"))
.build();
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("duplicate node name")
);
}
#[test]
fn dag_builder_rejects_missing_edge_endpoint() {
let p = Arc::new(MockProvider::new(vec![MockProvider::text_response(
"a", 1, 1,
)]));
let result = DagAgent::builder()
.node("A", make_agent(p, "A"))
.edge("A", "B")
.build();
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("unknown node"));
}
#[test]
fn dag_builder_rejects_cycle() {
let pa = Arc::new(MockProvider::new(vec![MockProvider::text_response(
"a", 1, 1,
)]));
let pb = Arc::new(MockProvider::new(vec![MockProvider::text_response(
"b", 1, 1,
)]));
let result = DagAgent::builder()
.node("A", make_agent(pa, "A"))
.node("B", make_agent(pb, "B"))
.edge("A", "B")
.edge("B", "A")
.build();
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("cycle"));
}
#[test]
fn dag_builder_accepts_single_node() {
let p = Arc::new(MockProvider::new(vec![MockProvider::text_response(
"ok", 1, 1,
)]));
let result = DagAgent::builder().node("A", make_agent(p, "A")).build();
assert!(result.is_ok());
}
#[tokio::test]
async fn dag_single_node() {
let p = Arc::new(MockProvider::new(vec![MockProvider::text_response(
"hello", 10, 5,
)]));
let dag = DagAgent::builder()
.node("A", make_agent(p, "A"))
.build()
.unwrap();
let output = dag.execute("task").await.unwrap();
assert_eq!(output.result, "hello");
assert_eq!(output.tokens_used.input_tokens, 10);
assert_eq!(output.tokens_used.output_tokens, 5);
}
#[tokio::test]
async fn dag_linear_a_b_c() {
let pa = Arc::new(MockProvider::new(vec![MockProvider::text_response(
"out-a", 10, 5,
)]));
let pb = Arc::new(MockProvider::new(vec![MockProvider::text_response(
"out-b", 20, 10,
)]));
let pc = Arc::new(MockProvider::new(vec![MockProvider::text_response(
"out-c", 30, 15,
)]));
let dag = DagAgent::builder()
.node("A", make_agent(pa, "A"))
.node("B", make_agent(pb, "B"))
.node("C", make_agent(pc, "C"))
.edge("A", "B")
.edge("B", "C")
.build()
.unwrap();
let output = dag.execute("start").await.unwrap();
assert_eq!(output.result, "out-c");
assert_eq!(output.tokens_used.input_tokens, 60);
assert_eq!(output.tokens_used.output_tokens, 30);
}
#[tokio::test]
async fn dag_fan_out() {
let pa = Arc::new(MockProvider::new(vec![MockProvider::text_response(
"root-out", 10, 5,
)]));
let pb = Arc::new(MockProvider::new(vec![MockProvider::text_response(
"branch-b", 20, 10,
)]));
let pc = Arc::new(MockProvider::new(vec![MockProvider::text_response(
"branch-c", 30, 15,
)]));
let dag = DagAgent::builder()
.node("A", make_agent(pa, "A"))
.node("B", make_agent(pb, "B"))
.node("C", make_agent(pc, "C"))
.edge("A", "B")
.edge("A", "C")
.build()
.unwrap();
let output = dag.execute("task").await.unwrap();
assert!(output.result.contains("branch-b"));
assert!(output.result.contains("branch-c"));
assert_eq!(output.tokens_used.input_tokens, 60);
assert_eq!(output.tokens_used.output_tokens, 30);
}
#[tokio::test]
async fn dag_fan_in() {
let pa = Arc::new(MockProvider::new(vec![MockProvider::text_response(
"from-a", 10, 5,
)]));
let pb = Arc::new(MockProvider::new(vec![MockProvider::text_response(
"from-b", 20, 10,
)]));
let pc = Arc::new(MockProvider::new(vec![MockProvider::text_response(
"merged", 30, 15,
)]));
let dag = DagAgent::builder()
.node("A", make_agent(pa, "A"))
.node("B", make_agent(pb, "B"))
.node("C", make_agent(pc, "C"))
.edge("A", "C")
.edge("B", "C")
.build()
.unwrap();
let output = dag.execute("task").await.unwrap();
assert_eq!(output.result, "merged");
assert_eq!(output.tokens_used.input_tokens, 60);
assert_eq!(output.tokens_used.output_tokens, 30);
}
#[tokio::test]
async fn dag_diamond() {
let pa = Arc::new(MockProvider::new(vec![MockProvider::text_response(
"root", 10, 5,
)]));
let pb = Arc::new(MockProvider::new(vec![MockProvider::text_response(
"left", 10, 5,
)]));
let pc = Arc::new(MockProvider::new(vec![MockProvider::text_response(
"right", 10, 5,
)]));
let pd = Arc::new(MockProvider::new(vec![MockProvider::text_response(
"diamond-end",
10,
5,
)]));
let dag = DagAgent::builder()
.node("A", make_agent(pa, "A"))
.node("B", make_agent(pb, "B"))
.node("C", make_agent(pc, "C"))
.node("D", make_agent(pd, "D"))
.edge("A", "B")
.edge("A", "C")
.edge("B", "D")
.edge("C", "D")
.build()
.unwrap();
let output = dag.execute("task").await.unwrap();
assert_eq!(output.result, "diamond-end");
assert_eq!(output.tokens_used.input_tokens, 40);
assert_eq!(output.tokens_used.output_tokens, 20);
}
#[tokio::test]
async fn dag_conditional_edge() {
let pa = Arc::new(MockProvider::new(vec![MockProvider::text_response(
"no", 10, 5,
)]));
let pb = Arc::new(MockProvider::new(vec![MockProvider::text_response(
"branch-b", 10, 5,
)]));
let pc = Arc::new(MockProvider::new(vec![MockProvider::text_response(
"branch-c", 10, 5,
)]));
let dag = DagAgent::builder()
.node("A", make_agent(pa, "A"))
.node("B", make_agent(pb, "B"))
.node("C", make_agent(pc, "C"))
.edge("A", "B")
.conditional_edge("A", "C", |output| output.contains("yes"))
.build()
.unwrap();
let output = dag.execute("task").await.unwrap();
assert!(output.result.contains("branch-b"));
assert!(!output.result.contains("branch-c"));
assert_eq!(output.tokens_used.input_tokens, 20);
assert_eq!(output.tokens_used.output_tokens, 10);
}
#[tokio::test]
async fn dag_conditional_edge_passes() {
let pa = Arc::new(MockProvider::new(vec![MockProvider::text_response(
"yes", 10, 5,
)]));
let pb = Arc::new(MockProvider::new(vec![MockProvider::text_response(
"branch-b", 10, 5,
)]));
let pc = Arc::new(MockProvider::new(vec![MockProvider::text_response(
"branch-c", 10, 5,
)]));
let dag = DagAgent::builder()
.node("A", make_agent(pa, "A"))
.node("B", make_agent(pb, "B"))
.node("C", make_agent(pc, "C"))
.edge("A", "B")
.conditional_edge("A", "C", |output| output.contains("yes"))
.build()
.unwrap();
let output = dag.execute("task").await.unwrap();
assert!(output.result.contains("branch-b"));
assert!(output.result.contains("branch-c"));
assert_eq!(output.tokens_used.input_tokens, 30);
}
#[tokio::test]
async fn dag_edge_with_transform() {
let pa = Arc::new(MockProvider::new(vec![MockProvider::text_response(
"hello", 10, 5,
)]));
let pb = Arc::new(MockProvider::new(vec![MockProvider::text_response(
"got-it", 10, 5,
)]));
let dag = DagAgent::builder()
.node("A", make_agent(pa, "A"))
.node("B", make_agent(pb, "B"))
.edge_with_transform("A", "B", |text| text.to_uppercase())
.build()
.unwrap();
let output = dag.execute("task").await.unwrap();
assert_eq!(output.result, "got-it");
assert_eq!(output.tokens_used.input_tokens, 20);
}
#[tokio::test]
async fn dag_token_accumulation() {
let pa = Arc::new(MockProvider::new(vec![MockProvider::text_response(
"a", 100, 50,
)]));
let pb = Arc::new(MockProvider::new(vec![MockProvider::text_response(
"b", 200, 100,
)]));
let pc = Arc::new(MockProvider::new(vec![MockProvider::text_response(
"c", 300, 150,
)]));
let pd = Arc::new(MockProvider::new(vec![MockProvider::text_response(
"d", 400, 200,
)]));
let dag = DagAgent::builder()
.node("A", make_agent(pa, "A"))
.node("B", make_agent(pb, "B"))
.node("C", make_agent(pc, "C"))
.node("D", make_agent(pd, "D"))
.edge("A", "B")
.edge("A", "C")
.edge("B", "D")
.edge("C", "D")
.build()
.unwrap();
let output = dag.execute("task").await.unwrap();
assert_eq!(output.tokens_used.input_tokens, 1000);
assert_eq!(output.tokens_used.output_tokens, 500);
}
#[tokio::test]
async fn dag_error_carries_partial_usage() {
let pa = Arc::new(MockProvider::new(vec![MockProvider::text_response(
"ok", 100, 50,
)]));
let pb = Arc::new(MockProvider::new(vec![]));
let dag = DagAgent::builder()
.node("A", make_agent(pa, "A"))
.node("B", make_agent(pb, "B"))
.edge("A", "B")
.build()
.unwrap();
let err = dag.execute("task").await.unwrap_err();
let partial = err.partial_usage();
assert!(partial.input_tokens >= 100);
}
#[tokio::test]
async fn dag_parallel_roots_error_carries_sibling_usage() {
let pa = Arc::new(MockProvider::new(vec![MockProvider::text_response(
"ok", 200, 100,
)]));
let pb = Arc::new(MockProvider::new(vec![]));
let dag = DagAgent::builder()
.node("A", make_agent(pa, "A"))
.node("B", make_agent(pb, "B"))
.build()
.unwrap();
let err = dag.execute("task").await.unwrap_err();
let partial = err.partial_usage();
let _ = partial;
}
#[test]
fn dag_debug_impl() {
let p = Arc::new(MockProvider::new(vec![MockProvider::text_response(
"a", 1, 1,
)]));
let dag = DagAgent::builder()
.node("A", make_agent(p, "A"))
.build()
.unwrap();
let debug = format!("{dag:?}");
assert!(debug.contains("DagAgent"));
assert!(debug.contains("node_count"));
}
}