Skip to main content

cersei_tools/
todo_write.rs

1//! TodoWrite tool: structured task list management.
2
3use super::*;
4use serde::{Deserialize, Serialize};
5
6/// Global todo storage keyed by session_id.
7static TODO_REGISTRY: once_cell::sync::Lazy<dashmap::DashMap<String, Vec<TodoItem>>> =
8    once_cell::sync::Lazy::new(dashmap::DashMap::new);
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct TodoItem {
12    pub content: String,
13    pub status: TodoStatus,
14    #[serde(rename = "activeForm")]
15    pub active_form: String,
16}
17
18#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
19#[serde(rename_all = "snake_case")]
20pub enum TodoStatus {
21    Pending,
22    InProgress,
23    Completed,
24}
25
26/// Get the current todo list for a session.
27pub fn get_todos(session_id: &str) -> Vec<TodoItem> {
28    TODO_REGISTRY
29        .get(session_id)
30        .map(|v| v.clone())
31        .unwrap_or_default()
32}
33
34/// Clear todos for a session.
35pub fn clear_todos(session_id: &str) {
36    TODO_REGISTRY.remove(session_id);
37}
38
39pub struct TodoWriteTool;
40
41#[async_trait]
42impl Tool for TodoWriteTool {
43    fn name(&self) -> &str {
44        "TodoWrite"
45    }
46    fn description(&self) -> &str {
47        "Create and manage a structured task list. Tracks progress with pending/in_progress/completed states."
48    }
49    fn permission_level(&self) -> PermissionLevel {
50        PermissionLevel::None
51    }
52
53    fn input_schema(&self) -> Value {
54        serde_json::json!({
55            "type": "object",
56            "properties": {
57                "todos": {
58                    "type": "array",
59                    "description": "The complete updated todo list",
60                    "items": {
61                        "type": "object",
62                        "properties": {
63                            "content": { "type": "string", "description": "Task description (imperative form)" },
64                            "status": { "type": "string", "enum": ["pending", "in_progress", "completed"] },
65                            "activeForm": { "type": "string", "description": "Present continuous form (e.g., 'Running tests')" }
66                        },
67                        "required": ["content", "status", "activeForm"]
68                    }
69                }
70            },
71            "required": ["todos"]
72        })
73    }
74
75    async fn execute(&self, input: Value, ctx: &ToolContext) -> ToolResult {
76        #[derive(Deserialize)]
77        struct Input {
78            todos: Vec<TodoItem>,
79        }
80
81        let input: Input = match serde_json::from_value(input) {
82            Ok(i) => i,
83            Err(e) => return ToolResult::error(format!("Invalid input: {}", e)),
84        };
85
86        TODO_REGISTRY.insert(ctx.session_id.clone(), input.todos.clone());
87
88        let summary: Vec<String> = input
89            .todos
90            .iter()
91            .enumerate()
92            .map(|(i, t)| {
93                let icon = match t.status {
94                    TodoStatus::Completed => "[x]",
95                    TodoStatus::InProgress => "[>]",
96                    TodoStatus::Pending => "[ ]",
97                };
98                format!("{}. {} {}", i + 1, icon, t.content)
99            })
100            .collect();
101
102        ToolResult::success(format!(
103            "Todos updated ({} items):\n{}",
104            input.todos.len(),
105            summary.join("\n")
106        ))
107    }
108}
109
110#[cfg(test)]
111mod tests {
112    use super::*;
113    use crate::permissions::AllowAll;
114
115    fn test_ctx() -> ToolContext {
116        ToolContext {
117            working_dir: std::env::temp_dir(),
118            session_id: "todo-test".into(),
119            permissions: Arc::new(AllowAll),
120            cost_tracker: Arc::new(CostTracker::new()),
121            mcp_manager: None,
122            extensions: Extensions::default(),
123        }
124    }
125
126    #[tokio::test]
127    async fn test_todo_create_and_read() {
128        clear_todos("todo-test");
129        let tool = TodoWriteTool;
130        let result = tool
131            .execute(
132                serde_json::json!({
133                    "todos": [
134                        {"content": "Build feature", "status": "in_progress", "activeForm": "Building feature"},
135                        {"content": "Write tests", "status": "pending", "activeForm": "Writing tests"}
136                    ]
137                }),
138                &test_ctx(),
139            )
140            .await;
141        assert!(!result.is_error);
142        assert!(result.content.contains("2 items"));
143
144        let todos = get_todos("todo-test");
145        assert_eq!(todos.len(), 2);
146        assert_eq!(todos[0].status, TodoStatus::InProgress);
147        assert_eq!(todos[1].status, TodoStatus::Pending);
148    }
149
150    #[tokio::test]
151    async fn test_todo_update() {
152        clear_todos("todo-test2");
153        let tool = TodoWriteTool;
154        let ctx = ToolContext {
155            session_id: "todo-test2".into(),
156            ..test_ctx()
157        };
158
159        // Create
160        tool.execute(
161            serde_json::json!({
162                "todos": [{"content": "Task A", "status": "pending", "activeForm": "Doing A"}]
163            }),
164            &ctx,
165        )
166        .await;
167
168        // Update to completed
169        tool.execute(
170            serde_json::json!({
171                "todos": [{"content": "Task A", "status": "completed", "activeForm": "Doing A"}]
172            }),
173            &ctx,
174        )
175        .await;
176
177        let todos = get_todos("todo-test2");
178        assert_eq!(todos[0].status, TodoStatus::Completed);
179    }
180}