1use std::collections::{HashMap, HashSet};
21use std::path::Path;
22use std::sync::OnceLock;
23
24use regex::{Regex, RegexBuilder};
25
26use crate::chunk::CodeChunk;
27use crate::encoder::ripvec::tokens::split_identifier;
28
29pub const ALPHA_SYMBOL: f32 = 0.3;
35pub const ALPHA_NL: f32 = 0.5;
37
38#[must_use]
44pub fn resolve_alpha(query: &str, alpha: Option<f32>) -> f32 {
45 if let Some(w) = alpha {
46 return w;
47 }
48 if is_symbol_query(query) {
49 ALPHA_SYMBOL
50 } else {
51 ALPHA_NL
52 }
53}
54
55fn symbol_query_re() -> &'static Regex {
62 static RE: OnceLock<Regex> = OnceLock::new();
63 RE.get_or_init(|| {
64 Regex::new(concat!(
65 r"^(?:",
66 r"[A-Za-z_][A-Za-z0-9_]*(?:(?:::|\\|->|\.)[A-Za-z_][A-Za-z0-9_]*)+",
68 r"|_[A-Za-z0-9_]*",
70 r"|[A-Za-z][A-Za-z0-9]*[A-Z_][A-Za-z0-9_]*",
72 r"|[A-Z][A-Za-z0-9]*",
74 r")$",
75 ))
76 .expect("symbol-query regex compiles")
77 })
78}
79
80fn embedded_symbol_re() -> &'static Regex {
83 static RE: OnceLock<Regex> = OnceLock::new();
84 RE.get_or_init(|| {
85 Regex::new(concat!(
86 r"\b(?:",
87 r"[A-Z][a-z][a-zA-Z0-9]*[A-Z][a-zA-Z0-9]*",
89 r"|[a-z][a-zA-Z0-9]*[A-Z][a-zA-Z0-9]+",
91 r")\b",
92 ))
93 .expect("embedded-symbol regex compiles")
94 })
95}
96
97#[must_use]
101pub fn is_symbol_query(query: &str) -> bool {
102 symbol_query_re().is_match(query.trim())
103}
104
105const DEFINITION_KEYWORDS: &[&str] = &[
113 "class",
114 "module",
115 "defmodule", "def",
117 "interface",
118 "struct",
119 "enum",
120 "trait",
121 "type",
122 "func",
123 "function",
124 "object",
125 "abstract class",
126 "data class",
127 "fn",
128 "fun", "package",
130 "namespace",
131 "protocol", "record", "typedef", ];
135
136const SQL_DEFINITION_KEYWORDS: &[&str] = &[
139 "CREATE TABLE",
140 "CREATE VIEW",
141 "CREATE PROCEDURE",
142 "CREATE FUNCTION",
143];
144
145const DEFINITION_BOOST_MULTIPLIER: f32 = 3.0;
146const STEM_BOOST_MULTIPLIER: f32 = 1.0;
147const FILE_COHERENCE_BOOST_FRAC: f32 = 0.2;
148const EMBEDDED_SYMBOL_BOOST_SCALE: f32 = 0.5;
149const EMBEDDED_STEM_MIN_LEN: usize = 4;
150
151const STOPWORDS: &[&str] = &[
154 "a", "an", "and", "are", "as", "at", "be", "by", "do", "does", "for", "from", "has", "have",
155 "how", "if", "in", "is", "it", "not", "of", "on", "or", "the", "to", "was", "what", "when",
156 "where", "which", "who", "why", "with",
157];
158
159fn definition_pattern(symbol_name: &str) -> (Regex, Regex) {
166 let escaped = regex::escape(symbol_name);
167 let ns_prefix = r"(?:[A-Za-z_][A-Za-z0-9_]*(?:\.|::))*";
168 let def_body = DEFINITION_KEYWORDS
173 .iter()
174 .map(|k| regex::escape(k))
175 .collect::<Vec<_>>()
176 .join("|");
177 let sql_body = SQL_DEFINITION_KEYWORDS
178 .iter()
179 .map(|k| regex::escape(k))
180 .collect::<Vec<_>>()
181 .join("|");
182 let suffix = format!(r")\s+{ns_prefix}{escaped}(?:\s|[<({{:\[;]|$)");
183 let no_lookbehind_prefix = r"(?:^|\s)(?:";
189 let general_pat = format!("{no_lookbehind_prefix}{def_body}{suffix}");
190 let sql_pat = format!("{no_lookbehind_prefix}{sql_body}{suffix}");
191 let general = RegexBuilder::new(&general_pat)
192 .multi_line(true)
193 .build()
194 .expect("general definition regex compiles");
195 let sql = RegexBuilder::new(&sql_pat)
196 .multi_line(true)
197 .case_insensitive(true)
198 .build()
199 .expect("SQL definition regex compiles");
200 (general, sql)
201}
202
203fn chunk_defines_symbol(content: &str, symbol_name: &str) -> bool {
210 let (general, sql) = definition_pattern(symbol_name);
211 general.is_match(content) || sql.is_match(content)
212}
213
214fn stem_matches(stem: &str, name: &str) -> bool {
217 let stem_norm = stem.replace('_', "");
218 stem == name
219 || stem_norm == name
220 || stem.trim_end_matches('s') == name
221 || stem_norm.trim_end_matches('s') == name
222}
223
224fn extract_symbol_name(query: &str) -> String {
233 for separator in &["::", "\\", "->", "."] {
234 if let Some(idx) = query.rfind(separator) {
235 return query[idx + separator.len()..].to_string();
236 }
237 }
238 query.trim().to_string()
239}
240
241fn definition_tier(chunk: &CodeChunk, names: &HashSet<String>, boost_unit: f32) -> f32 {
244 let any_match = names
245 .iter()
246 .any(|name| chunk_defines_symbol(&chunk.content, name));
247 if !any_match {
248 return 0.0;
249 }
250 let stem = Path::new(&chunk.file_path)
251 .file_stem()
252 .and_then(|s| s.to_str())
253 .unwrap_or_default()
254 .to_ascii_lowercase();
255 let stem_match_bonus = names
256 .iter()
257 .any(|name| stem_matches(&stem, &name.to_ascii_lowercase()));
258 boost_unit * if stem_match_bonus { 1.5 } else { 1.0 }
259}
260
261fn scan_non_candidates(
264 boosted: &mut HashMap<usize, f32>,
265 names: &HashSet<String>,
266 boost_unit: f32,
267 all_chunks: &[CodeChunk],
268 stem_ok: &dyn Fn(&str) -> bool,
269) {
270 for (idx, chunk) in all_chunks.iter().enumerate() {
271 if boosted.contains_key(&idx) {
272 continue;
273 }
274 let stem = Path::new(&chunk.file_path)
275 .file_stem()
276 .and_then(|s| s.to_str())
277 .unwrap_or_default()
278 .to_ascii_lowercase();
279 if !stem_ok(&stem) {
280 continue;
281 }
282 let tier = definition_tier(chunk, names, boost_unit);
283 if tier > 0.0 {
284 boosted.insert(idx, tier);
285 }
286 }
287}
288
289fn boost_symbol_definitions(
292 boosted: &mut HashMap<usize, f32>,
293 query: &str,
294 max_score: f32,
295 all_chunks: &[CodeChunk],
296) {
297 let symbol_name = extract_symbol_name(query);
298 let trimmed_query = query.trim();
299 let mut names: HashSet<String> = HashSet::new();
300 names.insert(symbol_name.clone());
301 if symbol_name != trimmed_query {
302 names.insert(trimmed_query.to_string());
303 }
304
305 let boost_unit = max_score * DEFINITION_BOOST_MULTIPLIER;
306
307 let candidate_indices: Vec<usize> = boosted.keys().copied().collect();
309 for idx in candidate_indices {
310 let tier = definition_tier(&all_chunks[idx], &names, boost_unit);
311 if tier > 0.0 {
312 *boosted.entry(idx).or_insert(0.0) += tier;
313 }
314 }
315
316 let symbol_lower = symbol_name.to_ascii_lowercase();
318 scan_non_candidates(boosted, &names, boost_unit, all_chunks, &|stem: &str| {
319 stem_matches(stem, &symbol_lower)
320 });
321}
322
323fn boost_embedded_symbols(
326 boosted: &mut HashMap<usize, f32>,
327 query: &str,
328 max_score: f32,
329 all_chunks: &[CodeChunk],
330) {
331 let names: HashSet<String> = embedded_symbol_re()
332 .find_iter(query)
333 .map(|m| m.as_str().to_string())
334 .collect();
335 if names.is_empty() {
336 return;
337 }
338
339 let boost_unit = max_score * DEFINITION_BOOST_MULTIPLIER * EMBEDDED_SYMBOL_BOOST_SCALE;
340
341 let candidate_indices: Vec<usize> = boosted.keys().copied().collect();
343 for idx in candidate_indices {
344 let tier = definition_tier(&all_chunks[idx], &names, boost_unit);
345 if tier > 0.0 {
346 *boosted.entry(idx).or_insert(0.0) += tier;
347 }
348 }
349
350 let symbols_lower: Vec<String> = names.iter().map(|n| n.to_ascii_lowercase()).collect();
352 let symbols_lower_for_scan = symbols_lower.clone();
353 scan_non_candidates(
354 boosted,
355 &names,
356 boost_unit,
357 all_chunks,
358 &move |stem: &str| {
359 let stem_norm = stem.replace('_', "");
360 symbols_lower_for_scan.iter().any(|sym_lower| {
361 stem == sym_lower
362 || stem_norm == *sym_lower
363 || (stem.len() >= EMBEDDED_STEM_MIN_LEN && sym_lower.starts_with(stem))
364 || (stem_norm.len() >= EMBEDDED_STEM_MIN_LEN
365 && sym_lower.starts_with(stem_norm.as_str()))
366 })
367 },
368 );
369}
370
371fn count_keyword_matches(keywords: &HashSet<String>, parts: &HashSet<String>) -> usize {
375 let exact: HashSet<&String> = keywords.iter().filter(|k| parts.contains(*k)).collect();
376 if exact.len() == keywords.len() {
377 return exact.len();
378 }
379 let mut n = exact.len();
380 for keyword in keywords {
381 if exact.contains(keyword) {
382 continue;
383 }
384 for part in parts {
385 let (shorter, longer) = if keyword.len() <= part.len() {
386 (keyword.as_str(), part.as_str())
387 } else {
388 (part.as_str(), keyword.as_str())
389 };
390 if shorter.len() >= 3 && longer.starts_with(shorter) {
391 n += 1;
392 break;
393 }
394 }
395 }
396 n
397}
398
399fn boost_stem_matches(
404 boosted: &mut HashMap<usize, f32>,
405 query: &str,
406 max_score: f32,
407 chunks: &[CodeChunk],
408) {
409 static KEYWORD_RE: OnceLock<Regex> = OnceLock::new();
410 let keyword_re =
411 KEYWORD_RE.get_or_init(|| Regex::new(r"[a-zA-Z_][a-zA-Z0-9_]*").expect("keyword regex"));
412 let keywords: HashSet<String> = keyword_re
413 .find_iter(query)
414 .map(|m| m.as_str().to_ascii_lowercase())
415 .filter(|w| w.len() > 2 && !STOPWORDS.contains(&w.as_str()))
416 .collect();
417 if keywords.is_empty() {
418 return;
419 }
420
421 let boost = max_score * STEM_BOOST_MULTIPLIER;
422 let mut path_cache: HashMap<String, HashSet<String>> = HashMap::new();
423 let candidate_indices: Vec<usize> = boosted.keys().copied().collect();
424 for idx in candidate_indices {
425 let path = &chunks[idx].file_path;
426 let parts = path_cache
427 .entry(path.clone())
428 .or_insert_with(|| {
429 let mut parts: HashSet<String> = HashSet::new();
430 let p = Path::new(path);
431 if let Some(stem) = p.file_stem().and_then(|s| s.to_str()) {
432 parts.extend(split_identifier(stem));
433 }
434 if let Some(parent_name) = p
435 .parent()
436 .and_then(Path::file_name)
437 .and_then(|s| s.to_str())
438 && !parent_name.is_empty()
439 && parent_name != "."
440 && parent_name != ".."
441 {
442 parts.extend(split_identifier(parent_name));
443 }
444 parts
445 })
446 .clone();
447 let n_matches = count_keyword_matches(&keywords, &parts);
448 if n_matches > 0 {
449 let match_ratio = n_matches as f32 / keywords.len() as f32;
450 if match_ratio >= 0.10 {
451 *boosted.entry(idx).or_insert(0.0) += boost * match_ratio;
452 }
453 }
454 }
455}
456
457#[expect(
472 clippy::implicit_hasher,
473 reason = "internal API; callers in the semble pipeline use the default RandomState"
474)]
475#[must_use]
476pub fn apply_query_boost(
477 combined_scores: &HashMap<usize, f32>,
478 query: &str,
479 all_chunks: &[CodeChunk],
480) -> HashMap<usize, f32> {
481 if combined_scores.is_empty() {
482 return HashMap::new();
483 }
484 let max_score = combined_scores
485 .values()
486 .copied()
487 .fold(f32::NEG_INFINITY, f32::max);
488 let mut boosted = combined_scores.clone();
489 if is_symbol_query(query) {
490 boost_symbol_definitions(&mut boosted, query, max_score, all_chunks);
491 } else {
492 boost_stem_matches(&mut boosted, query, max_score, all_chunks);
493 boost_embedded_symbols(&mut boosted, query, max_score, all_chunks);
494 }
495 boosted
496}
497
498#[expect(
501 clippy::implicit_hasher,
502 reason = "internal API; callers in the semble pipeline use the default RandomState"
503)]
504pub fn boost_multi_chunk_files(scores: &mut HashMap<usize, f32>, chunks: &[CodeChunk]) {
505 if scores.is_empty() {
506 return;
507 }
508 let max_score = scores.values().copied().fold(f32::NEG_INFINITY, f32::max);
509 if max_score == 0.0 || !max_score.is_finite() {
510 return;
511 }
512
513 let mut file_sum: HashMap<String, f32> = HashMap::new();
514 let mut best_chunk_idx: HashMap<String, usize> = HashMap::new();
515 for (&idx, &score) in scores.iter() {
516 let path = chunks[idx].file_path.clone();
517 *file_sum.entry(path.clone()).or_insert(0.0) += score;
518 match best_chunk_idx.get(&path) {
519 Some(&best) if scores[&best] >= score => {}
520 _ => {
521 best_chunk_idx.insert(path, idx);
522 }
523 }
524 }
525 let max_file_sum = file_sum.values().copied().fold(f32::NEG_INFINITY, f32::max);
526 if max_file_sum <= 0.0 || !max_file_sum.is_finite() {
527 return;
528 }
529 let boost_unit = max_score * FILE_COHERENCE_BOOST_FRAC;
530 for (path, &idx) in &best_chunk_idx {
531 let contribution = boost_unit * file_sum[path] / max_file_sum;
532 *scores.entry(idx).or_insert(0.0) += contribution;
533 }
534}
535
536#[cfg(test)]
537mod tests {
538 use super::*;
539
540 fn chunk(path: &str, content: &str) -> CodeChunk {
541 CodeChunk {
542 file_path: path.to_string(),
543 name: String::new(),
544 kind: String::new(),
545 start_line: 1,
546 end_line: 1,
547 content: content.to_string(),
548 enriched_content: content.to_string(),
549 }
550 }
551
552 #[test]
555 fn is_symbol_query_namespace() {
556 assert!(is_symbol_query("Sinatra::Base"));
557 assert!(is_symbol_query("module.Class"));
558 assert!(is_symbol_query("a->b->c"));
559 assert!(is_symbol_query(r"Foo\Bar"));
560 }
561
562 #[test]
563 fn is_symbol_query_pascal() {
564 assert!(is_symbol_query("Client"));
565 assert!(is_symbol_query("HTTPHandler"));
566 assert!(is_symbol_query("XMLParser"));
567 }
568
569 #[test]
570 fn is_symbol_query_plain_word_rejected() {
571 assert!(!is_symbol_query("session"));
572 assert!(!is_symbol_query("retry"));
573 assert!(!is_symbol_query("authentication"));
574 }
575
576 #[test]
577 fn is_symbol_query_leading_underscore_accepted() {
578 assert!(is_symbol_query("_private"));
579 assert!(is_symbol_query("__init__"));
580 }
581
582 #[test]
585 fn resolve_alpha_symbol_0_3() {
586 assert!((resolve_alpha("Client", None) - ALPHA_SYMBOL).abs() < 1e-6);
587 assert!((resolve_alpha("foo.Bar", None) - ALPHA_SYMBOL).abs() < 1e-6);
588 }
589
590 #[test]
591 fn resolve_alpha_nl_0_5() {
592 assert!((resolve_alpha("how does retry work", None) - ALPHA_NL).abs() < 1e-6);
593 assert!((resolve_alpha("authentication handling", None) - ALPHA_NL).abs() < 1e-6);
594 }
595
596 #[test]
597 fn resolve_alpha_explicit_override_wins() {
598 assert!((resolve_alpha("Client", Some(0.7)) - 0.7).abs() < 1e-6);
599 }
600
601 #[test]
604 fn chunk_defines_symbol_class() {
605 let content = "class Client:\n pass";
606 assert!(chunk_defines_symbol(content, "Client"));
607 }
608
609 #[test]
610 fn chunk_defines_symbol_def() {
611 let content = "def handle_request():\n pass";
612 assert!(chunk_defines_symbol(content, "handle_request"));
613 }
614
615 #[test]
616 fn chunk_defines_symbol_namespace_qualified() {
617 let content = " defmodule Phoenix.Router do\n";
619 assert!(chunk_defines_symbol(content, "Router"));
620 }
621
622 #[test]
623 fn chunk_defines_symbol_sql_case_insensitive() {
624 assert!(chunk_defines_symbol(
625 " create table users (id int)",
626 "users"
627 ));
628 assert!(chunk_defines_symbol(
629 " CREATE TABLE Users (id int)",
630 "Users"
631 ));
632 }
633
634 #[test]
635 fn chunk_defines_symbol_negative() {
636 assert!(!chunk_defines_symbol("client.do_thing()", "Client"));
637 }
638
639 #[test]
642 fn boost_symbol_definitions_stem_multiplier() {
643 let chunks = vec![
647 chunk("src/client.rs", "struct Client { /* ... */ }"),
648 chunk("src/unrelated.rs", "struct Client { /* ... */ }"),
649 ];
650 let mut boosted: HashMap<usize, f32> = HashMap::from([(0, 1.0), (1, 1.0)]);
651 boost_symbol_definitions(&mut boosted, "Client", 1.0, &chunks);
652 assert!((boosted[&0] - 5.5).abs() < 1e-6);
656 assert!((boosted[&1] - 4.0).abs() < 1e-6);
657 }
658
659 #[test]
662 fn boost_stem_matches_prefix() {
663 let chunks = vec![chunk("src/parser.rs", "fn run() {}")];
669 let mut boosted: HashMap<usize, f32> = HashMap::from([(0, 1.0_f32)]);
670 boost_stem_matches(&mut boosted, "parse json structure", 1.0, &chunks);
671 assert!(boosted[&0] > 1.0, "expected stem-match boost on parser.rs");
672 }
673
674 #[test]
677 fn boost_embedded_symbols_half_strength() {
678 let chunks = vec![chunk("src/myclass.rs", "struct MyClass {}")];
681 let mut boosted: HashMap<usize, f32> = HashMap::from([(0, 1.0_f32)]);
682 boost_embedded_symbols(&mut boosted, "how does MyClass handle errors", 1.0, &chunks);
683 assert!((boosted[&0] - 3.25).abs() < 1e-6, "got {}", boosted[&0]);
686 }
687
688 #[test]
691 fn boost_multi_chunk_files() {
692 let chunks = vec![
698 chunk("src/foo.rs", ""),
699 chunk("src/foo.rs", ""),
700 chunk("src/bar.rs", ""),
701 ];
702 let mut scores: HashMap<usize, f32> = HashMap::from([(0, 1.0), (1, 0.5), (2, 1.0)]);
703 super::boost_multi_chunk_files(&mut scores, &chunks);
704 assert!((scores[&0] - 1.2).abs() < 1e-6, "got {}", scores[&0]);
705 assert!((scores[&1] - 0.5).abs() < 1e-6, "non-best chunk unchanged");
706 let expected_bar = 1.0 + 0.2 * (1.0 / 1.5);
707 assert!(
708 (scores[&2] - expected_bar).abs() < 1e-6,
709 "got {}, expected {}",
710 scores[&2],
711 expected_bar
712 );
713 }
714
715 #[test]
718 fn property_symbol_regex_parity_python() {
719 let symbols = &[
721 "Client",
722 "handle_request",
723 "_private",
724 "getX",
725 "XMLParser",
726 "foo::bar",
727 "foo.bar.baz",
728 "a->b",
729 r"Foo\Bar",
730 "__init__",
731 "snake_case",
732 ];
733 for q in symbols {
734 assert!(is_symbol_query(q), "expected symbol query: {q:?}");
735 }
736 let non_symbols = &[
738 "session",
739 "retry",
740 "authentication",
741 "how does retry work",
742 "user authentication flow",
743 "hi",
744 ];
745 for q in non_symbols {
746 assert!(!is_symbol_query(q), "expected NL: {q:?}");
747 }
748 }
749}