Skip to main content

lean_ctx/core/
hybrid_search.rs

1//! Hybrid search combining BM25 (lexical) with dense vector search.
2//!
3//! Uses Reciprocal Rank Fusion (RRF) to merge BM25, dense embeddings, and optional
4//! property-graph proximity (session neighborhood) into one ranked list.
5//!
6//! Formula: score(d) = Σ 1/(k + rank_i(d))
7//! where k=60 (standard constant), and i ranges over retrieval methods.
8//!
9//! Reference: Cormack, Clarke & Buettcher (2009), "Reciprocal Rank Fusion
10//! outperforms Condorcet and individual Rank Learning Methods"
11
12use std::collections::HashMap;
13
14use super::bm25_index::{BM25Index, ChunkKind, SearchResult};
15
16#[cfg(feature = "embeddings")]
17use super::embeddings::{cosine_similarity, EmbeddingEngine};
18
19const RRF_K: f64 = 60.0;
20
21/// Default weights for standard RRF: equal contribution per ranking (`Σ weight/(k+r)` with weight=1).
22const DEFAULT_BM25_WEIGHT: f64 = 1.0;
23const DEFAULT_DENSE_WEIGHT: f64 = 1.0;
24
25/// Configuration for hybrid search behavior.
26pub struct HybridConfig {
27    pub bm25_weight: f64,
28    pub dense_weight: f64,
29    pub bm25_candidates: usize,
30    pub dense_candidates: usize,
31}
32
33impl Default for HybridConfig {
34    fn default() -> Self {
35        Self {
36            bm25_weight: DEFAULT_BM25_WEIGHT,
37            dense_weight: DEFAULT_DENSE_WEIGHT,
38            bm25_candidates: 50,
39            dense_candidates: 50,
40        }
41    }
42}
43
44/// Fuse two ranked result lists using Reciprocal Rank Fusion.
45///
46/// `graph_file_ranks`: optional repo-relative file path → rank (0-based) for neighbors of
47/// recently touched session files; each matching result gets an extra `1/(k+r)` term.
48pub fn reciprocal_rank_fusion(
49    bm25_results: &[SearchResult],
50    dense_results: &[DenseSearchResult],
51    config: &HybridConfig,
52    top_k: usize,
53    graph_file_ranks: Option<&HashMap<String, usize>>,
54) -> Vec<HybridResult> {
55    let mut scores: HashMap<String, HybridResult> = HashMap::new();
56
57    for (rank, result) in bm25_results.iter().enumerate() {
58        let key = result_key(&result.file_path, result.start_line);
59        let rrf_score = config.bm25_weight / (RRF_K + rank as f64 + 1.0);
60
61        let entry = scores.entry(key).or_insert_with(|| HybridResult {
62            file_path: result.file_path.clone(),
63            symbol_name: result.symbol_name.clone(),
64            kind: result.kind.clone(),
65            start_line: result.start_line,
66            end_line: result.end_line,
67            snippet: result.snippet.clone(),
68            rrf_score: 0.0,
69            bm25_score: Some(result.score),
70            dense_score: None,
71            bm25_rank: None,
72            dense_rank: None,
73        });
74        entry.rrf_score += rrf_score;
75        entry.bm25_rank = Some(rank + 1);
76    }
77
78    for (rank, result) in dense_results.iter().enumerate() {
79        let key = result_key(&result.file_path, result.start_line);
80        let rrf_score = config.dense_weight / (RRF_K + rank as f64 + 1.0);
81
82        let entry = scores.entry(key).or_insert_with(|| HybridResult {
83            file_path: result.file_path.clone(),
84            symbol_name: result.symbol_name.clone(),
85            kind: result.kind.clone(),
86            start_line: result.start_line,
87            end_line: result.end_line,
88            snippet: result.snippet.clone(),
89            rrf_score: 0.0,
90            bm25_score: None,
91            dense_score: None,
92            bm25_rank: None,
93            dense_rank: None,
94        });
95        entry.rrf_score += rrf_score;
96        entry.dense_score = Some(result.similarity);
97        entry.dense_rank = Some(rank + 1);
98    }
99
100    if let Some(gr) = graph_file_ranks {
101        if !gr.is_empty() {
102            for entry in scores.values_mut() {
103                if let Some(&rank) = gr.get(&entry.file_path) {
104                    entry.rrf_score += 1.0 / (RRF_K + rank as f64 + 1.0);
105                }
106            }
107        }
108    }
109
110    let mut results: Vec<HybridResult> = scores.into_values().collect();
111    results.sort_by(|a, b| {
112        b.rrf_score
113            .partial_cmp(&a.rrf_score)
114            .unwrap_or(std::cmp::Ordering::Equal)
115    });
116    results.truncate(top_k);
117    results
118}
119
120/// Run hybrid search: BM25 + dense retrieval with RRF fusion.
121/// Falls back to BM25-only if embedding engine is not available.
122#[cfg(feature = "embeddings")]
123pub fn hybrid_search(
124    query: &str,
125    index: &BM25Index,
126    engine: Option<&EmbeddingEngine>,
127    chunk_embeddings: Option<&[Vec<f32>]>,
128    top_k: usize,
129    config: &HybridConfig,
130    graph_file_ranks: Option<&HashMap<String, usize>>,
131) -> Vec<HybridResult> {
132    let bm25_results = index.search(query, config.bm25_candidates);
133
134    let dense_results = match (engine, chunk_embeddings) {
135        (Some(eng), Some(embeddings)) => dense_search(
136            query,
137            eng,
138            &index.chunks,
139            embeddings,
140            config.dense_candidates,
141        ),
142        _ => Vec::new(),
143    };
144
145    let graph_enhances = graph_file_ranks.is_some_and(|m| !m.is_empty());
146
147    if dense_results.is_empty() {
148        if graph_enhances {
149            return reciprocal_rank_fusion(&bm25_results, &[], config, top_k, graph_file_ranks);
150        }
151        return bm25_results
152            .into_iter()
153            .take(top_k)
154            .map(HybridResult::from_bm25)
155            .collect();
156    }
157
158    reciprocal_rank_fusion(
159        &bm25_results,
160        &dense_results,
161        config,
162        top_k,
163        graph_file_ranks,
164    )
165}
166
167#[cfg(not(feature = "embeddings"))]
168pub fn hybrid_search(query: &str, index: &BM25Index, top_k: usize) -> Vec<HybridResult> {
169    index
170        .search(query, top_k)
171        .into_iter()
172        .map(HybridResult::from_bm25)
173        .collect()
174}
175
176/// Dense vector search over pre-computed chunk embeddings.
177#[cfg(feature = "embeddings")]
178fn dense_search(
179    query: &str,
180    engine: &EmbeddingEngine,
181    chunks: &[super::bm25_index::CodeChunk],
182    embeddings: &[Vec<f32>],
183    top_k: usize,
184) -> Vec<DenseSearchResult> {
185    let Ok(query_embedding) = engine.embed(query) else {
186        return Vec::new();
187    };
188
189    let mut scored: Vec<(usize, f32)> = embeddings
190        .iter()
191        .enumerate()
192        .map(|(i, emb)| (i, cosine_similarity(&query_embedding, emb)))
193        .collect();
194
195    scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
196    scored.truncate(top_k);
197
198    scored
199        .into_iter()
200        .filter_map(|(idx, sim)| {
201            let chunk = chunks.get(idx)?;
202            let snippet = chunk.content.lines().take(5).collect::<Vec<_>>().join("\n");
203            Some(DenseSearchResult {
204                chunk_idx: idx,
205                similarity: sim,
206                file_path: chunk.file_path.clone(),
207                symbol_name: chunk.symbol_name.clone(),
208                kind: chunk.kind.clone(),
209                start_line: chunk.start_line,
210                end_line: chunk.end_line,
211                snippet,
212            })
213        })
214        .collect()
215}
216
217fn result_key(file_path: &str, start_line: usize) -> String {
218    format!("{file_path}:{start_line}")
219}
220
221/// Result from dense (embedding-based) search.
222#[derive(Debug, Clone)]
223pub struct DenseSearchResult {
224    pub chunk_idx: usize,
225    pub similarity: f32,
226    pub file_path: String,
227    pub symbol_name: String,
228    pub kind: ChunkKind,
229    pub start_line: usize,
230    pub end_line: usize,
231    pub snippet: String,
232}
233
234/// Fused result combining BM25 and dense scores.
235#[derive(Debug, Clone)]
236pub struct HybridResult {
237    pub file_path: String,
238    pub symbol_name: String,
239    pub kind: ChunkKind,
240    pub start_line: usize,
241    pub end_line: usize,
242    pub snippet: String,
243    pub rrf_score: f64,
244    pub bm25_score: Option<f64>,
245    pub dense_score: Option<f32>,
246    pub bm25_rank: Option<usize>,
247    pub dense_rank: Option<usize>,
248}
249
250impl HybridResult {
251    pub fn from_bm25_public(result: SearchResult) -> Self {
252        Self::from_bm25(result)
253    }
254
255    fn from_bm25(result: SearchResult) -> Self {
256        Self {
257            file_path: result.file_path,
258            symbol_name: result.symbol_name,
259            kind: result.kind,
260            start_line: result.start_line,
261            end_line: result.end_line,
262            snippet: result.snippet,
263            rrf_score: result.score,
264            bm25_score: Some(result.score),
265            dense_score: None,
266            bm25_rank: None,
267            dense_rank: None,
268        }
269    }
270
271    pub fn source_label(&self) -> &'static str {
272        match (self.bm25_rank.is_some(), self.dense_rank.is_some()) {
273            (true, true) => "hybrid",
274            (true, false) => "bm25",
275            (false, true) => "dense",
276            (false, false) => "unknown",
277        }
278    }
279}
280
281/// Format hybrid results for display.
282pub fn format_hybrid_results(results: &[HybridResult], compact: bool) -> String {
283    if results.is_empty() {
284        return "No results found.".to_string();
285    }
286
287    let mut out = String::new();
288    for (i, r) in results.iter().enumerate() {
289        if compact {
290            out.push_str(&format!(
291                "{}. {:.4} [{}] {}:{}-{} {:?} {}\n",
292                i + 1,
293                r.rrf_score,
294                r.source_label(),
295                r.file_path,
296                r.start_line,
297                r.end_line,
298                r.kind,
299                r.symbol_name,
300            ));
301        } else {
302            let source_info = match (r.bm25_rank, r.dense_rank) {
303                (Some(bm), Some(dn)) => format!("bm25:#{bm} + dense:#{dn}"),
304                (Some(bm), None) => format!("bm25:#{bm}"),
305                (None, Some(dn)) => format!("dense:#{dn}"),
306                _ => String::new(),
307            };
308            out.push_str(&format!(
309                "\n--- Result {} (rrf: {:.4}, {}) ---\n{} :: {} [{:?}] (L{}-{})\n{}\n",
310                i + 1,
311                r.rrf_score,
312                source_info,
313                r.file_path,
314                r.symbol_name,
315                r.kind,
316                r.start_line,
317                r.end_line,
318                r.snippet,
319            ));
320        }
321    }
322    out
323}
324
325#[cfg(test)]
326mod tests {
327    use super::*;
328
329    fn make_bm25_result(file: &str, name: &str, line: usize, score: f64) -> SearchResult {
330        SearchResult {
331            chunk_idx: 0,
332            score,
333            file_path: file.to_string(),
334            symbol_name: name.to_string(),
335            kind: ChunkKind::Function,
336            start_line: line,
337            end_line: line + 10,
338            snippet: format!("fn {name}() {{ }}"),
339        }
340    }
341
342    fn make_dense_result(file: &str, name: &str, line: usize, sim: f32) -> DenseSearchResult {
343        DenseSearchResult {
344            chunk_idx: 0,
345            similarity: sim,
346            file_path: file.to_string(),
347            symbol_name: name.to_string(),
348            kind: ChunkKind::Function,
349            start_line: line,
350            end_line: line + 10,
351            snippet: format!("fn {name}() {{ }}"),
352        }
353    }
354
355    #[test]
356    fn rrf_basic_fusion() {
357        let bm25 = vec![
358            make_bm25_result("a.rs", "alpha", 1, 5.0),
359            make_bm25_result("b.rs", "beta", 1, 3.0),
360            make_bm25_result("c.rs", "gamma", 1, 1.0),
361        ];
362        let dense = vec![
363            make_dense_result("b.rs", "beta", 1, 0.95),
364            make_dense_result("d.rs", "delta", 1, 0.90),
365            make_dense_result("a.rs", "alpha", 1, 0.85),
366        ];
367
368        let config = HybridConfig::default();
369        let results = reciprocal_rank_fusion(&bm25, &dense, &config, 10, None);
370
371        assert!(!results.is_empty());
372
373        let top = &results[0];
374        assert!(
375            top.bm25_rank.is_some() || top.dense_rank.is_some(),
376            "top result should appear in at least one ranking"
377        );
378
379        let beta = results.iter().find(|r| r.symbol_name == "beta").unwrap();
380        assert!(beta.bm25_rank.is_some() && beta.dense_rank.is_some());
381        assert_eq!(beta.source_label(), "hybrid");
382    }
383
384    #[test]
385    fn rrf_both_rankings_boost() {
386        let bm25 = vec![
387            make_bm25_result("a.rs", "only_bm25", 1, 5.0),
388            make_bm25_result("b.rs", "both", 1, 3.0),
389        ];
390        let dense = vec![
391            make_dense_result("c.rs", "only_dense", 1, 0.99),
392            make_dense_result("b.rs", "both", 1, 0.90),
393        ];
394
395        let config = HybridConfig {
396            bm25_weight: 0.5,
397            dense_weight: 0.5,
398            ..Default::default()
399        };
400        let results = reciprocal_rank_fusion(&bm25, &dense, &config, 10, None);
401
402        let both = results.iter().find(|r| r.symbol_name == "both").unwrap();
403        let only_bm25 = results
404            .iter()
405            .find(|r| r.symbol_name == "only_bm25")
406            .unwrap();
407        let only_dense = results
408            .iter()
409            .find(|r| r.symbol_name == "only_dense")
410            .unwrap();
411
412        assert!(
413            both.rrf_score > only_bm25.rrf_score,
414            "result in both rankings should score higher than BM25-only"
415        );
416        assert!(
417            both.rrf_score > only_dense.rrf_score,
418            "result in both rankings should score higher than dense-only"
419        );
420    }
421
422    #[test]
423    fn rrf_respects_top_k() {
424        let bm25: Vec<SearchResult> = (0..20)
425            .map(|i| make_bm25_result("a.rs", &format!("fn_{i}"), i * 10 + 1, 10.0 - i as f64))
426            .collect();
427
428        let results = reciprocal_rank_fusion(&bm25, &[], &HybridConfig::default(), 5, None);
429        assert_eq!(results.len(), 5);
430    }
431
432    #[test]
433    fn rrf_empty_inputs() {
434        let results = reciprocal_rank_fusion(&[], &[], &HybridConfig::default(), 10, None);
435        assert!(results.is_empty());
436    }
437
438    #[test]
439    fn rrf_bm25_only() {
440        let bm25 = vec![make_bm25_result("a.rs", "alpha", 1, 5.0)];
441        let results = reciprocal_rank_fusion(&bm25, &[], &HybridConfig::default(), 10, None);
442        assert_eq!(results.len(), 1);
443        assert_eq!(results[0].source_label(), "bm25");
444    }
445
446    #[test]
447    fn rrf_dense_only() {
448        let dense = vec![make_dense_result("a.rs", "alpha", 1, 0.95)];
449        let results = reciprocal_rank_fusion(&[], &dense, &HybridConfig::default(), 10, None);
450        assert_eq!(results.len(), 1);
451        assert_eq!(results[0].source_label(), "dense");
452    }
453
454    #[test]
455    fn format_compact() {
456        let results = vec![HybridResult {
457            file_path: "auth.rs".into(),
458            symbol_name: "validate".into(),
459            kind: ChunkKind::Function,
460            start_line: 10,
461            end_line: 20,
462            snippet: "fn validate() {}".into(),
463            rrf_score: 0.0156,
464            bm25_score: Some(4.2),
465            dense_score: Some(0.91),
466            bm25_rank: Some(1),
467            dense_rank: Some(2),
468        }];
469        let output = format_hybrid_results(&results, true);
470        assert!(output.contains("[hybrid]"));
471        assert!(output.contains("auth.rs"));
472        assert!(output.contains("validate"));
473    }
474
475    #[test]
476    fn format_verbose() {
477        let results = vec![HybridResult {
478            file_path: "auth.rs".into(),
479            symbol_name: "validate".into(),
480            kind: ChunkKind::Function,
481            start_line: 10,
482            end_line: 20,
483            snippet: "fn validate() {}".into(),
484            rrf_score: 0.0156,
485            bm25_score: Some(4.2),
486            dense_score: Some(0.91),
487            bm25_rank: Some(1),
488            dense_rank: Some(2),
489        }];
490        let output = format_hybrid_results(&results, false);
491        assert!(output.contains("bm25:#1 + dense:#2"));
492    }
493
494    #[test]
495    fn source_label_categories() {
496        let mut r = HybridResult {
497            file_path: String::new(),
498            symbol_name: String::new(),
499            kind: ChunkKind::Function,
500            start_line: 0,
501            end_line: 0,
502            snippet: String::new(),
503            rrf_score: 0.0,
504            bm25_score: None,
505            dense_score: None,
506            bm25_rank: None,
507            dense_rank: None,
508        };
509
510        r.bm25_rank = Some(1);
511        r.dense_rank = Some(1);
512        assert_eq!(r.source_label(), "hybrid");
513
514        r.dense_rank = None;
515        assert_eq!(r.source_label(), "bm25");
516
517        r.bm25_rank = None;
518        r.dense_rank = Some(1);
519        assert_eq!(r.source_label(), "dense");
520    }
521
522    #[test]
523    fn rrf_graph_proximity_boost() {
524        let bm25 = vec![
525            make_bm25_result("neighbor.rs", "n", 1, 5.0),
526            make_bm25_result("weak.rs", "low", 1, 1.0),
527        ];
528        let dense = vec![
529            make_dense_result("weak.rs", "low", 1, 0.99),
530            make_dense_result("other.rs", "o", 1, 0.50),
531        ];
532        let mut graph = HashMap::new();
533        graph.insert("neighbor.rs".to_string(), 0usize);
534
535        let results =
536            reciprocal_rank_fusion(&bm25, &dense, &HybridConfig::default(), 10, Some(&graph));
537
538        let neighbor = results
539            .iter()
540            .find(|r| r.file_path == "neighbor.rs")
541            .unwrap();
542        let weak = results.iter().find(|r| r.file_path == "weak.rs").unwrap();
543        assert!(
544            neighbor.rrf_score > weak.rrf_score,
545            "graph neighbor should outrank when it gets a third RRF signal"
546        );
547    }
548}