eshanized_polaris_core/
dag.rs

1//! Directed Acyclic Graph (DAG) support for task dependencies.
2
3use crate::errors::{PolarisError, PolarisResult};
4use crate::task::{Task, TaskId, TaskStatus};
5use std::collections::{HashMap, HashSet, VecDeque};
6
7/// DAG executor for managing task dependencies
8#[derive(Debug)]
9pub struct DagExecutor {
10    tasks: HashMap<TaskId, Task>,
11    dependencies: HashMap<TaskId, Vec<TaskId>>,
12    dependents: HashMap<TaskId, Vec<TaskId>>,
13}
14
15impl DagExecutor {
16    /// Create a new DAG executor
17    pub fn new() -> Self {
18        Self {
19            tasks: HashMap::new(),
20            dependencies: HashMap::new(),
21            dependents: HashMap::new(),
22        }
23    }
24
25    /// Add a task to the DAG
26    pub fn add_task(&mut self, task: Task) -> PolarisResult<()> {
27        let task_id = task.id;
28
29        // Store dependencies
30        if !task.dependencies.is_empty() {
31            self.dependencies
32                .insert(task_id, task.dependencies.clone());
33
34            // Build reverse dependency map
35            for dep_id in &task.dependencies {
36                self.dependents
37                    .entry(*dep_id)
38                    .or_insert_with(Vec::new)
39                    .push(task_id);
40            }
41        }
42
43        self.tasks.insert(task_id, task);
44        Ok(())
45    }
46
47    /// Validate the DAG for cycles
48    pub fn validate(&self) -> PolarisResult<()> {
49        let mut visited = HashSet::new();
50        let mut rec_stack = HashSet::new();
51
52        for task_id in self.tasks.keys() {
53            if !visited.contains(task_id) {
54                if self.has_cycle(*task_id, &mut visited, &mut rec_stack)? {
55                    return Err(PolarisError::DagCycleDetected);
56                }
57            }
58        }
59
60        Ok(())
61    }
62
63    /// Check for cycles using DFS
64    fn has_cycle(
65        &self,
66        task_id: TaskId,
67        visited: &mut HashSet<TaskId>,
68        rec_stack: &mut HashSet<TaskId>,
69    ) -> PolarisResult<bool> {
70        visited.insert(task_id);
71        rec_stack.insert(task_id);
72
73        if let Some(deps) = self.dependencies.get(&task_id) {
74            for dep_id in deps {
75                if !visited.contains(dep_id) {
76                    if self.has_cycle(*dep_id, visited, rec_stack)? {
77                        return Ok(true);
78                    }
79                } else if rec_stack.contains(dep_id) {
80                    return Ok(true);
81                }
82            }
83        }
84
85        rec_stack.remove(&task_id);
86        Ok(false)
87    }
88
89    /// Get tasks ready to execute (no pending dependencies)
90    pub fn get_ready_tasks(&self) -> Vec<TaskId> {
91        self.tasks
92            .iter()
93            .filter(|(task_id, task)| {
94                task.status == TaskStatus::Pending && self.are_dependencies_complete(**task_id)
95            })
96            .map(|(task_id, _)| *task_id)
97            .collect()
98    }
99
100    /// Check if all dependencies are complete
101    fn are_dependencies_complete(&self, task_id: TaskId) -> bool {
102        if let Some(deps) = self.dependencies.get(&task_id) {
103            deps.iter().all(|dep_id| {
104                self.tasks
105                    .get(dep_id)
106                    .map(|t| t.status == TaskStatus::Completed)
107                    .unwrap_or(false)
108            })
109        } else {
110            true // No dependencies
111        }
112    }
113
114    /// Get topological ordering of tasks
115    pub fn topological_sort(&self) -> PolarisResult<Vec<TaskId>> {
116        let mut in_degree: HashMap<TaskId, usize> = HashMap::new();
117        let mut queue = VecDeque::new();
118        let mut result = Vec::new();
119
120        // Calculate in-degrees
121        for task_id in self.tasks.keys() {
122            in_degree.insert(*task_id, 0);
123        }
124
125        for deps in self.dependencies.values() {
126            for dep_id in deps {
127                *in_degree.get_mut(dep_id).unwrap() += 1;
128            }
129        }
130
131        // Start with nodes that have no dependencies
132        for (task_id, &degree) in &in_degree {
133            if degree == 0 {
134                queue.push_back(*task_id);
135            }
136        }
137
138        // Process queue
139        while let Some(task_id) = queue.pop_front() {
140            result.push(task_id);
141
142            if let Some(dependents) = self.dependents.get(&task_id) {
143                for dependent_id in dependents {
144                    let degree = in_degree.get_mut(dependent_id).unwrap();
145                    *degree -= 1;
146                    if *degree == 0 {
147                        queue.push_back(*dependent_id);
148                    }
149                }
150            }
151        }
152
153        // If we didn't process all tasks, there's a cycle
154        if result.len() != self.tasks.len() {
155            return Err(PolarisError::DagCycleDetected);
156        }
157
158        Ok(result)
159    }
160
161    /// Update task status
162    pub fn update_task_status(&mut self, task_id: TaskId, status: TaskStatus) {
163        if let Some(task) = self.tasks.get_mut(&task_id) {
164            task.status = status;
165        }
166    }
167
168    /// Get task by ID
169    pub fn get_task(&self, task_id: TaskId) -> Option<&Task> {
170        self.tasks.get(&task_id)
171    }
172}
173
174impl Default for DagExecutor {
175    fn default() -> Self {
176        Self::new()
177    }
178}
179
180#[cfg(test)]
181mod tests {
182    use super::*;
183    use bytes::Bytes;
184
185    #[test]
186    fn test_dag_simple() {
187        let mut dag = DagExecutor::new();
188
189        let task1 = Task::new("task1", Bytes::new());
190        let task2 = Task::new("task2", Bytes::new()).with_dependency(task1.id);
191
192        dag.add_task(task1).unwrap();
193        dag.add_task(task2).unwrap();
194
195        assert!(dag.validate().is_ok());
196    }
197
198    #[test]
199    fn test_dag_cycle_detection() {
200        let mut dag = DagExecutor::new();
201
202        let mut task1 = Task::new("task1", Bytes::new());
203        let mut task2 = Task::new("task2", Bytes::new());
204
205        // Create a cycle
206        task1.dependencies.push(task2.id);
207        task2.dependencies.push(task1.id);
208
209        dag.add_task(task1).unwrap();
210        dag.add_task(task2).unwrap();
211
212        assert!(dag.validate().is_err());
213    }
214
215    #[test]
216    fn test_dag_ready_tasks() {
217        let mut dag = DagExecutor::new();
218
219        let task1 = Task::new("task1", Bytes::new());
220        let task2 = Task::new("task2", Bytes::new()).with_dependency(task1.id);
221
222        dag.add_task(task1.clone()).unwrap();
223        dag.add_task(task2).unwrap();
224
225        let ready = dag.get_ready_tasks();
226        assert_eq!(ready.len(), 1);
227        assert_eq!(ready[0], task1.id);
228    }
229
230    #[test]
231    fn test_dag_topological_sort() {
232        let mut dag = DagExecutor::new();
233
234        let task1 = Task::new("task1", Bytes::new());
235        let task2 = Task::new("task2", Bytes::new()).with_dependency(task1.id);
236        let task3 = Task::new("task3", Bytes::new()).with_dependency(task2.id);
237
238        dag.add_task(task1.clone()).unwrap();
239        dag.add_task(task2.clone()).unwrap();
240        dag.add_task(task3.clone()).unwrap();
241
242        let sorted = dag.topological_sort().unwrap();
243        assert_eq!(sorted.len(), 3);
244
245        // task1 should come before task2, task2 before task3
246        let pos1 = sorted.iter().position(|&id| id == task1.id).unwrap();
247        let pos2 = sorted.iter().position(|&id| id == task2.id).unwrap();
248        let pos3 = sorted.iter().position(|&id| id == task3.id).unwrap();
249
250        assert!(pos1 < pos2);
251        assert!(pos2 < pos3);
252    }
253
254    #[test]
255    fn test_dag_complex_dependencies() {
256        let mut dag = DagExecutor::new();
257
258        let task1 = Task::new("task1", Bytes::new());
259        let task2 = Task::new("task2", Bytes::new());
260        let task3 = Task::new("task3", Bytes::new())
261            .with_dependency(task1.id)
262            .with_dependency(task2.id);
263
264        dag.add_task(task1).unwrap();
265        dag.add_task(task2).unwrap();
266        dag.add_task(task3).unwrap();
267
268        assert!(dag.validate().is_ok());
269
270        let sorted = dag.topological_sort().unwrap();
271        assert_eq!(sorted.len(), 3);
272    }
273}