Skip to main content

argentor_builtins/
memory.rs

1use argentor_core::{ArgentorResult, ToolCall, ToolResult};
2use argentor_memory::{EmbeddingProvider, MemoryEntry, VectorStore};
3use argentor_security::Capability;
4use argentor_skills::skill::{Skill, SkillDescriptor};
5use async_trait::async_trait;
6use chrono::Utc;
7use std::collections::HashMap;
8use std::sync::Arc;
9use uuid::Uuid;
10
11/// Skill that stores text in the vector memory.
12pub struct MemoryStoreSkill {
13    descriptor: SkillDescriptor,
14    store: Arc<dyn VectorStore>,
15    embedder: Arc<dyn EmbeddingProvider>,
16}
17
18impl MemoryStoreSkill {
19    /// Create a new memory-store skill using the given vector store and embedding provider.
20    pub fn new(store: Arc<dyn VectorStore>, embedder: Arc<dyn EmbeddingProvider>) -> Self {
21        Self {
22            descriptor: SkillDescriptor {
23                name: "memory_store".to_string(),
24                description: "Store text in long-term vector memory for later retrieval. \
25                              Use this to save important facts, decisions, or context."
26                    .to_string(),
27                parameters_schema: serde_json::json!({
28                    "type": "object",
29                    "properties": {
30                        "content": {
31                            "type": "string",
32                            "description": "The text content to store in memory"
33                        },
34                        "metadata": {
35                            "type": "object",
36                            "description": "Optional metadata (tags, source, etc.)",
37                            "additionalProperties": true
38                        },
39                        "session_id": {
40                            "type": "string",
41                            "description": "Optional session ID to associate with this memory"
42                        }
43                    },
44                    "required": ["content"]
45                }),
46                required_capabilities: vec![Capability::DatabaseQuery],
47                requires_approval: false,
48            },
49            store,
50            embedder,
51        }
52    }
53}
54
55#[async_trait]
56impl Skill for MemoryStoreSkill {
57    fn descriptor(&self) -> &SkillDescriptor {
58        &self.descriptor
59    }
60
61    async fn execute(&self, call: ToolCall) -> ArgentorResult<ToolResult> {
62        let content = call.arguments["content"]
63            .as_str()
64            .unwrap_or_default()
65            .to_string();
66
67        if content.is_empty() {
68            return Ok(ToolResult::error(&call.id, "Content cannot be empty"));
69        }
70
71        // Compute embedding
72        let embedding = match self.embedder.embed(&content).await {
73            Ok(emb) => emb,
74            Err(e) => {
75                return Ok(ToolResult::error(
76                    &call.id,
77                    format!("Failed to compute embedding: {e}"),
78                ))
79            }
80        };
81
82        // Parse optional metadata
83        let metadata: HashMap<String, serde_json::Value> = call
84            .arguments
85            .get("metadata")
86            .and_then(|m| serde_json::from_value(m.clone()).ok())
87            .unwrap_or_default();
88
89        // Parse optional session_id
90        let session_id = call
91            .arguments
92            .get("session_id")
93            .and_then(|s| s.as_str())
94            .and_then(|s| Uuid::parse_str(s).ok());
95
96        let entry_id = Uuid::new_v4();
97        let entry = MemoryEntry {
98            id: entry_id,
99            content: content.clone(),
100            embedding,
101            metadata,
102            session_id,
103            created_at: Utc::now(),
104        };
105
106        if let Err(e) = self.store.insert(entry).await {
107            return Ok(ToolResult::error(
108                &call.id,
109                format!("Failed to store memory: {e}"),
110            ));
111        }
112
113        let response = serde_json::json!({
114            "stored": true,
115            "id": entry_id.to_string(),
116            "content_length": content.len(),
117        });
118        Ok(ToolResult::success(&call.id, response.to_string()))
119    }
120}
121
122/// Skill that searches the vector memory for similar text.
123pub struct MemorySearchSkill {
124    descriptor: SkillDescriptor,
125    store: Arc<dyn VectorStore>,
126    embedder: Arc<dyn EmbeddingProvider>,
127}
128
129impl MemorySearchSkill {
130    /// Create a new memory-search skill using the given vector store and embedding provider.
131    pub fn new(store: Arc<dyn VectorStore>, embedder: Arc<dyn EmbeddingProvider>) -> Self {
132        Self {
133            descriptor: SkillDescriptor {
134                name: "memory_search".to_string(),
135                description: "Search long-term vector memory for relevant past information. \
136                              Returns the most semantically similar stored memories."
137                    .to_string(),
138                parameters_schema: serde_json::json!({
139                    "type": "object",
140                    "properties": {
141                        "query": {
142                            "type": "string",
143                            "description": "The search query text"
144                        },
145                        "top_k": {
146                            "type": "integer",
147                            "description": "Number of results to return (default: 5, max: 20)",
148                            "default": 5
149                        },
150                        "session_id": {
151                            "type": "string",
152                            "description": "Optional session ID to filter results"
153                        }
154                    },
155                    "required": ["query"]
156                }),
157                required_capabilities: vec![Capability::DatabaseQuery],
158                requires_approval: false,
159            },
160            store,
161            embedder,
162        }
163    }
164}
165
166#[async_trait]
167impl Skill for MemorySearchSkill {
168    fn descriptor(&self) -> &SkillDescriptor {
169        &self.descriptor
170    }
171
172    async fn execute(&self, call: ToolCall) -> ArgentorResult<ToolResult> {
173        let query = call.arguments["query"]
174            .as_str()
175            .unwrap_or_default()
176            .to_string();
177
178        if query.is_empty() {
179            return Ok(ToolResult::error(&call.id, "Query cannot be empty"));
180        }
181
182        let top_k = call.arguments["top_k"].as_u64().unwrap_or(5).min(20) as usize;
183
184        let session_filter = call
185            .arguments
186            .get("session_id")
187            .and_then(|s| s.as_str())
188            .and_then(|s| Uuid::parse_str(s).ok());
189
190        // Compute query embedding
191        let query_embedding = match self.embedder.embed(&query).await {
192            Ok(emb) => emb,
193            Err(e) => {
194                return Ok(ToolResult::error(
195                    &call.id,
196                    format!("Failed to compute query embedding: {e}"),
197                ))
198            }
199        };
200
201        // Search
202        let results = match self
203            .store
204            .search(&query_embedding, top_k, session_filter)
205            .await
206        {
207            Ok(r) => r,
208            Err(e) => return Ok(ToolResult::error(&call.id, format!("Search failed: {e}"))),
209        };
210
211        let results_json: Vec<serde_json::Value> = results
212            .iter()
213            .map(|r| {
214                serde_json::json!({
215                    "id": r.entry.id.to_string(),
216                    "content": r.entry.content,
217                    "score": r.score,
218                    "metadata": r.entry.metadata,
219                    "created_at": r.entry.created_at.to_rfc3339(),
220                })
221            })
222            .collect();
223
224        let response = serde_json::json!({
225            "query": query,
226            "results": results_json,
227            "total": results_json.len(),
228        });
229
230        Ok(ToolResult::success(&call.id, response.to_string()))
231    }
232}
233
234#[cfg(test)]
235#[allow(clippy::unwrap_used, clippy::expect_used)]
236mod tests {
237    use super::*;
238    use argentor_memory::{InMemoryVectorStore, LocalEmbedding};
239
240    fn make_skills() -> (MemoryStoreSkill, MemorySearchSkill) {
241        let store: Arc<dyn VectorStore> = Arc::new(InMemoryVectorStore::new());
242        let embedder: Arc<dyn EmbeddingProvider> = Arc::new(LocalEmbedding::default());
243        let store_skill = MemoryStoreSkill::new(store.clone(), embedder.clone());
244        let search_skill = MemorySearchSkill::new(store, embedder);
245        (store_skill, search_skill)
246    }
247
248    #[tokio::test]
249    async fn test_memory_store_basic() {
250        let (store_skill, _) = make_skills();
251        let call = ToolCall {
252            id: "t1".to_string(),
253            name: "memory_store".to_string(),
254            arguments: serde_json::json!({"content": "Rust is a systems programming language"}),
255        };
256        let result = store_skill.execute(call).await.unwrap();
257        assert!(!result.is_error);
258        assert!(result.content.contains("\"stored\":true"));
259    }
260
261    #[tokio::test]
262    async fn test_memory_store_empty_content() {
263        let (store_skill, _) = make_skills();
264        let call = ToolCall {
265            id: "t2".to_string(),
266            name: "memory_store".to_string(),
267            arguments: serde_json::json!({"content": ""}),
268        };
269        let result = store_skill.execute(call).await.unwrap();
270        assert!(result.is_error);
271    }
272
273    #[tokio::test]
274    async fn test_memory_store_with_metadata() {
275        let (store_skill, _) = make_skills();
276        let call = ToolCall {
277            id: "t3".to_string(),
278            name: "memory_store".to_string(),
279            arguments: serde_json::json!({
280                "content": "Important decision: use Rust",
281                "metadata": {"tag": "architecture", "priority": "high"}
282            }),
283        };
284        let result = store_skill.execute(call).await.unwrap();
285        assert!(!result.is_error);
286    }
287
288    #[tokio::test]
289    async fn test_memory_search_basic() {
290        let (store_skill, search_skill) = make_skills();
291
292        // Store some entries
293        for content in &[
294            "Rust is great for systems",
295            "Python for data science",
296            "Go for networking",
297        ] {
298            let call = ToolCall {
299                id: "s".to_string(),
300                name: "memory_store".to_string(),
301                arguments: serde_json::json!({"content": content}),
302            };
303            store_skill.execute(call).await.unwrap();
304        }
305
306        // Search
307        let call = ToolCall {
308            id: "q1".to_string(),
309            name: "memory_search".to_string(),
310            arguments: serde_json::json!({"query": "systems programming language"}),
311        };
312        let result = search_skill.execute(call).await.unwrap();
313        assert!(!result.is_error);
314
315        let parsed: serde_json::Value = serde_json::from_str(&result.content).unwrap();
316        assert!(parsed["total"].as_u64().unwrap() > 0);
317        // First result should be the Rust entry (most similar)
318        assert!(parsed["results"][0]["content"]
319            .as_str()
320            .unwrap()
321            .contains("Rust"));
322    }
323
324    #[tokio::test]
325    async fn test_memory_search_empty_query() {
326        let (_, search_skill) = make_skills();
327        let call = ToolCall {
328            id: "q2".to_string(),
329            name: "memory_search".to_string(),
330            arguments: serde_json::json!({"query": ""}),
331        };
332        let result = search_skill.execute(call).await.unwrap();
333        assert!(result.is_error);
334    }
335
336    #[tokio::test]
337    async fn test_memory_search_no_results() {
338        let (_, search_skill) = make_skills();
339        let call = ToolCall {
340            id: "q3".to_string(),
341            name: "memory_search".to_string(),
342            arguments: serde_json::json!({"query": "anything"}),
343        };
344        let result = search_skill.execute(call).await.unwrap();
345        assert!(!result.is_error);
346        let parsed: serde_json::Value = serde_json::from_str(&result.content).unwrap();
347        assert_eq!(parsed["total"].as_u64().unwrap(), 0);
348    }
349
350    #[tokio::test]
351    async fn test_memory_search_with_top_k() {
352        let (store_skill, search_skill) = make_skills();
353
354        for i in 0..10 {
355            let call = ToolCall {
356                id: format!("s{i}"),
357                name: "memory_store".to_string(),
358                arguments: serde_json::json!({"content": format!("Memory entry number {}", i)}),
359            };
360            store_skill.execute(call).await.unwrap();
361        }
362
363        let call = ToolCall {
364            id: "q".to_string(),
365            name: "memory_search".to_string(),
366            arguments: serde_json::json!({"query": "memory entry", "top_k": 3}),
367        };
368        let result = search_skill.execute(call).await.unwrap();
369        let parsed: serde_json::Value = serde_json::from_str(&result.content).unwrap();
370        assert_eq!(parsed["total"].as_u64().unwrap(), 3);
371    }
372
373    #[test]
374    fn test_descriptors() {
375        let store: Arc<dyn VectorStore> = Arc::new(InMemoryVectorStore::new());
376        let embedder: Arc<dyn EmbeddingProvider> = Arc::new(LocalEmbedding::default());
377
378        let ms = MemoryStoreSkill::new(store.clone(), embedder.clone());
379        assert_eq!(ms.descriptor().name, "memory_store");
380
381        let msearch = MemorySearchSkill::new(store, embedder);
382        assert_eq!(msearch.descriptor().name, "memory_search");
383    }
384}