1use std::collections::{HashMap, HashSet};
24
25use super::distance::DistanceResult;
26use super::hnsw::{HnswIndex, NodeId};
27use super::vector_metadata::{MetadataFilter, MetadataStore};
28
29#[derive(Clone, Debug)]
35pub struct BM25Config {
36 pub k1: f32,
38 pub b: f32,
40}
41
42impl Default for BM25Config {
43 fn default() -> Self {
44 Self { k1: 1.2, b: 0.75 }
45 }
46}
47
48pub struct SparseIndex {
50 postings: HashMap<String, Vec<(NodeId, f32)>>,
52 doc_lengths: HashMap<NodeId, usize>,
54 avg_doc_length: f32,
56 doc_count: usize,
58 config: BM25Config,
60}
61
62impl SparseIndex {
63 pub fn new() -> Self {
65 Self {
66 postings: HashMap::new(),
67 doc_lengths: HashMap::new(),
68 avg_doc_length: 0.0,
69 doc_count: 0,
70 config: BM25Config::default(),
71 }
72 }
73
74 pub fn with_config(config: BM25Config) -> Self {
76 Self {
77 postings: HashMap::new(),
78 doc_lengths: HashMap::new(),
79 avg_doc_length: 0.0,
80 doc_count: 0,
81 config,
82 }
83 }
84
85 pub fn index(&mut self, doc_id: NodeId, terms: &[String]) {
87 let mut term_counts: HashMap<&str, usize> = HashMap::new();
89 for term in terms {
90 *term_counts.entry(term.as_str()).or_insert(0) += 1;
91 }
92
93 for (term, count) in term_counts {
95 self.postings
96 .entry(term.to_lowercase())
97 .or_default()
98 .push((doc_id, count as f32));
99 }
100
101 self.doc_lengths.insert(doc_id, terms.len());
103 self.doc_count += 1;
104
105 let total_length: usize = self.doc_lengths.values().sum();
107 self.avg_doc_length = total_length as f32 / self.doc_count as f32;
108 }
109
110 pub fn index_text(&mut self, doc_id: NodeId, text: &str) {
112 let terms: Vec<String> = tokenize(text);
113 self.index(doc_id, &terms);
114 }
115
116 pub fn remove(&mut self, doc_id: NodeId) {
118 for postings in self.postings.values_mut() {
120 postings.retain(|(id, _)| *id != doc_id);
121 }
122
123 if self.doc_lengths.remove(&doc_id).is_some() {
125 self.doc_count = self.doc_count.saturating_sub(1);
126
127 if self.doc_count > 0 {
129 let total_length: usize = self.doc_lengths.values().sum();
130 self.avg_doc_length = total_length as f32 / self.doc_count as f32;
131 } else {
132 self.avg_doc_length = 0.0;
133 }
134 }
135 }
136
137 pub fn search(&self, query: &str, k: usize) -> Vec<SparseResult> {
139 let query_terms = tokenize(query);
140
141 if query_terms.is_empty() {
142 return Vec::new();
143 }
144
145 let mut scores: HashMap<NodeId, f32> = HashMap::new();
147
148 for term in &query_terms {
149 let term_lower = term.to_lowercase();
150 if let Some(postings) = self.postings.get(&term_lower) {
151 let df = postings.len() as f32;
153 let idf = ((self.doc_count as f32 - df + 0.5) / (df + 0.5) + 1.0).ln();
154
155 for &(doc_id, tf) in postings {
156 let doc_len = self.doc_lengths.get(&doc_id).copied().unwrap_or(1) as f32;
157
158 let tf_component = (tf * (self.config.k1 + 1.0))
160 / (tf
161 + self.config.k1
162 * (1.0 - self.config.b
163 + self.config.b * doc_len / self.avg_doc_length));
164
165 *scores.entry(doc_id).or_insert(0.0) += idf * tf_component;
166 }
167 }
168 }
169
170 let mut results: Vec<SparseResult> = scores
172 .into_iter()
173 .map(|(id, score)| SparseResult { id, score })
174 .collect();
175
176 results.sort_by(|a, b| {
177 b.score
178 .partial_cmp(&a.score)
179 .unwrap_or(std::cmp::Ordering::Equal)
180 .then_with(|| a.id.cmp(&b.id))
181 });
182 results.truncate(k);
183
184 results
185 }
186
187 pub fn len(&self) -> usize {
189 self.doc_count
190 }
191
192 pub fn is_empty(&self) -> bool {
194 self.doc_count == 0
195 }
196
197 pub fn vocab_size(&self) -> usize {
199 self.postings.len()
200 }
201}
202
203impl Default for SparseIndex {
204 fn default() -> Self {
205 Self::new()
206 }
207}
208
209#[derive(Debug, Clone)]
211pub struct SparseResult {
212 pub id: NodeId,
213 pub score: f32,
214}
215
216fn tokenize(text: &str) -> Vec<String> {
218 text.split(|c: char| !c.is_alphanumeric() && c != '-' && c != '_')
219 .filter(|s| s.len() >= 2) .map(|s| s.to_lowercase())
221 .collect()
222}
223
224#[derive(Clone, Copy, Debug, PartialEq)]
230pub enum FusionMethod {
231 RRF(usize),
233 Linear(f32),
235 DBSF,
237}
238
239impl Default for FusionMethod {
240 fn default() -> Self {
241 FusionMethod::RRF(60)
242 }
243}
244
245pub fn reciprocal_rank_fusion(
250 dense_results: &[DistanceResult],
251 sparse_results: &[SparseResult],
252 k: usize,
253) -> Vec<HybridResult> {
254 let mut scores: HashMap<NodeId, f32> = HashMap::new();
255 let mut dense_scores: HashMap<NodeId, f32> = HashMap::new();
256 let mut sparse_scores: HashMap<NodeId, f32> = HashMap::new();
257
258 for (rank, result) in dense_results.iter().enumerate() {
260 let rrf_score = 1.0 / (k as f32 + rank as f32 + 1.0);
261 *scores.entry(result.id).or_insert(0.0) += rrf_score;
262 dense_scores.insert(result.id, result.distance);
263 }
264
265 for (rank, result) in sparse_results.iter().enumerate() {
267 let rrf_score = 1.0 / (k as f32 + rank as f32 + 1.0);
268 *scores.entry(result.id).or_insert(0.0) += rrf_score;
269 sparse_scores.insert(result.id, result.score);
270 }
271
272 let mut results: Vec<HybridResult> = scores
274 .into_iter()
275 .map(|(id, score)| HybridResult {
276 id,
277 score,
278 dense_score: dense_scores.get(&id).copied(),
279 sparse_score: sparse_scores.get(&id).copied(),
280 })
281 .collect();
282
283 results.sort_by(|a, b| {
284 b.score
285 .partial_cmp(&a.score)
286 .unwrap_or(std::cmp::Ordering::Equal)
287 .then_with(|| a.id.cmp(&b.id))
288 });
289 results
290}
291
292pub fn linear_fusion(
296 dense_results: &[DistanceResult],
297 sparse_results: &[SparseResult],
298 alpha: f32,
299) -> Vec<HybridResult> {
300 let mut scores: HashMap<NodeId, (Option<f32>, Option<f32>)> = HashMap::new();
301
302 let dense_min = dense_results
304 .iter()
305 .map(|r| r.distance)
306 .fold(f32::INFINITY, f32::min);
307 let dense_max = dense_results
308 .iter()
309 .map(|r| r.distance)
310 .fold(f32::NEG_INFINITY, f32::max);
311 let dense_range = (dense_max - dense_min).max(1e-6);
312
313 for result in dense_results {
314 let normalized = 1.0 - (result.distance - dense_min) / dense_range;
316 scores.entry(result.id).or_insert((None, None)).0 = Some(normalized);
317 }
318
319 let sparse_max = sparse_results
321 .iter()
322 .map(|r| r.score)
323 .fold(f32::NEG_INFINITY, f32::max);
324 let sparse_max = sparse_max.max(1e-6);
325
326 for result in sparse_results {
327 let normalized = result.score / sparse_max;
328 scores.entry(result.id).or_insert((None, None)).1 = Some(normalized);
329 }
330
331 let mut results: Vec<HybridResult> = scores
333 .into_iter()
334 .map(|(id, (dense, sparse))| {
335 let dense_contrib = dense.unwrap_or(0.0) * alpha;
336 let sparse_contrib = sparse.unwrap_or(0.0) * (1.0 - alpha);
337 HybridResult {
338 id,
339 score: dense_contrib + sparse_contrib,
340 dense_score: dense,
341 sparse_score: sparse,
342 }
343 })
344 .collect();
345
346 results.sort_by(|a, b| {
347 b.score
348 .partial_cmp(&a.score)
349 .unwrap_or(std::cmp::Ordering::Equal)
350 .then_with(|| a.id.cmp(&b.id))
351 });
352 results
353}
354
355pub fn dbsf_fusion(
359 dense_results: &[DistanceResult],
360 sparse_results: &[SparseResult],
361) -> Vec<HybridResult> {
362 let mut scores: HashMap<NodeId, (Option<f32>, Option<f32>)> = HashMap::new();
363
364 if !dense_results.is_empty() {
366 let similarities: Vec<f32> = dense_results
367 .iter()
368 .map(|r| 1.0 / (1.0 + r.distance))
369 .collect();
370 let mean: f32 = similarities.iter().sum::<f32>() / similarities.len() as f32;
371 let variance: f32 = similarities.iter().map(|s| (s - mean).powi(2)).sum::<f32>()
372 / similarities.len() as f32;
373 let std_dev = variance.sqrt().max(1e-6);
374
375 for (result, sim) in dense_results.iter().zip(similarities.iter()) {
376 let z_score = (sim - mean) / std_dev;
377 scores.entry(result.id).or_insert((None, None)).0 = Some(z_score);
378 }
379 }
380
381 if !sparse_results.is_empty() {
383 let mean: f32 =
384 sparse_results.iter().map(|r| r.score).sum::<f32>() / sparse_results.len() as f32;
385 let variance: f32 = sparse_results
386 .iter()
387 .map(|r| (r.score - mean).powi(2))
388 .sum::<f32>()
389 / sparse_results.len() as f32;
390 let std_dev = variance.sqrt().max(1e-6);
391
392 for result in sparse_results {
393 let z_score = (result.score - mean) / std_dev;
394 scores.entry(result.id).or_insert((None, None)).1 = Some(z_score);
395 }
396 }
397
398 let mut results: Vec<HybridResult> = scores
400 .into_iter()
401 .map(|(id, (dense, sparse))| HybridResult {
402 id,
403 score: dense.unwrap_or(0.0) + sparse.unwrap_or(0.0),
404 dense_score: dense,
405 sparse_score: sparse,
406 })
407 .collect();
408
409 results.sort_by(|a, b| {
410 b.score
411 .partial_cmp(&a.score)
412 .unwrap_or(std::cmp::Ordering::Equal)
413 .then_with(|| a.id.cmp(&b.id))
414 });
415 results
416}
417
418#[derive(Debug, Clone)]
424pub struct HybridResult {
425 pub id: NodeId,
427 pub score: f32,
429 pub dense_score: Option<f32>,
431 pub sparse_score: Option<f32>,
433}
434
435pub struct HybridSearch<'a> {
441 dense_index: &'a HnswIndex,
443 sparse_index: &'a SparseIndex,
445 metadata: Option<&'a MetadataStore>,
447}
448
449impl<'a> HybridSearch<'a> {
450 pub fn new(dense_index: &'a HnswIndex, sparse_index: &'a SparseIndex) -> Self {
452 Self {
453 dense_index,
454 sparse_index,
455 metadata: None,
456 }
457 }
458
459 pub fn with_metadata(mut self, metadata: &'a MetadataStore) -> Self {
461 self.metadata = Some(metadata);
462 self
463 }
464
465 pub fn query(&'a self) -> HybridQueryBuilder<'a> {
467 HybridQueryBuilder::new(self)
468 }
469
470 pub fn search(
472 &self,
473 query_vector: Option<&[f32]>,
474 query_text: Option<&str>,
475 k: usize,
476 fusion: FusionMethod,
477 pre_filter: Option<&HashSet<NodeId>>,
478 post_filter: Option<&dyn Fn(&HybridResult) -> bool>,
479 ) -> Vec<HybridResult> {
480 let fetch_k = k * 3;
482
483 let dense_results = if let Some(vector) = query_vector {
485 if let Some(filter) = pre_filter {
486 self.dense_index.search_filtered(vector, fetch_k, filter)
487 } else {
488 self.dense_index.search(vector, fetch_k)
489 }
490 } else {
491 Vec::new()
492 };
493
494 let sparse_results = if let Some(text) = query_text {
496 let mut results = self.sparse_index.search(text, fetch_k);
497 if let Some(filter) = pre_filter {
499 results.retain(|r| filter.contains(&r.id));
500 }
501 results
502 } else {
503 Vec::new()
504 };
505
506 let mut fused = match fusion {
508 FusionMethod::RRF(k_param) => {
509 reciprocal_rank_fusion(&dense_results, &sparse_results, k_param)
510 }
511 FusionMethod::Linear(alpha) => linear_fusion(&dense_results, &sparse_results, alpha),
512 FusionMethod::DBSF => dbsf_fusion(&dense_results, &sparse_results),
513 };
514
515 if let Some(filter_fn) = post_filter {
517 fused.retain(filter_fn);
518 }
519
520 fused.truncate(k);
522 fused
523 }
524
525 pub fn search_dense(&self, query_vector: &[f32], k: usize) -> Vec<DistanceResult> {
527 self.dense_index.search(query_vector, k)
528 }
529
530 pub fn search_sparse(&self, query_text: &str, k: usize) -> Vec<SparseResult> {
532 self.sparse_index.search(query_text, k)
533 }
534}
535
536pub struct HybridQueryBuilder<'a> {
542 search: &'a HybridSearch<'a>,
543 query_vector: Option<Vec<f32>>,
544 query_text: Option<String>,
545 k: usize,
546 fusion: FusionMethod,
547 pre_filter_ids: Option<HashSet<NodeId>>,
548 metadata_filter: Option<MetadataFilter>,
549}
550
551impl<'a> HybridQueryBuilder<'a> {
552 fn new(search: &'a HybridSearch<'a>) -> Self {
553 Self {
554 search,
555 query_vector: None,
556 query_text: None,
557 k: 10,
558 fusion: FusionMethod::default(),
559 pre_filter_ids: None,
560 metadata_filter: None,
561 }
562 }
563
564 pub fn with_vector(mut self, vector: Vec<f32>) -> Self {
566 self.query_vector = Some(vector);
567 self
568 }
569
570 pub fn with_text(mut self, text: impl Into<String>) -> Self {
572 self.query_text = Some(text.into());
573 self
574 }
575
576 pub fn with_both(self, vector: Vec<f32>, text: impl Into<String>) -> Self {
578 self.with_vector(vector).with_text(text)
579 }
580
581 pub fn top_k(mut self, k: usize) -> Self {
583 self.k = k;
584 self
585 }
586
587 pub fn fusion(mut self, method: FusionMethod) -> Self {
589 self.fusion = method;
590 self
591 }
592
593 pub fn rrf(mut self, k: usize) -> Self {
595 self.fusion = FusionMethod::RRF(k);
596 self
597 }
598
599 pub fn linear(mut self, alpha: f32) -> Self {
601 self.fusion = FusionMethod::Linear(alpha);
602 self
603 }
604
605 pub fn filter_ids(mut self, ids: HashSet<NodeId>) -> Self {
607 self.pre_filter_ids = Some(ids);
608 self
609 }
610
611 pub fn filter_metadata(mut self, filter: MetadataFilter) -> Self {
613 self.metadata_filter = Some(filter);
614 self
615 }
616
617 pub fn execute(self) -> Vec<HybridResult> {
619 let pre_filter = if let Some(meta_filter) = &self.metadata_filter {
621 if let Some(meta_store) = self.search.metadata {
622 let matching_ids = meta_store.filter(meta_filter);
624
625 if let Some(ref explicit_ids) = self.pre_filter_ids {
627 Some(matching_ids.intersection(explicit_ids).copied().collect())
628 } else {
629 Some(matching_ids)
630 }
631 } else {
632 self.pre_filter_ids.clone()
633 }
634 } else {
635 self.pre_filter_ids.clone()
636 };
637
638 self.search.search(
639 self.query_vector.as_deref(),
640 self.query_text.as_deref(),
641 self.k,
642 self.fusion,
643 pre_filter.as_ref(),
644 None,
645 )
646 }
647}
648
649pub trait Reranker: Send + Sync {
655 fn rerank(&self, results: &[HybridResult], query: &str) -> Vec<(NodeId, f32)>;
657}
658
659pub struct ExactMatchReranker {
661 pub boost: f32,
663}
664
665impl Default for ExactMatchReranker {
666 fn default() -> Self {
667 Self { boost: 2.0 }
668 }
669}
670
671impl Reranker for ExactMatchReranker {
672 fn rerank(&self, results: &[HybridResult], _query: &str) -> Vec<(NodeId, f32)> {
673 results.iter().map(|r| (r.id, r.score)).collect()
675 }
676}
677
678pub struct RerankerPipeline {
680 stages: Vec<Box<dyn Reranker>>,
681}
682
683impl RerankerPipeline {
684 pub fn new() -> Self {
685 Self { stages: Vec::new() }
686 }
687
688 pub fn add_stage(mut self, reranker: Box<dyn Reranker>) -> Self {
689 self.stages.push(reranker);
690 self
691 }
692
693 pub fn rerank(&self, mut results: Vec<HybridResult>, query: &str) -> Vec<HybridResult> {
694 for stage in &self.stages {
695 let reranked = stage.rerank(&results, query);
696 let score_map: HashMap<NodeId, f32> = reranked.into_iter().collect();
697
698 for result in &mut results {
699 if let Some(&new_score) = score_map.get(&result.id) {
700 result.score = new_score;
701 }
702 }
703
704 results.sort_by(|a, b| {
705 b.score
706 .partial_cmp(&a.score)
707 .unwrap_or(std::cmp::Ordering::Equal)
708 .then_with(|| a.id.cmp(&b.id))
709 });
710 }
711
712 results
713 }
714}
715
716impl Default for RerankerPipeline {
717 fn default() -> Self {
718 Self::new()
719 }
720}
721
722#[cfg(test)]
727mod tests {
728 use super::*;
729
730 #[test]
731 fn test_tokenize() {
732 let tokens = tokenize("Hello, World! This is a test-case.");
733 assert!(tokens.contains(&"hello".to_string()));
734 assert!(tokens.contains(&"world".to_string()));
735 assert!(tokens.contains(&"test-case".to_string()));
736 assert!(!tokens.contains(&"a".to_string())); }
738
739 #[test]
740 fn test_sparse_index() {
741 let mut index = SparseIndex::new();
742
743 index.index_text(0, "remote code execution vulnerability");
744 index.index_text(1, "cross-site scripting XSS vulnerability");
745 index.index_text(2, "SQL injection database vulnerability");
746
747 assert_eq!(index.len(), 3);
748
749 let results = index.search("code execution", 10);
750 assert!(!results.is_empty());
751 assert_eq!(results[0].id, 0); }
753
754 #[test]
755 fn test_sparse_remove() {
756 let mut index = SparseIndex::new();
757
758 index.index_text(0, "document one");
759 index.index_text(1, "document two");
760
761 assert_eq!(index.len(), 2);
762
763 index.remove(0);
764 assert_eq!(index.len(), 1);
765
766 let results = index.search("document", 10);
767 assert_eq!(results.len(), 1);
768 assert_eq!(results[0].id, 1);
769 }
770
771 #[test]
772 fn test_rrf_fusion() {
773 let dense = vec![
774 DistanceResult::new(1, 0.1),
775 DistanceResult::new(2, 0.2),
776 DistanceResult::new(3, 0.3),
777 ];
778
779 let sparse = vec![
780 SparseResult { id: 2, score: 5.0 },
781 SparseResult { id: 4, score: 4.0 },
782 SparseResult { id: 1, score: 3.0 },
783 ];
784
785 let fused = reciprocal_rank_fusion(&dense, &sparse, 60);
786
787 let top_ids: Vec<NodeId> = fused.iter().take(2).map(|r| r.id).collect();
789 assert!(top_ids.contains(&1));
790 assert!(top_ids.contains(&2));
791 }
792
793 #[test]
794 fn test_linear_fusion() {
795 let dense = vec![
796 DistanceResult::new(1, 0.1), DistanceResult::new(2, 0.5),
798 ];
799
800 let sparse = vec![
801 SparseResult { id: 2, score: 10.0 }, SparseResult { id: 1, score: 5.0 },
803 ];
804
805 let fused_dense = linear_fusion(&dense, &sparse, 0.9);
807 assert_eq!(fused_dense[0].id, 1); let fused_sparse = linear_fusion(&dense, &sparse, 0.1);
811 assert_eq!(fused_sparse[0].id, 2); }
813
814 #[test]
815 fn test_bm25_scoring() {
816 let mut index = SparseIndex::new();
817
818 index.index_text(0, "vulnerability vulnerability vulnerability");
820 index.index_text(1, "vulnerability in system");
821 index.index_text(2, "no relevant terms here");
822
823 let results = index.search("vulnerability", 10);
824
825 assert_eq!(results[0].id, 0);
827 assert!(results[0].score > results[1].score);
828 }
829}