agents_toolkit/
todos.rs

1use std::sync::{Arc, RwLock};
2
3use agents_core::agent::{ToolHandle, ToolResponse};
4use agents_core::command::{Command, StateDiff};
5use agents_core::messaging::{AgentMessage, MessageContent, MessageRole, ToolInvocation};
6use agents_core::state::{AgentStateSnapshot, TodoItem};
7use async_trait::async_trait;
8use serde::Deserialize;
9
10use crate::metadata_from;
11
12#[derive(Clone)]
13pub struct WriteTodosTool {
14    pub name: String,
15    pub state: Arc<RwLock<AgentStateSnapshot>>,
16}
17
18#[derive(Debug, Deserialize)]
19struct WriteTodosArgs {
20    todos: Vec<TodoItem>,
21}
22
23#[async_trait]
24impl ToolHandle for WriteTodosTool {
25    fn name(&self) -> &str {
26        &self.name
27    }
28
29    async fn invoke(&self, invocation: ToolInvocation) -> anyhow::Result<ToolResponse> {
30        let args: WriteTodosArgs = serde_json::from_value(invocation.args.clone())?;
31        let mut state = self.state.write().expect("todo state write lock poisoned");
32        state.todos = args.todos.clone();
33
34        let command = Command {
35            state: StateDiff {
36                todos: Some(args.todos.clone()),
37                ..StateDiff::default()
38            },
39            messages: vec![AgentMessage {
40                role: MessageRole::Tool,
41                content: MessageContent::Text(format!("Updated todo list to {:?}", args.todos)),
42                metadata: metadata_from(&invocation),
43            }],
44        };
45
46        Ok(ToolResponse::Command(command))
47    }
48}
49
50#[cfg(test)]
51mod tests {
52    use super::*;
53    use agents_core::messaging::ToolInvocation;
54    use serde_json::json;
55
56    #[tokio::test]
57    async fn write_todos_updates_state() {
58        let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
59        let tool = WriteTodosTool {
60            name: "write_todos".into(),
61            state: state.clone(),
62        };
63        let invocation = ToolInvocation {
64            tool_name: "write_todos".into(),
65            args: json!({
66                "todos": [
67                    { "content": "Do thing", "status": "pending" },
68                    { "content": "Ship", "status": "completed" }
69                ]
70            }),
71            tool_call_id: Some("call-1".into()),
72        };
73
74        let response = tool.invoke(invocation).await.unwrap();
75        match response {
76            ToolResponse::Command(cmd) => {
77                assert_eq!(cmd.state.todos.as_ref().unwrap().len(), 2);
78                assert_eq!(state.read().unwrap().todos.len(), 2);
79                assert_eq!(
80                    cmd.messages[0]
81                        .metadata
82                        .as_ref()
83                        .unwrap()
84                        .tool_call_id
85                        .as_deref(),
86                    Some("call-1")
87                );
88            }
89            _ => panic!("expected command"),
90        }
91    }
92}