1use crate::error::AgentError;
10use crate::types::*;
11use std::collections::HashMap;
12use std::sync::{
13 Mutex, OnceLock,
14 atomic::{AtomicU64, Ordering},
15};
16
17pub const TASK_CREATE_TOOL_NAME: &str = "TaskCreate";
18pub const TASK_GET_TOOL_NAME: &str = "TaskGet";
19pub const TASK_LIST_TOOL_NAME: &str = "TaskList";
20pub const TASK_UPDATE_TOOL_NAME: &str = "TaskUpdate";
21
22static TASKS: OnceLock<Mutex<HashMap<String, Task>>> = OnceLock::new();
24static TASK_COUNTER: AtomicU64 = AtomicU64::new(1);
25
26fn get_tasks_map() -> &'static Mutex<HashMap<String, Task>> {
27 TASKS.get_or_init(|| Mutex::new(HashMap::new()))
28}
29
30pub fn reset_task_store() {
31 let mut guard = get_tasks_map().lock().unwrap();
32 guard.clear();
33 drop(guard);
34 TASK_COUNTER.store(1, Ordering::SeqCst);
35}
36
37#[cfg(test)]
40pub fn get_test_lock() -> &'static Mutex<()> {
41 use std::sync::Mutex as StdMutex;
42 static LOCK: OnceLock<StdMutex<()>> = OnceLock::new();
43 LOCK.get_or_init(|| StdMutex::new(()))
44}
45
46pub fn get_unfinished_tasks() -> Vec<Task> {
48 let guard = get_tasks_map().lock().unwrap();
49 guard
50 .values()
51 .filter(|t| t.status != "completed" && t.status != "deleted")
52 .cloned()
53 .collect()
54}
55
56pub fn get_all_tasks() -> Vec<Task> {
58 let guard = get_tasks_map().lock().unwrap();
59 guard.values().cloned().collect()
60}
61
62fn next_task_id() -> String {
63 let id = TASK_COUNTER.fetch_add(1, Ordering::SeqCst);
64 format!("task-{}", id)
65}
66
67#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
69pub struct Task {
70 pub id: String,
71 pub subject: String,
72 pub description: String,
73 pub status: String, #[serde(rename = "activeForm")]
75 pub active_form: Option<String>,
76 pub owner: Option<String>,
77 pub blocks: Vec<String>, pub blocked_by: Vec<String>, #[serde(rename = "_internal")]
80 pub internal: Option<bool>,
81}
82
83impl Task {
84 fn new(id: String, subject: String, description: String, active_form: Option<String>) -> Self {
85 Self {
86 id,
87 subject,
88 description,
89 status: "pending".to_string(),
90 active_form,
91 owner: None,
92 blocks: vec![],
93 blocked_by: vec![],
94 internal: None,
95 }
96 }
97}
98
99pub struct TaskCreateTool;
101
102impl TaskCreateTool {
103 pub fn new() -> Self {
104 Self
105 }
106
107 pub fn name(&self) -> &str {
108 TASK_CREATE_TOOL_NAME
109 }
110
111 pub fn description(&self) -> &str {
112 "Create a new task in the task list. Tasks can be tracked with status and can block other tasks."
113 }
114
115 pub fn user_facing_name(&self, _input: Option<&serde_json::Value>) -> String {
116 "TaskCreate".to_string()
117 }
118
119 pub fn get_tool_use_summary(&self, input: Option<&serde_json::Value>) -> Option<String> {
120 input.and_then(|inp| inp["subject"].as_str().map(String::from))
121 }
122
123 pub fn render_tool_result_message(
124 &self,
125 content: &serde_json::Value,
126 ) -> Option<String> {
127 content["content"].as_str().map(|s| s.to_string())
128 }
129
130 pub fn input_schema(&self) -> ToolInputSchema {
131 ToolInputSchema {
132 schema_type: "object".to_string(),
133 properties: serde_json::json!({
134 "subject": {
135 "type": "string",
136 "description": "A brief title for the task"
137 },
138 "description": {
139 "type": "string",
140 "description": "What needs to be done"
141 },
142 "activeForm": {
143 "type": "string",
144 "description": "Present continuous form shown in spinner when in_progress"
145 }
146 }),
147 required: Some(vec!["subject".to_string(), "description".to_string()]),
148 }
149 }
150
151 pub async fn execute(
152 &self,
153 input: serde_json::Value,
154 _context: &ToolContext,
155 ) -> Result<ToolResult, AgentError> {
156 let subject = input["subject"]
157 .as_str()
158 .ok_or_else(|| AgentError::Tool("subject is required".to_string()))?
159 .to_string();
160
161 let description = input["description"]
162 .as_str()
163 .ok_or_else(|| AgentError::Tool("description is required".to_string()))?
164 .to_string();
165
166 let active_form = input["activeForm"].as_str().map(|s| s.to_string());
167
168 let id = next_task_id();
169 let task = Task::new(
170 id.clone(),
171 subject.clone(),
172 description.clone(),
173 active_form.clone(),
174 );
175
176 let mut guard = get_tasks_map().lock().unwrap();
177 guard.insert(id.clone(), task);
178 drop(guard);
179
180 Ok(ToolResult {
181 result_type: "text".to_string(),
182 tool_use_id: "".to_string(),
183 content: format!(
184 "Task created: {}\nSubject: {}\nID: {}",
185 id,
186 subject.clone(),
187 id
188 ),
189 is_error: Some(false),
190 was_persisted: None,
191 })
192 }
193}
194
195impl Default for TaskCreateTool {
196 fn default() -> Self {
197 Self::new()
198 }
199}
200
201pub struct TaskListTool;
203
204impl TaskListTool {
205 pub fn new() -> Self {
206 Self
207 }
208
209 pub fn name(&self) -> &str {
210 TASK_LIST_TOOL_NAME
211 }
212
213 pub fn description(&self) -> &str {
214 "List all tasks in the task list. Shows task ID, subject, status, and blocking information."
215 }
216
217 pub fn user_facing_name(&self, _input: Option<&serde_json::Value>) -> String {
218 "TaskList".to_string()
219 }
220
221 pub fn get_tool_use_summary(&self, _input: Option<&serde_json::Value>) -> Option<String> {
222 None
223 }
224
225 pub fn render_tool_result_message(
226 &self,
227 content: &serde_json::Value,
228 ) -> Option<String> {
229 let text = content["content"].as_str()?;
230 let lines = text.lines().count();
231 Some(format!("{} lines", lines))
232 }
233
234 pub fn input_schema(&self) -> ToolInputSchema {
235 ToolInputSchema {
236 schema_type: "object".to_string(),
237 properties: serde_json::json!({}),
238 required: None,
239 }
240 }
241
242 pub async fn execute(
243 &self,
244 _input: serde_json::Value,
245 _context: &ToolContext,
246 ) -> Result<ToolResult, AgentError> {
247 let guard = get_tasks_map().lock().unwrap();
248
249 let tasks: Vec<&Task> = guard
251 .values()
252 .filter(|t| t.internal != Some(true) && t.status != "deleted")
253 .collect();
254
255 if tasks.is_empty() {
256 return Ok(ToolResult {
257 result_type: "text".to_string(),
258 tool_use_id: "".to_string(),
259 content: "No tasks.".to_string(),
260 is_error: None,
261 was_persisted: None,
262 });
263 }
264
265 let lines: Vec<String> = tasks
266 .iter()
267 .map(|t| {
268 let blocking_note = if !t.blocks.is_empty() {
269 format!(" (blocks: {})", t.blocks.join(", "))
270 } else {
271 String::new()
272 };
273 let owner_note = if let Some(owner) = &t.owner {
274 format!(" [{}]", owner)
275 } else {
276 String::new()
277 };
278 format!(
279 "{}. {} [{}] - {}{}{}",
280 t.id,
281 t.subject,
282 t.status,
283 t.active_form.as_deref().unwrap_or(""),
284 owner_note,
285 blocking_note
286 )
287 })
288 .collect();
289
290 Ok(ToolResult {
291 result_type: "text".to_string(),
292 tool_use_id: "".to_string(),
293 content: format!("Tasks:\n{}", lines.join("\n")),
294 is_error: Some(false),
295 was_persisted: None,
296 })
297 }
298}
299
300impl Default for TaskListTool {
301 fn default() -> Self {
302 Self::new()
303 }
304}
305
306pub struct TaskUpdateTool;
308
309impl TaskUpdateTool {
310 pub fn new() -> Self {
311 Self
312 }
313
314 pub fn name(&self) -> &str {
315 TASK_UPDATE_TOOL_NAME
316 }
317
318 pub fn description(&self) -> &str {
319 "Update an existing task's status, subject, description, or other fields."
320 }
321
322 pub fn user_facing_name(&self, _input: Option<&serde_json::Value>) -> String {
323 "TaskUpdate".to_string()
324 }
325
326 pub fn get_tool_use_summary(&self, input: Option<&serde_json::Value>) -> Option<String> {
327 input.and_then(|inp| inp["taskId"].as_str().map(String::from))
328 }
329
330 pub fn render_tool_result_message(
331 &self,
332 content: &serde_json::Value,
333 ) -> Option<String> {
334 content["content"].as_str().map(|s| s.to_string())
335 }
336
337 pub fn input_schema(&self) -> ToolInputSchema {
338 ToolInputSchema {
339 schema_type: "object".to_string(),
340 properties: serde_json::json!({
341 "taskId": {
342 "type": "string",
343 "description": "The ID of the task to update"
344 },
345 "subject": {
346 "type": "string",
347 "description": "New subject for the task"
348 },
349 "description": {
350 "type": "string",
351 "description": "New description for the task"
352 },
353 "status": {
354 "type": "string",
355 "enum": ["pending", "in_progress", "completed", "deleted"],
356 "description": "New status for the task"
357 },
358 "activeForm": {
359 "type": "string",
360 "description": "New active form"
361 },
362 "owner": {
363 "type": "string",
364 "description": "New owner for the task"
365 },
366 "blocks": {
367 "type": "array",
368 "items": { "type": "string" },
369 "description": "Task IDs that this task blocks"
370 },
371 "blockedBy": {
372 "type": "array",
373 "items": { "type": "string" },
374 "description": "Task IDs that block this task"
375 }
376 }),
377 required: Some(vec!["taskId".to_string()]),
378 }
379 }
380
381 pub async fn execute(
382 &self,
383 input: serde_json::Value,
384 _context: &ToolContext,
385 ) -> Result<ToolResult, AgentError> {
386 let task_id = input["taskId"]
387 .as_str()
388 .ok_or_else(|| AgentError::Tool("taskId is required".to_string()))?;
389
390 let mut guard = get_tasks_map().lock().unwrap();
391 let task = guard
392 .get_mut(task_id)
393 .ok_or_else(|| AgentError::Tool(format!("Task '{}' not found", task_id)))?;
394
395 let mut changes: Vec<String> = Vec::new();
396
397 let old_status = task.status.clone();
398
399 if let Some(subject) = input["subject"].as_str() {
400 task.subject = subject.to_string();
401 changes.push("subject".to_string());
402 }
403 if let Some(description) = input["description"].as_str() {
404 task.description = description.to_string();
405 changes.push("description".to_string());
406 }
407 if let Some(status) = input["status"].as_str() {
408 task.status = status.to_string();
409 changes.push(format!("status: {} -> {}", old_status, status));
410 }
411 if let Some(active_form) = input["activeForm"].as_str() {
412 task.active_form = Some(active_form.to_string());
413 changes.push("activeForm".to_string());
414 }
415 if let Some(owner) = input["owner"].as_str() {
416 task.owner = Some(owner.to_string());
417 changes.push(format!("owner -> {}", owner));
418 }
419 if let Some(blocks) = input["blocks"].as_array() {
420 task.blocks = blocks
421 .iter()
422 .filter_map(|v| v.as_str().map(|s| s.to_string()))
423 .collect();
424 changes.push("blocks".to_string());
425 }
426 if let Some(blocked_by) = input["blockedBy"].as_array() {
427 task.blocked_by = blocked_by
428 .iter()
429 .filter_map(|v| v.as_str().map(|s| s.to_string()))
430 .collect();
431 changes.push("blockedBy".to_string());
432 }
433
434 drop(guard);
435
436 let changes_str = if changes.is_empty() {
437 "no changes".to_string()
438 } else {
439 changes.join(", ")
440 };
441
442 Ok(ToolResult {
443 result_type: "text".to_string(),
444 tool_use_id: "".to_string(),
445 content: format!("Task {} updated: {}", task_id, changes_str),
446 is_error: Some(false),
447 was_persisted: None,
448 })
449 }
450}
451
452impl Default for TaskUpdateTool {
453 fn default() -> Self {
454 Self::new()
455 }
456}
457
458pub struct TaskGetTool;
460
461impl TaskGetTool {
462 pub fn new() -> Self {
463 Self
464 }
465
466 pub fn name(&self) -> &str {
467 TASK_GET_TOOL_NAME
468 }
469
470 pub fn description(&self) -> &str {
471 "Get details of a specific task by ID."
472 }
473
474 pub fn user_facing_name(&self, _input: Option<&serde_json::Value>) -> String {
475 "TaskGet".to_string()
476 }
477
478 pub fn get_tool_use_summary(&self, input: Option<&serde_json::Value>) -> Option<String> {
479 input.and_then(|inp| inp["taskId"].as_str().map(String::from))
480 }
481
482 pub fn render_tool_result_message(
483 &self,
484 content: &serde_json::Value,
485 ) -> Option<String> {
486 let text = content["content"].as_str()?;
487 let lines = text.lines().count();
488 Some(format!("{} lines", lines))
489 }
490
491 pub fn input_schema(&self) -> ToolInputSchema {
492 ToolInputSchema {
493 schema_type: "object".to_string(),
494 properties: serde_json::json!({
495 "taskId": {
496 "type": "string",
497 "description": "The ID of the task to retrieve"
498 }
499 }),
500 required: Some(vec!["taskId".to_string()]),
501 }
502 }
503
504 pub async fn execute(
505 &self,
506 input: serde_json::Value,
507 _context: &ToolContext,
508 ) -> Result<ToolResult, AgentError> {
509 let task_id = input["taskId"]
510 .as_str()
511 .ok_or_else(|| AgentError::Tool("taskId is required".to_string()))?;
512
513 let guard = get_tasks_map().lock().unwrap();
514 let task = guard
515 .get(task_id)
516 .ok_or_else(|| AgentError::Tool(format!("Task '{}' not found", task_id)))?;
517
518 let content = serde_json::to_string_pretty(&serde_json::json!({
519 "id": task.id,
520 "subject": task.subject,
521 "description": task.description,
522 "status": task.status,
523 "activeForm": task.active_form,
524 "owner": task.owner,
525 "blocks": task.blocks,
526 "blockedBy": task.blocked_by
527 }))
528 .unwrap_or_default();
529
530 Ok(ToolResult {
531 result_type: "text".to_string(),
532 tool_use_id: "".to_string(),
533 content,
534 is_error: Some(false),
535 was_persisted: None,
536 })
537 }
538}
539
540impl Default for TaskGetTool {
541 fn default() -> Self {
542 Self::new()
543 }
544}
545
546#[cfg(test)]
547mod tests {
548 use super::*;
549
550 fn test_setup() -> std::sync::MutexGuard<'static, ()> {
551 let _lock = get_test_lock().lock().unwrap();
552 reset_task_store();
553 _lock
554 }
555
556 #[tokio::test]
557 async fn test_task_create_and_get() {
558 let _lock = test_setup();
559
560 let create = TaskCreateTool::new();
561 let result = create
562 .execute(
563 serde_json::json!({
564 "subject": "Test Task",
565 "description": "A test task",
566 "activeForm": "Testing"
567 }),
568 &ToolContext::default(),
569 )
570 .await;
571 assert!(result.is_ok());
572
573 let content = result.unwrap().content;
575 let task_id = content
576 .lines()
577 .find(|l| l.starts_with("ID: "))
578 .unwrap()
579 .strip_prefix("ID: ")
580 .unwrap()
581 .trim()
582 .to_string();
583
584 let get = TaskGetTool::new();
585 let get_result = get
586 .execute(
587 serde_json::json!({ "taskId": task_id }),
588 &ToolContext::default(),
589 )
590 .await;
591 assert!(get_result.is_ok());
592 let content = get_result.unwrap().content;
593 assert!(content.contains("Test Task"));
594 }
595
596 #[tokio::test]
597 async fn test_task_list() {
598 let _lock = test_setup();
599
600 let create = TaskCreateTool::new();
601 create
602 .execute(
603 serde_json::json!({ "subject": "Task A", "description": "Desc A" }),
604 &ToolContext::default(),
605 )
606 .await
607 .unwrap();
608
609 let list = TaskListTool::new();
610 let result = list
611 .execute(serde_json::json!({}), &ToolContext::default())
612 .await;
613 assert!(result.is_ok());
614 assert!(result.unwrap().content.contains("Task A"));
615 }
616
617 #[tokio::test]
618 async fn test_task_update_status() {
619 let _lock = test_setup();
620
621 let update = TaskUpdateTool::new();
622 let result = update
623 .execute(
624 serde_json::json!({
625 "taskId": "task-1",
626 "status": "in_progress"
627 }),
628 &ToolContext::default(),
629 )
630 .await;
631 let create = TaskCreateTool::new();
633 let create_result = create
634 .execute(
635 serde_json::json!({
636 "subject": "Update Me",
637 "description": "To be updated"
638 }),
639 &ToolContext::default(),
640 )
641 .await
642 .unwrap();
643 let task_id = create_result
644 .content
645 .lines()
646 .find(|l| l.starts_with("ID: "))
647 .unwrap()
648 .strip_prefix("ID: ")
649 .unwrap()
650 .trim()
651 .to_string();
652
653 let result = update
654 .execute(
655 serde_json::json!({
656 "taskId": task_id,
657 "status": "in_progress"
658 }),
659 &ToolContext::default(),
660 )
661 .await;
662 assert!(result.is_ok());
663
664 let get = TaskGetTool::new();
665 let get_result = get
666 .execute(
667 serde_json::json!({ "taskId": task_id }),
668 &ToolContext::default(),
669 )
670 .await;
671 assert!(get_result.unwrap().content.contains("in_progress"));
672 }
673}