Skip to main content

ripvec_core/encoder/ripvec/
ranking.rs

1//! Alpha auto-detection and query-driven boosting.
2//!
3//! Port of `~/src/semble/src/semble/ranking/weighting.py` and
4//! `~/src/semble/src/semble/ranking/boosting.py`.
5//!
6//! Three public entry points:
7//!
8//! - [`resolve_alpha`] — picks the semantic/BM25 blend weight from the
9//!   query shape: 0.3 for bare-symbol queries (lean BM25), 0.5 for
10//!   natural-language queries (balanced).
11//! - [`apply_query_boost`] — adds query-type boosts on top of a score
12//!   map. Returns a new map; callers re-rank afterwards.
13//! - [`boost_multi_chunk_files`] — file-coherence boost; promotes the
14//!   top chunk of files whose chunks collectively score high.
15//!
16//! Where Python keys `combined_scores` by `Chunk`, this port uses
17//! `HashMap<usize, f32>` (chunk index → score) plus `&[CodeChunk]` for
18//! lookups. Same shape as [`crate::encoder::ripvec::penalties::rerank_topk`].
19
20use std::collections::{HashMap, HashSet};
21use std::path::Path;
22use std::sync::OnceLock;
23
24use regex::{Regex, RegexBuilder};
25
26use crate::chunk::CodeChunk;
27use crate::encoder::ripvec::tokens::split_identifier;
28
29// ---------------------------------------------------------------------------
30// Alpha selection (weighting.py).
31// ---------------------------------------------------------------------------
32
33/// Semantic blend weight for symbol-shaped queries. Lean BM25.
34pub const ALPHA_SYMBOL: f32 = 0.3;
35/// Semantic blend weight for natural-language queries. Balanced.
36pub const ALPHA_NL: f32 = 0.5;
37
38/// Return the semantic blend weight, optionally overriding by caller.
39///
40/// `alpha = Some(w)` returns `w` directly. `alpha = None` auto-detects
41/// from the query: bare symbol-shaped → [`ALPHA_SYMBOL`], otherwise
42/// [`ALPHA_NL`].
43#[must_use]
44pub fn resolve_alpha(query: &str, alpha: Option<f32>) -> f32 {
45    if let Some(w) = alpha {
46        return w;
47    }
48    if is_symbol_query(query) {
49        ALPHA_SYMBOL
50    } else {
51        ALPHA_NL
52    }
53}
54
55// ---------------------------------------------------------------------------
56// Symbol-query detection (boosting.py:11).
57// ---------------------------------------------------------------------------
58
59/// Symbol-lookup queries: namespace-qualified, leading underscore, or
60/// containing uppercase/underscore. Plain lowercase words are NL.
61fn symbol_query_re() -> &'static Regex {
62    static RE: OnceLock<Regex> = OnceLock::new();
63    RE.get_or_init(|| {
64        Regex::new(concat!(
65            r"^(?:",
66            // namespace-qualified
67            r"[A-Za-z_][A-Za-z0-9_]*(?:(?:::|\\|->|\.)[A-Za-z_][A-Za-z0-9_]*)+",
68            // leading underscore
69            r"|_[A-Za-z0-9_]*",
70            // contains uppercase or underscore in the body
71            r"|[A-Za-z][A-Za-z0-9]*[A-Z_][A-Za-z0-9_]*",
72            // starts with uppercase
73            r"|[A-Z][A-Za-z0-9]*",
74            r")$",
75        ))
76        .expect("symbol-query regex compiles")
77    })
78}
79
80/// CamelCase / camelCase identifiers embedded in an NL query.
81/// Excludes plain words and pure acronyms.
82fn embedded_symbol_re() -> &'static Regex {
83    static RE: OnceLock<Regex> = OnceLock::new();
84    RE.get_or_init(|| {
85        Regex::new(concat!(
86            r"\b(?:",
87            // PascalCase: upper, lower run, then upper, then mix.
88            r"[A-Z][a-z][a-zA-Z0-9]*[A-Z][a-zA-Z0-9]*",
89            // camelCase: lower, mix, then upper, then mix.
90            r"|[a-z][a-zA-Z0-9]*[A-Z][a-zA-Z0-9]+",
91            r")\b",
92        ))
93        .expect("embedded-symbol regex compiles")
94    })
95}
96
97/// Return `true` when the query looks like a bare symbol or
98/// namespace-qualified identifier (`Foo::Bar`, `module.Class`, `_x`,
99/// `getX`, `XMLParser`, etc.).
100#[must_use]
101pub fn is_symbol_query(query: &str) -> bool {
102    symbol_query_re().is_match(query.trim())
103}
104
105/// Return `true` when `path` is a prose / documentation file
106/// (markdown, rst, plain text, asciidoc, org).
107///
108/// Used to gate the L-12 cross-encoder reranker on the ripvec engine:
109/// the ms-marco-trained model helps on natural-language prose but
110/// measurably hurts code retrieval (see `examples/semble_bench.rs`
111/// against the semble python corpus: -0.060 NDCG@10 average across
112/// 8 repos with rerank vs without). Code corpora skip rerank; prose
113/// corpora keep it.
114#[must_use]
115pub fn is_prose_path(path: &str) -> bool {
116    let lower = path.to_ascii_lowercase();
117    let ext = lower.rsplit('.').next().unwrap_or("");
118    matches!(
119        ext,
120        "md" | "markdown" | "mdx" | "rst" | "txt" | "text" | "adoc" | "asciidoc" | "org"
121    )
122}
123
124// ---------------------------------------------------------------------------
125// Definition-keyword scan (boosting.py:36).
126// ---------------------------------------------------------------------------
127
128/// Language-agnostic definition keywords (Python, JS, Go, Rust, Kotlin,
129/// Elixir, Swift, etc.). Case-sensitive matching avoids false positives
130/// like "Module" appearing in Python docstrings.
131const DEFINITION_KEYWORDS: &[&str] = &[
132    "class",
133    "module",
134    "defmodule", // Elixir
135    "def",
136    "interface",
137    "struct",
138    "enum",
139    "trait",
140    "type",
141    "func",
142    "function",
143    "object",
144    "abstract class",
145    "data class",
146    "fn",
147    "fun", // Kotlin
148    "package",
149    "namespace",
150    "protocol", // Swift
151    "record",   // C# 9+, Java 16+
152    "typedef",  // C/C++/Dart
153];
154
155/// SQL DDL is conventionally all-caps or all-lowercase; match both via
156/// case-insensitive regex.
157const SQL_DEFINITION_KEYWORDS: &[&str] = &[
158    "CREATE TABLE",
159    "CREATE VIEW",
160    "CREATE PROCEDURE",
161    "CREATE FUNCTION",
162];
163
164const DEFINITION_BOOST_MULTIPLIER: f32 = 3.0;
165const STEM_BOOST_MULTIPLIER: f32 = 1.0;
166const FILE_COHERENCE_BOOST_FRAC: f32 = 0.2;
167const EMBEDDED_SYMBOL_BOOST_SCALE: f32 = 0.5;
168const EMBEDDED_STEM_MIN_LEN: usize = 4;
169
170/// Common English stopwords excluded from file-stem matching for NL
171/// queries. Mirrors `_STOPWORDS` from boosting.py:82.
172const STOPWORDS: &[&str] = &[
173    "a", "an", "and", "are", "as", "at", "be", "by", "do", "does", "for", "from", "has", "have",
174    "how", "if", "in", "is", "it", "not", "of", "on", "or", "the", "to", "was", "what", "when",
175    "where", "which", "who", "why", "with",
176];
177
178/// Build the general-keyword definition regex for a given symbol name.
179///
180/// Mirrors Python's `_definition_pattern`. The Python upstream uses
181/// `functools.lru_cache(256)`; we mirror that with a process-wide
182/// bounded cache so the per-query hot path in `apply_query_boost`
183/// (called ~100 candidates × ~2 symbols per query) only pays the
184/// regex compile cost on first use of a given symbol name.
185///
186/// Profile evidence (samply, 2026-05-21, post-2A+2B): `apply_query_boost`
187/// accounted for 28.3% of `search_hybrid` wall time on a 1M-chunk
188/// corpus. The bulk was inside `Regex::new` / `RegexBuilder::build`
189/// reachable from this function. Caching collapses 100K+ compiles
190/// to ~20 across a 100-query bench.
191fn definition_pattern_uncached(symbol_name: &str) -> (Regex, Regex) {
192    let escaped = regex::escape(symbol_name);
193    let ns_prefix = r"(?:[A-Za-z_][A-Za-z0-9_]*(?:\.|::))*";
194    // Python's pattern uses `(?:^|(?<=\s))(?:keywords)` — keyword at start-of-line
195    // or preceded by whitespace. Rust's RE2 has no lookbehind, so we use
196    // `(?:^|\s)(?:keywords)` and let the whitespace be consumed; semantically
197    // equivalent for definition-keyword detection.
198    let def_body = DEFINITION_KEYWORDS
199        .iter()
200        .map(|k| regex::escape(k))
201        .collect::<Vec<_>>()
202        .join("|");
203    let sql_body = SQL_DEFINITION_KEYWORDS
204        .iter()
205        .map(|k| regex::escape(k))
206        .collect::<Vec<_>>()
207        .join("|");
208    let suffix = format!(r")\s+{ns_prefix}{escaped}(?:\s|[<({{:\[;]|$)");
209    // The Rust `regex` crate uses RE2 and does not support `(?<=\s)` (lookbehind).
210    // Python's pattern is `(?:^|(?<=\s))(?:keywords)` — keywords appear at start of line
211    // or after whitespace. Equivalent without lookbehind: anchor on `(?:^|\s)` and include
212    // the whitespace character in the match (rather than as a lookbehind boundary). The
213    // semantic is identical for definition detection.
214    let no_lookbehind_prefix = r"(?:^|\s)(?:";
215    let general_pat = format!("{no_lookbehind_prefix}{def_body}{suffix}");
216    let sql_pat = format!("{no_lookbehind_prefix}{sql_body}{suffix}");
217    let general = RegexBuilder::new(&general_pat)
218        .multi_line(true)
219        .build()
220        .expect("general definition regex compiles");
221    let sql = RegexBuilder::new(&sql_pat)
222        .multi_line(true)
223        .case_insensitive(true)
224        .build()
225        .expect("SQL definition regex compiles");
226    (general, sql)
227}
228
229/// Process-wide bounded cache for `definition_pattern_uncached`, keyed
230/// by symbol name. Bounded at 256 entries with simple stop-inserting
231/// eviction (matches semble's `lru_cache(256)`; on overflow the cache
232/// just stops growing, so subsequent unseen symbols fall through to a
233/// fresh compile each call but the hot working set still hits).
234///
235/// `regex::Regex` is internally `Arc`-shared so clone is cheap.
236fn definition_pattern(symbol_name: &str) -> (Regex, Regex) {
237    use std::sync::{Mutex, OnceLock};
238    static CACHE: OnceLock<Mutex<std::collections::HashMap<String, (Regex, Regex)>>> =
239        OnceLock::new();
240    let cache = CACHE.get_or_init(|| Mutex::new(std::collections::HashMap::new()));
241
242    if let Ok(map) = cache.lock()
243        && let Some(entry) = map.get(symbol_name)
244    {
245        return entry.clone();
246    }
247
248    let pair = definition_pattern_uncached(symbol_name);
249    if let Ok(mut map) = cache.lock()
250        && map.len() < 256
251    {
252        map.insert(symbol_name.to_string(), pair.clone());
253    }
254    pair
255}
256
257/// `true` when `content` contains a definition of `symbol_name`.
258///
259/// Mirrors Python's `_chunk_defines_symbol`. Case-sensitive for general
260/// keywords; case-insensitive for SQL DDL. Namespace-qualified forms
261/// (`defmodule Phoenix.Router` for `Router`) match because the pattern
262/// allows an optional `ns_prefix`.
263///
264/// Substring pre-filter: if `content` doesn't even contain
265/// `symbol_name` as a substring, no definition can possibly exist.
266/// This short-circuit skips two regex evaluations against ~1500-char
267/// chunk content for the vast majority of candidates whose top-ranked
268/// status comes from semantic / BM25 signal rather than from defining
269/// the queried symbol.
270fn chunk_defines_symbol(content: &str, symbol_name: &str) -> bool {
271    if !content.contains(symbol_name) {
272        return false;
273    }
274    let (general, sql) = definition_pattern(symbol_name);
275    general.is_match(content) || sql.is_match(content)
276}
277
278/// `true` when `stem` matches `name` (exact, snake_case-normalised, or
279/// plural). Mirrors Python's `_stem_matches`.
280fn stem_matches(stem: &str, name: &str) -> bool {
281    let stem_norm = stem.replace('_', "");
282    stem == name
283        || stem_norm == name
284        || stem.trim_end_matches('s') == name
285        || stem_norm.trim_end_matches('s') == name
286}
287
288// ---------------------------------------------------------------------------
289// Symbol extraction (boosting.py:137).
290// ---------------------------------------------------------------------------
291
292/// Extract the final identifier from a possibly namespace-qualified
293/// query. Mirrors `_extract_symbol_name`.
294///
295/// Examples: `Sinatra::Base` → `Base`, `Client` → `Client`.
296fn extract_symbol_name(query: &str) -> String {
297    for separator in &["::", "\\", "->", "."] {
298        if let Some(idx) = query.rfind(separator) {
299            return query[idx + separator.len()..].to_string();
300        }
301    }
302    query.trim().to_string()
303}
304
305/// Return the boost amount for a chunk that defines one of `names`
306/// (0.0 if none match). Mirrors `_definition_tier`.
307fn definition_tier(chunk: &CodeChunk, names: &HashSet<String>, boost_unit: f32) -> f32 {
308    let any_match = names
309        .iter()
310        .any(|name| chunk_defines_symbol(&chunk.content, name));
311    if !any_match {
312        return 0.0;
313    }
314    let stem = Path::new(&chunk.file_path)
315        .file_stem()
316        .and_then(|s| s.to_str())
317        .unwrap_or_default()
318        .to_ascii_lowercase();
319    let stem_match_bonus = names
320        .iter()
321        .any(|name| stem_matches(&stem, &name.to_ascii_lowercase()));
322    boost_unit * if stem_match_bonus { 1.5 } else { 1.0 }
323}
324
325/// Boost non-candidate chunks whose lowercased file stem satisfies
326/// `stem_ok`. Mirrors `_scan_non_candidates`.
327fn scan_non_candidates(
328    boosted: &mut HashMap<usize, f32>,
329    names: &HashSet<String>,
330    boost_unit: f32,
331    all_chunks: &[CodeChunk],
332    stem_ok: &dyn Fn(&str) -> bool,
333) {
334    for (idx, chunk) in all_chunks.iter().enumerate() {
335        if boosted.contains_key(&idx) {
336            continue;
337        }
338        let stem = Path::new(&chunk.file_path)
339            .file_stem()
340            .and_then(|s| s.to_str())
341            .unwrap_or_default()
342            .to_ascii_lowercase();
343        if !stem_ok(&stem) {
344            continue;
345        }
346        let tier = definition_tier(chunk, names, boost_unit);
347        if tier > 0.0 {
348            boosted.insert(idx, tier);
349        }
350    }
351}
352
353/// Symbol-query branch: boost chunks defining the queried symbol and
354/// stem-matched non-candidates. Mirrors `_boost_symbol_definitions`.
355fn boost_symbol_definitions(
356    boosted: &mut HashMap<usize, f32>,
357    query: &str,
358    max_score: f32,
359    all_chunks: &[CodeChunk],
360) {
361    let symbol_name = extract_symbol_name(query);
362    let trimmed_query = query.trim();
363    let mut names: HashSet<String> = HashSet::new();
364    names.insert(symbol_name.clone());
365    if symbol_name != trimmed_query {
366        names.insert(trimmed_query.to_string());
367    }
368
369    let boost_unit = max_score * DEFINITION_BOOST_MULTIPLIER;
370
371    // Pass 1: walk current candidates, add a tier if they define the symbol.
372    let candidate_indices: Vec<usize> = boosted.keys().copied().collect();
373    for idx in candidate_indices {
374        let tier = definition_tier(&all_chunks[idx], &names, boost_unit);
375        if tier > 0.0 {
376            *boosted.entry(idx).or_insert(0.0) += tier;
377        }
378    }
379
380    // Pass 2: scan non-candidate chunks whose stem matches the symbol name.
381    let symbol_lower = symbol_name.to_ascii_lowercase();
382    scan_non_candidates(boosted, &names, boost_unit, all_chunks, &|stem: &str| {
383        stem_matches(stem, &symbol_lower)
384    });
385}
386
387/// NL-query branch: boost CamelCase / camelCase identifiers embedded in
388/// the query at half strength. Mirrors `_boost_embedded_symbols`.
389fn boost_embedded_symbols(
390    boosted: &mut HashMap<usize, f32>,
391    query: &str,
392    max_score: f32,
393    all_chunks: &[CodeChunk],
394) {
395    let names: HashSet<String> = embedded_symbol_re()
396        .find_iter(query)
397        .map(|m| m.as_str().to_string())
398        .collect();
399    if names.is_empty() {
400        return;
401    }
402
403    let boost_unit = max_score * DEFINITION_BOOST_MULTIPLIER * EMBEDDED_SYMBOL_BOOST_SCALE;
404
405    // Pass 1: candidates that define the embedded symbol(s).
406    let candidate_indices: Vec<usize> = boosted.keys().copied().collect();
407    for idx in candidate_indices {
408        let tier = definition_tier(&all_chunks[idx], &names, boost_unit);
409        if tier > 0.0 {
410            *boosted.entry(idx).or_insert(0.0) += tier;
411        }
412    }
413
414    // Pass 2: non-candidate stem-prefix scan.
415    let symbols_lower: Vec<String> = names.iter().map(|n| n.to_ascii_lowercase()).collect();
416    let symbols_lower_for_scan = symbols_lower.clone();
417    scan_non_candidates(
418        boosted,
419        &names,
420        boost_unit,
421        all_chunks,
422        &move |stem: &str| {
423            let stem_norm = stem.replace('_', "");
424            symbols_lower_for_scan.iter().any(|sym_lower| {
425                stem == sym_lower
426                    || stem_norm == *sym_lower
427                    || (stem.len() >= EMBEDDED_STEM_MIN_LEN && sym_lower.starts_with(stem))
428                    || (stem_norm.len() >= EMBEDDED_STEM_MIN_LEN
429                        && sym_lower.starts_with(stem_norm.as_str()))
430            })
431        },
432    );
433}
434
435/// Count query keywords that match path parts. Mirrors
436/// `_count_keyword_matches`. Allows prefix overlap when the shorter
437/// side has at least 3 characters.
438fn count_keyword_matches(keywords: &HashSet<String>, parts: &HashSet<String>) -> usize {
439    let exact: HashSet<&String> = keywords.iter().filter(|k| parts.contains(*k)).collect();
440    if exact.len() == keywords.len() {
441        return exact.len();
442    }
443    let mut n = exact.len();
444    for keyword in keywords {
445        if exact.contains(keyword) {
446            continue;
447        }
448        for part in parts {
449            let (shorter, longer) = if keyword.len() <= part.len() {
450                (keyword.as_str(), part.as_str())
451            } else {
452                (part.as_str(), keyword.as_str())
453            };
454            if shorter.len() >= 3 && longer.starts_with(shorter) {
455                n += 1;
456                break;
457            }
458        }
459    }
460    n
461}
462
463/// Boost chunks whose file paths match NL query keywords. Mirrors
464/// `_boost_stem_matches`. Uses prefix matching for morphological
465/// variants ("dependency" matches "dependencies"). Matches file stems
466/// and the immediate parent directory name.
467fn boost_stem_matches(
468    boosted: &mut HashMap<usize, f32>,
469    query: &str,
470    max_score: f32,
471    chunks: &[CodeChunk],
472) {
473    static KEYWORD_RE: OnceLock<Regex> = OnceLock::new();
474    let keyword_re =
475        KEYWORD_RE.get_or_init(|| Regex::new(r"[a-zA-Z_][a-zA-Z0-9_]*").expect("keyword regex"));
476    let keywords: HashSet<String> = keyword_re
477        .find_iter(query)
478        .map(|m| m.as_str().to_ascii_lowercase())
479        .filter(|w| w.len() > 2 && !STOPWORDS.contains(&w.as_str()))
480        .collect();
481    if keywords.is_empty() {
482        return;
483    }
484
485    let boost = max_score * STEM_BOOST_MULTIPLIER;
486    let mut path_cache: HashMap<String, HashSet<String>> = HashMap::new();
487    let candidate_indices: Vec<usize> = boosted.keys().copied().collect();
488    for idx in candidate_indices {
489        let path = &chunks[idx].file_path;
490        let parts = path_cache
491            .entry(path.clone())
492            .or_insert_with(|| {
493                let mut parts: HashSet<String> = HashSet::new();
494                let p = Path::new(path);
495                if let Some(stem) = p.file_stem().and_then(|s| s.to_str()) {
496                    parts.extend(split_identifier(stem));
497                }
498                if let Some(parent_name) = p
499                    .parent()
500                    .and_then(Path::file_name)
501                    .and_then(|s| s.to_str())
502                    && !parent_name.is_empty()
503                    && parent_name != "."
504                    && parent_name != ".."
505                {
506                    parts.extend(split_identifier(parent_name));
507                }
508                parts
509            })
510            .clone();
511        let n_matches = count_keyword_matches(&keywords, &parts);
512        if n_matches > 0 {
513            let match_ratio = n_matches as f32 / keywords.len() as f32;
514            if match_ratio >= 0.10 {
515                *boosted.entry(idx).or_insert(0.0) += boost * match_ratio;
516            }
517        }
518    }
519}
520
521// ---------------------------------------------------------------------------
522// Public entry: apply_query_boost.
523// ---------------------------------------------------------------------------
524
525/// Apply query-type boosts to candidate scores.
526///
527/// Mirrors `apply_query_boost`. Returns a new map; the input is not
528/// mutated. Empty input passes through.
529///
530/// Branches on [`is_symbol_query`]:
531/// - Symbol-shaped query → [`boost_symbol_definitions`] (×3 base,
532///   ×1.5 on stem match) and scan non-candidate stem-matched chunks.
533/// - NL query → [`boost_stem_matches`] (file-stem keyword overlap) +
534///   [`boost_embedded_symbols`] (half-strength CamelCase scan).
535#[expect(
536    clippy::implicit_hasher,
537    reason = "internal API; callers in the semble pipeline use the default RandomState"
538)]
539#[must_use]
540pub fn apply_query_boost(
541    combined_scores: &HashMap<usize, f32>,
542    query: &str,
543    all_chunks: &[CodeChunk],
544) -> HashMap<usize, f32> {
545    if combined_scores.is_empty() {
546        return HashMap::new();
547    }
548    let max_score = combined_scores
549        .values()
550        .copied()
551        .fold(f32::NEG_INFINITY, f32::max);
552    let mut boosted = combined_scores.clone();
553    if is_symbol_query(query) {
554        boost_symbol_definitions(&mut boosted, query, max_score, all_chunks);
555    } else {
556        boost_stem_matches(&mut boosted, query, max_score, all_chunks);
557        boost_embedded_symbols(&mut boosted, query, max_score, all_chunks);
558    }
559    boosted
560}
561
562/// Promote files with multiple high-scoring chunks by boosting their
563/// top chunk in place. Mirrors `boost_multi_chunk_files`.
564#[expect(
565    clippy::implicit_hasher,
566    reason = "internal API; callers in the semble pipeline use the default RandomState"
567)]
568pub fn boost_multi_chunk_files(scores: &mut HashMap<usize, f32>, chunks: &[CodeChunk]) {
569    if scores.is_empty() {
570        return;
571    }
572    let max_score = scores.values().copied().fold(f32::NEG_INFINITY, f32::max);
573    if max_score == 0.0 || !max_score.is_finite() {
574        return;
575    }
576
577    let mut file_sum: HashMap<String, f32> = HashMap::new();
578    let mut best_chunk_idx: HashMap<String, usize> = HashMap::new();
579    for (&idx, &score) in scores.iter() {
580        let path = chunks[idx].file_path.clone();
581        *file_sum.entry(path.clone()).or_insert(0.0) += score;
582        match best_chunk_idx.get(&path) {
583            Some(&best) if scores[&best] >= score => {}
584            _ => {
585                best_chunk_idx.insert(path, idx);
586            }
587        }
588    }
589    let max_file_sum = file_sum.values().copied().fold(f32::NEG_INFINITY, f32::max);
590    if max_file_sum <= 0.0 || !max_file_sum.is_finite() {
591        return;
592    }
593    let boost_unit = max_score * FILE_COHERENCE_BOOST_FRAC;
594    for (path, &idx) in &best_chunk_idx {
595        let contribution = boost_unit * file_sum[path] / max_file_sum;
596        *scores.entry(idx).or_insert(0.0) += contribution;
597    }
598}
599
600#[cfg(test)]
601mod tests {
602    use super::*;
603
604    fn chunk(path: &str, content: &str) -> CodeChunk {
605        CodeChunk {
606            file_path: path.to_string(),
607            name: String::new(),
608            kind: String::new(),
609            content_kind: crate::chunk::ContentKind::Code,
610            start_line: 1,
611            symbol_line: 1,
612            end_line: 1,
613            content: content.to_string(),
614            enriched_content: content.to_string(),
615            qualified_name: None,
616        }
617    }
618
619    // ----- is_symbol_query (boosting.py:132) -----
620
621    #[test]
622    fn is_symbol_query_namespace() {
623        assert!(is_symbol_query("Sinatra::Base"));
624        assert!(is_symbol_query("module.Class"));
625        assert!(is_symbol_query("a->b->c"));
626        assert!(is_symbol_query(r"Foo\Bar"));
627    }
628
629    #[test]
630    fn is_symbol_query_pascal() {
631        assert!(is_symbol_query("Client"));
632        assert!(is_symbol_query("HTTPHandler"));
633        assert!(is_symbol_query("XMLParser"));
634    }
635
636    #[test]
637    fn is_symbol_query_plain_word_rejected() {
638        assert!(!is_symbol_query("session"));
639        assert!(!is_symbol_query("retry"));
640        assert!(!is_symbol_query("authentication"));
641    }
642
643    #[test]
644    fn is_symbol_query_leading_underscore_accepted() {
645        assert!(is_symbol_query("_private"));
646        assert!(is_symbol_query("__init__"));
647    }
648
649    // ----- resolve_alpha -----
650
651    #[test]
652    fn resolve_alpha_symbol_0_3() {
653        assert!((resolve_alpha("Client", None) - ALPHA_SYMBOL).abs() < 1e-6);
654        assert!((resolve_alpha("foo.Bar", None) - ALPHA_SYMBOL).abs() < 1e-6);
655    }
656
657    #[test]
658    fn resolve_alpha_nl_0_5() {
659        assert!((resolve_alpha("how does retry work", None) - ALPHA_NL).abs() < 1e-6);
660        assert!((resolve_alpha("authentication handling", None) - ALPHA_NL).abs() < 1e-6);
661    }
662
663    #[test]
664    fn resolve_alpha_explicit_override_wins() {
665        assert!((resolve_alpha("Client", Some(0.7)) - 0.7).abs() < 1e-6);
666    }
667
668    // ----- definition_pattern + chunk_defines_symbol -----
669
670    #[test]
671    fn chunk_defines_symbol_class() {
672        let content = "class Client:\n    pass";
673        assert!(chunk_defines_symbol(content, "Client"));
674    }
675
676    #[test]
677    fn chunk_defines_symbol_def() {
678        let content = "def handle_request():\n    pass";
679        assert!(chunk_defines_symbol(content, "handle_request"));
680    }
681
682    #[test]
683    fn chunk_defines_symbol_namespace_qualified() {
684        // Elixir defmodule with namespace prefix should match the bare symbol.
685        let content = " defmodule Phoenix.Router do\n";
686        assert!(chunk_defines_symbol(content, "Router"));
687    }
688
689    #[test]
690    fn chunk_defines_symbol_sql_case_insensitive() {
691        assert!(chunk_defines_symbol(
692            " create table users (id int)",
693            "users"
694        ));
695        assert!(chunk_defines_symbol(
696            " CREATE TABLE Users (id int)",
697            "Users"
698        ));
699    }
700
701    #[test]
702    fn chunk_defines_symbol_negative() {
703        assert!(!chunk_defines_symbol("client.do_thing()", "Client"));
704    }
705
706    // ----- boost_symbol_definitions stem multiplier -----
707
708    #[test]
709    fn boost_symbol_definitions_stem_multiplier() {
710        // Two chunks both define `Client`; one is in client.rs (stem match),
711        // the other in unrelated.rs. The stem-matched chunk gets the 1.5x
712        // bonus on top of the base ×3 multiplier.
713        let chunks = vec![
714            chunk("src/client.rs", "struct Client { /* ... */ }"),
715            chunk("src/unrelated.rs", "struct Client { /* ... */ }"),
716        ];
717        let mut boosted: HashMap<usize, f32> = HashMap::from([(0, 1.0), (1, 1.0)]);
718        boost_symbol_definitions(&mut boosted, "Client", 1.0, &chunks);
719        // base_boost = max_score (1.0) * 3.0 = 3.0
720        // stem-matched chunk 0: 1.0 + 3.0 * 1.5 = 5.5
721        // non-matched chunk 1: 1.0 + 3.0 = 4.0
722        assert!((boosted[&0] - 5.5).abs() < 1e-6);
723        assert!((boosted[&1] - 4.0).abs() < 1e-6);
724    }
725
726    // ----- boost_stem_matches -----
727
728    #[test]
729    fn boost_stem_matches_prefix() {
730        // Query keyword "parse" should boost a chunk in "parser.rs" via
731        // prefix overlap (shorter="parse", longer="parser", min-length 3
732        // satisfied). This is the real morphological-variant shape;
733        // Python semble's docstring example ("dependency" ↔ "dependencies")
734        // is misleading because the two diverge at char 9.
735        let chunks = vec![chunk("src/parser.rs", "fn run() {}")];
736        let mut boosted: HashMap<usize, f32> = HashMap::from([(0, 1.0_f32)]);
737        boost_stem_matches(&mut boosted, "parse json structure", 1.0, &chunks);
738        assert!(boosted[&0] > 1.0, "expected stem-match boost on parser.rs");
739    }
740
741    // ----- boost_embedded_symbols half-strength -----
742
743    #[test]
744    fn boost_embedded_symbols_half_strength() {
745        // Embedded symbol "MyClass" in an NL query boosts at half strength
746        // vs a pure symbol query.
747        let chunks = vec![chunk("src/myclass.rs", "struct MyClass {}")];
748        let mut boosted: HashMap<usize, f32> = HashMap::from([(0, 1.0_f32)]);
749        boost_embedded_symbols(&mut boosted, "how does MyClass handle errors", 1.0, &chunks);
750        // base_boost = 1.0 * 3.0 * 0.5 = 1.5
751        // stem-matched chunk 0: 1.0 + 1.5 * 1.5 = 3.25
752        assert!((boosted[&0] - 3.25).abs() < 1e-6, "got {}", boosted[&0]);
753    }
754
755    // ----- boost_multi_chunk_files (file coherence) -----
756
757    #[test]
758    fn boost_multi_chunk_files() {
759        // Three chunks: two in foo.rs (sum 1.5), one in bar.rs (sum 1.0).
760        // boost_unit = max_score (1.0) * 0.2 = 0.2
761        // foo.rs file_sum 1.5; boost to its TOP chunk (idx 0) =
762        //   0.2 * 1.5 / 1.5 = 0.2 → final 1.2.
763        // bar.rs file_sum 1.0; boost to chunk 2 = 0.2 * 1.0/1.5 ≈ 0.133 → ~1.133.
764        let chunks = vec![
765            chunk("src/foo.rs", ""),
766            chunk("src/foo.rs", ""),
767            chunk("src/bar.rs", ""),
768        ];
769        let mut scores: HashMap<usize, f32> = HashMap::from([(0, 1.0), (1, 0.5), (2, 1.0)]);
770        super::boost_multi_chunk_files(&mut scores, &chunks);
771        assert!((scores[&0] - 1.2).abs() < 1e-6, "got {}", scores[&0]);
772        assert!((scores[&1] - 0.5).abs() < 1e-6, "non-best chunk unchanged");
773        let expected_bar = 1.0 + 0.2 * (1.0 / 1.5);
774        assert!(
775            (scores[&2] - expected_bar).abs() < 1e-6,
776            "got {}, expected {}",
777            scores[&2],
778            expected_bar
779        );
780    }
781
782    // ----- property test (symbol regex parity) -----
783
784    #[test]
785    fn property_symbol_regex_parity_python() {
786        // These should be detected as symbol queries (matches Python).
787        let symbols = &[
788            "Client",
789            "handle_request",
790            "_private",
791            "getX",
792            "XMLParser",
793            "foo::bar",
794            "foo.bar.baz",
795            "a->b",
796            r"Foo\Bar",
797            "__init__",
798            "snake_case",
799        ];
800        for q in symbols {
801            assert!(is_symbol_query(q), "expected symbol query: {q:?}");
802        }
803        // These should NOT be detected (NL).
804        let non_symbols = &[
805            "session",
806            "retry",
807            "authentication",
808            "how does retry work",
809            "user authentication flow",
810            "hi",
811        ];
812        for q in non_symbols {
813            assert!(!is_symbol_query(q), "expected NL: {q:?}");
814        }
815    }
816}