use std::collections::HashMap;
use std::path::Path;
use std::sync::OnceLock;
use regex::{Regex, RegexBuilder};
use crate::chunk::CodeChunk;
pub const STRONG_PENALTY: f32 = 0.3;
pub const MODERATE_PENALTY: f32 = 0.5;
pub const MILD_PENALTY: f32 = 0.7;
pub const FILE_SATURATION_THRESHOLD: usize = 1;
pub const FILE_SATURATION_DECAY: f32 = 0.5;
fn test_file_re() -> &'static Regex {
static RE: OnceLock<Regex> = OnceLock::new();
RE.get_or_init(|| {
let pattern = concat!(
r"(?:^|/)",
r"(?:",
r"test_[^/]*\.py",
r"|[^/]*_test\.py",
r"|[^/]*_test\.go",
r"|[^/]*Tests?\.java",
r"|[^/]*Test\.php",
r"|[^/]*_spec\.rb",
r"|[^/]*_test\.rb",
r"|[^/]*\.test\.[jt]sx?",
r"|[^/]*\.spec\.[jt]sx?",
r"|[^/]*Tests?\.kt",
r"|[^/]*Spec\.kt",
r"|[^/]*Tests?\.swift",
r"|[^/]*Spec\.swift",
r"|[^/]*Tests?\.cs",
r"|test_[^/]*\.cpp",
r"|[^/]*_test\.cpp",
r"|test_[^/]*\.c",
r"|[^/]*_test\.c",
r"|[^/]*Spec\.scala",
r"|[^/]*Suite\.scala",
r"|[^/]*Test\.scala",
r"|[^/]*_test\.dart",
r"|test_[^/]*\.dart",
r"|[^/]*_spec\.lua",
r"|[^/]*_test\.lua",
r"|test_[^/]*\.lua",
r"|test_helpers?[^/]*\.\w+",
r")$",
);
Regex::new(pattern).expect("test-file regex compiles")
})
}
fn test_dir_re() -> &'static Regex {
static RE: OnceLock<Regex> = OnceLock::new();
RE.get_or_init(|| {
Regex::new(r"(?:^|/)(?:tests?|__tests__|spec|testing)(?:/|$)")
.expect("test-dir regex compiles")
})
}
fn compat_dir_re() -> &'static Regex {
static RE: OnceLock<Regex> = OnceLock::new();
RE.get_or_init(|| {
Regex::new(r"(?:^|/)(?:compat|_compat|legacy)(?:/|$)").expect("compat-dir regex compiles")
})
}
fn examples_dir_re() -> &'static Regex {
static RE: OnceLock<Regex> = OnceLock::new();
RE.get_or_init(|| {
Regex::new(r"(?:^|/)(?:_?examples?|docs?_src)(?:/|$)").expect("examples-dir regex compiles")
})
}
fn type_defs_re() -> &'static Regex {
static RE: OnceLock<Regex> = OnceLock::new();
RE.get_or_init(|| {
RegexBuilder::new(r"\.d\.ts$")
.build()
.expect("dts regex compiles")
})
}
const REEXPORT_FILENAMES: &[&str] = &["__init__.py", "package-info.java"];
#[must_use]
pub fn file_path_penalty(file_path: &str) -> f32 {
let normalised = file_path.replace('\\', "/");
let mut penalty = 1.0_f32;
if test_file_re().is_match(&normalised) || test_dir_re().is_match(&normalised) {
penalty *= STRONG_PENALTY;
}
if let Some(filename) = Path::new(file_path).file_name().and_then(|f| f.to_str())
&& REEXPORT_FILENAMES.contains(&filename)
{
penalty *= MODERATE_PENALTY;
}
if compat_dir_re().is_match(&normalised) {
penalty *= STRONG_PENALTY;
}
if examples_dir_re().is_match(&normalised) {
penalty *= STRONG_PENALTY;
}
if type_defs_re().is_match(&normalised) {
penalty *= MILD_PENALTY;
}
penalty
}
#[must_use]
pub fn rerank_topk(
scores: &[(usize, f32)],
chunks: &[CodeChunk],
top_k: usize,
penalise_paths: bool,
) -> Vec<(usize, f32)> {
if scores.is_empty() || top_k == 0 {
return Vec::new();
}
let mut penalty_cache: HashMap<String, f32> = HashMap::new();
let mut penalised: Vec<(usize, f32)> = scores
.iter()
.map(|&(idx, score)| {
if !penalise_paths {
return (idx, score);
}
let path = &chunks[idx].file_path;
let factor = *penalty_cache
.entry(path.clone())
.or_insert_with(|| file_path_penalty(path));
(idx, score * factor)
})
.collect();
penalised.sort_unstable_by(|a, b| b.1.total_cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
let mut file_selected: HashMap<String, usize> = HashMap::new();
let mut selected: Vec<(usize, f32)> = Vec::with_capacity(top_k.min(penalised.len()));
let mut min_selected = f32::INFINITY;
for &(idx, pen_score) in &penalised {
if selected.len() >= top_k && pen_score <= min_selected {
break;
}
let path = chunks[idx].file_path.clone();
let already = *file_selected.get(&path).unwrap_or(&0);
let eff_score = if already >= FILE_SATURATION_THRESHOLD {
let excess = already - FILE_SATURATION_THRESHOLD + 1;
let excess_i32 = i32::try_from(excess).unwrap_or(i32::MAX);
pen_score * FILE_SATURATION_DECAY.powi(excess_i32)
} else {
pen_score
};
selected.push((idx, eff_score));
file_selected.insert(path, already + 1);
if selected.len() >= top_k {
min_selected = selected
.iter()
.map(|(_, s)| *s)
.fold(f32::INFINITY, f32::min);
}
}
selected.sort_unstable_by(|a, b| b.1.total_cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
selected.truncate(top_k);
selected
}
#[cfg(test)]
mod tests {
use super::*;
fn chunk_at(path: &str) -> CodeChunk {
CodeChunk {
file_path: path.to_string(),
name: String::new(),
kind: String::new(),
content_kind: crate::chunk::ContentKind::Code,
start_line: 1,
symbol_line: 1,
end_line: 1,
content: String::new(),
enriched_content: String::new(),
qualified_name: None,
}
}
#[test]
fn penalties_test_file_regex_14_langs() {
let cases: &[&str] = &[
"src/test_foo.py",
"src/foo_test.py",
"pkg/foo_test.go",
"src/FooTest.java",
"src/FooTests.java",
"src/FooTest.php",
"spec/foo_spec.rb",
"test/foo_test.rb",
"src/foo.test.js",
"src/foo.spec.ts",
"src/foo.test.tsx",
"src/FooTest.kt",
"src/FooTests.kt",
"src/FooSpec.kt",
"src/FooTests.swift",
"src/FooSpec.swift",
"src/FooTest.cs",
"src/FooTests.cs",
"src/test_foo.cpp",
"src/foo_test.cpp",
"src/test_foo.c",
"src/foo_test.c",
"src/FooSpec.scala",
"src/FooSuite.scala",
"src/FooTest.scala",
"src/foo_test.dart",
"src/test_foo.dart",
"src/foo_spec.lua",
"src/foo_test.lua",
"src/test_foo.lua",
"test/test_helper.rb",
"test/test_helpers.go",
];
for path in cases {
assert!(
test_file_re().is_match(path),
"expected test_file_re to match {path:?}"
);
assert!(
(file_path_penalty(path) - STRONG_PENALTY).abs() < 1e-6,
"expected STRONG_PENALTY for {path:?}; got {}",
file_path_penalty(path)
);
}
}
#[test]
fn penalties_compat_dir() {
for path in &["compat/foo.py", "src/_compat/bar.rs", "legacy/baz.go"] {
assert!(
compat_dir_re().is_match(path),
"expected compat match for {path:?}"
);
assert!((file_path_penalty(path) - STRONG_PENALTY).abs() < 1e-6);
}
}
#[test]
fn penalties_examples_dir() {
for path in &[
"examples/foo.py",
"_examples/bar.rs",
"example/baz.go",
"docs_src/x.md",
] {
assert!(
examples_dir_re().is_match(path),
"expected examples match for {path:?}"
);
assert!((file_path_penalty(path) - STRONG_PENALTY).abs() < 1e-6);
}
}
#[test]
fn penalties_init_py_reexport() {
assert!((file_path_penalty("src/__init__.py") - MODERATE_PENALTY).abs() < 1e-6);
assert!(
(file_path_penalty("src/com/myorg/package-info.java") - MODERATE_PENALTY).abs() < 1e-6
);
}
#[test]
fn penalties_dts_stub() {
assert!((file_path_penalty("src/foo.d.ts") - MILD_PENALTY).abs() < 1e-6);
}
#[test]
fn non_penalized_path_is_identity() {
assert!((file_path_penalty("src/foo.rs") - 1.0).abs() < 1e-6);
assert!((file_path_penalty("lib/bar.py") - 1.0).abs() < 1e-6);
}
#[test]
fn rerank_topk_saturation_decay() {
let chunks = vec![
chunk_at("src/foo.rs"),
chunk_at("src/foo.rs"),
chunk_at("src/foo.rs"),
chunk_at("src/bar.rs"),
];
let scores = vec![(0, 1.0_f32), (1, 1.0), (2, 1.0), (3, 1.0)];
let got = rerank_topk(&scores, &chunks, 4, true);
assert_eq!(got.len(), 4);
let scored: HashMap<usize, f32> = got.iter().copied().collect();
assert!(
(scored[&0] - 1.0).abs() < 1e-6,
"scored[0] = {}",
scored[&0]
);
assert!(
(scored[&3] - 1.0).abs() < 1e-6,
"scored[3] = {}",
scored[&3]
);
assert!(
(scored[&1] - 0.5).abs() < 1e-6,
"scored[1] = {}",
scored[&1]
);
assert!(
(scored[&2] - 0.25).abs() < 1e-6,
"scored[2] = {}",
scored[&2]
);
}
#[test]
fn rerank_topk_greedy_early_exit() {
let chunks: Vec<CodeChunk> = (0..10).map(|i| chunk_at(&format!("src/f{i}.rs"))).collect();
let scores: Vec<(usize, f32)> = (0..10).map(|i| (i, 10.0 - i as f32)).collect();
let got = rerank_topk(&scores, &chunks, 3, true);
assert_eq!(got.len(), 3);
let indices: Vec<usize> = got.iter().map(|(i, _)| *i).collect();
assert_eq!(indices, vec![0, 1, 2]);
}
#[test]
fn property_penalty_regex_parity_python() {
let production_paths = &[
"src/foo.py",
"lib/parser.rs",
"crates/ripvec-core/src/encoder/semble/tokens.rs",
"pkg/server/handler.go",
"app/models/user.rb",
"main.c",
"include/foo.h",
];
for path in production_paths {
assert!(
(file_path_penalty(path) - 1.0).abs() < 1e-6,
"non-test path {path:?} should have penalty 1.0; got {}",
file_path_penalty(path)
);
}
for path in &["test_foo.py", "src/__init__.py", "src/foo.d.ts"] {
assert!(
file_path_penalty(path) < 1.0,
"test-shaped path {path:?} should be penalised; got 1.0"
);
}
}
#[test]
fn rerank_topk_no_path_penalties_still_decays() {
let chunks = vec![chunk_at("test/test_foo.py"), chunk_at("test/test_foo.py")];
let scores = vec![(0, 1.0_f32), (1, 1.0)];
let got = rerank_topk(&scores, &chunks, 2, false);
let scored: HashMap<usize, f32> = got.iter().copied().collect();
assert!((scored[&0] - 1.0).abs() < 1e-6);
assert!((scored[&1] - 0.5).abs() < 1e-6);
}
}