cortexai_crew/
task_manager.rs1use futures::future::join_all;
4use std::collections::{HashMap, HashSet, VecDeque};
5
6use cortexai_core::{errors::CrewError, Task, TaskResult};
7
8pub struct TaskManager {
10 tasks: HashMap<String, Task>,
11 max_concurrency: usize,
12}
13
14impl TaskManager {
15 pub fn new(max_concurrency: usize) -> Self {
16 Self {
17 tasks: HashMap::new(),
18 max_concurrency,
19 }
20 }
21
22 pub fn add_task(&mut self, task: Task) -> Result<(), CrewError> {
24 if self.has_circular_dependency(&task)? {
26 return Err(CrewError::CircularDependency);
27 }
28
29 self.tasks.insert(task.id.clone(), task);
30 Ok(())
31 }
32
33 pub fn get_all_tasks(&self) -> Vec<Task> {
35 self.tasks.values().cloned().collect()
36 }
37
38 pub fn task_count(&self) -> usize {
40 self.tasks.len()
41 }
42
43 pub async fn execute_with_dependencies<F, Fut>(
45 &self,
46 executor: F,
47 ) -> Result<Vec<TaskResult>, CrewError>
48 where
49 F: Fn(Task) -> Fut + Clone + Send + 'static,
50 Fut: std::future::Future<Output = Result<TaskResult, CrewError>> + Send,
51 {
52 let mut results = Vec::new();
53 let mut completed = HashSet::new();
54 let mut in_progress = HashSet::new();
55
56 let dep_graph = self.build_dependency_graph();
58
59 let mut ready_queue: VecDeque<String> = self
61 .tasks
62 .values()
63 .filter(|task| task.dependencies.is_empty())
64 .map(|task| task.id.clone())
65 .collect();
66
67 while !ready_queue.is_empty() || !in_progress.is_empty() {
68 let mut batch = Vec::new();
70
71 while batch.len() < self.max_concurrency && !ready_queue.is_empty() {
72 if let Some(task_id) = ready_queue.pop_front() {
73 if let Some(task) = self.tasks.get(&task_id) {
74 batch.push(task.clone());
75 in_progress.insert(task_id);
76 }
77 }
78 }
79
80 if batch.is_empty() && !in_progress.is_empty() {
81 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
83 continue;
84 }
85
86 let executor_clone = executor.clone();
88 let futures: Vec<_> = batch
89 .iter()
90 .map(|task| executor_clone(task.clone()))
91 .collect();
92
93 let batch_results = join_all(futures).await;
94
95 for (idx, result) in batch_results.into_iter().enumerate() {
97 let task = &batch[idx];
98
99 match result {
100 Ok(task_result) => {
101 results.push(task_result);
102 completed.insert(task.id.clone());
103 in_progress.remove(&task.id);
104
105 for (dep_task_id, deps) in &dep_graph {
107 if deps.iter().all(|d| completed.contains(d))
108 && !completed.contains(dep_task_id)
109 && !in_progress.contains(dep_task_id)
110 && !ready_queue.contains(dep_task_id)
111 {
112 ready_queue.push_back(dep_task_id.clone());
113 }
114 }
115 }
116 Err(e) => {
117 in_progress.remove(&task.id);
118 return Err(e);
119 }
120 }
121 }
122 }
123
124 Ok(results)
125 }
126
127 fn build_dependency_graph(&self) -> HashMap<String, Vec<String>> {
129 self.tasks
130 .values()
131 .filter(|task| !task.dependencies.is_empty())
132 .map(|task| (task.id.clone(), task.dependencies.clone()))
133 .collect()
134 }
135
136 fn has_circular_dependency(&self, new_task: &Task) -> Result<bool, CrewError> {
138 let mut visited = HashSet::new();
139 let mut stack = vec![new_task.id.clone()];
140
141 while let Some(task_id) = stack.pop() {
142 if visited.contains(&task_id) {
143 return Ok(true); }
145
146 visited.insert(task_id.clone());
147
148 let deps = if task_id == new_task.id {
150 &new_task.dependencies
151 } else if let Some(task) = self.tasks.get(&task_id) {
152 &task.dependencies
153 } else {
154 continue;
155 };
156
157 for dep_id in deps {
158 if dep_id == &new_task.id {
159 return Ok(true); }
161 stack.push(dep_id.clone());
162 }
163 }
164
165 Ok(false)
166 }
167}
168
169#[cfg(test)]
170mod tests {
171 use super::*;
172
173 #[test]
174 fn test_circular_dependency_detection() {
175 let mut manager = TaskManager::new(4);
176
177 let mut task1 = Task::new("Task 1").with_dependencies(vec!["task2".to_string()]);
178 task1.id = "task1".to_string();
179
180 let mut task2 = Task::new("Task 2").with_dependencies(vec!["task1".to_string()]);
181 task2.id = "task2".to_string();
182
183 manager.add_task(task1).unwrap();
184 let result = manager.add_task(task2);
185
186 assert!(result.is_err());
187 }
188}