1use crate::error::{DagError, Result};
4use petgraph::Direction;
5use petgraph::graph::{DiGraph, NodeIndex};
6use petgraph::visit::EdgeRef;
7use serde::{Deserialize, Serialize};
8use std::collections::{HashMap, HashSet, VecDeque};
9use std::hash::{Hash, Hasher};
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct TaskNode {
14 pub id: String,
16 pub name: String,
18 pub description: Option<String>,
20 pub config: serde_json::Value,
22 pub retry: RetryPolicy,
24 pub timeout_secs: Option<u64>,
26 pub resources: ResourceRequirements,
28 pub metadata: HashMap<String, String>,
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct RetryPolicy {
35 pub max_attempts: u32,
37 pub delay_ms: u64,
39 pub backoff_multiplier: f64,
41 pub max_delay_ms: u64,
43}
44
45impl Default for RetryPolicy {
46 fn default() -> Self {
47 Self {
48 max_attempts: 3,
49 delay_ms: 1000,
50 backoff_multiplier: 2.0,
51 max_delay_ms: 60000,
52 }
53 }
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct ResourceRequirements {
59 pub cpu_cores: f64,
61 pub memory_mb: u64,
63 pub gpu: bool,
65 pub disk_mb: u64,
67 pub custom: HashMap<String, f64>,
69}
70
71impl Default for ResourceRequirements {
72 fn default() -> Self {
73 Self {
74 cpu_cores: 1.0,
75 memory_mb: 1024,
76 gpu: false,
77 disk_mb: 1024,
78 custom: HashMap::new(),
79 }
80 }
81}
82
83impl PartialEq for TaskNode {
84 fn eq(&self, other: &Self) -> bool {
85 self.id == other.id
86 }
87}
88
89impl Eq for TaskNode {}
90
91impl Hash for TaskNode {
92 fn hash<H: Hasher>(&self, state: &mut H) {
93 self.id.hash(state);
94 }
95}
96
97#[derive(Debug, Clone, Serialize, Deserialize)]
99pub struct TaskEdge {
100 pub edge_type: EdgeType,
102 pub condition: Option<String>,
104}
105
106#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
108pub enum EdgeType {
109 Data,
111 Control,
113 Conditional,
115}
116
117impl Default for TaskEdge {
118 fn default() -> Self {
119 Self {
120 edge_type: EdgeType::Control,
121 condition: None,
122 }
123 }
124}
125
126#[derive(Debug, Clone, Serialize, Deserialize)]
127
128pub struct WorkflowDag {
130 pub(crate) graph: DiGraph<TaskNode, TaskEdge>,
132 pub(crate) task_map: HashMap<String, NodeIndex>,
134}
135
136impl WorkflowDag {
137 pub fn new() -> Self {
139 Self {
140 graph: DiGraph::new(),
141 task_map: HashMap::new(),
142 }
143 }
144
145 pub fn add_task(&mut self, task: TaskNode) -> Result<NodeIndex> {
147 if self.task_map.contains_key(&task.id) {
148 return Err(
149 DagError::InvalidNode(format!("Task '{}' already exists in DAG", task.id)).into(),
150 );
151 }
152
153 let node_index = self.graph.add_node(task.clone());
154 self.task_map.insert(task.id.clone(), node_index);
155 Ok(node_index)
156 }
157
158 pub fn add_dependency(
160 &mut self,
161 from_task_id: &str,
162 to_task_id: &str,
163 edge: TaskEdge,
164 ) -> Result<()> {
165 let from_idx = self
166 .task_map
167 .get(from_task_id)
168 .ok_or_else(|| DagError::invalid_node(from_task_id))?;
169
170 let to_idx = self
171 .task_map
172 .get(to_task_id)
173 .ok_or_else(|| DagError::invalid_node(to_task_id))?;
174
175 self.graph.add_edge(*from_idx, *to_idx, edge);
176 Ok(())
177 }
178
179 pub fn get_task(&self, task_id: &str) -> Option<&TaskNode> {
181 self.task_map
182 .get(task_id)
183 .and_then(|idx| self.graph.node_weight(*idx))
184 }
185
186 pub fn get_task_mut(&mut self, task_id: &str) -> Option<&mut TaskNode> {
188 self.task_map
189 .get(task_id)
190 .and_then(|idx| self.graph.node_weight_mut(*idx))
191 }
192
193 pub fn get_dependencies(&self, task_id: &str) -> Vec<String> {
195 if let Some(&idx) = self.task_map.get(task_id) {
196 self.graph
197 .edges_directed(idx, Direction::Incoming)
198 .filter_map(|edge| {
199 self.graph
200 .node_weight(edge.source())
201 .map(|task| task.id.clone())
202 })
203 .collect()
204 } else {
205 Vec::new()
206 }
207 }
208
209 pub fn get_dependents(&self, task_id: &str) -> Vec<String> {
211 if let Some(&idx) = self.task_map.get(task_id) {
212 self.graph
213 .edges_directed(idx, Direction::Outgoing)
214 .filter_map(|edge| {
215 self.graph
216 .node_weight(edge.target())
217 .map(|task| task.id.clone())
218 })
219 .collect()
220 } else {
221 Vec::new()
222 }
223 }
224
225 pub fn validate(&self) -> Result<()> {
227 if self.graph.node_count() == 0 {
229 return Err(DagError::EmptyDag.into());
230 }
231
232 self.check_cycles()?;
234
235 self.check_reachability()?;
237
238 Ok(())
239 }
240
241 fn check_cycles(&self) -> Result<()> {
243 let mut visited = HashSet::new();
244 let mut rec_stack = HashSet::new();
245
246 for node_idx in self.graph.node_indices() {
247 if !visited.contains(&node_idx) {
248 if let Some(cycle_path) =
249 self.dfs_cycle_check(node_idx, &mut visited, &mut rec_stack)
250 {
251 return Err(DagError::cycle(cycle_path).into());
252 }
253 }
254 }
255
256 Ok(())
257 }
258
259 fn dfs_cycle_check(
261 &self,
262 node: NodeIndex,
263 visited: &mut HashSet<NodeIndex>,
264 rec_stack: &mut HashSet<NodeIndex>,
265 ) -> Option<String> {
266 visited.insert(node);
267 rec_stack.insert(node);
268
269 for neighbor in self.graph.neighbors(node) {
270 if !visited.contains(&neighbor) {
271 if let Some(path) = self.dfs_cycle_check(neighbor, visited, rec_stack) {
272 return Some(path);
273 }
274 } else if rec_stack.contains(&neighbor) {
275 let current_task = self.graph.node_weight(node).map(|t| &t.id)?;
277 let next_task = self.graph.node_weight(neighbor).map(|t| &t.id)?;
278 return Some(format!("{} -> {}", current_task, next_task));
279 }
280 }
281
282 rec_stack.remove(&node);
283 None
284 }
285
286 fn check_reachability(&self) -> Result<()> {
288 let root_nodes: Vec<NodeIndex> = self
290 .graph
291 .node_indices()
292 .filter(|&idx| self.graph.edges_directed(idx, Direction::Incoming).count() == 0)
293 .collect();
294
295 if root_nodes.is_empty() {
296 return Ok(());
298 }
299
300 let mut reachable = HashSet::new();
302 let mut queue = VecDeque::from(root_nodes);
303
304 while let Some(node) = queue.pop_front() {
305 if reachable.insert(node) {
306 for neighbor in self.graph.neighbors(node) {
307 if !reachable.contains(&neighbor) {
308 queue.push_back(neighbor);
309 }
310 }
311 }
312 }
313
314 for node_idx in self.graph.node_indices() {
316 if !reachable.contains(&node_idx) {
317 if let Some(task) = self.graph.node_weight(node_idx) {
318 return Err(DagError::UnreachableNode(task.id.clone()).into());
319 }
320 }
321 }
322
323 Ok(())
324 }
325
326 pub fn tasks(&self) -> Vec<&TaskNode> {
328 self.graph
329 .node_indices()
330 .filter_map(|idx| self.graph.node_weight(idx))
331 .collect()
332 }
333
334 pub fn task_count(&self) -> usize {
336 self.graph.node_count()
337 }
338
339 pub fn dependency_count(&self) -> usize {
341 self.graph.edge_count()
342 }
343
344 pub fn root_tasks(&self) -> Vec<&TaskNode> {
346 self.graph
347 .node_indices()
348 .filter(|&idx| self.graph.edges_directed(idx, Direction::Incoming).count() == 0)
349 .filter_map(|idx| self.graph.node_weight(idx))
350 .collect()
351 }
352
353 pub fn leaf_tasks(&self) -> Vec<&TaskNode> {
355 self.graph
356 .node_indices()
357 .filter(|&idx| self.graph.edges_directed(idx, Direction::Outgoing).count() == 0)
358 .filter_map(|idx| self.graph.node_weight(idx))
359 .collect()
360 }
361
362 pub fn edges(&self) -> Vec<(&str, &str, &TaskEdge)> {
367 self.graph
368 .edge_indices()
369 .filter_map(|edge_idx| {
370 let (from_idx, to_idx) = self.graph.edge_endpoints(edge_idx)?;
371 let from_node = self.graph.node_weight(from_idx)?;
372 let to_node = self.graph.node_weight(to_idx)?;
373 let edge = self.graph.edge_weight(edge_idx)?;
374 Some((from_node.id.as_str(), to_node.id.as_str(), edge))
375 })
376 .collect()
377 }
378
379 pub fn edge_pairs(&self) -> Vec<(String, String)> {
383 self.graph
384 .edge_indices()
385 .filter_map(|edge_idx| {
386 let (from_idx, to_idx) = self.graph.edge_endpoints(edge_idx)?;
387 let from_node = self.graph.node_weight(from_idx)?;
388 let to_node = self.graph.node_weight(to_idx)?;
389 Some((from_node.id.clone(), to_node.id.clone()))
390 })
391 .collect()
392 }
393
394 pub fn get_dependencies_with_edges(&self, task_id: &str) -> Vec<(String, &TaskEdge)> {
399 if let Some(&idx) = self.task_map.get(task_id) {
400 self.graph
401 .edges_directed(idx, Direction::Incoming)
402 .filter_map(|edge| {
403 let source_node = self.graph.node_weight(edge.source())?;
404 Some((source_node.id.clone(), edge.weight()))
405 })
406 .collect()
407 } else {
408 Vec::new()
409 }
410 }
411
412 pub fn get_dependents_with_edges(&self, task_id: &str) -> Vec<(String, &TaskEdge)> {
417 if let Some(&idx) = self.task_map.get(task_id) {
418 self.graph
419 .edges_directed(idx, Direction::Outgoing)
420 .filter_map(|edge| {
421 let target_node = self.graph.node_weight(edge.target())?;
422 Some((target_node.id.clone(), edge.weight()))
423 })
424 .collect()
425 } else {
426 Vec::new()
427 }
428 }
429
430 pub fn get_edge_between(&self, from_task_id: &str, to_task_id: &str) -> Option<&TaskEdge> {
434 let from_idx = self.task_map.get(from_task_id)?;
435 let to_idx = self.task_map.get(to_task_id)?;
436 self.graph
437 .find_edge(*from_idx, *to_idx)
438 .and_then(|edge_idx| self.graph.edge_weight(edge_idx))
439 }
440
441 pub fn has_dependency(&self, from_task_id: &str, to_task_id: &str) -> bool {
445 self.get_edge_between(from_task_id, to_task_id).is_some()
446 }
447
448 pub fn has_dependencies(&self, task_id: &str) -> bool {
450 if let Some(&idx) = self.task_map.get(task_id) {
451 self.graph.edges_directed(idx, Direction::Incoming).count() > 0
452 } else {
453 false
454 }
455 }
456
457 pub fn has_dependents(&self, task_id: &str) -> bool {
459 if let Some(&idx) = self.task_map.get(task_id) {
460 self.graph.edges_directed(idx, Direction::Outgoing).count() > 0
461 } else {
462 false
463 }
464 }
465
466 pub fn in_degree(&self, task_id: &str) -> usize {
468 if let Some(&idx) = self.task_map.get(task_id) {
469 self.graph.edges_directed(idx, Direction::Incoming).count()
470 } else {
471 0
472 }
473 }
474
475 pub fn out_degree(&self, task_id: &str) -> usize {
477 if let Some(&idx) = self.task_map.get(task_id) {
478 self.graph.edges_directed(idx, Direction::Outgoing).count()
479 } else {
480 0
481 }
482 }
483
484 pub fn task_ids(&self) -> Vec<String> {
486 self.task_map.keys().cloned().collect()
487 }
488
489 pub fn contains_task(&self, task_id: &str) -> bool {
491 self.task_map.contains_key(task_id)
492 }
493
494 pub fn remove_task(&mut self, task_id: &str) -> Option<TaskNode> {
498 let node_idx = self.task_map.remove(task_id)?;
499 self.graph.remove_node(node_idx)
500 }
501
502 pub fn edges_by_type(&self, edge_type: EdgeType) -> Vec<(&str, &str, &TaskEdge)> {
504 self.graph
505 .edge_indices()
506 .filter_map(|edge_idx| {
507 let edge = self.graph.edge_weight(edge_idx)?;
508 if edge.edge_type != edge_type {
509 return None;
510 }
511 let (from_idx, to_idx) = self.graph.edge_endpoints(edge_idx)?;
512 let from_node = self.graph.node_weight(from_idx)?;
513 let to_node = self.graph.node_weight(to_idx)?;
514 Some((from_node.id.as_str(), to_node.id.as_str(), edge))
515 })
516 .collect()
517 }
518
519 pub fn subgraph(&self, task_ids: &[&str]) -> WorkflowDag {
523 let mut sub = WorkflowDag::new();
524 let id_set: HashSet<&str> = task_ids.iter().copied().collect();
525
526 for task_id in task_ids {
528 if let Some(task) = self.get_task(task_id) {
529 let _ = sub.add_task(task.clone());
531 }
532 }
533
534 for (from_id, to_id, edge) in self.edges() {
536 if id_set.contains(from_id) && id_set.contains(to_id) {
537 let _ = sub.add_dependency(from_id, to_id, edge.clone());
538 }
539 }
540
541 sub
542 }
543
544 pub fn transitive_dependencies(&self, task_id: &str) -> Vec<String> {
549 let mut visited = HashSet::new();
550 let mut queue = VecDeque::new();
551
552 for dep in self.get_dependencies(task_id) {
554 if visited.insert(dep.clone()) {
555 queue.push_back(dep);
556 }
557 }
558
559 while let Some(current) = queue.pop_front() {
560 for dep in self.get_dependencies(¤t) {
561 if visited.insert(dep.clone()) {
562 queue.push_back(dep);
563 }
564 }
565 }
566
567 visited.into_iter().collect()
568 }
569
570 pub fn transitive_dependents(&self, task_id: &str) -> Vec<String> {
574 let mut visited = HashSet::new();
575 let mut queue = VecDeque::new();
576
577 for dep in self.get_dependents(task_id) {
579 if visited.insert(dep.clone()) {
580 queue.push_back(dep);
581 }
582 }
583
584 while let Some(current) = queue.pop_front() {
585 for dep in self.get_dependents(¤t) {
586 if visited.insert(dep.clone()) {
587 queue.push_back(dep);
588 }
589 }
590 }
591
592 visited.into_iter().collect()
593 }
594
595 pub fn summary(&self) -> DagSummary {
597 let node_count = self.graph.node_count();
598 let edge_count = self.graph.edge_count();
599 let root_count = self.root_tasks().len();
600 let leaf_count = self.leaf_tasks().len();
601
602 let max_in_degree = self
603 .graph
604 .node_indices()
605 .map(|idx| self.graph.edges_directed(idx, Direction::Incoming).count())
606 .max()
607 .unwrap_or(0);
608
609 let max_out_degree = self
610 .graph
611 .node_indices()
612 .map(|idx| self.graph.edges_directed(idx, Direction::Outgoing).count())
613 .max()
614 .unwrap_or(0);
615
616 let data_edges = self.edges_by_type(EdgeType::Data).len();
617 let control_edges = self.edges_by_type(EdgeType::Control).len();
618 let conditional_edges = self.edges_by_type(EdgeType::Conditional).len();
619
620 DagSummary {
621 node_count,
622 edge_count,
623 root_count,
624 leaf_count,
625 max_in_degree,
626 max_out_degree,
627 data_edge_count: data_edges,
628 control_edge_count: control_edges,
629 conditional_edge_count: conditional_edges,
630 }
631 }
632}
633
634#[derive(Debug, Clone, Serialize, Deserialize)]
636pub struct DagSummary {
637 pub node_count: usize,
639 pub edge_count: usize,
641 pub root_count: usize,
643 pub leaf_count: usize,
645 pub max_in_degree: usize,
647 pub max_out_degree: usize,
649 pub data_edge_count: usize,
651 pub control_edge_count: usize,
653 pub conditional_edge_count: usize,
655}
656
657impl Default for WorkflowDag {
658 fn default() -> Self {
659 Self::new()
660 }
661}
662
663#[cfg(test)]
664mod tests {
665 use super::*;
666
667 fn create_test_task(id: &str, name: &str) -> TaskNode {
668 TaskNode {
669 id: id.to_string(),
670 name: name.to_string(),
671 description: None,
672 config: serde_json::json!({}),
673 retry: RetryPolicy::default(),
674 timeout_secs: Some(60),
675 resources: ResourceRequirements::default(),
676 metadata: HashMap::new(),
677 }
678 }
679
680 #[test]
681 fn test_add_task() {
682 let mut dag = WorkflowDag::new();
683 let task = create_test_task("task1", "Task 1");
684 let result = dag.add_task(task);
685 assert!(result.is_ok());
686 assert_eq!(dag.task_count(), 1);
687 }
688
689 #[test]
690 fn test_duplicate_task() {
691 let mut dag = WorkflowDag::new();
692 let task1 = create_test_task("task1", "Task 1");
693 let task2 = create_test_task("task1", "Task 1 Duplicate");
694
695 dag.add_task(task1).ok();
696 let result = dag.add_task(task2);
697 assert!(result.is_err());
698 }
699
700 #[test]
701 fn test_add_dependency() {
702 let mut dag = WorkflowDag::new();
703 dag.add_task(create_test_task("task1", "Task 1")).ok();
704 dag.add_task(create_test_task("task2", "Task 2")).ok();
705
706 let result = dag.add_dependency("task1", "task2", TaskEdge::default());
707 assert!(result.is_ok());
708 assert_eq!(dag.dependency_count(), 1);
709 }
710
711 #[test]
712 fn test_cycle_detection() {
713 let mut dag = WorkflowDag::new();
714 dag.add_task(create_test_task("task1", "Task 1")).ok();
715 dag.add_task(create_test_task("task2", "Task 2")).ok();
716 dag.add_task(create_test_task("task3", "Task 3")).ok();
717
718 dag.add_dependency("task1", "task2", TaskEdge::default())
720 .ok();
721 dag.add_dependency("task2", "task3", TaskEdge::default())
722 .ok();
723 dag.add_dependency("task3", "task1", TaskEdge::default())
724 .ok();
725
726 let result = dag.validate();
727 assert!(result.is_err());
728 }
729
730 #[test]
731 fn test_valid_dag() {
732 let mut dag = WorkflowDag::new();
733 dag.add_task(create_test_task("task1", "Task 1")).ok();
734 dag.add_task(create_test_task("task2", "Task 2")).ok();
735 dag.add_task(create_test_task("task3", "Task 3")).ok();
736
737 dag.add_dependency("task1", "task2", TaskEdge::default())
739 .ok();
740 dag.add_dependency("task1", "task3", TaskEdge::default())
741 .ok();
742
743 let result = dag.validate();
744 assert!(result.is_ok());
745 }
746
747 #[test]
748 fn test_root_and_leaf_tasks() {
749 let mut dag = WorkflowDag::new();
750 dag.add_task(create_test_task("task1", "Task 1")).ok();
751 dag.add_task(create_test_task("task2", "Task 2")).ok();
752 dag.add_task(create_test_task("task3", "Task 3")).ok();
753
754 dag.add_dependency("task1", "task2", TaskEdge::default())
755 .ok();
756 dag.add_dependency("task2", "task3", TaskEdge::default())
757 .ok();
758
759 let roots = dag.root_tasks();
760 assert_eq!(roots.len(), 1);
761 assert_eq!(roots[0].id, "task1");
762
763 let leaves = dag.leaf_tasks();
764 assert_eq!(leaves.len(), 1);
765 assert_eq!(leaves[0].id, "task3");
766 }
767
768 #[test]
769 fn test_edges() {
770 let mut dag = WorkflowDag::new();
771 dag.add_task(create_test_task("t1", "Task 1")).ok();
772 dag.add_task(create_test_task("t2", "Task 2")).ok();
773 dag.add_task(create_test_task("t3", "Task 3")).ok();
774
775 dag.add_dependency("t1", "t2", TaskEdge::default()).ok();
776 dag.add_dependency(
777 "t2",
778 "t3",
779 TaskEdge {
780 edge_type: EdgeType::Data,
781 condition: None,
782 },
783 )
784 .ok();
785
786 let edges = dag.edges();
787 assert_eq!(edges.len(), 2);
788
789 let (from, to, edge) = &edges[0];
791 assert_eq!(*from, "t1");
792 assert_eq!(*to, "t2");
793 assert_eq!(edge.edge_type, EdgeType::Control);
794
795 let (from, to, edge) = &edges[1];
797 assert_eq!(*from, "t2");
798 assert_eq!(*to, "t3");
799 assert_eq!(edge.edge_type, EdgeType::Data);
800 }
801
802 #[test]
803 fn test_get_dependencies_with_edges() {
804 let mut dag = WorkflowDag::new();
805 dag.add_task(create_test_task("t1", "Task 1")).ok();
806 dag.add_task(create_test_task("t2", "Task 2")).ok();
807 dag.add_task(create_test_task("t3", "Task 3")).ok();
808
809 dag.add_dependency(
810 "t1",
811 "t3",
812 TaskEdge {
813 edge_type: EdgeType::Data,
814 condition: None,
815 },
816 )
817 .ok();
818 dag.add_dependency("t2", "t3", TaskEdge::default()).ok();
819
820 let deps = dag.get_dependencies_with_edges("t3");
821 assert_eq!(deps.len(), 2);
822
823 let dep_ids: Vec<&str> = deps.iter().map(|(id, _)| id.as_str()).collect();
825 assert!(dep_ids.contains(&"t1"));
826 assert!(dep_ids.contains(&"t2"));
827
828 let root_deps = dag.get_dependencies_with_edges("t1");
830 assert!(root_deps.is_empty());
831
832 let missing_deps = dag.get_dependencies_with_edges("nonexistent");
834 assert!(missing_deps.is_empty());
835 }
836
837 #[test]
838 fn test_get_dependents_with_edges() {
839 let mut dag = WorkflowDag::new();
840 dag.add_task(create_test_task("t1", "Task 1")).ok();
841 dag.add_task(create_test_task("t2", "Task 2")).ok();
842 dag.add_task(create_test_task("t3", "Task 3")).ok();
843
844 dag.add_dependency("t1", "t2", TaskEdge::default()).ok();
845 dag.add_dependency("t1", "t3", TaskEdge::default()).ok();
846
847 let dependents = dag.get_dependents_with_edges("t1");
848 assert_eq!(dependents.len(), 2);
849
850 let dep_ids: Vec<&str> = dependents.iter().map(|(id, _)| id.as_str()).collect();
851 assert!(dep_ids.contains(&"t2"));
852 assert!(dep_ids.contains(&"t3"));
853 }
854
855 #[test]
856 fn test_get_edge_between() {
857 let mut dag = WorkflowDag::new();
858 dag.add_task(create_test_task("t1", "Task 1")).ok();
859 dag.add_task(create_test_task("t2", "Task 2")).ok();
860 dag.add_task(create_test_task("t3", "Task 3")).ok();
861
862 dag.add_dependency(
863 "t1",
864 "t2",
865 TaskEdge {
866 edge_type: EdgeType::Data,
867 condition: Some("output.ready".to_string()),
868 },
869 )
870 .ok();
871
872 let edge = dag.get_edge_between("t1", "t2");
873 assert!(edge.is_some());
874 let edge = edge.expect("Edge should exist");
875 assert_eq!(edge.edge_type, EdgeType::Data);
876 assert_eq!(edge.condition.as_deref(), Some("output.ready"));
877
878 assert!(dag.get_edge_between("t2", "t1").is_none());
880 assert!(dag.get_edge_between("t1", "t3").is_none());
882 }
883
884 #[test]
885 fn test_has_dependency() {
886 let mut dag = WorkflowDag::new();
887 dag.add_task(create_test_task("t1", "Task 1")).ok();
888 dag.add_task(create_test_task("t2", "Task 2")).ok();
889
890 dag.add_dependency("t1", "t2", TaskEdge::default()).ok();
891
892 assert!(dag.has_dependency("t1", "t2"));
893 assert!(!dag.has_dependency("t2", "t1"));
894 assert!(!dag.has_dependency("t1", "nonexistent"));
895 }
896
897 #[test]
898 fn test_has_dependencies_and_dependents() {
899 let mut dag = WorkflowDag::new();
900 dag.add_task(create_test_task("t1", "Task 1")).ok();
901 dag.add_task(create_test_task("t2", "Task 2")).ok();
902 dag.add_task(create_test_task("t3", "Task 3")).ok();
903
904 dag.add_dependency("t1", "t2", TaskEdge::default()).ok();
905 dag.add_dependency("t2", "t3", TaskEdge::default()).ok();
906
907 assert!(!dag.has_dependencies("t1"));
909 assert!(dag.has_dependents("t1"));
910
911 assert!(dag.has_dependencies("t2"));
913 assert!(dag.has_dependents("t2"));
914
915 assert!(dag.has_dependencies("t3"));
917 assert!(!dag.has_dependents("t3"));
918 }
919
920 #[test]
921 fn test_in_out_degree() {
922 let mut dag = WorkflowDag::new();
923 dag.add_task(create_test_task("t1", "Task 1")).ok();
924 dag.add_task(create_test_task("t2", "Task 2")).ok();
925 dag.add_task(create_test_task("t3", "Task 3")).ok();
926 dag.add_task(create_test_task("t4", "Task 4")).ok();
927
928 dag.add_dependency("t1", "t3", TaskEdge::default()).ok();
930 dag.add_dependency("t2", "t3", TaskEdge::default()).ok();
931 dag.add_dependency("t3", "t4", TaskEdge::default()).ok();
932
933 assert_eq!(dag.in_degree("t1"), 0);
934 assert_eq!(dag.out_degree("t1"), 1);
935 assert_eq!(dag.in_degree("t3"), 2);
936 assert_eq!(dag.out_degree("t3"), 1);
937 assert_eq!(dag.in_degree("t4"), 1);
938 assert_eq!(dag.out_degree("t4"), 0);
939 assert_eq!(dag.in_degree("nonexistent"), 0);
941 }
942
943 #[test]
944 fn test_task_ids_and_contains() {
945 let mut dag = WorkflowDag::new();
946 dag.add_task(create_test_task("t1", "Task 1")).ok();
947 dag.add_task(create_test_task("t2", "Task 2")).ok();
948
949 let ids = dag.task_ids();
950 assert_eq!(ids.len(), 2);
951 assert!(dag.contains_task("t1"));
952 assert!(dag.contains_task("t2"));
953 assert!(!dag.contains_task("t3"));
954 }
955
956 #[test]
957 fn test_remove_task() {
958 let mut dag = WorkflowDag::new();
959 dag.add_task(create_test_task("t1", "Task 1")).ok();
960 dag.add_task(create_test_task("t2", "Task 2")).ok();
961 dag.add_dependency("t1", "t2", TaskEdge::default()).ok();
962
963 assert_eq!(dag.task_count(), 2);
964 assert_eq!(dag.dependency_count(), 1);
965
966 let removed = dag.remove_task("t1");
967 assert!(removed.is_some());
968 assert_eq!(removed.as_ref().map(|t| t.id.as_str()), Some("t1"));
969 assert!(!dag.contains_task("t1"));
970
971 assert!(dag.remove_task("nonexistent").is_none());
973 }
974
975 #[test]
976 fn test_edges_by_type() {
977 let mut dag = WorkflowDag::new();
978 dag.add_task(create_test_task("t1", "Task 1")).ok();
979 dag.add_task(create_test_task("t2", "Task 2")).ok();
980 dag.add_task(create_test_task("t3", "Task 3")).ok();
981
982 dag.add_dependency(
983 "t1",
984 "t2",
985 TaskEdge {
986 edge_type: EdgeType::Data,
987 condition: None,
988 },
989 )
990 .ok();
991 dag.add_dependency("t1", "t3", TaskEdge::default()).ok();
992
993 let data_edges = dag.edges_by_type(EdgeType::Data);
994 assert_eq!(data_edges.len(), 1);
995 assert_eq!(data_edges[0].0, "t1");
996 assert_eq!(data_edges[0].1, "t2");
997
998 let control_edges = dag.edges_by_type(EdgeType::Control);
999 assert_eq!(control_edges.len(), 1);
1000 assert_eq!(control_edges[0].0, "t1");
1001 assert_eq!(control_edges[0].1, "t3");
1002 }
1003
1004 #[test]
1005 fn test_subgraph() {
1006 let mut dag = WorkflowDag::new();
1007 dag.add_task(create_test_task("t1", "Task 1")).ok();
1008 dag.add_task(create_test_task("t2", "Task 2")).ok();
1009 dag.add_task(create_test_task("t3", "Task 3")).ok();
1010 dag.add_task(create_test_task("t4", "Task 4")).ok();
1011
1012 dag.add_dependency("t1", "t2", TaskEdge::default()).ok();
1013 dag.add_dependency("t2", "t3", TaskEdge::default()).ok();
1014 dag.add_dependency("t3", "t4", TaskEdge::default()).ok();
1015
1016 let sub = dag.subgraph(&["t2", "t3"]);
1018 assert_eq!(sub.task_count(), 2);
1019 assert_eq!(sub.dependency_count(), 1);
1020 assert!(sub.contains_task("t2"));
1021 assert!(sub.contains_task("t3"));
1022 assert!(!sub.contains_task("t1"));
1023 assert!(!sub.contains_task("t4"));
1024 }
1025
1026 #[test]
1027 fn test_transitive_dependencies() {
1028 let mut dag = WorkflowDag::new();
1029 dag.add_task(create_test_task("t1", "Task 1")).ok();
1030 dag.add_task(create_test_task("t2", "Task 2")).ok();
1031 dag.add_task(create_test_task("t3", "Task 3")).ok();
1032 dag.add_task(create_test_task("t4", "Task 4")).ok();
1033
1034 dag.add_dependency("t1", "t2", TaskEdge::default()).ok();
1035 dag.add_dependency("t2", "t3", TaskEdge::default()).ok();
1036 dag.add_dependency("t3", "t4", TaskEdge::default()).ok();
1037
1038 let trans_deps = dag.transitive_dependencies("t4");
1039 assert_eq!(trans_deps.len(), 3);
1040 assert!(trans_deps.contains(&"t1".to_string()));
1041 assert!(trans_deps.contains(&"t2".to_string()));
1042 assert!(trans_deps.contains(&"t3".to_string()));
1043
1044 let root_deps = dag.transitive_dependencies("t1");
1046 assert!(root_deps.is_empty());
1047 }
1048
1049 #[test]
1050 fn test_transitive_dependents() {
1051 let mut dag = WorkflowDag::new();
1052 dag.add_task(create_test_task("t1", "Task 1")).ok();
1053 dag.add_task(create_test_task("t2", "Task 2")).ok();
1054 dag.add_task(create_test_task("t3", "Task 3")).ok();
1055 dag.add_task(create_test_task("t4", "Task 4")).ok();
1056
1057 dag.add_dependency("t1", "t2", TaskEdge::default()).ok();
1058 dag.add_dependency("t2", "t3", TaskEdge::default()).ok();
1059 dag.add_dependency("t3", "t4", TaskEdge::default()).ok();
1060
1061 let trans_dependents = dag.transitive_dependents("t1");
1062 assert_eq!(trans_dependents.len(), 3);
1063 assert!(trans_dependents.contains(&"t2".to_string()));
1064 assert!(trans_dependents.contains(&"t3".to_string()));
1065 assert!(trans_dependents.contains(&"t4".to_string()));
1066
1067 let leaf_deps = dag.transitive_dependents("t4");
1069 assert!(leaf_deps.is_empty());
1070 }
1071
1072 #[test]
1073 fn test_summary() {
1074 let mut dag = WorkflowDag::new();
1075 dag.add_task(create_test_task("t1", "Task 1")).ok();
1076 dag.add_task(create_test_task("t2", "Task 2")).ok();
1077 dag.add_task(create_test_task("t3", "Task 3")).ok();
1078 dag.add_task(create_test_task("t4", "Task 4")).ok();
1079
1080 dag.add_dependency(
1081 "t1",
1082 "t2",
1083 TaskEdge {
1084 edge_type: EdgeType::Data,
1085 condition: None,
1086 },
1087 )
1088 .ok();
1089 dag.add_dependency("t1", "t3", TaskEdge::default()).ok();
1090 dag.add_dependency("t2", "t4", TaskEdge::default()).ok();
1091 dag.add_dependency("t3", "t4", TaskEdge::default()).ok();
1092
1093 let summary = dag.summary();
1094 assert_eq!(summary.node_count, 4);
1095 assert_eq!(summary.edge_count, 4);
1096 assert_eq!(summary.root_count, 1);
1097 assert_eq!(summary.leaf_count, 1);
1098 assert_eq!(summary.max_in_degree, 2); assert_eq!(summary.max_out_degree, 2); assert_eq!(summary.data_edge_count, 1);
1101 assert_eq!(summary.control_edge_count, 3);
1102 assert_eq!(summary.conditional_edge_count, 0);
1103 }
1104
1105 #[test]
1106 fn test_edge_pairs() {
1107 let mut dag = WorkflowDag::new();
1108 dag.add_task(create_test_task("t1", "Task 1")).ok();
1109 dag.add_task(create_test_task("t2", "Task 2")).ok();
1110 dag.add_dependency("t1", "t2", TaskEdge::default()).ok();
1111
1112 let pairs = dag.edge_pairs();
1113 assert_eq!(pairs.len(), 1);
1114 assert_eq!(pairs[0], ("t1".to_string(), "t2".to_string()));
1115 }
1116
1117 #[test]
1118 fn test_get_dependencies_and_dependents() {
1119 let mut dag = WorkflowDag::new();
1120 dag.add_task(create_test_task("t1", "Task 1")).ok();
1121 dag.add_task(create_test_task("t2", "Task 2")).ok();
1122 dag.add_task(create_test_task("t3", "Task 3")).ok();
1123
1124 dag.add_dependency("t1", "t3", TaskEdge::default()).ok();
1125 dag.add_dependency("t2", "t3", TaskEdge::default()).ok();
1126
1127 let deps = dag.get_dependencies("t3");
1128 assert_eq!(deps.len(), 2);
1129 assert!(deps.contains(&"t1".to_string()));
1130 assert!(deps.contains(&"t2".to_string()));
1131
1132 let dependents = dag.get_dependents("t1");
1133 assert_eq!(dependents.len(), 1);
1134 assert!(dependents.contains(&"t3".to_string()));
1135 }
1136}