1use std::collections::HashMap;
17use std::sync::Arc;
18
19use async_trait::async_trait;
20use qdrant_client::Payload;
21use qdrant_client::Qdrant;
22use qdrant_client::qdrant::{
23 CountPointsBuilder, CreateCollectionBuilder, DeletePointsBuilder, Distance, FieldType,
24 PointStruct, PointsIdsList, ScrollPointsBuilder, SearchPointsBuilder, UpsertPointsBuilder,
25 VectorParamsBuilder, points_selector::PointsSelectorOneOf,
26};
27use serde_json::Value;
28use sha2::{Digest, Sha256};
29use uuid::Uuid;
30
31use entelix_core::context::ExecutionContext;
32use entelix_core::error::{Error, Result};
33use entelix_memory::{Document, Namespace, VectorFilter, VectorStore};
34
35use crate::error::{QdrantStoreError, QdrantStoreResult};
36use crate::filter::{self, CONTENT_KEY, DOC_ID_KEY, METADATA_KEY, NAMESPACE_KEY};
37
38#[derive(Clone, Copy, Debug, Eq, PartialEq, Default)]
42#[non_exhaustive]
43pub enum DistanceMetric {
44 #[default]
47 Cosine,
48 Dot,
50 Euclidean,
52}
53
54impl From<DistanceMetric> for Distance {
55 fn from(m: DistanceMetric) -> Self {
56 match m {
57 DistanceMetric::Cosine => Self::Cosine,
58 DistanceMetric::Dot => Self::Dot,
59 DistanceMetric::Euclidean => Self::Euclid,
60 }
61 }
62}
63
64#[derive(Clone)]
68pub struct QdrantVectorStore {
69 client: Arc<Qdrant>,
70 collection: Arc<str>,
71 dimension: usize,
72}
73
74impl std::fmt::Debug for QdrantVectorStore {
75 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
76 f.debug_struct("QdrantVectorStore")
77 .field("collection", &self.collection)
78 .field("dimension", &self.dimension)
79 .finish_non_exhaustive()
80 }
81}
82
83impl QdrantVectorStore {
84 pub fn builder(collection: impl Into<String>, dimension: usize) -> QdrantVectorStoreBuilder {
86 QdrantVectorStoreBuilder::new(collection, dimension)
87 }
88
89 fn point_id(namespace_key: &str, doc_id: &str) -> qdrant_client::qdrant::PointId {
94 let mut hasher = Sha256::new();
95 hasher.update(namespace_key.as_bytes());
96 hasher.update(b":");
97 hasher.update(doc_id.as_bytes());
98 let digest = hasher.finalize();
99 let mut bytes = [0u8; 16];
100 bytes.copy_from_slice(&digest[..16]);
101 Uuid::from_bytes(bytes).to_string().into()
102 }
103
104 fn build_payload(namespace_key: &str, doc_id: &str, document: &Document) -> Payload {
105 let mut map = serde_json::Map::with_capacity(4);
106 map.insert(
107 NAMESPACE_KEY.into(),
108 Value::String(namespace_key.to_owned()),
109 );
110 map.insert(DOC_ID_KEY.into(), Value::String(doc_id.to_owned()));
111 map.insert(CONTENT_KEY.into(), Value::String(document.content.clone()));
112 map.insert(METADATA_KEY.into(), document.metadata.clone());
117 Payload::try_from(Value::Object(map))
118 .expect("payload is a JSON object — Payload::try_from infallible on Object")
119 }
120
121 fn point_to_document(point: qdrant_client::qdrant::ScoredPoint) -> Document {
122 let (doc_id, content, metadata) = extract_payload_fields(&point.payload);
123 Document {
124 doc_id,
125 content,
126 metadata,
127 score: Some(point.score),
128 }
129 }
130
131 fn retrieved_to_document(point: qdrant_client::qdrant::RetrievedPoint) -> Document {
132 let (doc_id, content, metadata) = extract_payload_fields(&point.payload);
133 Document {
134 doc_id,
135 content,
136 metadata,
137 score: None,
138 }
139 }
140}
141
142fn extract_payload_fields(
143 payload: &HashMap<String, qdrant_client::qdrant::Value>,
144) -> (Option<String>, String, Value) {
145 let doc_id = payload
146 .get(DOC_ID_KEY)
147 .and_then(|v| v.as_str().map(ToOwned::to_owned));
148 let content = payload
149 .get(CONTENT_KEY)
150 .and_then(|v| v.as_str().map(ToOwned::to_owned))
151 .unwrap_or_default();
152 let metadata = payload
153 .get(METADATA_KEY)
154 .map_or(Value::Null, qdrant_value_to_json);
155 (doc_id, content, metadata)
156}
157
158fn qdrant_value_to_json(v: &qdrant_client::qdrant::Value) -> Value {
164 match &v.kind {
165 Some(qdrant_client::qdrant::value::Kind::NullValue(_)) | None => Value::Null,
166 Some(qdrant_client::qdrant::value::Kind::DoubleValue(d)) => {
167 serde_json::Number::from_f64(*d).map_or(Value::Null, Value::Number)
168 }
169 Some(qdrant_client::qdrant::value::Kind::IntegerValue(i)) => Value::Number((*i).into()),
170 Some(qdrant_client::qdrant::value::Kind::StringValue(s)) => Value::String(s.clone()),
171 Some(qdrant_client::qdrant::value::Kind::BoolValue(b)) => Value::Bool(*b),
172 Some(qdrant_client::qdrant::value::Kind::ListValue(list)) => {
173 Value::Array(list.values.iter().map(qdrant_value_to_json).collect())
174 }
175 Some(qdrant_client::qdrant::value::Kind::StructValue(s)) => Value::Object(
176 s.fields
177 .iter()
178 .map(|(k, v)| (k.clone(), qdrant_value_to_json(v)))
179 .collect(),
180 ),
181 }
182}
183
184#[must_use]
186pub struct QdrantVectorStoreBuilder {
187 collection: String,
188 dimension: usize,
189 distance: DistanceMetric,
190 url: String,
191 api_key: Option<String>,
192 timeout: Option<std::time::Duration>,
193 skip_create_collection: bool,
194 on_disk: Option<bool>,
195}
196
197impl QdrantVectorStoreBuilder {
198 fn new(collection: impl Into<String>, dimension: usize) -> Self {
199 Self {
200 collection: collection.into(),
201 dimension,
202 distance: DistanceMetric::default(),
203 url: "http://localhost:6334".into(),
204 api_key: None,
205 timeout: None,
206 skip_create_collection: false,
207 on_disk: None,
208 }
209 }
210
211 pub fn with_url(mut self, url: impl Into<String>) -> Self {
214 self.url = url.into();
215 self
216 }
217
218 pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
220 self.api_key = Some(api_key.into());
221 self
222 }
223
224 pub const fn with_timeout(mut self, timeout: std::time::Duration) -> Self {
227 self.timeout = Some(timeout);
228 self
229 }
230
231 pub const fn with_distance(mut self, distance: DistanceMetric) -> Self {
234 self.distance = distance;
235 self
236 }
237
238 pub const fn with_existing_collection(mut self) -> Self {
243 self.skip_create_collection = true;
244 self
245 }
246
247 pub const fn with_on_disk(mut self, on_disk: bool) -> Self {
250 self.on_disk = Some(on_disk);
251 self
252 }
253
254 pub async fn build(self) -> QdrantStoreResult<QdrantVectorStore> {
259 let mut config = qdrant_client::config::QdrantConfig::from_url(&self.url);
260 if let Some(api_key) = self.api_key {
261 config.api_key = Some(api_key);
262 }
263 if let Some(timeout) = self.timeout {
264 config.timeout = timeout;
265 }
266 let client = Qdrant::new(config)?;
267
268 if !self.skip_create_collection {
269 let exists = client
272 .collection_exists(&self.collection)
273 .await
274 .unwrap_or(false);
275 if !exists {
276 let mut vector_params =
277 VectorParamsBuilder::new(self.dimension as u64, Distance::from(self.distance));
278 if let Some(on_disk) = self.on_disk {
279 vector_params = vector_params.on_disk(on_disk);
280 }
281 client
282 .create_collection(
283 CreateCollectionBuilder::new(&self.collection)
284 .vectors_config(vector_params),
285 )
286 .await?;
287
288 let _ = client
292 .create_field_index(
293 qdrant_client::qdrant::CreateFieldIndexCollectionBuilder::new(
294 &self.collection,
295 NAMESPACE_KEY,
296 FieldType::Keyword,
297 ),
298 )
299 .await?;
300 let _ = client
301 .create_field_index(
302 qdrant_client::qdrant::CreateFieldIndexCollectionBuilder::new(
303 &self.collection,
304 DOC_ID_KEY,
305 FieldType::Keyword,
306 ),
307 )
308 .await?;
309 }
310 }
311
312 Ok(QdrantVectorStore {
313 client: Arc::new(client),
314 collection: self.collection.into(),
315 dimension: self.dimension,
316 })
317 }
318}
319
320#[async_trait]
321impl VectorStore for QdrantVectorStore {
322 fn dimension(&self) -> usize {
323 self.dimension
324 }
325
326 async fn add(
327 &self,
328 ctx: &ExecutionContext,
329 ns: &Namespace,
330 document: Document,
331 vector: Vec<f32>,
332 ) -> Result<()> {
333 if ctx.is_cancelled() {
334 return Err(Error::Cancelled);
335 }
336 if vector.len() != self.dimension {
337 return Err(Error::invalid_request(format!(
338 "QdrantVectorStore: vector dimension {} does not match \
339 index dimension {}",
340 vector.len(),
341 self.dimension
342 )));
343 }
344 let ns_key = ns.render();
345 let doc_id = document
346 .doc_id
347 .clone()
348 .unwrap_or_else(|| Uuid::new_v4().to_string());
349 let stored_doc = Document {
350 doc_id: Some(doc_id.clone()),
351 ..document
352 };
353 let payload = Self::build_payload(&ns_key, &doc_id, &stored_doc);
354 let point = PointStruct::new(Self::point_id(&ns_key, &doc_id), vector, payload);
355 self.client
356 .upsert_points(UpsertPointsBuilder::new(&*self.collection, vec![point]).wait(true))
357 .await
358 .map_err(|e| Error::from(QdrantStoreError::from(e)))?;
359 Ok(())
360 }
361
362 async fn add_batch(
363 &self,
364 ctx: &ExecutionContext,
365 ns: &Namespace,
366 items: Vec<(Document, Vec<f32>)>,
367 ) -> Result<()> {
368 if ctx.is_cancelled() {
369 return Err(Error::Cancelled);
370 }
371 if items.is_empty() {
372 return Ok(());
373 }
374 let ns_key = ns.render();
375 let mut points = Vec::with_capacity(items.len());
376 for (mut document, vector) in items {
377 if vector.len() != self.dimension {
378 return Err(Error::invalid_request(format!(
379 "QdrantVectorStore: vector dimension {} does not match \
380 index dimension {}",
381 vector.len(),
382 self.dimension
383 )));
384 }
385 let doc_id = document
386 .doc_id
387 .clone()
388 .unwrap_or_else(|| Uuid::new_v4().to_string());
389 document.doc_id = Some(doc_id.clone());
390 let payload = Self::build_payload(&ns_key, &doc_id, &document);
391 points.push(PointStruct::new(
392 Self::point_id(&ns_key, &doc_id),
393 vector,
394 payload,
395 ));
396 }
397 self.client
398 .upsert_points(UpsertPointsBuilder::new(&*self.collection, points).wait(true))
399 .await
400 .map_err(|e| Error::from(QdrantStoreError::from(e)))?;
401 Ok(())
402 }
403
404 async fn search(
405 &self,
406 ctx: &ExecutionContext,
407 ns: &Namespace,
408 query_vector: &[f32],
409 top_k: usize,
410 ) -> Result<Vec<Document>> {
411 self.search_filtered(ctx, ns, query_vector, top_k, &VectorFilter::All)
412 .await
413 }
414
415 async fn search_filtered(
416 &self,
417 ctx: &ExecutionContext,
418 ns: &Namespace,
419 query_vector: &[f32],
420 top_k: usize,
421 filter: &VectorFilter,
422 ) -> Result<Vec<Document>> {
423 if ctx.is_cancelled() {
424 return Err(Error::Cancelled);
425 }
426 if query_vector.len() != self.dimension {
427 return Err(Error::invalid_request(format!(
428 "QdrantVectorStore: query dimension {} does not match \
429 index dimension {}",
430 query_vector.len(),
431 self.dimension
432 )));
433 }
434 let ns_key = ns.render();
435 let projected = filter::project(Some(filter), &ns_key).map_err(Error::from)?;
436
437 let resp = self
438 .client
439 .search_points(
440 SearchPointsBuilder::new(&*self.collection, query_vector.to_vec(), top_k as u64)
441 .filter(projected)
442 .with_payload(true),
443 )
444 .await
445 .map_err(|e| Error::from(QdrantStoreError::from(e)))?;
446 Ok(resp
447 .result
448 .into_iter()
449 .map(Self::point_to_document)
450 .collect())
451 }
452
453 async fn delete(&self, ctx: &ExecutionContext, ns: &Namespace, doc_id: &str) -> Result<()> {
454 if ctx.is_cancelled() {
455 return Err(Error::Cancelled);
456 }
457 let ns_key = ns.render();
458 let pid = Self::point_id(&ns_key, doc_id);
459 self.client
460 .delete_points(
461 DeletePointsBuilder::new(&*self.collection)
462 .points(PointsSelectorOneOf::Points(PointsIdsList {
463 ids: vec![pid],
464 }))
465 .wait(true),
466 )
467 .await
468 .map_err(|e| Error::from(QdrantStoreError::from(e)))?;
469 Ok(())
470 }
471
472 async fn update(
473 &self,
474 ctx: &ExecutionContext,
475 ns: &Namespace,
476 doc_id: &str,
477 document: Document,
478 vector: Vec<f32>,
479 ) -> Result<()> {
480 let stored = Document {
484 doc_id: Some(doc_id.to_owned()),
485 ..document
486 };
487 self.add(ctx, ns, stored, vector).await
488 }
489
490 async fn count(
491 &self,
492 ctx: &ExecutionContext,
493 ns: &Namespace,
494 filter: Option<&VectorFilter>,
495 ) -> Result<usize> {
496 if ctx.is_cancelled() {
497 return Err(Error::Cancelled);
498 }
499 let ns_key = ns.render();
500 let projected = filter::project(filter, &ns_key).map_err(Error::from)?;
501 let resp = self
502 .client
503 .count(
504 CountPointsBuilder::new(&*self.collection)
505 .filter(projected)
506 .exact(true),
507 )
508 .await
509 .map_err(|e| Error::from(QdrantStoreError::from(e)))?;
510 Ok(resp.result.map(|r| r.count as usize).unwrap_or(0))
511 }
512
513 async fn list(
514 &self,
515 ctx: &ExecutionContext,
516 ns: &Namespace,
517 filter: Option<&VectorFilter>,
518 limit: usize,
519 offset: usize,
520 ) -> Result<Vec<Document>> {
521 if ctx.is_cancelled() {
522 return Err(Error::Cancelled);
523 }
524 let ns_key = ns.render();
525 let projected = filter::project(filter, &ns_key).map_err(Error::from)?;
526 let resp = self
533 .client
534 .scroll(
535 ScrollPointsBuilder::new(&*self.collection)
536 .filter(projected)
537 .limit((limit + offset) as u32)
538 .with_payload(true),
539 )
540 .await
541 .map_err(|e| Error::from(QdrantStoreError::from(e)))?;
542 Ok(resp
543 .result
544 .into_iter()
545 .skip(offset)
546 .take(limit)
547 .map(Self::retrieved_to_document)
548 .collect())
549 }
550}