ricecoder_agents/
scheduler.rs

1//! Agent scheduler for managing execution order and parallelism
2
3use crate::error::{AgentError, Result};
4use crate::models::AgentTask;
5use std::collections::{HashMap, HashSet};
6
7/// Execution schedule for agents
8#[derive(Debug, Clone)]
9pub struct ExecutionSchedule {
10    /// Ordered list of execution phases
11    pub phases: Vec<ExecutionPhase>,
12}
13
14/// A phase of execution (can contain parallel tasks)
15#[derive(Debug, Clone)]
16pub struct ExecutionPhase {
17    /// Tasks to execute in parallel in this phase
18    pub tasks: Vec<AgentTask>,
19}
20
21/// Task dependency information
22#[derive(Debug, Clone)]
23pub struct TaskDependency {
24    /// Task ID
25    pub task_id: String,
26    /// IDs of tasks this task depends on
27    pub depends_on: Vec<String>,
28}
29
30/// Directed Acyclic Graph (DAG) for task execution
31#[derive(Debug, Clone)]
32pub struct TaskDAG {
33    /// Map of task ID to its dependencies
34    pub dependencies: HashMap<String, Vec<String>>,
35    /// Map of task ID to tasks that depend on it
36    pub dependents: HashMap<String, Vec<String>>,
37    /// All tasks in the DAG
38    pub tasks: HashMap<String, AgentTask>,
39}
40
41impl TaskDAG {
42    /// Create a new empty DAG
43    pub fn new() -> Self {
44        Self {
45            dependencies: HashMap::new(),
46            dependents: HashMap::new(),
47            tasks: HashMap::new(),
48        }
49    }
50
51    /// Add a task to the DAG
52    pub fn add_task(&mut self, task: AgentTask) {
53        let task_id = task.id.clone();
54        self.tasks.insert(task_id.clone(), task);
55        self.dependencies.entry(task_id.clone()).or_default();
56        self.dependents.entry(task_id).or_default();
57    }
58
59    /// Add a dependency between tasks
60    pub fn add_dependency(&mut self, task_id: String, depends_on: String) {
61        self.dependencies
62            .entry(task_id.clone())
63            .or_default()
64            .push(depends_on.clone());
65
66        self.dependents.entry(depends_on).or_default().push(task_id);
67    }
68
69    /// Get tasks with no dependencies (can execute immediately)
70    pub fn get_root_tasks(&self) -> Vec<String> {
71        self.dependencies
72            .iter()
73            .filter(|(_, deps)| deps.is_empty())
74            .map(|(id, _)| id.clone())
75            .collect()
76    }
77
78    /// Get tasks that depend on a given task
79    pub fn get_dependents(&self, task_id: &str) -> Vec<String> {
80        self.dependents.get(task_id).cloned().unwrap_or_default()
81    }
82
83    /// Get dependencies for a task
84    pub fn get_dependencies(&self, task_id: &str) -> Vec<String> {
85        self.dependencies.get(task_id).cloned().unwrap_or_default()
86    }
87}
88
89impl Default for TaskDAG {
90    fn default() -> Self {
91        Self::new()
92    }
93}
94
95/// Agent scheduler for managing execution order and parallelism
96pub struct AgentScheduler;
97
98impl AgentScheduler {
99    /// Create a new agent scheduler
100    pub fn new() -> Self {
101        Self
102    }
103
104    /// Create an execution schedule from tasks
105    pub fn schedule(&self, tasks: &[AgentTask]) -> Result<ExecutionSchedule> {
106        // Build a DAG from tasks (currently no explicit dependencies)
107        let mut dag = TaskDAG::new();
108        for task in tasks {
109            dag.add_task(task.clone());
110        }
111
112        // Detect circular dependencies
113        self.detect_circular_dependencies_in_dag(&dag)?;
114
115        // Create execution phases based on dependencies
116        let phases = self.create_execution_phases(&dag)?;
117
118        Ok(ExecutionSchedule { phases })
119    }
120
121    /// Resolve task dependencies and create a DAG
122    pub fn resolve_dependencies(&self, tasks: &[AgentTask]) -> Result<TaskDAG> {
123        let mut dag = TaskDAG::new();
124
125        // Add all tasks to the DAG
126        for task in tasks {
127            dag.add_task(task.clone());
128        }
129
130        // Currently, tasks have no explicit dependencies
131        // This can be enhanced to parse dependencies from task options
132
133        Ok(dag)
134    }
135
136    /// Detect circular dependencies in the DAG
137    pub fn detect_circular_dependencies(&self, tasks: &[AgentTask]) -> Result<()> {
138        let dag = self.resolve_dependencies(tasks)?;
139        self.detect_circular_dependencies_in_dag(&dag)
140    }
141
142    /// Detect circular dependencies in a DAG using DFS
143    fn detect_circular_dependencies_in_dag(&self, dag: &TaskDAG) -> Result<()> {
144        let mut visited = HashSet::new();
145        let mut rec_stack = HashSet::new();
146
147        for task_id in dag.tasks.keys() {
148            if !visited.contains(task_id) {
149                self.dfs_detect_cycle(task_id, dag, &mut visited, &mut rec_stack)?;
150            }
151        }
152
153        Ok(())
154    }
155
156    /// DFS helper to detect cycles
157    #[allow(clippy::only_used_in_recursion)]
158    fn dfs_detect_cycle(
159        &self,
160        task_id: &str,
161        dag: &TaskDAG,
162        visited: &mut HashSet<String>,
163        rec_stack: &mut HashSet<String>,
164    ) -> Result<()> {
165        visited.insert(task_id.to_string());
166        rec_stack.insert(task_id.to_string());
167
168        let dependencies = dag.get_dependencies(task_id);
169        for dep_id in dependencies {
170            if !visited.contains(&dep_id) {
171                self.dfs_detect_cycle(&dep_id, dag, visited, rec_stack)?;
172            } else if rec_stack.contains(&dep_id) {
173                return Err(AgentError::invalid_input(format!(
174                    "Circular dependency detected: {} -> {}",
175                    task_id, dep_id
176                )));
177            }
178        }
179
180        rec_stack.remove(task_id);
181        Ok(())
182    }
183
184    /// Create execution phases from a DAG
185    fn create_execution_phases(&self, dag: &TaskDAG) -> Result<Vec<ExecutionPhase>> {
186        let mut phases = Vec::new();
187        let mut completed = HashSet::new();
188        let mut remaining: HashSet<String> = dag.tasks.keys().cloned().collect();
189
190        while !remaining.is_empty() {
191            // Find tasks that can execute in this phase (all dependencies completed)
192            let mut phase_tasks = Vec::new();
193
194            for task_id in remaining.iter() {
195                let dependencies = dag.get_dependencies(task_id);
196                if dependencies.iter().all(|dep| completed.contains(dep)) {
197                    phase_tasks.push(task_id.clone());
198                }
199            }
200
201            if phase_tasks.is_empty() {
202                // This shouldn't happen if circular dependency detection worked
203                return Err(AgentError::invalid_input(
204                    "Unable to create execution phases: no executable tasks found".to_string(),
205                ));
206            }
207
208            // Create phase with tasks that can execute in parallel
209            let phase = ExecutionPhase {
210                tasks: phase_tasks
211                    .iter()
212                    .filter_map(|id| dag.tasks.get(id).cloned())
213                    .collect(),
214            };
215
216            phases.push(phase);
217
218            // Mark tasks as completed
219            for task_id in phase_tasks {
220                completed.insert(task_id.clone());
221                remaining.remove(&task_id);
222            }
223        }
224
225        Ok(phases)
226    }
227}
228
229impl Default for AgentScheduler {
230    fn default() -> Self {
231        Self::new()
232    }
233}
234
235#[cfg(test)]
236mod tests {
237    use super::*;
238    use crate::models::{TaskOptions, TaskScope, TaskTarget, TaskType};
239    use std::path::PathBuf;
240
241    fn create_test_task(id: &str) -> AgentTask {
242        AgentTask {
243            id: id.to_string(),
244            task_type: TaskType::CodeReview,
245            target: TaskTarget {
246                files: vec![PathBuf::from("test.rs")],
247                scope: TaskScope::File,
248            },
249            options: TaskOptions::default(),
250        }
251    }
252
253    #[test]
254    fn test_schedule_single_task() {
255        let scheduler = AgentScheduler::new();
256        let tasks = vec![create_test_task("task1")];
257
258        let schedule = scheduler.schedule(&tasks).unwrap();
259        assert_eq!(schedule.phases.len(), 1);
260        assert_eq!(schedule.phases[0].tasks.len(), 1);
261        assert_eq!(schedule.phases[0].tasks[0].id, "task1");
262    }
263
264    #[test]
265    fn test_schedule_multiple_tasks() {
266        let scheduler = AgentScheduler::new();
267        let tasks = vec![
268            create_test_task("task1"),
269            create_test_task("task2"),
270            create_test_task("task3"),
271        ];
272
273        let schedule = scheduler.schedule(&tasks).unwrap();
274        assert_eq!(schedule.phases.len(), 1);
275        assert_eq!(schedule.phases[0].tasks.len(), 3);
276    }
277
278    #[test]
279    fn test_resolve_dependencies() {
280        let scheduler = AgentScheduler::new();
281        let tasks = vec![create_test_task("task1"), create_test_task("task2")];
282
283        let dag = scheduler.resolve_dependencies(&tasks).unwrap();
284        assert_eq!(dag.tasks.len(), 2);
285        assert!(dag.tasks.contains_key("task1"));
286        assert!(dag.tasks.contains_key("task2"));
287    }
288
289    #[test]
290    fn test_detect_circular_dependencies() {
291        let scheduler = AgentScheduler::new();
292        let tasks = vec![create_test_task("task1")];
293
294        let result = scheduler.detect_circular_dependencies(&tasks);
295        assert!(result.is_ok());
296    }
297
298    #[test]
299    fn test_task_dag_add_task() {
300        let mut dag = TaskDAG::new();
301        let task = create_test_task("task1");
302
303        dag.add_task(task.clone());
304
305        assert_eq!(dag.tasks.len(), 1);
306        assert!(dag.tasks.contains_key("task1"));
307        assert!(dag.dependencies.contains_key("task1"));
308        assert!(dag.dependents.contains_key("task1"));
309    }
310
311    #[test]
312    fn test_task_dag_add_dependency() {
313        let mut dag = TaskDAG::new();
314        dag.add_task(create_test_task("task1"));
315        dag.add_task(create_test_task("task2"));
316
317        dag.add_dependency("task2".to_string(), "task1".to_string());
318
319        assert_eq!(dag.get_dependencies("task2"), vec!["task1"]);
320        assert_eq!(dag.get_dependents("task1"), vec!["task2"]);
321    }
322
323    #[test]
324    fn test_task_dag_get_root_tasks() {
325        let mut dag = TaskDAG::new();
326        dag.add_task(create_test_task("task1"));
327        dag.add_task(create_test_task("task2"));
328        dag.add_task(create_test_task("task3"));
329
330        dag.add_dependency("task2".to_string(), "task1".to_string());
331        dag.add_dependency("task3".to_string(), "task1".to_string());
332
333        let root_tasks = dag.get_root_tasks();
334        assert_eq!(root_tasks.len(), 1);
335        assert_eq!(root_tasks[0], "task1");
336    }
337
338    #[test]
339    fn test_task_dag_multiple_root_tasks() {
340        let mut dag = TaskDAG::new();
341        dag.add_task(create_test_task("task1"));
342        dag.add_task(create_test_task("task2"));
343        dag.add_task(create_test_task("task3"));
344
345        dag.add_dependency("task3".to_string(), "task1".to_string());
346
347        let root_tasks = dag.get_root_tasks();
348        assert_eq!(root_tasks.len(), 2);
349        assert!(root_tasks.contains(&"task1".to_string()));
350        assert!(root_tasks.contains(&"task2".to_string()));
351    }
352
353    #[test]
354    fn test_create_execution_phases_linear_dependency() {
355        let scheduler = AgentScheduler::new();
356        let mut dag = TaskDAG::new();
357
358        dag.add_task(create_test_task("task1"));
359        dag.add_task(create_test_task("task2"));
360        dag.add_task(create_test_task("task3"));
361
362        dag.add_dependency("task2".to_string(), "task1".to_string());
363        dag.add_dependency("task3".to_string(), "task2".to_string());
364
365        let phases = scheduler.create_execution_phases(&dag).unwrap();
366
367        assert_eq!(phases.len(), 3);
368        assert_eq!(phases[0].tasks.len(), 1);
369        assert_eq!(phases[0].tasks[0].id, "task1");
370        assert_eq!(phases[1].tasks.len(), 1);
371        assert_eq!(phases[1].tasks[0].id, "task2");
372        assert_eq!(phases[2].tasks.len(), 1);
373        assert_eq!(phases[2].tasks[0].id, "task3");
374    }
375
376    #[test]
377    fn test_create_execution_phases_parallel_tasks() {
378        let scheduler = AgentScheduler::new();
379        let mut dag = TaskDAG::new();
380
381        dag.add_task(create_test_task("task1"));
382        dag.add_task(create_test_task("task2"));
383        dag.add_task(create_test_task("task3"));
384
385        dag.add_dependency("task3".to_string(), "task1".to_string());
386        dag.add_dependency("task3".to_string(), "task2".to_string());
387
388        let phases = scheduler.create_execution_phases(&dag).unwrap();
389
390        assert_eq!(phases.len(), 2);
391        assert_eq!(phases[0].tasks.len(), 2);
392        assert_eq!(phases[1].tasks.len(), 1);
393        assert_eq!(phases[1].tasks[0].id, "task3");
394    }
395
396    #[test]
397    fn test_detect_circular_dependency_simple() {
398        let scheduler = AgentScheduler::new();
399        let mut dag = TaskDAG::new();
400
401        dag.add_task(create_test_task("task1"));
402        dag.add_task(create_test_task("task2"));
403
404        dag.add_dependency("task1".to_string(), "task2".to_string());
405        dag.add_dependency("task2".to_string(), "task1".to_string());
406
407        let result = scheduler.detect_circular_dependencies_in_dag(&dag);
408        assert!(result.is_err());
409        assert!(result
410            .unwrap_err()
411            .to_string()
412            .contains("Circular dependency"));
413    }
414
415    #[test]
416    fn test_detect_circular_dependency_self_loop() {
417        let scheduler = AgentScheduler::new();
418        let mut dag = TaskDAG::new();
419
420        dag.add_task(create_test_task("task1"));
421        dag.add_dependency("task1".to_string(), "task1".to_string());
422
423        let result = scheduler.detect_circular_dependencies_in_dag(&dag);
424        assert!(result.is_err());
425    }
426
427    #[test]
428    fn test_detect_circular_dependency_complex() {
429        let scheduler = AgentScheduler::new();
430        let mut dag = TaskDAG::new();
431
432        dag.add_task(create_test_task("task1"));
433        dag.add_task(create_test_task("task2"));
434        dag.add_task(create_test_task("task3"));
435        dag.add_task(create_test_task("task4"));
436
437        dag.add_dependency("task2".to_string(), "task1".to_string());
438        dag.add_dependency("task3".to_string(), "task2".to_string());
439        dag.add_dependency("task1".to_string(), "task3".to_string()); // Creates cycle
440
441        let result = scheduler.detect_circular_dependencies_in_dag(&dag);
442        assert!(result.is_err());
443    }
444
445    #[test]
446    fn test_schedule_with_no_tasks() {
447        let scheduler = AgentScheduler::new();
448        let tasks: Vec<AgentTask> = vec![];
449
450        let schedule = scheduler.schedule(&tasks).unwrap();
451        assert_eq!(schedule.phases.len(), 0);
452    }
453
454    #[test]
455    fn test_task_dag_default() {
456        let dag = TaskDAG::default();
457        assert!(dag.tasks.is_empty());
458        assert!(dag.dependencies.is_empty());
459        assert!(dag.dependents.is_empty());
460    }
461
462    #[test]
463    fn test_scheduler_default() {
464        let _scheduler = AgentScheduler::default();
465        // Just verify it can be created with default
466    }
467}