use std::collections::{HashMap, HashSet};
use super::distance::DistanceResult;
use super::hnsw::{HnswIndex, NodeId};
use super::vector_metadata::{MetadataFilter, MetadataStore};
#[derive(Clone, Debug)]
pub struct BM25Config {
pub k1: f32,
pub b: f32,
}
impl Default for BM25Config {
fn default() -> Self {
Self { k1: 1.2, b: 0.75 }
}
}
pub struct SparseIndex {
postings: HashMap<String, Vec<(NodeId, f32)>>,
doc_lengths: HashMap<NodeId, usize>,
avg_doc_length: f32,
doc_count: usize,
config: BM25Config,
}
impl SparseIndex {
pub fn new() -> Self {
Self {
postings: HashMap::new(),
doc_lengths: HashMap::new(),
avg_doc_length: 0.0,
doc_count: 0,
config: BM25Config::default(),
}
}
pub fn with_config(config: BM25Config) -> Self {
Self {
postings: HashMap::new(),
doc_lengths: HashMap::new(),
avg_doc_length: 0.0,
doc_count: 0,
config,
}
}
pub fn index(&mut self, doc_id: NodeId, terms: &[String]) {
let mut term_counts: HashMap<&str, usize> = HashMap::new();
for term in terms {
*term_counts.entry(term.as_str()).or_insert(0) += 1;
}
for (term, count) in term_counts {
self.postings
.entry(term.to_lowercase())
.or_default()
.push((doc_id, count as f32));
}
self.doc_lengths.insert(doc_id, terms.len());
self.doc_count += 1;
let total_length: usize = self.doc_lengths.values().sum();
self.avg_doc_length = total_length as f32 / self.doc_count as f32;
}
pub fn index_text(&mut self, doc_id: NodeId, text: &str) {
let terms: Vec<String> = tokenize(text);
self.index(doc_id, &terms);
}
pub fn remove(&mut self, doc_id: NodeId) {
for postings in self.postings.values_mut() {
postings.retain(|(id, _)| *id != doc_id);
}
if self.doc_lengths.remove(&doc_id).is_some() {
self.doc_count = self.doc_count.saturating_sub(1);
if self.doc_count > 0 {
let total_length: usize = self.doc_lengths.values().sum();
self.avg_doc_length = total_length as f32 / self.doc_count as f32;
} else {
self.avg_doc_length = 0.0;
}
}
}
pub fn search(&self, query: &str, k: usize) -> Vec<SparseResult> {
let query_terms = tokenize(query);
if query_terms.is_empty() {
return Vec::new();
}
let mut scores: HashMap<NodeId, f32> = HashMap::new();
for term in &query_terms {
let term_lower = term.to_lowercase();
if let Some(postings) = self.postings.get(&term_lower) {
let df = postings.len() as f32;
let idf = ((self.doc_count as f32 - df + 0.5) / (df + 0.5) + 1.0).ln();
for &(doc_id, tf) in postings {
let doc_len = self.doc_lengths.get(&doc_id).copied().unwrap_or(1) as f32;
let tf_component = (tf * (self.config.k1 + 1.0))
/ (tf
+ self.config.k1
* (1.0 - self.config.b
+ self.config.b * doc_len / self.avg_doc_length));
*scores.entry(doc_id).or_insert(0.0) += idf * tf_component;
}
}
}
let mut results: Vec<SparseResult> = scores
.into_iter()
.map(|(id, score)| SparseResult { id, score })
.collect();
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| a.id.cmp(&b.id))
});
results.truncate(k);
results
}
pub fn len(&self) -> usize {
self.doc_count
}
pub fn is_empty(&self) -> bool {
self.doc_count == 0
}
pub fn vocab_size(&self) -> usize {
self.postings.len()
}
}
impl Default for SparseIndex {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct SparseResult {
pub id: NodeId,
pub score: f32,
}
fn tokenize(text: &str) -> Vec<String> {
text.split(|c: char| !c.is_alphanumeric() && c != '-' && c != '_')
.filter(|s| s.len() >= 2) .map(|s| s.to_lowercase())
.collect()
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum FusionMethod {
RRF(usize),
Linear(f32),
DBSF,
}
impl Default for FusionMethod {
fn default() -> Self {
FusionMethod::RRF(60)
}
}
pub fn reciprocal_rank_fusion(
dense_results: &[DistanceResult],
sparse_results: &[SparseResult],
k: usize,
) -> Vec<HybridResult> {
let mut scores: HashMap<NodeId, f32> = HashMap::new();
let mut dense_scores: HashMap<NodeId, f32> = HashMap::new();
let mut sparse_scores: HashMap<NodeId, f32> = HashMap::new();
for (rank, result) in dense_results.iter().enumerate() {
let rrf_score = 1.0 / (k as f32 + rank as f32 + 1.0);
*scores.entry(result.id).or_insert(0.0) += rrf_score;
dense_scores.insert(result.id, result.distance);
}
for (rank, result) in sparse_results.iter().enumerate() {
let rrf_score = 1.0 / (k as f32 + rank as f32 + 1.0);
*scores.entry(result.id).or_insert(0.0) += rrf_score;
sparse_scores.insert(result.id, result.score);
}
let mut results: Vec<HybridResult> = scores
.into_iter()
.map(|(id, score)| HybridResult {
id,
score,
dense_score: dense_scores.get(&id).copied(),
sparse_score: sparse_scores.get(&id).copied(),
})
.collect();
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| a.id.cmp(&b.id))
});
results
}
pub fn linear_fusion(
dense_results: &[DistanceResult],
sparse_results: &[SparseResult],
alpha: f32,
) -> Vec<HybridResult> {
let mut scores: HashMap<NodeId, (Option<f32>, Option<f32>)> = HashMap::new();
let dense_min = dense_results
.iter()
.map(|r| r.distance)
.fold(f32::INFINITY, f32::min);
let dense_max = dense_results
.iter()
.map(|r| r.distance)
.fold(f32::NEG_INFINITY, f32::max);
let dense_range = (dense_max - dense_min).max(1e-6);
for result in dense_results {
let normalized = 1.0 - (result.distance - dense_min) / dense_range;
scores.entry(result.id).or_insert((None, None)).0 = Some(normalized);
}
let sparse_max = sparse_results
.iter()
.map(|r| r.score)
.fold(f32::NEG_INFINITY, f32::max);
let sparse_max = sparse_max.max(1e-6);
for result in sparse_results {
let normalized = result.score / sparse_max;
scores.entry(result.id).or_insert((None, None)).1 = Some(normalized);
}
let mut results: Vec<HybridResult> = scores
.into_iter()
.map(|(id, (dense, sparse))| {
let dense_contrib = dense.unwrap_or(0.0) * alpha;
let sparse_contrib = sparse.unwrap_or(0.0) * (1.0 - alpha);
HybridResult {
id,
score: dense_contrib + sparse_contrib,
dense_score: dense,
sparse_score: sparse,
}
})
.collect();
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| a.id.cmp(&b.id))
});
results
}
pub fn dbsf_fusion(
dense_results: &[DistanceResult],
sparse_results: &[SparseResult],
) -> Vec<HybridResult> {
let mut scores: HashMap<NodeId, (Option<f32>, Option<f32>)> = HashMap::new();
if !dense_results.is_empty() {
let similarities: Vec<f32> = dense_results
.iter()
.map(|r| 1.0 / (1.0 + r.distance))
.collect();
let mean: f32 = similarities.iter().sum::<f32>() / similarities.len() as f32;
let variance: f32 = similarities.iter().map(|s| (s - mean).powi(2)).sum::<f32>()
/ similarities.len() as f32;
let std_dev = variance.sqrt().max(1e-6);
for (result, sim) in dense_results.iter().zip(similarities.iter()) {
let z_score = (sim - mean) / std_dev;
scores.entry(result.id).or_insert((None, None)).0 = Some(z_score);
}
}
if !sparse_results.is_empty() {
let mean: f32 =
sparse_results.iter().map(|r| r.score).sum::<f32>() / sparse_results.len() as f32;
let variance: f32 = sparse_results
.iter()
.map(|r| (r.score - mean).powi(2))
.sum::<f32>()
/ sparse_results.len() as f32;
let std_dev = variance.sqrt().max(1e-6);
for result in sparse_results {
let z_score = (result.score - mean) / std_dev;
scores.entry(result.id).or_insert((None, None)).1 = Some(z_score);
}
}
let mut results: Vec<HybridResult> = scores
.into_iter()
.map(|(id, (dense, sparse))| HybridResult {
id,
score: dense.unwrap_or(0.0) + sparse.unwrap_or(0.0),
dense_score: dense,
sparse_score: sparse,
})
.collect();
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| a.id.cmp(&b.id))
});
results
}
#[derive(Debug, Clone)]
pub struct HybridResult {
pub id: NodeId,
pub score: f32,
pub dense_score: Option<f32>,
pub sparse_score: Option<f32>,
}
pub struct HybridSearch<'a> {
dense_index: &'a HnswIndex,
sparse_index: &'a SparseIndex,
metadata: Option<&'a MetadataStore>,
}
impl<'a> HybridSearch<'a> {
pub fn new(dense_index: &'a HnswIndex, sparse_index: &'a SparseIndex) -> Self {
Self {
dense_index,
sparse_index,
metadata: None,
}
}
pub fn with_metadata(mut self, metadata: &'a MetadataStore) -> Self {
self.metadata = Some(metadata);
self
}
pub fn query(&'a self) -> HybridQueryBuilder<'a> {
HybridQueryBuilder::new(self)
}
pub fn search(
&self,
query_vector: Option<&[f32]>,
query_text: Option<&str>,
k: usize,
fusion: FusionMethod,
pre_filter: Option<&HashSet<NodeId>>,
post_filter: Option<&dyn Fn(&HybridResult) -> bool>,
) -> Vec<HybridResult> {
let fetch_k = k * 3;
let dense_results = if let Some(vector) = query_vector {
if let Some(filter) = pre_filter {
self.dense_index.search_filtered(vector, fetch_k, filter)
} else {
self.dense_index.search(vector, fetch_k)
}
} else {
Vec::new()
};
let sparse_results = if let Some(text) = query_text {
let mut results = self.sparse_index.search(text, fetch_k);
if let Some(filter) = pre_filter {
results.retain(|r| filter.contains(&r.id));
}
results
} else {
Vec::new()
};
let mut fused = match fusion {
FusionMethod::RRF(k_param) => {
reciprocal_rank_fusion(&dense_results, &sparse_results, k_param)
}
FusionMethod::Linear(alpha) => linear_fusion(&dense_results, &sparse_results, alpha),
FusionMethod::DBSF => dbsf_fusion(&dense_results, &sparse_results),
};
if let Some(filter_fn) = post_filter {
fused.retain(filter_fn);
}
fused.truncate(k);
fused
}
pub fn search_dense(&self, query_vector: &[f32], k: usize) -> Vec<DistanceResult> {
self.dense_index.search(query_vector, k)
}
pub fn search_sparse(&self, query_text: &str, k: usize) -> Vec<SparseResult> {
self.sparse_index.search(query_text, k)
}
}
pub struct HybridQueryBuilder<'a> {
search: &'a HybridSearch<'a>,
query_vector: Option<Vec<f32>>,
query_text: Option<String>,
k: usize,
fusion: FusionMethod,
pre_filter_ids: Option<HashSet<NodeId>>,
metadata_filter: Option<MetadataFilter>,
}
impl<'a> HybridQueryBuilder<'a> {
fn new(search: &'a HybridSearch<'a>) -> Self {
Self {
search,
query_vector: None,
query_text: None,
k: 10,
fusion: FusionMethod::default(),
pre_filter_ids: None,
metadata_filter: None,
}
}
pub fn with_vector(mut self, vector: Vec<f32>) -> Self {
self.query_vector = Some(vector);
self
}
pub fn with_text(mut self, text: impl Into<String>) -> Self {
self.query_text = Some(text.into());
self
}
pub fn with_both(self, vector: Vec<f32>, text: impl Into<String>) -> Self {
self.with_vector(vector).with_text(text)
}
pub fn top_k(mut self, k: usize) -> Self {
self.k = k;
self
}
pub fn fusion(mut self, method: FusionMethod) -> Self {
self.fusion = method;
self
}
pub fn rrf(mut self, k: usize) -> Self {
self.fusion = FusionMethod::RRF(k);
self
}
pub fn linear(mut self, alpha: f32) -> Self {
self.fusion = FusionMethod::Linear(alpha);
self
}
pub fn filter_ids(mut self, ids: HashSet<NodeId>) -> Self {
self.pre_filter_ids = Some(ids);
self
}
pub fn filter_metadata(mut self, filter: MetadataFilter) -> Self {
self.metadata_filter = Some(filter);
self
}
pub fn execute(self) -> Vec<HybridResult> {
let pre_filter = if let Some(meta_filter) = &self.metadata_filter {
if let Some(meta_store) = self.search.metadata {
let matching_ids = meta_store.filter(meta_filter);
if let Some(ref explicit_ids) = self.pre_filter_ids {
Some(matching_ids.intersection(explicit_ids).copied().collect())
} else {
Some(matching_ids)
}
} else {
self.pre_filter_ids.clone()
}
} else {
self.pre_filter_ids.clone()
};
self.search.search(
self.query_vector.as_deref(),
self.query_text.as_deref(),
self.k,
self.fusion,
pre_filter.as_ref(),
None,
)
}
}
pub trait Reranker: Send + Sync {
fn rerank(&self, results: &[HybridResult], query: &str) -> Vec<(NodeId, f32)>;
}
pub struct ExactMatchReranker {
pub boost: f32,
}
impl Default for ExactMatchReranker {
fn default() -> Self {
Self { boost: 2.0 }
}
}
impl Reranker for ExactMatchReranker {
fn rerank(&self, results: &[HybridResult], _query: &str) -> Vec<(NodeId, f32)> {
results.iter().map(|r| (r.id, r.score)).collect()
}
}
pub struct RerankerPipeline {
stages: Vec<Box<dyn Reranker>>,
}
impl RerankerPipeline {
pub fn new() -> Self {
Self { stages: Vec::new() }
}
pub fn add_stage(mut self, reranker: Box<dyn Reranker>) -> Self {
self.stages.push(reranker);
self
}
pub fn rerank(&self, mut results: Vec<HybridResult>, query: &str) -> Vec<HybridResult> {
for stage in &self.stages {
let reranked = stage.rerank(&results, query);
let score_map: HashMap<NodeId, f32> = reranked.into_iter().collect();
for result in &mut results {
if let Some(&new_score) = score_map.get(&result.id) {
result.score = new_score;
}
}
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| a.id.cmp(&b.id))
});
}
results
}
}
impl Default for RerankerPipeline {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tokenize() {
let tokens = tokenize("Hello, World! This is a test-case.");
assert!(tokens.contains(&"hello".to_string()));
assert!(tokens.contains(&"world".to_string()));
assert!(tokens.contains(&"test-case".to_string()));
assert!(!tokens.contains(&"a".to_string())); }
#[test]
fn test_sparse_index() {
let mut index = SparseIndex::new();
index.index_text(0, "remote code execution vulnerability");
index.index_text(1, "cross-site scripting XSS vulnerability");
index.index_text(2, "SQL injection database vulnerability");
assert_eq!(index.len(), 3);
let results = index.search("code execution", 10);
assert!(!results.is_empty());
assert_eq!(results[0].id, 0); }
#[test]
fn test_sparse_remove() {
let mut index = SparseIndex::new();
index.index_text(0, "document one");
index.index_text(1, "document two");
assert_eq!(index.len(), 2);
index.remove(0);
assert_eq!(index.len(), 1);
let results = index.search("document", 10);
assert_eq!(results.len(), 1);
assert_eq!(results[0].id, 1);
}
#[test]
fn test_rrf_fusion() {
let dense = vec![
DistanceResult::new(1, 0.1),
DistanceResult::new(2, 0.2),
DistanceResult::new(3, 0.3),
];
let sparse = vec![
SparseResult { id: 2, score: 5.0 },
SparseResult { id: 4, score: 4.0 },
SparseResult { id: 1, score: 3.0 },
];
let fused = reciprocal_rank_fusion(&dense, &sparse, 60);
let top_ids: Vec<NodeId> = fused.iter().take(2).map(|r| r.id).collect();
assert!(top_ids.contains(&1));
assert!(top_ids.contains(&2));
}
#[test]
fn test_linear_fusion() {
let dense = vec![
DistanceResult::new(1, 0.1), DistanceResult::new(2, 0.5),
];
let sparse = vec![
SparseResult { id: 2, score: 10.0 }, SparseResult { id: 1, score: 5.0 },
];
let fused_dense = linear_fusion(&dense, &sparse, 0.9);
assert_eq!(fused_dense[0].id, 1);
let fused_sparse = linear_fusion(&dense, &sparse, 0.1);
assert_eq!(fused_sparse[0].id, 2); }
#[test]
fn test_bm25_scoring() {
let mut index = SparseIndex::new();
index.index_text(0, "vulnerability vulnerability vulnerability");
index.index_text(1, "vulnerability in system");
index.index_text(2, "no relevant terms here");
let results = index.search("vulnerability", 10);
assert_eq!(results[0].id, 0);
assert!(results[0].score > results[1].score);
}
}