use crate::embedding::Embedder;
use crate::vectorstore::qdrant::Store;
use qdrant_client::qdrant::{CreateCollectionBuilder, Distance, Filter, VectorParamsBuilder};
use qdrant_client::Qdrant;
use std::error::Error;
use std::sync::Arc;
pub struct StoreBuilder {
client: Option<Qdrant>,
embedder: Option<Arc<dyn Embedder>>,
collection_name: Option<String>,
content_field: String,
metadata_field: String,
recreate_collection: bool,
search_filter: Option<Filter>,
}
impl Default for StoreBuilder {
fn default() -> Self {
Self::new()
}
}
impl StoreBuilder {
pub fn new() -> Self {
StoreBuilder {
client: None,
embedder: None,
collection_name: None,
search_filter: None,
content_field: "page_content".to_string(),
metadata_field: "metadata".to_string(),
recreate_collection: false,
}
}
pub fn client(mut self, client: Qdrant) -> Self {
self.client = Some(client);
self
}
pub fn embedder<E: Embedder + 'static>(mut self, embedder: E) -> Self {
self.embedder = Some(Arc::new(embedder));
self
}
pub fn collection_name(mut self, collection_name: &str) -> Self {
self.collection_name = Some(collection_name.to_string());
self
}
pub fn metadata_field(mut self, metadata_field: &str) -> Self {
self.metadata_field = metadata_field.to_string();
self
}
pub fn content_field(mut self, content_field: &str) -> Self {
self.content_field = content_field.to_string();
self
}
pub fn recreate_collection(mut self, recreate_collection: bool) -> Self {
self.recreate_collection = recreate_collection;
self
}
pub fn search_filter(mut self, search_filter: Filter) -> Self {
self.search_filter = Some(search_filter);
self
}
pub async fn build(mut self) -> Result<Store, Box<dyn Error>> {
let client = self.client.take().ok_or("'client' is required")?;
let embedder = self.embedder.take().ok_or("'embedder' is required")?;
let collection_name = self
.collection_name
.take()
.ok_or("'collection_name' is required")?;
let collection_exists = client.collection_exists(&collection_name).await?;
if collection_exists && self.recreate_collection {
client.delete_collection(&collection_name).await?;
}
if !collection_exists || self.recreate_collection {
let embeddings = embedder
.embed_query("Text to retrieve embeddings dimension")
.await?;
let embeddings_dimension = embeddings.len() as u64;
client
.create_collection(
CreateCollectionBuilder::new(&collection_name).vectors_config(
VectorParamsBuilder::new(embeddings_dimension, Distance::Cosine),
),
)
.await?;
}
Ok(Store {
client,
embedder,
collection_name,
search_filter: self.search_filter,
content_field: self.content_field,
metadata_field: self.metadata_field,
})
}
}