Skip to main content

oxigdal_workflow/dag/
graph.rs

1//! DAG construction and validation.
2
3use crate::error::{DagError, Result};
4use petgraph::Direction;
5use petgraph::graph::{DiGraph, NodeIndex};
6use petgraph::visit::EdgeRef;
7use serde::{Deserialize, Serialize};
8use std::collections::{HashMap, HashSet, VecDeque};
9use std::hash::{Hash, Hasher};
10
11/// A task node in the workflow DAG.
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct TaskNode {
14    /// Unique task identifier.
15    pub id: String,
16    /// Task name.
17    pub name: String,
18    /// Task description.
19    pub description: Option<String>,
20    /// Task configuration as JSON.
21    pub config: serde_json::Value,
22    /// Retry policy.
23    pub retry: RetryPolicy,
24    /// Timeout in seconds.
25    pub timeout_secs: Option<u64>,
26    /// Resource requirements.
27    pub resources: ResourceRequirements,
28    /// Custom metadata.
29    pub metadata: HashMap<String, String>,
30}
31
32/// Retry policy for task execution.
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct RetryPolicy {
35    /// Maximum number of retry attempts.
36    pub max_attempts: u32,
37    /// Delay between retries in milliseconds.
38    pub delay_ms: u64,
39    /// Backoff multiplier for exponential backoff.
40    pub backoff_multiplier: f64,
41    /// Maximum delay in milliseconds.
42    pub max_delay_ms: u64,
43}
44
45impl Default for RetryPolicy {
46    fn default() -> Self {
47        Self {
48            max_attempts: 3,
49            delay_ms: 1000,
50            backoff_multiplier: 2.0,
51            max_delay_ms: 60000,
52        }
53    }
54}
55
56/// Resource requirements for a task.
57#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct ResourceRequirements {
59    /// CPU cores required (can be fractional).
60    pub cpu_cores: f64,
61    /// Memory required in MB.
62    pub memory_mb: u64,
63    /// GPU required.
64    pub gpu: bool,
65    /// Disk space required in MB.
66    pub disk_mb: u64,
67    /// Custom resource requirements.
68    pub custom: HashMap<String, f64>,
69}
70
71impl Default for ResourceRequirements {
72    fn default() -> Self {
73        Self {
74            cpu_cores: 1.0,
75            memory_mb: 1024,
76            gpu: false,
77            disk_mb: 1024,
78            custom: HashMap::new(),
79        }
80    }
81}
82
83impl PartialEq for TaskNode {
84    fn eq(&self, other: &Self) -> bool {
85        self.id == other.id
86    }
87}
88
89impl Eq for TaskNode {}
90
91impl Hash for TaskNode {
92    fn hash<H: Hasher>(&self, state: &mut H) {
93        self.id.hash(state);
94    }
95}
96
97/// An edge representing a dependency between tasks.
98#[derive(Debug, Clone, Serialize, Deserialize)]
99pub struct TaskEdge {
100    /// Edge type (data dependency, control dependency, etc.).
101    pub edge_type: EdgeType,
102    /// Condition for edge activation.
103    pub condition: Option<String>,
104}
105
106/// Type of dependency edge.
107#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
108pub enum EdgeType {
109    /// Data dependency - output of one task is input to another.
110    Data,
111    /// Control dependency - one task must complete before another starts.
112    Control,
113    /// Conditional - edge is only followed if condition is met.
114    Conditional,
115}
116
117impl Default for TaskEdge {
118    fn default() -> Self {
119        Self {
120            edge_type: EdgeType::Control,
121            condition: None,
122        }
123    }
124}
125
126#[derive(Debug, Clone, Serialize, Deserialize)]
127
128/// Workflow DAG structure.
129pub struct WorkflowDag {
130    /// Underlying directed graph.
131    pub(crate) graph: DiGraph<TaskNode, TaskEdge>,
132    /// Mapping from task ID to node index.
133    pub(crate) task_map: HashMap<String, NodeIndex>,
134}
135
136impl WorkflowDag {
137    /// Create a new empty workflow DAG.
138    pub fn new() -> Self {
139        Self {
140            graph: DiGraph::new(),
141            task_map: HashMap::new(),
142        }
143    }
144
145    /// Add a task to the DAG.
146    pub fn add_task(&mut self, task: TaskNode) -> Result<NodeIndex> {
147        if self.task_map.contains_key(&task.id) {
148            return Err(
149                DagError::InvalidNode(format!("Task '{}' already exists in DAG", task.id)).into(),
150            );
151        }
152
153        let node_index = self.graph.add_node(task.clone());
154        self.task_map.insert(task.id.clone(), node_index);
155        Ok(node_index)
156    }
157
158    /// Add a dependency edge between two tasks.
159    pub fn add_dependency(
160        &mut self,
161        from_task_id: &str,
162        to_task_id: &str,
163        edge: TaskEdge,
164    ) -> Result<()> {
165        let from_idx = self
166            .task_map
167            .get(from_task_id)
168            .ok_or_else(|| DagError::invalid_node(from_task_id))?;
169
170        let to_idx = self
171            .task_map
172            .get(to_task_id)
173            .ok_or_else(|| DagError::invalid_node(to_task_id))?;
174
175        self.graph.add_edge(*from_idx, *to_idx, edge);
176        Ok(())
177    }
178
179    /// Get a task by ID.
180    pub fn get_task(&self, task_id: &str) -> Option<&TaskNode> {
181        self.task_map
182            .get(task_id)
183            .and_then(|idx| self.graph.node_weight(*idx))
184    }
185
186    /// Get a task by ID (mutable).
187    pub fn get_task_mut(&mut self, task_id: &str) -> Option<&mut TaskNode> {
188        self.task_map
189            .get(task_id)
190            .and_then(|idx| self.graph.node_weight_mut(*idx))
191    }
192
193    /// Get task dependencies (tasks that must complete before this task).
194    pub fn get_dependencies(&self, task_id: &str) -> Vec<String> {
195        if let Some(&idx) = self.task_map.get(task_id) {
196            self.graph
197                .edges_directed(idx, Direction::Incoming)
198                .filter_map(|edge| {
199                    self.graph
200                        .node_weight(edge.source())
201                        .map(|task| task.id.clone())
202                })
203                .collect()
204        } else {
205            Vec::new()
206        }
207    }
208
209    /// Get task dependents (tasks that depend on this task).
210    pub fn get_dependents(&self, task_id: &str) -> Vec<String> {
211        if let Some(&idx) = self.task_map.get(task_id) {
212            self.graph
213                .edges_directed(idx, Direction::Outgoing)
214                .filter_map(|edge| {
215                    self.graph
216                        .node_weight(edge.target())
217                        .map(|task| task.id.clone())
218                })
219                .collect()
220        } else {
221            Vec::new()
222        }
223    }
224
225    /// Validate the DAG structure.
226    pub fn validate(&self) -> Result<()> {
227        // Check if DAG is empty
228        if self.graph.node_count() == 0 {
229            return Err(DagError::EmptyDag.into());
230        }
231
232        // Check for cycles
233        self.check_cycles()?;
234
235        // Check for unreachable nodes
236        self.check_reachability()?;
237
238        Ok(())
239    }
240
241    /// Check for cycles in the DAG using DFS.
242    fn check_cycles(&self) -> Result<()> {
243        let mut visited = HashSet::new();
244        let mut rec_stack = HashSet::new();
245
246        for node_idx in self.graph.node_indices() {
247            if !visited.contains(&node_idx) {
248                if let Some(cycle_path) =
249                    self.dfs_cycle_check(node_idx, &mut visited, &mut rec_stack)
250                {
251                    return Err(DagError::cycle(cycle_path).into());
252                }
253            }
254        }
255
256        Ok(())
257    }
258
259    /// DFS-based cycle detection.
260    fn dfs_cycle_check(
261        &self,
262        node: NodeIndex,
263        visited: &mut HashSet<NodeIndex>,
264        rec_stack: &mut HashSet<NodeIndex>,
265    ) -> Option<String> {
266        visited.insert(node);
267        rec_stack.insert(node);
268
269        for neighbor in self.graph.neighbors(node) {
270            if !visited.contains(&neighbor) {
271                if let Some(path) = self.dfs_cycle_check(neighbor, visited, rec_stack) {
272                    return Some(path);
273                }
274            } else if rec_stack.contains(&neighbor) {
275                // Cycle detected, construct path
276                let current_task = self.graph.node_weight(node).map(|t| &t.id)?;
277                let next_task = self.graph.node_weight(neighbor).map(|t| &t.id)?;
278                return Some(format!("{} -> {}", current_task, next_task));
279            }
280        }
281
282        rec_stack.remove(&node);
283        None
284    }
285
286    /// Check if all nodes are reachable from root nodes.
287    fn check_reachability(&self) -> Result<()> {
288        // Find root nodes (nodes with no incoming edges)
289        let root_nodes: Vec<NodeIndex> = self
290            .graph
291            .node_indices()
292            .filter(|&idx| self.graph.edges_directed(idx, Direction::Incoming).count() == 0)
293            .collect();
294
295        if root_nodes.is_empty() {
296            // If no root nodes, check if the graph has cycles (all nodes have incoming edges)
297            return Ok(());
298        }
299
300        // BFS from all root nodes to find reachable nodes
301        let mut reachable = HashSet::new();
302        let mut queue = VecDeque::from(root_nodes);
303
304        while let Some(node) = queue.pop_front() {
305            if reachable.insert(node) {
306                for neighbor in self.graph.neighbors(node) {
307                    if !reachable.contains(&neighbor) {
308                        queue.push_back(neighbor);
309                    }
310                }
311            }
312        }
313
314        // Check if all nodes are reachable
315        for node_idx in self.graph.node_indices() {
316            if !reachable.contains(&node_idx) {
317                if let Some(task) = self.graph.node_weight(node_idx) {
318                    return Err(DagError::UnreachableNode(task.id.clone()).into());
319                }
320            }
321        }
322
323        Ok(())
324    }
325
326    /// Get all tasks in the DAG.
327    pub fn tasks(&self) -> Vec<&TaskNode> {
328        self.graph
329            .node_indices()
330            .filter_map(|idx| self.graph.node_weight(idx))
331            .collect()
332    }
333
334    /// Get the number of tasks in the DAG.
335    pub fn task_count(&self) -> usize {
336        self.graph.node_count()
337    }
338
339    /// Get the number of dependencies in the DAG.
340    pub fn dependency_count(&self) -> usize {
341        self.graph.edge_count()
342    }
343
344    /// Get root tasks (tasks with no dependencies).
345    pub fn root_tasks(&self) -> Vec<&TaskNode> {
346        self.graph
347            .node_indices()
348            .filter(|&idx| self.graph.edges_directed(idx, Direction::Incoming).count() == 0)
349            .filter_map(|idx| self.graph.node_weight(idx))
350            .collect()
351    }
352
353    /// Get leaf tasks (tasks with no dependents).
354    pub fn leaf_tasks(&self) -> Vec<&TaskNode> {
355        self.graph
356            .node_indices()
357            .filter(|&idx| self.graph.edges_directed(idx, Direction::Outgoing).count() == 0)
358            .filter_map(|idx| self.graph.node_weight(idx))
359            .collect()
360    }
361
362    /// Get all edges in the DAG as (from_task_id, to_task_id, edge_data) tuples.
363    ///
364    /// This method is useful for visualization and serialization purposes.
365    /// Returns edges in the order they are stored in the graph.
366    pub fn edges(&self) -> Vec<(&str, &str, &TaskEdge)> {
367        self.graph
368            .edge_indices()
369            .filter_map(|edge_idx| {
370                let (from_idx, to_idx) = self.graph.edge_endpoints(edge_idx)?;
371                let from_node = self.graph.node_weight(from_idx)?;
372                let to_node = self.graph.node_weight(to_idx)?;
373                let edge = self.graph.edge_weight(edge_idx)?;
374                Some((from_node.id.as_str(), to_node.id.as_str(), edge))
375            })
376            .collect()
377    }
378
379    /// Get all edges with their edge types as (from_task_id, to_task_id, edge_type) tuples.
380    ///
381    /// A simplified version of `edges()` that only returns edge types.
382    pub fn edge_pairs(&self) -> Vec<(String, String)> {
383        self.graph
384            .edge_indices()
385            .filter_map(|edge_idx| {
386                let (from_idx, to_idx) = self.graph.edge_endpoints(edge_idx)?;
387                let from_node = self.graph.node_weight(from_idx)?;
388                let to_node = self.graph.node_weight(to_idx)?;
389                Some((from_node.id.clone(), to_node.id.clone()))
390            })
391            .collect()
392    }
393
394    /// Get task dependencies along with their edge data.
395    ///
396    /// Returns a vector of (dependency_task_id, edge_data) tuples for the given task.
397    /// Dependencies are tasks that must complete before the specified task can start.
398    pub fn get_dependencies_with_edges(&self, task_id: &str) -> Vec<(String, &TaskEdge)> {
399        if let Some(&idx) = self.task_map.get(task_id) {
400            self.graph
401                .edges_directed(idx, Direction::Incoming)
402                .filter_map(|edge| {
403                    let source_node = self.graph.node_weight(edge.source())?;
404                    Some((source_node.id.clone(), edge.weight()))
405                })
406                .collect()
407        } else {
408            Vec::new()
409        }
410    }
411
412    /// Get task dependents along with their edge data.
413    ///
414    /// Returns a vector of (dependent_task_id, edge_data) tuples for the given task.
415    /// Dependents are tasks that wait for the specified task to complete.
416    pub fn get_dependents_with_edges(&self, task_id: &str) -> Vec<(String, &TaskEdge)> {
417        if let Some(&idx) = self.task_map.get(task_id) {
418            self.graph
419                .edges_directed(idx, Direction::Outgoing)
420                .filter_map(|edge| {
421                    let target_node = self.graph.node_weight(edge.target())?;
422                    Some((target_node.id.clone(), edge.weight()))
423                })
424                .collect()
425        } else {
426            Vec::new()
427        }
428    }
429
430    /// Get the edge data between two specific tasks, if it exists.
431    ///
432    /// Returns `None` if either task does not exist or no edge connects them.
433    pub fn get_edge_between(&self, from_task_id: &str, to_task_id: &str) -> Option<&TaskEdge> {
434        let from_idx = self.task_map.get(from_task_id)?;
435        let to_idx = self.task_map.get(to_task_id)?;
436        self.graph
437            .find_edge(*from_idx, *to_idx)
438            .and_then(|edge_idx| self.graph.edge_weight(edge_idx))
439    }
440
441    /// Check if a dependency exists between two tasks.
442    ///
443    /// Returns `true` if `from_task_id` has a direct edge to `to_task_id`.
444    pub fn has_dependency(&self, from_task_id: &str, to_task_id: &str) -> bool {
445        self.get_edge_between(from_task_id, to_task_id).is_some()
446    }
447
448    /// Check if a task has any dependencies (incoming edges).
449    pub fn has_dependencies(&self, task_id: &str) -> bool {
450        if let Some(&idx) = self.task_map.get(task_id) {
451            self.graph.edges_directed(idx, Direction::Incoming).count() > 0
452        } else {
453            false
454        }
455    }
456
457    /// Check if a task has any dependents (outgoing edges).
458    pub fn has_dependents(&self, task_id: &str) -> bool {
459        if let Some(&idx) = self.task_map.get(task_id) {
460            self.graph.edges_directed(idx, Direction::Outgoing).count() > 0
461        } else {
462            false
463        }
464    }
465
466    /// Get the in-degree of a task (number of dependencies).
467    pub fn in_degree(&self, task_id: &str) -> usize {
468        if let Some(&idx) = self.task_map.get(task_id) {
469            self.graph.edges_directed(idx, Direction::Incoming).count()
470        } else {
471            0
472        }
473    }
474
475    /// Get the out-degree of a task (number of dependents).
476    pub fn out_degree(&self, task_id: &str) -> usize {
477        if let Some(&idx) = self.task_map.get(task_id) {
478            self.graph.edges_directed(idx, Direction::Outgoing).count()
479        } else {
480            0
481        }
482    }
483
484    /// Get all task IDs in the DAG.
485    pub fn task_ids(&self) -> Vec<String> {
486        self.task_map.keys().cloned().collect()
487    }
488
489    /// Check if a task exists in the DAG.
490    pub fn contains_task(&self, task_id: &str) -> bool {
491        self.task_map.contains_key(task_id)
492    }
493
494    /// Remove a task from the DAG along with all its edges.
495    ///
496    /// Returns the removed task, or `None` if the task did not exist.
497    pub fn remove_task(&mut self, task_id: &str) -> Option<TaskNode> {
498        let node_idx = self.task_map.remove(task_id)?;
499        self.graph.remove_node(node_idx)
500    }
501
502    /// Get edges filtered by edge type.
503    pub fn edges_by_type(&self, edge_type: EdgeType) -> Vec<(&str, &str, &TaskEdge)> {
504        self.graph
505            .edge_indices()
506            .filter_map(|edge_idx| {
507                let edge = self.graph.edge_weight(edge_idx)?;
508                if edge.edge_type != edge_type {
509                    return None;
510                }
511                let (from_idx, to_idx) = self.graph.edge_endpoints(edge_idx)?;
512                let from_node = self.graph.node_weight(from_idx)?;
513                let to_node = self.graph.node_weight(to_idx)?;
514                Some((from_node.id.as_str(), to_node.id.as_str(), edge))
515            })
516            .collect()
517    }
518
519    /// Get a subgraph containing only the specified tasks and edges between them.
520    ///
521    /// Tasks not present in the original DAG are silently ignored.
522    pub fn subgraph(&self, task_ids: &[&str]) -> WorkflowDag {
523        let mut sub = WorkflowDag::new();
524        let id_set: HashSet<&str> = task_ids.iter().copied().collect();
525
526        // Add matching nodes
527        for task_id in task_ids {
528            if let Some(task) = self.get_task(task_id) {
529                // Ignore errors from duplicate insertions if task_ids has duplicates
530                let _ = sub.add_task(task.clone());
531            }
532        }
533
534        // Add edges that connect nodes within the subgraph
535        for (from_id, to_id, edge) in self.edges() {
536            if id_set.contains(from_id) && id_set.contains(to_id) {
537                let _ = sub.add_dependency(from_id, to_id, edge.clone());
538            }
539        }
540
541        sub
542    }
543
544    /// Compute the transitive closure of dependencies for a task.
545    ///
546    /// Returns all tasks that must complete (directly or transitively) before
547    /// the given task can execute.
548    pub fn transitive_dependencies(&self, task_id: &str) -> Vec<String> {
549        let mut visited = HashSet::new();
550        let mut queue = VecDeque::new();
551
552        // Seed with direct dependencies
553        for dep in self.get_dependencies(task_id) {
554            if visited.insert(dep.clone()) {
555                queue.push_back(dep);
556            }
557        }
558
559        while let Some(current) = queue.pop_front() {
560            for dep in self.get_dependencies(&current) {
561                if visited.insert(dep.clone()) {
562                    queue.push_back(dep);
563                }
564            }
565        }
566
567        visited.into_iter().collect()
568    }
569
570    /// Compute the transitive closure of dependents for a task.
571    ///
572    /// Returns all tasks that (directly or transitively) depend on the given task.
573    pub fn transitive_dependents(&self, task_id: &str) -> Vec<String> {
574        let mut visited = HashSet::new();
575        let mut queue = VecDeque::new();
576
577        // Seed with direct dependents
578        for dep in self.get_dependents(task_id) {
579            if visited.insert(dep.clone()) {
580                queue.push_back(dep);
581            }
582        }
583
584        while let Some(current) = queue.pop_front() {
585            for dep in self.get_dependents(&current) {
586                if visited.insert(dep.clone()) {
587                    queue.push_back(dep);
588                }
589            }
590        }
591
592        visited.into_iter().collect()
593    }
594
595    /// Get summary statistics about the DAG structure.
596    pub fn summary(&self) -> DagSummary {
597        let node_count = self.graph.node_count();
598        let edge_count = self.graph.edge_count();
599        let root_count = self.root_tasks().len();
600        let leaf_count = self.leaf_tasks().len();
601
602        let max_in_degree = self
603            .graph
604            .node_indices()
605            .map(|idx| self.graph.edges_directed(idx, Direction::Incoming).count())
606            .max()
607            .unwrap_or(0);
608
609        let max_out_degree = self
610            .graph
611            .node_indices()
612            .map(|idx| self.graph.edges_directed(idx, Direction::Outgoing).count())
613            .max()
614            .unwrap_or(0);
615
616        let data_edges = self.edges_by_type(EdgeType::Data).len();
617        let control_edges = self.edges_by_type(EdgeType::Control).len();
618        let conditional_edges = self.edges_by_type(EdgeType::Conditional).len();
619
620        DagSummary {
621            node_count,
622            edge_count,
623            root_count,
624            leaf_count,
625            max_in_degree,
626            max_out_degree,
627            data_edge_count: data_edges,
628            control_edge_count: control_edges,
629            conditional_edge_count: conditional_edges,
630        }
631    }
632}
633
634/// Summary statistics for a DAG.
635#[derive(Debug, Clone, Serialize, Deserialize)]
636pub struct DagSummary {
637    /// Number of task nodes.
638    pub node_count: usize,
639    /// Number of dependency edges.
640    pub edge_count: usize,
641    /// Number of root tasks (no dependencies).
642    pub root_count: usize,
643    /// Number of leaf tasks (no dependents).
644    pub leaf_count: usize,
645    /// Maximum number of dependencies any single task has.
646    pub max_in_degree: usize,
647    /// Maximum number of dependents any single task has.
648    pub max_out_degree: usize,
649    /// Number of data dependency edges.
650    pub data_edge_count: usize,
651    /// Number of control dependency edges.
652    pub control_edge_count: usize,
653    /// Number of conditional dependency edges.
654    pub conditional_edge_count: usize,
655}
656
657impl Default for WorkflowDag {
658    fn default() -> Self {
659        Self::new()
660    }
661}
662
663#[cfg(test)]
664mod tests {
665    use super::*;
666
667    fn create_test_task(id: &str, name: &str) -> TaskNode {
668        TaskNode {
669            id: id.to_string(),
670            name: name.to_string(),
671            description: None,
672            config: serde_json::json!({}),
673            retry: RetryPolicy::default(),
674            timeout_secs: Some(60),
675            resources: ResourceRequirements::default(),
676            metadata: HashMap::new(),
677        }
678    }
679
680    #[test]
681    fn test_add_task() {
682        let mut dag = WorkflowDag::new();
683        let task = create_test_task("task1", "Task 1");
684        let result = dag.add_task(task);
685        assert!(result.is_ok());
686        assert_eq!(dag.task_count(), 1);
687    }
688
689    #[test]
690    fn test_duplicate_task() {
691        let mut dag = WorkflowDag::new();
692        let task1 = create_test_task("task1", "Task 1");
693        let task2 = create_test_task("task1", "Task 1 Duplicate");
694
695        dag.add_task(task1).ok();
696        let result = dag.add_task(task2);
697        assert!(result.is_err());
698    }
699
700    #[test]
701    fn test_add_dependency() {
702        let mut dag = WorkflowDag::new();
703        dag.add_task(create_test_task("task1", "Task 1")).ok();
704        dag.add_task(create_test_task("task2", "Task 2")).ok();
705
706        let result = dag.add_dependency("task1", "task2", TaskEdge::default());
707        assert!(result.is_ok());
708        assert_eq!(dag.dependency_count(), 1);
709    }
710
711    #[test]
712    fn test_cycle_detection() {
713        let mut dag = WorkflowDag::new();
714        dag.add_task(create_test_task("task1", "Task 1")).ok();
715        dag.add_task(create_test_task("task2", "Task 2")).ok();
716        dag.add_task(create_test_task("task3", "Task 3")).ok();
717
718        // Create a cycle: task1 -> task2 -> task3 -> task1
719        dag.add_dependency("task1", "task2", TaskEdge::default())
720            .ok();
721        dag.add_dependency("task2", "task3", TaskEdge::default())
722            .ok();
723        dag.add_dependency("task3", "task1", TaskEdge::default())
724            .ok();
725
726        let result = dag.validate();
727        assert!(result.is_err());
728    }
729
730    #[test]
731    fn test_valid_dag() {
732        let mut dag = WorkflowDag::new();
733        dag.add_task(create_test_task("task1", "Task 1")).ok();
734        dag.add_task(create_test_task("task2", "Task 2")).ok();
735        dag.add_task(create_test_task("task3", "Task 3")).ok();
736
737        // Create a valid DAG: task1 -> task2, task1 -> task3
738        dag.add_dependency("task1", "task2", TaskEdge::default())
739            .ok();
740        dag.add_dependency("task1", "task3", TaskEdge::default())
741            .ok();
742
743        let result = dag.validate();
744        assert!(result.is_ok());
745    }
746
747    #[test]
748    fn test_root_and_leaf_tasks() {
749        let mut dag = WorkflowDag::new();
750        dag.add_task(create_test_task("task1", "Task 1")).ok();
751        dag.add_task(create_test_task("task2", "Task 2")).ok();
752        dag.add_task(create_test_task("task3", "Task 3")).ok();
753
754        dag.add_dependency("task1", "task2", TaskEdge::default())
755            .ok();
756        dag.add_dependency("task2", "task3", TaskEdge::default())
757            .ok();
758
759        let roots = dag.root_tasks();
760        assert_eq!(roots.len(), 1);
761        assert_eq!(roots[0].id, "task1");
762
763        let leaves = dag.leaf_tasks();
764        assert_eq!(leaves.len(), 1);
765        assert_eq!(leaves[0].id, "task3");
766    }
767
768    #[test]
769    fn test_edges() {
770        let mut dag = WorkflowDag::new();
771        dag.add_task(create_test_task("t1", "Task 1")).ok();
772        dag.add_task(create_test_task("t2", "Task 2")).ok();
773        dag.add_task(create_test_task("t3", "Task 3")).ok();
774
775        dag.add_dependency("t1", "t2", TaskEdge::default()).ok();
776        dag.add_dependency(
777            "t2",
778            "t3",
779            TaskEdge {
780                edge_type: EdgeType::Data,
781                condition: None,
782            },
783        )
784        .ok();
785
786        let edges = dag.edges();
787        assert_eq!(edges.len(), 2);
788
789        // Check first edge
790        let (from, to, edge) = &edges[0];
791        assert_eq!(*from, "t1");
792        assert_eq!(*to, "t2");
793        assert_eq!(edge.edge_type, EdgeType::Control);
794
795        // Check second edge
796        let (from, to, edge) = &edges[1];
797        assert_eq!(*from, "t2");
798        assert_eq!(*to, "t3");
799        assert_eq!(edge.edge_type, EdgeType::Data);
800    }
801
802    #[test]
803    fn test_get_dependencies_with_edges() {
804        let mut dag = WorkflowDag::new();
805        dag.add_task(create_test_task("t1", "Task 1")).ok();
806        dag.add_task(create_test_task("t2", "Task 2")).ok();
807        dag.add_task(create_test_task("t3", "Task 3")).ok();
808
809        dag.add_dependency(
810            "t1",
811            "t3",
812            TaskEdge {
813                edge_type: EdgeType::Data,
814                condition: None,
815            },
816        )
817        .ok();
818        dag.add_dependency("t2", "t3", TaskEdge::default()).ok();
819
820        let deps = dag.get_dependencies_with_edges("t3");
821        assert_eq!(deps.len(), 2);
822
823        // Both t1 and t2 should be dependencies of t3
824        let dep_ids: Vec<&str> = deps.iter().map(|(id, _)| id.as_str()).collect();
825        assert!(dep_ids.contains(&"t1"));
826        assert!(dep_ids.contains(&"t2"));
827
828        // No dependencies for root task
829        let root_deps = dag.get_dependencies_with_edges("t1");
830        assert!(root_deps.is_empty());
831
832        // Non-existent task returns empty
833        let missing_deps = dag.get_dependencies_with_edges("nonexistent");
834        assert!(missing_deps.is_empty());
835    }
836
837    #[test]
838    fn test_get_dependents_with_edges() {
839        let mut dag = WorkflowDag::new();
840        dag.add_task(create_test_task("t1", "Task 1")).ok();
841        dag.add_task(create_test_task("t2", "Task 2")).ok();
842        dag.add_task(create_test_task("t3", "Task 3")).ok();
843
844        dag.add_dependency("t1", "t2", TaskEdge::default()).ok();
845        dag.add_dependency("t1", "t3", TaskEdge::default()).ok();
846
847        let dependents = dag.get_dependents_with_edges("t1");
848        assert_eq!(dependents.len(), 2);
849
850        let dep_ids: Vec<&str> = dependents.iter().map(|(id, _)| id.as_str()).collect();
851        assert!(dep_ids.contains(&"t2"));
852        assert!(dep_ids.contains(&"t3"));
853    }
854
855    #[test]
856    fn test_get_edge_between() {
857        let mut dag = WorkflowDag::new();
858        dag.add_task(create_test_task("t1", "Task 1")).ok();
859        dag.add_task(create_test_task("t2", "Task 2")).ok();
860        dag.add_task(create_test_task("t3", "Task 3")).ok();
861
862        dag.add_dependency(
863            "t1",
864            "t2",
865            TaskEdge {
866                edge_type: EdgeType::Data,
867                condition: Some("output.ready".to_string()),
868            },
869        )
870        .ok();
871
872        let edge = dag.get_edge_between("t1", "t2");
873        assert!(edge.is_some());
874        let edge = edge.expect("Edge should exist");
875        assert_eq!(edge.edge_type, EdgeType::Data);
876        assert_eq!(edge.condition.as_deref(), Some("output.ready"));
877
878        // Reverse direction should not exist
879        assert!(dag.get_edge_between("t2", "t1").is_none());
880        // Non-connected nodes
881        assert!(dag.get_edge_between("t1", "t3").is_none());
882    }
883
884    #[test]
885    fn test_has_dependency() {
886        let mut dag = WorkflowDag::new();
887        dag.add_task(create_test_task("t1", "Task 1")).ok();
888        dag.add_task(create_test_task("t2", "Task 2")).ok();
889
890        dag.add_dependency("t1", "t2", TaskEdge::default()).ok();
891
892        assert!(dag.has_dependency("t1", "t2"));
893        assert!(!dag.has_dependency("t2", "t1"));
894        assert!(!dag.has_dependency("t1", "nonexistent"));
895    }
896
897    #[test]
898    fn test_has_dependencies_and_dependents() {
899        let mut dag = WorkflowDag::new();
900        dag.add_task(create_test_task("t1", "Task 1")).ok();
901        dag.add_task(create_test_task("t2", "Task 2")).ok();
902        dag.add_task(create_test_task("t3", "Task 3")).ok();
903
904        dag.add_dependency("t1", "t2", TaskEdge::default()).ok();
905        dag.add_dependency("t2", "t3", TaskEdge::default()).ok();
906
907        // t1: root, has dependents but no dependencies
908        assert!(!dag.has_dependencies("t1"));
909        assert!(dag.has_dependents("t1"));
910
911        // t2: middle, has both
912        assert!(dag.has_dependencies("t2"));
913        assert!(dag.has_dependents("t2"));
914
915        // t3: leaf, has dependencies but no dependents
916        assert!(dag.has_dependencies("t3"));
917        assert!(!dag.has_dependents("t3"));
918    }
919
920    #[test]
921    fn test_in_out_degree() {
922        let mut dag = WorkflowDag::new();
923        dag.add_task(create_test_task("t1", "Task 1")).ok();
924        dag.add_task(create_test_task("t2", "Task 2")).ok();
925        dag.add_task(create_test_task("t3", "Task 3")).ok();
926        dag.add_task(create_test_task("t4", "Task 4")).ok();
927
928        // t1 -> t3, t2 -> t3, t3 -> t4
929        dag.add_dependency("t1", "t3", TaskEdge::default()).ok();
930        dag.add_dependency("t2", "t3", TaskEdge::default()).ok();
931        dag.add_dependency("t3", "t4", TaskEdge::default()).ok();
932
933        assert_eq!(dag.in_degree("t1"), 0);
934        assert_eq!(dag.out_degree("t1"), 1);
935        assert_eq!(dag.in_degree("t3"), 2);
936        assert_eq!(dag.out_degree("t3"), 1);
937        assert_eq!(dag.in_degree("t4"), 1);
938        assert_eq!(dag.out_degree("t4"), 0);
939        // Non-existent
940        assert_eq!(dag.in_degree("nonexistent"), 0);
941    }
942
943    #[test]
944    fn test_task_ids_and_contains() {
945        let mut dag = WorkflowDag::new();
946        dag.add_task(create_test_task("t1", "Task 1")).ok();
947        dag.add_task(create_test_task("t2", "Task 2")).ok();
948
949        let ids = dag.task_ids();
950        assert_eq!(ids.len(), 2);
951        assert!(dag.contains_task("t1"));
952        assert!(dag.contains_task("t2"));
953        assert!(!dag.contains_task("t3"));
954    }
955
956    #[test]
957    fn test_remove_task() {
958        let mut dag = WorkflowDag::new();
959        dag.add_task(create_test_task("t1", "Task 1")).ok();
960        dag.add_task(create_test_task("t2", "Task 2")).ok();
961        dag.add_dependency("t1", "t2", TaskEdge::default()).ok();
962
963        assert_eq!(dag.task_count(), 2);
964        assert_eq!(dag.dependency_count(), 1);
965
966        let removed = dag.remove_task("t1");
967        assert!(removed.is_some());
968        assert_eq!(removed.as_ref().map(|t| t.id.as_str()), Some("t1"));
969        assert!(!dag.contains_task("t1"));
970
971        // Removing non-existent should return None
972        assert!(dag.remove_task("nonexistent").is_none());
973    }
974
975    #[test]
976    fn test_edges_by_type() {
977        let mut dag = WorkflowDag::new();
978        dag.add_task(create_test_task("t1", "Task 1")).ok();
979        dag.add_task(create_test_task("t2", "Task 2")).ok();
980        dag.add_task(create_test_task("t3", "Task 3")).ok();
981
982        dag.add_dependency(
983            "t1",
984            "t2",
985            TaskEdge {
986                edge_type: EdgeType::Data,
987                condition: None,
988            },
989        )
990        .ok();
991        dag.add_dependency("t1", "t3", TaskEdge::default()).ok();
992
993        let data_edges = dag.edges_by_type(EdgeType::Data);
994        assert_eq!(data_edges.len(), 1);
995        assert_eq!(data_edges[0].0, "t1");
996        assert_eq!(data_edges[0].1, "t2");
997
998        let control_edges = dag.edges_by_type(EdgeType::Control);
999        assert_eq!(control_edges.len(), 1);
1000        assert_eq!(control_edges[0].0, "t1");
1001        assert_eq!(control_edges[0].1, "t3");
1002    }
1003
1004    #[test]
1005    fn test_subgraph() {
1006        let mut dag = WorkflowDag::new();
1007        dag.add_task(create_test_task("t1", "Task 1")).ok();
1008        dag.add_task(create_test_task("t2", "Task 2")).ok();
1009        dag.add_task(create_test_task("t3", "Task 3")).ok();
1010        dag.add_task(create_test_task("t4", "Task 4")).ok();
1011
1012        dag.add_dependency("t1", "t2", TaskEdge::default()).ok();
1013        dag.add_dependency("t2", "t3", TaskEdge::default()).ok();
1014        dag.add_dependency("t3", "t4", TaskEdge::default()).ok();
1015
1016        // Extract subgraph with only t2 and t3
1017        let sub = dag.subgraph(&["t2", "t3"]);
1018        assert_eq!(sub.task_count(), 2);
1019        assert_eq!(sub.dependency_count(), 1);
1020        assert!(sub.contains_task("t2"));
1021        assert!(sub.contains_task("t3"));
1022        assert!(!sub.contains_task("t1"));
1023        assert!(!sub.contains_task("t4"));
1024    }
1025
1026    #[test]
1027    fn test_transitive_dependencies() {
1028        let mut dag = WorkflowDag::new();
1029        dag.add_task(create_test_task("t1", "Task 1")).ok();
1030        dag.add_task(create_test_task("t2", "Task 2")).ok();
1031        dag.add_task(create_test_task("t3", "Task 3")).ok();
1032        dag.add_task(create_test_task("t4", "Task 4")).ok();
1033
1034        dag.add_dependency("t1", "t2", TaskEdge::default()).ok();
1035        dag.add_dependency("t2", "t3", TaskEdge::default()).ok();
1036        dag.add_dependency("t3", "t4", TaskEdge::default()).ok();
1037
1038        let trans_deps = dag.transitive_dependencies("t4");
1039        assert_eq!(trans_deps.len(), 3);
1040        assert!(trans_deps.contains(&"t1".to_string()));
1041        assert!(trans_deps.contains(&"t2".to_string()));
1042        assert!(trans_deps.contains(&"t3".to_string()));
1043
1044        // Root has no transitive dependencies
1045        let root_deps = dag.transitive_dependencies("t1");
1046        assert!(root_deps.is_empty());
1047    }
1048
1049    #[test]
1050    fn test_transitive_dependents() {
1051        let mut dag = WorkflowDag::new();
1052        dag.add_task(create_test_task("t1", "Task 1")).ok();
1053        dag.add_task(create_test_task("t2", "Task 2")).ok();
1054        dag.add_task(create_test_task("t3", "Task 3")).ok();
1055        dag.add_task(create_test_task("t4", "Task 4")).ok();
1056
1057        dag.add_dependency("t1", "t2", TaskEdge::default()).ok();
1058        dag.add_dependency("t2", "t3", TaskEdge::default()).ok();
1059        dag.add_dependency("t3", "t4", TaskEdge::default()).ok();
1060
1061        let trans_dependents = dag.transitive_dependents("t1");
1062        assert_eq!(trans_dependents.len(), 3);
1063        assert!(trans_dependents.contains(&"t2".to_string()));
1064        assert!(trans_dependents.contains(&"t3".to_string()));
1065        assert!(trans_dependents.contains(&"t4".to_string()));
1066
1067        // Leaf has no transitive dependents
1068        let leaf_deps = dag.transitive_dependents("t4");
1069        assert!(leaf_deps.is_empty());
1070    }
1071
1072    #[test]
1073    fn test_summary() {
1074        let mut dag = WorkflowDag::new();
1075        dag.add_task(create_test_task("t1", "Task 1")).ok();
1076        dag.add_task(create_test_task("t2", "Task 2")).ok();
1077        dag.add_task(create_test_task("t3", "Task 3")).ok();
1078        dag.add_task(create_test_task("t4", "Task 4")).ok();
1079
1080        dag.add_dependency(
1081            "t1",
1082            "t2",
1083            TaskEdge {
1084                edge_type: EdgeType::Data,
1085                condition: None,
1086            },
1087        )
1088        .ok();
1089        dag.add_dependency("t1", "t3", TaskEdge::default()).ok();
1090        dag.add_dependency("t2", "t4", TaskEdge::default()).ok();
1091        dag.add_dependency("t3", "t4", TaskEdge::default()).ok();
1092
1093        let summary = dag.summary();
1094        assert_eq!(summary.node_count, 4);
1095        assert_eq!(summary.edge_count, 4);
1096        assert_eq!(summary.root_count, 1);
1097        assert_eq!(summary.leaf_count, 1);
1098        assert_eq!(summary.max_in_degree, 2); // t4 has 2 incoming
1099        assert_eq!(summary.max_out_degree, 2); // t1 has 2 outgoing
1100        assert_eq!(summary.data_edge_count, 1);
1101        assert_eq!(summary.control_edge_count, 3);
1102        assert_eq!(summary.conditional_edge_count, 0);
1103    }
1104
1105    #[test]
1106    fn test_edge_pairs() {
1107        let mut dag = WorkflowDag::new();
1108        dag.add_task(create_test_task("t1", "Task 1")).ok();
1109        dag.add_task(create_test_task("t2", "Task 2")).ok();
1110        dag.add_dependency("t1", "t2", TaskEdge::default()).ok();
1111
1112        let pairs = dag.edge_pairs();
1113        assert_eq!(pairs.len(), 1);
1114        assert_eq!(pairs[0], ("t1".to_string(), "t2".to_string()));
1115    }
1116
1117    #[test]
1118    fn test_get_dependencies_and_dependents() {
1119        let mut dag = WorkflowDag::new();
1120        dag.add_task(create_test_task("t1", "Task 1")).ok();
1121        dag.add_task(create_test_task("t2", "Task 2")).ok();
1122        dag.add_task(create_test_task("t3", "Task 3")).ok();
1123
1124        dag.add_dependency("t1", "t3", TaskEdge::default()).ok();
1125        dag.add_dependency("t2", "t3", TaskEdge::default()).ok();
1126
1127        let deps = dag.get_dependencies("t3");
1128        assert_eq!(deps.len(), 2);
1129        assert!(deps.contains(&"t1".to_string()));
1130        assert!(deps.contains(&"t2".to_string()));
1131
1132        let dependents = dag.get_dependents("t1");
1133        assert_eq!(dependents.len(), 1);
1134        assert!(dependents.contains(&"t3".to_string()));
1135    }
1136}