use std::cmp::Ordering;
use std::collections::HashMap;
use std::sync::LazyLock;
use indexmap::IndexMap;
use regex::Regex;
use crate::types::Chunk;
pub const STRONG_PENALTY: f64 = 0.3;
pub const MODERATE_PENALTY: f64 = 0.5;
pub const MILD_PENALTY: f64 = 0.7;
pub const FILE_SATURATION_THRESHOLD: usize = 1;
pub const FILE_SATURATION_DECAY: f64 = 0.5;
const REEXPORT_FILENAMES: [&str; 2] = ["__init__.py", "package-info.java"];
fn compile(pattern: &str) -> Regex {
Regex::new(pattern).expect("penalty regex is valid")
}
static TEST_FILE_RE: LazyLock<Regex> = LazyLock::new(|| {
compile(concat!(
r"(?:^|/)(?:",
r"test_[^/]*\.py|[^/]*_test\.py",
r"|[^/]*_test\.go",
r"|[^/]*Tests?\.java",
r"|[^/]*Test\.php",
r"|[^/]*_spec\.rb|[^/]*_test\.rb",
r"|[^/]*\.test\.[jt]sx?|[^/]*\.spec\.[jt]sx?",
r"|[^/]*Tests?\.kt|[^/]*Spec\.kt",
r"|[^/]*Tests?\.swift|[^/]*Spec\.swift",
r"|[^/]*Tests?\.cs",
r"|test_[^/]*\.cpp|[^/]*_test\.cpp|test_[^/]*\.c|[^/]*_test\.c",
r"|[^/]*Spec\.scala|[^/]*Suite\.scala|[^/]*Test\.scala",
r"|[^/]*_test\.dart|test_[^/]*\.dart",
r"|[^/]*_spec\.lua|[^/]*_test\.lua|test_[^/]*\.lua",
r"|test_helper[^/]*\.\w+",
r")$",
))
});
static TEST_DIR_RE: LazyLock<Regex> =
LazyLock::new(|| compile(r"(?:^|/)(?:tests?|__tests__|spec|testing)(?:/|$)"));
static COMPAT_DIR_RE: LazyLock<Regex> =
LazyLock::new(|| compile(r"(?:^|/)(?:compat|_compat|legacy)(?:/|$)"));
static EXAMPLES_DIR_RE: LazyLock<Regex> =
LazyLock::new(|| compile(r"(?:^|/)(?:_?examples?|docs?_src)(?:/|$)"));
static TYPE_DEFS_RE: LazyLock<Regex> = LazyLock::new(|| compile(r"\.d\.ts$"));
pub fn file_path_penalty(file_path: &str) -> f64 {
let normalised = file_path.replace('\\', "/");
let mut penalty = 1.0;
if TEST_FILE_RE.is_match(&normalised) || TEST_DIR_RE.is_match(&normalised) {
penalty *= STRONG_PENALTY;
}
let basename = match file_path.rfind('/') {
Some(i) => &file_path[i + 1..],
None => file_path,
};
if REEXPORT_FILENAMES.contains(&basename) {
penalty *= MODERATE_PENALTY;
}
if COMPAT_DIR_RE.is_match(&normalised) {
penalty *= STRONG_PENALTY;
}
if EXAMPLES_DIR_RE.is_match(&normalised) {
penalty *= STRONG_PENALTY;
}
if TYPE_DEFS_RE.is_match(&normalised) {
penalty *= MILD_PENALTY;
}
penalty
}
fn by_score_desc(a: f64, b: f64) -> Ordering {
b.partial_cmp(&a).unwrap_or(Ordering::Equal)
}
pub fn rerank_top_k(
scores: &super::Scores,
chunks: &[Chunk],
top_k: usize,
penalise_paths: bool,
) -> Vec<(usize, f64)> {
if scores.is_empty() || top_k == 0 {
return Vec::new();
}
let mut penalty_cache: HashMap<&str, f64> = HashMap::new();
let mut penalised: IndexMap<usize, f64> = IndexMap::with_capacity(scores.len());
for (&idx, &score) in scores {
let file_path = chunks[idx].file_path.as_str();
let pen = if penalise_paths {
*penalty_cache
.entry(file_path)
.or_insert_with(|| file_path_penalty(file_path))
} else {
1.0
};
penalised.insert(idx, score * pen);
}
let mut ranked: Vec<usize> = penalised.keys().copied().collect();
ranked.sort_by(|&a, &b| by_score_desc(penalised[&a], penalised[&b]));
let mut file_selected: HashMap<&str, usize> = HashMap::new();
let mut selected: Vec<(f64, usize)> = Vec::new();
let mut min_selected = f64::INFINITY;
for &idx in &ranked {
let pen_score = penalised[&idx];
if selected.len() >= top_k && pen_score <= min_selected {
break;
}
let file_path = chunks[idx].file_path.as_str();
let already = file_selected.get(file_path).copied().unwrap_or(0);
let mut eff_score = pen_score;
if already >= FILE_SATURATION_THRESHOLD {
let excess = already - FILE_SATURATION_THRESHOLD + 1;
eff_score *= FILE_SATURATION_DECAY.powi(excess as i32);
}
selected.push((eff_score, idx));
file_selected.insert(file_path, already + 1);
if selected.len() >= top_k {
min_selected = selected
.iter()
.map(|&(s, _)| s)
.fold(f64::INFINITY, f64::min);
}
}
selected.sort_by(|a, b| by_score_desc(a.0, b.0));
selected.truncate(top_k);
selected
.into_iter()
.map(|(score, idx)| (idx, score))
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
fn chunk(file_path: &str, idx: u32) -> Chunk {
Chunk {
content: format!("chunk {idx}"),
file_path: file_path.to_string(),
start_line: idx,
end_line: idx + 1,
language: None,
}
}
fn scores_from(pairs: &[(usize, f64)]) -> super::super::Scores {
pairs.iter().copied().collect()
}
#[test]
fn penalises_js_ts_test_files() {
assert_eq!(file_path_penalty("src/foo.test.ts"), STRONG_PENALTY);
assert_eq!(file_path_penalty("src/foo.spec.tsx"), STRONG_PENALTY);
}
#[test]
fn penalises_reexport_barrel() {
assert_eq!(file_path_penalty("src/__init__.py"), MODERATE_PENALTY);
assert_eq!(file_path_penalty("__init__.py"), MODERATE_PENALTY);
}
#[test]
fn penalises_type_stubs() {
assert_eq!(file_path_penalty("src/foo.d.ts"), MILD_PENALTY);
assert_eq!(file_path_penalty("src/__init__.d.ts"), MILD_PENALTY);
}
#[test]
fn test_dir_and_test_file_share_one_strong_branch() {
assert!((file_path_penalty("tests/test_foo.py") - STRONG_PENALTY).abs() < 1e-10);
}
#[test]
fn ordinary_files_are_unpenalised() {
assert_eq!(file_path_penalty("src/foo.ts"), 1.0);
}
#[test]
fn compounds_strong_penalties() {
assert!(
(file_path_penalty("examples/foo.test.ts") - STRONG_PENALTY * STRONG_PENALTY).abs()
< 1e-10
);
}
#[test]
fn penalises_dirs_and_other_languages() {
assert_eq!(file_path_penalty("compat/foo.ts"), STRONG_PENALTY);
assert_eq!(file_path_penalty("examples/foo.ts"), STRONG_PENALTY);
assert_eq!(file_path_penalty("legacy/foo.ts"), STRONG_PENALTY);
assert_eq!(file_path_penalty("pkg/foo_test.go"), STRONG_PENALTY);
assert_eq!(file_path_penalty("src/FooTests.java"), STRONG_PENALTY);
}
#[test]
fn normalises_backslashes_before_matching() {
assert_eq!(file_path_penalty("src\\foo.test.ts"), STRONG_PENALTY);
}
#[test]
fn empty_input_returns_empty() {
let chunks: Vec<Chunk> = vec![];
assert!(rerank_top_k(&scores_from(&[]), &chunks, 5, true).is_empty());
}
#[test]
fn non_positive_topk_returns_empty() {
let chunks = [chunk("a.ts", 0)];
let scores = scores_from(&[(0, 1.0)]);
assert!(rerank_top_k(&scores, &chunks, 0, true).is_empty());
}
#[test]
fn applies_saturation_decay_within_a_file() {
let chunks = [
chunk("src/foo.ts", 0),
chunk("src/foo.ts", 1),
chunk("src/foo.ts", 2),
chunk("src/foo.ts", 3),
];
let scores = scores_from(&[(0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0)]);
let result = rerank_top_k(&scores, &chunks, 4, false);
assert_eq!(result.len(), 4);
let s: Vec<f64> = result.iter().map(|&(_, s)| s).collect();
assert!((s[0] - 1.0).abs() < 1e-10);
assert!((s[1] - FILE_SATURATION_DECAY).abs() < 1e-10);
assert!((s[2] - FILE_SATURATION_DECAY.powi(2)).abs() < 1e-10);
assert!((s[3] - FILE_SATURATION_DECAY.powi(3)).abs() < 1e-10);
}
#[test]
fn truncates_to_topk_after_sorting() {
let chunks = [chunk("a.ts", 0), chunk("b.ts", 1), chunk("c.ts", 2)];
let scores = scores_from(&[(0, 0.5), (1, 0.9), (2, 0.1)]);
let result = rerank_top_k(&scores, &chunks, 2, false);
assert_eq!(result.len(), 2);
assert_eq!(result[0].0, 1); assert_eq!(result[1].0, 0); }
#[test]
fn applies_path_penalties_before_sorting() {
let chunks = [chunk("src/foo.test.ts", 0), chunk("src/foo.ts", 1)];
let scores = scores_from(&[(0, 0.9), (1, 0.5)]);
let result = rerank_top_k(&scores, &chunks, 2, true);
assert_eq!(result[0].0, 1); assert_eq!(result[1].0, 0);
assert!((result[0].1 - 0.5).abs() < 1e-10);
assert!((result[1].1 - 0.9 * STRONG_PENALTY).abs() < 1e-10);
}
#[test]
fn skips_path_penalties_when_disabled() {
let chunks = [chunk("src/foo.test.ts", 0), chunk("src/foo.ts", 1)];
let scores = scores_from(&[(0, 0.9), (1, 0.5)]);
let result = rerank_top_k(&scores, &chunks, 2, false);
assert_eq!(result[0].0, 0);
assert!((result[0].1 - 0.9).abs() < 1e-10);
assert_eq!(result[1].0, 1);
assert!((result[1].1 - 0.5).abs() < 1e-10);
}
#[test]
fn mixes_saturation_decay_across_files() {
let chunks = [
chunk("a.ts", 0),
chunk("a.ts", 1),
chunk("b.ts", 2),
chunk("b.ts", 3),
];
let scores = scores_from(&[(0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0)]);
let result = rerank_top_k(&scores, &chunks, 4, false);
assert_eq!(result.len(), 4);
let s: Vec<f64> = result.iter().map(|&(_, sc)| sc).collect();
assert!((s[0] - 1.0).abs() < 1e-10);
assert!((s[1] - 1.0).abs() < 1e-10);
assert!((s[2] - FILE_SATURATION_DECAY).abs() < 1e-10);
assert!((s[3] - FILE_SATURATION_DECAY).abs() < 1e-10);
}
}