Skip to main content

argentor_memory/
qdrant.rs

1//! Qdrant vector store adapter.
2//!
3//! Qdrant is an open-source vector database (Rust-native) with REST and
4//! gRPC APIs. This adapter implements the [`VectorStore`] trait with a stub
5//! backend; real HTTP calls are gated behind the `http-vectorstore` feature.
6
7use crate::store::{MemoryEntry, SearchResult, VectorStore};
8use argentor_core::{ArgentorError, ArgentorResult};
9use async_trait::async_trait;
10use std::collections::HashMap;
11use tokio::sync::RwLock;
12use uuid::Uuid;
13
14/// Qdrant vector store adapter.
15pub struct QdrantStore {
16    /// Qdrant endpoint (e.g., "http://localhost:6333").
17    #[allow(dead_code)]
18    endpoint: String,
19    /// Optional API key (Qdrant Cloud).
20    #[allow(dead_code)]
21    api_key: Option<String>,
22    /// Collection name.
23    #[allow(dead_code)]
24    collection_name: String,
25    /// Vector dimension for the collection.
26    #[allow(dead_code)]
27    vector_size: usize,
28    /// HTTP client — `None` in stub mode.
29    #[cfg(feature = "http-vectorstore")]
30    #[allow(dead_code)]
31    client: Option<reqwest::Client>,
32    /// Stub in-memory storage.
33    entries: RwLock<HashMap<Uuid, MemoryEntry>>,
34}
35
36impl QdrantStore {
37    /// Create a new Qdrant adapter in stub mode.
38    pub fn new(
39        endpoint: impl Into<String>,
40        collection_name: impl Into<String>,
41        vector_size: usize,
42    ) -> Self {
43        Self {
44            endpoint: endpoint.into(),
45            api_key: None,
46            collection_name: collection_name.into(),
47            vector_size,
48            #[cfg(feature = "http-vectorstore")]
49            client: None,
50            entries: RwLock::new(HashMap::new()),
51        }
52    }
53
54    /// Attach an API key (Qdrant Cloud).
55    pub fn with_api_key(mut self, key: impl Into<String>) -> Self {
56        self.api_key = Some(key.into());
57        self
58    }
59
60    /// Return the configured endpoint.
61    pub fn endpoint(&self) -> &str {
62        &self.endpoint
63    }
64
65    /// Return the configured collection name.
66    pub fn collection_name(&self) -> &str {
67        &self.collection_name
68    }
69
70    /// Return the configured vector size.
71    pub fn vector_size(&self) -> usize {
72        self.vector_size
73    }
74
75    /// Return whether an API key is configured.
76    pub fn has_api_key(&self) -> bool {
77        self.api_key.is_some()
78    }
79
80    /// Enable real HTTP mode with a [`reqwest::Client`].
81    #[cfg(feature = "http-vectorstore")]
82    pub fn with_http_client(mut self, client: reqwest::Client) -> Self {
83        self.client = Some(client);
84        self
85    }
86
87    /// Build the points upsert URL for this collection.
88    #[cfg(feature = "http-vectorstore")]
89    #[allow(dead_code)]
90    fn upsert_url(&self) -> String {
91        format!(
92            "{}/collections/{}/points",
93            self.endpoint.trim_end_matches('/'),
94            self.collection_name
95        )
96    }
97}
98
99#[async_trait]
100impl VectorStore for QdrantStore {
101    async fn insert(&self, entry: MemoryEntry) -> ArgentorResult<()> {
102        if !entry.embedding.is_empty() && entry.embedding.len() != self.vector_size {
103            return Err(ArgentorError::Agent(format!(
104                "Qdrant: embedding dim mismatch (got {}, expected {})",
105                entry.embedding.len(),
106                self.vector_size
107            )));
108        }
109        let mut entries = self.entries.write().await;
110        entries.insert(entry.id, entry);
111        Ok(())
112    }
113
114    async fn search(
115        &self,
116        query_embedding: &[f32],
117        top_k: usize,
118        session_filter: Option<Uuid>,
119    ) -> ArgentorResult<Vec<SearchResult>> {
120        if query_embedding.is_empty() {
121            return Err(ArgentorError::Agent("Empty query embedding".to_string()));
122        }
123        let entries = self.entries.read().await;
124        let mut scored: Vec<SearchResult> = entries
125            .values()
126            .filter(|e| {
127                if let Some(sid) = session_filter {
128                    e.session_id == Some(sid)
129                } else {
130                    true
131                }
132            })
133            .map(|e| {
134                let score = cosine(query_embedding, &e.embedding);
135                SearchResult {
136                    entry: e.clone(),
137                    score,
138                }
139            })
140            .collect();
141        scored.sort_by(|a, b| {
142            b.score
143                .partial_cmp(&a.score)
144                .unwrap_or(std::cmp::Ordering::Equal)
145        });
146        scored.truncate(top_k);
147        Ok(scored)
148    }
149
150    async fn delete(&self, id: Uuid) -> ArgentorResult<bool> {
151        let mut entries = self.entries.write().await;
152        Ok(entries.remove(&id).is_some())
153    }
154
155    async fn list(&self, session_filter: Option<Uuid>) -> ArgentorResult<Vec<MemoryEntry>> {
156        let entries = self.entries.read().await;
157        Ok(entries
158            .values()
159            .filter(|e| {
160                if let Some(sid) = session_filter {
161                    e.session_id == Some(sid)
162                } else {
163                    true
164                }
165            })
166            .cloned()
167            .collect())
168    }
169
170    async fn count(&self) -> ArgentorResult<usize> {
171        let entries = self.entries.read().await;
172        Ok(entries.len())
173    }
174}
175
176fn cosine(a: &[f32], b: &[f32]) -> f32 {
177    if a.len() != b.len() {
178        return 0.0;
179    }
180    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
181    let na: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
182    let nb: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
183    if na == 0.0 || nb == 0.0 {
184        0.0
185    } else {
186        dot / (na * nb)
187    }
188}
189
190#[cfg(test)]
191#[allow(clippy::unwrap_used, clippy::expect_used)]
192mod tests {
193    use super::*;
194    use chrono::Utc;
195
196    fn entry(content: &str, emb: Vec<f32>, session: Option<Uuid>) -> MemoryEntry {
197        MemoryEntry {
198            id: Uuid::new_v4(),
199            content: content.to_string(),
200            embedding: emb,
201            metadata: HashMap::new(),
202            session_id: session,
203            created_at: Utc::now(),
204        }
205    }
206
207    #[test]
208    fn test_new_sets_fields() {
209        let s = QdrantStore::new("http://localhost:6333", "docs", 768);
210        assert_eq!(s.endpoint(), "http://localhost:6333");
211        assert_eq!(s.collection_name(), "docs");
212        assert_eq!(s.vector_size(), 768);
213        assert!(!s.has_api_key());
214    }
215
216    #[test]
217    fn test_with_api_key() {
218        let s = QdrantStore::new("http://x", "c", 3).with_api_key("tok");
219        assert!(s.has_api_key());
220    }
221
222    #[test]
223    fn test_accepts_owned_strings() {
224        let s = QdrantStore::new(String::from("http://x"), String::from("c"), 16);
225        assert_eq!(s.vector_size(), 16);
226    }
227
228    #[tokio::test]
229    async fn test_insert_and_count() {
230        let s = QdrantStore::new("http://x", "c", 2);
231        s.insert(entry("a", vec![1.0, 0.0], None)).await.unwrap();
232        assert_eq!(s.count().await.unwrap(), 1);
233    }
234
235    #[tokio::test]
236    async fn test_insert_rejects_bad_dim() {
237        let s = QdrantStore::new("http://x", "c", 3);
238        let bad = entry("bad", vec![1.0, 0.0], None);
239        assert!(s.insert(bad).await.is_err());
240    }
241
242    #[tokio::test]
243    async fn test_insert_allows_empty_embedding() {
244        // Empty embedding is allowed (e.g., for upsert-later patterns).
245        let s = QdrantStore::new("http://x", "c", 3);
246        let e = entry("pending", vec![], None);
247        assert!(s.insert(e).await.is_ok());
248    }
249
250    #[tokio::test]
251    async fn test_insert_many() {
252        let s = QdrantStore::new("http://x", "c", 2);
253        for i in 0..12 {
254            s.insert(entry(&format!("e{i}"), vec![1.0, i as f32], None))
255                .await
256                .unwrap();
257        }
258        assert_eq!(s.count().await.unwrap(), 12);
259    }
260
261    #[tokio::test]
262    async fn test_search_orders_by_similarity() {
263        let s = QdrantStore::new("http://x", "c", 3);
264        s.insert(entry("near", vec![0.9, 0.1, 0.0], None))
265            .await
266            .unwrap();
267        s.insert(entry("far", vec![0.0, 0.0, 1.0], None))
268            .await
269            .unwrap();
270        let r = s.search(&[1.0, 0.0, 0.0], 2, None).await.unwrap();
271        assert_eq!(r[0].entry.content, "near");
272    }
273
274    #[tokio::test]
275    async fn test_search_top_k() {
276        let s = QdrantStore::new("http://x", "c", 2);
277        for i in 0..6 {
278            s.insert(entry(&format!("e{i}"), vec![1.0, i as f32], None))
279                .await
280                .unwrap();
281        }
282        let r = s.search(&[1.0, 0.0], 2, None).await.unwrap();
283        assert_eq!(r.len(), 2);
284    }
285
286    #[tokio::test]
287    async fn test_search_empty_query_errors() {
288        let s = QdrantStore::new("http://x", "c", 2);
289        assert!(s.search(&[], 1, None).await.is_err());
290    }
291
292    #[tokio::test]
293    async fn test_search_session_filter() {
294        let s = QdrantStore::new("http://x", "c", 2);
295        let sid = Uuid::new_v4();
296        s.insert(entry("s", vec![1.0, 0.0], Some(sid)))
297            .await
298            .unwrap();
299        s.insert(entry("o", vec![1.0, 0.0], None)).await.unwrap();
300        let r = s.search(&[1.0, 0.0], 5, Some(sid)).await.unwrap();
301        assert_eq!(r.len(), 1);
302        assert_eq!(r[0].entry.content, "s");
303    }
304
305    #[tokio::test]
306    async fn test_delete_existing() {
307        let s = QdrantStore::new("http://x", "c", 2);
308        let e = entry("x", vec![1.0, 0.0], None);
309        let id = e.id;
310        s.insert(e).await.unwrap();
311        assert!(s.delete(id).await.unwrap());
312    }
313
314    #[tokio::test]
315    async fn test_delete_missing() {
316        let s = QdrantStore::new("http://x", "c", 2);
317        assert!(!s.delete(Uuid::new_v4()).await.unwrap());
318    }
319
320    #[tokio::test]
321    async fn test_list_all() {
322        let s = QdrantStore::new("http://x", "c", 2);
323        s.insert(entry("a", vec![1.0, 0.0], None)).await.unwrap();
324        s.insert(entry("b", vec![0.0, 1.0], None)).await.unwrap();
325        let all = s.list(None).await.unwrap();
326        assert_eq!(all.len(), 2);
327    }
328
329    #[tokio::test]
330    async fn test_list_filtered() {
331        let s = QdrantStore::new("http://x", "c", 2);
332        let sid = Uuid::new_v4();
333        s.insert(entry("a", vec![1.0, 0.0], Some(sid)))
334            .await
335            .unwrap();
336        s.insert(entry("b", vec![0.0, 1.0], None)).await.unwrap();
337        let f = s.list(Some(sid)).await.unwrap();
338        assert_eq!(f.len(), 1);
339    }
340
341    #[tokio::test]
342    async fn test_metadata_preserved() {
343        let s = QdrantStore::new("http://x", "c", 2);
344        let mut e = entry("m", vec![1.0, 0.0], None);
345        e.metadata.insert("k".into(), serde_json::json!(42));
346        let id = e.id;
347        s.insert(e).await.unwrap();
348        let all = s.list(None).await.unwrap();
349        let got = all.iter().find(|x| x.id == id).unwrap();
350        assert_eq!(got.metadata.get("k").unwrap(), &serde_json::json!(42));
351    }
352
353    #[tokio::test]
354    async fn test_empty_search_result() {
355        let s = QdrantStore::new("http://x", "c", 2);
356        let r = s.search(&[1.0, 0.0], 5, None).await.unwrap();
357        assert!(r.is_empty());
358    }
359
360    #[tokio::test]
361    async fn test_count_after_deletes() {
362        let s = QdrantStore::new("http://x", "c", 2);
363        let e = entry("a", vec![1.0, 0.0], None);
364        let id = e.id;
365        s.insert(e).await.unwrap();
366        s.insert(entry("b", vec![0.0, 1.0], None)).await.unwrap();
367        s.delete(id).await.unwrap();
368        assert_eq!(s.count().await.unwrap(), 1);
369    }
370}