Skip to main content

agentzero_tools/
memory_tools.rs

1use agentzero_core::{Tool, ToolContext, ToolResult};
2use anyhow::{anyhow, Context};
3use async_trait::async_trait;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::path::Path;
7use tokio::fs;
8
9const MEMORY_FILE: &str = ".agentzero/memory.json";
10const DEFAULT_NAMESPACE: &str = "default";
11
12#[derive(Debug, Clone, Default, Serialize, Deserialize)]
13struct MemoryStore {
14    namespaces: HashMap<String, HashMap<String, String>>,
15}
16
17impl MemoryStore {
18    async fn load(workspace_root: &str) -> anyhow::Result<Self> {
19        let path = Path::new(workspace_root).join(MEMORY_FILE);
20        if !path.exists() {
21            return Ok(Self::default());
22        }
23        let data = fs::read_to_string(&path)
24            .await
25            .context("failed to read memory store")?;
26        serde_json::from_str(&data).context("failed to parse memory store")
27    }
28
29    async fn save(&self, workspace_root: &str) -> anyhow::Result<()> {
30        let path = Path::new(workspace_root).join(MEMORY_FILE);
31        if let Some(parent) = path.parent() {
32            fs::create_dir_all(parent)
33                .await
34                .context("failed to create .agentzero directory")?;
35        }
36        let data = serde_json::to_string_pretty(self).context("failed to serialize memory")?;
37        fs::write(&path, data)
38            .await
39            .context("failed to write memory store")
40    }
41}
42
43// --- memory_store ---
44
45#[derive(Debug, Deserialize)]
46struct MemoryStoreInput {
47    key: String,
48    value: String,
49    #[serde(default)]
50    namespace: Option<String>,
51}
52
53#[derive(Debug, Default, Clone, Copy)]
54pub struct MemoryStoreTool;
55
56#[async_trait]
57impl Tool for MemoryStoreTool {
58    fn name(&self) -> &'static str {
59        "memory_store"
60    }
61
62    fn description(&self) -> &'static str {
63        "Store a key-value pair in persistent memory, optionally under a namespace."
64    }
65
66    fn input_schema(&self) -> Option<serde_json::Value> {
67        Some(serde_json::json!({
68            "type": "object",
69            "properties": {
70                "key": { "type": "string", "description": "The key to store" },
71                "value": { "type": "string", "description": "The value to store" },
72                "namespace": { "type": "string", "description": "Optional namespace for grouping" }
73            },
74            "required": ["key", "value"]
75        }))
76    }
77
78    async fn execute(&self, input: &str, ctx: &ToolContext) -> anyhow::Result<ToolResult> {
79        let req: MemoryStoreInput = serde_json::from_str(input)
80            .context("memory_store expects JSON: {\"key\", \"value\", \"namespace\"?}")?;
81
82        if req.key.trim().is_empty() {
83            return Err(anyhow!("key must not be empty"));
84        }
85
86        let ns = req
87            .namespace
88            .as_deref()
89            .unwrap_or(DEFAULT_NAMESPACE)
90            .to_string();
91        let mut store = MemoryStore::load(&ctx.workspace_root).await?;
92        store
93            .namespaces
94            .entry(ns.clone())
95            .or_default()
96            .insert(req.key.clone(), req.value.clone());
97        store.save(&ctx.workspace_root).await?;
98
99        Ok(ToolResult {
100            output: format!(
101                "stored key={} namespace={} bytes={}",
102                req.key,
103                ns,
104                req.value.len()
105            ),
106        })
107    }
108}
109
110// --- memory_recall ---
111
112#[derive(Debug, Deserialize)]
113struct MemoryRecallInput {
114    #[serde(default)]
115    key: Option<String>,
116    #[serde(default)]
117    namespace: Option<String>,
118    #[serde(default = "default_limit")]
119    limit: usize,
120}
121
122fn default_limit() -> usize {
123    50
124}
125
126#[derive(Debug, Default, Clone, Copy)]
127pub struct MemoryRecallTool;
128
129#[async_trait]
130impl Tool for MemoryRecallTool {
131    fn name(&self) -> &'static str {
132        "memory_recall"
133    }
134
135    fn description(&self) -> &'static str {
136        "Recall stored values from memory by key or list recent entries in a namespace."
137    }
138
139    fn input_schema(&self) -> Option<serde_json::Value> {
140        Some(serde_json::json!({
141            "type": "object",
142            "properties": {
143                "key": { "type": "string", "description": "Specific key to recall" },
144                "namespace": { "type": "string", "description": "Namespace to search within" },
145                "limit": { "type": "integer", "description": "Max entries to return (default: 50)" }
146            }
147        }))
148    }
149
150    async fn execute(&self, input: &str, ctx: &ToolContext) -> anyhow::Result<ToolResult> {
151        let req: MemoryRecallInput = serde_json::from_str(input)
152            .context("memory_recall expects JSON: {\"key\"?, \"namespace\"?, \"limit\"?}")?;
153
154        let ns = req.namespace.as_deref().unwrap_or(DEFAULT_NAMESPACE);
155        let store = MemoryStore::load(&ctx.workspace_root).await?;
156
157        let entries = match store.namespaces.get(ns) {
158            Some(map) => map,
159            None => {
160                return Ok(ToolResult {
161                    output: "no entries found".to_string(),
162                });
163            }
164        };
165
166        if let Some(ref key) = req.key {
167            match entries.get(key.as_str()) {
168                Some(value) => {
169                    return Ok(ToolResult {
170                        output: value.clone(),
171                    });
172                }
173                None => {
174                    return Ok(ToolResult {
175                        output: format!("key not found: {key}"),
176                    });
177                }
178            }
179        }
180
181        // List all keys in namespace.
182        let limit = if req.limit == 0 { 50 } else { req.limit };
183        let mut keys: Vec<&String> = entries.keys().collect();
184        keys.sort();
185        let results: Vec<String> = keys
186            .iter()
187            .take(limit)
188            .map(|k| format!("{}={}", k, entries[k.as_str()]))
189            .collect();
190
191        if results.is_empty() {
192            return Ok(ToolResult {
193                output: "no entries found".to_string(),
194            });
195        }
196
197        Ok(ToolResult {
198            output: results.join("\n"),
199        })
200    }
201}
202
203// --- memory_forget ---
204
205#[derive(Debug, Deserialize)]
206struct MemoryForgetInput {
207    key: String,
208    #[serde(default)]
209    namespace: Option<String>,
210}
211
212#[derive(Debug, Default, Clone, Copy)]
213pub struct MemoryForgetTool;
214
215#[async_trait]
216impl Tool for MemoryForgetTool {
217    fn name(&self) -> &'static str {
218        "memory_forget"
219    }
220
221    fn description(&self) -> &'static str {
222        "Remove a key-value pair from memory."
223    }
224
225    fn input_schema(&self) -> Option<serde_json::Value> {
226        Some(serde_json::json!({
227            "type": "object",
228            "properties": {
229                "key": { "type": "string", "description": "The key to forget" },
230                "namespace": { "type": "string", "description": "Namespace the key belongs to" }
231            },
232            "required": ["key"]
233        }))
234    }
235
236    async fn execute(&self, input: &str, ctx: &ToolContext) -> anyhow::Result<ToolResult> {
237        let req: MemoryForgetInput = serde_json::from_str(input)
238            .context("memory_forget expects JSON: {\"key\", \"namespace\"?}")?;
239
240        if req.key.trim().is_empty() {
241            return Err(anyhow!("key must not be empty"));
242        }
243
244        let ns = req
245            .namespace
246            .as_deref()
247            .unwrap_or(DEFAULT_NAMESPACE)
248            .to_string();
249        let mut store = MemoryStore::load(&ctx.workspace_root).await?;
250
251        let removed = store
252            .namespaces
253            .get_mut(&ns)
254            .and_then(|map| map.remove(&req.key))
255            .is_some();
256
257        if removed {
258            store.save(&ctx.workspace_root).await?;
259            Ok(ToolResult {
260                output: format!("forgotten key={} namespace={}", req.key, ns),
261            })
262        } else {
263            Ok(ToolResult {
264                output: format!("key not found: {} in namespace={}", req.key, ns),
265            })
266        }
267    }
268}
269
270#[cfg(test)]
271mod tests {
272    use super::{MemoryForgetTool, MemoryRecallTool, MemoryStoreTool};
273    use agentzero_core::{Tool, ToolContext};
274    use std::fs;
275    use std::path::PathBuf;
276    use std::sync::atomic::{AtomicU64, Ordering};
277    use std::time::{SystemTime, UNIX_EPOCH};
278
279    static TEMP_COUNTER: AtomicU64 = AtomicU64::new(0);
280
281    fn temp_dir() -> PathBuf {
282        let nanos = SystemTime::now()
283            .duration_since(UNIX_EPOCH)
284            .expect("clock")
285            .as_nanos();
286        let seq = TEMP_COUNTER.fetch_add(1, Ordering::Relaxed);
287        let dir = std::env::temp_dir().join(format!(
288            "agentzero-memory-tools-{}-{nanos}-{seq}",
289            std::process::id()
290        ));
291        fs::create_dir_all(&dir).expect("temp dir should be created");
292        dir
293    }
294
295    #[tokio::test]
296    async fn memory_store_recall_roundtrip() {
297        let dir = temp_dir();
298        let ctx = ToolContext::new(dir.to_string_lossy().to_string());
299
300        let store = MemoryStoreTool;
301        store
302            .execute(r#"{"key": "greeting", "value": "hello world"}"#, &ctx)
303            .await
304            .expect("store should succeed");
305
306        let recall = MemoryRecallTool;
307        let result = recall
308            .execute(r#"{"key": "greeting"}"#, &ctx)
309            .await
310            .expect("recall should succeed");
311        assert_eq!(result.output, "hello world");
312        fs::remove_dir_all(dir).ok();
313    }
314
315    #[tokio::test]
316    async fn memory_forget_removes_key() {
317        let dir = temp_dir();
318        let ctx = ToolContext::new(dir.to_string_lossy().to_string());
319
320        MemoryStoreTool
321            .execute(r#"{"key": "temp", "value": "data"}"#, &ctx)
322            .await
323            .unwrap();
324
325        let forget = MemoryForgetTool;
326        let result = forget
327            .execute(r#"{"key": "temp"}"#, &ctx)
328            .await
329            .expect("forget should succeed");
330        assert!(result.output.contains("forgotten"));
331
332        let recall = MemoryRecallTool;
333        let result = recall
334            .execute(r#"{"key": "temp"}"#, &ctx)
335            .await
336            .expect("recall should succeed");
337        assert!(result.output.contains("key not found"));
338        fs::remove_dir_all(dir).ok();
339    }
340
341    #[tokio::test]
342    async fn memory_namespace_isolation() {
343        let dir = temp_dir();
344        let ctx = ToolContext::new(dir.to_string_lossy().to_string());
345
346        MemoryStoreTool
347            .execute(r#"{"key": "x", "value": "default_val"}"#, &ctx)
348            .await
349            .unwrap();
350        MemoryStoreTool
351            .execute(
352                r#"{"key": "x", "value": "custom_val", "namespace": "custom"}"#,
353                &ctx,
354            )
355            .await
356            .unwrap();
357
358        let result = MemoryRecallTool
359            .execute(r#"{"key": "x"}"#, &ctx)
360            .await
361            .unwrap();
362        assert_eq!(result.output, "default_val");
363
364        let result = MemoryRecallTool
365            .execute(r#"{"key": "x", "namespace": "custom"}"#, &ctx)
366            .await
367            .unwrap();
368        assert_eq!(result.output, "custom_val");
369        fs::remove_dir_all(dir).ok();
370    }
371
372    #[tokio::test]
373    async fn memory_store_rejects_empty_key_negative_path() {
374        let dir = temp_dir();
375        let ctx = ToolContext::new(dir.to_string_lossy().to_string());
376
377        let err = MemoryStoreTool
378            .execute(r#"{"key": "", "value": "data"}"#, &ctx)
379            .await
380            .expect_err("empty key should fail");
381        assert!(err.to_string().contains("key must not be empty"));
382        fs::remove_dir_all(dir).ok();
383    }
384
385    #[tokio::test]
386    async fn memory_recall_lists_all_keys() {
387        let dir = temp_dir();
388        let ctx = ToolContext::new(dir.to_string_lossy().to_string());
389
390        MemoryStoreTool
391            .execute(r#"{"key": "a", "value": "1"}"#, &ctx)
392            .await
393            .unwrap();
394        MemoryStoreTool
395            .execute(r#"{"key": "b", "value": "2"}"#, &ctx)
396            .await
397            .unwrap();
398
399        let result = MemoryRecallTool
400            .execute(r#"{}"#, &ctx)
401            .await
402            .expect("list should succeed");
403        assert!(result.output.contains("a=1"));
404        assert!(result.output.contains("b=2"));
405        fs::remove_dir_all(dir).ok();
406    }
407}