Skip to main content

a3s_code_core/task/
manager.rs

1//! Task Manager - Centralized Task Lifecycle Management
2//!
3//! Provides centralized management for all tasks in a3s-code.
4//! Handles spawning, tracking, and coordinating task execution.
5//!
6//! ## Example
7//!
8//! ```rust,ignore
9//! use a3s_code_core::task::{TaskManager, Task, TaskType};
10//!
11//! let manager = TaskManager::new();
12//!
13//! // Spawn a task
14//! let task = Task::agent("general", "/workspace", "Analyze this");
15//! let task_id = manager.spawn(task);
16//!
17//! // Wait for completion
18//! let result = manager.wait(task_id).await;
19//! assert!(result.is_ok());
20//!
21//! // Subscribe to updates
22//! let mut rx = manager.subscribe(task_id).unwrap();
23//! while let Some(event) = rx.recv().await {
24//!     println!("Task update: {:?}", event);
25//! }
26//! ```
27
28use crate::task::{Task, TaskId};
29use std::collections::{HashMap, VecDeque};
30use std::sync::RwLock;
31use tokio::sync::broadcast;
32
33/// Task lifecycle events
34#[derive(Debug, Clone)]
35pub enum TaskEvent {
36    /// Task was spawned
37    Spawned {
38        task_id: TaskId,
39        parent_id: Option<TaskId>,
40    },
41    /// Task started execution
42    Started(TaskId),
43    /// Task progress update
44    Progress { task_id: TaskId, message: String },
45    /// Task completed
46    Completed { task_id: TaskId, result: TaskResult },
47    /// Task failed
48    Failed { task_id: TaskId, error: String },
49    /// Task was killed
50    Killed(TaskId),
51    /// Child task spawned
52    ChildSpawned { parent_id: TaskId, child_id: TaskId },
53}
54
55/// Task execution result
56#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
57pub struct TaskResult {
58    /// Task ID
59    pub task_id: TaskId,
60    /// Output value (tool result, agent response, etc.)
61    pub output: Option<serde_json::Value>,
62    /// Execution duration in milliseconds
63    pub duration_ms: u64,
64}
65
66/// Task manager errors
67#[derive(Debug, thiserror::Error)]
68pub enum TaskManagerError {
69    #[error("Task not found: {0}")]
70    TaskNotFound(TaskId),
71
72    #[error("Task already completed: {0}")]
73    TaskAlreadyCompleted(TaskId),
74
75    #[error("Task {0} is still running")]
76    TaskStillRunning(TaskId),
77
78    #[error("Manager is shutdown")]
79    Shutdown,
80
81    #[error("Send error: {0}")]
82    SendError(String),
83}
84
85/// Centralized task lifecycle manager
86pub struct TaskManager {
87    /// All tasks indexed by ID
88    tasks: RwLock<HashMap<TaskId, Task>>,
89    /// Task results (kept after completion for retrieval)
90    results: RwLock<HashMap<TaskId, TaskResult>>,
91    /// Event subscribers (one per task)
92    subscribers: RwLock<HashMap<TaskId, broadcast::Sender<TaskEvent>>>,
93    /// Global event broadcast
94    global_events: broadcast::Sender<TaskEvent>,
95    /// Task queue for pending tasks
96    pending_queue: RwLock<VecDeque<TaskId>>,
97    /// Whether manager is shutdown
98    shutdown: RwLock<bool>,
99}
100
101impl Default for TaskManager {
102    fn default() -> Self {
103        Self::new()
104    }
105}
106
107impl TaskManager {
108    /// Create a new task manager
109    pub fn new() -> Self {
110        let (global_events, _) = broadcast::channel(1024);
111        Self {
112            tasks: RwLock::new(HashMap::new()),
113            results: RwLock::new(HashMap::new()),
114            subscribers: RwLock::new(HashMap::new()),
115            global_events,
116            pending_queue: RwLock::new(VecDeque::new()),
117            shutdown: RwLock::new(false),
118        }
119    }
120
121    /// Spawn a new task
122    ///
123    /// Returns the task ID.
124    pub fn spawn(&self, task: Task) -> TaskId {
125        if *self.shutdown.read().unwrap() {
126            // Return a dummy ID, caller should check
127            return task.id;
128        }
129
130        let task_id = task.id;
131        let parent_id = task.parent_id;
132
133        // Create subscription channel for this task
134        let (tx, _) = broadcast::channel(64);
135        self.subscribers.write().unwrap().insert(task_id, tx);
136
137        // Add to tasks
138        self.tasks.write().unwrap().insert(task_id, task);
139
140        // Add to pending queue
141        self.pending_queue.write().unwrap().push_back(task_id);
142
143        // Emit spawn event
144        let _ = self
145            .global_events
146            .send(TaskEvent::Spawned { task_id, parent_id });
147
148        task_id
149    }
150
151    /// Start executing a task
152    pub fn start(&self, task_id: TaskId) -> Result<(), TaskManagerError> {
153        let mut tasks = self.tasks.write().unwrap();
154
155        let task = tasks
156            .get_mut(&task_id)
157            .ok_or(TaskManagerError::TaskNotFound(task_id))?;
158
159        if task.status.is_terminal() {
160            return Err(TaskManagerError::TaskAlreadyCompleted(task_id));
161        }
162
163        task.start();
164
165        // Notify subscribers
166        if let Some(tx) = self.subscribers.read().unwrap().get(&task_id) {
167            let _ = tx.send(TaskEvent::Started(task_id));
168        }
169
170        Ok(())
171    }
172
173    /// Update task progress
174    pub fn progress(
175        &self,
176        task_id: TaskId,
177        message: impl Into<String>,
178    ) -> Result<(), TaskManagerError> {
179        let tasks = self.tasks.read().unwrap();
180
181        if !tasks.contains_key(&task_id) {
182            return Err(TaskManagerError::TaskNotFound(task_id));
183        }
184
185        // Notify subscribers
186        if let Some(tx) = self.subscribers.read().unwrap().get(&task_id) {
187            let _ = tx.send(TaskEvent::Progress {
188                task_id,
189                message: message.into(),
190            });
191        }
192
193        Ok(())
194    }
195
196    /// Complete a task with result
197    pub fn complete(
198        &self,
199        task_id: TaskId,
200        output: Option<serde_json::Value>,
201    ) -> Result<(), TaskManagerError> {
202        let mut tasks = self.tasks.write().unwrap();
203
204        let task = tasks
205            .get_mut(&task_id)
206            .ok_or(TaskManagerError::TaskNotFound(task_id))?;
207
208        if task.status.is_terminal() {
209            return Err(TaskManagerError::TaskAlreadyCompleted(task_id));
210        }
211
212        let duration_ms = task.duration_ms().unwrap_or(0);
213        task.complete();
214
215        // Store result
216        let result = TaskResult {
217            task_id,
218            output,
219            duration_ms,
220        };
221        self.results
222            .write()
223            .unwrap()
224            .insert(task_id, result.clone());
225
226        // Notify subscribers
227        if let Some(tx) = self.subscribers.read().unwrap().get(&task_id) {
228            let _ = tx.send(TaskEvent::Completed { task_id, result });
229        }
230
231        Ok(())
232    }
233
234    /// Fail a task with error
235    pub fn fail(&self, task_id: TaskId, error: impl Into<String>) -> Result<(), TaskManagerError> {
236        let mut tasks = self.tasks.write().unwrap();
237
238        let task = tasks
239            .get_mut(&task_id)
240            .ok_or(TaskManagerError::TaskNotFound(task_id))?;
241
242        if task.status.is_terminal() {
243            return Err(TaskManagerError::TaskAlreadyCompleted(task_id));
244        }
245
246        let error_msg = error.into();
247        task.fail(&error_msg);
248
249        // Notify subscribers
250        if let Some(tx) = self.subscribers.read().unwrap().get(&task_id) {
251            let _ = tx.send(TaskEvent::Failed {
252                task_id,
253                error: error_msg,
254            });
255        }
256
257        Ok(())
258    }
259
260    /// Kill a task
261    pub fn kill(&self, task_id: TaskId) -> Result<(), TaskManagerError> {
262        let mut tasks = self.tasks.write().unwrap();
263
264        let task = tasks
265            .get_mut(&task_id)
266            .ok_or(TaskManagerError::TaskNotFound(task_id))?;
267
268        if task.status.is_terminal() {
269            return Err(TaskManagerError::TaskAlreadyCompleted(task_id));
270        }
271
272        task.kill();
273
274        // Notify subscribers
275        if let Some(tx) = self.subscribers.read().unwrap().get(&task_id) {
276            let _ = tx.send(TaskEvent::Killed(task_id));
277        }
278
279        Ok(())
280    }
281
282    /// Get task by ID
283    pub fn get(&self, task_id: TaskId) -> Option<Task> {
284        self.tasks.read().unwrap().get(&task_id).cloned()
285    }
286
287    /// Get task result
288    pub fn get_result(&self, task_id: TaskId) -> Option<TaskResult> {
289        self.results.read().unwrap().get(&task_id).cloned()
290    }
291
292    /// Check if task is terminal
293    pub fn is_terminal(&self, task_id: TaskId) -> bool {
294        self.tasks
295            .read()
296            .unwrap()
297            .get(&task_id)
298            .map(|t| t.is_terminal())
299            .unwrap_or(true)
300    }
301
302    /// Subscribe to task events
303    ///
304    /// Returns a receiver that will receive all events for the task.
305    /// The receiver is closed when the task completes.
306    pub fn subscribe(&self, task_id: TaskId) -> Option<broadcast::Receiver<TaskEvent>> {
307        self.subscribers
308            .read()
309            .unwrap()
310            .get(&task_id)
311            .map(|tx| tx.subscribe())
312    }
313
314    /// Subscribe to all task events (global)
315    pub fn subscribe_all(&self) -> broadcast::Receiver<TaskEvent> {
316        self.global_events.subscribe()
317    }
318
319    /// Wait for a task to complete
320    ///
321    /// Returns the task result when the task completes.
322    pub async fn wait(&self, task_id: TaskId) -> Result<TaskResult, TaskManagerError> {
323        // First check if already completed
324        if let Some(result) = self.get_result(task_id) {
325            return Ok(result);
326        }
327
328        // Subscribe to events
329        let mut rx = self
330            .subscribe(task_id)
331            .ok_or(TaskManagerError::TaskNotFound(task_id))?;
332
333        // Wait for completion
334        while let Ok(event) = rx.recv().await {
335            match event {
336                TaskEvent::Completed {
337                    task_id: id,
338                    result,
339                } if id == task_id => {
340                    return Ok(result);
341                }
342                TaskEvent::Failed { task_id: id, error } if id == task_id => {
343                    return Err(TaskManagerError::SendError(error));
344                }
345                TaskEvent::Killed(id) if id == task_id => {
346                    return Err(TaskManagerError::SendError("Task was killed".to_string()));
347                }
348                _ => {}
349            }
350        }
351
352        Err(TaskManagerError::Shutdown)
353    }
354
355    /// Add a child task to a parent
356    pub fn add_child(&self, parent_id: TaskId, child_id: TaskId) -> Result<(), TaskManagerError> {
357        let mut tasks = self.tasks.write().unwrap();
358
359        let parent = tasks
360            .get_mut(&parent_id)
361            .ok_or(TaskManagerError::TaskNotFound(parent_id))?;
362        parent.add_child(child_id);
363
364        // Notify about child spawn
365        if let Some(tx) = self.subscribers.read().unwrap().get(&parent_id) {
366            let _ = tx.send(TaskEvent::ChildSpawned {
367                parent_id,
368                child_id,
369            });
370        }
371
372        Ok(())
373    }
374
375    /// Spawn a child task with the given parent ID, setting up parent-child relationship.
376    ///
377    /// This is a convenience method that combines spawn + add_child in one call.
378    pub fn spawn_child(&self, parent_id: TaskId, task: Task) -> Result<TaskId, TaskManagerError> {
379        // Verify parent exists
380        if !self.tasks.read().unwrap().contains_key(&parent_id) {
381            return Err(TaskManagerError::TaskNotFound(parent_id));
382        }
383
384        let child_id = self.spawn(task);
385        self.add_child(parent_id, child_id)?;
386        Ok(child_id)
387    }
388
389    /// Wait for all child tasks of a parent to complete.
390    ///
391    /// Returns a list of results for each child task.
392    pub async fn wait_children(
393        &self,
394        parent_id: TaskId,
395    ) -> Result<Vec<TaskResult>, TaskManagerError> {
396        let children: Vec<TaskId> = {
397            let tasks = self.tasks.read().unwrap();
398            let parent = tasks
399                .get(&parent_id)
400                .ok_or(TaskManagerError::TaskNotFound(parent_id))?;
401            parent.child_ids.clone()
402        };
403
404        let mut results = Vec::new();
405        for child_id in children {
406            match self.wait(child_id).await {
407                Ok(result) => results.push(result),
408                Err(TaskManagerError::TaskStillRunning(_)) => {
409                    // Child still running, wait for it
410                    let result = self.wait(child_id).await?;
411                    results.push(result);
412                }
413                Err(e) => return Err(e),
414            }
415        }
416
417        Ok(results)
418    }
419
420    /// Get all child task IDs for a parent task.
421    pub fn get_children(&self, parent_id: TaskId) -> Result<Vec<TaskId>, TaskManagerError> {
422        let tasks = self.tasks.read().unwrap();
423        let parent = tasks
424            .get(&parent_id)
425            .ok_or(TaskManagerError::TaskNotFound(parent_id))?;
426        Ok(parent.child_ids.clone())
427    }
428
429    /// Check if all children of a parent task are complete.
430    pub fn all_children_complete(&self, parent_id: TaskId) -> bool {
431        if let Some(children) = self
432            .tasks
433            .read()
434            .unwrap()
435            .get(&parent_id)
436            .map(|t| &t.child_ids)
437        {
438            children.iter().all(|id| self.is_terminal(*id))
439        } else {
440            true
441        }
442    }
443
444    /// Get pending task IDs (FIFO)
445    pub fn pending_tasks(&self) -> Vec<TaskId> {
446        self.pending_queue.read().unwrap().iter().copied().collect()
447    }
448
449    /// Pop the next pending task
450    pub fn pop_pending(&self) -> Option<TaskId> {
451        self.pending_queue.write().unwrap().pop_front()
452    }
453
454    /// Shutdown the manager
455    pub fn shutdown(&self) {
456        *self.shutdown.write().unwrap() = true;
457        self.tasks.write().unwrap().clear();
458        self.results.write().unwrap().clear();
459        self.subscribers.write().unwrap().clear();
460        self.pending_queue.write().unwrap().clear();
461    }
462
463    /// Get all tasks
464    pub fn all_tasks(&self) -> HashMap<TaskId, Task> {
465        self.tasks.read().unwrap().clone()
466    }
467}
468
469#[cfg(test)]
470mod tests {
471    use super::*;
472    use crate::task::TaskStatus;
473    use serde_json::json;
474
475    #[tokio::test]
476    async fn test_task_manager_spawn_and_wait() {
477        let manager = TaskManager::new();
478
479        let task = Task::tool("read", json!({"file_path": "test.txt"}));
480        let task_id = manager.spawn(task);
481
482        manager.start(task_id).unwrap();
483
484        let result_json = json!({"success": true, "output": "file content"});
485        manager
486            .complete(task_id, Some(result_json.clone()))
487            .unwrap();
488
489        let result = manager.wait(task_id).await.unwrap();
490        assert_eq!(result.output, Some(result_json));
491        assert!(manager.is_terminal(task_id));
492    }
493
494    #[tokio::test]
495    async fn test_task_manager_fail() {
496        let manager = TaskManager::new();
497
498        let task = Task::tool("read", json!({"file_path": "nonexistent.txt"}));
499        let task_id = manager.spawn(task);
500
501        manager.start(task_id).unwrap();
502        manager.fail(task_id, "File not found").unwrap();
503
504        let result = manager.wait(task_id).await;
505        assert!(result.is_err());
506    }
507
508    #[test]
509    fn test_task_manager_kill() {
510        let manager = TaskManager::new();
511
512        let task = Task::tool("bash", json!({"command": "sleep 100"}));
513        let task_id = manager.spawn(task);
514
515        manager.start(task_id).unwrap();
516        manager.kill(task_id).unwrap();
517
518        assert!(manager.is_terminal(task_id));
519        assert_eq!(manager.get(task_id).unwrap().status, TaskStatus::Killed);
520    }
521}