1use async_trait::async_trait;
4use serde_json::{Value, json};
5
6use super::{AgentTool, AgentToolResult, ToolContext, ToolError};
7
8const VALID_KINDS: [&str; 4] = ["fact", "preference", "context", "summary"];
10
11pub struct MemoryRetainTool;
18
19#[async_trait]
20impl AgentTool for MemoryRetainTool {
21 fn name(&self) -> &str {
22 "memory_retain"
23 }
24
25 fn label(&self) -> &str {
26 "Memory Retain"
27 }
28
29 fn description(&self) -> &str {
30 "Store a piece of information to long-term memory for later recall. \
31 Use for facts, preferences, context, or summaries worth remembering \
32 across sessions."
33 }
34
35 fn essential(&self) -> bool {
36 false
37 }
38
39 fn parameters_schema(&self) -> Value {
40 json!({
41 "type": "object",
42 "properties": {
43 "content": {
44 "type": "string",
45 "description": "The text to remember."
46 },
47 "kind": {
48 "type": "string",
49 "enum": ["fact", "preference", "context", "summary"],
50 "default": "fact",
51 "description": "Category of the memory."
52 },
53 "importance": {
54 "type": "number",
55 "minimum": 0.0,
56 "maximum": 1.0,
57 "default": 0.5,
58 "description": "How important this memory is (0–1)."
59 }
60 },
61 "required": ["content"]
62 })
63 }
64
65 async fn execute(
66 &self,
67 _tool_call_id: &str,
68 params: Value,
69 _signal: Option<tokio::sync::oneshot::Receiver<()>>,
70 ctx: &ToolContext,
71 ) -> Result<AgentToolResult, ToolError> {
72 let backend = ctx.memory.as_ref().ok_or("Memory not configured")?;
73
74 let content = params
75 .get("content")
76 .and_then(|v| v.as_str())
77 .ok_or("Missing required parameter: content")?;
78
79 let kind = params
80 .get("kind")
81 .and_then(|v| v.as_str())
82 .unwrap_or("fact");
83 if !VALID_KINDS.contains(&kind) {
84 return Err(format!(
85 "Invalid kind '{}': expected one of {:?}",
86 kind, VALID_KINDS
87 ));
88 }
89
90 if let Some(importance) = params.get("importance").and_then(|v| v.as_f64())
93 && !(0.0..=1.0).contains(&importance)
94 {
95 return Err(format!(
96 "importance must be between 0 and 1, got {}",
97 importance
98 ));
99 }
100
101 let subject = ctx.session_id.as_deref().unwrap_or("default");
102 backend.put(content, kind, subject).await?;
103
104 Ok(AgentToolResult::success(format!(
105 "Retained [{}] to memory.",
106 kind
107 )))
108 }
109}
110
111#[cfg(test)]
112mod tests {
113 use super::*;
114 use crate::tools::MemoryBackend;
115 use parking_lot::Mutex;
116 use std::future::Future;
117 use std::pin::Pin;
118 use std::sync::Arc;
119
120 #[derive(Debug)]
122 struct MockMemory {
123 puts: Mutex<Vec<(String, String, String)>>,
124 }
125
126 impl MockMemory {
127 fn new() -> Self {
128 Self {
129 puts: Mutex::new(vec![]),
130 }
131 }
132 }
133
134 impl MemoryBackend for MockMemory {
135 fn put<'a>(
136 &'a self,
137 content: &'a str,
138 kind: &'a str,
139 subject: &'a str,
140 ) -> Pin<Box<dyn Future<Output = Result<String, ToolError>> + Send + 'a>> {
141 self.puts
142 .lock()
143 .push((content.into(), kind.into(), subject.into()));
144 Box::pin(async move { Ok("mem-1".to_string()) })
145 }
146
147 fn search<'a>(
148 &'a self,
149 _query: &'a str,
150 _k: usize,
151 ) -> Pin<
152 Box<dyn Future<Output = Result<Vec<crate::tools::MemoryItem>, ToolError>> + Send + 'a>,
153 > {
154 Box::pin(async move { Ok(vec![]) })
155 }
156
157 fn list<'a>(
158 &'a self,
159 _subject: &'a str,
160 ) -> Pin<
161 Box<dyn Future<Output = Result<Vec<crate::tools::MemoryItem>, ToolError>> + Send + 'a>,
162 > {
163 Box::pin(async move { Ok(vec![]) })
164 }
165
166 fn delete<'a>(
167 &'a self,
168 _id: &'a str,
169 ) -> Pin<Box<dyn Future<Output = Result<(), ToolError>> + Send + 'a>> {
170 Box::pin(async move { Ok(()) })
171 }
172 }
173
174 #[tokio::test]
175 async fn retain_calls_put_with_correct_args() {
176 let mock = Arc::new(MockMemory::new());
177 let ctx = ToolContext::default()
178 .with_session("sess-42")
179 .with_memory(mock.clone());
180 let result = MemoryRetainTool
181 .execute(
182 "c1",
183 json!({"content": "hello", "kind": "fact", "importance": 0.9}),
184 None,
185 &ctx,
186 )
187 .await
188 .unwrap();
189 assert!(result.success);
190 assert_eq!(result.output, "Retained [fact] to memory.");
191 let puts = mock.puts.lock();
192 assert_eq!(puts.len(), 1);
193 assert_eq!(puts[0].0, "hello");
194 assert_eq!(puts[0].1, "fact");
195 assert_eq!(puts[0].2, "sess-42");
196 }
197
198 #[tokio::test]
199 async fn retain_defaults_kind_to_fact() {
200 let mock = Arc::new(MockMemory::new());
201 let ctx = ToolContext::default().with_memory(mock.clone());
202 let result = MemoryRetainTool
203 .execute("c1", json!({"content": "x"}), None, &ctx)
204 .await
205 .unwrap();
206 assert_eq!(result.output, "Retained [fact] to memory.");
207 assert_eq!(mock.puts.lock()[0].1, "fact");
208 }
209
210 #[tokio::test]
211 async fn retain_uses_default_subject_without_session() {
212 let mock = Arc::new(MockMemory::new());
213 let ctx = ToolContext::default().with_memory(mock.clone());
214 MemoryRetainTool
215 .execute("c1", json!({"content": "x"}), None, &ctx)
216 .await
217 .unwrap();
218 assert_eq!(mock.puts.lock()[0].2, "default");
219 }
220
221 #[tokio::test]
222 async fn retain_errors_when_memory_not_configured() {
223 let ctx = ToolContext::default();
224 let err = MemoryRetainTool
225 .execute("c1", json!({"content": "x"}), None, &ctx)
226 .await
227 .unwrap_err();
228 assert_eq!(err, "Memory not configured");
229 }
230
231 #[tokio::test]
232 async fn retain_rejects_invalid_kind() {
233 let mock = Arc::new(MockMemory::new());
234 let ctx = ToolContext::default().with_memory(mock.clone());
235 let err = MemoryRetainTool
236 .execute("c1", json!({"content": "x", "kind": "bogus"}), None, &ctx)
237 .await
238 .unwrap_err();
239 assert!(err.contains("Invalid kind"));
240 }
241
242 #[tokio::test]
243 async fn retain_rejects_out_of_range_importance() {
244 let mock = Arc::new(MockMemory::new());
245 let ctx = ToolContext::default().with_memory(mock.clone());
246 let err = MemoryRetainTool
247 .execute("c1", json!({"content": "x", "importance": 1.5}), None, &ctx)
248 .await
249 .unwrap_err();
250 assert!(err.contains("importance"));
251 }
252
253 #[tokio::test]
254 async fn retain_rejects_missing_content() {
255 let mock = Arc::new(MockMemory::new());
256 let ctx = ToolContext::default().with_memory(mock.clone());
257 let err = MemoryRetainTool
258 .execute("c1", json!({"kind": "fact"}), None, &ctx)
259 .await
260 .unwrap_err();
261 assert!(err.contains("content"));
262 }
263}