Skip to main content

atomcode_core/tool/
todo.rs

1//! TodoWrite tool — lets the agent track tasks during a session.
2
3use std::sync::Arc;
4
5use anyhow::Result;
6use async_trait::async_trait;
7use serde::Deserialize;
8use serde_json::json;
9use tokio::sync::Mutex;
10
11use super::{ApprovalRequirement, Tool, ToolContext, ToolDef, ToolResult};
12
13#[derive(Debug, Clone)]
14struct TodoItem {
15    id: usize,
16    content: String,
17    status: String, // "pending", "in_progress", "completed"
18}
19
20/// Shared todo list, wrapped in Arc<Mutex> so multiple tool calls can access it.
21/// State IS shared across calls: TodoTool is constructed once via TodoTool::new()
22/// and registered as a single instance in the ToolRegistry. All execute() calls
23/// reference the same Arc<Mutex> fields, so items persist across the session.
24pub struct TodoTool {
25    items: Arc<Mutex<Vec<TodoItem>>>,
26    next_id: Arc<Mutex<usize>>,
27}
28
29impl TodoTool {
30    pub fn new() -> Self {
31        Self {
32            items: Arc::new(Mutex::new(Vec::new())),
33            next_id: Arc::new(Mutex::new(1)),
34        }
35    }
36
37    /// Format the todo list for display.
38    async fn format_list(items: &Arc<Mutex<Vec<TodoItem>>>) -> String {
39        let items = items.lock().await;
40        if items.is_empty() {
41            return "No tasks.".to_string();
42        }
43        let mut out = String::new();
44        for item in items.iter() {
45            let icon = match item.status.as_str() {
46                "completed" => "[x]",
47                "in_progress" => "[>]",
48                _ => "[ ]",
49            };
50            out.push_str(&format!("{} {}. {}\n", icon, item.id, item.content));
51        }
52        out
53    }
54}
55
56#[derive(Deserialize)]
57struct TodoArgs {
58    action: String,
59    #[serde(default)]
60    content: Option<String>,
61    #[serde(default)]
62    id: Option<usize>,
63    #[serde(default)]
64    status: Option<String>,
65}
66
67#[async_trait]
68impl Tool for TodoTool {
69    fn definition(&self) -> ToolDef {
70        ToolDef {
71            name: "todo",
72            description: "Manage a task list to track progress on multi-step work. Use 'add' to create tasks, 'update' to change status, and 'list' to show all tasks.".to_string(),
73            parameters: json!({
74                "type": "object",
75                "properties": {
76                    "action": {
77                        "type": "string",
78                        "enum": ["add", "update", "list"],
79                        "description": "Action: 'add' a new task, 'update' a task's status, or 'list' all tasks"
80                    },
81                    "content": {
82                        "type": "string",
83                        "description": "Task description (required for 'add')"
84                    },
85                    "id": {
86                        "type": "integer",
87                        "description": "Task ID (required for 'update')"
88                    },
89                    "status": {
90                        "type": "string",
91                        "enum": ["pending", "in_progress", "completed"],
92                        "description": "New status (required for 'update')"
93                    }
94                },
95                "required": ["action"]
96            }),
97        }
98    }
99
100    fn approval(&self, _args: &str) -> ApprovalRequirement {
101        ApprovalRequirement::AutoApprove
102    }
103
104    async fn execute(&self, args: &str, _ctx: &ToolContext) -> Result<ToolResult> {
105        let parsed: TodoArgs = serde_json::from_str(args)?;
106
107        match parsed.action.as_str() {
108            "add" => {
109                let content = parsed.content.unwrap_or_else(|| "Untitled task".to_string());
110                let mut id_guard = self.next_id.lock().await;
111                let id = *id_guard;
112                *id_guard += 1;
113                drop(id_guard);
114
115                let item = TodoItem {
116                    id,
117                    content: content.clone(),
118                    status: "pending".to_string(),
119                };
120                self.items.lock().await.push(item);
121
122                Ok(ToolResult {
123                    call_id: String::new(),
124                    output: format!("Added task #{}: {}", id, content),
125                    success: true,
126                })
127            }
128            "update" => {
129                let id = parsed.id.ok_or_else(|| anyhow::anyhow!("'id' is required for update"))?;
130                let status = parsed.status.unwrap_or_else(|| "in_progress".to_string());
131
132                let mut items = self.items.lock().await;
133                if let Some(item) = items.iter_mut().find(|i| i.id == id) {
134                    item.status = status.clone();
135                    Ok(ToolResult {
136                        call_id: String::new(),
137                        output: format!("Task #{} updated to '{}'", id, status),
138                        success: true,
139                    })
140                } else {
141                    Ok(ToolResult {
142                        call_id: String::new(),
143                        output: format!("Task #{} not found", id),
144                        success: false,
145                    })
146                }
147            }
148            "list" => {
149                let list = Self::format_list(&self.items).await;
150                Ok(ToolResult {
151                    call_id: String::new(),
152                    output: list,
153                    success: true,
154                })
155            }
156            other => Ok(ToolResult {
157                call_id: String::new(),
158                output: format!("Unknown action: {}. Use 'add', 'update', or 'list'.", other),
159                success: false,
160            }),
161        }
162    }
163}
164
165#[cfg(test)]
166mod tests {
167    use super::*;
168
169    #[tokio::test]
170    async fn add_and_list_tasks() {
171        let tool = TodoTool::new();
172        let ctx = ToolContext::new(std::path::PathBuf::from("/tmp"));
173
174        // Add
175        let r = tool.execute(r#"{"action":"add","content":"Write tests"}"#, &ctx).await.unwrap();
176        assert!(r.success);
177        assert!(r.output.contains("#1"));
178
179        let r = tool.execute(r#"{"action":"add","content":"Fix bug"}"#, &ctx).await.unwrap();
180        assert!(r.output.contains("#2"));
181
182        // List
183        let r = tool.execute(r#"{"action":"list"}"#, &ctx).await.unwrap();
184        assert!(r.output.contains("Write tests"));
185        assert!(r.output.contains("Fix bug"));
186        assert!(r.output.contains("[ ]")); // pending
187    }
188
189    #[tokio::test]
190    async fn update_task_status() {
191        let tool = TodoTool::new();
192        let ctx = ToolContext::new(std::path::PathBuf::from("/tmp"));
193
194        tool.execute(r#"{"action":"add","content":"Task 1"}"#, &ctx).await.unwrap();
195
196        let r = tool.execute(r#"{"action":"update","id":1,"status":"completed"}"#, &ctx).await.unwrap();
197        assert!(r.success);
198
199        let r = tool.execute(r#"{"action":"list"}"#, &ctx).await.unwrap();
200        assert!(r.output.contains("[x]")); // completed
201    }
202
203    #[tokio::test]
204    async fn update_nonexistent_task_fails() {
205        let tool = TodoTool::new();
206        let ctx = ToolContext::new(std::path::PathBuf::from("/tmp"));
207
208        let r = tool.execute(r#"{"action":"update","id":99,"status":"completed"}"#, &ctx).await.unwrap();
209        assert!(!r.success);
210        assert!(r.output.contains("not found"));
211    }
212
213    #[tokio::test]
214    async fn list_empty_shows_no_tasks() {
215        let tool = TodoTool::new();
216        let ctx = ToolContext::new(std::path::PathBuf::from("/tmp"));
217
218        let r = tool.execute(r#"{"action":"list"}"#, &ctx).await.unwrap();
219        assert!(r.output.contains("No tasks"));
220    }
221}