Skip to main content

ripvec_core/encoder/ripvec/
ranking.rs

1//! Alpha auto-detection and query-driven boosting.
2//!
3//! Port of `~/src/semble/src/semble/ranking/weighting.py` and
4//! `~/src/semble/src/semble/ranking/boosting.py`.
5//!
6//! Three public entry points:
7//!
8//! - [`resolve_alpha`] — picks the semantic/BM25 blend weight from the
9//!   query shape: 0.3 for bare-symbol queries (lean BM25), 0.5 for
10//!   natural-language queries (balanced).
11//! - [`apply_query_boost`] — adds query-type boosts on top of a score
12//!   map. Returns a new map; callers re-rank afterwards.
13//! - [`boost_multi_chunk_files`] — file-coherence boost; promotes the
14//!   top chunk of files whose chunks collectively score high.
15//!
16//! Where Python keys `combined_scores` by `Chunk`, this port uses
17//! `HashMap<usize, f32>` (chunk index → score) plus `&[CodeChunk]` for
18//! lookups. Same shape as [`crate::encoder::ripvec::penalties::rerank_topk`].
19
20use std::collections::{HashMap, HashSet};
21use std::path::Path;
22use std::sync::OnceLock;
23
24use regex::{Regex, RegexBuilder};
25
26use crate::chunk::CodeChunk;
27use crate::encoder::ripvec::tokens::split_identifier;
28
29// ---------------------------------------------------------------------------
30// Alpha selection (weighting.py).
31// ---------------------------------------------------------------------------
32
33/// Semantic blend weight for symbol-shaped queries. Lean BM25.
34pub const ALPHA_SYMBOL: f32 = 0.3;
35/// Semantic blend weight for natural-language queries. Balanced.
36pub const ALPHA_NL: f32 = 0.5;
37
38/// Return the semantic blend weight, optionally overriding by caller.
39///
40/// `alpha = Some(w)` returns `w` directly. `alpha = None` auto-detects
41/// from the query: bare symbol-shaped → [`ALPHA_SYMBOL`], otherwise
42/// [`ALPHA_NL`].
43#[must_use]
44pub fn resolve_alpha(query: &str, alpha: Option<f32>) -> f32 {
45    if let Some(w) = alpha {
46        return w;
47    }
48    if is_symbol_query(query) {
49        ALPHA_SYMBOL
50    } else {
51        ALPHA_NL
52    }
53}
54
55// ---------------------------------------------------------------------------
56// Symbol-query detection (boosting.py:11).
57// ---------------------------------------------------------------------------
58
59/// Symbol-lookup queries: namespace-qualified, leading underscore, or
60/// containing uppercase/underscore. Plain lowercase words are NL.
61fn symbol_query_re() -> &'static Regex {
62    static RE: OnceLock<Regex> = OnceLock::new();
63    RE.get_or_init(|| {
64        Regex::new(concat!(
65            r"^(?:",
66            // namespace-qualified
67            r"[A-Za-z_][A-Za-z0-9_]*(?:(?:::|\\|->|\.)[A-Za-z_][A-Za-z0-9_]*)+",
68            // leading underscore
69            r"|_[A-Za-z0-9_]*",
70            // contains uppercase or underscore in the body
71            r"|[A-Za-z][A-Za-z0-9]*[A-Z_][A-Za-z0-9_]*",
72            // starts with uppercase
73            r"|[A-Z][A-Za-z0-9]*",
74            r")$",
75        ))
76        .expect("symbol-query regex compiles")
77    })
78}
79
80/// CamelCase / camelCase identifiers embedded in an NL query.
81/// Excludes plain words and pure acronyms.
82fn embedded_symbol_re() -> &'static Regex {
83    static RE: OnceLock<Regex> = OnceLock::new();
84    RE.get_or_init(|| {
85        Regex::new(concat!(
86            r"\b(?:",
87            // PascalCase: upper, lower run, then upper, then mix.
88            r"[A-Z][a-z][a-zA-Z0-9]*[A-Z][a-zA-Z0-9]*",
89            // camelCase: lower, mix, then upper, then mix.
90            r"|[a-z][a-zA-Z0-9]*[A-Z][a-zA-Z0-9]+",
91            r")\b",
92        ))
93        .expect("embedded-symbol regex compiles")
94    })
95}
96
97/// Return `true` when the query looks like a bare symbol or
98/// namespace-qualified identifier (`Foo::Bar`, `module.Class`, `_x`,
99/// `getX`, `XMLParser`, etc.).
100#[must_use]
101pub fn is_symbol_query(query: &str) -> bool {
102    symbol_query_re().is_match(query.trim())
103}
104
105// ---------------------------------------------------------------------------
106// Definition-keyword scan (boosting.py:36).
107// ---------------------------------------------------------------------------
108
109/// Language-agnostic definition keywords (Python, JS, Go, Rust, Kotlin,
110/// Elixir, Swift, etc.). Case-sensitive matching avoids false positives
111/// like "Module" appearing in Python docstrings.
112const DEFINITION_KEYWORDS: &[&str] = &[
113    "class",
114    "module",
115    "defmodule", // Elixir
116    "def",
117    "interface",
118    "struct",
119    "enum",
120    "trait",
121    "type",
122    "func",
123    "function",
124    "object",
125    "abstract class",
126    "data class",
127    "fn",
128    "fun", // Kotlin
129    "package",
130    "namespace",
131    "protocol", // Swift
132    "record",   // C# 9+, Java 16+
133    "typedef",  // C/C++/Dart
134];
135
136/// SQL DDL is conventionally all-caps or all-lowercase; match both via
137/// case-insensitive regex.
138const SQL_DEFINITION_KEYWORDS: &[&str] = &[
139    "CREATE TABLE",
140    "CREATE VIEW",
141    "CREATE PROCEDURE",
142    "CREATE FUNCTION",
143];
144
145const DEFINITION_BOOST_MULTIPLIER: f32 = 3.0;
146const STEM_BOOST_MULTIPLIER: f32 = 1.0;
147const FILE_COHERENCE_BOOST_FRAC: f32 = 0.2;
148const EMBEDDED_SYMBOL_BOOST_SCALE: f32 = 0.5;
149const EMBEDDED_STEM_MIN_LEN: usize = 4;
150
151/// Common English stopwords excluded from file-stem matching for NL
152/// queries. Mirrors `_STOPWORDS` from boosting.py:82.
153const STOPWORDS: &[&str] = &[
154    "a", "an", "and", "are", "as", "at", "be", "by", "do", "does", "for", "from", "has", "have",
155    "how", "if", "in", "is", "it", "not", "of", "on", "or", "the", "to", "was", "what", "when",
156    "where", "which", "who", "why", "with",
157];
158
159/// Build the general-keyword definition regex for a given symbol name.
160///
161/// Mirrors Python's `_definition_pattern`. Returns the compiled regex
162/// each call (Python uses `functools.lru_cache(256)` — we forgo the
163/// cache here; the hot path in `apply_query_boost` calls this a
164/// constant number of times per query).
165fn definition_pattern(symbol_name: &str) -> (Regex, Regex) {
166    let escaped = regex::escape(symbol_name);
167    let ns_prefix = r"(?:[A-Za-z_][A-Za-z0-9_]*(?:\.|::))*";
168    // Python's pattern uses `(?:^|(?<=\s))(?:keywords)` — keyword at start-of-line
169    // or preceded by whitespace. Rust's RE2 has no lookbehind, so we use
170    // `(?:^|\s)(?:keywords)` and let the whitespace be consumed; semantically
171    // equivalent for definition-keyword detection.
172    let def_body = DEFINITION_KEYWORDS
173        .iter()
174        .map(|k| regex::escape(k))
175        .collect::<Vec<_>>()
176        .join("|");
177    let sql_body = SQL_DEFINITION_KEYWORDS
178        .iter()
179        .map(|k| regex::escape(k))
180        .collect::<Vec<_>>()
181        .join("|");
182    let suffix = format!(r")\s+{ns_prefix}{escaped}(?:\s|[<({{:\[;]|$)");
183    // The Rust `regex` crate uses RE2 and does not support `(?<=\s)` (lookbehind).
184    // Python's pattern is `(?:^|(?<=\s))(?:keywords)` — keywords appear at start of line
185    // or after whitespace. Equivalent without lookbehind: anchor on `(?:^|\s)` and include
186    // the whitespace character in the match (rather than as a lookbehind boundary). The
187    // semantic is identical for definition detection.
188    let no_lookbehind_prefix = r"(?:^|\s)(?:";
189    let general_pat = format!("{no_lookbehind_prefix}{def_body}{suffix}");
190    let sql_pat = format!("{no_lookbehind_prefix}{sql_body}{suffix}");
191    let general = RegexBuilder::new(&general_pat)
192        .multi_line(true)
193        .build()
194        .expect("general definition regex compiles");
195    let sql = RegexBuilder::new(&sql_pat)
196        .multi_line(true)
197        .case_insensitive(true)
198        .build()
199        .expect("SQL definition regex compiles");
200    (general, sql)
201}
202
203/// `true` when `content` contains a definition of `symbol_name`.
204///
205/// Mirrors Python's `_chunk_defines_symbol`. Case-sensitive for general
206/// keywords; case-insensitive for SQL DDL. Namespace-qualified forms
207/// (`defmodule Phoenix.Router` for `Router`) match because the pattern
208/// allows an optional `ns_prefix`.
209fn chunk_defines_symbol(content: &str, symbol_name: &str) -> bool {
210    let (general, sql) = definition_pattern(symbol_name);
211    general.is_match(content) || sql.is_match(content)
212}
213
214/// `true` when `stem` matches `name` (exact, snake_case-normalised, or
215/// plural). Mirrors Python's `_stem_matches`.
216fn stem_matches(stem: &str, name: &str) -> bool {
217    let stem_norm = stem.replace('_', "");
218    stem == name
219        || stem_norm == name
220        || stem.trim_end_matches('s') == name
221        || stem_norm.trim_end_matches('s') == name
222}
223
224// ---------------------------------------------------------------------------
225// Symbol extraction (boosting.py:137).
226// ---------------------------------------------------------------------------
227
228/// Extract the final identifier from a possibly namespace-qualified
229/// query. Mirrors `_extract_symbol_name`.
230///
231/// Examples: `Sinatra::Base` → `Base`, `Client` → `Client`.
232fn extract_symbol_name(query: &str) -> String {
233    for separator in &["::", "\\", "->", "."] {
234        if let Some(idx) = query.rfind(separator) {
235            return query[idx + separator.len()..].to_string();
236        }
237    }
238    query.trim().to_string()
239}
240
241/// Return the boost amount for a chunk that defines one of `names`
242/// (0.0 if none match). Mirrors `_definition_tier`.
243fn definition_tier(chunk: &CodeChunk, names: &HashSet<String>, boost_unit: f32) -> f32 {
244    let any_match = names
245        .iter()
246        .any(|name| chunk_defines_symbol(&chunk.content, name));
247    if !any_match {
248        return 0.0;
249    }
250    let stem = Path::new(&chunk.file_path)
251        .file_stem()
252        .and_then(|s| s.to_str())
253        .unwrap_or_default()
254        .to_ascii_lowercase();
255    let stem_match_bonus = names
256        .iter()
257        .any(|name| stem_matches(&stem, &name.to_ascii_lowercase()));
258    boost_unit * if stem_match_bonus { 1.5 } else { 1.0 }
259}
260
261/// Boost non-candidate chunks whose lowercased file stem satisfies
262/// `stem_ok`. Mirrors `_scan_non_candidates`.
263fn scan_non_candidates(
264    boosted: &mut HashMap<usize, f32>,
265    names: &HashSet<String>,
266    boost_unit: f32,
267    all_chunks: &[CodeChunk],
268    stem_ok: &dyn Fn(&str) -> bool,
269) {
270    for (idx, chunk) in all_chunks.iter().enumerate() {
271        if boosted.contains_key(&idx) {
272            continue;
273        }
274        let stem = Path::new(&chunk.file_path)
275            .file_stem()
276            .and_then(|s| s.to_str())
277            .unwrap_or_default()
278            .to_ascii_lowercase();
279        if !stem_ok(&stem) {
280            continue;
281        }
282        let tier = definition_tier(chunk, names, boost_unit);
283        if tier > 0.0 {
284            boosted.insert(idx, tier);
285        }
286    }
287}
288
289/// Symbol-query branch: boost chunks defining the queried symbol and
290/// stem-matched non-candidates. Mirrors `_boost_symbol_definitions`.
291fn boost_symbol_definitions(
292    boosted: &mut HashMap<usize, f32>,
293    query: &str,
294    max_score: f32,
295    all_chunks: &[CodeChunk],
296) {
297    let symbol_name = extract_symbol_name(query);
298    let trimmed_query = query.trim();
299    let mut names: HashSet<String> = HashSet::new();
300    names.insert(symbol_name.clone());
301    if symbol_name != trimmed_query {
302        names.insert(trimmed_query.to_string());
303    }
304
305    let boost_unit = max_score * DEFINITION_BOOST_MULTIPLIER;
306
307    // Pass 1: walk current candidates, add a tier if they define the symbol.
308    let candidate_indices: Vec<usize> = boosted.keys().copied().collect();
309    for idx in candidate_indices {
310        let tier = definition_tier(&all_chunks[idx], &names, boost_unit);
311        if tier > 0.0 {
312            *boosted.entry(idx).or_insert(0.0) += tier;
313        }
314    }
315
316    // Pass 2: scan non-candidate chunks whose stem matches the symbol name.
317    let symbol_lower = symbol_name.to_ascii_lowercase();
318    scan_non_candidates(boosted, &names, boost_unit, all_chunks, &|stem: &str| {
319        stem_matches(stem, &symbol_lower)
320    });
321}
322
323/// NL-query branch: boost CamelCase / camelCase identifiers embedded in
324/// the query at half strength. Mirrors `_boost_embedded_symbols`.
325fn boost_embedded_symbols(
326    boosted: &mut HashMap<usize, f32>,
327    query: &str,
328    max_score: f32,
329    all_chunks: &[CodeChunk],
330) {
331    let names: HashSet<String> = embedded_symbol_re()
332        .find_iter(query)
333        .map(|m| m.as_str().to_string())
334        .collect();
335    if names.is_empty() {
336        return;
337    }
338
339    let boost_unit = max_score * DEFINITION_BOOST_MULTIPLIER * EMBEDDED_SYMBOL_BOOST_SCALE;
340
341    // Pass 1: candidates that define the embedded symbol(s).
342    let candidate_indices: Vec<usize> = boosted.keys().copied().collect();
343    for idx in candidate_indices {
344        let tier = definition_tier(&all_chunks[idx], &names, boost_unit);
345        if tier > 0.0 {
346            *boosted.entry(idx).or_insert(0.0) += tier;
347        }
348    }
349
350    // Pass 2: non-candidate stem-prefix scan.
351    let symbols_lower: Vec<String> = names.iter().map(|n| n.to_ascii_lowercase()).collect();
352    let symbols_lower_for_scan = symbols_lower.clone();
353    scan_non_candidates(
354        boosted,
355        &names,
356        boost_unit,
357        all_chunks,
358        &move |stem: &str| {
359            let stem_norm = stem.replace('_', "");
360            symbols_lower_for_scan.iter().any(|sym_lower| {
361                stem == sym_lower
362                    || stem_norm == *sym_lower
363                    || (stem.len() >= EMBEDDED_STEM_MIN_LEN && sym_lower.starts_with(stem))
364                    || (stem_norm.len() >= EMBEDDED_STEM_MIN_LEN
365                        && sym_lower.starts_with(stem_norm.as_str()))
366            })
367        },
368    );
369}
370
371/// Count query keywords that match path parts. Mirrors
372/// `_count_keyword_matches`. Allows prefix overlap when the shorter
373/// side has at least 3 characters.
374fn count_keyword_matches(keywords: &HashSet<String>, parts: &HashSet<String>) -> usize {
375    let exact: HashSet<&String> = keywords.iter().filter(|k| parts.contains(*k)).collect();
376    if exact.len() == keywords.len() {
377        return exact.len();
378    }
379    let mut n = exact.len();
380    for keyword in keywords {
381        if exact.contains(keyword) {
382            continue;
383        }
384        for part in parts {
385            let (shorter, longer) = if keyword.len() <= part.len() {
386                (keyword.as_str(), part.as_str())
387            } else {
388                (part.as_str(), keyword.as_str())
389            };
390            if shorter.len() >= 3 && longer.starts_with(shorter) {
391                n += 1;
392                break;
393            }
394        }
395    }
396    n
397}
398
399/// Boost chunks whose file paths match NL query keywords. Mirrors
400/// `_boost_stem_matches`. Uses prefix matching for morphological
401/// variants ("dependency" matches "dependencies"). Matches file stems
402/// and the immediate parent directory name.
403fn boost_stem_matches(
404    boosted: &mut HashMap<usize, f32>,
405    query: &str,
406    max_score: f32,
407    chunks: &[CodeChunk],
408) {
409    static KEYWORD_RE: OnceLock<Regex> = OnceLock::new();
410    let keyword_re =
411        KEYWORD_RE.get_or_init(|| Regex::new(r"[a-zA-Z_][a-zA-Z0-9_]*").expect("keyword regex"));
412    let keywords: HashSet<String> = keyword_re
413        .find_iter(query)
414        .map(|m| m.as_str().to_ascii_lowercase())
415        .filter(|w| w.len() > 2 && !STOPWORDS.contains(&w.as_str()))
416        .collect();
417    if keywords.is_empty() {
418        return;
419    }
420
421    let boost = max_score * STEM_BOOST_MULTIPLIER;
422    let mut path_cache: HashMap<String, HashSet<String>> = HashMap::new();
423    let candidate_indices: Vec<usize> = boosted.keys().copied().collect();
424    for idx in candidate_indices {
425        let path = &chunks[idx].file_path;
426        let parts = path_cache
427            .entry(path.clone())
428            .or_insert_with(|| {
429                let mut parts: HashSet<String> = HashSet::new();
430                let p = Path::new(path);
431                if let Some(stem) = p.file_stem().and_then(|s| s.to_str()) {
432                    parts.extend(split_identifier(stem));
433                }
434                if let Some(parent_name) = p
435                    .parent()
436                    .and_then(Path::file_name)
437                    .and_then(|s| s.to_str())
438                    && !parent_name.is_empty()
439                    && parent_name != "."
440                    && parent_name != ".."
441                {
442                    parts.extend(split_identifier(parent_name));
443                }
444                parts
445            })
446            .clone();
447        let n_matches = count_keyword_matches(&keywords, &parts);
448        if n_matches > 0 {
449            let match_ratio = n_matches as f32 / keywords.len() as f32;
450            if match_ratio >= 0.10 {
451                *boosted.entry(idx).or_insert(0.0) += boost * match_ratio;
452            }
453        }
454    }
455}
456
457// ---------------------------------------------------------------------------
458// Public entry: apply_query_boost.
459// ---------------------------------------------------------------------------
460
461/// Apply query-type boosts to candidate scores.
462///
463/// Mirrors `apply_query_boost`. Returns a new map; the input is not
464/// mutated. Empty input passes through.
465///
466/// Branches on [`is_symbol_query`]:
467/// - Symbol-shaped query → [`boost_symbol_definitions`] (×3 base,
468///   ×1.5 on stem match) and scan non-candidate stem-matched chunks.
469/// - NL query → [`boost_stem_matches`] (file-stem keyword overlap) +
470///   [`boost_embedded_symbols`] (half-strength CamelCase scan).
471#[expect(
472    clippy::implicit_hasher,
473    reason = "internal API; callers in the semble pipeline use the default RandomState"
474)]
475#[must_use]
476pub fn apply_query_boost(
477    combined_scores: &HashMap<usize, f32>,
478    query: &str,
479    all_chunks: &[CodeChunk],
480) -> HashMap<usize, f32> {
481    if combined_scores.is_empty() {
482        return HashMap::new();
483    }
484    let max_score = combined_scores
485        .values()
486        .copied()
487        .fold(f32::NEG_INFINITY, f32::max);
488    let mut boosted = combined_scores.clone();
489    if is_symbol_query(query) {
490        boost_symbol_definitions(&mut boosted, query, max_score, all_chunks);
491    } else {
492        boost_stem_matches(&mut boosted, query, max_score, all_chunks);
493        boost_embedded_symbols(&mut boosted, query, max_score, all_chunks);
494    }
495    boosted
496}
497
498/// Promote files with multiple high-scoring chunks by boosting their
499/// top chunk in place. Mirrors `boost_multi_chunk_files`.
500#[expect(
501    clippy::implicit_hasher,
502    reason = "internal API; callers in the semble pipeline use the default RandomState"
503)]
504pub fn boost_multi_chunk_files(scores: &mut HashMap<usize, f32>, chunks: &[CodeChunk]) {
505    if scores.is_empty() {
506        return;
507    }
508    let max_score = scores.values().copied().fold(f32::NEG_INFINITY, f32::max);
509    if max_score == 0.0 || !max_score.is_finite() {
510        return;
511    }
512
513    let mut file_sum: HashMap<String, f32> = HashMap::new();
514    let mut best_chunk_idx: HashMap<String, usize> = HashMap::new();
515    for (&idx, &score) in scores.iter() {
516        let path = chunks[idx].file_path.clone();
517        *file_sum.entry(path.clone()).or_insert(0.0) += score;
518        match best_chunk_idx.get(&path) {
519            Some(&best) if scores[&best] >= score => {}
520            _ => {
521                best_chunk_idx.insert(path, idx);
522            }
523        }
524    }
525    let max_file_sum = file_sum.values().copied().fold(f32::NEG_INFINITY, f32::max);
526    if max_file_sum <= 0.0 || !max_file_sum.is_finite() {
527        return;
528    }
529    let boost_unit = max_score * FILE_COHERENCE_BOOST_FRAC;
530    for (path, &idx) in &best_chunk_idx {
531        let contribution = boost_unit * file_sum[path] / max_file_sum;
532        *scores.entry(idx).or_insert(0.0) += contribution;
533    }
534}
535
536#[cfg(test)]
537mod tests {
538    use super::*;
539
540    fn chunk(path: &str, content: &str) -> CodeChunk {
541        CodeChunk {
542            file_path: path.to_string(),
543            name: String::new(),
544            kind: String::new(),
545            start_line: 1,
546            end_line: 1,
547            content: content.to_string(),
548            enriched_content: content.to_string(),
549        }
550    }
551
552    // ----- is_symbol_query (boosting.py:132) -----
553
554    #[test]
555    fn is_symbol_query_namespace() {
556        assert!(is_symbol_query("Sinatra::Base"));
557        assert!(is_symbol_query("module.Class"));
558        assert!(is_symbol_query("a->b->c"));
559        assert!(is_symbol_query(r"Foo\Bar"));
560    }
561
562    #[test]
563    fn is_symbol_query_pascal() {
564        assert!(is_symbol_query("Client"));
565        assert!(is_symbol_query("HTTPHandler"));
566        assert!(is_symbol_query("XMLParser"));
567    }
568
569    #[test]
570    fn is_symbol_query_plain_word_rejected() {
571        assert!(!is_symbol_query("session"));
572        assert!(!is_symbol_query("retry"));
573        assert!(!is_symbol_query("authentication"));
574    }
575
576    #[test]
577    fn is_symbol_query_leading_underscore_accepted() {
578        assert!(is_symbol_query("_private"));
579        assert!(is_symbol_query("__init__"));
580    }
581
582    // ----- resolve_alpha -----
583
584    #[test]
585    fn resolve_alpha_symbol_0_3() {
586        assert!((resolve_alpha("Client", None) - ALPHA_SYMBOL).abs() < 1e-6);
587        assert!((resolve_alpha("foo.Bar", None) - ALPHA_SYMBOL).abs() < 1e-6);
588    }
589
590    #[test]
591    fn resolve_alpha_nl_0_5() {
592        assert!((resolve_alpha("how does retry work", None) - ALPHA_NL).abs() < 1e-6);
593        assert!((resolve_alpha("authentication handling", None) - ALPHA_NL).abs() < 1e-6);
594    }
595
596    #[test]
597    fn resolve_alpha_explicit_override_wins() {
598        assert!((resolve_alpha("Client", Some(0.7)) - 0.7).abs() < 1e-6);
599    }
600
601    // ----- definition_pattern + chunk_defines_symbol -----
602
603    #[test]
604    fn chunk_defines_symbol_class() {
605        let content = "class Client:\n    pass";
606        assert!(chunk_defines_symbol(content, "Client"));
607    }
608
609    #[test]
610    fn chunk_defines_symbol_def() {
611        let content = "def handle_request():\n    pass";
612        assert!(chunk_defines_symbol(content, "handle_request"));
613    }
614
615    #[test]
616    fn chunk_defines_symbol_namespace_qualified() {
617        // Elixir defmodule with namespace prefix should match the bare symbol.
618        let content = " defmodule Phoenix.Router do\n";
619        assert!(chunk_defines_symbol(content, "Router"));
620    }
621
622    #[test]
623    fn chunk_defines_symbol_sql_case_insensitive() {
624        assert!(chunk_defines_symbol(
625            " create table users (id int)",
626            "users"
627        ));
628        assert!(chunk_defines_symbol(
629            " CREATE TABLE Users (id int)",
630            "Users"
631        ));
632    }
633
634    #[test]
635    fn chunk_defines_symbol_negative() {
636        assert!(!chunk_defines_symbol("client.do_thing()", "Client"));
637    }
638
639    // ----- boost_symbol_definitions stem multiplier -----
640
641    #[test]
642    fn boost_symbol_definitions_stem_multiplier() {
643        // Two chunks both define `Client`; one is in client.rs (stem match),
644        // the other in unrelated.rs. The stem-matched chunk gets the 1.5x
645        // bonus on top of the base ×3 multiplier.
646        let chunks = vec![
647            chunk("src/client.rs", "struct Client { /* ... */ }"),
648            chunk("src/unrelated.rs", "struct Client { /* ... */ }"),
649        ];
650        let mut boosted: HashMap<usize, f32> = HashMap::from([(0, 1.0), (1, 1.0)]);
651        boost_symbol_definitions(&mut boosted, "Client", 1.0, &chunks);
652        // base_boost = max_score (1.0) * 3.0 = 3.0
653        // stem-matched chunk 0: 1.0 + 3.0 * 1.5 = 5.5
654        // non-matched chunk 1: 1.0 + 3.0 = 4.0
655        assert!((boosted[&0] - 5.5).abs() < 1e-6);
656        assert!((boosted[&1] - 4.0).abs() < 1e-6);
657    }
658
659    // ----- boost_stem_matches -----
660
661    #[test]
662    fn boost_stem_matches_prefix() {
663        // Query keyword "parse" should boost a chunk in "parser.rs" via
664        // prefix overlap (shorter="parse", longer="parser", min-length 3
665        // satisfied). This is the real morphological-variant shape;
666        // Python semble's docstring example ("dependency" ↔ "dependencies")
667        // is misleading because the two diverge at char 9.
668        let chunks = vec![chunk("src/parser.rs", "fn run() {}")];
669        let mut boosted: HashMap<usize, f32> = HashMap::from([(0, 1.0_f32)]);
670        boost_stem_matches(&mut boosted, "parse json structure", 1.0, &chunks);
671        assert!(boosted[&0] > 1.0, "expected stem-match boost on parser.rs");
672    }
673
674    // ----- boost_embedded_symbols half-strength -----
675
676    #[test]
677    fn boost_embedded_symbols_half_strength() {
678        // Embedded symbol "MyClass" in an NL query boosts at half strength
679        // vs a pure symbol query.
680        let chunks = vec![chunk("src/myclass.rs", "struct MyClass {}")];
681        let mut boosted: HashMap<usize, f32> = HashMap::from([(0, 1.0_f32)]);
682        boost_embedded_symbols(&mut boosted, "how does MyClass handle errors", 1.0, &chunks);
683        // base_boost = 1.0 * 3.0 * 0.5 = 1.5
684        // stem-matched chunk 0: 1.0 + 1.5 * 1.5 = 3.25
685        assert!((boosted[&0] - 3.25).abs() < 1e-6, "got {}", boosted[&0]);
686    }
687
688    // ----- boost_multi_chunk_files (file coherence) -----
689
690    #[test]
691    fn boost_multi_chunk_files() {
692        // Three chunks: two in foo.rs (sum 1.5), one in bar.rs (sum 1.0).
693        // boost_unit = max_score (1.0) * 0.2 = 0.2
694        // foo.rs file_sum 1.5; boost to its TOP chunk (idx 0) =
695        //   0.2 * 1.5 / 1.5 = 0.2 → final 1.2.
696        // bar.rs file_sum 1.0; boost to chunk 2 = 0.2 * 1.0/1.5 ≈ 0.133 → ~1.133.
697        let chunks = vec![
698            chunk("src/foo.rs", ""),
699            chunk("src/foo.rs", ""),
700            chunk("src/bar.rs", ""),
701        ];
702        let mut scores: HashMap<usize, f32> = HashMap::from([(0, 1.0), (1, 0.5), (2, 1.0)]);
703        super::boost_multi_chunk_files(&mut scores, &chunks);
704        assert!((scores[&0] - 1.2).abs() < 1e-6, "got {}", scores[&0]);
705        assert!((scores[&1] - 0.5).abs() < 1e-6, "non-best chunk unchanged");
706        let expected_bar = 1.0 + 0.2 * (1.0 / 1.5);
707        assert!(
708            (scores[&2] - expected_bar).abs() < 1e-6,
709            "got {}, expected {}",
710            scores[&2],
711            expected_bar
712        );
713    }
714
715    // ----- property test (symbol regex parity) -----
716
717    #[test]
718    fn property_symbol_regex_parity_python() {
719        // These should be detected as symbol queries (matches Python).
720        let symbols = &[
721            "Client",
722            "handle_request",
723            "_private",
724            "getX",
725            "XMLParser",
726            "foo::bar",
727            "foo.bar.baz",
728            "a->b",
729            r"Foo\Bar",
730            "__init__",
731            "snake_case",
732        ];
733        for q in symbols {
734            assert!(is_symbol_query(q), "expected symbol query: {q:?}");
735        }
736        // These should NOT be detected (NL).
737        let non_symbols = &[
738            "session",
739            "retry",
740            "authentication",
741            "how does retry work",
742            "user authentication flow",
743            "hi",
744        ];
745        for q in non_symbols {
746            assert!(!is_symbol_query(q), "expected NL: {q:?}");
747        }
748    }
749}