Skip to main content

cortexai_crew/
task_manager.rs

1//! Task management and dependency resolution
2
3use futures::future::join_all;
4use std::collections::{HashMap, HashSet, VecDeque};
5
6use cortexai_core::{errors::CrewError, Task, TaskResult};
7
8/// Task manager with dependency resolution
9pub struct TaskManager {
10    tasks: HashMap<String, Task>,
11    max_concurrency: usize,
12}
13
14impl TaskManager {
15    pub fn new(max_concurrency: usize) -> Self {
16        Self {
17            tasks: HashMap::new(),
18            max_concurrency,
19        }
20    }
21
22    /// Add a task
23    pub fn add_task(&mut self, task: Task) -> Result<(), CrewError> {
24        // Check for circular dependencies
25        if self.has_circular_dependency(&task)? {
26            return Err(CrewError::CircularDependency);
27        }
28
29        self.tasks.insert(task.id.clone(), task);
30        Ok(())
31    }
32
33    /// Get all tasks
34    pub fn get_all_tasks(&self) -> Vec<Task> {
35        self.tasks.values().cloned().collect()
36    }
37
38    /// Get task count
39    pub fn task_count(&self) -> usize {
40        self.tasks.len()
41    }
42
43    /// Execute tasks respecting dependencies
44    pub async fn execute_with_dependencies<F, Fut>(
45        &self,
46        executor: F,
47    ) -> Result<Vec<TaskResult>, CrewError>
48    where
49        F: Fn(Task) -> Fut + Clone + Send + 'static,
50        Fut: std::future::Future<Output = Result<TaskResult, CrewError>> + Send,
51    {
52        let mut results = Vec::new();
53        let mut completed = HashSet::new();
54        let mut in_progress = HashSet::new();
55
56        // Build dependency graph
57        let dep_graph = self.build_dependency_graph();
58
59        // Find tasks with no dependencies
60        let mut ready_queue: VecDeque<String> = self
61            .tasks
62            .values()
63            .filter(|task| task.dependencies.is_empty())
64            .map(|task| task.id.clone())
65            .collect();
66
67        while !ready_queue.is_empty() || !in_progress.is_empty() {
68            // Execute ready tasks in parallel (up to max_concurrency)
69            let mut batch = Vec::new();
70
71            while batch.len() < self.max_concurrency && !ready_queue.is_empty() {
72                if let Some(task_id) = ready_queue.pop_front() {
73                    if let Some(task) = self.tasks.get(&task_id) {
74                        batch.push(task.clone());
75                        in_progress.insert(task_id);
76                    }
77                }
78            }
79
80            if batch.is_empty() && !in_progress.is_empty() {
81                // Wait a bit for in-progress tasks
82                tokio::time::sleep(std::time::Duration::from_millis(100)).await;
83                continue;
84            }
85
86            // Execute batch
87            let executor_clone = executor.clone();
88            let futures: Vec<_> = batch
89                .iter()
90                .map(|task| executor_clone(task.clone()))
91                .collect();
92
93            let batch_results = join_all(futures).await;
94
95            // Process results
96            for (idx, result) in batch_results.into_iter().enumerate() {
97                let task = &batch[idx];
98
99                match result {
100                    Ok(task_result) => {
101                        results.push(task_result);
102                        completed.insert(task.id.clone());
103                        in_progress.remove(&task.id);
104
105                        // Check if any dependent tasks are now ready
106                        for (dep_task_id, deps) in &dep_graph {
107                            if deps.iter().all(|d| completed.contains(d))
108                                && !completed.contains(dep_task_id)
109                                && !in_progress.contains(dep_task_id)
110                                && !ready_queue.contains(dep_task_id)
111                            {
112                                ready_queue.push_back(dep_task_id.clone());
113                            }
114                        }
115                    }
116                    Err(e) => {
117                        in_progress.remove(&task.id);
118                        return Err(e);
119                    }
120                }
121            }
122        }
123
124        Ok(results)
125    }
126
127    /// Build dependency graph
128    fn build_dependency_graph(&self) -> HashMap<String, Vec<String>> {
129        self.tasks
130            .values()
131            .filter(|task| !task.dependencies.is_empty())
132            .map(|task| (task.id.clone(), task.dependencies.clone()))
133            .collect()
134    }
135
136    /// Check for circular dependencies
137    fn has_circular_dependency(&self, new_task: &Task) -> Result<bool, CrewError> {
138        let mut visited = HashSet::new();
139        let mut stack = vec![new_task.id.clone()];
140
141        while let Some(task_id) = stack.pop() {
142            if visited.contains(&task_id) {
143                return Ok(true); // Circular dependency detected
144            }
145
146            visited.insert(task_id.clone());
147
148            // Get dependencies
149            let deps = if task_id == new_task.id {
150                &new_task.dependencies
151            } else if let Some(task) = self.tasks.get(&task_id) {
152                &task.dependencies
153            } else {
154                continue;
155            };
156
157            for dep_id in deps {
158                if dep_id == &new_task.id {
159                    return Ok(true); // Points back to new task
160                }
161                stack.push(dep_id.clone());
162            }
163        }
164
165        Ok(false)
166    }
167}
168
169#[cfg(test)]
170mod tests {
171    use super::*;
172
173    #[test]
174    fn test_circular_dependency_detection() {
175        let mut manager = TaskManager::new(4);
176
177        let mut task1 = Task::new("Task 1").with_dependencies(vec!["task2".to_string()]);
178        task1.id = "task1".to_string();
179
180        let mut task2 = Task::new("Task 2").with_dependencies(vec!["task1".to_string()]);
181        task2.id = "task2".to_string();
182
183        manager.add_task(task1).unwrap();
184        let result = manager.add_task(task2);
185
186        assert!(result.is_err());
187    }
188}