1use crate::store::{MemoryEntry, SearchResult, VectorStore};
8use argentor_core::{ArgentorError, ArgentorResult};
9use async_trait::async_trait;
10use std::collections::HashMap;
11use tokio::sync::RwLock;
12use uuid::Uuid;
13
14pub struct WeaviateStore {
16 #[allow(dead_code)]
18 endpoint: String,
19 #[allow(dead_code)]
21 api_key: Option<String>,
22 #[allow(dead_code)]
24 class_name: String,
25 #[cfg(feature = "http-vectorstore")]
27 #[allow(dead_code)]
28 client: Option<reqwest::Client>,
29 entries: RwLock<HashMap<Uuid, MemoryEntry>>,
31}
32
33impl WeaviateStore {
34 pub fn new(endpoint: impl Into<String>, class_name: impl Into<String>) -> Self {
36 Self {
37 endpoint: endpoint.into(),
38 api_key: None,
39 class_name: class_name.into(),
40 #[cfg(feature = "http-vectorstore")]
41 client: None,
42 entries: RwLock::new(HashMap::new()),
43 }
44 }
45
46 pub fn with_api_key(mut self, key: impl Into<String>) -> Self {
48 self.api_key = Some(key.into());
49 self
50 }
51
52 pub fn endpoint(&self) -> &str {
54 &self.endpoint
55 }
56
57 pub fn class_name(&self) -> &str {
59 &self.class_name
60 }
61
62 pub fn has_api_key(&self) -> bool {
64 self.api_key.is_some()
65 }
66
67 #[cfg(feature = "http-vectorstore")]
69 pub fn with_http_client(mut self, client: reqwest::Client) -> Self {
70 self.client = Some(client);
71 self
72 }
73
74 #[cfg(feature = "http-vectorstore")]
76 #[allow(dead_code)]
77 fn graphql_url(&self) -> String {
78 format!("{}/v1/graphql", self.endpoint.trim_end_matches('/'))
79 }
80}
81
82#[async_trait]
83impl VectorStore for WeaviateStore {
84 async fn insert(&self, entry: MemoryEntry) -> ArgentorResult<()> {
85 let mut entries = self.entries.write().await;
86 entries.insert(entry.id, entry);
87 Ok(())
88 }
89
90 async fn search(
91 &self,
92 query_embedding: &[f32],
93 top_k: usize,
94 session_filter: Option<Uuid>,
95 ) -> ArgentorResult<Vec<SearchResult>> {
96 if query_embedding.is_empty() {
97 return Err(ArgentorError::Agent("Empty query embedding".to_string()));
98 }
99 let entries = self.entries.read().await;
100 let mut scored: Vec<SearchResult> = entries
101 .values()
102 .filter(|e| {
103 if let Some(sid) = session_filter {
104 e.session_id == Some(sid)
105 } else {
106 true
107 }
108 })
109 .map(|e| {
110 let score = cosine(query_embedding, &e.embedding);
111 SearchResult {
112 entry: e.clone(),
113 score,
114 }
115 })
116 .collect();
117 scored.sort_by(|a, b| {
118 b.score
119 .partial_cmp(&a.score)
120 .unwrap_or(std::cmp::Ordering::Equal)
121 });
122 scored.truncate(top_k);
123 Ok(scored)
124 }
125
126 async fn delete(&self, id: Uuid) -> ArgentorResult<bool> {
127 let mut entries = self.entries.write().await;
128 Ok(entries.remove(&id).is_some())
129 }
130
131 async fn list(&self, session_filter: Option<Uuid>) -> ArgentorResult<Vec<MemoryEntry>> {
132 let entries = self.entries.read().await;
133 Ok(entries
134 .values()
135 .filter(|e| {
136 if let Some(sid) = session_filter {
137 e.session_id == Some(sid)
138 } else {
139 true
140 }
141 })
142 .cloned()
143 .collect())
144 }
145
146 async fn count(&self) -> ArgentorResult<usize> {
147 let entries = self.entries.read().await;
148 Ok(entries.len())
149 }
150}
151
152fn cosine(a: &[f32], b: &[f32]) -> f32 {
153 if a.len() != b.len() {
154 return 0.0;
155 }
156 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
157 let na: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
158 let nb: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
159 if na == 0.0 || nb == 0.0 {
160 0.0
161 } else {
162 dot / (na * nb)
163 }
164}
165
166#[cfg(test)]
167#[allow(clippy::unwrap_used, clippy::expect_used)]
168mod tests {
169 use super::*;
170 use chrono::Utc;
171
172 fn entry(content: &str, emb: Vec<f32>, session: Option<Uuid>) -> MemoryEntry {
173 MemoryEntry {
174 id: Uuid::new_v4(),
175 content: content.to_string(),
176 embedding: emb,
177 metadata: HashMap::new(),
178 session_id: session,
179 created_at: Utc::now(),
180 }
181 }
182
183 #[test]
184 fn test_new_sets_fields() {
185 let store = WeaviateStore::new("https://my-cluster.weaviate.network", "Document");
186 assert_eq!(store.endpoint(), "https://my-cluster.weaviate.network");
187 assert_eq!(store.class_name(), "Document");
188 assert!(!store.has_api_key());
189 }
190
191 #[test]
192 fn test_with_api_key() {
193 let store = WeaviateStore::new("https://x", "C").with_api_key("secret");
194 assert!(store.has_api_key());
195 }
196
197 #[test]
198 fn test_accepts_owned_strings() {
199 let store = WeaviateStore::new(String::from("https://x"), String::from("Class"));
200 assert_eq!(store.class_name(), "Class");
201 }
202
203 #[tokio::test]
204 async fn test_insert_count() {
205 let store = WeaviateStore::new("https://x", "C");
206 assert_eq!(store.count().await.unwrap(), 0);
207 store
208 .insert(entry("hi", vec![1.0, 0.0], None))
209 .await
210 .unwrap();
211 assert_eq!(store.count().await.unwrap(), 1);
212 }
213
214 #[tokio::test]
215 async fn test_insert_many() {
216 let store = WeaviateStore::new("https://x", "C");
217 for i in 0..20 {
218 store
219 .insert(entry(&format!("e{i}"), vec![i as f32], None))
220 .await
221 .unwrap();
222 }
223 assert_eq!(store.count().await.unwrap(), 20);
224 }
225
226 #[tokio::test]
227 async fn test_search_orders_by_similarity() {
228 let store = WeaviateStore::new("https://x", "C");
229 store
230 .insert(entry("near", vec![0.9, 0.1, 0.0], None))
231 .await
232 .unwrap();
233 store
234 .insert(entry("far", vec![0.0, 0.0, 1.0], None))
235 .await
236 .unwrap();
237 let r = store.search(&[1.0, 0.0, 0.0], 2, None).await.unwrap();
238 assert_eq!(r[0].entry.content, "near");
239 assert!(r[0].score > r[1].score);
240 }
241
242 #[tokio::test]
243 async fn test_search_top_k() {
244 let store = WeaviateStore::new("https://x", "C");
245 for i in 0..8 {
246 store
247 .insert(entry(&format!("e{i}"), vec![1.0, i as f32 / 8.0], None))
248 .await
249 .unwrap();
250 }
251 let r = store.search(&[1.0, 0.0], 4, None).await.unwrap();
252 assert_eq!(r.len(), 4);
253 }
254
255 #[tokio::test]
256 async fn test_search_empty_errors() {
257 let store = WeaviateStore::new("https://x", "C");
258 assert!(store.search(&[], 1, None).await.is_err());
259 }
260
261 #[tokio::test]
262 async fn test_search_session_filter() {
263 let store = WeaviateStore::new("https://x", "C");
264 let sid = Uuid::new_v4();
265 store
266 .insert(entry("s", vec![1.0, 0.0], Some(sid)))
267 .await
268 .unwrap();
269 store
270 .insert(entry("other", vec![1.0, 0.0], None))
271 .await
272 .unwrap();
273 let r = store.search(&[1.0, 0.0], 5, Some(sid)).await.unwrap();
274 assert_eq!(r.len(), 1);
275 assert_eq!(r[0].entry.content, "s");
276 }
277
278 #[tokio::test]
279 async fn test_delete_existing() {
280 let store = WeaviateStore::new("https://x", "C");
281 let e = entry("x", vec![1.0], None);
282 let id = e.id;
283 store.insert(e).await.unwrap();
284 assert!(store.delete(id).await.unwrap());
285 assert_eq!(store.count().await.unwrap(), 0);
286 }
287
288 #[tokio::test]
289 async fn test_delete_missing() {
290 let store = WeaviateStore::new("https://x", "C");
291 assert!(!store.delete(Uuid::new_v4()).await.unwrap());
292 }
293
294 #[tokio::test]
295 async fn test_list_all() {
296 let store = WeaviateStore::new("https://x", "C");
297 store.insert(entry("a", vec![1.0], None)).await.unwrap();
298 store.insert(entry("b", vec![0.5], None)).await.unwrap();
299 let all = store.list(None).await.unwrap();
300 assert_eq!(all.len(), 2);
301 }
302
303 #[tokio::test]
304 async fn test_list_filtered() {
305 let store = WeaviateStore::new("https://x", "C");
306 let sid = Uuid::new_v4();
307 store
308 .insert(entry("a", vec![1.0], Some(sid)))
309 .await
310 .unwrap();
311 store.insert(entry("b", vec![0.5], None)).await.unwrap();
312 let filtered = store.list(Some(sid)).await.unwrap();
313 assert_eq!(filtered.len(), 1);
314 }
315
316 #[tokio::test]
317 async fn test_metadata_preserved() {
318 let store = WeaviateStore::new("https://x", "C");
319 let mut e = entry("with-meta", vec![1.0], None);
320 e.metadata.insert("k".to_string(), serde_json::json!("v"));
321 let id = e.id;
322 store.insert(e).await.unwrap();
323 let all = store.list(None).await.unwrap();
324 let got = all.iter().find(|x| x.id == id).unwrap();
325 assert_eq!(got.metadata.get("k").unwrap(), &serde_json::json!("v"));
326 }
327
328 #[tokio::test]
329 async fn test_instances_are_isolated() {
330 let a = WeaviateStore::new("https://x", "A");
331 let b = WeaviateStore::new("https://x", "B");
332 a.insert(entry("x", vec![1.0], None)).await.unwrap();
333 assert_eq!(a.count().await.unwrap(), 1);
334 assert_eq!(b.count().await.unwrap(), 0);
335 }
336
337 #[tokio::test]
338 async fn test_search_empty_store() {
339 let store = WeaviateStore::new("https://x", "C");
340 let r = store.search(&[1.0, 0.0], 5, None).await.unwrap();
341 assert!(r.is_empty());
342 }
343
344 #[tokio::test]
345 async fn test_count_after_deletes() {
346 let store = WeaviateStore::new("https://x", "C");
347 let e = entry("a", vec![1.0], None);
348 let id = e.id;
349 store.insert(e).await.unwrap();
350 store.insert(entry("b", vec![0.5], None)).await.unwrap();
351 store.delete(id).await.unwrap();
352 assert_eq!(store.count().await.unwrap(), 1);
353 }
354}