use async_trait::async_trait;
use qdrant_client::Qdrant;
use qdrant_client::qdrant::{CreateCollectionBuilder, VectorParamsBuilder, Distance, PointStruct, UpsertPointsBuilder, Condition, Filter, DeletePointsBuilder, GetPointsBuilder, Value};
use std::collections::HashMap;
use crate::models::{Chunk, Node};
use crate::traits::{CerebroError, Result, VectorStore};
pub struct QdrantVectorStore {
client: Qdrant,
collection_name: String,
}
impl QdrantVectorStore {
pub async fn new(url: &str, collection_name: &str, vector_size: u64) -> Result<Self> {
let client = Qdrant::from_url(url).build().map_err(|e| CerebroError::StorageError(e.to_string()))?;
if !client.collection_exists(collection_name).await.map_err(|e| CerebroError::StorageError(e.to_string()))? {
client.create_collection(
CreateCollectionBuilder::new(collection_name)
.vectors_config(VectorParamsBuilder::new(vector_size, Distance::Cosine))
).await.map_err(|e| CerebroError::StorageError(e.to_string()))?;
}
Ok(Self { client, collection_name: collection_name.into() })
}
}
#[async_trait]
impl VectorStore for QdrantVectorStore {
async fn upsert(&self, nodes: Vec<Node>) -> Result<()> {
let mut points = Vec::with_capacity(nodes.len());
for node in nodes {
let mut payload: HashMap<String, Value> = HashMap::new();
payload.insert("document_id".to_string(), node.chunk.document_id.into());
payload.insert("chunk_index".to_string(), (node.chunk.index as i64).into());
payload.insert("text_content".to_string(), node.chunk.text.into());
points.push(PointStruct::new(
node.id,
node.embedding,
payload,
));
}
self.client.upsert_points(UpsertPointsBuilder::new(&self.collection_name, points))
.await
.map_err(|e| CerebroError::StorageError(e.to_string()))?;
Ok(())
}
async fn get(&self, node_ids: &[&str]) -> Result<Vec<Node>> {
let points = self.client.get_points(
GetPointsBuilder::new(&self.collection_name, node_ids.iter().map(|&s| s.to_string().into()).collect::<Vec<qdrant_client::qdrant::PointId>>())
.with_payload(true)
.with_vectors(true)
).await.map_err(|e| CerebroError::StorageError(e.to_string()))?;
let mut results = Vec::new();
for point in points.result {
let id = point.id.and_then(|i| i.point_id_options).map(|opt| match opt {
qdrant_client::qdrant::point_id::PointIdOptions::Uuid(u) => u,
_ => String::new()
}).unwrap_or_default();
if id.is_empty() { continue; }
let payload = point.payload;
let document_id = payload.get("document_id").and_then(|v| v.kind.clone()).map(|k| match k {
qdrant_client::qdrant::value::Kind::StringValue(s) => s,
_ => String::new(),
}).unwrap_or_default();
let chunk_index = payload.get("chunk_index").and_then(|v| v.kind.clone()).map(|k| match k {
qdrant_client::qdrant::value::Kind::IntegerValue(i) => i as usize,
_ => 0,
}).unwrap_or(0);
let text_content = payload.get("text_content").and_then(|v| v.kind.clone()).map(|k| match k {
qdrant_client::qdrant::value::Kind::StringValue(s) => s,
_ => String::new(),
}).unwrap_or_default();
let embedding = if let Some(qdrant_client::qdrant::vectors_output::VectorsOptions::Vector(v)) = point.vectors.and_then(|v| v.vectors_options) {
v.data
} else {
vec![]
};
results.push(Node {
id,
chunk: Chunk { document_id, index: chunk_index, text: text_content },
embedding,
edges: vec![],
});
}
Ok(results)
}
async fn search(&self, _text_query: &str, embedding: &[f32], top_k: usize) -> Result<Vec<(Node, f32)>> {
let points = self.client.search_points(
qdrant_client::qdrant::SearchPointsBuilder::new(&self.collection_name, embedding.to_vec(), top_k as u64)
.with_payload(true)
.with_vectors(true)
).await.map_err(|e| CerebroError::StorageError(e.to_string()))?;
let mut results = Vec::new();
for point in points.result {
let id = point.id.and_then(|i| i.point_id_options).map(|opt| match opt {
qdrant_client::qdrant::point_id::PointIdOptions::Uuid(u) => u,
_ => String::new()
}).unwrap_or_default();
if id.is_empty() { continue; }
let payload = point.payload;
let document_id = payload.get("document_id").and_then(|v| v.kind.clone()).map(|k| match k {
qdrant_client::qdrant::value::Kind::StringValue(s) => s,
_ => String::new(),
}).unwrap_or_default();
let chunk_index = payload.get("chunk_index").and_then(|v| v.kind.clone()).map(|k| match k {
qdrant_client::qdrant::value::Kind::IntegerValue(i) => i as usize,
_ => 0,
}).unwrap_or(0);
let text_content = payload.get("text_content").and_then(|v| v.kind.clone()).map(|k| match k {
qdrant_client::qdrant::value::Kind::StringValue(s) => s,
_ => String::new(),
}).unwrap_or_default();
let embedding = if let Some(qdrant_client::qdrant::vectors_output::VectorsOptions::Vector(v)) = point.vectors.and_then(|v| v.vectors_options) {
v.data
} else {
vec![]
};
results.push((Node {
id,
chunk: Chunk { document_id, index: chunk_index, text: text_content },
embedding,
edges: vec![],
}, point.score));
}
Ok(results)
}
async fn delete_document(&self, doc_id: &str) -> Result<()> {
let condition = Condition::matches("document_id", doc_id.to_string());
self.client.delete_points(
DeletePointsBuilder::new(&self.collection_name)
.points(Filter::must([condition]))
).await.map_err(|e| CerebroError::StorageError(e.to_string()))?;
Ok(())
}
async fn get_all_nodes(&self) -> Result<Vec<Node>> {
Ok(vec![])
}
}