1use std::collections::{HashMap, HashSet};
21use std::path::Path;
22use std::sync::OnceLock;
23
24use regex::{Regex, RegexBuilder};
25
26use crate::chunk::CodeChunk;
27use crate::encoder::ripvec::tokens::split_identifier;
28
29pub const ALPHA_SYMBOL: f32 = 0.3;
35pub const ALPHA_NL: f32 = 0.5;
37
38#[must_use]
44pub fn resolve_alpha(query: &str, alpha: Option<f32>) -> f32 {
45 if let Some(w) = alpha {
46 return w;
47 }
48 if is_symbol_query(query) {
49 ALPHA_SYMBOL
50 } else {
51 ALPHA_NL
52 }
53}
54
55fn symbol_query_re() -> &'static Regex {
62 static RE: OnceLock<Regex> = OnceLock::new();
63 RE.get_or_init(|| {
64 Regex::new(concat!(
65 r"^(?:",
66 r"[A-Za-z_][A-Za-z0-9_]*(?:(?:::|\\|->|\.)[A-Za-z_][A-Za-z0-9_]*)+",
68 r"|_[A-Za-z0-9_]*",
70 r"|[A-Za-z][A-Za-z0-9]*[A-Z_][A-Za-z0-9_]*",
72 r"|[A-Z][A-Za-z0-9]*",
74 r")$",
75 ))
76 .expect("symbol-query regex compiles")
77 })
78}
79
80fn embedded_symbol_re() -> &'static Regex {
83 static RE: OnceLock<Regex> = OnceLock::new();
84 RE.get_or_init(|| {
85 Regex::new(concat!(
86 r"\b(?:",
87 r"[A-Z][a-z][a-zA-Z0-9]*[A-Z][a-zA-Z0-9]*",
89 r"|[a-z][a-zA-Z0-9]*[A-Z][a-zA-Z0-9]+",
91 r")\b",
92 ))
93 .expect("embedded-symbol regex compiles")
94 })
95}
96
97#[must_use]
101pub fn is_symbol_query(query: &str) -> bool {
102 symbol_query_re().is_match(query.trim())
103}
104
105#[must_use]
115pub fn is_prose_path(path: &str) -> bool {
116 let lower = path.to_ascii_lowercase();
117 let ext = lower.rsplit('.').next().unwrap_or("");
118 matches!(
119 ext,
120 "md" | "markdown" | "mdx" | "rst" | "txt" | "text" | "adoc" | "asciidoc" | "org"
121 )
122}
123
124const DEFINITION_KEYWORDS: &[&str] = &[
132 "class",
133 "module",
134 "defmodule", "def",
136 "interface",
137 "struct",
138 "enum",
139 "trait",
140 "type",
141 "func",
142 "function",
143 "object",
144 "abstract class",
145 "data class",
146 "fn",
147 "fun", "package",
149 "namespace",
150 "protocol", "record", "typedef", ];
154
155const SQL_DEFINITION_KEYWORDS: &[&str] = &[
158 "CREATE TABLE",
159 "CREATE VIEW",
160 "CREATE PROCEDURE",
161 "CREATE FUNCTION",
162];
163
164const DEFINITION_BOOST_MULTIPLIER: f32 = 3.0;
165const STEM_BOOST_MULTIPLIER: f32 = 1.0;
166const FILE_COHERENCE_BOOST_FRAC: f32 = 0.2;
167const EMBEDDED_SYMBOL_BOOST_SCALE: f32 = 0.5;
168const EMBEDDED_STEM_MIN_LEN: usize = 4;
169
170const STOPWORDS: &[&str] = &[
173 "a", "an", "and", "are", "as", "at", "be", "by", "do", "does", "for", "from", "has", "have",
174 "how", "if", "in", "is", "it", "not", "of", "on", "or", "the", "to", "was", "what", "when",
175 "where", "which", "who", "why", "with",
176];
177
178fn definition_pattern(symbol_name: &str) -> (Regex, Regex) {
185 let escaped = regex::escape(symbol_name);
186 let ns_prefix = r"(?:[A-Za-z_][A-Za-z0-9_]*(?:\.|::))*";
187 let def_body = DEFINITION_KEYWORDS
192 .iter()
193 .map(|k| regex::escape(k))
194 .collect::<Vec<_>>()
195 .join("|");
196 let sql_body = SQL_DEFINITION_KEYWORDS
197 .iter()
198 .map(|k| regex::escape(k))
199 .collect::<Vec<_>>()
200 .join("|");
201 let suffix = format!(r")\s+{ns_prefix}{escaped}(?:\s|[<({{:\[;]|$)");
202 let no_lookbehind_prefix = r"(?:^|\s)(?:";
208 let general_pat = format!("{no_lookbehind_prefix}{def_body}{suffix}");
209 let sql_pat = format!("{no_lookbehind_prefix}{sql_body}{suffix}");
210 let general = RegexBuilder::new(&general_pat)
211 .multi_line(true)
212 .build()
213 .expect("general definition regex compiles");
214 let sql = RegexBuilder::new(&sql_pat)
215 .multi_line(true)
216 .case_insensitive(true)
217 .build()
218 .expect("SQL definition regex compiles");
219 (general, sql)
220}
221
222fn chunk_defines_symbol(content: &str, symbol_name: &str) -> bool {
229 let (general, sql) = definition_pattern(symbol_name);
230 general.is_match(content) || sql.is_match(content)
231}
232
233fn stem_matches(stem: &str, name: &str) -> bool {
236 let stem_norm = stem.replace('_', "");
237 stem == name
238 || stem_norm == name
239 || stem.trim_end_matches('s') == name
240 || stem_norm.trim_end_matches('s') == name
241}
242
243fn extract_symbol_name(query: &str) -> String {
252 for separator in &["::", "\\", "->", "."] {
253 if let Some(idx) = query.rfind(separator) {
254 return query[idx + separator.len()..].to_string();
255 }
256 }
257 query.trim().to_string()
258}
259
260fn definition_tier(chunk: &CodeChunk, names: &HashSet<String>, boost_unit: f32) -> f32 {
263 let any_match = names
264 .iter()
265 .any(|name| chunk_defines_symbol(&chunk.content, name));
266 if !any_match {
267 return 0.0;
268 }
269 let stem = Path::new(&chunk.file_path)
270 .file_stem()
271 .and_then(|s| s.to_str())
272 .unwrap_or_default()
273 .to_ascii_lowercase();
274 let stem_match_bonus = names
275 .iter()
276 .any(|name| stem_matches(&stem, &name.to_ascii_lowercase()));
277 boost_unit * if stem_match_bonus { 1.5 } else { 1.0 }
278}
279
280fn scan_non_candidates(
283 boosted: &mut HashMap<usize, f32>,
284 names: &HashSet<String>,
285 boost_unit: f32,
286 all_chunks: &[CodeChunk],
287 stem_ok: &dyn Fn(&str) -> bool,
288) {
289 for (idx, chunk) in all_chunks.iter().enumerate() {
290 if boosted.contains_key(&idx) {
291 continue;
292 }
293 let stem = Path::new(&chunk.file_path)
294 .file_stem()
295 .and_then(|s| s.to_str())
296 .unwrap_or_default()
297 .to_ascii_lowercase();
298 if !stem_ok(&stem) {
299 continue;
300 }
301 let tier = definition_tier(chunk, names, boost_unit);
302 if tier > 0.0 {
303 boosted.insert(idx, tier);
304 }
305 }
306}
307
308fn boost_symbol_definitions(
311 boosted: &mut HashMap<usize, f32>,
312 query: &str,
313 max_score: f32,
314 all_chunks: &[CodeChunk],
315) {
316 let symbol_name = extract_symbol_name(query);
317 let trimmed_query = query.trim();
318 let mut names: HashSet<String> = HashSet::new();
319 names.insert(symbol_name.clone());
320 if symbol_name != trimmed_query {
321 names.insert(trimmed_query.to_string());
322 }
323
324 let boost_unit = max_score * DEFINITION_BOOST_MULTIPLIER;
325
326 let candidate_indices: Vec<usize> = boosted.keys().copied().collect();
328 for idx in candidate_indices {
329 let tier = definition_tier(&all_chunks[idx], &names, boost_unit);
330 if tier > 0.0 {
331 *boosted.entry(idx).or_insert(0.0) += tier;
332 }
333 }
334
335 let symbol_lower = symbol_name.to_ascii_lowercase();
337 scan_non_candidates(boosted, &names, boost_unit, all_chunks, &|stem: &str| {
338 stem_matches(stem, &symbol_lower)
339 });
340}
341
342fn boost_embedded_symbols(
345 boosted: &mut HashMap<usize, f32>,
346 query: &str,
347 max_score: f32,
348 all_chunks: &[CodeChunk],
349) {
350 let names: HashSet<String> = embedded_symbol_re()
351 .find_iter(query)
352 .map(|m| m.as_str().to_string())
353 .collect();
354 if names.is_empty() {
355 return;
356 }
357
358 let boost_unit = max_score * DEFINITION_BOOST_MULTIPLIER * EMBEDDED_SYMBOL_BOOST_SCALE;
359
360 let candidate_indices: Vec<usize> = boosted.keys().copied().collect();
362 for idx in candidate_indices {
363 let tier = definition_tier(&all_chunks[idx], &names, boost_unit);
364 if tier > 0.0 {
365 *boosted.entry(idx).or_insert(0.0) += tier;
366 }
367 }
368
369 let symbols_lower: Vec<String> = names.iter().map(|n| n.to_ascii_lowercase()).collect();
371 let symbols_lower_for_scan = symbols_lower.clone();
372 scan_non_candidates(
373 boosted,
374 &names,
375 boost_unit,
376 all_chunks,
377 &move |stem: &str| {
378 let stem_norm = stem.replace('_', "");
379 symbols_lower_for_scan.iter().any(|sym_lower| {
380 stem == sym_lower
381 || stem_norm == *sym_lower
382 || (stem.len() >= EMBEDDED_STEM_MIN_LEN && sym_lower.starts_with(stem))
383 || (stem_norm.len() >= EMBEDDED_STEM_MIN_LEN
384 && sym_lower.starts_with(stem_norm.as_str()))
385 })
386 },
387 );
388}
389
390fn count_keyword_matches(keywords: &HashSet<String>, parts: &HashSet<String>) -> usize {
394 let exact: HashSet<&String> = keywords.iter().filter(|k| parts.contains(*k)).collect();
395 if exact.len() == keywords.len() {
396 return exact.len();
397 }
398 let mut n = exact.len();
399 for keyword in keywords {
400 if exact.contains(keyword) {
401 continue;
402 }
403 for part in parts {
404 let (shorter, longer) = if keyword.len() <= part.len() {
405 (keyword.as_str(), part.as_str())
406 } else {
407 (part.as_str(), keyword.as_str())
408 };
409 if shorter.len() >= 3 && longer.starts_with(shorter) {
410 n += 1;
411 break;
412 }
413 }
414 }
415 n
416}
417
418fn boost_stem_matches(
423 boosted: &mut HashMap<usize, f32>,
424 query: &str,
425 max_score: f32,
426 chunks: &[CodeChunk],
427) {
428 static KEYWORD_RE: OnceLock<Regex> = OnceLock::new();
429 let keyword_re =
430 KEYWORD_RE.get_or_init(|| Regex::new(r"[a-zA-Z_][a-zA-Z0-9_]*").expect("keyword regex"));
431 let keywords: HashSet<String> = keyword_re
432 .find_iter(query)
433 .map(|m| m.as_str().to_ascii_lowercase())
434 .filter(|w| w.len() > 2 && !STOPWORDS.contains(&w.as_str()))
435 .collect();
436 if keywords.is_empty() {
437 return;
438 }
439
440 let boost = max_score * STEM_BOOST_MULTIPLIER;
441 let mut path_cache: HashMap<String, HashSet<String>> = HashMap::new();
442 let candidate_indices: Vec<usize> = boosted.keys().copied().collect();
443 for idx in candidate_indices {
444 let path = &chunks[idx].file_path;
445 let parts = path_cache
446 .entry(path.clone())
447 .or_insert_with(|| {
448 let mut parts: HashSet<String> = HashSet::new();
449 let p = Path::new(path);
450 if let Some(stem) = p.file_stem().and_then(|s| s.to_str()) {
451 parts.extend(split_identifier(stem));
452 }
453 if let Some(parent_name) = p
454 .parent()
455 .and_then(Path::file_name)
456 .and_then(|s| s.to_str())
457 && !parent_name.is_empty()
458 && parent_name != "."
459 && parent_name != ".."
460 {
461 parts.extend(split_identifier(parent_name));
462 }
463 parts
464 })
465 .clone();
466 let n_matches = count_keyword_matches(&keywords, &parts);
467 if n_matches > 0 {
468 let match_ratio = n_matches as f32 / keywords.len() as f32;
469 if match_ratio >= 0.10 {
470 *boosted.entry(idx).or_insert(0.0) += boost * match_ratio;
471 }
472 }
473 }
474}
475
476#[expect(
491 clippy::implicit_hasher,
492 reason = "internal API; callers in the semble pipeline use the default RandomState"
493)]
494#[must_use]
495pub fn apply_query_boost(
496 combined_scores: &HashMap<usize, f32>,
497 query: &str,
498 all_chunks: &[CodeChunk],
499) -> HashMap<usize, f32> {
500 if combined_scores.is_empty() {
501 return HashMap::new();
502 }
503 let max_score = combined_scores
504 .values()
505 .copied()
506 .fold(f32::NEG_INFINITY, f32::max);
507 let mut boosted = combined_scores.clone();
508 if is_symbol_query(query) {
509 boost_symbol_definitions(&mut boosted, query, max_score, all_chunks);
510 } else {
511 boost_stem_matches(&mut boosted, query, max_score, all_chunks);
512 boost_embedded_symbols(&mut boosted, query, max_score, all_chunks);
513 }
514 boosted
515}
516
517#[expect(
520 clippy::implicit_hasher,
521 reason = "internal API; callers in the semble pipeline use the default RandomState"
522)]
523pub fn boost_multi_chunk_files(scores: &mut HashMap<usize, f32>, chunks: &[CodeChunk]) {
524 if scores.is_empty() {
525 return;
526 }
527 let max_score = scores.values().copied().fold(f32::NEG_INFINITY, f32::max);
528 if max_score == 0.0 || !max_score.is_finite() {
529 return;
530 }
531
532 let mut file_sum: HashMap<String, f32> = HashMap::new();
533 let mut best_chunk_idx: HashMap<String, usize> = HashMap::new();
534 for (&idx, &score) in scores.iter() {
535 let path = chunks[idx].file_path.clone();
536 *file_sum.entry(path.clone()).or_insert(0.0) += score;
537 match best_chunk_idx.get(&path) {
538 Some(&best) if scores[&best] >= score => {}
539 _ => {
540 best_chunk_idx.insert(path, idx);
541 }
542 }
543 }
544 let max_file_sum = file_sum.values().copied().fold(f32::NEG_INFINITY, f32::max);
545 if max_file_sum <= 0.0 || !max_file_sum.is_finite() {
546 return;
547 }
548 let boost_unit = max_score * FILE_COHERENCE_BOOST_FRAC;
549 for (path, &idx) in &best_chunk_idx {
550 let contribution = boost_unit * file_sum[path] / max_file_sum;
551 *scores.entry(idx).or_insert(0.0) += contribution;
552 }
553}
554
555#[cfg(test)]
556mod tests {
557 use super::*;
558
559 fn chunk(path: &str, content: &str) -> CodeChunk {
560 CodeChunk {
561 file_path: path.to_string(),
562 name: String::new(),
563 kind: String::new(),
564 start_line: 1,
565 end_line: 1,
566 content: content.to_string(),
567 enriched_content: content.to_string(),
568 }
569 }
570
571 #[test]
574 fn is_symbol_query_namespace() {
575 assert!(is_symbol_query("Sinatra::Base"));
576 assert!(is_symbol_query("module.Class"));
577 assert!(is_symbol_query("a->b->c"));
578 assert!(is_symbol_query(r"Foo\Bar"));
579 }
580
581 #[test]
582 fn is_symbol_query_pascal() {
583 assert!(is_symbol_query("Client"));
584 assert!(is_symbol_query("HTTPHandler"));
585 assert!(is_symbol_query("XMLParser"));
586 }
587
588 #[test]
589 fn is_symbol_query_plain_word_rejected() {
590 assert!(!is_symbol_query("session"));
591 assert!(!is_symbol_query("retry"));
592 assert!(!is_symbol_query("authentication"));
593 }
594
595 #[test]
596 fn is_symbol_query_leading_underscore_accepted() {
597 assert!(is_symbol_query("_private"));
598 assert!(is_symbol_query("__init__"));
599 }
600
601 #[test]
604 fn resolve_alpha_symbol_0_3() {
605 assert!((resolve_alpha("Client", None) - ALPHA_SYMBOL).abs() < 1e-6);
606 assert!((resolve_alpha("foo.Bar", None) - ALPHA_SYMBOL).abs() < 1e-6);
607 }
608
609 #[test]
610 fn resolve_alpha_nl_0_5() {
611 assert!((resolve_alpha("how does retry work", None) - ALPHA_NL).abs() < 1e-6);
612 assert!((resolve_alpha("authentication handling", None) - ALPHA_NL).abs() < 1e-6);
613 }
614
615 #[test]
616 fn resolve_alpha_explicit_override_wins() {
617 assert!((resolve_alpha("Client", Some(0.7)) - 0.7).abs() < 1e-6);
618 }
619
620 #[test]
623 fn chunk_defines_symbol_class() {
624 let content = "class Client:\n pass";
625 assert!(chunk_defines_symbol(content, "Client"));
626 }
627
628 #[test]
629 fn chunk_defines_symbol_def() {
630 let content = "def handle_request():\n pass";
631 assert!(chunk_defines_symbol(content, "handle_request"));
632 }
633
634 #[test]
635 fn chunk_defines_symbol_namespace_qualified() {
636 let content = " defmodule Phoenix.Router do\n";
638 assert!(chunk_defines_symbol(content, "Router"));
639 }
640
641 #[test]
642 fn chunk_defines_symbol_sql_case_insensitive() {
643 assert!(chunk_defines_symbol(
644 " create table users (id int)",
645 "users"
646 ));
647 assert!(chunk_defines_symbol(
648 " CREATE TABLE Users (id int)",
649 "Users"
650 ));
651 }
652
653 #[test]
654 fn chunk_defines_symbol_negative() {
655 assert!(!chunk_defines_symbol("client.do_thing()", "Client"));
656 }
657
658 #[test]
661 fn boost_symbol_definitions_stem_multiplier() {
662 let chunks = vec![
666 chunk("src/client.rs", "struct Client { /* ... */ }"),
667 chunk("src/unrelated.rs", "struct Client { /* ... */ }"),
668 ];
669 let mut boosted: HashMap<usize, f32> = HashMap::from([(0, 1.0), (1, 1.0)]);
670 boost_symbol_definitions(&mut boosted, "Client", 1.0, &chunks);
671 assert!((boosted[&0] - 5.5).abs() < 1e-6);
675 assert!((boosted[&1] - 4.0).abs() < 1e-6);
676 }
677
678 #[test]
681 fn boost_stem_matches_prefix() {
682 let chunks = vec![chunk("src/parser.rs", "fn run() {}")];
688 let mut boosted: HashMap<usize, f32> = HashMap::from([(0, 1.0_f32)]);
689 boost_stem_matches(&mut boosted, "parse json structure", 1.0, &chunks);
690 assert!(boosted[&0] > 1.0, "expected stem-match boost on parser.rs");
691 }
692
693 #[test]
696 fn boost_embedded_symbols_half_strength() {
697 let chunks = vec![chunk("src/myclass.rs", "struct MyClass {}")];
700 let mut boosted: HashMap<usize, f32> = HashMap::from([(0, 1.0_f32)]);
701 boost_embedded_symbols(&mut boosted, "how does MyClass handle errors", 1.0, &chunks);
702 assert!((boosted[&0] - 3.25).abs() < 1e-6, "got {}", boosted[&0]);
705 }
706
707 #[test]
710 fn boost_multi_chunk_files() {
711 let chunks = vec![
717 chunk("src/foo.rs", ""),
718 chunk("src/foo.rs", ""),
719 chunk("src/bar.rs", ""),
720 ];
721 let mut scores: HashMap<usize, f32> = HashMap::from([(0, 1.0), (1, 0.5), (2, 1.0)]);
722 super::boost_multi_chunk_files(&mut scores, &chunks);
723 assert!((scores[&0] - 1.2).abs() < 1e-6, "got {}", scores[&0]);
724 assert!((scores[&1] - 0.5).abs() < 1e-6, "non-best chunk unchanged");
725 let expected_bar = 1.0 + 0.2 * (1.0 / 1.5);
726 assert!(
727 (scores[&2] - expected_bar).abs() < 1e-6,
728 "got {}, expected {}",
729 scores[&2],
730 expected_bar
731 );
732 }
733
734 #[test]
737 fn property_symbol_regex_parity_python() {
738 let symbols = &[
740 "Client",
741 "handle_request",
742 "_private",
743 "getX",
744 "XMLParser",
745 "foo::bar",
746 "foo.bar.baz",
747 "a->b",
748 r"Foo\Bar",
749 "__init__",
750 "snake_case",
751 ];
752 for q in symbols {
753 assert!(is_symbol_query(q), "expected symbol query: {q:?}");
754 }
755 let non_symbols = &[
757 "session",
758 "retry",
759 "authentication",
760 "how does retry work",
761 "user authentication flow",
762 "hi",
763 ];
764 for q in non_symbols {
765 assert!(!is_symbol_query(q), "expected NL: {q:?}");
766 }
767 }
768}