use std::{
collections::{HashMap, hash_map},
sync::Arc,
time::Duration,
};
use dashmap::DashMap;
use petgraph::{
Direction,
graph::{EdgeIndex, NodeIndex},
prelude::StableGraph,
visit::EdgeRef,
};
use thiserror::Error;
use tokio::sync::Mutex;
use crate::structs::agent::Agent;
pub struct DAGWorkflow {
pub name: String,
pub description: String,
agents: DashMap<String, Box<dyn Agent>>,
workflow: StableGraph<AgentNode, Flow>,
name_to_node: HashMap<String, NodeIndex>,
}
impl DAGWorkflow {
pub fn new<S: Into<String>>(name: S, description: S) -> Self {
Self {
name: name.into(),
description: description.into(),
agents: DashMap::new(),
workflow: StableGraph::new(),
name_to_node: HashMap::new(),
}
}
pub fn register_agent(&mut self, agent: Box<dyn Agent>) {
let agent_name = agent.name();
self.agents.insert(agent_name.clone(), agent);
if let hash_map::Entry::Vacant(e) = self.name_to_node.entry(agent_name.clone()) {
let node_idx = self.workflow.add_node(AgentNode {
name: agent_name.clone(),
last_result: Mutex::new(None),
});
e.insert(node_idx);
}
}
pub fn connect_agents(
&mut self,
from: &str,
to: &str,
flow: Flow,
) -> Result<EdgeIndex, GraphWorkflowError> {
if !self.agents.contains_key(from) {
return Err(GraphWorkflowError::AgentNotFound(format!(
"Source agent '{}' not found",
from
)));
}
if !self.agents.contains_key(to) {
return Err(GraphWorkflowError::AgentNotFound(format!(
"Target agent '{}' not found",
to
)));
}
let from_entry = self.name_to_node.entry(from.to_string());
let from_idx = *from_entry.or_insert_with(|| {
self.workflow.add_node(AgentNode {
name: from.to_string(),
last_result: Mutex::new(None),
})
});
let to_entry = self.name_to_node.entry(to.to_string());
let to_idx = *to_entry.or_insert_with(|| {
self.workflow.add_node(AgentNode {
name: to.to_string(),
last_result: Mutex::new(None),
})
});
let edge_idx = self.workflow.add_edge(from_idx, to_idx, flow);
if self.has_cycle() {
self.workflow.remove_edge(edge_idx);
return Err(GraphWorkflowError::CycleDetected);
}
Ok(edge_idx)
}
fn has_cycle(&self) -> bool {
let mut visited = vec![false; self.workflow.node_count()];
let mut rec_stack = vec![false; self.workflow.node_count()];
for node in self.workflow.node_indices() {
if !visited[node.index()] && self.is_cyclic_util(node, &mut visited, &mut rec_stack) {
return true;
}
}
false
}
fn is_cyclic_util(
&self,
node: NodeIndex,
visited: &mut [bool],
rec_stack: &mut [bool],
) -> bool {
visited[node.index()] = true;
rec_stack[node.index()] = true;
for neighbor in self.workflow.neighbors_directed(node, Direction::Outgoing) {
if !visited[neighbor.index()] {
if self.is_cyclic_util(neighbor, visited, rec_stack) {
return true;
}
} else if rec_stack[neighbor.index()] {
return true;
}
}
rec_stack[node.index()] = false;
false
}
pub fn disconnect_agents(&mut self, from: &str, to: &str) -> Result<(), GraphWorkflowError> {
let from_idx = self.name_to_node.get(from).ok_or_else(|| {
GraphWorkflowError::AgentNotFound(format!("Source agent '{}' not found", from))
})?;
let to_idx = self.name_to_node.get(to).ok_or_else(|| {
GraphWorkflowError::AgentNotFound(format!("Target agent '{}' not found", to))
})?;
if let Some(edge) = self.workflow.find_edge(*from_idx, *to_idx) {
self.workflow.remove_edge(edge);
Ok(())
} else {
Err(GraphWorkflowError::AgentNotFound(format!(
"No connection from '{}' to '{}'",
from, to
)))
}
}
pub fn remove_agent(&mut self, name: &str) -> Result<(), GraphWorkflowError> {
if let Some(node_idx) = self.name_to_node.remove(name) {
self.workflow.remove_node(node_idx);
self.agents.remove(name);
Ok(())
} else {
Err(GraphWorkflowError::AgentNotFound(format!(
"Agent '{}' not found",
name
)))
}
}
pub async fn execute_agent(
&self,
name: &str,
input: String,
) -> Result<String, GraphWorkflowError> {
if let Some(agent) = self.agents.get(name) {
agent
.run(input)
.await
.map_err(|e| GraphWorkflowError::AgentError(e.to_string()))
} else {
Err(GraphWorkflowError::AgentNotFound(format!(
"Agent '{}' not found",
name
)))
}
}
pub async fn execute_workflow(
&mut self,
start_agent: &str,
input: impl Into<String>,
) -> Result<DashMap<String, Result<String, GraphWorkflowError>>, GraphWorkflowError> {
let input = input.into();
let start_idx = self.name_to_node.get(start_agent).ok_or_else(|| {
GraphWorkflowError::AgentNotFound(format!("Start agent '{}' not found", start_agent))
})?;
let node_idxs = self.workflow.node_indices().collect::<Vec<_>>();
for idx in node_idxs {
if let Some(node_weight) = self.workflow.node_weight_mut(idx) {
let mut last_result = node_weight.last_result.lock().await;
*last_result = None;
}
}
let results = Arc::new(DashMap::new());
let edge_tracker = Arc::new(DashMap::new());
let processed_nodes = Arc::new(DashMap::new());
self.execute_node(
*start_idx,
input,
Arc::clone(&results),
edge_tracker,
processed_nodes,
)
.await?;
Ok(Arc::into_inner(results).expect("Results should not be poisoned"))
}
async fn execute_node(
&self,
node_idx: NodeIndex,
input: String,
results: Arc<DashMap<String, Result<String, GraphWorkflowError>>>,
edge_tracker: Arc<DashMap<(NodeIndex, NodeIndex), bool>>,
processed_nodes: Arc<DashMap<NodeIndex, Vec<(NodeIndex, String)>>>,
) -> Result<String, GraphWorkflowError> {
let agent_name = &self
.workflow
.node_weight(node_idx)
.ok_or_else(|| {
GraphWorkflowError::AgentNotFound("Node not found in graph".to_string())
})?
.name;
if let Some(entry) = results.get(agent_name) {
return entry.value().clone();
}
let result = tokio::time::timeout(
Duration::from_secs(300), self.execute_agent(agent_name, input),
)
.await
.map_err(|_| GraphWorkflowError::Timeout(agent_name.clone()))?;
results.entry(agent_name.clone()).or_insert(result.clone());
if let Some(node_weight) = self.workflow.node_weight(node_idx) {
let mut last_result = node_weight.last_result.lock().await;
*last_result = Some(result.clone());
}
match &result {
Ok(output) => {
let valid_edges = self
.workflow
.edges_directed(node_idx, Direction::Outgoing)
.filter(|edge| {
edge.weight()
.condition
.as_ref()
.map(|cond| cond(output))
.unwrap_or(true) })
.collect::<Vec<_>>();
let mut futures = Vec::new();
for edge in valid_edges {
let source_node = node_idx;
let target_node = edge.target();
let flow = edge.weight().clone();
let results_clone = Arc::clone(&results);
let processed_nodes_clone = Arc::clone(&processed_nodes);
let edge_tracker_clone = Arc::clone(&edge_tracker);
let future = async move {
let next_input = flow
.transform
.as_ref()
.map_or_else(|| output.clone(), |transform| transform(output.clone()));
edge_tracker_clone.insert((source_node, target_node), true);
processed_nodes_clone
.entry(target_node)
.or_default()
.push((source_node, next_input));
let incoming_edges = self
.workflow
.edges_directed(target_node, Direction::Incoming)
.map(|e| (e.source(), target_node))
.collect::<Vec<_>>();
let all_processed = incoming_edges
.iter()
.all(|edge| edge_tracker_clone.contains_key(edge));
if all_processed {
let mut aggregated_input = String::new();
if let Some(inputs) = processed_nodes_clone.get(&target_node) {
for (source_idx, input) in inputs.value() {
let source_name =
&self.workflow.node_weight(*source_idx).unwrap().name;
aggregated_input
.push_str(&format!("[From {}] {}\n", source_name, input));
}
}
if let Err(e) = self
.execute_node(
target_node,
aggregated_input,
results_clone,
edge_tracker_clone,
processed_nodes_clone,
)
.await
{
tracing::error!("Failed to execute node: {:?}", e);
}
}
};
futures.push(future);
}
futures::future::join_all(futures).await; },
Err(e) => {
tracing::error!("Agent '{}' execution failed: {:?}", agent_name, e);
},
}
result
}
pub fn get_workflow_structure(&self) -> HashMap<String, Vec<(String, Option<String>)>> {
let mut structure = HashMap::new();
for node_idx in self.workflow.node_indices() {
if let Some(node) = self.workflow.node_weight(node_idx) {
let mut connections = Vec::new();
for edge in self.workflow.edges_directed(node_idx, Direction::Outgoing) {
if let Some(target) = self.workflow.node_weight(edge.target()) {
let edge_label = if edge.weight().transform.is_some() {
Some("transform".to_string())
} else {
None
};
connections.push((target.name.clone(), edge_label));
}
}
structure.insert(node.name.clone(), connections);
}
}
structure
}
pub fn export_workflow_dot(&self) -> String {
let mut dot = String::from("digraph {\n");
for node_idx in self.workflow.node_indices() {
if let Some(node) = self.workflow.node_weight(node_idx) {
dot.push_str(&format!(
" \"{}\" [label=\"{}\"];\n",
node.name, node.name
));
}
}
for edge in self.workflow.edge_indices() {
if let Some((source, target)) = self.workflow.edge_endpoints(edge) {
if let (Some(source_node), Some(target_node)) = (
self.workflow.node_weight(source),
self.workflow.node_weight(target),
) {
dot.push_str(&format!(
" \"{}\" -> \"{}\";\n",
source_node.name, target_node.name
));
}
}
}
dot.push_str("}\n");
dot
}
pub fn find_execution_paths(
&self,
start_agent: &str,
) -> Result<Vec<Vec<String>>, GraphWorkflowError> {
let start_idx = self.name_to_node.get(start_agent).ok_or_else(|| {
GraphWorkflowError::AgentNotFound(format!("Start agent '{}' not found", start_agent))
})?;
let mut paths = Vec::new();
let mut current_path = Vec::new();
self.dfs_paths(*start_idx, &mut current_path, &mut paths);
Ok(paths)
}
fn dfs_paths(
&self,
node_idx: NodeIndex,
current_path: &mut Vec<String>,
all_paths: &mut Vec<Vec<String>>,
) {
if let Some(node) = self.workflow.node_weight(node_idx) {
current_path.push(node.name.clone());
let has_outgoing = self
.workflow
.neighbors_directed(node_idx, Direction::Outgoing)
.count()
> 0;
if !has_outgoing {
all_paths.push(current_path.clone());
} else {
for neighbor in self
.workflow
.neighbors_directed(node_idx, Direction::Outgoing)
{
self.dfs_paths(neighbor, current_path, all_paths);
}
}
current_path.pop();
}
}
pub fn detect_potential_deadlocks(&self) -> Vec<Vec<String>> {
let mut dependency_graph = petgraph::Graph::<String, ()>::new();
let mut node_map = HashMap::new();
for name in self.name_to_node.keys() {
let idx = dependency_graph.add_node(name.clone());
node_map.insert(name.clone(), idx);
}
for node_idx in self.workflow.node_indices() {
if let Some(node) = self.workflow.node_weight(node_idx) {
let target_dep_idx = *node_map.get(&node.name).unwrap();
for source in self
.workflow
.neighbors_directed(node_idx, Direction::Incoming)
{
if let Some(source_node) = self.workflow.node_weight(source) {
let source_dep_idx = *node_map.get(&source_node.name).unwrap();
dependency_graph.add_edge(source_dep_idx, target_dep_idx, ());
}
}
}
}
let sccs = petgraph::algo::kosaraju_scc(&dependency_graph);
sccs.into_iter()
.filter(|scc| scc.len() > 1)
.map(|scc| {
scc.into_iter()
.map(|idx| dependency_graph[idx].clone())
.collect()
})
.collect()
}
}
#[allow(clippy::type_complexity)]
#[derive(Clone, Default)]
pub struct Flow {
pub transform: Option<Arc<dyn Fn(String) -> String + Send + Sync>>,
pub condition: Option<Arc<dyn Fn(&str) -> bool + Send + Sync>>,
}
#[derive(Debug)]
pub struct AgentNode {
pub name: String,
pub last_result: Mutex<Option<Result<String, GraphWorkflowError>>>,
}
#[derive(Clone, Debug, Error)]
pub enum GraphWorkflowError {
#[error("Agent Error: {0}")]
AgentError(String),
#[error("Agent not found: {0}")]
AgentNotFound(String),
#[error("Cycle detected in workflow")]
CycleDetected,
#[error("Timeout executing agent: {0}")]
Timeout(String),
#[error("Deadlock detected in workflow execution")]
Deadlock,
#[error("Workflow execution canceled")]
Canceled,
}
#[cfg(test)]
mod tests {
use futures::future;
use crate::structs::test_utils::{MockAgent, create_failing_agent, create_mock_agent};
use super::*;
#[test]
fn test_dag_creation() {
let workflow = DAGWorkflow::new("test", "Test workflow");
assert_eq!(workflow.name, "test");
assert_eq!(workflow.description, "Test workflow");
assert_eq!(workflow.agents.len(), 0);
assert_eq!(workflow.workflow.node_count(), 0);
assert_eq!(workflow.workflow.edge_count(), 0);
}
#[test]
fn test_agent_registration() {
let mut workflow = DAGWorkflow::new("test", "Test workflow");
workflow.register_agent(create_mock_agent("1", "agent1", "Test agent", "response1"));
assert_eq!(workflow.agents.len(), 1);
assert_eq!(workflow.workflow.node_count(), 1);
assert!(workflow.name_to_node.contains_key("agent1"));
}
#[test]
fn test_agent_connection() {
let mut workflow = DAGWorkflow::new("test", "Test workflow");
workflow.register_agent(create_mock_agent("1", "agent1", "First agent", "response1"));
workflow.register_agent(create_mock_agent(
"2",
"agent2",
"Second agent",
"response2",
));
let result = workflow.connect_agents("agent1", "agent2", Flow::default());
assert!(result.is_ok());
assert_eq!(workflow.workflow.edge_count(), 1);
}
#[test]
fn test_agent_connection_failure_nonexistent_agent() {
let mut workflow = DAGWorkflow::new("test", "Test workflow");
workflow.register_agent(create_mock_agent("1", "agent1", "Test agent", "response1"));
let result = workflow.connect_agents("agent1", "nonexistent", Flow::default());
assert!(matches!(result, Err(GraphWorkflowError::AgentNotFound(_))));
let result = workflow.connect_agents("nonexistent", "agent1", Flow::default());
assert!(matches!(result, Err(GraphWorkflowError::AgentNotFound(_))));
}
#[test]
fn test_cycle_detection() {
let mut workflow = DAGWorkflow::new("test", "Test workflow");
workflow.register_agent(create_mock_agent("1", "agent1", "First agent", "response1"));
workflow.register_agent(create_mock_agent(
"2",
"agent2",
"Second agent",
"response2",
));
workflow.register_agent(create_mock_agent("3", "agent3", "Third agent", "response3"));
let result1 = workflow.connect_agents("agent1", "agent2", Flow::default());
assert!(result1.is_ok());
let result2 = workflow.connect_agents("agent2", "agent3", Flow::default());
assert!(result2.is_ok());
let result3 = workflow.connect_agents("agent3", "agent1", Flow::default());
assert!(matches!(result3, Err(GraphWorkflowError::CycleDetected)));
assert_eq!(workflow.workflow.edge_count(), 2);
}
#[test]
fn test_agent_disconnection() {
let mut workflow = DAGWorkflow::new("test", "Test workflow");
workflow.register_agent(create_mock_agent("1", "agent1", "First agent", "response1"));
workflow.register_agent(create_mock_agent(
"2",
"agent2",
"Second agent",
"response2",
));
workflow
.connect_agents("agent1", "agent2", Flow::default())
.unwrap();
assert_eq!(workflow.workflow.edge_count(), 1);
let result = workflow.disconnect_agents("agent1", "agent2");
assert!(result.is_ok());
assert_eq!(workflow.workflow.edge_count(), 0);
}
#[test]
fn test_agent_disconnection_failure() {
let mut workflow = DAGWorkflow::new("test", "Test workflow");
workflow.register_agent(create_mock_agent("1", "agent1", "First agent", "response1"));
workflow.register_agent(create_mock_agent(
"2",
"agent2",
"Second agent",
"response2",
));
let result = workflow.disconnect_agents("agent1", "agent2");
assert!(matches!(result, Err(GraphWorkflowError::AgentNotFound(_))));
let result = workflow.disconnect_agents("nonexistent", "agent2");
assert!(matches!(result, Err(GraphWorkflowError::AgentNotFound(_))));
}
#[test]
fn test_agent_removal() {
let mut workflow = DAGWorkflow::new("test", "Test workflow");
workflow.register_agent(create_mock_agent("1", "agent1", "First agent", "response1"));
workflow.register_agent(create_mock_agent(
"2",
"agent2",
"Second agent",
"response2",
));
workflow
.connect_agents("agent1", "agent2", Flow::default())
.unwrap();
assert_eq!(workflow.agents.len(), 2);
assert_eq!(workflow.workflow.node_count(), 2);
let result = workflow.remove_agent("agent1");
assert!(result.is_ok());
assert_eq!(workflow.agents.len(), 1);
assert_eq!(workflow.workflow.node_count(), 1);
assert!(!workflow.name_to_node.contains_key("agent1"));
assert_eq!(workflow.workflow.edge_count(), 0);
}
#[test]
fn test_agent_removal_nonexistent() {
let mut workflow = DAGWorkflow::new("test", "Test workflow");
let result = workflow.remove_agent("nonexistent");
assert!(matches!(result, Err(GraphWorkflowError::AgentNotFound(_))));
}
#[tokio::test]
async fn test_execute_single_agent() {
let mut workflow = DAGWorkflow::new("test", "Test workflow");
workflow.register_agent(create_mock_agent("1", "agent1", "Test agent", "response1"));
let result = workflow.execute_agent("agent1", "input".to_string()).await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "response1");
}
#[tokio::test]
async fn test_execute_single_agent_failure() {
let mut workflow = DAGWorkflow::new("test", "Test workflow");
workflow.register_agent(create_failing_agent("1", "agent1", "test error"));
let result = workflow.execute_agent("agent1", "input".to_string()).await;
assert!(matches!(result, Err(GraphWorkflowError::AgentError(_))));
}
#[tokio::test]
async fn test_execute_single_agent_not_found() {
let workflow = DAGWorkflow::new("test", "Test workflow");
let result = workflow
.execute_agent("nonexistent", "input".to_string())
.await;
assert!(matches!(result, Err(GraphWorkflowError::AgentNotFound(_))));
}
#[tokio::test]
async fn test_execute_workflow_linear() {
let mut workflow = DAGWorkflow::new("test", "Test workflow");
workflow.register_agent(create_mock_agent("1", "agent1", "First agent", "response1"));
workflow.register_agent(create_mock_agent(
"2",
"agent2",
"Second agent",
"response2",
));
workflow
.connect_agents("agent1", "agent2", Flow::default())
.unwrap();
let results = workflow.execute_workflow("agent1", "input").await.unwrap();
assert_eq!(results.len(), 2);
assert_eq!(
results.get("agent1").unwrap().as_ref().unwrap(),
"response1"
);
assert_eq!(
results.get("agent2").unwrap().as_ref().unwrap(),
"response2"
);
}
#[tokio::test]
async fn test_execute_workflow_branching() {
let mut workflow = DAGWorkflow::new("test", "Test workflow");
workflow.register_agent(create_mock_agent("1", "agent1", "Root agent", "response1"));
workflow.register_agent(create_mock_agent("2", "agent2", "Branch 1", "response2"));
workflow.register_agent(create_mock_agent("3", "agent3", "Branch 2", "response3"));
workflow
.connect_agents("agent1", "agent2", Flow::default())
.unwrap();
workflow
.connect_agents("agent1", "agent3", Flow::default())
.unwrap();
let results = workflow.execute_workflow("agent1", "input").await.unwrap();
assert_eq!(results.len(), 3);
assert_eq!(
results.get("agent1").unwrap().as_ref().unwrap(),
"response1"
);
assert_eq!(
results.get("agent2").unwrap().as_ref().unwrap(),
"response2"
);
assert_eq!(
results.get("agent3").unwrap().as_ref().unwrap(),
"response3"
);
}
#[tokio::test]
async fn test_execute_workflow_with_transformation() {
let mut workflow = DAGWorkflow::new("test", "Test workflow");
workflow.register_agent(create_mock_agent("1", "agent1", "First agent", "response1"));
workflow.register_agent(create_mock_agent(
"2",
"agent2",
"Second agent",
"response2",
));
let transform_fn = Arc::new(|input: String| format!("transformed: {}", input));
let flow = Flow {
transform: Some(transform_fn),
condition: None,
};
workflow.connect_agents("agent1", "agent2", flow).unwrap();
let results = workflow.execute_workflow("agent1", "input").await.unwrap();
assert_eq!(results.len(), 2);
let structure = workflow.get_workflow_structure();
let agent1_connections = &structure["agent1"];
assert_eq!(agent1_connections.len(), 1);
assert_eq!(agent1_connections[0].0, "agent2");
assert_eq!(agent1_connections[0].1, Some("transform".to_string()));
}
#[tokio::test]
async fn test_execute_workflow_with_condition_true() {
let mut workflow = DAGWorkflow::new("test", "Test workflow");
workflow.register_agent(create_mock_agent("1", "agent1", "First agent", "true"));
workflow.register_agent(create_mock_agent("2", "agent2", "Second agent", "executed"));
let true_condition = Arc::new(|output: &str| output.contains("true"));
workflow
.connect_agents(
"agent1",
"agent2",
Flow {
transform: None,
condition: Some(true_condition),
},
)
.unwrap();
let results = workflow.execute_workflow("agent1", "input").await.unwrap();
assert_eq!(results.len(), 2);
assert!(results.contains_key("agent1"));
assert!(results.contains_key("agent2"));
}
#[tokio::test]
async fn test_execute_workflow_with_condition_false() {
let mut workflow = DAGWorkflow::new("test", "Test workflow");
workflow.register_agent(create_mock_agent("1", "agent1", "First agent", "response1"));
workflow.register_agent(create_mock_agent(
"2",
"agent2",
"Second agent",
"not executed",
));
let false_condition = Arc::new(|output: &str| output.contains("nonexistent"));
workflow
.connect_agents(
"agent1",
"agent2",
Flow {
transform: None,
condition: Some(false_condition),
},
)
.unwrap();
let results = workflow.execute_workflow("agent1", "input").await.unwrap();
assert_eq!(results.len(), 1);
assert!(results.contains_key("agent1"));
assert!(!results.contains_key("agent2"));
}
#[tokio::test]
async fn test_workflow_execution_start_agent_not_found() {
let mut workflow = DAGWorkflow::new("test", "Test workflow");
workflow.register_agent(create_mock_agent("1", "agent1", "First agent", "response1"));
let result = workflow.execute_workflow("nonexistent", "input").await;
assert!(matches!(result, Err(GraphWorkflowError::AgentNotFound(_))));
}
#[tokio::test]
async fn test_workflow_execution_with_failing_agent() {
let mut workflow = DAGWorkflow::new("test", "Test workflow");
workflow.register_agent(create_mock_agent("1", "agent1", "First agent", "response1"));
workflow.register_agent(create_failing_agent("2", "agent2", "fail error"));
workflow.register_agent(create_mock_agent("3", "agent3", "Third agent", "response3"));
workflow
.connect_agents("agent1", "agent2", Flow::default())
.unwrap();
workflow
.connect_agents("agent2", "agent3", Flow::default())
.unwrap();
let results = workflow.execute_workflow("agent1", "input").await.unwrap();
assert_eq!(results.len(), 2);
assert!(results.contains_key("agent1"));
assert!(results.contains_key("agent2"));
assert!(!results.contains_key("agent3"));
let agent2_result = results.get("agent2").unwrap();
assert!(agent2_result.is_err());
}
#[test]
fn test_find_execution_paths() {
let mut workflow = DAGWorkflow::new("test", "Test workflow");
workflow.register_agent(create_mock_agent("0", "start", "Starting point", "start"));
workflow.register_agent(create_mock_agent("1", "a", "Path A", "a"));
workflow.register_agent(create_mock_agent("2", "b", "Path B", "b"));
workflow.register_agent(create_mock_agent("3", "c", "End of A", "c"));
workflow.register_agent(create_mock_agent("4", "d", "End of B", "d"));
workflow
.connect_agents("start", "a", Flow::default())
.unwrap();
workflow
.connect_agents("start", "b", Flow::default())
.unwrap();
workflow.connect_agents("a", "c", Flow::default()).unwrap();
workflow.connect_agents("b", "d", Flow::default()).unwrap();
let paths = workflow.find_execution_paths("start").unwrap();
assert_eq!(paths.len(), 2);
let has_path1 = paths
.iter()
.any(|p| p == &vec!["start".to_string(), "a".to_string(), "c".to_string()]);
let has_path2 = paths
.iter()
.any(|p| p == &vec!["start".to_string(), "b".to_string(), "d".to_string()]);
assert!(has_path1);
assert!(has_path2);
}
#[test]
fn test_find_execution_paths_start_agent_not_found() {
let workflow = DAGWorkflow::new("test", "Test workflow");
let result = workflow.find_execution_paths("nonexistent");
assert!(matches!(result, Err(GraphWorkflowError::AgentNotFound(_))));
}
#[test]
fn test_find_execution_paths_diamond_pattern() {
let mut workflow = DAGWorkflow::new("test", "Test workflow");
workflow.register_agent(create_mock_agent("0", "start", "Start", "start"));
workflow.register_agent(create_mock_agent("1", "a", "Middle A", "a"));
workflow.register_agent(create_mock_agent("2", "b", "Middle B", "b"));
workflow.register_agent(create_mock_agent("3", "end", "End", "end"));
workflow
.connect_agents("start", "a", Flow::default())
.unwrap();
workflow
.connect_agents("start", "b", Flow::default())
.unwrap();
workflow
.connect_agents("a", "end", Flow::default())
.unwrap();
workflow
.connect_agents("b", "end", Flow::default())
.unwrap();
let paths = workflow.find_execution_paths("start").unwrap();
assert_eq!(paths.len(), 2);
let has_path1 = paths
.iter()
.any(|p| p == &vec!["start".to_string(), "a".to_string(), "end".to_string()]);
let has_path2 = paths
.iter()
.any(|p| p == &vec!["start".to_string(), "b".to_string(), "end".to_string()]);
assert!(has_path1);
assert!(has_path2);
}
#[test]
fn test_detect_potential_deadlocks() {
let mut workflow = DAGWorkflow::new("test", "Test workflow");
workflow.register_agent(create_mock_agent("1", "a", "Agent A", "a"));
workflow.register_agent(create_mock_agent("2", "b", "Agent B", "b"));
workflow.register_agent(create_mock_agent("3", "c", "Agent C", "c"));
workflow.connect_agents("a", "b", Flow::default()).unwrap();
workflow.connect_agents("b", "c", Flow::default()).unwrap();
let deadlocks = workflow.detect_potential_deadlocks();
assert_eq!(deadlocks.len(), 0);
let result = workflow.connect_agents("c", "a", Flow::default());
assert!(matches!(result, Err(GraphWorkflowError::CycleDetected)));
}
#[test]
fn test_get_workflow_structure() {
let mut workflow = DAGWorkflow::new("test", "Test workflow");
workflow.register_agent(create_mock_agent("1", "a", "Agent A", "a"));
workflow.register_agent(create_mock_agent("2", "b", "Agent B", "b"));
workflow.register_agent(create_mock_agent("3", "c", "Agent C", "c"));
workflow.connect_agents("a", "b", Flow::default()).unwrap();
let transform_fn = Arc::new(|input: String| format!("transformed: {}", input));
let flow = Flow {
transform: Some(transform_fn),
condition: None,
};
workflow.connect_agents("b", "c", flow).unwrap();
let structure = workflow.get_workflow_structure();
assert_eq!(structure.len(), 3);
assert_eq!(structure["a"].len(), 1);
assert_eq!(structure["a"][0].0, "b");
assert_eq!(structure["a"][0].1, None);
assert_eq!(structure["b"].len(), 1);
assert_eq!(structure["b"][0].0, "c");
assert_eq!(structure["b"][0].1, Some("transform".to_string()));
assert_eq!(structure["c"].len(), 0); }
#[test]
fn test_export_workflow_dot() {
let mut workflow = DAGWorkflow::new("test", "Test workflow");
workflow.register_agent(create_mock_agent("1", "a", "Agent A", "a"));
workflow.register_agent(create_mock_agent("2", "b", "Agent B", "b"));
workflow.connect_agents("a", "b", Flow::default()).unwrap();
let dot = workflow.export_workflow_dot();
assert!(dot.contains("digraph {"));
assert!(dot.contains("\"a\" [label=\"a\"]"));
assert!(dot.contains("\"b\" [label=\"b\"]"));
assert!(dot.contains("\"a\" -> \"b\""));
assert!(dot.contains("}"));
}
#[tokio::test]
async fn test_caching_execution_results() {
let mut workflow = DAGWorkflow::new("test", "Test workflow");
let mut agent = Box::new(MockAgent::new());
let agent_name = "counter".to_string();
agent.expect_name().return_const(agent_name.clone());
agent.expect_id().return_const("1".to_string());
agent
.expect_description()
.return_const("Counter Agent".to_string());
let mut count = 0;
agent.expect_run().returning(move |_| {
count += 1;
Box::pin(future::ready(Ok(format!("Called {} times", count))))
});
agent.expect_is_response_complete().returning(|_| true);
agent
.expect_run_multiple_tasks()
.returning(|_| Box::pin(future::ready(Ok(vec![]))));
agent
.expect_plan()
.returning(|_| Box::pin(future::ready(Ok(()))));
agent
.expect_query_long_term_memory()
.returning(|_| Box::pin(future::ready(Ok(()))));
agent
.expect_save_task_state()
.returning(|_| Box::pin(future::ready(Ok(()))));
workflow.register_agent(agent);
let results1 = workflow
.execute_workflow("counter", "input1")
.await
.unwrap();
assert_eq!(
results1.get("counter").unwrap().as_ref().unwrap(),
"Called 1 times"
);
let results2 = workflow
.execute_workflow("counter", "input2")
.await
.unwrap();
assert_eq!(
results2.get("counter").unwrap().as_ref().unwrap(),
"Called 2 times"
);
let result3 = workflow
.execute_agent("counter", "input3".to_string())
.await
.unwrap();
assert_eq!(result3, "Called 3 times");
}
#[tokio::test]
async fn test_execute_node_result_caching() {
let mut workflow = DAGWorkflow::new("test", "Test workflow");
let mut agent1 = Box::new(MockAgent::new());
agent1.expect_name().return_const("agent1".to_string());
agent1.expect_id().return_const("1".to_string());
agent1
.expect_description()
.return_const("First agent".to_string());
let mut run_count = 0;
agent1.expect_run().returning(move |input| {
run_count += 1;
Box::pin(future::ready(Ok(format!(
"response for '{}' (call #{})",
input, run_count
))))
});
agent1.expect_is_response_complete().returning(|_| true);
agent1
.expect_run_multiple_tasks()
.returning(|_| Box::pin(future::ready(Ok(vec![]))));
agent1
.expect_plan()
.returning(|_| Box::pin(future::ready(Ok(()))));
agent1
.expect_query_long_term_memory()
.returning(|_| Box::pin(future::ready(Ok(()))));
agent1
.expect_save_task_state()
.returning(|_| Box::pin(future::ready(Ok(()))));
workflow.register_agent(agent1);
workflow.register_agent(create_mock_agent(
"2",
"agent2",
"Second agent",
"response2",
));
workflow
.connect_agents("agent1", "agent2", Flow::default())
.unwrap();
let agent1_idx = *workflow.name_to_node.get("agent1").unwrap();
let results = Arc::new(DashMap::new());
let edge_tracker = Arc::new(DashMap::new());
let processed_nodes = Arc::new(DashMap::new());
let result1 = workflow
.execute_node(
agent1_idx,
"input1".to_string(),
Arc::clone(&results),
Arc::clone(&edge_tracker),
Arc::clone(&processed_nodes),
)
.await
.unwrap();
assert_eq!(result1, "response for 'input1' (call #1)");
assert!(results.contains_key("agent1"));
assert!(results.contains_key("agent2"));
let result2 = workflow
.execute_node(
agent1_idx,
"input2".to_string(),
Arc::clone(&results),
Arc::clone(&edge_tracker),
Arc::clone(&processed_nodes),
)
.await
.unwrap();
assert_eq!(result2, "response for 'input1' (call #1)");
results.clear();
let result3 = workflow
.execute_node(
agent1_idx,
"input3".to_string(),
Arc::clone(&results),
Arc::clone(&edge_tracker),
Arc::clone(&processed_nodes),
)
.await
.unwrap();
assert_eq!(result3, "response for 'input3' (call #2)");
}
}