eshanized_polaris_core/
dag.rs1use crate::errors::{PolarisError, PolarisResult};
4use crate::task::{Task, TaskId, TaskStatus};
5use std::collections::{HashMap, HashSet, VecDeque};
6
7#[derive(Debug)]
9pub struct DagExecutor {
10 tasks: HashMap<TaskId, Task>,
11 dependencies: HashMap<TaskId, Vec<TaskId>>,
12 dependents: HashMap<TaskId, Vec<TaskId>>,
13}
14
15impl DagExecutor {
16 pub fn new() -> Self {
18 Self {
19 tasks: HashMap::new(),
20 dependencies: HashMap::new(),
21 dependents: HashMap::new(),
22 }
23 }
24
25 pub fn add_task(&mut self, task: Task) -> PolarisResult<()> {
27 let task_id = task.id;
28
29 if !task.dependencies.is_empty() {
31 self.dependencies
32 .insert(task_id, task.dependencies.clone());
33
34 for dep_id in &task.dependencies {
36 self.dependents
37 .entry(*dep_id)
38 .or_insert_with(Vec::new)
39 .push(task_id);
40 }
41 }
42
43 self.tasks.insert(task_id, task);
44 Ok(())
45 }
46
47 pub fn validate(&self) -> PolarisResult<()> {
49 let mut visited = HashSet::new();
50 let mut rec_stack = HashSet::new();
51
52 for task_id in self.tasks.keys() {
53 if !visited.contains(task_id) {
54 if self.has_cycle(*task_id, &mut visited, &mut rec_stack)? {
55 return Err(PolarisError::DagCycleDetected);
56 }
57 }
58 }
59
60 Ok(())
61 }
62
63 fn has_cycle(
65 &self,
66 task_id: TaskId,
67 visited: &mut HashSet<TaskId>,
68 rec_stack: &mut HashSet<TaskId>,
69 ) -> PolarisResult<bool> {
70 visited.insert(task_id);
71 rec_stack.insert(task_id);
72
73 if let Some(deps) = self.dependencies.get(&task_id) {
74 for dep_id in deps {
75 if !visited.contains(dep_id) {
76 if self.has_cycle(*dep_id, visited, rec_stack)? {
77 return Ok(true);
78 }
79 } else if rec_stack.contains(dep_id) {
80 return Ok(true);
81 }
82 }
83 }
84
85 rec_stack.remove(&task_id);
86 Ok(false)
87 }
88
89 pub fn get_ready_tasks(&self) -> Vec<TaskId> {
91 self.tasks
92 .iter()
93 .filter(|(task_id, task)| {
94 task.status == TaskStatus::Pending && self.are_dependencies_complete(**task_id)
95 })
96 .map(|(task_id, _)| *task_id)
97 .collect()
98 }
99
100 fn are_dependencies_complete(&self, task_id: TaskId) -> bool {
102 if let Some(deps) = self.dependencies.get(&task_id) {
103 deps.iter().all(|dep_id| {
104 self.tasks
105 .get(dep_id)
106 .map(|t| t.status == TaskStatus::Completed)
107 .unwrap_or(false)
108 })
109 } else {
110 true }
112 }
113
114 pub fn topological_sort(&self) -> PolarisResult<Vec<TaskId>> {
116 let mut in_degree: HashMap<TaskId, usize> = HashMap::new();
117 let mut queue = VecDeque::new();
118 let mut result = Vec::new();
119
120 for task_id in self.tasks.keys() {
122 in_degree.insert(*task_id, 0);
123 }
124
125 for deps in self.dependencies.values() {
126 for dep_id in deps {
127 *in_degree.get_mut(dep_id).unwrap() += 1;
128 }
129 }
130
131 for (task_id, °ree) in &in_degree {
133 if degree == 0 {
134 queue.push_back(*task_id);
135 }
136 }
137
138 while let Some(task_id) = queue.pop_front() {
140 result.push(task_id);
141
142 if let Some(dependents) = self.dependents.get(&task_id) {
143 for dependent_id in dependents {
144 let degree = in_degree.get_mut(dependent_id).unwrap();
145 *degree -= 1;
146 if *degree == 0 {
147 queue.push_back(*dependent_id);
148 }
149 }
150 }
151 }
152
153 if result.len() != self.tasks.len() {
155 return Err(PolarisError::DagCycleDetected);
156 }
157
158 Ok(result)
159 }
160
161 pub fn update_task_status(&mut self, task_id: TaskId, status: TaskStatus) {
163 if let Some(task) = self.tasks.get_mut(&task_id) {
164 task.status = status;
165 }
166 }
167
168 pub fn get_task(&self, task_id: TaskId) -> Option<&Task> {
170 self.tasks.get(&task_id)
171 }
172}
173
174impl Default for DagExecutor {
175 fn default() -> Self {
176 Self::new()
177 }
178}
179
180#[cfg(test)]
181mod tests {
182 use super::*;
183 use bytes::Bytes;
184
185 #[test]
186 fn test_dag_simple() {
187 let mut dag = DagExecutor::new();
188
189 let task1 = Task::new("task1", Bytes::new());
190 let task2 = Task::new("task2", Bytes::new()).with_dependency(task1.id);
191
192 dag.add_task(task1).unwrap();
193 dag.add_task(task2).unwrap();
194
195 assert!(dag.validate().is_ok());
196 }
197
198 #[test]
199 fn test_dag_cycle_detection() {
200 let mut dag = DagExecutor::new();
201
202 let mut task1 = Task::new("task1", Bytes::new());
203 let mut task2 = Task::new("task2", Bytes::new());
204
205 task1.dependencies.push(task2.id);
207 task2.dependencies.push(task1.id);
208
209 dag.add_task(task1).unwrap();
210 dag.add_task(task2).unwrap();
211
212 assert!(dag.validate().is_err());
213 }
214
215 #[test]
216 fn test_dag_ready_tasks() {
217 let mut dag = DagExecutor::new();
218
219 let task1 = Task::new("task1", Bytes::new());
220 let task2 = Task::new("task2", Bytes::new()).with_dependency(task1.id);
221
222 dag.add_task(task1.clone()).unwrap();
223 dag.add_task(task2).unwrap();
224
225 let ready = dag.get_ready_tasks();
226 assert_eq!(ready.len(), 1);
227 assert_eq!(ready[0], task1.id);
228 }
229
230 #[test]
231 fn test_dag_topological_sort() {
232 let mut dag = DagExecutor::new();
233
234 let task1 = Task::new("task1", Bytes::new());
235 let task2 = Task::new("task2", Bytes::new()).with_dependency(task1.id);
236 let task3 = Task::new("task3", Bytes::new()).with_dependency(task2.id);
237
238 dag.add_task(task1.clone()).unwrap();
239 dag.add_task(task2.clone()).unwrap();
240 dag.add_task(task3.clone()).unwrap();
241
242 let sorted = dag.topological_sort().unwrap();
243 assert_eq!(sorted.len(), 3);
244
245 let pos1 = sorted.iter().position(|&id| id == task1.id).unwrap();
247 let pos2 = sorted.iter().position(|&id| id == task2.id).unwrap();
248 let pos3 = sorted.iter().position(|&id| id == task3.id).unwrap();
249
250 assert!(pos1 < pos2);
251 assert!(pos2 < pos3);
252 }
253
254 #[test]
255 fn test_dag_complex_dependencies() {
256 let mut dag = DagExecutor::new();
257
258 let task1 = Task::new("task1", Bytes::new());
259 let task2 = Task::new("task2", Bytes::new());
260 let task3 = Task::new("task3", Bytes::new())
261 .with_dependency(task1.id)
262 .with_dependency(task2.id);
263
264 dag.add_task(task1).unwrap();
265 dag.add_task(task2).unwrap();
266 dag.add_task(task3).unwrap();
267
268 assert!(dag.validate().is_ok());
269
270 let sorted = dag.topological_sort().unwrap();
271 assert_eq!(sorted.len(), 3);
272 }
273}