ripvec_core/encoder/ripvec/penalties.rs
1//! File-path penalties + `rerank_topk` with file-saturation decay.
2//!
3//! Port of `~/src/semble/src/semble/ranking/penalties.py`. Two surfaces:
4//!
5//! 1. [`file_path_penalty`] — multiplicative penalty per file path.
6//! Combines penalties for test files, compat/legacy/examples
7//! directories, re-export barrels (`__init__.py`, `package-info.java`),
8//! and `.d.ts` declaration stubs.
9//! 2. [`rerank_topk`] — greedy top-k selection that applies path
10//! penalties (when `penalise_paths == true`) then decays by 0.5 per
11//! extra chunk from the same file beyond a threshold of 1.
12//!
13//! ## Indexing convention
14//!
15//! Where Python uses `dict[Chunk, float]` keyed by hashable Chunk,
16//! Rust uses `(chunk_index, score)` pairs — the same convention ripvec
17//! already uses in [`crate::hybrid`]. This avoids adding `Hash`/`Eq`
18//! impls to [`CodeChunk`] just to satisfy a HashMap key.
19
20use std::collections::HashMap;
21use std::path::Path;
22use std::sync::OnceLock;
23
24use regex::{Regex, RegexBuilder};
25
26use crate::chunk::CodeChunk;
27
28// ---------------------------------------------------------------------------
29// Penalty constants (Python: `_STRONG_PENALTY` etc. in penalties.py:67).
30// ---------------------------------------------------------------------------
31
32/// Test files, compat shims, example/doc code.
33pub const STRONG_PENALTY: f32 = 0.3;
34/// Re-export / metadata files (`__init__.py`, `package-info.java`).
35pub const MODERATE_PENALTY: f32 = 0.5;
36/// `.d.ts` declaration stubs (still carry useful type info).
37pub const MILD_PENALTY: f32 = 0.7;
38
39/// Maximum chunks from the same file before saturation penalty applies.
40pub const FILE_SATURATION_THRESHOLD: usize = 1;
41
42/// Multiplicative penalty per extra chunk from the same file beyond
43/// [`FILE_SATURATION_THRESHOLD`]. Excess chunks pay `decay^excess`.
44pub const FILE_SATURATION_DECAY: f32 = 0.5;
45
46// ---------------------------------------------------------------------------
47// Path-prior regex bank (Python: penalties.py:8-65).
48// ---------------------------------------------------------------------------
49
50/// Regex matching test files across 14 languages plus shared helpers.
51///
52/// Mirrors `_TEST_FILE_RE` from `penalties.py:8`. Lazy-compiled.
53fn test_file_re() -> &'static Regex {
54 static RE: OnceLock<Regex> = OnceLock::new();
55 RE.get_or_init(|| {
56 // Match exactly the Python pattern. The non-capture group after
57 // `(?:^|/)` lists per-language test-file conventions; the
58 // pattern terminates with `$`.
59 let pattern = concat!(
60 r"(?:^|/)",
61 r"(?:",
62 // Python
63 r"test_[^/]*\.py",
64 r"|[^/]*_test\.py",
65 // Go
66 r"|[^/]*_test\.go",
67 // Java
68 r"|[^/]*Tests?\.java",
69 // PHP
70 r"|[^/]*Test\.php",
71 // Ruby
72 r"|[^/]*_spec\.rb",
73 r"|[^/]*_test\.rb",
74 // JavaScript / TypeScript
75 r"|[^/]*\.test\.[jt]sx?",
76 r"|[^/]*\.spec\.[jt]sx?",
77 // Kotlin
78 r"|[^/]*Tests?\.kt",
79 r"|[^/]*Spec\.kt",
80 // Swift
81 r"|[^/]*Tests?\.swift",
82 r"|[^/]*Spec\.swift",
83 // C#
84 r"|[^/]*Tests?\.cs",
85 // C / C++
86 r"|test_[^/]*\.cpp",
87 r"|[^/]*_test\.cpp",
88 r"|test_[^/]*\.c",
89 r"|[^/]*_test\.c",
90 // Scala
91 r"|[^/]*Spec\.scala",
92 r"|[^/]*Suite\.scala",
93 r"|[^/]*Test\.scala",
94 // Dart
95 r"|[^/]*_test\.dart",
96 r"|test_[^/]*\.dart",
97 // Lua
98 r"|[^/]*_spec\.lua",
99 r"|[^/]*_test\.lua",
100 r"|test_[^/]*\.lua",
101 // Shared helper patterns (all languages)
102 r"|test_helpers?[^/]*\.\w+",
103 r")$",
104 );
105 Regex::new(pattern).expect("test-file regex compiles")
106 })
107}
108
109/// Regex matching directories named `test`, `tests`, `__tests__`, `spec`,
110/// `testing`. Mirrors `_TEST_DIR_RE` from `penalties.py:56`.
111fn test_dir_re() -> &'static Regex {
112 static RE: OnceLock<Regex> = OnceLock::new();
113 RE.get_or_init(|| {
114 Regex::new(r"(?:^|/)(?:tests?|__tests__|spec|testing)(?:/|$)")
115 .expect("test-dir regex compiles")
116 })
117}
118
119/// Regex matching compat / legacy directories. Mirrors `_COMPAT_DIR_RE`.
120fn compat_dir_re() -> &'static Regex {
121 static RE: OnceLock<Regex> = OnceLock::new();
122 RE.get_or_init(|| {
123 Regex::new(r"(?:^|/)(?:compat|_compat|legacy)(?:/|$)").expect("compat-dir regex compiles")
124 })
125}
126
127/// Regex matching examples / docs-src directories. Mirrors
128/// `_EXAMPLES_DIR_RE`.
129fn examples_dir_re() -> &'static Regex {
130 static RE: OnceLock<Regex> = OnceLock::new();
131 RE.get_or_init(|| {
132 Regex::new(r"(?:^|/)(?:_?examples?|docs?_src)(?:/|$)").expect("examples-dir regex compiles")
133 })
134}
135
136/// Regex matching TypeScript `.d.ts` declaration stubs.
137fn type_defs_re() -> &'static Regex {
138 static RE: OnceLock<Regex> = OnceLock::new();
139 RE.get_or_init(|| {
140 RegexBuilder::new(r"\.d\.ts$")
141 .build()
142 .expect("dts regex compiles")
143 })
144}
145
146/// Filenames that are re-export barrels or package-level metadata.
147const REEXPORT_FILENAMES: &[&str] = &["__init__.py", "package-info.java"];
148
149// ---------------------------------------------------------------------------
150// Public API.
151// ---------------------------------------------------------------------------
152
153/// Combined multiplicative penalty for all applicable path patterns.
154///
155/// Always in `(0.0, 1.0]`. Returns `1.0` when no patterns match.
156/// Each matched category multiplies into the result independently, so a
157/// file matching both a test-dir pattern and a re-export filename pays
158/// `STRONG_PENALTY * MODERATE_PENALTY`.
159///
160/// Backslashes in `file_path` are normalised to forward slashes before
161/// matching, so Windows-style paths work too.
162#[must_use]
163pub fn file_path_penalty(file_path: &str) -> f32 {
164 let normalised = file_path.replace('\\', "/");
165 let mut penalty = 1.0_f32;
166 if test_file_re().is_match(&normalised) || test_dir_re().is_match(&normalised) {
167 penalty *= STRONG_PENALTY;
168 }
169 if let Some(filename) = Path::new(file_path).file_name().and_then(|f| f.to_str())
170 && REEXPORT_FILENAMES.contains(&filename)
171 {
172 penalty *= MODERATE_PENALTY;
173 }
174 if compat_dir_re().is_match(&normalised) {
175 penalty *= STRONG_PENALTY;
176 }
177 if examples_dir_re().is_match(&normalised) {
178 penalty *= STRONG_PENALTY;
179 }
180 if type_defs_re().is_match(&normalised) {
181 penalty *= MILD_PENALTY;
182 }
183 penalty
184}
185
186/// Select the top-k results with optional path penalties and
187/// file-saturation decay.
188///
189/// Mirrors `rerank_topk` from `penalties.py:81`. The greedy pass uses
190/// the early-exit optimisation: once `selected.len() >= top_k` and the
191/// remaining candidate's penalised score cannot beat the current k-th
192/// best, the loop terminates.
193///
194/// - `penalise_paths == true` — apply [`file_path_penalty`] before
195/// sorting and selection. This matches semble's behaviour for hybrid
196/// and BM25-led queries.
197/// - `penalise_paths == false` — bypass path priors (pure-semantic
198/// queries where the priors don't apply).
199///
200/// Returns `(chunk_index, effective_score)` pairs, highest first, with
201/// length capped at `top_k`.
202#[must_use]
203pub fn rerank_topk(
204 scores: &[(usize, f32)],
205 chunks: &[CodeChunk],
206 top_k: usize,
207 penalise_paths: bool,
208) -> Vec<(usize, f32)> {
209 if scores.is_empty() || top_k == 0 {
210 return Vec::new();
211 }
212
213 // Step 1: apply path penalties (or identity), producing (chunk_index, penalised_score).
214 let mut penalty_cache: HashMap<String, f32> = HashMap::new();
215 let mut penalised: Vec<(usize, f32)> = scores
216 .iter()
217 .map(|&(idx, score)| {
218 if !penalise_paths {
219 return (idx, score);
220 }
221 let path = &chunks[idx].file_path;
222 let factor = *penalty_cache
223 .entry(path.clone())
224 .or_insert_with(|| file_path_penalty(path));
225 (idx, score * factor)
226 })
227 .collect();
228
229 // Step 2: sort descending by penalised score with stable tie-break by index.
230 penalised.sort_unstable_by(|a, b| b.1.total_cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
231
232 // Step 3: greedy pick with file-saturation decay and early-exit.
233 let mut file_selected: HashMap<String, usize> = HashMap::new();
234 let mut selected: Vec<(usize, f32)> = Vec::with_capacity(top_k.min(penalised.len()));
235 let mut min_selected = f32::INFINITY;
236
237 for &(idx, pen_score) in &penalised {
238 if selected.len() >= top_k && pen_score <= min_selected {
239 // Remaining (sorted-descending) cannot beat the current k-th best.
240 break;
241 }
242 let path = chunks[idx].file_path.clone();
243 let already = *file_selected.get(&path).unwrap_or(&0);
244 let eff_score = if already >= FILE_SATURATION_THRESHOLD {
245 let excess = already - FILE_SATURATION_THRESHOLD + 1;
246 // Excess is bounded by the number of chunks in the front; the
247 // saturation here is far below i32::MAX in practice. The clamp
248 // is defence-in-depth.
249 let excess_i32 = i32::try_from(excess).unwrap_or(i32::MAX);
250 pen_score * FILE_SATURATION_DECAY.powi(excess_i32)
251 } else {
252 pen_score
253 };
254 selected.push((idx, eff_score));
255 file_selected.insert(path, already + 1);
256
257 if selected.len() >= top_k {
258 // Refresh the min over the current selection.
259 min_selected = selected
260 .iter()
261 .map(|(_, s)| *s)
262 .fold(f32::INFINITY, f32::min);
263 }
264 }
265
266 // Final descending sort on effective scores.
267 selected.sort_unstable_by(|a, b| b.1.total_cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
268 selected.truncate(top_k);
269 selected
270}
271
272#[cfg(test)]
273mod tests {
274 use super::*;
275
276 fn chunk_at(path: &str) -> CodeChunk {
277 CodeChunk {
278 file_path: path.to_string(),
279 name: String::new(),
280 kind: String::new(),
281 start_line: 1,
282 end_line: 1,
283 content: String::new(),
284 enriched_content: String::new(),
285 }
286 }
287
288 /// `test:penalties-test-file-regex-14-langs` — every per-language
289 /// test-file pattern from `penalties.py:8` matches a representative
290 /// file path.
291 #[test]
292 fn penalties_test_file_regex_14_langs() {
293 let cases: &[&str] = &[
294 // Python
295 "src/test_foo.py",
296 "src/foo_test.py",
297 // Go
298 "pkg/foo_test.go",
299 // Java
300 "src/FooTest.java",
301 "src/FooTests.java",
302 // PHP
303 "src/FooTest.php",
304 // Ruby
305 "spec/foo_spec.rb",
306 "test/foo_test.rb",
307 // JS/TS
308 "src/foo.test.js",
309 "src/foo.spec.ts",
310 "src/foo.test.tsx",
311 // Kotlin
312 "src/FooTest.kt",
313 "src/FooTests.kt",
314 "src/FooSpec.kt",
315 // Swift
316 "src/FooTests.swift",
317 "src/FooSpec.swift",
318 // C#
319 "src/FooTest.cs",
320 "src/FooTests.cs",
321 // C / C++
322 "src/test_foo.cpp",
323 "src/foo_test.cpp",
324 "src/test_foo.c",
325 "src/foo_test.c",
326 // Scala
327 "src/FooSpec.scala",
328 "src/FooSuite.scala",
329 "src/FooTest.scala",
330 // Dart
331 "src/foo_test.dart",
332 "src/test_foo.dart",
333 // Lua
334 "src/foo_spec.lua",
335 "src/foo_test.lua",
336 "src/test_foo.lua",
337 // Shared helpers
338 "test/test_helper.rb",
339 "test/test_helpers.go",
340 ];
341 for path in cases {
342 assert!(
343 test_file_re().is_match(path),
344 "expected test_file_re to match {path:?}"
345 );
346 assert!(
347 (file_path_penalty(path) - STRONG_PENALTY).abs() < 1e-6,
348 "expected STRONG_PENALTY for {path:?}; got {}",
349 file_path_penalty(path)
350 );
351 }
352 }
353
354 /// `test:penalties-compat-dir` — compat / legacy directories trigger
355 /// the strong penalty.
356 #[test]
357 fn penalties_compat_dir() {
358 for path in &["compat/foo.py", "src/_compat/bar.rs", "legacy/baz.go"] {
359 assert!(
360 compat_dir_re().is_match(path),
361 "expected compat match for {path:?}"
362 );
363 assert!((file_path_penalty(path) - STRONG_PENALTY).abs() < 1e-6);
364 }
365 }
366
367 /// `test:penalties-examples-dir` — example / docs-src directories
368 /// trigger the strong penalty.
369 #[test]
370 fn penalties_examples_dir() {
371 for path in &[
372 "examples/foo.py",
373 "_examples/bar.rs",
374 "example/baz.go",
375 "docs_src/x.md",
376 ] {
377 assert!(
378 examples_dir_re().is_match(path),
379 "expected examples match for {path:?}"
380 );
381 assert!((file_path_penalty(path) - STRONG_PENALTY).abs() < 1e-6);
382 }
383 }
384
385 /// `test:penalties-init-py-reexport` — `__init__.py` carries the
386 /// moderate re-export penalty (and Java's `package-info.java`).
387 ///
388 /// Java path deliberately avoids `/example/` since that segment
389 /// triggers `examples_dir_re`'s strong penalty (we want to verify
390 /// re-export in isolation here).
391 #[test]
392 fn penalties_init_py_reexport() {
393 assert!((file_path_penalty("src/__init__.py") - MODERATE_PENALTY).abs() < 1e-6);
394 assert!(
395 (file_path_penalty("src/com/myorg/package-info.java") - MODERATE_PENALTY).abs() < 1e-6
396 );
397 }
398
399 /// `test:penalties-dts-stub` — `.d.ts` files take the mild penalty.
400 #[test]
401 fn penalties_dts_stub() {
402 assert!((file_path_penalty("src/foo.d.ts") - MILD_PENALTY).abs() < 1e-6);
403 }
404
405 /// Non-penalized paths return penalty 1.0.
406 #[test]
407 fn non_penalized_path_is_identity() {
408 assert!((file_path_penalty("src/foo.rs") - 1.0).abs() < 1e-6);
409 assert!((file_path_penalty("lib/bar.py") - 1.0).abs() < 1e-6);
410 }
411
412 /// `test:rerank-topk-saturation-decay` — a third chunk from the
413 /// same file is penalised by 0.5^2.
414 #[test]
415 fn rerank_topk_saturation_decay() {
416 let chunks = vec![
417 chunk_at("src/foo.rs"),
418 chunk_at("src/foo.rs"),
419 chunk_at("src/foo.rs"),
420 chunk_at("src/bar.rs"),
421 ];
422 // All four chunks have identical raw scores. The path penalty is 1.0
423 // for both files. Greedy order picks chunks in their submitted
424 // ordering (stable tie-break by index).
425 let scores = vec![(0, 1.0_f32), (1, 1.0), (2, 1.0), (3, 1.0)];
426 let got = rerank_topk(&scores, &chunks, 4, true);
427 assert_eq!(got.len(), 4);
428
429 let scored: HashMap<usize, f32> = got.iter().copied().collect();
430 // Chunk 0 (first from foo.rs): score 1.0.
431 // Chunk 3 (first from bar.rs): score 1.0.
432 // Chunk 1 (second from foo.rs): already=1, excess=1, score *= 0.5.
433 // Chunk 2 (third from foo.rs): already=2, excess=2, score *= 0.25.
434 assert!(
435 (scored[&0] - 1.0).abs() < 1e-6,
436 "scored[0] = {}",
437 scored[&0]
438 );
439 assert!(
440 (scored[&3] - 1.0).abs() < 1e-6,
441 "scored[3] = {}",
442 scored[&3]
443 );
444 assert!(
445 (scored[&1] - 0.5).abs() < 1e-6,
446 "scored[1] = {}",
447 scored[&1]
448 );
449 assert!(
450 (scored[&2] - 0.25).abs() < 1e-6,
451 "scored[2] = {}",
452 scored[&2]
453 );
454 }
455
456 /// `test:rerank-topk-greedy-early-exit` — top_k smaller than the
457 /// input cuts the iteration short while preserving ordering.
458 #[test]
459 fn rerank_topk_greedy_early_exit() {
460 let chunks: Vec<CodeChunk> = (0..10).map(|i| chunk_at(&format!("src/f{i}.rs"))).collect();
461 let scores: Vec<(usize, f32)> = (0..10).map(|i| (i, 10.0 - i as f32)).collect();
462 let got = rerank_topk(&scores, &chunks, 3, true);
463 assert_eq!(got.len(), 3);
464 let indices: Vec<usize> = got.iter().map(|(i, _)| *i).collect();
465 assert_eq!(indices, vec![0, 1, 2]);
466 }
467
468 /// `property:penalty-regex-parity-python` — non-test paths in
469 /// production layouts (src/, lib/, crates/, etc.) carry no penalty,
470 /// matching Python's behaviour exactly.
471 #[test]
472 fn property_penalty_regex_parity_python() {
473 let production_paths = &[
474 "src/foo.py",
475 "lib/parser.rs",
476 "crates/ripvec-core/src/encoder/semble/tokens.rs",
477 "pkg/server/handler.go",
478 "app/models/user.rb",
479 "main.c",
480 "include/foo.h",
481 ];
482 for path in production_paths {
483 assert!(
484 (file_path_penalty(path) - 1.0).abs() < 1e-6,
485 "non-test path {path:?} should have penalty 1.0; got {}",
486 file_path_penalty(path)
487 );
488 }
489 // And test-shaped paths always strong-penalize.
490 for path in &["test_foo.py", "src/__init__.py", "src/foo.d.ts"] {
491 assert!(
492 file_path_penalty(path) < 1.0,
493 "test-shaped path {path:?} should be penalised; got 1.0"
494 );
495 }
496 }
497
498 /// `penalise_paths == false` bypasses path priors but still applies
499 /// file-saturation decay.
500 #[test]
501 fn rerank_topk_no_path_penalties_still_decays() {
502 let chunks = vec![chunk_at("test/test_foo.py"), chunk_at("test/test_foo.py")];
503 let scores = vec![(0, 1.0_f32), (1, 1.0)];
504 // With penalise_paths=false, the test_ file's path penalty (0.3)
505 // does NOT apply; but the second chunk from the same file still
506 // gets the saturation decay (×0.5).
507 let got = rerank_topk(&scores, &chunks, 2, false);
508 let scored: HashMap<usize, f32> = got.iter().copied().collect();
509 assert!((scored[&0] - 1.0).abs() < 1e-6);
510 assert!((scored[&1] - 0.5).abs() < 1e-6);
511 }
512}