use super::bm25::{BM25Index, BM25Config};
use super::fusion::{FusionMethod, reciprocal_rank_fusion, weighted_sum_fusion, max_score_fusion};
use crate::embeddings::EmbeddingProvider;
use crate::vector_store::{Filter, SearchResult, VectorStore};
use anyhow::{Context, Result};
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct HybridConfig {
pub dense_weight: f32,
pub sparse_weight: f32,
pub rrf_k: f32,
pub fusion_method: FusionMethod,
pub retrieval_multiplier: usize,
pub bm25_config: BM25Config,
}
impl Default for HybridConfig {
fn default() -> Self {
Self {
dense_weight: 0.7,
sparse_weight: 0.3,
rrf_k: 60.0,
fusion_method: FusionMethod::ReciprocalRank,
retrieval_multiplier: 3,
bm25_config: BM25Config::in_memory(),
}
}
}
impl HybridConfig {
pub fn with_weights(dense_weight: f32, sparse_weight: f32) -> Self {
Self {
dense_weight,
sparse_weight,
..Default::default()
}
}
pub fn with_fusion(mut self, method: FusionMethod) -> Self {
self.fusion_method = method;
self
}
pub fn with_rrf_k(mut self, k: f32) -> Self {
self.rrf_k = k;
self
}
}
#[derive(Debug, Clone)]
pub struct HybridSearchResult {
pub id: String,
pub score: f32,
pub dense_score: Option<f32>,
pub sparse_score: Option<f32>,
pub metadata: Option<crate::vector_store::DocumentMetadata>,
}
pub struct HybridRetriever<V: VectorStore, E: EmbeddingProvider> {
dense_store: Arc<V>,
sparse_index: BM25Index,
embedding_provider: Arc<E>,
config: HybridConfig,
}
impl<V: VectorStore, E: EmbeddingProvider> HybridRetriever<V, E> {
pub fn new(
dense_store: Arc<V>,
embedding_provider: Arc<E>,
config: HybridConfig,
) -> Result<Self> {
let sparse_index = BM25Index::new(config.bm25_config.clone())
.context("Failed to create BM25 index")?;
Ok(Self {
dense_store,
sparse_index,
embedding_provider,
config,
})
}
pub fn with_index(
dense_store: Arc<V>,
sparse_index: BM25Index,
embedding_provider: Arc<E>,
config: HybridConfig,
) -> Self {
Self {
dense_store,
sparse_index,
embedding_provider,
config,
}
}
pub fn add_to_sparse_index(
&mut self,
id: &str,
tool_name: &str,
skill_name: &str,
description: &str,
full_text: &str,
) -> Result<()> {
self.sparse_index.add_document(id, tool_name, skill_name, description, full_text)
}
pub fn commit_sparse_index(&mut self) -> Result<()> {
self.sparse_index.commit()
}
pub fn clear_sparse_index(&mut self) -> Result<()> {
self.sparse_index.clear()
}
pub async fn search(
&self,
query: &str,
filter: Option<Filter>,
top_k: usize,
) -> Result<Vec<HybridSearchResult>> {
let expanded_k = top_k * self.config.retrieval_multiplier;
let query_embedding = self.embedding_provider
.embed_query(query)
.await
.context("Failed to embed query")?;
let dense_results = self.dense_store
.search(query_embedding, filter, expanded_k)
.await
.context("Dense search failed")?;
let sparse_results = self.sparse_index
.search(query, expanded_k)
.context("Sparse search failed")?;
let dense_ranked: Vec<(String, f32)> = dense_results
.iter()
.map(|r| (r.id.clone(), r.score))
.collect();
let sparse_ranked: Vec<(String, f32)> = sparse_results
.iter()
.map(|r| (r.id.clone(), r.score))
.collect();
let fused_results = match self.config.fusion_method {
FusionMethod::ReciprocalRank => reciprocal_rank_fusion(
vec![("dense", dense_ranked), ("sparse", sparse_ranked)],
self.config.rrf_k,
top_k,
),
FusionMethod::WeightedSum => weighted_sum_fusion(
vec![
("dense", self.config.dense_weight, dense_ranked),
("sparse", self.config.sparse_weight, sparse_ranked),
],
top_k,
),
FusionMethod::MaxScore => max_score_fusion(
vec![("dense", dense_ranked), ("sparse", sparse_ranked)],
top_k,
),
};
let dense_metadata: std::collections::HashMap<_, _> = dense_results
.into_iter()
.map(|r| (r.id, r.metadata))
.collect();
let results = fused_results
.into_iter()
.map(|fused| {
HybridSearchResult {
id: fused.id.clone(),
score: fused.score,
dense_score: fused.source_scores.get("dense").copied(),
sparse_score: fused.source_scores.get("sparse").copied(),
metadata: dense_metadata.get(&fused.id).cloned(),
}
})
.collect();
Ok(results)
}
pub async fn dense_only(
&self,
query: &str,
filter: Option<Filter>,
top_k: usize,
) -> Result<Vec<SearchResult>> {
let query_embedding = self.embedding_provider
.embed_query(query)
.await
.context("Failed to embed query")?;
self.dense_store.search(query_embedding, filter, top_k).await
}
pub fn sparse_only(&self, query: &str, top_k: usize) -> Result<Vec<super::bm25::BM25SearchResult>> {
self.sparse_index.search(query, top_k)
}
pub fn config(&self) -> &HybridConfig {
&self.config
}
pub fn sparse_document_count(&self) -> u64 {
self.sparse_index.document_count()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_default() {
let config = HybridConfig::default();
assert!((config.dense_weight - 0.7).abs() < 0.001);
assert!((config.sparse_weight - 0.3).abs() < 0.001);
assert_eq!(config.fusion_method, FusionMethod::ReciprocalRank);
}
#[test]
fn test_config_builder() {
let config = HybridConfig::with_weights(0.5, 0.5)
.with_fusion(FusionMethod::WeightedSum)
.with_rrf_k(30.0);
assert!((config.dense_weight - 0.5).abs() < 0.001);
assert!((config.sparse_weight - 0.5).abs() < 0.001);
assert_eq!(config.fusion_method, FusionMethod::WeightedSum);
assert!((config.rrf_k - 30.0).abs() < 0.001);
}
}