use crate::hnsw::{GraphError, HnswIndex, VectorId};
use crate::sparse::{SparseSearcher, SparseStorage, SparseVector};
use crate::storage::VectorStorage;
use thiserror::Error;
use super::fusion::{linear_fusion, rrf_fusion, FusionMethod, FusionResult};
#[derive(Debug, Clone, PartialEq, Error)]
pub enum HybridError {
#[error("Invalid config: {0}")]
InvalidConfig(String),
#[error("Dense search error: {0}")]
DenseSearchError(String),
#[error("Dimension mismatch: expected {expected}, got {actual}")]
DimensionMismatch {
expected: usize,
actual: usize,
},
}
impl From<GraphError> for HybridError {
fn from(err: GraphError) -> Self {
match err {
GraphError::DimensionMismatch { expected, actual } => {
HybridError::DimensionMismatch { expected, actual }
}
other => HybridError::DenseSearchError(other.to_string()),
}
}
}
#[derive(Clone, Debug)]
pub struct HybridSearchConfig {
pub dense_k: usize,
pub sparse_k: usize,
pub final_k: usize,
pub fusion: FusionMethod,
}
impl Default for HybridSearchConfig {
fn default() -> Self {
Self {
dense_k: 20,
sparse_k: 20,
final_k: 10,
fusion: FusionMethod::default(),
}
}
}
impl HybridSearchConfig {
#[must_use]
pub fn new(dense_k: usize, sparse_k: usize, final_k: usize, fusion: FusionMethod) -> Self {
Self {
dense_k,
sparse_k,
final_k,
fusion,
}
}
#[must_use]
pub fn rrf(dense_k: usize, sparse_k: usize, final_k: usize) -> Self {
Self::new(dense_k, sparse_k, final_k, FusionMethod::rrf())
}
#[must_use]
pub fn rrf_with_k(dense_k: usize, sparse_k: usize, final_k: usize, rrf_k: u32) -> Self {
Self::new(dense_k, sparse_k, final_k, FusionMethod::rrf_with_k(rrf_k))
}
pub fn linear(
dense_k: usize,
sparse_k: usize,
final_k: usize,
alpha: f32,
) -> Result<Self, String> {
Ok(Self::new(
dense_k,
sparse_k,
final_k,
FusionMethod::linear(alpha)?,
))
}
pub fn validate(&self) -> Result<(), String> {
if self.dense_k == 0 && self.sparse_k == 0 {
return Err("At least one of dense_k or sparse_k must be > 0".to_string());
}
if self.final_k == 0 {
return Err("final_k must be > 0".to_string());
}
Ok(())
}
}
#[derive(Clone, Debug, PartialEq)]
pub struct HybridSearchResult {
pub id: VectorId,
pub score: f32,
pub dense_rank: Option<usize>,
pub dense_score: Option<f32>,
pub sparse_rank: Option<usize>,
pub sparse_score: Option<f32>,
}
impl HybridSearchResult {
fn from_fusion(
fusion: &FusionResult,
dense_scores: &std::collections::HashMap<u64, f32>,
sparse_scores: &std::collections::HashMap<u64, f32>,
) -> Self {
Self {
id: VectorId(fusion.id),
score: fusion.score,
dense_rank: fusion.dense_rank,
dense_score: dense_scores.get(&fusion.id).copied(),
sparse_rank: fusion.sparse_rank,
sparse_score: sparse_scores.get(&fusion.id).copied(),
}
}
}
pub struct HybridSearcher<'a> {
index: &'a HnswIndex,
dense_storage: &'a VectorStorage,
sparse_storage: &'a SparseStorage,
}
impl<'a> HybridSearcher<'a> {
#[must_use]
pub fn new(
index: &'a HnswIndex,
dense_storage: &'a VectorStorage,
sparse_storage: &'a SparseStorage,
) -> Self {
#[cfg(debug_assertions)]
if dense_storage.len() != sparse_storage.len() {
eprintln!(
"[HybridSearcher] Warning: dense_storage.len()={} != sparse_storage.len()={}. \
This may indicate ID misalignment.",
dense_storage.len(),
sparse_storage.len()
);
}
Self {
index,
dense_storage,
sparse_storage,
}
}
pub fn search(
&self,
dense_query: &[f32],
sparse_query: &SparseVector,
config: &HybridSearchConfig,
) -> Result<Vec<HybridSearchResult>, HybridError> {
config.validate().map_err(HybridError::InvalidConfig)?;
let dense_results = if config.dense_k > 0 {
self.index
.search(dense_query, config.dense_k, self.dense_storage)?
.into_iter()
.map(|r| (r.vector_id.0, r.distance))
.collect::<Vec<_>>()
} else {
Vec::new()
};
let sparse_searcher = SparseSearcher::new(self.sparse_storage);
let sparse_results = if config.sparse_k > 0 {
sparse_searcher.search_u64(sparse_query, config.sparse_k)
} else {
Vec::new()
};
let fused = match &config.fusion {
FusionMethod::Rrf { k } => {
rrf_fusion(&dense_results, &sparse_results, *k, config.final_k)
}
FusionMethod::Linear { alpha } => {
linear_fusion(&dense_results, &sparse_results, *alpha, config.final_k)
}
};
let dense_scores: std::collections::HashMap<u64, f32> =
dense_results.iter().copied().collect();
let sparse_scores: std::collections::HashMap<u64, f32> =
sparse_results.iter().copied().collect();
let results = fused
.iter()
.map(|f| HybridSearchResult::from_fusion(f, &dense_scores, &sparse_scores))
.collect();
Ok(results)
}
pub fn search_dense_only(
&self,
dense_query: &[f32],
k: usize,
) -> Result<Vec<HybridSearchResult>, HybridError> {
let config = HybridSearchConfig {
dense_k: k,
sparse_k: 0,
final_k: k,
fusion: FusionMethod::rrf(), };
let sparse_query = SparseVector::singleton(0, 0.0, 1)
.map_err(|e| HybridError::InvalidConfig(e.to_string()))?;
self.search(dense_query, &sparse_query, &config)
}
pub fn search_sparse_only(
&self,
sparse_query: &SparseVector,
k: usize,
) -> Result<Vec<HybridSearchResult>, HybridError> {
let config = HybridSearchConfig {
dense_k: 0,
sparse_k: k,
final_k: k,
fusion: FusionMethod::rrf(), };
let dense_query = vec![0.0; self.dense_storage.dimensions() as usize];
self.search(&dense_query, sparse_query, &config)
}
#[must_use]
pub fn components(&self) -> (&HnswIndex, &VectorStorage, &SparseStorage) {
(self.index, self.dense_storage, self.sparse_storage)
}
#[must_use]
pub fn dense_count(&self) -> usize {
self.dense_storage.len()
}
#[must_use]
pub fn sparse_count(&self) -> usize {
self.sparse_storage.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_default() {
let config = HybridSearchConfig::default();
assert_eq!(config.dense_k, 20);
assert_eq!(config.sparse_k, 20);
assert_eq!(config.final_k, 10);
}
#[test]
fn test_config_rrf() {
let config = HybridSearchConfig::rrf(30, 40, 15);
assert_eq!(config.dense_k, 30);
assert_eq!(config.sparse_k, 40);
assert_eq!(config.final_k, 15);
assert!(matches!(config.fusion, FusionMethod::Rrf { k: 60 }));
}
#[test]
fn test_config_rrf_with_k() {
let config = HybridSearchConfig::rrf_with_k(20, 20, 10, 100);
assert!(matches!(config.fusion, FusionMethod::Rrf { k: 100 }));
}
#[test]
fn test_config_linear() {
let config = HybridSearchConfig::linear(20, 20, 10, 0.7).unwrap();
match config.fusion {
FusionMethod::Linear { alpha } => assert!((alpha - 0.7).abs() < 1e-6),
FusionMethod::Rrf { .. } => panic!("Expected Linear fusion"),
}
}
#[test]
fn test_config_validation_both_zero() {
let config = HybridSearchConfig {
dense_k: 0,
sparse_k: 0,
final_k: 10,
fusion: FusionMethod::rrf(),
};
assert!(config.validate().is_err());
}
#[test]
fn test_config_validation_final_zero() {
let config = HybridSearchConfig {
dense_k: 10,
sparse_k: 10,
final_k: 0,
fusion: FusionMethod::rrf(),
};
assert!(config.validate().is_err());
}
#[test]
fn test_config_validation_valid() {
let config = HybridSearchConfig::default();
assert!(config.validate().is_ok());
let config = HybridSearchConfig {
dense_k: 10,
sparse_k: 0,
final_k: 5,
fusion: FusionMethod::rrf(),
};
assert!(config.validate().is_ok());
let config = HybridSearchConfig {
dense_k: 0,
sparse_k: 10,
final_k: 5,
fusion: FusionMethod::rrf(),
};
assert!(config.validate().is_ok());
}
#[test]
fn test_hybrid_error_display() {
let err = HybridError::InvalidConfig("test".to_string());
assert!(err.to_string().contains("Invalid config"));
let err = HybridError::DimensionMismatch {
expected: 128,
actual: 64,
};
assert!(err.to_string().contains("128"));
assert!(err.to_string().contains("64"));
}
#[test]
fn test_hybrid_error_from_graph_error() {
let graph_err = GraphError::DimensionMismatch {
expected: 128,
actual: 64,
};
let hybrid_err: HybridError = graph_err.into();
assert!(matches!(
hybrid_err,
HybridError::DimensionMismatch {
expected: 128,
actual: 64
}
));
}
#[test]
fn test_result_from_fusion() {
use std::collections::HashMap;
let fusion = FusionResult::with_ranks(42, 0.5, Some(1), Some(2));
let dense_scores: HashMap<u64, f32> = [(42, 0.9), (100, 0.8)].into_iter().collect();
let sparse_scores: HashMap<u64, f32> = [(42, 5.0), (200, 4.0)].into_iter().collect();
let result = HybridSearchResult::from_fusion(&fusion, &dense_scores, &sparse_scores);
assert_eq!(result.id.0, 42);
assert!((result.score - 0.5).abs() < f32::EPSILON);
assert_eq!(result.dense_rank, Some(1));
assert_eq!(result.sparse_rank, Some(2));
assert_eq!(result.dense_score, Some(0.9));
assert_eq!(result.sparse_score, Some(5.0));
}
#[test]
fn test_result_from_fusion_missing_scores() {
use std::collections::HashMap;
let fusion = FusionResult::with_ranks(42, 0.5, Some(1), None);
let dense_scores: HashMap<u64, f32> = [(42, 0.9)].into_iter().collect();
let sparse_scores: HashMap<u64, f32> = HashMap::new();
let result = HybridSearchResult::from_fusion(&fusion, &dense_scores, &sparse_scores);
assert_eq!(result.dense_score, Some(0.9));
assert_eq!(result.sparse_score, None);
}
}