1use agentzero_core::{Tool, ToolContext, ToolResult};
2use anyhow::{anyhow, Context};
3use async_trait::async_trait;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::path::Path;
7use tokio::fs;
8
9const MEMORY_FILE: &str = ".agentzero/memory.json";
10const DEFAULT_NAMESPACE: &str = "default";
11
12#[derive(Debug, Clone, Default, Serialize, Deserialize)]
13struct MemoryStore {
14 namespaces: HashMap<String, HashMap<String, String>>,
15}
16
17impl MemoryStore {
18 async fn load(workspace_root: &str) -> anyhow::Result<Self> {
19 let path = Path::new(workspace_root).join(MEMORY_FILE);
20 if !path.exists() {
21 return Ok(Self::default());
22 }
23 let data = fs::read_to_string(&path)
24 .await
25 .context("failed to read memory store")?;
26 serde_json::from_str(&data).context("failed to parse memory store")
27 }
28
29 async fn save(&self, workspace_root: &str) -> anyhow::Result<()> {
30 let path = Path::new(workspace_root).join(MEMORY_FILE);
31 if let Some(parent) = path.parent() {
32 fs::create_dir_all(parent)
33 .await
34 .context("failed to create .agentzero directory")?;
35 }
36 let data = serde_json::to_string_pretty(self).context("failed to serialize memory")?;
37 fs::write(&path, data)
38 .await
39 .context("failed to write memory store")
40 }
41}
42
43#[derive(Debug, Deserialize)]
46struct MemoryStoreInput {
47 key: String,
48 value: String,
49 #[serde(default)]
50 namespace: Option<String>,
51}
52
53#[derive(Debug, Default, Clone, Copy)]
54pub struct MemoryStoreTool;
55
56#[async_trait]
57impl Tool for MemoryStoreTool {
58 fn name(&self) -> &'static str {
59 "memory_store"
60 }
61
62 fn description(&self) -> &'static str {
63 "Store a key-value pair in persistent memory, optionally under a namespace."
64 }
65
66 fn input_schema(&self) -> Option<serde_json::Value> {
67 Some(serde_json::json!({
68 "type": "object",
69 "properties": {
70 "key": { "type": "string", "description": "The key to store" },
71 "value": { "type": "string", "description": "The value to store" },
72 "namespace": { "type": "string", "description": "Optional namespace for grouping" }
73 },
74 "required": ["key", "value"]
75 }))
76 }
77
78 async fn execute(&self, input: &str, ctx: &ToolContext) -> anyhow::Result<ToolResult> {
79 let req: MemoryStoreInput = serde_json::from_str(input)
80 .context("memory_store expects JSON: {\"key\", \"value\", \"namespace\"?}")?;
81
82 if req.key.trim().is_empty() {
83 return Err(anyhow!("key must not be empty"));
84 }
85
86 let ns = req
87 .namespace
88 .as_deref()
89 .unwrap_or(DEFAULT_NAMESPACE)
90 .to_string();
91 let mut store = MemoryStore::load(&ctx.workspace_root).await?;
92 store
93 .namespaces
94 .entry(ns.clone())
95 .or_default()
96 .insert(req.key.clone(), req.value.clone());
97 store.save(&ctx.workspace_root).await?;
98
99 Ok(ToolResult {
100 output: format!(
101 "stored key={} namespace={} bytes={}",
102 req.key,
103 ns,
104 req.value.len()
105 ),
106 })
107 }
108}
109
110#[derive(Debug, Deserialize)]
113struct MemoryRecallInput {
114 #[serde(default)]
115 key: Option<String>,
116 #[serde(default)]
117 namespace: Option<String>,
118 #[serde(default = "default_limit")]
119 limit: usize,
120}
121
122fn default_limit() -> usize {
123 50
124}
125
126#[derive(Debug, Default, Clone, Copy)]
127pub struct MemoryRecallTool;
128
129#[async_trait]
130impl Tool for MemoryRecallTool {
131 fn name(&self) -> &'static str {
132 "memory_recall"
133 }
134
135 fn description(&self) -> &'static str {
136 "Recall stored values from memory by key or list recent entries in a namespace."
137 }
138
139 fn input_schema(&self) -> Option<serde_json::Value> {
140 Some(serde_json::json!({
141 "type": "object",
142 "properties": {
143 "key": { "type": "string", "description": "Specific key to recall" },
144 "namespace": { "type": "string", "description": "Namespace to search within" },
145 "limit": { "type": "integer", "description": "Max entries to return (default: 50)" }
146 }
147 }))
148 }
149
150 async fn execute(&self, input: &str, ctx: &ToolContext) -> anyhow::Result<ToolResult> {
151 let req: MemoryRecallInput = serde_json::from_str(input)
152 .context("memory_recall expects JSON: {\"key\"?, \"namespace\"?, \"limit\"?}")?;
153
154 let ns = req.namespace.as_deref().unwrap_or(DEFAULT_NAMESPACE);
155 let store = MemoryStore::load(&ctx.workspace_root).await?;
156
157 let entries = match store.namespaces.get(ns) {
158 Some(map) => map,
159 None => {
160 return Ok(ToolResult {
161 output: "no entries found".to_string(),
162 });
163 }
164 };
165
166 if let Some(ref key) = req.key {
167 match entries.get(key.as_str()) {
168 Some(value) => {
169 return Ok(ToolResult {
170 output: value.clone(),
171 });
172 }
173 None => {
174 return Ok(ToolResult {
175 output: format!("key not found: {key}"),
176 });
177 }
178 }
179 }
180
181 let limit = if req.limit == 0 { 50 } else { req.limit };
183 let mut keys: Vec<&String> = entries.keys().collect();
184 keys.sort();
185 let results: Vec<String> = keys
186 .iter()
187 .take(limit)
188 .map(|k| format!("{}={}", k, entries[k.as_str()]))
189 .collect();
190
191 if results.is_empty() {
192 return Ok(ToolResult {
193 output: "no entries found".to_string(),
194 });
195 }
196
197 Ok(ToolResult {
198 output: results.join("\n"),
199 })
200 }
201}
202
203#[derive(Debug, Deserialize)]
206struct MemoryForgetInput {
207 key: String,
208 #[serde(default)]
209 namespace: Option<String>,
210}
211
212#[derive(Debug, Default, Clone, Copy)]
213pub struct MemoryForgetTool;
214
215#[async_trait]
216impl Tool for MemoryForgetTool {
217 fn name(&self) -> &'static str {
218 "memory_forget"
219 }
220
221 fn description(&self) -> &'static str {
222 "Remove a key-value pair from memory."
223 }
224
225 fn input_schema(&self) -> Option<serde_json::Value> {
226 Some(serde_json::json!({
227 "type": "object",
228 "properties": {
229 "key": { "type": "string", "description": "The key to forget" },
230 "namespace": { "type": "string", "description": "Namespace the key belongs to" }
231 },
232 "required": ["key"]
233 }))
234 }
235
236 async fn execute(&self, input: &str, ctx: &ToolContext) -> anyhow::Result<ToolResult> {
237 let req: MemoryForgetInput = serde_json::from_str(input)
238 .context("memory_forget expects JSON: {\"key\", \"namespace\"?}")?;
239
240 if req.key.trim().is_empty() {
241 return Err(anyhow!("key must not be empty"));
242 }
243
244 let ns = req
245 .namespace
246 .as_deref()
247 .unwrap_or(DEFAULT_NAMESPACE)
248 .to_string();
249 let mut store = MemoryStore::load(&ctx.workspace_root).await?;
250
251 let removed = store
252 .namespaces
253 .get_mut(&ns)
254 .and_then(|map| map.remove(&req.key))
255 .is_some();
256
257 if removed {
258 store.save(&ctx.workspace_root).await?;
259 Ok(ToolResult {
260 output: format!("forgotten key={} namespace={}", req.key, ns),
261 })
262 } else {
263 Ok(ToolResult {
264 output: format!("key not found: {} in namespace={}", req.key, ns),
265 })
266 }
267 }
268}
269
270#[cfg(test)]
271mod tests {
272 use super::{MemoryForgetTool, MemoryRecallTool, MemoryStoreTool};
273 use agentzero_core::{Tool, ToolContext};
274 use std::fs;
275 use std::path::PathBuf;
276 use std::sync::atomic::{AtomicU64, Ordering};
277 use std::time::{SystemTime, UNIX_EPOCH};
278
279 static TEMP_COUNTER: AtomicU64 = AtomicU64::new(0);
280
281 fn temp_dir() -> PathBuf {
282 let nanos = SystemTime::now()
283 .duration_since(UNIX_EPOCH)
284 .expect("clock")
285 .as_nanos();
286 let seq = TEMP_COUNTER.fetch_add(1, Ordering::Relaxed);
287 let dir = std::env::temp_dir().join(format!(
288 "agentzero-memory-tools-{}-{nanos}-{seq}",
289 std::process::id()
290 ));
291 fs::create_dir_all(&dir).expect("temp dir should be created");
292 dir
293 }
294
295 #[tokio::test]
296 async fn memory_store_recall_roundtrip() {
297 let dir = temp_dir();
298 let ctx = ToolContext::new(dir.to_string_lossy().to_string());
299
300 let store = MemoryStoreTool;
301 store
302 .execute(r#"{"key": "greeting", "value": "hello world"}"#, &ctx)
303 .await
304 .expect("store should succeed");
305
306 let recall = MemoryRecallTool;
307 let result = recall
308 .execute(r#"{"key": "greeting"}"#, &ctx)
309 .await
310 .expect("recall should succeed");
311 assert_eq!(result.output, "hello world");
312 fs::remove_dir_all(dir).ok();
313 }
314
315 #[tokio::test]
316 async fn memory_forget_removes_key() {
317 let dir = temp_dir();
318 let ctx = ToolContext::new(dir.to_string_lossy().to_string());
319
320 MemoryStoreTool
321 .execute(r#"{"key": "temp", "value": "data"}"#, &ctx)
322 .await
323 .unwrap();
324
325 let forget = MemoryForgetTool;
326 let result = forget
327 .execute(r#"{"key": "temp"}"#, &ctx)
328 .await
329 .expect("forget should succeed");
330 assert!(result.output.contains("forgotten"));
331
332 let recall = MemoryRecallTool;
333 let result = recall
334 .execute(r#"{"key": "temp"}"#, &ctx)
335 .await
336 .expect("recall should succeed");
337 assert!(result.output.contains("key not found"));
338 fs::remove_dir_all(dir).ok();
339 }
340
341 #[tokio::test]
342 async fn memory_namespace_isolation() {
343 let dir = temp_dir();
344 let ctx = ToolContext::new(dir.to_string_lossy().to_string());
345
346 MemoryStoreTool
347 .execute(r#"{"key": "x", "value": "default_val"}"#, &ctx)
348 .await
349 .unwrap();
350 MemoryStoreTool
351 .execute(
352 r#"{"key": "x", "value": "custom_val", "namespace": "custom"}"#,
353 &ctx,
354 )
355 .await
356 .unwrap();
357
358 let result = MemoryRecallTool
359 .execute(r#"{"key": "x"}"#, &ctx)
360 .await
361 .unwrap();
362 assert_eq!(result.output, "default_val");
363
364 let result = MemoryRecallTool
365 .execute(r#"{"key": "x", "namespace": "custom"}"#, &ctx)
366 .await
367 .unwrap();
368 assert_eq!(result.output, "custom_val");
369 fs::remove_dir_all(dir).ok();
370 }
371
372 #[tokio::test]
373 async fn memory_store_rejects_empty_key_negative_path() {
374 let dir = temp_dir();
375 let ctx = ToolContext::new(dir.to_string_lossy().to_string());
376
377 let err = MemoryStoreTool
378 .execute(r#"{"key": "", "value": "data"}"#, &ctx)
379 .await
380 .expect_err("empty key should fail");
381 assert!(err.to_string().contains("key must not be empty"));
382 fs::remove_dir_all(dir).ok();
383 }
384
385 #[tokio::test]
386 async fn memory_recall_lists_all_keys() {
387 let dir = temp_dir();
388 let ctx = ToolContext::new(dir.to_string_lossy().to_string());
389
390 MemoryStoreTool
391 .execute(r#"{"key": "a", "value": "1"}"#, &ctx)
392 .await
393 .unwrap();
394 MemoryStoreTool
395 .execute(r#"{"key": "b", "value": "2"}"#, &ctx)
396 .await
397 .unwrap();
398
399 let result = MemoryRecallTool
400 .execute(r#"{}"#, &ctx)
401 .await
402 .expect("list should succeed");
403 assert!(result.output.contains("a=1"));
404 assert!(result.output.contains("b=2"));
405 fs::remove_dir_all(dir).ok();
406 }
407}