Skip to main content

cersei_tools/
tasks.rs

1//! Task system: create, track, update, and manage background tasks.
2//!
3//! Tasks represent long-running sub-agent work that runs asynchronously.
4//! The coordinator can create tasks, check their status, and retrieve output.
5
6use super::*;
7use serde::{Deserialize, Serialize};
8
9// ─── Task registry ───────────────────────────────────────────────────────────
10
11static TASK_REGISTRY: once_cell::sync::Lazy<dashmap::DashMap<String, TaskEntry>> =
12    once_cell::sync::Lazy::new(dashmap::DashMap::new);
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct TaskEntry {
16    pub id: String,
17    pub description: String,
18    pub status: TaskStatus,
19    pub output: Option<String>,
20    pub created_at: String,
21    pub updated_at: String,
22    pub session_id: String,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
26#[serde(rename_all = "snake_case")]
27pub enum TaskStatus {
28    Pending,
29    Running,
30    Completed,
31    Failed,
32    Stopped,
33}
34
35pub fn get_task(id: &str) -> Option<TaskEntry> {
36    TASK_REGISTRY.get(id).map(|e| e.clone())
37}
38
39pub fn list_tasks() -> Vec<TaskEntry> {
40    TASK_REGISTRY.iter().map(|e| e.value().clone()).collect()
41}
42
43pub fn clear_tasks() {
44    TASK_REGISTRY.clear();
45}
46
47// ─── TaskCreate ──────────────────────────────────────────────────────────────
48
49pub struct TaskCreateTool;
50
51#[async_trait]
52impl Tool for TaskCreateTool {
53    fn name(&self) -> &str {
54        "TaskCreate"
55    }
56    fn description(&self) -> &str {
57        "Create a new task for tracking sub-agent work."
58    }
59    fn permission_level(&self) -> PermissionLevel {
60        PermissionLevel::None
61    }
62    fn category(&self) -> ToolCategory {
63        ToolCategory::Orchestration
64    }
65
66    fn input_schema(&self) -> Value {
67        serde_json::json!({
68            "type": "object",
69            "properties": {
70                "description": { "type": "string", "description": "What this task does" },
71                "prompt": { "type": "string", "description": "The prompt for the sub-agent (optional)" }
72            },
73            "required": ["description"]
74        })
75    }
76
77    async fn execute(&self, input: Value, ctx: &ToolContext) -> ToolResult {
78        #[derive(Deserialize)]
79        #[allow(dead_code)]
80        struct Input {
81            description: String,
82            prompt: Option<String>,
83        }
84
85        let input: Input = match serde_json::from_value(input) {
86            Ok(i) => i,
87            Err(e) => return ToolResult::error(format!("Invalid input: {}", e)),
88        };
89
90        let id = uuid::Uuid::new_v4().to_string()[..8].to_string();
91        let now = chrono::Utc::now().to_rfc3339();
92        let task = TaskEntry {
93            id: id.clone(),
94            description: input.description.clone(),
95            status: TaskStatus::Pending,
96            output: None,
97            created_at: now.clone(),
98            updated_at: now,
99            session_id: ctx.session_id.clone(),
100        };
101        TASK_REGISTRY.insert(id.clone(), task);
102        ToolResult::success(format!("Task '{}' created: {}", id, input.description))
103    }
104}
105
106// ─── TaskGet ─────────────────────────────────────────────────────────────────
107
108pub struct TaskGetTool;
109
110#[async_trait]
111impl Tool for TaskGetTool {
112    fn name(&self) -> &str {
113        "TaskGet"
114    }
115    fn description(&self) -> &str {
116        "Get the status and output of a task by ID."
117    }
118    fn permission_level(&self) -> PermissionLevel {
119        PermissionLevel::None
120    }
121    fn category(&self) -> ToolCategory {
122        ToolCategory::Orchestration
123    }
124
125    fn input_schema(&self) -> Value {
126        serde_json::json!({
127            "type": "object",
128            "properties": {
129                "id": { "type": "string", "description": "Task ID" }
130            },
131            "required": ["id"]
132        })
133    }
134
135    async fn execute(&self, input: Value, _ctx: &ToolContext) -> ToolResult {
136        #[derive(Deserialize)]
137        struct Input {
138            id: String,
139        }
140
141        let input: Input = match serde_json::from_value(input) {
142            Ok(i) => i,
143            Err(e) => return ToolResult::error(format!("Invalid input: {}", e)),
144        };
145
146        match get_task(&input.id) {
147            Some(task) => {
148                let output = task.output.as_deref().unwrap_or("(no output yet)");
149                ToolResult::success(format!(
150                    "Task [{}] {:?}\n  {}\n  Output: {}",
151                    task.id, task.status, task.description, output
152                ))
153            }
154            None => ToolResult::error(format!("Task '{}' not found", input.id)),
155        }
156    }
157}
158
159// ─── TaskUpdate ──────────────────────────────────────────────────────────────
160
161pub struct TaskUpdateTool;
162
163#[async_trait]
164impl Tool for TaskUpdateTool {
165    fn name(&self) -> &str {
166        "TaskUpdate"
167    }
168    fn description(&self) -> &str {
169        "Update a task's status and/or output."
170    }
171    fn permission_level(&self) -> PermissionLevel {
172        PermissionLevel::None
173    }
174    fn category(&self) -> ToolCategory {
175        ToolCategory::Orchestration
176    }
177
178    fn input_schema(&self) -> Value {
179        serde_json::json!({
180            "type": "object",
181            "properties": {
182                "id": { "type": "string", "description": "Task ID" },
183                "status": { "type": "string", "enum": ["pending", "running", "completed", "failed", "stopped"] },
184                "output": { "type": "string", "description": "Task output/result text" }
185            },
186            "required": ["id"]
187        })
188    }
189
190    async fn execute(&self, input: Value, _ctx: &ToolContext) -> ToolResult {
191        #[derive(Deserialize)]
192        struct Input {
193            id: String,
194            status: Option<TaskStatus>,
195            output: Option<String>,
196        }
197
198        let input: Input = match serde_json::from_value(input) {
199            Ok(i) => i,
200            Err(e) => return ToolResult::error(format!("Invalid input: {}", e)),
201        };
202
203        match TASK_REGISTRY.get_mut(&input.id) {
204            Some(mut entry) => {
205                if let Some(status) = input.status {
206                    entry.status = status;
207                }
208                if let Some(output) = input.output {
209                    entry.output = Some(output);
210                }
211                entry.updated_at = chrono::Utc::now().to_rfc3339();
212                ToolResult::success(format!("Task '{}' updated", input.id))
213            }
214            None => ToolResult::error(format!("Task '{}' not found", input.id)),
215        }
216    }
217}
218
219// ─── TaskList ────────────────────────────────────────────────────────────────
220
221pub struct TaskListTool;
222
223#[async_trait]
224impl Tool for TaskListTool {
225    fn name(&self) -> &str {
226        "TaskList"
227    }
228    fn description(&self) -> &str {
229        "List all tasks with their status."
230    }
231    fn permission_level(&self) -> PermissionLevel {
232        PermissionLevel::None
233    }
234    fn category(&self) -> ToolCategory {
235        ToolCategory::Orchestration
236    }
237
238    fn input_schema(&self) -> Value {
239        serde_json::json!({"type": "object", "properties": {}, "required": []})
240    }
241
242    async fn execute(&self, _input: Value, _ctx: &ToolContext) -> ToolResult {
243        let tasks = list_tasks();
244        if tasks.is_empty() {
245            return ToolResult::success("No tasks.");
246        }
247        let lines: Vec<String> = tasks
248            .iter()
249            .map(|t| {
250                let status = format!("{:?}", t.status);
251                format!("- [{}] {} — {}", t.id, status, t.description)
252            })
253            .collect();
254        ToolResult::success(lines.join("\n"))
255    }
256}
257
258// ─── TaskStop ────────────────────────────────────────────────────────────────
259
260pub struct TaskStopTool;
261
262#[async_trait]
263impl Tool for TaskStopTool {
264    fn name(&self) -> &str {
265        "TaskStop"
266    }
267    fn description(&self) -> &str {
268        "Stop/cancel a running task."
269    }
270    fn permission_level(&self) -> PermissionLevel {
271        PermissionLevel::None
272    }
273    fn category(&self) -> ToolCategory {
274        ToolCategory::Orchestration
275    }
276
277    fn input_schema(&self) -> Value {
278        serde_json::json!({
279            "type": "object",
280            "properties": {
281                "id": { "type": "string", "description": "Task ID to stop" }
282            },
283            "required": ["id"]
284        })
285    }
286
287    async fn execute(&self, input: Value, _ctx: &ToolContext) -> ToolResult {
288        #[derive(Deserialize)]
289        struct Input {
290            id: String,
291        }
292
293        let input: Input = match serde_json::from_value(input) {
294            Ok(i) => i,
295            Err(e) => return ToolResult::error(format!("Invalid input: {}", e)),
296        };
297
298        match TASK_REGISTRY.get_mut(&input.id) {
299            Some(mut entry) => {
300                entry.status = TaskStatus::Stopped;
301                entry.updated_at = chrono::Utc::now().to_rfc3339();
302                ToolResult::success(format!("Task '{}' stopped", input.id))
303            }
304            None => ToolResult::error(format!("Task '{}' not found", input.id)),
305        }
306    }
307}
308
309// ─── TaskOutput ──────────────────────────────────────────────────────────────
310
311pub struct TaskOutputTool;
312
313#[async_trait]
314impl Tool for TaskOutputTool {
315    fn name(&self) -> &str {
316        "TaskOutput"
317    }
318    fn description(&self) -> &str {
319        "Get the full output of a completed task."
320    }
321    fn permission_level(&self) -> PermissionLevel {
322        PermissionLevel::None
323    }
324    fn category(&self) -> ToolCategory {
325        ToolCategory::Orchestration
326    }
327
328    fn input_schema(&self) -> Value {
329        serde_json::json!({
330            "type": "object",
331            "properties": {
332                "id": { "type": "string", "description": "Task ID" }
333            },
334            "required": ["id"]
335        })
336    }
337
338    async fn execute(&self, input: Value, _ctx: &ToolContext) -> ToolResult {
339        #[derive(Deserialize)]
340        struct Input {
341            id: String,
342        }
343
344        let input: Input = match serde_json::from_value(input) {
345            Ok(i) => i,
346            Err(e) => return ToolResult::error(format!("Invalid input: {}", e)),
347        };
348
349        match get_task(&input.id) {
350            Some(task) => match &task.output {
351                Some(output) => ToolResult::success(output.clone()),
352                None => ToolResult::success("(no output yet)"),
353            },
354            None => ToolResult::error(format!("Task '{}' not found", input.id)),
355        }
356    }
357}
358
359// ─── Tests ───────────────────────────────────────────────────────────────────
360
361#[cfg(test)]
362mod tests {
363    use super::*;
364    use crate::permissions::AllowAll;
365
366    fn test_ctx() -> ToolContext {
367        ToolContext {
368            working_dir: std::env::temp_dir(),
369            session_id: "task-test".into(),
370            permissions: Arc::new(AllowAll),
371            cost_tracker: Arc::new(CostTracker::new()),
372            mcp_manager: None,
373            extensions: Extensions::default(),
374        }
375    }
376
377    #[tokio::test]
378    async fn test_task_full_lifecycle() {
379        clear_tasks();
380        let ctx = ToolContext {
381            session_id: format!("task-lifecycle-{}", uuid::Uuid::new_v4()),
382            ..test_ctx()
383        };
384
385        // Create
386        let create = TaskCreateTool;
387        let r = create
388            .execute(serde_json::json!({"description": "Run tests"}), &ctx)
389            .await;
390        assert!(!r.is_error);
391        // Extract ID from "Task 'XXXXXXXX' created: ..."
392        let id = r.content.split('\'').nth(1).unwrap().to_string();
393
394        // List
395        let list = TaskListTool;
396        let r = list.execute(serde_json::json!({}), &ctx).await;
397        assert!(r.content.contains("Run tests"));
398
399        // Update to running
400        let update = TaskUpdateTool;
401        update
402            .execute(serde_json::json!({"id": &id, "status": "running"}), &ctx)
403            .await;
404        assert_eq!(get_task(&id).unwrap().status, TaskStatus::Running);
405
406        // Update with output
407        update
408            .execute(
409                serde_json::json!({
410                    "id": &id,
411                    "status": "completed",
412                    "output": "All 42 tests passed"
413                }),
414                &ctx,
415            )
416            .await;
417        let task = get_task(&id).unwrap();
418        assert_eq!(task.status, TaskStatus::Completed);
419        assert_eq!(task.output.as_deref(), Some("All 42 tests passed"));
420
421        // Get output
422        let output = TaskOutputTool;
423        let r = output.execute(serde_json::json!({"id": &id}), &ctx).await;
424        assert!(r.content.contains("42 tests passed"));
425
426        // Get status
427        let get = TaskGetTool;
428        let r = get.execute(serde_json::json!({"id": &id}), &ctx).await;
429        assert!(r.content.contains("Completed"));
430    }
431
432    #[tokio::test]
433    async fn test_task_stop() {
434        let ctx = ToolContext {
435            session_id: format!("stop-{}", uuid::Uuid::new_v4()),
436            ..test_ctx()
437        };
438
439        let create = TaskCreateTool;
440        let r = create
441            .execute(serde_json::json!({"description": "Long task"}), &ctx)
442            .await;
443        let id = r.content.split('\'').nth(1).unwrap().to_string();
444
445        let stop = TaskStopTool;
446        stop.execute(serde_json::json!({"id": &id}), &ctx).await;
447        assert_eq!(get_task(&id).unwrap().status, TaskStatus::Stopped);
448    }
449}