1use std::fmt::Write;
20use std::path::PathBuf;
21use std::sync::Arc;
22
23use crate::{PrimitiveToolName, Tool, ToolContext, ToolResult, ToolTier};
24use anyhow::{Context, Result};
25use serde::{Deserialize, Serialize};
26use serde_json::{Value, json};
27use tokio::sync::RwLock;
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
31#[serde(rename_all = "snake_case")]
32pub enum TodoStatus {
33 Pending,
35 InProgress,
37 Completed,
39}
40
41impl TodoStatus {
42 #[must_use]
44 pub const fn icon(&self) -> &'static str {
45 match self {
46 Self::Pending => "○",
47 Self::InProgress => "⚡",
48 Self::Completed => "✓",
49 }
50 }
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct TodoItem {
56 pub content: String,
58 pub status: TodoStatus,
60 pub active_form: String,
62}
63
64impl TodoItem {
65 #[must_use]
67 pub fn new(content: impl Into<String>, active_form: impl Into<String>) -> Self {
68 Self {
69 content: content.into(),
70 status: TodoStatus::Pending,
71 active_form: active_form.into(),
72 }
73 }
74
75 #[must_use]
77 pub fn with_status(
78 content: impl Into<String>,
79 active_form: impl Into<String>,
80 status: TodoStatus,
81 ) -> Self {
82 Self {
83 content: content.into(),
84 status,
85 active_form: active_form.into(),
86 }
87 }
88
89 #[must_use]
91 pub const fn icon(&self) -> &'static str {
92 self.status.icon()
93 }
94}
95
96#[derive(Debug, Default)]
98pub struct TodoState {
99 pub items: Vec<TodoItem>,
101 storage_path: Option<PathBuf>,
103}
104
105impl TodoState {
106 #[must_use]
108 pub const fn new() -> Self {
109 Self {
110 items: Vec::new(),
111 storage_path: None,
112 }
113 }
114
115 #[must_use]
117 pub const fn with_storage(path: PathBuf) -> Self {
118 Self {
119 items: Vec::new(),
120 storage_path: Some(path),
121 }
122 }
123
124 pub fn set_storage_path(&mut self, path: PathBuf) {
126 self.storage_path = Some(path);
127 }
128
129 pub fn load(&mut self) -> Result<()> {
135 if let Some(ref path) = self.storage_path.as_ref().filter(|p| p.exists()) {
136 let content = std::fs::read_to_string(path).context("Failed to read todos file")?;
137 self.items = serde_json::from_str(&content).context("Failed to parse todos file")?;
138 }
139 Ok(())
140 }
141
142 pub fn save(&self) -> Result<()> {
148 if let Some(ref path) = self.storage_path {
149 if let Some(parent) = path.parent() {
151 std::fs::create_dir_all(parent).context("Failed to create todos directory")?;
152 }
153 let content =
154 serde_json::to_string_pretty(&self.items).context("Failed to serialize todos")?;
155 std::fs::write(path, content).context("Failed to write todos file")?;
156 }
157 Ok(())
158 }
159
160 pub fn set_items(&mut self, items: Vec<TodoItem>) {
162 self.items = items;
163 }
164
165 pub fn add(&mut self, item: TodoItem) {
167 self.items.push(item);
168 }
169
170 #[must_use]
172 pub fn count_by_status(&self) -> (usize, usize, usize) {
173 let pending = self
174 .items
175 .iter()
176 .filter(|i| i.status == TodoStatus::Pending)
177 .count();
178 let in_progress = self
179 .items
180 .iter()
181 .filter(|i| i.status == TodoStatus::InProgress)
182 .count();
183 let completed = self
184 .items
185 .iter()
186 .filter(|i| i.status == TodoStatus::Completed)
187 .count();
188 (pending, in_progress, completed)
189 }
190
191 #[must_use]
193 pub fn current_task(&self) -> Option<&TodoItem> {
194 self.items
195 .iter()
196 .find(|i| i.status == TodoStatus::InProgress)
197 }
198
199 #[must_use]
201 pub fn format_display(&self) -> String {
202 if self.items.is_empty() {
203 return "No tasks".to_string();
204 }
205
206 let (_pending, in_progress, completed) = self.count_by_status();
207 let total = self.items.len();
208
209 let mut output = format!("TODO ({completed}/{total})");
210
211 if in_progress > 0
212 && let Some(current) = self.current_task()
213 {
214 let _ = write!(output, " - {}", current.active_form);
215 }
216
217 output.push('\n');
218
219 for item in &self.items {
220 let _ = writeln!(output, " {} {}", item.icon(), item.content);
221 }
222
223 output
224 }
225
226 #[must_use]
228 pub const fn is_empty(&self) -> bool {
229 self.items.is_empty()
230 }
231
232 #[must_use]
234 pub const fn len(&self) -> usize {
235 self.items.len()
236 }
237}
238
239pub struct TodoWriteTool {
241 state: Arc<RwLock<TodoState>>,
243}
244
245impl TodoWriteTool {
246 #[must_use]
248 pub const fn new(state: Arc<RwLock<TodoState>>) -> Self {
249 Self { state }
250 }
251}
252
253#[derive(Debug, Deserialize)]
255struct TodoItemInput {
256 content: String,
257 status: TodoStatus,
258 #[serde(rename = "activeForm")]
259 active_form: String,
260}
261
262#[derive(Debug, Deserialize)]
264struct TodoWriteInput {
265 todos: Vec<TodoItemInput>,
266}
267
268impl<Ctx: Send + Sync + 'static> Tool<Ctx> for TodoWriteTool {
269 type Name = PrimitiveToolName;
270
271 fn name(&self) -> PrimitiveToolName {
272 PrimitiveToolName::TodoWrite
273 }
274
275 fn display_name(&self) -> &'static str {
276 "Update Tasks"
277 }
278
279 fn description(&self) -> &'static str {
280 "Update the TODO list to track tasks and show progress to the user. \
281 Use this tool frequently to plan complex tasks and mark progress. \
282 Each item needs 'content' (imperative form like 'Fix the bug'), \
283 'status' (pending/in_progress/completed), and 'activeForm' \
284 (present continuous like 'Fixing the bug'). \
285 Mark tasks completed immediately when done - don't batch completions."
286 }
287
288 fn input_schema(&self) -> Value {
289 json!({
290 "type": "object",
291 "required": ["todos"],
292 "properties": {
293 "todos": {
294 "type": "array",
295 "description": "The complete TODO list (replaces existing)",
296 "items": {
297 "type": "object",
298 "required": ["content", "status", "activeForm"],
299 "properties": {
300 "content": {
301 "type": "string",
302 "description": "Task description in imperative form (e.g., 'Fix the bug')"
303 },
304 "status": {
305 "type": "string",
306 "enum": ["pending", "in_progress", "completed"],
307 "description": "Current status of the task"
308 },
309 "activeForm": {
310 "type": "string",
311 "description": "Present continuous form shown during execution (e.g., 'Fixing the bug')"
312 }
313 }
314 }
315 }
316 }
317 })
318 }
319
320 fn tier(&self) -> ToolTier {
321 ToolTier::Observe }
323
324 async fn execute(&self, _ctx: &ToolContext<Ctx>, input: Value) -> Result<ToolResult> {
325 let input: TodoWriteInput =
326 serde_json::from_value(input).context("Invalid input for todo_write")?;
327
328 let items: Vec<TodoItem> = input
329 .todos
330 .into_iter()
331 .map(|t| TodoItem {
332 content: t.content,
333 status: t.status,
334 active_form: t.active_form,
335 })
336 .collect();
337
338 let display = {
339 let mut state = self.state.write().await;
340 state.set_items(items);
341
342 if let Err(e) = state.save() {
344 log::warn!("Failed to save todos: {e}");
345 }
346
347 state.format_display()
348 };
349
350 Ok(ToolResult::success(format!(
351 "TODO list updated.\n\n{display}"
352 )))
353 }
354}
355
356pub struct TodoReadTool {
358 state: Arc<RwLock<TodoState>>,
360}
361
362impl TodoReadTool {
363 #[must_use]
365 pub const fn new(state: Arc<RwLock<TodoState>>) -> Self {
366 Self { state }
367 }
368}
369
370impl<Ctx: Send + Sync + 'static> Tool<Ctx> for TodoReadTool {
371 type Name = PrimitiveToolName;
372
373 fn name(&self) -> PrimitiveToolName {
374 PrimitiveToolName::TodoRead
375 }
376
377 fn display_name(&self) -> &'static str {
378 "Read Tasks"
379 }
380
381 fn description(&self) -> &'static str {
382 "Read the current TODO list to see task status and progress."
383 }
384
385 fn input_schema(&self) -> Value {
386 json!({
387 "type": "object",
388 "properties": {}
389 })
390 }
391
392 fn tier(&self) -> ToolTier {
393 ToolTier::Observe
394 }
395
396 async fn execute(&self, _ctx: &ToolContext<Ctx>, _input: Value) -> Result<ToolResult> {
397 let display = {
398 let state = self.state.read().await;
399 state.format_display()
400 };
401
402 Ok(ToolResult::success(display))
403 }
404}
405
406#[cfg(test)]
407mod tests {
408 use super::*;
409
410 #[test]
411 fn test_todo_status_icons() {
412 assert_eq!(TodoStatus::Pending.icon(), "○");
413 assert_eq!(TodoStatus::InProgress.icon(), "⚡");
414 assert_eq!(TodoStatus::Completed.icon(), "✓");
415 }
416
417 #[test]
418 fn test_todo_item_new() {
419 let item = TodoItem::new("Fix the bug", "Fixing the bug");
420 assert_eq!(item.content, "Fix the bug");
421 assert_eq!(item.active_form, "Fixing the bug");
422 assert_eq!(item.status, TodoStatus::Pending);
423 }
424
425 #[test]
426 fn test_todo_state_count_by_status() {
427 let mut state = TodoState::new();
428 state.add(TodoItem::with_status(
429 "Task 1",
430 "Task 1",
431 TodoStatus::Pending,
432 ));
433 state.add(TodoItem::with_status(
434 "Task 2",
435 "Task 2",
436 TodoStatus::InProgress,
437 ));
438 state.add(TodoItem::with_status(
439 "Task 3",
440 "Task 3",
441 TodoStatus::Completed,
442 ));
443 state.add(TodoItem::with_status(
444 "Task 4",
445 "Task 4",
446 TodoStatus::Completed,
447 ));
448
449 let (pending, in_progress, completed) = state.count_by_status();
450 assert_eq!(pending, 1);
451 assert_eq!(in_progress, 1);
452 assert_eq!(completed, 2);
453 }
454
455 #[test]
456 fn test_todo_state_current_task() {
457 let mut state = TodoState::new();
458 state.add(TodoItem::with_status(
459 "Task 1",
460 "Task 1",
461 TodoStatus::Pending,
462 ));
463 assert!(state.current_task().is_none());
464
465 state.add(TodoItem::with_status(
466 "Task 2",
467 "Working on Task 2",
468 TodoStatus::InProgress,
469 ));
470 let current = state.current_task().unwrap();
471 assert_eq!(current.content, "Task 2");
472 assert_eq!(current.active_form, "Working on Task 2");
473 }
474
475 #[test]
476 fn test_todo_state_format_display() {
477 let mut state = TodoState::new();
478 assert_eq!(state.format_display(), "No tasks");
479
480 state.add(TodoItem::with_status(
481 "Fix bug",
482 "Fixing bug",
483 TodoStatus::InProgress,
484 ));
485 state.add(TodoItem::with_status(
486 "Write tests",
487 "Writing tests",
488 TodoStatus::Pending,
489 ));
490
491 let display = state.format_display();
492 assert!(display.contains("TODO (0/2)"));
493 assert!(display.contains("Fixing bug"));
494 assert!(display.contains("⚡ Fix bug"));
495 assert!(display.contains("○ Write tests"));
496 }
497
498 #[test]
499 fn test_todo_status_serde() {
500 let status = TodoStatus::InProgress;
501 let json = serde_json::to_string(&status).unwrap();
502 assert_eq!(json, "\"in_progress\"");
503
504 let parsed: TodoStatus = serde_json::from_str("\"completed\"").unwrap();
505 assert_eq!(parsed, TodoStatus::Completed);
506 }
507}