use ahash::AHashMap;
use ahash::AHashSet;
use regex::Regex;
use std::path::Path;
use std::sync::LazyLock;
use crate::tokenizer::split_identifier;
use crate::types::Chunk;
static SYMBOL_QUERY_RE: LazyLock<Regex> = LazyLock::new(|| {
Regex::new(
r"(?x)
^(?:
[A-Za-z_][A-Za-z0-9_]* (?: (?:::|\\|->|\.) [A-Za-z_][A-Za-z0-9_]* )+ # qualified: foo::bar, Foo.bar
| _[A-Za-z0-9_]* # leading underscore
| [A-Za-z][A-Za-z0-9]* [A-Z_] [A-Za-z0-9_]* # camelCase / SCREAMING_SNAKE
| [A-Z][A-Za-z0-9]* # leading uppercase: Foo, Manifest
)$",
)
.unwrap()
});
static EMBEDDED_SYMBOL_RE: LazyLock<Regex> = LazyLock::new(|| {
Regex::new(
r"(?x)
\b(?:
[A-Z][a-z][a-zA-Z0-9]*[A-Z][a-zA-Z0-9]* # PascalCase: Foo, FooBar
| [a-z][a-zA-Z0-9]*[A-Z][a-zA-Z0-9]+ # camelCase: fooBar
)\b",
)
.unwrap()
});
static TOKEN_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"[\p{L}_][\p{L}\p{N}_]*").unwrap());
const EMBEDDED_STEM_MIN_LEN: usize = 4;
const EMBEDDED_SYMBOL_BOOST_SCALE: f64 = 0.5;
const DEFINITION_KEYWORDS: &[&str] = &[
"class",
"module",
"defmodule",
"def",
"interface",
"struct",
"enum",
"trait",
"type",
"func",
"function",
"object",
"abstract class",
"data class",
"fn",
"fun",
"package",
"namespace",
"protocol",
"record",
"typedef",
"const",
"static",
];
const SQL_DEFINITION_KEYWORDS: &[&str] = &[
"CREATE TABLE",
"CREATE VIEW",
"CREATE PROCEDURE",
"CREATE FUNCTION",
];
const DEFINITION_BOOST_MULTIPLIER: f64 = 3.0;
const STEM_BOOST_MULTIPLIER: f64 = 1.0;
const FILE_COHERENCE_BOOST_FRAC: f64 = 0.2;
static STOPWORDS: LazyLock<AHashSet<&'static str>> = LazyLock::new(|| {
"a an and are as at be by do does for from has have how if in is it not of on or the to was \
what when where which who why with"
.split_whitespace()
.collect()
});
pub fn is_symbol_query(query: &str) -> bool {
SYMBOL_QUERY_RE.is_match(query.trim())
}
pub fn resolve_alpha(query: &str, alpha: Option<f64>) -> f64 {
match alpha {
Some(a) => a,
None => {
if is_symbol_query(query) {
0.3 } else {
0.5 }
}
}
}
pub fn apply_query_boost(scores: &mut [f64], query: &str, chunks: &[Chunk]) {
if scores.is_empty() {
return;
}
let max_score = current_max(scores);
if max_score <= 0.0 || max_score.is_nan() {
return;
}
if is_symbol_query(query) {
boost_symbol_definitions(scores, query, max_score, chunks);
} else {
boost_stem_matches(scores, query, max_score, chunks);
boost_embedded_symbols(scores, query, max_score, chunks);
}
}
pub fn boost_multi_chunk_files(scores: &mut [f64], chunks: &[Chunk]) {
let max_score = current_max(scores);
if max_score <= 0.0 || max_score.is_nan() {
return;
}
let mut file_sum: AHashMap<&str, f64> = AHashMap::new();
let mut best_idx: AHashMap<&str, (usize, f64)> = AHashMap::new();
for (i, &score) in scores.iter().enumerate() {
if score <= 0.0 || score.is_nan() {
continue;
}
let fp = chunks[i].file_path.as_str();
*file_sum.entry(fp).or_insert(0.0) += score;
let entry = best_idx.entry(fp).or_insert((i, f64::NEG_INFINITY));
if score > entry.1 {
*entry = (i, score);
}
}
let max_file_sum = file_sum.values().copied().fold(f64::NEG_INFINITY, f64::max);
if max_file_sum <= 0.0 || max_file_sum.is_nan() {
return;
}
let boost_unit = max_score * FILE_COHERENCE_BOOST_FRAC;
for (fp, (idx, _)) in &best_idx {
let sum = *file_sum.get(*fp).unwrap_or(&0.0);
scores[*idx] += boost_unit * sum / max_file_sum;
}
}
fn current_max(scores: &[f64]) -> f64 {
scores.iter().copied().fold(f64::NEG_INFINITY, f64::max)
}
fn extract_symbol_name(query: &str) -> String {
let query = query.trim();
for separator in ["::", "\\", "->", "."] {
if query.contains(separator) {
return query.rsplit(separator).next().unwrap_or(query).to_string();
}
}
query.to_string()
}
fn definition_pattern(symbol_name: &str) -> (Regex, Regex) {
let escaped = regex::escape(symbol_name);
let ns_prefix = r"(?:[A-Za-z_][A-Za-z0-9_]*(?:\.|::))*";
let suffix = format!(r")\s+{ns_prefix}{escaped}(?:\s|[<({{\[:;]|$)");
let kw_body: String = DEFINITION_KEYWORDS
.iter()
.map(|k| regex::escape(k))
.collect::<Vec<_>>()
.join("|");
let sql_body: String = SQL_DEFINITION_KEYWORDS
.iter()
.map(|k| regex::escape(k))
.collect::<Vec<_>>()
.join("|");
let general = Regex::new(&format!(r"\b(?:{kw_body}{suffix}")).unwrap();
let sql = Regex::new(&format!(r"(?i)\b(?:{sql_body}{suffix}")).unwrap();
(general, sql)
}
struct DefinitionMatchers {
patterns: Vec<(Regex, Regex)>,
names: Vec<String>,
}
impl DefinitionMatchers {
fn for_names<I: IntoIterator<Item = String>>(names: I) -> Self {
let names: Vec<String> = names.into_iter().collect();
let patterns = names.iter().map(|n| definition_pattern(n)).collect();
Self { patterns, names }
}
fn defines_any(&self, content: &str) -> bool {
self.patterns
.iter()
.any(|(g, s)| g.is_match(content) || s.is_match(content))
}
}
fn stem_matches(stem: &str, name: &str) -> bool {
let stem_norm = stem.replace('_', "");
stem == name
|| stem_norm == name
|| stem.trim_end_matches('s') == name
|| stem_norm.trim_end_matches('s') == name
}
fn definition_tier(chunk: &Chunk, matchers: &DefinitionMatchers, boost_unit: f64) -> f64 {
if !matchers.defines_any(&chunk.content) {
return 0.0;
}
let stem = file_stem_lower(&chunk.file_path);
if matchers
.names
.iter()
.any(|n| stem_matches(&stem, &n.to_lowercase()))
{
boost_unit * 1.5
} else {
boost_unit
}
}
fn file_stem_lower(file_path: &str) -> String {
Path::new(file_path)
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or("")
.to_lowercase()
}
fn boost_symbol_definitions(scores: &mut [f64], query: &str, max_score: f64, chunks: &[Chunk]) {
let symbol_name = extract_symbol_name(query);
let trimmed = query.trim().to_string();
let mut names = vec![symbol_name.clone()];
if symbol_name != trimmed {
names.push(trimmed);
}
let matchers = DefinitionMatchers::for_names(names.clone());
let boost_unit = max_score * DEFINITION_BOOST_MULTIPLIER;
for (i, chunk) in chunks.iter().enumerate() {
if !names.iter().any(|n| chunk.content.contains(n)) {
continue;
}
let tier = definition_tier(chunk, &matchers, boost_unit);
if tier > 0.0 {
scores[i] += tier;
}
}
}
fn boost_embedded_symbols(scores: &mut [f64], query: &str, max_score: f64, chunks: &[Chunk]) {
let names: Vec<String> = EMBEDDED_SYMBOL_RE
.find_iter(query)
.map(|m| m.as_str().to_string())
.collect();
if names.is_empty() {
return;
}
let matchers = DefinitionMatchers::for_names(names.clone());
let boost_unit = max_score * DEFINITION_BOOST_MULTIPLIER * EMBEDDED_SYMBOL_BOOST_SCALE;
let symbols_lower: Vec<String> = names.iter().map(|s| s.to_lowercase()).collect();
for (i, chunk) in chunks.iter().enumerate() {
let in_pool = scores[i] > 0.0;
if in_pool {
let tier = definition_tier(chunk, &matchers, boost_unit);
if tier > 0.0 {
scores[i] += tier;
}
} else {
let stem = file_stem_lower(&chunk.file_path);
let stem_norm = stem.replace('_', "");
let matches = symbols_lower.iter().any(|sym_lower| {
stem == *sym_lower
|| stem_norm == *sym_lower
|| (stem.len() >= EMBEDDED_STEM_MIN_LEN && sym_lower.starts_with(&stem))
|| (stem_norm.len() >= EMBEDDED_STEM_MIN_LEN
&& sym_lower.starts_with(&stem_norm))
});
if !matches {
continue;
}
let tier = definition_tier(chunk, &matchers, boost_unit);
if tier > 0.0 {
scores[i] += tier;
}
}
}
}
fn boost_stem_matches(scores: &mut [f64], query: &str, max_score: f64, chunks: &[Chunk]) {
let keywords: AHashSet<String> = TOKEN_RE
.find_iter(query)
.map(|m| m.as_str().to_lowercase())
.filter(|w| w.chars().count() > 2 && !STOPWORDS.contains(w.as_str()))
.collect();
if keywords.is_empty() {
return;
}
let boost = max_score * STEM_BOOST_MULTIPLIER;
let mut path_cache: AHashMap<&str, AHashSet<String>> = AHashMap::new();
for (i, chunk) in chunks.iter().enumerate() {
if scores[i] <= 0.0 || scores[i].is_nan() {
continue;
}
let parts = path_cache
.entry(chunk.file_path.as_str())
.or_insert_with(|| build_path_parts(&chunk.file_path));
let n_matches = count_keyword_matches(&keywords, parts);
if n_matches > 0 {
let match_ratio = n_matches as f64 / keywords.len() as f64;
if match_ratio >= 0.10 {
scores[i] += boost * match_ratio;
}
}
}
}
fn build_path_parts(file_path: &str) -> AHashSet<String> {
let path = Path::new(file_path);
let mut parts: AHashSet<String> = AHashSet::new();
if let Some(stem) = path.file_stem().and_then(|s| s.to_str()) {
parts.extend(split_identifier(stem));
}
if let Some(parent_name) = path
.parent()
.and_then(|p| p.file_name())
.and_then(|n| n.to_str())
&& ![".", "/", ".."].contains(&parent_name)
{
parts.extend(split_identifier(parent_name));
}
parts
}
fn count_keyword_matches(keywords: &AHashSet<String>, parts: &AHashSet<String>) -> usize {
let mut n_matches = 0usize;
let mut residual: Vec<&String> = Vec::with_capacity(keywords.len());
for kw in keywords {
if parts.contains(kw) {
n_matches += 1;
} else {
residual.push(kw);
}
}
if residual.is_empty() {
return n_matches;
}
for keyword in residual {
for part in parts {
let (shorter, longer) = if keyword.len() <= part.len() {
(keyword.as_str(), part.as_str())
} else {
(part.as_str(), keyword.as_str())
};
if shorter.len() >= 3 && longer.starts_with(shorter) {
n_matches += 1;
break;
}
}
}
n_matches
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn symbol_queries_recognised() {
for q in [
"Manifest",
"TopK",
"CamelCase",
"parse_config",
"PREVIEW_FILE_CACHE",
"_private_thing",
"foo::bar",
"module::Type",
"obj->method",
"Foo.bar",
] {
assert!(is_symbol_query(q), "expected symbol query: {q:?}");
}
}
#[test]
fn natural_language_not_a_symbol_query() {
for q in [
"parse the config file",
"how does auth work",
"rate limiting middleware",
"fn parse_config", ] {
assert!(!is_symbol_query(q), "did not expect symbol query: {q:?}");
}
}
#[test]
fn embedded_camelcase_is_extracted() {
let hits: Vec<&str> = EMBEDDED_SYMBOL_RE
.find_iter("how does FooBar interact with bazQux today")
.map(|m| m.as_str())
.collect();
assert!(hits.contains(&"FooBar"), "FooBar not found in {hits:?}");
assert!(hits.contains(&"bazQux"), "bazQux not found in {hits:?}");
}
#[test]
fn symbol_def_boost_lifts_a_buried_definition() {
use crate::types::Chunk;
let chunks = vec![
Chunk {
content: "const TOP_K: usize = 50;".into(),
file_path: "src/app.rs".into(),
start_line: 1,
end_line: 1,
language: Some("rust".into()),
},
Chunk {
content: "fn search(top_k: usize) { call(top_k); }".into(),
file_path: "src/search.rs".into(),
start_line: 1,
end_line: 1,
language: Some("rust".into()),
},
];
let mut scores = vec![0.0, 1.0];
apply_query_boost(&mut scores, "TOP_K", &chunks);
assert!(
scores[0] > scores[1],
"expected the const definition to outrank a reference chunk: scores={scores:?}"
);
}
#[test]
fn definition_pattern_recognises_const_and_static() {
let (general, _) = definition_pattern("PREVIEW_FILE_CACHE");
assert!(general.is_match("const PREVIEW_FILE_CACHE: usize = 8;"));
assert!(general.is_match("static PREVIEW_FILE_CACHE: usize = 8;"));
}
#[test]
fn definition_pattern_compiles_and_matches() {
let (general, sql) = definition_pattern("Manifest");
assert!(general.is_match("pub struct Manifest {"));
assert!(general.is_match(" fn Manifest() {}"));
assert!(!general.is_match("classManifest {"));
let (_, sql_lower) = definition_pattern("users");
assert!(sql_lower.is_match("create table users ("));
let _ = sql; }
}