Skip to main content

lean_ctx/core/
search_reranking.rs

1//! Post-RRF Reranking Pipeline for code-aware search.
2//!
3//! Scientific foundations:
4//! - Cormack et al. (SIGIR 2009): RRF as unsupervised fusion baseline
5//! - Carbonell & Goldstein (SIGIR 1998): MMR diversity via file-saturation decay
6//! - CoRNStack (ICLR 2025): Definition-boost + noise filtering for code
7//! - SACL (EMNLP 2025): Query-type-adaptive weighting + path enrichment
8//! - SweRank (2025): Multi-stage retrieve-then-rerank for code localization
9//!
10//! Pipeline order (applied after RRF fusion):
11//! 1. Definition Boost — chunks defining the queried symbol rank higher
12//! 2. File Coherence — files with multiple relevant chunks get boosted
13//! 3. Noise Penalties — test/legacy/compat paths get penalized
14//! 4. MMR Diversity — exponential decay per file prevents single-file dominance
15
16use std::collections::{HashMap, HashSet};
17use std::path::Path;
18
19use super::bm25_index::ChunkKind;
20use super::hybrid_search::HybridResult;
21
22// --- Constants (empirically validated by semble's ablation study) ---
23
24const DEFINITION_BOOST_MULTIPLIER: f64 = 3.0;
25const FILE_COHERENCE_FRAC: f64 = 0.2;
26const SATURATION_DECAY: f64 = 0.5;
27const SATURATION_THRESHOLD: usize = 1;
28
29const STRONG_PENALTY: f64 = 0.3;
30const MODERATE_PENALTY: f64 = 0.5;
31const MILD_PENALTY: f64 = 0.7;
32
33// --- Query Classification (SACL-inspired) ---
34
35#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36pub enum QueryType {
37    Symbol,
38    NaturalLanguage,
39    Architecture,
40}
41
42/// Classify a search query as Symbol, NL, or Architecture.
43///
44/// Symbol queries: namespace-qualified (`Foo::bar`), leading underscore,
45/// CamelCase single identifier, SCREAMING_CASE.
46/// Architecture queries: contain structural keywords (how, where, pattern, flow, architecture).
47pub fn classify_query(query: &str) -> QueryType {
48    let trimmed = query.trim();
49    if trimmed.is_empty() {
50        return QueryType::NaturalLanguage;
51    }
52
53    if is_symbol_query(trimmed) {
54        return QueryType::Symbol;
55    }
56
57    let lower = trimmed.to_lowercase();
58    if is_architecture_query(&lower) {
59        return QueryType::Architecture;
60    }
61
62    QueryType::NaturalLanguage
63}
64
65fn is_symbol_query(query: &str) -> bool {
66    let tokens: Vec<&str> = query.split_whitespace().collect();
67    if tokens.len() != 1 {
68        return false;
69    }
70    let token = tokens[0];
71
72    // Namespace-qualified: Foo::bar, path.to.Module, obj->field
73    if token.contains("::")
74        || (token.contains('.') && token.chars().any(char::is_uppercase))
75        || token.contains("->")
76    {
77        return true;
78    }
79
80    // Leading underscore
81    if token.starts_with('_') && token.len() > 1 {
82        return true;
83    }
84
85    // SCREAMING_CASE: ALL_CAPS_WITH_UNDERSCORES
86    if token.len() > 2
87        && token
88            .chars()
89            .all(|c| c.is_uppercase() || c == '_' || c.is_ascii_digit())
90        && token.contains('_')
91    {
92        return true;
93    }
94
95    // CamelCase or PascalCase (at least one transition)
96    let has_lower_to_upper = token
97        .as_bytes()
98        .windows(2)
99        .any(|w| w[0].is_ascii_lowercase() && w[1].is_ascii_uppercase());
100    let starts_upper = token.starts_with(|c: char| c.is_uppercase());
101
102    // snake_case identifier
103    if token.contains('_')
104        && token.len() > 2
105        && token.chars().all(|c| c.is_alphanumeric() || c == '_')
106    {
107        return true;
108    }
109
110    has_lower_to_upper
111        || (starts_upper && token.len() > 1 && token[1..].contains(char::is_lowercase))
112}
113
114fn is_architecture_query(lower: &str) -> bool {
115    const ARCH_KEYWORDS: &[&str] = &[
116        "how does",
117        "how is",
118        "where is",
119        "where are",
120        "architecture",
121        "design pattern",
122        "data flow",
123        "control flow",
124        "module structure",
125        "component",
126        "layer",
127        "pipeline",
128    ];
129    ARCH_KEYWORDS.iter().any(|kw| lower.contains(kw))
130}
131
132/// Resolve BM25 vs Dense weight based on query type.
133/// Returns (bm25_weight, dense_weight).
134pub fn resolve_weights(query_type: QueryType) -> (f64, f64) {
135    match query_type {
136        QueryType::Symbol => (1.4, 0.6),
137        QueryType::NaturalLanguage => (1.0, 1.0),
138        QueryType::Architecture => (0.6, 1.4),
139    }
140}
141
142// --- Reranking Pipeline ---
143
144/// Apply the full post-RRF reranking pipeline.
145///
146/// Mutates scores in-place for efficiency, then applies diversity-based
147/// selection (MMR-inspired file-saturation decay) to produce final top-k.
148pub fn rerank_pipeline(results: &mut Vec<HybridResult>, query: &str, top_k: usize) {
149    if results.is_empty() {
150        return;
151    }
152
153    let query_type = classify_query(query);
154
155    definition_boost(results, query, query_type);
156    file_coherence_boost(results);
157    apply_noise_penalties(results);
158    *results = apply_diversity(std::mem::take(results), top_k);
159}
160
161// --- Signal 1: Definition Boost ---
162
163fn definition_boost(results: &mut [HybridResult], query: &str, query_type: QueryType) {
164    if query_type != QueryType::Symbol {
165        return;
166    }
167
168    let symbol = extract_symbol_name(query);
169    if symbol.is_empty() {
170        return;
171    }
172
173    let max_score = results.iter().map(|r| r.rrf_score).fold(0.0_f64, f64::max);
174    if max_score == 0.0 {
175        return;
176    }
177
178    let boost = max_score * DEFINITION_BOOST_MULTIPLIER;
179    let symbol_lower = symbol.to_lowercase();
180
181    for result in results.iter_mut() {
182        if is_defining_chunk(result, &symbol_lower) {
183            result.rrf_score += boost;
184        }
185    }
186}
187
188fn extract_symbol_name(query: &str) -> &str {
189    let trimmed = query.trim();
190    // Foo::bar -> bar
191    if let Some(pos) = trimmed.rfind("::") {
192        return &trimmed[pos + 2..];
193    }
194    // obj.method -> method
195    if let Some(pos) = trimmed.rfind('.') {
196        return &trimmed[pos + 1..];
197    }
198    // obj->field -> field
199    if let Some(pos) = trimmed.rfind("->") {
200        return &trimmed[pos + 2..];
201    }
202    trimmed
203}
204
205fn is_defining_chunk(result: &HybridResult, symbol_lower: &str) -> bool {
206    match result.kind {
207        ChunkKind::Other => false,
208        _ => result.symbol_name.to_lowercase().contains(symbol_lower),
209    }
210}
211
212// --- Signal 2: File Coherence Boost ---
213
214fn file_coherence_boost(results: &mut [HybridResult]) {
215    if results.len() < 2 {
216        return;
217    }
218
219    let max_score = results.iter().map(|r| r.rrf_score).fold(0.0_f64, f64::max);
220    if max_score == 0.0 {
221        return;
222    }
223
224    let mut file_scores: HashMap<String, f64> = HashMap::new();
225    for r in results.iter() {
226        *file_scores.entry(r.file_path.clone()).or_insert(0.0) += r.rrf_score;
227    }
228
229    let max_file_score = file_scores.values().copied().fold(0.0_f64, f64::max);
230    if max_file_score == 0.0 {
231        return;
232    }
233
234    let boost_unit = max_score * FILE_COHERENCE_FRAC;
235    let mut seen: HashSet<String> = HashSet::new();
236
237    for result in results.iter_mut() {
238        if seen.insert(result.file_path.clone()) {
239            let file_score = file_scores.get(&result.file_path).copied().unwrap_or(0.0);
240            result.rrf_score += boost_unit * file_score / max_file_score;
241        }
242    }
243}
244
245// --- Signal 3: Noise Penalties ---
246
247fn apply_noise_penalties(results: &mut [HybridResult]) {
248    for result in results.iter_mut() {
249        let penalty = path_penalty(&result.file_path);
250        if penalty < 1.0 {
251            result.rrf_score *= penalty;
252        }
253    }
254}
255
256fn path_penalty(file_path: &str) -> f64 {
257    let normalized = file_path.replace('\\', "/");
258    let mut penalty = 1.0;
259
260    if is_test_file(&normalized) {
261        penalty *= STRONG_PENALTY;
262    }
263    if is_compat_legacy(&normalized) {
264        penalty *= STRONG_PENALTY;
265    }
266    if is_example_docs(&normalized) {
267        penalty *= STRONG_PENALTY;
268    }
269    if is_reexport_barrel(&normalized) {
270        penalty *= MODERATE_PENALTY;
271    }
272    if is_type_stub(&normalized) {
273        penalty *= MILD_PENALTY;
274    }
275
276    penalty
277}
278
279fn is_test_file(path: &str) -> bool {
280    let filename = Path::new(path)
281        .file_name()
282        .and_then(|f| f.to_str())
283        .unwrap_or("");
284
285    // test_*.py, *_test.py, *_test.go, *_test.rs
286    if filename.starts_with("test_") || filename.contains("_test.") {
287        return true;
288    }
289    // *.test.js/ts, *.spec.js/ts
290    if filename.contains(".test.") || filename.contains(".spec.") {
291        return true;
292    }
293    // *Test.java, *Tests.java, *Test.kt, *Test.cs
294    if filename.ends_with("Test.java")
295        || filename.ends_with("Tests.java")
296        || filename.ends_with("Test.kt")
297        || filename.ends_with("Test.cs")
298        || filename.ends_with("Tests.swift")
299    {
300        return true;
301    }
302    // *_spec.rb
303    if filename.ends_with("_spec.rb") {
304        return true;
305    }
306
307    // Test directories (absolute or relative)
308    path.contains("/tests/")
309        || path.contains("/test/")
310        || path.contains("/__tests__/")
311        || path.contains("/spec/")
312        || path.contains("/testing/")
313        || path.starts_with("tests/")
314        || path.starts_with("test/")
315}
316
317fn is_compat_legacy(path: &str) -> bool {
318    path.contains("/compat/")
319        || path.contains("/_compat/")
320        || path.contains("/legacy/")
321        || path.contains("/deprecated/")
322}
323
324fn is_example_docs(path: &str) -> bool {
325    path.contains("/examples/")
326        || path.contains("/example/")
327        || path.contains("/_examples/")
328        || path.contains("/docs_src/")
329        || path.starts_with("examples/")
330        || path.starts_with("example/")
331}
332
333fn is_reexport_barrel(path: &str) -> bool {
334    let filename = Path::new(path)
335        .file_name()
336        .and_then(|f| f.to_str())
337        .unwrap_or("");
338    filename == "__init__.py" || filename == "package-info.java" || filename == "index.ts"
339}
340
341#[allow(clippy::case_sensitive_file_extension_comparisons)]
342fn is_type_stub(path: &str) -> bool {
343    let lower = path.to_ascii_lowercase();
344    lower.ends_with(".d.ts") || lower.ends_with(".pyi")
345}
346
347// --- Signal 4: MMR-Inspired Diversity (File Saturation Decay) ---
348
349fn apply_diversity(mut results: Vec<HybridResult>, top_k: usize) -> Vec<HybridResult> {
350    if results.is_empty() {
351        return results;
352    }
353
354    results.sort_by(|a, b| {
355        b.rrf_score
356            .partial_cmp(&a.rrf_score)
357            .unwrap_or(std::cmp::Ordering::Equal)
358    });
359
360    let mut selected: Vec<HybridResult> = Vec::with_capacity(top_k);
361    let mut file_count: HashMap<&str, usize> = HashMap::new();
362    let mut remaining: Vec<(usize, f64)> = results
363        .iter()
364        .enumerate()
365        .map(|(i, r)| (i, r.rrf_score))
366        .collect();
367
368    while selected.len() < top_k && !remaining.is_empty() {
369        // Compute effective scores with file saturation decay
370        let mut best_idx = 0;
371        let mut best_effective = f64::NEG_INFINITY;
372
373        for (pos, &(orig_idx, base_score)) in remaining.iter().enumerate() {
374            let file = results[orig_idx].file_path.as_str();
375            let count = file_count.get(file).copied().unwrap_or(0);
376            let effective = if count >= SATURATION_THRESHOLD {
377                let excess = count - SATURATION_THRESHOLD + 1;
378                base_score * SATURATION_DECAY.powi(excess as i32)
379            } else {
380                base_score
381            };
382
383            if effective > best_effective {
384                best_effective = effective;
385                best_idx = pos;
386            }
387        }
388
389        let (orig_idx, _) = remaining.remove(best_idx);
390        let file = results[orig_idx].file_path.as_str();
391        *file_count.entry(file).or_insert(0) += 1;
392        selected.push(results[orig_idx].clone());
393    }
394
395    selected
396}
397
398// --- Tests ---
399
400#[cfg(test)]
401mod tests {
402    use super::*;
403
404    fn make_result(file: &str, symbol: &str, kind: ChunkKind, score: f64) -> HybridResult {
405        HybridResult {
406            file_path: file.to_string(),
407            symbol_name: symbol.to_string(),
408            kind,
409            start_line: 1,
410            end_line: 10,
411            snippet: String::new(),
412            rrf_score: score,
413            bm25_score: Some(score),
414            dense_score: None,
415            bm25_rank: Some(1),
416            dense_rank: None,
417        }
418    }
419
420    #[test]
421    fn classify_symbol_queries() {
422        assert_eq!(classify_query("AuthService"), QueryType::Symbol);
423        assert_eq!(classify_query("Foo::bar"), QueryType::Symbol);
424        assert_eq!(classify_query("get_user_by_id"), QueryType::Symbol);
425        assert_eq!(classify_query("_private"), QueryType::Symbol);
426        assert_eq!(classify_query("HTTP_CLIENT"), QueryType::Symbol);
427        assert_eq!(classify_query("getUserById"), QueryType::Symbol);
428    }
429
430    #[test]
431    fn classify_nl_queries() {
432        assert_eq!(
433            classify_query("authentication flow"),
434            QueryType::NaturalLanguage
435        );
436        assert_eq!(
437            classify_query("save model to disk"),
438            QueryType::NaturalLanguage
439        );
440        assert_eq!(classify_query("error handling"), QueryType::NaturalLanguage);
441    }
442
443    #[test]
444    fn classify_architecture_queries() {
445        assert_eq!(
446            classify_query("how does auth work"),
447            QueryType::Architecture
448        );
449        assert_eq!(
450            classify_query("where is the data flow"),
451            QueryType::Architecture
452        );
453        assert_eq!(
454            classify_query("module structure overview"),
455            QueryType::Architecture
456        );
457    }
458
459    #[test]
460    fn definition_boost_works() {
461        let mut results = vec![
462            make_result("src/auth.rs", "authenticate", ChunkKind::Function, 0.5),
463            make_result("src/main.rs", "main", ChunkKind::Function, 0.8),
464            make_result("src/auth.rs", "AuthService", ChunkKind::Struct, 0.4),
465        ];
466
467        definition_boost(&mut results, "AuthService", QueryType::Symbol);
468
469        // AuthService struct should now be highest
470        assert!(results[2].rrf_score > results[1].rrf_score);
471    }
472
473    #[test]
474    fn noise_penalty_applies() {
475        let mut results = vec![
476            make_result("src/auth.rs", "auth", ChunkKind::Function, 1.0),
477            make_result("tests/test_auth.rs", "test_auth", ChunkKind::Function, 1.0),
478        ];
479
480        apply_noise_penalties(&mut results);
481
482        assert!(results[0].rrf_score > results[1].rrf_score);
483        assert!((results[1].rrf_score - STRONG_PENALTY).abs() < 0.001);
484    }
485
486    #[test]
487    fn file_coherence_boosts_multi_chunk_files() {
488        let mut results = vec![
489            make_result("src/auth.rs", "login", ChunkKind::Function, 0.5),
490            make_result("src/auth.rs", "logout", ChunkKind::Function, 0.4),
491            make_result("src/main.rs", "main", ChunkKind::Function, 0.6),
492        ];
493
494        file_coherence_boost(&mut results);
495
496        // auth.rs top chunk should be boosted (multi-chunk file)
497        assert!(results[0].rrf_score > 0.5);
498    }
499
500    #[test]
501    fn diversity_limits_same_file() {
502        let results = vec![
503            make_result("src/big.rs", "fn1", ChunkKind::Function, 1.0),
504            make_result("src/big.rs", "fn2", ChunkKind::Function, 0.9),
505            make_result("src/big.rs", "fn3", ChunkKind::Function, 0.8),
506            make_result("src/other.rs", "fn4", ChunkKind::Function, 0.7),
507        ];
508
509        let diverse = apply_diversity(results, 3);
510        // Should include other.rs due to saturation of big.rs
511        let files: Vec<&str> = diverse.iter().map(|r| r.file_path.as_str()).collect();
512        assert!(files.contains(&"src/other.rs"));
513    }
514
515    #[test]
516    fn extract_symbol_from_qualified() {
517        assert_eq!(extract_symbol_name("Foo::bar"), "bar");
518        assert_eq!(extract_symbol_name("obj.method"), "method");
519        assert_eq!(extract_symbol_name("ptr->field"), "field");
520        assert_eq!(extract_symbol_name("SimpleIdent"), "SimpleIdent");
521    }
522
523    #[test]
524    fn path_penalties_correct() {
525        assert!((path_penalty("src/auth.rs") - 1.0).abs() < 0.001);
526        assert!((path_penalty("tests/test_auth.py") - STRONG_PENALTY).abs() < 0.001);
527        assert!((path_penalty("src/compat/old.rs") - STRONG_PENALTY).abs() < 0.001);
528        assert!((path_penalty("src/types.d.ts") - MILD_PENALTY).abs() < 0.001);
529        assert!((path_penalty("src/__init__.py") - MODERATE_PENALTY).abs() < 0.001);
530    }
531}