Skip to main content

argentor_memory/
weaviate.rs

1//! Weaviate vector store adapter.
2//!
3//! Weaviate is an open-source vector database with GraphQL + REST APIs.
4//! This adapter implements the [`VectorStore`] trait with a stub backend;
5//! real HTTP calls are gated behind the `http-vectorstore` feature.
6
7use crate::store::{MemoryEntry, SearchResult, VectorStore};
8use argentor_core::{ArgentorError, ArgentorResult};
9use async_trait::async_trait;
10use std::collections::HashMap;
11use tokio::sync::RwLock;
12use uuid::Uuid;
13
14/// Weaviate vector store adapter.
15pub struct WeaviateStore {
16    /// Base endpoint (e.g., "https://my-cluster.weaviate.network").
17    #[allow(dead_code)]
18    endpoint: String,
19    /// Optional API key for authenticated clusters.
20    #[allow(dead_code)]
21    api_key: Option<String>,
22    /// Weaviate class (schema) name.
23    #[allow(dead_code)]
24    class_name: String,
25    /// HTTP client — `None` in stub mode.
26    #[cfg(feature = "http-vectorstore")]
27    #[allow(dead_code)]
28    client: Option<reqwest::Client>,
29    /// Stub in-memory storage.
30    entries: RwLock<HashMap<Uuid, MemoryEntry>>,
31}
32
33impl WeaviateStore {
34    /// Create a new Weaviate adapter in stub mode (no API key).
35    pub fn new(endpoint: impl Into<String>, class_name: impl Into<String>) -> Self {
36        Self {
37            endpoint: endpoint.into(),
38            api_key: None,
39            class_name: class_name.into(),
40            #[cfg(feature = "http-vectorstore")]
41            client: None,
42            entries: RwLock::new(HashMap::new()),
43        }
44    }
45
46    /// Attach an API key (used for Weaviate Cloud / secured clusters).
47    pub fn with_api_key(mut self, key: impl Into<String>) -> Self {
48        self.api_key = Some(key.into());
49        self
50    }
51
52    /// Return the configured endpoint.
53    pub fn endpoint(&self) -> &str {
54        &self.endpoint
55    }
56
57    /// Return the configured class name.
58    pub fn class_name(&self) -> &str {
59        &self.class_name
60    }
61
62    /// Return whether an API key is configured.
63    pub fn has_api_key(&self) -> bool {
64        self.api_key.is_some()
65    }
66
67    /// Enable real HTTP mode with a [`reqwest::Client`].
68    #[cfg(feature = "http-vectorstore")]
69    pub fn with_http_client(mut self, client: reqwest::Client) -> Self {
70        self.client = Some(client);
71        self
72    }
73
74    /// Build the GraphQL endpoint URL.
75    #[cfg(feature = "http-vectorstore")]
76    #[allow(dead_code)]
77    fn graphql_url(&self) -> String {
78        format!("{}/v1/graphql", self.endpoint.trim_end_matches('/'))
79    }
80}
81
82#[async_trait]
83impl VectorStore for WeaviateStore {
84    async fn insert(&self, entry: MemoryEntry) -> ArgentorResult<()> {
85        let mut entries = self.entries.write().await;
86        entries.insert(entry.id, entry);
87        Ok(())
88    }
89
90    async fn search(
91        &self,
92        query_embedding: &[f32],
93        top_k: usize,
94        session_filter: Option<Uuid>,
95    ) -> ArgentorResult<Vec<SearchResult>> {
96        if query_embedding.is_empty() {
97            return Err(ArgentorError::Agent("Empty query embedding".to_string()));
98        }
99        let entries = self.entries.read().await;
100        let mut scored: Vec<SearchResult> = entries
101            .values()
102            .filter(|e| {
103                if let Some(sid) = session_filter {
104                    e.session_id == Some(sid)
105                } else {
106                    true
107                }
108            })
109            .map(|e| {
110                let score = cosine(query_embedding, &e.embedding);
111                SearchResult {
112                    entry: e.clone(),
113                    score,
114                }
115            })
116            .collect();
117        scored.sort_by(|a, b| {
118            b.score
119                .partial_cmp(&a.score)
120                .unwrap_or(std::cmp::Ordering::Equal)
121        });
122        scored.truncate(top_k);
123        Ok(scored)
124    }
125
126    async fn delete(&self, id: Uuid) -> ArgentorResult<bool> {
127        let mut entries = self.entries.write().await;
128        Ok(entries.remove(&id).is_some())
129    }
130
131    async fn list(&self, session_filter: Option<Uuid>) -> ArgentorResult<Vec<MemoryEntry>> {
132        let entries = self.entries.read().await;
133        Ok(entries
134            .values()
135            .filter(|e| {
136                if let Some(sid) = session_filter {
137                    e.session_id == Some(sid)
138                } else {
139                    true
140                }
141            })
142            .cloned()
143            .collect())
144    }
145
146    async fn count(&self) -> ArgentorResult<usize> {
147        let entries = self.entries.read().await;
148        Ok(entries.len())
149    }
150}
151
152fn cosine(a: &[f32], b: &[f32]) -> f32 {
153    if a.len() != b.len() {
154        return 0.0;
155    }
156    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
157    let na: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
158    let nb: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
159    if na == 0.0 || nb == 0.0 {
160        0.0
161    } else {
162        dot / (na * nb)
163    }
164}
165
166#[cfg(test)]
167#[allow(clippy::unwrap_used, clippy::expect_used)]
168mod tests {
169    use super::*;
170    use chrono::Utc;
171
172    fn entry(content: &str, emb: Vec<f32>, session: Option<Uuid>) -> MemoryEntry {
173        MemoryEntry {
174            id: Uuid::new_v4(),
175            content: content.to_string(),
176            embedding: emb,
177            metadata: HashMap::new(),
178            session_id: session,
179            created_at: Utc::now(),
180        }
181    }
182
183    #[test]
184    fn test_new_sets_fields() {
185        let store = WeaviateStore::new("https://my-cluster.weaviate.network", "Document");
186        assert_eq!(store.endpoint(), "https://my-cluster.weaviate.network");
187        assert_eq!(store.class_name(), "Document");
188        assert!(!store.has_api_key());
189    }
190
191    #[test]
192    fn test_with_api_key() {
193        let store = WeaviateStore::new("https://x", "C").with_api_key("secret");
194        assert!(store.has_api_key());
195    }
196
197    #[test]
198    fn test_accepts_owned_strings() {
199        let store = WeaviateStore::new(String::from("https://x"), String::from("Class"));
200        assert_eq!(store.class_name(), "Class");
201    }
202
203    #[tokio::test]
204    async fn test_insert_count() {
205        let store = WeaviateStore::new("https://x", "C");
206        assert_eq!(store.count().await.unwrap(), 0);
207        store
208            .insert(entry("hi", vec![1.0, 0.0], None))
209            .await
210            .unwrap();
211        assert_eq!(store.count().await.unwrap(), 1);
212    }
213
214    #[tokio::test]
215    async fn test_insert_many() {
216        let store = WeaviateStore::new("https://x", "C");
217        for i in 0..20 {
218            store
219                .insert(entry(&format!("e{i}"), vec![i as f32], None))
220                .await
221                .unwrap();
222        }
223        assert_eq!(store.count().await.unwrap(), 20);
224    }
225
226    #[tokio::test]
227    async fn test_search_orders_by_similarity() {
228        let store = WeaviateStore::new("https://x", "C");
229        store
230            .insert(entry("near", vec![0.9, 0.1, 0.0], None))
231            .await
232            .unwrap();
233        store
234            .insert(entry("far", vec![0.0, 0.0, 1.0], None))
235            .await
236            .unwrap();
237        let r = store.search(&[1.0, 0.0, 0.0], 2, None).await.unwrap();
238        assert_eq!(r[0].entry.content, "near");
239        assert!(r[0].score > r[1].score);
240    }
241
242    #[tokio::test]
243    async fn test_search_top_k() {
244        let store = WeaviateStore::new("https://x", "C");
245        for i in 0..8 {
246            store
247                .insert(entry(&format!("e{i}"), vec![1.0, i as f32 / 8.0], None))
248                .await
249                .unwrap();
250        }
251        let r = store.search(&[1.0, 0.0], 4, None).await.unwrap();
252        assert_eq!(r.len(), 4);
253    }
254
255    #[tokio::test]
256    async fn test_search_empty_errors() {
257        let store = WeaviateStore::new("https://x", "C");
258        assert!(store.search(&[], 1, None).await.is_err());
259    }
260
261    #[tokio::test]
262    async fn test_search_session_filter() {
263        let store = WeaviateStore::new("https://x", "C");
264        let sid = Uuid::new_v4();
265        store
266            .insert(entry("s", vec![1.0, 0.0], Some(sid)))
267            .await
268            .unwrap();
269        store
270            .insert(entry("other", vec![1.0, 0.0], None))
271            .await
272            .unwrap();
273        let r = store.search(&[1.0, 0.0], 5, Some(sid)).await.unwrap();
274        assert_eq!(r.len(), 1);
275        assert_eq!(r[0].entry.content, "s");
276    }
277
278    #[tokio::test]
279    async fn test_delete_existing() {
280        let store = WeaviateStore::new("https://x", "C");
281        let e = entry("x", vec![1.0], None);
282        let id = e.id;
283        store.insert(e).await.unwrap();
284        assert!(store.delete(id).await.unwrap());
285        assert_eq!(store.count().await.unwrap(), 0);
286    }
287
288    #[tokio::test]
289    async fn test_delete_missing() {
290        let store = WeaviateStore::new("https://x", "C");
291        assert!(!store.delete(Uuid::new_v4()).await.unwrap());
292    }
293
294    #[tokio::test]
295    async fn test_list_all() {
296        let store = WeaviateStore::new("https://x", "C");
297        store.insert(entry("a", vec![1.0], None)).await.unwrap();
298        store.insert(entry("b", vec![0.5], None)).await.unwrap();
299        let all = store.list(None).await.unwrap();
300        assert_eq!(all.len(), 2);
301    }
302
303    #[tokio::test]
304    async fn test_list_filtered() {
305        let store = WeaviateStore::new("https://x", "C");
306        let sid = Uuid::new_v4();
307        store
308            .insert(entry("a", vec![1.0], Some(sid)))
309            .await
310            .unwrap();
311        store.insert(entry("b", vec![0.5], None)).await.unwrap();
312        let filtered = store.list(Some(sid)).await.unwrap();
313        assert_eq!(filtered.len(), 1);
314    }
315
316    #[tokio::test]
317    async fn test_metadata_preserved() {
318        let store = WeaviateStore::new("https://x", "C");
319        let mut e = entry("with-meta", vec![1.0], None);
320        e.metadata.insert("k".to_string(), serde_json::json!("v"));
321        let id = e.id;
322        store.insert(e).await.unwrap();
323        let all = store.list(None).await.unwrap();
324        let got = all.iter().find(|x| x.id == id).unwrap();
325        assert_eq!(got.metadata.get("k").unwrap(), &serde_json::json!("v"));
326    }
327
328    #[tokio::test]
329    async fn test_instances_are_isolated() {
330        let a = WeaviateStore::new("https://x", "A");
331        let b = WeaviateStore::new("https://x", "B");
332        a.insert(entry("x", vec![1.0], None)).await.unwrap();
333        assert_eq!(a.count().await.unwrap(), 1);
334        assert_eq!(b.count().await.unwrap(), 0);
335    }
336
337    #[tokio::test]
338    async fn test_search_empty_store() {
339        let store = WeaviateStore::new("https://x", "C");
340        let r = store.search(&[1.0, 0.0], 5, None).await.unwrap();
341        assert!(r.is_empty());
342    }
343
344    #[tokio::test]
345    async fn test_count_after_deletes() {
346        let store = WeaviateStore::new("https://x", "C");
347        let e = entry("a", vec![1.0], None);
348        let id = e.id;
349        store.insert(e).await.unwrap();
350        store.insert(entry("b", vec![0.5], None)).await.unwrap();
351        store.delete(id).await.unwrap();
352        assert_eq!(store.count().await.unwrap(), 1);
353    }
354}