use crate::distance::Distance;
use crate::error::{Error, Result};
use crate::index::{BM25Index, Index};
use crate::query::{compare_metadata_values, Filter, FilterEvaluator, OrderBy, SortDirection};
use crate::storage::Storage;
use crate::types::HybridSearchResult;
use super::rrf::{weighted_reciprocal_rank_fusion, RankedResult, DEFAULT_RRF_K};
#[derive(Debug, Clone, Default)]
pub enum SearchMode {
#[default]
Vector,
Keyword,
Hybrid {
vector_weight: f32,
keyword_weight: f32,
},
FilterOnly,
}
#[derive(Debug, Clone)]
pub struct HybridSearchParams {
pub vector: Option<Vec<f32>>,
pub text_query: Option<String>,
pub filter: Option<Filter>,
pub mode: SearchMode,
pub k: usize,
pub offset: usize,
pub order_by: Option<OrderBy>,
}
impl HybridSearchParams {
pub fn vector(query: Vec<f32>, k: usize) -> Self {
Self {
vector: Some(query),
text_query: None,
filter: None,
mode: SearchMode::Vector,
k,
offset: 0,
order_by: None,
}
}
pub fn keyword(query: impl Into<String>, k: usize) -> Self {
Self {
vector: None,
text_query: Some(query.into()),
filter: None,
mode: SearchMode::Keyword,
k,
offset: 0,
order_by: None,
}
}
pub fn hybrid(vector: Vec<f32>, text: impl Into<String>, k: usize) -> Self {
Self {
vector: Some(vector),
text_query: Some(text.into()),
filter: None,
mode: SearchMode::Hybrid {
vector_weight: 0.5,
keyword_weight: 0.5,
},
k,
offset: 0,
order_by: None,
}
}
pub fn filter_only(filter: Filter, limit: usize) -> Self {
Self {
vector: None,
text_query: None,
filter: Some(filter),
mode: SearchMode::FilterOnly,
k: limit,
offset: 0,
order_by: None,
}
}
pub fn with_filter(mut self, filter: Filter) -> Self {
self.filter = Some(filter);
self
}
pub fn with_offset(mut self, offset: usize) -> Self {
self.offset = offset;
self
}
pub fn with_order_by(mut self, order: OrderBy) -> Self {
self.order_by = Some(order);
self
}
pub fn with_weights(mut self, vector_weight: f32, keyword_weight: f32) -> Self {
self.mode = SearchMode::Hybrid {
vector_weight,
keyword_weight,
};
self
}
}
pub struct HybridSearch;
impl HybridSearch {
pub fn search(
params: &HybridSearchParams,
vector_index: &dyn Index,
bm25_index: Option<&BM25Index>,
storage: &dyn Storage,
distance: Distance,
) -> Result<Vec<HybridSearchResult>> {
let mut results = match ¶ms.mode {
SearchMode::Vector => Self::vector_search(params, vector_index, storage, distance)?,
SearchMode::Keyword => Self::keyword_search(params, bm25_index, storage)?,
SearchMode::Hybrid {
vector_weight,
keyword_weight,
} => Self::hybrid_search(
params,
vector_index,
bm25_index,
storage,
distance,
*vector_weight,
*keyword_weight,
)?,
SearchMode::FilterOnly => Self::filter_only_search(params, storage)?,
};
results.retain(|r| {
!matches!(
r.metadata.as_ref().and_then(|m| m.get("deleted")),
Some(crate::types::MetadataValue::Bool(true))
)
});
if let Some(ref order) = params.order_by {
let field = &order.field;
results.sort_by(|a, b| {
let val_a = a.metadata.as_ref().and_then(|m| m.get(field));
let val_b = b.metadata.as_ref().and_then(|m| m.get(field));
let cmp = compare_metadata_values(val_a, val_b);
match order.direction {
SortDirection::Asc => cmp,
SortDirection::Desc => cmp.reverse(),
}
});
}
if params.offset > 0 {
if params.offset >= results.len() {
return Ok(vec![]);
}
results = results.into_iter().skip(params.offset).collect();
}
if let Some(ref _order) = params.order_by {
results.truncate(params.k);
}
Ok(results)
}
fn vector_search(
params: &HybridSearchParams,
index: &dyn Index,
storage: &dyn Storage,
distance: Distance,
) -> Result<Vec<HybridSearchResult>> {
let query = params.vector.as_ref().ok_or_else(|| {
Error::InvalidConfig("Vector query required for vector search".into())
})?;
let search_k = if params.filter.is_some() {
params.k * 10 } else {
params.k
};
let results = index.search(query, search_k, storage, distance)?;
let filtered: Vec<_> = results
.into_iter()
.filter(|r| {
if let Some(filter) = ¶ms.filter {
FilterEvaluator::evaluate(filter, r.metadata.as_ref())
} else {
true
}
})
.take(params.k)
.enumerate()
.map(|(rank, r)| HybridSearchResult {
id: r.id,
score: r.distance, vector_distance: Some(r.distance),
bm25_score: None,
vector_rank: Some(rank),
keyword_rank: None,
metadata: r.metadata,
})
.collect();
Ok(filtered)
}
fn keyword_search(
params: &HybridSearchParams,
bm25_index: Option<&BM25Index>,
storage: &dyn Storage,
) -> Result<Vec<HybridSearchResult>> {
let query = params
.text_query
.as_ref()
.ok_or_else(|| Error::InvalidConfig("Text query required for keyword search".into()))?;
let index = bm25_index
.ok_or_else(|| Error::InvalidConfig("BM25 index required for keyword search".into()))?;
let search_k = if params.filter.is_some() {
params.k * 10
} else {
params.k
};
let results = index.search(query, search_k);
let mut hybrid_results = Vec::new();
for (rank, result) in results.into_iter().enumerate() {
if let Ok(Some(doc)) = storage.get(&result.id) {
if let Some(filter) = ¶ms.filter {
if !FilterEvaluator::evaluate(filter, doc.metadata.as_ref()) {
continue;
}
}
hybrid_results.push(HybridSearchResult {
id: result.id,
score: -result.score, vector_distance: None,
bm25_score: Some(result.score),
vector_rank: None,
keyword_rank: Some(rank),
metadata: doc.metadata,
});
if hybrid_results.len() >= params.k {
break;
}
}
}
Ok(hybrid_results)
}
fn hybrid_search(
params: &HybridSearchParams,
vector_index: &dyn Index,
bm25_index: Option<&BM25Index>,
storage: &dyn Storage,
distance: Distance,
vector_weight: f32,
keyword_weight: f32,
) -> Result<Vec<HybridSearchResult>> {
let fetch_k = params.k * 3;
let vector_results = if let Some(query) = ¶ms.vector {
let results = vector_index.search(query, fetch_k, storage, distance)?;
results
.into_iter()
.enumerate()
.map(|(rank, r)| RankedResult {
id: r.id,
rank,
original_score: r.distance,
})
.collect()
} else {
Vec::new()
};
let keyword_results = if let (Some(query), Some(index)) = (¶ms.text_query, bm25_index) {
index
.search(query, fetch_k)
.into_iter()
.enumerate()
.map(|(rank, result)| RankedResult {
id: result.id,
rank,
original_score: result.score,
})
.collect()
} else {
Vec::new()
};
let vector_info: std::collections::HashMap<_, _> = vector_results
.iter()
.map(|r| (r.id.clone(), (r.rank, r.original_score)))
.collect();
let keyword_info: std::collections::HashMap<_, _> = keyword_results
.iter()
.map(|r| (r.id.clone(), (r.rank, r.original_score)))
.collect();
let rrf_results = weighted_reciprocal_rank_fusion(
vec![
(vector_results, vector_weight),
(keyword_results, keyword_weight),
],
DEFAULT_RRF_K,
);
let mut final_results = Vec::new();
for (id, rrf_score) in rrf_results {
if let Ok(Some(doc)) = storage.get(&id) {
if let Some(filter) = ¶ms.filter {
if !FilterEvaluator::evaluate(filter, doc.metadata.as_ref()) {
continue;
}
}
let (vec_rank, vec_dist) = vector_info
.get(&id)
.map(|(r, d)| (Some(*r), Some(*d)))
.unwrap_or((None, None));
let (kw_rank, kw_score) = keyword_info
.get(&id)
.map(|(r, s)| (Some(*r), Some(*s)))
.unwrap_or((None, None));
final_results.push(HybridSearchResult {
id,
score: -rrf_score, vector_distance: vec_dist,
bm25_score: kw_score,
vector_rank: vec_rank,
keyword_rank: kw_rank,
metadata: doc.metadata,
});
if final_results.len() >= params.k {
break;
}
}
}
Ok(final_results)
}
fn filter_only_search(
params: &HybridSearchParams,
storage: &dyn Storage,
) -> Result<Vec<HybridSearchResult>> {
let filter = params
.filter
.as_ref()
.ok_or_else(|| Error::InvalidConfig("Filter required for filter-only search".into()))?;
let need_all = params.order_by.is_some() || params.offset > 0;
let take_limit = if need_all {
100_000
} else {
params.k
};
let results: Vec<_> = storage
.iter()
.filter(|doc| FilterEvaluator::evaluate(filter, doc.metadata.as_ref()))
.take(take_limit)
.map(|doc| HybridSearchResult {
id: doc.id,
score: 0.0, vector_distance: None,
bm25_score: None,
vector_rank: None,
keyword_rank: None,
metadata: doc.metadata,
})
.collect();
Ok(results)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::index::FlatIndex;
use crate::storage::MemoryStorage;
use crate::types::Metadata;
use std::sync::Arc;
fn setup_test_data() -> (Arc<MemoryStorage>, Arc<FlatIndex>, Arc<BM25Index>) {
let storage = Arc::new(MemoryStorage::new());
let vector_index = Arc::new(FlatIndex::new());
let bm25_index = Arc::new(BM25Index::new(vec!["title".into(), "content".into()]));
let mut meta1 = Metadata::new();
meta1.insert("title", "Rust Programming");
meta1.insert("content", "Learn Rust systems programming");
meta1.insert("category", "tech");
storage
.insert(
"doc-1".into(),
Some(vec![1.0, 0.0, 0.0]),
Some(meta1.clone()),
)
.unwrap();
vector_index
.add("doc-1", &[1.0, 0.0, 0.0], &*storage, Distance::Cosine)
.unwrap();
bm25_index.add("doc-1", Some(&meta1)).unwrap();
let mut meta2 = Metadata::new();
meta2.insert("title", "Python Guide");
meta2.insert("content", "Python for beginners programming");
meta2.insert("category", "tech");
storage
.insert(
"doc-2".into(),
Some(vec![0.0, 1.0, 0.0]),
Some(meta2.clone()),
)
.unwrap();
vector_index
.add("doc-2", &[0.0, 1.0, 0.0], &*storage, Distance::Cosine)
.unwrap();
bm25_index.add("doc-2", Some(&meta2)).unwrap();
let mut meta3 = Metadata::new();
meta3.insert("title", "Cooking Recipes");
meta3.insert("content", "Delicious food recipes");
meta3.insert("category", "food");
storage
.insert(
"doc-3".into(),
Some(vec![0.0, 0.0, 1.0]),
Some(meta3.clone()),
)
.unwrap();
vector_index
.add("doc-3", &[0.0, 0.0, 1.0], &*storage, Distance::Cosine)
.unwrap();
bm25_index.add("doc-3", Some(&meta3)).unwrap();
(storage, vector_index, bm25_index)
}
#[test]
fn test_vector_search() {
let (storage, vector_index, _) = setup_test_data();
let params = HybridSearchParams::vector(vec![1.0, 0.0, 0.0], 2);
let results = HybridSearch::search(
¶ms,
vector_index.as_ref(),
None,
storage.as_ref(),
Distance::Euclidean,
)
.unwrap();
assert_eq!(results.len(), 2);
assert_eq!(results[0].id, "doc-1"); }
#[test]
fn test_keyword_search() {
let (storage, _, bm25_index) = setup_test_data();
let vector_index = FlatIndex::new();
let params = HybridSearchParams::keyword("rust programming", 2);
let results = HybridSearch::search(
¶ms,
&vector_index,
Some(bm25_index.as_ref()),
storage.as_ref(),
Distance::Euclidean,
)
.unwrap();
assert!(!results.is_empty());
assert_eq!(results[0].id, "doc-1"); }
#[test]
fn test_hybrid_search() {
let (storage, vector_index, bm25_index) = setup_test_data();
let params = HybridSearchParams::hybrid(
vec![0.0, 1.0, 0.0], "rust", 3,
);
let results = HybridSearch::search(
¶ms,
vector_index.as_ref(),
Some(bm25_index.as_ref()),
storage.as_ref(),
Distance::Euclidean,
)
.unwrap();
assert!(!results.is_empty());
}
#[test]
fn test_filter_search() {
let (storage, vector_index, _) = setup_test_data();
let filter = Filter::eq("category", "tech");
let params = HybridSearchParams::vector(vec![0.5, 0.5, 0.0], 10).with_filter(filter);
let results = HybridSearch::search(
¶ms,
vector_index.as_ref(),
None,
storage.as_ref(),
Distance::Euclidean,
)
.unwrap();
assert_eq!(results.len(), 2);
for r in &results {
assert!(r.id == "doc-1" || r.id == "doc-2");
}
}
#[test]
fn test_filter_only_search() {
let (storage, vector_index, _) = setup_test_data();
let filter = Filter::eq("category", "food");
let params = HybridSearchParams::filter_only(filter, 10);
let results = HybridSearch::search(
¶ms,
vector_index.as_ref(),
None,
storage.as_ref(),
Distance::Euclidean,
)
.unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].id, "doc-3");
}
}