use crate::{
embed::Embedder,
fusion::FusionStrategy,
index::{BM25Index, SparseIndex, VectorStore},
Chunk, ChunkId, Result,
};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RetrievalResult {
pub chunk: Chunk,
pub dense_score: Option<f32>,
pub sparse_score: Option<f32>,
#[cfg(feature = "multivector")]
pub multivector_score: Option<f32>,
pub fused_score: Option<f32>,
pub rerank_score: Option<f32>,
}
impl RetrievalResult {
#[must_use]
pub fn new(chunk: Chunk) -> Self {
Self {
chunk,
dense_score: None,
sparse_score: None,
#[cfg(feature = "multivector")]
multivector_score: None,
fused_score: None,
rerank_score: None,
}
}
#[must_use]
pub fn with_dense_score(mut self, score: f32) -> Self {
self.dense_score = Some(score);
self
}
#[must_use]
pub fn with_sparse_score(mut self, score: f32) -> Self {
self.sparse_score = Some(score);
self
}
#[must_use]
pub fn with_fused_score(mut self, score: f32) -> Self {
self.fused_score = Some(score);
self
}
#[must_use]
pub fn with_rerank_score(mut self, score: f32) -> Self {
self.rerank_score = Some(score);
self
}
#[cfg(feature = "multivector")]
#[must_use]
pub fn with_multivector_score(mut self, score: f32) -> Self {
self.multivector_score = Some(score);
self
}
#[must_use]
pub fn best_score(&self) -> f32 {
self.rerank_score
.or(self.fused_score)
.or(self.dense_score)
.or(self.sparse_score)
.unwrap_or(0.0)
}
#[cfg(feature = "multivector")]
#[must_use]
pub fn best_score_with_multivector(&self) -> f32 {
self.rerank_score
.or(self.fused_score)
.or(self.multivector_score)
.or(self.dense_score)
.or(self.sparse_score)
.unwrap_or(0.0)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HybridRetrieverConfig {
pub candidates_per_source: usize,
pub fusion: FusionStrategy,
pub use_dense: bool,
pub use_sparse: bool,
}
impl Default for HybridRetrieverConfig {
fn default() -> Self {
Self {
candidates_per_source: 50,
fusion: FusionStrategy::default(),
use_dense: true,
use_sparse: true,
}
}
}
pub struct HybridRetriever<E: Embedder> {
dense: VectorStore,
sparse: BM25Index,
embedder: E,
config: HybridRetrieverConfig,
}
impl<E: Embedder> HybridRetriever<E> {
#[must_use]
pub fn new(dense: VectorStore, sparse: BM25Index, embedder: E) -> Self {
Self { dense, sparse, embedder, config: HybridRetrieverConfig::default() }
}
#[must_use]
pub fn with_config(mut self, config: HybridRetrieverConfig) -> Self {
self.config = config;
self
}
#[must_use]
pub fn dense_store(&self) -> &VectorStore {
&self.dense
}
pub fn dense_store_mut(&mut self) -> &mut VectorStore {
&mut self.dense
}
#[must_use]
pub fn sparse_index(&self) -> &BM25Index {
&self.sparse
}
pub fn sparse_index_mut(&mut self) -> &mut BM25Index {
&mut self.sparse
}
pub fn index(&mut self, chunk: Chunk) -> Result<()> {
self.sparse.add(&chunk);
self.dense.insert(chunk)?;
Ok(())
}
pub fn index_batch(&mut self, chunks: Vec<Chunk>) -> Result<()> {
for chunk in chunks {
self.index(chunk)?;
}
Ok(())
}
pub fn retrieve(&self, query: &str, k: usize) -> Result<Vec<RetrievalResult>> {
contract_pre_configuration!(query.as_bytes());
let candidates = self.config.candidates_per_source;
let dense_results = if self.config.use_dense {
let query_embedding = self.embedder.embed_query(query)?;
self.dense.search(&query_embedding, candidates)?
} else {
Vec::new()
};
let sparse_results =
if self.config.use_sparse { self.sparse.search(query, candidates) } else { Vec::new() };
let fused = self.config.fusion.fuse(&dense_results, &sparse_results);
let dense_scores: std::collections::HashMap<ChunkId, f32> =
dense_results.into_iter().collect();
let sparse_scores: std::collections::HashMap<ChunkId, f32> =
sparse_results.into_iter().collect();
let mut results = Vec::with_capacity(k.min(fused.len()));
for (chunk_id, fused_score) in fused.into_iter().take(k) {
if let Some(chunk) = self.dense.get(chunk_id) {
let mut result = RetrievalResult::new(chunk.clone()).with_fused_score(fused_score);
if let Some(&score) = dense_scores.get(&chunk_id) {
result = result.with_dense_score(score);
}
if let Some(&score) = sparse_scores.get(&chunk_id) {
result = result.with_sparse_score(score);
}
results.push(result);
}
}
Ok(results)
}
pub fn retrieve_dense(&self, query: &str, k: usize) -> Result<Vec<RetrievalResult>> {
contract_pre_configuration!(query.as_bytes());
let query_embedding = self.embedder.embed_query(query)?;
let results = self.dense.search(&query_embedding, k)?;
let mut retrieval_results = Vec::with_capacity(results.len());
for (chunk_id, score) in results {
if let Some(chunk) = self.dense.get(chunk_id) {
retrieval_results.push(RetrievalResult::new(chunk.clone()).with_dense_score(score));
}
}
Ok(retrieval_results)
}
pub fn retrieve_sparse(&self, query: &str, k: usize) -> Result<Vec<RetrievalResult>> {
contract_pre_configuration!(query.as_bytes());
let results = self.sparse.search(query, k);
let mut retrieval_results = Vec::with_capacity(results.len());
for (chunk_id, score) in results {
if let Some(chunk) = self.dense.get(chunk_id) {
retrieval_results
.push(RetrievalResult::new(chunk.clone()).with_sparse_score(score));
}
}
Ok(retrieval_results)
}
#[must_use]
pub fn len(&self) -> usize {
self.dense.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.dense.is_empty()
}
}
pub struct DenseRetriever<E: Embedder> {
store: VectorStore,
embedder: E,
}
impl<E: Embedder> DenseRetriever<E> {
#[must_use]
pub fn new(store: VectorStore, embedder: E) -> Self {
Self { store, embedder }
}
pub fn index(&mut self, chunk: Chunk) -> Result<()> {
self.store.insert(chunk)
}
pub fn retrieve(&self, query: &str, k: usize) -> Result<Vec<RetrievalResult>> {
let query_embedding = self.embedder.embed_query(query)?;
let results = self.store.search(&query_embedding, k)?;
let mut retrieval_results = Vec::with_capacity(results.len());
for (chunk_id, score) in results {
if let Some(chunk) = self.store.get(chunk_id) {
retrieval_results.push(RetrievalResult::new(chunk.clone()).with_dense_score(score));
}
}
Ok(retrieval_results)
}
}
pub struct SparseRetriever {
index: BM25Index,
chunks: std::collections::HashMap<ChunkId, Chunk>,
}
impl SparseRetriever {
#[must_use]
pub fn new() -> Self {
Self { index: BM25Index::new(), chunks: std::collections::HashMap::new() }
}
pub fn index(&mut self, chunk: Chunk) {
self.index.add(&chunk);
self.chunks.insert(chunk.id, chunk);
}
#[must_use]
pub fn retrieve(&self, query: &str, k: usize) -> Vec<RetrievalResult> {
let results = self.index.search(query, k);
results
.into_iter()
.filter_map(|(chunk_id, score)| {
self.chunks
.get(&chunk_id)
.map(|chunk| RetrievalResult::new(chunk.clone()).with_sparse_score(score))
})
.collect()
}
}
impl Default for SparseRetriever {
fn default() -> Self {
Self::new()
}
}
#[cfg(feature = "multivector")]
pub struct MultiVectorRetriever<E: crate::multivector::MultiVectorEmbedder> {
index: crate::multivector::WarpIndex,
embedder: E,
search_config: crate::multivector::WarpSearchConfig,
}
#[cfg(feature = "multivector")]
impl<E: crate::multivector::MultiVectorEmbedder> MultiVectorRetriever<E> {
#[must_use]
pub fn new(config: crate::multivector::WarpIndexConfig, embedder: E) -> Self {
Self {
index: crate::multivector::WarpIndex::new(config),
embedder,
search_config: crate::multivector::WarpSearchConfig::default(),
}
}
#[must_use]
pub fn with_search_config(mut self, config: crate::multivector::WarpSearchConfig) -> Self {
self.search_config = config;
self
}
pub fn train(&mut self, sample_chunks: &[Chunk]) -> Result<()> {
let texts: Vec<&str> = sample_chunks.iter().map(|c| c.content.as_str()).collect();
let embeddings = self.embedder.embed_tokens_batch(&texts)?;
self.index.train(&embeddings)?;
Ok(())
}
pub fn index(&mut self, chunk: Chunk) -> Result<()> {
let embedding = self.embedder.embed_tokens(&chunk.content)?;
self.index.insert(chunk, embedding)?;
Ok(())
}
pub fn index_batch(&mut self, chunks: Vec<Chunk>) -> Result<()> {
for chunk in chunks {
self.index(chunk)?;
}
Ok(())
}
pub fn build(&mut self) -> Result<()> {
self.index.build()
}
pub fn retrieve(&self, query: &str, k: usize) -> Result<Vec<RetrievalResult>> {
let query_embedding = self.embedder.embed_tokens(query)?;
let search_config = crate::multivector::WarpSearchConfig::with_k(k)
.nprobe(self.search_config.nprobe)
.bound(self.search_config.bound)
.centroid_score_threshold(self.search_config.centroid_score_threshold);
let results = self.index.search(&query_embedding, &search_config)?;
let mut retrieval_results = Vec::with_capacity(results.len());
for (chunk_id, score) in results {
if let Some(chunk) = self.index.get_chunk(&chunk_id) {
retrieval_results
.push(RetrievalResult::new(chunk.clone()).with_multivector_score(score));
}
}
Ok(retrieval_results)
}
#[must_use]
pub fn len(&self) -> usize {
self.index.num_chunks()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
#[must_use]
pub fn warp_index(&self) -> &crate::multivector::WarpIndex {
&self.index
}
#[must_use]
pub fn embedder(&self) -> &E {
contract_pre_embedding_lookup!();
&self.embedder
}
#[must_use]
pub fn memory_usage(&self) -> usize {
self.index.memory_usage()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{embed::MockEmbedder, DocumentId};
fn create_test_chunk(content: &str, embedding: Vec<f32>) -> Chunk {
let mut chunk = Chunk::new(DocumentId::new(), content.to_string(), 0, content.len());
chunk.set_embedding(embedding);
chunk
}
#[test]
fn test_retrieval_result_new() {
let chunk = Chunk::new(DocumentId::new(), "test".to_string(), 0, 4);
let result = RetrievalResult::new(chunk);
assert!(result.dense_score.is_none());
assert!(result.sparse_score.is_none());
assert!(result.fused_score.is_none());
assert!(result.rerank_score.is_none());
}
#[test]
fn test_retrieval_result_with_scores() {
let chunk = Chunk::new(DocumentId::new(), "test".to_string(), 0, 4);
let result = RetrievalResult::new(chunk)
.with_dense_score(0.9)
.with_sparse_score(0.8)
.with_fused_score(0.85)
.with_rerank_score(0.95);
assert_eq!(result.dense_score, Some(0.9));
assert_eq!(result.sparse_score, Some(0.8));
assert_eq!(result.fused_score, Some(0.85));
assert_eq!(result.rerank_score, Some(0.95));
}
#[test]
fn test_retrieval_result_best_score_priority() {
let chunk = Chunk::new(DocumentId::new(), "test".to_string(), 0, 4);
let result =
RetrievalResult::new(chunk.clone()).with_dense_score(0.5).with_rerank_score(0.9);
assert!((result.best_score() - 0.9).abs() < 0.001);
let result =
RetrievalResult::new(chunk.clone()).with_dense_score(0.5).with_fused_score(0.7);
assert!((result.best_score() - 0.7).abs() < 0.001);
let result = RetrievalResult::new(chunk).with_dense_score(0.5);
assert!((result.best_score() - 0.5).abs() < 0.001);
}
#[test]
fn test_retrieval_result_best_score_default() {
let chunk = Chunk::new(DocumentId::new(), "test".to_string(), 0, 4);
let result = RetrievalResult::new(chunk);
assert!((result.best_score() - 0.0).abs() < 0.001);
}
#[test]
fn test_hybrid_config_default() {
let config = HybridRetrieverConfig::default();
assert_eq!(config.candidates_per_source, 50);
assert!(config.use_dense);
assert!(config.use_sparse);
}
#[test]
fn test_hybrid_retriever_new() {
let embedder = MockEmbedder::new(64);
let dense = VectorStore::with_dimension(64);
let sparse = BM25Index::new();
let retriever = HybridRetriever::new(dense, sparse, embedder);
assert!(retriever.is_empty());
}
#[test]
fn test_hybrid_retriever_index() {
let embedder = MockEmbedder::new(64);
let dense = VectorStore::with_dimension(64);
let sparse = BM25Index::new();
let mut retriever = HybridRetriever::new(dense, sparse, embedder);
let chunk = create_test_chunk("machine learning is great", vec![0.0; 64]);
retriever.index(chunk).unwrap();
assert_eq!(retriever.len(), 1);
}
#[test]
fn test_hybrid_retriever_index_batch() {
let embedder = MockEmbedder::new(64);
let dense = VectorStore::with_dimension(64);
let sparse = BM25Index::new();
let mut retriever = HybridRetriever::new(dense, sparse, embedder);
let chunks = vec![
create_test_chunk("first document", vec![1.0; 64]),
create_test_chunk("second document", vec![0.5; 64]),
];
retriever.index_batch(chunks).unwrap();
assert_eq!(retriever.len(), 2);
}
#[test]
fn test_hybrid_retriever_retrieve() {
let embedder = MockEmbedder::new(3);
let dense = VectorStore::with_dimension(3);
let sparse = BM25Index::new();
let mut retriever = HybridRetriever::new(dense, sparse, embedder);
retriever
.index(create_test_chunk("machine learning algorithms", vec![1.0, 0.0, 0.0]))
.unwrap();
retriever
.index(create_test_chunk("deep learning neural networks", vec![0.9, 0.1, 0.0]))
.unwrap();
retriever.index(create_test_chunk("cooking recipes", vec![0.0, 0.0, 1.0])).unwrap();
let results = retriever.retrieve("machine learning", 2).unwrap();
assert!(!results.is_empty());
assert!(results.len() <= 2);
}
#[test]
fn test_hybrid_retriever_retrieve_dense_only() {
let embedder = MockEmbedder::new(3);
let dense = VectorStore::with_dimension(3);
let sparse = BM25Index::new();
let mut retriever = HybridRetriever::new(dense, sparse, embedder);
retriever.index(create_test_chunk("test doc", vec![1.0, 0.0, 0.0])).unwrap();
let results = retriever.retrieve_dense("test", 10).unwrap();
assert!(!results.is_empty());
assert!(results[0].dense_score.is_some());
assert!(results[0].sparse_score.is_none());
}
#[test]
fn test_hybrid_retriever_retrieve_sparse_only() {
let embedder = MockEmbedder::new(3);
let dense = VectorStore::with_dimension(3);
let sparse = BM25Index::new();
let mut retriever = HybridRetriever::new(dense, sparse, embedder);
retriever.index(create_test_chunk("machine learning test", vec![1.0, 0.0, 0.0])).unwrap();
let results = retriever.retrieve_sparse("machine", 10).unwrap();
assert!(!results.is_empty());
assert!(results[0].sparse_score.is_some());
assert!(results[0].dense_score.is_none());
}
#[test]
fn test_hybrid_retriever_config() {
let embedder = MockEmbedder::new(3);
let dense = VectorStore::with_dimension(3);
let sparse = BM25Index::new();
let config = HybridRetrieverConfig {
candidates_per_source: 100,
fusion: FusionStrategy::Linear { dense_weight: 0.7 },
use_dense: true,
use_sparse: true,
};
let retriever = HybridRetriever::new(dense, sparse, embedder).with_config(config);
assert_eq!(retriever.config.candidates_per_source, 100);
}
#[test]
fn test_dense_retriever() {
let embedder = MockEmbedder::new(3);
let store = VectorStore::with_dimension(3);
let mut retriever = DenseRetriever::new(store, embedder);
retriever.index(create_test_chunk("test document", vec![1.0, 0.0, 0.0])).unwrap();
let results = retriever.retrieve("test", 10).unwrap();
assert_eq!(results.len(), 1);
assert!(results[0].dense_score.is_some());
}
#[test]
fn test_sparse_retriever_new() {
let retriever = SparseRetriever::new();
let results = retriever.retrieve("test", 10);
assert!(results.is_empty());
}
#[test]
fn test_sparse_retriever_index() {
let mut retriever = SparseRetriever::new();
let chunk = Chunk::new(DocumentId::new(), "machine learning test".to_string(), 0, 20);
retriever.index(chunk);
let results = retriever.retrieve("machine", 10);
assert_eq!(results.len(), 1);
assert!(results[0].sparse_score.is_some());
}
#[test]
fn test_sparse_retriever_multiple() {
let mut retriever = SparseRetriever::new();
retriever.index(Chunk::new(
DocumentId::new(),
"rust programming language".to_string(),
0,
24,
));
retriever.index(Chunk::new(
DocumentId::new(),
"python programming language".to_string(),
0,
26,
));
let results = retriever.retrieve("programming", 10);
assert_eq!(results.len(), 2);
}
#[test]
fn test_hybrid_retriever_store_accessors() {
let embedder = MockEmbedder::new(64);
let dense = VectorStore::with_dimension(64);
let sparse = BM25Index::new();
let mut retriever = HybridRetriever::new(dense, sparse, embedder);
let _dense_store = retriever.dense_store();
let _sparse_index = retriever.sparse_index();
let dense_mut = retriever.dense_store_mut();
assert!(dense_mut.is_empty());
let sparse_mut = retriever.sparse_index_mut();
let _ = sparse_mut; }
#[test]
fn test_hybrid_retriever_is_empty() {
let embedder = MockEmbedder::new(64);
let dense = VectorStore::with_dimension(64);
let sparse = BM25Index::new();
let mut retriever = HybridRetriever::new(dense, sparse, embedder);
assert!(retriever.is_empty());
retriever.index(create_test_chunk("test", vec![0.0; 64])).unwrap();
assert!(!retriever.is_empty());
}
#[test]
fn test_sparse_retriever_default() {
let retriever = SparseRetriever::default();
let results = retriever.retrieve("test", 10);
assert!(results.is_empty());
}
#[test]
fn test_retrieval_result_best_score_sparse_fallback() {
let chunk = Chunk::new(DocumentId::new(), "test".to_string(), 0, 4);
let result = RetrievalResult::new(chunk).with_sparse_score(0.75);
assert!((result.best_score() - 0.75).abs() < 0.001);
}
#[test]
fn test_hybrid_retriever_with_dense_disabled() {
let embedder = MockEmbedder::new(3);
let dense = VectorStore::with_dimension(3);
let sparse = BM25Index::new();
let config = HybridRetrieverConfig {
candidates_per_source: 50,
fusion: FusionStrategy::default(),
use_dense: false,
use_sparse: true,
};
let mut retriever = HybridRetriever::new(dense, sparse, embedder).with_config(config);
retriever.index(create_test_chunk("machine learning test", vec![1.0, 0.0, 0.0])).unwrap();
let results = retriever.retrieve("machine", 10).unwrap();
assert!(results.len() <= 10);
}
#[test]
fn test_hybrid_retriever_with_sparse_disabled() {
let embedder = MockEmbedder::new(3);
let dense = VectorStore::with_dimension(3);
let sparse = BM25Index::new();
let config = HybridRetrieverConfig {
candidates_per_source: 50,
fusion: FusionStrategy::default(),
use_dense: true,
use_sparse: false,
};
let mut retriever = HybridRetriever::new(dense, sparse, embedder).with_config(config);
retriever.index(create_test_chunk("test content", vec![1.0, 0.0, 0.0])).unwrap();
let results = retriever.retrieve("test", 10).unwrap();
assert!(results.len() <= 10);
}
#[test]
fn test_hybrid_retriever_config_serialization() {
let config = HybridRetrieverConfig {
candidates_per_source: 100,
fusion: FusionStrategy::RRF { k: 60.0 },
use_dense: true,
use_sparse: false,
};
let json = serde_json::to_string(&config).unwrap();
let deserialized: HybridRetrieverConfig = serde_json::from_str(&json).unwrap();
assert_eq!(config.candidates_per_source, deserialized.candidates_per_source);
assert_eq!(config.use_dense, deserialized.use_dense);
assert_eq!(config.use_sparse, deserialized.use_sparse);
}
#[test]
fn test_retrieval_result_serialization() {
let chunk = Chunk::new(DocumentId::new(), "test content".to_string(), 0, 12);
let result = RetrievalResult::new(chunk)
.with_dense_score(0.9)
.with_sparse_score(0.8)
.with_fused_score(0.85)
.with_rerank_score(0.95);
let json = serde_json::to_string(&result).unwrap();
let deserialized: RetrievalResult = serde_json::from_str(&json).unwrap();
assert_eq!(result.dense_score, deserialized.dense_score);
assert_eq!(result.sparse_score, deserialized.sparse_score);
assert_eq!(result.fused_score, deserialized.fused_score);
assert_eq!(result.rerank_score, deserialized.rerank_score);
}
use proptest::prelude::*;
proptest! {
#[test]
fn prop_retrieval_result_scores_preserved(
dense in 0.0f32..1.0,
sparse in 0.0f32..1.0,
fused in 0.0f32..1.0
) {
let chunk = Chunk::new(DocumentId::new(), "test".to_string(), 0, 4);
let result = RetrievalResult::new(chunk)
.with_dense_score(dense)
.with_sparse_score(sparse)
.with_fused_score(fused);
prop_assert!((result.dense_score.unwrap() - dense).abs() < 0.0001);
prop_assert!((result.sparse_score.unwrap() - sparse).abs() < 0.0001);
prop_assert!((result.fused_score.unwrap() - fused).abs() < 0.0001);
}
#[test]
fn prop_hybrid_retriever_respects_k(k in 1usize..10) {
let embedder = MockEmbedder::new(3);
let dense = VectorStore::with_dimension(3);
let sparse = BM25Index::new();
let mut retriever = HybridRetriever::new(dense, sparse, embedder);
for i in 0..20 {
let mut emb = vec![0.0; 3];
emb[i % 3] = 1.0;
retriever.index(create_test_chunk(
&format!("document number {i} about testing"),
emb,
)).unwrap();
}
let results = retriever.retrieve("testing", k).unwrap();
prop_assert!(results.len() <= k);
}
}
}