1use std::collections::{HashMap, HashSet};
21use std::path::Path;
22use std::sync::OnceLock;
23
24use regex::{Regex, RegexBuilder};
25
26use crate::chunk::CodeChunk;
27use crate::encoder::ripvec::tokens::split_identifier;
28
29pub const ALPHA_SYMBOL: f32 = 0.3;
35pub const ALPHA_NL: f32 = 0.5;
37
38#[must_use]
44pub fn resolve_alpha(query: &str, alpha: Option<f32>) -> f32 {
45 if let Some(w) = alpha {
46 return w;
47 }
48 if is_symbol_query(query) {
49 ALPHA_SYMBOL
50 } else {
51 ALPHA_NL
52 }
53}
54
55fn symbol_query_re() -> &'static Regex {
62 static RE: OnceLock<Regex> = OnceLock::new();
63 RE.get_or_init(|| {
64 Regex::new(concat!(
65 r"^(?:",
66 r"[A-Za-z_][A-Za-z0-9_]*(?:(?:::|\\|->|\.)[A-Za-z_][A-Za-z0-9_]*)+",
68 r"|_[A-Za-z0-9_]*",
70 r"|[A-Za-z][A-Za-z0-9]*[A-Z_][A-Za-z0-9_]*",
72 r"|[A-Z][A-Za-z0-9]*",
74 r")$",
75 ))
76 .expect("symbol-query regex compiles")
77 })
78}
79
80fn embedded_symbol_re() -> &'static Regex {
83 static RE: OnceLock<Regex> = OnceLock::new();
84 RE.get_or_init(|| {
85 Regex::new(concat!(
86 r"\b(?:",
87 r"[A-Z][a-z][a-zA-Z0-9]*[A-Z][a-zA-Z0-9]*",
89 r"|[a-z][a-zA-Z0-9]*[A-Z][a-zA-Z0-9]+",
91 r")\b",
92 ))
93 .expect("embedded-symbol regex compiles")
94 })
95}
96
97#[must_use]
101pub fn is_symbol_query(query: &str) -> bool {
102 symbol_query_re().is_match(query.trim())
103}
104
105#[must_use]
115pub fn is_prose_path(path: &str) -> bool {
116 let lower = path.to_ascii_lowercase();
117 let ext = lower.rsplit('.').next().unwrap_or("");
118 matches!(
119 ext,
120 "md" | "markdown" | "mdx" | "rst" | "txt" | "text" | "adoc" | "asciidoc" | "org"
121 )
122}
123
124const DEFINITION_KEYWORDS: &[&str] = &[
132 "class",
133 "module",
134 "defmodule", "def",
136 "interface",
137 "struct",
138 "enum",
139 "trait",
140 "type",
141 "func",
142 "function",
143 "object",
144 "abstract class",
145 "data class",
146 "fn",
147 "fun", "package",
149 "namespace",
150 "protocol", "record", "typedef", ];
154
155const SQL_DEFINITION_KEYWORDS: &[&str] = &[
158 "CREATE TABLE",
159 "CREATE VIEW",
160 "CREATE PROCEDURE",
161 "CREATE FUNCTION",
162];
163
164const DEFINITION_BOOST_MULTIPLIER: f32 = 3.0;
165const STEM_BOOST_MULTIPLIER: f32 = 1.0;
166const FILE_COHERENCE_BOOST_FRAC: f32 = 0.2;
167const EMBEDDED_SYMBOL_BOOST_SCALE: f32 = 0.5;
168const EMBEDDED_STEM_MIN_LEN: usize = 4;
169
170const STOPWORDS: &[&str] = &[
173 "a", "an", "and", "are", "as", "at", "be", "by", "do", "does", "for", "from", "has", "have",
174 "how", "if", "in", "is", "it", "not", "of", "on", "or", "the", "to", "was", "what", "when",
175 "where", "which", "who", "why", "with",
176];
177
178fn definition_pattern_uncached(symbol_name: &str) -> (Regex, Regex) {
192 let escaped = regex::escape(symbol_name);
193 let ns_prefix = r"(?:[A-Za-z_][A-Za-z0-9_]*(?:\.|::))*";
194 let def_body = DEFINITION_KEYWORDS
199 .iter()
200 .map(|k| regex::escape(k))
201 .collect::<Vec<_>>()
202 .join("|");
203 let sql_body = SQL_DEFINITION_KEYWORDS
204 .iter()
205 .map(|k| regex::escape(k))
206 .collect::<Vec<_>>()
207 .join("|");
208 let suffix = format!(r")\s+{ns_prefix}{escaped}(?:\s|[<({{:\[;]|$)");
209 let no_lookbehind_prefix = r"(?:^|\s)(?:";
215 let general_pat = format!("{no_lookbehind_prefix}{def_body}{suffix}");
216 let sql_pat = format!("{no_lookbehind_prefix}{sql_body}{suffix}");
217 let general = RegexBuilder::new(&general_pat)
218 .multi_line(true)
219 .build()
220 .expect("general definition regex compiles");
221 let sql = RegexBuilder::new(&sql_pat)
222 .multi_line(true)
223 .case_insensitive(true)
224 .build()
225 .expect("SQL definition regex compiles");
226 (general, sql)
227}
228
229fn definition_pattern(symbol_name: &str) -> (Regex, Regex) {
237 use std::sync::{Mutex, OnceLock};
238 static CACHE: OnceLock<Mutex<std::collections::HashMap<String, (Regex, Regex)>>> =
239 OnceLock::new();
240 let cache = CACHE.get_or_init(|| Mutex::new(std::collections::HashMap::new()));
241
242 if let Ok(map) = cache.lock()
243 && let Some(entry) = map.get(symbol_name)
244 {
245 return entry.clone();
246 }
247
248 let pair = definition_pattern_uncached(symbol_name);
249 if let Ok(mut map) = cache.lock()
250 && map.len() < 256
251 {
252 map.insert(symbol_name.to_string(), pair.clone());
253 }
254 pair
255}
256
257fn chunk_defines_symbol(content: &str, symbol_name: &str) -> bool {
271 if !content.contains(symbol_name) {
272 return false;
273 }
274 let (general, sql) = definition_pattern(symbol_name);
275 general.is_match(content) || sql.is_match(content)
276}
277
278fn stem_matches(stem: &str, name: &str) -> bool {
281 let stem_norm = stem.replace('_', "");
282 stem == name
283 || stem_norm == name
284 || stem.trim_end_matches('s') == name
285 || stem_norm.trim_end_matches('s') == name
286}
287
288fn extract_symbol_name(query: &str) -> String {
297 for separator in &["::", "\\", "->", "."] {
298 if let Some(idx) = query.rfind(separator) {
299 return query[idx + separator.len()..].to_string();
300 }
301 }
302 query.trim().to_string()
303}
304
305fn definition_tier(chunk: &CodeChunk, names: &HashSet<String>, boost_unit: f32) -> f32 {
308 let any_match = names
309 .iter()
310 .any(|name| chunk_defines_symbol(&chunk.content, name));
311 if !any_match {
312 return 0.0;
313 }
314 let stem = Path::new(&chunk.file_path)
315 .file_stem()
316 .and_then(|s| s.to_str())
317 .unwrap_or_default()
318 .to_ascii_lowercase();
319 let stem_match_bonus = names
320 .iter()
321 .any(|name| stem_matches(&stem, &name.to_ascii_lowercase()));
322 boost_unit * if stem_match_bonus { 1.5 } else { 1.0 }
323}
324
325fn scan_non_candidates(
328 boosted: &mut HashMap<usize, f32>,
329 names: &HashSet<String>,
330 boost_unit: f32,
331 all_chunks: &[CodeChunk],
332 stem_ok: &dyn Fn(&str) -> bool,
333) {
334 for (idx, chunk) in all_chunks.iter().enumerate() {
335 if boosted.contains_key(&idx) {
336 continue;
337 }
338 let stem = Path::new(&chunk.file_path)
339 .file_stem()
340 .and_then(|s| s.to_str())
341 .unwrap_or_default()
342 .to_ascii_lowercase();
343 if !stem_ok(&stem) {
344 continue;
345 }
346 let tier = definition_tier(chunk, names, boost_unit);
347 if tier > 0.0 {
348 boosted.insert(idx, tier);
349 }
350 }
351}
352
353fn boost_symbol_definitions(
356 boosted: &mut HashMap<usize, f32>,
357 query: &str,
358 max_score: f32,
359 all_chunks: &[CodeChunk],
360) {
361 let symbol_name = extract_symbol_name(query);
362 let trimmed_query = query.trim();
363 let mut names: HashSet<String> = HashSet::new();
364 names.insert(symbol_name.clone());
365 if symbol_name != trimmed_query {
366 names.insert(trimmed_query.to_string());
367 }
368
369 let boost_unit = max_score * DEFINITION_BOOST_MULTIPLIER;
370
371 let candidate_indices: Vec<usize> = boosted.keys().copied().collect();
373 for idx in candidate_indices {
374 let tier = definition_tier(&all_chunks[idx], &names, boost_unit);
375 if tier > 0.0 {
376 *boosted.entry(idx).or_insert(0.0) += tier;
377 }
378 }
379
380 let symbol_lower = symbol_name.to_ascii_lowercase();
382 scan_non_candidates(boosted, &names, boost_unit, all_chunks, &|stem: &str| {
383 stem_matches(stem, &symbol_lower)
384 });
385}
386
387fn boost_embedded_symbols(
390 boosted: &mut HashMap<usize, f32>,
391 query: &str,
392 max_score: f32,
393 all_chunks: &[CodeChunk],
394) {
395 let names: HashSet<String> = embedded_symbol_re()
396 .find_iter(query)
397 .map(|m| m.as_str().to_string())
398 .collect();
399 if names.is_empty() {
400 return;
401 }
402
403 let boost_unit = max_score * DEFINITION_BOOST_MULTIPLIER * EMBEDDED_SYMBOL_BOOST_SCALE;
404
405 let candidate_indices: Vec<usize> = boosted.keys().copied().collect();
407 for idx in candidate_indices {
408 let tier = definition_tier(&all_chunks[idx], &names, boost_unit);
409 if tier > 0.0 {
410 *boosted.entry(idx).or_insert(0.0) += tier;
411 }
412 }
413
414 let symbols_lower: Vec<String> = names.iter().map(|n| n.to_ascii_lowercase()).collect();
416 let symbols_lower_for_scan = symbols_lower.clone();
417 scan_non_candidates(
418 boosted,
419 &names,
420 boost_unit,
421 all_chunks,
422 &move |stem: &str| {
423 let stem_norm = stem.replace('_', "");
424 symbols_lower_for_scan.iter().any(|sym_lower| {
425 stem == sym_lower
426 || stem_norm == *sym_lower
427 || (stem.len() >= EMBEDDED_STEM_MIN_LEN && sym_lower.starts_with(stem))
428 || (stem_norm.len() >= EMBEDDED_STEM_MIN_LEN
429 && sym_lower.starts_with(stem_norm.as_str()))
430 })
431 },
432 );
433}
434
435fn count_keyword_matches(keywords: &HashSet<String>, parts: &HashSet<String>) -> usize {
439 let exact: HashSet<&String> = keywords.iter().filter(|k| parts.contains(*k)).collect();
440 if exact.len() == keywords.len() {
441 return exact.len();
442 }
443 let mut n = exact.len();
444 for keyword in keywords {
445 if exact.contains(keyword) {
446 continue;
447 }
448 for part in parts {
449 let (shorter, longer) = if keyword.len() <= part.len() {
450 (keyword.as_str(), part.as_str())
451 } else {
452 (part.as_str(), keyword.as_str())
453 };
454 if shorter.len() >= 3 && longer.starts_with(shorter) {
455 n += 1;
456 break;
457 }
458 }
459 }
460 n
461}
462
463fn boost_stem_matches(
468 boosted: &mut HashMap<usize, f32>,
469 query: &str,
470 max_score: f32,
471 chunks: &[CodeChunk],
472) {
473 static KEYWORD_RE: OnceLock<Regex> = OnceLock::new();
474 let keyword_re =
475 KEYWORD_RE.get_or_init(|| Regex::new(r"[a-zA-Z_][a-zA-Z0-9_]*").expect("keyword regex"));
476 let keywords: HashSet<String> = keyword_re
477 .find_iter(query)
478 .map(|m| m.as_str().to_ascii_lowercase())
479 .filter(|w| w.len() > 2 && !STOPWORDS.contains(&w.as_str()))
480 .collect();
481 if keywords.is_empty() {
482 return;
483 }
484
485 let boost = max_score * STEM_BOOST_MULTIPLIER;
486 let mut path_cache: HashMap<String, HashSet<String>> = HashMap::new();
487 let candidate_indices: Vec<usize> = boosted.keys().copied().collect();
488 for idx in candidate_indices {
489 let path = &chunks[idx].file_path;
490 let parts = path_cache
491 .entry(path.clone())
492 .or_insert_with(|| {
493 let mut parts: HashSet<String> = HashSet::new();
494 let p = Path::new(path);
495 if let Some(stem) = p.file_stem().and_then(|s| s.to_str()) {
496 parts.extend(split_identifier(stem));
497 }
498 if let Some(parent_name) = p
499 .parent()
500 .and_then(Path::file_name)
501 .and_then(|s| s.to_str())
502 && !parent_name.is_empty()
503 && parent_name != "."
504 && parent_name != ".."
505 {
506 parts.extend(split_identifier(parent_name));
507 }
508 parts
509 })
510 .clone();
511 let n_matches = count_keyword_matches(&keywords, &parts);
512 if n_matches > 0 {
513 let match_ratio = n_matches as f32 / keywords.len() as f32;
514 if match_ratio >= 0.10 {
515 *boosted.entry(idx).or_insert(0.0) += boost * match_ratio;
516 }
517 }
518 }
519}
520
521#[expect(
536 clippy::implicit_hasher,
537 reason = "internal API; callers in the semble pipeline use the default RandomState"
538)]
539#[must_use]
540pub fn apply_query_boost(
541 combined_scores: &HashMap<usize, f32>,
542 query: &str,
543 all_chunks: &[CodeChunk],
544) -> HashMap<usize, f32> {
545 if combined_scores.is_empty() {
546 return HashMap::new();
547 }
548 let max_score = combined_scores
549 .values()
550 .copied()
551 .fold(f32::NEG_INFINITY, f32::max);
552 let mut boosted = combined_scores.clone();
553 if is_symbol_query(query) {
554 boost_symbol_definitions(&mut boosted, query, max_score, all_chunks);
555 } else {
556 boost_stem_matches(&mut boosted, query, max_score, all_chunks);
557 boost_embedded_symbols(&mut boosted, query, max_score, all_chunks);
558 }
559 boosted
560}
561
562#[expect(
565 clippy::implicit_hasher,
566 reason = "internal API; callers in the semble pipeline use the default RandomState"
567)]
568pub fn boost_multi_chunk_files(scores: &mut HashMap<usize, f32>, chunks: &[CodeChunk]) {
569 if scores.is_empty() {
570 return;
571 }
572 let max_score = scores.values().copied().fold(f32::NEG_INFINITY, f32::max);
573 if max_score == 0.0 || !max_score.is_finite() {
574 return;
575 }
576
577 let mut file_sum: HashMap<String, f32> = HashMap::new();
578 let mut best_chunk_idx: HashMap<String, usize> = HashMap::new();
579 for (&idx, &score) in scores.iter() {
580 let path = chunks[idx].file_path.clone();
581 *file_sum.entry(path.clone()).or_insert(0.0) += score;
582 match best_chunk_idx.get(&path) {
583 Some(&best) if scores[&best] >= score => {}
584 _ => {
585 best_chunk_idx.insert(path, idx);
586 }
587 }
588 }
589 let max_file_sum = file_sum.values().copied().fold(f32::NEG_INFINITY, f32::max);
590 if max_file_sum <= 0.0 || !max_file_sum.is_finite() {
591 return;
592 }
593 let boost_unit = max_score * FILE_COHERENCE_BOOST_FRAC;
594 for (path, &idx) in &best_chunk_idx {
595 let contribution = boost_unit * file_sum[path] / max_file_sum;
596 *scores.entry(idx).or_insert(0.0) += contribution;
597 }
598}
599
600#[cfg(test)]
601mod tests {
602 use super::*;
603
604 fn chunk(path: &str, content: &str) -> CodeChunk {
605 CodeChunk {
606 file_path: path.to_string(),
607 name: String::new(),
608 kind: String::new(),
609 content_kind: crate::chunk::ContentKind::Code,
610 start_line: 1,
611 symbol_line: 1,
612 end_line: 1,
613 content: content.to_string(),
614 enriched_content: content.to_string(),
615 qualified_name: None,
616 }
617 }
618
619 #[test]
622 fn is_symbol_query_namespace() {
623 assert!(is_symbol_query("Sinatra::Base"));
624 assert!(is_symbol_query("module.Class"));
625 assert!(is_symbol_query("a->b->c"));
626 assert!(is_symbol_query(r"Foo\Bar"));
627 }
628
629 #[test]
630 fn is_symbol_query_pascal() {
631 assert!(is_symbol_query("Client"));
632 assert!(is_symbol_query("HTTPHandler"));
633 assert!(is_symbol_query("XMLParser"));
634 }
635
636 #[test]
637 fn is_symbol_query_plain_word_rejected() {
638 assert!(!is_symbol_query("session"));
639 assert!(!is_symbol_query("retry"));
640 assert!(!is_symbol_query("authentication"));
641 }
642
643 #[test]
644 fn is_symbol_query_leading_underscore_accepted() {
645 assert!(is_symbol_query("_private"));
646 assert!(is_symbol_query("__init__"));
647 }
648
649 #[test]
652 fn resolve_alpha_symbol_0_3() {
653 assert!((resolve_alpha("Client", None) - ALPHA_SYMBOL).abs() < 1e-6);
654 assert!((resolve_alpha("foo.Bar", None) - ALPHA_SYMBOL).abs() < 1e-6);
655 }
656
657 #[test]
658 fn resolve_alpha_nl_0_5() {
659 assert!((resolve_alpha("how does retry work", None) - ALPHA_NL).abs() < 1e-6);
660 assert!((resolve_alpha("authentication handling", None) - ALPHA_NL).abs() < 1e-6);
661 }
662
663 #[test]
664 fn resolve_alpha_explicit_override_wins() {
665 assert!((resolve_alpha("Client", Some(0.7)) - 0.7).abs() < 1e-6);
666 }
667
668 #[test]
671 fn chunk_defines_symbol_class() {
672 let content = "class Client:\n pass";
673 assert!(chunk_defines_symbol(content, "Client"));
674 }
675
676 #[test]
677 fn chunk_defines_symbol_def() {
678 let content = "def handle_request():\n pass";
679 assert!(chunk_defines_symbol(content, "handle_request"));
680 }
681
682 #[test]
683 fn chunk_defines_symbol_namespace_qualified() {
684 let content = " defmodule Phoenix.Router do\n";
686 assert!(chunk_defines_symbol(content, "Router"));
687 }
688
689 #[test]
690 fn chunk_defines_symbol_sql_case_insensitive() {
691 assert!(chunk_defines_symbol(
692 " create table users (id int)",
693 "users"
694 ));
695 assert!(chunk_defines_symbol(
696 " CREATE TABLE Users (id int)",
697 "Users"
698 ));
699 }
700
701 #[test]
702 fn chunk_defines_symbol_negative() {
703 assert!(!chunk_defines_symbol("client.do_thing()", "Client"));
704 }
705
706 #[test]
709 fn boost_symbol_definitions_stem_multiplier() {
710 let chunks = vec![
714 chunk("src/client.rs", "struct Client { /* ... */ }"),
715 chunk("src/unrelated.rs", "struct Client { /* ... */ }"),
716 ];
717 let mut boosted: HashMap<usize, f32> = HashMap::from([(0, 1.0), (1, 1.0)]);
718 boost_symbol_definitions(&mut boosted, "Client", 1.0, &chunks);
719 assert!((boosted[&0] - 5.5).abs() < 1e-6);
723 assert!((boosted[&1] - 4.0).abs() < 1e-6);
724 }
725
726 #[test]
729 fn boost_stem_matches_prefix() {
730 let chunks = vec![chunk("src/parser.rs", "fn run() {}")];
736 let mut boosted: HashMap<usize, f32> = HashMap::from([(0, 1.0_f32)]);
737 boost_stem_matches(&mut boosted, "parse json structure", 1.0, &chunks);
738 assert!(boosted[&0] > 1.0, "expected stem-match boost on parser.rs");
739 }
740
741 #[test]
744 fn boost_embedded_symbols_half_strength() {
745 let chunks = vec![chunk("src/myclass.rs", "struct MyClass {}")];
748 let mut boosted: HashMap<usize, f32> = HashMap::from([(0, 1.0_f32)]);
749 boost_embedded_symbols(&mut boosted, "how does MyClass handle errors", 1.0, &chunks);
750 assert!((boosted[&0] - 3.25).abs() < 1e-6, "got {}", boosted[&0]);
753 }
754
755 #[test]
758 fn boost_multi_chunk_files() {
759 let chunks = vec![
765 chunk("src/foo.rs", ""),
766 chunk("src/foo.rs", ""),
767 chunk("src/bar.rs", ""),
768 ];
769 let mut scores: HashMap<usize, f32> = HashMap::from([(0, 1.0), (1, 0.5), (2, 1.0)]);
770 super::boost_multi_chunk_files(&mut scores, &chunks);
771 assert!((scores[&0] - 1.2).abs() < 1e-6, "got {}", scores[&0]);
772 assert!((scores[&1] - 0.5).abs() < 1e-6, "non-best chunk unchanged");
773 let expected_bar = 1.0 + 0.2 * (1.0 / 1.5);
774 assert!(
775 (scores[&2] - expected_bar).abs() < 1e-6,
776 "got {}, expected {}",
777 scores[&2],
778 expected_bar
779 );
780 }
781
782 #[test]
785 fn property_symbol_regex_parity_python() {
786 let symbols = &[
788 "Client",
789 "handle_request",
790 "_private",
791 "getX",
792 "XMLParser",
793 "foo::bar",
794 "foo.bar.baz",
795 "a->b",
796 r"Foo\Bar",
797 "__init__",
798 "snake_case",
799 ];
800 for q in symbols {
801 assert!(is_symbol_query(q), "expected symbol query: {q:?}");
802 }
803 let non_symbols = &[
805 "session",
806 "retry",
807 "authentication",
808 "how does retry work",
809 "user authentication flow",
810 "hi",
811 ];
812 for q in non_symbols {
813 assert!(!is_symbol_query(q), "expected NL: {q:?}");
814 }
815 }
816}