Skip to main content

ripvec_core/encoder/ripvec/
ranking.rs

1//! Alpha auto-detection and query-driven boosting.
2//!
3//! Port of `~/src/semble/src/semble/ranking/weighting.py` and
4//! `~/src/semble/src/semble/ranking/boosting.py`.
5//!
6//! Three public entry points:
7//!
8//! - [`resolve_alpha`] — picks the semantic/BM25 blend weight from the
9//!   query shape: 0.3 for bare-symbol queries (lean BM25), 0.5 for
10//!   natural-language queries (balanced).
11//! - [`apply_query_boost`] — adds query-type boosts on top of a score
12//!   map. Returns a new map; callers re-rank afterwards.
13//! - [`boost_multi_chunk_files`] — file-coherence boost; promotes the
14//!   top chunk of files whose chunks collectively score high.
15//!
16//! Where Python keys `combined_scores` by `Chunk`, this port uses
17//! `HashMap<usize, f32>` (chunk index → score) plus `&[CodeChunk]` for
18//! lookups. Same shape as [`crate::encoder::ripvec::penalties::rerank_topk`].
19
20use std::collections::{HashMap, HashSet};
21use std::path::Path;
22use std::sync::OnceLock;
23
24use regex::{Regex, RegexBuilder};
25
26use crate::chunk::CodeChunk;
27use crate::encoder::ripvec::tokens::split_identifier;
28
29// ---------------------------------------------------------------------------
30// Alpha selection (weighting.py).
31// ---------------------------------------------------------------------------
32
33/// Semantic blend weight for symbol-shaped queries. Lean BM25.
34pub const ALPHA_SYMBOL: f32 = 0.3;
35/// Semantic blend weight for natural-language queries. Balanced.
36pub const ALPHA_NL: f32 = 0.5;
37
38/// Return the semantic blend weight, optionally overriding by caller.
39///
40/// `alpha = Some(w)` returns `w` directly. `alpha = None` auto-detects
41/// from the query: bare symbol-shaped → [`ALPHA_SYMBOL`], otherwise
42/// [`ALPHA_NL`].
43#[must_use]
44pub fn resolve_alpha(query: &str, alpha: Option<f32>) -> f32 {
45    if let Some(w) = alpha {
46        return w;
47    }
48    if is_symbol_query(query) {
49        ALPHA_SYMBOL
50    } else {
51        ALPHA_NL
52    }
53}
54
55// ---------------------------------------------------------------------------
56// Symbol-query detection (boosting.py:11).
57// ---------------------------------------------------------------------------
58
59/// Symbol-lookup queries: namespace-qualified, leading underscore, or
60/// containing uppercase/underscore. Plain lowercase words are NL.
61fn symbol_query_re() -> &'static Regex {
62    static RE: OnceLock<Regex> = OnceLock::new();
63    RE.get_or_init(|| {
64        Regex::new(concat!(
65            r"^(?:",
66            // namespace-qualified
67            r"[A-Za-z_][A-Za-z0-9_]*(?:(?:::|\\|->|\.)[A-Za-z_][A-Za-z0-9_]*)+",
68            // leading underscore
69            r"|_[A-Za-z0-9_]*",
70            // contains uppercase or underscore in the body
71            r"|[A-Za-z][A-Za-z0-9]*[A-Z_][A-Za-z0-9_]*",
72            // starts with uppercase
73            r"|[A-Z][A-Za-z0-9]*",
74            r")$",
75        ))
76        .expect("symbol-query regex compiles")
77    })
78}
79
80/// CamelCase / camelCase identifiers embedded in an NL query.
81/// Excludes plain words and pure acronyms.
82fn embedded_symbol_re() -> &'static Regex {
83    static RE: OnceLock<Regex> = OnceLock::new();
84    RE.get_or_init(|| {
85        Regex::new(concat!(
86            r"\b(?:",
87            // PascalCase: upper, lower run, then upper, then mix.
88            r"[A-Z][a-z][a-zA-Z0-9]*[A-Z][a-zA-Z0-9]*",
89            // camelCase: lower, mix, then upper, then mix.
90            r"|[a-z][a-zA-Z0-9]*[A-Z][a-zA-Z0-9]+",
91            r")\b",
92        ))
93        .expect("embedded-symbol regex compiles")
94    })
95}
96
97/// Return `true` when the query looks like a bare symbol or
98/// namespace-qualified identifier (`Foo::Bar`, `module.Class`, `_x`,
99/// `getX`, `XMLParser`, etc.).
100#[must_use]
101pub fn is_symbol_query(query: &str) -> bool {
102    symbol_query_re().is_match(query.trim())
103}
104
105/// Return `true` when `path` is a prose / documentation file
106/// (markdown, rst, plain text, asciidoc, org).
107///
108/// Used to gate the L-12 cross-encoder reranker on the ripvec engine:
109/// the ms-marco-trained model helps on natural-language prose but
110/// measurably hurts code retrieval (see `examples/semble_bench.rs`
111/// against the semble python corpus: -0.060 NDCG@10 average across
112/// 8 repos with rerank vs without). Code corpora skip rerank; prose
113/// corpora keep it.
114#[must_use]
115pub fn is_prose_path(path: &str) -> bool {
116    let lower = path.to_ascii_lowercase();
117    let ext = lower.rsplit('.').next().unwrap_or("");
118    matches!(
119        ext,
120        "md" | "markdown" | "mdx" | "rst" | "txt" | "text" | "adoc" | "asciidoc" | "org"
121    )
122}
123
124// ---------------------------------------------------------------------------
125// Definition-keyword scan (boosting.py:36).
126// ---------------------------------------------------------------------------
127
128/// Language-agnostic definition keywords (Python, JS, Go, Rust, Kotlin,
129/// Elixir, Swift, etc.). Case-sensitive matching avoids false positives
130/// like "Module" appearing in Python docstrings.
131const DEFINITION_KEYWORDS: &[&str] = &[
132    "class",
133    "module",
134    "defmodule", // Elixir
135    "def",
136    "interface",
137    "struct",
138    "enum",
139    "trait",
140    "type",
141    "func",
142    "function",
143    "object",
144    "abstract class",
145    "data class",
146    "fn",
147    "fun", // Kotlin
148    "package",
149    "namespace",
150    "protocol", // Swift
151    "record",   // C# 9+, Java 16+
152    "typedef",  // C/C++/Dart
153];
154
155/// SQL DDL is conventionally all-caps or all-lowercase; match both via
156/// case-insensitive regex.
157const SQL_DEFINITION_KEYWORDS: &[&str] = &[
158    "CREATE TABLE",
159    "CREATE VIEW",
160    "CREATE PROCEDURE",
161    "CREATE FUNCTION",
162];
163
164const DEFINITION_BOOST_MULTIPLIER: f32 = 3.0;
165const STEM_BOOST_MULTIPLIER: f32 = 1.0;
166const FILE_COHERENCE_BOOST_FRAC: f32 = 0.2;
167const EMBEDDED_SYMBOL_BOOST_SCALE: f32 = 0.5;
168const EMBEDDED_STEM_MIN_LEN: usize = 4;
169
170/// Common English stopwords excluded from file-stem matching for NL
171/// queries. Mirrors `_STOPWORDS` from boosting.py:82.
172const STOPWORDS: &[&str] = &[
173    "a", "an", "and", "are", "as", "at", "be", "by", "do", "does", "for", "from", "has", "have",
174    "how", "if", "in", "is", "it", "not", "of", "on", "or", "the", "to", "was", "what", "when",
175    "where", "which", "who", "why", "with",
176];
177
178/// Build the general-keyword definition regex for a given symbol name.
179///
180/// Mirrors Python's `_definition_pattern`. Returns the compiled regex
181/// each call (Python uses `functools.lru_cache(256)` — we forgo the
182/// cache here; the hot path in `apply_query_boost` calls this a
183/// constant number of times per query).
184fn definition_pattern(symbol_name: &str) -> (Regex, Regex) {
185    let escaped = regex::escape(symbol_name);
186    let ns_prefix = r"(?:[A-Za-z_][A-Za-z0-9_]*(?:\.|::))*";
187    // Python's pattern uses `(?:^|(?<=\s))(?:keywords)` — keyword at start-of-line
188    // or preceded by whitespace. Rust's RE2 has no lookbehind, so we use
189    // `(?:^|\s)(?:keywords)` and let the whitespace be consumed; semantically
190    // equivalent for definition-keyword detection.
191    let def_body = DEFINITION_KEYWORDS
192        .iter()
193        .map(|k| regex::escape(k))
194        .collect::<Vec<_>>()
195        .join("|");
196    let sql_body = SQL_DEFINITION_KEYWORDS
197        .iter()
198        .map(|k| regex::escape(k))
199        .collect::<Vec<_>>()
200        .join("|");
201    let suffix = format!(r")\s+{ns_prefix}{escaped}(?:\s|[<({{:\[;]|$)");
202    // The Rust `regex` crate uses RE2 and does not support `(?<=\s)` (lookbehind).
203    // Python's pattern is `(?:^|(?<=\s))(?:keywords)` — keywords appear at start of line
204    // or after whitespace. Equivalent without lookbehind: anchor on `(?:^|\s)` and include
205    // the whitespace character in the match (rather than as a lookbehind boundary). The
206    // semantic is identical for definition detection.
207    let no_lookbehind_prefix = r"(?:^|\s)(?:";
208    let general_pat = format!("{no_lookbehind_prefix}{def_body}{suffix}");
209    let sql_pat = format!("{no_lookbehind_prefix}{sql_body}{suffix}");
210    let general = RegexBuilder::new(&general_pat)
211        .multi_line(true)
212        .build()
213        .expect("general definition regex compiles");
214    let sql = RegexBuilder::new(&sql_pat)
215        .multi_line(true)
216        .case_insensitive(true)
217        .build()
218        .expect("SQL definition regex compiles");
219    (general, sql)
220}
221
222/// `true` when `content` contains a definition of `symbol_name`.
223///
224/// Mirrors Python's `_chunk_defines_symbol`. Case-sensitive for general
225/// keywords; case-insensitive for SQL DDL. Namespace-qualified forms
226/// (`defmodule Phoenix.Router` for `Router`) match because the pattern
227/// allows an optional `ns_prefix`.
228fn chunk_defines_symbol(content: &str, symbol_name: &str) -> bool {
229    let (general, sql) = definition_pattern(symbol_name);
230    general.is_match(content) || sql.is_match(content)
231}
232
233/// `true` when `stem` matches `name` (exact, snake_case-normalised, or
234/// plural). Mirrors Python's `_stem_matches`.
235fn stem_matches(stem: &str, name: &str) -> bool {
236    let stem_norm = stem.replace('_', "");
237    stem == name
238        || stem_norm == name
239        || stem.trim_end_matches('s') == name
240        || stem_norm.trim_end_matches('s') == name
241}
242
243// ---------------------------------------------------------------------------
244// Symbol extraction (boosting.py:137).
245// ---------------------------------------------------------------------------
246
247/// Extract the final identifier from a possibly namespace-qualified
248/// query. Mirrors `_extract_symbol_name`.
249///
250/// Examples: `Sinatra::Base` → `Base`, `Client` → `Client`.
251fn extract_symbol_name(query: &str) -> String {
252    for separator in &["::", "\\", "->", "."] {
253        if let Some(idx) = query.rfind(separator) {
254            return query[idx + separator.len()..].to_string();
255        }
256    }
257    query.trim().to_string()
258}
259
260/// Return the boost amount for a chunk that defines one of `names`
261/// (0.0 if none match). Mirrors `_definition_tier`.
262fn definition_tier(chunk: &CodeChunk, names: &HashSet<String>, boost_unit: f32) -> f32 {
263    let any_match = names
264        .iter()
265        .any(|name| chunk_defines_symbol(&chunk.content, name));
266    if !any_match {
267        return 0.0;
268    }
269    let stem = Path::new(&chunk.file_path)
270        .file_stem()
271        .and_then(|s| s.to_str())
272        .unwrap_or_default()
273        .to_ascii_lowercase();
274    let stem_match_bonus = names
275        .iter()
276        .any(|name| stem_matches(&stem, &name.to_ascii_lowercase()));
277    boost_unit * if stem_match_bonus { 1.5 } else { 1.0 }
278}
279
280/// Boost non-candidate chunks whose lowercased file stem satisfies
281/// `stem_ok`. Mirrors `_scan_non_candidates`.
282fn scan_non_candidates(
283    boosted: &mut HashMap<usize, f32>,
284    names: &HashSet<String>,
285    boost_unit: f32,
286    all_chunks: &[CodeChunk],
287    stem_ok: &dyn Fn(&str) -> bool,
288) {
289    for (idx, chunk) in all_chunks.iter().enumerate() {
290        if boosted.contains_key(&idx) {
291            continue;
292        }
293        let stem = Path::new(&chunk.file_path)
294            .file_stem()
295            .and_then(|s| s.to_str())
296            .unwrap_or_default()
297            .to_ascii_lowercase();
298        if !stem_ok(&stem) {
299            continue;
300        }
301        let tier = definition_tier(chunk, names, boost_unit);
302        if tier > 0.0 {
303            boosted.insert(idx, tier);
304        }
305    }
306}
307
308/// Symbol-query branch: boost chunks defining the queried symbol and
309/// stem-matched non-candidates. Mirrors `_boost_symbol_definitions`.
310fn boost_symbol_definitions(
311    boosted: &mut HashMap<usize, f32>,
312    query: &str,
313    max_score: f32,
314    all_chunks: &[CodeChunk],
315) {
316    let symbol_name = extract_symbol_name(query);
317    let trimmed_query = query.trim();
318    let mut names: HashSet<String> = HashSet::new();
319    names.insert(symbol_name.clone());
320    if symbol_name != trimmed_query {
321        names.insert(trimmed_query.to_string());
322    }
323
324    let boost_unit = max_score * DEFINITION_BOOST_MULTIPLIER;
325
326    // Pass 1: walk current candidates, add a tier if they define the symbol.
327    let candidate_indices: Vec<usize> = boosted.keys().copied().collect();
328    for idx in candidate_indices {
329        let tier = definition_tier(&all_chunks[idx], &names, boost_unit);
330        if tier > 0.0 {
331            *boosted.entry(idx).or_insert(0.0) += tier;
332        }
333    }
334
335    // Pass 2: scan non-candidate chunks whose stem matches the symbol name.
336    let symbol_lower = symbol_name.to_ascii_lowercase();
337    scan_non_candidates(boosted, &names, boost_unit, all_chunks, &|stem: &str| {
338        stem_matches(stem, &symbol_lower)
339    });
340}
341
342/// NL-query branch: boost CamelCase / camelCase identifiers embedded in
343/// the query at half strength. Mirrors `_boost_embedded_symbols`.
344fn boost_embedded_symbols(
345    boosted: &mut HashMap<usize, f32>,
346    query: &str,
347    max_score: f32,
348    all_chunks: &[CodeChunk],
349) {
350    let names: HashSet<String> = embedded_symbol_re()
351        .find_iter(query)
352        .map(|m| m.as_str().to_string())
353        .collect();
354    if names.is_empty() {
355        return;
356    }
357
358    let boost_unit = max_score * DEFINITION_BOOST_MULTIPLIER * EMBEDDED_SYMBOL_BOOST_SCALE;
359
360    // Pass 1: candidates that define the embedded symbol(s).
361    let candidate_indices: Vec<usize> = boosted.keys().copied().collect();
362    for idx in candidate_indices {
363        let tier = definition_tier(&all_chunks[idx], &names, boost_unit);
364        if tier > 0.0 {
365            *boosted.entry(idx).or_insert(0.0) += tier;
366        }
367    }
368
369    // Pass 2: non-candidate stem-prefix scan.
370    let symbols_lower: Vec<String> = names.iter().map(|n| n.to_ascii_lowercase()).collect();
371    let symbols_lower_for_scan = symbols_lower.clone();
372    scan_non_candidates(
373        boosted,
374        &names,
375        boost_unit,
376        all_chunks,
377        &move |stem: &str| {
378            let stem_norm = stem.replace('_', "");
379            symbols_lower_for_scan.iter().any(|sym_lower| {
380                stem == sym_lower
381                    || stem_norm == *sym_lower
382                    || (stem.len() >= EMBEDDED_STEM_MIN_LEN && sym_lower.starts_with(stem))
383                    || (stem_norm.len() >= EMBEDDED_STEM_MIN_LEN
384                        && sym_lower.starts_with(stem_norm.as_str()))
385            })
386        },
387    );
388}
389
390/// Count query keywords that match path parts. Mirrors
391/// `_count_keyword_matches`. Allows prefix overlap when the shorter
392/// side has at least 3 characters.
393fn count_keyword_matches(keywords: &HashSet<String>, parts: &HashSet<String>) -> usize {
394    let exact: HashSet<&String> = keywords.iter().filter(|k| parts.contains(*k)).collect();
395    if exact.len() == keywords.len() {
396        return exact.len();
397    }
398    let mut n = exact.len();
399    for keyword in keywords {
400        if exact.contains(keyword) {
401            continue;
402        }
403        for part in parts {
404            let (shorter, longer) = if keyword.len() <= part.len() {
405                (keyword.as_str(), part.as_str())
406            } else {
407                (part.as_str(), keyword.as_str())
408            };
409            if shorter.len() >= 3 && longer.starts_with(shorter) {
410                n += 1;
411                break;
412            }
413        }
414    }
415    n
416}
417
418/// Boost chunks whose file paths match NL query keywords. Mirrors
419/// `_boost_stem_matches`. Uses prefix matching for morphological
420/// variants ("dependency" matches "dependencies"). Matches file stems
421/// and the immediate parent directory name.
422fn boost_stem_matches(
423    boosted: &mut HashMap<usize, f32>,
424    query: &str,
425    max_score: f32,
426    chunks: &[CodeChunk],
427) {
428    static KEYWORD_RE: OnceLock<Regex> = OnceLock::new();
429    let keyword_re =
430        KEYWORD_RE.get_or_init(|| Regex::new(r"[a-zA-Z_][a-zA-Z0-9_]*").expect("keyword regex"));
431    let keywords: HashSet<String> = keyword_re
432        .find_iter(query)
433        .map(|m| m.as_str().to_ascii_lowercase())
434        .filter(|w| w.len() > 2 && !STOPWORDS.contains(&w.as_str()))
435        .collect();
436    if keywords.is_empty() {
437        return;
438    }
439
440    let boost = max_score * STEM_BOOST_MULTIPLIER;
441    let mut path_cache: HashMap<String, HashSet<String>> = HashMap::new();
442    let candidate_indices: Vec<usize> = boosted.keys().copied().collect();
443    for idx in candidate_indices {
444        let path = &chunks[idx].file_path;
445        let parts = path_cache
446            .entry(path.clone())
447            .or_insert_with(|| {
448                let mut parts: HashSet<String> = HashSet::new();
449                let p = Path::new(path);
450                if let Some(stem) = p.file_stem().and_then(|s| s.to_str()) {
451                    parts.extend(split_identifier(stem));
452                }
453                if let Some(parent_name) = p
454                    .parent()
455                    .and_then(Path::file_name)
456                    .and_then(|s| s.to_str())
457                    && !parent_name.is_empty()
458                    && parent_name != "."
459                    && parent_name != ".."
460                {
461                    parts.extend(split_identifier(parent_name));
462                }
463                parts
464            })
465            .clone();
466        let n_matches = count_keyword_matches(&keywords, &parts);
467        if n_matches > 0 {
468            let match_ratio = n_matches as f32 / keywords.len() as f32;
469            if match_ratio >= 0.10 {
470                *boosted.entry(idx).or_insert(0.0) += boost * match_ratio;
471            }
472        }
473    }
474}
475
476// ---------------------------------------------------------------------------
477// Public entry: apply_query_boost.
478// ---------------------------------------------------------------------------
479
480/// Apply query-type boosts to candidate scores.
481///
482/// Mirrors `apply_query_boost`. Returns a new map; the input is not
483/// mutated. Empty input passes through.
484///
485/// Branches on [`is_symbol_query`]:
486/// - Symbol-shaped query → [`boost_symbol_definitions`] (×3 base,
487///   ×1.5 on stem match) and scan non-candidate stem-matched chunks.
488/// - NL query → [`boost_stem_matches`] (file-stem keyword overlap) +
489///   [`boost_embedded_symbols`] (half-strength CamelCase scan).
490#[expect(
491    clippy::implicit_hasher,
492    reason = "internal API; callers in the semble pipeline use the default RandomState"
493)]
494#[must_use]
495pub fn apply_query_boost(
496    combined_scores: &HashMap<usize, f32>,
497    query: &str,
498    all_chunks: &[CodeChunk],
499) -> HashMap<usize, f32> {
500    if combined_scores.is_empty() {
501        return HashMap::new();
502    }
503    let max_score = combined_scores
504        .values()
505        .copied()
506        .fold(f32::NEG_INFINITY, f32::max);
507    let mut boosted = combined_scores.clone();
508    if is_symbol_query(query) {
509        boost_symbol_definitions(&mut boosted, query, max_score, all_chunks);
510    } else {
511        boost_stem_matches(&mut boosted, query, max_score, all_chunks);
512        boost_embedded_symbols(&mut boosted, query, max_score, all_chunks);
513    }
514    boosted
515}
516
517/// Promote files with multiple high-scoring chunks by boosting their
518/// top chunk in place. Mirrors `boost_multi_chunk_files`.
519#[expect(
520    clippy::implicit_hasher,
521    reason = "internal API; callers in the semble pipeline use the default RandomState"
522)]
523pub fn boost_multi_chunk_files(scores: &mut HashMap<usize, f32>, chunks: &[CodeChunk]) {
524    if scores.is_empty() {
525        return;
526    }
527    let max_score = scores.values().copied().fold(f32::NEG_INFINITY, f32::max);
528    if max_score == 0.0 || !max_score.is_finite() {
529        return;
530    }
531
532    let mut file_sum: HashMap<String, f32> = HashMap::new();
533    let mut best_chunk_idx: HashMap<String, usize> = HashMap::new();
534    for (&idx, &score) in scores.iter() {
535        let path = chunks[idx].file_path.clone();
536        *file_sum.entry(path.clone()).or_insert(0.0) += score;
537        match best_chunk_idx.get(&path) {
538            Some(&best) if scores[&best] >= score => {}
539            _ => {
540                best_chunk_idx.insert(path, idx);
541            }
542        }
543    }
544    let max_file_sum = file_sum.values().copied().fold(f32::NEG_INFINITY, f32::max);
545    if max_file_sum <= 0.0 || !max_file_sum.is_finite() {
546        return;
547    }
548    let boost_unit = max_score * FILE_COHERENCE_BOOST_FRAC;
549    for (path, &idx) in &best_chunk_idx {
550        let contribution = boost_unit * file_sum[path] / max_file_sum;
551        *scores.entry(idx).or_insert(0.0) += contribution;
552    }
553}
554
555#[cfg(test)]
556mod tests {
557    use super::*;
558
559    fn chunk(path: &str, content: &str) -> CodeChunk {
560        CodeChunk {
561            file_path: path.to_string(),
562            name: String::new(),
563            kind: String::new(),
564            start_line: 1,
565            end_line: 1,
566            content: content.to_string(),
567            enriched_content: content.to_string(),
568        }
569    }
570
571    // ----- is_symbol_query (boosting.py:132) -----
572
573    #[test]
574    fn is_symbol_query_namespace() {
575        assert!(is_symbol_query("Sinatra::Base"));
576        assert!(is_symbol_query("module.Class"));
577        assert!(is_symbol_query("a->b->c"));
578        assert!(is_symbol_query(r"Foo\Bar"));
579    }
580
581    #[test]
582    fn is_symbol_query_pascal() {
583        assert!(is_symbol_query("Client"));
584        assert!(is_symbol_query("HTTPHandler"));
585        assert!(is_symbol_query("XMLParser"));
586    }
587
588    #[test]
589    fn is_symbol_query_plain_word_rejected() {
590        assert!(!is_symbol_query("session"));
591        assert!(!is_symbol_query("retry"));
592        assert!(!is_symbol_query("authentication"));
593    }
594
595    #[test]
596    fn is_symbol_query_leading_underscore_accepted() {
597        assert!(is_symbol_query("_private"));
598        assert!(is_symbol_query("__init__"));
599    }
600
601    // ----- resolve_alpha -----
602
603    #[test]
604    fn resolve_alpha_symbol_0_3() {
605        assert!((resolve_alpha("Client", None) - ALPHA_SYMBOL).abs() < 1e-6);
606        assert!((resolve_alpha("foo.Bar", None) - ALPHA_SYMBOL).abs() < 1e-6);
607    }
608
609    #[test]
610    fn resolve_alpha_nl_0_5() {
611        assert!((resolve_alpha("how does retry work", None) - ALPHA_NL).abs() < 1e-6);
612        assert!((resolve_alpha("authentication handling", None) - ALPHA_NL).abs() < 1e-6);
613    }
614
615    #[test]
616    fn resolve_alpha_explicit_override_wins() {
617        assert!((resolve_alpha("Client", Some(0.7)) - 0.7).abs() < 1e-6);
618    }
619
620    // ----- definition_pattern + chunk_defines_symbol -----
621
622    #[test]
623    fn chunk_defines_symbol_class() {
624        let content = "class Client:\n    pass";
625        assert!(chunk_defines_symbol(content, "Client"));
626    }
627
628    #[test]
629    fn chunk_defines_symbol_def() {
630        let content = "def handle_request():\n    pass";
631        assert!(chunk_defines_symbol(content, "handle_request"));
632    }
633
634    #[test]
635    fn chunk_defines_symbol_namespace_qualified() {
636        // Elixir defmodule with namespace prefix should match the bare symbol.
637        let content = " defmodule Phoenix.Router do\n";
638        assert!(chunk_defines_symbol(content, "Router"));
639    }
640
641    #[test]
642    fn chunk_defines_symbol_sql_case_insensitive() {
643        assert!(chunk_defines_symbol(
644            " create table users (id int)",
645            "users"
646        ));
647        assert!(chunk_defines_symbol(
648            " CREATE TABLE Users (id int)",
649            "Users"
650        ));
651    }
652
653    #[test]
654    fn chunk_defines_symbol_negative() {
655        assert!(!chunk_defines_symbol("client.do_thing()", "Client"));
656    }
657
658    // ----- boost_symbol_definitions stem multiplier -----
659
660    #[test]
661    fn boost_symbol_definitions_stem_multiplier() {
662        // Two chunks both define `Client`; one is in client.rs (stem match),
663        // the other in unrelated.rs. The stem-matched chunk gets the 1.5x
664        // bonus on top of the base ×3 multiplier.
665        let chunks = vec![
666            chunk("src/client.rs", "struct Client { /* ... */ }"),
667            chunk("src/unrelated.rs", "struct Client { /* ... */ }"),
668        ];
669        let mut boosted: HashMap<usize, f32> = HashMap::from([(0, 1.0), (1, 1.0)]);
670        boost_symbol_definitions(&mut boosted, "Client", 1.0, &chunks);
671        // base_boost = max_score (1.0) * 3.0 = 3.0
672        // stem-matched chunk 0: 1.0 + 3.0 * 1.5 = 5.5
673        // non-matched chunk 1: 1.0 + 3.0 = 4.0
674        assert!((boosted[&0] - 5.5).abs() < 1e-6);
675        assert!((boosted[&1] - 4.0).abs() < 1e-6);
676    }
677
678    // ----- boost_stem_matches -----
679
680    #[test]
681    fn boost_stem_matches_prefix() {
682        // Query keyword "parse" should boost a chunk in "parser.rs" via
683        // prefix overlap (shorter="parse", longer="parser", min-length 3
684        // satisfied). This is the real morphological-variant shape;
685        // Python semble's docstring example ("dependency" ↔ "dependencies")
686        // is misleading because the two diverge at char 9.
687        let chunks = vec![chunk("src/parser.rs", "fn run() {}")];
688        let mut boosted: HashMap<usize, f32> = HashMap::from([(0, 1.0_f32)]);
689        boost_stem_matches(&mut boosted, "parse json structure", 1.0, &chunks);
690        assert!(boosted[&0] > 1.0, "expected stem-match boost on parser.rs");
691    }
692
693    // ----- boost_embedded_symbols half-strength -----
694
695    #[test]
696    fn boost_embedded_symbols_half_strength() {
697        // Embedded symbol "MyClass" in an NL query boosts at half strength
698        // vs a pure symbol query.
699        let chunks = vec![chunk("src/myclass.rs", "struct MyClass {}")];
700        let mut boosted: HashMap<usize, f32> = HashMap::from([(0, 1.0_f32)]);
701        boost_embedded_symbols(&mut boosted, "how does MyClass handle errors", 1.0, &chunks);
702        // base_boost = 1.0 * 3.0 * 0.5 = 1.5
703        // stem-matched chunk 0: 1.0 + 1.5 * 1.5 = 3.25
704        assert!((boosted[&0] - 3.25).abs() < 1e-6, "got {}", boosted[&0]);
705    }
706
707    // ----- boost_multi_chunk_files (file coherence) -----
708
709    #[test]
710    fn boost_multi_chunk_files() {
711        // Three chunks: two in foo.rs (sum 1.5), one in bar.rs (sum 1.0).
712        // boost_unit = max_score (1.0) * 0.2 = 0.2
713        // foo.rs file_sum 1.5; boost to its TOP chunk (idx 0) =
714        //   0.2 * 1.5 / 1.5 = 0.2 → final 1.2.
715        // bar.rs file_sum 1.0; boost to chunk 2 = 0.2 * 1.0/1.5 ≈ 0.133 → ~1.133.
716        let chunks = vec![
717            chunk("src/foo.rs", ""),
718            chunk("src/foo.rs", ""),
719            chunk("src/bar.rs", ""),
720        ];
721        let mut scores: HashMap<usize, f32> = HashMap::from([(0, 1.0), (1, 0.5), (2, 1.0)]);
722        super::boost_multi_chunk_files(&mut scores, &chunks);
723        assert!((scores[&0] - 1.2).abs() < 1e-6, "got {}", scores[&0]);
724        assert!((scores[&1] - 0.5).abs() < 1e-6, "non-best chunk unchanged");
725        let expected_bar = 1.0 + 0.2 * (1.0 / 1.5);
726        assert!(
727            (scores[&2] - expected_bar).abs() < 1e-6,
728            "got {}, expected {}",
729            scores[&2],
730            expected_bar
731        );
732    }
733
734    // ----- property test (symbol regex parity) -----
735
736    #[test]
737    fn property_symbol_regex_parity_python() {
738        // These should be detected as symbol queries (matches Python).
739        let symbols = &[
740            "Client",
741            "handle_request",
742            "_private",
743            "getX",
744            "XMLParser",
745            "foo::bar",
746            "foo.bar.baz",
747            "a->b",
748            r"Foo\Bar",
749            "__init__",
750            "snake_case",
751        ];
752        for q in symbols {
753            assert!(is_symbol_query(q), "expected symbol query: {q:?}");
754        }
755        // These should NOT be detected (NL).
756        let non_symbols = &[
757            "session",
758            "retry",
759            "authentication",
760            "how does retry work",
761            "user authentication flow",
762            "hi",
763        ];
764        for q in non_symbols {
765            assert!(!is_symbol_query(q), "expected NL: {q:?}");
766        }
767    }
768}