1use crate::error::AgentError;
7use crate::tools::agent::constants::VERIFICATION_AGENT_TYPE;
8use crate::types::*;
9use std::collections::HashMap;
10use std::sync::{Mutex, OnceLock};
11
12pub const TODO_WRITE_TOOL_NAME: &str = "TodoWrite";
13
14static TODOS: OnceLock<Mutex<HashMap<String, Vec<TodoItem>>>> = OnceLock::new();
16
17fn get_todos_map() -> &'static Mutex<HashMap<String, Vec<TodoItem>>> {
18 TODOS.get_or_init(|| Mutex::new(HashMap::new()))
19}
20
21pub fn get_unfinished_todos(session_key: &str) -> Vec<TodoItem> {
23 let mut guard = get_todos_map().lock().unwrap();
24 guard
25 .get(session_key)
26 .cloned()
27 .unwrap_or_default()
28 .into_iter()
29 .filter(|t| t.status != "completed")
30 .collect()
31}
32
33pub fn get_all_todos(session_key: &str) -> Vec<TodoItem> {
35 let mut guard = get_todos_map().lock().unwrap();
36 guard.get(session_key).cloned().unwrap_or_default()
37}
38
39#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
41pub struct TodoItem {
42 pub content: String,
43 pub status: String, #[serde(rename = "ACTIVE_FORM")]
45 pub active_form: Option<String>,
46}
47
48#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
50pub struct TodoList {
51 pub old_todos: Vec<TodoItem>,
52 pub new_todos: Vec<TodoItem>,
53 pub verification_nudge_needed: Option<bool>,
54}
55
56pub struct TodoWriteTool;
58
59impl TodoWriteTool {
60 pub fn new() -> Self {
61 Self
62 }
63
64 pub fn name(&self) -> &str {
65 TODO_WRITE_TOOL_NAME
66 }
67
68 pub fn description(&self) -> &str {
69 "Update the todo list for this session. Provide the complete updated list of todos."
70 }
71
72 pub fn user_facing_name(&self, _input: Option<&serde_json::Value>) -> String {
73 "TodoWrite".to_string()
74 }
75
76 pub fn get_tool_use_summary(&self, _input: Option<&serde_json::Value>) -> Option<String> {
77 None
78 }
79
80 pub fn render_tool_result_message(
81 &self,
82 content: &serde_json::Value,
83 ) -> Option<String> {
84 content["content"].as_str().map(|s| s.to_string())
85 }
86
87 pub fn input_schema(&self) -> ToolInputSchema {
88 ToolInputSchema {
89 schema_type: "object".to_string(),
90 properties: serde_json::json!({
91 "todos": {
92 "type": "array",
93 "items": {
94 "type": "object",
95 "properties": {
96 "content": { "type": "string", "description": "What needs to be done" },
97 "status": {
98 "type": "string",
99 "enum": ["in_progress", "pending", "completed"],
100 "description": "Current status of the task"
101 },
102 "ACTIVE_FORM": { "type": "string", "description": "Present continuous form for display" }
103 },
104 "required": ["content", "status"]
105 },
106 "description": "The updated todo list"
107 }
108 }),
109 required: Some(vec!["todos".to_string()]),
110 }
111 }
112
113 pub async fn execute(
114 &self,
115 input: serde_json::Value,
116 _context: &ToolContext,
117 ) -> Result<ToolResult, AgentError> {
118 let todos = input["todos"]
119 .as_array()
120 .ok_or_else(|| AgentError::Tool("todos is required".to_string()))?;
121
122 let new_items: Vec<TodoItem> = todos
123 .iter()
124 .filter_map(|t| {
125 let content = t.get("content")?.as_str()?.to_string();
126 let status = t.get("status")?.as_str()?.to_string();
127 let active_form = t
128 .get("ACTIVE_FORM")
129 .and_then(|v| v.as_str())
130 .map(|s| s.to_string());
131 Some(TodoItem {
132 content,
133 status,
134 active_form,
135 })
136 })
137 .collect();
138
139 let todo_key = "default_session".to_string();
141
142 let mut guard = get_todos_map().lock().unwrap();
143 let old_todos = guard.get(&todo_key).cloned().unwrap_or_default();
144
145 let all_done = new_items.iter().all(|t| t.status == "completed");
147 let stored_todos = if all_done { vec![] } else { new_items.clone() };
148
149 guard.insert(todo_key.clone(), stored_todos);
150 drop(guard);
151
152 let verification_nudge_needed = all_done
154 && new_items.len() >= 3
155 && !new_items
156 .iter()
157 .any(|t| t.content.to_lowercase().contains("verif"));
158
159 let base = "Todos have been modified successfully. \
160 Ensure that you continue to use the todo list to track your progress. \
161 Please proceed with the current tasks if applicable";
162
163 let nudge = if verification_nudge_needed {
164 format!(
165 "\n\nNOTE: You just closed out {}+ tasks and none of them was a verification step. \
166 Before writing your final summary, spawn the verification agent (subagent_type=\"{}\"). \
167 You cannot self-assign PARTIAL by listing caveats in your summary — only the verifier issues a verdict.",
168 new_items.len(),
169 VERIFICATION_AGENT_TYPE
170 )
171 } else {
172 String::new()
173 };
174
175 Ok(ToolResult {
176 result_type: "text".to_string(),
177 tool_use_id: "todo_write".to_string(),
178 content: format!("{}{}", base, nudge),
179 is_error: Some(false),
180 was_persisted: None,
181 })
182 }
183}
184
185impl Default for TodoWriteTool {
186 fn default() -> Self {
187 Self::new()
188 }
189}
190
191pub fn reset_todos_for_testing() {
193 let mut guard = get_todos_map().lock().unwrap();
194 guard.clear();
195}
196
197#[cfg(test)]
198mod tests {
199 use super::*;
200
201 use crate::tests::common::clear_all_test_state;
202
203 #[test]
204 fn test_todo_write_tool_name() {
205 clear_all_test_state();
206 let tool = TodoWriteTool::new();
207 assert_eq!(tool.name(), TODO_WRITE_TOOL_NAME);
208 }
209
210 #[test]
211 fn test_todo_write_schema() {
212 clear_all_test_state();
213 let tool = TodoWriteTool::new();
214 let schema = tool.input_schema();
215 assert!(schema.properties.get("todos").is_some());
216 }
217
218 #[tokio::test]
219 async fn test_todo_write_creates_items() {
220 clear_all_test_state();
221 let tool = TodoWriteTool::new();
222 let input = serde_json::json!({
223 "todos": [
224 { "content": "Task 1", "status": "pending" },
225 { "content": "Task 2", "status": "in_progress" }
226 ]
227 });
228 let context = ToolContext::default();
229 let result = tool.execute(input, &context).await;
230 assert!(result.is_ok());
231 assert!(result.unwrap().content.contains("modified successfully"));
232 }
233
234 #[tokio::test]
235 async fn test_todo_write_clears_when_all_done() {
236 clear_all_test_state();
237 let tool = TodoWriteTool::new();
238 let input = serde_json::json!({
240 "todos": [
241 { "content": "Task A", "status": "completed" },
242 { "content": "Task B", "status": "completed" },
243 { "content": "Task C", "status": "completed" },
244 { "content": "Task D", "status": "completed" }
245 ]
246 });
247 let context = ToolContext::default();
248 let result = tool.execute(input, &context).await;
249 assert!(result.is_ok());
250 let content = result.unwrap().content;
251 assert!(content.contains("modified successfully"));
252 }
253
254 #[tokio::test]
255 async fn test_todo_write_verification_nudge() {
256 clear_all_test_state();
257 let tool = TodoWriteTool::new();
258 let input = serde_json::json!({
260 "todos": [
261 { "content": "Implement feature", "status": "completed" },
262 { "content": "Write tests", "status": "completed" },
263 { "content": "Update docs", "status": "completed" }
264 ]
265 });
266 let context = ToolContext::default();
267 let result = tool.execute(input, &context).await;
268 assert!(result.is_ok());
269 let content = result.unwrap().content;
270 assert!(content.contains("verification step"));
271 assert!(content.contains(VERIFICATION_AGENT_TYPE));
272 }
273}