Skip to main content

batuta/agent/tool/
memory.rs

1//! Memory tool — read/write agent persistent state.
2//!
3//! Wraps a `MemorySubstrate` as a `Tool` for use in the agent loop.
4//! Supports two actions: "remember" (store) and "recall" (retrieve).
5
6use async_trait::async_trait;
7use std::sync::Arc;
8
9use super::{Tool, ToolResult};
10use crate::agent::capability::Capability;
11use crate::agent::driver::ToolDefinition;
12use crate::agent::memory::MemorySubstrate;
13
14/// Tool for reading and writing agent memory.
15pub struct MemoryTool {
16    substrate: Arc<dyn MemorySubstrate>,
17    agent_id: String,
18}
19
20impl MemoryTool {
21    /// Create a new memory tool for the given agent.
22    pub fn new(substrate: Arc<dyn MemorySubstrate>, agent_id: String) -> Self {
23        Self { substrate, agent_id }
24    }
25}
26
27#[async_trait]
28impl Tool for MemoryTool {
29    fn name(&self) -> &'static str {
30        "memory"
31    }
32
33    fn definition(&self) -> ToolDefinition {
34        ToolDefinition {
35            name: "memory".into(),
36            description: "Read and write agent memory. \
37                Actions: 'remember' stores content, \
38                'recall' retrieves relevant memories."
39                .into(),
40            input_schema: serde_json::json!({
41                "type": "object",
42                "properties": {
43                    "action": {
44                        "type": "string",
45                        "enum": ["remember", "recall"],
46                        "description": "Action to perform"
47                    },
48                    "content": {
49                        "type": "string",
50                        "description": "Content to store (remember) or query (recall)"
51                    },
52                    "limit": {
53                        "type": "integer",
54                        "description": "Max memories to recall (default 5)"
55                    }
56                },
57                "required": ["action", "content"]
58            }),
59        }
60    }
61
62    async fn execute(&self, input: serde_json::Value) -> ToolResult {
63        let action = input.get("action").and_then(|v| v.as_str()).unwrap_or("");
64        let content = input.get("content").and_then(|v| v.as_str()).unwrap_or("");
65
66        match action {
67            "remember" => self.do_remember(content).await,
68            "recall" => {
69                #[allow(clippy::cast_possible_truncation)]
70                let limit =
71                    input.get("limit").and_then(serde_json::Value::as_u64).unwrap_or(5) as usize;
72                self.do_recall(content, limit).await
73            }
74            other => ToolResult::error(format!(
75                "unknown action '{other}', expected 'remember' or 'recall'"
76            )),
77        }
78    }
79
80    fn required_capability(&self) -> Capability {
81        Capability::Memory
82    }
83}
84
85impl MemoryTool {
86    async fn do_remember(&self, content: &str) -> ToolResult {
87        match self
88            .substrate
89            .remember(&self.agent_id, content, crate::agent::memory::MemorySource::ToolResult, None)
90            .await
91        {
92            Ok(id) => ToolResult::success(format!("Stored memory: {id}")),
93            Err(e) => ToolResult::error(format!("Failed to store: {e}")),
94        }
95    }
96
97    async fn do_recall(&self, query: &str, limit: usize) -> ToolResult {
98        match self.substrate.recall(query, limit, None, None).await {
99            Ok(fragments) => {
100                if fragments.is_empty() {
101                    return ToolResult::success("No memories found.");
102                }
103                let text = fragments
104                    .iter()
105                    .enumerate()
106                    .map(|(i, f)| {
107                        format!("{}. [score={:.2}] {}", i + 1, f.relevance_score, f.content)
108                    })
109                    .collect::<Vec<_>>()
110                    .join("\n");
111                ToolResult::success(text)
112            }
113            Err(e) => ToolResult::error(format!("Failed to recall: {e}")),
114        }
115    }
116}
117
118#[cfg(test)]
119mod tests {
120    use super::*;
121    use crate::agent::memory::InMemorySubstrate;
122
123    fn make_tool() -> MemoryTool {
124        let substrate = Arc::new(InMemorySubstrate::new());
125        MemoryTool::new(substrate, "test-agent".into())
126    }
127
128    #[tokio::test]
129    async fn test_remember_and_recall() {
130        let tool = make_tool();
131
132        // Remember
133        let result = tool
134            .execute(serde_json::json!({
135                "action": "remember",
136                "content": "Rust is great for systems programming"
137            }))
138            .await;
139        assert!(!result.is_error);
140        assert!(result.content.contains("Stored memory"));
141
142        // Recall
143        let result = tool
144            .execute(serde_json::json!({
145                "action": "recall",
146                "content": "Rust",
147                "limit": 3
148            }))
149            .await;
150        assert!(!result.is_error);
151        assert!(result.content.contains("systems programming"));
152    }
153
154    #[tokio::test]
155    async fn test_recall_empty() {
156        let tool = make_tool();
157
158        let result = tool
159            .execute(serde_json::json!({
160                "action": "recall",
161                "content": "nonexistent"
162            }))
163            .await;
164        assert!(!result.is_error);
165        assert!(result.content.contains("No memories found"));
166    }
167
168    #[tokio::test]
169    async fn test_unknown_action() {
170        let tool = make_tool();
171
172        let result = tool
173            .execute(serde_json::json!({
174                "action": "delete",
175                "content": "test"
176            }))
177            .await;
178        assert!(result.is_error);
179        assert!(result.content.contains("unknown action"));
180    }
181
182    #[test]
183    fn test_tool_metadata() {
184        let tool = make_tool();
185        assert_eq!(tool.name(), "memory");
186        assert_eq!(tool.required_capability(), Capability::Memory);
187
188        let def = tool.definition();
189        assert_eq!(def.name, "memory");
190        assert!(def.description.contains("recall"));
191    }
192}