Skip to main content

reddb_server/storage/engine/
hybrid.rs

1//! Hybrid Search for RedDB
2//!
3//! Combines dense (vector) and sparse (keyword) search for improved retrieval:
4//! - Dense search: Semantic similarity via HNSW
5//! - Sparse search: BM25-style keyword matching
6//! - Fusion: Reciprocal Rank Fusion (RRF), Linear Combination
7//! - Filtering: Pre-filter and post-filter by metadata
8//! - Re-ranking: Score adjustment pipeline
9//!
10//! # Example
11//!
12//! ```ignore
13//! let hybrid = HybridSearch::new(&hnsw_index, &sparse_index);
14//! let results = hybrid
15//!     .query("CVE remote code execution")
16//!     .with_vector(&query_embedding)
17//!     .with_alpha(0.7)  // 70% dense, 30% sparse
18//!     .filter(|meta| meta.get("severity") == Some(&"critical".into()))
19//!     .top_k(10)
20//!     .execute();
21//! ```
22
23use std::collections::{HashMap, HashSet};
24
25use super::distance::DistanceResult;
26use super::hnsw::{HnswIndex, NodeId};
27use super::vector_metadata::{MetadataFilter, MetadataStore};
28
29// ============================================================================
30// Sparse Index (BM25-style)
31// ============================================================================
32
33/// BM25 parameters
34#[derive(Clone, Debug)]
35pub struct BM25Config {
36    /// Term frequency saturation parameter (typically 1.2-2.0)
37    pub k1: f32,
38    /// Length normalization parameter (typically 0.75)
39    pub b: f32,
40}
41
42impl Default for BM25Config {
43    fn default() -> Self {
44        Self { k1: 1.2, b: 0.75 }
45    }
46}
47
48/// Sparse inverted index for keyword search
49pub struct SparseIndex {
50    /// Term -> list of (doc_id, term_frequency)
51    postings: HashMap<String, Vec<(NodeId, f32)>>,
52    /// Document lengths (number of terms)
53    doc_lengths: HashMap<NodeId, usize>,
54    /// Average document length
55    avg_doc_length: f32,
56    /// Number of documents
57    doc_count: usize,
58    /// BM25 configuration
59    config: BM25Config,
60}
61
62impl SparseIndex {
63    /// Create a new sparse index
64    pub fn new() -> Self {
65        Self {
66            postings: HashMap::new(),
67            doc_lengths: HashMap::new(),
68            avg_doc_length: 0.0,
69            doc_count: 0,
70            config: BM25Config::default(),
71        }
72    }
73
74    /// Create with custom BM25 config
75    pub fn with_config(config: BM25Config) -> Self {
76        Self {
77            postings: HashMap::new(),
78            doc_lengths: HashMap::new(),
79            avg_doc_length: 0.0,
80            doc_count: 0,
81            config,
82        }
83    }
84
85    /// Index a document with its terms
86    pub fn index(&mut self, doc_id: NodeId, terms: &[String]) {
87        // Count term frequencies
88        let mut term_counts: HashMap<&str, usize> = HashMap::new();
89        for term in terms {
90            *term_counts.entry(term.as_str()).or_insert(0) += 1;
91        }
92
93        // Update postings
94        for (term, count) in term_counts {
95            self.postings
96                .entry(term.to_lowercase())
97                .or_default()
98                .push((doc_id, count as f32));
99        }
100
101        // Update document length
102        self.doc_lengths.insert(doc_id, terms.len());
103        self.doc_count += 1;
104
105        // Recalculate average document length
106        let total_length: usize = self.doc_lengths.values().sum();
107        self.avg_doc_length = total_length as f32 / self.doc_count as f32;
108    }
109
110    /// Index a document from text (tokenizes automatically)
111    pub fn index_text(&mut self, doc_id: NodeId, text: &str) {
112        let terms: Vec<String> = tokenize(text);
113        self.index(doc_id, &terms);
114    }
115
116    /// Remove a document from the index
117    pub fn remove(&mut self, doc_id: NodeId) {
118        // Remove from postings
119        for postings in self.postings.values_mut() {
120            postings.retain(|(id, _)| *id != doc_id);
121        }
122
123        // Remove from doc_lengths
124        if self.doc_lengths.remove(&doc_id).is_some() {
125            self.doc_count = self.doc_count.saturating_sub(1);
126
127            // Recalculate average
128            if self.doc_count > 0 {
129                let total_length: usize = self.doc_lengths.values().sum();
130                self.avg_doc_length = total_length as f32 / self.doc_count as f32;
131            } else {
132                self.avg_doc_length = 0.0;
133            }
134        }
135    }
136
137    /// Search using BM25 scoring
138    pub fn search(&self, query: &str, k: usize) -> Vec<SparseResult> {
139        let query_terms = tokenize(query);
140
141        if query_terms.is_empty() {
142            return Vec::new();
143        }
144
145        // Calculate BM25 scores for each document
146        let mut scores: HashMap<NodeId, f32> = HashMap::new();
147
148        for term in &query_terms {
149            let term_lower = term.to_lowercase();
150            if let Some(postings) = self.postings.get(&term_lower) {
151                // IDF component
152                let df = postings.len() as f32;
153                let idf = ((self.doc_count as f32 - df + 0.5) / (df + 0.5) + 1.0).ln();
154
155                for &(doc_id, tf) in postings {
156                    let doc_len = self.doc_lengths.get(&doc_id).copied().unwrap_or(1) as f32;
157
158                    // BM25 TF component
159                    let tf_component = (tf * (self.config.k1 + 1.0))
160                        / (tf
161                            + self.config.k1
162                                * (1.0 - self.config.b
163                                    + self.config.b * doc_len / self.avg_doc_length));
164
165                    *scores.entry(doc_id).or_insert(0.0) += idf * tf_component;
166                }
167            }
168        }
169
170        // Sort by score descending
171        let mut results: Vec<SparseResult> = scores
172            .into_iter()
173            .map(|(id, score)| SparseResult { id, score })
174            .collect();
175
176        results.sort_by(|a, b| {
177            b.score
178                .partial_cmp(&a.score)
179                .unwrap_or(std::cmp::Ordering::Equal)
180                .then_with(|| a.id.cmp(&b.id))
181        });
182        results.truncate(k);
183
184        results
185    }
186
187    /// Get number of indexed documents
188    pub fn len(&self) -> usize {
189        self.doc_count
190    }
191
192    /// Check if index is empty
193    pub fn is_empty(&self) -> bool {
194        self.doc_count == 0
195    }
196
197    /// Get vocabulary size
198    pub fn vocab_size(&self) -> usize {
199        self.postings.len()
200    }
201}
202
203impl Default for SparseIndex {
204    fn default() -> Self {
205        Self::new()
206    }
207}
208
209/// Result from sparse search
210#[derive(Debug, Clone)]
211pub struct SparseResult {
212    pub id: NodeId,
213    pub score: f32,
214}
215
216/// Simple tokenizer for text
217fn tokenize(text: &str) -> Vec<String> {
218    text.split(|c: char| !c.is_alphanumeric() && c != '-' && c != '_')
219        .filter(|s| s.len() >= 2) // Skip single characters
220        .map(|s| s.to_lowercase())
221        .collect()
222}
223
224// ============================================================================
225// Fusion Methods
226// ============================================================================
227
228/// Method for combining dense and sparse scores
229#[derive(Clone, Copy, Debug, PartialEq)]
230pub enum FusionMethod {
231    /// Reciprocal Rank Fusion with parameter k (default: 60)
232    RRF(usize),
233    /// Linear combination: alpha * dense + (1-alpha) * sparse
234    Linear(f32),
235    /// Distribution-Based Score Fusion
236    DBSF,
237}
238
239impl Default for FusionMethod {
240    fn default() -> Self {
241        FusionMethod::RRF(60)
242    }
243}
244
245/// Reciprocal Rank Fusion
246///
247/// RRF(d) = Σ 1/(k + rank(d))
248/// Works well when scores aren't comparable across systems
249pub fn reciprocal_rank_fusion(
250    dense_results: &[DistanceResult],
251    sparse_results: &[SparseResult],
252    k: usize,
253) -> Vec<HybridResult> {
254    let mut scores: HashMap<NodeId, f32> = HashMap::new();
255    let mut dense_scores: HashMap<NodeId, f32> = HashMap::new();
256    let mut sparse_scores: HashMap<NodeId, f32> = HashMap::new();
257
258    // Add dense scores
259    for (rank, result) in dense_results.iter().enumerate() {
260        let rrf_score = 1.0 / (k as f32 + rank as f32 + 1.0);
261        *scores.entry(result.id).or_insert(0.0) += rrf_score;
262        dense_scores.insert(result.id, result.distance);
263    }
264
265    // Add sparse scores
266    for (rank, result) in sparse_results.iter().enumerate() {
267        let rrf_score = 1.0 / (k as f32 + rank as f32 + 1.0);
268        *scores.entry(result.id).or_insert(0.0) += rrf_score;
269        sparse_scores.insert(result.id, result.score);
270    }
271
272    // Convert to results
273    let mut results: Vec<HybridResult> = scores
274        .into_iter()
275        .map(|(id, score)| HybridResult {
276            id,
277            score,
278            dense_score: dense_scores.get(&id).copied(),
279            sparse_score: sparse_scores.get(&id).copied(),
280        })
281        .collect();
282
283    results.sort_by(|a, b| {
284        b.score
285            .partial_cmp(&a.score)
286            .unwrap_or(std::cmp::Ordering::Equal)
287            .then_with(|| a.id.cmp(&b.id))
288    });
289    results
290}
291
292/// Linear combination of scores
293///
294/// score = alpha * normalized_dense + (1 - alpha) * normalized_sparse
295pub fn linear_fusion(
296    dense_results: &[DistanceResult],
297    sparse_results: &[SparseResult],
298    alpha: f32,
299) -> Vec<HybridResult> {
300    let mut scores: HashMap<NodeId, (Option<f32>, Option<f32>)> = HashMap::new();
301
302    // Normalize dense scores (distance to similarity: 1 / (1 + distance))
303    let dense_min = dense_results
304        .iter()
305        .map(|r| r.distance)
306        .fold(f32::INFINITY, f32::min);
307    let dense_max = dense_results
308        .iter()
309        .map(|r| r.distance)
310        .fold(f32::NEG_INFINITY, f32::max);
311    let dense_range = (dense_max - dense_min).max(1e-6);
312
313    for result in dense_results {
314        // Convert distance to similarity (lower distance = higher similarity)
315        let normalized = 1.0 - (result.distance - dense_min) / dense_range;
316        scores.entry(result.id).or_insert((None, None)).0 = Some(normalized);
317    }
318
319    // Normalize sparse scores (already similarity-based)
320    let sparse_max = sparse_results
321        .iter()
322        .map(|r| r.score)
323        .fold(f32::NEG_INFINITY, f32::max);
324    let sparse_max = sparse_max.max(1e-6);
325
326    for result in sparse_results {
327        let normalized = result.score / sparse_max;
328        scores.entry(result.id).or_insert((None, None)).1 = Some(normalized);
329    }
330
331    // Combine scores
332    let mut results: Vec<HybridResult> = scores
333        .into_iter()
334        .map(|(id, (dense, sparse))| {
335            let dense_contrib = dense.unwrap_or(0.0) * alpha;
336            let sparse_contrib = sparse.unwrap_or(0.0) * (1.0 - alpha);
337            HybridResult {
338                id,
339                score: dense_contrib + sparse_contrib,
340                dense_score: dense,
341                sparse_score: sparse,
342            }
343        })
344        .collect();
345
346    results.sort_by(|a, b| {
347        b.score
348            .partial_cmp(&a.score)
349            .unwrap_or(std::cmp::Ordering::Equal)
350            .then_with(|| a.id.cmp(&b.id))
351    });
352    results
353}
354
355/// Distribution-Based Score Fusion
356///
357/// Normalizes scores based on their distribution (z-score normalization)
358pub fn dbsf_fusion(
359    dense_results: &[DistanceResult],
360    sparse_results: &[SparseResult],
361) -> Vec<HybridResult> {
362    let mut scores: HashMap<NodeId, (Option<f32>, Option<f32>)> = HashMap::new();
363
364    // Z-score normalize dense (convert distance to similarity first)
365    if !dense_results.is_empty() {
366        let similarities: Vec<f32> = dense_results
367            .iter()
368            .map(|r| 1.0 / (1.0 + r.distance))
369            .collect();
370        let mean: f32 = similarities.iter().sum::<f32>() / similarities.len() as f32;
371        let variance: f32 = similarities.iter().map(|s| (s - mean).powi(2)).sum::<f32>()
372            / similarities.len() as f32;
373        let std_dev = variance.sqrt().max(1e-6);
374
375        for (result, sim) in dense_results.iter().zip(similarities.iter()) {
376            let z_score = (sim - mean) / std_dev;
377            scores.entry(result.id).or_insert((None, None)).0 = Some(z_score);
378        }
379    }
380
381    // Z-score normalize sparse
382    if !sparse_results.is_empty() {
383        let mean: f32 =
384            sparse_results.iter().map(|r| r.score).sum::<f32>() / sparse_results.len() as f32;
385        let variance: f32 = sparse_results
386            .iter()
387            .map(|r| (r.score - mean).powi(2))
388            .sum::<f32>()
389            / sparse_results.len() as f32;
390        let std_dev = variance.sqrt().max(1e-6);
391
392        for result in sparse_results {
393            let z_score = (result.score - mean) / std_dev;
394            scores.entry(result.id).or_insert((None, None)).1 = Some(z_score);
395        }
396    }
397
398    // Sum z-scores
399    let mut results: Vec<HybridResult> = scores
400        .into_iter()
401        .map(|(id, (dense, sparse))| HybridResult {
402            id,
403            score: dense.unwrap_or(0.0) + sparse.unwrap_or(0.0),
404            dense_score: dense,
405            sparse_score: sparse,
406        })
407        .collect();
408
409    results.sort_by(|a, b| {
410        b.score
411            .partial_cmp(&a.score)
412            .unwrap_or(std::cmp::Ordering::Equal)
413            .then_with(|| a.id.cmp(&b.id))
414    });
415    results
416}
417
418// ============================================================================
419// Hybrid Result
420// ============================================================================
421
422/// Result from hybrid search
423#[derive(Debug, Clone)]
424pub struct HybridResult {
425    /// Document ID
426    pub id: NodeId,
427    /// Combined score
428    pub score: f32,
429    /// Score from dense search (if present)
430    pub dense_score: Option<f32>,
431    /// Score from sparse search (if present)
432    pub sparse_score: Option<f32>,
433}
434
435// ============================================================================
436// Hybrid Search
437// ============================================================================
438
439/// Hybrid search combining dense and sparse retrieval
440pub struct HybridSearch<'a> {
441    /// Dense index (HNSW)
442    dense_index: &'a HnswIndex,
443    /// Sparse index (BM25)
444    sparse_index: &'a SparseIndex,
445    /// Optional metadata store for filtering
446    metadata: Option<&'a MetadataStore>,
447}
448
449impl<'a> HybridSearch<'a> {
450    /// Create a new hybrid search
451    pub fn new(dense_index: &'a HnswIndex, sparse_index: &'a SparseIndex) -> Self {
452        Self {
453            dense_index,
454            sparse_index,
455            metadata: None,
456        }
457    }
458
459    /// Add metadata store for filtering
460    pub fn with_metadata(mut self, metadata: &'a MetadataStore) -> Self {
461        self.metadata = Some(metadata);
462        self
463    }
464
465    /// Create a query builder
466    pub fn query(&'a self) -> HybridQueryBuilder<'a> {
467        HybridQueryBuilder::new(self)
468    }
469
470    /// Execute hybrid search
471    pub fn search(
472        &self,
473        query_vector: Option<&[f32]>,
474        query_text: Option<&str>,
475        k: usize,
476        fusion: FusionMethod,
477        pre_filter: Option<&HashSet<NodeId>>,
478        post_filter: Option<&dyn Fn(&HybridResult) -> bool>,
479    ) -> Vec<HybridResult> {
480        // Fetch more results for filtering
481        let fetch_k = k * 3;
482
483        // Dense search
484        let dense_results = if let Some(vector) = query_vector {
485            if let Some(filter) = pre_filter {
486                self.dense_index.search_filtered(vector, fetch_k, filter)
487            } else {
488                self.dense_index.search(vector, fetch_k)
489            }
490        } else {
491            Vec::new()
492        };
493
494        // Sparse search
495        let sparse_results = if let Some(text) = query_text {
496            let mut results = self.sparse_index.search(text, fetch_k);
497            // Apply pre-filter to sparse results
498            if let Some(filter) = pre_filter {
499                results.retain(|r| filter.contains(&r.id));
500            }
501            results
502        } else {
503            Vec::new()
504        };
505
506        // Fuse results
507        let mut fused = match fusion {
508            FusionMethod::RRF(k_param) => {
509                reciprocal_rank_fusion(&dense_results, &sparse_results, k_param)
510            }
511            FusionMethod::Linear(alpha) => linear_fusion(&dense_results, &sparse_results, alpha),
512            FusionMethod::DBSF => dbsf_fusion(&dense_results, &sparse_results),
513        };
514
515        // Apply post-filter
516        if let Some(filter_fn) = post_filter {
517            fused.retain(filter_fn);
518        }
519
520        // Return top k
521        fused.truncate(k);
522        fused
523    }
524
525    /// Dense-only search (for comparison)
526    pub fn search_dense(&self, query_vector: &[f32], k: usize) -> Vec<DistanceResult> {
527        self.dense_index.search(query_vector, k)
528    }
529
530    /// Sparse-only search (for comparison)
531    pub fn search_sparse(&self, query_text: &str, k: usize) -> Vec<SparseResult> {
532        self.sparse_index.search(query_text, k)
533    }
534}
535
536// ============================================================================
537// Query Builder
538// ============================================================================
539
540/// Builder for hybrid queries
541pub struct HybridQueryBuilder<'a> {
542    search: &'a HybridSearch<'a>,
543    query_vector: Option<Vec<f32>>,
544    query_text: Option<String>,
545    k: usize,
546    fusion: FusionMethod,
547    pre_filter_ids: Option<HashSet<NodeId>>,
548    metadata_filter: Option<MetadataFilter>,
549}
550
551impl<'a> HybridQueryBuilder<'a> {
552    fn new(search: &'a HybridSearch<'a>) -> Self {
553        Self {
554            search,
555            query_vector: None,
556            query_text: None,
557            k: 10,
558            fusion: FusionMethod::default(),
559            pre_filter_ids: None,
560            metadata_filter: None,
561        }
562    }
563
564    /// Set the query vector for dense search
565    pub fn with_vector(mut self, vector: Vec<f32>) -> Self {
566        self.query_vector = Some(vector);
567        self
568    }
569
570    /// Set the query text for sparse search
571    pub fn with_text(mut self, text: impl Into<String>) -> Self {
572        self.query_text = Some(text.into());
573        self
574    }
575
576    /// Set both vector and text
577    pub fn with_both(self, vector: Vec<f32>, text: impl Into<String>) -> Self {
578        self.with_vector(vector).with_text(text)
579    }
580
581    /// Set number of results to return
582    pub fn top_k(mut self, k: usize) -> Self {
583        self.k = k;
584        self
585    }
586
587    /// Set fusion method
588    pub fn fusion(mut self, method: FusionMethod) -> Self {
589        self.fusion = method;
590        self
591    }
592
593    /// Use RRF fusion
594    pub fn rrf(mut self, k: usize) -> Self {
595        self.fusion = FusionMethod::RRF(k);
596        self
597    }
598
599    /// Use linear fusion with alpha weight for dense
600    pub fn linear(mut self, alpha: f32) -> Self {
601        self.fusion = FusionMethod::Linear(alpha);
602        self
603    }
604
605    /// Pre-filter by document IDs
606    pub fn filter_ids(mut self, ids: HashSet<NodeId>) -> Self {
607        self.pre_filter_ids = Some(ids);
608        self
609    }
610
611    /// Pre-filter by metadata
612    pub fn filter_metadata(mut self, filter: MetadataFilter) -> Self {
613        self.metadata_filter = Some(filter);
614        self
615    }
616
617    /// Execute the query
618    pub fn execute(self) -> Vec<HybridResult> {
619        // Build pre-filter from metadata if available
620        let pre_filter = if let Some(meta_filter) = &self.metadata_filter {
621            if let Some(meta_store) = self.search.metadata {
622                // Use MetadataStore's filter method
623                let matching_ids = meta_store.filter(meta_filter);
624
625                // Intersect with explicit ID filter if present
626                if let Some(ref explicit_ids) = self.pre_filter_ids {
627                    Some(matching_ids.intersection(explicit_ids).copied().collect())
628                } else {
629                    Some(matching_ids)
630                }
631            } else {
632                self.pre_filter_ids.clone()
633            }
634        } else {
635            self.pre_filter_ids.clone()
636        };
637
638        self.search.search(
639            self.query_vector.as_deref(),
640            self.query_text.as_deref(),
641            self.k,
642            self.fusion,
643            pre_filter.as_ref(),
644            None,
645        )
646    }
647}
648
649// ============================================================================
650// Re-ranking
651// ============================================================================
652
653/// Re-ranker for adjusting hybrid search results
654pub trait Reranker: Send + Sync {
655    /// Re-rank the results, returning adjusted scores
656    fn rerank(&self, results: &[HybridResult], query: &str) -> Vec<(NodeId, f32)>;
657}
658
659/// Simple re-ranker that boosts exact matches
660pub struct ExactMatchReranker {
661    /// Boost factor for exact matches
662    pub boost: f32,
663}
664
665impl Default for ExactMatchReranker {
666    fn default() -> Self {
667        Self { boost: 2.0 }
668    }
669}
670
671impl Reranker for ExactMatchReranker {
672    fn rerank(&self, results: &[HybridResult], _query: &str) -> Vec<(NodeId, f32)> {
673        // This is a placeholder - real implementation would check document content
674        results.iter().map(|r| (r.id, r.score)).collect()
675    }
676}
677
678/// Re-ranking pipeline
679pub struct RerankerPipeline {
680    stages: Vec<Box<dyn Reranker>>,
681}
682
683impl RerankerPipeline {
684    pub fn new() -> Self {
685        Self { stages: Vec::new() }
686    }
687
688    pub fn add_stage(mut self, reranker: Box<dyn Reranker>) -> Self {
689        self.stages.push(reranker);
690        self
691    }
692
693    pub fn rerank(&self, mut results: Vec<HybridResult>, query: &str) -> Vec<HybridResult> {
694        for stage in &self.stages {
695            let reranked = stage.rerank(&results, query);
696            let score_map: HashMap<NodeId, f32> = reranked.into_iter().collect();
697
698            for result in &mut results {
699                if let Some(&new_score) = score_map.get(&result.id) {
700                    result.score = new_score;
701                }
702            }
703
704            results.sort_by(|a, b| {
705                b.score
706                    .partial_cmp(&a.score)
707                    .unwrap_or(std::cmp::Ordering::Equal)
708                    .then_with(|| a.id.cmp(&b.id))
709            });
710        }
711
712        results
713    }
714}
715
716impl Default for RerankerPipeline {
717    fn default() -> Self {
718        Self::new()
719    }
720}
721
722// ============================================================================
723// Tests
724// ============================================================================
725
726#[cfg(test)]
727mod tests {
728    use super::*;
729
730    #[test]
731    fn test_tokenize() {
732        let tokens = tokenize("Hello, World! This is a test-case.");
733        assert!(tokens.contains(&"hello".to_string()));
734        assert!(tokens.contains(&"world".to_string()));
735        assert!(tokens.contains(&"test-case".to_string()));
736        assert!(!tokens.contains(&"a".to_string())); // Single char filtered
737    }
738
739    #[test]
740    fn test_sparse_index() {
741        let mut index = SparseIndex::new();
742
743        index.index_text(0, "remote code execution vulnerability");
744        index.index_text(1, "cross-site scripting XSS vulnerability");
745        index.index_text(2, "SQL injection database vulnerability");
746
747        assert_eq!(index.len(), 3);
748
749        let results = index.search("code execution", 10);
750        assert!(!results.is_empty());
751        assert_eq!(results[0].id, 0); // Best match
752    }
753
754    #[test]
755    fn test_sparse_remove() {
756        let mut index = SparseIndex::new();
757
758        index.index_text(0, "document one");
759        index.index_text(1, "document two");
760
761        assert_eq!(index.len(), 2);
762
763        index.remove(0);
764        assert_eq!(index.len(), 1);
765
766        let results = index.search("document", 10);
767        assert_eq!(results.len(), 1);
768        assert_eq!(results[0].id, 1);
769    }
770
771    #[test]
772    fn test_rrf_fusion() {
773        let dense = vec![
774            DistanceResult::new(1, 0.1),
775            DistanceResult::new(2, 0.2),
776            DistanceResult::new(3, 0.3),
777        ];
778
779        let sparse = vec![
780            SparseResult { id: 2, score: 5.0 },
781            SparseResult { id: 4, score: 4.0 },
782            SparseResult { id: 1, score: 3.0 },
783        ];
784
785        let fused = reciprocal_rank_fusion(&dense, &sparse, 60);
786
787        // IDs 1 and 2 should have highest scores (appear in both)
788        let top_ids: Vec<NodeId> = fused.iter().take(2).map(|r| r.id).collect();
789        assert!(top_ids.contains(&1));
790        assert!(top_ids.contains(&2));
791    }
792
793    #[test]
794    fn test_linear_fusion() {
795        let dense = vec![
796            DistanceResult::new(1, 0.1), // closest
797            DistanceResult::new(2, 0.5),
798        ];
799
800        let sparse = vec![
801            SparseResult { id: 2, score: 10.0 }, // best sparse
802            SparseResult { id: 1, score: 5.0 },
803        ];
804
805        // With high alpha (dense-weighted)
806        let fused_dense = linear_fusion(&dense, &sparse, 0.9);
807        assert_eq!(fused_dense[0].id, 1); // Dense winner
808
809        // With low alpha (sparse-weighted)
810        let fused_sparse = linear_fusion(&dense, &sparse, 0.1);
811        assert_eq!(fused_sparse[0].id, 2); // Sparse winner
812    }
813
814    #[test]
815    fn test_bm25_scoring() {
816        let mut index = SparseIndex::new();
817
818        // Document with more relevant terms should score higher
819        index.index_text(0, "vulnerability vulnerability vulnerability");
820        index.index_text(1, "vulnerability in system");
821        index.index_text(2, "no relevant terms here");
822
823        let results = index.search("vulnerability", 10);
824
825        // Doc 0 has highest TF
826        assert_eq!(results[0].id, 0);
827        assert!(results[0].score > results[1].score);
828    }
829}