1#![deny(missing_docs)]
4
5use std::{
6 collections::{HashMap, hash_map},
7 fmt::Debug,
8 sync::Arc,
9 time::Duration,
10};
11
12use dashmap::DashMap;
13use petgraph::{
14 Direction,
15 graph::{EdgeIndex, NodeIndex},
16 prelude::StableGraph,
17 visit::EdgeRef,
18};
19use thiserror::Error;
20use tokio::sync::Mutex;
21
22use crate::agent::Agent;
23
24pub struct DAGWorkflow {
26 pub name: String,
28 pub description: String,
30 agents: DashMap<String, Arc<dyn Agent>>,
32 workflow: StableGraph<AgentNode, Flow>,
34 name_to_node: HashMap<String, NodeIndex>,
36}
37
38impl DAGWorkflow {
39 pub fn new<S: Into<String>>(name: S, description: S) -> Self {
41 Self {
42 name: name.into(),
43 description: description.into(),
44 agents: DashMap::new(),
45 workflow: StableGraph::new(),
46 name_to_node: HashMap::new(),
47 }
48 }
49
50 pub fn register_agent(&mut self, agent: Arc<dyn Agent>) {
52 let agent_name = agent.name();
53 self.agents.insert(agent_name.clone(), agent);
54
55 if let hash_map::Entry::Vacant(e) = self.name_to_node.entry(agent_name.clone()) {
57 let node_idx = self.workflow.add_node(AgentNode {
58 name: agent_name.clone(),
59 last_result: Mutex::new(None),
60 });
61 e.insert(node_idx);
62 }
63 }
64
65 pub fn connect_agents(
67 &mut self,
68 from: &str,
69 to: &str,
70 flow: Flow,
71 ) -> Result<EdgeIndex, GraphWorkflowError> {
72 if !self.agents.contains_key(from) {
74 return Err(GraphWorkflowError::AgentNotFound(format!(
75 "Source agent '{from}' not found",
76 )));
77 }
78 if !self.agents.contains_key(to) {
79 return Err(GraphWorkflowError::AgentNotFound(format!(
80 "Target agent '{to}' not found",
81 )));
82 }
83
84 let from_entry = self.name_to_node.entry(from.to_owned());
86 let from_idx = *from_entry.or_insert_with(|| {
87 self.workflow.add_node(AgentNode {
88 name: from.to_owned(),
89 last_result: Mutex::new(None),
90 })
91 });
92
93 let to_entry = self.name_to_node.entry(to.to_owned());
94 let to_idx = *to_entry.or_insert_with(|| {
95 self.workflow.add_node(AgentNode {
96 name: to.to_owned(),
97 last_result: Mutex::new(None),
98 })
99 });
100
101 let edge_idx = self.workflow.add_edge(from_idx, to_idx, flow);
103
104 if self.has_cycle() {
106 self.workflow.remove_edge(edge_idx);
108 return Err(GraphWorkflowError::CycleDetected);
109 }
110
111 Ok(edge_idx)
112 }
113
114 fn has_cycle(&self) -> bool {
116 let mut visited = vec![false; self.workflow.node_count()];
118 let mut rec_stack = vec![false; self.workflow.node_count()];
119
120 for node in self.workflow.node_indices() {
121 if !visited[node.index()] && self.is_cyclic_util(node, &mut visited, &mut rec_stack) {
122 return true;
123 }
124 }
125 false
126 }
127
128 fn is_cyclic_util(
129 &self,
130 node: NodeIndex,
131 visited: &mut [bool],
132 rec_stack: &mut [bool],
133 ) -> bool {
134 visited[node.index()] = true;
135 rec_stack[node.index()] = true;
136
137 for neighbor in self.workflow.neighbors_directed(node, Direction::Outgoing) {
138 if !visited[neighbor.index()] {
139 if self.is_cyclic_util(neighbor, visited, rec_stack) {
140 return true;
141 }
142 } else if rec_stack[neighbor.index()] {
143 return true;
144 }
145 }
146
147 rec_stack[node.index()] = false;
148 false
149 }
150
151 pub fn disconnect_agents(&mut self, from: &str, to: &str) -> Result<(), GraphWorkflowError> {
153 let from_idx = self.name_to_node.get(from).ok_or_else(|| {
154 GraphWorkflowError::AgentNotFound(format!("Source agent '{from}' not found"))
155 })?;
156 let to_idx = self.name_to_node.get(to).ok_or_else(|| {
157 GraphWorkflowError::AgentNotFound(format!("Target agent '{to}' not found"))
158 })?;
159
160 if let Some(edge) = self.workflow.find_edge(*from_idx, *to_idx) {
162 self.workflow.remove_edge(edge);
163 Ok(())
164 } else {
165 Err(GraphWorkflowError::AgentNotFound(format!(
166 "No connection from '{from}' to '{to}'"
167 )))
168 }
169 }
170
171 pub fn remove_agent(&mut self, name: &str) -> Result<(), GraphWorkflowError> {
173 if let Some(node_idx) = self.name_to_node.remove(name) {
174 self.workflow.remove_node(node_idx);
175 self.agents.remove(name);
176 Ok(())
177 } else {
178 Err(GraphWorkflowError::AgentNotFound(format!(
179 "Agent '{name}' not found"
180 )))
181 }
182 }
183
184 pub async fn execute_agent(
186 &self,
187 name: &str,
188 input: String,
189 ) -> Result<String, GraphWorkflowError> {
190 if let Some(agent) = self.agents.get(name) {
191 agent
192 .run(input)
193 .await
194 .map_err(|e| GraphWorkflowError::AgentError(e.to_string()))
195 } else {
196 Err(GraphWorkflowError::AgentNotFound(format!(
197 "Agent '{name}' not found"
198 )))
199 }
200 }
201
202 pub async fn execute_workflow(
214 &mut self,
215 start_agents: &[&str],
216 input: impl Into<String>,
217 ) -> Result<DashMap<String, Result<String, GraphWorkflowError>>, GraphWorkflowError> {
218 let input = input.into();
219
220 let start_indices = start_agents
221 .iter()
222 .map(|agent| {
223 self.name_to_node
224 .get(*agent)
225 .ok_or_else(|| {
226 GraphWorkflowError::AgentNotFound(format!(
227 "Start agent '{agent}' not found"
228 ))
229 })
230 .copied()
231 })
232 .collect::<Result<Vec<_>, _>>()?;
233
234 let node_idxs = self.workflow.node_indices().collect::<Vec<_>>();
236 for idx in node_idxs {
237 if let Some(node_weight) = self.workflow.node_weight_mut(idx) {
238 let mut last_result = node_weight.last_result.lock().await;
239 *last_result = None;
240 }
241 }
242
243 let results = Arc::new(DashMap::new());
245 let edge_tracker = Arc::new(DashMap::new());
247 let processed_nodes = Arc::new(DashMap::new());
248 let mut tasks = Vec::new();
250 for &start_idx in &start_indices {
251 let task = self.execute_node(
252 start_idx,
253 input.clone(),
254 Arc::clone(&results),
255 Arc::clone(&edge_tracker),
256 Arc::clone(&processed_nodes),
257 );
258 tasks.push(task);
259 }
260 futures::future::join_all(tasks)
261 .await
262 .into_iter()
263 .collect::<Result<Vec<_>, _>>()
264 .map_err(|e| GraphWorkflowError::ExecutionError(e.to_string()))?;
265 Ok(Arc::into_inner(results).expect("Results should not be poisoned"))
266 }
267
268 async fn execute_node(
269 &self,
270 node_idx: NodeIndex,
271 input: String,
272 results: Arc<DashMap<String, Result<String, GraphWorkflowError>>>,
273 edge_tracker: Arc<DashMap<(NodeIndex, NodeIndex), bool>>,
274 processed_nodes: Arc<DashMap<NodeIndex, Vec<(NodeIndex, String)>>>,
275 ) -> Result<String, GraphWorkflowError> {
276 let agent_name = &self
278 .workflow
279 .node_weight(node_idx)
280 .ok_or_else(|| GraphWorkflowError::AgentNotFound("Node not found in graph".to_owned()))?
281 .name;
282
283 if let Some(entry) = results.get(agent_name) {
285 return entry.value().clone();
286 }
287
288 let result = tokio::time::timeout(
290 Duration::from_secs(3600), self.execute_agent(agent_name, input),
292 )
293 .await
294 .map_err(|_| GraphWorkflowError::Timeout(agent_name.clone()))?;
295
296 results.insert(agent_name.clone(), result.clone());
298
299 if let Some(node_weight) = self.workflow.node_weight(node_idx) {
301 let mut last_result = node_weight.last_result.lock().await;
302 *last_result = Some(result.clone());
303 }
304
305 match &result {
307 Ok(output) => {
308 let valid_edges = self
310 .workflow
311 .edges_directed(node_idx, Direction::Outgoing)
312 .filter(|edge| {
313 let condition_result = edge
315 .weight()
316 .condition
317 .as_ref()
318 .map(|cond| {
319 let result = cond(output);
321 tracing::debug!(
322 "Condition for edge {:?} -> {:?}: {}",
323 node_idx,
324 edge.target(),
325 result
326 );
327 result
328 })
329 .unwrap_or(true); condition_result
332 })
333 .collect::<Vec<_>>();
334
335 let mut futures = Vec::new();
336
337 for edge in valid_edges {
338 let source_node = node_idx;
339 let target_node = edge.target();
340 let flow = edge.weight().clone();
341 let results_clone = Arc::clone(&results);
342 let processed_nodes_clone = Arc::clone(&processed_nodes);
343 let edge_tracker_clone = Arc::clone(&edge_tracker);
344
345 let future = async move {
346 let next_input = flow
348 .transform
349 .as_ref()
350 .map_or_else(|| output.clone(), |transform| transform(output.clone()));
351
352 edge_tracker_clone.insert((source_node, target_node), true);
354
355 {
358 processed_nodes_clone
359 .entry(target_node)
360 .and_modify(|v| v.push((source_node, next_input.clone())))
361 .or_insert_with(|| vec![(source_node, next_input.clone())]);
362 }
363
364 let all_incoming_edges = self
366 .workflow
367 .edges_directed(target_node, Direction::Incoming)
368 .map(|e| (e.source(), target_node))
369 .collect::<Vec<_>>();
370
371 let all_processed = all_incoming_edges.iter().all(|edge| {
374 let processed = edge_tracker_clone.contains_key(edge);
376
377 let conditionally_skipped = if !processed {
380 if let Some(edge_idx) = self.workflow.find_edge(edge.0, edge.1) {
381 let edge_weight = self.workflow.edge_weight(edge_idx).unwrap();
382 if let Some(cond) = &edge_weight.condition {
383 if let Some(source_name) =
385 self.workflow.node_weight(edge.0).map(|n| &n.name)
386 {
387 if let Some(source_result) =
388 results_clone.get(source_name)
389 {
390 if let Ok(output) = source_result.as_ref() {
391 let condition_result = !cond(output);
393 if condition_result {
394 edge_tracker_clone
396 .insert((edge.0, edge.1), true);
397 }
398 condition_result
399 } else {
400 edge_tracker_clone
402 .insert((edge.0, edge.1), true);
403 true
404 }
405 } else {
406 false
407 }
408 } else {
409 false
410 }
411 } else {
412 false
413 }
414 } else {
415 false
416 }
417 } else {
418 false
419 };
420
421 tracing::debug!(
422 "Edge {:?} processed: {}, conditionally skipped: {}",
423 edge,
424 processed,
425 conditionally_skipped
426 );
427 processed || conditionally_skipped
428 });
429
430 if all_processed {
432 let aggregated_input = processed_nodes_clone
434 .get(&target_node)
435 .map(|inputs| {
436 let mut sorted_inputs = inputs.value().clone();
438 sorted_inputs.sort_by_key(|(source_idx, _)| *source_idx);
439
440 tracing::debug!(
442 "Node {:?} has {} inputs",
443 target_node,
444 sorted_inputs.len()
445 );
446
447 let formatted_inputs = sorted_inputs
449 .iter()
450 .map(|(source_idx, input)| {
451 let source_name = &self
452 .workflow
453 .node_weight(*source_idx)
454 .unwrap()
455 .name;
456 format!("[From {source_name}] {input}")
457 })
458 .collect::<Vec<_>>();
459
460 let result = formatted_inputs.join("\n\n---\n\n");
462 tracing::debug!(
463 "Aggregated input for node {:?}: {}",
464 target_node,
465 result
466 );
467 result
468 })
469 .unwrap_or_default();
470
471 tracing::debug!(
472 "Executing node {:?} with aggregated input",
473 target_node
474 );
475
476 if let Err(e) = self
478 .execute_node(
479 target_node,
480 aggregated_input,
481 results_clone,
482 edge_tracker_clone,
483 processed_nodes_clone,
484 )
485 .await
486 {
487 tracing::error!("Failed to execute node: {:?}", e);
488 }
489 }
490 };
491
492 futures.push(future);
493 }
494
495 futures::future::join_all(futures).await; }
498 Err(e) => {
499 tracing::error!("Agent '{}' execution failed: {:?}", agent_name, e);
500 }
502 }
503
504 result
505 }
506
507 pub fn get_workflow_structure(&self) -> HashMap<String, Vec<(String, Option<String>)>> {
509 let mut structure = HashMap::new();
510
511 for node_idx in self.workflow.node_indices() {
512 if let Some(node) = self.workflow.node_weight(node_idx) {
513 let mut connections = Vec::new();
514
515 for edge in self.workflow.edges_directed(node_idx, Direction::Outgoing) {
516 if let Some(target) = self.workflow.node_weight(edge.target()) {
517 let edge_label = if edge.weight().transform.is_some() {
519 Some("transform".to_owned())
520 } else {
521 None
522 };
523
524 connections.push((target.name.clone(), edge_label));
525 }
526 }
527
528 structure.insert(node.name.clone(), connections);
529 }
530 }
531
532 structure
533 }
534
535 pub fn export_workflow_dot(&self) -> String {
537 let mut dot = String::from("digraph {\n");
541
542 for node_idx in self.workflow.node_indices() {
544 if let Some(node) = self.workflow.node_weight(node_idx) {
545 dot.push_str(&format!(
546 " \"{}\" [label=\"{}\"];\n",
547 node.name, node.name
548 ));
549 }
550 }
551
552 for edge in self.workflow.edge_indices() {
554 if let Some((source, target)) = self.workflow.edge_endpoints(edge) {
555 if let (Some(source_node), Some(target_node)) = (
556 self.workflow.node_weight(source),
557 self.workflow.node_weight(target),
558 ) {
559 dot.push_str(&format!(
560 " \"{}\" -> \"{}\";\n",
561 source_node.name, target_node.name
562 ));
563 }
564 }
565 }
566
567 dot.push_str("}\n");
568 dot
569 }
570
571 pub fn find_execution_paths(
573 &self,
574 start_agents: &[&str],
575 ) -> Result<Vec<Vec<String>>, GraphWorkflowError> {
576 let start_indices = start_agents
577 .iter()
578 .map(|agent| {
579 self.name_to_node
580 .get(*agent)
581 .ok_or_else(|| {
582 GraphWorkflowError::AgentNotFound(format!(
583 "Start agent '{agent}' not found"
584 ))
585 })
586 .copied()
587 })
588 .collect::<Result<Vec<_>, _>>()?;
589
590 let mut paths = Vec::new();
591 let mut current_path = Vec::new();
592
593 for start_idx in &start_indices {
594 current_path.clear();
595 self.dfs_paths(*start_idx, &mut current_path, &mut paths);
596 }
597
598 Ok(paths)
599 }
600
601 fn dfs_paths(
602 &self,
603 node_idx: NodeIndex,
604 current_path: &mut Vec<String>,
605 all_paths: &mut Vec<Vec<String>>,
606 ) {
607 if let Some(node) = self.workflow.node_weight(node_idx) {
608 current_path.push(node.name.clone());
610
611 let has_outgoing = self
613 .workflow
614 .neighbors_directed(node_idx, Direction::Outgoing)
615 .count()
616 > 0;
617
618 if !has_outgoing {
619 all_paths.push(current_path.clone());
621 } else {
622 for neighbor in self
624 .workflow
625 .neighbors_directed(node_idx, Direction::Outgoing)
626 {
627 self.dfs_paths(neighbor, current_path, all_paths);
628 }
629 }
630
631 current_path.pop();
633 }
634 }
635
636 pub fn detect_potential_deadlocks(&self) -> Vec<Vec<String>> {
648 let mut dependency_graph = petgraph::Graph::<String, ()>::new();
650 let mut node_map = HashMap::new();
651
652 for name in self.name_to_node.keys() {
654 let idx = dependency_graph.add_node(name.clone());
655 node_map.insert(name.clone(), idx);
656 }
657
658 for node_idx in self.workflow.node_indices() {
660 if let Some(node) = self.workflow.node_weight(node_idx) {
661 let target_dep_idx = *node_map.get(&node.name).unwrap();
662
663 for source in self
665 .workflow
666 .neighbors_directed(node_idx, Direction::Incoming)
667 {
668 if let Some(source_node) = self.workflow.node_weight(source) {
669 let source_dep_idx = *node_map.get(&source_node.name).unwrap();
670 dependency_graph.add_edge(source_dep_idx, target_dep_idx, ());
671 }
672 }
673 }
674 }
675
676 let sccs = petgraph::algo::kosaraju_scc(&dependency_graph);
678
679 sccs.into_iter()
681 .filter(|scc| scc.len() > 1)
682 .map(|scc| {
683 scc.into_iter()
684 .map(|idx| dependency_graph[idx].clone())
685 .collect()
686 })
687 .collect()
688 }
689}
690
691#[allow(clippy::type_complexity)]
693#[derive(Clone, Default)]
694pub struct Flow {
695 pub transform: Option<Arc<dyn Fn(String) -> String + Send + Sync>>,
697 pub condition: Option<Arc<dyn Fn(&str) -> bool + Send + Sync>>,
699}
700
701#[derive(Debug)]
703pub struct AgentNode {
704 pub name: String,
706 pub last_result: Mutex<Option<Result<String, GraphWorkflowError>>>,
708}
709
710#[allow(missing_docs)]
712#[derive(Clone, Debug, Error)]
713pub enum GraphWorkflowError {
714 #[error("Agent Error: {0}")]
715 AgentError(String),
716 #[error("Agent not found: {0}")]
717 AgentNotFound(String),
718 #[error("Cycle detected in workflow")]
719 CycleDetected,
720 #[error("Execution error: {0}")]
721 ExecutionError(String),
722 #[error("Timeout executing agent: {0}")]
723 Timeout(String),
724 #[error("Deadlock detected in workflow execution")]
725 Deadlock,
726 #[error("Workflow execution canceled")]
727 Canceled,
728}
729
730impl Debug for Flow {
731 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
732 f.debug_struct("Flow")
733 .field("transform", &self.transform.is_some())
734 .field("condition", &self.condition.is_some())
735 .finish()
736 }
737}
738
739#[cfg(test)]
740mod tests {
741 use super::*;
742
743 use futures::future::{self, BoxFuture};
744 use mockall::mock;
745
746 use crate::agent::AgentError;
747
748 mock! {
749 #[derive(Debug)]
750 pub Agent{}
751
752 impl Agent for Agent {
753 fn run(&self, task: String) -> BoxFuture<'static, Result<String, AgentError>> {
754 Box::pin(future::ready(Ok(String::new())))
755 }
756 fn run_multiple_tasks(&mut self, tasks: Vec<String>) -> BoxFuture<'static, Result<Vec<String>, AgentError>> {
757 Box::pin(future::ready(Ok(vec![])))
758 }
759 fn id(&self) -> String {
760 String::new()
761 }
762 fn name(&self) -> String {
763 String::new()
764 }
765 fn description(&self) -> String {
766 String::new()
767 }
768 }
769 }
770
771 fn create_mock_agent(id: &str, name: &str, desc: &str, response: &str) -> Arc<MockAgent> {
772 let mut agent = MockAgent::new();
773
774 let id_str = id.to_owned();
775 agent.expect_id().return_const(id_str);
776
777 let name_str = name.to_owned();
778 agent.expect_name().return_const(name_str);
779
780 let desc_str = desc.to_owned();
781 agent.expect_description().return_const(desc_str);
782
783 let response_str = response.to_owned();
784 let response_str_clone = response_str.clone();
785 agent.expect_run().returning(move |_| {
786 let res = response_str_clone.clone();
787 Box::pin(future::ready(Ok(res)))
788 });
789
790 let response_str_clone = response_str.clone();
791 agent.expect_run_multiple_tasks().returning(move |tasks| {
792 let responses = tasks.iter().map(|_| response_str_clone.clone()).collect();
793 Box::pin(future::ready(Ok(responses)))
794 });
795
796 Arc::new(agent)
797 }
798
799 fn create_failing_agent(id: &str, name: &str, error_msg: &str) -> Arc<MockAgent> {
800 let mut agent = MockAgent::new();
801
802 let id_str = id.to_owned();
803 agent.expect_id().return_const(id_str);
804
805 let name_str = name.to_owned();
806 agent.expect_name().return_const(name_str);
807
808 agent
809 .expect_description()
810 .return_const("Failing agent".to_owned());
811
812 let error_str = error_msg.to_owned();
813 let error_str_for_run = error_str.clone();
814 agent.expect_run().returning(move |_| {
815 let err = AgentError::TestError(error_str_for_run.clone());
816 Box::pin(future::ready(Err(err)))
817 });
818
819 agent.expect_run_multiple_tasks().returning(move |_| {
820 let err = AgentError::TestError(error_str.clone());
821 Box::pin(future::ready(Err(err)))
822 });
823
824 Arc::new(agent)
825 }
826
827 #[test]
828 fn test_dag_creation() {
829 let workflow = DAGWorkflow::new("test", "Test workflow");
830 assert_eq!(workflow.name, "test");
831 assert_eq!(workflow.description, "Test workflow");
832 assert_eq!(workflow.agents.len(), 0);
833 assert_eq!(workflow.workflow.node_count(), 0);
834 assert_eq!(workflow.workflow.edge_count(), 0);
835 }
836
837 #[test]
838 fn test_agent_registration() {
839 let mut workflow = DAGWorkflow::new("test", "Test workflow");
840 workflow.register_agent(create_mock_agent("1", "agent1", "Test agent", "response1"));
841
842 assert_eq!(workflow.agents.len(), 1);
843 assert_eq!(workflow.workflow.node_count(), 1);
844 assert!(workflow.name_to_node.contains_key("agent1"));
845 }
846
847 #[test]
848 fn test_agent_connection() {
849 let mut workflow = DAGWorkflow::new("test", "Test workflow");
850 workflow.register_agent(create_mock_agent("1", "agent1", "First agent", "response1"));
851 workflow.register_agent(create_mock_agent(
852 "2",
853 "agent2",
854 "Second agent",
855 "response2",
856 ));
857
858 let result = workflow.connect_agents("agent1", "agent2", Flow::default());
859 assert!(result.is_ok());
860 assert_eq!(workflow.workflow.edge_count(), 1);
861 }
862
863 #[test]
864 fn test_agent_connection_failure_nonexistent_agent() {
865 let mut workflow = DAGWorkflow::new("test", "Test workflow");
866 workflow.register_agent(create_mock_agent("1", "agent1", "Test agent", "response1"));
867
868 let result = workflow.connect_agents("agent1", "nonexistent", Flow::default());
869 assert!(matches!(result, Err(GraphWorkflowError::AgentNotFound(_))));
870
871 let result = workflow.connect_agents("nonexistent", "agent1", Flow::default());
872 assert!(matches!(result, Err(GraphWorkflowError::AgentNotFound(_))));
873 }
874
875 #[test]
876 fn test_cycle_detection() {
877 let mut workflow = DAGWorkflow::new("test", "Test workflow");
878 workflow.register_agent(create_mock_agent("1", "agent1", "First agent", "response1"));
879 workflow.register_agent(create_mock_agent(
880 "2",
881 "agent2",
882 "Second agent",
883 "response2",
884 ));
885 workflow.register_agent(create_mock_agent("3", "agent3", "Third agent", "response3"));
886
887 let result1 = workflow.connect_agents("agent1", "agent2", Flow::default());
889 assert!(result1.is_ok());
890 let result2 = workflow.connect_agents("agent2", "agent3", Flow::default());
891 assert!(result2.is_ok());
892
893 let result3 = workflow.connect_agents("agent3", "agent1", Flow::default());
895 assert!(matches!(result3, Err(GraphWorkflowError::CycleDetected)));
896
897 assert_eq!(workflow.workflow.edge_count(), 2);
899 }
900
901 #[test]
902 fn test_agent_disconnection() {
903 let mut workflow = DAGWorkflow::new("test", "Test workflow");
904 workflow.register_agent(create_mock_agent("1", "agent1", "First agent", "response1"));
905 workflow.register_agent(create_mock_agent(
906 "2",
907 "agent2",
908 "Second agent",
909 "response2",
910 ));
911
912 workflow
913 .connect_agents("agent1", "agent2", Flow::default())
914 .unwrap();
915 assert_eq!(workflow.workflow.edge_count(), 1);
916
917 let result = workflow.disconnect_agents("agent1", "agent2");
918 assert!(result.is_ok());
919 assert_eq!(workflow.workflow.edge_count(), 0);
920 }
921
922 #[test]
923 fn test_agent_disconnection_failure() {
924 let mut workflow = DAGWorkflow::new("test", "Test workflow");
925 workflow.register_agent(create_mock_agent("1", "agent1", "First agent", "response1"));
926 workflow.register_agent(create_mock_agent(
927 "2",
928 "agent2",
929 "Second agent",
930 "response2",
931 ));
932
933 let result = workflow.disconnect_agents("agent1", "agent2");
935 assert!(matches!(result, Err(GraphWorkflowError::AgentNotFound(_))));
936
937 let result = workflow.disconnect_agents("nonexistent", "agent2");
939 assert!(matches!(result, Err(GraphWorkflowError::AgentNotFound(_))));
940 }
941
942 #[test]
943 fn test_agent_removal() {
944 let mut workflow = DAGWorkflow::new("test", "Test workflow");
945 workflow.register_agent(create_mock_agent("1", "agent1", "First agent", "response1"));
946 workflow.register_agent(create_mock_agent(
947 "2",
948 "agent2",
949 "Second agent",
950 "response2",
951 ));
952
953 workflow
954 .connect_agents("agent1", "agent2", Flow::default())
955 .unwrap();
956 assert_eq!(workflow.agents.len(), 2);
957 assert_eq!(workflow.workflow.node_count(), 2);
958
959 let result = workflow.remove_agent("agent1");
960 assert!(result.is_ok());
961 assert_eq!(workflow.agents.len(), 1);
962 assert_eq!(workflow.workflow.node_count(), 1);
963 assert!(!workflow.name_to_node.contains_key("agent1"));
964
965 assert_eq!(workflow.workflow.edge_count(), 0);
966 }
967
968 #[test]
969 fn test_agent_removal_nonexistent() {
970 let mut workflow = DAGWorkflow::new("test", "Test workflow");
971
972 let result = workflow.remove_agent("nonexistent");
973 assert!(matches!(result, Err(GraphWorkflowError::AgentNotFound(_))));
974 }
975
976 #[tokio::test]
977 async fn test_execute_single_agent() {
978 let mut workflow = DAGWorkflow::new("test", "Test workflow");
979 workflow.register_agent(create_mock_agent("1", "agent1", "Test agent", "response1"));
980
981 let result = workflow.execute_agent("agent1", "input".to_owned()).await;
982 assert!(result.is_ok());
983 assert_eq!(result.unwrap(), "response1");
984 }
985
986 #[tokio::test]
987 async fn test_execute_single_agent_failure() {
988 let mut workflow = DAGWorkflow::new("test", "Test workflow");
989 workflow.register_agent(create_failing_agent("1", "agent1", "test error"));
990
991 let result = workflow.execute_agent("agent1", "input".to_owned()).await;
992 assert!(matches!(result, Err(GraphWorkflowError::AgentError(_))));
993 }
994
995 #[tokio::test]
996 async fn test_execute_single_agent_not_found() {
997 let workflow = DAGWorkflow::new("test", "Test workflow");
998
999 let result = workflow
1000 .execute_agent("nonexistent", "input".to_owned())
1001 .await;
1002 assert!(matches!(result, Err(GraphWorkflowError::AgentNotFound(_))));
1003 }
1004
1005 #[tokio::test]
1006 async fn test_execute_workflow_linear() {
1007 let mut workflow = DAGWorkflow::new("test", "Test workflow");
1008 workflow.register_agent(create_mock_agent("1", "agent1", "First agent", "response1"));
1009 workflow.register_agent(create_mock_agent(
1010 "2",
1011 "agent2",
1012 "Second agent",
1013 "response2",
1014 ));
1015
1016 workflow
1017 .connect_agents("agent1", "agent2", Flow::default())
1018 .unwrap();
1019
1020 let results = workflow
1021 .execute_workflow(&["agent1"], "input")
1022 .await
1023 .unwrap();
1024 assert_eq!(results.len(), 2);
1025 assert_eq!(
1026 results.get("agent1").unwrap().as_ref().unwrap(),
1027 "response1"
1028 );
1029 assert_eq!(
1030 results.get("agent2").unwrap().as_ref().unwrap(),
1031 "response2"
1032 );
1033 }
1034
1035 #[tokio::test]
1036 async fn test_execute_workflow_branching() {
1037 let mut workflow = DAGWorkflow::new("test", "Test workflow");
1038 workflow.register_agent(create_mock_agent("1", "agent1", "Root agent", "response1"));
1039 workflow.register_agent(create_mock_agent("2", "agent2", "Branch 1", "response2"));
1040 workflow.register_agent(create_mock_agent("3", "agent3", "Branch 2", "response3"));
1041
1042 workflow
1043 .connect_agents("agent1", "agent2", Flow::default())
1044 .unwrap();
1045 workflow
1046 .connect_agents("agent1", "agent3", Flow::default())
1047 .unwrap();
1048
1049 let results = workflow
1050 .execute_workflow(&["agent1"], "input")
1051 .await
1052 .unwrap();
1053 assert_eq!(results.len(), 3);
1054 assert_eq!(
1055 results.get("agent1").unwrap().as_ref().unwrap(),
1056 "response1"
1057 );
1058 assert_eq!(
1059 results.get("agent2").unwrap().as_ref().unwrap(),
1060 "response2"
1061 );
1062 assert_eq!(
1063 results.get("agent3").unwrap().as_ref().unwrap(),
1064 "response3"
1065 );
1066 }
1067
1068 #[tokio::test]
1069 async fn test_execute_workflow_with_transformation() {
1070 let mut workflow = DAGWorkflow::new("test", "Test workflow");
1071 workflow.register_agent(create_mock_agent("1", "agent1", "First agent", "response1"));
1072 workflow.register_agent(create_mock_agent(
1073 "2",
1074 "agent2",
1075 "Second agent",
1076 "response2",
1077 ));
1078
1079 let transform_fn = Arc::new(|input: String| format!("transformed: {input}"));
1080 let flow = Flow {
1081 transform: Some(transform_fn),
1082 condition: None,
1083 };
1084
1085 workflow.connect_agents("agent1", "agent2", flow).unwrap();
1086
1087 let results = workflow
1088 .execute_workflow(&["agent1"], "input")
1089 .await
1090 .unwrap();
1091 assert_eq!(results.len(), 2);
1092
1093 let structure = workflow.get_workflow_structure();
1094 let agent1_connections = &structure["agent1"];
1095 assert_eq!(agent1_connections.len(), 1);
1096 assert_eq!(agent1_connections[0].0, "agent2");
1097 assert_eq!(agent1_connections[0].1, Some("transform".to_owned()));
1098 }
1099
1100 #[tokio::test]
1101 async fn test_execute_workflow_with_condition_true() {
1102 let mut workflow = DAGWorkflow::new("test", "Test workflow");
1103 workflow.register_agent(create_mock_agent("1", "agent1", "First agent", "true"));
1104 workflow.register_agent(create_mock_agent("2", "agent2", "Second agent", "executed"));
1105
1106 let true_condition = Arc::new(|output: &str| output.contains("true"));
1107
1108 workflow
1109 .connect_agents(
1110 "agent1",
1111 "agent2",
1112 Flow {
1113 transform: None,
1114 condition: Some(true_condition),
1115 },
1116 )
1117 .unwrap();
1118
1119 let results = workflow
1120 .execute_workflow(&["agent1"], "input")
1121 .await
1122 .unwrap();
1123 assert_eq!(results.len(), 2);
1124 assert!(results.contains_key("agent1"));
1125 assert!(results.contains_key("agent2"));
1126 }
1127
1128 #[tokio::test]
1129 async fn test_execute_workflow_with_condition_false() {
1130 let mut workflow = DAGWorkflow::new("test", "Test workflow");
1131 workflow.register_agent(create_mock_agent("1", "agent1", "First agent", "response1"));
1132 workflow.register_agent(create_mock_agent(
1133 "2",
1134 "agent2",
1135 "Second agent",
1136 "not executed",
1137 ));
1138
1139 let false_condition = Arc::new(|output: &str| output.contains("nonexistent"));
1140
1141 workflow
1142 .connect_agents(
1143 "agent1",
1144 "agent2",
1145 Flow {
1146 transform: None,
1147 condition: Some(false_condition),
1148 },
1149 )
1150 .unwrap();
1151
1152 let results = workflow
1153 .execute_workflow(&["agent1"], "input")
1154 .await
1155 .unwrap();
1156 assert_eq!(results.len(), 1);
1157 assert!(results.contains_key("agent1"));
1158 assert!(!results.contains_key("agent2"));
1159 }
1160
1161 #[tokio::test]
1162 async fn test_workflow_execution_start_agent_not_found() {
1163 let mut workflow = DAGWorkflow::new("test", "Test workflow");
1164 workflow.register_agent(create_mock_agent("1", "agent1", "First agent", "response1"));
1165
1166 let result = workflow.execute_workflow(&["nonexistent"], "input").await;
1167 assert!(matches!(result, Err(GraphWorkflowError::AgentNotFound(_))));
1168 }
1169
1170 #[tokio::test]
1171 async fn test_workflow_execution_with_failing_agent() {
1172 let mut workflow = DAGWorkflow::new("test", "Test workflow");
1173 workflow.register_agent(create_mock_agent("1", "agent1", "First agent", "response1"));
1174 workflow.register_agent(create_failing_agent("2", "agent2", "fail error"));
1175 workflow.register_agent(create_mock_agent("3", "agent3", "Third agent", "response3"));
1176
1177 workflow
1179 .connect_agents("agent1", "agent2", Flow::default())
1180 .unwrap();
1181 workflow
1182 .connect_agents("agent2", "agent3", Flow::default())
1183 .unwrap();
1184
1185 let results = workflow
1186 .execute_workflow(&["agent1"], "input")
1187 .await
1188 .unwrap();
1189 assert_eq!(results.len(), 2);
1190 assert!(results.contains_key("agent1"));
1191 assert!(results.contains_key("agent2"));
1192 assert!(!results.contains_key("agent3"));
1193
1194 let agent2_result = results.get("agent2").unwrap();
1195 assert!(agent2_result.is_err());
1196 }
1197
1198 #[tokio::test]
1199 async fn test_independent_multiple_starts() {
1200 let mut workflow = DAGWorkflow::new("test", "");
1201
1202 let agent_a = create_mock_agent("1", "A", "A", "A_result");
1203 let agent_b = create_mock_agent("2", "B", "B", "B_result");
1204 let agent_c = create_mock_agent("3", "C", "C", "C_result");
1205 let agent_d = create_mock_agent("4", "D", "D", "D_result");
1206
1207 workflow.register_agent(agent_a);
1208 workflow.register_agent(agent_b);
1209 workflow.register_agent(agent_c);
1210 workflow.register_agent(agent_d);
1211
1212 workflow.connect_agents("A", "C", Flow::default()).unwrap();
1213 workflow.connect_agents("B", "D", Flow::default()).unwrap();
1214
1215 let results = workflow
1216 .execute_workflow(&["A", "B"], "input")
1217 .await
1218 .unwrap();
1219
1220 assert_eq!(results.get("A").unwrap().as_ref().unwrap(), "A_result");
1221 assert_eq!(results.get("B").unwrap().as_ref().unwrap(), "B_result");
1222 assert_eq!(results.get("C").unwrap().as_ref().unwrap(), "C_result");
1223 assert_eq!(results.get("D").unwrap().as_ref().unwrap(), "D_result");
1224 }
1225
1226 #[tokio::test]
1228 async fn test_converging_multiple_starts() {
1229 let mut workflow = DAGWorkflow::new("test", "");
1230
1231 let agent_a = create_mock_agent("1", "A", "A", "A_result");
1232 let agent_b = create_mock_agent("2", "B", "B", "B_result");
1233 let agent_c = create_mock_agent("3", "C", "C", "C_result");
1234
1235 workflow.register_agent(agent_a);
1236 workflow.register_agent(agent_b);
1237 workflow.register_agent(agent_c);
1238
1239 workflow.connect_agents("A", "C", Flow::default()).unwrap();
1240 workflow.connect_agents("B", "C", Flow::default()).unwrap();
1241
1242 let _results = workflow
1243 .execute_workflow(&["A", "B"], "input")
1244 .await
1245 .unwrap();
1246
1247 let c_node = workflow.name_to_node.get("C").unwrap();
1248 let node_data = workflow.workflow.node_weight(*c_node).unwrap();
1249 let last_result = node_data.last_result.lock().await;
1250 assert!(last_result.is_some());
1251 assert!(
1252 last_result
1253 .as_ref()
1254 .unwrap()
1255 .as_ref()
1256 .unwrap()
1257 .contains("A_result")
1258 );
1259 assert!(
1260 last_result
1261 .as_ref()
1262 .unwrap()
1263 .as_ref()
1264 .unwrap()
1265 .contains("B_result")
1266 );
1267 }
1268
1269 #[tokio::test]
1271 async fn test_conditional_branches() {
1272 let mut workflow = DAGWorkflow::new("test", "");
1273
1274 let agent_a = create_mock_agent("1", "A", "A", "A_trigger");
1275 let agent_b = create_mock_agent("2", "B", "B", "B_result");
1276 let agent_c = create_mock_agent("3", "C", "C", "C_result");
1277
1278 workflow.register_agent(agent_a);
1279 workflow.register_agent(agent_b);
1280 workflow.register_agent(agent_c);
1281
1282 let conditional_flow = Flow {
1283 condition: Some(Arc::new(|output: &str| output.contains("trigger"))),
1284 transform: None,
1285 };
1286
1287 workflow.connect_agents("A", "B", conditional_flow).unwrap();
1288 workflow.connect_agents("A", "C", Flow::default()).unwrap();
1289
1290 let results = workflow.execute_workflow(&["A"], "input").await.unwrap();
1291
1292 assert!(results.get("B").is_none());
1293 assert_eq!(results.get("C").unwrap().as_ref().unwrap(), "C_result");
1294 }
1295
1296 #[test]
1297 fn test_find_execution_paths() {
1298 let mut workflow = DAGWorkflow::new("test", "Test workflow");
1299 workflow.register_agent(create_mock_agent("0", "start", "Starting point", "start"));
1300 workflow.register_agent(create_mock_agent("1", "a", "Path A", "a"));
1301 workflow.register_agent(create_mock_agent("2", "b", "Path B", "b"));
1302 workflow.register_agent(create_mock_agent("3", "c", "End of A", "c"));
1303 workflow.register_agent(create_mock_agent("4", "d", "End of B", "d"));
1304
1305 workflow
1306 .connect_agents("start", "a", Flow::default())
1307 .unwrap();
1308 workflow
1309 .connect_agents("start", "b", Flow::default())
1310 .unwrap();
1311 workflow.connect_agents("a", "c", Flow::default()).unwrap();
1312 workflow.connect_agents("b", "d", Flow::default()).unwrap();
1313
1314 let paths = workflow.find_execution_paths(&["start"]).unwrap();
1315 assert_eq!(paths.len(), 2);
1316
1317 let has_path1 = paths
1319 .iter()
1320 .any(|p| p == &vec!["start".to_owned(), "a".to_owned(), "c".to_owned()]);
1321 let has_path2 = paths
1322 .iter()
1323 .any(|p| p == &vec!["start".to_owned(), "b".to_owned(), "d".to_owned()]);
1324
1325 assert!(has_path1);
1326 assert!(has_path2);
1327 }
1328
1329 #[test]
1330 fn test_find_execution_paths_start_agent_not_found() {
1331 let workflow = DAGWorkflow::new("test", "Test workflow");
1332
1333 let result = workflow.find_execution_paths(&["nonexistent"]);
1334 assert!(matches!(result, Err(GraphWorkflowError::AgentNotFound(_))));
1335 }
1336
1337 #[test]
1338 fn test_find_execution_paths_diamond_pattern() {
1339 let mut workflow = DAGWorkflow::new("test", "Test workflow");
1340 workflow.register_agent(create_mock_agent("0", "start", "Start", "start"));
1341 workflow.register_agent(create_mock_agent("1", "a", "Middle A", "a"));
1342 workflow.register_agent(create_mock_agent("2", "b", "Middle B", "b"));
1343 workflow.register_agent(create_mock_agent("3", "end", "End", "end"));
1344
1345 workflow
1348 .connect_agents("start", "a", Flow::default())
1349 .unwrap();
1350 workflow
1351 .connect_agents("start", "b", Flow::default())
1352 .unwrap();
1353 workflow
1354 .connect_agents("a", "end", Flow::default())
1355 .unwrap();
1356 workflow
1357 .connect_agents("b", "end", Flow::default())
1358 .unwrap();
1359
1360 let paths = workflow.find_execution_paths(&["start"]).unwrap();
1361 assert_eq!(paths.len(), 2);
1362
1363 let has_path1 = paths
1365 .iter()
1366 .any(|p| p == &vec!["start".to_owned(), "a".to_owned(), "end".to_owned()]);
1367 let has_path2 = paths
1368 .iter()
1369 .any(|p| p == &vec!["start".to_owned(), "b".to_owned(), "end".to_owned()]);
1370
1371 assert!(has_path1);
1372 assert!(has_path2);
1373 }
1374
1375 #[test]
1376 fn test_detect_potential_deadlocks() {
1377 let mut workflow = DAGWorkflow::new("test", "Test workflow");
1378 workflow.register_agent(create_mock_agent("1", "a", "Agent A", "a"));
1379 workflow.register_agent(create_mock_agent("2", "b", "Agent B", "b"));
1380 workflow.register_agent(create_mock_agent("3", "c", "Agent C", "c"));
1381
1382 workflow.connect_agents("a", "b", Flow::default()).unwrap();
1384 workflow.connect_agents("b", "c", Flow::default()).unwrap();
1385
1386 let deadlocks = workflow.detect_potential_deadlocks();
1388 assert_eq!(deadlocks.len(), 0);
1389
1390 let result = workflow.connect_agents("c", "a", Flow::default());
1392 assert!(matches!(result, Err(GraphWorkflowError::CycleDetected)));
1393 }
1394
1395 #[test]
1396 fn test_get_workflow_structure() {
1397 let mut workflow = DAGWorkflow::new("test", "Test workflow");
1398 workflow.register_agent(create_mock_agent("1", "a", "Agent A", "a"));
1399 workflow.register_agent(create_mock_agent("2", "b", "Agent B", "b"));
1400 workflow.register_agent(create_mock_agent("3", "c", "Agent C", "c"));
1401
1402 workflow.connect_agents("a", "b", Flow::default()).unwrap();
1403
1404 let transform_fn = Arc::new(|input: String| format!("transformed: {input}"));
1405 let flow = Flow {
1406 transform: Some(transform_fn),
1407 condition: None,
1408 };
1409
1410 workflow.connect_agents("b", "c", flow).unwrap();
1411
1412 let structure = workflow.get_workflow_structure();
1413 assert_eq!(structure.len(), 3);
1414
1415 assert_eq!(structure["a"].len(), 1);
1416 assert_eq!(structure["a"][0].0, "b");
1417 assert_eq!(structure["a"][0].1, None);
1418
1419 assert_eq!(structure["b"].len(), 1);
1420 assert_eq!(structure["b"][0].0, "c");
1421 assert_eq!(structure["b"][0].1, Some("transform".to_owned())); assert_eq!(structure["c"].len(), 0); }
1425
1426 #[test]
1427 fn test_export_workflow_dot() {
1428 let mut workflow = DAGWorkflow::new("test", "Test workflow");
1429 workflow.register_agent(create_mock_agent("1", "a", "Agent A", "a"));
1430 workflow.register_agent(create_mock_agent("2", "b", "Agent B", "b"));
1431
1432 workflow.connect_agents("a", "b", Flow::default()).unwrap();
1433
1434 let dot = workflow.export_workflow_dot();
1435
1436 assert!(dot.contains("digraph {"));
1437 assert!(dot.contains("\"a\" [label=\"a\"]"));
1438 assert!(dot.contains("\"b\" [label=\"b\"]"));
1439 assert!(dot.contains("\"a\" -> \"b\""));
1440 assert!(dot.contains("}"));
1441 }
1442
1443 #[tokio::test]
1444 async fn test_caching_execution_results() {
1445 let mut workflow = DAGWorkflow::new("test", "Test workflow");
1446
1447 let mut agent = MockAgent::new();
1449 let agent_name = "counter".to_owned();
1450 agent.expect_name().return_const(agent_name.clone());
1451 agent.expect_id().return_const("1".to_owned());
1452 agent
1453 .expect_description()
1454 .return_const("Counter Agent".to_owned());
1455
1456 let mut count = 0;
1457 agent.expect_run().returning(move |_| {
1458 count += 1;
1459 Box::pin(future::ready(Ok(format!("Called {count} times"))))
1460 });
1461
1462 agent
1463 .expect_run_multiple_tasks()
1464 .returning(|_| Box::pin(future::ready(Ok(vec![]))));
1465
1466 workflow.register_agent(Arc::new(agent));
1467
1468 let results1 = workflow
1470 .execute_workflow(&["counter"], "input1")
1471 .await
1472 .unwrap();
1473 assert_eq!(
1474 results1.get("counter").unwrap().as_ref().unwrap(),
1475 "Called 1 times"
1476 );
1477
1478 let results2 = workflow
1480 .execute_workflow(&["counter"], "input2")
1481 .await
1482 .unwrap();
1483 assert_eq!(
1484 results2.get("counter").unwrap().as_ref().unwrap(),
1485 "Called 2 times"
1486 );
1487
1488 let result3 = workflow
1490 .execute_agent("counter", "input3".to_owned())
1491 .await
1492 .unwrap();
1493 assert_eq!(result3, "Called 3 times");
1494 }
1495
1496 #[tokio::test]
1497 async fn test_execute_node_result_caching() {
1498 let mut workflow = DAGWorkflow::new("test", "Test workflow");
1499
1500 let mut agent1 = MockAgent::new();
1502 agent1.expect_name().return_const("agent1".to_owned());
1503 agent1.expect_id().return_const("1".to_owned());
1504 agent1
1505 .expect_description()
1506 .return_const("First agent".to_owned());
1507
1508 let mut run_count = 0;
1510 agent1.expect_run().returning(move |input| {
1511 run_count += 1;
1512 Box::pin(future::ready(Ok(format!(
1513 "response for '{input}' (call #{run_count})"
1514 ))))
1515 });
1516
1517 agent1
1518 .expect_run_multiple_tasks()
1519 .returning(|_| Box::pin(future::ready(Ok(vec![]))));
1520
1521 workflow.register_agent(Arc::new(agent1));
1522
1523 workflow.register_agent(create_mock_agent(
1525 "2",
1526 "agent2",
1527 "Second agent",
1528 "response2",
1529 ));
1530
1531 workflow
1533 .connect_agents("agent1", "agent2", Flow::default())
1534 .unwrap();
1535
1536 let agent1_idx = *workflow.name_to_node.get("agent1").unwrap();
1537
1538 let results = Arc::new(DashMap::new());
1540 let edge_tracker = Arc::new(DashMap::new());
1541 let processed_nodes = Arc::new(DashMap::new());
1542
1543 let result1 = workflow
1545 .execute_node(
1546 agent1_idx,
1547 "input1".to_owned(),
1548 Arc::clone(&results),
1549 Arc::clone(&edge_tracker),
1550 Arc::clone(&processed_nodes),
1551 )
1552 .await
1553 .unwrap();
1554
1555 assert_eq!(result1, "response for 'input1' (call #1)");
1556 assert!(results.contains_key("agent1"));
1557 assert!(results.contains_key("agent2")); let result2 = workflow
1561 .execute_node(
1562 agent1_idx,
1563 "input2".to_owned(),
1564 Arc::clone(&results),
1565 Arc::clone(&edge_tracker),
1566 Arc::clone(&processed_nodes),
1567 )
1568 .await
1569 .unwrap();
1570
1571 assert_eq!(result2, "response for 'input1' (call #1)"); results.clear();
1576
1577 let result3 = workflow
1579 .execute_node(
1580 agent1_idx,
1581 "input3".to_owned(),
1582 Arc::clone(&results),
1583 Arc::clone(&edge_tracker),
1584 Arc::clone(&processed_nodes),
1585 )
1586 .await
1587 .unwrap();
1588
1589 assert_eq!(result3, "response for 'input3' (call #2)");
1591 }
1592}