1use std::collections::HashMap;
2
3use async_trait::async_trait;
4use autoagents_core::embeddings::{Embed, Embedding, EmbeddingError, SharedEmbeddingProvider};
5use autoagents_core::one_or_many::OneOrMany;
6use autoagents_core::vector_store::request::{Filter, FilterError};
7use autoagents_core::vector_store::{
8 DEFAULT_VECTOR_NAME, NamedVectorDocument, PreparedDocument, VectorSearchRequest,
9 VectorStoreError, VectorStoreIndex, embed_documents, embed_named_documents, normalize_id,
10};
11use qdrant_client::Payload;
12use qdrant_client::Qdrant;
13use qdrant_client::qdrant::{
14 Condition, CreateCollectionBuilder, DeletePointsBuilder, Distance, Filter as QdrantFilter,
15 PointStruct, Range, SearchPointsBuilder, UpsertPointsBuilder, VectorParamsBuilder,
16 VectorsConfigBuilder, condition, with_payload_selector,
17};
18use serde::{Deserialize, Serialize};
19use uuid::Uuid;
20
21#[derive(Clone)]
22pub struct QdrantVectorStore {
23 client: Qdrant,
24 collection_name: String,
25 provider: SharedEmbeddingProvider,
26}
27
28impl QdrantVectorStore {
29 fn stable_point_id(source_id: &str) -> String {
30 Uuid::new_v5(&Uuid::NAMESPACE_URL, source_id.as_bytes()).to_string()
33 }
34
35 pub fn new(
36 provider: SharedEmbeddingProvider,
37 url: impl Into<String>,
38 collection_name: impl Into<String>,
39 ) -> Result<Self, VectorStoreError> {
40 Self::with_api_key(provider, url, collection_name, None)
41 }
42
43 pub fn with_api_key(
44 provider: SharedEmbeddingProvider,
45 url: impl Into<String>,
46 collection_name: impl Into<String>,
47 api_key: Option<String>,
48 ) -> Result<Self, VectorStoreError> {
49 let url = url.into();
50 let builder = Qdrant::from_url(&url);
51 let client = if let Some(key) = api_key {
52 builder
53 .api_key(key)
54 .build()
55 .map_err(|err| VectorStoreError::DatastoreError(Box::new(err)))?
56 } else {
57 builder
58 .build()
59 .map_err(|err| VectorStoreError::DatastoreError(Box::new(err)))?
60 };
61
62 Ok(Self {
63 client,
64 collection_name: collection_name.into(),
65 provider,
66 })
67 }
68
69 async fn ensure_collection(&self, dimension: u64) -> Result<(), VectorStoreError> {
70 let request = CreateCollectionBuilder::new(self.collection_name.clone())
71 .vectors_config(VectorParamsBuilder::new(dimension, Distance::Cosine))
72 .build();
73
74 let result = self.client.create_collection(request).await;
75 if let Err(err) = result {
76 let message = err.to_string();
78 if !message.contains("already exists") {
79 return Err(VectorStoreError::DatastoreError(Box::new(err)));
80 }
81 }
82
83 Ok(())
84 }
85
86 async fn ensure_named_collection(
87 &self,
88 dimensions: &HashMap<String, u64>,
89 ) -> Result<(), VectorStoreError> {
90 let request = Self::named_collection_request(&self.collection_name, dimensions);
91
92 let result = self.client.create_collection(request).await;
93 if let Err(err) = result {
94 let message = err.to_string();
95 if !message.contains("already exists") {
96 return Err(VectorStoreError::DatastoreError(Box::new(err)));
97 }
98 }
99
100 Ok(())
101 }
102
103 fn named_collection_request(
104 collection_name: &str,
105 dimensions: &HashMap<String, u64>,
106 ) -> qdrant_client::qdrant::CreateCollection {
107 let mut config = VectorsConfigBuilder::default();
108 for (name, dimension) in dimensions {
109 config.add_named_vector_params(
110 name.clone(),
111 VectorParamsBuilder::new(*dimension, Distance::Cosine),
112 );
113 }
114
115 CreateCollectionBuilder::new(collection_name.to_string())
116 .vectors_config(config)
117 .build()
118 }
119
120 fn payload_for(doc: &PreparedDocument) -> Result<Payload, VectorStoreError> {
121 let payload = serde_json::json!({
122 "raw": doc.raw,
123 "source_id": doc.id,
124 });
125
126 Payload::try_from(payload).map_err(|err| VectorStoreError::DatastoreError(Box::new(err)))
127 }
128
129 fn decode_id(payload: &HashMap<String, qdrant_client::qdrant::Value>) -> Option<String> {
130 payload
131 .get("source_id")
132 .and_then(|value| serde_json::to_value(value).ok())
133 .and_then(|v| v.as_str().map(|id| id.to_string()))
134 }
135
136 fn decode_raw<T>(
137 payload: &HashMap<String, qdrant_client::qdrant::Value>,
138 ) -> Result<Option<T>, VectorStoreError>
139 where
140 T: for<'de> Deserialize<'de>,
141 {
142 if let Some(raw) = payload.get("raw") {
143 let value = serde_json::to_value(raw).map_err(VectorStoreError::JsonError)?;
144 let parsed = serde_json::from_value(value)?;
145 Ok(Some(parsed))
146 } else {
147 Ok(None)
148 }
149 }
150
151 pub async fn delete_documents_by_ids(
153 &self,
154 source_ids: &[String],
155 ) -> Result<(), VectorStoreError> {
156 if source_ids.is_empty() {
157 return Ok(());
158 }
159
160 let point_ids = source_ids
161 .iter()
162 .map(|source_id| Self::stable_point_id(source_id))
163 .collect::<Vec<_>>();
164
165 self.client
166 .delete_points(
167 DeletePointsBuilder::new(self.collection_name.clone())
168 .points(point_ids)
169 .wait(true),
170 )
171 .await
172 .map_err(|err| VectorStoreError::DatastoreError(Box::new(err)))?;
173
174 Ok(())
175 }
176
177 pub async fn delete_collection_if_exists(&self) -> Result<(), VectorStoreError> {
179 let exists = self
180 .client
181 .collection_exists(self.collection_name.clone())
182 .await
183 .map_err(|err| VectorStoreError::DatastoreError(Box::new(err)))?;
184 if !exists {
185 return Ok(());
186 }
187
188 self.client
189 .delete_collection(self.collection_name.clone())
190 .await
191 .map_err(|err| VectorStoreError::DatastoreError(Box::new(err)))?;
192
193 Ok(())
194 }
195
196 fn named_dimensions(vectors: &HashMap<String, Vec<f32>>) -> HashMap<String, u64> {
197 vectors
198 .iter()
199 .map(|(name, vector)| (name.clone(), vector.len() as u64))
200 .collect()
201 }
202}
203
204#[async_trait]
205impl VectorStoreIndex for QdrantVectorStore {
206 type Filter = Filter<serde_json::Value>;
207
208 async fn insert_documents<T>(&self, documents: Vec<T>) -> Result<(), VectorStoreError>
209 where
210 T: Embed + Serialize + Send + Sync + Clone,
211 {
212 let docs: Vec<(String, T)> = documents
213 .into_iter()
214 .map(|doc| (normalize_id(None), doc))
215 .collect();
216 self.insert_documents_with_ids(docs).await
217 }
218
219 async fn insert_documents_with_ids<T>(
220 &self,
221 documents: Vec<(String, T)>,
222 ) -> Result<(), VectorStoreError>
223 where
224 T: Embed + Serialize + Send + Sync + Clone,
225 {
226 let normalized: Vec<(String, T)> = documents
227 .into_iter()
228 .map(|(id, doc)| (normalize_id(Some(id)), doc))
229 .collect();
230 let prepared = embed_documents(&self.provider, normalized).await?;
231 let Some(first) = prepared.first() else {
232 return Ok(());
233 };
234
235 let dim = first
236 .embeddings
237 .iter()
238 .next()
239 .map(|e| e.vec.len())
240 .unwrap_or(0);
241 self.ensure_collection(dim as u64).await?;
242
243 let mut points = Vec::new();
244 for doc in prepared {
245 let payload = Self::payload_for(&doc)?;
246 let vector = combine_embeddings(&doc.embeddings)?;
247
248 let point_id = Self::stable_point_id(&doc.id);
250
251 points.push(PointStruct::new(point_id, vector, payload.clone()));
252 }
253
254 let request = UpsertPointsBuilder::new(self.collection_name.clone(), points).build();
255 self.client
256 .upsert_points(request)
257 .await
258 .map_err(|err| VectorStoreError::DatastoreError(Box::new(err)))?;
259
260 Ok(())
261 }
262
263 async fn top_n<T>(
264 &self,
265 req: VectorSearchRequest<Self::Filter>,
266 ) -> Result<Vec<(f64, String, T)>, VectorStoreError>
267 where
268 T: for<'de> Deserialize<'de> + Send + Sync,
269 {
270 let vectors = self
271 .provider
272 .embed(vec![req.query().to_string()])
273 .await
274 .map_err(EmbeddingError::Provider)?;
275
276 let Some(vector) = vectors.into_iter().next() else {
277 return Ok(Vec::new());
278 };
279
280 let mut search =
281 SearchPointsBuilder::new(self.collection_name.clone(), vector, req.samples())
282 .with_payload(with_payload_selector::SelectorOptions::Enable(true));
283
284 if let Some(vector_name) = req.query_vector_name()
285 && vector_name != DEFAULT_VECTOR_NAME
286 {
287 search = search.vector_name(vector_name.to_string());
288 }
289
290 if let Some(filter) = req.filter() {
291 search = search.filter(to_qdrant_filter(filter.clone())?);
292 }
293
294 if let Some(threshold) = req.threshold() {
295 search = search.score_threshold(threshold as f32);
296 }
297
298 let response = self
299 .client
300 .search_points(search)
301 .await
302 .map_err(|err| VectorStoreError::DatastoreError(Box::new(err)))?;
303
304 let mut results = Vec::new();
305 for point in response.result {
306 let id = Self::decode_id(&point.payload)
307 .or_else(|| point.id.map(|id| format!("{id:?}")))
308 .unwrap_or_default();
309
310 if let Some(raw) = Self::decode_raw::<T>(&point.payload)? {
311 results.push((point.score as f64, id, raw));
312 }
313 }
314
315 Ok(results)
316 }
317
318 async fn top_n_ids(
319 &self,
320 req: VectorSearchRequest<Self::Filter>,
321 ) -> Result<Vec<(f64, String)>, VectorStoreError> {
322 let vectors = self
323 .provider
324 .embed(vec![req.query().to_string()])
325 .await
326 .map_err(EmbeddingError::Provider)?;
327
328 let Some(vector) = vectors.into_iter().next() else {
329 return Ok(Vec::new());
330 };
331
332 let mut search =
333 SearchPointsBuilder::new(self.collection_name.clone(), vector, req.samples())
334 .with_payload(with_payload_selector::SelectorOptions::Enable(true));
335
336 if let Some(vector_name) = req.query_vector_name()
337 && vector_name != DEFAULT_VECTOR_NAME
338 {
339 search = search.vector_name(vector_name.to_string());
340 }
341
342 if let Some(filter) = req.filter() {
343 search = search.filter(to_qdrant_filter(filter.clone())?);
344 }
345
346 if let Some(threshold) = req.threshold() {
347 search = search.score_threshold(threshold as f32);
348 }
349
350 let response = self
351 .client
352 .search_points(search)
353 .await
354 .map_err(|err| VectorStoreError::DatastoreError(Box::new(err)))?;
355
356 let mut results = Vec::new();
357 for point in response.result {
358 let id = Self::decode_id(&point.payload)
359 .or_else(|| point.id.map(|id| format!("{id:?}")))
360 .unwrap_or_default();
361 results.push((point.score as f64, id));
362 }
363
364 Ok(results)
365 }
366
367 async fn insert_documents_with_named_vectors<T>(
368 &self,
369 documents: Vec<NamedVectorDocument<T>>,
370 ) -> Result<(), VectorStoreError>
371 where
372 T: Serialize + Send + Sync + Clone,
373 {
374 let normalized = documents
375 .into_iter()
376 .map(|doc| NamedVectorDocument {
377 id: normalize_id(Some(doc.id)),
378 raw: doc.raw,
379 vectors: doc.vectors,
380 })
381 .collect::<Vec<_>>();
382
383 let prepared = embed_named_documents(&self.provider, normalized).await?;
384 let Some(first) = prepared.first() else {
385 return Ok(());
386 };
387
388 let dimensions = Self::named_dimensions(&first.vectors);
389 self.ensure_named_collection(&dimensions).await?;
390
391 let mut points = Vec::new();
392 for doc in prepared {
393 let source_id = doc.id.clone();
394 let payload = Payload::try_from(serde_json::json!({
395 "raw": doc.raw,
396 "source_id": source_id,
397 }))
398 .map_err(|err| VectorStoreError::DatastoreError(Box::new(err)))?;
399 let point_id = Self::stable_point_id(&source_id);
400 points.push(PointStruct::new(point_id, doc.vectors, payload));
401 }
402
403 let request = UpsertPointsBuilder::new(self.collection_name.clone(), points).build();
404 self.client
405 .upsert_points(request)
406 .await
407 .map_err(|err| VectorStoreError::DatastoreError(Box::new(err)))?;
408
409 Ok(())
410 }
411}
412
413fn to_qdrant_filter(filter: Filter<serde_json::Value>) -> Result<QdrantFilter, VectorStoreError> {
414 use Filter::*;
415
416 let empty = || QdrantFilter {
417 must: Vec::new(),
418 should: Vec::new(),
419 must_not: Vec::new(),
420 min_should: None,
421 };
422
423 match filter {
424 Eq(key, value) => {
425 let mut filter = empty();
426 filter
427 .must
428 .push(Condition::matches(key, value_to_match_value(value)?));
429 Ok(filter)
430 }
431 Gt(key, value) => {
432 let mut filter = empty();
433 filter.must.push(Condition::range(
434 key,
435 Range {
436 gt: Some(number_to_f64(&value)?),
437 gte: None,
438 lt: None,
439 lte: None,
440 },
441 ));
442 Ok(filter)
443 }
444 Lt(key, value) => {
445 let mut filter = empty();
446 filter.must.push(Condition::range(
447 key,
448 Range {
449 lt: Some(number_to_f64(&value)?),
450 lte: None,
451 gt: None,
452 gte: None,
453 },
454 ));
455 Ok(filter)
456 }
457 And(lhs, rhs) => {
458 let mut left = to_qdrant_filter(*lhs)?;
459 let right = to_qdrant_filter(*rhs)?;
460
461 left.must.extend(right.must);
462 left.must.extend(right.should);
463 Ok(left)
464 }
465 Or(lhs, rhs) => {
466 let left = to_qdrant_filter(*lhs)?;
467 let right = to_qdrant_filter(*rhs)?;
468
469 Ok(QdrantFilter {
470 should: vec![
471 Condition {
472 condition_one_of: Some(condition::ConditionOneOf::Filter(left)),
473 },
474 Condition {
475 condition_one_of: Some(condition::ConditionOneOf::Filter(right)),
476 },
477 ],
478 must: Vec::new(),
479 must_not: Vec::new(),
480 min_should: None,
481 })
482 }
483 }
484}
485
486fn value_to_match_value(
487 value: serde_json::Value,
488) -> Result<qdrant_client::qdrant::r#match::MatchValue, VectorStoreError> {
489 use qdrant_client::qdrant::r#match::MatchValue;
490 match value {
491 serde_json::Value::String(s) => Ok(MatchValue::Keyword(s)),
492 serde_json::Value::Number(num) => {
493 if let Some(i) = num.as_i64() {
494 Ok(MatchValue::Integer(i))
495 } else if let Some(f) = num.as_f64() {
496 Ok(MatchValue::Keyword(f.to_string()))
497 } else {
498 Err(FilterError::TypeError("Unsupported number".into()).into())
499 }
500 }
501 serde_json::Value::Bool(b) => Ok(MatchValue::Boolean(b)),
502 other => Err(FilterError::TypeError(format!("Unsupported filter value {other:?}")).into()),
503 }
504}
505
506fn number_to_f64(value: &serde_json::Value) -> Result<f64, VectorStoreError> {
507 value
508 .as_f64()
509 .or_else(|| value.as_i64().map(|v| v as f64))
510 .ok_or_else(|| FilterError::TypeError(format!("Expected number, got {value:?}")).into())
511}
512
513fn combine_embeddings(embeddings: &OneOrMany<Embedding>) -> Result<Vec<f32>, VectorStoreError> {
514 match embeddings {
515 OneOrMany::One(embedding) => Ok(embedding.vec.to_vec()),
516 OneOrMany::Many(list) => {
517 let Some(first) = list.first() else {
518 return Err(VectorStoreError::EmbeddingError(
519 EmbeddingError::EmbedFailure("no embeddings".into()),
520 ));
521 };
522
523 let dim = first.vec.len();
524 let mut sum = vec![0.0; dim];
525 for embedding in list {
526 if embedding.vec.len() != dim {
527 return Err(VectorStoreError::EmbeddingError(
528 EmbeddingError::EmbedFailure("inconsistent embedding dimensions".into()),
529 ));
530 }
531 for (i, value) in embedding.vec.iter().enumerate() {
532 sum[i] += value;
533 }
534 }
535
536 let count = list.len() as f32;
537 for value in &mut sum {
538 *value /= count;
539 }
540
541 Ok(sum)
542 }
543 }
544}
545
546#[cfg(test)]
547mod tests {
548 use super::*;
549 use autoagents_core::embeddings::Embedding;
550 use autoagents_core::one_or_many::OneOrMany;
551 use autoagents_core::vector_store::request::{Filter, SearchFilter};
552 use std::sync::Arc;
553
554 #[test]
555 fn test_stable_point_id_deterministic() {
556 let id1 = QdrantVectorStore::stable_point_id("doc:1");
557 let id2 = QdrantVectorStore::stable_point_id("doc:1");
558 let id3 = QdrantVectorStore::stable_point_id("doc:2");
559 assert_eq!(id1, id2);
560 assert_ne!(id1, id3);
561 }
562
563 #[test]
564 fn test_payload_encode_decode() {
565 #[derive(Debug, Clone, serde::Deserialize)]
566 struct TestDoc {
567 name: String,
568 }
569
570 let doc = PreparedDocument {
571 id: "doc-1".to_string(),
572 raw: serde_json::json!({"name":"alpha"}),
573 embeddings: OneOrMany::One(Embedding {
574 document: "alpha".to_string(),
575 vec: Arc::from(vec![0.1_f32, 0.2_f32]),
576 }),
577 };
578
579 let payload = QdrantVectorStore::payload_for(&doc).unwrap();
580 let payload_map: HashMap<String, qdrant_client::qdrant::Value> = payload.clone().into();
581 let decoded_id = QdrantVectorStore::decode_id(&payload_map).unwrap();
582 assert_eq!(decoded_id, "doc-1");
583
584 let decoded: Option<TestDoc> = QdrantVectorStore::decode_raw(&payload_map).unwrap();
585 assert_eq!(decoded.unwrap().name, "alpha");
586 }
587
588 #[test]
589 fn test_named_dimensions() {
590 let vectors = HashMap::from([
591 ("a".to_string(), vec![0.1_f32, 0.2_f32]),
592 ("b".to_string(), vec![1.0_f32]),
593 ]);
594 let dims = QdrantVectorStore::named_dimensions(&vectors);
595 assert_eq!(dims.get("a"), Some(&2));
596 assert_eq!(dims.get("b"), Some(&1));
597 }
598
599 #[test]
600 fn test_number_to_f64() {
601 assert_eq!(number_to_f64(&serde_json::json!(1)).unwrap(), 1.0);
602 assert_eq!(number_to_f64(&serde_json::json!(1.5)).unwrap(), 1.5);
603 assert!(number_to_f64(&serde_json::json!("x")).is_err());
604 }
605
606 #[test]
607 fn test_value_to_match_value() {
608 let m = value_to_match_value(serde_json::json!("a")).unwrap();
609 match m {
610 qdrant_client::qdrant::r#match::MatchValue::Keyword(val) => assert_eq!(val, "a"),
611 _ => panic!("expected keyword"),
612 }
613
614 let m = value_to_match_value(serde_json::json!(true)).unwrap();
615 match m {
616 qdrant_client::qdrant::r#match::MatchValue::Boolean(val) => assert!(val),
617 _ => panic!("expected boolean"),
618 }
619 }
620
621 #[test]
622 fn test_value_to_match_value_numbers_and_errors() {
623 let m = value_to_match_value(serde_json::json!(42)).unwrap();
624 match m {
625 qdrant_client::qdrant::r#match::MatchValue::Integer(val) => assert_eq!(val, 42),
626 _ => panic!("expected integer"),
627 }
628
629 let m = value_to_match_value(serde_json::json!(1.5)).unwrap();
630 match m {
631 qdrant_client::qdrant::r#match::MatchValue::Keyword(val) => assert_eq!(val, "1.5"),
632 _ => panic!("expected keyword"),
633 }
634
635 assert!(value_to_match_value(serde_json::json!([1, 2, 3])).is_err());
636 }
637
638 #[test]
639 fn test_to_qdrant_filter_lt() {
640 let filter = Filter::Lt("num".to_string(), serde_json::json!(10));
641 let qdrant = to_qdrant_filter(filter).unwrap();
642 assert_eq!(qdrant.must.len(), 1);
643 }
644
645 #[test]
646 fn test_to_qdrant_filter_and_or() {
647 let filter = Filter::Eq("field".to_string(), serde_json::json!("x"))
648 .and(Filter::Gt("num".to_string(), serde_json::json!(2)));
649 let qdrant = to_qdrant_filter(filter).unwrap();
650 assert_eq!(qdrant.must.len(), 2);
651
652 let filter = Filter::Eq("field".to_string(), serde_json::json!("x"))
653 .or(Filter::Lt("num".to_string(), serde_json::json!(10)));
654 let qdrant = to_qdrant_filter(filter).unwrap();
655 assert_eq!(qdrant.should.len(), 2);
656 }
657
658 #[test]
659 fn test_decode_helpers_missing_fields() {
660 let payload: HashMap<String, qdrant_client::qdrant::Value> = HashMap::new();
661 assert!(QdrantVectorStore::decode_id(&payload).is_none());
662 let raw: Option<serde_json::Value> = QdrantVectorStore::decode_raw(&payload).unwrap();
663 assert!(raw.is_none());
664 }
665
666 #[test]
667 fn test_to_qdrant_filter_eq_and_gt() {
668 let filter = Filter::Eq("tag".to_string(), serde_json::json!("alpha"));
669 let qdrant = to_qdrant_filter(filter).unwrap();
670 assert_eq!(qdrant.must.len(), 1);
671
672 let filter = Filter::Gt("score".to_string(), serde_json::json!(1.5));
673 let qdrant = to_qdrant_filter(filter).unwrap();
674 assert_eq!(qdrant.must.len(), 1);
675 }
676
677 #[test]
678 fn test_combine_embeddings() {
679 let one = OneOrMany::One(Embedding {
680 document: "doc".to_string(),
681 vec: Arc::from(vec![1.0_f32, 2.0_f32]),
682 });
683 let combined = combine_embeddings(&one).unwrap();
684 assert_eq!(combined, vec![1.0, 2.0]);
685
686 let many = OneOrMany::Many(vec![
687 Embedding {
688 document: "a".to_string(),
689 vec: Arc::from(vec![1.0_f32, 3.0_f32]),
690 },
691 Embedding {
692 document: "b".to_string(),
693 vec: Arc::from(vec![3.0_f32, 5.0_f32]),
694 },
695 ]);
696 let combined = combine_embeddings(&many).unwrap();
697 assert_eq!(combined, vec![2.0, 4.0]);
698 }
699
700 #[test]
701 fn test_combine_embeddings_dimension_mismatch() {
702 let many = OneOrMany::Many(vec![
703 Embedding {
704 document: "a".to_string(),
705 vec: Arc::from(vec![1.0_f32, 2.0_f32]),
706 },
707 Embedding {
708 document: "b".to_string(),
709 vec: Arc::from(vec![1.0_f32]),
710 },
711 ]);
712 let err = combine_embeddings(&many).unwrap_err();
713 assert!(
714 err.to_string()
715 .contains("inconsistent embedding dimensions")
716 );
717 }
718}