1use std::collections::HashMap;
13
14use super::bm25_index::{BM25Index, ChunkKind, SearchResult};
15
16#[cfg(feature = "embeddings")]
17use super::embeddings::EmbeddingEngine;
18
19const RRF_K: f64 = 60.0;
20
21const DEFAULT_BM25_WEIGHT: f64 = 1.0;
23const DEFAULT_DENSE_WEIGHT: f64 = 1.0;
24
25pub struct HybridConfig {
27 pub bm25_weight: f64,
28 pub dense_weight: f64,
29 pub bm25_candidates: usize,
30 pub dense_candidates: usize,
31}
32
33impl Default for HybridConfig {
34 fn default() -> Self {
35 Self {
36 bm25_weight: DEFAULT_BM25_WEIGHT,
37 dense_weight: DEFAULT_DENSE_WEIGHT,
38 bm25_candidates: 50,
39 dense_candidates: 50,
40 }
41 }
42}
43
44pub fn reciprocal_rank_fusion(
49 bm25_results: &[SearchResult],
50 dense_results: &[DenseSearchResult],
51 config: &HybridConfig,
52 top_k: usize,
53 graph_file_ranks: Option<&HashMap<String, usize>>,
54) -> Vec<HybridResult> {
55 let mut scores: HashMap<String, HybridResult> = HashMap::new();
56
57 for (rank, result) in bm25_results.iter().enumerate() {
58 let key = result_key(&result.file_path, result.start_line);
59 let rrf_score = config.bm25_weight / (RRF_K + rank as f64 + 1.0);
60
61 let entry = scores.entry(key).or_insert_with(|| HybridResult {
62 file_path: result.file_path.clone(),
63 symbol_name: result.symbol_name.clone(),
64 kind: result.kind.clone(),
65 start_line: result.start_line,
66 end_line: result.end_line,
67 snippet: result.snippet.clone(),
68 rrf_score: 0.0,
69 bm25_score: Some(result.score),
70 dense_score: None,
71 bm25_rank: None,
72 dense_rank: None,
73 });
74 entry.rrf_score += rrf_score;
75 entry.bm25_rank = Some(rank + 1);
76 }
77
78 for (rank, result) in dense_results.iter().enumerate() {
79 let key = result_key(&result.file_path, result.start_line);
80 let rrf_score = config.dense_weight / (RRF_K + rank as f64 + 1.0);
81
82 let entry = scores.entry(key).or_insert_with(|| HybridResult {
83 file_path: result.file_path.clone(),
84 symbol_name: result.symbol_name.clone(),
85 kind: result.kind.clone(),
86 start_line: result.start_line,
87 end_line: result.end_line,
88 snippet: result.snippet.clone(),
89 rrf_score: 0.0,
90 bm25_score: None,
91 dense_score: None,
92 bm25_rank: None,
93 dense_rank: None,
94 });
95 entry.rrf_score += rrf_score;
96 entry.dense_score = Some(result.similarity);
97 entry.dense_rank = Some(rank + 1);
98 }
99
100 if let Some(gr) = graph_file_ranks {
101 if !gr.is_empty() {
102 for entry in scores.values_mut() {
103 if let Some(&rank) = gr.get(&entry.file_path) {
104 entry.rrf_score += 1.0 / (RRF_K + rank as f64 + 1.0);
105 }
106 }
107 }
108 }
109
110 let mut results: Vec<HybridResult> = scores.into_values().collect();
111 results.sort_by(|a, b| {
112 b.rrf_score
113 .partial_cmp(&a.rrf_score)
114 .unwrap_or(std::cmp::Ordering::Equal)
115 });
116 results.truncate(top_k);
117 results
118}
119
120#[cfg(feature = "embeddings")]
123pub fn hybrid_search(
124 query: &str,
125 index: &BM25Index,
126 engine: Option<&EmbeddingEngine>,
127 chunk_embeddings: Option<&[Vec<f32>]>,
128 top_k: usize,
129 config: &HybridConfig,
130 graph_file_ranks: Option<&HashMap<String, usize>>,
131) -> Vec<HybridResult> {
132 let bm25_results = index.search(query, config.bm25_candidates);
133
134 let dense_results = match (engine, chunk_embeddings) {
135 (Some(eng), Some(embeddings)) => dense_search(
136 query,
137 eng,
138 &index.chunks,
139 embeddings,
140 config.dense_candidates,
141 ),
142 _ => Vec::new(),
143 };
144
145 let graph_enhances = graph_file_ranks.is_some_and(|m| !m.is_empty());
146
147 let candidate_count = (top_k * 5).min(config.bm25_candidates);
149
150 let mut results = if dense_results.is_empty() {
151 if graph_enhances {
152 reciprocal_rank_fusion(
153 &bm25_results,
154 &[],
155 config,
156 candidate_count,
157 graph_file_ranks,
158 )
159 } else {
160 bm25_results
161 .into_iter()
162 .take(candidate_count)
163 .map(HybridResult::from_bm25)
164 .collect()
165 }
166 } else {
167 reciprocal_rank_fusion(
168 &bm25_results,
169 &dense_results,
170 config,
171 candidate_count,
172 graph_file_ranks,
173 )
174 };
175
176 super::search_reranking::rerank_pipeline(&mut results, query, top_k);
177 results
178}
179
180#[cfg(not(feature = "embeddings"))]
181pub fn hybrid_search(query: &str, index: &BM25Index, top_k: usize) -> Vec<HybridResult> {
182 let candidate_count = (top_k * 5).min(50);
183 let mut results: Vec<HybridResult> = index
184 .search(query, candidate_count)
185 .into_iter()
186 .map(HybridResult::from_bm25)
187 .collect();
188 super::search_reranking::rerank_pipeline(&mut results, query, top_k);
189 results
190}
191
192#[cfg(feature = "embeddings")]
195fn dense_search(
196 query: &str,
197 engine: &EmbeddingEngine,
198 chunks: &[super::bm25_index::CodeChunk],
199 embeddings: &[Vec<f32>],
200 top_k: usize,
201) -> Vec<DenseSearchResult> {
202 let Ok(query_embedding) = engine.embed(query) else {
203 return Vec::new();
204 };
205
206 let scored = super::hnsw::brute_force_topk(embeddings, &query_embedding, top_k);
208
209 scored
210 .into_iter()
211 .filter_map(|(idx, sim)| {
212 let chunk = chunks.get(idx)?;
213 let snippet = chunk.content.lines().take(5).collect::<Vec<_>>().join("\n");
214 Some(DenseSearchResult {
215 chunk_idx: idx,
216 similarity: sim,
217 file_path: chunk.file_path.clone(),
218 symbol_name: chunk.symbol_name.clone(),
219 kind: chunk.kind.clone(),
220 start_line: chunk.start_line,
221 end_line: chunk.end_line,
222 snippet,
223 })
224 })
225 .collect()
226}
227
228fn result_key(file_path: &str, start_line: usize) -> String {
229 format!("{file_path}:{start_line}")
230}
231
232#[derive(Debug, Clone)]
234pub struct DenseSearchResult {
235 pub chunk_idx: usize,
236 pub similarity: f32,
237 pub file_path: String,
238 pub symbol_name: String,
239 pub kind: ChunkKind,
240 pub start_line: usize,
241 pub end_line: usize,
242 pub snippet: String,
243}
244
245#[derive(Debug, Clone)]
247pub struct HybridResult {
248 pub file_path: String,
249 pub symbol_name: String,
250 pub kind: ChunkKind,
251 pub start_line: usize,
252 pub end_line: usize,
253 pub snippet: String,
254 pub rrf_score: f64,
255 pub bm25_score: Option<f64>,
256 pub dense_score: Option<f32>,
257 pub bm25_rank: Option<usize>,
258 pub dense_rank: Option<usize>,
259}
260
261impl HybridResult {
262 pub fn from_bm25_public(result: SearchResult) -> Self {
263 Self::from_bm25(result)
264 }
265
266 fn from_bm25(result: SearchResult) -> Self {
267 Self {
268 file_path: result.file_path,
269 symbol_name: result.symbol_name,
270 kind: result.kind,
271 start_line: result.start_line,
272 end_line: result.end_line,
273 snippet: result.snippet,
274 rrf_score: result.score,
275 bm25_score: Some(result.score),
276 dense_score: None,
277 bm25_rank: None,
278 dense_rank: None,
279 }
280 }
281
282 pub fn source_label(&self) -> &'static str {
283 match (self.bm25_rank.is_some(), self.dense_rank.is_some()) {
284 (true, true) => "hybrid",
285 (true, false) => "bm25",
286 (false, true) => "dense",
287 (false, false) => "unknown",
288 }
289 }
290}
291
292pub fn format_hybrid_results(results: &[HybridResult], compact: bool) -> String {
294 if results.is_empty() {
295 return "No results found.".to_string();
296 }
297
298 let mut out = String::new();
299 for (i, r) in results.iter().enumerate() {
300 if compact {
301 out.push_str(&format!(
302 "{}. {:.4} [{}] {}:{}-{} {:?} {}\n",
303 i + 1,
304 r.rrf_score,
305 r.source_label(),
306 r.file_path,
307 r.start_line,
308 r.end_line,
309 r.kind,
310 r.symbol_name,
311 ));
312 } else {
313 let source_info = match (r.bm25_rank, r.dense_rank) {
314 (Some(bm), Some(dn)) => format!("bm25:#{bm} + dense:#{dn}"),
315 (Some(bm), None) => format!("bm25:#{bm}"),
316 (None, Some(dn)) => format!("dense:#{dn}"),
317 _ => String::new(),
318 };
319 out.push_str(&format!(
320 "\n--- Result {} (rrf: {:.4}, {}) ---\n{} :: {} [{:?}] (L{}-{})\n{}\n",
321 i + 1,
322 r.rrf_score,
323 source_info,
324 r.file_path,
325 r.symbol_name,
326 r.kind,
327 r.start_line,
328 r.end_line,
329 r.snippet,
330 ));
331 }
332 }
333 out
334}
335
336#[cfg(test)]
337mod tests {
338 use super::*;
339
340 fn make_bm25_result(file: &str, name: &str, line: usize, score: f64) -> SearchResult {
341 SearchResult {
342 chunk_idx: 0,
343 score,
344 file_path: file.to_string(),
345 symbol_name: name.to_string(),
346 kind: ChunkKind::Function,
347 start_line: line,
348 end_line: line + 10,
349 snippet: format!("fn {name}() {{ }}"),
350 }
351 }
352
353 fn make_dense_result(file: &str, name: &str, line: usize, sim: f32) -> DenseSearchResult {
354 DenseSearchResult {
355 chunk_idx: 0,
356 similarity: sim,
357 file_path: file.to_string(),
358 symbol_name: name.to_string(),
359 kind: ChunkKind::Function,
360 start_line: line,
361 end_line: line + 10,
362 snippet: format!("fn {name}() {{ }}"),
363 }
364 }
365
366 #[test]
367 fn rrf_basic_fusion() {
368 let bm25 = vec![
369 make_bm25_result("a.rs", "alpha", 1, 5.0),
370 make_bm25_result("b.rs", "beta", 1, 3.0),
371 make_bm25_result("c.rs", "gamma", 1, 1.0),
372 ];
373 let dense = vec![
374 make_dense_result("b.rs", "beta", 1, 0.95),
375 make_dense_result("d.rs", "delta", 1, 0.90),
376 make_dense_result("a.rs", "alpha", 1, 0.85),
377 ];
378
379 let config = HybridConfig::default();
380 let results = reciprocal_rank_fusion(&bm25, &dense, &config, 10, None);
381
382 assert!(!results.is_empty());
383
384 let top = &results[0];
385 assert!(
386 top.bm25_rank.is_some() || top.dense_rank.is_some(),
387 "top result should appear in at least one ranking"
388 );
389
390 let beta = results.iter().find(|r| r.symbol_name == "beta").unwrap();
391 assert!(beta.bm25_rank.is_some() && beta.dense_rank.is_some());
392 assert_eq!(beta.source_label(), "hybrid");
393 }
394
395 #[test]
396 fn rrf_both_rankings_boost() {
397 let bm25 = vec![
398 make_bm25_result("a.rs", "only_bm25", 1, 5.0),
399 make_bm25_result("b.rs", "both", 1, 3.0),
400 ];
401 let dense = vec![
402 make_dense_result("c.rs", "only_dense", 1, 0.99),
403 make_dense_result("b.rs", "both", 1, 0.90),
404 ];
405
406 let config = HybridConfig {
407 bm25_weight: 0.5,
408 dense_weight: 0.5,
409 ..Default::default()
410 };
411 let results = reciprocal_rank_fusion(&bm25, &dense, &config, 10, None);
412
413 let both = results.iter().find(|r| r.symbol_name == "both").unwrap();
414 let only_bm25 = results
415 .iter()
416 .find(|r| r.symbol_name == "only_bm25")
417 .unwrap();
418 let only_dense = results
419 .iter()
420 .find(|r| r.symbol_name == "only_dense")
421 .unwrap();
422
423 assert!(
424 both.rrf_score > only_bm25.rrf_score,
425 "result in both rankings should score higher than BM25-only"
426 );
427 assert!(
428 both.rrf_score > only_dense.rrf_score,
429 "result in both rankings should score higher than dense-only"
430 );
431 }
432
433 #[test]
434 fn rrf_respects_top_k() {
435 let bm25: Vec<SearchResult> = (0..20)
436 .map(|i| make_bm25_result("a.rs", &format!("fn_{i}"), i * 10 + 1, 10.0 - i as f64))
437 .collect();
438
439 let results = reciprocal_rank_fusion(&bm25, &[], &HybridConfig::default(), 5, None);
440 assert_eq!(results.len(), 5);
441 }
442
443 #[test]
444 fn rrf_empty_inputs() {
445 let results = reciprocal_rank_fusion(&[], &[], &HybridConfig::default(), 10, None);
446 assert!(results.is_empty());
447 }
448
449 #[test]
450 fn rrf_bm25_only() {
451 let bm25 = vec![make_bm25_result("a.rs", "alpha", 1, 5.0)];
452 let results = reciprocal_rank_fusion(&bm25, &[], &HybridConfig::default(), 10, None);
453 assert_eq!(results.len(), 1);
454 assert_eq!(results[0].source_label(), "bm25");
455 }
456
457 #[test]
458 fn rrf_dense_only() {
459 let dense = vec![make_dense_result("a.rs", "alpha", 1, 0.95)];
460 let results = reciprocal_rank_fusion(&[], &dense, &HybridConfig::default(), 10, None);
461 assert_eq!(results.len(), 1);
462 assert_eq!(results[0].source_label(), "dense");
463 }
464
465 #[test]
466 fn format_compact() {
467 let results = vec![HybridResult {
468 file_path: "auth.rs".into(),
469 symbol_name: "validate".into(),
470 kind: ChunkKind::Function,
471 start_line: 10,
472 end_line: 20,
473 snippet: "fn validate() {}".into(),
474 rrf_score: 0.0156,
475 bm25_score: Some(4.2),
476 dense_score: Some(0.91),
477 bm25_rank: Some(1),
478 dense_rank: Some(2),
479 }];
480 let output = format_hybrid_results(&results, true);
481 assert!(output.contains("[hybrid]"));
482 assert!(output.contains("auth.rs"));
483 assert!(output.contains("validate"));
484 }
485
486 #[test]
487 fn format_verbose() {
488 let results = vec![HybridResult {
489 file_path: "auth.rs".into(),
490 symbol_name: "validate".into(),
491 kind: ChunkKind::Function,
492 start_line: 10,
493 end_line: 20,
494 snippet: "fn validate() {}".into(),
495 rrf_score: 0.0156,
496 bm25_score: Some(4.2),
497 dense_score: Some(0.91),
498 bm25_rank: Some(1),
499 dense_rank: Some(2),
500 }];
501 let output = format_hybrid_results(&results, false);
502 assert!(output.contains("bm25:#1 + dense:#2"));
503 }
504
505 #[test]
506 fn source_label_categories() {
507 let mut r = HybridResult {
508 file_path: String::new(),
509 symbol_name: String::new(),
510 kind: ChunkKind::Function,
511 start_line: 0,
512 end_line: 0,
513 snippet: String::new(),
514 rrf_score: 0.0,
515 bm25_score: None,
516 dense_score: None,
517 bm25_rank: None,
518 dense_rank: None,
519 };
520
521 r.bm25_rank = Some(1);
522 r.dense_rank = Some(1);
523 assert_eq!(r.source_label(), "hybrid");
524
525 r.dense_rank = None;
526 assert_eq!(r.source_label(), "bm25");
527
528 r.bm25_rank = None;
529 r.dense_rank = Some(1);
530 assert_eq!(r.source_label(), "dense");
531 }
532
533 #[test]
534 fn rrf_graph_proximity_boost() {
535 let bm25 = vec![
536 make_bm25_result("neighbor.rs", "n", 1, 5.0),
537 make_bm25_result("weak.rs", "low", 1, 1.0),
538 ];
539 let dense = vec![
540 make_dense_result("weak.rs", "low", 1, 0.99),
541 make_dense_result("other.rs", "o", 1, 0.50),
542 ];
543 let mut graph = HashMap::new();
544 graph.insert("neighbor.rs".to_string(), 0usize);
545
546 let results =
547 reciprocal_rank_fusion(&bm25, &dense, &HybridConfig::default(), 10, Some(&graph));
548
549 let neighbor = results
550 .iter()
551 .find(|r| r.file_path == "neighbor.rs")
552 .unwrap();
553 let weak = results.iter().find(|r| r.file_path == "weak.rs").unwrap();
554 assert!(
555 neighbor.rrf_score > weak.rrf_score,
556 "graph neighbor should outrank when it gets a third RRF signal"
557 );
558 }
559}