1use crate::error::{AgentError, Result};
4use crate::models::AgentTask;
5use std::collections::{HashMap, HashSet};
6
7#[derive(Debug, Clone)]
9pub struct ExecutionSchedule {
10 pub phases: Vec<ExecutionPhase>,
12}
13
14#[derive(Debug, Clone)]
16pub struct ExecutionPhase {
17 pub tasks: Vec<AgentTask>,
19}
20
21#[derive(Debug, Clone)]
23pub struct TaskDependency {
24 pub task_id: String,
26 pub depends_on: Vec<String>,
28}
29
30#[derive(Debug, Clone)]
32pub struct TaskDAG {
33 pub dependencies: HashMap<String, Vec<String>>,
35 pub dependents: HashMap<String, Vec<String>>,
37 pub tasks: HashMap<String, AgentTask>,
39}
40
41impl TaskDAG {
42 pub fn new() -> Self {
44 Self {
45 dependencies: HashMap::new(),
46 dependents: HashMap::new(),
47 tasks: HashMap::new(),
48 }
49 }
50
51 pub fn add_task(&mut self, task: AgentTask) {
53 let task_id = task.id.clone();
54 self.tasks.insert(task_id.clone(), task);
55 self.dependencies.entry(task_id.clone()).or_default();
56 self.dependents.entry(task_id).or_default();
57 }
58
59 pub fn add_dependency(&mut self, task_id: String, depends_on: String) {
61 self.dependencies
62 .entry(task_id.clone())
63 .or_default()
64 .push(depends_on.clone());
65
66 self.dependents.entry(depends_on).or_default().push(task_id);
67 }
68
69 pub fn get_root_tasks(&self) -> Vec<String> {
71 self.dependencies
72 .iter()
73 .filter(|(_, deps)| deps.is_empty())
74 .map(|(id, _)| id.clone())
75 .collect()
76 }
77
78 pub fn get_dependents(&self, task_id: &str) -> Vec<String> {
80 self.dependents.get(task_id).cloned().unwrap_or_default()
81 }
82
83 pub fn get_dependencies(&self, task_id: &str) -> Vec<String> {
85 self.dependencies.get(task_id).cloned().unwrap_or_default()
86 }
87}
88
89impl Default for TaskDAG {
90 fn default() -> Self {
91 Self::new()
92 }
93}
94
95pub struct AgentScheduler;
97
98impl AgentScheduler {
99 pub fn new() -> Self {
101 Self
102 }
103
104 pub fn schedule(&self, tasks: &[AgentTask]) -> Result<ExecutionSchedule> {
106 let mut dag = TaskDAG::new();
108 for task in tasks {
109 dag.add_task(task.clone());
110 }
111
112 self.detect_circular_dependencies_in_dag(&dag)?;
114
115 let phases = self.create_execution_phases(&dag)?;
117
118 Ok(ExecutionSchedule { phases })
119 }
120
121 pub fn resolve_dependencies(&self, tasks: &[AgentTask]) -> Result<TaskDAG> {
123 let mut dag = TaskDAG::new();
124
125 for task in tasks {
127 dag.add_task(task.clone());
128 }
129
130 Ok(dag)
134 }
135
136 pub fn detect_circular_dependencies(&self, tasks: &[AgentTask]) -> Result<()> {
138 let dag = self.resolve_dependencies(tasks)?;
139 self.detect_circular_dependencies_in_dag(&dag)
140 }
141
142 fn detect_circular_dependencies_in_dag(&self, dag: &TaskDAG) -> Result<()> {
144 let mut visited = HashSet::new();
145 let mut rec_stack = HashSet::new();
146
147 for task_id in dag.tasks.keys() {
148 if !visited.contains(task_id) {
149 self.dfs_detect_cycle(task_id, dag, &mut visited, &mut rec_stack)?;
150 }
151 }
152
153 Ok(())
154 }
155
156 #[allow(clippy::only_used_in_recursion)]
158 fn dfs_detect_cycle(
159 &self,
160 task_id: &str,
161 dag: &TaskDAG,
162 visited: &mut HashSet<String>,
163 rec_stack: &mut HashSet<String>,
164 ) -> Result<()> {
165 visited.insert(task_id.to_string());
166 rec_stack.insert(task_id.to_string());
167
168 let dependencies = dag.get_dependencies(task_id);
169 for dep_id in dependencies {
170 if !visited.contains(&dep_id) {
171 self.dfs_detect_cycle(&dep_id, dag, visited, rec_stack)?;
172 } else if rec_stack.contains(&dep_id) {
173 return Err(AgentError::invalid_input(format!(
174 "Circular dependency detected: {} -> {}",
175 task_id, dep_id
176 )));
177 }
178 }
179
180 rec_stack.remove(task_id);
181 Ok(())
182 }
183
184 fn create_execution_phases(&self, dag: &TaskDAG) -> Result<Vec<ExecutionPhase>> {
186 let mut phases = Vec::new();
187 let mut completed = HashSet::new();
188 let mut remaining: HashSet<String> = dag.tasks.keys().cloned().collect();
189
190 while !remaining.is_empty() {
191 let mut phase_tasks = Vec::new();
193
194 for task_id in remaining.iter() {
195 let dependencies = dag.get_dependencies(task_id);
196 if dependencies.iter().all(|dep| completed.contains(dep)) {
197 phase_tasks.push(task_id.clone());
198 }
199 }
200
201 if phase_tasks.is_empty() {
202 return Err(AgentError::invalid_input(
204 "Unable to create execution phases: no executable tasks found".to_string(),
205 ));
206 }
207
208 let phase = ExecutionPhase {
210 tasks: phase_tasks
211 .iter()
212 .filter_map(|id| dag.tasks.get(id).cloned())
213 .collect(),
214 };
215
216 phases.push(phase);
217
218 for task_id in phase_tasks {
220 completed.insert(task_id.clone());
221 remaining.remove(&task_id);
222 }
223 }
224
225 Ok(phases)
226 }
227}
228
229impl Default for AgentScheduler {
230 fn default() -> Self {
231 Self::new()
232 }
233}
234
235#[cfg(test)]
236mod tests {
237 use super::*;
238 use crate::models::{TaskOptions, TaskScope, TaskTarget, TaskType};
239 use std::path::PathBuf;
240
241 fn create_test_task(id: &str) -> AgentTask {
242 AgentTask {
243 id: id.to_string(),
244 task_type: TaskType::CodeReview,
245 target: TaskTarget {
246 files: vec![PathBuf::from("test.rs")],
247 scope: TaskScope::File,
248 },
249 options: TaskOptions::default(),
250 }
251 }
252
253 #[test]
254 fn test_schedule_single_task() {
255 let scheduler = AgentScheduler::new();
256 let tasks = vec![create_test_task("task1")];
257
258 let schedule = scheduler.schedule(&tasks).unwrap();
259 assert_eq!(schedule.phases.len(), 1);
260 assert_eq!(schedule.phases[0].tasks.len(), 1);
261 assert_eq!(schedule.phases[0].tasks[0].id, "task1");
262 }
263
264 #[test]
265 fn test_schedule_multiple_tasks() {
266 let scheduler = AgentScheduler::new();
267 let tasks = vec![
268 create_test_task("task1"),
269 create_test_task("task2"),
270 create_test_task("task3"),
271 ];
272
273 let schedule = scheduler.schedule(&tasks).unwrap();
274 assert_eq!(schedule.phases.len(), 1);
275 assert_eq!(schedule.phases[0].tasks.len(), 3);
276 }
277
278 #[test]
279 fn test_resolve_dependencies() {
280 let scheduler = AgentScheduler::new();
281 let tasks = vec![create_test_task("task1"), create_test_task("task2")];
282
283 let dag = scheduler.resolve_dependencies(&tasks).unwrap();
284 assert_eq!(dag.tasks.len(), 2);
285 assert!(dag.tasks.contains_key("task1"));
286 assert!(dag.tasks.contains_key("task2"));
287 }
288
289 #[test]
290 fn test_detect_circular_dependencies() {
291 let scheduler = AgentScheduler::new();
292 let tasks = vec![create_test_task("task1")];
293
294 let result = scheduler.detect_circular_dependencies(&tasks);
295 assert!(result.is_ok());
296 }
297
298 #[test]
299 fn test_task_dag_add_task() {
300 let mut dag = TaskDAG::new();
301 let task = create_test_task("task1");
302
303 dag.add_task(task.clone());
304
305 assert_eq!(dag.tasks.len(), 1);
306 assert!(dag.tasks.contains_key("task1"));
307 assert!(dag.dependencies.contains_key("task1"));
308 assert!(dag.dependents.contains_key("task1"));
309 }
310
311 #[test]
312 fn test_task_dag_add_dependency() {
313 let mut dag = TaskDAG::new();
314 dag.add_task(create_test_task("task1"));
315 dag.add_task(create_test_task("task2"));
316
317 dag.add_dependency("task2".to_string(), "task1".to_string());
318
319 assert_eq!(dag.get_dependencies("task2"), vec!["task1"]);
320 assert_eq!(dag.get_dependents("task1"), vec!["task2"]);
321 }
322
323 #[test]
324 fn test_task_dag_get_root_tasks() {
325 let mut dag = TaskDAG::new();
326 dag.add_task(create_test_task("task1"));
327 dag.add_task(create_test_task("task2"));
328 dag.add_task(create_test_task("task3"));
329
330 dag.add_dependency("task2".to_string(), "task1".to_string());
331 dag.add_dependency("task3".to_string(), "task1".to_string());
332
333 let root_tasks = dag.get_root_tasks();
334 assert_eq!(root_tasks.len(), 1);
335 assert_eq!(root_tasks[0], "task1");
336 }
337
338 #[test]
339 fn test_task_dag_multiple_root_tasks() {
340 let mut dag = TaskDAG::new();
341 dag.add_task(create_test_task("task1"));
342 dag.add_task(create_test_task("task2"));
343 dag.add_task(create_test_task("task3"));
344
345 dag.add_dependency("task3".to_string(), "task1".to_string());
346
347 let root_tasks = dag.get_root_tasks();
348 assert_eq!(root_tasks.len(), 2);
349 assert!(root_tasks.contains(&"task1".to_string()));
350 assert!(root_tasks.contains(&"task2".to_string()));
351 }
352
353 #[test]
354 fn test_create_execution_phases_linear_dependency() {
355 let scheduler = AgentScheduler::new();
356 let mut dag = TaskDAG::new();
357
358 dag.add_task(create_test_task("task1"));
359 dag.add_task(create_test_task("task2"));
360 dag.add_task(create_test_task("task3"));
361
362 dag.add_dependency("task2".to_string(), "task1".to_string());
363 dag.add_dependency("task3".to_string(), "task2".to_string());
364
365 let phases = scheduler.create_execution_phases(&dag).unwrap();
366
367 assert_eq!(phases.len(), 3);
368 assert_eq!(phases[0].tasks.len(), 1);
369 assert_eq!(phases[0].tasks[0].id, "task1");
370 assert_eq!(phases[1].tasks.len(), 1);
371 assert_eq!(phases[1].tasks[0].id, "task2");
372 assert_eq!(phases[2].tasks.len(), 1);
373 assert_eq!(phases[2].tasks[0].id, "task3");
374 }
375
376 #[test]
377 fn test_create_execution_phases_parallel_tasks() {
378 let scheduler = AgentScheduler::new();
379 let mut dag = TaskDAG::new();
380
381 dag.add_task(create_test_task("task1"));
382 dag.add_task(create_test_task("task2"));
383 dag.add_task(create_test_task("task3"));
384
385 dag.add_dependency("task3".to_string(), "task1".to_string());
386 dag.add_dependency("task3".to_string(), "task2".to_string());
387
388 let phases = scheduler.create_execution_phases(&dag).unwrap();
389
390 assert_eq!(phases.len(), 2);
391 assert_eq!(phases[0].tasks.len(), 2);
392 assert_eq!(phases[1].tasks.len(), 1);
393 assert_eq!(phases[1].tasks[0].id, "task3");
394 }
395
396 #[test]
397 fn test_detect_circular_dependency_simple() {
398 let scheduler = AgentScheduler::new();
399 let mut dag = TaskDAG::new();
400
401 dag.add_task(create_test_task("task1"));
402 dag.add_task(create_test_task("task2"));
403
404 dag.add_dependency("task1".to_string(), "task2".to_string());
405 dag.add_dependency("task2".to_string(), "task1".to_string());
406
407 let result = scheduler.detect_circular_dependencies_in_dag(&dag);
408 assert!(result.is_err());
409 assert!(result
410 .unwrap_err()
411 .to_string()
412 .contains("Circular dependency"));
413 }
414
415 #[test]
416 fn test_detect_circular_dependency_self_loop() {
417 let scheduler = AgentScheduler::new();
418 let mut dag = TaskDAG::new();
419
420 dag.add_task(create_test_task("task1"));
421 dag.add_dependency("task1".to_string(), "task1".to_string());
422
423 let result = scheduler.detect_circular_dependencies_in_dag(&dag);
424 assert!(result.is_err());
425 }
426
427 #[test]
428 fn test_detect_circular_dependency_complex() {
429 let scheduler = AgentScheduler::new();
430 let mut dag = TaskDAG::new();
431
432 dag.add_task(create_test_task("task1"));
433 dag.add_task(create_test_task("task2"));
434 dag.add_task(create_test_task("task3"));
435 dag.add_task(create_test_task("task4"));
436
437 dag.add_dependency("task2".to_string(), "task1".to_string());
438 dag.add_dependency("task3".to_string(), "task2".to_string());
439 dag.add_dependency("task1".to_string(), "task3".to_string()); let result = scheduler.detect_circular_dependencies_in_dag(&dag);
442 assert!(result.is_err());
443 }
444
445 #[test]
446 fn test_schedule_with_no_tasks() {
447 let scheduler = AgentScheduler::new();
448 let tasks: Vec<AgentTask> = vec![];
449
450 let schedule = scheduler.schedule(&tasks).unwrap();
451 assert_eq!(schedule.phases.len(), 0);
452 }
453
454 #[test]
455 fn test_task_dag_default() {
456 let dag = TaskDAG::default();
457 assert!(dag.tasks.is_empty());
458 assert!(dag.dependencies.is_empty());
459 assert!(dag.dependents.is_empty());
460 }
461
462 #[test]
463 fn test_scheduler_default() {
464 let _scheduler = AgentScheduler::default();
465 }
467}