use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use super::bm25::{BM25Config, BM25Index};
use crate::rag::SliceLayer;
use crate::storage::{ChromaDocument, StorageManager};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Deserialize, Serialize)]
#[serde(rename_all = "lowercase")]
pub enum SearchMode {
Vector,
Keyword,
#[default]
Hybrid,
}
impl std::str::FromStr for SearchMode {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"vector" => Ok(SearchMode::Vector),
"keyword" | "bm25" => Ok(SearchMode::Keyword),
"hybrid" => Ok(SearchMode::Hybrid),
other => Err(format!(
"Invalid search mode: '{}'. Use 'vector', 'keyword', or 'hybrid'",
other
)),
}
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct HybridConfig {
#[serde(default)]
pub mode: SearchMode,
#[serde(default = "default_vector_weight")]
pub vector_weight: f32,
#[serde(default = "default_bm25_weight")]
pub bm25_weight: f32,
#[serde(default)]
pub use_rrf: bool,
#[serde(default = "default_rrf_k")]
pub rrf_k: f32,
#[serde(default)]
pub bm25: BM25Config,
}
fn default_vector_weight() -> f32 {
0.6
}
fn default_bm25_weight() -> f32 {
0.4
}
fn default_rrf_k() -> f32 {
60.0
}
impl Default for HybridConfig {
fn default() -> Self {
Self {
mode: SearchMode::default(),
vector_weight: 0.6,
bm25_weight: 0.4,
use_rrf: false,
rrf_k: 60.0,
bm25: BM25Config::default(),
}
}
}
#[derive(Debug, Clone)]
pub struct HybridSearchResult {
pub id: String,
pub namespace: String,
pub document: String,
pub combined_score: f32,
pub vector_score: Option<f32>,
pub bm25_score: Option<f32>,
pub metadata: serde_json::Value,
pub layer: Option<SliceLayer>,
pub parent_id: Option<String>,
pub children_ids: Vec<String>,
pub keywords: Vec<String>,
}
pub struct HybridSearcher {
storage: Arc<StorageManager>,
bm25_index: Option<Arc<BM25Index>>,
config: HybridConfig,
}
impl HybridSearcher {
pub async fn new(storage: Arc<StorageManager>, config: HybridConfig) -> Result<Self> {
let bm25_index = if config.mode != SearchMode::Vector {
Some(Arc::new(BM25Index::new(&config.bm25)?))
} else {
None
};
Ok(Self {
storage,
bm25_index,
config,
})
}
pub fn with_bm25_index(
storage: Arc<StorageManager>,
bm25_index: Arc<BM25Index>,
config: HybridConfig,
) -> Self {
Self {
storage,
bm25_index: Some(bm25_index),
config,
}
}
pub fn bm25_index(&self) -> Option<&Arc<BM25Index>> {
self.bm25_index.as_ref()
}
pub async fn index_documents(&self, docs: &[ChromaDocument]) -> Result<()> {
self.storage.add_to_store(docs.to_vec()).await?;
if let Some(ref bm25) = self.bm25_index {
let bm25_docs: Vec<(String, String, String)> = docs
.iter()
.map(|d| (d.id.clone(), d.namespace.clone(), d.document.clone()))
.collect();
bm25.add_documents(&bm25_docs).await?;
}
Ok(())
}
pub async fn search(
&self,
query: &str,
query_embedding: Vec<f32>,
namespace: Option<&str>,
limit: usize,
layer_filter: Option<SliceLayer>,
) -> Result<Vec<HybridSearchResult>> {
match self.config.mode {
SearchMode::Vector => {
self.vector_only_search(query_embedding, namespace, limit, layer_filter)
.await
}
SearchMode::Keyword => self.keyword_only_search(query, namespace, limit).await,
SearchMode::Hybrid => {
self.hybrid_search(query, query_embedding, namespace, limit, layer_filter)
.await
}
}
}
async fn vector_only_search(
&self,
query_embedding: Vec<f32>,
namespace: Option<&str>,
limit: usize,
layer_filter: Option<SliceLayer>,
) -> Result<Vec<HybridSearchResult>> {
let candidates = self
.storage
.search_store_with_layer(namespace, query_embedding, limit, layer_filter)
.await?;
Ok(candidates
.into_iter()
.map(|doc| {
let layer = doc.slice_layer(); HybridSearchResult {
id: doc.id,
namespace: doc.namespace,
document: doc.document,
combined_score: 1.0, vector_score: Some(1.0),
bm25_score: None,
metadata: doc.metadata,
layer,
parent_id: doc.parent_id,
children_ids: doc.children_ids,
keywords: doc.keywords,
}
})
.collect())
}
async fn keyword_only_search(
&self,
query: &str,
namespace: Option<&str>,
limit: usize,
) -> Result<Vec<HybridSearchResult>> {
let bm25 = self
.bm25_index
.as_ref()
.ok_or_else(|| anyhow::anyhow!("BM25 index not initialized for keyword search"))?;
let bm25_results = bm25.search(query, namespace, limit)?;
let mut results = Vec::with_capacity(bm25_results.len());
for (id, score) in bm25_results {
if let Some(doc) = self
.storage
.get_document(namespace.unwrap_or("rag"), &id)
.await?
{
let layer = doc.slice_layer(); results.push(HybridSearchResult {
id: doc.id,
namespace: doc.namespace,
document: doc.document,
combined_score: score,
vector_score: None,
bm25_score: Some(score),
metadata: doc.metadata,
layer,
parent_id: doc.parent_id,
children_ids: doc.children_ids,
keywords: doc.keywords,
});
}
}
Ok(results)
}
async fn hybrid_search(
&self,
query: &str,
query_embedding: Vec<f32>,
namespace: Option<&str>,
limit: usize,
layer_filter: Option<SliceLayer>,
) -> Result<Vec<HybridSearchResult>> {
let expanded_limit = limit * 3;
let vector_results = self
.storage
.search_store_with_layer(namespace, query_embedding, expanded_limit, layer_filter)
.await?;
let bm25_results = if let Some(ref bm25) = self.bm25_index {
bm25.search(query, namespace, expanded_limit)?
} else {
vec![]
};
let fused = if self.config.use_rrf {
self.reciprocal_rank_fusion(&vector_results, &bm25_results)
} else {
self.weighted_linear_fusion(&vector_results, &bm25_results)
};
let mut results: Vec<_> = fused.into_iter().collect();
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
results.truncate(limit);
let mut final_results = Vec::with_capacity(results.len());
for (id, (combined_score, vector_score, bm25_score)) in results {
if let Some(doc) = vector_results.iter().find(|d| d.id == id) {
let layer = doc.slice_layer(); final_results.push(HybridSearchResult {
id: doc.id.clone(),
namespace: doc.namespace.clone(),
document: doc.document.clone(),
combined_score,
vector_score,
bm25_score,
metadata: doc.metadata.clone(),
layer,
parent_id: doc.parent_id.clone(),
children_ids: doc.children_ids.clone(),
keywords: doc.keywords.clone(),
});
} else if let Some(doc) = self
.storage
.get_document(namespace.unwrap_or("rag"), &id)
.await?
{
let layer = doc.slice_layer(); final_results.push(HybridSearchResult {
id: doc.id,
namespace: doc.namespace,
document: doc.document,
combined_score,
vector_score,
bm25_score,
metadata: doc.metadata,
layer,
parent_id: doc.parent_id,
children_ids: doc.children_ids,
keywords: doc.keywords,
});
}
}
tracing::debug!(
"Hybrid search: {} vector + {} BM25 -> {} fused results",
vector_results.len(),
bm25_results.len(),
final_results.len()
);
Ok(final_results)
}
fn weighted_linear_fusion(
&self,
vector_results: &[ChromaDocument],
bm25_results: &[(String, f32)],
) -> HashMap<String, (f32, Option<f32>, Option<f32>)> {
let mut combined: HashMap<String, (f32, Option<f32>, Option<f32>)> = HashMap::new();
for (idx, doc) in vector_results.iter().enumerate() {
let normalized = 1.0 - (idx as f32 / vector_results.len().max(1) as f32);
let weighted = normalized * self.config.vector_weight;
combined.insert(doc.id.clone(), (weighted, Some(normalized), None));
}
let bm25_max = bm25_results.iter().map(|(_, s)| *s).fold(0.0_f32, f32::max);
for (id, score) in bm25_results {
let normalized = if bm25_max > 0.0 {
score / bm25_max
} else {
0.0
};
let weighted = normalized * self.config.bm25_weight;
combined
.entry(id.clone())
.and_modify(|(total, _, bm25)| {
*total += weighted;
*bm25 = Some(normalized);
})
.or_insert((weighted, None, Some(normalized)));
}
combined
}
fn reciprocal_rank_fusion(
&self,
vector_results: &[ChromaDocument],
bm25_results: &[(String, f32)],
) -> HashMap<String, (f32, Option<f32>, Option<f32>)> {
let mut combined: HashMap<String, (f32, Option<f32>, Option<f32>)> = HashMap::new();
let k = self.config.rrf_k;
for (rank, doc) in vector_results.iter().enumerate() {
let rrf_score = 1.0 / (k + rank as f32 + 1.0);
let weighted = rrf_score * self.config.vector_weight;
combined.insert(doc.id.clone(), (weighted, Some(rrf_score), None));
}
for (rank, (id, _)) in bm25_results.iter().enumerate() {
let rrf_score = 1.0 / (k + rank as f32 + 1.0);
let weighted = rrf_score * self.config.bm25_weight;
combined
.entry(id.clone())
.and_modify(|(total, _, bm25)| {
*total += weighted;
*bm25 = Some(rrf_score);
})
.or_insert((weighted, None, Some(rrf_score)));
}
combined
}
pub async fn delete_documents(&self, namespace: &str, ids: &[String]) -> Result<usize> {
let mut deleted = 0;
for id in ids {
deleted += self.storage.delete_document(namespace, id).await?;
}
if let Some(ref bm25) = self.bm25_index {
bm25.delete_documents(ids).await?;
}
Ok(deleted)
}
pub async fn purge_namespace(&self, namespace: &str) -> Result<usize> {
let deleted = self.storage.purge_namespace(namespace).await?;
if let Some(ref bm25) = self.bm25_index {
bm25.purge_namespace(namespace).await?;
}
Ok(deleted)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_search_mode_parsing() {
assert_eq!("vector".parse::<SearchMode>().unwrap(), SearchMode::Vector);
assert_eq!(
"keyword".parse::<SearchMode>().unwrap(),
SearchMode::Keyword
);
assert_eq!("bm25".parse::<SearchMode>().unwrap(), SearchMode::Keyword);
assert_eq!("hybrid".parse::<SearchMode>().unwrap(), SearchMode::Hybrid);
assert!("invalid".parse::<SearchMode>().is_err());
}
#[test]
fn test_default_config() {
let config = HybridConfig::default();
assert_eq!(config.mode, SearchMode::Hybrid);
assert_eq!(config.vector_weight, 0.6);
assert_eq!(config.bm25_weight, 0.4);
assert!(!config.use_rrf);
}
}