1use std::collections::{HashMap, HashSet};
21use std::path::Path;
22use std::sync::OnceLock;
23
24use regex::{Regex, RegexBuilder};
25
26use crate::chunk::CodeChunk;
27use crate::encoder::ripvec::tokens::split_identifier;
28
29pub const ALPHA_SYMBOL: f32 = 0.3;
35pub const ALPHA_NL: f32 = 0.5;
37
38#[must_use]
44pub fn resolve_alpha(query: &str, alpha: Option<f32>) -> f32 {
45 if let Some(w) = alpha {
46 return w;
47 }
48 if is_symbol_query(query) {
49 ALPHA_SYMBOL
50 } else {
51 ALPHA_NL
52 }
53}
54
55fn symbol_query_re() -> &'static Regex {
62 static RE: OnceLock<Regex> = OnceLock::new();
63 RE.get_or_init(|| {
64 Regex::new(concat!(
65 r"^(?:",
66 r"[A-Za-z_][A-Za-z0-9_]*(?:(?:::|\\|->|\.)[A-Za-z_][A-Za-z0-9_]*)+",
68 r"|_[A-Za-z0-9_]*",
70 r"|[A-Za-z][A-Za-z0-9]*[A-Z_][A-Za-z0-9_]*",
72 r"|[A-Z][A-Za-z0-9]*",
74 r")$",
75 ))
76 .expect("symbol-query regex compiles")
77 })
78}
79
80fn embedded_symbol_re() -> &'static Regex {
83 static RE: OnceLock<Regex> = OnceLock::new();
84 RE.get_or_init(|| {
85 Regex::new(concat!(
86 r"\b(?:",
87 r"[A-Z][a-z][a-zA-Z0-9]*[A-Z][a-zA-Z0-9]*",
89 r"|[a-z][a-zA-Z0-9]*[A-Z][a-zA-Z0-9]+",
91 r")\b",
92 ))
93 .expect("embedded-symbol regex compiles")
94 })
95}
96
97#[must_use]
101pub fn is_symbol_query(query: &str) -> bool {
102 symbol_query_re().is_match(query.trim())
103}
104
105#[must_use]
115pub fn is_prose_path(path: &str) -> bool {
116 let lower = path.to_ascii_lowercase();
117 let ext = lower.rsplit('.').next().unwrap_or("");
118 matches!(
119 ext,
120 "md" | "markdown" | "mdx" | "rst" | "txt" | "text" | "adoc" | "asciidoc" | "org"
121 )
122}
123
124const DEFINITION_KEYWORDS: &[&str] = &[
132 "class",
133 "module",
134 "defmodule", "def",
136 "interface",
137 "struct",
138 "enum",
139 "trait",
140 "type",
141 "func",
142 "function",
143 "object",
144 "abstract class",
145 "data class",
146 "fn",
147 "fun", "package",
149 "namespace",
150 "protocol", "record", "typedef", ];
154
155const SQL_DEFINITION_KEYWORDS: &[&str] = &[
158 "CREATE TABLE",
159 "CREATE VIEW",
160 "CREATE PROCEDURE",
161 "CREATE FUNCTION",
162];
163
164const DEFINITION_BOOST_MULTIPLIER: f32 = 3.0;
165const STEM_BOOST_MULTIPLIER: f32 = 1.0;
166const FILE_COHERENCE_BOOST_FRAC: f32 = 0.2;
167const EMBEDDED_SYMBOL_BOOST_SCALE: f32 = 0.5;
168const EMBEDDED_STEM_MIN_LEN: usize = 4;
169
170const STOPWORDS: &[&str] = &[
173 "a", "an", "and", "are", "as", "at", "be", "by", "do", "does", "for", "from", "has", "have",
174 "how", "if", "in", "is", "it", "not", "of", "on", "or", "the", "to", "was", "what", "when",
175 "where", "which", "who", "why", "with",
176];
177
178fn definition_pattern_uncached(symbol_name: &str) -> (Regex, Regex) {
192 let escaped = regex::escape(symbol_name);
193 let ns_prefix = r"(?:[A-Za-z_][A-Za-z0-9_]*(?:\.|::))*";
194 let def_body = DEFINITION_KEYWORDS
199 .iter()
200 .map(|k| regex::escape(k))
201 .collect::<Vec<_>>()
202 .join("|");
203 let sql_body = SQL_DEFINITION_KEYWORDS
204 .iter()
205 .map(|k| regex::escape(k))
206 .collect::<Vec<_>>()
207 .join("|");
208 let suffix = format!(r")\s+{ns_prefix}{escaped}(?:\s|[<({{:\[;]|$)");
209 let no_lookbehind_prefix = r"(?:^|\s)(?:";
215 let general_pat = format!("{no_lookbehind_prefix}{def_body}{suffix}");
216 let sql_pat = format!("{no_lookbehind_prefix}{sql_body}{suffix}");
217 let general = RegexBuilder::new(&general_pat)
218 .multi_line(true)
219 .build()
220 .expect("general definition regex compiles");
221 let sql = RegexBuilder::new(&sql_pat)
222 .multi_line(true)
223 .case_insensitive(true)
224 .build()
225 .expect("SQL definition regex compiles");
226 (general, sql)
227}
228
229fn definition_pattern(symbol_name: &str) -> (Regex, Regex) {
237 use std::sync::{Mutex, OnceLock};
238 static CACHE: OnceLock<Mutex<std::collections::HashMap<String, (Regex, Regex)>>> =
239 OnceLock::new();
240 let cache = CACHE.get_or_init(|| Mutex::new(std::collections::HashMap::new()));
241
242 if let Ok(map) = cache.lock()
243 && let Some(entry) = map.get(symbol_name)
244 {
245 return entry.clone();
246 }
247
248 let pair = definition_pattern_uncached(symbol_name);
249 if let Ok(mut map) = cache.lock()
250 && map.len() < 256
251 {
252 map.insert(symbol_name.to_string(), pair.clone());
253 }
254 pair
255}
256
257fn chunk_defines_symbol(content: &str, symbol_name: &str) -> bool {
271 if !content.contains(symbol_name) {
272 return false;
273 }
274 let (general, sql) = definition_pattern(symbol_name);
275 general.is_match(content) || sql.is_match(content)
276}
277
278fn stem_matches(stem: &str, name: &str) -> bool {
281 let stem_norm = stem.replace('_', "");
282 stem == name
283 || stem_norm == name
284 || stem.trim_end_matches('s') == name
285 || stem_norm.trim_end_matches('s') == name
286}
287
288fn extract_symbol_name(query: &str) -> String {
297 for separator in &["::", "\\", "->", "."] {
298 if let Some(idx) = query.rfind(separator) {
299 return query[idx + separator.len()..].to_string();
300 }
301 }
302 query.trim().to_string()
303}
304
305fn definition_tier(chunk: &CodeChunk, names: &HashSet<String>, boost_unit: f32) -> f32 {
308 let any_match = names
309 .iter()
310 .any(|name| chunk_defines_symbol(&chunk.content, name));
311 if !any_match {
312 return 0.0;
313 }
314 let stem = Path::new(&chunk.file_path)
315 .file_stem()
316 .and_then(|s| s.to_str())
317 .unwrap_or_default()
318 .to_ascii_lowercase();
319 let stem_match_bonus = names
320 .iter()
321 .any(|name| stem_matches(&stem, &name.to_ascii_lowercase()));
322 boost_unit * if stem_match_bonus { 1.5 } else { 1.0 }
323}
324
325fn scan_non_candidates(
328 boosted: &mut HashMap<usize, f32>,
329 names: &HashSet<String>,
330 boost_unit: f32,
331 all_chunks: &[CodeChunk],
332 stem_ok: &dyn Fn(&str) -> bool,
333) {
334 for (idx, chunk) in all_chunks.iter().enumerate() {
335 if boosted.contains_key(&idx) {
336 continue;
337 }
338 let stem = Path::new(&chunk.file_path)
339 .file_stem()
340 .and_then(|s| s.to_str())
341 .unwrap_or_default()
342 .to_ascii_lowercase();
343 if !stem_ok(&stem) {
344 continue;
345 }
346 let tier = definition_tier(chunk, names, boost_unit);
347 if tier > 0.0 {
348 boosted.insert(idx, tier);
349 }
350 }
351}
352
353fn boost_symbol_definitions(
356 boosted: &mut HashMap<usize, f32>,
357 query: &str,
358 max_score: f32,
359 all_chunks: &[CodeChunk],
360) {
361 let symbol_name = extract_symbol_name(query);
362 let trimmed_query = query.trim();
363 let mut names: HashSet<String> = HashSet::new();
364 names.insert(symbol_name.clone());
365 if symbol_name != trimmed_query {
366 names.insert(trimmed_query.to_string());
367 }
368
369 let boost_unit = max_score * DEFINITION_BOOST_MULTIPLIER;
370
371 let candidate_indices: Vec<usize> = boosted.keys().copied().collect();
373 for idx in candidate_indices {
374 let tier = definition_tier(&all_chunks[idx], &names, boost_unit);
375 if tier > 0.0 {
376 *boosted.entry(idx).or_insert(0.0) += tier;
377 }
378 }
379
380 let symbol_lower = symbol_name.to_ascii_lowercase();
382 scan_non_candidates(boosted, &names, boost_unit, all_chunks, &|stem: &str| {
383 stem_matches(stem, &symbol_lower)
384 });
385}
386
387fn boost_embedded_symbols(
390 boosted: &mut HashMap<usize, f32>,
391 query: &str,
392 max_score: f32,
393 all_chunks: &[CodeChunk],
394) {
395 let names: HashSet<String> = embedded_symbol_re()
396 .find_iter(query)
397 .map(|m| m.as_str().to_string())
398 .collect();
399 if names.is_empty() {
400 return;
401 }
402
403 let boost_unit = max_score * DEFINITION_BOOST_MULTIPLIER * EMBEDDED_SYMBOL_BOOST_SCALE;
404
405 let candidate_indices: Vec<usize> = boosted.keys().copied().collect();
407 for idx in candidate_indices {
408 let tier = definition_tier(&all_chunks[idx], &names, boost_unit);
409 if tier > 0.0 {
410 *boosted.entry(idx).or_insert(0.0) += tier;
411 }
412 }
413
414 let symbols_lower: Vec<String> = names.iter().map(|n| n.to_ascii_lowercase()).collect();
416 let symbols_lower_for_scan = symbols_lower.clone();
417 scan_non_candidates(
418 boosted,
419 &names,
420 boost_unit,
421 all_chunks,
422 &move |stem: &str| {
423 let stem_norm = stem.replace('_', "");
424 symbols_lower_for_scan.iter().any(|sym_lower| {
425 stem == sym_lower
426 || stem_norm == *sym_lower
427 || (stem.len() >= EMBEDDED_STEM_MIN_LEN && sym_lower.starts_with(stem))
428 || (stem_norm.len() >= EMBEDDED_STEM_MIN_LEN
429 && sym_lower.starts_with(stem_norm.as_str()))
430 })
431 },
432 );
433}
434
435fn count_keyword_matches(keywords: &HashSet<String>, parts: &HashSet<String>) -> usize {
439 let exact: HashSet<&String> = keywords.iter().filter(|k| parts.contains(*k)).collect();
440 if exact.len() == keywords.len() {
441 return exact.len();
442 }
443 let mut n = exact.len();
444 for keyword in keywords {
445 if exact.contains(keyword) {
446 continue;
447 }
448 for part in parts {
449 let (shorter, longer) = if keyword.len() <= part.len() {
450 (keyword.as_str(), part.as_str())
451 } else {
452 (part.as_str(), keyword.as_str())
453 };
454 if shorter.len() >= 3 && longer.starts_with(shorter) {
455 n += 1;
456 break;
457 }
458 }
459 }
460 n
461}
462
463fn boost_stem_matches(
468 boosted: &mut HashMap<usize, f32>,
469 query: &str,
470 max_score: f32,
471 chunks: &[CodeChunk],
472) {
473 static KEYWORD_RE: OnceLock<Regex> = OnceLock::new();
474 let keyword_re =
475 KEYWORD_RE.get_or_init(|| Regex::new(r"[a-zA-Z_][a-zA-Z0-9_]*").expect("keyword regex"));
476 let keywords: HashSet<String> = keyword_re
477 .find_iter(query)
478 .map(|m| m.as_str().to_ascii_lowercase())
479 .filter(|w| w.len() > 2 && !STOPWORDS.contains(&w.as_str()))
480 .collect();
481 if keywords.is_empty() {
482 return;
483 }
484
485 let boost = max_score * STEM_BOOST_MULTIPLIER;
486 let mut path_cache: HashMap<String, HashSet<String>> = HashMap::new();
487 let candidate_indices: Vec<usize> = boosted.keys().copied().collect();
488 for idx in candidate_indices {
489 let path = &chunks[idx].file_path;
490 let parts = path_cache
491 .entry(path.clone())
492 .or_insert_with(|| {
493 let mut parts: HashSet<String> = HashSet::new();
494 let p = Path::new(path);
495 if let Some(stem) = p.file_stem().and_then(|s| s.to_str()) {
496 parts.extend(split_identifier(stem));
497 }
498 if let Some(parent_name) = p
499 .parent()
500 .and_then(Path::file_name)
501 .and_then(|s| s.to_str())
502 && !parent_name.is_empty()
503 && parent_name != "."
504 && parent_name != ".."
505 {
506 parts.extend(split_identifier(parent_name));
507 }
508 parts
509 })
510 .clone();
511 let n_matches = count_keyword_matches(&keywords, &parts);
512 if n_matches > 0 {
513 let match_ratio = n_matches as f32 / keywords.len() as f32;
514 if match_ratio >= 0.10 {
515 *boosted.entry(idx).or_insert(0.0) += boost * match_ratio;
516 }
517 }
518 }
519}
520
521#[expect(
536 clippy::implicit_hasher,
537 reason = "internal API; callers in the semble pipeline use the default RandomState"
538)]
539#[must_use]
540pub fn apply_query_boost(
541 combined_scores: &HashMap<usize, f32>,
542 query: &str,
543 all_chunks: &[CodeChunk],
544) -> HashMap<usize, f32> {
545 if combined_scores.is_empty() {
546 return HashMap::new();
547 }
548 let max_score = combined_scores
549 .values()
550 .copied()
551 .fold(f32::NEG_INFINITY, f32::max);
552 let mut boosted = combined_scores.clone();
553 if is_symbol_query(query) {
554 boost_symbol_definitions(&mut boosted, query, max_score, all_chunks);
555 } else {
556 boost_stem_matches(&mut boosted, query, max_score, all_chunks);
557 boost_embedded_symbols(&mut boosted, query, max_score, all_chunks);
558 }
559 boosted
560}
561
562#[expect(
565 clippy::implicit_hasher,
566 reason = "internal API; callers in the semble pipeline use the default RandomState"
567)]
568pub fn boost_multi_chunk_files(scores: &mut HashMap<usize, f32>, chunks: &[CodeChunk]) {
569 if scores.is_empty() {
570 return;
571 }
572 let max_score = scores.values().copied().fold(f32::NEG_INFINITY, f32::max);
573 if max_score == 0.0 || !max_score.is_finite() {
574 return;
575 }
576
577 let mut file_sum: HashMap<String, f32> = HashMap::new();
578 let mut best_chunk_idx: HashMap<String, usize> = HashMap::new();
579 for (&idx, &score) in scores.iter() {
580 let path = chunks[idx].file_path.clone();
581 *file_sum.entry(path.clone()).or_insert(0.0) += score;
582 match best_chunk_idx.get(&path) {
583 Some(&best) if scores[&best] >= score => {}
584 _ => {
585 best_chunk_idx.insert(path, idx);
586 }
587 }
588 }
589 let max_file_sum = file_sum.values().copied().fold(f32::NEG_INFINITY, f32::max);
590 if max_file_sum <= 0.0 || !max_file_sum.is_finite() {
591 return;
592 }
593 let boost_unit = max_score * FILE_COHERENCE_BOOST_FRAC;
594 for (path, &idx) in &best_chunk_idx {
595 let contribution = boost_unit * file_sum[path] / max_file_sum;
596 *scores.entry(idx).or_insert(0.0) += contribution;
597 }
598}
599
600#[cfg(test)]
601mod tests {
602 use super::*;
603
604 fn chunk(path: &str, content: &str) -> CodeChunk {
605 CodeChunk {
606 file_path: path.to_string(),
607 name: String::new(),
608 kind: String::new(),
609 start_line: 1,
610 end_line: 1,
611 content: content.to_string(),
612 enriched_content: content.to_string(),
613 }
614 }
615
616 #[test]
619 fn is_symbol_query_namespace() {
620 assert!(is_symbol_query("Sinatra::Base"));
621 assert!(is_symbol_query("module.Class"));
622 assert!(is_symbol_query("a->b->c"));
623 assert!(is_symbol_query(r"Foo\Bar"));
624 }
625
626 #[test]
627 fn is_symbol_query_pascal() {
628 assert!(is_symbol_query("Client"));
629 assert!(is_symbol_query("HTTPHandler"));
630 assert!(is_symbol_query("XMLParser"));
631 }
632
633 #[test]
634 fn is_symbol_query_plain_word_rejected() {
635 assert!(!is_symbol_query("session"));
636 assert!(!is_symbol_query("retry"));
637 assert!(!is_symbol_query("authentication"));
638 }
639
640 #[test]
641 fn is_symbol_query_leading_underscore_accepted() {
642 assert!(is_symbol_query("_private"));
643 assert!(is_symbol_query("__init__"));
644 }
645
646 #[test]
649 fn resolve_alpha_symbol_0_3() {
650 assert!((resolve_alpha("Client", None) - ALPHA_SYMBOL).abs() < 1e-6);
651 assert!((resolve_alpha("foo.Bar", None) - ALPHA_SYMBOL).abs() < 1e-6);
652 }
653
654 #[test]
655 fn resolve_alpha_nl_0_5() {
656 assert!((resolve_alpha("how does retry work", None) - ALPHA_NL).abs() < 1e-6);
657 assert!((resolve_alpha("authentication handling", None) - ALPHA_NL).abs() < 1e-6);
658 }
659
660 #[test]
661 fn resolve_alpha_explicit_override_wins() {
662 assert!((resolve_alpha("Client", Some(0.7)) - 0.7).abs() < 1e-6);
663 }
664
665 #[test]
668 fn chunk_defines_symbol_class() {
669 let content = "class Client:\n pass";
670 assert!(chunk_defines_symbol(content, "Client"));
671 }
672
673 #[test]
674 fn chunk_defines_symbol_def() {
675 let content = "def handle_request():\n pass";
676 assert!(chunk_defines_symbol(content, "handle_request"));
677 }
678
679 #[test]
680 fn chunk_defines_symbol_namespace_qualified() {
681 let content = " defmodule Phoenix.Router do\n";
683 assert!(chunk_defines_symbol(content, "Router"));
684 }
685
686 #[test]
687 fn chunk_defines_symbol_sql_case_insensitive() {
688 assert!(chunk_defines_symbol(
689 " create table users (id int)",
690 "users"
691 ));
692 assert!(chunk_defines_symbol(
693 " CREATE TABLE Users (id int)",
694 "Users"
695 ));
696 }
697
698 #[test]
699 fn chunk_defines_symbol_negative() {
700 assert!(!chunk_defines_symbol("client.do_thing()", "Client"));
701 }
702
703 #[test]
706 fn boost_symbol_definitions_stem_multiplier() {
707 let chunks = vec![
711 chunk("src/client.rs", "struct Client { /* ... */ }"),
712 chunk("src/unrelated.rs", "struct Client { /* ... */ }"),
713 ];
714 let mut boosted: HashMap<usize, f32> = HashMap::from([(0, 1.0), (1, 1.0)]);
715 boost_symbol_definitions(&mut boosted, "Client", 1.0, &chunks);
716 assert!((boosted[&0] - 5.5).abs() < 1e-6);
720 assert!((boosted[&1] - 4.0).abs() < 1e-6);
721 }
722
723 #[test]
726 fn boost_stem_matches_prefix() {
727 let chunks = vec![chunk("src/parser.rs", "fn run() {}")];
733 let mut boosted: HashMap<usize, f32> = HashMap::from([(0, 1.0_f32)]);
734 boost_stem_matches(&mut boosted, "parse json structure", 1.0, &chunks);
735 assert!(boosted[&0] > 1.0, "expected stem-match boost on parser.rs");
736 }
737
738 #[test]
741 fn boost_embedded_symbols_half_strength() {
742 let chunks = vec![chunk("src/myclass.rs", "struct MyClass {}")];
745 let mut boosted: HashMap<usize, f32> = HashMap::from([(0, 1.0_f32)]);
746 boost_embedded_symbols(&mut boosted, "how does MyClass handle errors", 1.0, &chunks);
747 assert!((boosted[&0] - 3.25).abs() < 1e-6, "got {}", boosted[&0]);
750 }
751
752 #[test]
755 fn boost_multi_chunk_files() {
756 let chunks = vec![
762 chunk("src/foo.rs", ""),
763 chunk("src/foo.rs", ""),
764 chunk("src/bar.rs", ""),
765 ];
766 let mut scores: HashMap<usize, f32> = HashMap::from([(0, 1.0), (1, 0.5), (2, 1.0)]);
767 super::boost_multi_chunk_files(&mut scores, &chunks);
768 assert!((scores[&0] - 1.2).abs() < 1e-6, "got {}", scores[&0]);
769 assert!((scores[&1] - 0.5).abs() < 1e-6, "non-best chunk unchanged");
770 let expected_bar = 1.0 + 0.2 * (1.0 / 1.5);
771 assert!(
772 (scores[&2] - expected_bar).abs() < 1e-6,
773 "got {}, expected {}",
774 scores[&2],
775 expected_bar
776 );
777 }
778
779 #[test]
782 fn property_symbol_regex_parity_python() {
783 let symbols = &[
785 "Client",
786 "handle_request",
787 "_private",
788 "getX",
789 "XMLParser",
790 "foo::bar",
791 "foo.bar.baz",
792 "a->b",
793 r"Foo\Bar",
794 "__init__",
795 "snake_case",
796 ];
797 for q in symbols {
798 assert!(is_symbol_query(q), "expected symbol query: {q:?}");
799 }
800 let non_symbols = &[
802 "session",
803 "retry",
804 "authentication",
805 "how does retry work",
806 "user authentication flow",
807 "hi",
808 ];
809 for q in non_symbols {
810 assert!(!is_symbol_query(q), "expected NL: {q:?}");
811 }
812 }
813}