1use async_trait::async_trait;
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15use std::sync::{Arc, RwLock};
16
17use super::base::{PermissionCheckResult, Tool};
18use super::context::{ToolContext, ToolOptions, ToolResult};
19use super::error::ToolError;
20
21#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
23#[serde(rename_all = "snake_case")]
24pub enum TodoStatus {
25 #[default]
27 Pending,
28 InProgress,
30 Completed,
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct TodoItem {
37 pub content: String,
39 pub status: TodoStatus,
41 pub active_form: String,
43}
44
45impl TodoItem {
46 pub fn new(content: impl Into<String>, active_form: impl Into<String>) -> Self {
48 Self {
49 content: content.into(),
50 status: TodoStatus::Pending,
51 active_form: active_form.into(),
52 }
53 }
54
55 pub fn with_status(
57 content: impl Into<String>,
58 active_form: impl Into<String>,
59 status: TodoStatus,
60 ) -> Self {
61 Self {
62 content: content.into(),
63 status,
64 active_form: active_form.into(),
65 }
66 }
67
68 pub fn is_in_progress(&self) -> bool {
70 self.status == TodoStatus::InProgress
71 }
72
73 pub fn is_completed(&self) -> bool {
75 self.status == TodoStatus::Completed
76 }
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct TodoWriteInput {
82 pub todos: Vec<TodoItem>,
84}
85
86#[derive(Debug, Default)]
88pub struct TodoStorage {
89 storage: RwLock<HashMap<String, Vec<TodoItem>>>,
91}
92
93impl TodoStorage {
94 pub fn new() -> Self {
96 Self::default()
97 }
98
99 pub fn get_todos(&self, agent_id: &str) -> Vec<TodoItem> {
101 self.storage
102 .read()
103 .unwrap()
104 .get(agent_id)
105 .cloned()
106 .unwrap_or_default()
107 }
108
109 pub fn set_todos(&self, agent_id: &str, todos: Vec<TodoItem>) {
111 let mut storage = self.storage.write().unwrap();
112 if todos.is_empty() {
113 storage.remove(agent_id);
114 } else {
115 storage.insert(agent_id.to_string(), todos);
116 }
117 }
118
119 pub fn get_stats(&self) -> HashMap<String, (usize, usize, usize)> {
121 let storage = self.storage.read().unwrap();
122 storage
123 .iter()
124 .map(|(agent_id, todos)| {
125 let pending = todos
126 .iter()
127 .filter(|t| t.status == TodoStatus::Pending)
128 .count();
129 let in_progress = todos
130 .iter()
131 .filter(|t| t.status == TodoStatus::InProgress)
132 .count();
133 let completed = todos
134 .iter()
135 .filter(|t| t.status == TodoStatus::Completed)
136 .count();
137 (agent_id.clone(), (pending, in_progress, completed))
138 })
139 .collect()
140 }
141}
142
143#[derive(Debug)]
152pub struct TodoWriteTool {
153 storage: Arc<TodoStorage>,
155 default_agent_id: String,
157}
158
159impl Default for TodoWriteTool {
160 fn default() -> Self {
161 Self::new()
162 }
163}
164
165impl TodoWriteTool {
166 pub fn new() -> Self {
168 Self {
169 storage: Arc::new(TodoStorage::new()),
170 default_agent_id: "main".to_string(),
171 }
172 }
173
174 pub fn with_storage(storage: Arc<TodoStorage>) -> Self {
176 Self {
177 storage,
178 default_agent_id: "main".to_string(),
179 }
180 }
181
182 pub fn with_default_agent_id(mut self, agent_id: impl Into<String>) -> Self {
184 self.default_agent_id = agent_id.into();
185 self
186 }
187
188 pub fn storage(&self) -> &Arc<TodoStorage> {
190 &self.storage
191 }
192
193 fn validate_todos(&self, todos: &[TodoItem]) -> Result<(), String> {
195 let in_progress_count = todos.iter().filter(|t| t.is_in_progress()).count();
197 if in_progress_count > 1 {
198 return Err("Only one task can be in_progress at a time".to_string());
199 }
200
201 for todo in todos {
203 if todo.content.trim().is_empty() {
204 return Err("Task content cannot be empty".to_string());
205 }
206 if todo.active_form.trim().is_empty() {
207 return Err("Task active_form cannot be empty".to_string());
208 }
209 }
210
211 Ok(())
212 }
213
214 fn get_agent_id(&self, context: &ToolContext) -> String {
216 context
218 .environment
219 .get("AGENT_ID")
220 .cloned()
221 .unwrap_or_else(|| {
222 if context.session_id.is_empty() {
223 self.default_agent_id.clone()
224 } else {
225 context.session_id.clone()
226 }
227 })
228 }
229}
230
231#[async_trait]
232impl Tool for TodoWriteTool {
233 fn name(&self) -> &str {
235 "TodoWrite"
236 }
237
238 fn description(&self) -> &str {
240 "Use this tool to create and manage a structured task list for your current coding session. \
241 This helps you track progress, organize complex tasks, and demonstrate thoroughness to the user. \
242 It also helps the user understand the progress of the task and overall progress of their requests.\n\n\
243 ## When to Use This Tool\n\
244 Use this tool proactively in these scenarios:\n\
245 1. Complex multi-step tasks - When a task requires 3 or more distinct steps or actions\n\
246 2. Non-trivial and complex tasks - Tasks that require careful planning or multiple operations\n\
247 3. User explicitly requests todo list - When the user directly asks you to use the todo list\n\
248 4. User provides multiple tasks - When users provide a list of things to be done\n\
249 5. After receiving new instructions - Immediately capture user requirements as todos\n\
250 6. When you start working on a task - Mark it as in_progress BEFORE beginning work\n\
251 7. After completing a task - Mark it as completed and add any new follow-up tasks\n\n\
252 ## Task States and Management\n\
253 1. **Task States**: Use these states to track progress:\n\
254 - pending: Task not yet started\n\
255 - in_progress: Currently working on (limit to ONE task at a time)\n\
256 - completed: Task finished successfully\n\
257 2. **Task Management**:\n\
258 - Update task status in real-time as you work\n\
259 - Mark tasks complete IMMEDIATELY after finishing\n\
260 - Exactly ONE task must be in_progress at any time\n\
261 - Complete current tasks before starting new ones\n\
262 - Remove tasks that are no longer relevant from the list entirely"
263 }
264
265 fn input_schema(&self) -> serde_json::Value {
267 serde_json::json!({
268 "type": "object",
269 "properties": {
270 "todos": {
271 "type": "array",
272 "description": "The updated todo list",
273 "items": {
274 "type": "object",
275 "properties": {
276 "content": {
277 "type": "string",
278 "minLength": 1,
279 "description": "Task description (imperative form, e.g., 'Run tests')"
280 },
281 "status": {
282 "type": "string",
283 "enum": ["pending", "in_progress", "completed"],
284 "description": "Task status"
285 },
286 "active_form": {
287 "type": "string",
288 "minLength": 1,
289 "description": "Present continuous form (e.g., 'Running tests')"
290 }
291 },
292 "required": ["content", "status", "active_form"]
293 }
294 }
295 },
296 "required": ["todos"]
297 })
298 }
299
300 async fn execute(
302 &self,
303 params: serde_json::Value,
304 context: &ToolContext,
305 ) -> Result<ToolResult, ToolError> {
306 let input: TodoWriteInput = serde_json::from_value(params)
308 .map_err(|e| ToolError::invalid_params(format!("Invalid input format: {}", e)))?;
309
310 if let Err(error) = self.validate_todos(&input.todos) {
312 return Ok(ToolResult::error(error));
313 }
314
315 let agent_id = self.get_agent_id(context);
317
318 let old_todos = self.storage.get_todos(&agent_id);
320
321 let new_todos = if input.todos.iter().all(|t| t.is_completed()) {
323 Vec::new()
324 } else {
325 input.todos.clone()
326 };
327
328 self.storage.set_todos(&agent_id, new_todos.clone());
330
331 let message = if new_todos.is_empty() && !input.todos.is_empty() {
333 "All tasks completed! Todo list has been automatically cleared. \
334 Ensure that you continue to use the todo list to track your progress for future tasks."
335 } else {
336 "Todos have been modified successfully. \
337 Ensure that you continue to use the todo list to track your progress. \
338 Please proceed with the current tasks if applicable."
339 };
340
341 Ok(ToolResult::success(message)
343 .with_metadata("agent_id", serde_json::json!(agent_id))
344 .with_metadata("old_todos", serde_json::json!(old_todos))
345 .with_metadata("new_todos", serde_json::json!(input.todos))
346 .with_metadata(
347 "auto_cleared",
348 serde_json::json!(new_todos.is_empty() && !input.todos.is_empty()),
349 ))
350 }
351
352 async fn check_permissions(
354 &self,
355 params: &serde_json::Value,
356 _context: &ToolContext,
357 ) -> PermissionCheckResult {
358 match serde_json::from_value::<TodoWriteInput>(params.clone()) {
360 Ok(input) => {
361 if let Err(error) = self.validate_todos(&input.todos) {
363 return PermissionCheckResult::deny(format!("Invalid todos: {}", error));
364 }
365 PermissionCheckResult::allow()
366 }
367 Err(e) => PermissionCheckResult::deny(format!("Invalid input format: {}", e)),
368 }
369 }
370
371 fn options(&self) -> ToolOptions {
373 ToolOptions::new()
374 .with_max_retries(0) .with_base_timeout(std::time::Duration::from_secs(5)) .with_dynamic_timeout(false)
377 }
378}
379
380#[cfg(test)]
385mod tests {
386 use super::*;
387 use std::path::PathBuf;
388
389 fn create_test_context() -> ToolContext {
390 ToolContext::new(PathBuf::from("/tmp"))
391 .with_session_id("test-session")
392 .with_user("test-user")
393 }
394
395 fn create_test_storage() -> Arc<TodoStorage> {
396 Arc::new(TodoStorage::new())
397 }
398
399 #[test]
400 fn test_todo_item_creation() {
401 let todo = TodoItem::new("Run tests", "Running tests");
402 assert_eq!(todo.content, "Run tests");
403 assert_eq!(todo.active_form, "Running tests");
404 assert_eq!(todo.status, TodoStatus::Pending);
405 assert!(!todo.is_in_progress());
406 assert!(!todo.is_completed());
407 }
408
409 #[test]
410 fn test_todo_item_with_status() {
411 let todo =
412 TodoItem::with_status("Build project", "Building project", TodoStatus::InProgress);
413 assert_eq!(todo.content, "Build project");
414 assert_eq!(todo.active_form, "Building project");
415 assert_eq!(todo.status, TodoStatus::InProgress);
416 assert!(todo.is_in_progress());
417 assert!(!todo.is_completed());
418 }
419
420 #[test]
421 fn test_todo_storage_basic_operations() {
422 let storage = TodoStorage::new();
423 let agent_id = "test-agent";
424
425 assert!(storage.get_todos(agent_id).is_empty());
427
428 let todos = vec![
430 TodoItem::new("Task 1", "Doing task 1"),
431 TodoItem::with_status("Task 2", "Doing task 2", TodoStatus::InProgress),
432 ];
433 storage.set_todos(agent_id, todos.clone());
434
435 let retrieved = storage.get_todos(agent_id);
437 assert_eq!(retrieved.len(), 2);
438 assert_eq!(retrieved[0].content, "Task 1");
439 assert_eq!(retrieved[1].content, "Task 2");
440 assert_eq!(retrieved[1].status, TodoStatus::InProgress);
441
442 storage.set_todos(agent_id, vec![]);
444 assert!(storage.get_todos(agent_id).is_empty());
445 }
446
447 #[test]
448 fn test_todo_storage_multi_agent() {
449 let storage = TodoStorage::new();
450 let agent1 = "agent-1";
451 let agent2 = "agent-2";
452
453 storage.set_todos(
455 agent1,
456 vec![TodoItem::new("Agent 1 Task", "Doing agent 1 task")],
457 );
458 storage.set_todos(
459 agent2,
460 vec![TodoItem::new("Agent 2 Task", "Doing agent 2 task")],
461 );
462
463 let todos1 = storage.get_todos(agent1);
465 let todos2 = storage.get_todos(agent2);
466
467 assert_eq!(todos1.len(), 1);
468 assert_eq!(todos2.len(), 1);
469 assert_eq!(todos1[0].content, "Agent 1 Task");
470 assert_eq!(todos2[0].content, "Agent 2 Task");
471 }
472
473 #[test]
474 fn test_todo_storage_stats() {
475 let storage = TodoStorage::new();
476 let agent_id = "test-agent";
477
478 let todos = vec![
479 TodoItem::new("Pending task", "Doing pending task"),
480 TodoItem::with_status(
481 "In progress task",
482 "Doing in progress task",
483 TodoStatus::InProgress,
484 ),
485 TodoItem::with_status(
486 "Completed task",
487 "Doing completed task",
488 TodoStatus::Completed,
489 ),
490 ];
491 storage.set_todos(agent_id, todos);
492
493 let stats = storage.get_stats();
494 assert_eq!(stats.len(), 1);
495 assert_eq!(stats[agent_id], (1, 1, 1)); }
497
498 #[test]
499 fn test_tool_name() {
500 let tool = TodoWriteTool::new();
501 assert_eq!(tool.name(), "TodoWrite");
502 }
503
504 #[test]
505 fn test_tool_description() {
506 let tool = TodoWriteTool::new();
507 assert!(!tool.description().is_empty());
508 assert!(tool.description().contains("task list"));
509 assert!(tool.description().contains("progress"));
510 }
511
512 #[test]
513 fn test_tool_input_schema() {
514 let tool = TodoWriteTool::new();
515 let schema = tool.input_schema();
516 assert_eq!(schema["type"], "object");
517 assert!(schema["properties"]["todos"].is_object());
518 assert!(schema["required"]
519 .as_array()
520 .unwrap()
521 .contains(&serde_json::json!("todos")));
522 }
523
524 #[test]
525 fn test_tool_options() {
526 let tool = TodoWriteTool::new();
527 let options = tool.options();
528 assert_eq!(options.max_retries, 0);
529 assert_eq!(options.base_timeout, std::time::Duration::from_secs(5));
530 assert!(!options.enable_dynamic_timeout);
531 }
532
533 #[test]
534 fn test_builder_with_storage() {
535 let storage = create_test_storage();
536 let tool = TodoWriteTool::with_storage(storage.clone());
537 assert!(Arc::ptr_eq(&tool.storage, &storage));
538 }
539
540 #[test]
541 fn test_builder_with_default_agent_id() {
542 let tool = TodoWriteTool::new().with_default_agent_id("custom-agent");
543 assert_eq!(tool.default_agent_id, "custom-agent");
544 }
545
546 #[test]
547 fn test_validate_todos_success() {
548 let tool = TodoWriteTool::new();
549 let todos = vec![
550 TodoItem::new("Task 1", "Doing task 1"),
551 TodoItem::with_status("Task 2", "Doing task 2", TodoStatus::InProgress),
552 ];
553 assert!(tool.validate_todos(&todos).is_ok());
554 }
555
556 #[test]
557 fn test_validate_todos_multiple_in_progress() {
558 let tool = TodoWriteTool::new();
559 let todos = vec![
560 TodoItem::with_status("Task 1", "Doing task 1", TodoStatus::InProgress),
561 TodoItem::with_status("Task 2", "Doing task 2", TodoStatus::InProgress),
562 ];
563 let result = tool.validate_todos(&todos);
564 assert!(result.is_err());
565 assert!(result
566 .unwrap_err()
567 .contains("Only one task can be in_progress"));
568 }
569
570 #[test]
571 fn test_validate_todos_empty_content() {
572 let tool = TodoWriteTool::new();
573 let todos = vec![TodoItem::new("", "Doing something")];
574 let result = tool.validate_todos(&todos);
575 assert!(result.is_err());
576 assert!(result.unwrap_err().contains("Task content cannot be empty"));
577 }
578
579 #[test]
580 fn test_validate_todos_empty_active_form() {
581 let tool = TodoWriteTool::new();
582 let todos = vec![TodoItem::new("Do something", "")];
583 let result = tool.validate_todos(&todos);
584 assert!(result.is_err());
585 assert!(result
586 .unwrap_err()
587 .contains("Task active_form cannot be empty"));
588 }
589
590 #[test]
591 fn test_get_agent_id_from_environment() {
592 let tool = TodoWriteTool::new();
593 let context = create_test_context().with_env_var("AGENT_ID", "env-agent");
594 let agent_id = tool.get_agent_id(&context);
595 assert_eq!(agent_id, "env-agent");
596 }
597
598 #[test]
599 fn test_get_agent_id_from_session() {
600 let tool = TodoWriteTool::new();
601 let context = create_test_context();
602 let agent_id = tool.get_agent_id(&context);
603 assert_eq!(agent_id, "test-session");
604 }
605
606 #[test]
607 fn test_get_agent_id_default() {
608 let tool = TodoWriteTool::new();
609 let context = ToolContext::new(PathBuf::from("/tmp"));
610 let agent_id = tool.get_agent_id(&context);
611 assert_eq!(agent_id, "main");
612 }
613
614 #[tokio::test]
617 async fn test_check_permissions_valid_input() {
618 let tool = TodoWriteTool::new();
619 let context = create_test_context();
620 let params = serde_json::json!({
621 "todos": [
622 {
623 "content": "Test task",
624 "status": "pending",
625 "active_form": "Testing task"
626 }
627 ]
628 });
629
630 let result = tool.check_permissions(¶ms, &context).await;
631 assert!(result.is_allowed());
632 }
633
634 #[tokio::test]
635 async fn test_check_permissions_invalid_format() {
636 let tool = TodoWriteTool::new();
637 let context = create_test_context();
638 let params = serde_json::json!({"invalid": "format"});
639
640 let result = tool.check_permissions(¶ms, &context).await;
641 assert!(result.is_denied());
642 }
643
644 #[tokio::test]
645 async fn test_check_permissions_multiple_in_progress() {
646 let tool = TodoWriteTool::new();
647 let context = create_test_context();
648 let params = serde_json::json!({
649 "todos": [
650 {
651 "content": "Task 1",
652 "status": "in_progress",
653 "active_form": "Doing task 1"
654 },
655 {
656 "content": "Task 2",
657 "status": "in_progress",
658 "active_form": "Doing task 2"
659 }
660 ]
661 });
662
663 let result = tool.check_permissions(¶ms, &context).await;
664 assert!(result.is_denied());
665 }
666
667 #[tokio::test]
670 async fn test_execute_simple_todos() {
671 let storage = create_test_storage();
672 let tool = TodoWriteTool::with_storage(storage.clone());
673 let context = create_test_context();
674 let params = serde_json::json!({
675 "todos": [
676 {
677 "content": "Run tests",
678 "status": "pending",
679 "active_form": "Running tests"
680 },
681 {
682 "content": "Build project",
683 "status": "in_progress",
684 "active_form": "Building project"
685 }
686 ]
687 });
688
689 let result = tool.execute(params, &context).await;
690 assert!(result.is_ok());
691 let tool_result = result.unwrap();
692 assert!(tool_result.is_success());
693 assert!(tool_result
694 .output
695 .unwrap()
696 .contains("modified successfully"));
697
698 let saved_todos = storage.get_todos("test-session");
700 assert_eq!(saved_todos.len(), 2);
701 assert_eq!(saved_todos[0].content, "Run tests");
702 assert_eq!(saved_todos[1].content, "Build project");
703 assert_eq!(saved_todos[1].status, TodoStatus::InProgress);
704 }
705
706 #[tokio::test]
707 async fn test_execute_auto_clear_completed() {
708 let storage = create_test_storage();
709 let tool = TodoWriteTool::with_storage(storage.clone());
710 let context = create_test_context();
711 let params = serde_json::json!({
712 "todos": [
713 {
714 "content": "Task 1",
715 "status": "completed",
716 "active_form": "Doing task 1"
717 },
718 {
719 "content": "Task 2",
720 "status": "completed",
721 "active_form": "Doing task 2"
722 }
723 ]
724 });
725
726 let result = tool.execute(params, &context).await;
727 assert!(result.is_ok());
728 let tool_result = result.unwrap();
729 assert!(tool_result.is_success());
730 assert!(tool_result.output.unwrap().contains("All tasks completed"));
731
732 let saved_todos = storage.get_todos("test-session");
734 assert!(saved_todos.is_empty());
735
736 assert_eq!(
738 tool_result.metadata.get("auto_cleared"),
739 Some(&serde_json::json!(true))
740 );
741 }
742
743 #[tokio::test]
744 async fn test_execute_invalid_todos() {
745 let tool = TodoWriteTool::new();
746 let context = create_test_context();
747 let params = serde_json::json!({
748 "todos": [
749 {
750 "content": "Task 1",
751 "status": "in_progress",
752 "active_form": "Doing task 1"
753 },
754 {
755 "content": "Task 2",
756 "status": "in_progress",
757 "active_form": "Doing task 2"
758 }
759 ]
760 });
761
762 let result = tool.execute(params, &context).await;
763 assert!(result.is_ok());
764 let tool_result = result.unwrap();
765 assert!(tool_result.is_error());
766 assert!(tool_result
767 .error
768 .unwrap()
769 .contains("Only one task can be in_progress"));
770 }
771
772 #[tokio::test]
773 async fn test_execute_invalid_input_format() {
774 let tool = TodoWriteTool::new();
775 let context = create_test_context();
776 let params = serde_json::json!({"invalid": "format"});
777
778 let result = tool.execute(params, &context).await;
779 assert!(result.is_err());
780 assert!(matches!(result.unwrap_err(), ToolError::InvalidParams(_)));
781 }
782
783 #[tokio::test]
784 async fn test_execute_with_metadata() {
785 let storage = create_test_storage();
786 let tool = TodoWriteTool::with_storage(storage.clone());
787 let context = create_test_context();
788
789 let initial_todos = vec![TodoItem::new("Old task", "Doing old task")];
791 storage.set_todos("test-session", initial_todos.clone());
792
793 let params = serde_json::json!({
794 "todos": [
795 {
796 "content": "New task",
797 "status": "pending",
798 "active_form": "Doing new task"
799 }
800 ]
801 });
802
803 let result = tool.execute(params, &context).await;
804 assert!(result.is_ok());
805 let tool_result = result.unwrap();
806 assert!(tool_result.is_success());
807
808 assert_eq!(
810 tool_result.metadata.get("agent_id"),
811 Some(&serde_json::json!("test-session"))
812 );
813 assert!(tool_result.metadata.contains_key("old_todos"));
814 assert!(tool_result.metadata.contains_key("new_todos"));
815 assert_eq!(
816 tool_result.metadata.get("auto_cleared"),
817 Some(&serde_json::json!(false))
818 );
819 }
820}