use std::collections::{HashMap, HashSet};
use std::path::Path;
use std::sync::OnceLock;
use regex::{Regex, RegexBuilder};
use crate::chunk::CodeChunk;
use crate::encoder::ripvec::tokens::split_identifier;
pub const ALPHA_SYMBOL: f32 = 0.3;
pub const ALPHA_NL: f32 = 0.5;
#[must_use]
pub fn resolve_alpha(query: &str, alpha: Option<f32>) -> f32 {
if let Some(w) = alpha {
return w;
}
if is_symbol_query(query) {
ALPHA_SYMBOL
} else {
ALPHA_NL
}
}
fn symbol_query_re() -> &'static Regex {
static RE: OnceLock<Regex> = OnceLock::new();
RE.get_or_init(|| {
Regex::new(concat!(
r"^(?:",
r"[A-Za-z_][A-Za-z0-9_]*(?:(?:::|\\|->|\.)[A-Za-z_][A-Za-z0-9_]*)+",
r"|_[A-Za-z0-9_]*",
r"|[A-Za-z][A-Za-z0-9]*[A-Z_][A-Za-z0-9_]*",
r"|[A-Z][A-Za-z0-9]*",
r")$",
))
.expect("symbol-query regex compiles")
})
}
fn embedded_symbol_re() -> &'static Regex {
static RE: OnceLock<Regex> = OnceLock::new();
RE.get_or_init(|| {
Regex::new(concat!(
r"\b(?:",
r"[A-Z][a-z][a-zA-Z0-9]*[A-Z][a-zA-Z0-9]*",
r"|[a-z][a-zA-Z0-9]*[A-Z][a-zA-Z0-9]+",
r")\b",
))
.expect("embedded-symbol regex compiles")
})
}
#[must_use]
pub fn is_symbol_query(query: &str) -> bool {
symbol_query_re().is_match(query.trim())
}
#[must_use]
pub fn is_prose_path(path: &str) -> bool {
let lower = path.to_ascii_lowercase();
let ext = lower.rsplit('.').next().unwrap_or("");
matches!(
ext,
"md" | "markdown" | "mdx" | "rst" | "txt" | "text" | "adoc" | "asciidoc" | "org"
)
}
const DEFINITION_KEYWORDS: &[&str] = &[
"class",
"module",
"defmodule", "def",
"interface",
"struct",
"enum",
"trait",
"type",
"func",
"function",
"object",
"abstract class",
"data class",
"fn",
"fun", "package",
"namespace",
"protocol", "record", "typedef", ];
const SQL_DEFINITION_KEYWORDS: &[&str] = &[
"CREATE TABLE",
"CREATE VIEW",
"CREATE PROCEDURE",
"CREATE FUNCTION",
];
const DEFINITION_BOOST_MULTIPLIER: f32 = 3.0;
const STEM_BOOST_MULTIPLIER: f32 = 1.0;
const FILE_COHERENCE_BOOST_FRAC: f32 = 0.2;
const EMBEDDED_SYMBOL_BOOST_SCALE: f32 = 0.5;
const EMBEDDED_STEM_MIN_LEN: usize = 4;
const STOPWORDS: &[&str] = &[
"a", "an", "and", "are", "as", "at", "be", "by", "do", "does", "for", "from", "has", "have",
"how", "if", "in", "is", "it", "not", "of", "on", "or", "the", "to", "was", "what", "when",
"where", "which", "who", "why", "with",
];
fn definition_pattern_uncached(symbol_name: &str) -> (Regex, Regex) {
let escaped = regex::escape(symbol_name);
let ns_prefix = r"(?:[A-Za-z_][A-Za-z0-9_]*(?:\.|::))*";
let def_body = DEFINITION_KEYWORDS
.iter()
.map(|k| regex::escape(k))
.collect::<Vec<_>>()
.join("|");
let sql_body = SQL_DEFINITION_KEYWORDS
.iter()
.map(|k| regex::escape(k))
.collect::<Vec<_>>()
.join("|");
let suffix = format!(r")\s+{ns_prefix}{escaped}(?:\s|[<({{:\[;]|$)");
let no_lookbehind_prefix = r"(?:^|\s)(?:";
let general_pat = format!("{no_lookbehind_prefix}{def_body}{suffix}");
let sql_pat = format!("{no_lookbehind_prefix}{sql_body}{suffix}");
let general = RegexBuilder::new(&general_pat)
.multi_line(true)
.build()
.expect("general definition regex compiles");
let sql = RegexBuilder::new(&sql_pat)
.multi_line(true)
.case_insensitive(true)
.build()
.expect("SQL definition regex compiles");
(general, sql)
}
fn definition_pattern(symbol_name: &str) -> (Regex, Regex) {
use std::sync::{Mutex, OnceLock};
static CACHE: OnceLock<Mutex<std::collections::HashMap<String, (Regex, Regex)>>> =
OnceLock::new();
let cache = CACHE.get_or_init(|| Mutex::new(std::collections::HashMap::new()));
if let Ok(map) = cache.lock()
&& let Some(entry) = map.get(symbol_name)
{
return entry.clone();
}
let pair = definition_pattern_uncached(symbol_name);
if let Ok(mut map) = cache.lock()
&& map.len() < 256
{
map.insert(symbol_name.to_string(), pair.clone());
}
pair
}
fn chunk_defines_symbol(content: &str, symbol_name: &str) -> bool {
if !content.contains(symbol_name) {
return false;
}
let (general, sql) = definition_pattern(symbol_name);
general.is_match(content) || sql.is_match(content)
}
fn stem_matches(stem: &str, name: &str) -> bool {
let stem_norm = stem.replace('_', "");
stem == name
|| stem_norm == name
|| stem.trim_end_matches('s') == name
|| stem_norm.trim_end_matches('s') == name
}
fn extract_symbol_name(query: &str) -> String {
for separator in &["::", "\\", "->", "."] {
if let Some(idx) = query.rfind(separator) {
return query[idx + separator.len()..].to_string();
}
}
query.trim().to_string()
}
fn definition_tier(chunk: &CodeChunk, names: &HashSet<String>, boost_unit: f32) -> f32 {
let any_match = names
.iter()
.any(|name| chunk_defines_symbol(&chunk.content, name));
if !any_match {
return 0.0;
}
let stem = Path::new(&chunk.file_path)
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or_default()
.to_ascii_lowercase();
let stem_match_bonus = names
.iter()
.any(|name| stem_matches(&stem, &name.to_ascii_lowercase()));
boost_unit * if stem_match_bonus { 1.5 } else { 1.0 }
}
fn scan_non_candidates(
boosted: &mut HashMap<usize, f32>,
names: &HashSet<String>,
boost_unit: f32,
all_chunks: &[CodeChunk],
stem_ok: &dyn Fn(&str) -> bool,
) {
for (idx, chunk) in all_chunks.iter().enumerate() {
if boosted.contains_key(&idx) {
continue;
}
let stem = Path::new(&chunk.file_path)
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or_default()
.to_ascii_lowercase();
if !stem_ok(&stem) {
continue;
}
let tier = definition_tier(chunk, names, boost_unit);
if tier > 0.0 {
boosted.insert(idx, tier);
}
}
}
fn boost_symbol_definitions(
boosted: &mut HashMap<usize, f32>,
query: &str,
max_score: f32,
all_chunks: &[CodeChunk],
) {
let symbol_name = extract_symbol_name(query);
let trimmed_query = query.trim();
let mut names: HashSet<String> = HashSet::new();
names.insert(symbol_name.clone());
if symbol_name != trimmed_query {
names.insert(trimmed_query.to_string());
}
let boost_unit = max_score * DEFINITION_BOOST_MULTIPLIER;
let candidate_indices: Vec<usize> = boosted.keys().copied().collect();
for idx in candidate_indices {
let tier = definition_tier(&all_chunks[idx], &names, boost_unit);
if tier > 0.0 {
*boosted.entry(idx).or_insert(0.0) += tier;
}
}
let symbol_lower = symbol_name.to_ascii_lowercase();
scan_non_candidates(boosted, &names, boost_unit, all_chunks, &|stem: &str| {
stem_matches(stem, &symbol_lower)
});
}
fn boost_embedded_symbols(
boosted: &mut HashMap<usize, f32>,
query: &str,
max_score: f32,
all_chunks: &[CodeChunk],
) {
let names: HashSet<String> = embedded_symbol_re()
.find_iter(query)
.map(|m| m.as_str().to_string())
.collect();
if names.is_empty() {
return;
}
let boost_unit = max_score * DEFINITION_BOOST_MULTIPLIER * EMBEDDED_SYMBOL_BOOST_SCALE;
let candidate_indices: Vec<usize> = boosted.keys().copied().collect();
for idx in candidate_indices {
let tier = definition_tier(&all_chunks[idx], &names, boost_unit);
if tier > 0.0 {
*boosted.entry(idx).or_insert(0.0) += tier;
}
}
let symbols_lower: Vec<String> = names.iter().map(|n| n.to_ascii_lowercase()).collect();
let symbols_lower_for_scan = symbols_lower.clone();
scan_non_candidates(
boosted,
&names,
boost_unit,
all_chunks,
&move |stem: &str| {
let stem_norm = stem.replace('_', "");
symbols_lower_for_scan.iter().any(|sym_lower| {
stem == sym_lower
|| stem_norm == *sym_lower
|| (stem.len() >= EMBEDDED_STEM_MIN_LEN && sym_lower.starts_with(stem))
|| (stem_norm.len() >= EMBEDDED_STEM_MIN_LEN
&& sym_lower.starts_with(stem_norm.as_str()))
})
},
);
}
fn count_keyword_matches(keywords: &HashSet<String>, parts: &HashSet<String>) -> usize {
let exact: HashSet<&String> = keywords.iter().filter(|k| parts.contains(*k)).collect();
if exact.len() == keywords.len() {
return exact.len();
}
let mut n = exact.len();
for keyword in keywords {
if exact.contains(keyword) {
continue;
}
for part in parts {
let (shorter, longer) = if keyword.len() <= part.len() {
(keyword.as_str(), part.as_str())
} else {
(part.as_str(), keyword.as_str())
};
if shorter.len() >= 3 && longer.starts_with(shorter) {
n += 1;
break;
}
}
}
n
}
fn boost_stem_matches(
boosted: &mut HashMap<usize, f32>,
query: &str,
max_score: f32,
chunks: &[CodeChunk],
) {
static KEYWORD_RE: OnceLock<Regex> = OnceLock::new();
let keyword_re =
KEYWORD_RE.get_or_init(|| Regex::new(r"[a-zA-Z_][a-zA-Z0-9_]*").expect("keyword regex"));
let keywords: HashSet<String> = keyword_re
.find_iter(query)
.map(|m| m.as_str().to_ascii_lowercase())
.filter(|w| w.len() > 2 && !STOPWORDS.contains(&w.as_str()))
.collect();
if keywords.is_empty() {
return;
}
let boost = max_score * STEM_BOOST_MULTIPLIER;
let mut path_cache: HashMap<String, HashSet<String>> = HashMap::new();
let candidate_indices: Vec<usize> = boosted.keys().copied().collect();
for idx in candidate_indices {
let path = &chunks[idx].file_path;
let parts = path_cache
.entry(path.clone())
.or_insert_with(|| {
let mut parts: HashSet<String> = HashSet::new();
let p = Path::new(path);
if let Some(stem) = p.file_stem().and_then(|s| s.to_str()) {
parts.extend(split_identifier(stem));
}
if let Some(parent_name) = p
.parent()
.and_then(Path::file_name)
.and_then(|s| s.to_str())
&& !parent_name.is_empty()
&& parent_name != "."
&& parent_name != ".."
{
parts.extend(split_identifier(parent_name));
}
parts
})
.clone();
let n_matches = count_keyword_matches(&keywords, &parts);
if n_matches > 0 {
let match_ratio = n_matches as f32 / keywords.len() as f32;
if match_ratio >= 0.10 {
*boosted.entry(idx).or_insert(0.0) += boost * match_ratio;
}
}
}
}
#[expect(
clippy::implicit_hasher,
reason = "internal API; callers in the semble pipeline use the default RandomState"
)]
#[must_use]
pub fn apply_query_boost(
combined_scores: &HashMap<usize, f32>,
query: &str,
all_chunks: &[CodeChunk],
) -> HashMap<usize, f32> {
if combined_scores.is_empty() {
return HashMap::new();
}
let max_score = combined_scores
.values()
.copied()
.fold(f32::NEG_INFINITY, f32::max);
let mut boosted = combined_scores.clone();
if is_symbol_query(query) {
boost_symbol_definitions(&mut boosted, query, max_score, all_chunks);
} else {
boost_stem_matches(&mut boosted, query, max_score, all_chunks);
boost_embedded_symbols(&mut boosted, query, max_score, all_chunks);
}
boosted
}
#[expect(
clippy::implicit_hasher,
reason = "internal API; callers in the semble pipeline use the default RandomState"
)]
pub fn boost_multi_chunk_files(scores: &mut HashMap<usize, f32>, chunks: &[CodeChunk]) {
if scores.is_empty() {
return;
}
let max_score = scores.values().copied().fold(f32::NEG_INFINITY, f32::max);
if max_score == 0.0 || !max_score.is_finite() {
return;
}
let mut file_sum: HashMap<String, f32> = HashMap::new();
let mut best_chunk_idx: HashMap<String, usize> = HashMap::new();
for (&idx, &score) in scores.iter() {
let path = chunks[idx].file_path.clone();
*file_sum.entry(path.clone()).or_insert(0.0) += score;
match best_chunk_idx.get(&path) {
Some(&best) if scores[&best] >= score => {}
_ => {
best_chunk_idx.insert(path, idx);
}
}
}
let max_file_sum = file_sum.values().copied().fold(f32::NEG_INFINITY, f32::max);
if max_file_sum <= 0.0 || !max_file_sum.is_finite() {
return;
}
let boost_unit = max_score * FILE_COHERENCE_BOOST_FRAC;
for (path, &idx) in &best_chunk_idx {
let contribution = boost_unit * file_sum[path] / max_file_sum;
*scores.entry(idx).or_insert(0.0) += contribution;
}
}
#[cfg(test)]
mod tests {
use super::*;
fn chunk(path: &str, content: &str) -> CodeChunk {
CodeChunk {
file_path: path.to_string(),
name: String::new(),
kind: String::new(),
start_line: 1,
end_line: 1,
content: content.to_string(),
enriched_content: content.to_string(),
}
}
#[test]
fn is_symbol_query_namespace() {
assert!(is_symbol_query("Sinatra::Base"));
assert!(is_symbol_query("module.Class"));
assert!(is_symbol_query("a->b->c"));
assert!(is_symbol_query(r"Foo\Bar"));
}
#[test]
fn is_symbol_query_pascal() {
assert!(is_symbol_query("Client"));
assert!(is_symbol_query("HTTPHandler"));
assert!(is_symbol_query("XMLParser"));
}
#[test]
fn is_symbol_query_plain_word_rejected() {
assert!(!is_symbol_query("session"));
assert!(!is_symbol_query("retry"));
assert!(!is_symbol_query("authentication"));
}
#[test]
fn is_symbol_query_leading_underscore_accepted() {
assert!(is_symbol_query("_private"));
assert!(is_symbol_query("__init__"));
}
#[test]
fn resolve_alpha_symbol_0_3() {
assert!((resolve_alpha("Client", None) - ALPHA_SYMBOL).abs() < 1e-6);
assert!((resolve_alpha("foo.Bar", None) - ALPHA_SYMBOL).abs() < 1e-6);
}
#[test]
fn resolve_alpha_nl_0_5() {
assert!((resolve_alpha("how does retry work", None) - ALPHA_NL).abs() < 1e-6);
assert!((resolve_alpha("authentication handling", None) - ALPHA_NL).abs() < 1e-6);
}
#[test]
fn resolve_alpha_explicit_override_wins() {
assert!((resolve_alpha("Client", Some(0.7)) - 0.7).abs() < 1e-6);
}
#[test]
fn chunk_defines_symbol_class() {
let content = "class Client:\n pass";
assert!(chunk_defines_symbol(content, "Client"));
}
#[test]
fn chunk_defines_symbol_def() {
let content = "def handle_request():\n pass";
assert!(chunk_defines_symbol(content, "handle_request"));
}
#[test]
fn chunk_defines_symbol_namespace_qualified() {
let content = " defmodule Phoenix.Router do\n";
assert!(chunk_defines_symbol(content, "Router"));
}
#[test]
fn chunk_defines_symbol_sql_case_insensitive() {
assert!(chunk_defines_symbol(
" create table users (id int)",
"users"
));
assert!(chunk_defines_symbol(
" CREATE TABLE Users (id int)",
"Users"
));
}
#[test]
fn chunk_defines_symbol_negative() {
assert!(!chunk_defines_symbol("client.do_thing()", "Client"));
}
#[test]
fn boost_symbol_definitions_stem_multiplier() {
let chunks = vec![
chunk("src/client.rs", "struct Client { /* ... */ }"),
chunk("src/unrelated.rs", "struct Client { /* ... */ }"),
];
let mut boosted: HashMap<usize, f32> = HashMap::from([(0, 1.0), (1, 1.0)]);
boost_symbol_definitions(&mut boosted, "Client", 1.0, &chunks);
assert!((boosted[&0] - 5.5).abs() < 1e-6);
assert!((boosted[&1] - 4.0).abs() < 1e-6);
}
#[test]
fn boost_stem_matches_prefix() {
let chunks = vec![chunk("src/parser.rs", "fn run() {}")];
let mut boosted: HashMap<usize, f32> = HashMap::from([(0, 1.0_f32)]);
boost_stem_matches(&mut boosted, "parse json structure", 1.0, &chunks);
assert!(boosted[&0] > 1.0, "expected stem-match boost on parser.rs");
}
#[test]
fn boost_embedded_symbols_half_strength() {
let chunks = vec![chunk("src/myclass.rs", "struct MyClass {}")];
let mut boosted: HashMap<usize, f32> = HashMap::from([(0, 1.0_f32)]);
boost_embedded_symbols(&mut boosted, "how does MyClass handle errors", 1.0, &chunks);
assert!((boosted[&0] - 3.25).abs() < 1e-6, "got {}", boosted[&0]);
}
#[test]
fn boost_multi_chunk_files() {
let chunks = vec![
chunk("src/foo.rs", ""),
chunk("src/foo.rs", ""),
chunk("src/bar.rs", ""),
];
let mut scores: HashMap<usize, f32> = HashMap::from([(0, 1.0), (1, 0.5), (2, 1.0)]);
super::boost_multi_chunk_files(&mut scores, &chunks);
assert!((scores[&0] - 1.2).abs() < 1e-6, "got {}", scores[&0]);
assert!((scores[&1] - 0.5).abs() < 1e-6, "non-best chunk unchanged");
let expected_bar = 1.0 + 0.2 * (1.0 / 1.5);
assert!(
(scores[&2] - expected_bar).abs() < 1e-6,
"got {}, expected {}",
scores[&2],
expected_bar
);
}
#[test]
fn property_symbol_regex_parity_python() {
let symbols = &[
"Client",
"handle_request",
"_private",
"getX",
"XMLParser",
"foo::bar",
"foo.bar.baz",
"a->b",
r"Foo\Bar",
"__init__",
"snake_case",
];
for q in symbols {
assert!(is_symbol_query(q), "expected symbol query: {q:?}");
}
let non_symbols = &[
"session",
"retry",
"authentication",
"how does retry work",
"user authentication flow",
"hi",
];
for q in non_symbols {
assert!(!is_symbol_query(q), "expected NL: {q:?}");
}
}
}