celers_core/
dag.rs

1//! Directed Acyclic Graph (DAG) support for task dependencies
2//!
3//! This module provides functionality for managing task dependencies and validating
4//! that task graphs are acyclic.
5//!
6//! # Example
7//!
8//! ```
9//! use celers_core::dag::{TaskDag, DagNode};
10//! use uuid::Uuid;
11//!
12//! let mut dag = TaskDag::new();
13//! let task1 = Uuid::new_v4();
14//! let task2 = Uuid::new_v4();
15//! let task3 = Uuid::new_v4();
16//!
17//! // Create a simple DAG: task1 -> task2 -> task3
18//! dag.add_node(task1, "task1");
19//! dag.add_node(task2, "task2");
20//! dag.add_node(task3, "task3");
21//!
22//! dag.add_dependency(task2, task1).unwrap();
23//! dag.add_dependency(task3, task2).unwrap();
24//!
25//! // Validate the DAG (no cycles)
26//! assert!(dag.validate().is_ok());
27//!
28//! // Get execution order
29//! let order = dag.topological_sort().unwrap();
30//! assert_eq!(order, vec![task1, task2, task3]);
31//! ```
32
33use crate::{CelersError, Result, TaskId};
34use serde::{Deserialize, Serialize};
35use std::collections::{HashMap, HashSet, VecDeque};
36
37/// A node in the task DAG
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct DagNode {
40    /// Task ID
41    pub task_id: TaskId,
42
43    /// Task name for debugging
44    pub task_name: String,
45
46    /// Tasks that this task depends on (must complete before this task)
47    pub dependencies: HashSet<TaskId>,
48
49    /// Tasks that depend on this task (will execute after this task)
50    pub dependents: HashSet<TaskId>,
51}
52
53impl DagNode {
54    /// Create a new DAG node
55    #[must_use]
56    pub fn new(task_id: TaskId, task_name: impl Into<String>) -> Self {
57        Self {
58            task_id,
59            task_name: task_name.into(),
60            dependencies: HashSet::new(),
61            dependents: HashSet::new(),
62        }
63    }
64
65    /// Check if this node has any dependencies
66    #[inline]
67    #[must_use]
68    pub fn has_dependencies(&self) -> bool {
69        !self.dependencies.is_empty()
70    }
71
72    /// Check if this node has any dependents
73    #[inline]
74    #[must_use]
75    pub fn has_dependents(&self) -> bool {
76        !self.dependents.is_empty()
77    }
78
79    /// Check if this is a root node (no dependencies)
80    #[inline]
81    #[must_use]
82    pub fn is_root(&self) -> bool {
83        self.dependencies.is_empty()
84    }
85
86    /// Check if this is a leaf node (no dependents)
87    #[inline]
88    #[must_use]
89    pub fn is_leaf(&self) -> bool {
90        self.dependents.is_empty()
91    }
92}
93
94/// Directed Acyclic Graph for task dependencies
95#[derive(Debug, Clone, Serialize, Deserialize)]
96pub struct TaskDag {
97    /// All nodes in the DAG
98    nodes: HashMap<TaskId, DagNode>,
99}
100
101impl TaskDag {
102    /// Create a new empty DAG
103    #[must_use]
104    pub fn new() -> Self {
105        Self {
106            nodes: HashMap::new(),
107        }
108    }
109
110    /// Add a node to the DAG
111    pub fn add_node(&mut self, task_id: TaskId, task_name: impl Into<String>) {
112        self.nodes
113            .entry(task_id)
114            .or_insert_with(|| DagNode::new(task_id, task_name));
115    }
116
117    /// Add a dependency relationship: `task_id` depends on `depends_on`
118    ///
119    /// This means `depends_on` must complete before `task_id` can execute.
120    ///
121    /// # Errors
122    ///
123    /// Returns an error if either task is not found in the DAG or if adding this dependency would create a cycle.
124    pub fn add_dependency(&mut self, task_id: TaskId, depends_on: TaskId) -> Result<()> {
125        // Ensure both nodes exist
126        if !self.nodes.contains_key(&task_id) {
127            return Err(CelersError::Configuration(format!(
128                "Task {task_id} not found in DAG"
129            )));
130        }
131        if !self.nodes.contains_key(&depends_on) {
132            return Err(CelersError::Configuration(format!(
133                "Dependency task {depends_on} not found in DAG"
134            )));
135        }
136
137        // Add the dependency
138        if let Some(node) = self.nodes.get_mut(&task_id) {
139            node.dependencies.insert(depends_on);
140        }
141
142        // Add the dependent
143        if let Some(node) = self.nodes.get_mut(&depends_on) {
144            node.dependents.insert(task_id);
145        }
146
147        // Validate no cycles were introduced
148        self.validate()?;
149
150        Ok(())
151    }
152
153    /// Remove a dependency relationship
154    pub fn remove_dependency(&mut self, task_id: TaskId, depends_on: TaskId) {
155        if let Some(node) = self.nodes.get_mut(&task_id) {
156            node.dependencies.remove(&depends_on);
157        }
158        if let Some(node) = self.nodes.get_mut(&depends_on) {
159            node.dependents.remove(&task_id);
160        }
161    }
162
163    /// Get a node by task ID
164    #[inline]
165    #[must_use]
166    pub fn get_node(&self, task_id: &TaskId) -> Option<&DagNode> {
167        self.nodes.get(task_id)
168    }
169
170    /// Get all root nodes (nodes with no dependencies)
171    #[inline]
172    #[must_use]
173    pub fn get_roots(&self) -> Vec<TaskId> {
174        self.nodes
175            .values()
176            .filter(|node| node.is_root())
177            .map(|node| node.task_id)
178            .collect()
179    }
180
181    /// Get all leaf nodes (nodes with no dependents)
182    #[inline]
183    #[must_use]
184    pub fn get_leaves(&self) -> Vec<TaskId> {
185        self.nodes
186            .values()
187            .filter(|node| node.is_leaf())
188            .map(|node| node.task_id)
189            .collect()
190    }
191
192    /// Get the dependencies of a task
193    #[inline]
194    #[must_use]
195    pub fn get_dependencies(&self, task_id: &TaskId) -> Option<Vec<TaskId>> {
196        self.nodes
197            .get(task_id)
198            .map(|node| node.dependencies.iter().copied().collect())
199    }
200
201    /// Get the dependents of a task
202    #[inline]
203    #[must_use]
204    pub fn get_dependents(&self, task_id: &TaskId) -> Option<Vec<TaskId>> {
205        self.nodes
206            .get(task_id)
207            .map(|node| node.dependents.iter().copied().collect())
208    }
209
210    /// Check if the DAG contains a cycle
211    fn has_cycle(&self) -> bool {
212        let mut visited = HashSet::new();
213        let mut rec_stack = HashSet::new();
214
215        for node_id in self.nodes.keys() {
216            if self.has_cycle_util(*node_id, &mut visited, &mut rec_stack) {
217                return true;
218            }
219        }
220
221        false
222    }
223
224    /// Helper function for cycle detection using DFS
225    fn has_cycle_util(
226        &self,
227        node_id: TaskId,
228        visited: &mut HashSet<TaskId>,
229        rec_stack: &mut HashSet<TaskId>,
230    ) -> bool {
231        if rec_stack.contains(&node_id) {
232            return true; // Cycle detected
233        }
234
235        if visited.contains(&node_id) {
236            return false; // Already visited this path
237        }
238
239        visited.insert(node_id);
240        rec_stack.insert(node_id);
241
242        if let Some(node) = self.nodes.get(&node_id) {
243            for &dep_id in &node.dependencies {
244                if self.has_cycle_util(dep_id, visited, rec_stack) {
245                    return true;
246                }
247            }
248        }
249
250        rec_stack.remove(&node_id);
251        false
252    }
253
254    /// Validate the DAG structure (check for cycles)
255    ///
256    /// # Errors
257    ///
258    /// Returns an error if the DAG contains a cycle.
259    pub fn validate(&self) -> Result<()> {
260        if self.has_cycle() {
261            return Err(CelersError::Configuration(
262                "Task DAG contains a cycle".to_string(),
263            ));
264        }
265        Ok(())
266    }
267
268    /// Perform topological sort on the DAG
269    ///
270    /// Returns tasks in execution order (dependencies before dependents).
271    ///
272    /// # Errors
273    ///
274    /// Returns an error if the DAG contains a cycle.
275    pub fn topological_sort(&self) -> Result<Vec<TaskId>> {
276        self.validate()?;
277
278        let mut in_degree: HashMap<TaskId, usize> = HashMap::new();
279        let mut result = Vec::new();
280        let mut queue = VecDeque::new();
281
282        // Calculate in-degrees
283        for node in self.nodes.values() {
284            in_degree.insert(node.task_id, node.dependencies.len());
285            if node.is_root() {
286                queue.push_back(node.task_id);
287            }
288        }
289
290        // Process nodes in topological order
291        while let Some(task_id) = queue.pop_front() {
292            result.push(task_id);
293
294            if let Some(node) = self.nodes.get(&task_id) {
295                for &dependent_id in &node.dependents {
296                    if let Some(degree) = in_degree.get_mut(&dependent_id) {
297                        *degree -= 1;
298                        if *degree == 0 {
299                            queue.push_back(dependent_id);
300                        }
301                    }
302                }
303            }
304        }
305
306        // If not all nodes were processed, there's a cycle
307        if result.len() != self.nodes.len() {
308            return Err(CelersError::Configuration(
309                "Task DAG contains a cycle".to_string(),
310            ));
311        }
312
313        Ok(result)
314    }
315
316    /// Get the number of nodes in the DAG
317    #[inline]
318    #[must_use]
319    pub fn node_count(&self) -> usize {
320        self.nodes.len()
321    }
322
323    /// Get the number of edges in the DAG
324    #[inline]
325    #[must_use]
326    pub fn edge_count(&self) -> usize {
327        self.nodes
328            .values()
329            .map(|node| node.dependencies.len())
330            .sum()
331    }
332
333    /// Check if the DAG is empty
334    #[inline]
335    #[must_use]
336    pub fn is_empty(&self) -> bool {
337        self.nodes.is_empty()
338    }
339
340    /// Clear all nodes and edges
341    pub fn clear(&mut self) {
342        self.nodes.clear();
343    }
344}
345
346impl Default for TaskDag {
347    fn default() -> Self {
348        Self::new()
349    }
350}
351
352#[cfg(test)]
353mod tests {
354    use super::*;
355
356    #[test]
357    fn test_dag_basic() {
358        let mut dag = TaskDag::new();
359        let task1 = TaskId::new_v4();
360        let task2 = TaskId::new_v4();
361
362        dag.add_node(task1, "task1");
363        dag.add_node(task2, "task2");
364
365        assert_eq!(dag.node_count(), 2);
366        assert_eq!(dag.edge_count(), 0);
367    }
368
369    #[test]
370    fn test_dag_dependencies() {
371        let mut dag = TaskDag::new();
372        let task1 = TaskId::new_v4();
373        let task2 = TaskId::new_v4();
374
375        dag.add_node(task1, "task1");
376        dag.add_node(task2, "task2");
377        dag.add_dependency(task2, task1).unwrap();
378
379        assert_eq!(dag.edge_count(), 1);
380
381        let deps = dag.get_dependencies(&task2).unwrap();
382        assert_eq!(deps.len(), 1);
383        assert!(deps.contains(&task1));
384
385        let dependents = dag.get_dependents(&task1).unwrap();
386        assert_eq!(dependents.len(), 1);
387        assert!(dependents.contains(&task2));
388    }
389
390    #[test]
391    fn test_dag_cycle_detection() {
392        let mut dag = TaskDag::new();
393        let task1 = TaskId::new_v4();
394        let task2 = TaskId::new_v4();
395        let task3 = TaskId::new_v4();
396
397        dag.add_node(task1, "task1");
398        dag.add_node(task2, "task2");
399        dag.add_node(task3, "task3");
400
401        dag.add_dependency(task2, task1).unwrap();
402        dag.add_dependency(task3, task2).unwrap();
403
404        // Try to create a cycle
405        let result = dag.add_dependency(task1, task3);
406        assert!(result.is_err());
407    }
408
409    #[test]
410    fn test_dag_topological_sort() {
411        let mut dag = TaskDag::new();
412        let task1 = TaskId::new_v4();
413        let task2 = TaskId::new_v4();
414        let task3 = TaskId::new_v4();
415
416        dag.add_node(task1, "task1");
417        dag.add_node(task2, "task2");
418        dag.add_node(task3, "task3");
419
420        dag.add_dependency(task2, task1).unwrap();
421        dag.add_dependency(task3, task2).unwrap();
422
423        let order = dag.topological_sort().unwrap();
424        assert_eq!(order.len(), 3);
425
426        // task1 should come before task2
427        let pos1 = order.iter().position(|&t| t == task1).unwrap();
428        let pos2 = order.iter().position(|&t| t == task2).unwrap();
429        let pos3 = order.iter().position(|&t| t == task3).unwrap();
430
431        assert!(pos1 < pos2);
432        assert!(pos2 < pos3);
433    }
434
435    #[test]
436    fn test_dag_roots_and_leaves() {
437        let mut dag = TaskDag::new();
438        let task1 = TaskId::new_v4();
439        let task2 = TaskId::new_v4();
440        let task3 = TaskId::new_v4();
441
442        dag.add_node(task1, "task1");
443        dag.add_node(task2, "task2");
444        dag.add_node(task3, "task3");
445
446        dag.add_dependency(task2, task1).unwrap();
447        dag.add_dependency(task3, task2).unwrap();
448
449        let roots = dag.get_roots();
450        assert_eq!(roots.len(), 1);
451        assert!(roots.contains(&task1));
452
453        let leaves = dag.get_leaves();
454        assert_eq!(leaves.len(), 1);
455        assert!(leaves.contains(&task3));
456    }
457
458    #[test]
459    fn test_dag_remove_dependency() {
460        let mut dag = TaskDag::new();
461        let task1 = TaskId::new_v4();
462        let task2 = TaskId::new_v4();
463
464        dag.add_node(task1, "task1");
465        dag.add_node(task2, "task2");
466        dag.add_dependency(task2, task1).unwrap();
467
468        assert_eq!(dag.edge_count(), 1);
469
470        dag.remove_dependency(task2, task1);
471        assert_eq!(dag.edge_count(), 0);
472    }
473
474    #[test]
475    fn test_dag_complex() {
476        let mut dag = TaskDag::new();
477        let task1 = TaskId::new_v4();
478        let task2 = TaskId::new_v4();
479        let task3 = TaskId::new_v4();
480        let task4 = TaskId::new_v4();
481
482        dag.add_node(task1, "task1");
483        dag.add_node(task2, "task2");
484        dag.add_node(task3, "task3");
485        dag.add_node(task4, "task4");
486
487        // task1 and task2 are independent roots
488        // task3 depends on both task1 and task2
489        // task4 depends on task3
490        dag.add_dependency(task3, task1).unwrap();
491        dag.add_dependency(task3, task2).unwrap();
492        dag.add_dependency(task4, task3).unwrap();
493
494        let order = dag.topological_sort().unwrap();
495        assert_eq!(order.len(), 4);
496
497        // task3 must come after both task1 and task2
498        let pos1 = order.iter().position(|&t| t == task1).unwrap();
499        let pos2 = order.iter().position(|&t| t == task2).unwrap();
500        let pos3 = order.iter().position(|&t| t == task3).unwrap();
501        let pos4 = order.iter().position(|&t| t == task4).unwrap();
502
503        assert!(pos1 < pos3);
504        assert!(pos2 < pos3);
505        assert!(pos3 < pos4);
506    }
507
508    mod proptests {
509        use super::*;
510        use proptest::prelude::*;
511
512        proptest! {
513            #[test]
514            fn test_dag_node_count_matches_added_nodes(count in 1usize..20) {
515                let mut dag = TaskDag::new();
516                let ids: Vec<_> = (0..count).map(|_| TaskId::new_v4()).collect();
517
518                for (i, id) in ids.iter().enumerate() {
519                    dag.add_node(*id, format!("task_{i}"));
520                }
521
522                prop_assert_eq!(dag.node_count(), count);
523            }
524
525            #[test]
526            fn test_dag_linear_chain_sorts_correctly(count in 2usize..15) {
527                let mut dag = TaskDag::new();
528                let ids: Vec<_> = (0..count).map(|_| TaskId::new_v4()).collect();
529
530                // Create linear chain: ids[0] -> ids[1] -> ids[2] -> ...
531                for (i, id) in ids.iter().enumerate() {
532                    dag.add_node(*id, format!("task_{i}"));
533                }
534
535                for i in 1..ids.len() {
536                    dag.add_dependency(ids[i], ids[i - 1]).unwrap();
537                }
538
539                let sorted = dag.topological_sort().unwrap();
540                prop_assert_eq!(sorted.len(), count);
541
542                // Verify order: each node must appear after all its dependencies
543                for i in 1..ids.len() {
544                    let pos_parent = sorted.iter().position(|&t| t == ids[i - 1]).unwrap();
545                    let pos_child = sorted.iter().position(|&t| t == ids[i]).unwrap();
546                    prop_assert!(pos_parent < pos_child);
547                }
548            }
549
550            #[test]
551            fn test_dag_validate_always_succeeds_for_acyclic(count in 2usize..10) {
552                let mut dag = TaskDag::new();
553                let ids: Vec<_> = (0..count).map(|_| TaskId::new_v4()).collect();
554
555                for (i, id) in ids.iter().enumerate() {
556                    dag.add_node(*id, format!("task_{i}"));
557                }
558
559                // Create valid DAG structure
560                for i in 1..ids.len() {
561                    dag.add_dependency(ids[i], ids[i - 1]).unwrap();
562                }
563
564                prop_assert!(dag.validate().is_ok());
565            }
566
567            #[test]
568            fn test_dag_roots_have_no_dependencies(count in 2usize..10) {
569                let mut dag = TaskDag::new();
570                let ids: Vec<_> = (0..count).map(|_| TaskId::new_v4()).collect();
571
572                for (i, id) in ids.iter().enumerate() {
573                    dag.add_node(*id, format!("task_{i}"));
574                }
575
576                // Create linear chain
577                for i in 1..ids.len() {
578                    dag.add_dependency(ids[i], ids[i - 1]).unwrap();
579                }
580
581                let roots = dag.get_roots();
582
583                // Each root should have no dependencies
584                for root in roots {
585                    let deps = dag.get_dependencies(&root).unwrap();
586                    prop_assert_eq!(deps.len(), 0);
587                }
588            }
589
590            #[test]
591            fn test_dag_leaves_have_no_dependents(count in 2usize..10) {
592                let mut dag = TaskDag::new();
593                let ids: Vec<_> = (0..count).map(|_| TaskId::new_v4()).collect();
594
595                for (i, id) in ids.iter().enumerate() {
596                    dag.add_node(*id, format!("task_{i}"));
597                }
598
599                // Create linear chain
600                for i in 1..ids.len() {
601                    dag.add_dependency(ids[i], ids[i - 1]).unwrap();
602                }
603
604                let leaves = dag.get_leaves();
605
606                // Each leaf should have no dependents
607                for leaf in leaves {
608                    let dependents = dag.get_dependents(&leaf).unwrap();
609                    prop_assert_eq!(dependents.len(), 0);
610                }
611            }
612
613            #[test]
614            fn test_dag_edge_count_matches_added_dependencies(node_count in 2usize..10) {
615                let mut dag = TaskDag::new();
616                let ids: Vec<_> = (0..node_count).map(|_| TaskId::new_v4()).collect();
617
618                for (i, id) in ids.iter().enumerate() {
619                    dag.add_node(*id, format!("task_{i}"));
620                }
621
622                // Add edges in a linear chain
623                let edge_count = node_count - 1;
624                for i in 1..ids.len() {
625                    dag.add_dependency(ids[i], ids[i - 1]).unwrap();
626                }
627
628                prop_assert_eq!(dag.edge_count(), edge_count);
629            }
630
631            #[test]
632            fn test_dag_remove_dependency_decreases_edge_count(node_count in 2usize..10) {
633                let mut dag = TaskDag::new();
634                let ids: Vec<_> = (0..node_count).map(|_| TaskId::new_v4()).collect();
635
636                for (i, id) in ids.iter().enumerate() {
637                    dag.add_node(*id, format!("task_{i}"));
638                }
639
640                // Add all edges
641                for i in 1..ids.len() {
642                    dag.add_dependency(ids[i], ids[i - 1]).unwrap();
643                }
644
645                let initial_count = dag.edge_count();
646
647                // Remove one dependency
648                dag.remove_dependency(ids[1], ids[0]);
649
650                prop_assert_eq!(dag.edge_count(), initial_count - 1);
651            }
652        }
653    }
654}