batuta/agent/tool/
memory.rs1use async_trait::async_trait;
7use std::sync::Arc;
8
9use super::{Tool, ToolResult};
10use crate::agent::capability::Capability;
11use crate::agent::driver::ToolDefinition;
12use crate::agent::memory::MemorySubstrate;
13
14pub struct MemoryTool {
16 substrate: Arc<dyn MemorySubstrate>,
17 agent_id: String,
18}
19
20impl MemoryTool {
21 pub fn new(substrate: Arc<dyn MemorySubstrate>, agent_id: String) -> Self {
23 Self { substrate, agent_id }
24 }
25}
26
27#[async_trait]
28impl Tool for MemoryTool {
29 fn name(&self) -> &'static str {
30 "memory"
31 }
32
33 fn definition(&self) -> ToolDefinition {
34 ToolDefinition {
35 name: "memory".into(),
36 description: "Read and write agent memory. \
37 Actions: 'remember' stores content, \
38 'recall' retrieves relevant memories."
39 .into(),
40 input_schema: serde_json::json!({
41 "type": "object",
42 "properties": {
43 "action": {
44 "type": "string",
45 "enum": ["remember", "recall"],
46 "description": "Action to perform"
47 },
48 "content": {
49 "type": "string",
50 "description": "Content to store (remember) or query (recall)"
51 },
52 "limit": {
53 "type": "integer",
54 "description": "Max memories to recall (default 5)"
55 }
56 },
57 "required": ["action", "content"]
58 }),
59 }
60 }
61
62 async fn execute(&self, input: serde_json::Value) -> ToolResult {
63 let action = input.get("action").and_then(|v| v.as_str()).unwrap_or("");
64 let content = input.get("content").and_then(|v| v.as_str()).unwrap_or("");
65
66 match action {
67 "remember" => self.do_remember(content).await,
68 "recall" => {
69 #[allow(clippy::cast_possible_truncation)]
70 let limit =
71 input.get("limit").and_then(serde_json::Value::as_u64).unwrap_or(5) as usize;
72 self.do_recall(content, limit).await
73 }
74 other => ToolResult::error(format!(
75 "unknown action '{other}', expected 'remember' or 'recall'"
76 )),
77 }
78 }
79
80 fn required_capability(&self) -> Capability {
81 Capability::Memory
82 }
83}
84
85impl MemoryTool {
86 async fn do_remember(&self, content: &str) -> ToolResult {
87 match self
88 .substrate
89 .remember(&self.agent_id, content, crate::agent::memory::MemorySource::ToolResult, None)
90 .await
91 {
92 Ok(id) => ToolResult::success(format!("Stored memory: {id}")),
93 Err(e) => ToolResult::error(format!("Failed to store: {e}")),
94 }
95 }
96
97 async fn do_recall(&self, query: &str, limit: usize) -> ToolResult {
98 match self.substrate.recall(query, limit, None, None).await {
99 Ok(fragments) => {
100 if fragments.is_empty() {
101 return ToolResult::success("No memories found.");
102 }
103 let text = fragments
104 .iter()
105 .enumerate()
106 .map(|(i, f)| {
107 format!("{}. [score={:.2}] {}", i + 1, f.relevance_score, f.content)
108 })
109 .collect::<Vec<_>>()
110 .join("\n");
111 ToolResult::success(text)
112 }
113 Err(e) => ToolResult::error(format!("Failed to recall: {e}")),
114 }
115 }
116}
117
118#[cfg(test)]
119mod tests {
120 use super::*;
121 use crate::agent::memory::InMemorySubstrate;
122
123 fn make_tool() -> MemoryTool {
124 let substrate = Arc::new(InMemorySubstrate::new());
125 MemoryTool::new(substrate, "test-agent".into())
126 }
127
128 #[tokio::test]
129 async fn test_remember_and_recall() {
130 let tool = make_tool();
131
132 let result = tool
134 .execute(serde_json::json!({
135 "action": "remember",
136 "content": "Rust is great for systems programming"
137 }))
138 .await;
139 assert!(!result.is_error);
140 assert!(result.content.contains("Stored memory"));
141
142 let result = tool
144 .execute(serde_json::json!({
145 "action": "recall",
146 "content": "Rust",
147 "limit": 3
148 }))
149 .await;
150 assert!(!result.is_error);
151 assert!(result.content.contains("systems programming"));
152 }
153
154 #[tokio::test]
155 async fn test_recall_empty() {
156 let tool = make_tool();
157
158 let result = tool
159 .execute(serde_json::json!({
160 "action": "recall",
161 "content": "nonexistent"
162 }))
163 .await;
164 assert!(!result.is_error);
165 assert!(result.content.contains("No memories found"));
166 }
167
168 #[tokio::test]
169 async fn test_unknown_action() {
170 let tool = make_tool();
171
172 let result = tool
173 .execute(serde_json::json!({
174 "action": "delete",
175 "content": "test"
176 }))
177 .await;
178 assert!(result.is_error);
179 assert!(result.content.contains("unknown action"));
180 }
181
182 #[test]
183 fn test_tool_metadata() {
184 let tool = make_tool();
185 assert_eq!(tool.name(), "memory");
186 assert_eq!(tool.required_capability(), Capability::Memory);
187
188 let def = tool.definition();
189 assert_eq!(def.name, "memory");
190 assert!(def.description.contains("recall"));
191 }
192}