use crate::bm25::{
adapter::{Bm25Adapter, SearchOptions as Bm25SearchOptions},
result::Bm25Result,
};
use crate::common::namespace::Namespace;
use crate::hybrid::{
config::{FusionMethod, HybridConfig},
confidence::apply_confidence,
error::{Error, Result},
fusion::{FusedScore, FusionStrategy, ScoredItem},
fusion::{linear::Linear, rrf::Rrf},
result::HybridResult,
};
use crate::vector::{
adapter::{SearchOptions as VectorSearchOptions, VectorAdapter},
result::VectorResult,
};
#[derive(Debug, Clone, Default)]
pub struct HybridSearchOptions {
pub bm25_candidates: Option<usize>,
pub vector_candidates: Option<usize>,
pub bm25_weight: Option<f32>,
pub vector_weight: Option<f32>,
pub fusion: Option<FusionMethod>,
pub limit: usize,
pub compute_confidence: bool,
}
impl HybridSearchOptions {
pub fn new() -> Self {
Self { limit: 20, compute_confidence: true, ..Default::default() }
}
pub fn with_limit(mut self, limit: usize) -> Self {
self.limit = limit;
self
}
pub fn without_confidence(mut self) -> Self {
self.compute_confidence = false;
self
}
}
pub struct HybridSearch<B, V>
where
B: Bm25Adapter,
V: VectorAdapter,
{
bm25: B,
vector: V,
config: HybridConfig,
}
impl<B: Bm25Adapter, V: VectorAdapter> HybridSearch<B, V> {
pub fn new(bm25: B, vector: V, config: HybridConfig) -> Self {
Self { bm25, vector, config }
}
pub async fn search(
&self,
query: &str,
query_vector: &[f32],
namespace: Option<&Namespace>,
options: HybridSearchOptions,
) -> Result<Vec<HybridResult>> {
let bm25_limit = options.bm25_candidates.unwrap_or(self.config.bm25_candidates);
let vec_limit = options.vector_candidates.unwrap_or(self.config.vector_candidates);
let bm25_weight = options.bm25_weight.unwrap_or(self.config.bm25_weight);
let vector_weight = options.vector_weight.unwrap_or(self.config.vector_weight);
let fusion_method = options.fusion.unwrap_or(self.config.fusion);
let bm25_results = self.fetch_bm25(query, namespace, bm25_limit).await?;
let vector_results = self.fetch_vector(query_vector, namespace, vec_limit).await?;
let bm25_items: Vec<ScoredItem> = bm25_results
.iter()
.map(|r| ScoredItem::new(r.id.clone(), r.score))
.collect();
let vector_items: Vec<ScoredItem> = vector_results
.iter()
.map(|r| ScoredItem::new(r.id.clone(), r.score()))
.collect();
let fused: std::collections::HashMap<String, FusedScore> = match fusion_method {
FusionMethod::Rrf => {
Rrf::new(bm25_weight, vector_weight, self.config.rrf_k)
.fuse(&bm25_items, &vector_items)
}
FusionMethod::Linear => {
Linear::new(bm25_weight, vector_weight, self.config.normalisation)
.fuse(&bm25_items, &vector_items)
}
};
let mut results: Vec<HybridResult> = fused
.into_iter()
.map(|(id, score)| {
HybridResult::new(id, score.hybrid, score.bm25, score.vector, fusion_method_label(fusion_method))
})
.collect();
results.sort_by(|a, b| {
b.hybrid_score
.partial_cmp(&a.hybrid_score)
.unwrap_or(std::cmp::Ordering::Equal)
});
results.truncate(options.limit);
if options.compute_confidence {
apply_confidence(&mut results);
}
Ok(results)
}
async fn fetch_bm25(
&self,
query: &str,
namespace: Option<&Namespace>,
limit: usize,
) -> Result<Vec<Bm25Result>> {
self.bm25
.search(query, namespace, Bm25SearchOptions::default().with_limit(limit))
.await
.map_err(Error::from)
}
async fn fetch_vector(
&self,
query_vector: &[f32],
namespace: Option<&Namespace>,
limit: usize,
) -> Result<Vec<VectorResult>> {
self.vector
.nearest_neighbors(
query_vector,
namespace,
VectorSearchOptions::default().with_limit(limit),
)
.await
.map_err(Error::from)
}
}
fn fusion_method_label(method: FusionMethod) -> &'static str {
match method {
FusionMethod::Rrf => "rrf",
FusionMethod::Linear => "linear",
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::bm25::adapters::MemoryBm25Adapter;
use crate::bm25::adapter::IndexDocument;
use crate::bm25::config::Bm25Config;
use crate::vector::adapters::MemoryVectorAdapter;
use crate::vector::VectorConfig;
use crate::hybrid::config::HybridConfig;
async fn setup() -> HybridSearch<MemoryBm25Adapter, MemoryVectorAdapter> {
let bm25 = MemoryBm25Adapter::connect(Bm25Config::default()).await.unwrap();
let vector = MemoryVectorAdapter::connect(VectorConfig::with_dimension(2)).await.unwrap();
HybridSearch::new(bm25, vector, HybridConfig::default())
}
async fn setup_with_data() -> HybridSearch<MemoryBm25Adapter, MemoryVectorAdapter> {
let bm25 = MemoryBm25Adapter::connect(Bm25Config::default()).await.unwrap();
let vector = MemoryVectorAdapter::connect(VectorConfig::with_dimension(2)).await.unwrap();
bm25.index(IndexDocument::new("rust", "rust systems programming language"), None).await.unwrap();
bm25.index(IndexDocument::new("python", "python scripting easy language"), None).await.unwrap();
bm25.index(IndexDocument::new("go", "go concurrency goroutines"), None).await.unwrap();
vector.upsert("rust", vec![1.0, 0.0], None, None).await.unwrap();
vector.upsert("python", vec![0.7, 0.7], None, None).await.unwrap();
vector.upsert("go", vec![0.0, 1.0], None, None).await.unwrap();
HybridSearch::new(bm25, vector, HybridConfig::default())
}
fn opts() -> HybridSearchOptions {
HybridSearchOptions::new()
}
#[tokio::test]
async fn search_empty_index_returns_no_results() {
let hs = setup().await;
let results = hs.search("rust", &[1.0, 0.0], None, opts()).await.unwrap();
assert!(results.is_empty());
}
#[tokio::test]
async fn search_returns_results_when_data_present() {
let hs = setup_with_data().await;
let results = hs.search("rust", &[1.0, 0.0], None, opts()).await.unwrap();
assert!(!results.is_empty());
}
#[tokio::test]
async fn search_result_for_rust_query_favours_rust_document() {
let hs = setup_with_data().await;
let results = hs.search("rust", &[1.0, 0.0], None, opts()).await.unwrap();
assert_eq!(results[0].id, "rust");
}
#[tokio::test]
async fn results_are_sorted_descending_by_hybrid_score() {
let hs = setup_with_data().await;
let results = hs.search("language", &[0.8, 0.2], None, opts()).await.unwrap();
for window in results.windows(2) {
assert!(window[0].hybrid_score >= window[1].hybrid_score);
}
}
#[tokio::test]
async fn confidence_is_set_when_requested() {
let hs = setup_with_data().await;
let results = hs.search("rust", &[1.0, 0.0], None, opts()).await.unwrap();
assert!(results.iter().all(|r| r.confidence_score.is_some()));
}
#[tokio::test]
async fn confidence_not_set_when_disabled() {
let hs = setup_with_data().await;
let results = hs
.search("rust", &[1.0, 0.0], None, opts().without_confidence())
.await
.unwrap();
assert!(results.iter().all(|r| r.confidence_score.is_none()));
}
#[tokio::test]
async fn limit_is_respected() {
let hs = setup_with_data().await;
let results = hs.search("language", &[0.7, 0.3], None, opts().with_limit(1)).await.unwrap();
assert_eq!(results.len(), 1);
}
#[tokio::test]
async fn result_in_both_sources_has_both_scores() {
let hs = setup_with_data().await;
let results = hs.search("rust", &[1.0, 0.0], None, opts()).await.unwrap();
let rust = results.iter().find(|r| r.id == "rust");
if let Some(r) = rust {
assert!(r.bm25_score.is_some());
assert!(r.vector_score.is_some());
}
}
#[tokio::test]
async fn linear_fusion_returns_results() {
let hs = setup_with_data().await;
let mut o = opts();
o.fusion = Some(FusionMethod::Linear);
let results = hs.search("rust", &[1.0, 0.0], None, o).await.unwrap();
assert!(!results.is_empty());
}
#[tokio::test]
async fn results_have_positive_hybrid_score() {
let hs = setup_with_data().await;
let results = hs.search("rust", &[1.0, 0.0], None, opts()).await.unwrap();
assert!(results.iter().all(|r| r.hybrid_score >= 0.0));
}
}