Skip to main content

lean_ctx/core/
hybrid_search.rs

1//! Hybrid search combining BM25 (lexical) with dense vector search.
2//!
3//! Uses Reciprocal Rank Fusion (RRF) to merge BM25, dense embeddings, and optional
4//! property-graph proximity (session neighborhood) into one ranked list.
5//!
6//! Formula: score(d) = Σ 1/(k + rank_i(d))
7//! where k=60 (standard constant), and i ranges over retrieval methods.
8//!
9//! Reference: Cormack, Clarke & Buettcher (2009), "Reciprocal Rank Fusion
10//! outperforms Condorcet and individual Rank Learning Methods"
11
12use std::collections::HashMap;
13
14use super::bm25_index::{BM25Index, ChunkKind, SearchResult};
15
16#[cfg(feature = "embeddings")]
17use super::embeddings::EmbeddingEngine;
18
19const RRF_K: f64 = 60.0;
20
21/// Default weights for standard RRF: equal contribution per ranking (`Σ weight/(k+r)` with weight=1).
22const DEFAULT_BM25_WEIGHT: f64 = 1.0;
23const DEFAULT_DENSE_WEIGHT: f64 = 1.0;
24
25/// Configuration for hybrid search behavior.
26pub struct HybridConfig {
27    pub bm25_weight: f64,
28    pub dense_weight: f64,
29    pub bm25_candidates: usize,
30    pub dense_candidates: usize,
31}
32
33impl Default for HybridConfig {
34    fn default() -> Self {
35        Self {
36            bm25_weight: DEFAULT_BM25_WEIGHT,
37            dense_weight: DEFAULT_DENSE_WEIGHT,
38            bm25_candidates: 50,
39            dense_candidates: 50,
40        }
41    }
42}
43
44/// Fuse two ranked result lists using Reciprocal Rank Fusion.
45///
46/// `graph_file_ranks`: optional repo-relative file path → rank (0-based) for neighbors of
47/// recently touched session files; each matching result gets an extra `1/(k+r)` term.
48pub fn reciprocal_rank_fusion(
49    bm25_results: &[SearchResult],
50    dense_results: &[DenseSearchResult],
51    config: &HybridConfig,
52    top_k: usize,
53    graph_file_ranks: Option<&HashMap<String, usize>>,
54) -> Vec<HybridResult> {
55    let mut scores: HashMap<String, HybridResult> = HashMap::new();
56
57    for (rank, result) in bm25_results.iter().enumerate() {
58        let key = result_key(&result.file_path, result.start_line);
59        let rrf_score = config.bm25_weight / (RRF_K + rank as f64 + 1.0);
60
61        let entry = scores.entry(key).or_insert_with(|| HybridResult {
62            file_path: result.file_path.clone(),
63            symbol_name: result.symbol_name.clone(),
64            kind: result.kind.clone(),
65            start_line: result.start_line,
66            end_line: result.end_line,
67            snippet: result.snippet.clone(),
68            rrf_score: 0.0,
69            bm25_score: Some(result.score),
70            dense_score: None,
71            bm25_rank: None,
72            dense_rank: None,
73        });
74        entry.rrf_score += rrf_score;
75        entry.bm25_rank = Some(rank + 1);
76    }
77
78    for (rank, result) in dense_results.iter().enumerate() {
79        let key = result_key(&result.file_path, result.start_line);
80        let rrf_score = config.dense_weight / (RRF_K + rank as f64 + 1.0);
81
82        let entry = scores.entry(key).or_insert_with(|| HybridResult {
83            file_path: result.file_path.clone(),
84            symbol_name: result.symbol_name.clone(),
85            kind: result.kind.clone(),
86            start_line: result.start_line,
87            end_line: result.end_line,
88            snippet: result.snippet.clone(),
89            rrf_score: 0.0,
90            bm25_score: None,
91            dense_score: None,
92            bm25_rank: None,
93            dense_rank: None,
94        });
95        entry.rrf_score += rrf_score;
96        entry.dense_score = Some(result.similarity);
97        entry.dense_rank = Some(rank + 1);
98    }
99
100    if let Some(gr) = graph_file_ranks {
101        if !gr.is_empty() {
102            for entry in scores.values_mut() {
103                if let Some(&rank) = gr.get(&entry.file_path) {
104                    entry.rrf_score += 1.0 / (RRF_K + rank as f64 + 1.0);
105                }
106            }
107        }
108    }
109
110    let mut results: Vec<HybridResult> = scores.into_values().collect();
111    results.sort_by(|a, b| {
112        b.rrf_score
113            .partial_cmp(&a.rrf_score)
114            .unwrap_or(std::cmp::Ordering::Equal)
115    });
116    results.truncate(top_k);
117    results
118}
119
120/// Run hybrid search: BM25 + dense retrieval with RRF fusion + post-RRF reranking.
121/// Falls back to BM25-only if embedding engine is not available.
122#[cfg(feature = "embeddings")]
123pub fn hybrid_search(
124    query: &str,
125    index: &BM25Index,
126    engine: Option<&EmbeddingEngine>,
127    chunk_embeddings: Option<&[Vec<f32>]>,
128    top_k: usize,
129    config: &HybridConfig,
130    graph_file_ranks: Option<&HashMap<String, usize>>,
131) -> Vec<HybridResult> {
132    let bm25_results = index.search(query, config.bm25_candidates);
133
134    let dense_results = match (engine, chunk_embeddings) {
135        (Some(eng), Some(embeddings)) => dense_search(
136            query,
137            eng,
138            &index.chunks,
139            embeddings,
140            config.dense_candidates,
141        ),
142        _ => Vec::new(),
143    };
144
145    let graph_enhances = graph_file_ranks.is_some_and(|m| !m.is_empty());
146
147    // Over-fetch candidates for reranking (5x top_k, capped at available)
148    let candidate_count = (top_k * 5).min(config.bm25_candidates);
149
150    let mut results = if dense_results.is_empty() {
151        if graph_enhances {
152            reciprocal_rank_fusion(
153                &bm25_results,
154                &[],
155                config,
156                candidate_count,
157                graph_file_ranks,
158            )
159        } else {
160            bm25_results
161                .into_iter()
162                .take(candidate_count)
163                .map(HybridResult::from_bm25)
164                .collect()
165        }
166    } else {
167        reciprocal_rank_fusion(
168            &bm25_results,
169            &dense_results,
170            config,
171            candidate_count,
172            graph_file_ranks,
173        )
174    };
175
176    super::search_reranking::rerank_pipeline(&mut results, query, top_k);
177    results
178}
179
180#[cfg(not(feature = "embeddings"))]
181pub fn hybrid_search(query: &str, index: &BM25Index, top_k: usize) -> Vec<HybridResult> {
182    let candidate_count = (top_k * 5).min(50);
183    let mut results: Vec<HybridResult> = index
184        .search(query, candidate_count)
185        .into_iter()
186        .map(HybridResult::from_bm25)
187        .collect();
188    super::search_reranking::rerank_pipeline(&mut results, query, top_k);
189    results
190}
191
192/// Dense vector search over pre-computed chunk embeddings.
193/// Uses O(n log k) binary-heap top-k selection for small indices, HNSW for large ones.
194#[cfg(feature = "embeddings")]
195fn dense_search(
196    query: &str,
197    engine: &EmbeddingEngine,
198    chunks: &[super::bm25_index::CodeChunk],
199    embeddings: &[Vec<f32>],
200    top_k: usize,
201) -> Vec<DenseSearchResult> {
202    let Ok(query_embedding) = engine.embed(query) else {
203        return Vec::new();
204    };
205
206    // Use efficient O(n log k) top-k selection instead of O(n log n) full sort
207    let scored = super::hnsw::brute_force_topk(embeddings, &query_embedding, top_k);
208
209    scored
210        .into_iter()
211        .filter_map(|(idx, sim)| {
212            let chunk = chunks.get(idx)?;
213            let snippet = chunk.content.lines().take(5).collect::<Vec<_>>().join("\n");
214            Some(DenseSearchResult {
215                chunk_idx: idx,
216                similarity: sim,
217                file_path: chunk.file_path.clone(),
218                symbol_name: chunk.symbol_name.clone(),
219                kind: chunk.kind.clone(),
220                start_line: chunk.start_line,
221                end_line: chunk.end_line,
222                snippet,
223            })
224        })
225        .collect()
226}
227
228fn result_key(file_path: &str, start_line: usize) -> String {
229    format!("{file_path}:{start_line}")
230}
231
232/// Result from dense (embedding-based) search.
233#[derive(Debug, Clone)]
234pub struct DenseSearchResult {
235    pub chunk_idx: usize,
236    pub similarity: f32,
237    pub file_path: String,
238    pub symbol_name: String,
239    pub kind: ChunkKind,
240    pub start_line: usize,
241    pub end_line: usize,
242    pub snippet: String,
243}
244
245/// Fused result combining BM25 and dense scores.
246#[derive(Debug, Clone)]
247pub struct HybridResult {
248    pub file_path: String,
249    pub symbol_name: String,
250    pub kind: ChunkKind,
251    pub start_line: usize,
252    pub end_line: usize,
253    pub snippet: String,
254    pub rrf_score: f64,
255    pub bm25_score: Option<f64>,
256    pub dense_score: Option<f32>,
257    pub bm25_rank: Option<usize>,
258    pub dense_rank: Option<usize>,
259}
260
261impl HybridResult {
262    pub fn from_bm25_public(result: SearchResult) -> Self {
263        Self::from_bm25(result)
264    }
265
266    fn from_bm25(result: SearchResult) -> Self {
267        Self {
268            file_path: result.file_path,
269            symbol_name: result.symbol_name,
270            kind: result.kind,
271            start_line: result.start_line,
272            end_line: result.end_line,
273            snippet: result.snippet,
274            rrf_score: result.score,
275            bm25_score: Some(result.score),
276            dense_score: None,
277            bm25_rank: None,
278            dense_rank: None,
279        }
280    }
281
282    pub fn source_label(&self) -> &'static str {
283        match (self.bm25_rank.is_some(), self.dense_rank.is_some()) {
284            (true, true) => "hybrid",
285            (true, false) => "bm25",
286            (false, true) => "dense",
287            (false, false) => "unknown",
288        }
289    }
290}
291
292/// Format hybrid results for display.
293pub fn format_hybrid_results(results: &[HybridResult], compact: bool) -> String {
294    if results.is_empty() {
295        return "No results found.".to_string();
296    }
297
298    let mut out = String::new();
299    for (i, r) in results.iter().enumerate() {
300        if compact {
301            out.push_str(&format!(
302                "{}. {:.4} [{}] {}:{}-{} {:?} {}\n",
303                i + 1,
304                r.rrf_score,
305                r.source_label(),
306                r.file_path,
307                r.start_line,
308                r.end_line,
309                r.kind,
310                r.symbol_name,
311            ));
312        } else {
313            let source_info = match (r.bm25_rank, r.dense_rank) {
314                (Some(bm), Some(dn)) => format!("bm25:#{bm} + dense:#{dn}"),
315                (Some(bm), None) => format!("bm25:#{bm}"),
316                (None, Some(dn)) => format!("dense:#{dn}"),
317                _ => String::new(),
318            };
319            out.push_str(&format!(
320                "\n--- Result {} (rrf: {:.4}, {}) ---\n{} :: {} [{:?}] (L{}-{})\n{}\n",
321                i + 1,
322                r.rrf_score,
323                source_info,
324                r.file_path,
325                r.symbol_name,
326                r.kind,
327                r.start_line,
328                r.end_line,
329                r.snippet,
330            ));
331        }
332    }
333    out
334}
335
336#[cfg(test)]
337mod tests {
338    use super::*;
339
340    fn make_bm25_result(file: &str, name: &str, line: usize, score: f64) -> SearchResult {
341        SearchResult {
342            chunk_idx: 0,
343            score,
344            file_path: file.to_string(),
345            symbol_name: name.to_string(),
346            kind: ChunkKind::Function,
347            start_line: line,
348            end_line: line + 10,
349            snippet: format!("fn {name}() {{ }}"),
350        }
351    }
352
353    fn make_dense_result(file: &str, name: &str, line: usize, sim: f32) -> DenseSearchResult {
354        DenseSearchResult {
355            chunk_idx: 0,
356            similarity: sim,
357            file_path: file.to_string(),
358            symbol_name: name.to_string(),
359            kind: ChunkKind::Function,
360            start_line: line,
361            end_line: line + 10,
362            snippet: format!("fn {name}() {{ }}"),
363        }
364    }
365
366    #[test]
367    fn rrf_basic_fusion() {
368        let bm25 = vec![
369            make_bm25_result("a.rs", "alpha", 1, 5.0),
370            make_bm25_result("b.rs", "beta", 1, 3.0),
371            make_bm25_result("c.rs", "gamma", 1, 1.0),
372        ];
373        let dense = vec![
374            make_dense_result("b.rs", "beta", 1, 0.95),
375            make_dense_result("d.rs", "delta", 1, 0.90),
376            make_dense_result("a.rs", "alpha", 1, 0.85),
377        ];
378
379        let config = HybridConfig::default();
380        let results = reciprocal_rank_fusion(&bm25, &dense, &config, 10, None);
381
382        assert!(!results.is_empty());
383
384        let top = &results[0];
385        assert!(
386            top.bm25_rank.is_some() || top.dense_rank.is_some(),
387            "top result should appear in at least one ranking"
388        );
389
390        let beta = results.iter().find(|r| r.symbol_name == "beta").unwrap();
391        assert!(beta.bm25_rank.is_some() && beta.dense_rank.is_some());
392        assert_eq!(beta.source_label(), "hybrid");
393    }
394
395    #[test]
396    fn rrf_both_rankings_boost() {
397        let bm25 = vec![
398            make_bm25_result("a.rs", "only_bm25", 1, 5.0),
399            make_bm25_result("b.rs", "both", 1, 3.0),
400        ];
401        let dense = vec![
402            make_dense_result("c.rs", "only_dense", 1, 0.99),
403            make_dense_result("b.rs", "both", 1, 0.90),
404        ];
405
406        let config = HybridConfig {
407            bm25_weight: 0.5,
408            dense_weight: 0.5,
409            ..Default::default()
410        };
411        let results = reciprocal_rank_fusion(&bm25, &dense, &config, 10, None);
412
413        let both = results.iter().find(|r| r.symbol_name == "both").unwrap();
414        let only_bm25 = results
415            .iter()
416            .find(|r| r.symbol_name == "only_bm25")
417            .unwrap();
418        let only_dense = results
419            .iter()
420            .find(|r| r.symbol_name == "only_dense")
421            .unwrap();
422
423        assert!(
424            both.rrf_score > only_bm25.rrf_score,
425            "result in both rankings should score higher than BM25-only"
426        );
427        assert!(
428            both.rrf_score > only_dense.rrf_score,
429            "result in both rankings should score higher than dense-only"
430        );
431    }
432
433    #[test]
434    fn rrf_respects_top_k() {
435        let bm25: Vec<SearchResult> = (0..20)
436            .map(|i| make_bm25_result("a.rs", &format!("fn_{i}"), i * 10 + 1, 10.0 - i as f64))
437            .collect();
438
439        let results = reciprocal_rank_fusion(&bm25, &[], &HybridConfig::default(), 5, None);
440        assert_eq!(results.len(), 5);
441    }
442
443    #[test]
444    fn rrf_empty_inputs() {
445        let results = reciprocal_rank_fusion(&[], &[], &HybridConfig::default(), 10, None);
446        assert!(results.is_empty());
447    }
448
449    #[test]
450    fn rrf_bm25_only() {
451        let bm25 = vec![make_bm25_result("a.rs", "alpha", 1, 5.0)];
452        let results = reciprocal_rank_fusion(&bm25, &[], &HybridConfig::default(), 10, None);
453        assert_eq!(results.len(), 1);
454        assert_eq!(results[0].source_label(), "bm25");
455    }
456
457    #[test]
458    fn rrf_dense_only() {
459        let dense = vec![make_dense_result("a.rs", "alpha", 1, 0.95)];
460        let results = reciprocal_rank_fusion(&[], &dense, &HybridConfig::default(), 10, None);
461        assert_eq!(results.len(), 1);
462        assert_eq!(results[0].source_label(), "dense");
463    }
464
465    #[test]
466    fn format_compact() {
467        let results = vec![HybridResult {
468            file_path: "auth.rs".into(),
469            symbol_name: "validate".into(),
470            kind: ChunkKind::Function,
471            start_line: 10,
472            end_line: 20,
473            snippet: "fn validate() {}".into(),
474            rrf_score: 0.0156,
475            bm25_score: Some(4.2),
476            dense_score: Some(0.91),
477            bm25_rank: Some(1),
478            dense_rank: Some(2),
479        }];
480        let output = format_hybrid_results(&results, true);
481        assert!(output.contains("[hybrid]"));
482        assert!(output.contains("auth.rs"));
483        assert!(output.contains("validate"));
484    }
485
486    #[test]
487    fn format_verbose() {
488        let results = vec![HybridResult {
489            file_path: "auth.rs".into(),
490            symbol_name: "validate".into(),
491            kind: ChunkKind::Function,
492            start_line: 10,
493            end_line: 20,
494            snippet: "fn validate() {}".into(),
495            rrf_score: 0.0156,
496            bm25_score: Some(4.2),
497            dense_score: Some(0.91),
498            bm25_rank: Some(1),
499            dense_rank: Some(2),
500        }];
501        let output = format_hybrid_results(&results, false);
502        assert!(output.contains("bm25:#1 + dense:#2"));
503    }
504
505    #[test]
506    fn source_label_categories() {
507        let mut r = HybridResult {
508            file_path: String::new(),
509            symbol_name: String::new(),
510            kind: ChunkKind::Function,
511            start_line: 0,
512            end_line: 0,
513            snippet: String::new(),
514            rrf_score: 0.0,
515            bm25_score: None,
516            dense_score: None,
517            bm25_rank: None,
518            dense_rank: None,
519        };
520
521        r.bm25_rank = Some(1);
522        r.dense_rank = Some(1);
523        assert_eq!(r.source_label(), "hybrid");
524
525        r.dense_rank = None;
526        assert_eq!(r.source_label(), "bm25");
527
528        r.bm25_rank = None;
529        r.dense_rank = Some(1);
530        assert_eq!(r.source_label(), "dense");
531    }
532
533    #[test]
534    fn rrf_graph_proximity_boost() {
535        let bm25 = vec![
536            make_bm25_result("neighbor.rs", "n", 1, 5.0),
537            make_bm25_result("weak.rs", "low", 1, 1.0),
538        ];
539        let dense = vec![
540            make_dense_result("weak.rs", "low", 1, 0.99),
541            make_dense_result("other.rs", "o", 1, 0.50),
542        ];
543        let mut graph = HashMap::new();
544        graph.insert("neighbor.rs".to_string(), 0usize);
545
546        let results =
547            reciprocal_rank_fusion(&bm25, &dense, &HybridConfig::default(), 10, Some(&graph));
548
549        let neighbor = results
550            .iter()
551            .find(|r| r.file_path == "neighbor.rs")
552            .unwrap();
553        let weak = results.iter().find(|r| r.file_path == "weak.rs").unwrap();
554        assert!(
555            neighbor.rrf_score > weak.rrf_score,
556            "graph neighbor should outrank when it gets a third RRF signal"
557        );
558    }
559}