Skip to main content

ripvec_core/
hybrid.rs

1//! Hybrid semantic + keyword search with Reciprocal Rank Fusion (RRF).
2//!
3//! [`HybridIndex`] wraps a [`SearchIndex`] (dense vector search) and a
4//! [`Bm25Index`] (BM25 keyword search) and fuses their ranked results via
5//! Reciprocal Rank Fusion so that chunks appearing high in either list
6//! bubble to the top of the combined ranking.
7
8use std::collections::HashMap;
9use std::fmt;
10use std::str::FromStr;
11
12use crate::bm25::Bm25Index;
13use crate::chunk::CodeChunk;
14use crate::index::SearchIndex;
15
16/// Controls which retrieval strategy is used during search.
17#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
18pub enum SearchMode {
19    /// Fuse semantic (vector) and keyword (BM25) results via RRF.
20    #[default]
21    Hybrid,
22    /// Dense vector cosine-similarity ranking only.
23    Semantic,
24    /// BM25 keyword ranking only.
25    Keyword,
26}
27
28impl fmt::Display for SearchMode {
29    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
30        match self {
31            Self::Hybrid => f.write_str("hybrid"),
32            Self::Semantic => f.write_str("semantic"),
33            Self::Keyword => f.write_str("keyword"),
34        }
35    }
36}
37
38/// Error returned when a `SearchMode` string cannot be parsed.
39#[derive(Debug, Clone, PartialEq, Eq)]
40pub struct ParseSearchModeError(String);
41
42impl fmt::Display for ParseSearchModeError {
43    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
44        write!(
45            f,
46            "unknown search mode {:?}; expected hybrid, semantic, or keyword",
47            self.0
48        )
49    }
50}
51
52impl std::error::Error for ParseSearchModeError {}
53
54impl FromStr for SearchMode {
55    type Err = ParseSearchModeError;
56
57    fn from_str(s: &str) -> Result<Self, Self::Err> {
58        match s {
59            "hybrid" => Ok(Self::Hybrid),
60            "semantic" => Ok(Self::Semantic),
61            "keyword" => Ok(Self::Keyword),
62            other => Err(ParseSearchModeError(other.to_string())),
63        }
64    }
65}
66
67/// Combined semantic + keyword search index with RRF fusion.
68///
69/// Build once from chunks and pre-computed embeddings; query repeatedly
70/// via [`search`](Self::search).
71pub struct HybridIndex {
72    /// Semantic (dense vector) search index.
73    pub semantic: SearchIndex,
74    /// BM25 keyword search index.
75    bm25: Bm25Index,
76}
77
78impl HybridIndex {
79    /// Build a `HybridIndex` from raw chunks and their pre-computed embeddings.
80    ///
81    /// Constructs both the [`SearchIndex`] and [`Bm25Index`] in one call.
82    /// `cascade_dim` is forwarded to [`SearchIndex::new`] for optional MRL
83    /// cascade pre-filtering.
84    ///
85    /// # Errors
86    ///
87    /// Returns an error if the BM25 index cannot be built (e.g., tantivy
88    /// schema or writer failure).
89    pub fn new(
90        chunks: Vec<CodeChunk>,
91        embeddings: &[Vec<f32>],
92        cascade_dim: Option<usize>,
93    ) -> crate::Result<Self> {
94        let bm25 = Bm25Index::build(&chunks)?;
95        let semantic = SearchIndex::new(chunks, embeddings, cascade_dim);
96        Ok(Self { semantic, bm25 })
97    }
98
99    /// Assemble a `HybridIndex` from pre-built components.
100    ///
101    /// Useful when the caller has already constructed the sub-indices
102    /// separately (e.g., loaded from a cache).
103    #[must_use]
104    pub fn from_parts(semantic: SearchIndex, bm25: Bm25Index) -> Self {
105        Self { semantic, bm25 }
106    }
107
108    /// Search the index and return `(chunk_index, score)` pairs.
109    ///
110    /// Dispatches based on `mode`:
111    /// - [`SearchMode::Semantic`] — pure dense vector search via
112    ///   [`SearchIndex::rank`].
113    /// - [`SearchMode::Keyword`] — pure BM25 keyword search, truncated to
114    ///   `top_k`.
115    /// - [`SearchMode::Hybrid`] — retrieves both ranked lists, fuses them
116    ///   with [`rrf_fuse`], then truncates to `top_k`.
117    ///
118    /// Scores are min-max normalized to `[0, 1]` regardless of mode, so
119    /// a threshold of 0.5 always means "above midpoint of the score range"
120    /// whether the underlying scores are cosine similarity, BM25, or RRF.
121    #[must_use]
122    pub fn search(
123        &self,
124        query_embedding: &[f32],
125        query_text: &str,
126        top_k: usize,
127        threshold: f32,
128        mode: SearchMode,
129    ) -> Vec<(usize, f32)> {
130        let mut raw = match mode {
131            SearchMode::Semantic => {
132                // Fetch more than top_k so normalization has a meaningful range.
133                self.semantic
134                    .rank_turboquant(query_embedding, top_k.max(100), 0.0)
135            }
136            SearchMode::Keyword => self.bm25.search(query_text, top_k.max(100)),
137            SearchMode::Hybrid => {
138                let sem = self
139                    .semantic
140                    .rank_turboquant(query_embedding, top_k.max(100), 0.0);
141                let kw = self.bm25.search(query_text, top_k.max(100));
142                rrf_fuse(&sem, &kw, 60.0)
143            }
144        };
145
146        // Min-max normalize scores to [0, 1] so threshold is model-agnostic.
147        if let (Some(max), Some(min)) = (raw.first().map(|(_, s)| *s), raw.last().map(|(_, s)| *s))
148        {
149            let range = max - min;
150            if range > f32::EPSILON {
151                for (_, score) in &mut raw {
152                    *score = (*score - min) / range;
153                }
154            } else {
155                // All scores identical — normalize to 1.0
156                for (_, score) in &mut raw {
157                    *score = 1.0;
158                }
159            }
160        }
161
162        // Apply threshold on normalized scores, then truncate
163        raw.retain(|(_, score)| *score >= threshold);
164        raw.truncate(top_k);
165        raw
166    }
167
168    /// All chunks in the index.
169    #[must_use]
170    pub fn chunks(&self) -> &[CodeChunk] {
171        &self.semantic.chunks
172    }
173}
174
175/// Reciprocal Rank Fusion of two ranked lists.
176///
177/// Each entry in `semantic` and `bm25` is `(chunk_index, _score)`.
178/// The fused score for a chunk is the sum of `1 / (k + rank + 1)` across
179/// every list the chunk appears in, where `rank` is 0-based.
180///
181/// Returns all chunks that appear in either list, sorted descending by
182/// fused RRF score.
183///
184/// `k` should typically be 60.0 — a conventional constant that smooths the
185/// ranking boost for the very top results.
186#[must_use]
187pub fn rrf_fuse(semantic: &[(usize, f32)], bm25: &[(usize, f32)], k: f32) -> Vec<(usize, f32)> {
188    let mut scores: HashMap<usize, f32> = HashMap::new();
189
190    for (rank, &(idx, _)) in semantic.iter().enumerate() {
191        *scores.entry(idx).or_insert(0.0) += 1.0 / (k + rank as f32 + 1.0);
192    }
193    for (rank, &(idx, _)) in bm25.iter().enumerate() {
194        *scores.entry(idx).or_insert(0.0) += 1.0 / (k + rank as f32 + 1.0);
195    }
196
197    let mut results: Vec<(usize, f32)> = scores.into_iter().collect();
198    results.sort_unstable_by(|a, b| {
199        b.1.total_cmp(&a.1).then_with(|| a.0.cmp(&b.0)) // stable tie-break by chunk index
200    });
201    results
202}
203
204/// Logarithmic saturation steepness for PageRank boost.
205///
206/// Controls how quickly the boost curve flattens. With `PAGERANK_BETA=10`:
207/// - rank 0.01 → 4% of max boost (barely boosted)
208/// - rank 0.10 → 29% of max boost
209/// - rank 0.50 → 75% of max boost
210/// - rank 1.00 → 100% of max boost
211///
212/// This prevents the single highest-ranked definition from getting a
213/// disproportionate boost relative to the second-highest.
214const PAGERANK_BETA: f32 = 10.0;
215
216/// Apply a multiplicative PageRank boost to search results.
217///
218/// For each result, looks up the chunk's PageRank score and applies a
219/// log-saturated boost:
220///
221///   `boosted = score * (1 + alpha * log(1 + beta * rank) / log(1 + beta))`
222///
223/// The logarithmic saturation compresses high PageRank values so the
224/// top-ranked definition doesn't dominate. This models a Bayesian prior
225/// where structural importance multiplies with query relevance.
226///
227/// Results are re-sorted after boosting.
228///
229/// `pagerank_by_file` maps relative file paths to their PageRank scores
230/// (pre-normalized to [0, 1] by dividing by max rank).
231/// `alpha` controls overall boost strength. The `alpha` field from
232/// [`RepoGraph`] is recommended (auto-tuned from graph density).
233pub fn boost_with_pagerank<S: std::hash::BuildHasher>(
234    results: &mut [(usize, f32)],
235    chunks: &[CodeChunk],
236    pagerank_by_file: &HashMap<String, f32, S>,
237    alpha: f32,
238) {
239    let log_denom = (1.0 + PAGERANK_BETA).ln();
240    if log_denom <= f32::EPSILON {
241        return;
242    }
243
244    for (idx, score) in results.iter_mut() {
245        if let Some(chunk) = chunks.get(*idx) {
246            // Try definition-level lookup first, fall back to file-level
247            let def_key = format!("{}::{}", chunk.file_path, chunk.name);
248            let rank = pagerank_by_file
249                .get(&def_key)
250                .or_else(|| pagerank_by_file.get(&chunk.file_path))
251                .copied()
252                .unwrap_or(0.0);
253            let saturated = (1.0 + PAGERANK_BETA * rank).ln() / log_denom;
254            *score *= 1.0 + alpha * saturated;
255        }
256    }
257    // Re-sort descending by boosted score
258    results.sort_unstable_by(|a, b| b.1.total_cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
259}
260
261/// Build a normalized PageRank lookup table from a [`RepoGraph`].
262///
263/// Returns a map from `"file_path::def_name"` to definition-level PageRank
264/// normalized to `[0, 1]`. Also inserts file-level entries (`"file_path"`)
265/// as aggregated fallback for chunks that don't match a specific definition.
266#[must_use]
267pub fn pagerank_lookup(graph: &crate::repo_map::RepoGraph) -> HashMap<String, f32> {
268    let max_rank = graph.def_ranks.iter().copied().fold(0.0_f32, f32::max);
269    if max_rank <= f32::EPSILON {
270        // Fall back to file-level ranks if no def-level data
271        let file_max = graph.base_ranks.iter().copied().fold(0.0_f32, f32::max);
272        if file_max <= f32::EPSILON {
273            return HashMap::new();
274        }
275        return graph
276            .files
277            .iter()
278            .zip(graph.base_ranks.iter())
279            .map(|(file, &rank)| (file.path.clone(), rank / file_max))
280            .collect();
281    }
282
283    let mut map = HashMap::new();
284
285    // Definition-level entries: "path::name" -> def_rank
286    for (file_idx, file) in graph.files.iter().enumerate() {
287        for (def_idx, def) in file.defs.iter().enumerate() {
288            let flat = graph.def_offsets[file_idx] + def_idx;
289            if let Some(&rank) = graph.def_ranks.get(flat) {
290                let key = format!("{}::{}", file.path, def.name);
291                map.insert(key, rank / max_rank);
292            }
293        }
294        // File-level aggregate for fallback
295        if file_idx < graph.base_ranks.len() {
296            map.insert(file.path.clone(), graph.base_ranks[file_idx] / max_rank);
297        }
298    }
299
300    map
301}
302
303#[cfg(test)]
304mod tests {
305    use super::*;
306
307    #[test]
308    fn rrf_union_semantics() {
309        // sem: [0, 1, 2], bm25: [3, 0, 4]
310        // Chunk 0 appears in both lists → highest RRF score.
311        // Chunks 1, 2, 3, 4 appear in exactly one list → all five appear.
312        let sem = vec![(0, 0.9), (1, 0.8), (2, 0.7)];
313        let bm25 = vec![(3, 10.0), (0, 8.0), (4, 6.0)];
314
315        let fused = rrf_fuse(&sem, &bm25, 60.0);
316
317        let indices: Vec<usize> = fused.iter().map(|&(i, _)| i).collect();
318
319        // All 5 unique chunks must appear
320        for expected in [0, 1, 2, 3, 4] {
321            assert!(
322                indices.contains(&expected),
323                "chunk {expected} missing from fused results"
324            );
325        }
326        assert_eq!(fused.len(), 5);
327
328        // Chunk 0 must rank first (double-list bonus)
329        assert_eq!(indices[0], 0, "chunk 0 should rank first");
330    }
331
332    #[test]
333    fn rrf_single_list() {
334        // Only semantic results; BM25 is empty.
335        let sem = vec![(0, 0.9), (1, 0.8)];
336        let bm25: Vec<(usize, f32)> = vec![];
337
338        let fused = rrf_fuse(&sem, &bm25, 60.0);
339
340        assert_eq!(fused.len(), 2);
341        // Chunk 0 ranked first in sem list → higher RRF score than chunk 1
342        assert_eq!(fused[0].0, 0);
343        assert_eq!(fused[1].0, 1);
344        assert!(fused[0].1 > fused[1].1);
345    }
346
347    #[test]
348    fn search_mode_roundtrip() {
349        assert_eq!("hybrid".parse::<SearchMode>().unwrap(), SearchMode::Hybrid);
350        assert_eq!(
351            "semantic".parse::<SearchMode>().unwrap(),
352            SearchMode::Semantic
353        );
354        assert_eq!(
355            "keyword".parse::<SearchMode>().unwrap(),
356            SearchMode::Keyword
357        );
358
359        let err = "invalid".parse::<SearchMode>();
360        assert!(err.is_err(), "expected parse error for 'invalid'");
361        let msg = err.unwrap_err().to_string();
362        assert!(
363            msg.contains("invalid"),
364            "error message should echo the bad input"
365        );
366    }
367
368    #[test]
369    fn search_mode_display() {
370        assert_eq!(SearchMode::Hybrid.to_string(), "hybrid");
371        assert_eq!(SearchMode::Semantic.to_string(), "semantic");
372        assert_eq!(SearchMode::Keyword.to_string(), "keyword");
373    }
374
375    #[test]
376    fn pagerank_boost_amplifies_relevant() {
377        let chunks = vec![
378            CodeChunk {
379                file_path: "important.rs".into(),
380                name: "a".into(),
381                kind: "function".into(),
382                start_line: 1,
383                end_line: 10,
384                content: String::new(),
385                enriched_content: String::new(),
386            },
387            CodeChunk {
388                file_path: "obscure.rs".into(),
389                name: "b".into(),
390                kind: "function".into(),
391                start_line: 1,
392                end_line: 10,
393                content: String::new(),
394                enriched_content: String::new(),
395            },
396        ];
397
398        // Both start with same score; important.rs has high PageRank
399        let mut results = vec![(0, 0.8_f32), (1, 0.8)];
400        let mut pr = HashMap::new();
401        pr.insert("important.rs".to_string(), 1.0); // max PageRank
402        pr.insert("obscure.rs".to_string(), 0.1); // low PageRank
403
404        boost_with_pagerank(&mut results, &chunks, &pr, 0.3);
405
406        // important.rs should now rank higher
407        assert_eq!(
408            results[0].0, 0,
409            "important.rs should rank first after boost"
410        );
411        assert!(results[0].1 > results[1].1);
412
413        // Verify the math with log saturation (beta=10):
414        // rank=1.0: saturated = ln(11)/ln(11) = 1.0 → 0.8 * (1 + 0.3 * 1.0) = 1.04
415        assert!(
416            (results[0].1 - 1.04).abs() < 0.01,
417            "rank=1.0 boost: expected ~1.04, got {}",
418            results[0].1
419        );
420        // rank=0.1: saturated = ln(2)/ln(11) ≈ 0.289 → 0.8 * (1 + 0.3 * 0.289) ≈ 0.869
421        assert!(
422            (results[1].1 - 0.869).abs() < 0.01,
423            "rank=0.1 boost: expected ~0.869, got {}",
424            results[1].1
425        );
426    }
427
428    #[test]
429    fn pagerank_boost_zero_relevance_stays_zero() {
430        let chunks = vec![CodeChunk {
431            file_path: "important.rs".into(),
432            name: "a".into(),
433            kind: "function".into(),
434            start_line: 1,
435            end_line: 10,
436            content: String::new(),
437            enriched_content: String::new(),
438        }];
439
440        let mut results = vec![(0, 0.0_f32)];
441        let mut pr = HashMap::new();
442        pr.insert("important.rs".to_string(), 1.0);
443
444        boost_with_pagerank(&mut results, &chunks, &pr, 0.3);
445
446        // Zero score stays zero regardless of PageRank
447        assert_eq!(results[0].1, 0.0);
448    }
449
450    #[test]
451    fn pagerank_boost_unknown_file_no_effect() {
452        let chunks = vec![CodeChunk {
453            file_path: "unknown.rs".into(),
454            name: "a".into(),
455            kind: "function".into(),
456            start_line: 1,
457            end_line: 10,
458            content: String::new(),
459            enriched_content: String::new(),
460        }];
461
462        let mut results = vec![(0, 0.5_f32)];
463        let pr = HashMap::new(); // empty — no PageRank data
464
465        boost_with_pagerank(&mut results, &chunks, &pr, 0.3);
466
467        // No PageRank data → no boost
468        assert_eq!(results[0].1, 0.5);
469    }
470}