Skip to main content

heartbit_core/tool/builtins/
todo.rs

1use std::future::Future;
2use std::pin::Pin;
3use std::sync::{Arc, RwLock};
4
5use serde::{Deserialize, Serialize};
6use serde_json::json;
7
8use crate::error::Error;
9use crate::llm::types::ToolDefinition;
10use crate::tool::{Tool, ToolOutput};
11
12// --- TodoStore ---
13
14/// A single item in the agent's task list.
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct TodoItem {
17    /// The task description.
18    pub content: String,
19    /// Current status of the task.
20    pub status: TodoStatus,
21    /// Priority level of the task.
22    pub priority: TodoPriority,
23}
24
25/// Status of a to-do item.
26#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
27#[serde(rename_all = "snake_case")]
28pub enum TodoStatus {
29    /// Task has not been started.
30    Pending,
31    /// Task is currently being worked on (at most one at a time).
32    InProgress,
33    /// Task has been successfully completed.
34    Completed,
35    /// Task was abandoned.
36    Cancelled,
37    /// Task could not be completed due to an error.
38    Failed,
39    /// Task is waiting on an external dependency.
40    Blocked,
41}
42
43impl std::fmt::Display for TodoStatus {
44    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45        match self {
46            TodoStatus::Pending => write!(f, "pending"),
47            TodoStatus::InProgress => write!(f, "in_progress"),
48            TodoStatus::Completed => write!(f, "completed"),
49            TodoStatus::Cancelled => write!(f, "cancelled"),
50            TodoStatus::Failed => write!(f, "failed"),
51            TodoStatus::Blocked => write!(f, "blocked"),
52        }
53    }
54}
55
56/// Priority level of a to-do item.
57#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
58#[serde(rename_all = "snake_case")]
59pub enum TodoPriority {
60    /// Must be done immediately.
61    Critical,
62    /// Should be done next.
63    High,
64    /// Normal priority.
65    Medium,
66    /// Do when convenient.
67    Low,
68}
69
70impl std::fmt::Display for TodoPriority {
71    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
72        match self {
73            TodoPriority::Critical => write!(f, "critical"),
74            TodoPriority::High => write!(f, "high"),
75            TodoPriority::Medium => write!(f, "medium"),
76            TodoPriority::Low => write!(f, "low"),
77        }
78    }
79}
80
81/// Shared in-process store for agent to-do items.
82///
83/// Accessed via the `todo_read` and `todo_write` builtin tools.
84/// Thread-safe: backed by a `std::sync::RwLock`.
85pub struct TodoStore {
86    todos: RwLock<Vec<TodoItem>>,
87}
88
89impl Default for TodoStore {
90    fn default() -> Self {
91        Self::new()
92    }
93}
94
95impl TodoStore {
96    /// Create an empty `TodoStore`.
97    pub fn new() -> Self {
98        Self {
99            todos: RwLock::new(Vec::new()),
100        }
101    }
102
103    fn set(&self, todos: Vec<TodoItem>) -> Result<(), String> {
104        // Validate: at most 1 item can be in_progress
105        let in_progress_count = todos
106            .iter()
107            .filter(|t| t.status == TodoStatus::InProgress)
108            .count();
109        if in_progress_count > 1 {
110            return Err(format!(
111                "Only 1 item can be in_progress at a time (got {in_progress_count})"
112            ));
113        }
114
115        let mut guard = self.todos.write().expect("todo store lock poisoned");
116        *guard = todos;
117        Ok(())
118    }
119
120    fn get_all(&self) -> Vec<TodoItem> {
121        let guard = self.todos.read().expect("todo store lock poisoned");
122        guard.clone()
123    }
124}
125
126// --- Tools ---
127
128/// Create the `todo_read` and `todo_write` tool pair sharing a single store.
129pub fn todo_tools(store: Arc<TodoStore>) -> Vec<Arc<dyn Tool>> {
130    vec![
131        Arc::new(TodoWriteTool {
132            store: store.clone(),
133        }),
134        Arc::new(TodoReadTool { store }),
135    ]
136}
137
138struct TodoWriteTool {
139    store: Arc<TodoStore>,
140}
141
142impl Tool for TodoWriteTool {
143    fn definition(&self) -> ToolDefinition {
144        ToolDefinition {
145            name: "todowrite".into(),
146            description:
147                "Write/replace the full todo list. Only 1 item can be in_progress at a time. \
148                          This replaces the entire list (not append)."
149                    .into(),
150            input_schema: json!({
151                "type": "object",
152                "properties": {
153                    "todos": {
154                        "type": "array",
155                        "items": {
156                            "type": "object",
157                            "properties": {
158                                "content": {"type": "string"},
159                                "status": {
160                                    "type": "string",
161                                    "enum": ["pending", "in_progress", "completed", "cancelled", "failed", "blocked"]
162                                },
163                                "priority": {
164                                    "type": "string",
165                                    "enum": ["critical", "high", "medium", "low"]
166                                }
167                            },
168                            "required": ["content", "status", "priority"]
169                        }
170                    }
171                },
172                "required": ["todos"]
173            }),
174        }
175    }
176
177    fn execute(
178        &self,
179        _ctx: &crate::ExecutionContext,
180        input: serde_json::Value,
181    ) -> Pin<Box<dyn Future<Output = Result<ToolOutput, Error>> + Send + '_>> {
182        Box::pin(async move {
183            let todos_value = input
184                .get("todos")
185                .ok_or_else(|| Error::Agent("todos is required".into()))?;
186
187            let todos: Vec<TodoItem> = serde_json::from_value(todos_value.clone())
188                .map_err(|e| Error::Agent(format!("Invalid todo list: {e}")))?;
189
190            if let Err(msg) = self.store.set(todos) {
191                return Ok(ToolOutput::error(msg));
192            }
193
194            let all = self.store.get_all();
195            Ok(ToolOutput::success(format!(
196                "Todo list updated ({} items)",
197                all.len()
198            )))
199        })
200    }
201}
202
203struct TodoReadTool {
204    store: Arc<TodoStore>,
205}
206
207impl Tool for TodoReadTool {
208    fn definition(&self) -> ToolDefinition {
209        ToolDefinition {
210            name: "todoread".into(),
211            description: "Read the current todo list.".into(),
212            input_schema: json!({"type": "object"}),
213        }
214    }
215
216    fn execute(
217        &self,
218        _ctx: &crate::ExecutionContext,
219        _input: serde_json::Value,
220    ) -> Pin<Box<dyn Future<Output = Result<ToolOutput, Error>> + Send + '_>> {
221        Box::pin(async move {
222            let todos = self.store.get_all();
223
224            if todos.is_empty() {
225                return Ok(ToolOutput::success("No todos."));
226            }
227
228            let mut output = String::new();
229            for (i, todo) in todos.iter().enumerate() {
230                let status_icon = match todo.status {
231                    TodoStatus::Pending => "[ ]",
232                    TodoStatus::InProgress => "[>]",
233                    TodoStatus::Completed => "[x]",
234                    TodoStatus::Cancelled => "[-]",
235                    TodoStatus::Failed => "[!]",
236                    TodoStatus::Blocked => "[B]",
237                };
238                let priority_tag = match todo.priority {
239                    TodoPriority::Critical => " [CRITICAL]",
240                    TodoPriority::High => " [HIGH]",
241                    TodoPriority::Medium => "",
242                    TodoPriority::Low => " [low]",
243                };
244                output.push_str(&format!(
245                    "{}. {} {}{}\n",
246                    i + 1,
247                    status_icon,
248                    todo.content,
249                    priority_tag
250                ));
251            }
252
253            Ok(ToolOutput::success(output))
254        })
255    }
256}
257
258#[cfg(test)]
259mod tests {
260    use super::*;
261
262    #[test]
263    fn definition_names() {
264        let store = Arc::new(TodoStore::new());
265        let tools = todo_tools(store);
266        let names: Vec<String> = tools.iter().map(|t| t.definition().name).collect();
267        assert!(names.contains(&"todowrite".to_string()));
268        assert!(names.contains(&"todoread".to_string()));
269    }
270
271    #[tokio::test]
272    async fn todowrite_and_read() {
273        let store = Arc::new(TodoStore::new());
274        let tools = todo_tools(store);
275        let write_tool = &tools[0];
276        let read_tool = &tools[1];
277
278        // Write some todos
279        let result = write_tool
280            .execute(
281                &crate::ExecutionContext::default(),
282                json!({
283                    "todos": [
284                        {"content": "Fix bug", "status": "in_progress", "priority": "high"},
285                        {"content": "Write tests", "status": "pending", "priority": "medium"}
286                    ]
287                }),
288            )
289            .await
290            .unwrap();
291        assert!(!result.is_error, "got error: {}", result.content);
292        assert!(result.content.contains("2 items"));
293
294        // Read them back
295        let result = read_tool
296            .execute(&crate::ExecutionContext::default(), json!({}))
297            .await
298            .unwrap();
299        assert!(!result.is_error);
300        assert!(result.content.contains("Fix bug"));
301        assert!(result.content.contains("[HIGH]"));
302        assert!(result.content.contains("Write tests"));
303        assert!(result.content.contains("[>]")); // in_progress
304    }
305
306    #[tokio::test]
307    async fn todowrite_rejects_multiple_in_progress() {
308        let store = Arc::new(TodoStore::new());
309        let tools = todo_tools(store);
310        let write_tool = &tools[0];
311
312        let result = write_tool
313            .execute(
314                &crate::ExecutionContext::default(),
315                json!({
316                    "todos": [
317                        {"content": "Task 1", "status": "in_progress", "priority": "high"},
318                        {"content": "Task 2", "status": "in_progress", "priority": "high"}
319                    ]
320                }),
321            )
322            .await
323            .unwrap();
324        assert!(result.is_error);
325        assert!(result.content.contains("Only 1 item"));
326    }
327
328    #[tokio::test]
329    async fn todoread_empty() {
330        let store = Arc::new(TodoStore::new());
331        let tools = todo_tools(store);
332        let read_tool = &tools[1];
333
334        let result = read_tool
335            .execute(&crate::ExecutionContext::default(), json!({}))
336            .await
337            .unwrap();
338        assert!(!result.is_error);
339        assert!(result.content.contains("No todos"));
340    }
341
342    #[tokio::test]
343    async fn todowrite_replaces_full_list() {
344        let store = Arc::new(TodoStore::new());
345        let tools = todo_tools(store);
346        let write_tool = &tools[0];
347        let read_tool = &tools[1];
348
349        // First write
350        write_tool
351            .execute(
352                &crate::ExecutionContext::default(),
353                json!({"todos": [{"content": "Old", "status": "pending", "priority": "low"}]}),
354            )
355            .await
356            .unwrap();
357
358        // Second write replaces
359        write_tool
360            .execute(
361                &crate::ExecutionContext::default(),
362                json!({"todos": [{"content": "New", "status": "completed", "priority": "high"}]}),
363            )
364            .await
365            .unwrap();
366
367        let result = read_tool
368            .execute(&crate::ExecutionContext::default(), json!({}))
369            .await
370            .unwrap();
371        assert!(result.content.contains("New"));
372        assert!(!result.content.contains("Old"));
373    }
374}