ripvec_core/encoder/ripvec/hybrid.rs
1//! Hybrid search: RRF fusion of semantic + BM25, then boosts and rerank.
2//!
3//! Port of `~/src/semble/src/semble/search.py`. Three entry points:
4//!
5//! - [`search_semantic`] — cosine similarity over the dense index.
6//! - [`search_bm25`](crate::encoder::ripvec::bm25::search_bm25) — BM25
7//! scoring (re-exported from the bm25 module).
8//! - [`search_hybrid`] — fuses both ranked lists via Reciprocal Rank
9//! Fusion (k=60), over-fetching `top_k * 5` candidates, then applies
10//! ripvec's `boost_multi_chunk_files` + `apply_query_boost` + the
11//! penalty-aware `rerank_topk`.
12
13use std::collections::{HashMap, HashSet};
14
15use ndarray::{Array1, Array2, ArrayView1, s};
16use rayon::prelude::*;
17
18use crate::chunk::CodeChunk;
19use crate::encoder::ripvec::bm25::{Bm25Index, search_bm25};
20use crate::encoder::ripvec::penalties::rerank_topk;
21use crate::encoder::ripvec::ranking::{apply_query_boost, boost_multi_chunk_files, resolve_alpha};
22
23/// Reciprocal Rank Fusion smoothing constant. Matches Python
24/// `_RRF_K = 60` from `search.py:11`.
25pub const RRF_K: f32 = 60.0;
26
27/// Over-fetch factor when assembling the hybrid candidate pool.
28const CANDIDATE_MULTIPLIER: usize = 5;
29
30/// Parallel matrix-vector multiply: `scores = matrix @ vector`.
31///
32/// Splits the matrix into one row-chunk per rayon worker. Each worker
33/// computes its slice's sgemv via ndarray's BLAS dispatch and writes
34/// into a disjoint output range. The chunk size is rounded up so the
35/// number of shards equals the rayon worker count (no work-stealing
36/// imbalance for symmetric input).
37///
38/// For a 1M-row × 256-col matrix on a 12-core M2 Max this approaches
39/// the aggregate memory-bandwidth ceiling (~250 GB/s) instead of the
40/// single-core ceiling (~50-80 GB/s) Accelerate's serial sgemv
41/// otherwise caps us at.
42/// Row count below which a single serial BLAS sgemv is faster than
43/// rayon-sharded parallel sgemv (the per-thread dispatch overhead
44/// dominates the inner work for small matrices).
45const SGEMV_SERIAL_THRESHOLD: usize = 4096;
46
47/// Parallel matrix-vector multiply via row-sharded BLAS sgemv.
48///
49/// See call site in `search_semantic` for the rationale; in short,
50/// Accelerate's level-2 BLAS is single-threaded on macOS, so we shard
51/// the matrix into row-chunks and call sgemv per worker to saturate
52/// aggregate memory bandwidth.
53///
54/// # Panics
55///
56/// Panics if ndarray returns a non-contiguous slice from
57/// `Array2::slice(s![start..end, ..])`. Row slices of a row-major
58/// matrix are always contiguous, so this is structurally unreachable;
59/// the panic guards against future layout changes that would silently
60/// break correctness.
61#[must_use]
62pub fn parallel_sgemv(matrix: &Array2<f32>, vector: &ArrayView1<f32>) -> Array1<f32> {
63 let n = matrix.nrows();
64 if n == 0 {
65 return Array1::zeros(0);
66 }
67 let n_threads = rayon::current_num_threads().max(1);
68 if n <= SGEMV_SERIAL_THRESHOLD || n_threads == 1 {
69 return matrix.dot(vector);
70 }
71 let chunk_size = n.div_ceil(n_threads);
72 let mut scores = vec![0.0_f32; n];
73 scores
74 .par_chunks_mut(chunk_size)
75 .enumerate()
76 .for_each(|(thread_idx, out)| {
77 let start = thread_idx * chunk_size;
78 let end = (start + out.len()).min(n);
79 let slice = matrix.slice(s![start..end, ..]);
80 let local: Array1<f32> = slice.dot(vector);
81 // SAFETY in spirit: `local` length == `out` length by
82 // construction (`out.len() == end - start` from
83 // par_chunks_mut, and `slice.nrows() == end - start`).
84 out.copy_from_slice(local.as_slice().expect("sgemv output contiguous"));
85 });
86 // `Array1::from_vec` is O(1).
87 Array1::from_vec(scores)
88}
89
90/// Pure semantic search: rank every chunk by dot product against the
91/// query embedding, then take the top-k after optional selector mask.
92///
93/// Math:
94/// scores = chunk_embeddings @ query_embedding
95/// top-k by select_nth_unstable_by, then sort the survivors.
96///
97/// `chunk_embeddings` is row-major `[n_chunks, hidden_dim]`; with the
98/// `cpu-accelerate` feature ndarray's `.dot()` dispatches to Accelerate's
99/// `cblas_sgemv`, which is vendor-tuned and near memory-bandwidth-bound
100/// (1 GB read per query at ~250 GB/s = ~4 ms theoretical floor on 1M
101/// chunks at 256 dim). Earlier scalar pointer-chasing path took 583
102/// ms per query (profile: samply v1, 2026-05-21).
103///
104/// Top-k uses `select_nth_unstable_by` (O(N) average) instead of a
105/// full sort (O(N log N)) — at 1M chunks selecting top-100 that's
106/// ~1M ops vs ~20M.
107#[must_use]
108pub fn search_semantic(
109 query_embedding: &[f32],
110 chunk_embeddings: &Array2<f32>,
111 top_k: usize,
112 selector: Option<&[usize]>,
113) -> Vec<(usize, f32)> {
114 let n_chunks = chunk_embeddings.nrows();
115 if top_k == 0 || n_chunks == 0 {
116 return Vec::new();
117 }
118 debug_assert_eq!(
119 query_embedding.len(),
120 chunk_embeddings.ncols(),
121 "query embedding dim ({}) != chunk embedding dim ({})",
122 query_embedding.len(),
123 chunk_embeddings.ncols(),
124 );
125
126 // GEMV: scores[i] = sum_d chunk_embeddings[i, d] * query[d].
127 //
128 // Accelerate's level-2 BLAS (`cblas_sgemv`) is single-threaded on
129 // macOS — only level-3 (GEMM) gets the multi-thread treatment.
130 // Single-core memory bandwidth on M2 Max is ~50-80 GB/s; the
131 // 1M-chunk × 256-dim matrix is 1 GB, so a single sgemv pays
132 // ~12-20 ms just on memory bandwidth and we measured ~76 ms in
133 // the profile.
134 //
135 // Fix: shard the matrix into row-chunks and dispatch one sgemv
136 // per rayon worker. Each thread reads its slice independently;
137 // aggregate bandwidth on M2 Max scales to ~250 GB/s with all
138 // cores active. Theoretical floor drops to ~4 ms. Each shard's
139 // sgemv is itself BLAS-optimal; we just stop forcing serial.
140 let query: ArrayView1<f32> = ArrayView1::from(query_embedding);
141 let scores: Array1<f32> = parallel_sgemv(chunk_embeddings, &query);
142
143 // Filter by selector if set. Build a HashSet for O(1) membership;
144 // at 1M chunks the HashSet is ~50 ms to build but per-chunk lookup
145 // amortises against the avoided dense scoring elsewhere.
146 let selector_set: Option<HashSet<usize>> = selector.map(|s| s.iter().copied().collect());
147
148 let mut scored: Vec<(usize, f32)> = if let Some(set) = selector_set {
149 scores
150 .iter()
151 .enumerate()
152 .filter(|(i, _)| set.contains(i))
153 .map(|(i, &s)| (i, s))
154 .collect()
155 } else {
156 // No selector: keep everything (we'll partial-sort below).
157 scores.iter().enumerate().map(|(i, &s)| (i, s)).collect()
158 };
159
160 // Top-k via O(N) selection. `select_nth_unstable_by` partitions
161 // around the k-th element; everything before it is in (unsorted)
162 // top-k. We then sort that small slice to recover the ordering.
163 if scored.len() > top_k {
164 scored.select_nth_unstable_by(top_k - 1, |a, b| {
165 b.1.total_cmp(&a.1).then_with(|| a.0.cmp(&b.0))
166 });
167 scored.truncate(top_k);
168 }
169 scored.sort_unstable_by(|a, b| b.1.total_cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
170 scored
171}
172
173/// Convert a list of `(index, raw_score)` to RRF scores.
174/// `rrf_score = 1 / (RRF_K + rank)` where rank is 1-based and the
175/// list is sorted descending by raw_score.
176fn rrf_scores(ranked: &[(usize, f32)]) -> HashMap<usize, f32> {
177 ranked
178 .iter()
179 .enumerate()
180 .map(|(rank0, (idx, _))| {
181 let rank = rank0 as f32 + 1.0;
182 (*idx, 1.0 / (RRF_K + rank))
183 })
184 .collect()
185}
186
187/// Hybrid search: alpha-weighted RRF fusion of semantic + BM25,
188/// followed by file-coherence + query boosts and the penalty-aware
189/// reranker. Mirrors `search.py:search_hybrid`.
190///
191/// `query_embedding` is the embedding of `query` produced by the same
192/// encoder that populated `chunk_embeddings`.
193///
194/// Over-fetches `top_k * 5` candidates from both sub-searches before
195/// fusing, so the merged pool is large enough that the boosts and
196/// reranker can do meaningful work.
197#[must_use]
198pub fn search_hybrid(
199 query: &str,
200 query_embedding: &[f32],
201 chunk_embeddings: &Array2<f32>,
202 chunks: &[CodeChunk],
203 bm25: &Bm25Index,
204 top_k: usize,
205 alpha: Option<f32>,
206 selector: Option<&[usize]>,
207) -> Vec<(usize, f32)> {
208 if top_k == 0 || chunks.is_empty() {
209 return Vec::new();
210 }
211 let alpha_weight = resolve_alpha(query, alpha);
212 let candidate_count = top_k.saturating_mul(CANDIDATE_MULTIPLIER);
213
214 let semantic = search_semantic(query_embedding, chunk_embeddings, candidate_count, selector);
215 let bm25_hits = search_bm25(query, bm25, candidate_count, selector);
216
217 let normalized_semantic = rrf_scores(&semantic);
218 let normalized_bm25 = rrf_scores(&bm25_hits);
219
220 // Union of all chunks present in either ranked list.
221 let mut combined: HashMap<usize, f32> = HashMap::new();
222 let union: HashSet<usize> = normalized_semantic
223 .keys()
224 .chain(normalized_bm25.keys())
225 .copied()
226 .collect();
227 for idx in union {
228 let s = normalized_semantic.get(&idx).copied().unwrap_or(0.0);
229 let b = normalized_bm25.get(&idx).copied().unwrap_or(0.0);
230 combined.insert(idx, alpha_weight * s + (1.0 - alpha_weight) * b);
231 }
232
233 // Multi-chunk-file boost (in-place).
234 boost_multi_chunk_files(&mut combined, chunks);
235 // Query-type boost (returns a new map; matches Python's behaviour).
236 let boosted = apply_query_boost(&combined, query, chunks);
237
238 // Path penalties + saturation rerank.
239 // Semble disables path penalties for pure-semantic queries (α=1.0);
240 // alpha_weight comes from resolve_alpha so the < 1.0 condition matches
241 // Python's `penalise_paths=alpha_weight < 1.0` at search.py:121.
242 let penalise_paths = alpha_weight < 1.0;
243 let scores_vec: Vec<(usize, f32)> = boosted.into_iter().collect();
244 rerank_topk(&scores_vec, chunks, top_k, penalise_paths)
245}
246
247#[cfg(test)]
248mod tests {
249 use super::*;
250 use crate::encoder::ripvec::bm25::Bm25Index;
251
252 fn chunk(path: &str, content: &str) -> CodeChunk {
253 CodeChunk {
254 file_path: path.to_string(),
255 name: String::new(),
256 kind: String::new(),
257 content_kind: crate::chunk::ContentKind::Code,
258 start_line: 1,
259 symbol_line: 1,
260 end_line: 1,
261 content: content.to_string(),
262 enriched_content: content.to_string(),
263 qualified_name: None,
264 }
265 }
266
267 fn unit_vec(values: &[f32]) -> Vec<f32> {
268 let norm: f32 = values.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-12);
269 values.iter().map(|x| x / norm).collect()
270 }
271
272 /// `test:rrf-k-60` — RRF scores use k=60 with 1-based ranks.
273 /// Rank 1 → 1/61; rank 2 → 1/62; rank 3 → 1/63.
274 #[test]
275 fn rrf_k_60() {
276 let ranked = vec![(7, 0.9), (3, 0.8), (5, 0.5)];
277 let rrf = rrf_scores(&ranked);
278 assert!((rrf[&7] - 1.0 / 61.0).abs() < 1e-7);
279 assert!((rrf[&3] - 1.0 / 62.0).abs() < 1e-7);
280 assert!((rrf[&5] - 1.0 / 63.0).abs() < 1e-7);
281 }
282
283 /// `test:hybrid-candidate-count-5x-top-k` — when both sub-searches
284 /// produce enough hits, hybrid over-fetches 5x top_k.
285 #[test]
286 fn hybrid_candidate_count_5x_top_k() {
287 // 10 chunks; embedding = a unit vector that aligns with chunk
288 // idx. Query embedding aligns most strongly with chunk 0.
289 let chunks: Vec<CodeChunk> = (0..10)
290 .map(|i| chunk(&format!("src/f{i}.rs"), &format!("content {i}")))
291 .collect();
292 let flat: Vec<f32> = (0..10)
293 .flat_map(|i| {
294 let mut v = vec![0.0_f32; 10];
295 v[i] = 1.0;
296 v
297 })
298 .collect();
299 let embeddings = Array2::from_shape_vec((10, 10), flat).unwrap();
300 let query_emb = unit_vec(&{
301 let mut q = vec![0.0_f32; 10];
302 q[0] = 1.0;
303 q
304 });
305 let bm25 = Bm25Index::build(&chunks);
306 let results = search_hybrid(
307 "content",
308 &query_emb,
309 &embeddings,
310 &chunks,
311 &bm25,
312 2,
313 Some(0.5),
314 None,
315 );
316 // top_k=2; the semantic best hit (chunk 0) should be present.
317 assert!(!results.is_empty());
318 assert!(results.iter().any(|(i, _)| *i == 0));
319 assert!(results.len() <= 2);
320 }
321
322 /// `test:hybrid-zero-bm25-excluded-from-fusion` — BM25 zero scores
323 /// don't enter the RRF pool because `search_bm25` drops them.
324 #[test]
325 fn hybrid_zero_bm25_excluded_from_fusion() {
326 let chunks = vec![chunk("src/a.rs", "alpha"), chunk("src/b.rs", "bravo")];
327 let bm25 = Bm25Index::build(&chunks);
328 // Query "alpha" only matches doc 0 in BM25.
329 let bm = search_bm25("alpha", &bm25, 10, None);
330 assert_eq!(bm.len(), 1);
331 let rrf = rrf_scores(&bm);
332 assert!(
333 !rrf.contains_key(&1),
334 "BM25 zero-score doc should be excluded"
335 );
336 }
337
338 /// `test:hybrid-applies-rerank-topk` — file-saturation decay applies
339 /// when hybrid returns multiple chunks from the same file.
340 #[test]
341 fn hybrid_applies_rerank_topk() {
342 // Two chunks in the same file with identical embeddings will
343 // tie in both sub-rankings; rerank_topk applies the 0.5 decay
344 // so the second chunk's effective score is half of the first.
345 let chunks = vec![
346 chunk("src/a.rs", "alpha bravo"),
347 chunk("src/a.rs", "alpha bravo"),
348 ];
349 let embeddings = Array2::from_shape_vec((2, 2), vec![1.0_f32, 0.0, 1.0, 0.0]).unwrap();
350 let bm25 = Bm25Index::build(&chunks);
351 let query_emb = vec![1.0_f32, 0.0];
352 let results = search_hybrid(
353 "alpha",
354 &query_emb,
355 &embeddings,
356 &chunks,
357 &bm25,
358 2,
359 Some(0.5),
360 None,
361 );
362 assert_eq!(results.len(), 2);
363 // The first hit's score should be strictly greater than the
364 // second's (saturation decay).
365 assert!(
366 results[0].1 > results[1].1,
367 "expected saturation decay; got scores={results:?}"
368 );
369 }
370
371 /// `test:hybrid-applies-query-boost` and
372 /// `test:hybrid-applies-multi-chunk-boost` are exercised transitively
373 /// by the rerank_topk and boost_multi_chunk_files unit tests in their
374 /// respective modules — the wiring in this module is a single call
375 /// through each. A non-trivial regression here would require a
376 /// behavioural shift in those modules, which their own tests cover.
377 #[test]
378 fn hybrid_pipeline_wires_through_boosts_and_rerank() {
379 // Smoke test: a query that touches a chunk whose file stem matches
380 // it should bubble up via the apply_query_boost stem-match path.
381 let chunks = vec![
382 chunk("src/auth.rs", "fn login() {}"),
383 chunk("src/utils.rs", "fn unrelated() {}"),
384 ];
385 let embeddings = Array2::from_shape_vec((2, 2), vec![1.0_f32, 0.0, 0.0, 1.0]).unwrap();
386 let bm25 = Bm25Index::build(&chunks);
387 let query_emb = vec![0.0_f32, 0.0]; // unhelpful semantic vector
388 let results = search_hybrid(
389 "auth",
390 &query_emb,
391 &embeddings,
392 &chunks,
393 &bm25,
394 2,
395 Some(0.5),
396 None,
397 );
398 // The auth.rs chunk should rank first because the stem matches.
399 assert!(!results.is_empty());
400 let top = results[0].0;
401 assert_eq!(top, 0, "expected auth.rs first; got {results:?}");
402 }
403}