1use std::fmt::Write;
20use std::path::PathBuf;
21use std::sync::Arc;
22
23use crate::{PrimitiveToolName, Tool, ToolContext, ToolResult, ToolTier};
24use anyhow::{Context, Result};
25use serde::{Deserialize, Serialize};
26use serde_json::{Value, json};
27use tokio::sync::RwLock;
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
31#[serde(rename_all = "snake_case")]
32pub enum TodoStatus {
33 Pending,
35 InProgress,
37 Completed,
39}
40
41impl TodoStatus {
42 #[must_use]
44 pub const fn icon(&self) -> &'static str {
45 match self {
46 Self::Pending => "○",
47 Self::InProgress => "⚡",
48 Self::Completed => "✓",
49 }
50 }
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct TodoItem {
56 pub content: String,
58 pub status: TodoStatus,
60 pub active_form: String,
62}
63
64impl TodoItem {
65 #[must_use]
67 pub fn new(content: impl Into<String>, active_form: impl Into<String>) -> Self {
68 Self {
69 content: content.into(),
70 status: TodoStatus::Pending,
71 active_form: active_form.into(),
72 }
73 }
74
75 #[must_use]
77 pub fn with_status(
78 content: impl Into<String>,
79 active_form: impl Into<String>,
80 status: TodoStatus,
81 ) -> Self {
82 Self {
83 content: content.into(),
84 status,
85 active_form: active_form.into(),
86 }
87 }
88
89 #[must_use]
91 pub const fn icon(&self) -> &'static str {
92 self.status.icon()
93 }
94}
95
96#[derive(Debug, Default)]
98pub struct TodoState {
99 pub items: Vec<TodoItem>,
101 storage_path: Option<PathBuf>,
103}
104
105impl TodoState {
106 #[must_use]
108 pub const fn new() -> Self {
109 Self {
110 items: Vec::new(),
111 storage_path: None,
112 }
113 }
114
115 #[must_use]
117 pub const fn with_storage(path: PathBuf) -> Self {
118 Self {
119 items: Vec::new(),
120 storage_path: Some(path),
121 }
122 }
123
124 pub fn set_storage_path(&mut self, path: PathBuf) {
126 self.storage_path = Some(path);
127 }
128
129 pub async fn load(&mut self) -> Result<()> {
135 if let Some(ref path) = self.storage_path.as_ref().filter(|p| p.exists()) {
136 let content = tokio::fs::read_to_string(path)
137 .await
138 .context("Failed to read todos file")?;
139 self.items = serde_json::from_str(&content).context("Failed to parse todos file")?;
140 }
141 Ok(())
142 }
143
144 pub async fn save(&self) -> Result<()> {
150 if let Some(ref path) = self.storage_path {
151 if let Some(parent) = path.parent() {
153 tokio::fs::create_dir_all(parent)
154 .await
155 .context("Failed to create todos directory")?;
156 }
157 let content =
158 serde_json::to_string_pretty(&self.items).context("Failed to serialize todos")?;
159 tokio::fs::write(path, content)
160 .await
161 .context("Failed to write todos file")?;
162 }
163 Ok(())
164 }
165
166 pub fn set_items(&mut self, items: Vec<TodoItem>) {
168 self.items = items;
169 }
170
171 pub fn add(&mut self, item: TodoItem) {
173 self.items.push(item);
174 }
175
176 #[must_use]
178 pub fn count_by_status(&self) -> (usize, usize, usize) {
179 let pending = self
180 .items
181 .iter()
182 .filter(|i| i.status == TodoStatus::Pending)
183 .count();
184 let in_progress = self
185 .items
186 .iter()
187 .filter(|i| i.status == TodoStatus::InProgress)
188 .count();
189 let completed = self
190 .items
191 .iter()
192 .filter(|i| i.status == TodoStatus::Completed)
193 .count();
194 (pending, in_progress, completed)
195 }
196
197 #[must_use]
199 pub fn current_task(&self) -> Option<&TodoItem> {
200 self.items
201 .iter()
202 .find(|i| i.status == TodoStatus::InProgress)
203 }
204
205 #[must_use]
207 pub fn format_display(&self) -> String {
208 if self.items.is_empty() {
209 return "No tasks".to_string();
210 }
211
212 let (_pending, in_progress, completed) = self.count_by_status();
213 let total = self.items.len();
214
215 let mut output = format!("TODO ({completed}/{total})");
216
217 if in_progress > 0
218 && let Some(current) = self.current_task()
219 {
220 let _ = write!(output, " - {}", current.active_form);
221 }
222
223 output.push('\n');
224
225 for item in &self.items {
226 let _ = writeln!(output, " {} {}", item.icon(), item.content);
227 }
228
229 output
230 }
231
232 #[must_use]
234 pub const fn is_empty(&self) -> bool {
235 self.items.is_empty()
236 }
237
238 #[must_use]
240 pub const fn len(&self) -> usize {
241 self.items.len()
242 }
243}
244
245pub struct TodoWriteTool {
247 state: Arc<RwLock<TodoState>>,
249}
250
251impl TodoWriteTool {
252 #[must_use]
254 pub const fn new(state: Arc<RwLock<TodoState>>) -> Self {
255 Self { state }
256 }
257}
258
259#[derive(Debug, Deserialize)]
261struct TodoItemInput {
262 content: String,
263 status: TodoStatus,
264 #[serde(rename = "activeForm")]
265 active_form: String,
266}
267
268#[derive(Debug, Deserialize)]
270struct TodoWriteInput {
271 todos: Vec<TodoItemInput>,
272}
273
274impl<Ctx: Send + Sync + 'static> Tool<Ctx> for TodoWriteTool {
275 type Name = PrimitiveToolName;
276
277 fn name(&self) -> PrimitiveToolName {
278 PrimitiveToolName::TodoWrite
279 }
280
281 fn display_name(&self) -> &'static str {
282 "Update Tasks"
283 }
284
285 fn description(&self) -> &'static str {
286 "Update the TODO list to track tasks and show progress to the user. \
287 Use this tool frequently to plan complex tasks and mark progress. \
288 Each item needs 'content' (imperative form like 'Fix the bug'), \
289 'status' (pending/in_progress/completed), and 'activeForm' \
290 (present continuous like 'Fixing the bug'). \
291 Mark tasks completed immediately when done - don't batch completions."
292 }
293
294 fn input_schema(&self) -> Value {
295 json!({
296 "type": "object",
297 "required": ["todos"],
298 "properties": {
299 "todos": {
300 "type": "array",
301 "description": "The complete TODO list (replaces existing)",
302 "items": {
303 "type": "object",
304 "required": ["content", "status", "activeForm"],
305 "properties": {
306 "content": {
307 "type": "string",
308 "description": "Task description in imperative form (e.g., 'Fix the bug')"
309 },
310 "status": {
311 "type": "string",
312 "enum": ["pending", "in_progress", "completed"],
313 "description": "Current status of the task"
314 },
315 "activeForm": {
316 "type": "string",
317 "description": "Present continuous form shown during execution (e.g., 'Fixing the bug')"
318 }
319 }
320 }
321 }
322 }
323 })
324 }
325
326 fn tier(&self) -> ToolTier {
327 ToolTier::Observe }
329
330 async fn execute(&self, _ctx: &ToolContext<Ctx>, input: Value) -> Result<ToolResult> {
331 let input: TodoWriteInput =
332 serde_json::from_value(input).context("Invalid input for todo_write")?;
333
334 let items: Vec<TodoItem> = input
335 .todos
336 .into_iter()
337 .map(|t| TodoItem {
338 content: t.content,
339 status: t.status,
340 active_form: t.active_form,
341 })
342 .collect();
343
344 let display = {
345 let mut state = self.state.write().await;
346 state.set_items(items);
347
348 if let Err(e) = state.save().await {
350 log::warn!("Failed to save todos: {e}");
351 }
352
353 state.format_display()
354 };
355
356 Ok(ToolResult::success(format!(
357 "TODO list updated.\n\n{display}"
358 )))
359 }
360}
361
362pub struct TodoReadTool {
364 state: Arc<RwLock<TodoState>>,
366}
367
368impl TodoReadTool {
369 #[must_use]
371 pub const fn new(state: Arc<RwLock<TodoState>>) -> Self {
372 Self { state }
373 }
374}
375
376impl<Ctx: Send + Sync + 'static> Tool<Ctx> for TodoReadTool {
377 type Name = PrimitiveToolName;
378
379 fn name(&self) -> PrimitiveToolName {
380 PrimitiveToolName::TodoRead
381 }
382
383 fn display_name(&self) -> &'static str {
384 "Read Tasks"
385 }
386
387 fn description(&self) -> &'static str {
388 "Read the current TODO list to see task status and progress."
389 }
390
391 fn input_schema(&self) -> Value {
392 json!({
393 "type": "object",
394 "properties": {}
395 })
396 }
397
398 fn tier(&self) -> ToolTier {
399 ToolTier::Observe
400 }
401
402 async fn execute(&self, _ctx: &ToolContext<Ctx>, _input: Value) -> Result<ToolResult> {
403 let display = {
404 let state = self.state.read().await;
405 state.format_display()
406 };
407
408 Ok(ToolResult::success(display))
409 }
410}
411
412#[cfg(test)]
413mod tests {
414 use super::*;
415
416 #[test]
417 fn test_todo_status_icons() {
418 assert_eq!(TodoStatus::Pending.icon(), "○");
419 assert_eq!(TodoStatus::InProgress.icon(), "⚡");
420 assert_eq!(TodoStatus::Completed.icon(), "✓");
421 }
422
423 #[test]
424 fn test_todo_item_new() {
425 let item = TodoItem::new("Fix the bug", "Fixing the bug");
426 assert_eq!(item.content, "Fix the bug");
427 assert_eq!(item.active_form, "Fixing the bug");
428 assert_eq!(item.status, TodoStatus::Pending);
429 }
430
431 #[test]
432 fn test_todo_state_count_by_status() {
433 let mut state = TodoState::new();
434 state.add(TodoItem::with_status(
435 "Task 1",
436 "Task 1",
437 TodoStatus::Pending,
438 ));
439 state.add(TodoItem::with_status(
440 "Task 2",
441 "Task 2",
442 TodoStatus::InProgress,
443 ));
444 state.add(TodoItem::with_status(
445 "Task 3",
446 "Task 3",
447 TodoStatus::Completed,
448 ));
449 state.add(TodoItem::with_status(
450 "Task 4",
451 "Task 4",
452 TodoStatus::Completed,
453 ));
454
455 let (pending, in_progress, completed) = state.count_by_status();
456 assert_eq!(pending, 1);
457 assert_eq!(in_progress, 1);
458 assert_eq!(completed, 2);
459 }
460
461 #[test]
462 fn test_todo_state_current_task() {
463 let mut state = TodoState::new();
464 state.add(TodoItem::with_status(
465 "Task 1",
466 "Task 1",
467 TodoStatus::Pending,
468 ));
469 assert!(state.current_task().is_none());
470
471 state.add(TodoItem::with_status(
472 "Task 2",
473 "Working on Task 2",
474 TodoStatus::InProgress,
475 ));
476 let current = state.current_task().unwrap();
477 assert_eq!(current.content, "Task 2");
478 assert_eq!(current.active_form, "Working on Task 2");
479 }
480
481 #[test]
482 fn test_todo_state_format_display() {
483 let mut state = TodoState::new();
484 assert_eq!(state.format_display(), "No tasks");
485
486 state.add(TodoItem::with_status(
487 "Fix bug",
488 "Fixing bug",
489 TodoStatus::InProgress,
490 ));
491 state.add(TodoItem::with_status(
492 "Write tests",
493 "Writing tests",
494 TodoStatus::Pending,
495 ));
496
497 let display = state.format_display();
498 assert!(display.contains("TODO (0/2)"));
499 assert!(display.contains("Fixing bug"));
500 assert!(display.contains("⚡ Fix bug"));
501 assert!(display.contains("○ Write tests"));
502 }
503
504 #[test]
505 fn test_todo_status_serde() {
506 let status = TodoStatus::InProgress;
507 let json = serde_json::to_string(&status).unwrap();
508 assert_eq!(json, "\"in_progress\"");
509
510 let parsed: TodoStatus = serde_json::from_str("\"completed\"").unwrap();
511 assert_eq!(parsed, TodoStatus::Completed);
512 }
513}