use crate::databases::bm25_helpers::{self, SharedIdfStats};
use crate::databases::traits::VectorDatabase;
use crate::glob_utils;
use anyhow::{Context, Result};
use brainwires_core::{ChunkMetadata, DatabaseStats, SearchResult};
use qdrant_client::qdrant::vectors_config::Config;
use qdrant_client::qdrant::{
Condition, CreateCollectionBuilder, DeletePointsBuilder, Distance, Filter, PointStruct,
SearchPointsBuilder, UpsertPointsBuilder, VectorParams, VectorsConfig,
};
use qdrant_client::{Payload, Qdrant};
use serde_json::json;
const COLLECTION_NAME: &str = "code_embeddings";
pub struct QdrantDatabase {
client: Qdrant,
idf_stats: SharedIdfStats,
}
impl QdrantDatabase {
pub async fn new() -> Result<Self> {
Self::with_url(&Self::default_url()).await
}
pub fn default_url() -> String {
"http://localhost:6334".to_string()
}
pub async fn with_url(url: &str) -> Result<Self> {
tracing::info!("Connecting to Qdrant at {}", url);
let client = Qdrant::from_url(url)
.build()
.context("Failed to create Qdrant client")?;
let db = Self {
client,
idf_stats: bm25_helpers::new_shared_idf_stats(),
};
if let Err(e) = db.refresh_idf_stats().await {
tracing::warn!("Failed to initialize IDF stats: {}", e);
}
Ok(db)
}
async fn refresh_idf_stats(&self) -> Result<()> {
use qdrant_client::qdrant::ScrollPointsBuilder;
tracing::info!("Refreshing IDF statistics...");
let mut documents = Vec::new();
let mut offset: Option<qdrant_client::qdrant::PointId> = None;
loop {
let mut builder = ScrollPointsBuilder::new(COLLECTION_NAME)
.with_payload(true)
.limit(100);
if let Some(ref point_id) = offset {
builder = builder.offset(point_id.clone());
}
let scroll_result = match self.client.scroll(builder).await {
Ok(result) => result,
Err(_) => break,
};
if scroll_result.result.is_empty() {
break;
}
for point in &scroll_result.result {
if let Some(content) = point.payload.get("content").and_then(|v| v.as_str()) {
documents.push(content.to_string());
}
}
offset = scroll_result.next_page_offset;
if offset.is_none() {
break;
}
}
tracing::info!("Refreshing IDF stats from {} documents", documents.len());
bm25_helpers::update_idf_stats(&self.idf_stats, &documents).await;
Ok(())
}
async fn collection_exists(&self) -> Result<bool> {
let collections = self
.client
.list_collections()
.await
.context("Failed to list collections")?;
Ok(collections
.collections
.iter()
.any(|c| c.name == COLLECTION_NAME))
}
}
#[async_trait::async_trait]
impl VectorDatabase for QdrantDatabase {
async fn initialize(&self, dimension: usize) -> Result<()> {
if self.collection_exists().await? {
tracing::info!("Collection '{}' already exists", COLLECTION_NAME);
return Ok(());
}
tracing::info!(
"Creating collection '{}' with dimension {}",
COLLECTION_NAME,
dimension
);
self.client
.create_collection(
CreateCollectionBuilder::new(COLLECTION_NAME).vectors_config(VectorsConfig {
config: Some(Config::Params(VectorParams {
size: dimension as u64,
distance: Distance::Cosine.into(),
..Default::default()
})),
}),
)
.await
.context("Failed to create collection")?;
Ok(())
}
async fn store_embeddings(
&self,
embeddings: Vec<Vec<f32>>,
metadata: Vec<ChunkMetadata>,
contents: Vec<String>,
_root_path: &str,
) -> Result<usize> {
if embeddings.is_empty() {
return Ok(0);
}
let count = embeddings.len();
tracing::debug!("Storing {} embeddings", count);
let points: Vec<PointStruct> = embeddings
.into_iter()
.zip(metadata)
.zip(contents)
.enumerate()
.map(|(idx, ((embedding, meta), content))| {
let payload: Payload = json!({
"file_path": meta.file_path,
"project": meta.project,
"start_line": meta.start_line,
"end_line": meta.end_line,
"language": meta.language,
"extension": meta.extension,
"file_hash": meta.file_hash,
"indexed_at": meta.indexed_at,
"content": content,
})
.try_into()
.expect("JSON object always converts to Payload");
PointStruct::new(idx as u64, embedding, payload)
})
.collect();
self.client
.upsert_points(UpsertPointsBuilder::new(COLLECTION_NAME, points))
.await
.context("Failed to upsert points")?;
if let Err(e) = self.refresh_idf_stats().await {
tracing::warn!("Failed to refresh IDF stats after indexing: {}", e);
}
Ok(count)
}
async fn search(
&self,
query_vector: Vec<f32>,
query_text: &str,
limit: usize,
min_score: f32,
project: Option<String>,
root_path: Option<String>,
hybrid: bool,
) -> Result<Vec<SearchResult>> {
self.search_filtered(
query_vector,
query_text,
limit,
min_score,
project,
root_path,
hybrid,
vec![],
vec![],
vec![],
)
.await
}
async fn search_filtered(
&self,
query_vector: Vec<f32>,
query_text: &str,
limit: usize,
min_score: f32,
project: Option<String>,
root_path: Option<String>,
hybrid: bool,
file_extensions: Vec<String>,
languages: Vec<String>,
path_patterns: Vec<String>,
) -> Result<Vec<SearchResult>> {
tracing::debug!(
"Searching with limit={}, min_score={}, project={:?}, root_path={:?}, hybrid={}, filters: ext={:?}, lang={:?}, path={:?}",
limit,
min_score,
project,
root_path,
hybrid,
file_extensions,
languages,
path_patterns
);
let mut filter = Filter::default();
let mut must_conditions = vec![];
if let Some(proj) = project {
must_conditions.push(Condition::matches("project", proj));
}
if !file_extensions.is_empty() {
must_conditions.push(Condition::matches(
"extension",
file_extensions.into_iter().collect::<Vec<_>>(),
));
}
if !languages.is_empty() {
must_conditions.push(Condition::matches(
"language",
languages.into_iter().collect::<Vec<_>>(),
));
}
if !must_conditions.is_empty() {
filter.must = must_conditions;
}
let mut search_builder =
SearchPointsBuilder::new(COLLECTION_NAME, query_vector, limit as u64)
.score_threshold(min_score)
.with_payload(true);
if !filter.must.is_empty() {
search_builder = search_builder.filter(filter);
}
let search_result = self
.client
.search_points(search_builder)
.await
.context("Failed to search points")?;
let mut results: Vec<SearchResult> = Vec::new();
for point in search_result.result {
let payload = point.payload;
let vector_score = point.score;
let content = match payload.get("content").and_then(|v| v.as_str()) {
Some(c) => c.to_string(),
None => continue,
};
let (final_score, keyword_score) = if hybrid {
let kw_score =
bm25_helpers::calculate_bm25_score(&self.idf_stats, query_text, &content).await;
(
bm25_helpers::combine_scores(vector_score, kw_score),
Some(kw_score),
)
} else {
(vector_score, None)
};
let file_path = match payload.get("file_path").and_then(|v| v.as_str()) {
Some(p) => p.to_string(),
None => continue,
};
let start_line = match payload.get("start_line").and_then(|v| v.as_integer()) {
Some(l) => l as usize,
None => continue,
};
let end_line = match payload.get("end_line").and_then(|v| v.as_integer()) {
Some(l) => l as usize,
None => continue,
};
let language = payload
.get("language")
.and_then(|v| v.as_str().map(String::from))
.unwrap_or_else(|| "Unknown".to_string());
let project = payload
.get("project")
.and_then(|v| v.as_str().map(String::from));
let result_root_path = payload
.get("root_path")
.and_then(|v| v.as_str().map(String::from));
if let Some(ref filter_path) = root_path
&& result_root_path.as_ref() != Some(filter_path)
{
continue;
}
let indexed_at = payload
.get("indexed_at")
.and_then(|v| v.as_integer())
.unwrap_or(0);
results.push(SearchResult {
file_path,
root_path: result_root_path,
content,
score: final_score,
vector_score,
keyword_score,
start_line,
end_line,
language,
project,
indexed_at,
});
}
if hybrid {
results.sort_by(|a, b| b.score.total_cmp(&a.score));
}
if !path_patterns.is_empty() {
results.retain(|r| glob_utils::matches_any_pattern(&r.file_path, &path_patterns));
}
Ok(results)
}
async fn delete_by_file(&self, file_path: &str) -> Result<usize> {
tracing::debug!("Deleting embeddings for file: {}", file_path);
let filter = Filter::must([Condition::matches("file_path", file_path.to_string())]);
self.client
.delete_points(DeletePointsBuilder::new(COLLECTION_NAME).points(filter))
.await
.context("Failed to delete points")?;
Ok(0)
}
async fn clear(&self) -> Result<()> {
tracing::info!("Clearing all embeddings from collection");
self.client
.delete_collection(COLLECTION_NAME)
.await
.context("Failed to delete collection")?;
let mut stats = self.idf_stats.write().await;
stats.total_docs = 0;
stats.doc_frequencies.clear();
Ok(())
}
async fn get_statistics(&self) -> Result<DatabaseStats> {
let collection_info = self
.client
.collection_info(COLLECTION_NAME)
.await
.context("Failed to get collection info")?;
let points_count = collection_info
.result
.and_then(|r| r.points_count)
.unwrap_or(0);
Ok(DatabaseStats {
total_points: points_count as usize,
total_vectors: points_count as usize,
language_breakdown: vec![],
})
}
async fn flush(&self) -> Result<()> {
Ok(())
}
async fn count_by_root_path(&self, root_path: &str) -> Result<usize> {
use qdrant_client::qdrant::CountPointsBuilder;
let filter = Filter::must([Condition::matches("root_path", root_path.to_string())]);
let count_result = self
.client
.count(CountPointsBuilder::new(COLLECTION_NAME).filter(filter))
.await
.context("Failed to count points by root path")?;
Ok(count_result.result.map(|r| r.count).unwrap_or(0) as usize)
}
async fn get_indexed_files(&self, root_path: &str) -> Result<Vec<String>> {
use qdrant_client::qdrant::ScrollPointsBuilder;
let filter = Filter::must([Condition::matches("root_path", root_path.to_string())]);
let mut file_paths = std::collections::HashSet::new();
let mut offset: Option<qdrant_client::qdrant::PointId> = None;
loop {
let mut builder = ScrollPointsBuilder::new(COLLECTION_NAME)
.filter(filter.clone())
.with_payload(true)
.limit(1000);
if let Some(ref point_id) = offset {
builder = builder.offset(point_id.clone());
}
let scroll_result = self
.client
.scroll(builder)
.await
.context("Failed to scroll points")?;
if scroll_result.result.is_empty() {
break;
}
for point in &scroll_result.result {
if let Some(file_path) = point.payload.get("file_path").and_then(|v| v.as_str()) {
file_paths.insert(file_path.to_string());
}
}
offset = scroll_result.next_page_offset;
if offset.is_none() {
break;
}
}
Ok(file_paths.into_iter().collect())
}
}
impl Default for QdrantDatabase {
fn default() -> Self {
tokio::runtime::Runtime::new()
.expect("failed to create tokio runtime")
.block_on(Self::new())
.expect("Failed to create default Qdrant client")
}
}