1use argentor_core::{ArgentorResult, ToolCall, ToolResult};
2use argentor_memory::{EmbeddingProvider, MemoryEntry, VectorStore};
3use argentor_security::Capability;
4use argentor_skills::skill::{Skill, SkillDescriptor};
5use async_trait::async_trait;
6use chrono::Utc;
7use std::collections::HashMap;
8use std::sync::Arc;
9use uuid::Uuid;
10
11pub struct MemoryStoreSkill {
13 descriptor: SkillDescriptor,
14 store: Arc<dyn VectorStore>,
15 embedder: Arc<dyn EmbeddingProvider>,
16}
17
18impl MemoryStoreSkill {
19 pub fn new(store: Arc<dyn VectorStore>, embedder: Arc<dyn EmbeddingProvider>) -> Self {
21 Self {
22 descriptor: SkillDescriptor {
23 name: "memory_store".to_string(),
24 description: "Store text in long-term vector memory for later retrieval. \
25 Use this to save important facts, decisions, or context."
26 .to_string(),
27 parameters_schema: serde_json::json!({
28 "type": "object",
29 "properties": {
30 "content": {
31 "type": "string",
32 "description": "The text content to store in memory"
33 },
34 "metadata": {
35 "type": "object",
36 "description": "Optional metadata (tags, source, etc.)",
37 "additionalProperties": true
38 },
39 "session_id": {
40 "type": "string",
41 "description": "Optional session ID to associate with this memory"
42 }
43 },
44 "required": ["content"]
45 }),
46 required_capabilities: vec![Capability::DatabaseQuery],
47 requires_approval: false,
48 },
49 store,
50 embedder,
51 }
52 }
53}
54
55#[async_trait]
56impl Skill for MemoryStoreSkill {
57 fn descriptor(&self) -> &SkillDescriptor {
58 &self.descriptor
59 }
60
61 async fn execute(&self, call: ToolCall) -> ArgentorResult<ToolResult> {
62 let content = call.arguments["content"]
63 .as_str()
64 .unwrap_or_default()
65 .to_string();
66
67 if content.is_empty() {
68 return Ok(ToolResult::error(&call.id, "Content cannot be empty"));
69 }
70
71 let embedding = match self.embedder.embed(&content).await {
73 Ok(emb) => emb,
74 Err(e) => {
75 return Ok(ToolResult::error(
76 &call.id,
77 format!("Failed to compute embedding: {e}"),
78 ))
79 }
80 };
81
82 let metadata: HashMap<String, serde_json::Value> = call
84 .arguments
85 .get("metadata")
86 .and_then(|m| serde_json::from_value(m.clone()).ok())
87 .unwrap_or_default();
88
89 let session_id = call
91 .arguments
92 .get("session_id")
93 .and_then(|s| s.as_str())
94 .and_then(|s| Uuid::parse_str(s).ok());
95
96 let entry_id = Uuid::new_v4();
97 let entry = MemoryEntry {
98 id: entry_id,
99 content: content.clone(),
100 embedding,
101 metadata,
102 session_id,
103 created_at: Utc::now(),
104 };
105
106 if let Err(e) = self.store.insert(entry).await {
107 return Ok(ToolResult::error(
108 &call.id,
109 format!("Failed to store memory: {e}"),
110 ));
111 }
112
113 let response = serde_json::json!({
114 "stored": true,
115 "id": entry_id.to_string(),
116 "content_length": content.len(),
117 });
118 Ok(ToolResult::success(&call.id, response.to_string()))
119 }
120}
121
122pub struct MemorySearchSkill {
124 descriptor: SkillDescriptor,
125 store: Arc<dyn VectorStore>,
126 embedder: Arc<dyn EmbeddingProvider>,
127}
128
129impl MemorySearchSkill {
130 pub fn new(store: Arc<dyn VectorStore>, embedder: Arc<dyn EmbeddingProvider>) -> Self {
132 Self {
133 descriptor: SkillDescriptor {
134 name: "memory_search".to_string(),
135 description: "Search long-term vector memory for relevant past information. \
136 Returns the most semantically similar stored memories."
137 .to_string(),
138 parameters_schema: serde_json::json!({
139 "type": "object",
140 "properties": {
141 "query": {
142 "type": "string",
143 "description": "The search query text"
144 },
145 "top_k": {
146 "type": "integer",
147 "description": "Number of results to return (default: 5, max: 20)",
148 "default": 5
149 },
150 "session_id": {
151 "type": "string",
152 "description": "Optional session ID to filter results"
153 }
154 },
155 "required": ["query"]
156 }),
157 required_capabilities: vec![Capability::DatabaseQuery],
158 requires_approval: false,
159 },
160 store,
161 embedder,
162 }
163 }
164}
165
166#[async_trait]
167impl Skill for MemorySearchSkill {
168 fn descriptor(&self) -> &SkillDescriptor {
169 &self.descriptor
170 }
171
172 async fn execute(&self, call: ToolCall) -> ArgentorResult<ToolResult> {
173 let query = call.arguments["query"]
174 .as_str()
175 .unwrap_or_default()
176 .to_string();
177
178 if query.is_empty() {
179 return Ok(ToolResult::error(&call.id, "Query cannot be empty"));
180 }
181
182 let top_k = call.arguments["top_k"].as_u64().unwrap_or(5).min(20) as usize;
183
184 let session_filter = call
185 .arguments
186 .get("session_id")
187 .and_then(|s| s.as_str())
188 .and_then(|s| Uuid::parse_str(s).ok());
189
190 let query_embedding = match self.embedder.embed(&query).await {
192 Ok(emb) => emb,
193 Err(e) => {
194 return Ok(ToolResult::error(
195 &call.id,
196 format!("Failed to compute query embedding: {e}"),
197 ))
198 }
199 };
200
201 let results = match self
203 .store
204 .search(&query_embedding, top_k, session_filter)
205 .await
206 {
207 Ok(r) => r,
208 Err(e) => return Ok(ToolResult::error(&call.id, format!("Search failed: {e}"))),
209 };
210
211 let results_json: Vec<serde_json::Value> = results
212 .iter()
213 .map(|r| {
214 serde_json::json!({
215 "id": r.entry.id.to_string(),
216 "content": r.entry.content,
217 "score": r.score,
218 "metadata": r.entry.metadata,
219 "created_at": r.entry.created_at.to_rfc3339(),
220 })
221 })
222 .collect();
223
224 let response = serde_json::json!({
225 "query": query,
226 "results": results_json,
227 "total": results_json.len(),
228 });
229
230 Ok(ToolResult::success(&call.id, response.to_string()))
231 }
232}
233
234#[cfg(test)]
235#[allow(clippy::unwrap_used, clippy::expect_used)]
236mod tests {
237 use super::*;
238 use argentor_memory::{InMemoryVectorStore, LocalEmbedding};
239
240 fn make_skills() -> (MemoryStoreSkill, MemorySearchSkill) {
241 let store: Arc<dyn VectorStore> = Arc::new(InMemoryVectorStore::new());
242 let embedder: Arc<dyn EmbeddingProvider> = Arc::new(LocalEmbedding::default());
243 let store_skill = MemoryStoreSkill::new(store.clone(), embedder.clone());
244 let search_skill = MemorySearchSkill::new(store, embedder);
245 (store_skill, search_skill)
246 }
247
248 #[tokio::test]
249 async fn test_memory_store_basic() {
250 let (store_skill, _) = make_skills();
251 let call = ToolCall {
252 id: "t1".to_string(),
253 name: "memory_store".to_string(),
254 arguments: serde_json::json!({"content": "Rust is a systems programming language"}),
255 };
256 let result = store_skill.execute(call).await.unwrap();
257 assert!(!result.is_error);
258 assert!(result.content.contains("\"stored\":true"));
259 }
260
261 #[tokio::test]
262 async fn test_memory_store_empty_content() {
263 let (store_skill, _) = make_skills();
264 let call = ToolCall {
265 id: "t2".to_string(),
266 name: "memory_store".to_string(),
267 arguments: serde_json::json!({"content": ""}),
268 };
269 let result = store_skill.execute(call).await.unwrap();
270 assert!(result.is_error);
271 }
272
273 #[tokio::test]
274 async fn test_memory_store_with_metadata() {
275 let (store_skill, _) = make_skills();
276 let call = ToolCall {
277 id: "t3".to_string(),
278 name: "memory_store".to_string(),
279 arguments: serde_json::json!({
280 "content": "Important decision: use Rust",
281 "metadata": {"tag": "architecture", "priority": "high"}
282 }),
283 };
284 let result = store_skill.execute(call).await.unwrap();
285 assert!(!result.is_error);
286 }
287
288 #[tokio::test]
289 async fn test_memory_search_basic() {
290 let (store_skill, search_skill) = make_skills();
291
292 for content in &[
294 "Rust is great for systems",
295 "Python for data science",
296 "Go for networking",
297 ] {
298 let call = ToolCall {
299 id: "s".to_string(),
300 name: "memory_store".to_string(),
301 arguments: serde_json::json!({"content": content}),
302 };
303 store_skill.execute(call).await.unwrap();
304 }
305
306 let call = ToolCall {
308 id: "q1".to_string(),
309 name: "memory_search".to_string(),
310 arguments: serde_json::json!({"query": "systems programming language"}),
311 };
312 let result = search_skill.execute(call).await.unwrap();
313 assert!(!result.is_error);
314
315 let parsed: serde_json::Value = serde_json::from_str(&result.content).unwrap();
316 assert!(parsed["total"].as_u64().unwrap() > 0);
317 assert!(parsed["results"][0]["content"]
319 .as_str()
320 .unwrap()
321 .contains("Rust"));
322 }
323
324 #[tokio::test]
325 async fn test_memory_search_empty_query() {
326 let (_, search_skill) = make_skills();
327 let call = ToolCall {
328 id: "q2".to_string(),
329 name: "memory_search".to_string(),
330 arguments: serde_json::json!({"query": ""}),
331 };
332 let result = search_skill.execute(call).await.unwrap();
333 assert!(result.is_error);
334 }
335
336 #[tokio::test]
337 async fn test_memory_search_no_results() {
338 let (_, search_skill) = make_skills();
339 let call = ToolCall {
340 id: "q3".to_string(),
341 name: "memory_search".to_string(),
342 arguments: serde_json::json!({"query": "anything"}),
343 };
344 let result = search_skill.execute(call).await.unwrap();
345 assert!(!result.is_error);
346 let parsed: serde_json::Value = serde_json::from_str(&result.content).unwrap();
347 assert_eq!(parsed["total"].as_u64().unwrap(), 0);
348 }
349
350 #[tokio::test]
351 async fn test_memory_search_with_top_k() {
352 let (store_skill, search_skill) = make_skills();
353
354 for i in 0..10 {
355 let call = ToolCall {
356 id: format!("s{i}"),
357 name: "memory_store".to_string(),
358 arguments: serde_json::json!({"content": format!("Memory entry number {}", i)}),
359 };
360 store_skill.execute(call).await.unwrap();
361 }
362
363 let call = ToolCall {
364 id: "q".to_string(),
365 name: "memory_search".to_string(),
366 arguments: serde_json::json!({"query": "memory entry", "top_k": 3}),
367 };
368 let result = search_skill.execute(call).await.unwrap();
369 let parsed: serde_json::Value = serde_json::from_str(&result.content).unwrap();
370 assert_eq!(parsed["total"].as_u64().unwrap(), 3);
371 }
372
373 #[test]
374 fn test_descriptors() {
375 let store: Arc<dyn VectorStore> = Arc::new(InMemoryVectorStore::new());
376 let embedder: Arc<dyn EmbeddingProvider> = Arc::new(LocalEmbedding::default());
377
378 let ms = MemoryStoreSkill::new(store.clone(), embedder.clone());
379 assert_eq!(ms.descriptor().name, "memory_store");
380
381 let msearch = MemorySearchSkill::new(store, embedder);
382 assert_eq!(msearch.descriptor().name, "memory_search");
383 }
384}