ripvec_core/encoder/ripvec/penalties.rs
1//! File-path penalties + `rerank_topk` with file-saturation decay.
2//!
3//! Port of `~/src/semble/src/semble/ranking/penalties.py`. Two surfaces:
4//!
5//! 1. [`file_path_penalty`] — multiplicative penalty per file path.
6//! Combines penalties for test files, compat/legacy/examples
7//! directories, re-export barrels (`__init__.py`, `package-info.java`),
8//! and `.d.ts` declaration stubs.
9//! 2. [`rerank_topk`] — greedy top-k selection that applies path
10//! penalties (when `penalise_paths == true`) then decays by 0.5 per
11//! extra chunk from the same file beyond a threshold of 1.
12//!
13//! ## Indexing convention
14//!
15//! Where Python uses `dict[Chunk, float]` keyed by hashable Chunk,
16//! Rust uses `(chunk_index, score)` pairs — the same convention ripvec
17//! already uses in [`crate::hybrid`]. This avoids adding `Hash`/`Eq`
18//! impls to [`CodeChunk`] just to satisfy a HashMap key.
19
20use std::collections::HashMap;
21use std::path::Path;
22use std::sync::OnceLock;
23
24use regex::{Regex, RegexBuilder};
25
26use crate::chunk::CodeChunk;
27
28// ---------------------------------------------------------------------------
29// Penalty constants (Python: `_STRONG_PENALTY` etc. in penalties.py:67).
30// ---------------------------------------------------------------------------
31
32/// Test files, compat shims, example/doc code.
33pub const STRONG_PENALTY: f32 = 0.3;
34/// Re-export / metadata files (`__init__.py`, `package-info.java`).
35pub const MODERATE_PENALTY: f32 = 0.5;
36/// `.d.ts` declaration stubs (still carry useful type info).
37pub const MILD_PENALTY: f32 = 0.7;
38
39/// Maximum chunks from the same file before saturation penalty applies.
40pub const FILE_SATURATION_THRESHOLD: usize = 1;
41
42/// Multiplicative penalty per extra chunk from the same file beyond
43/// [`FILE_SATURATION_THRESHOLD`]. Excess chunks pay `decay^excess`.
44pub const FILE_SATURATION_DECAY: f32 = 0.5;
45
46// ---------------------------------------------------------------------------
47// Path-prior regex bank (Python: penalties.py:8-65).
48// ---------------------------------------------------------------------------
49
50/// Regex matching test files across 14 languages plus shared helpers.
51///
52/// Mirrors `_TEST_FILE_RE` from `penalties.py:8`. Lazy-compiled.
53fn test_file_re() -> &'static Regex {
54 static RE: OnceLock<Regex> = OnceLock::new();
55 RE.get_or_init(|| {
56 // Match exactly the Python pattern. The non-capture group after
57 // `(?:^|/)` lists per-language test-file conventions; the
58 // pattern terminates with `$`.
59 let pattern = concat!(
60 r"(?:^|/)",
61 r"(?:",
62 // Python
63 r"test_[^/]*\.py",
64 r"|[^/]*_test\.py",
65 // Go
66 r"|[^/]*_test\.go",
67 // Java
68 r"|[^/]*Tests?\.java",
69 // PHP
70 r"|[^/]*Test\.php",
71 // Ruby
72 r"|[^/]*_spec\.rb",
73 r"|[^/]*_test\.rb",
74 // JavaScript / TypeScript
75 r"|[^/]*\.test\.[jt]sx?",
76 r"|[^/]*\.spec\.[jt]sx?",
77 // Kotlin
78 r"|[^/]*Tests?\.kt",
79 r"|[^/]*Spec\.kt",
80 // Swift
81 r"|[^/]*Tests?\.swift",
82 r"|[^/]*Spec\.swift",
83 // C#
84 r"|[^/]*Tests?\.cs",
85 // C / C++
86 r"|test_[^/]*\.cpp",
87 r"|[^/]*_test\.cpp",
88 r"|test_[^/]*\.c",
89 r"|[^/]*_test\.c",
90 // Scala
91 r"|[^/]*Spec\.scala",
92 r"|[^/]*Suite\.scala",
93 r"|[^/]*Test\.scala",
94 // Dart
95 r"|[^/]*_test\.dart",
96 r"|test_[^/]*\.dart",
97 // Lua
98 r"|[^/]*_spec\.lua",
99 r"|[^/]*_test\.lua",
100 r"|test_[^/]*\.lua",
101 // Shared helper patterns (all languages)
102 r"|test_helpers?[^/]*\.\w+",
103 r")$",
104 );
105 Regex::new(pattern).expect("test-file regex compiles")
106 })
107}
108
109/// Regex matching directories named `test`, `tests`, `__tests__`, `spec`,
110/// `testing`. Mirrors `_TEST_DIR_RE` from `penalties.py:56`.
111fn test_dir_re() -> &'static Regex {
112 static RE: OnceLock<Regex> = OnceLock::new();
113 RE.get_or_init(|| {
114 Regex::new(r"(?:^|/)(?:tests?|__tests__|spec|testing)(?:/|$)")
115 .expect("test-dir regex compiles")
116 })
117}
118
119/// Regex matching compat / legacy directories. Mirrors `_COMPAT_DIR_RE`.
120fn compat_dir_re() -> &'static Regex {
121 static RE: OnceLock<Regex> = OnceLock::new();
122 RE.get_or_init(|| {
123 Regex::new(r"(?:^|/)(?:compat|_compat|legacy)(?:/|$)").expect("compat-dir regex compiles")
124 })
125}
126
127/// Regex matching examples / docs-src directories. Mirrors
128/// `_EXAMPLES_DIR_RE`.
129fn examples_dir_re() -> &'static Regex {
130 static RE: OnceLock<Regex> = OnceLock::new();
131 RE.get_or_init(|| {
132 Regex::new(r"(?:^|/)(?:_?examples?|docs?_src)(?:/|$)").expect("examples-dir regex compiles")
133 })
134}
135
136/// Regex matching TypeScript `.d.ts` declaration stubs.
137fn type_defs_re() -> &'static Regex {
138 static RE: OnceLock<Regex> = OnceLock::new();
139 RE.get_or_init(|| {
140 RegexBuilder::new(r"\.d\.ts$")
141 .build()
142 .expect("dts regex compiles")
143 })
144}
145
146/// Filenames that are re-export barrels or package-level metadata.
147const REEXPORT_FILENAMES: &[&str] = &["__init__.py", "package-info.java"];
148
149// ---------------------------------------------------------------------------
150// Public API.
151// ---------------------------------------------------------------------------
152
153/// Combined multiplicative penalty for all applicable path patterns.
154///
155/// Always in `(0.0, 1.0]`. Returns `1.0` when no patterns match.
156/// Each matched category multiplies into the result independently, so a
157/// file matching both a test-dir pattern and a re-export filename pays
158/// `STRONG_PENALTY * MODERATE_PENALTY`.
159///
160/// Backslashes in `file_path` are normalised to forward slashes before
161/// matching, so Windows-style paths work too.
162#[must_use]
163pub fn file_path_penalty(file_path: &str) -> f32 {
164 let normalised = file_path.replace('\\', "/");
165 let mut penalty = 1.0_f32;
166 if test_file_re().is_match(&normalised) || test_dir_re().is_match(&normalised) {
167 penalty *= STRONG_PENALTY;
168 }
169 if let Some(filename) = Path::new(file_path).file_name().and_then(|f| f.to_str())
170 && REEXPORT_FILENAMES.contains(&filename)
171 {
172 penalty *= MODERATE_PENALTY;
173 }
174 if compat_dir_re().is_match(&normalised) {
175 penalty *= STRONG_PENALTY;
176 }
177 if examples_dir_re().is_match(&normalised) {
178 penalty *= STRONG_PENALTY;
179 }
180 if type_defs_re().is_match(&normalised) {
181 penalty *= MILD_PENALTY;
182 }
183 penalty
184}
185
186/// Select the top-k results with optional path penalties and
187/// file-saturation decay.
188///
189/// Mirrors `rerank_topk` from `penalties.py:81`. The greedy pass uses
190/// the early-exit optimisation: once `selected.len() >= top_k` and the
191/// remaining candidate's penalised score cannot beat the current k-th
192/// best, the loop terminates.
193///
194/// - `penalise_paths == true` — apply [`file_path_penalty`] before
195/// sorting and selection. This matches semble's behaviour for hybrid
196/// and BM25-led queries.
197/// - `penalise_paths == false` — bypass path priors (pure-semantic
198/// queries where the priors don't apply).
199///
200/// Returns `(chunk_index, effective_score)` pairs, highest first, with
201/// length capped at `top_k`.
202#[must_use]
203pub fn rerank_topk(
204 scores: &[(usize, f32)],
205 chunks: &[CodeChunk],
206 top_k: usize,
207 penalise_paths: bool,
208) -> Vec<(usize, f32)> {
209 if scores.is_empty() || top_k == 0 {
210 return Vec::new();
211 }
212
213 // Step 1: apply path penalties (or identity), producing (chunk_index, penalised_score).
214 let mut penalty_cache: HashMap<String, f32> = HashMap::new();
215 let mut penalised: Vec<(usize, f32)> = scores
216 .iter()
217 .map(|&(idx, score)| {
218 if !penalise_paths {
219 return (idx, score);
220 }
221 let path = &chunks[idx].file_path;
222 let factor = *penalty_cache
223 .entry(path.clone())
224 .or_insert_with(|| file_path_penalty(path));
225 (idx, score * factor)
226 })
227 .collect();
228
229 // Step 2: sort descending by penalised score with stable tie-break by index.
230 penalised.sort_unstable_by(|a, b| b.1.total_cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
231
232 // Step 3: greedy pick with file-saturation decay and early-exit.
233 let mut file_selected: HashMap<String, usize> = HashMap::new();
234 let mut selected: Vec<(usize, f32)> = Vec::with_capacity(top_k.min(penalised.len()));
235 let mut min_selected = f32::INFINITY;
236
237 for &(idx, pen_score) in &penalised {
238 if selected.len() >= top_k && pen_score <= min_selected {
239 // Remaining (sorted-descending) cannot beat the current k-th best.
240 break;
241 }
242 let path = chunks[idx].file_path.clone();
243 let already = *file_selected.get(&path).unwrap_or(&0);
244 let eff_score = if already >= FILE_SATURATION_THRESHOLD {
245 let excess = already - FILE_SATURATION_THRESHOLD + 1;
246 // Excess is bounded by the number of chunks in the front; the
247 // saturation here is far below i32::MAX in practice. The clamp
248 // is defence-in-depth.
249 let excess_i32 = i32::try_from(excess).unwrap_or(i32::MAX);
250 pen_score * FILE_SATURATION_DECAY.powi(excess_i32)
251 } else {
252 pen_score
253 };
254 selected.push((idx, eff_score));
255 file_selected.insert(path, already + 1);
256
257 if selected.len() >= top_k {
258 // Refresh the min over the current selection.
259 min_selected = selected
260 .iter()
261 .map(|(_, s)| *s)
262 .fold(f32::INFINITY, f32::min);
263 }
264 }
265
266 // Final descending sort on effective scores.
267 selected.sort_unstable_by(|a, b| b.1.total_cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
268 selected.truncate(top_k);
269 selected
270}
271
272#[cfg(test)]
273mod tests {
274 use super::*;
275
276 fn chunk_at(path: &str) -> CodeChunk {
277 CodeChunk {
278 file_path: path.to_string(),
279 name: String::new(),
280 kind: String::new(),
281 content_kind: crate::chunk::ContentKind::Code,
282 start_line: 1,
283 symbol_line: 1,
284 end_line: 1,
285 content: String::new(),
286 enriched_content: String::new(),
287 qualified_name: None,
288 }
289 }
290
291 /// `test:penalties-test-file-regex-14-langs` — every per-language
292 /// test-file pattern from `penalties.py:8` matches a representative
293 /// file path.
294 #[test]
295 fn penalties_test_file_regex_14_langs() {
296 let cases: &[&str] = &[
297 // Python
298 "src/test_foo.py",
299 "src/foo_test.py",
300 // Go
301 "pkg/foo_test.go",
302 // Java
303 "src/FooTest.java",
304 "src/FooTests.java",
305 // PHP
306 "src/FooTest.php",
307 // Ruby
308 "spec/foo_spec.rb",
309 "test/foo_test.rb",
310 // JS/TS
311 "src/foo.test.js",
312 "src/foo.spec.ts",
313 "src/foo.test.tsx",
314 // Kotlin
315 "src/FooTest.kt",
316 "src/FooTests.kt",
317 "src/FooSpec.kt",
318 // Swift
319 "src/FooTests.swift",
320 "src/FooSpec.swift",
321 // C#
322 "src/FooTest.cs",
323 "src/FooTests.cs",
324 // C / C++
325 "src/test_foo.cpp",
326 "src/foo_test.cpp",
327 "src/test_foo.c",
328 "src/foo_test.c",
329 // Scala
330 "src/FooSpec.scala",
331 "src/FooSuite.scala",
332 "src/FooTest.scala",
333 // Dart
334 "src/foo_test.dart",
335 "src/test_foo.dart",
336 // Lua
337 "src/foo_spec.lua",
338 "src/foo_test.lua",
339 "src/test_foo.lua",
340 // Shared helpers
341 "test/test_helper.rb",
342 "test/test_helpers.go",
343 ];
344 for path in cases {
345 assert!(
346 test_file_re().is_match(path),
347 "expected test_file_re to match {path:?}"
348 );
349 assert!(
350 (file_path_penalty(path) - STRONG_PENALTY).abs() < 1e-6,
351 "expected STRONG_PENALTY for {path:?}; got {}",
352 file_path_penalty(path)
353 );
354 }
355 }
356
357 /// `test:penalties-compat-dir` — compat / legacy directories trigger
358 /// the strong penalty.
359 #[test]
360 fn penalties_compat_dir() {
361 for path in &["compat/foo.py", "src/_compat/bar.rs", "legacy/baz.go"] {
362 assert!(
363 compat_dir_re().is_match(path),
364 "expected compat match for {path:?}"
365 );
366 assert!((file_path_penalty(path) - STRONG_PENALTY).abs() < 1e-6);
367 }
368 }
369
370 /// `test:penalties-examples-dir` — example / docs-src directories
371 /// trigger the strong penalty.
372 #[test]
373 fn penalties_examples_dir() {
374 for path in &[
375 "examples/foo.py",
376 "_examples/bar.rs",
377 "example/baz.go",
378 "docs_src/x.md",
379 ] {
380 assert!(
381 examples_dir_re().is_match(path),
382 "expected examples match for {path:?}"
383 );
384 assert!((file_path_penalty(path) - STRONG_PENALTY).abs() < 1e-6);
385 }
386 }
387
388 /// `test:penalties-init-py-reexport` — `__init__.py` carries the
389 /// moderate re-export penalty (and Java's `package-info.java`).
390 ///
391 /// Java path deliberately avoids `/example/` since that segment
392 /// triggers `examples_dir_re`'s strong penalty (we want to verify
393 /// re-export in isolation here).
394 #[test]
395 fn penalties_init_py_reexport() {
396 assert!((file_path_penalty("src/__init__.py") - MODERATE_PENALTY).abs() < 1e-6);
397 assert!(
398 (file_path_penalty("src/com/myorg/package-info.java") - MODERATE_PENALTY).abs() < 1e-6
399 );
400 }
401
402 /// `test:penalties-dts-stub` — `.d.ts` files take the mild penalty.
403 #[test]
404 fn penalties_dts_stub() {
405 assert!((file_path_penalty("src/foo.d.ts") - MILD_PENALTY).abs() < 1e-6);
406 }
407
408 /// Non-penalized paths return penalty 1.0.
409 #[test]
410 fn non_penalized_path_is_identity() {
411 assert!((file_path_penalty("src/foo.rs") - 1.0).abs() < 1e-6);
412 assert!((file_path_penalty("lib/bar.py") - 1.0).abs() < 1e-6);
413 }
414
415 /// `test:rerank-topk-saturation-decay` — a third chunk from the
416 /// same file is penalised by 0.5^2.
417 #[test]
418 fn rerank_topk_saturation_decay() {
419 let chunks = vec![
420 chunk_at("src/foo.rs"),
421 chunk_at("src/foo.rs"),
422 chunk_at("src/foo.rs"),
423 chunk_at("src/bar.rs"),
424 ];
425 // All four chunks have identical raw scores. The path penalty is 1.0
426 // for both files. Greedy order picks chunks in their submitted
427 // ordering (stable tie-break by index).
428 let scores = vec![(0, 1.0_f32), (1, 1.0), (2, 1.0), (3, 1.0)];
429 let got = rerank_topk(&scores, &chunks, 4, true);
430 assert_eq!(got.len(), 4);
431
432 let scored: HashMap<usize, f32> = got.iter().copied().collect();
433 // Chunk 0 (first from foo.rs): score 1.0.
434 // Chunk 3 (first from bar.rs): score 1.0.
435 // Chunk 1 (second from foo.rs): already=1, excess=1, score *= 0.5.
436 // Chunk 2 (third from foo.rs): already=2, excess=2, score *= 0.25.
437 assert!(
438 (scored[&0] - 1.0).abs() < 1e-6,
439 "scored[0] = {}",
440 scored[&0]
441 );
442 assert!(
443 (scored[&3] - 1.0).abs() < 1e-6,
444 "scored[3] = {}",
445 scored[&3]
446 );
447 assert!(
448 (scored[&1] - 0.5).abs() < 1e-6,
449 "scored[1] = {}",
450 scored[&1]
451 );
452 assert!(
453 (scored[&2] - 0.25).abs() < 1e-6,
454 "scored[2] = {}",
455 scored[&2]
456 );
457 }
458
459 /// `test:rerank-topk-greedy-early-exit` — top_k smaller than the
460 /// input cuts the iteration short while preserving ordering.
461 #[test]
462 fn rerank_topk_greedy_early_exit() {
463 let chunks: Vec<CodeChunk> = (0..10).map(|i| chunk_at(&format!("src/f{i}.rs"))).collect();
464 let scores: Vec<(usize, f32)> = (0..10).map(|i| (i, 10.0 - i as f32)).collect();
465 let got = rerank_topk(&scores, &chunks, 3, true);
466 assert_eq!(got.len(), 3);
467 let indices: Vec<usize> = got.iter().map(|(i, _)| *i).collect();
468 assert_eq!(indices, vec![0, 1, 2]);
469 }
470
471 /// `property:penalty-regex-parity-python` — non-test paths in
472 /// production layouts (src/, lib/, crates/, etc.) carry no penalty,
473 /// matching Python's behaviour exactly.
474 #[test]
475 fn property_penalty_regex_parity_python() {
476 let production_paths = &[
477 "src/foo.py",
478 "lib/parser.rs",
479 "crates/ripvec-core/src/encoder/semble/tokens.rs",
480 "pkg/server/handler.go",
481 "app/models/user.rb",
482 "main.c",
483 "include/foo.h",
484 ];
485 for path in production_paths {
486 assert!(
487 (file_path_penalty(path) - 1.0).abs() < 1e-6,
488 "non-test path {path:?} should have penalty 1.0; got {}",
489 file_path_penalty(path)
490 );
491 }
492 // And test-shaped paths always strong-penalize.
493 for path in &["test_foo.py", "src/__init__.py", "src/foo.d.ts"] {
494 assert!(
495 file_path_penalty(path) < 1.0,
496 "test-shaped path {path:?} should be penalised; got 1.0"
497 );
498 }
499 }
500
501 /// `penalise_paths == false` bypasses path priors but still applies
502 /// file-saturation decay.
503 #[test]
504 fn rerank_topk_no_path_penalties_still_decays() {
505 let chunks = vec![chunk_at("test/test_foo.py"), chunk_at("test/test_foo.py")];
506 let scores = vec![(0, 1.0_f32), (1, 1.0)];
507 // With penalise_paths=false, the test_ file's path penalty (0.3)
508 // does NOT apply; but the second chunk from the same file still
509 // gets the saturation decay (×0.5).
510 let got = rerank_topk(&scores, &chunks, 2, false);
511 let scored: HashMap<usize, f32> = got.iter().copied().collect();
512 assert!((scored[&0] - 1.0).abs() < 1e-6);
513 assert!((scored[&1] - 0.5).abs() < 1e-6);
514 }
515}