1use async_trait::async_trait;
4use serde_json::{Value, json};
5
6use super::{AgentTool, AgentToolResult, MemoryItem, ToolContext, ToolError};
7
8const DEFAULT_LIMIT: usize = 5;
10const MAX_LIMIT: usize = 20;
12
13pub struct MemoryRecallTool;
18
19#[async_trait]
20impl AgentTool for MemoryRecallTool {
21 fn name(&self) -> &str {
22 "memory_recall"
23 }
24
25 fn label(&self) -> &str {
26 "Memory Recall"
27 }
28
29 fn description(&self) -> &str {
30 "Search long-term memory for information relevant to a query. \
31 Returns the most relevant stored memories (facts, preferences, \
32 context, summaries)."
33 }
34
35 fn essential(&self) -> bool {
36 false
37 }
38
39 fn parameters_schema(&self) -> Value {
40 json!({
41 "type": "object",
42 "properties": {
43 "query": {
44 "type": "string",
45 "description": "What to search for in memory."
46 },
47 "limit": {
48 "type": "integer",
49 "minimum": 1,
50 "maximum": 20,
51 "default": 5,
52 "description": "Maximum number of results to return."
53 }
54 },
55 "required": ["query"]
56 })
57 }
58
59 async fn execute(
60 &self,
61 _tool_call_id: &str,
62 params: Value,
63 _signal: Option<tokio::sync::oneshot::Receiver<()>>,
64 ctx: &ToolContext,
65 ) -> Result<AgentToolResult, ToolError> {
66 let backend = ctx.memory.as_ref().ok_or("Memory not configured")?;
67
68 let query = params
69 .get("query")
70 .and_then(|v| v.as_str())
71 .ok_or("Missing required parameter: query")?;
72
73 let limit = params
74 .get("limit")
75 .and_then(|v| v.as_u64())
76 .map(|l| (l as usize).clamp(1, MAX_LIMIT))
77 .unwrap_or(DEFAULT_LIMIT);
78
79 let results = backend.search(query, limit).await?;
80
81 Ok(AgentToolResult::success(format_results(&results)))
82 }
83}
84
85fn format_results(items: &[MemoryItem]) -> String {
87 if items.is_empty() {
88 return "No matching memories found.".to_string();
89 }
90 let mut out = format!(
91 "Found {} memor{}:\n\n",
92 items.len(),
93 if items.len() == 1 { "y" } else { "ies" }
94 );
95 for (i, item) in items.iter().enumerate() {
96 out.push_str(&format!("{}. [{}] {}\n", i + 1, item.kind, item.content));
97 }
98 out
99}
100
101#[cfg(test)]
102mod tests {
103 use super::*;
104 use crate::tools::MemoryBackend;
105 use parking_lot::Mutex;
106 use std::future::Future;
107 use std::pin::Pin;
108 use std::sync::Arc;
109
110 #[derive(Debug)]
112 struct MockMemory {
113 items: Vec<MemoryItem>,
114 last_k: Mutex<Option<usize>>,
115 }
116
117 impl MemoryBackend for MockMemory {
118 fn put<'a>(
119 &'a self,
120 _content: &'a str,
121 _kind: &'a str,
122 _subject: &'a str,
123 ) -> Pin<Box<dyn Future<Output = Result<String, ToolError>> + Send + 'a>> {
124 Box::pin(async move { Ok("mem-1".to_string()) })
125 }
126
127 fn search<'a>(
128 &'a self,
129 _query: &'a str,
130 k: usize,
131 ) -> Pin<Box<dyn Future<Output = Result<Vec<MemoryItem>, ToolError>> + Send + 'a>> {
132 *self.last_k.lock() = Some(k);
133 let items: Vec<MemoryItem> = self.items.iter().take(k).cloned().collect();
134 Box::pin(async move { Ok(items) })
135 }
136
137 fn list<'a>(
138 &'a self,
139 _subject: &'a str,
140 ) -> Pin<Box<dyn Future<Output = Result<Vec<MemoryItem>, ToolError>> + Send + 'a>> {
141 Box::pin(async move { Ok(vec![]) })
142 }
143
144 fn delete<'a>(
145 &'a self,
146 _id: &'a str,
147 ) -> Pin<Box<dyn Future<Output = Result<(), ToolError>> + Send + 'a>> {
148 Box::pin(async move { Ok(()) })
149 }
150 }
151
152 fn make_item(id: &str, kind: &str, content: &str) -> MemoryItem {
153 MemoryItem {
154 id: id.into(),
155 kind: kind.into(),
156 content: content.into(),
157 subject: "s".into(),
158 }
159 }
160
161 #[tokio::test]
162 async fn recall_returns_formatted_results() {
163 let mock = Arc::new(MockMemory {
164 items: vec![
165 make_item("1", "fact", "Rust is fast"),
166 make_item("2", "preference", "Likes dark mode"),
167 ],
168 last_k: Mutex::new(None),
169 });
170 let ctx = ToolContext::default().with_memory(mock.clone());
171 let result = MemoryRecallTool
172 .execute("c1", json!({"query": "rust", "limit": 5}), None, &ctx)
173 .await
174 .unwrap();
175 assert!(result.success);
176 assert!(result.output.contains("[fact] Rust is fast"));
177 assert!(result.output.contains("[preference] Likes dark mode"));
178 assert_eq!(*mock.last_k.lock(), Some(5));
179 }
180
181 #[tokio::test]
182 async fn recall_reports_empty_results() {
183 let mock = Arc::new(MockMemory {
184 items: vec![],
185 last_k: Mutex::new(None),
186 });
187 let ctx = ToolContext::default().with_memory(mock);
188 let result = MemoryRecallTool
189 .execute("c1", json!({"query": "nothing"}), None, &ctx)
190 .await
191 .unwrap();
192 assert!(result.success);
193 assert_eq!(result.output, "No matching memories found.");
194 }
195
196 #[tokio::test]
197 async fn recall_uses_default_limit() {
198 let mock = Arc::new(MockMemory {
199 items: vec![],
200 last_k: Mutex::new(None),
201 });
202 let ctx = ToolContext::default().with_memory(mock.clone());
203 MemoryRecallTool
204 .execute("c1", json!({"query": "x"}), None, &ctx)
205 .await
206 .unwrap();
207 assert_eq!(*mock.last_k.lock(), Some(DEFAULT_LIMIT));
208 }
209
210 #[tokio::test]
211 async fn recall_clamps_oversized_limit() {
212 let mock = Arc::new(MockMemory {
213 items: vec![],
214 last_k: Mutex::new(None),
215 });
216 let ctx = ToolContext::default().with_memory(mock.clone());
217 MemoryRecallTool
218 .execute("c1", json!({"query": "x", "limit": 100}), None, &ctx)
219 .await
220 .unwrap();
221 assert_eq!(*mock.last_k.lock(), Some(MAX_LIMIT));
222 }
223
224 #[tokio::test]
225 async fn recall_clamps_zero_limit() {
226 let mock = Arc::new(MockMemory {
227 items: vec![],
228 last_k: Mutex::new(None),
229 });
230 let ctx = ToolContext::default().with_memory(mock.clone());
231 MemoryRecallTool
232 .execute("c1", json!({"query": "x", "limit": 0}), None, &ctx)
233 .await
234 .unwrap();
235 assert_eq!(*mock.last_k.lock(), Some(1));
236 }
237
238 #[tokio::test]
239 async fn recall_errors_when_memory_not_configured() {
240 let ctx = ToolContext::default();
241 let err = MemoryRecallTool
242 .execute("c1", json!({"query": "x"}), None, &ctx)
243 .await
244 .unwrap_err();
245 assert_eq!(err, "Memory not configured");
246 }
247
248 #[tokio::test]
249 async fn recall_rejects_missing_query() {
250 let mock = Arc::new(MockMemory {
251 items: vec![],
252 last_k: Mutex::new(None),
253 });
254 let ctx = ToolContext::default().with_memory(mock);
255 let err = MemoryRecallTool
256 .execute("c1", json!({"limit": 3}), None, &ctx)
257 .await
258 .unwrap_err();
259 assert!(err.contains("query"));
260 }
261}