use llm_coding_tools_core::operations::{read_todos, write_todos};
use llm_coding_tools_core::tool_names;
use llm_coding_tools_core::{ToolContext, ToolError, ToolOutput};
use rig::completion::ToolDefinition;
use rig::tool::Tool;
use schemars::{schema_for, JsonSchema};
use serde::Deserialize;
pub use llm_coding_tools_core::{Todo, TodoPriority, TodoState, TodoStatus};
#[derive(Debug, Clone, Deserialize, JsonSchema)]
pub struct TodoWriteArgs {
pub todos: Vec<Todo>,
}
#[derive(Debug, Clone, Deserialize, JsonSchema)]
pub struct TodoReadArgs {}
#[derive(Debug, Clone)]
pub struct TodoWriteTool {
state: TodoState,
}
impl TodoWriteTool {
pub fn new(state: TodoState) -> Self {
Self { state }
}
}
impl Tool for TodoWriteTool {
const NAME: &'static str = tool_names::TODO_WRITE;
type Error = ToolError;
type Args = TodoWriteArgs;
type Output = ToolOutput;
async fn definition(&self, _prompt: String) -> ToolDefinition {
ToolDefinition {
name: <Self as Tool>::NAME.to_string(),
description: "Replace the todo list with new items.".to_string(),
parameters: serde_json::to_value(schema_for!(TodoWriteArgs))
.expect("schema serialization should never fail"),
}
}
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
let message = write_todos(&self.state, args.todos)?;
Ok(ToolOutput::new(message))
}
}
#[derive(Debug, Clone)]
pub struct TodoReadTool {
state: TodoState,
}
impl TodoReadTool {
pub fn new(state: TodoState) -> Self {
Self { state }
}
}
impl Tool for TodoReadTool {
const NAME: &'static str = tool_names::TODO_READ;
type Error = ToolError;
type Args = TodoReadArgs;
type Output = ToolOutput;
async fn definition(&self, _prompt: String) -> ToolDefinition {
ToolDefinition {
name: <Self as Tool>::NAME.to_string(),
description: "Read the current todo list.".to_string(),
parameters: serde_json::to_value(schema_for!(TodoReadArgs))
.expect("schema serialization should never fail"),
}
}
async fn call(&self, _args: Self::Args) -> Result<Self::Output, Self::Error> {
let content = read_todos(&self.state);
Ok(ToolOutput::new(content))
}
}
impl ToolContext for TodoWriteTool {
const NAME: &'static str = tool_names::TODO_WRITE;
fn context(&self) -> &'static str {
llm_coding_tools_core::context::TODO_WRITE
}
}
impl ToolContext for TodoReadTool {
const NAME: &'static str = tool_names::TODO_READ;
fn context(&self) -> &'static str {
llm_coding_tools_core::context::TODO_READ
}
}
pub struct TodoTools {
pub write: TodoWriteTool,
pub read: TodoReadTool,
}
impl TodoTools {
pub fn new() -> Self {
let state = TodoState::new();
Self {
write: TodoWriteTool::new(state.clone()),
read: TodoReadTool::new(state),
}
}
pub fn with_state(state: TodoState) -> Self {
Self {
write: TodoWriteTool::new(state.clone()),
read: TodoReadTool::new(state),
}
}
}
impl Default for TodoTools {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_todo(id: &str, status: TodoStatus) -> Todo {
Todo {
id: id.to_string(),
content: format!("Task {id}"),
status,
priority: TodoPriority::Medium,
}
}
#[tokio::test]
async fn write_and_read_todos() {
let tools = TodoTools::new();
let write_args = TodoWriteArgs {
todos: vec![
make_todo("1", TodoStatus::Pending),
make_todo("2", TodoStatus::Completed),
],
};
let write_result = tools.write.call(write_args).await.unwrap();
assert!(write_result.content.contains("2 task(s)"));
let read_result = tools.read.call(TodoReadArgs {}).await.unwrap();
assert!(read_result.content.contains("Task 1"));
assert!(read_result.content.contains("Task 2"));
}
#[tokio::test]
async fn shared_state_works() {
let state = TodoState::new();
let write_tool = TodoWriteTool::new(state.clone());
let read_tool = TodoReadTool::new(state);
let write_args = TodoWriteArgs {
todos: vec![make_todo("shared", TodoStatus::InProgress)],
};
write_tool.call(write_args).await.unwrap();
let read_result = read_tool.call(TodoReadArgs {}).await.unwrap();
assert!(read_result.content.contains("shared"));
}
#[tokio::test]
async fn empty_list_returns_no_tasks() {
let tools = TodoTools::new();
let result = tools.read.call(TodoReadArgs {}).await.unwrap();
assert_eq!(result.content, "No tasks.");
}
}