cersei_tools/
todo_write.rs1use super::*;
4use serde::{Deserialize, Serialize};
5
6static TODO_REGISTRY: once_cell::sync::Lazy<dashmap::DashMap<String, Vec<TodoItem>>> =
8 once_cell::sync::Lazy::new(dashmap::DashMap::new);
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct TodoItem {
12 pub content: String,
13 pub status: TodoStatus,
14 #[serde(rename = "activeForm")]
15 pub active_form: String,
16}
17
18#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
19#[serde(rename_all = "snake_case")]
20pub enum TodoStatus {
21 Pending,
22 InProgress,
23 Completed,
24}
25
26pub fn get_todos(session_id: &str) -> Vec<TodoItem> {
28 TODO_REGISTRY
29 .get(session_id)
30 .map(|v| v.clone())
31 .unwrap_or_default()
32}
33
34pub fn clear_todos(session_id: &str) {
36 TODO_REGISTRY.remove(session_id);
37}
38
39pub struct TodoWriteTool;
40
41#[async_trait]
42impl Tool for TodoWriteTool {
43 fn name(&self) -> &str {
44 "TodoWrite"
45 }
46 fn description(&self) -> &str {
47 "Create and manage a structured task list. Tracks progress with pending/in_progress/completed states."
48 }
49 fn permission_level(&self) -> PermissionLevel {
50 PermissionLevel::None
51 }
52
53 fn input_schema(&self) -> Value {
54 serde_json::json!({
55 "type": "object",
56 "properties": {
57 "todos": {
58 "type": "array",
59 "description": "The complete updated todo list",
60 "items": {
61 "type": "object",
62 "properties": {
63 "content": { "type": "string", "description": "Task description (imperative form)" },
64 "status": { "type": "string", "enum": ["pending", "in_progress", "completed"] },
65 "activeForm": { "type": "string", "description": "Present continuous form (e.g., 'Running tests')" }
66 },
67 "required": ["content", "status", "activeForm"]
68 }
69 }
70 },
71 "required": ["todos"]
72 })
73 }
74
75 async fn execute(&self, input: Value, ctx: &ToolContext) -> ToolResult {
76 #[derive(Deserialize)]
77 struct Input {
78 todos: Vec<TodoItem>,
79 }
80
81 let input: Input = match serde_json::from_value(input) {
82 Ok(i) => i,
83 Err(e) => return ToolResult::error(format!("Invalid input: {}", e)),
84 };
85
86 TODO_REGISTRY.insert(ctx.session_id.clone(), input.todos.clone());
87
88 let summary: Vec<String> = input
89 .todos
90 .iter()
91 .enumerate()
92 .map(|(i, t)| {
93 let icon = match t.status {
94 TodoStatus::Completed => "[x]",
95 TodoStatus::InProgress => "[>]",
96 TodoStatus::Pending => "[ ]",
97 };
98 format!("{}. {} {}", i + 1, icon, t.content)
99 })
100 .collect();
101
102 ToolResult::success(format!(
103 "Todos updated ({} items):\n{}",
104 input.todos.len(),
105 summary.join("\n")
106 ))
107 }
108}
109
110#[cfg(test)]
111mod tests {
112 use super::*;
113 use crate::permissions::AllowAll;
114
115 fn test_ctx() -> ToolContext {
116 ToolContext {
117 working_dir: std::env::temp_dir(),
118 session_id: "todo-test".into(),
119 permissions: Arc::new(AllowAll),
120 cost_tracker: Arc::new(CostTracker::new()),
121 mcp_manager: None,
122 extensions: Extensions::default(),
123 }
124 }
125
126 #[tokio::test]
127 async fn test_todo_create_and_read() {
128 clear_todos("todo-test");
129 let tool = TodoWriteTool;
130 let result = tool
131 .execute(
132 serde_json::json!({
133 "todos": [
134 {"content": "Build feature", "status": "in_progress", "activeForm": "Building feature"},
135 {"content": "Write tests", "status": "pending", "activeForm": "Writing tests"}
136 ]
137 }),
138 &test_ctx(),
139 )
140 .await;
141 assert!(!result.is_error);
142 assert!(result.content.contains("2 items"));
143
144 let todos = get_todos("todo-test");
145 assert_eq!(todos.len(), 2);
146 assert_eq!(todos[0].status, TodoStatus::InProgress);
147 assert_eq!(todos[1].status, TodoStatus::Pending);
148 }
149
150 #[tokio::test]
151 async fn test_todo_update() {
152 clear_todos("todo-test2");
153 let tool = TodoWriteTool;
154 let ctx = ToolContext {
155 session_id: "todo-test2".into(),
156 ..test_ctx()
157 };
158
159 tool.execute(
161 serde_json::json!({
162 "todos": [{"content": "Task A", "status": "pending", "activeForm": "Doing A"}]
163 }),
164 &ctx,
165 )
166 .await;
167
168 tool.execute(
170 serde_json::json!({
171 "todos": [{"content": "Task A", "status": "completed", "activeForm": "Doing A"}]
172 }),
173 &ctx,
174 )
175 .await;
176
177 let todos = get_todos("todo-test2");
178 assert_eq!(todos[0].status, TodoStatus::Completed);
179 }
180}