use regex::Regex;
use std::collections::HashMap;
use std::path::Path;
use std::sync::OnceLock;
fn env_f32(name: &str, default: f32) -> f32 {
std::env::var(name)
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(default)
}
fn strong_penalty() -> f32 {
env_f32("COLGREP_STRONG_PENALTY", 0.30)
}
fn moderate_penalty() -> f32 {
env_f32("COLGREP_MODERATE_PENALTY", 0.50)
}
fn mild_penalty() -> f32 {
env_f32("COLGREP_MILD_PENALTY", 0.70)
}
fn test_file_re() -> &'static Regex {
static R: OnceLock<Regex> = OnceLock::new();
R.get_or_init(|| {
Regex::new(
r"(?x)
(?:^|/)(?:
test_[^/]*\.py # Python: test_foo.py
| [^/]*_test\.py # Python: foo_test.py
| [^/]*_test\.go # Go
| [^/]*Tests?\.java # Java: FooTest/FooTests.java
| [^/]*Test\.php # PHP: FooTest.php
| [^/]*_spec\.rb # Ruby (RSpec)
| [^/]*_test\.rb # Ruby
| [^/]*\.test\.[jt]sx? # JS/TS: foo.test.js/ts/jsx/tsx
| [^/]*\.spec\.[jt]sx? # JS/TS: foo.spec.*
| [^/]*Tests?\.kt # Kotlin
| [^/]*Spec\.kt # Kotlin (Kotest)
| [^/]*Tests?\.swift # Swift (XCTest)
| [^/]*Spec\.swift # Swift (Quick)
| [^/]*Tests?\.cs # C#
| test_[^/]*\.(?:cpp|cc|cxx) # C++ (Google Test)
| [^/]*_test\.(?:cpp|cc|cxx) # C++
| test_[^/]*\.c # C
| [^/]*_test\.c # C
| [^/]*Spec\.scala # Scala (ScalaTest)
| [^/]*Suite\.scala # Scala (MUnit)
| [^/]*Test\.scala # Scala
| [^/]*_test\.dart # Dart
| test_[^/]*\.dart # Dart
| [^/]*_spec\.lua # Lua (busted)
| [^/]*_test\.lua # Lua
| test_[^/]*\.lua # Lua (luaunit)
| [^/]*_test\.rs # Rust
| tests\.rs # Rust (top-level integration test module)
| [^/]*_test\.exs # Elixir (ExUnit)
| [^/]*Spec\.hs # Haskell (HSpec)
| [^/]*Test\.hs # Haskell (Tasty/HUnit)
| test_[^/]*\.ml # OCaml (Alcotest)
| [^/]*_test\.ml # OCaml
| test[-_][^/]*\.[rR] # R (testthat: test-foo.R / test_foo.R)
| [^/]*_test\.zig # Zig
| test_[^/]*\.zig # Zig
| runtests\.jl # Julia (Pkg convention)
| test_[^/]*\.jl # Julia
| [^/]*_test\.jl # Julia
| [^/]*\.test\.vue # Vue
| [^/]*\.spec\.vue # Vue
| [^/]*\.test\.svelte # Svelte
| [^/]*\.spec\.svelte # Svelte
| tst_[^/]*\.qml # QML (Qt Test)
| [^/]*\.bats # Bash (bats-core)
| test_[^/]*\.(?:sh|bash|zsh) # Shell
| [^/]*_test\.(?:sh|bash|zsh) # Shell
| [^/]*\.Tests\.ps1 # PowerShell (Pester)
| test_helpers?[^/]*\.\w+ # cross-language test helpers
)$
",
)
.expect("test_file_re")
})
}
fn test_dir_re() -> &'static Regex {
static R: OnceLock<Regex> = OnceLock::new();
R.get_or_init(|| Regex::new(r"(?:^|/)(?:tests?|__tests__|spec|testing)(?:/|$)").unwrap())
}
fn compat_dir_re() -> &'static Regex {
static R: OnceLock<Regex> = OnceLock::new();
R.get_or_init(|| Regex::new(r"(?:^|/)(?:compat|_compat|legacy)(?:/|$)").unwrap())
}
fn examples_dir_re() -> &'static Regex {
static R: OnceLock<Regex> = OnceLock::new();
R.get_or_init(|| Regex::new(r"(?:^|/)(?:_?examples?|docs?_src)(?:/|$)").unwrap())
}
pub fn file_path_penalty(file: &str) -> f32 {
let normalised = file.replace('\\', "/");
let mut penalty: f32 = 1.0;
if test_file_re().is_match(&normalised) || test_dir_re().is_match(&normalised) {
penalty *= strong_penalty();
}
if compat_dir_re().is_match(&normalised) {
penalty *= strong_penalty();
}
if examples_dir_re().is_match(&normalised) {
penalty *= strong_penalty();
}
if normalised.ends_with(".d.ts") {
penalty *= mild_penalty();
}
let name = Path::new(&normalised)
.file_name()
.and_then(|n| n.to_str())
.unwrap_or("");
if matches!(name, "__init__.py" | "package-info.java") {
penalty *= moderate_penalty();
}
penalty
}
pub fn should_apply_path_penalty(query: &str) -> bool {
let q = query.to_lowercase();
!(q.contains("test") || q.contains("spec") || q.contains("benchmark"))
}
fn definition_boost_frac() -> f32 {
env_f32("COLGREP_DEF_BOOST", 0.25)
}
pub fn apply_definition_boost<T>(
items: &mut [T],
query: &str,
name: impl Fn(&T) -> &str,
is_definition: impl Fn(&T) -> bool,
score: impl Fn(&T) -> f32,
set_score: impl Fn(&mut T, f32),
) {
if items.is_empty() {
return;
}
let max_score = items.iter().map(&score).fold(f32::NEG_INFINITY, f32::max);
if !max_score.is_finite() || max_score <= 0.0 {
return;
}
let query_tokens: std::collections::HashSet<String> =
next_plaid::text_search::tokenize_identifiers(query)
.into_iter()
.collect();
if query_tokens.is_empty() {
return;
}
let boost = max_score * definition_boost_frac();
for item in items.iter_mut() {
if !is_definition(item) {
continue;
}
let n = name(item).to_lowercase();
if n.is_empty() {
continue;
}
let name_tokens = next_plaid::text_search::tokenize_identifiers(&n);
let hit = name_tokens.iter().any(|t| query_tokens.contains(t));
if hit {
let cur = score(item);
set_score(item, cur + boost);
}
}
}
fn path_stem_boost_frac() -> f32 {
env_f32("COLGREP_STEM_BOOST", 0.40)
}
fn path_stem_prefix_frac() -> f32 {
env_f32("COLGREP_STEM_PREFIX_BOOST", 0.20)
}
fn env_flag(name: &str, default: bool) -> bool {
match std::env::var(name) {
Ok(v) => matches!(v.as_str(), "1" | "true" | "TRUE" | "yes"),
Err(_) => default,
}
}
fn stem_stopword_filter() -> bool {
env_flag("COLGREP_STEM_STOPWORDS", true)
}
fn stem_plural_snake() -> bool {
env_flag("COLGREP_STEM_PLURAL_SNAKE", true)
}
const STEM_BOOST_STOPWORDS: &[&str] = &[
"a", "an", "and", "are", "as", "at", "be", "by", "do", "does", "for", "from", "has", "have",
"how", "if", "in", "into", "is", "it", "its", "of", "on", "or", "so", "that", "the", "their",
"then", "there", "these", "this", "to", "was", "were", "what", "when", "where", "which", "who",
"why", "with",
];
pub fn apply_path_stem_boost<T>(
items: &mut [T],
query: &str,
file_path: impl Fn(&T) -> &str,
score: impl Fn(&T) -> f32,
set_score: impl Fn(&mut T, f32),
) {
if items.is_empty() {
return;
}
let max_score = items.iter().map(&score).fold(f32::NEG_INFINITY, f32::max);
if !max_score.is_finite() || max_score <= 0.0 {
return;
}
let stopwords: std::collections::HashSet<&'static str> = if stem_stopword_filter() {
STEM_BOOST_STOPWORDS.iter().copied().collect()
} else {
std::collections::HashSet::new()
};
let raw_q: Vec<String> = next_plaid::text_search::tokenize_identifiers(query);
let query_tokens: std::collections::HashSet<String> = raw_q
.iter()
.filter(|t| !stopwords.contains(t.as_str()))
.cloned()
.collect();
if query_tokens.is_empty() {
return;
}
let do_plural_snake = stem_plural_snake();
let max_boost = max_score * path_stem_boost_frac();
let max_prefix_boost = max_score * path_stem_prefix_frac();
for item in items.iter_mut() {
let stem = Path::new(file_path(item))
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or("")
.to_lowercase();
if stem.is_empty() {
continue;
}
let stem_tokens = next_plaid::text_search::tokenize_identifiers(&stem);
let normalize = |s: &str| -> Vec<String> {
let mut out = vec![s.to_string()];
if do_plural_snake {
let stripped = s.replace('_', "");
if stripped != s {
out.push(stripped);
}
if s.ends_with('s') && s.len() > 1 {
out.push(s[..s.len() - 1].to_string());
}
}
out
};
let mut exact_hit = false;
let mut prefix_hit = false;
'outer: for qtok in &query_tokens {
let qvars = normalize(qtok);
for stem_tok in &stem_tokens {
let svars = normalize(stem_tok);
if svars.iter().any(|sv| qvars.iter().any(|qv| sv == qv)) {
exact_hit = true;
break 'outer;
}
if svars.iter().any(|sv| {
qvars.iter().any(|qv| {
let (short, long) = if sv.len() <= qv.len() {
(sv.as_str(), qv.as_str())
} else {
(qv.as_str(), sv.as_str())
};
short.len() >= 3 && long.starts_with(short)
})
}) {
prefix_hit = true;
}
}
}
if exact_hit {
let cur = score(item);
set_score(item, cur + max_boost);
} else if prefix_hit {
let cur = score(item);
set_score(item, cur + max_prefix_boost);
}
}
}
fn file_coherence_boost_frac() -> f32 {
env_f32("COLGREP_COHERENCE_BOOST", 0.20)
}
pub fn apply_file_coherence_boost<T>(
items: &mut [T],
file_path: impl Fn(&T) -> &str,
score: impl Fn(&T) -> f32,
set_score: impl Fn(&mut T, f32),
) {
if items.is_empty() {
return;
}
let max_score = items.iter().map(&score).fold(f32::NEG_INFINITY, f32::max);
if !max_score.is_finite() || max_score <= 0.0 {
return;
}
let mut per_file: HashMap<String, (f32, usize)> = HashMap::new();
for (i, item) in items.iter().enumerate() {
let path = file_path(item).to_string();
let s = score(item);
per_file
.entry(path)
.and_modify(|(sum, top_idx)| {
*sum += s;
if s > score(&items[*top_idx]) {
*top_idx = i;
}
})
.or_insert((s, i));
}
let max_file_sum = per_file
.values()
.map(|(sum, _)| *sum)
.fold(f32::NEG_INFINITY, f32::max);
if !max_file_sum.is_finite() || max_file_sum <= 0.0 {
return;
}
let boost_unit = max_score * file_coherence_boost_frac();
let updates: Vec<(usize, f32)> = per_file
.into_values()
.map(|(sum, idx)| (idx, score(&items[idx]) + boost_unit * sum / max_file_sum))
.collect();
for (idx, new_score) in updates {
set_score(&mut items[idx], new_score);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_canonical_path_no_penalty() {
assert_eq!(file_path_penalty("src/foo.py"), 1.0);
assert_eq!(file_path_penalty("lib/core/Axios.js"), 1.0);
}
#[test]
fn test_python_test_files_penalised() {
assert!(file_path_penalty("tests/test_foo.py") < 0.5);
assert!(file_path_penalty("foo_test.py") < 0.5);
assert!(file_path_penalty("src/__init__.py") < 1.0);
}
#[test]
fn test_compat_and_examples_penalised() {
assert!(file_path_penalty("compat/old_api.py") < 0.5);
assert!(file_path_penalty("legacy/foo.py") < 0.5);
assert!(file_path_penalty("examples/demo.py") < 0.5);
}
#[test]
fn test_dts_mild_penalty() {
let p = file_path_penalty("types/index.d.ts");
assert!(p < 1.0 && p > 0.5);
}
#[test]
fn test_compounding_penalty() {
let p = file_path_penalty("compat/foo_test.py");
assert!(p < 0.1, "expected compound penalty, got {p}");
}
#[test]
fn test_same_category_does_not_compound() {
let p = file_path_penalty("tests/foo_test.py");
assert!((p - 0.3).abs() < 1e-6, "expected 0.3, got {p}");
}
#[test]
fn test_should_apply_path_penalty() {
assert!(should_apply_path_penalty("how authentication works"));
assert!(!should_apply_path_penalty("unit test for foo"));
assert!(!should_apply_path_penalty("benchmark suite"));
assert!(!should_apply_path_penalty("rspec setup"));
}
}