rigs/
graph_workflow.rs

1//! Graph workflow implementation
2//!
3#![deny(missing_docs)]
4
5use std::{
6    collections::{HashMap, hash_map},
7    fmt::Debug,
8    sync::Arc,
9    time::Duration,
10};
11
12use dashmap::DashMap;
13use petgraph::{
14    Direction,
15    graph::{EdgeIndex, NodeIndex},
16    prelude::StableGraph,
17    visit::EdgeRef,
18};
19use thiserror::Error;
20use tokio::sync::Mutex;
21
22use crate::agent::Agent;
23
24/// The main orchestration structure
25pub struct DAGWorkflow {
26    /// The workflow name
27    pub name: String,
28    /// The workflow description
29    pub description: String,
30    /// Store all registered agents
31    agents: DashMap<String, Arc<dyn Agent>>,
32    /// The workflow graph
33    workflow: StableGraph<AgentNode, Flow>,
34    /// Map from agent name to node index for quick lookup
35    name_to_node: HashMap<String, NodeIndex>,
36}
37
38impl DAGWorkflow {
39    /// Create a new DAGWorkflow
40    pub fn new<S: Into<String>>(name: S, description: S) -> Self {
41        Self {
42            name: name.into(),
43            description: description.into(),
44            agents: DashMap::new(),
45            workflow: StableGraph::new(),
46            name_to_node: HashMap::new(),
47        }
48    }
49
50    /// Register an agent with the orchestrator
51    pub fn register_agent(&mut self, agent: Arc<dyn Agent>) {
52        let agent_name = agent.name();
53        self.agents.insert(agent_name.clone(), agent);
54
55        // If agent isn't already in the graph, add it
56        if let hash_map::Entry::Vacant(e) = self.name_to_node.entry(agent_name.clone()) {
57            let node_idx = self.workflow.add_node(AgentNode {
58                name: agent_name.clone(),
59                last_result: Mutex::new(None),
60            });
61            e.insert(node_idx);
62        }
63    }
64
65    /// Add a flow connection between two agents
66    pub fn connect_agents(
67        &mut self,
68        from: &str,
69        to: &str,
70        flow: Flow,
71    ) -> Result<EdgeIndex, GraphWorkflowError> {
72        // Ensure both agents exist
73        if !self.agents.contains_key(from) {
74            return Err(GraphWorkflowError::AgentNotFound(format!(
75                "Source agent '{from}' not found",
76            )));
77        }
78        if !self.agents.contains_key(to) {
79            return Err(GraphWorkflowError::AgentNotFound(format!(
80                "Target agent '{to}' not found",
81            )));
82        }
83
84        // Get node indices, creating nodes if necessary
85        let from_entry = self.name_to_node.entry(from.to_owned());
86        let from_idx = *from_entry.or_insert_with(|| {
87            self.workflow.add_node(AgentNode {
88                name: from.to_owned(),
89                last_result: Mutex::new(None),
90            })
91        });
92
93        let to_entry = self.name_to_node.entry(to.to_owned());
94        let to_idx = *to_entry.or_insert_with(|| {
95            self.workflow.add_node(AgentNode {
96                name: to.to_owned(),
97                last_result: Mutex::new(None),
98            })
99        });
100
101        // Add the edge
102        let edge_idx = self.workflow.add_edge(from_idx, to_idx, flow);
103
104        // Check for cycles
105        if self.has_cycle() {
106            // Remove the edge we just added
107            self.workflow.remove_edge(edge_idx);
108            return Err(GraphWorkflowError::CycleDetected);
109        }
110
111        Ok(edge_idx)
112    }
113
114    // Check if the workflow has a cycle
115    fn has_cycle(&self) -> bool {
116        // Implementation using DFS to detect cycles
117        let mut visited = vec![false; self.workflow.node_count()];
118        let mut rec_stack = vec![false; self.workflow.node_count()];
119
120        for node in self.workflow.node_indices() {
121            if !visited[node.index()] && self.is_cyclic_util(node, &mut visited, &mut rec_stack) {
122                return true;
123            }
124        }
125        false
126    }
127
128    fn is_cyclic_util(
129        &self,
130        node: NodeIndex,
131        visited: &mut [bool],
132        rec_stack: &mut [bool],
133    ) -> bool {
134        visited[node.index()] = true;
135        rec_stack[node.index()] = true;
136
137        for neighbor in self.workflow.neighbors_directed(node, Direction::Outgoing) {
138            if !visited[neighbor.index()] {
139                if self.is_cyclic_util(neighbor, visited, rec_stack) {
140                    return true;
141                }
142            } else if rec_stack[neighbor.index()] {
143                return true;
144            }
145        }
146
147        rec_stack[node.index()] = false;
148        false
149    }
150
151    /// Remove an agent connection
152    pub fn disconnect_agents(&mut self, from: &str, to: &str) -> Result<(), GraphWorkflowError> {
153        let from_idx = self.name_to_node.get(from).ok_or_else(|| {
154            GraphWorkflowError::AgentNotFound(format!("Source agent '{from}' not found"))
155        })?;
156        let to_idx = self.name_to_node.get(to).ok_or_else(|| {
157            GraphWorkflowError::AgentNotFound(format!("Target agent '{to}' not found"))
158        })?;
159
160        // Find and remove the edge
161        if let Some(edge) = self.workflow.find_edge(*from_idx, *to_idx) {
162            self.workflow.remove_edge(edge);
163            Ok(())
164        } else {
165            Err(GraphWorkflowError::AgentNotFound(format!(
166                "No connection from '{from}' to '{to}'"
167            )))
168        }
169    }
170
171    /// Remove an agent from the orchestrator
172    pub fn remove_agent(&mut self, name: &str) -> Result<(), GraphWorkflowError> {
173        if let Some(node_idx) = self.name_to_node.remove(name) {
174            self.workflow.remove_node(node_idx);
175            self.agents.remove(name);
176            Ok(())
177        } else {
178            Err(GraphWorkflowError::AgentNotFound(format!(
179                "Agent '{name}' not found"
180            )))
181        }
182    }
183
184    /// Execute a specific agent
185    pub async fn execute_agent(
186        &self,
187        name: &str,
188        input: String,
189    ) -> Result<String, GraphWorkflowError> {
190        if let Some(agent) = self.agents.get(name) {
191            agent
192                .run(input)
193                .await
194                .map_err(|e| GraphWorkflowError::AgentError(e.to_string()))
195        } else {
196            Err(GraphWorkflowError::AgentNotFound(format!(
197                "Agent '{name}' not found"
198            )))
199        }
200    }
201
202    /// Execute the entire workflow starting from a specific agent
203    ///
204    /// # Arguments
205    ///
206    /// * `start_agent`: The name of the agent to start the workflow from
207    /// * `input`: The input to the workflow
208    ///
209    /// # Returns
210    ///
211    /// * `Result<DashMap<String, Result<String, GraphWorkflowError>>, GraphWorkflowError>`: A map of agent names to their results
212    ///
213    pub async fn execute_workflow(
214        &mut self,
215        start_agents: &[&str],
216        input: impl Into<String>,
217    ) -> Result<DashMap<String, Result<String, GraphWorkflowError>>, GraphWorkflowError> {
218        let input = input.into();
219
220        let start_indices = start_agents
221            .iter()
222            .map(|agent| {
223                self.name_to_node
224                    .get(*agent)
225                    .ok_or_else(|| {
226                        GraphWorkflowError::AgentNotFound(format!(
227                            "Start agent '{agent}' not found"
228                        ))
229                    })
230                    .copied()
231            })
232            .collect::<Result<Vec<_>, _>>()?;
233
234        // Reset all results
235        let node_idxs = self.workflow.node_indices().collect::<Vec<_>>();
236        for idx in node_idxs {
237            if let Some(node_weight) = self.workflow.node_weight_mut(idx) {
238                let mut last_result = node_weight.last_result.lock().await;
239                *last_result = None;
240            }
241        }
242
243        // Create a shared results map for all agents to write to
244        let results = Arc::new(DashMap::new());
245        // Create a shared tracking state for the entire workflow
246        let edge_tracker = Arc::new(DashMap::new());
247        let processed_nodes = Arc::new(DashMap::new());
248        // Execute the workflow
249        let mut tasks = Vec::new();
250        for &start_idx in &start_indices {
251            let task = self.execute_node(
252                start_idx,
253                input.clone(),
254                Arc::clone(&results),
255                Arc::clone(&edge_tracker),
256                Arc::clone(&processed_nodes),
257            );
258            tasks.push(task);
259        }
260        futures::future::join_all(tasks)
261            .await
262            .into_iter()
263            .collect::<Result<Vec<_>, _>>()
264            .map_err(|e| GraphWorkflowError::ExecutionError(e.to_string()))?;
265        Ok(Arc::into_inner(results).expect("Results should not be poisoned"))
266    }
267
268    async fn execute_node(
269        &self,
270        node_idx: NodeIndex,
271        input: String,
272        results: Arc<DashMap<String, Result<String, GraphWorkflowError>>>,
273        edge_tracker: Arc<DashMap<(NodeIndex, NodeIndex), bool>>,
274        processed_nodes: Arc<DashMap<NodeIndex, Vec<(NodeIndex, String)>>>,
275    ) -> Result<String, GraphWorkflowError> {
276        // Get the agent name from the node
277        let agent_name = &self
278            .workflow
279            .node_weight(node_idx)
280            .ok_or_else(|| GraphWorkflowError::AgentNotFound("Node not found in graph".to_owned()))?
281            .name;
282
283        // Check if we already have a result for this node (avoid duplicate work)
284        if let Some(entry) = results.get(agent_name) {
285            return entry.value().clone();
286        }
287
288        // Execute the agent with timeout protection
289        let result = tokio::time::timeout(
290            Duration::from_secs(3600), // 60-minute timeout
291            self.execute_agent(agent_name, input),
292        )
293        .await
294        .map_err(|_| GraphWorkflowError::Timeout(agent_name.clone()))?;
295
296        // Store the result
297        results.insert(agent_name.clone(), result.clone());
298
299        // Update the node's last result
300        if let Some(node_weight) = self.workflow.node_weight(node_idx) {
301            let mut last_result = node_weight.last_result.lock().await;
302            *last_result = Some(result.clone());
303        }
304
305        // If successful, propagate to connected agents
306        match &result {
307            Ok(output) => {
308                // Find all outgoing edges that pass the condition (if any)
309                let valid_edges = self
310                    .workflow
311                    .edges_directed(node_idx, Direction::Outgoing)
312                    .filter(|edge| {
313                        // Evaluate condition with the current output
314                        let condition_result = edge
315                            .weight()
316                            .condition
317                            .as_ref()
318                            .map(|cond| {
319                                // Apply condition to the current output
320                                let result = cond(output);
321                                tracing::debug!(
322                                    "Condition for edge {:?} -> {:?}: {}",
323                                    node_idx,
324                                    edge.target(),
325                                    result
326                                );
327                                result
328                            })
329                            .unwrap_or(true); // if no condition, always execute
330
331                        condition_result
332                    })
333                    .collect::<Vec<_>>();
334
335                let mut futures = Vec::new();
336
337                for edge in valid_edges {
338                    let source_node = node_idx;
339                    let target_node = edge.target();
340                    let flow = edge.weight().clone();
341                    let results_clone = Arc::clone(&results);
342                    let processed_nodes_clone = Arc::clone(&processed_nodes);
343                    let edge_tracker_clone = Arc::clone(&edge_tracker);
344
345                    let future = async move {
346                        // Apply transformation if any
347                        let next_input = flow
348                            .transform
349                            .as_ref()
350                            .map_or_else(|| output.clone(), |transform| transform(output.clone()));
351
352                        // mark this edge as processed
353                        edge_tracker_clone.insert((source_node, target_node), true);
354
355                        // record the input for this node with proper synchronization
356                        // Use a scope to ensure the lock is released after the operation
357                        {
358                            processed_nodes_clone
359                                .entry(target_node)
360                                .and_modify(|v| v.push((source_node, next_input.clone())))
361                                .or_insert_with(|| vec![(source_node, next_input.clone())]);
362                        }
363
364                        // Get all input edges (including those from different starting nodes)
365                        let all_incoming_edges = self
366                            .workflow
367                            .edges_directed(target_node, Direction::Incoming)
368                            .map(|e| (e.source(), target_node))
369                            .collect::<Vec<_>>();
370
371                        // Check that all input edges have completed processing (from different paths).
372                        // For conditional flows, we need to check if the edge has a condition and if it evaluates to false
373                        let all_processed = all_incoming_edges.iter().all(|edge| {
374                            // Check if this edge is already processed
375                            let processed = edge_tracker_clone.contains_key(edge);
376
377                            // If not processed, check if it has a condition that evaluates to false
378                            // In that case, we should consider it as "processed" (skipped)
379                            let conditionally_skipped = if !processed {
380                                if let Some(edge_idx) = self.workflow.find_edge(edge.0, edge.1) {
381                                    let edge_weight = self.workflow.edge_weight(edge_idx).unwrap();
382                                    if let Some(cond) = &edge_weight.condition {
383                                        // If we can find the source node's result, check the condition
384                                        if let Some(source_name) =
385                                            self.workflow.node_weight(edge.0).map(|n| &n.name)
386                                        {
387                                            if let Some(source_result) =
388                                                results_clone.get(source_name)
389                                            {
390                                                if let Ok(output) = source_result.as_ref() {
391                                                    // If condition is false, this edge is conditionally skipped
392                                                    let condition_result = !cond(output);
393                                                    if condition_result {
394                                                        // Mark this edge as processed (skipped due to condition)
395                                                        edge_tracker_clone
396                                                            .insert((edge.0, edge.1), true);
397                                                    }
398                                                    condition_result
399                                                } else {
400                                                    // Source node execution failed, consider edge as processed
401                                                    edge_tracker_clone
402                                                        .insert((edge.0, edge.1), true);
403                                                    true
404                                                }
405                                            } else {
406                                                false
407                                            }
408                                        } else {
409                                            false
410                                        }
411                                    } else {
412                                        false
413                                    }
414                                } else {
415                                    false
416                                }
417                            } else {
418                                false
419                            };
420
421                            tracing::debug!(
422                                "Edge {:?} processed: {}, conditionally skipped: {}",
423                                edge,
424                                processed,
425                                conditionally_skipped
426                            );
427                            processed || conditionally_skipped
428                        });
429
430                        // only execute if all incoming edges have been processed
431                        if all_processed {
432                            // Aggregate all inputs from different paths
433                            let aggregated_input = processed_nodes_clone
434                                .get(&target_node)
435                                .map(|inputs| {
436                                    // Sort inputs by source node to ensure consistent ordering
437                                    let mut sorted_inputs = inputs.value().clone();
438                                    sorted_inputs.sort_by_key(|(source_idx, _)| *source_idx);
439
440                                    // Log the number of inputs for debugging
441                                    tracing::debug!(
442                                        "Node {:?} has {} inputs",
443                                        target_node,
444                                        sorted_inputs.len()
445                                    );
446
447                                    // Format each input with its source agent name
448                                    let formatted_inputs = sorted_inputs
449                                        .iter()
450                                        .map(|(source_idx, input)| {
451                                            let source_name = &self
452                                                .workflow
453                                                .node_weight(*source_idx)
454                                                .unwrap()
455                                                .name;
456                                            format!("[From {source_name}] {input}")
457                                        })
458                                        .collect::<Vec<_>>();
459
460                                    // Join all inputs with a clear separator
461                                    let result = formatted_inputs.join("\n\n---\n\n");
462                                    tracing::debug!(
463                                        "Aggregated input for node {:?}: {}",
464                                        target_node,
465                                        result
466                                    );
467                                    result
468                                })
469                                .unwrap_or_default();
470
471                            tracing::debug!(
472                                "Executing node {:?} with aggregated input",
473                                target_node
474                            );
475
476                            // execute the target node with the aggregated input
477                            if let Err(e) = self
478                                .execute_node(
479                                    target_node,
480                                    aggregated_input,
481                                    results_clone,
482                                    edge_tracker_clone,
483                                    processed_nodes_clone,
484                                )
485                                .await
486                            {
487                                tracing::error!("Failed to execute node: {:?}", e);
488                            }
489                        }
490                    };
491
492                    futures.push(future);
493                }
494
495                // Execute connected agents concurrently
496                futures::future::join_all(futures).await; // TODO: may use another way which can handle errors
497            }
498            Err(e) => {
499                tracing::error!("Agent '{}' execution failed: {:?}", agent_name, e);
500                // TODO: maybe we need to propagate the error to the caller?
501            }
502        }
503
504        result
505    }
506
507    /// Get the current workflow as a visualization-friendly format
508    pub fn get_workflow_structure(&self) -> HashMap<String, Vec<(String, Option<String>)>> {
509        let mut structure = HashMap::new();
510
511        for node_idx in self.workflow.node_indices() {
512            if let Some(node) = self.workflow.node_weight(node_idx) {
513                let mut connections = Vec::new();
514
515                for edge in self.workflow.edges_directed(node_idx, Direction::Outgoing) {
516                    if let Some(target) = self.workflow.node_weight(edge.target()) {
517                        // TODO: can add more edge metadata here if needed
518                        let edge_label = if edge.weight().transform.is_some() {
519                            Some("transform".to_owned())
520                        } else {
521                            None
522                        };
523
524                        connections.push((target.name.clone(), edge_label));
525                    }
526                }
527
528                structure.insert(node.name.clone(), connections);
529            }
530        }
531
532        structure
533    }
534
535    /// Export the workflow to a format that can be visualized (e.g., DOT format for Graphviz)
536    pub fn export_workflow_dot(&self) -> String {
537        // TODO: can use petgraph's built-in dot
538        // let dot = Dot::with_config(&self.workflow, &[dot::Config::EdgeNoLabel]);
539
540        let mut dot = String::from("digraph {\n");
541
542        // Add nodes
543        for node_idx in self.workflow.node_indices() {
544            if let Some(node) = self.workflow.node_weight(node_idx) {
545                dot.push_str(&format!(
546                    "    \"{}\" [label=\"{}\"];\n",
547                    node.name, node.name
548                ));
549            }
550        }
551
552        // Add edges
553        for edge in self.workflow.edge_indices() {
554            if let Some((source, target)) = self.workflow.edge_endpoints(edge) {
555                if let (Some(source_node), Some(target_node)) = (
556                    self.workflow.node_weight(source),
557                    self.workflow.node_weight(target),
558                ) {
559                    dot.push_str(&format!(
560                        "    \"{}\" -> \"{}\";\n",
561                        source_node.name, target_node.name
562                    ));
563                }
564            }
565        }
566
567        dot.push_str("}\n");
568        dot
569    }
570
571    /// Helper method to find all possible execution paths
572    pub fn find_execution_paths(
573        &self,
574        start_agents: &[&str],
575    ) -> Result<Vec<Vec<String>>, GraphWorkflowError> {
576        let start_indices = start_agents
577            .iter()
578            .map(|agent| {
579                self.name_to_node
580                    .get(*agent)
581                    .ok_or_else(|| {
582                        GraphWorkflowError::AgentNotFound(format!(
583                            "Start agent '{agent}' not found"
584                        ))
585                    })
586                    .copied()
587            })
588            .collect::<Result<Vec<_>, _>>()?;
589
590        let mut paths = Vec::new();
591        let mut current_path = Vec::new();
592
593        for start_idx in &start_indices {
594            current_path.clear();
595            self.dfs_paths(*start_idx, &mut current_path, &mut paths);
596        }
597
598        Ok(paths)
599    }
600
601    fn dfs_paths(
602        &self,
603        node_idx: NodeIndex,
604        current_path: &mut Vec<String>,
605        all_paths: &mut Vec<Vec<String>>,
606    ) {
607        if let Some(node) = self.workflow.node_weight(node_idx) {
608            // Add current node to path
609            current_path.push(node.name.clone());
610
611            // Check if this is a leaf node (no outgoing edges)
612            let has_outgoing = self
613                .workflow
614                .neighbors_directed(node_idx, Direction::Outgoing)
615                .count()
616                > 0;
617
618            if !has_outgoing {
619                // We've reached a leaf node, save this path
620                all_paths.push(current_path.clone());
621            } else {
622                // Continue DFS for all neighbors
623                for neighbor in self
624                    .workflow
625                    .neighbors_directed(node_idx, Direction::Outgoing)
626                {
627                    self.dfs_paths(neighbor, current_path, all_paths);
628                }
629            }
630
631            // Backtrack
632            current_path.pop();
633        }
634    }
635
636    /// Detect potential deadlocks in the workflow. Whether there will actually be a deadlock depends on the flow at execution time.
637    ///
638    /// ## Info
639    ///
640    /// Maybe we need a monitor to detect deadlocks instead of this function.
641    ///
642    /// ## Returns
643    ///
644    /// Returns a vector of cycles (each cycle is a vector of agent names).
645    ///
646    /// Example: vec![vec!["A", "B", "C"], vec!["X", "Y"]]
647    pub fn detect_potential_deadlocks(&self) -> Vec<Vec<String>> {
648        // Build a dependency graph where an edge A→B means B depends on A
649        let mut dependency_graph = petgraph::Graph::<String, ()>::new();
650        let mut node_map = HashMap::new();
651
652        // Create nodes
653        for name in self.name_to_node.keys() {
654            let idx = dependency_graph.add_node(name.clone());
655            node_map.insert(name.clone(), idx);
656        }
657
658        // Add dependencies
659        for node_idx in self.workflow.node_indices() {
660            if let Some(node) = self.workflow.node_weight(node_idx) {
661                let target_dep_idx = *node_map.get(&node.name).unwrap();
662
663                // Add an edge for each incoming connection
664                for source in self
665                    .workflow
666                    .neighbors_directed(node_idx, Direction::Incoming)
667                {
668                    if let Some(source_node) = self.workflow.node_weight(source) {
669                        let source_dep_idx = *node_map.get(&source_node.name).unwrap();
670                        dependency_graph.add_edge(source_dep_idx, target_dep_idx, ());
671                    }
672                }
673            }
674        }
675
676        // Find strongly connected components (cycles in the dependency graph)
677        let sccs = petgraph::algo::kosaraju_scc(&dependency_graph);
678
679        // Return only the non-trivial SCCs (size > 1)
680        sccs.into_iter()
681            .filter(|scc| scc.len() > 1)
682            .map(|scc| {
683                scc.into_iter()
684                    .map(|idx| dependency_graph[idx].clone())
685                    .collect()
686            })
687            .collect()
688    }
689}
690
691/// Edge weight to represent the flow of data between agents
692#[allow(clippy::type_complexity)]
693#[derive(Clone, Default)]
694pub struct Flow {
695    /// Optional transformation function to apply to the output before passing to the next agent
696    pub transform: Option<Arc<dyn Fn(String) -> String + Send + Sync>>,
697    /// Optional condition to determine if this flow should be taken
698    pub condition: Option<Arc<dyn Fn(&str) -> bool + Send + Sync>>,
699}
700
701/// Node weight for the graph
702#[derive(Debug)]
703pub struct AgentNode {
704    /// Name of the agent
705    pub name: String,
706    /// Cache for execution results
707    pub last_result: Mutex<Option<Result<String, GraphWorkflowError>>>,
708}
709
710/// Error type for the graph workflow
711#[allow(missing_docs)]
712#[derive(Clone, Debug, Error)]
713pub enum GraphWorkflowError {
714    #[error("Agent Error: {0}")]
715    AgentError(String),
716    #[error("Agent not found: {0}")]
717    AgentNotFound(String),
718    #[error("Cycle detected in workflow")]
719    CycleDetected,
720    #[error("Execution error: {0}")]
721    ExecutionError(String),
722    #[error("Timeout executing agent: {0}")]
723    Timeout(String),
724    #[error("Deadlock detected in workflow execution")]
725    Deadlock,
726    #[error("Workflow execution canceled")]
727    Canceled,
728}
729
730impl Debug for Flow {
731    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
732        f.debug_struct("Flow")
733            .field("transform", &self.transform.is_some())
734            .field("condition", &self.condition.is_some())
735            .finish()
736    }
737}
738
739#[cfg(test)]
740mod tests {
741    use super::*;
742
743    use futures::future::{self, BoxFuture};
744    use mockall::mock;
745
746    use crate::agent::AgentError;
747
748    mock! {
749        #[derive(Debug)]
750        pub Agent{}
751
752        impl Agent for Agent {
753            fn run(&self, task: String) -> BoxFuture<'static, Result<String, AgentError>> {
754                Box::pin(future::ready(Ok(String::new())))
755            }
756            fn run_multiple_tasks(&mut self, tasks: Vec<String>) -> BoxFuture<'static, Result<Vec<String>, AgentError>> {
757                Box::pin(future::ready(Ok(vec![])))
758            }
759            fn id(&self) -> String {
760                String::new()
761            }
762            fn name(&self) -> String {
763                String::new()
764            }
765            fn description(&self) -> String {
766                String::new()
767            }
768        }
769    }
770
771    fn create_mock_agent(id: &str, name: &str, desc: &str, response: &str) -> Arc<MockAgent> {
772        let mut agent = MockAgent::new();
773
774        let id_str = id.to_owned();
775        agent.expect_id().return_const(id_str);
776
777        let name_str = name.to_owned();
778        agent.expect_name().return_const(name_str);
779
780        let desc_str = desc.to_owned();
781        agent.expect_description().return_const(desc_str);
782
783        let response_str = response.to_owned();
784        let response_str_clone = response_str.clone();
785        agent.expect_run().returning(move |_| {
786            let res = response_str_clone.clone();
787            Box::pin(future::ready(Ok(res)))
788        });
789
790        let response_str_clone = response_str.clone();
791        agent.expect_run_multiple_tasks().returning(move |tasks| {
792            let responses = tasks.iter().map(|_| response_str_clone.clone()).collect();
793            Box::pin(future::ready(Ok(responses)))
794        });
795
796        Arc::new(agent)
797    }
798
799    fn create_failing_agent(id: &str, name: &str, error_msg: &str) -> Arc<MockAgent> {
800        let mut agent = MockAgent::new();
801
802        let id_str = id.to_owned();
803        agent.expect_id().return_const(id_str);
804
805        let name_str = name.to_owned();
806        agent.expect_name().return_const(name_str);
807
808        agent
809            .expect_description()
810            .return_const("Failing agent".to_owned());
811
812        let error_str = error_msg.to_owned();
813        let error_str_for_run = error_str.clone();
814        agent.expect_run().returning(move |_| {
815            let err = AgentError::TestError(error_str_for_run.clone());
816            Box::pin(future::ready(Err(err)))
817        });
818
819        agent.expect_run_multiple_tasks().returning(move |_| {
820            let err = AgentError::TestError(error_str.clone());
821            Box::pin(future::ready(Err(err)))
822        });
823
824        Arc::new(agent)
825    }
826
827    #[test]
828    fn test_dag_creation() {
829        let workflow = DAGWorkflow::new("test", "Test workflow");
830        assert_eq!(workflow.name, "test");
831        assert_eq!(workflow.description, "Test workflow");
832        assert_eq!(workflow.agents.len(), 0);
833        assert_eq!(workflow.workflow.node_count(), 0);
834        assert_eq!(workflow.workflow.edge_count(), 0);
835    }
836
837    #[test]
838    fn test_agent_registration() {
839        let mut workflow = DAGWorkflow::new("test", "Test workflow");
840        workflow.register_agent(create_mock_agent("1", "agent1", "Test agent", "response1"));
841
842        assert_eq!(workflow.agents.len(), 1);
843        assert_eq!(workflow.workflow.node_count(), 1);
844        assert!(workflow.name_to_node.contains_key("agent1"));
845    }
846
847    #[test]
848    fn test_agent_connection() {
849        let mut workflow = DAGWorkflow::new("test", "Test workflow");
850        workflow.register_agent(create_mock_agent("1", "agent1", "First agent", "response1"));
851        workflow.register_agent(create_mock_agent(
852            "2",
853            "agent2",
854            "Second agent",
855            "response2",
856        ));
857
858        let result = workflow.connect_agents("agent1", "agent2", Flow::default());
859        assert!(result.is_ok());
860        assert_eq!(workflow.workflow.edge_count(), 1);
861    }
862
863    #[test]
864    fn test_agent_connection_failure_nonexistent_agent() {
865        let mut workflow = DAGWorkflow::new("test", "Test workflow");
866        workflow.register_agent(create_mock_agent("1", "agent1", "Test agent", "response1"));
867
868        let result = workflow.connect_agents("agent1", "nonexistent", Flow::default());
869        assert!(matches!(result, Err(GraphWorkflowError::AgentNotFound(_))));
870
871        let result = workflow.connect_agents("nonexistent", "agent1", Flow::default());
872        assert!(matches!(result, Err(GraphWorkflowError::AgentNotFound(_))));
873    }
874
875    #[test]
876    fn test_cycle_detection() {
877        let mut workflow = DAGWorkflow::new("test", "Test workflow");
878        workflow.register_agent(create_mock_agent("1", "agent1", "First agent", "response1"));
879        workflow.register_agent(create_mock_agent(
880            "2",
881            "agent2",
882            "Second agent",
883            "response2",
884        ));
885        workflow.register_agent(create_mock_agent("3", "agent3", "Third agent", "response3"));
886
887        // agent1 -> agent2 -> agent3
888        let result1 = workflow.connect_agents("agent1", "agent2", Flow::default());
889        assert!(result1.is_ok());
890        let result2 = workflow.connect_agents("agent2", "agent3", Flow::default());
891        assert!(result2.is_ok());
892
893        // cycle it: agent3 -> agent1
894        let result3 = workflow.connect_agents("agent3", "agent1", Flow::default());
895        assert!(matches!(result3, Err(GraphWorkflowError::CycleDetected)));
896
897        // edge should not be added
898        assert_eq!(workflow.workflow.edge_count(), 2);
899    }
900
901    #[test]
902    fn test_agent_disconnection() {
903        let mut workflow = DAGWorkflow::new("test", "Test workflow");
904        workflow.register_agent(create_mock_agent("1", "agent1", "First agent", "response1"));
905        workflow.register_agent(create_mock_agent(
906            "2",
907            "agent2",
908            "Second agent",
909            "response2",
910        ));
911
912        workflow
913            .connect_agents("agent1", "agent2", Flow::default())
914            .unwrap();
915        assert_eq!(workflow.workflow.edge_count(), 1);
916
917        let result = workflow.disconnect_agents("agent1", "agent2");
918        assert!(result.is_ok());
919        assert_eq!(workflow.workflow.edge_count(), 0);
920    }
921
922    #[test]
923    fn test_agent_disconnection_failure() {
924        let mut workflow = DAGWorkflow::new("test", "Test workflow");
925        workflow.register_agent(create_mock_agent("1", "agent1", "First agent", "response1"));
926        workflow.register_agent(create_mock_agent(
927            "2",
928            "agent2",
929            "Second agent",
930            "response2",
931        ));
932
933        // try to disconnect non-existent edge
934        let result = workflow.disconnect_agents("agent1", "agent2");
935        assert!(matches!(result, Err(GraphWorkflowError::AgentNotFound(_))));
936
937        // try to disconnect non-existent agent
938        let result = workflow.disconnect_agents("nonexistent", "agent2");
939        assert!(matches!(result, Err(GraphWorkflowError::AgentNotFound(_))));
940    }
941
942    #[test]
943    fn test_agent_removal() {
944        let mut workflow = DAGWorkflow::new("test", "Test workflow");
945        workflow.register_agent(create_mock_agent("1", "agent1", "First agent", "response1"));
946        workflow.register_agent(create_mock_agent(
947            "2",
948            "agent2",
949            "Second agent",
950            "response2",
951        ));
952
953        workflow
954            .connect_agents("agent1", "agent2", Flow::default())
955            .unwrap();
956        assert_eq!(workflow.agents.len(), 2);
957        assert_eq!(workflow.workflow.node_count(), 2);
958
959        let result = workflow.remove_agent("agent1");
960        assert!(result.is_ok());
961        assert_eq!(workflow.agents.len(), 1);
962        assert_eq!(workflow.workflow.node_count(), 1);
963        assert!(!workflow.name_to_node.contains_key("agent1"));
964
965        assert_eq!(workflow.workflow.edge_count(), 0);
966    }
967
968    #[test]
969    fn test_agent_removal_nonexistent() {
970        let mut workflow = DAGWorkflow::new("test", "Test workflow");
971
972        let result = workflow.remove_agent("nonexistent");
973        assert!(matches!(result, Err(GraphWorkflowError::AgentNotFound(_))));
974    }
975
976    #[tokio::test]
977    async fn test_execute_single_agent() {
978        let mut workflow = DAGWorkflow::new("test", "Test workflow");
979        workflow.register_agent(create_mock_agent("1", "agent1", "Test agent", "response1"));
980
981        let result = workflow.execute_agent("agent1", "input".to_owned()).await;
982        assert!(result.is_ok());
983        assert_eq!(result.unwrap(), "response1");
984    }
985
986    #[tokio::test]
987    async fn test_execute_single_agent_failure() {
988        let mut workflow = DAGWorkflow::new("test", "Test workflow");
989        workflow.register_agent(create_failing_agent("1", "agent1", "test error"));
990
991        let result = workflow.execute_agent("agent1", "input".to_owned()).await;
992        assert!(matches!(result, Err(GraphWorkflowError::AgentError(_))));
993    }
994
995    #[tokio::test]
996    async fn test_execute_single_agent_not_found() {
997        let workflow = DAGWorkflow::new("test", "Test workflow");
998
999        let result = workflow
1000            .execute_agent("nonexistent", "input".to_owned())
1001            .await;
1002        assert!(matches!(result, Err(GraphWorkflowError::AgentNotFound(_))));
1003    }
1004
1005    #[tokio::test]
1006    async fn test_execute_workflow_linear() {
1007        let mut workflow = DAGWorkflow::new("test", "Test workflow");
1008        workflow.register_agent(create_mock_agent("1", "agent1", "First agent", "response1"));
1009        workflow.register_agent(create_mock_agent(
1010            "2",
1011            "agent2",
1012            "Second agent",
1013            "response2",
1014        ));
1015
1016        workflow
1017            .connect_agents("agent1", "agent2", Flow::default())
1018            .unwrap();
1019
1020        let results = workflow
1021            .execute_workflow(&["agent1"], "input")
1022            .await
1023            .unwrap();
1024        assert_eq!(results.len(), 2);
1025        assert_eq!(
1026            results.get("agent1").unwrap().as_ref().unwrap(),
1027            "response1"
1028        );
1029        assert_eq!(
1030            results.get("agent2").unwrap().as_ref().unwrap(),
1031            "response2"
1032        );
1033    }
1034
1035    #[tokio::test]
1036    async fn test_execute_workflow_branching() {
1037        let mut workflow = DAGWorkflow::new("test", "Test workflow");
1038        workflow.register_agent(create_mock_agent("1", "agent1", "Root agent", "response1"));
1039        workflow.register_agent(create_mock_agent("2", "agent2", "Branch 1", "response2"));
1040        workflow.register_agent(create_mock_agent("3", "agent3", "Branch 2", "response3"));
1041
1042        workflow
1043            .connect_agents("agent1", "agent2", Flow::default())
1044            .unwrap();
1045        workflow
1046            .connect_agents("agent1", "agent3", Flow::default())
1047            .unwrap();
1048
1049        let results = workflow
1050            .execute_workflow(&["agent1"], "input")
1051            .await
1052            .unwrap();
1053        assert_eq!(results.len(), 3);
1054        assert_eq!(
1055            results.get("agent1").unwrap().as_ref().unwrap(),
1056            "response1"
1057        );
1058        assert_eq!(
1059            results.get("agent2").unwrap().as_ref().unwrap(),
1060            "response2"
1061        );
1062        assert_eq!(
1063            results.get("agent3").unwrap().as_ref().unwrap(),
1064            "response3"
1065        );
1066    }
1067
1068    #[tokio::test]
1069    async fn test_execute_workflow_with_transformation() {
1070        let mut workflow = DAGWorkflow::new("test", "Test workflow");
1071        workflow.register_agent(create_mock_agent("1", "agent1", "First agent", "response1"));
1072        workflow.register_agent(create_mock_agent(
1073            "2",
1074            "agent2",
1075            "Second agent",
1076            "response2",
1077        ));
1078
1079        let transform_fn = Arc::new(|input: String| format!("transformed: {input}"));
1080        let flow = Flow {
1081            transform: Some(transform_fn),
1082            condition: None,
1083        };
1084
1085        workflow.connect_agents("agent1", "agent2", flow).unwrap();
1086
1087        let results = workflow
1088            .execute_workflow(&["agent1"], "input")
1089            .await
1090            .unwrap();
1091        assert_eq!(results.len(), 2);
1092
1093        let structure = workflow.get_workflow_structure();
1094        let agent1_connections = &structure["agent1"];
1095        assert_eq!(agent1_connections.len(), 1);
1096        assert_eq!(agent1_connections[0].0, "agent2");
1097        assert_eq!(agent1_connections[0].1, Some("transform".to_owned()));
1098    }
1099
1100    #[tokio::test]
1101    async fn test_execute_workflow_with_condition_true() {
1102        let mut workflow = DAGWorkflow::new("test", "Test workflow");
1103        workflow.register_agent(create_mock_agent("1", "agent1", "First agent", "true"));
1104        workflow.register_agent(create_mock_agent("2", "agent2", "Second agent", "executed"));
1105
1106        let true_condition = Arc::new(|output: &str| output.contains("true"));
1107
1108        workflow
1109            .connect_agents(
1110                "agent1",
1111                "agent2",
1112                Flow {
1113                    transform: None,
1114                    condition: Some(true_condition),
1115                },
1116            )
1117            .unwrap();
1118
1119        let results = workflow
1120            .execute_workflow(&["agent1"], "input")
1121            .await
1122            .unwrap();
1123        assert_eq!(results.len(), 2);
1124        assert!(results.contains_key("agent1"));
1125        assert!(results.contains_key("agent2"));
1126    }
1127
1128    #[tokio::test]
1129    async fn test_execute_workflow_with_condition_false() {
1130        let mut workflow = DAGWorkflow::new("test", "Test workflow");
1131        workflow.register_agent(create_mock_agent("1", "agent1", "First agent", "response1"));
1132        workflow.register_agent(create_mock_agent(
1133            "2",
1134            "agent2",
1135            "Second agent",
1136            "not executed",
1137        ));
1138
1139        let false_condition = Arc::new(|output: &str| output.contains("nonexistent"));
1140
1141        workflow
1142            .connect_agents(
1143                "agent1",
1144                "agent2",
1145                Flow {
1146                    transform: None,
1147                    condition: Some(false_condition),
1148                },
1149            )
1150            .unwrap();
1151
1152        let results = workflow
1153            .execute_workflow(&["agent1"], "input")
1154            .await
1155            .unwrap();
1156        assert_eq!(results.len(), 1);
1157        assert!(results.contains_key("agent1"));
1158        assert!(!results.contains_key("agent2"));
1159    }
1160
1161    #[tokio::test]
1162    async fn test_workflow_execution_start_agent_not_found() {
1163        let mut workflow = DAGWorkflow::new("test", "Test workflow");
1164        workflow.register_agent(create_mock_agent("1", "agent1", "First agent", "response1"));
1165
1166        let result = workflow.execute_workflow(&["nonexistent"], "input").await;
1167        assert!(matches!(result, Err(GraphWorkflowError::AgentNotFound(_))));
1168    }
1169
1170    #[tokio::test]
1171    async fn test_workflow_execution_with_failing_agent() {
1172        let mut workflow = DAGWorkflow::new("test", "Test workflow");
1173        workflow.register_agent(create_mock_agent("1", "agent1", "First agent", "response1"));
1174        workflow.register_agent(create_failing_agent("2", "agent2", "fail error"));
1175        workflow.register_agent(create_mock_agent("3", "agent3", "Third agent", "response3"));
1176
1177        // agent1 -> agent2 -> agent3
1178        workflow
1179            .connect_agents("agent1", "agent2", Flow::default())
1180            .unwrap();
1181        workflow
1182            .connect_agents("agent2", "agent3", Flow::default())
1183            .unwrap();
1184
1185        let results = workflow
1186            .execute_workflow(&["agent1"], "input")
1187            .await
1188            .unwrap();
1189        assert_eq!(results.len(), 2);
1190        assert!(results.contains_key("agent1"));
1191        assert!(results.contains_key("agent2"));
1192        assert!(!results.contains_key("agent3"));
1193
1194        let agent2_result = results.get("agent2").unwrap();
1195        assert!(agent2_result.is_err());
1196    }
1197
1198    #[tokio::test]
1199    async fn test_independent_multiple_starts() {
1200        let mut workflow = DAGWorkflow::new("test", "");
1201
1202        let agent_a = create_mock_agent("1", "A", "A", "A_result");
1203        let agent_b = create_mock_agent("2", "B", "B", "B_result");
1204        let agent_c = create_mock_agent("3", "C", "C", "C_result");
1205        let agent_d = create_mock_agent("4", "D", "D", "D_result");
1206
1207        workflow.register_agent(agent_a);
1208        workflow.register_agent(agent_b);
1209        workflow.register_agent(agent_c);
1210        workflow.register_agent(agent_d);
1211
1212        workflow.connect_agents("A", "C", Flow::default()).unwrap();
1213        workflow.connect_agents("B", "D", Flow::default()).unwrap();
1214
1215        let results = workflow
1216            .execute_workflow(&["A", "B"], "input")
1217            .await
1218            .unwrap();
1219
1220        assert_eq!(results.get("A").unwrap().as_ref().unwrap(), "A_result");
1221        assert_eq!(results.get("B").unwrap().as_ref().unwrap(), "B_result");
1222        assert_eq!(results.get("C").unwrap().as_ref().unwrap(), "C_result");
1223        assert_eq!(results.get("D").unwrap().as_ref().unwrap(), "D_result");
1224    }
1225
1226    /// FIXME: This test fails
1227    #[tokio::test]
1228    async fn test_converging_multiple_starts() {
1229        let mut workflow = DAGWorkflow::new("test", "");
1230
1231        let agent_a = create_mock_agent("1", "A", "A", "A_result");
1232        let agent_b = create_mock_agent("2", "B", "B", "B_result");
1233        let agent_c = create_mock_agent("3", "C", "C", "C_result");
1234
1235        workflow.register_agent(agent_a);
1236        workflow.register_agent(agent_b);
1237        workflow.register_agent(agent_c);
1238
1239        workflow.connect_agents("A", "C", Flow::default()).unwrap();
1240        workflow.connect_agents("B", "C", Flow::default()).unwrap();
1241
1242        let _results = workflow
1243            .execute_workflow(&["A", "B"], "input")
1244            .await
1245            .unwrap();
1246
1247        let c_node = workflow.name_to_node.get("C").unwrap();
1248        let node_data = workflow.workflow.node_weight(*c_node).unwrap();
1249        let last_result = node_data.last_result.lock().await;
1250        assert!(last_result.is_some());
1251        assert!(
1252            last_result
1253                .as_ref()
1254                .unwrap()
1255                .as_ref()
1256                .unwrap()
1257                .contains("A_result")
1258        );
1259        assert!(
1260            last_result
1261                .as_ref()
1262                .unwrap()
1263                .as_ref()
1264                .unwrap()
1265                .contains("B_result")
1266        );
1267    }
1268
1269    /// FIXME: This test fails
1270    #[tokio::test]
1271    async fn test_conditional_branches() {
1272        let mut workflow = DAGWorkflow::new("test", "");
1273
1274        let agent_a = create_mock_agent("1", "A", "A", "A_trigger");
1275        let agent_b = create_mock_agent("2", "B", "B", "B_result");
1276        let agent_c = create_mock_agent("3", "C", "C", "C_result");
1277
1278        workflow.register_agent(agent_a);
1279        workflow.register_agent(agent_b);
1280        workflow.register_agent(agent_c);
1281
1282        let conditional_flow = Flow {
1283            condition: Some(Arc::new(|output: &str| output.contains("trigger"))),
1284            transform: None,
1285        };
1286
1287        workflow.connect_agents("A", "B", conditional_flow).unwrap();
1288        workflow.connect_agents("A", "C", Flow::default()).unwrap();
1289
1290        let results = workflow.execute_workflow(&["A"], "input").await.unwrap();
1291
1292        assert!(results.get("B").is_none());
1293        assert_eq!(results.get("C").unwrap().as_ref().unwrap(), "C_result");
1294    }
1295
1296    #[test]
1297    fn test_find_execution_paths() {
1298        let mut workflow = DAGWorkflow::new("test", "Test workflow");
1299        workflow.register_agent(create_mock_agent("0", "start", "Starting point", "start"));
1300        workflow.register_agent(create_mock_agent("1", "a", "Path A", "a"));
1301        workflow.register_agent(create_mock_agent("2", "b", "Path B", "b"));
1302        workflow.register_agent(create_mock_agent("3", "c", "End of A", "c"));
1303        workflow.register_agent(create_mock_agent("4", "d", "End of B", "d"));
1304
1305        workflow
1306            .connect_agents("start", "a", Flow::default())
1307            .unwrap();
1308        workflow
1309            .connect_agents("start", "b", Flow::default())
1310            .unwrap();
1311        workflow.connect_agents("a", "c", Flow::default()).unwrap();
1312        workflow.connect_agents("b", "d", Flow::default()).unwrap();
1313
1314        let paths = workflow.find_execution_paths(&["start"]).unwrap();
1315        assert_eq!(paths.len(), 2);
1316
1317        // path should be [start, a, c] or [start, b, d]
1318        let has_path1 = paths
1319            .iter()
1320            .any(|p| p == &vec!["start".to_owned(), "a".to_owned(), "c".to_owned()]);
1321        let has_path2 = paths
1322            .iter()
1323            .any(|p| p == &vec!["start".to_owned(), "b".to_owned(), "d".to_owned()]);
1324
1325        assert!(has_path1);
1326        assert!(has_path2);
1327    }
1328
1329    #[test]
1330    fn test_find_execution_paths_start_agent_not_found() {
1331        let workflow = DAGWorkflow::new("test", "Test workflow");
1332
1333        let result = workflow.find_execution_paths(&["nonexistent"]);
1334        assert!(matches!(result, Err(GraphWorkflowError::AgentNotFound(_))));
1335    }
1336
1337    #[test]
1338    fn test_find_execution_paths_diamond_pattern() {
1339        let mut workflow = DAGWorkflow::new("test", "Test workflow");
1340        workflow.register_agent(create_mock_agent("0", "start", "Start", "start"));
1341        workflow.register_agent(create_mock_agent("1", "a", "Middle A", "a"));
1342        workflow.register_agent(create_mock_agent("2", "b", "Middle B", "b"));
1343        workflow.register_agent(create_mock_agent("3", "end", "End", "end"));
1344
1345        //            start -> a -> end
1346        //                 \-> b -/
1347        workflow
1348            .connect_agents("start", "a", Flow::default())
1349            .unwrap();
1350        workflow
1351            .connect_agents("start", "b", Flow::default())
1352            .unwrap();
1353        workflow
1354            .connect_agents("a", "end", Flow::default())
1355            .unwrap();
1356        workflow
1357            .connect_agents("b", "end", Flow::default())
1358            .unwrap();
1359
1360        let paths = workflow.find_execution_paths(&["start"]).unwrap();
1361        assert_eq!(paths.len(), 2);
1362
1363        // path should be [start, a, end] or [start, b, end]
1364        let has_path1 = paths
1365            .iter()
1366            .any(|p| p == &vec!["start".to_owned(), "a".to_owned(), "end".to_owned()]);
1367        let has_path2 = paths
1368            .iter()
1369            .any(|p| p == &vec!["start".to_owned(), "b".to_owned(), "end".to_owned()]);
1370
1371        assert!(has_path1);
1372        assert!(has_path2);
1373    }
1374
1375    #[test]
1376    fn test_detect_potential_deadlocks() {
1377        let mut workflow = DAGWorkflow::new("test", "Test workflow");
1378        workflow.register_agent(create_mock_agent("1", "a", "Agent A", "a"));
1379        workflow.register_agent(create_mock_agent("2", "b", "Agent B", "b"));
1380        workflow.register_agent(create_mock_agent("3", "c", "Agent C", "c"));
1381
1382        // a -> b -> c
1383        workflow.connect_agents("a", "b", Flow::default()).unwrap();
1384        workflow.connect_agents("b", "c", Flow::default()).unwrap();
1385
1386        // no cycles, should return empty vector
1387        let deadlocks = workflow.detect_potential_deadlocks();
1388        assert_eq!(deadlocks.len(), 0);
1389
1390        // try to add c -> a, which should fail because has_cycle prevents it
1391        let result = workflow.connect_agents("c", "a", Flow::default());
1392        assert!(matches!(result, Err(GraphWorkflowError::CycleDetected)));
1393    }
1394
1395    #[test]
1396    fn test_get_workflow_structure() {
1397        let mut workflow = DAGWorkflow::new("test", "Test workflow");
1398        workflow.register_agent(create_mock_agent("1", "a", "Agent A", "a"));
1399        workflow.register_agent(create_mock_agent("2", "b", "Agent B", "b"));
1400        workflow.register_agent(create_mock_agent("3", "c", "Agent C", "c"));
1401
1402        workflow.connect_agents("a", "b", Flow::default()).unwrap();
1403
1404        let transform_fn = Arc::new(|input: String| format!("transformed: {input}"));
1405        let flow = Flow {
1406            transform: Some(transform_fn),
1407            condition: None,
1408        };
1409
1410        workflow.connect_agents("b", "c", flow).unwrap();
1411
1412        let structure = workflow.get_workflow_structure();
1413        assert_eq!(structure.len(), 3);
1414
1415        assert_eq!(structure["a"].len(), 1);
1416        assert_eq!(structure["a"][0].0, "b");
1417        assert_eq!(structure["a"][0].1, None);
1418
1419        assert_eq!(structure["b"].len(), 1);
1420        assert_eq!(structure["b"][0].0, "c");
1421        assert_eq!(structure["b"][0].1, Some("transform".to_owned())); // has transform
1422
1423        assert_eq!(structure["c"].len(), 0); // c is a leaf node
1424    }
1425
1426    #[test]
1427    fn test_export_workflow_dot() {
1428        let mut workflow = DAGWorkflow::new("test", "Test workflow");
1429        workflow.register_agent(create_mock_agent("1", "a", "Agent A", "a"));
1430        workflow.register_agent(create_mock_agent("2", "b", "Agent B", "b"));
1431
1432        workflow.connect_agents("a", "b", Flow::default()).unwrap();
1433
1434        let dot = workflow.export_workflow_dot();
1435
1436        assert!(dot.contains("digraph {"));
1437        assert!(dot.contains("\"a\" [label=\"a\"]"));
1438        assert!(dot.contains("\"b\" [label=\"b\"]"));
1439        assert!(dot.contains("\"a\" -> \"b\""));
1440        assert!(dot.contains("}"));
1441    }
1442
1443    #[tokio::test]
1444    async fn test_caching_execution_results() {
1445        let mut workflow = DAGWorkflow::new("test", "Test workflow");
1446
1447        // mock counter agent
1448        let mut agent = MockAgent::new();
1449        let agent_name = "counter".to_owned();
1450        agent.expect_name().return_const(agent_name.clone());
1451        agent.expect_id().return_const("1".to_owned());
1452        agent
1453            .expect_description()
1454            .return_const("Counter Agent".to_owned());
1455
1456        let mut count = 0;
1457        agent.expect_run().returning(move |_| {
1458            count += 1;
1459            Box::pin(future::ready(Ok(format!("Called {count} times"))))
1460        });
1461
1462        agent
1463            .expect_run_multiple_tasks()
1464            .returning(|_| Box::pin(future::ready(Ok(vec![]))));
1465
1466        workflow.register_agent(Arc::new(agent));
1467
1468        // first execution
1469        let results1 = workflow
1470            .execute_workflow(&["counter"], "input1")
1471            .await
1472            .unwrap();
1473        assert_eq!(
1474            results1.get("counter").unwrap().as_ref().unwrap(),
1475            "Called 1 times"
1476        );
1477
1478        // second execution (should reset and call again)
1479        let results2 = workflow
1480            .execute_workflow(&["counter"], "input2")
1481            .await
1482            .unwrap();
1483        assert_eq!(
1484            results2.get("counter").unwrap().as_ref().unwrap(),
1485            "Called 2 times"
1486        );
1487
1488        // call execute_agent directly (should not use cache)
1489        let result3 = workflow
1490            .execute_agent("counter", "input3".to_owned())
1491            .await
1492            .unwrap();
1493        assert_eq!(result3, "Called 3 times");
1494    }
1495
1496    #[tokio::test]
1497    async fn test_execute_node_result_caching() {
1498        let mut workflow = DAGWorkflow::new("test", "Test workflow");
1499
1500        // Create a mock agent that records the number of calls
1501        let mut agent1 = MockAgent::new();
1502        agent1.expect_name().return_const("agent1".to_owned());
1503        agent1.expect_id().return_const("1".to_owned());
1504        agent1
1505            .expect_description()
1506            .return_const("First agent".to_owned());
1507
1508        // Set a counter to verify that the run method was called only once
1509        let mut run_count = 0;
1510        agent1.expect_run().returning(move |input| {
1511            run_count += 1;
1512            Box::pin(future::ready(Ok(format!(
1513                "response for '{input}' (call #{run_count})"
1514            ))))
1515        });
1516
1517        agent1
1518            .expect_run_multiple_tasks()
1519            .returning(|_| Box::pin(future::ready(Ok(vec![]))));
1520
1521        workflow.register_agent(Arc::new(agent1));
1522
1523        // Create a normal second proxy
1524        workflow.register_agent(create_mock_agent(
1525            "2",
1526            "agent2",
1527            "Second agent",
1528            "response2",
1529        ));
1530
1531        // connect the two agents
1532        workflow
1533            .connect_agents("agent1", "agent2", Flow::default())
1534            .unwrap();
1535
1536        let agent1_idx = *workflow.name_to_node.get("agent1").unwrap();
1537
1538        // create shared data structures
1539        let results = Arc::new(DashMap::new());
1540        let edge_tracker = Arc::new(DashMap::new());
1541        let processed_nodes = Arc::new(DashMap::new());
1542
1543        // first execution of agent1
1544        let result1 = workflow
1545            .execute_node(
1546                agent1_idx,
1547                "input1".to_owned(),
1548                Arc::clone(&results),
1549                Arc::clone(&edge_tracker),
1550                Arc::clone(&processed_nodes),
1551            )
1552            .await
1553            .unwrap();
1554
1555        assert_eq!(result1, "response for 'input1' (call #1)");
1556        assert!(results.contains_key("agent1"));
1557        assert!(results.contains_key("agent2")); // agent2 also executed
1558
1559        // second execution of agent1 with a different input
1560        let result2 = workflow
1561            .execute_node(
1562                agent1_idx,
1563                "input2".to_owned(),
1564                Arc::clone(&results),
1565                Arc::clone(&edge_tracker),
1566                Arc::clone(&processed_nodes),
1567            )
1568            .await
1569            .unwrap();
1570
1571        // the results should be the same, indicating that the agent was not executed again
1572        assert_eq!(result2, "response for 'input1' (call #1)"); // not "response for 'input2' (call #1)"
1573
1574        // clear the results map
1575        results.clear();
1576
1577        // third execution of agent1
1578        let result3 = workflow
1579            .execute_node(
1580                agent1_idx,
1581                "input3".to_owned(),
1582                Arc::clone(&results),
1583                Arc::clone(&edge_tracker),
1584                Arc::clone(&processed_nodes),
1585            )
1586            .await
1587            .unwrap();
1588
1589        // the results should contain the new call count, indicating that the agent was re-executed
1590        assert_eq!(result3, "response for 'input3' (call #2)");
1591    }
1592}