Skip to main content

a3s_code_core/task/
manager.rs

1//! Task Manager - Centralized Task Lifecycle Management
2//!
3//! Provides centralized management for all tasks in a3s-code.
4//! Handles spawning, tracking, and coordinating task execution.
5//!
6//! ## Example
7//!
8//! ```rust,ignore
9//! use a3s_code_core::task::{TaskManager, Task, TaskType};
10//!
11//! let manager = TaskManager::new();
12//!
13//! // Spawn a task
14//! let task = Task::agent("general", "/workspace", "Analyze this");
15//! let task_id = manager.spawn(task);
16//!
17//! // Wait for completion
18//! let result = manager.wait(task_id).await;
19//! assert!(result.is_ok());
20//!
21//! // Subscribe to updates
22//! let mut rx = manager.subscribe(task_id).unwrap();
23//! while let Some(event) = rx.recv().await {
24//!     println!("Task update: {:?}", event);
25//! }
26//! ```
27
28use crate::task::{Task, TaskId};
29use std::collections::{HashMap, VecDeque};
30use std::sync::RwLock;
31use tokio::sync::broadcast;
32
33/// Task lifecycle events
34#[derive(Debug, Clone)]
35pub enum TaskEvent {
36    /// Task was spawned
37    Spawned {
38        task_id: TaskId,
39        parent_id: Option<TaskId>,
40    },
41    /// Task started execution
42    Started(TaskId),
43    /// Task progress update
44    Progress { task_id: TaskId, message: String },
45    /// Task completed
46    Completed { task_id: TaskId, result: TaskResult },
47    /// Task failed
48    Failed { task_id: TaskId, error: String },
49    /// Task was killed
50    Killed(TaskId),
51    /// Child task spawned
52    ChildSpawned { parent_id: TaskId, child_id: TaskId },
53}
54
55/// Task execution result
56#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
57pub struct TaskResult {
58    /// Task ID
59    pub task_id: TaskId,
60    /// Whether the task completed successfully
61    pub success: bool,
62    /// Output value (tool result, agent response, etc.)
63    pub output: Option<serde_json::Value>,
64    /// Execution duration in milliseconds
65    pub duration_ms: u64,
66}
67
68/// Task manager errors
69#[derive(Debug, thiserror::Error)]
70pub enum TaskManagerError {
71    #[error("Task not found: {0}")]
72    TaskNotFound(TaskId),
73
74    #[error("Task already completed: {0}")]
75    TaskAlreadyCompleted(TaskId),
76
77    #[error("Task {0} is still running")]
78    TaskStillRunning(TaskId),
79
80    #[error("Manager is shutdown")]
81    Shutdown,
82
83    #[error("Send error: {0}")]
84    SendError(String),
85}
86
87/// Centralized task lifecycle manager
88pub struct TaskManager {
89    /// All tasks indexed by ID
90    tasks: RwLock<HashMap<TaskId, Task>>,
91    /// Task results (kept after completion for retrieval)
92    results: RwLock<HashMap<TaskId, TaskResult>>,
93    /// Event subscribers (one per task)
94    subscribers: RwLock<HashMap<TaskId, broadcast::Sender<TaskEvent>>>,
95    /// Global event broadcast
96    global_events: broadcast::Sender<TaskEvent>,
97    /// Task queue for pending tasks
98    pending_queue: RwLock<VecDeque<TaskId>>,
99    /// Whether manager is shutdown
100    shutdown: RwLock<bool>,
101}
102
103impl Default for TaskManager {
104    fn default() -> Self {
105        Self::new()
106    }
107}
108
109impl TaskManager {
110    /// Create a new task manager
111    pub fn new() -> Self {
112        let (global_events, _) = broadcast::channel(1024);
113        Self {
114            tasks: RwLock::new(HashMap::new()),
115            results: RwLock::new(HashMap::new()),
116            subscribers: RwLock::new(HashMap::new()),
117            global_events,
118            pending_queue: RwLock::new(VecDeque::new()),
119            shutdown: RwLock::new(false),
120        }
121    }
122
123    /// Spawn a new task
124    ///
125    /// Returns the task ID.
126    pub fn spawn(&self, task: Task) -> TaskId {
127        if *self.shutdown.read().unwrap() {
128            // Return a dummy ID, caller should check
129            return task.id;
130        }
131
132        let task_id = task.id;
133        let parent_id = task.parent_id;
134
135        // Create subscription channel for this task
136        let (tx, _) = broadcast::channel(64);
137        self.subscribers.write().unwrap().insert(task_id, tx);
138
139        // Add to tasks
140        self.tasks.write().unwrap().insert(task_id, task);
141
142        // Add to pending queue
143        self.pending_queue.write().unwrap().push_back(task_id);
144
145        // Emit spawn event
146        let _ = self
147            .global_events
148            .send(TaskEvent::Spawned { task_id, parent_id });
149
150        task_id
151    }
152
153    /// Start executing a task
154    pub fn start(&self, task_id: TaskId) -> Result<(), TaskManagerError> {
155        let mut tasks = self.tasks.write().unwrap();
156
157        let task = tasks
158            .get_mut(&task_id)
159            .ok_or(TaskManagerError::TaskNotFound(task_id))?;
160
161        if task.status.is_terminal() {
162            return Err(TaskManagerError::TaskAlreadyCompleted(task_id));
163        }
164
165        task.start();
166
167        // Notify subscribers
168        if let Some(tx) = self.subscribers.read().unwrap().get(&task_id) {
169            let _ = tx.send(TaskEvent::Started(task_id));
170        }
171
172        Ok(())
173    }
174
175    /// Update task progress
176    pub fn progress(
177        &self,
178        task_id: TaskId,
179        message: impl Into<String>,
180    ) -> Result<(), TaskManagerError> {
181        let tasks = self.tasks.read().unwrap();
182
183        if !tasks.contains_key(&task_id) {
184            return Err(TaskManagerError::TaskNotFound(task_id));
185        }
186
187        // Notify subscribers
188        if let Some(tx) = self.subscribers.read().unwrap().get(&task_id) {
189            let _ = tx.send(TaskEvent::Progress {
190                task_id,
191                message: message.into(),
192            });
193        }
194
195        Ok(())
196    }
197
198    /// Complete a task with result
199    pub fn complete(
200        &self,
201        task_id: TaskId,
202        output: Option<serde_json::Value>,
203    ) -> Result<(), TaskManagerError> {
204        let mut tasks = self.tasks.write().unwrap();
205
206        let task = tasks
207            .get_mut(&task_id)
208            .ok_or(TaskManagerError::TaskNotFound(task_id))?;
209
210        if task.status.is_terminal() {
211            return Err(TaskManagerError::TaskAlreadyCompleted(task_id));
212        }
213
214        let duration_ms = task.duration_ms().unwrap_or(0);
215        task.complete();
216
217        // Store result
218        let result = TaskResult {
219            task_id,
220            success: true,
221            output,
222            duration_ms,
223        };
224        self.results
225            .write()
226            .unwrap()
227            .insert(task_id, result.clone());
228
229        // Notify subscribers
230        if let Some(tx) = self.subscribers.read().unwrap().get(&task_id) {
231            let _ = tx.send(TaskEvent::Completed { task_id, result });
232        }
233
234        Ok(())
235    }
236
237    /// Fail a task with error
238    pub fn fail(&self, task_id: TaskId, error: impl Into<String>) -> Result<(), TaskManagerError> {
239        let mut tasks = self.tasks.write().unwrap();
240
241        let task = tasks
242            .get_mut(&task_id)
243            .ok_or(TaskManagerError::TaskNotFound(task_id))?;
244
245        if task.status.is_terminal() {
246            return Err(TaskManagerError::TaskAlreadyCompleted(task_id));
247        }
248
249        let error_msg = error.into();
250        task.fail(&error_msg);
251
252        // Store error result so wait() can retrieve it even if called after fail
253        let result = TaskResult {
254            task_id,
255            success: false,
256            output: Some(serde_json::json!({ "error": error_msg.clone() })),
257            duration_ms: task.duration_ms().unwrap_or(0),
258        };
259        self.results.write().unwrap().insert(task_id, result);
260
261        // Notify subscribers
262        if let Some(tx) = self.subscribers.read().unwrap().get(&task_id) {
263            let _ = tx.send(TaskEvent::Failed {
264                task_id,
265                error: error_msg,
266            });
267        }
268
269        Ok(())
270    }
271
272    /// Kill a task
273    pub fn kill(&self, task_id: TaskId) -> Result<(), TaskManagerError> {
274        let mut tasks = self.tasks.write().unwrap();
275
276        let task = tasks
277            .get_mut(&task_id)
278            .ok_or(TaskManagerError::TaskNotFound(task_id))?;
279
280        if task.status.is_terminal() {
281            return Err(TaskManagerError::TaskAlreadyCompleted(task_id));
282        }
283
284        task.kill();
285
286        // Store killed result so wait() can retrieve it even if called after kill
287        let result = TaskResult {
288            task_id,
289            success: false,
290            output: Some(serde_json::json!({ "error": "Task was killed" })),
291            duration_ms: task.duration_ms().unwrap_or(0),
292        };
293        self.results.write().unwrap().insert(task_id, result);
294
295        // Notify subscribers
296        if let Some(tx) = self.subscribers.read().unwrap().get(&task_id) {
297            let _ = tx.send(TaskEvent::Killed(task_id));
298        }
299
300        Ok(())
301    }
302
303    /// Get task by ID
304    pub fn get(&self, task_id: TaskId) -> Option<Task> {
305        self.tasks.read().unwrap().get(&task_id).cloned()
306    }
307
308    /// Get task result
309    pub fn get_result(&self, task_id: TaskId) -> Option<TaskResult> {
310        self.results.read().unwrap().get(&task_id).cloned()
311    }
312
313    /// Check if task is terminal
314    pub fn is_terminal(&self, task_id: TaskId) -> bool {
315        self.tasks
316            .read()
317            .unwrap()
318            .get(&task_id)
319            .map(|t| t.is_terminal())
320            .unwrap_or(true)
321    }
322
323    /// Subscribe to task events
324    ///
325    /// Returns a receiver that will receive all events for the task.
326    /// The receiver is closed when the task completes.
327    pub fn subscribe(&self, task_id: TaskId) -> Option<broadcast::Receiver<TaskEvent>> {
328        self.subscribers
329            .read()
330            .unwrap()
331            .get(&task_id)
332            .map(|tx| tx.subscribe())
333    }
334
335    /// Subscribe to all task events (global)
336    pub fn subscribe_all(&self) -> broadcast::Receiver<TaskEvent> {
337        self.global_events.subscribe()
338    }
339
340    /// Wait for a task to complete
341    ///
342    /// Returns the task result when the task completes.
343    pub async fn wait(&self, task_id: TaskId) -> Result<TaskResult, TaskManagerError> {
344        // First check if already completed
345        if let Some(result) = self.get_result(task_id) {
346            if !result.success {
347                let error = result
348                    .output
349                    .as_ref()
350                    .and_then(|v| v.get("error"))
351                    .and_then(|v| v.as_str())
352                    .unwrap_or("Unknown error")
353                    .to_string();
354                return Err(TaskManagerError::SendError(error));
355            }
356            return Ok(result);
357        }
358
359        // Subscribe to events
360        let mut rx = self
361            .subscribe(task_id)
362            .ok_or(TaskManagerError::TaskNotFound(task_id))?;
363
364        // Wait for completion
365        while let Ok(event) = rx.recv().await {
366            match event {
367                TaskEvent::Completed {
368                    task_id: id,
369                    result,
370                } if id == task_id => {
371                    return Ok(result);
372                }
373                TaskEvent::Failed { task_id: id, error } if id == task_id => {
374                    return Err(TaskManagerError::SendError(error));
375                }
376                TaskEvent::Killed(id) if id == task_id => {
377                    return Err(TaskManagerError::SendError("Task was killed".to_string()));
378                }
379                _ => {}
380            }
381        }
382
383        Err(TaskManagerError::Shutdown)
384    }
385
386    /// Add a child task to a parent
387    pub fn add_child(&self, parent_id: TaskId, child_id: TaskId) -> Result<(), TaskManagerError> {
388        let mut tasks = self.tasks.write().unwrap();
389
390        let parent = tasks
391            .get_mut(&parent_id)
392            .ok_or(TaskManagerError::TaskNotFound(parent_id))?;
393        parent.add_child(child_id);
394
395        // Notify about child spawn
396        if let Some(tx) = self.subscribers.read().unwrap().get(&parent_id) {
397            let _ = tx.send(TaskEvent::ChildSpawned {
398                parent_id,
399                child_id,
400            });
401        }
402
403        Ok(())
404    }
405
406    /// Spawn a child task with the given parent ID, setting up parent-child relationship.
407    ///
408    /// This is a convenience method that combines spawn + add_child in one call.
409    pub fn spawn_child(&self, parent_id: TaskId, task: Task) -> Result<TaskId, TaskManagerError> {
410        // Verify parent exists
411        if !self.tasks.read().unwrap().contains_key(&parent_id) {
412            return Err(TaskManagerError::TaskNotFound(parent_id));
413        }
414
415        let child_id = self.spawn(task);
416        self.add_child(parent_id, child_id)?;
417        Ok(child_id)
418    }
419
420    /// Wait for all child tasks of a parent to complete.
421    ///
422    /// Returns a list of results for each child task.
423    pub async fn wait_children(
424        &self,
425        parent_id: TaskId,
426    ) -> Result<Vec<TaskResult>, TaskManagerError> {
427        let children: Vec<TaskId> = {
428            let tasks = self.tasks.read().unwrap();
429            let parent = tasks
430                .get(&parent_id)
431                .ok_or(TaskManagerError::TaskNotFound(parent_id))?;
432            parent.child_ids.clone()
433        };
434
435        let mut results = Vec::new();
436        for child_id in children {
437            match self.wait(child_id).await {
438                Ok(result) => results.push(result),
439                Err(TaskManagerError::TaskStillRunning(_)) => {
440                    // Child still running, wait for it
441                    let result = self.wait(child_id).await?;
442                    results.push(result);
443                }
444                Err(e) => return Err(e),
445            }
446        }
447
448        Ok(results)
449    }
450
451    /// Get all child task IDs for a parent task.
452    pub fn get_children(&self, parent_id: TaskId) -> Result<Vec<TaskId>, TaskManagerError> {
453        let tasks = self.tasks.read().unwrap();
454        let parent = tasks
455            .get(&parent_id)
456            .ok_or(TaskManagerError::TaskNotFound(parent_id))?;
457        Ok(parent.child_ids.clone())
458    }
459
460    /// Check if all children of a parent task are complete.
461    pub fn all_children_complete(&self, parent_id: TaskId) -> bool {
462        if let Some(children) = self
463            .tasks
464            .read()
465            .unwrap()
466            .get(&parent_id)
467            .map(|t| &t.child_ids)
468        {
469            children.iter().all(|id| self.is_terminal(*id))
470        } else {
471            true
472        }
473    }
474
475    /// Get pending task IDs (FIFO)
476    pub fn pending_tasks(&self) -> Vec<TaskId> {
477        self.pending_queue.read().unwrap().iter().copied().collect()
478    }
479
480    /// Pop the next pending task
481    pub fn pop_pending(&self) -> Option<TaskId> {
482        self.pending_queue.write().unwrap().pop_front()
483    }
484
485    /// Shutdown the manager
486    pub fn shutdown(&self) {
487        *self.shutdown.write().unwrap() = true;
488        self.tasks.write().unwrap().clear();
489        self.results.write().unwrap().clear();
490        self.subscribers.write().unwrap().clear();
491        self.pending_queue.write().unwrap().clear();
492    }
493
494    /// Get all tasks
495    pub fn all_tasks(&self) -> HashMap<TaskId, Task> {
496        self.tasks.read().unwrap().clone()
497    }
498}
499
500#[cfg(test)]
501mod tests {
502    use super::*;
503    use crate::task::TaskStatus;
504    use serde_json::json;
505
506    #[tokio::test]
507    async fn test_task_manager_spawn_and_wait() {
508        let manager = TaskManager::new();
509
510        let task = Task::tool("read", json!({"file_path": "test.txt"}));
511        let task_id = manager.spawn(task);
512
513        manager.start(task_id).unwrap();
514
515        let result_json = json!({"success": true, "output": "file content"});
516        manager
517            .complete(task_id, Some(result_json.clone()))
518            .unwrap();
519
520        let result = manager.wait(task_id).await.unwrap();
521        assert_eq!(result.output, Some(result_json));
522        assert!(manager.is_terminal(task_id));
523    }
524
525    #[tokio::test]
526    async fn test_task_manager_fail() {
527        let manager = TaskManager::new();
528
529        let task = Task::tool("read", json!({"file_path": "nonexistent.txt"}));
530        let task_id = manager.spawn(task);
531
532        manager.start(task_id).unwrap();
533        manager.fail(task_id, "File not found").unwrap();
534
535        let result = manager.wait(task_id).await;
536        assert!(result.is_err());
537    }
538
539    #[test]
540    fn test_task_manager_kill() {
541        let manager = TaskManager::new();
542
543        let task = Task::tool("bash", json!({"command": "sleep 100"}));
544        let task_id = manager.spawn(task);
545
546        manager.start(task_id).unwrap();
547        manager.kill(task_id).unwrap();
548
549        assert!(manager.is_terminal(task_id));
550        assert_eq!(manager.get(task_id).unwrap().status, TaskStatus::Killed);
551    }
552}