Skip to main content

agent_sdk/
todo.rs

1//! TODO task tracking for agents.
2//!
3//! This module provides tools for agents to track tasks and show progress.
4//! Task tracking helps agents organize complex work and gives users visibility
5//! into what the agent is working on.
6//!
7//! # Example
8//!
9//! ```no_run
10//! use agent_sdk::todo::{TodoState, TodoWriteTool, TodoReadTool};
11//! use std::sync::Arc;
12//! use tokio::sync::RwLock;
13//!
14//! let state = Arc::new(RwLock::new(TodoState::new()));
15//! let write_tool = TodoWriteTool::new(Arc::clone(&state));
16//! let read_tool = TodoReadTool::new(state);
17//! ```
18
19use std::fmt::Write;
20use std::path::PathBuf;
21use std::sync::Arc;
22
23use crate::{PrimitiveToolName, Tool, ToolContext, ToolResult, ToolTier};
24use anyhow::{Context, Result};
25use serde::{Deserialize, Serialize};
26use serde_json::{Value, json};
27use tokio::sync::RwLock;
28
29/// Status of a TODO item.
30#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
31#[serde(rename_all = "snake_case")]
32pub enum TodoStatus {
33    /// Task not yet started.
34    Pending,
35    /// Task currently being worked on.
36    InProgress,
37    /// Task finished successfully.
38    Completed,
39}
40
41impl TodoStatus {
42    /// Returns the icon for this status.
43    #[must_use]
44    pub const fn icon(&self) -> &'static str {
45        match self {
46            Self::Pending => "○",
47            Self::InProgress => "⚡",
48            Self::Completed => "✓",
49        }
50    }
51}
52
53/// A single TODO item.
54#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct TodoItem {
56    /// Task description in imperative form (e.g., "Fix the bug").
57    pub content: String,
58    /// Current status of the task.
59    pub status: TodoStatus,
60    /// Present continuous form shown during execution (e.g., "Fixing the bug").
61    pub active_form: String,
62}
63
64impl TodoItem {
65    /// Creates a new pending TODO item.
66    #[must_use]
67    pub fn new(content: impl Into<String>, active_form: impl Into<String>) -> Self {
68        Self {
69            content: content.into(),
70            status: TodoStatus::Pending,
71            active_form: active_form.into(),
72        }
73    }
74
75    /// Creates a new TODO item with the given status.
76    #[must_use]
77    pub fn with_status(
78        content: impl Into<String>,
79        active_form: impl Into<String>,
80        status: TodoStatus,
81    ) -> Self {
82        Self {
83            content: content.into(),
84            status,
85            active_form: active_form.into(),
86        }
87    }
88
89    /// Returns the icon for this item's status.
90    #[must_use]
91    pub const fn icon(&self) -> &'static str {
92        self.status.icon()
93    }
94}
95
96/// Shared TODO state that can be persisted.
97#[derive(Debug, Default)]
98pub struct TodoState {
99    /// The list of TODO items.
100    pub items: Vec<TodoItem>,
101    /// Optional path for persistence.
102    storage_path: Option<PathBuf>,
103}
104
105impl TodoState {
106    /// Creates a new empty TODO state.
107    #[must_use]
108    pub const fn new() -> Self {
109        Self {
110            items: Vec::new(),
111            storage_path: None,
112        }
113    }
114
115    /// Creates a new TODO state with persistence.
116    #[must_use]
117    pub const fn with_storage(path: PathBuf) -> Self {
118        Self {
119            items: Vec::new(),
120            storage_path: Some(path),
121        }
122    }
123
124    /// Sets the storage path for persistence.
125    pub fn set_storage_path(&mut self, path: PathBuf) {
126        self.storage_path = Some(path);
127    }
128
129    /// Loads todos from storage if path is set.
130    ///
131    /// # Errors
132    ///
133    /// Returns an error if the file cannot be read or parsed.
134    pub fn load(&mut self) -> Result<()> {
135        if let Some(ref path) = self.storage_path.as_ref().filter(|p| p.exists()) {
136            let content = std::fs::read_to_string(path).context("Failed to read todos file")?;
137            self.items = serde_json::from_str(&content).context("Failed to parse todos file")?;
138        }
139        Ok(())
140    }
141
142    /// Saves todos to storage if path is set.
143    ///
144    /// # Errors
145    ///
146    /// Returns an error if the file cannot be written.
147    pub fn save(&self) -> Result<()> {
148        if let Some(ref path) = self.storage_path {
149            // Ensure parent directory exists
150            if let Some(parent) = path.parent() {
151                std::fs::create_dir_all(parent).context("Failed to create todos directory")?;
152            }
153            let content =
154                serde_json::to_string_pretty(&self.items).context("Failed to serialize todos")?;
155            std::fs::write(path, content).context("Failed to write todos file")?;
156        }
157        Ok(())
158    }
159
160    /// Replaces the entire TODO list.
161    pub fn set_items(&mut self, items: Vec<TodoItem>) {
162        self.items = items;
163    }
164
165    /// Adds a new TODO item.
166    pub fn add(&mut self, item: TodoItem) {
167        self.items.push(item);
168    }
169
170    /// Returns the count of items by status.
171    #[must_use]
172    pub fn count_by_status(&self) -> (usize, usize, usize) {
173        let pending = self
174            .items
175            .iter()
176            .filter(|i| i.status == TodoStatus::Pending)
177            .count();
178        let in_progress = self
179            .items
180            .iter()
181            .filter(|i| i.status == TodoStatus::InProgress)
182            .count();
183        let completed = self
184            .items
185            .iter()
186            .filter(|i| i.status == TodoStatus::Completed)
187            .count();
188        (pending, in_progress, completed)
189    }
190
191    /// Returns the currently in-progress item, if any.
192    #[must_use]
193    pub fn current_task(&self) -> Option<&TodoItem> {
194        self.items
195            .iter()
196            .find(|i| i.status == TodoStatus::InProgress)
197    }
198
199    /// Formats the TODO list for display.
200    #[must_use]
201    pub fn format_display(&self) -> String {
202        if self.items.is_empty() {
203            return "No tasks".to_string();
204        }
205
206        let (_pending, in_progress, completed) = self.count_by_status();
207        let total = self.items.len();
208
209        let mut output = format!("TODO ({completed}/{total})");
210
211        if in_progress > 0
212            && let Some(current) = self.current_task()
213        {
214            let _ = write!(output, " - {}", current.active_form);
215        }
216
217        output.push('\n');
218
219        for item in &self.items {
220            let _ = writeln!(output, "  {} {}", item.icon(), item.content);
221        }
222
223        output
224    }
225
226    /// Returns true if there are no items.
227    #[must_use]
228    pub const fn is_empty(&self) -> bool {
229        self.items.is_empty()
230    }
231
232    /// Returns the number of items.
233    #[must_use]
234    pub const fn len(&self) -> usize {
235        self.items.len()
236    }
237}
238
239/// Tool for writing/updating the TODO list.
240pub struct TodoWriteTool {
241    /// Shared TODO state.
242    state: Arc<RwLock<TodoState>>,
243}
244
245impl TodoWriteTool {
246    /// Creates a new `TodoWriteTool`.
247    #[must_use]
248    pub const fn new(state: Arc<RwLock<TodoState>>) -> Self {
249        Self { state }
250    }
251}
252
253/// Input for a single TODO item.
254#[derive(Debug, Deserialize)]
255struct TodoItemInput {
256    content: String,
257    status: TodoStatus,
258    #[serde(rename = "activeForm")]
259    active_form: String,
260}
261
262/// Input schema for `TodoWriteTool`.
263#[derive(Debug, Deserialize)]
264struct TodoWriteInput {
265    todos: Vec<TodoItemInput>,
266}
267
268impl<Ctx: Send + Sync + 'static> Tool<Ctx> for TodoWriteTool {
269    type Name = PrimitiveToolName;
270
271    fn name(&self) -> PrimitiveToolName {
272        PrimitiveToolName::TodoWrite
273    }
274
275    fn display_name(&self) -> &'static str {
276        "Update Tasks"
277    }
278
279    fn description(&self) -> &'static str {
280        "Update the TODO list to track tasks and show progress to the user. \
281         Use this tool frequently to plan complex tasks and mark progress. \
282         Each item needs 'content' (imperative form like 'Fix the bug'), \
283         'status' (pending/in_progress/completed), and 'activeForm' \
284         (present continuous like 'Fixing the bug'). \
285         Mark tasks completed immediately when done - don't batch completions."
286    }
287
288    fn input_schema(&self) -> Value {
289        json!({
290            "type": "object",
291            "required": ["todos"],
292            "properties": {
293                "todos": {
294                    "type": "array",
295                    "description": "The complete TODO list (replaces existing)",
296                    "items": {
297                        "type": "object",
298                        "required": ["content", "status", "activeForm"],
299                        "properties": {
300                            "content": {
301                                "type": "string",
302                                "description": "Task description in imperative form (e.g., 'Fix the bug')"
303                            },
304                            "status": {
305                                "type": "string",
306                                "enum": ["pending", "in_progress", "completed"],
307                                "description": "Current status of the task"
308                            },
309                            "activeForm": {
310                                "type": "string",
311                                "description": "Present continuous form shown during execution (e.g., 'Fixing the bug')"
312                            }
313                        }
314                    }
315                }
316            }
317        })
318    }
319
320    fn tier(&self) -> ToolTier {
321        ToolTier::Observe // No dangerous side effects
322    }
323
324    async fn execute(&self, _ctx: &ToolContext<Ctx>, input: Value) -> Result<ToolResult> {
325        let input: TodoWriteInput =
326            serde_json::from_value(input).context("Invalid input for todo_write")?;
327
328        let items: Vec<TodoItem> = input
329            .todos
330            .into_iter()
331            .map(|t| TodoItem {
332                content: t.content,
333                status: t.status,
334                active_form: t.active_form,
335            })
336            .collect();
337
338        let display = {
339            let mut state = self.state.write().await;
340            state.set_items(items);
341
342            // Save to storage if configured
343            if let Err(e) = state.save() {
344                log::warn!("Failed to save todos: {e}");
345            }
346
347            state.format_display()
348        };
349
350        Ok(ToolResult::success(format!(
351            "TODO list updated.\n\n{display}"
352        )))
353    }
354}
355
356/// Tool for reading the current TODO list.
357pub struct TodoReadTool {
358    /// Shared TODO state.
359    state: Arc<RwLock<TodoState>>,
360}
361
362impl TodoReadTool {
363    /// Creates a new `TodoReadTool`.
364    #[must_use]
365    pub const fn new(state: Arc<RwLock<TodoState>>) -> Self {
366        Self { state }
367    }
368}
369
370impl<Ctx: Send + Sync + 'static> Tool<Ctx> for TodoReadTool {
371    type Name = PrimitiveToolName;
372
373    fn name(&self) -> PrimitiveToolName {
374        PrimitiveToolName::TodoRead
375    }
376
377    fn display_name(&self) -> &'static str {
378        "Read Tasks"
379    }
380
381    fn description(&self) -> &'static str {
382        "Read the current TODO list to see task status and progress."
383    }
384
385    fn input_schema(&self) -> Value {
386        json!({
387            "type": "object",
388            "properties": {}
389        })
390    }
391
392    fn tier(&self) -> ToolTier {
393        ToolTier::Observe
394    }
395
396    async fn execute(&self, _ctx: &ToolContext<Ctx>, _input: Value) -> Result<ToolResult> {
397        let display = {
398            let state = self.state.read().await;
399            state.format_display()
400        };
401
402        Ok(ToolResult::success(display))
403    }
404}
405
406#[cfg(test)]
407mod tests {
408    use super::*;
409
410    #[test]
411    fn test_todo_status_icons() {
412        assert_eq!(TodoStatus::Pending.icon(), "○");
413        assert_eq!(TodoStatus::InProgress.icon(), "⚡");
414        assert_eq!(TodoStatus::Completed.icon(), "✓");
415    }
416
417    #[test]
418    fn test_todo_item_new() {
419        let item = TodoItem::new("Fix the bug", "Fixing the bug");
420        assert_eq!(item.content, "Fix the bug");
421        assert_eq!(item.active_form, "Fixing the bug");
422        assert_eq!(item.status, TodoStatus::Pending);
423    }
424
425    #[test]
426    fn test_todo_state_count_by_status() {
427        let mut state = TodoState::new();
428        state.add(TodoItem::with_status(
429            "Task 1",
430            "Task 1",
431            TodoStatus::Pending,
432        ));
433        state.add(TodoItem::with_status(
434            "Task 2",
435            "Task 2",
436            TodoStatus::InProgress,
437        ));
438        state.add(TodoItem::with_status(
439            "Task 3",
440            "Task 3",
441            TodoStatus::Completed,
442        ));
443        state.add(TodoItem::with_status(
444            "Task 4",
445            "Task 4",
446            TodoStatus::Completed,
447        ));
448
449        let (pending, in_progress, completed) = state.count_by_status();
450        assert_eq!(pending, 1);
451        assert_eq!(in_progress, 1);
452        assert_eq!(completed, 2);
453    }
454
455    #[test]
456    fn test_todo_state_current_task() {
457        let mut state = TodoState::new();
458        state.add(TodoItem::with_status(
459            "Task 1",
460            "Task 1",
461            TodoStatus::Pending,
462        ));
463        assert!(state.current_task().is_none());
464
465        state.add(TodoItem::with_status(
466            "Task 2",
467            "Working on Task 2",
468            TodoStatus::InProgress,
469        ));
470        let current = state.current_task().unwrap();
471        assert_eq!(current.content, "Task 2");
472        assert_eq!(current.active_form, "Working on Task 2");
473    }
474
475    #[test]
476    fn test_todo_state_format_display() {
477        let mut state = TodoState::new();
478        assert_eq!(state.format_display(), "No tasks");
479
480        state.add(TodoItem::with_status(
481            "Fix bug",
482            "Fixing bug",
483            TodoStatus::InProgress,
484        ));
485        state.add(TodoItem::with_status(
486            "Write tests",
487            "Writing tests",
488            TodoStatus::Pending,
489        ));
490
491        let display = state.format_display();
492        assert!(display.contains("TODO (0/2)"));
493        assert!(display.contains("Fixing bug"));
494        assert!(display.contains("⚡ Fix bug"));
495        assert!(display.contains("○ Write tests"));
496    }
497
498    #[test]
499    fn test_todo_status_serde() {
500        let status = TodoStatus::InProgress;
501        let json = serde_json::to_string(&status).unwrap();
502        assert_eq!(json, "\"in_progress\"");
503
504        let parsed: TodoStatus = serde_json::from_str("\"completed\"").unwrap();
505        assert_eq!(parsed, TodoStatus::Completed);
506    }
507}