1use crate::task::{Task, TaskId};
29use std::collections::{HashMap, VecDeque};
30use std::sync::RwLock;
31use tokio::sync::broadcast;
32
33#[derive(Debug, Clone)]
35pub enum TaskEvent {
36 Spawned {
38 task_id: TaskId,
39 parent_id: Option<TaskId>,
40 },
41 Started(TaskId),
43 Progress { task_id: TaskId, message: String },
45 Completed { task_id: TaskId, result: TaskResult },
47 Failed { task_id: TaskId, error: String },
49 Killed(TaskId),
51 ChildSpawned { parent_id: TaskId, child_id: TaskId },
53}
54
55#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
57pub struct TaskResult {
58 pub task_id: TaskId,
60 pub success: bool,
62 pub output: Option<serde_json::Value>,
64 pub duration_ms: u64,
66}
67
68#[derive(Debug, thiserror::Error)]
70pub enum TaskManagerError {
71 #[error("Task not found: {0}")]
72 TaskNotFound(TaskId),
73
74 #[error("Task already completed: {0}")]
75 TaskAlreadyCompleted(TaskId),
76
77 #[error("Task {0} is still running")]
78 TaskStillRunning(TaskId),
79
80 #[error("Manager is shutdown")]
81 Shutdown,
82
83 #[error("Send error: {0}")]
84 SendError(String),
85}
86
87pub struct TaskManager {
89 tasks: RwLock<HashMap<TaskId, Task>>,
91 results: RwLock<HashMap<TaskId, TaskResult>>,
93 subscribers: RwLock<HashMap<TaskId, broadcast::Sender<TaskEvent>>>,
95 global_events: broadcast::Sender<TaskEvent>,
97 pending_queue: RwLock<VecDeque<TaskId>>,
99 shutdown: RwLock<bool>,
101}
102
103impl Default for TaskManager {
104 fn default() -> Self {
105 Self::new()
106 }
107}
108
109impl TaskManager {
110 pub fn new() -> Self {
112 let (global_events, _) = broadcast::channel(1024);
113 Self {
114 tasks: RwLock::new(HashMap::new()),
115 results: RwLock::new(HashMap::new()),
116 subscribers: RwLock::new(HashMap::new()),
117 global_events,
118 pending_queue: RwLock::new(VecDeque::new()),
119 shutdown: RwLock::new(false),
120 }
121 }
122
123 pub fn spawn(&self, task: Task) -> TaskId {
127 if *self.shutdown.read().unwrap() {
128 return task.id;
130 }
131
132 let task_id = task.id;
133 let parent_id = task.parent_id;
134
135 let (tx, _) = broadcast::channel(64);
137 self.subscribers.write().unwrap().insert(task_id, tx);
138
139 self.tasks.write().unwrap().insert(task_id, task);
141
142 self.pending_queue.write().unwrap().push_back(task_id);
144
145 let _ = self
147 .global_events
148 .send(TaskEvent::Spawned { task_id, parent_id });
149
150 task_id
151 }
152
153 pub fn start(&self, task_id: TaskId) -> Result<(), TaskManagerError> {
155 let mut tasks = self.tasks.write().unwrap();
156
157 let task = tasks
158 .get_mut(&task_id)
159 .ok_or(TaskManagerError::TaskNotFound(task_id))?;
160
161 if task.status.is_terminal() {
162 return Err(TaskManagerError::TaskAlreadyCompleted(task_id));
163 }
164
165 task.start();
166
167 if let Some(tx) = self.subscribers.read().unwrap().get(&task_id) {
169 let _ = tx.send(TaskEvent::Started(task_id));
170 }
171
172 Ok(())
173 }
174
175 pub fn progress(
177 &self,
178 task_id: TaskId,
179 message: impl Into<String>,
180 ) -> Result<(), TaskManagerError> {
181 let tasks = self.tasks.read().unwrap();
182
183 if !tasks.contains_key(&task_id) {
184 return Err(TaskManagerError::TaskNotFound(task_id));
185 }
186
187 if let Some(tx) = self.subscribers.read().unwrap().get(&task_id) {
189 let _ = tx.send(TaskEvent::Progress {
190 task_id,
191 message: message.into(),
192 });
193 }
194
195 Ok(())
196 }
197
198 pub fn complete(
200 &self,
201 task_id: TaskId,
202 output: Option<serde_json::Value>,
203 ) -> Result<(), TaskManagerError> {
204 let mut tasks = self.tasks.write().unwrap();
205
206 let task = tasks
207 .get_mut(&task_id)
208 .ok_or(TaskManagerError::TaskNotFound(task_id))?;
209
210 if task.status.is_terminal() {
211 return Err(TaskManagerError::TaskAlreadyCompleted(task_id));
212 }
213
214 let duration_ms = task.duration_ms().unwrap_or(0);
215 task.complete();
216
217 let result = TaskResult {
219 task_id,
220 success: true,
221 output,
222 duration_ms,
223 };
224 self.results
225 .write()
226 .unwrap()
227 .insert(task_id, result.clone());
228
229 if let Some(tx) = self.subscribers.read().unwrap().get(&task_id) {
231 let _ = tx.send(TaskEvent::Completed { task_id, result });
232 }
233
234 Ok(())
235 }
236
237 pub fn fail(&self, task_id: TaskId, error: impl Into<String>) -> Result<(), TaskManagerError> {
239 let mut tasks = self.tasks.write().unwrap();
240
241 let task = tasks
242 .get_mut(&task_id)
243 .ok_or(TaskManagerError::TaskNotFound(task_id))?;
244
245 if task.status.is_terminal() {
246 return Err(TaskManagerError::TaskAlreadyCompleted(task_id));
247 }
248
249 let error_msg = error.into();
250 task.fail(&error_msg);
251
252 let result = TaskResult {
254 task_id,
255 success: false,
256 output: Some(serde_json::json!({ "error": error_msg.clone() })),
257 duration_ms: task.duration_ms().unwrap_or(0),
258 };
259 self.results.write().unwrap().insert(task_id, result);
260
261 if let Some(tx) = self.subscribers.read().unwrap().get(&task_id) {
263 let _ = tx.send(TaskEvent::Failed {
264 task_id,
265 error: error_msg,
266 });
267 }
268
269 Ok(())
270 }
271
272 pub fn kill(&self, task_id: TaskId) -> Result<(), TaskManagerError> {
274 let mut tasks = self.tasks.write().unwrap();
275
276 let task = tasks
277 .get_mut(&task_id)
278 .ok_or(TaskManagerError::TaskNotFound(task_id))?;
279
280 if task.status.is_terminal() {
281 return Err(TaskManagerError::TaskAlreadyCompleted(task_id));
282 }
283
284 task.kill();
285
286 let result = TaskResult {
288 task_id,
289 success: false,
290 output: Some(serde_json::json!({ "error": "Task was killed" })),
291 duration_ms: task.duration_ms().unwrap_or(0),
292 };
293 self.results.write().unwrap().insert(task_id, result);
294
295 if let Some(tx) = self.subscribers.read().unwrap().get(&task_id) {
297 let _ = tx.send(TaskEvent::Killed(task_id));
298 }
299
300 Ok(())
301 }
302
303 pub fn get(&self, task_id: TaskId) -> Option<Task> {
305 self.tasks.read().unwrap().get(&task_id).cloned()
306 }
307
308 pub fn get_result(&self, task_id: TaskId) -> Option<TaskResult> {
310 self.results.read().unwrap().get(&task_id).cloned()
311 }
312
313 pub fn is_terminal(&self, task_id: TaskId) -> bool {
315 self.tasks
316 .read()
317 .unwrap()
318 .get(&task_id)
319 .map(|t| t.is_terminal())
320 .unwrap_or(true)
321 }
322
323 pub fn subscribe(&self, task_id: TaskId) -> Option<broadcast::Receiver<TaskEvent>> {
328 self.subscribers
329 .read()
330 .unwrap()
331 .get(&task_id)
332 .map(|tx| tx.subscribe())
333 }
334
335 pub fn subscribe_all(&self) -> broadcast::Receiver<TaskEvent> {
337 self.global_events.subscribe()
338 }
339
340 pub async fn wait(&self, task_id: TaskId) -> Result<TaskResult, TaskManagerError> {
344 if let Some(result) = self.get_result(task_id) {
346 if !result.success {
347 let error = result
348 .output
349 .as_ref()
350 .and_then(|v| v.get("error"))
351 .and_then(|v| v.as_str())
352 .unwrap_or("Unknown error")
353 .to_string();
354 return Err(TaskManagerError::SendError(error));
355 }
356 return Ok(result);
357 }
358
359 let mut rx = self
361 .subscribe(task_id)
362 .ok_or(TaskManagerError::TaskNotFound(task_id))?;
363
364 while let Ok(event) = rx.recv().await {
366 match event {
367 TaskEvent::Completed {
368 task_id: id,
369 result,
370 } if id == task_id => {
371 return Ok(result);
372 }
373 TaskEvent::Failed { task_id: id, error } if id == task_id => {
374 return Err(TaskManagerError::SendError(error));
375 }
376 TaskEvent::Killed(id) if id == task_id => {
377 return Err(TaskManagerError::SendError("Task was killed".to_string()));
378 }
379 _ => {}
380 }
381 }
382
383 Err(TaskManagerError::Shutdown)
384 }
385
386 pub fn add_child(&self, parent_id: TaskId, child_id: TaskId) -> Result<(), TaskManagerError> {
388 let mut tasks = self.tasks.write().unwrap();
389
390 let parent = tasks
391 .get_mut(&parent_id)
392 .ok_or(TaskManagerError::TaskNotFound(parent_id))?;
393 parent.add_child(child_id);
394
395 if let Some(tx) = self.subscribers.read().unwrap().get(&parent_id) {
397 let _ = tx.send(TaskEvent::ChildSpawned {
398 parent_id,
399 child_id,
400 });
401 }
402
403 Ok(())
404 }
405
406 pub fn spawn_child(&self, parent_id: TaskId, task: Task) -> Result<TaskId, TaskManagerError> {
410 if !self.tasks.read().unwrap().contains_key(&parent_id) {
412 return Err(TaskManagerError::TaskNotFound(parent_id));
413 }
414
415 let child_id = self.spawn(task);
416 self.add_child(parent_id, child_id)?;
417 Ok(child_id)
418 }
419
420 pub async fn wait_children(
424 &self,
425 parent_id: TaskId,
426 ) -> Result<Vec<TaskResult>, TaskManagerError> {
427 let children: Vec<TaskId> = {
428 let tasks = self.tasks.read().unwrap();
429 let parent = tasks
430 .get(&parent_id)
431 .ok_or(TaskManagerError::TaskNotFound(parent_id))?;
432 parent.child_ids.clone()
433 };
434
435 let mut results = Vec::new();
436 for child_id in children {
437 match self.wait(child_id).await {
438 Ok(result) => results.push(result),
439 Err(TaskManagerError::TaskStillRunning(_)) => {
440 let result = self.wait(child_id).await?;
442 results.push(result);
443 }
444 Err(e) => return Err(e),
445 }
446 }
447
448 Ok(results)
449 }
450
451 pub fn get_children(&self, parent_id: TaskId) -> Result<Vec<TaskId>, TaskManagerError> {
453 let tasks = self.tasks.read().unwrap();
454 let parent = tasks
455 .get(&parent_id)
456 .ok_or(TaskManagerError::TaskNotFound(parent_id))?;
457 Ok(parent.child_ids.clone())
458 }
459
460 pub fn all_children_complete(&self, parent_id: TaskId) -> bool {
462 if let Some(children) = self
463 .tasks
464 .read()
465 .unwrap()
466 .get(&parent_id)
467 .map(|t| &t.child_ids)
468 {
469 children.iter().all(|id| self.is_terminal(*id))
470 } else {
471 true
472 }
473 }
474
475 pub fn pending_tasks(&self) -> Vec<TaskId> {
477 self.pending_queue.read().unwrap().iter().copied().collect()
478 }
479
480 pub fn pop_pending(&self) -> Option<TaskId> {
482 self.pending_queue.write().unwrap().pop_front()
483 }
484
485 pub fn shutdown(&self) {
487 *self.shutdown.write().unwrap() = true;
488 self.tasks.write().unwrap().clear();
489 self.results.write().unwrap().clear();
490 self.subscribers.write().unwrap().clear();
491 self.pending_queue.write().unwrap().clear();
492 }
493
494 pub fn all_tasks(&self) -> HashMap<TaskId, Task> {
496 self.tasks.read().unwrap().clone()
497 }
498}
499
500#[cfg(test)]
501mod tests {
502 use super::*;
503 use crate::task::TaskStatus;
504 use serde_json::json;
505
506 #[tokio::test]
507 async fn test_task_manager_spawn_and_wait() {
508 let manager = TaskManager::new();
509
510 let task = Task::tool("read", json!({"file_path": "test.txt"}));
511 let task_id = manager.spawn(task);
512
513 manager.start(task_id).unwrap();
514
515 let result_json = json!({"success": true, "output": "file content"});
516 manager
517 .complete(task_id, Some(result_json.clone()))
518 .unwrap();
519
520 let result = manager.wait(task_id).await.unwrap();
521 assert_eq!(result.output, Some(result_json));
522 assert!(manager.is_terminal(task_id));
523 }
524
525 #[tokio::test]
526 async fn test_task_manager_fail() {
527 let manager = TaskManager::new();
528
529 let task = Task::tool("read", json!({"file_path": "nonexistent.txt"}));
530 let task_id = manager.spawn(task);
531
532 manager.start(task_id).unwrap();
533 manager.fail(task_id, "File not found").unwrap();
534
535 let result = manager.wait(task_id).await;
536 assert!(result.is_err());
537 }
538
539 #[test]
540 fn test_task_manager_kill() {
541 let manager = TaskManager::new();
542
543 let task = Task::tool("bash", json!({"command": "sleep 100"}));
544 let task_id = manager.spawn(task);
545
546 manager.start(task_id).unwrap();
547 manager.kill(task_id).unwrap();
548
549 assert!(manager.is_terminal(task_id));
550 assert_eq!(manager.get(task_id).unwrap().status, TaskStatus::Killed);
551 }
552}