1use crate::{CelersError, Result, TaskId};
34use serde::{Deserialize, Serialize};
35use std::collections::{HashMap, HashSet, VecDeque};
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct DagNode {
40 pub task_id: TaskId,
42
43 pub task_name: String,
45
46 pub dependencies: HashSet<TaskId>,
48
49 pub dependents: HashSet<TaskId>,
51}
52
53impl DagNode {
54 #[must_use]
56 pub fn new(task_id: TaskId, task_name: impl Into<String>) -> Self {
57 Self {
58 task_id,
59 task_name: task_name.into(),
60 dependencies: HashSet::new(),
61 dependents: HashSet::new(),
62 }
63 }
64
65 #[inline]
67 #[must_use]
68 pub fn has_dependencies(&self) -> bool {
69 !self.dependencies.is_empty()
70 }
71
72 #[inline]
74 #[must_use]
75 pub fn has_dependents(&self) -> bool {
76 !self.dependents.is_empty()
77 }
78
79 #[inline]
81 #[must_use]
82 pub fn is_root(&self) -> bool {
83 self.dependencies.is_empty()
84 }
85
86 #[inline]
88 #[must_use]
89 pub fn is_leaf(&self) -> bool {
90 self.dependents.is_empty()
91 }
92}
93
94#[derive(Debug, Clone, Serialize, Deserialize)]
96pub struct TaskDag {
97 nodes: HashMap<TaskId, DagNode>,
99}
100
101impl TaskDag {
102 #[must_use]
104 pub fn new() -> Self {
105 Self {
106 nodes: HashMap::new(),
107 }
108 }
109
110 pub fn add_node(&mut self, task_id: TaskId, task_name: impl Into<String>) {
112 self.nodes
113 .entry(task_id)
114 .or_insert_with(|| DagNode::new(task_id, task_name));
115 }
116
117 pub fn add_dependency(&mut self, task_id: TaskId, depends_on: TaskId) -> Result<()> {
125 if !self.nodes.contains_key(&task_id) {
127 return Err(CelersError::Configuration(format!(
128 "Task {task_id} not found in DAG"
129 )));
130 }
131 if !self.nodes.contains_key(&depends_on) {
132 return Err(CelersError::Configuration(format!(
133 "Dependency task {depends_on} not found in DAG"
134 )));
135 }
136
137 if let Some(node) = self.nodes.get_mut(&task_id) {
139 node.dependencies.insert(depends_on);
140 }
141
142 if let Some(node) = self.nodes.get_mut(&depends_on) {
144 node.dependents.insert(task_id);
145 }
146
147 self.validate()?;
149
150 Ok(())
151 }
152
153 pub fn remove_dependency(&mut self, task_id: TaskId, depends_on: TaskId) {
155 if let Some(node) = self.nodes.get_mut(&task_id) {
156 node.dependencies.remove(&depends_on);
157 }
158 if let Some(node) = self.nodes.get_mut(&depends_on) {
159 node.dependents.remove(&task_id);
160 }
161 }
162
163 #[inline]
165 #[must_use]
166 pub fn get_node(&self, task_id: &TaskId) -> Option<&DagNode> {
167 self.nodes.get(task_id)
168 }
169
170 #[inline]
172 #[must_use]
173 pub fn get_roots(&self) -> Vec<TaskId> {
174 self.nodes
175 .values()
176 .filter(|node| node.is_root())
177 .map(|node| node.task_id)
178 .collect()
179 }
180
181 #[inline]
183 #[must_use]
184 pub fn get_leaves(&self) -> Vec<TaskId> {
185 self.nodes
186 .values()
187 .filter(|node| node.is_leaf())
188 .map(|node| node.task_id)
189 .collect()
190 }
191
192 #[inline]
194 #[must_use]
195 pub fn get_dependencies(&self, task_id: &TaskId) -> Option<Vec<TaskId>> {
196 self.nodes
197 .get(task_id)
198 .map(|node| node.dependencies.iter().copied().collect())
199 }
200
201 #[inline]
203 #[must_use]
204 pub fn get_dependents(&self, task_id: &TaskId) -> Option<Vec<TaskId>> {
205 self.nodes
206 .get(task_id)
207 .map(|node| node.dependents.iter().copied().collect())
208 }
209
210 fn has_cycle(&self) -> bool {
212 let mut visited = HashSet::new();
213 let mut rec_stack = HashSet::new();
214
215 for node_id in self.nodes.keys() {
216 if self.has_cycle_util(*node_id, &mut visited, &mut rec_stack) {
217 return true;
218 }
219 }
220
221 false
222 }
223
224 fn has_cycle_util(
226 &self,
227 node_id: TaskId,
228 visited: &mut HashSet<TaskId>,
229 rec_stack: &mut HashSet<TaskId>,
230 ) -> bool {
231 if rec_stack.contains(&node_id) {
232 return true; }
234
235 if visited.contains(&node_id) {
236 return false; }
238
239 visited.insert(node_id);
240 rec_stack.insert(node_id);
241
242 if let Some(node) = self.nodes.get(&node_id) {
243 for &dep_id in &node.dependencies {
244 if self.has_cycle_util(dep_id, visited, rec_stack) {
245 return true;
246 }
247 }
248 }
249
250 rec_stack.remove(&node_id);
251 false
252 }
253
254 pub fn validate(&self) -> Result<()> {
260 if self.has_cycle() {
261 return Err(CelersError::Configuration(
262 "Task DAG contains a cycle".to_string(),
263 ));
264 }
265 Ok(())
266 }
267
268 pub fn topological_sort(&self) -> Result<Vec<TaskId>> {
276 self.validate()?;
277
278 let mut in_degree: HashMap<TaskId, usize> = HashMap::new();
279 let mut result = Vec::new();
280 let mut queue = VecDeque::new();
281
282 for node in self.nodes.values() {
284 in_degree.insert(node.task_id, node.dependencies.len());
285 if node.is_root() {
286 queue.push_back(node.task_id);
287 }
288 }
289
290 while let Some(task_id) = queue.pop_front() {
292 result.push(task_id);
293
294 if let Some(node) = self.nodes.get(&task_id) {
295 for &dependent_id in &node.dependents {
296 if let Some(degree) = in_degree.get_mut(&dependent_id) {
297 *degree -= 1;
298 if *degree == 0 {
299 queue.push_back(dependent_id);
300 }
301 }
302 }
303 }
304 }
305
306 if result.len() != self.nodes.len() {
308 return Err(CelersError::Configuration(
309 "Task DAG contains a cycle".to_string(),
310 ));
311 }
312
313 Ok(result)
314 }
315
316 #[inline]
318 #[must_use]
319 pub fn node_count(&self) -> usize {
320 self.nodes.len()
321 }
322
323 #[inline]
325 #[must_use]
326 pub fn edge_count(&self) -> usize {
327 self.nodes
328 .values()
329 .map(|node| node.dependencies.len())
330 .sum()
331 }
332
333 #[inline]
335 #[must_use]
336 pub fn is_empty(&self) -> bool {
337 self.nodes.is_empty()
338 }
339
340 pub fn clear(&mut self) {
342 self.nodes.clear();
343 }
344}
345
346impl Default for TaskDag {
347 fn default() -> Self {
348 Self::new()
349 }
350}
351
352#[cfg(test)]
353mod tests {
354 use super::*;
355
356 #[test]
357 fn test_dag_basic() {
358 let mut dag = TaskDag::new();
359 let task1 = TaskId::new_v4();
360 let task2 = TaskId::new_v4();
361
362 dag.add_node(task1, "task1");
363 dag.add_node(task2, "task2");
364
365 assert_eq!(dag.node_count(), 2);
366 assert_eq!(dag.edge_count(), 0);
367 }
368
369 #[test]
370 fn test_dag_dependencies() {
371 let mut dag = TaskDag::new();
372 let task1 = TaskId::new_v4();
373 let task2 = TaskId::new_v4();
374
375 dag.add_node(task1, "task1");
376 dag.add_node(task2, "task2");
377 dag.add_dependency(task2, task1).unwrap();
378
379 assert_eq!(dag.edge_count(), 1);
380
381 let deps = dag.get_dependencies(&task2).unwrap();
382 assert_eq!(deps.len(), 1);
383 assert!(deps.contains(&task1));
384
385 let dependents = dag.get_dependents(&task1).unwrap();
386 assert_eq!(dependents.len(), 1);
387 assert!(dependents.contains(&task2));
388 }
389
390 #[test]
391 fn test_dag_cycle_detection() {
392 let mut dag = TaskDag::new();
393 let task1 = TaskId::new_v4();
394 let task2 = TaskId::new_v4();
395 let task3 = TaskId::new_v4();
396
397 dag.add_node(task1, "task1");
398 dag.add_node(task2, "task2");
399 dag.add_node(task3, "task3");
400
401 dag.add_dependency(task2, task1).unwrap();
402 dag.add_dependency(task3, task2).unwrap();
403
404 let result = dag.add_dependency(task1, task3);
406 assert!(result.is_err());
407 }
408
409 #[test]
410 fn test_dag_topological_sort() {
411 let mut dag = TaskDag::new();
412 let task1 = TaskId::new_v4();
413 let task2 = TaskId::new_v4();
414 let task3 = TaskId::new_v4();
415
416 dag.add_node(task1, "task1");
417 dag.add_node(task2, "task2");
418 dag.add_node(task3, "task3");
419
420 dag.add_dependency(task2, task1).unwrap();
421 dag.add_dependency(task3, task2).unwrap();
422
423 let order = dag.topological_sort().unwrap();
424 assert_eq!(order.len(), 3);
425
426 let pos1 = order.iter().position(|&t| t == task1).unwrap();
428 let pos2 = order.iter().position(|&t| t == task2).unwrap();
429 let pos3 = order.iter().position(|&t| t == task3).unwrap();
430
431 assert!(pos1 < pos2);
432 assert!(pos2 < pos3);
433 }
434
435 #[test]
436 fn test_dag_roots_and_leaves() {
437 let mut dag = TaskDag::new();
438 let task1 = TaskId::new_v4();
439 let task2 = TaskId::new_v4();
440 let task3 = TaskId::new_v4();
441
442 dag.add_node(task1, "task1");
443 dag.add_node(task2, "task2");
444 dag.add_node(task3, "task3");
445
446 dag.add_dependency(task2, task1).unwrap();
447 dag.add_dependency(task3, task2).unwrap();
448
449 let roots = dag.get_roots();
450 assert_eq!(roots.len(), 1);
451 assert!(roots.contains(&task1));
452
453 let leaves = dag.get_leaves();
454 assert_eq!(leaves.len(), 1);
455 assert!(leaves.contains(&task3));
456 }
457
458 #[test]
459 fn test_dag_remove_dependency() {
460 let mut dag = TaskDag::new();
461 let task1 = TaskId::new_v4();
462 let task2 = TaskId::new_v4();
463
464 dag.add_node(task1, "task1");
465 dag.add_node(task2, "task2");
466 dag.add_dependency(task2, task1).unwrap();
467
468 assert_eq!(dag.edge_count(), 1);
469
470 dag.remove_dependency(task2, task1);
471 assert_eq!(dag.edge_count(), 0);
472 }
473
474 #[test]
475 fn test_dag_complex() {
476 let mut dag = TaskDag::new();
477 let task1 = TaskId::new_v4();
478 let task2 = TaskId::new_v4();
479 let task3 = TaskId::new_v4();
480 let task4 = TaskId::new_v4();
481
482 dag.add_node(task1, "task1");
483 dag.add_node(task2, "task2");
484 dag.add_node(task3, "task3");
485 dag.add_node(task4, "task4");
486
487 dag.add_dependency(task3, task1).unwrap();
491 dag.add_dependency(task3, task2).unwrap();
492 dag.add_dependency(task4, task3).unwrap();
493
494 let order = dag.topological_sort().unwrap();
495 assert_eq!(order.len(), 4);
496
497 let pos1 = order.iter().position(|&t| t == task1).unwrap();
499 let pos2 = order.iter().position(|&t| t == task2).unwrap();
500 let pos3 = order.iter().position(|&t| t == task3).unwrap();
501 let pos4 = order.iter().position(|&t| t == task4).unwrap();
502
503 assert!(pos1 < pos3);
504 assert!(pos2 < pos3);
505 assert!(pos3 < pos4);
506 }
507
508 mod proptests {
509 use super::*;
510 use proptest::prelude::*;
511
512 proptest! {
513 #[test]
514 fn test_dag_node_count_matches_added_nodes(count in 1usize..20) {
515 let mut dag = TaskDag::new();
516 let ids: Vec<_> = (0..count).map(|_| TaskId::new_v4()).collect();
517
518 for (i, id) in ids.iter().enumerate() {
519 dag.add_node(*id, format!("task_{i}"));
520 }
521
522 prop_assert_eq!(dag.node_count(), count);
523 }
524
525 #[test]
526 fn test_dag_linear_chain_sorts_correctly(count in 2usize..15) {
527 let mut dag = TaskDag::new();
528 let ids: Vec<_> = (0..count).map(|_| TaskId::new_v4()).collect();
529
530 for (i, id) in ids.iter().enumerate() {
532 dag.add_node(*id, format!("task_{i}"));
533 }
534
535 for i in 1..ids.len() {
536 dag.add_dependency(ids[i], ids[i - 1]).unwrap();
537 }
538
539 let sorted = dag.topological_sort().unwrap();
540 prop_assert_eq!(sorted.len(), count);
541
542 for i in 1..ids.len() {
544 let pos_parent = sorted.iter().position(|&t| t == ids[i - 1]).unwrap();
545 let pos_child = sorted.iter().position(|&t| t == ids[i]).unwrap();
546 prop_assert!(pos_parent < pos_child);
547 }
548 }
549
550 #[test]
551 fn test_dag_validate_always_succeeds_for_acyclic(count in 2usize..10) {
552 let mut dag = TaskDag::new();
553 let ids: Vec<_> = (0..count).map(|_| TaskId::new_v4()).collect();
554
555 for (i, id) in ids.iter().enumerate() {
556 dag.add_node(*id, format!("task_{i}"));
557 }
558
559 for i in 1..ids.len() {
561 dag.add_dependency(ids[i], ids[i - 1]).unwrap();
562 }
563
564 prop_assert!(dag.validate().is_ok());
565 }
566
567 #[test]
568 fn test_dag_roots_have_no_dependencies(count in 2usize..10) {
569 let mut dag = TaskDag::new();
570 let ids: Vec<_> = (0..count).map(|_| TaskId::new_v4()).collect();
571
572 for (i, id) in ids.iter().enumerate() {
573 dag.add_node(*id, format!("task_{i}"));
574 }
575
576 for i in 1..ids.len() {
578 dag.add_dependency(ids[i], ids[i - 1]).unwrap();
579 }
580
581 let roots = dag.get_roots();
582
583 for root in roots {
585 let deps = dag.get_dependencies(&root).unwrap();
586 prop_assert_eq!(deps.len(), 0);
587 }
588 }
589
590 #[test]
591 fn test_dag_leaves_have_no_dependents(count in 2usize..10) {
592 let mut dag = TaskDag::new();
593 let ids: Vec<_> = (0..count).map(|_| TaskId::new_v4()).collect();
594
595 for (i, id) in ids.iter().enumerate() {
596 dag.add_node(*id, format!("task_{i}"));
597 }
598
599 for i in 1..ids.len() {
601 dag.add_dependency(ids[i], ids[i - 1]).unwrap();
602 }
603
604 let leaves = dag.get_leaves();
605
606 for leaf in leaves {
608 let dependents = dag.get_dependents(&leaf).unwrap();
609 prop_assert_eq!(dependents.len(), 0);
610 }
611 }
612
613 #[test]
614 fn test_dag_edge_count_matches_added_dependencies(node_count in 2usize..10) {
615 let mut dag = TaskDag::new();
616 let ids: Vec<_> = (0..node_count).map(|_| TaskId::new_v4()).collect();
617
618 for (i, id) in ids.iter().enumerate() {
619 dag.add_node(*id, format!("task_{i}"));
620 }
621
622 let edge_count = node_count - 1;
624 for i in 1..ids.len() {
625 dag.add_dependency(ids[i], ids[i - 1]).unwrap();
626 }
627
628 prop_assert_eq!(dag.edge_count(), edge_count);
629 }
630
631 #[test]
632 fn test_dag_remove_dependency_decreases_edge_count(node_count in 2usize..10) {
633 let mut dag = TaskDag::new();
634 let ids: Vec<_> = (0..node_count).map(|_| TaskId::new_v4()).collect();
635
636 for (i, id) in ids.iter().enumerate() {
637 dag.add_node(*id, format!("task_{i}"));
638 }
639
640 for i in 1..ids.len() {
642 dag.add_dependency(ids[i], ids[i - 1]).unwrap();
643 }
644
645 let initial_count = dag.edge_count();
646
647 dag.remove_dependency(ids[1], ids[0]);
649
650 prop_assert_eq!(dag.edge_count(), initial_count - 1);
651 }
652 }
653 }
654}