use crate::models::{Chunk, Node};
use crate::traits::{CerebroError, Result, VectorStore};
use async_trait::async_trait;
use qdrant_client::qdrant::{
Condition, CreateCollectionBuilder, DeletePointsBuilder, Distance, Filter, GetPointsBuilder,
PointStruct, UpsertPointsBuilder, Value, VectorParamsBuilder,
};
use qdrant_client::Qdrant;
use std::collections::HashMap;
pub struct QdrantVectorStore {
client: Qdrant,
collection_name: String,
}
impl QdrantVectorStore {
pub async fn new(url: &str, collection_name: &str, vector_size: u64) -> Result<Self> {
let client = Qdrant::from_url(url)
.build()
.map_err(|e| CerebroError::StorageError(e.to_string()))?;
if !client
.collection_exists(collection_name)
.await
.map_err(|e| CerebroError::StorageError(e.to_string()))?
{
client
.create_collection(
CreateCollectionBuilder::new(collection_name)
.vectors_config(VectorParamsBuilder::new(vector_size, Distance::Cosine)),
)
.await
.map_err(|e| CerebroError::StorageError(e.to_string()))?;
}
Ok(Self {
client,
collection_name: collection_name.into(),
})
}
}
#[async_trait]
impl VectorStore for QdrantVectorStore {
async fn upsert(&self, nodes: Vec<Node>) -> Result<()> {
let mut points = Vec::with_capacity(nodes.len());
for node in nodes {
let mut payload: HashMap<String, Value> = HashMap::new();
payload.insert("document_id".to_string(), node.chunk.document_id.into());
payload.insert("chunk_index".to_string(), (node.chunk.index as i64).into());
payload.insert("text_content".to_string(), node.chunk.text.into());
points.push(PointStruct::new(node.id, node.embedding, payload));
}
self.client
.upsert_points(UpsertPointsBuilder::new(&self.collection_name, points))
.await
.map_err(|e| CerebroError::StorageError(e.to_string()))?;
Ok(())
}
async fn get(&self, node_ids: &[&str]) -> Result<Vec<Node>> {
let points = self
.client
.get_points(
GetPointsBuilder::new(
&self.collection_name,
node_ids
.iter()
.map(|&s| s.to_string().into())
.collect::<Vec<qdrant_client::qdrant::PointId>>(),
)
.with_payload(true)
.with_vectors(true),
)
.await
.map_err(|e| CerebroError::StorageError(e.to_string()))?;
let mut results = Vec::new();
for point in points.result {
let id = point
.id
.and_then(|i| i.point_id_options)
.map(|opt| match opt {
qdrant_client::qdrant::point_id::PointIdOptions::Uuid(u) => u,
_ => String::new(),
})
.unwrap_or_default();
if id.is_empty() {
continue;
}
let payload = point.payload;
let document_id = payload
.get("document_id")
.and_then(|v| v.kind.clone())
.map(|k| match k {
qdrant_client::qdrant::value::Kind::StringValue(s) => s,
_ => String::new(),
})
.unwrap_or_default();
let chunk_index = payload
.get("chunk_index")
.and_then(|v| v.kind.clone())
.map(|k| match k {
qdrant_client::qdrant::value::Kind::IntegerValue(i) => i as usize,
_ => 0,
})
.unwrap_or(0);
let text_content = payload
.get("text_content")
.and_then(|v| v.kind.clone())
.map(|k| match k {
qdrant_client::qdrant::value::Kind::StringValue(s) => s,
_ => String::new(),
})
.unwrap_or_default();
let embedding =
if let Some(qdrant_client::qdrant::vectors_output::VectorsOptions::Vector(v)) =
point.vectors.and_then(|v| v.vectors_options)
{
v.data
} else {
vec![]
};
results.push(Node {
id,
chunk: Chunk {
document_id,
index: chunk_index,
text: text_content,
},
embedding,
edges: vec![],
});
}
Ok(results)
}
async fn search(
&self,
_text_query: &str,
embedding: &[f32],
top_k: usize,
) -> Result<Vec<(Node, f32)>> {
let points = self
.client
.search_points(
qdrant_client::qdrant::SearchPointsBuilder::new(
&self.collection_name,
embedding.to_vec(),
top_k as u64,
)
.with_payload(true)
.with_vectors(true),
)
.await
.map_err(|e| CerebroError::StorageError(e.to_string()))?;
let mut results = Vec::new();
for point in points.result {
let id = point
.id
.and_then(|i| i.point_id_options)
.map(|opt| match opt {
qdrant_client::qdrant::point_id::PointIdOptions::Uuid(u) => u,
_ => String::new(),
})
.unwrap_or_default();
if id.is_empty() {
continue;
}
let payload = point.payload;
let document_id = payload
.get("document_id")
.and_then(|v| v.kind.clone())
.map(|k| match k {
qdrant_client::qdrant::value::Kind::StringValue(s) => s,
_ => String::new(),
})
.unwrap_or_default();
let chunk_index = payload
.get("chunk_index")
.and_then(|v| v.kind.clone())
.map(|k| match k {
qdrant_client::qdrant::value::Kind::IntegerValue(i) => i as usize,
_ => 0,
})
.unwrap_or(0);
let text_content = payload
.get("text_content")
.and_then(|v| v.kind.clone())
.map(|k| match k {
qdrant_client::qdrant::value::Kind::StringValue(s) => s,
_ => String::new(),
})
.unwrap_or_default();
let embedding =
if let Some(qdrant_client::qdrant::vectors_output::VectorsOptions::Vector(v)) =
point.vectors.and_then(|v| v.vectors_options)
{
v.data
} else {
vec![]
};
results.push((
Node {
id,
chunk: Chunk {
document_id,
index: chunk_index,
text: text_content,
},
embedding,
edges: vec![],
},
point.score,
));
}
Ok(results)
}
async fn delete_document(&self, doc_id: &str) -> Result<()> {
let condition = Condition::matches("document_id", doc_id.to_string());
self.client
.delete_points(
DeletePointsBuilder::new(&self.collection_name).points(Filter::must([condition])),
)
.await
.map_err(|e| CerebroError::StorageError(e.to_string()))?;
Ok(())
}
async fn get_all_nodes(&self) -> Result<Vec<Node>> {
Ok(vec![])
}
}