1use super::*;
7use serde::{Deserialize, Serialize};
8
9static TASK_REGISTRY: once_cell::sync::Lazy<dashmap::DashMap<String, TaskEntry>> =
12 once_cell::sync::Lazy::new(dashmap::DashMap::new);
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct TaskEntry {
16 pub id: String,
17 pub description: String,
18 pub status: TaskStatus,
19 pub output: Option<String>,
20 pub created_at: String,
21 pub updated_at: String,
22 pub session_id: String,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
26#[serde(rename_all = "snake_case")]
27pub enum TaskStatus {
28 Pending,
29 Running,
30 Completed,
31 Failed,
32 Stopped,
33}
34
35pub fn get_task(id: &str) -> Option<TaskEntry> {
36 TASK_REGISTRY.get(id).map(|e| e.clone())
37}
38
39pub fn list_tasks() -> Vec<TaskEntry> {
40 TASK_REGISTRY.iter().map(|e| e.value().clone()).collect()
41}
42
43pub fn clear_tasks() {
44 TASK_REGISTRY.clear();
45}
46
47pub struct TaskCreateTool;
50
51#[async_trait]
52impl Tool for TaskCreateTool {
53 fn name(&self) -> &str {
54 "TaskCreate"
55 }
56 fn description(&self) -> &str {
57 "Create a new task for tracking sub-agent work."
58 }
59 fn permission_level(&self) -> PermissionLevel {
60 PermissionLevel::None
61 }
62 fn category(&self) -> ToolCategory {
63 ToolCategory::Orchestration
64 }
65
66 fn input_schema(&self) -> Value {
67 serde_json::json!({
68 "type": "object",
69 "properties": {
70 "description": { "type": "string", "description": "What this task does" },
71 "prompt": { "type": "string", "description": "The prompt for the sub-agent (optional)" }
72 },
73 "required": ["description"]
74 })
75 }
76
77 async fn execute(&self, input: Value, ctx: &ToolContext) -> ToolResult {
78 #[derive(Deserialize)]
79 #[allow(dead_code)]
80 struct Input {
81 description: String,
82 prompt: Option<String>,
83 }
84
85 let input: Input = match serde_json::from_value(input) {
86 Ok(i) => i,
87 Err(e) => return ToolResult::error(format!("Invalid input: {}", e)),
88 };
89
90 let id = uuid::Uuid::new_v4().to_string()[..8].to_string();
91 let now = chrono::Utc::now().to_rfc3339();
92 let task = TaskEntry {
93 id: id.clone(),
94 description: input.description.clone(),
95 status: TaskStatus::Pending,
96 output: None,
97 created_at: now.clone(),
98 updated_at: now,
99 session_id: ctx.session_id.clone(),
100 };
101 TASK_REGISTRY.insert(id.clone(), task);
102 ToolResult::success(format!("Task '{}' created: {}", id, input.description))
103 }
104}
105
106pub struct TaskGetTool;
109
110#[async_trait]
111impl Tool for TaskGetTool {
112 fn name(&self) -> &str {
113 "TaskGet"
114 }
115 fn description(&self) -> &str {
116 "Get the status and output of a task by ID."
117 }
118 fn permission_level(&self) -> PermissionLevel {
119 PermissionLevel::None
120 }
121 fn category(&self) -> ToolCategory {
122 ToolCategory::Orchestration
123 }
124
125 fn input_schema(&self) -> Value {
126 serde_json::json!({
127 "type": "object",
128 "properties": {
129 "id": { "type": "string", "description": "Task ID" }
130 },
131 "required": ["id"]
132 })
133 }
134
135 async fn execute(&self, input: Value, _ctx: &ToolContext) -> ToolResult {
136 #[derive(Deserialize)]
137 struct Input {
138 id: String,
139 }
140
141 let input: Input = match serde_json::from_value(input) {
142 Ok(i) => i,
143 Err(e) => return ToolResult::error(format!("Invalid input: {}", e)),
144 };
145
146 match get_task(&input.id) {
147 Some(task) => {
148 let output = task.output.as_deref().unwrap_or("(no output yet)");
149 ToolResult::success(format!(
150 "Task [{}] {:?}\n {}\n Output: {}",
151 task.id, task.status, task.description, output
152 ))
153 }
154 None => ToolResult::error(format!("Task '{}' not found", input.id)),
155 }
156 }
157}
158
159pub struct TaskUpdateTool;
162
163#[async_trait]
164impl Tool for TaskUpdateTool {
165 fn name(&self) -> &str {
166 "TaskUpdate"
167 }
168 fn description(&self) -> &str {
169 "Update a task's status and/or output."
170 }
171 fn permission_level(&self) -> PermissionLevel {
172 PermissionLevel::None
173 }
174 fn category(&self) -> ToolCategory {
175 ToolCategory::Orchestration
176 }
177
178 fn input_schema(&self) -> Value {
179 serde_json::json!({
180 "type": "object",
181 "properties": {
182 "id": { "type": "string", "description": "Task ID" },
183 "status": { "type": "string", "enum": ["pending", "running", "completed", "failed", "stopped"] },
184 "output": { "type": "string", "description": "Task output/result text" }
185 },
186 "required": ["id"]
187 })
188 }
189
190 async fn execute(&self, input: Value, _ctx: &ToolContext) -> ToolResult {
191 #[derive(Deserialize)]
192 struct Input {
193 id: String,
194 status: Option<TaskStatus>,
195 output: Option<String>,
196 }
197
198 let input: Input = match serde_json::from_value(input) {
199 Ok(i) => i,
200 Err(e) => return ToolResult::error(format!("Invalid input: {}", e)),
201 };
202
203 match TASK_REGISTRY.get_mut(&input.id) {
204 Some(mut entry) => {
205 if let Some(status) = input.status {
206 entry.status = status;
207 }
208 if let Some(output) = input.output {
209 entry.output = Some(output);
210 }
211 entry.updated_at = chrono::Utc::now().to_rfc3339();
212 ToolResult::success(format!("Task '{}' updated", input.id))
213 }
214 None => ToolResult::error(format!("Task '{}' not found", input.id)),
215 }
216 }
217}
218
219pub struct TaskListTool;
222
223#[async_trait]
224impl Tool for TaskListTool {
225 fn name(&self) -> &str {
226 "TaskList"
227 }
228 fn description(&self) -> &str {
229 "List all tasks with their status."
230 }
231 fn permission_level(&self) -> PermissionLevel {
232 PermissionLevel::None
233 }
234 fn category(&self) -> ToolCategory {
235 ToolCategory::Orchestration
236 }
237
238 fn input_schema(&self) -> Value {
239 serde_json::json!({"type": "object", "properties": {}, "required": []})
240 }
241
242 async fn execute(&self, _input: Value, _ctx: &ToolContext) -> ToolResult {
243 let tasks = list_tasks();
244 if tasks.is_empty() {
245 return ToolResult::success("No tasks.");
246 }
247 let lines: Vec<String> = tasks
248 .iter()
249 .map(|t| {
250 let status = format!("{:?}", t.status);
251 format!("- [{}] {} — {}", t.id, status, t.description)
252 })
253 .collect();
254 ToolResult::success(lines.join("\n"))
255 }
256}
257
258pub struct TaskStopTool;
261
262#[async_trait]
263impl Tool for TaskStopTool {
264 fn name(&self) -> &str {
265 "TaskStop"
266 }
267 fn description(&self) -> &str {
268 "Stop/cancel a running task."
269 }
270 fn permission_level(&self) -> PermissionLevel {
271 PermissionLevel::None
272 }
273 fn category(&self) -> ToolCategory {
274 ToolCategory::Orchestration
275 }
276
277 fn input_schema(&self) -> Value {
278 serde_json::json!({
279 "type": "object",
280 "properties": {
281 "id": { "type": "string", "description": "Task ID to stop" }
282 },
283 "required": ["id"]
284 })
285 }
286
287 async fn execute(&self, input: Value, _ctx: &ToolContext) -> ToolResult {
288 #[derive(Deserialize)]
289 struct Input {
290 id: String,
291 }
292
293 let input: Input = match serde_json::from_value(input) {
294 Ok(i) => i,
295 Err(e) => return ToolResult::error(format!("Invalid input: {}", e)),
296 };
297
298 match TASK_REGISTRY.get_mut(&input.id) {
299 Some(mut entry) => {
300 entry.status = TaskStatus::Stopped;
301 entry.updated_at = chrono::Utc::now().to_rfc3339();
302 ToolResult::success(format!("Task '{}' stopped", input.id))
303 }
304 None => ToolResult::error(format!("Task '{}' not found", input.id)),
305 }
306 }
307}
308
309pub struct TaskOutputTool;
312
313#[async_trait]
314impl Tool for TaskOutputTool {
315 fn name(&self) -> &str {
316 "TaskOutput"
317 }
318 fn description(&self) -> &str {
319 "Get the full output of a completed task."
320 }
321 fn permission_level(&self) -> PermissionLevel {
322 PermissionLevel::None
323 }
324 fn category(&self) -> ToolCategory {
325 ToolCategory::Orchestration
326 }
327
328 fn input_schema(&self) -> Value {
329 serde_json::json!({
330 "type": "object",
331 "properties": {
332 "id": { "type": "string", "description": "Task ID" }
333 },
334 "required": ["id"]
335 })
336 }
337
338 async fn execute(&self, input: Value, _ctx: &ToolContext) -> ToolResult {
339 #[derive(Deserialize)]
340 struct Input {
341 id: String,
342 }
343
344 let input: Input = match serde_json::from_value(input) {
345 Ok(i) => i,
346 Err(e) => return ToolResult::error(format!("Invalid input: {}", e)),
347 };
348
349 match get_task(&input.id) {
350 Some(task) => match &task.output {
351 Some(output) => ToolResult::success(output.clone()),
352 None => ToolResult::success("(no output yet)"),
353 },
354 None => ToolResult::error(format!("Task '{}' not found", input.id)),
355 }
356 }
357}
358
359#[cfg(test)]
362mod tests {
363 use super::*;
364 use crate::permissions::AllowAll;
365
366 fn test_ctx() -> ToolContext {
367 ToolContext {
368 working_dir: std::env::temp_dir(),
369 session_id: "task-test".into(),
370 permissions: Arc::new(AllowAll),
371 cost_tracker: Arc::new(CostTracker::new()),
372 mcp_manager: None,
373 extensions: Extensions::default(),
374 }
375 }
376
377 #[tokio::test]
378 async fn test_task_full_lifecycle() {
379 clear_tasks();
380 let ctx = ToolContext {
381 session_id: format!("task-lifecycle-{}", uuid::Uuid::new_v4()),
382 ..test_ctx()
383 };
384
385 let create = TaskCreateTool;
387 let r = create
388 .execute(serde_json::json!({"description": "Run tests"}), &ctx)
389 .await;
390 assert!(!r.is_error);
391 let id = r.content.split('\'').nth(1).unwrap().to_string();
393
394 let list = TaskListTool;
396 let r = list.execute(serde_json::json!({}), &ctx).await;
397 assert!(r.content.contains("Run tests"));
398
399 let update = TaskUpdateTool;
401 update
402 .execute(serde_json::json!({"id": &id, "status": "running"}), &ctx)
403 .await;
404 assert_eq!(get_task(&id).unwrap().status, TaskStatus::Running);
405
406 update
408 .execute(
409 serde_json::json!({
410 "id": &id,
411 "status": "completed",
412 "output": "All 42 tests passed"
413 }),
414 &ctx,
415 )
416 .await;
417 let task = get_task(&id).unwrap();
418 assert_eq!(task.status, TaskStatus::Completed);
419 assert_eq!(task.output.as_deref(), Some("All 42 tests passed"));
420
421 let output = TaskOutputTool;
423 let r = output.execute(serde_json::json!({"id": &id}), &ctx).await;
424 assert!(r.content.contains("42 tests passed"));
425
426 let get = TaskGetTool;
428 let r = get.execute(serde_json::json!({"id": &id}), &ctx).await;
429 assert!(r.content.contains("Completed"));
430 }
431
432 #[tokio::test]
433 async fn test_task_stop() {
434 let ctx = ToolContext {
435 session_id: format!("stop-{}", uuid::Uuid::new_v4()),
436 ..test_ctx()
437 };
438
439 let create = TaskCreateTool;
440 let r = create
441 .execute(serde_json::json!({"description": "Long task"}), &ctx)
442 .await;
443 let id = r.content.split('\'').nth(1).unwrap().to_string();
444
445 let stop = TaskStopTool;
446 stop.execute(serde_json::json!({"id": &id}), &ctx).await;
447 assert_eq!(get_task(&id).unwrap().status, TaskStatus::Stopped);
448 }
449}