1use std::collections::HashMap;
13
14use super::bm25_index::{BM25Index, ChunkKind, SearchResult};
15
16#[cfg(feature = "embeddings")]
17use super::embeddings::{cosine_similarity, EmbeddingEngine};
18
19const RRF_K: f64 = 60.0;
20
21const DEFAULT_BM25_WEIGHT: f64 = 1.0;
23const DEFAULT_DENSE_WEIGHT: f64 = 1.0;
24
25pub struct HybridConfig {
27 pub bm25_weight: f64,
28 pub dense_weight: f64,
29 pub bm25_candidates: usize,
30 pub dense_candidates: usize,
31}
32
33impl Default for HybridConfig {
34 fn default() -> Self {
35 Self {
36 bm25_weight: DEFAULT_BM25_WEIGHT,
37 dense_weight: DEFAULT_DENSE_WEIGHT,
38 bm25_candidates: 50,
39 dense_candidates: 50,
40 }
41 }
42}
43
44pub fn reciprocal_rank_fusion(
49 bm25_results: &[SearchResult],
50 dense_results: &[DenseSearchResult],
51 config: &HybridConfig,
52 top_k: usize,
53 graph_file_ranks: Option<&HashMap<String, usize>>,
54) -> Vec<HybridResult> {
55 let mut scores: HashMap<String, HybridResult> = HashMap::new();
56
57 for (rank, result) in bm25_results.iter().enumerate() {
58 let key = result_key(&result.file_path, result.start_line);
59 let rrf_score = config.bm25_weight / (RRF_K + rank as f64 + 1.0);
60
61 let entry = scores.entry(key).or_insert_with(|| HybridResult {
62 file_path: result.file_path.clone(),
63 symbol_name: result.symbol_name.clone(),
64 kind: result.kind.clone(),
65 start_line: result.start_line,
66 end_line: result.end_line,
67 snippet: result.snippet.clone(),
68 rrf_score: 0.0,
69 bm25_score: Some(result.score),
70 dense_score: None,
71 bm25_rank: None,
72 dense_rank: None,
73 });
74 entry.rrf_score += rrf_score;
75 entry.bm25_rank = Some(rank + 1);
76 }
77
78 for (rank, result) in dense_results.iter().enumerate() {
79 let key = result_key(&result.file_path, result.start_line);
80 let rrf_score = config.dense_weight / (RRF_K + rank as f64 + 1.0);
81
82 let entry = scores.entry(key).or_insert_with(|| HybridResult {
83 file_path: result.file_path.clone(),
84 symbol_name: result.symbol_name.clone(),
85 kind: result.kind.clone(),
86 start_line: result.start_line,
87 end_line: result.end_line,
88 snippet: result.snippet.clone(),
89 rrf_score: 0.0,
90 bm25_score: None,
91 dense_score: None,
92 bm25_rank: None,
93 dense_rank: None,
94 });
95 entry.rrf_score += rrf_score;
96 entry.dense_score = Some(result.similarity);
97 entry.dense_rank = Some(rank + 1);
98 }
99
100 if let Some(gr) = graph_file_ranks {
101 if !gr.is_empty() {
102 for entry in scores.values_mut() {
103 if let Some(&rank) = gr.get(&entry.file_path) {
104 entry.rrf_score += 1.0 / (RRF_K + rank as f64 + 1.0);
105 }
106 }
107 }
108 }
109
110 let mut results: Vec<HybridResult> = scores.into_values().collect();
111 results.sort_by(|a, b| {
112 b.rrf_score
113 .partial_cmp(&a.rrf_score)
114 .unwrap_or(std::cmp::Ordering::Equal)
115 });
116 results.truncate(top_k);
117 results
118}
119
120#[cfg(feature = "embeddings")]
123pub fn hybrid_search(
124 query: &str,
125 index: &BM25Index,
126 engine: Option<&EmbeddingEngine>,
127 chunk_embeddings: Option<&[Vec<f32>]>,
128 top_k: usize,
129 config: &HybridConfig,
130 graph_file_ranks: Option<&HashMap<String, usize>>,
131) -> Vec<HybridResult> {
132 let bm25_results = index.search(query, config.bm25_candidates);
133
134 let dense_results = match (engine, chunk_embeddings) {
135 (Some(eng), Some(embeddings)) => dense_search(
136 query,
137 eng,
138 &index.chunks,
139 embeddings,
140 config.dense_candidates,
141 ),
142 _ => Vec::new(),
143 };
144
145 let graph_enhances = graph_file_ranks.is_some_and(|m| !m.is_empty());
146
147 if dense_results.is_empty() {
148 if graph_enhances {
149 return reciprocal_rank_fusion(&bm25_results, &[], config, top_k, graph_file_ranks);
150 }
151 return bm25_results
152 .into_iter()
153 .take(top_k)
154 .map(HybridResult::from_bm25)
155 .collect();
156 }
157
158 reciprocal_rank_fusion(
159 &bm25_results,
160 &dense_results,
161 config,
162 top_k,
163 graph_file_ranks,
164 )
165}
166
167#[cfg(not(feature = "embeddings"))]
168pub fn hybrid_search(query: &str, index: &BM25Index, top_k: usize) -> Vec<HybridResult> {
169 index
170 .search(query, top_k)
171 .into_iter()
172 .map(HybridResult::from_bm25)
173 .collect()
174}
175
176#[cfg(feature = "embeddings")]
178fn dense_search(
179 query: &str,
180 engine: &EmbeddingEngine,
181 chunks: &[super::bm25_index::CodeChunk],
182 embeddings: &[Vec<f32>],
183 top_k: usize,
184) -> Vec<DenseSearchResult> {
185 let Ok(query_embedding) = engine.embed(query) else {
186 return Vec::new();
187 };
188
189 let mut scored: Vec<(usize, f32)> = embeddings
190 .iter()
191 .enumerate()
192 .map(|(i, emb)| (i, cosine_similarity(&query_embedding, emb)))
193 .collect();
194
195 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
196 scored.truncate(top_k);
197
198 scored
199 .into_iter()
200 .filter_map(|(idx, sim)| {
201 let chunk = chunks.get(idx)?;
202 let snippet = chunk.content.lines().take(5).collect::<Vec<_>>().join("\n");
203 Some(DenseSearchResult {
204 chunk_idx: idx,
205 similarity: sim,
206 file_path: chunk.file_path.clone(),
207 symbol_name: chunk.symbol_name.clone(),
208 kind: chunk.kind.clone(),
209 start_line: chunk.start_line,
210 end_line: chunk.end_line,
211 snippet,
212 })
213 })
214 .collect()
215}
216
217fn result_key(file_path: &str, start_line: usize) -> String {
218 format!("{file_path}:{start_line}")
219}
220
221#[derive(Debug, Clone)]
223pub struct DenseSearchResult {
224 pub chunk_idx: usize,
225 pub similarity: f32,
226 pub file_path: String,
227 pub symbol_name: String,
228 pub kind: ChunkKind,
229 pub start_line: usize,
230 pub end_line: usize,
231 pub snippet: String,
232}
233
234#[derive(Debug, Clone)]
236pub struct HybridResult {
237 pub file_path: String,
238 pub symbol_name: String,
239 pub kind: ChunkKind,
240 pub start_line: usize,
241 pub end_line: usize,
242 pub snippet: String,
243 pub rrf_score: f64,
244 pub bm25_score: Option<f64>,
245 pub dense_score: Option<f32>,
246 pub bm25_rank: Option<usize>,
247 pub dense_rank: Option<usize>,
248}
249
250impl HybridResult {
251 pub fn from_bm25_public(result: SearchResult) -> Self {
252 Self::from_bm25(result)
253 }
254
255 fn from_bm25(result: SearchResult) -> Self {
256 Self {
257 file_path: result.file_path,
258 symbol_name: result.symbol_name,
259 kind: result.kind,
260 start_line: result.start_line,
261 end_line: result.end_line,
262 snippet: result.snippet,
263 rrf_score: result.score,
264 bm25_score: Some(result.score),
265 dense_score: None,
266 bm25_rank: None,
267 dense_rank: None,
268 }
269 }
270
271 pub fn source_label(&self) -> &'static str {
272 match (self.bm25_rank.is_some(), self.dense_rank.is_some()) {
273 (true, true) => "hybrid",
274 (true, false) => "bm25",
275 (false, true) => "dense",
276 (false, false) => "unknown",
277 }
278 }
279}
280
281pub fn format_hybrid_results(results: &[HybridResult], compact: bool) -> String {
283 if results.is_empty() {
284 return "No results found.".to_string();
285 }
286
287 let mut out = String::new();
288 for (i, r) in results.iter().enumerate() {
289 if compact {
290 out.push_str(&format!(
291 "{}. {:.4} [{}] {}:{}-{} {:?} {}\n",
292 i + 1,
293 r.rrf_score,
294 r.source_label(),
295 r.file_path,
296 r.start_line,
297 r.end_line,
298 r.kind,
299 r.symbol_name,
300 ));
301 } else {
302 let source_info = match (r.bm25_rank, r.dense_rank) {
303 (Some(bm), Some(dn)) => format!("bm25:#{bm} + dense:#{dn}"),
304 (Some(bm), None) => format!("bm25:#{bm}"),
305 (None, Some(dn)) => format!("dense:#{dn}"),
306 _ => String::new(),
307 };
308 out.push_str(&format!(
309 "\n--- Result {} (rrf: {:.4}, {}) ---\n{} :: {} [{:?}] (L{}-{})\n{}\n",
310 i + 1,
311 r.rrf_score,
312 source_info,
313 r.file_path,
314 r.symbol_name,
315 r.kind,
316 r.start_line,
317 r.end_line,
318 r.snippet,
319 ));
320 }
321 }
322 out
323}
324
325#[cfg(test)]
326mod tests {
327 use super::*;
328
329 fn make_bm25_result(file: &str, name: &str, line: usize, score: f64) -> SearchResult {
330 SearchResult {
331 chunk_idx: 0,
332 score,
333 file_path: file.to_string(),
334 symbol_name: name.to_string(),
335 kind: ChunkKind::Function,
336 start_line: line,
337 end_line: line + 10,
338 snippet: format!("fn {name}() {{ }}"),
339 }
340 }
341
342 fn make_dense_result(file: &str, name: &str, line: usize, sim: f32) -> DenseSearchResult {
343 DenseSearchResult {
344 chunk_idx: 0,
345 similarity: sim,
346 file_path: file.to_string(),
347 symbol_name: name.to_string(),
348 kind: ChunkKind::Function,
349 start_line: line,
350 end_line: line + 10,
351 snippet: format!("fn {name}() {{ }}"),
352 }
353 }
354
355 #[test]
356 fn rrf_basic_fusion() {
357 let bm25 = vec![
358 make_bm25_result("a.rs", "alpha", 1, 5.0),
359 make_bm25_result("b.rs", "beta", 1, 3.0),
360 make_bm25_result("c.rs", "gamma", 1, 1.0),
361 ];
362 let dense = vec![
363 make_dense_result("b.rs", "beta", 1, 0.95),
364 make_dense_result("d.rs", "delta", 1, 0.90),
365 make_dense_result("a.rs", "alpha", 1, 0.85),
366 ];
367
368 let config = HybridConfig::default();
369 let results = reciprocal_rank_fusion(&bm25, &dense, &config, 10, None);
370
371 assert!(!results.is_empty());
372
373 let top = &results[0];
374 assert!(
375 top.bm25_rank.is_some() || top.dense_rank.is_some(),
376 "top result should appear in at least one ranking"
377 );
378
379 let beta = results.iter().find(|r| r.symbol_name == "beta").unwrap();
380 assert!(beta.bm25_rank.is_some() && beta.dense_rank.is_some());
381 assert_eq!(beta.source_label(), "hybrid");
382 }
383
384 #[test]
385 fn rrf_both_rankings_boost() {
386 let bm25 = vec![
387 make_bm25_result("a.rs", "only_bm25", 1, 5.0),
388 make_bm25_result("b.rs", "both", 1, 3.0),
389 ];
390 let dense = vec![
391 make_dense_result("c.rs", "only_dense", 1, 0.99),
392 make_dense_result("b.rs", "both", 1, 0.90),
393 ];
394
395 let config = HybridConfig {
396 bm25_weight: 0.5,
397 dense_weight: 0.5,
398 ..Default::default()
399 };
400 let results = reciprocal_rank_fusion(&bm25, &dense, &config, 10, None);
401
402 let both = results.iter().find(|r| r.symbol_name == "both").unwrap();
403 let only_bm25 = results
404 .iter()
405 .find(|r| r.symbol_name == "only_bm25")
406 .unwrap();
407 let only_dense = results
408 .iter()
409 .find(|r| r.symbol_name == "only_dense")
410 .unwrap();
411
412 assert!(
413 both.rrf_score > only_bm25.rrf_score,
414 "result in both rankings should score higher than BM25-only"
415 );
416 assert!(
417 both.rrf_score > only_dense.rrf_score,
418 "result in both rankings should score higher than dense-only"
419 );
420 }
421
422 #[test]
423 fn rrf_respects_top_k() {
424 let bm25: Vec<SearchResult> = (0..20)
425 .map(|i| make_bm25_result("a.rs", &format!("fn_{i}"), i * 10 + 1, 10.0 - i as f64))
426 .collect();
427
428 let results = reciprocal_rank_fusion(&bm25, &[], &HybridConfig::default(), 5, None);
429 assert_eq!(results.len(), 5);
430 }
431
432 #[test]
433 fn rrf_empty_inputs() {
434 let results = reciprocal_rank_fusion(&[], &[], &HybridConfig::default(), 10, None);
435 assert!(results.is_empty());
436 }
437
438 #[test]
439 fn rrf_bm25_only() {
440 let bm25 = vec![make_bm25_result("a.rs", "alpha", 1, 5.0)];
441 let results = reciprocal_rank_fusion(&bm25, &[], &HybridConfig::default(), 10, None);
442 assert_eq!(results.len(), 1);
443 assert_eq!(results[0].source_label(), "bm25");
444 }
445
446 #[test]
447 fn rrf_dense_only() {
448 let dense = vec![make_dense_result("a.rs", "alpha", 1, 0.95)];
449 let results = reciprocal_rank_fusion(&[], &dense, &HybridConfig::default(), 10, None);
450 assert_eq!(results.len(), 1);
451 assert_eq!(results[0].source_label(), "dense");
452 }
453
454 #[test]
455 fn format_compact() {
456 let results = vec![HybridResult {
457 file_path: "auth.rs".into(),
458 symbol_name: "validate".into(),
459 kind: ChunkKind::Function,
460 start_line: 10,
461 end_line: 20,
462 snippet: "fn validate() {}".into(),
463 rrf_score: 0.0156,
464 bm25_score: Some(4.2),
465 dense_score: Some(0.91),
466 bm25_rank: Some(1),
467 dense_rank: Some(2),
468 }];
469 let output = format_hybrid_results(&results, true);
470 assert!(output.contains("[hybrid]"));
471 assert!(output.contains("auth.rs"));
472 assert!(output.contains("validate"));
473 }
474
475 #[test]
476 fn format_verbose() {
477 let results = vec![HybridResult {
478 file_path: "auth.rs".into(),
479 symbol_name: "validate".into(),
480 kind: ChunkKind::Function,
481 start_line: 10,
482 end_line: 20,
483 snippet: "fn validate() {}".into(),
484 rrf_score: 0.0156,
485 bm25_score: Some(4.2),
486 dense_score: Some(0.91),
487 bm25_rank: Some(1),
488 dense_rank: Some(2),
489 }];
490 let output = format_hybrid_results(&results, false);
491 assert!(output.contains("bm25:#1 + dense:#2"));
492 }
493
494 #[test]
495 fn source_label_categories() {
496 let mut r = HybridResult {
497 file_path: String::new(),
498 symbol_name: String::new(),
499 kind: ChunkKind::Function,
500 start_line: 0,
501 end_line: 0,
502 snippet: String::new(),
503 rrf_score: 0.0,
504 bm25_score: None,
505 dense_score: None,
506 bm25_rank: None,
507 dense_rank: None,
508 };
509
510 r.bm25_rank = Some(1);
511 r.dense_rank = Some(1);
512 assert_eq!(r.source_label(), "hybrid");
513
514 r.dense_rank = None;
515 assert_eq!(r.source_label(), "bm25");
516
517 r.bm25_rank = None;
518 r.dense_rank = Some(1);
519 assert_eq!(r.source_label(), "dense");
520 }
521
522 #[test]
523 fn rrf_graph_proximity_boost() {
524 let bm25 = vec![
525 make_bm25_result("neighbor.rs", "n", 1, 5.0),
526 make_bm25_result("weak.rs", "low", 1, 1.0),
527 ];
528 let dense = vec![
529 make_dense_result("weak.rs", "low", 1, 0.99),
530 make_dense_result("other.rs", "o", 1, 0.50),
531 ];
532 let mut graph = HashMap::new();
533 graph.insert("neighbor.rs".to_string(), 0usize);
534
535 let results =
536 reciprocal_rank_fusion(&bm25, &dense, &HybridConfig::default(), 10, Some(&graph));
537
538 let neighbor = results
539 .iter()
540 .find(|r| r.file_path == "neighbor.rs")
541 .unwrap();
542 let weak = results.iter().find(|r| r.file_path == "weak.rs").unwrap();
543 assert!(
544 neighbor.rrf_score > weak.rrf_score,
545 "graph neighbor should outrank when it gets a third RRF signal"
546 );
547 }
548}