1use std::sync::{Arc, RwLock};
2
3use agents_core::agent::{ToolHandle, ToolResponse};
4use agents_core::command::{Command, StateDiff};
5use agents_core::messaging::{AgentMessage, MessageContent, MessageRole, ToolInvocation};
6use agents_core::state::{AgentStateSnapshot, TodoItem};
7use async_trait::async_trait;
8use serde::Deserialize;
9
10use crate::metadata_from;
11
12#[derive(Clone)]
13pub struct WriteTodosTool {
14 pub name: String,
15 pub state: Arc<RwLock<AgentStateSnapshot>>,
16}
17
18#[derive(Debug, Deserialize)]
19struct WriteTodosArgs {
20 todos: Vec<TodoItem>,
21}
22
23#[async_trait]
24impl ToolHandle for WriteTodosTool {
25 fn name(&self) -> &str {
26 &self.name
27 }
28
29 async fn invoke(&self, invocation: ToolInvocation) -> anyhow::Result<ToolResponse> {
30 let args: WriteTodosArgs = serde_json::from_value(invocation.args.clone())?;
31 let mut state = self.state.write().expect("todo state write lock poisoned");
32 state.todos = args.todos.clone();
33
34 let command = Command {
35 state: StateDiff {
36 todos: Some(args.todos.clone()),
37 ..StateDiff::default()
38 },
39 messages: vec![AgentMessage {
40 role: MessageRole::Tool,
41 content: MessageContent::Text(format!("Updated todo list to {:?}", args.todos)),
42 metadata: metadata_from(&invocation),
43 }],
44 };
45
46 Ok(ToolResponse::Command(command))
47 }
48}
49
50#[cfg(test)]
51mod tests {
52 use super::*;
53 use agents_core::messaging::ToolInvocation;
54 use serde_json::json;
55
56 #[tokio::test]
57 async fn write_todos_updates_state() {
58 let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
59 let tool = WriteTodosTool {
60 name: "write_todos".into(),
61 state: state.clone(),
62 };
63 let invocation = ToolInvocation {
64 tool_name: "write_todos".into(),
65 args: json!({
66 "todos": [
67 { "content": "Do thing", "status": "pending" },
68 { "content": "Ship", "status": "completed" }
69 ]
70 }),
71 tool_call_id: Some("call-1".into()),
72 };
73
74 let response = tool.invoke(invocation).await.unwrap();
75 match response {
76 ToolResponse::Command(cmd) => {
77 assert_eq!(cmd.state.todos.as_ref().unwrap().len(), 2);
78 assert_eq!(state.read().unwrap().todos.len(), 2);
79 assert_eq!(
80 cmd.messages[0]
81 .metadata
82 .as_ref()
83 .unwrap()
84 .tool_call_id
85 .as_deref(),
86 Some("call-1")
87 );
88 }
89 _ => panic!("expected command"),
90 }
91 }
92}