Skip to main content

ripvec_core/encoder/ripvec/
ranking.rs

1//! Alpha auto-detection and query-driven boosting.
2//!
3//! Port of `~/src/semble/src/semble/ranking/weighting.py` and
4//! `~/src/semble/src/semble/ranking/boosting.py`.
5//!
6//! Three public entry points:
7//!
8//! - [`resolve_alpha`] — picks the semantic/BM25 blend weight from the
9//!   query shape: 0.3 for bare-symbol queries (lean BM25), 0.5 for
10//!   natural-language queries (balanced).
11//! - [`apply_query_boost`] — adds query-type boosts on top of a score
12//!   map. Returns a new map; callers re-rank afterwards.
13//! - [`boost_multi_chunk_files`] — file-coherence boost; promotes the
14//!   top chunk of files whose chunks collectively score high.
15//!
16//! Where Python keys `combined_scores` by `Chunk`, this port uses
17//! `HashMap<usize, f32>` (chunk index → score) plus `&[CodeChunk]` for
18//! lookups. Same shape as [`crate::encoder::ripvec::penalties::rerank_topk`].
19
20use std::collections::{HashMap, HashSet};
21use std::path::Path;
22use std::sync::OnceLock;
23
24use regex::{Regex, RegexBuilder};
25
26use crate::chunk::CodeChunk;
27use crate::encoder::ripvec::tokens::split_identifier;
28
29// ---------------------------------------------------------------------------
30// Alpha selection (weighting.py).
31// ---------------------------------------------------------------------------
32
33/// Semantic blend weight for symbol-shaped queries. Lean BM25.
34pub const ALPHA_SYMBOL: f32 = 0.3;
35/// Semantic blend weight for natural-language queries. Balanced.
36pub const ALPHA_NL: f32 = 0.5;
37
38/// Return the semantic blend weight, optionally overriding by caller.
39///
40/// `alpha = Some(w)` returns `w` directly. `alpha = None` auto-detects
41/// from the query: bare symbol-shaped → [`ALPHA_SYMBOL`], otherwise
42/// [`ALPHA_NL`].
43#[must_use]
44pub fn resolve_alpha(query: &str, alpha: Option<f32>) -> f32 {
45    if let Some(w) = alpha {
46        return w;
47    }
48    if is_symbol_query(query) {
49        ALPHA_SYMBOL
50    } else {
51        ALPHA_NL
52    }
53}
54
55// ---------------------------------------------------------------------------
56// Symbol-query detection (boosting.py:11).
57// ---------------------------------------------------------------------------
58
59/// Symbol-lookup queries: namespace-qualified, leading underscore, or
60/// containing uppercase/underscore. Plain lowercase words are NL.
61fn symbol_query_re() -> &'static Regex {
62    static RE: OnceLock<Regex> = OnceLock::new();
63    RE.get_or_init(|| {
64        Regex::new(concat!(
65            r"^(?:",
66            // namespace-qualified
67            r"[A-Za-z_][A-Za-z0-9_]*(?:(?:::|\\|->|\.)[A-Za-z_][A-Za-z0-9_]*)+",
68            // leading underscore
69            r"|_[A-Za-z0-9_]*",
70            // contains uppercase or underscore in the body
71            r"|[A-Za-z][A-Za-z0-9]*[A-Z_][A-Za-z0-9_]*",
72            // starts with uppercase
73            r"|[A-Z][A-Za-z0-9]*",
74            r")$",
75        ))
76        .expect("symbol-query regex compiles")
77    })
78}
79
80/// CamelCase / camelCase identifiers embedded in an NL query.
81/// Excludes plain words and pure acronyms.
82fn embedded_symbol_re() -> &'static Regex {
83    static RE: OnceLock<Regex> = OnceLock::new();
84    RE.get_or_init(|| {
85        Regex::new(concat!(
86            r"\b(?:",
87            // PascalCase: upper, lower run, then upper, then mix.
88            r"[A-Z][a-z][a-zA-Z0-9]*[A-Z][a-zA-Z0-9]*",
89            // camelCase: lower, mix, then upper, then mix.
90            r"|[a-z][a-zA-Z0-9]*[A-Z][a-zA-Z0-9]+",
91            r")\b",
92        ))
93        .expect("embedded-symbol regex compiles")
94    })
95}
96
97/// Return `true` when the query looks like a bare symbol or
98/// namespace-qualified identifier (`Foo::Bar`, `module.Class`, `_x`,
99/// `getX`, `XMLParser`, etc.).
100#[must_use]
101pub fn is_symbol_query(query: &str) -> bool {
102    symbol_query_re().is_match(query.trim())
103}
104
105/// Return `true` when `path` is a prose / documentation file
106/// (markdown, rst, plain text, asciidoc, org).
107///
108/// Used to gate the L-12 cross-encoder reranker on the ripvec engine:
109/// the ms-marco-trained model helps on natural-language prose but
110/// measurably hurts code retrieval (see `examples/semble_bench.rs`
111/// against the semble python corpus: -0.060 NDCG@10 average across
112/// 8 repos with rerank vs without). Code corpora skip rerank; prose
113/// corpora keep it.
114#[must_use]
115pub fn is_prose_path(path: &str) -> bool {
116    let lower = path.to_ascii_lowercase();
117    let ext = lower.rsplit('.').next().unwrap_or("");
118    matches!(
119        ext,
120        "md" | "markdown" | "mdx" | "rst" | "txt" | "text" | "adoc" | "asciidoc" | "org"
121    )
122}
123
124// ---------------------------------------------------------------------------
125// Definition-keyword scan (boosting.py:36).
126// ---------------------------------------------------------------------------
127
128/// Language-agnostic definition keywords (Python, JS, Go, Rust, Kotlin,
129/// Elixir, Swift, etc.). Case-sensitive matching avoids false positives
130/// like "Module" appearing in Python docstrings.
131const DEFINITION_KEYWORDS: &[&str] = &[
132    "class",
133    "module",
134    "defmodule", // Elixir
135    "def",
136    "interface",
137    "struct",
138    "enum",
139    "trait",
140    "type",
141    "func",
142    "function",
143    "object",
144    "abstract class",
145    "data class",
146    "fn",
147    "fun", // Kotlin
148    "package",
149    "namespace",
150    "protocol", // Swift
151    "record",   // C# 9+, Java 16+
152    "typedef",  // C/C++/Dart
153];
154
155/// SQL DDL is conventionally all-caps or all-lowercase; match both via
156/// case-insensitive regex.
157const SQL_DEFINITION_KEYWORDS: &[&str] = &[
158    "CREATE TABLE",
159    "CREATE VIEW",
160    "CREATE PROCEDURE",
161    "CREATE FUNCTION",
162];
163
164const DEFINITION_BOOST_MULTIPLIER: f32 = 3.0;
165const STEM_BOOST_MULTIPLIER: f32 = 1.0;
166const FILE_COHERENCE_BOOST_FRAC: f32 = 0.2;
167const EMBEDDED_SYMBOL_BOOST_SCALE: f32 = 0.5;
168const EMBEDDED_STEM_MIN_LEN: usize = 4;
169
170/// Common English stopwords excluded from file-stem matching for NL
171/// queries. Mirrors `_STOPWORDS` from boosting.py:82.
172const STOPWORDS: &[&str] = &[
173    "a", "an", "and", "are", "as", "at", "be", "by", "do", "does", "for", "from", "has", "have",
174    "how", "if", "in", "is", "it", "not", "of", "on", "or", "the", "to", "was", "what", "when",
175    "where", "which", "who", "why", "with",
176];
177
178/// Build the general-keyword definition regex for a given symbol name.
179///
180/// Mirrors Python's `_definition_pattern`. The Python upstream uses
181/// `functools.lru_cache(256)`; we mirror that with a process-wide
182/// bounded cache so the per-query hot path in `apply_query_boost`
183/// (called ~100 candidates × ~2 symbols per query) only pays the
184/// regex compile cost on first use of a given symbol name.
185///
186/// Profile evidence (samply, 2026-05-21, post-2A+2B): `apply_query_boost`
187/// accounted for 28.3% of `search_hybrid` wall time on a 1M-chunk
188/// corpus. The bulk was inside `Regex::new` / `RegexBuilder::build`
189/// reachable from this function. Caching collapses 100K+ compiles
190/// to ~20 across a 100-query bench.
191fn definition_pattern_uncached(symbol_name: &str) -> (Regex, Regex) {
192    let escaped = regex::escape(symbol_name);
193    let ns_prefix = r"(?:[A-Za-z_][A-Za-z0-9_]*(?:\.|::))*";
194    // Python's pattern uses `(?:^|(?<=\s))(?:keywords)` — keyword at start-of-line
195    // or preceded by whitespace. Rust's RE2 has no lookbehind, so we use
196    // `(?:^|\s)(?:keywords)` and let the whitespace be consumed; semantically
197    // equivalent for definition-keyword detection.
198    let def_body = DEFINITION_KEYWORDS
199        .iter()
200        .map(|k| regex::escape(k))
201        .collect::<Vec<_>>()
202        .join("|");
203    let sql_body = SQL_DEFINITION_KEYWORDS
204        .iter()
205        .map(|k| regex::escape(k))
206        .collect::<Vec<_>>()
207        .join("|");
208    let suffix = format!(r")\s+{ns_prefix}{escaped}(?:\s|[<({{:\[;]|$)");
209    // The Rust `regex` crate uses RE2 and does not support `(?<=\s)` (lookbehind).
210    // Python's pattern is `(?:^|(?<=\s))(?:keywords)` — keywords appear at start of line
211    // or after whitespace. Equivalent without lookbehind: anchor on `(?:^|\s)` and include
212    // the whitespace character in the match (rather than as a lookbehind boundary). The
213    // semantic is identical for definition detection.
214    let no_lookbehind_prefix = r"(?:^|\s)(?:";
215    let general_pat = format!("{no_lookbehind_prefix}{def_body}{suffix}");
216    let sql_pat = format!("{no_lookbehind_prefix}{sql_body}{suffix}");
217    let general = RegexBuilder::new(&general_pat)
218        .multi_line(true)
219        .build()
220        .expect("general definition regex compiles");
221    let sql = RegexBuilder::new(&sql_pat)
222        .multi_line(true)
223        .case_insensitive(true)
224        .build()
225        .expect("SQL definition regex compiles");
226    (general, sql)
227}
228
229/// Process-wide bounded cache for `definition_pattern_uncached`, keyed
230/// by symbol name. Bounded at 256 entries with simple stop-inserting
231/// eviction (matches semble's `lru_cache(256)`; on overflow the cache
232/// just stops growing, so subsequent unseen symbols fall through to a
233/// fresh compile each call but the hot working set still hits).
234///
235/// `regex::Regex` is internally `Arc`-shared so clone is cheap.
236fn definition_pattern(symbol_name: &str) -> (Regex, Regex) {
237    use std::sync::{Mutex, OnceLock};
238    static CACHE: OnceLock<Mutex<std::collections::HashMap<String, (Regex, Regex)>>> =
239        OnceLock::new();
240    let cache = CACHE.get_or_init(|| Mutex::new(std::collections::HashMap::new()));
241
242    if let Ok(map) = cache.lock()
243        && let Some(entry) = map.get(symbol_name)
244    {
245        return entry.clone();
246    }
247
248    let pair = definition_pattern_uncached(symbol_name);
249    if let Ok(mut map) = cache.lock()
250        && map.len() < 256
251    {
252        map.insert(symbol_name.to_string(), pair.clone());
253    }
254    pair
255}
256
257/// `true` when `content` contains a definition of `symbol_name`.
258///
259/// Mirrors Python's `_chunk_defines_symbol`. Case-sensitive for general
260/// keywords; case-insensitive for SQL DDL. Namespace-qualified forms
261/// (`defmodule Phoenix.Router` for `Router`) match because the pattern
262/// allows an optional `ns_prefix`.
263///
264/// Substring pre-filter: if `content` doesn't even contain
265/// `symbol_name` as a substring, no definition can possibly exist.
266/// This short-circuit skips two regex evaluations against ~1500-char
267/// chunk content for the vast majority of candidates whose top-ranked
268/// status comes from semantic / BM25 signal rather than from defining
269/// the queried symbol.
270fn chunk_defines_symbol(content: &str, symbol_name: &str) -> bool {
271    if !content.contains(symbol_name) {
272        return false;
273    }
274    let (general, sql) = definition_pattern(symbol_name);
275    general.is_match(content) || sql.is_match(content)
276}
277
278/// `true` when `stem` matches `name` (exact, snake_case-normalised, or
279/// plural). Mirrors Python's `_stem_matches`.
280fn stem_matches(stem: &str, name: &str) -> bool {
281    let stem_norm = stem.replace('_', "");
282    stem == name
283        || stem_norm == name
284        || stem.trim_end_matches('s') == name
285        || stem_norm.trim_end_matches('s') == name
286}
287
288// ---------------------------------------------------------------------------
289// Symbol extraction (boosting.py:137).
290// ---------------------------------------------------------------------------
291
292/// Extract the final identifier from a possibly namespace-qualified
293/// query. Mirrors `_extract_symbol_name`.
294///
295/// Examples: `Sinatra::Base` → `Base`, `Client` → `Client`.
296fn extract_symbol_name(query: &str) -> String {
297    for separator in &["::", "\\", "->", "."] {
298        if let Some(idx) = query.rfind(separator) {
299            return query[idx + separator.len()..].to_string();
300        }
301    }
302    query.trim().to_string()
303}
304
305/// Return the boost amount for a chunk that defines one of `names`
306/// (0.0 if none match). Mirrors `_definition_tier`.
307fn definition_tier(chunk: &CodeChunk, names: &HashSet<String>, boost_unit: f32) -> f32 {
308    let any_match = names
309        .iter()
310        .any(|name| chunk_defines_symbol(&chunk.content, name));
311    if !any_match {
312        return 0.0;
313    }
314    let stem = Path::new(&chunk.file_path)
315        .file_stem()
316        .and_then(|s| s.to_str())
317        .unwrap_or_default()
318        .to_ascii_lowercase();
319    let stem_match_bonus = names
320        .iter()
321        .any(|name| stem_matches(&stem, &name.to_ascii_lowercase()));
322    boost_unit * if stem_match_bonus { 1.5 } else { 1.0 }
323}
324
325/// Boost non-candidate chunks whose lowercased file stem satisfies
326/// `stem_ok`. Mirrors `_scan_non_candidates`.
327fn scan_non_candidates(
328    boosted: &mut HashMap<usize, f32>,
329    names: &HashSet<String>,
330    boost_unit: f32,
331    all_chunks: &[CodeChunk],
332    stem_ok: &dyn Fn(&str) -> bool,
333) {
334    for (idx, chunk) in all_chunks.iter().enumerate() {
335        if boosted.contains_key(&idx) {
336            continue;
337        }
338        let stem = Path::new(&chunk.file_path)
339            .file_stem()
340            .and_then(|s| s.to_str())
341            .unwrap_or_default()
342            .to_ascii_lowercase();
343        if !stem_ok(&stem) {
344            continue;
345        }
346        let tier = definition_tier(chunk, names, boost_unit);
347        if tier > 0.0 {
348            boosted.insert(idx, tier);
349        }
350    }
351}
352
353/// Symbol-query branch: boost chunks defining the queried symbol and
354/// stem-matched non-candidates. Mirrors `_boost_symbol_definitions`.
355fn boost_symbol_definitions(
356    boosted: &mut HashMap<usize, f32>,
357    query: &str,
358    max_score: f32,
359    all_chunks: &[CodeChunk],
360) {
361    let symbol_name = extract_symbol_name(query);
362    let trimmed_query = query.trim();
363    let mut names: HashSet<String> = HashSet::new();
364    names.insert(symbol_name.clone());
365    if symbol_name != trimmed_query {
366        names.insert(trimmed_query.to_string());
367    }
368
369    let boost_unit = max_score * DEFINITION_BOOST_MULTIPLIER;
370
371    // Pass 1: walk current candidates, add a tier if they define the symbol.
372    let candidate_indices: Vec<usize> = boosted.keys().copied().collect();
373    for idx in candidate_indices {
374        let tier = definition_tier(&all_chunks[idx], &names, boost_unit);
375        if tier > 0.0 {
376            *boosted.entry(idx).or_insert(0.0) += tier;
377        }
378    }
379
380    // Pass 2: scan non-candidate chunks whose stem matches the symbol name.
381    let symbol_lower = symbol_name.to_ascii_lowercase();
382    scan_non_candidates(boosted, &names, boost_unit, all_chunks, &|stem: &str| {
383        stem_matches(stem, &symbol_lower)
384    });
385}
386
387/// NL-query branch: boost CamelCase / camelCase identifiers embedded in
388/// the query at half strength. Mirrors `_boost_embedded_symbols`.
389fn boost_embedded_symbols(
390    boosted: &mut HashMap<usize, f32>,
391    query: &str,
392    max_score: f32,
393    all_chunks: &[CodeChunk],
394) {
395    let names: HashSet<String> = embedded_symbol_re()
396        .find_iter(query)
397        .map(|m| m.as_str().to_string())
398        .collect();
399    if names.is_empty() {
400        return;
401    }
402
403    let boost_unit = max_score * DEFINITION_BOOST_MULTIPLIER * EMBEDDED_SYMBOL_BOOST_SCALE;
404
405    // Pass 1: candidates that define the embedded symbol(s).
406    let candidate_indices: Vec<usize> = boosted.keys().copied().collect();
407    for idx in candidate_indices {
408        let tier = definition_tier(&all_chunks[idx], &names, boost_unit);
409        if tier > 0.0 {
410            *boosted.entry(idx).or_insert(0.0) += tier;
411        }
412    }
413
414    // Pass 2: non-candidate stem-prefix scan.
415    let symbols_lower: Vec<String> = names.iter().map(|n| n.to_ascii_lowercase()).collect();
416    let symbols_lower_for_scan = symbols_lower.clone();
417    scan_non_candidates(
418        boosted,
419        &names,
420        boost_unit,
421        all_chunks,
422        &move |stem: &str| {
423            let stem_norm = stem.replace('_', "");
424            symbols_lower_for_scan.iter().any(|sym_lower| {
425                stem == sym_lower
426                    || stem_norm == *sym_lower
427                    || (stem.len() >= EMBEDDED_STEM_MIN_LEN && sym_lower.starts_with(stem))
428                    || (stem_norm.len() >= EMBEDDED_STEM_MIN_LEN
429                        && sym_lower.starts_with(stem_norm.as_str()))
430            })
431        },
432    );
433}
434
435/// Count query keywords that match path parts. Mirrors
436/// `_count_keyword_matches`. Allows prefix overlap when the shorter
437/// side has at least 3 characters.
438fn count_keyword_matches(keywords: &HashSet<String>, parts: &HashSet<String>) -> usize {
439    let exact: HashSet<&String> = keywords.iter().filter(|k| parts.contains(*k)).collect();
440    if exact.len() == keywords.len() {
441        return exact.len();
442    }
443    let mut n = exact.len();
444    for keyword in keywords {
445        if exact.contains(keyword) {
446            continue;
447        }
448        for part in parts {
449            let (shorter, longer) = if keyword.len() <= part.len() {
450                (keyword.as_str(), part.as_str())
451            } else {
452                (part.as_str(), keyword.as_str())
453            };
454            if shorter.len() >= 3 && longer.starts_with(shorter) {
455                n += 1;
456                break;
457            }
458        }
459    }
460    n
461}
462
463/// Boost chunks whose file paths match NL query keywords. Mirrors
464/// `_boost_stem_matches`. Uses prefix matching for morphological
465/// variants ("dependency" matches "dependencies"). Matches file stems
466/// and the immediate parent directory name.
467fn boost_stem_matches(
468    boosted: &mut HashMap<usize, f32>,
469    query: &str,
470    max_score: f32,
471    chunks: &[CodeChunk],
472) {
473    static KEYWORD_RE: OnceLock<Regex> = OnceLock::new();
474    let keyword_re =
475        KEYWORD_RE.get_or_init(|| Regex::new(r"[a-zA-Z_][a-zA-Z0-9_]*").expect("keyword regex"));
476    let keywords: HashSet<String> = keyword_re
477        .find_iter(query)
478        .map(|m| m.as_str().to_ascii_lowercase())
479        .filter(|w| w.len() > 2 && !STOPWORDS.contains(&w.as_str()))
480        .collect();
481    if keywords.is_empty() {
482        return;
483    }
484
485    let boost = max_score * STEM_BOOST_MULTIPLIER;
486    let mut path_cache: HashMap<String, HashSet<String>> = HashMap::new();
487    let candidate_indices: Vec<usize> = boosted.keys().copied().collect();
488    for idx in candidate_indices {
489        let path = &chunks[idx].file_path;
490        let parts = path_cache
491            .entry(path.clone())
492            .or_insert_with(|| {
493                let mut parts: HashSet<String> = HashSet::new();
494                let p = Path::new(path);
495                if let Some(stem) = p.file_stem().and_then(|s| s.to_str()) {
496                    parts.extend(split_identifier(stem));
497                }
498                if let Some(parent_name) = p
499                    .parent()
500                    .and_then(Path::file_name)
501                    .and_then(|s| s.to_str())
502                    && !parent_name.is_empty()
503                    && parent_name != "."
504                    && parent_name != ".."
505                {
506                    parts.extend(split_identifier(parent_name));
507                }
508                parts
509            })
510            .clone();
511        let n_matches = count_keyword_matches(&keywords, &parts);
512        if n_matches > 0 {
513            let match_ratio = n_matches as f32 / keywords.len() as f32;
514            if match_ratio >= 0.10 {
515                *boosted.entry(idx).or_insert(0.0) += boost * match_ratio;
516            }
517        }
518    }
519}
520
521// ---------------------------------------------------------------------------
522// Public entry: apply_query_boost.
523// ---------------------------------------------------------------------------
524
525/// Apply query-type boosts to candidate scores.
526///
527/// Mirrors `apply_query_boost`. Returns a new map; the input is not
528/// mutated. Empty input passes through.
529///
530/// Branches on [`is_symbol_query`]:
531/// - Symbol-shaped query → [`boost_symbol_definitions`] (×3 base,
532///   ×1.5 on stem match) and scan non-candidate stem-matched chunks.
533/// - NL query → [`boost_stem_matches`] (file-stem keyword overlap) +
534///   [`boost_embedded_symbols`] (half-strength CamelCase scan).
535#[expect(
536    clippy::implicit_hasher,
537    reason = "internal API; callers in the semble pipeline use the default RandomState"
538)]
539#[must_use]
540pub fn apply_query_boost(
541    combined_scores: &HashMap<usize, f32>,
542    query: &str,
543    all_chunks: &[CodeChunk],
544) -> HashMap<usize, f32> {
545    if combined_scores.is_empty() {
546        return HashMap::new();
547    }
548    let max_score = combined_scores
549        .values()
550        .copied()
551        .fold(f32::NEG_INFINITY, f32::max);
552    let mut boosted = combined_scores.clone();
553    if is_symbol_query(query) {
554        boost_symbol_definitions(&mut boosted, query, max_score, all_chunks);
555    } else {
556        boost_stem_matches(&mut boosted, query, max_score, all_chunks);
557        boost_embedded_symbols(&mut boosted, query, max_score, all_chunks);
558    }
559    boosted
560}
561
562/// Promote files with multiple high-scoring chunks by boosting their
563/// top chunk in place. Mirrors `boost_multi_chunk_files`.
564#[expect(
565    clippy::implicit_hasher,
566    reason = "internal API; callers in the semble pipeline use the default RandomState"
567)]
568pub fn boost_multi_chunk_files(scores: &mut HashMap<usize, f32>, chunks: &[CodeChunk]) {
569    if scores.is_empty() {
570        return;
571    }
572    let max_score = scores.values().copied().fold(f32::NEG_INFINITY, f32::max);
573    if max_score == 0.0 || !max_score.is_finite() {
574        return;
575    }
576
577    let mut file_sum: HashMap<String, f32> = HashMap::new();
578    let mut best_chunk_idx: HashMap<String, usize> = HashMap::new();
579    for (&idx, &score) in scores.iter() {
580        let path = chunks[idx].file_path.clone();
581        *file_sum.entry(path.clone()).or_insert(0.0) += score;
582        match best_chunk_idx.get(&path) {
583            Some(&best) if scores[&best] >= score => {}
584            _ => {
585                best_chunk_idx.insert(path, idx);
586            }
587        }
588    }
589    let max_file_sum = file_sum.values().copied().fold(f32::NEG_INFINITY, f32::max);
590    if max_file_sum <= 0.0 || !max_file_sum.is_finite() {
591        return;
592    }
593    let boost_unit = max_score * FILE_COHERENCE_BOOST_FRAC;
594    for (path, &idx) in &best_chunk_idx {
595        let contribution = boost_unit * file_sum[path] / max_file_sum;
596        *scores.entry(idx).or_insert(0.0) += contribution;
597    }
598}
599
600#[cfg(test)]
601mod tests {
602    use super::*;
603
604    fn chunk(path: &str, content: &str) -> CodeChunk {
605        CodeChunk {
606            file_path: path.to_string(),
607            name: String::new(),
608            kind: String::new(),
609            start_line: 1,
610            end_line: 1,
611            content: content.to_string(),
612            enriched_content: content.to_string(),
613        }
614    }
615
616    // ----- is_symbol_query (boosting.py:132) -----
617
618    #[test]
619    fn is_symbol_query_namespace() {
620        assert!(is_symbol_query("Sinatra::Base"));
621        assert!(is_symbol_query("module.Class"));
622        assert!(is_symbol_query("a->b->c"));
623        assert!(is_symbol_query(r"Foo\Bar"));
624    }
625
626    #[test]
627    fn is_symbol_query_pascal() {
628        assert!(is_symbol_query("Client"));
629        assert!(is_symbol_query("HTTPHandler"));
630        assert!(is_symbol_query("XMLParser"));
631    }
632
633    #[test]
634    fn is_symbol_query_plain_word_rejected() {
635        assert!(!is_symbol_query("session"));
636        assert!(!is_symbol_query("retry"));
637        assert!(!is_symbol_query("authentication"));
638    }
639
640    #[test]
641    fn is_symbol_query_leading_underscore_accepted() {
642        assert!(is_symbol_query("_private"));
643        assert!(is_symbol_query("__init__"));
644    }
645
646    // ----- resolve_alpha -----
647
648    #[test]
649    fn resolve_alpha_symbol_0_3() {
650        assert!((resolve_alpha("Client", None) - ALPHA_SYMBOL).abs() < 1e-6);
651        assert!((resolve_alpha("foo.Bar", None) - ALPHA_SYMBOL).abs() < 1e-6);
652    }
653
654    #[test]
655    fn resolve_alpha_nl_0_5() {
656        assert!((resolve_alpha("how does retry work", None) - ALPHA_NL).abs() < 1e-6);
657        assert!((resolve_alpha("authentication handling", None) - ALPHA_NL).abs() < 1e-6);
658    }
659
660    #[test]
661    fn resolve_alpha_explicit_override_wins() {
662        assert!((resolve_alpha("Client", Some(0.7)) - 0.7).abs() < 1e-6);
663    }
664
665    // ----- definition_pattern + chunk_defines_symbol -----
666
667    #[test]
668    fn chunk_defines_symbol_class() {
669        let content = "class Client:\n    pass";
670        assert!(chunk_defines_symbol(content, "Client"));
671    }
672
673    #[test]
674    fn chunk_defines_symbol_def() {
675        let content = "def handle_request():\n    pass";
676        assert!(chunk_defines_symbol(content, "handle_request"));
677    }
678
679    #[test]
680    fn chunk_defines_symbol_namespace_qualified() {
681        // Elixir defmodule with namespace prefix should match the bare symbol.
682        let content = " defmodule Phoenix.Router do\n";
683        assert!(chunk_defines_symbol(content, "Router"));
684    }
685
686    #[test]
687    fn chunk_defines_symbol_sql_case_insensitive() {
688        assert!(chunk_defines_symbol(
689            " create table users (id int)",
690            "users"
691        ));
692        assert!(chunk_defines_symbol(
693            " CREATE TABLE Users (id int)",
694            "Users"
695        ));
696    }
697
698    #[test]
699    fn chunk_defines_symbol_negative() {
700        assert!(!chunk_defines_symbol("client.do_thing()", "Client"));
701    }
702
703    // ----- boost_symbol_definitions stem multiplier -----
704
705    #[test]
706    fn boost_symbol_definitions_stem_multiplier() {
707        // Two chunks both define `Client`; one is in client.rs (stem match),
708        // the other in unrelated.rs. The stem-matched chunk gets the 1.5x
709        // bonus on top of the base ×3 multiplier.
710        let chunks = vec![
711            chunk("src/client.rs", "struct Client { /* ... */ }"),
712            chunk("src/unrelated.rs", "struct Client { /* ... */ }"),
713        ];
714        let mut boosted: HashMap<usize, f32> = HashMap::from([(0, 1.0), (1, 1.0)]);
715        boost_symbol_definitions(&mut boosted, "Client", 1.0, &chunks);
716        // base_boost = max_score (1.0) * 3.0 = 3.0
717        // stem-matched chunk 0: 1.0 + 3.0 * 1.5 = 5.5
718        // non-matched chunk 1: 1.0 + 3.0 = 4.0
719        assert!((boosted[&0] - 5.5).abs() < 1e-6);
720        assert!((boosted[&1] - 4.0).abs() < 1e-6);
721    }
722
723    // ----- boost_stem_matches -----
724
725    #[test]
726    fn boost_stem_matches_prefix() {
727        // Query keyword "parse" should boost a chunk in "parser.rs" via
728        // prefix overlap (shorter="parse", longer="parser", min-length 3
729        // satisfied). This is the real morphological-variant shape;
730        // Python semble's docstring example ("dependency" ↔ "dependencies")
731        // is misleading because the two diverge at char 9.
732        let chunks = vec![chunk("src/parser.rs", "fn run() {}")];
733        let mut boosted: HashMap<usize, f32> = HashMap::from([(0, 1.0_f32)]);
734        boost_stem_matches(&mut boosted, "parse json structure", 1.0, &chunks);
735        assert!(boosted[&0] > 1.0, "expected stem-match boost on parser.rs");
736    }
737
738    // ----- boost_embedded_symbols half-strength -----
739
740    #[test]
741    fn boost_embedded_symbols_half_strength() {
742        // Embedded symbol "MyClass" in an NL query boosts at half strength
743        // vs a pure symbol query.
744        let chunks = vec![chunk("src/myclass.rs", "struct MyClass {}")];
745        let mut boosted: HashMap<usize, f32> = HashMap::from([(0, 1.0_f32)]);
746        boost_embedded_symbols(&mut boosted, "how does MyClass handle errors", 1.0, &chunks);
747        // base_boost = 1.0 * 3.0 * 0.5 = 1.5
748        // stem-matched chunk 0: 1.0 + 1.5 * 1.5 = 3.25
749        assert!((boosted[&0] - 3.25).abs() < 1e-6, "got {}", boosted[&0]);
750    }
751
752    // ----- boost_multi_chunk_files (file coherence) -----
753
754    #[test]
755    fn boost_multi_chunk_files() {
756        // Three chunks: two in foo.rs (sum 1.5), one in bar.rs (sum 1.0).
757        // boost_unit = max_score (1.0) * 0.2 = 0.2
758        // foo.rs file_sum 1.5; boost to its TOP chunk (idx 0) =
759        //   0.2 * 1.5 / 1.5 = 0.2 → final 1.2.
760        // bar.rs file_sum 1.0; boost to chunk 2 = 0.2 * 1.0/1.5 ≈ 0.133 → ~1.133.
761        let chunks = vec![
762            chunk("src/foo.rs", ""),
763            chunk("src/foo.rs", ""),
764            chunk("src/bar.rs", ""),
765        ];
766        let mut scores: HashMap<usize, f32> = HashMap::from([(0, 1.0), (1, 0.5), (2, 1.0)]);
767        super::boost_multi_chunk_files(&mut scores, &chunks);
768        assert!((scores[&0] - 1.2).abs() < 1e-6, "got {}", scores[&0]);
769        assert!((scores[&1] - 0.5).abs() < 1e-6, "non-best chunk unchanged");
770        let expected_bar = 1.0 + 0.2 * (1.0 / 1.5);
771        assert!(
772            (scores[&2] - expected_bar).abs() < 1e-6,
773            "got {}, expected {}",
774            scores[&2],
775            expected_bar
776        );
777    }
778
779    // ----- property test (symbol regex parity) -----
780
781    #[test]
782    fn property_symbol_regex_parity_python() {
783        // These should be detected as symbol queries (matches Python).
784        let symbols = &[
785            "Client",
786            "handle_request",
787            "_private",
788            "getX",
789            "XMLParser",
790            "foo::bar",
791            "foo.bar.baz",
792            "a->b",
793            r"Foo\Bar",
794            "__init__",
795            "snake_case",
796        ];
797        for q in symbols {
798            assert!(is_symbol_query(q), "expected symbol query: {q:?}");
799        }
800        // These should NOT be detected (NL).
801        let non_symbols = &[
802            "session",
803            "retry",
804            "authentication",
805            "how does retry work",
806            "user authentication flow",
807            "hi",
808        ];
809        for q in non_symbols {
810            assert!(!is_symbol_query(q), "expected NL: {q:?}");
811        }
812    }
813}