Skip to main content

agent_sdk/
todo.rs

1//! TODO task tracking for agents.
2//!
3//! This module provides tools for agents to track tasks and show progress.
4//! Task tracking helps agents organize complex work and gives users visibility
5//! into what the agent is working on.
6//!
7//! # Example
8//!
9//! ```no_run
10//! use agent_sdk::todo::{TodoState, TodoWriteTool, TodoReadTool};
11//! use std::sync::Arc;
12//! use tokio::sync::RwLock;
13//!
14//! let state = Arc::new(RwLock::new(TodoState::new()));
15//! let write_tool = TodoWriteTool::new(Arc::clone(&state));
16//! let read_tool = TodoReadTool::new(state);
17//! ```
18
19use std::fmt::Write;
20use std::path::PathBuf;
21use std::sync::Arc;
22
23use crate::{PrimitiveToolName, Tool, ToolContext, ToolResult, ToolTier};
24use anyhow::{Context, Result};
25use serde::{Deserialize, Serialize};
26use serde_json::{Value, json};
27use tokio::sync::RwLock;
28
29/// Status of a TODO item.
30#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
31#[serde(rename_all = "snake_case")]
32pub enum TodoStatus {
33    /// Task not yet started.
34    Pending,
35    /// Task currently being worked on.
36    InProgress,
37    /// Task finished successfully.
38    Completed,
39}
40
41impl TodoStatus {
42    /// Returns the icon for this status.
43    #[must_use]
44    pub const fn icon(&self) -> &'static str {
45        match self {
46            Self::Pending => "○",
47            Self::InProgress => "⚡",
48            Self::Completed => "✓",
49        }
50    }
51}
52
53/// A single TODO item.
54#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct TodoItem {
56    /// Task description in imperative form (e.g., "Fix the bug").
57    pub content: String,
58    /// Current status of the task.
59    pub status: TodoStatus,
60    /// Present continuous form shown during execution (e.g., "Fixing the bug").
61    pub active_form: String,
62}
63
64impl TodoItem {
65    /// Creates a new pending TODO item.
66    #[must_use]
67    pub fn new(content: impl Into<String>, active_form: impl Into<String>) -> Self {
68        Self {
69            content: content.into(),
70            status: TodoStatus::Pending,
71            active_form: active_form.into(),
72        }
73    }
74
75    /// Creates a new TODO item with the given status.
76    #[must_use]
77    pub fn with_status(
78        content: impl Into<String>,
79        active_form: impl Into<String>,
80        status: TodoStatus,
81    ) -> Self {
82        Self {
83            content: content.into(),
84            status,
85            active_form: active_form.into(),
86        }
87    }
88
89    /// Returns the icon for this item's status.
90    #[must_use]
91    pub const fn icon(&self) -> &'static str {
92        self.status.icon()
93    }
94}
95
96/// Shared TODO state that can be persisted.
97#[derive(Debug, Default)]
98pub struct TodoState {
99    /// The list of TODO items.
100    pub items: Vec<TodoItem>,
101    /// Optional path for persistence.
102    storage_path: Option<PathBuf>,
103}
104
105impl TodoState {
106    /// Creates a new empty TODO state.
107    #[must_use]
108    pub const fn new() -> Self {
109        Self {
110            items: Vec::new(),
111            storage_path: None,
112        }
113    }
114
115    /// Creates a new TODO state with persistence.
116    #[must_use]
117    pub const fn with_storage(path: PathBuf) -> Self {
118        Self {
119            items: Vec::new(),
120            storage_path: Some(path),
121        }
122    }
123
124    /// Sets the storage path for persistence.
125    pub fn set_storage_path(&mut self, path: PathBuf) {
126        self.storage_path = Some(path);
127    }
128
129    /// Loads todos from storage if path is set.
130    ///
131    /// # Errors
132    ///
133    /// Returns an error if the file cannot be read or parsed.
134    pub async fn load(&mut self) -> Result<()> {
135        if let Some(ref path) = self.storage_path.as_ref().filter(|p| p.exists()) {
136            let content = tokio::fs::read_to_string(path)
137                .await
138                .context("Failed to read todos file")?;
139            self.items = serde_json::from_str(&content).context("Failed to parse todos file")?;
140        }
141        Ok(())
142    }
143
144    /// Saves todos to storage if path is set.
145    ///
146    /// # Errors
147    ///
148    /// Returns an error if the file cannot be written.
149    pub async fn save(&self) -> Result<()> {
150        if let Some(ref path) = self.storage_path {
151            // Ensure parent directory exists
152            if let Some(parent) = path.parent() {
153                tokio::fs::create_dir_all(parent)
154                    .await
155                    .context("Failed to create todos directory")?;
156            }
157            let content =
158                serde_json::to_string_pretty(&self.items).context("Failed to serialize todos")?;
159            tokio::fs::write(path, content)
160                .await
161                .context("Failed to write todos file")?;
162        }
163        Ok(())
164    }
165
166    /// Replaces the entire TODO list.
167    pub fn set_items(&mut self, items: Vec<TodoItem>) {
168        self.items = items;
169    }
170
171    /// Adds a new TODO item.
172    pub fn add(&mut self, item: TodoItem) {
173        self.items.push(item);
174    }
175
176    /// Returns the count of items by status.
177    #[must_use]
178    pub fn count_by_status(&self) -> (usize, usize, usize) {
179        let pending = self
180            .items
181            .iter()
182            .filter(|i| i.status == TodoStatus::Pending)
183            .count();
184        let in_progress = self
185            .items
186            .iter()
187            .filter(|i| i.status == TodoStatus::InProgress)
188            .count();
189        let completed = self
190            .items
191            .iter()
192            .filter(|i| i.status == TodoStatus::Completed)
193            .count();
194        (pending, in_progress, completed)
195    }
196
197    /// Returns the currently in-progress item, if any.
198    #[must_use]
199    pub fn current_task(&self) -> Option<&TodoItem> {
200        self.items
201            .iter()
202            .find(|i| i.status == TodoStatus::InProgress)
203    }
204
205    /// Formats the TODO list for display.
206    #[must_use]
207    pub fn format_display(&self) -> String {
208        if self.items.is_empty() {
209            return "No tasks".to_string();
210        }
211
212        let (_pending, in_progress, completed) = self.count_by_status();
213        let total = self.items.len();
214
215        let mut output = format!("TODO ({completed}/{total})");
216
217        if in_progress > 0
218            && let Some(current) = self.current_task()
219        {
220            let _ = write!(output, " - {}", current.active_form);
221        }
222
223        output.push('\n');
224
225        for item in &self.items {
226            let _ = writeln!(output, "  {} {}", item.icon(), item.content);
227        }
228
229        output
230    }
231
232    /// Returns true if there are no items.
233    #[must_use]
234    pub const fn is_empty(&self) -> bool {
235        self.items.is_empty()
236    }
237
238    /// Returns the number of items.
239    #[must_use]
240    pub const fn len(&self) -> usize {
241        self.items.len()
242    }
243}
244
245/// Tool for writing/updating the TODO list.
246pub struct TodoWriteTool {
247    /// Shared TODO state.
248    state: Arc<RwLock<TodoState>>,
249}
250
251impl TodoWriteTool {
252    /// Creates a new `TodoWriteTool`.
253    #[must_use]
254    pub const fn new(state: Arc<RwLock<TodoState>>) -> Self {
255        Self { state }
256    }
257}
258
259/// Input for a single TODO item.
260#[derive(Debug, Deserialize)]
261struct TodoItemInput {
262    content: String,
263    status: TodoStatus,
264    #[serde(rename = "activeForm")]
265    active_form: String,
266}
267
268/// Input schema for `TodoWriteTool`.
269#[derive(Debug, Deserialize)]
270struct TodoWriteInput {
271    todos: Vec<TodoItemInput>,
272}
273
274impl<Ctx: Send + Sync + 'static> Tool<Ctx> for TodoWriteTool {
275    type Name = PrimitiveToolName;
276
277    fn name(&self) -> PrimitiveToolName {
278        PrimitiveToolName::TodoWrite
279    }
280
281    fn display_name(&self) -> &'static str {
282        "Update Tasks"
283    }
284
285    fn description(&self) -> &'static str {
286        "Update the TODO list to track tasks and show progress to the user. \
287         Use this tool frequently to plan complex tasks and mark progress. \
288         Each item needs 'content' (imperative form like 'Fix the bug'), \
289         'status' (pending/in_progress/completed), and 'activeForm' \
290         (present continuous like 'Fixing the bug'). \
291         Mark tasks completed immediately when done - don't batch completions."
292    }
293
294    fn input_schema(&self) -> Value {
295        json!({
296            "type": "object",
297            "required": ["todos"],
298            "properties": {
299                "todos": {
300                    "type": "array",
301                    "description": "The complete TODO list (replaces existing)",
302                    "items": {
303                        "type": "object",
304                        "required": ["content", "status", "activeForm"],
305                        "properties": {
306                            "content": {
307                                "type": "string",
308                                "description": "Task description in imperative form (e.g., 'Fix the bug')"
309                            },
310                            "status": {
311                                "type": "string",
312                                "enum": ["pending", "in_progress", "completed"],
313                                "description": "Current status of the task"
314                            },
315                            "activeForm": {
316                                "type": "string",
317                                "description": "Present continuous form shown during execution (e.g., 'Fixing the bug')"
318                            }
319                        }
320                    }
321                }
322            }
323        })
324    }
325
326    fn tier(&self) -> ToolTier {
327        ToolTier::Observe // No dangerous side effects
328    }
329
330    async fn execute(&self, _ctx: &ToolContext<Ctx>, input: Value) -> Result<ToolResult> {
331        let input: TodoWriteInput =
332            serde_json::from_value(input).context("Invalid input for todo_write")?;
333
334        let items: Vec<TodoItem> = input
335            .todos
336            .into_iter()
337            .map(|t| TodoItem {
338                content: t.content,
339                status: t.status,
340                active_form: t.active_form,
341            })
342            .collect();
343
344        let display = {
345            let mut state = self.state.write().await;
346            state.set_items(items);
347
348            // Save to storage if configured
349            if let Err(e) = state.save().await {
350                log::warn!("Failed to save todos: {e}");
351            }
352
353            state.format_display()
354        };
355
356        Ok(ToolResult::success(format!(
357            "TODO list updated.\n\n{display}"
358        )))
359    }
360}
361
362/// Tool for reading the current TODO list.
363pub struct TodoReadTool {
364    /// Shared TODO state.
365    state: Arc<RwLock<TodoState>>,
366}
367
368impl TodoReadTool {
369    /// Creates a new `TodoReadTool`.
370    #[must_use]
371    pub const fn new(state: Arc<RwLock<TodoState>>) -> Self {
372        Self { state }
373    }
374}
375
376impl<Ctx: Send + Sync + 'static> Tool<Ctx> for TodoReadTool {
377    type Name = PrimitiveToolName;
378
379    fn name(&self) -> PrimitiveToolName {
380        PrimitiveToolName::TodoRead
381    }
382
383    fn display_name(&self) -> &'static str {
384        "Read Tasks"
385    }
386
387    fn description(&self) -> &'static str {
388        "Read the current TODO list to see task status and progress."
389    }
390
391    fn input_schema(&self) -> Value {
392        json!({
393            "type": "object",
394            "properties": {}
395        })
396    }
397
398    fn tier(&self) -> ToolTier {
399        ToolTier::Observe
400    }
401
402    async fn execute(&self, _ctx: &ToolContext<Ctx>, _input: Value) -> Result<ToolResult> {
403        let display = {
404            let state = self.state.read().await;
405            state.format_display()
406        };
407
408        Ok(ToolResult::success(display))
409    }
410}
411
412#[cfg(test)]
413mod tests {
414    use super::*;
415
416    #[test]
417    fn test_todo_status_icons() {
418        assert_eq!(TodoStatus::Pending.icon(), "○");
419        assert_eq!(TodoStatus::InProgress.icon(), "⚡");
420        assert_eq!(TodoStatus::Completed.icon(), "✓");
421    }
422
423    #[test]
424    fn test_todo_item_new() {
425        let item = TodoItem::new("Fix the bug", "Fixing the bug");
426        assert_eq!(item.content, "Fix the bug");
427        assert_eq!(item.active_form, "Fixing the bug");
428        assert_eq!(item.status, TodoStatus::Pending);
429    }
430
431    #[test]
432    fn test_todo_state_count_by_status() {
433        let mut state = TodoState::new();
434        state.add(TodoItem::with_status(
435            "Task 1",
436            "Task 1",
437            TodoStatus::Pending,
438        ));
439        state.add(TodoItem::with_status(
440            "Task 2",
441            "Task 2",
442            TodoStatus::InProgress,
443        ));
444        state.add(TodoItem::with_status(
445            "Task 3",
446            "Task 3",
447            TodoStatus::Completed,
448        ));
449        state.add(TodoItem::with_status(
450            "Task 4",
451            "Task 4",
452            TodoStatus::Completed,
453        ));
454
455        let (pending, in_progress, completed) = state.count_by_status();
456        assert_eq!(pending, 1);
457        assert_eq!(in_progress, 1);
458        assert_eq!(completed, 2);
459    }
460
461    #[test]
462    fn test_todo_state_current_task() {
463        let mut state = TodoState::new();
464        state.add(TodoItem::with_status(
465            "Task 1",
466            "Task 1",
467            TodoStatus::Pending,
468        ));
469        assert!(state.current_task().is_none());
470
471        state.add(TodoItem::with_status(
472            "Task 2",
473            "Working on Task 2",
474            TodoStatus::InProgress,
475        ));
476        let current = state.current_task().unwrap();
477        assert_eq!(current.content, "Task 2");
478        assert_eq!(current.active_form, "Working on Task 2");
479    }
480
481    #[test]
482    fn test_todo_state_format_display() {
483        let mut state = TodoState::new();
484        assert_eq!(state.format_display(), "No tasks");
485
486        state.add(TodoItem::with_status(
487            "Fix bug",
488            "Fixing bug",
489            TodoStatus::InProgress,
490        ));
491        state.add(TodoItem::with_status(
492            "Write tests",
493            "Writing tests",
494            TodoStatus::Pending,
495        ));
496
497        let display = state.format_display();
498        assert!(display.contains("TODO (0/2)"));
499        assert!(display.contains("Fixing bug"));
500        assert!(display.contains("⚡ Fix bug"));
501        assert!(display.contains("○ Write tests"));
502    }
503
504    #[test]
505    fn test_todo_status_serde() {
506        let status = TodoStatus::InProgress;
507        let json = serde_json::to_string(&status).unwrap();
508        assert_eq!(json, "\"in_progress\"");
509
510        let parsed: TodoStatus = serde_json::from_str("\"completed\"").unwrap();
511        assert_eq!(parsed, TodoStatus::Completed);
512    }
513}