Skip to main content

ripvec_core/encoder/ripvec/
hybrid.rs

1//! Hybrid search: RRF fusion of semantic + BM25, then boosts and rerank.
2//!
3//! Port of `~/src/semble/src/semble/search.py`. Three entry points:
4//!
5//! - [`search_semantic`] — cosine similarity over the dense index.
6//! - [`search_bm25`](crate::encoder::ripvec::bm25::search_bm25) — BM25
7//!   scoring (re-exported from the bm25 module).
8//! - [`search_hybrid`] — fuses both ranked lists via Reciprocal Rank
9//!   Fusion (k=60), over-fetching `top_k * 5` candidates, then applies
10//!   ripvec's `boost_multi_chunk_files` + `apply_query_boost` + the
11//!   penalty-aware `rerank_topk`.
12
13use std::collections::{HashMap, HashSet};
14
15use crate::chunk::CodeChunk;
16use crate::encoder::ripvec::bm25::{Bm25Index, search_bm25};
17use crate::encoder::ripvec::penalties::rerank_topk;
18use crate::encoder::ripvec::ranking::{apply_query_boost, boost_multi_chunk_files, resolve_alpha};
19
20/// Reciprocal Rank Fusion smoothing constant. Matches Python
21/// `_RRF_K = 60` from `search.py:11`.
22pub const RRF_K: f32 = 60.0;
23
24/// Over-fetch factor when assembling the hybrid candidate pool.
25const CANDIDATE_MULTIPLIER: usize = 5;
26
27/// Cosine similarity over L2-normalized vectors == dot product.
28fn dot(a: &[f32], b: &[f32]) -> f32 {
29    debug_assert_eq!(a.len(), b.len(), "embedding length mismatch");
30    a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
31}
32
33/// Pure semantic search: rank every chunk by dot product against the
34/// query embedding, then take the top-k after optional selector mask.
35#[must_use]
36pub fn search_semantic(
37    query_embedding: &[f32],
38    chunk_embeddings: &[Vec<f32>],
39    top_k: usize,
40    selector: Option<&[usize]>,
41) -> Vec<(usize, f32)> {
42    if top_k == 0 || chunk_embeddings.is_empty() {
43        return Vec::new();
44    }
45    let selector_set: Option<HashSet<usize>> = selector.map(|s| s.iter().copied().collect());
46
47    let mut scored: Vec<(usize, f32)> = chunk_embeddings
48        .iter()
49        .enumerate()
50        .filter(|(i, _)| selector_set.as_ref().is_none_or(|s| s.contains(i)))
51        .map(|(i, emb)| (i, dot(query_embedding, emb)))
52        .collect();
53
54    scored.sort_unstable_by(|a, b| b.1.total_cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
55    scored.truncate(top_k);
56    scored
57}
58
59/// Convert a list of `(index, raw_score)` to RRF scores.
60/// `rrf_score = 1 / (RRF_K + rank)` where rank is 1-based and the
61/// list is sorted descending by raw_score.
62fn rrf_scores(ranked: &[(usize, f32)]) -> HashMap<usize, f32> {
63    ranked
64        .iter()
65        .enumerate()
66        .map(|(rank0, (idx, _))| {
67            let rank = rank0 as f32 + 1.0;
68            (*idx, 1.0 / (RRF_K + rank))
69        })
70        .collect()
71}
72
73/// Hybrid search: alpha-weighted RRF fusion of semantic + BM25,
74/// followed by file-coherence + query boosts and the penalty-aware
75/// reranker. Mirrors `search.py:search_hybrid`.
76///
77/// `query_embedding` is the embedding of `query` produced by the same
78/// encoder that populated `chunk_embeddings`.
79///
80/// Over-fetches `top_k * 5` candidates from both sub-searches before
81/// fusing, so the merged pool is large enough that the boosts and
82/// reranker can do meaningful work.
83#[must_use]
84pub fn search_hybrid(
85    query: &str,
86    query_embedding: &[f32],
87    chunk_embeddings: &[Vec<f32>],
88    chunks: &[CodeChunk],
89    bm25: &Bm25Index,
90    top_k: usize,
91    alpha: Option<f32>,
92    selector: Option<&[usize]>,
93) -> Vec<(usize, f32)> {
94    if top_k == 0 || chunks.is_empty() {
95        return Vec::new();
96    }
97    let alpha_weight = resolve_alpha(query, alpha);
98    let candidate_count = top_k.saturating_mul(CANDIDATE_MULTIPLIER);
99
100    let semantic = search_semantic(query_embedding, chunk_embeddings, candidate_count, selector);
101    let bm25_hits = search_bm25(query, bm25, candidate_count, selector);
102
103    let normalized_semantic = rrf_scores(&semantic);
104    let normalized_bm25 = rrf_scores(&bm25_hits);
105
106    // Union of all chunks present in either ranked list.
107    let mut combined: HashMap<usize, f32> = HashMap::new();
108    let union: HashSet<usize> = normalized_semantic
109        .keys()
110        .chain(normalized_bm25.keys())
111        .copied()
112        .collect();
113    for idx in union {
114        let s = normalized_semantic.get(&idx).copied().unwrap_or(0.0);
115        let b = normalized_bm25.get(&idx).copied().unwrap_or(0.0);
116        combined.insert(idx, alpha_weight * s + (1.0 - alpha_weight) * b);
117    }
118
119    // Multi-chunk-file boost (in-place).
120    boost_multi_chunk_files(&mut combined, chunks);
121    // Query-type boost (returns a new map; matches Python's behaviour).
122    let boosted = apply_query_boost(&combined, query, chunks);
123
124    // Path penalties + saturation rerank.
125    // Semble disables path penalties for pure-semantic queries (α=1.0);
126    // alpha_weight comes from resolve_alpha so the < 1.0 condition matches
127    // Python's `penalise_paths=alpha_weight < 1.0` at search.py:121.
128    let penalise_paths = alpha_weight < 1.0;
129    let scores_vec: Vec<(usize, f32)> = boosted.into_iter().collect();
130    rerank_topk(&scores_vec, chunks, top_k, penalise_paths)
131}
132
133#[cfg(test)]
134mod tests {
135    use super::*;
136    use crate::encoder::ripvec::bm25::Bm25Index;
137
138    fn chunk(path: &str, content: &str) -> CodeChunk {
139        CodeChunk {
140            file_path: path.to_string(),
141            name: String::new(),
142            kind: String::new(),
143            start_line: 1,
144            end_line: 1,
145            content: content.to_string(),
146            enriched_content: content.to_string(),
147        }
148    }
149
150    fn unit_vec(values: &[f32]) -> Vec<f32> {
151        let norm: f32 = values.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-12);
152        values.iter().map(|x| x / norm).collect()
153    }
154
155    /// `test:rrf-k-60` — RRF scores use k=60 with 1-based ranks.
156    /// Rank 1 → 1/61; rank 2 → 1/62; rank 3 → 1/63.
157    #[test]
158    fn rrf_k_60() {
159        let ranked = vec![(7, 0.9), (3, 0.8), (5, 0.5)];
160        let rrf = rrf_scores(&ranked);
161        assert!((rrf[&7] - 1.0 / 61.0).abs() < 1e-7);
162        assert!((rrf[&3] - 1.0 / 62.0).abs() < 1e-7);
163        assert!((rrf[&5] - 1.0 / 63.0).abs() < 1e-7);
164    }
165
166    /// `test:hybrid-candidate-count-5x-top-k` — when both sub-searches
167    /// produce enough hits, hybrid over-fetches 5x top_k.
168    #[test]
169    fn hybrid_candidate_count_5x_top_k() {
170        // 10 chunks; embedding = a unit vector that aligns with chunk
171        // idx. Query embedding aligns most strongly with chunk 0.
172        let chunks: Vec<CodeChunk> = (0..10)
173            .map(|i| chunk(&format!("src/f{i}.rs"), &format!("content {i}")))
174            .collect();
175        let embeddings: Vec<Vec<f32>> = (0..10)
176            .map(|i| {
177                let mut v = vec![0.0_f32; 10];
178                v[i] = 1.0;
179                v
180            })
181            .collect();
182        let query_emb = unit_vec(&{
183            let mut q = vec![0.0_f32; 10];
184            q[0] = 1.0;
185            q
186        });
187        let bm25 = Bm25Index::build(&chunks);
188        let results = search_hybrid(
189            "content",
190            &query_emb,
191            &embeddings,
192            &chunks,
193            &bm25,
194            2,
195            Some(0.5),
196            None,
197        );
198        // top_k=2; the semantic best hit (chunk 0) should be present.
199        assert!(!results.is_empty());
200        assert!(results.iter().any(|(i, _)| *i == 0));
201        assert!(results.len() <= 2);
202    }
203
204    /// `test:hybrid-zero-bm25-excluded-from-fusion` — BM25 zero scores
205    /// don't enter the RRF pool because `search_bm25` drops them.
206    #[test]
207    fn hybrid_zero_bm25_excluded_from_fusion() {
208        let chunks = vec![chunk("src/a.rs", "alpha"), chunk("src/b.rs", "bravo")];
209        let bm25 = Bm25Index::build(&chunks);
210        // Query "alpha" only matches doc 0 in BM25.
211        let bm = search_bm25("alpha", &bm25, 10, None);
212        assert_eq!(bm.len(), 1);
213        let rrf = rrf_scores(&bm);
214        assert!(
215            !rrf.contains_key(&1),
216            "BM25 zero-score doc should be excluded"
217        );
218    }
219
220    /// `test:hybrid-applies-rerank-topk` — file-saturation decay applies
221    /// when hybrid returns multiple chunks from the same file.
222    #[test]
223    fn hybrid_applies_rerank_topk() {
224        // Two chunks in the same file with identical embeddings will
225        // tie in both sub-rankings; rerank_topk applies the 0.5 decay
226        // so the second chunk's effective score is half of the first.
227        let chunks = vec![
228            chunk("src/a.rs", "alpha bravo"),
229            chunk("src/a.rs", "alpha bravo"),
230        ];
231        let embeddings = vec![vec![1.0_f32, 0.0], vec![1.0_f32, 0.0]];
232        let bm25 = Bm25Index::build(&chunks);
233        let query_emb = vec![1.0_f32, 0.0];
234        let results = search_hybrid(
235            "alpha",
236            &query_emb,
237            &embeddings,
238            &chunks,
239            &bm25,
240            2,
241            Some(0.5),
242            None,
243        );
244        assert_eq!(results.len(), 2);
245        // The first hit's score should be strictly greater than the
246        // second's (saturation decay).
247        assert!(
248            results[0].1 > results[1].1,
249            "expected saturation decay; got scores={results:?}"
250        );
251    }
252
253    /// `test:hybrid-applies-query-boost` and
254    /// `test:hybrid-applies-multi-chunk-boost` are exercised transitively
255    /// by the rerank_topk and boost_multi_chunk_files unit tests in their
256    /// respective modules — the wiring in this module is a single call
257    /// through each. A non-trivial regression here would require a
258    /// behavioural shift in those modules, which their own tests cover.
259    #[test]
260    fn hybrid_pipeline_wires_through_boosts_and_rerank() {
261        // Smoke test: a query that touches a chunk whose file stem matches
262        // it should bubble up via the apply_query_boost stem-match path.
263        let chunks = vec![
264            chunk("src/auth.rs", "fn login() {}"),
265            chunk("src/utils.rs", "fn unrelated() {}"),
266        ];
267        let embeddings = vec![vec![1.0_f32, 0.0], vec![0.0, 1.0]];
268        let bm25 = Bm25Index::build(&chunks);
269        let query_emb = vec![0.0_f32, 0.0]; // unhelpful semantic vector
270        let results = search_hybrid(
271            "auth",
272            &query_emb,
273            &embeddings,
274            &chunks,
275            &bm25,
276            2,
277            Some(0.5),
278            None,
279        );
280        // The auth.rs chunk should rank first because the stem matches.
281        assert!(!results.is_empty());
282        let top = results[0].0;
283        assert_eq!(top, 0, "expected auth.rs first; got {results:?}");
284    }
285}