1use crate::task::{Task, TaskId};
29use std::collections::{HashMap, VecDeque};
30use std::sync::RwLock;
31use tokio::sync::broadcast;
32
33#[derive(Debug, Clone)]
35pub enum TaskEvent {
36 Spawned {
38 task_id: TaskId,
39 parent_id: Option<TaskId>,
40 },
41 Started(TaskId),
43 Progress { task_id: TaskId, message: String },
45 Completed { task_id: TaskId, result: TaskResult },
47 Failed { task_id: TaskId, error: String },
49 Killed(TaskId),
51 ChildSpawned { parent_id: TaskId, child_id: TaskId },
53}
54
55#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
57pub struct TaskResult {
58 pub task_id: TaskId,
60 pub output: Option<serde_json::Value>,
62 pub duration_ms: u64,
64}
65
66#[derive(Debug, thiserror::Error)]
68pub enum TaskManagerError {
69 #[error("Task not found: {0}")]
70 TaskNotFound(TaskId),
71
72 #[error("Task already completed: {0}")]
73 TaskAlreadyCompleted(TaskId),
74
75 #[error("Task {0} is still running")]
76 TaskStillRunning(TaskId),
77
78 #[error("Manager is shutdown")]
79 Shutdown,
80
81 #[error("Send error: {0}")]
82 SendError(String),
83}
84
85pub struct TaskManager {
87 tasks: RwLock<HashMap<TaskId, Task>>,
89 results: RwLock<HashMap<TaskId, TaskResult>>,
91 subscribers: RwLock<HashMap<TaskId, broadcast::Sender<TaskEvent>>>,
93 global_events: broadcast::Sender<TaskEvent>,
95 pending_queue: RwLock<VecDeque<TaskId>>,
97 shutdown: RwLock<bool>,
99}
100
101impl Default for TaskManager {
102 fn default() -> Self {
103 Self::new()
104 }
105}
106
107impl TaskManager {
108 pub fn new() -> Self {
110 let (global_events, _) = broadcast::channel(1024);
111 Self {
112 tasks: RwLock::new(HashMap::new()),
113 results: RwLock::new(HashMap::new()),
114 subscribers: RwLock::new(HashMap::new()),
115 global_events,
116 pending_queue: RwLock::new(VecDeque::new()),
117 shutdown: RwLock::new(false),
118 }
119 }
120
121 pub fn spawn(&self, task: Task) -> TaskId {
125 if *self.shutdown.read().unwrap() {
126 return task.id;
128 }
129
130 let task_id = task.id;
131 let parent_id = task.parent_id;
132
133 let (tx, _) = broadcast::channel(64);
135 self.subscribers.write().unwrap().insert(task_id, tx);
136
137 self.tasks.write().unwrap().insert(task_id, task);
139
140 self.pending_queue.write().unwrap().push_back(task_id);
142
143 let _ = self
145 .global_events
146 .send(TaskEvent::Spawned { task_id, parent_id });
147
148 task_id
149 }
150
151 pub fn start(&self, task_id: TaskId) -> Result<(), TaskManagerError> {
153 let mut tasks = self.tasks.write().unwrap();
154
155 let task = tasks
156 .get_mut(&task_id)
157 .ok_or(TaskManagerError::TaskNotFound(task_id))?;
158
159 if task.status.is_terminal() {
160 return Err(TaskManagerError::TaskAlreadyCompleted(task_id));
161 }
162
163 task.start();
164
165 if let Some(tx) = self.subscribers.read().unwrap().get(&task_id) {
167 let _ = tx.send(TaskEvent::Started(task_id));
168 }
169
170 Ok(())
171 }
172
173 pub fn progress(
175 &self,
176 task_id: TaskId,
177 message: impl Into<String>,
178 ) -> Result<(), TaskManagerError> {
179 let tasks = self.tasks.read().unwrap();
180
181 if !tasks.contains_key(&task_id) {
182 return Err(TaskManagerError::TaskNotFound(task_id));
183 }
184
185 if let Some(tx) = self.subscribers.read().unwrap().get(&task_id) {
187 let _ = tx.send(TaskEvent::Progress {
188 task_id,
189 message: message.into(),
190 });
191 }
192
193 Ok(())
194 }
195
196 pub fn complete(
198 &self,
199 task_id: TaskId,
200 output: Option<serde_json::Value>,
201 ) -> Result<(), TaskManagerError> {
202 let mut tasks = self.tasks.write().unwrap();
203
204 let task = tasks
205 .get_mut(&task_id)
206 .ok_or(TaskManagerError::TaskNotFound(task_id))?;
207
208 if task.status.is_terminal() {
209 return Err(TaskManagerError::TaskAlreadyCompleted(task_id));
210 }
211
212 let duration_ms = task.duration_ms().unwrap_or(0);
213 task.complete();
214
215 let result = TaskResult {
217 task_id,
218 output,
219 duration_ms,
220 };
221 self.results
222 .write()
223 .unwrap()
224 .insert(task_id, result.clone());
225
226 if let Some(tx) = self.subscribers.read().unwrap().get(&task_id) {
228 let _ = tx.send(TaskEvent::Completed { task_id, result });
229 }
230
231 Ok(())
232 }
233
234 pub fn fail(&self, task_id: TaskId, error: impl Into<String>) -> Result<(), TaskManagerError> {
236 let mut tasks = self.tasks.write().unwrap();
237
238 let task = tasks
239 .get_mut(&task_id)
240 .ok_or(TaskManagerError::TaskNotFound(task_id))?;
241
242 if task.status.is_terminal() {
243 return Err(TaskManagerError::TaskAlreadyCompleted(task_id));
244 }
245
246 let error_msg = error.into();
247 task.fail(&error_msg);
248
249 if let Some(tx) = self.subscribers.read().unwrap().get(&task_id) {
251 let _ = tx.send(TaskEvent::Failed {
252 task_id,
253 error: error_msg,
254 });
255 }
256
257 Ok(())
258 }
259
260 pub fn kill(&self, task_id: TaskId) -> Result<(), TaskManagerError> {
262 let mut tasks = self.tasks.write().unwrap();
263
264 let task = tasks
265 .get_mut(&task_id)
266 .ok_or(TaskManagerError::TaskNotFound(task_id))?;
267
268 if task.status.is_terminal() {
269 return Err(TaskManagerError::TaskAlreadyCompleted(task_id));
270 }
271
272 task.kill();
273
274 if let Some(tx) = self.subscribers.read().unwrap().get(&task_id) {
276 let _ = tx.send(TaskEvent::Killed(task_id));
277 }
278
279 Ok(())
280 }
281
282 pub fn get(&self, task_id: TaskId) -> Option<Task> {
284 self.tasks.read().unwrap().get(&task_id).cloned()
285 }
286
287 pub fn get_result(&self, task_id: TaskId) -> Option<TaskResult> {
289 self.results.read().unwrap().get(&task_id).cloned()
290 }
291
292 pub fn is_terminal(&self, task_id: TaskId) -> bool {
294 self.tasks
295 .read()
296 .unwrap()
297 .get(&task_id)
298 .map(|t| t.is_terminal())
299 .unwrap_or(true)
300 }
301
302 pub fn subscribe(&self, task_id: TaskId) -> Option<broadcast::Receiver<TaskEvent>> {
307 self.subscribers
308 .read()
309 .unwrap()
310 .get(&task_id)
311 .map(|tx| tx.subscribe())
312 }
313
314 pub fn subscribe_all(&self) -> broadcast::Receiver<TaskEvent> {
316 self.global_events.subscribe()
317 }
318
319 pub async fn wait(&self, task_id: TaskId) -> Result<TaskResult, TaskManagerError> {
323 if let Some(result) = self.get_result(task_id) {
325 return Ok(result);
326 }
327
328 let mut rx = self
330 .subscribe(task_id)
331 .ok_or(TaskManagerError::TaskNotFound(task_id))?;
332
333 while let Ok(event) = rx.recv().await {
335 match event {
336 TaskEvent::Completed {
337 task_id: id,
338 result,
339 } if id == task_id => {
340 return Ok(result);
341 }
342 TaskEvent::Failed { task_id: id, error } if id == task_id => {
343 return Err(TaskManagerError::SendError(error));
344 }
345 TaskEvent::Killed(id) if id == task_id => {
346 return Err(TaskManagerError::SendError("Task was killed".to_string()));
347 }
348 _ => {}
349 }
350 }
351
352 Err(TaskManagerError::Shutdown)
353 }
354
355 pub fn add_child(&self, parent_id: TaskId, child_id: TaskId) -> Result<(), TaskManagerError> {
357 let mut tasks = self.tasks.write().unwrap();
358
359 let parent = tasks
360 .get_mut(&parent_id)
361 .ok_or(TaskManagerError::TaskNotFound(parent_id))?;
362 parent.add_child(child_id);
363
364 if let Some(tx) = self.subscribers.read().unwrap().get(&parent_id) {
366 let _ = tx.send(TaskEvent::ChildSpawned {
367 parent_id,
368 child_id,
369 });
370 }
371
372 Ok(())
373 }
374
375 pub fn spawn_child(&self, parent_id: TaskId, task: Task) -> Result<TaskId, TaskManagerError> {
379 if !self.tasks.read().unwrap().contains_key(&parent_id) {
381 return Err(TaskManagerError::TaskNotFound(parent_id));
382 }
383
384 let child_id = self.spawn(task);
385 self.add_child(parent_id, child_id)?;
386 Ok(child_id)
387 }
388
389 pub async fn wait_children(
393 &self,
394 parent_id: TaskId,
395 ) -> Result<Vec<TaskResult>, TaskManagerError> {
396 let children: Vec<TaskId> = {
397 let tasks = self.tasks.read().unwrap();
398 let parent = tasks
399 .get(&parent_id)
400 .ok_or(TaskManagerError::TaskNotFound(parent_id))?;
401 parent.child_ids.clone()
402 };
403
404 let mut results = Vec::new();
405 for child_id in children {
406 match self.wait(child_id).await {
407 Ok(result) => results.push(result),
408 Err(TaskManagerError::TaskStillRunning(_)) => {
409 let result = self.wait(child_id).await?;
411 results.push(result);
412 }
413 Err(e) => return Err(e),
414 }
415 }
416
417 Ok(results)
418 }
419
420 pub fn get_children(&self, parent_id: TaskId) -> Result<Vec<TaskId>, TaskManagerError> {
422 let tasks = self.tasks.read().unwrap();
423 let parent = tasks
424 .get(&parent_id)
425 .ok_or(TaskManagerError::TaskNotFound(parent_id))?;
426 Ok(parent.child_ids.clone())
427 }
428
429 pub fn all_children_complete(&self, parent_id: TaskId) -> bool {
431 if let Some(children) = self
432 .tasks
433 .read()
434 .unwrap()
435 .get(&parent_id)
436 .map(|t| &t.child_ids)
437 {
438 children.iter().all(|id| self.is_terminal(*id))
439 } else {
440 true
441 }
442 }
443
444 pub fn pending_tasks(&self) -> Vec<TaskId> {
446 self.pending_queue.read().unwrap().iter().copied().collect()
447 }
448
449 pub fn pop_pending(&self) -> Option<TaskId> {
451 self.pending_queue.write().unwrap().pop_front()
452 }
453
454 pub fn shutdown(&self) {
456 *self.shutdown.write().unwrap() = true;
457 self.tasks.write().unwrap().clear();
458 self.results.write().unwrap().clear();
459 self.subscribers.write().unwrap().clear();
460 self.pending_queue.write().unwrap().clear();
461 }
462
463 pub fn all_tasks(&self) -> HashMap<TaskId, Task> {
465 self.tasks.read().unwrap().clone()
466 }
467}
468
469#[cfg(test)]
470mod tests {
471 use super::*;
472 use crate::task::TaskStatus;
473 use serde_json::json;
474
475 #[tokio::test]
476 async fn test_task_manager_spawn_and_wait() {
477 let manager = TaskManager::new();
478
479 let task = Task::tool("read", json!({"file_path": "test.txt"}));
480 let task_id = manager.spawn(task);
481
482 manager.start(task_id).unwrap();
483
484 let result_json = json!({"success": true, "output": "file content"});
485 manager
486 .complete(task_id, Some(result_json.clone()))
487 .unwrap();
488
489 let result = manager.wait(task_id).await.unwrap();
490 assert_eq!(result.output, Some(result_json));
491 assert!(manager.is_terminal(task_id));
492 }
493
494 #[tokio::test]
495 async fn test_task_manager_fail() {
496 let manager = TaskManager::new();
497
498 let task = Task::tool("read", json!({"file_path": "nonexistent.txt"}));
499 let task_id = manager.spawn(task);
500
501 manager.start(task_id).unwrap();
502 manager.fail(task_id, "File not found").unwrap();
503
504 let result = manager.wait(task_id).await;
505 assert!(result.is_err());
506 }
507
508 #[test]
509 fn test_task_manager_kill() {
510 let manager = TaskManager::new();
511
512 let task = Task::tool("bash", json!({"command": "sleep 100"}));
513 let task_id = manager.spawn(task);
514
515 manager.start(task_id).unwrap();
516 manager.kill(task_id).unwrap();
517
518 assert!(manager.is_terminal(task_id));
519 assert_eq!(manager.get(task_id).unwrap().status, TaskStatus::Killed);
520 }
521}