Skip to main content

autoagents_qdrant/
lib.rs

1use std::collections::HashMap;
2
3use async_trait::async_trait;
4use autoagents_core::embeddings::{Embed, Embedding, EmbeddingError, SharedEmbeddingProvider};
5use autoagents_core::one_or_many::OneOrMany;
6use autoagents_core::vector_store::request::{Filter, FilterError};
7use autoagents_core::vector_store::{
8    DEFAULT_VECTOR_NAME, NamedVectorDocument, PreparedDocument, VectorSearchRequest,
9    VectorStoreError, VectorStoreIndex, embed_documents, embed_named_documents, normalize_id,
10};
11use qdrant_client::Payload;
12use qdrant_client::Qdrant;
13use qdrant_client::qdrant::{
14    Condition, CreateCollectionBuilder, DeletePointsBuilder, Distance, Filter as QdrantFilter,
15    PointStruct, Range, SearchPointsBuilder, UpsertPointsBuilder, VectorParamsBuilder,
16    VectorsConfigBuilder, condition, with_payload_selector,
17};
18use serde::{Deserialize, Serialize};
19use uuid::Uuid;
20
21#[derive(Clone)]
22pub struct QdrantVectorStore {
23    client: Qdrant,
24    collection_name: String,
25    provider: SharedEmbeddingProvider,
26}
27
28impl QdrantVectorStore {
29    fn stable_point_id(source_id: &str) -> String {
30        // Qdrant point ids are UUID/u64. Convert arbitrary logical ids
31        // (e.g. "path:start:end") into a deterministic UUIDv5.
32        Uuid::new_v5(&Uuid::NAMESPACE_URL, source_id.as_bytes()).to_string()
33    }
34
35    pub fn new(
36        provider: SharedEmbeddingProvider,
37        url: impl Into<String>,
38        collection_name: impl Into<String>,
39    ) -> Result<Self, VectorStoreError> {
40        Self::with_api_key(provider, url, collection_name, None)
41    }
42
43    pub fn with_api_key(
44        provider: SharedEmbeddingProvider,
45        url: impl Into<String>,
46        collection_name: impl Into<String>,
47        api_key: Option<String>,
48    ) -> Result<Self, VectorStoreError> {
49        let url = url.into();
50        let builder = Qdrant::from_url(&url);
51        let client = if let Some(key) = api_key {
52            builder
53                .api_key(key)
54                .build()
55                .map_err(|err| VectorStoreError::DatastoreError(Box::new(err)))?
56        } else {
57            builder
58                .build()
59                .map_err(|err| VectorStoreError::DatastoreError(Box::new(err)))?
60        };
61
62        Ok(Self {
63            client,
64            collection_name: collection_name.into(),
65            provider,
66        })
67    }
68
69    async fn ensure_collection(&self, dimension: u64) -> Result<(), VectorStoreError> {
70        let request = CreateCollectionBuilder::new(self.collection_name.clone())
71            .vectors_config(VectorParamsBuilder::new(dimension, Distance::Cosine))
72            .build();
73
74        let result = self.client.create_collection(request).await;
75        if let Err(err) = result {
76            // Ignore already existing collections to keep the operation idempotent.
77            let message = err.to_string();
78            if !message.contains("already exists") {
79                return Err(VectorStoreError::DatastoreError(Box::new(err)));
80            }
81        }
82
83        Ok(())
84    }
85
86    async fn ensure_named_collection(
87        &self,
88        dimensions: &HashMap<String, u64>,
89    ) -> Result<(), VectorStoreError> {
90        let request = Self::named_collection_request(&self.collection_name, dimensions);
91
92        let result = self.client.create_collection(request).await;
93        if let Err(err) = result {
94            let message = err.to_string();
95            if !message.contains("already exists") {
96                return Err(VectorStoreError::DatastoreError(Box::new(err)));
97            }
98        }
99
100        Ok(())
101    }
102
103    fn named_collection_request(
104        collection_name: &str,
105        dimensions: &HashMap<String, u64>,
106    ) -> qdrant_client::qdrant::CreateCollection {
107        let mut config = VectorsConfigBuilder::default();
108        for (name, dimension) in dimensions {
109            config.add_named_vector_params(
110                name.clone(),
111                VectorParamsBuilder::new(*dimension, Distance::Cosine),
112            );
113        }
114
115        CreateCollectionBuilder::new(collection_name.to_string())
116            .vectors_config(config)
117            .build()
118    }
119
120    fn payload_for(doc: &PreparedDocument) -> Result<Payload, VectorStoreError> {
121        let payload = serde_json::json!({
122            "raw": doc.raw,
123            "source_id": doc.id,
124        });
125
126        Payload::try_from(payload).map_err(|err| VectorStoreError::DatastoreError(Box::new(err)))
127    }
128
129    fn decode_id(payload: &HashMap<String, qdrant_client::qdrant::Value>) -> Option<String> {
130        payload
131            .get("source_id")
132            .and_then(|value| serde_json::to_value(value).ok())
133            .and_then(|v| v.as_str().map(|id| id.to_string()))
134    }
135
136    fn decode_raw<T>(
137        payload: &HashMap<String, qdrant_client::qdrant::Value>,
138    ) -> Result<Option<T>, VectorStoreError>
139    where
140        T: for<'de> Deserialize<'de>,
141    {
142        if let Some(raw) = payload.get("raw") {
143            let value = serde_json::to_value(raw).map_err(VectorStoreError::JsonError)?;
144            let parsed = serde_json::from_value(value)?;
145            Ok(Some(parsed))
146        } else {
147            Ok(None)
148        }
149    }
150
151    /// Deletes documents using their logical/source IDs (the IDs used for upsert).
152    pub async fn delete_documents_by_ids(
153        &self,
154        source_ids: &[String],
155    ) -> Result<(), VectorStoreError> {
156        if source_ids.is_empty() {
157            return Ok(());
158        }
159
160        let point_ids = source_ids
161            .iter()
162            .map(|source_id| Self::stable_point_id(source_id))
163            .collect::<Vec<_>>();
164
165        self.client
166            .delete_points(
167                DeletePointsBuilder::new(self.collection_name.clone())
168                    .points(point_ids)
169                    .wait(true),
170            )
171            .await
172            .map_err(|err| VectorStoreError::DatastoreError(Box::new(err)))?;
173
174        Ok(())
175    }
176
177    /// Deletes this collection if it already exists.
178    pub async fn delete_collection_if_exists(&self) -> Result<(), VectorStoreError> {
179        let exists = self
180            .client
181            .collection_exists(self.collection_name.clone())
182            .await
183            .map_err(|err| VectorStoreError::DatastoreError(Box::new(err)))?;
184        if !exists {
185            return Ok(());
186        }
187
188        self.client
189            .delete_collection(self.collection_name.clone())
190            .await
191            .map_err(|err| VectorStoreError::DatastoreError(Box::new(err)))?;
192
193        Ok(())
194    }
195
196    fn named_dimensions(vectors: &HashMap<String, Vec<f32>>) -> HashMap<String, u64> {
197        vectors
198            .iter()
199            .map(|(name, vector)| (name.clone(), vector.len() as u64))
200            .collect()
201    }
202}
203
204#[async_trait]
205impl VectorStoreIndex for QdrantVectorStore {
206    type Filter = Filter<serde_json::Value>;
207
208    async fn insert_documents<T>(&self, documents: Vec<T>) -> Result<(), VectorStoreError>
209    where
210        T: Embed + Serialize + Send + Sync + Clone,
211    {
212        let docs: Vec<(String, T)> = documents
213            .into_iter()
214            .map(|doc| (normalize_id(None), doc))
215            .collect();
216        self.insert_documents_with_ids(docs).await
217    }
218
219    async fn insert_documents_with_ids<T>(
220        &self,
221        documents: Vec<(String, T)>,
222    ) -> Result<(), VectorStoreError>
223    where
224        T: Embed + Serialize + Send + Sync + Clone,
225    {
226        let normalized: Vec<(String, T)> = documents
227            .into_iter()
228            .map(|(id, doc)| (normalize_id(Some(id)), doc))
229            .collect();
230        let prepared = embed_documents(&self.provider, normalized).await?;
231        let Some(first) = prepared.first() else {
232            return Ok(());
233        };
234
235        let dim = first
236            .embeddings
237            .iter()
238            .next()
239            .map(|e| e.vec.len())
240            .unwrap_or(0);
241        self.ensure_collection(dim as u64).await?;
242
243        let mut points = Vec::new();
244        for doc in prepared {
245            let payload = Self::payload_for(&doc)?;
246            let vector = combine_embeddings(&doc.embeddings)?;
247
248            // Keep logical id in payload and map point id to a stable UUID.
249            let point_id = Self::stable_point_id(&doc.id);
250
251            points.push(PointStruct::new(point_id, vector, payload.clone()));
252        }
253
254        let request = UpsertPointsBuilder::new(self.collection_name.clone(), points).build();
255        self.client
256            .upsert_points(request)
257            .await
258            .map_err(|err| VectorStoreError::DatastoreError(Box::new(err)))?;
259
260        Ok(())
261    }
262
263    async fn top_n<T>(
264        &self,
265        req: VectorSearchRequest<Self::Filter>,
266    ) -> Result<Vec<(f64, String, T)>, VectorStoreError>
267    where
268        T: for<'de> Deserialize<'de> + Send + Sync,
269    {
270        let vectors = self
271            .provider
272            .embed(vec![req.query().to_string()])
273            .await
274            .map_err(EmbeddingError::Provider)?;
275
276        let Some(vector) = vectors.into_iter().next() else {
277            return Ok(Vec::new());
278        };
279
280        let mut search =
281            SearchPointsBuilder::new(self.collection_name.clone(), vector, req.samples())
282                .with_payload(with_payload_selector::SelectorOptions::Enable(true));
283
284        if let Some(vector_name) = req.query_vector_name()
285            && vector_name != DEFAULT_VECTOR_NAME
286        {
287            search = search.vector_name(vector_name.to_string());
288        }
289
290        if let Some(filter) = req.filter() {
291            search = search.filter(to_qdrant_filter(filter.clone())?);
292        }
293
294        if let Some(threshold) = req.threshold() {
295            search = search.score_threshold(threshold as f32);
296        }
297
298        let response = self
299            .client
300            .search_points(search)
301            .await
302            .map_err(|err| VectorStoreError::DatastoreError(Box::new(err)))?;
303
304        let mut results = Vec::new();
305        for point in response.result {
306            let id = Self::decode_id(&point.payload)
307                .or_else(|| point.id.map(|id| format!("{id:?}")))
308                .unwrap_or_default();
309
310            if let Some(raw) = Self::decode_raw::<T>(&point.payload)? {
311                results.push((point.score as f64, id, raw));
312            }
313        }
314
315        Ok(results)
316    }
317
318    async fn top_n_ids(
319        &self,
320        req: VectorSearchRequest<Self::Filter>,
321    ) -> Result<Vec<(f64, String)>, VectorStoreError> {
322        let vectors = self
323            .provider
324            .embed(vec![req.query().to_string()])
325            .await
326            .map_err(EmbeddingError::Provider)?;
327
328        let Some(vector) = vectors.into_iter().next() else {
329            return Ok(Vec::new());
330        };
331
332        let mut search =
333            SearchPointsBuilder::new(self.collection_name.clone(), vector, req.samples())
334                .with_payload(with_payload_selector::SelectorOptions::Enable(true));
335
336        if let Some(vector_name) = req.query_vector_name()
337            && vector_name != DEFAULT_VECTOR_NAME
338        {
339            search = search.vector_name(vector_name.to_string());
340        }
341
342        if let Some(filter) = req.filter() {
343            search = search.filter(to_qdrant_filter(filter.clone())?);
344        }
345
346        if let Some(threshold) = req.threshold() {
347            search = search.score_threshold(threshold as f32);
348        }
349
350        let response = self
351            .client
352            .search_points(search)
353            .await
354            .map_err(|err| VectorStoreError::DatastoreError(Box::new(err)))?;
355
356        let mut results = Vec::new();
357        for point in response.result {
358            let id = Self::decode_id(&point.payload)
359                .or_else(|| point.id.map(|id| format!("{id:?}")))
360                .unwrap_or_default();
361            results.push((point.score as f64, id));
362        }
363
364        Ok(results)
365    }
366
367    async fn insert_documents_with_named_vectors<T>(
368        &self,
369        documents: Vec<NamedVectorDocument<T>>,
370    ) -> Result<(), VectorStoreError>
371    where
372        T: Serialize + Send + Sync + Clone,
373    {
374        let normalized = documents
375            .into_iter()
376            .map(|doc| NamedVectorDocument {
377                id: normalize_id(Some(doc.id)),
378                raw: doc.raw,
379                vectors: doc.vectors,
380            })
381            .collect::<Vec<_>>();
382
383        let prepared = embed_named_documents(&self.provider, normalized).await?;
384        let Some(first) = prepared.first() else {
385            return Ok(());
386        };
387
388        let dimensions = Self::named_dimensions(&first.vectors);
389        self.ensure_named_collection(&dimensions).await?;
390
391        let mut points = Vec::new();
392        for doc in prepared {
393            let source_id = doc.id.clone();
394            let payload = Payload::try_from(serde_json::json!({
395                "raw": doc.raw,
396                "source_id": source_id,
397            }))
398            .map_err(|err| VectorStoreError::DatastoreError(Box::new(err)))?;
399            let point_id = Self::stable_point_id(&source_id);
400            points.push(PointStruct::new(point_id, doc.vectors, payload));
401        }
402
403        let request = UpsertPointsBuilder::new(self.collection_name.clone(), points).build();
404        self.client
405            .upsert_points(request)
406            .await
407            .map_err(|err| VectorStoreError::DatastoreError(Box::new(err)))?;
408
409        Ok(())
410    }
411}
412
413fn to_qdrant_filter(filter: Filter<serde_json::Value>) -> Result<QdrantFilter, VectorStoreError> {
414    use Filter::*;
415
416    let empty = || QdrantFilter {
417        must: Vec::new(),
418        should: Vec::new(),
419        must_not: Vec::new(),
420        min_should: None,
421    };
422
423    match filter {
424        Eq(key, value) => {
425            let mut filter = empty();
426            filter
427                .must
428                .push(Condition::matches(key, value_to_match_value(value)?));
429            Ok(filter)
430        }
431        Gt(key, value) => {
432            let mut filter = empty();
433            filter.must.push(Condition::range(
434                key,
435                Range {
436                    gt: Some(number_to_f64(&value)?),
437                    gte: None,
438                    lt: None,
439                    lte: None,
440                },
441            ));
442            Ok(filter)
443        }
444        Lt(key, value) => {
445            let mut filter = empty();
446            filter.must.push(Condition::range(
447                key,
448                Range {
449                    lt: Some(number_to_f64(&value)?),
450                    lte: None,
451                    gt: None,
452                    gte: None,
453                },
454            ));
455            Ok(filter)
456        }
457        And(lhs, rhs) => {
458            let mut left = to_qdrant_filter(*lhs)?;
459            let right = to_qdrant_filter(*rhs)?;
460
461            left.must.extend(right.must);
462            left.must.extend(right.should);
463            Ok(left)
464        }
465        Or(lhs, rhs) => {
466            let left = to_qdrant_filter(*lhs)?;
467            let right = to_qdrant_filter(*rhs)?;
468
469            Ok(QdrantFilter {
470                should: vec![
471                    Condition {
472                        condition_one_of: Some(condition::ConditionOneOf::Filter(left)),
473                    },
474                    Condition {
475                        condition_one_of: Some(condition::ConditionOneOf::Filter(right)),
476                    },
477                ],
478                must: Vec::new(),
479                must_not: Vec::new(),
480                min_should: None,
481            })
482        }
483    }
484}
485
486fn value_to_match_value(
487    value: serde_json::Value,
488) -> Result<qdrant_client::qdrant::r#match::MatchValue, VectorStoreError> {
489    use qdrant_client::qdrant::r#match::MatchValue;
490    match value {
491        serde_json::Value::String(s) => Ok(MatchValue::Keyword(s)),
492        serde_json::Value::Number(num) => {
493            if let Some(i) = num.as_i64() {
494                Ok(MatchValue::Integer(i))
495            } else if let Some(f) = num.as_f64() {
496                Ok(MatchValue::Keyword(f.to_string()))
497            } else {
498                Err(FilterError::TypeError("Unsupported number".into()).into())
499            }
500        }
501        serde_json::Value::Bool(b) => Ok(MatchValue::Boolean(b)),
502        other => Err(FilterError::TypeError(format!("Unsupported filter value {other:?}")).into()),
503    }
504}
505
506fn number_to_f64(value: &serde_json::Value) -> Result<f64, VectorStoreError> {
507    value
508        .as_f64()
509        .or_else(|| value.as_i64().map(|v| v as f64))
510        .ok_or_else(|| FilterError::TypeError(format!("Expected number, got {value:?}")).into())
511}
512
513fn combine_embeddings(embeddings: &OneOrMany<Embedding>) -> Result<Vec<f32>, VectorStoreError> {
514    match embeddings {
515        OneOrMany::One(embedding) => Ok(embedding.vec.to_vec()),
516        OneOrMany::Many(list) => {
517            let Some(first) = list.first() else {
518                return Err(VectorStoreError::EmbeddingError(
519                    EmbeddingError::EmbedFailure("no embeddings".into()),
520                ));
521            };
522
523            let dim = first.vec.len();
524            let mut sum = vec![0.0; dim];
525            for embedding in list {
526                if embedding.vec.len() != dim {
527                    return Err(VectorStoreError::EmbeddingError(
528                        EmbeddingError::EmbedFailure("inconsistent embedding dimensions".into()),
529                    ));
530                }
531                for (i, value) in embedding.vec.iter().enumerate() {
532                    sum[i] += value;
533                }
534            }
535
536            let count = list.len() as f32;
537            for value in &mut sum {
538                *value /= count;
539            }
540
541            Ok(sum)
542        }
543    }
544}
545
546#[cfg(test)]
547mod tests {
548    use super::*;
549    use autoagents_core::embeddings::Embedding;
550    use autoagents_core::one_or_many::OneOrMany;
551    use autoagents_core::vector_store::request::{Filter, SearchFilter};
552    use std::sync::Arc;
553
554    #[test]
555    fn test_stable_point_id_deterministic() {
556        let id1 = QdrantVectorStore::stable_point_id("doc:1");
557        let id2 = QdrantVectorStore::stable_point_id("doc:1");
558        let id3 = QdrantVectorStore::stable_point_id("doc:2");
559        assert_eq!(id1, id2);
560        assert_ne!(id1, id3);
561    }
562
563    #[test]
564    fn test_payload_encode_decode() {
565        #[derive(Debug, Clone, serde::Deserialize)]
566        struct TestDoc {
567            name: String,
568        }
569
570        let doc = PreparedDocument {
571            id: "doc-1".to_string(),
572            raw: serde_json::json!({"name":"alpha"}),
573            embeddings: OneOrMany::One(Embedding {
574                document: "alpha".to_string(),
575                vec: Arc::from(vec![0.1_f32, 0.2_f32]),
576            }),
577        };
578
579        let payload = QdrantVectorStore::payload_for(&doc).unwrap();
580        let payload_map: HashMap<String, qdrant_client::qdrant::Value> = payload.clone().into();
581        let decoded_id = QdrantVectorStore::decode_id(&payload_map).unwrap();
582        assert_eq!(decoded_id, "doc-1");
583
584        let decoded: Option<TestDoc> = QdrantVectorStore::decode_raw(&payload_map).unwrap();
585        assert_eq!(decoded.unwrap().name, "alpha");
586    }
587
588    #[test]
589    fn test_named_dimensions() {
590        let vectors = HashMap::from([
591            ("a".to_string(), vec![0.1_f32, 0.2_f32]),
592            ("b".to_string(), vec![1.0_f32]),
593        ]);
594        let dims = QdrantVectorStore::named_dimensions(&vectors);
595        assert_eq!(dims.get("a"), Some(&2));
596        assert_eq!(dims.get("b"), Some(&1));
597    }
598
599    #[test]
600    fn test_number_to_f64() {
601        assert_eq!(number_to_f64(&serde_json::json!(1)).unwrap(), 1.0);
602        assert_eq!(number_to_f64(&serde_json::json!(1.5)).unwrap(), 1.5);
603        assert!(number_to_f64(&serde_json::json!("x")).is_err());
604    }
605
606    #[test]
607    fn test_value_to_match_value() {
608        let m = value_to_match_value(serde_json::json!("a")).unwrap();
609        match m {
610            qdrant_client::qdrant::r#match::MatchValue::Keyword(val) => assert_eq!(val, "a"),
611            _ => panic!("expected keyword"),
612        }
613
614        let m = value_to_match_value(serde_json::json!(true)).unwrap();
615        match m {
616            qdrant_client::qdrant::r#match::MatchValue::Boolean(val) => assert!(val),
617            _ => panic!("expected boolean"),
618        }
619    }
620
621    #[test]
622    fn test_value_to_match_value_numbers_and_errors() {
623        let m = value_to_match_value(serde_json::json!(42)).unwrap();
624        match m {
625            qdrant_client::qdrant::r#match::MatchValue::Integer(val) => assert_eq!(val, 42),
626            _ => panic!("expected integer"),
627        }
628
629        let m = value_to_match_value(serde_json::json!(1.5)).unwrap();
630        match m {
631            qdrant_client::qdrant::r#match::MatchValue::Keyword(val) => assert_eq!(val, "1.5"),
632            _ => panic!("expected keyword"),
633        }
634
635        assert!(value_to_match_value(serde_json::json!([1, 2, 3])).is_err());
636    }
637
638    #[test]
639    fn test_to_qdrant_filter_lt() {
640        let filter = Filter::Lt("num".to_string(), serde_json::json!(10));
641        let qdrant = to_qdrant_filter(filter).unwrap();
642        assert_eq!(qdrant.must.len(), 1);
643    }
644
645    #[test]
646    fn test_to_qdrant_filter_and_or() {
647        let filter = Filter::Eq("field".to_string(), serde_json::json!("x"))
648            .and(Filter::Gt("num".to_string(), serde_json::json!(2)));
649        let qdrant = to_qdrant_filter(filter).unwrap();
650        assert_eq!(qdrant.must.len(), 2);
651
652        let filter = Filter::Eq("field".to_string(), serde_json::json!("x"))
653            .or(Filter::Lt("num".to_string(), serde_json::json!(10)));
654        let qdrant = to_qdrant_filter(filter).unwrap();
655        assert_eq!(qdrant.should.len(), 2);
656    }
657
658    #[test]
659    fn test_decode_helpers_missing_fields() {
660        let payload: HashMap<String, qdrant_client::qdrant::Value> = HashMap::new();
661        assert!(QdrantVectorStore::decode_id(&payload).is_none());
662        let raw: Option<serde_json::Value> = QdrantVectorStore::decode_raw(&payload).unwrap();
663        assert!(raw.is_none());
664    }
665
666    #[test]
667    fn test_to_qdrant_filter_eq_and_gt() {
668        let filter = Filter::Eq("tag".to_string(), serde_json::json!("alpha"));
669        let qdrant = to_qdrant_filter(filter).unwrap();
670        assert_eq!(qdrant.must.len(), 1);
671
672        let filter = Filter::Gt("score".to_string(), serde_json::json!(1.5));
673        let qdrant = to_qdrant_filter(filter).unwrap();
674        assert_eq!(qdrant.must.len(), 1);
675    }
676
677    #[test]
678    fn test_combine_embeddings() {
679        let one = OneOrMany::One(Embedding {
680            document: "doc".to_string(),
681            vec: Arc::from(vec![1.0_f32, 2.0_f32]),
682        });
683        let combined = combine_embeddings(&one).unwrap();
684        assert_eq!(combined, vec![1.0, 2.0]);
685
686        let many = OneOrMany::Many(vec![
687            Embedding {
688                document: "a".to_string(),
689                vec: Arc::from(vec![1.0_f32, 3.0_f32]),
690            },
691            Embedding {
692                document: "b".to_string(),
693                vec: Arc::from(vec![3.0_f32, 5.0_f32]),
694            },
695        ]);
696        let combined = combine_embeddings(&many).unwrap();
697        assert_eq!(combined, vec![2.0, 4.0]);
698    }
699
700    #[test]
701    fn test_combine_embeddings_dimension_mismatch() {
702        let many = OneOrMany::Many(vec![
703            Embedding {
704                document: "a".to_string(),
705                vec: Arc::from(vec![1.0_f32, 2.0_f32]),
706            },
707            Embedding {
708                document: "b".to_string(),
709                vec: Arc::from(vec![1.0_f32]),
710            },
711        ]);
712        let err = combine_embeddings(&many).unwrap_err();
713        assert!(
714            err.to_string()
715                .contains("inconsistent embedding dimensions")
716        );
717    }
718}