use async_trait::async_trait;
use qdrant_client::qdrant::{
DeletePointsBuilder, Filter, PointStruct, PointsIdsList, SearchPointsBuilder,
UpsertPointsBuilder,
};
use qdrant_client::Payload;
use serde_json::{json, Value};
use std::sync::Arc;
pub use qdrant_client::Qdrant;
#[deprecated(note = "use `Qdrant` instead")]
pub use qdrant_client::Qdrant as QdrantClient;
use crate::{
embedding::embedder_trait::Embedder,
schemas::Document,
vectorstore::{VecStoreOptions, VectorStore, VectorStoreError},
};
use uuid::Uuid;
pub struct Store {
pub client: Qdrant,
pub embedder: Arc<dyn Embedder>,
pub collection_name: String,
pub content_field: String,
pub metadata_field: String,
pub search_filter: Option<Filter>,
}
type QdrantOptions = VecStoreOptions<Value>;
#[async_trait]
impl VectorStore for Store {
type Options = QdrantOptions;
async fn add_documents(
&self,
docs: &[Document],
opt: &QdrantOptions,
) -> Result<Vec<String>, VectorStoreError> {
let embedder = opt.embedder.as_ref().unwrap_or(&self.embedder);
let texts: Vec<String> = docs.iter().map(|d| d.page_content.clone()).collect();
let ids = docs.iter().map(|_| Uuid::new_v4().to_string());
let vectors = embedder.embed_documents(&texts).await?.into_iter();
let payloads = docs.iter().map(|d| {
let mut base = json!({
&self.content_field: d.page_content,
&self.metadata_field: d.metadata,
});
if let Some(extra_json) = opt.filters.clone() {
if let (Value::Object(ref mut base_map), Value::Object(extra_map)) =
(&mut base, extra_json)
{
base_map.extend(extra_map);
}
}
base
});
let mut points: Vec<PointStruct> = Vec::with_capacity(docs.len());
for (id, (vector, payload)) in ids.clone().zip(vectors.zip(payloads)) {
let vector: Vec<f32> = vector.into_iter().map(|f| f as f32).collect();
let point = PointStruct::new(id, vector, Payload::try_from(payload).unwrap());
points.push(point);
}
self.client
.upsert_points(UpsertPointsBuilder::new(&self.collection_name, points).wait(true))
.await
.map_err(|e| VectorStoreError::from(e.to_string()))?;
Ok(ids.collect())
}
async fn similarity_search(
&self,
query: &str,
limit: usize,
opt: &QdrantOptions,
) -> Result<Vec<Document>, VectorStoreError> {
if opt.name_space.is_some() {
return Err(VectorStoreError::InvalidParameter(
"Qdrant doesn't support namespaces".to_string(),
));
}
if opt.filters.is_some() {
return Err(VectorStoreError::InvalidParameter(
"'qdrant_client' doesn't support 'serde_json::Value' filters.
Use `search_filter` when constructing VectorStore instead"
.to_string(),
));
}
let embedder = opt.embedder.as_ref().unwrap_or(&self.embedder);
let query_vector: Vec<f32> = embedder
.embed_query(query)
.await?
.into_iter()
.map(|f| f as f32)
.collect();
let mut operation =
SearchPointsBuilder::new(&self.collection_name, query_vector, limit as u64)
.with_payload(true);
if let Some(score_threshold) = opt.score_threshold {
operation = operation.score_threshold(score_threshold);
}
if let Some(filter) = &self.search_filter {
operation = operation.filter(filter.clone());
}
let results = self
.client
.search_points(operation)
.await
.map_err(|e| VectorStoreError::from(e.to_string()))?;
let documents = results
.result
.into_iter()
.map(|scored_point| {
let payload = scored_point.payload;
let page_content = payload[&self.content_field].to_string();
let metadata =
serde_json::from_value(payload[&self.metadata_field].clone().into_json())
.unwrap();
let score = scored_point.score as f64;
Document {
page_content,
metadata,
score,
}
})
.collect();
Ok(documents)
}
async fn delete(&self, ids: &[String], _opt: &QdrantOptions) -> Result<(), VectorStoreError> {
if ids.is_empty() {
return Ok(());
}
let point_ids: Vec<qdrant_client::qdrant::PointId> = ids
.iter()
.map(|s| qdrant_client::qdrant::PointId::from(s.as_str()))
.collect();
self.client
.delete_points(
DeletePointsBuilder::new(&self.collection_name)
.points(PointsIdsList { ids: point_ids })
.wait(true),
)
.await
.map_err(|e| VectorStoreError::from(e.to_string()))?;
Ok(())
}
}