ripvec_core/encoder/ripvec/
hybrid.rs1use std::collections::{HashMap, HashSet};
14
15use crate::chunk::CodeChunk;
16use crate::encoder::ripvec::bm25::{Bm25Index, search_bm25};
17use crate::encoder::ripvec::penalties::rerank_topk;
18use crate::encoder::ripvec::ranking::{apply_query_boost, boost_multi_chunk_files, resolve_alpha};
19
20pub const RRF_K: f32 = 60.0;
23
24const CANDIDATE_MULTIPLIER: usize = 5;
26
27fn dot(a: &[f32], b: &[f32]) -> f32 {
29 debug_assert_eq!(a.len(), b.len(), "embedding length mismatch");
30 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
31}
32
33#[must_use]
36pub fn search_semantic(
37 query_embedding: &[f32],
38 chunk_embeddings: &[Vec<f32>],
39 top_k: usize,
40 selector: Option<&[usize]>,
41) -> Vec<(usize, f32)> {
42 if top_k == 0 || chunk_embeddings.is_empty() {
43 return Vec::new();
44 }
45 let selector_set: Option<HashSet<usize>> = selector.map(|s| s.iter().copied().collect());
46
47 let mut scored: Vec<(usize, f32)> = chunk_embeddings
48 .iter()
49 .enumerate()
50 .filter(|(i, _)| selector_set.as_ref().is_none_or(|s| s.contains(i)))
51 .map(|(i, emb)| (i, dot(query_embedding, emb)))
52 .collect();
53
54 scored.sort_unstable_by(|a, b| b.1.total_cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
55 scored.truncate(top_k);
56 scored
57}
58
59fn rrf_scores(ranked: &[(usize, f32)]) -> HashMap<usize, f32> {
63 ranked
64 .iter()
65 .enumerate()
66 .map(|(rank0, (idx, _))| {
67 let rank = rank0 as f32 + 1.0;
68 (*idx, 1.0 / (RRF_K + rank))
69 })
70 .collect()
71}
72
73#[must_use]
84pub fn search_hybrid(
85 query: &str,
86 query_embedding: &[f32],
87 chunk_embeddings: &[Vec<f32>],
88 chunks: &[CodeChunk],
89 bm25: &Bm25Index,
90 top_k: usize,
91 alpha: Option<f32>,
92 selector: Option<&[usize]>,
93) -> Vec<(usize, f32)> {
94 if top_k == 0 || chunks.is_empty() {
95 return Vec::new();
96 }
97 let alpha_weight = resolve_alpha(query, alpha);
98 let candidate_count = top_k.saturating_mul(CANDIDATE_MULTIPLIER);
99
100 let semantic = search_semantic(query_embedding, chunk_embeddings, candidate_count, selector);
101 let bm25_hits = search_bm25(query, bm25, candidate_count, selector);
102
103 let normalized_semantic = rrf_scores(&semantic);
104 let normalized_bm25 = rrf_scores(&bm25_hits);
105
106 let mut combined: HashMap<usize, f32> = HashMap::new();
108 let union: HashSet<usize> = normalized_semantic
109 .keys()
110 .chain(normalized_bm25.keys())
111 .copied()
112 .collect();
113 for idx in union {
114 let s = normalized_semantic.get(&idx).copied().unwrap_or(0.0);
115 let b = normalized_bm25.get(&idx).copied().unwrap_or(0.0);
116 combined.insert(idx, alpha_weight * s + (1.0 - alpha_weight) * b);
117 }
118
119 boost_multi_chunk_files(&mut combined, chunks);
121 let boosted = apply_query_boost(&combined, query, chunks);
123
124 let penalise_paths = alpha_weight < 1.0;
129 let scores_vec: Vec<(usize, f32)> = boosted.into_iter().collect();
130 rerank_topk(&scores_vec, chunks, top_k, penalise_paths)
131}
132
133#[cfg(test)]
134mod tests {
135 use super::*;
136 use crate::encoder::ripvec::bm25::Bm25Index;
137
138 fn chunk(path: &str, content: &str) -> CodeChunk {
139 CodeChunk {
140 file_path: path.to_string(),
141 name: String::new(),
142 kind: String::new(),
143 start_line: 1,
144 end_line: 1,
145 content: content.to_string(),
146 enriched_content: content.to_string(),
147 }
148 }
149
150 fn unit_vec(values: &[f32]) -> Vec<f32> {
151 let norm: f32 = values.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-12);
152 values.iter().map(|x| x / norm).collect()
153 }
154
155 #[test]
158 fn rrf_k_60() {
159 let ranked = vec![(7, 0.9), (3, 0.8), (5, 0.5)];
160 let rrf = rrf_scores(&ranked);
161 assert!((rrf[&7] - 1.0 / 61.0).abs() < 1e-7);
162 assert!((rrf[&3] - 1.0 / 62.0).abs() < 1e-7);
163 assert!((rrf[&5] - 1.0 / 63.0).abs() < 1e-7);
164 }
165
166 #[test]
169 fn hybrid_candidate_count_5x_top_k() {
170 let chunks: Vec<CodeChunk> = (0..10)
173 .map(|i| chunk(&format!("src/f{i}.rs"), &format!("content {i}")))
174 .collect();
175 let embeddings: Vec<Vec<f32>> = (0..10)
176 .map(|i| {
177 let mut v = vec![0.0_f32; 10];
178 v[i] = 1.0;
179 v
180 })
181 .collect();
182 let query_emb = unit_vec(&{
183 let mut q = vec![0.0_f32; 10];
184 q[0] = 1.0;
185 q
186 });
187 let bm25 = Bm25Index::build(&chunks);
188 let results = search_hybrid(
189 "content",
190 &query_emb,
191 &embeddings,
192 &chunks,
193 &bm25,
194 2,
195 Some(0.5),
196 None,
197 );
198 assert!(!results.is_empty());
200 assert!(results.iter().any(|(i, _)| *i == 0));
201 assert!(results.len() <= 2);
202 }
203
204 #[test]
207 fn hybrid_zero_bm25_excluded_from_fusion() {
208 let chunks = vec![chunk("src/a.rs", "alpha"), chunk("src/b.rs", "bravo")];
209 let bm25 = Bm25Index::build(&chunks);
210 let bm = search_bm25("alpha", &bm25, 10, None);
212 assert_eq!(bm.len(), 1);
213 let rrf = rrf_scores(&bm);
214 assert!(
215 !rrf.contains_key(&1),
216 "BM25 zero-score doc should be excluded"
217 );
218 }
219
220 #[test]
223 fn hybrid_applies_rerank_topk() {
224 let chunks = vec![
228 chunk("src/a.rs", "alpha bravo"),
229 chunk("src/a.rs", "alpha bravo"),
230 ];
231 let embeddings = vec![vec![1.0_f32, 0.0], vec![1.0_f32, 0.0]];
232 let bm25 = Bm25Index::build(&chunks);
233 let query_emb = vec![1.0_f32, 0.0];
234 let results = search_hybrid(
235 "alpha",
236 &query_emb,
237 &embeddings,
238 &chunks,
239 &bm25,
240 2,
241 Some(0.5),
242 None,
243 );
244 assert_eq!(results.len(), 2);
245 assert!(
248 results[0].1 > results[1].1,
249 "expected saturation decay; got scores={results:?}"
250 );
251 }
252
253 #[test]
260 fn hybrid_pipeline_wires_through_boosts_and_rerank() {
261 let chunks = vec![
264 chunk("src/auth.rs", "fn login() {}"),
265 chunk("src/utils.rs", "fn unrelated() {}"),
266 ];
267 let embeddings = vec![vec![1.0_f32, 0.0], vec![0.0, 1.0]];
268 let bm25 = Bm25Index::build(&chunks);
269 let query_emb = vec![0.0_f32, 0.0]; let results = search_hybrid(
271 "auth",
272 &query_emb,
273 &embeddings,
274 &chunks,
275 &bm25,
276 2,
277 Some(0.5),
278 None,
279 );
280 assert!(!results.is_empty());
282 let top = results[0].0;
283 assert_eq!(top, 0, "expected auth.rs first; got {results:?}");
284 }
285}