1use crate::store::{MemoryEntry, SearchResult, VectorStore};
8use argentor_core::{ArgentorError, ArgentorResult};
9use async_trait::async_trait;
10use std::collections::HashMap;
11use tokio::sync::RwLock;
12use uuid::Uuid;
13
14pub struct QdrantStore {
16 #[allow(dead_code)]
18 endpoint: String,
19 #[allow(dead_code)]
21 api_key: Option<String>,
22 #[allow(dead_code)]
24 collection_name: String,
25 #[allow(dead_code)]
27 vector_size: usize,
28 #[cfg(feature = "http-vectorstore")]
30 #[allow(dead_code)]
31 client: Option<reqwest::Client>,
32 entries: RwLock<HashMap<Uuid, MemoryEntry>>,
34}
35
36impl QdrantStore {
37 pub fn new(
39 endpoint: impl Into<String>,
40 collection_name: impl Into<String>,
41 vector_size: usize,
42 ) -> Self {
43 Self {
44 endpoint: endpoint.into(),
45 api_key: None,
46 collection_name: collection_name.into(),
47 vector_size,
48 #[cfg(feature = "http-vectorstore")]
49 client: None,
50 entries: RwLock::new(HashMap::new()),
51 }
52 }
53
54 pub fn with_api_key(mut self, key: impl Into<String>) -> Self {
56 self.api_key = Some(key.into());
57 self
58 }
59
60 pub fn endpoint(&self) -> &str {
62 &self.endpoint
63 }
64
65 pub fn collection_name(&self) -> &str {
67 &self.collection_name
68 }
69
70 pub fn vector_size(&self) -> usize {
72 self.vector_size
73 }
74
75 pub fn has_api_key(&self) -> bool {
77 self.api_key.is_some()
78 }
79
80 #[cfg(feature = "http-vectorstore")]
82 pub fn with_http_client(mut self, client: reqwest::Client) -> Self {
83 self.client = Some(client);
84 self
85 }
86
87 #[cfg(feature = "http-vectorstore")]
89 #[allow(dead_code)]
90 fn upsert_url(&self) -> String {
91 format!(
92 "{}/collections/{}/points",
93 self.endpoint.trim_end_matches('/'),
94 self.collection_name
95 )
96 }
97}
98
99#[async_trait]
100impl VectorStore for QdrantStore {
101 async fn insert(&self, entry: MemoryEntry) -> ArgentorResult<()> {
102 if !entry.embedding.is_empty() && entry.embedding.len() != self.vector_size {
103 return Err(ArgentorError::Agent(format!(
104 "Qdrant: embedding dim mismatch (got {}, expected {})",
105 entry.embedding.len(),
106 self.vector_size
107 )));
108 }
109 let mut entries = self.entries.write().await;
110 entries.insert(entry.id, entry);
111 Ok(())
112 }
113
114 async fn search(
115 &self,
116 query_embedding: &[f32],
117 top_k: usize,
118 session_filter: Option<Uuid>,
119 ) -> ArgentorResult<Vec<SearchResult>> {
120 if query_embedding.is_empty() {
121 return Err(ArgentorError::Agent("Empty query embedding".to_string()));
122 }
123 let entries = self.entries.read().await;
124 let mut scored: Vec<SearchResult> = entries
125 .values()
126 .filter(|e| {
127 if let Some(sid) = session_filter {
128 e.session_id == Some(sid)
129 } else {
130 true
131 }
132 })
133 .map(|e| {
134 let score = cosine(query_embedding, &e.embedding);
135 SearchResult {
136 entry: e.clone(),
137 score,
138 }
139 })
140 .collect();
141 scored.sort_by(|a, b| {
142 b.score
143 .partial_cmp(&a.score)
144 .unwrap_or(std::cmp::Ordering::Equal)
145 });
146 scored.truncate(top_k);
147 Ok(scored)
148 }
149
150 async fn delete(&self, id: Uuid) -> ArgentorResult<bool> {
151 let mut entries = self.entries.write().await;
152 Ok(entries.remove(&id).is_some())
153 }
154
155 async fn list(&self, session_filter: Option<Uuid>) -> ArgentorResult<Vec<MemoryEntry>> {
156 let entries = self.entries.read().await;
157 Ok(entries
158 .values()
159 .filter(|e| {
160 if let Some(sid) = session_filter {
161 e.session_id == Some(sid)
162 } else {
163 true
164 }
165 })
166 .cloned()
167 .collect())
168 }
169
170 async fn count(&self) -> ArgentorResult<usize> {
171 let entries = self.entries.read().await;
172 Ok(entries.len())
173 }
174}
175
176fn cosine(a: &[f32], b: &[f32]) -> f32 {
177 if a.len() != b.len() {
178 return 0.0;
179 }
180 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
181 let na: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
182 let nb: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
183 if na == 0.0 || nb == 0.0 {
184 0.0
185 } else {
186 dot / (na * nb)
187 }
188}
189
190#[cfg(test)]
191#[allow(clippy::unwrap_used, clippy::expect_used)]
192mod tests {
193 use super::*;
194 use chrono::Utc;
195
196 fn entry(content: &str, emb: Vec<f32>, session: Option<Uuid>) -> MemoryEntry {
197 MemoryEntry {
198 id: Uuid::new_v4(),
199 content: content.to_string(),
200 embedding: emb,
201 metadata: HashMap::new(),
202 session_id: session,
203 created_at: Utc::now(),
204 }
205 }
206
207 #[test]
208 fn test_new_sets_fields() {
209 let s = QdrantStore::new("http://localhost:6333", "docs", 768);
210 assert_eq!(s.endpoint(), "http://localhost:6333");
211 assert_eq!(s.collection_name(), "docs");
212 assert_eq!(s.vector_size(), 768);
213 assert!(!s.has_api_key());
214 }
215
216 #[test]
217 fn test_with_api_key() {
218 let s = QdrantStore::new("http://x", "c", 3).with_api_key("tok");
219 assert!(s.has_api_key());
220 }
221
222 #[test]
223 fn test_accepts_owned_strings() {
224 let s = QdrantStore::new(String::from("http://x"), String::from("c"), 16);
225 assert_eq!(s.vector_size(), 16);
226 }
227
228 #[tokio::test]
229 async fn test_insert_and_count() {
230 let s = QdrantStore::new("http://x", "c", 2);
231 s.insert(entry("a", vec![1.0, 0.0], None)).await.unwrap();
232 assert_eq!(s.count().await.unwrap(), 1);
233 }
234
235 #[tokio::test]
236 async fn test_insert_rejects_bad_dim() {
237 let s = QdrantStore::new("http://x", "c", 3);
238 let bad = entry("bad", vec![1.0, 0.0], None);
239 assert!(s.insert(bad).await.is_err());
240 }
241
242 #[tokio::test]
243 async fn test_insert_allows_empty_embedding() {
244 let s = QdrantStore::new("http://x", "c", 3);
246 let e = entry("pending", vec![], None);
247 assert!(s.insert(e).await.is_ok());
248 }
249
250 #[tokio::test]
251 async fn test_insert_many() {
252 let s = QdrantStore::new("http://x", "c", 2);
253 for i in 0..12 {
254 s.insert(entry(&format!("e{i}"), vec![1.0, i as f32], None))
255 .await
256 .unwrap();
257 }
258 assert_eq!(s.count().await.unwrap(), 12);
259 }
260
261 #[tokio::test]
262 async fn test_search_orders_by_similarity() {
263 let s = QdrantStore::new("http://x", "c", 3);
264 s.insert(entry("near", vec![0.9, 0.1, 0.0], None))
265 .await
266 .unwrap();
267 s.insert(entry("far", vec![0.0, 0.0, 1.0], None))
268 .await
269 .unwrap();
270 let r = s.search(&[1.0, 0.0, 0.0], 2, None).await.unwrap();
271 assert_eq!(r[0].entry.content, "near");
272 }
273
274 #[tokio::test]
275 async fn test_search_top_k() {
276 let s = QdrantStore::new("http://x", "c", 2);
277 for i in 0..6 {
278 s.insert(entry(&format!("e{i}"), vec![1.0, i as f32], None))
279 .await
280 .unwrap();
281 }
282 let r = s.search(&[1.0, 0.0], 2, None).await.unwrap();
283 assert_eq!(r.len(), 2);
284 }
285
286 #[tokio::test]
287 async fn test_search_empty_query_errors() {
288 let s = QdrantStore::new("http://x", "c", 2);
289 assert!(s.search(&[], 1, None).await.is_err());
290 }
291
292 #[tokio::test]
293 async fn test_search_session_filter() {
294 let s = QdrantStore::new("http://x", "c", 2);
295 let sid = Uuid::new_v4();
296 s.insert(entry("s", vec![1.0, 0.0], Some(sid)))
297 .await
298 .unwrap();
299 s.insert(entry("o", vec![1.0, 0.0], None)).await.unwrap();
300 let r = s.search(&[1.0, 0.0], 5, Some(sid)).await.unwrap();
301 assert_eq!(r.len(), 1);
302 assert_eq!(r[0].entry.content, "s");
303 }
304
305 #[tokio::test]
306 async fn test_delete_existing() {
307 let s = QdrantStore::new("http://x", "c", 2);
308 let e = entry("x", vec![1.0, 0.0], None);
309 let id = e.id;
310 s.insert(e).await.unwrap();
311 assert!(s.delete(id).await.unwrap());
312 }
313
314 #[tokio::test]
315 async fn test_delete_missing() {
316 let s = QdrantStore::new("http://x", "c", 2);
317 assert!(!s.delete(Uuid::new_v4()).await.unwrap());
318 }
319
320 #[tokio::test]
321 async fn test_list_all() {
322 let s = QdrantStore::new("http://x", "c", 2);
323 s.insert(entry("a", vec![1.0, 0.0], None)).await.unwrap();
324 s.insert(entry("b", vec![0.0, 1.0], None)).await.unwrap();
325 let all = s.list(None).await.unwrap();
326 assert_eq!(all.len(), 2);
327 }
328
329 #[tokio::test]
330 async fn test_list_filtered() {
331 let s = QdrantStore::new("http://x", "c", 2);
332 let sid = Uuid::new_v4();
333 s.insert(entry("a", vec![1.0, 0.0], Some(sid)))
334 .await
335 .unwrap();
336 s.insert(entry("b", vec![0.0, 1.0], None)).await.unwrap();
337 let f = s.list(Some(sid)).await.unwrap();
338 assert_eq!(f.len(), 1);
339 }
340
341 #[tokio::test]
342 async fn test_metadata_preserved() {
343 let s = QdrantStore::new("http://x", "c", 2);
344 let mut e = entry("m", vec![1.0, 0.0], None);
345 e.metadata.insert("k".into(), serde_json::json!(42));
346 let id = e.id;
347 s.insert(e).await.unwrap();
348 let all = s.list(None).await.unwrap();
349 let got = all.iter().find(|x| x.id == id).unwrap();
350 assert_eq!(got.metadata.get("k").unwrap(), &serde_json::json!(42));
351 }
352
353 #[tokio::test]
354 async fn test_empty_search_result() {
355 let s = QdrantStore::new("http://x", "c", 2);
356 let r = s.search(&[1.0, 0.0], 5, None).await.unwrap();
357 assert!(r.is_empty());
358 }
359
360 #[tokio::test]
361 async fn test_count_after_deletes() {
362 let s = QdrantStore::new("http://x", "c", 2);
363 let e = entry("a", vec![1.0, 0.0], None);
364 let id = e.id;
365 s.insert(e).await.unwrap();
366 s.insert(entry("b", vec![0.0, 1.0], None)).await.unwrap();
367 s.delete(id).await.unwrap();
368 assert_eq!(s.count().await.unwrap(), 1);
369 }
370}