Skip to main content

ripvec_core/encoder/ripvec/
penalties.rs

1//! File-path penalties + `rerank_topk` with file-saturation decay.
2//!
3//! Port of `~/src/semble/src/semble/ranking/penalties.py`. Two surfaces:
4//!
5//! 1. [`file_path_penalty`] — multiplicative penalty per file path.
6//!    Combines penalties for test files, compat/legacy/examples
7//!    directories, re-export barrels (`__init__.py`, `package-info.java`),
8//!    and `.d.ts` declaration stubs.
9//! 2. [`rerank_topk`] — greedy top-k selection that applies path
10//!    penalties (when `penalise_paths == true`) then decays by 0.5 per
11//!    extra chunk from the same file beyond a threshold of 1.
12//!
13//! ## Indexing convention
14//!
15//! Where Python uses `dict[Chunk, float]` keyed by hashable Chunk,
16//! Rust uses `(chunk_index, score)` pairs — the same convention ripvec
17//! already uses in [`crate::hybrid`]. This avoids adding `Hash`/`Eq`
18//! impls to [`CodeChunk`] just to satisfy a HashMap key.
19
20use std::collections::HashMap;
21use std::path::Path;
22use std::sync::OnceLock;
23
24use regex::{Regex, RegexBuilder};
25
26use crate::chunk::CodeChunk;
27
28// ---------------------------------------------------------------------------
29// Penalty constants (Python: `_STRONG_PENALTY` etc. in penalties.py:67).
30// ---------------------------------------------------------------------------
31
32/// Test files, compat shims, example/doc code.
33pub const STRONG_PENALTY: f32 = 0.3;
34/// Re-export / metadata files (`__init__.py`, `package-info.java`).
35pub const MODERATE_PENALTY: f32 = 0.5;
36/// `.d.ts` declaration stubs (still carry useful type info).
37pub const MILD_PENALTY: f32 = 0.7;
38
39/// Maximum chunks from the same file before saturation penalty applies.
40pub const FILE_SATURATION_THRESHOLD: usize = 1;
41
42/// Multiplicative penalty per extra chunk from the same file beyond
43/// [`FILE_SATURATION_THRESHOLD`]. Excess chunks pay `decay^excess`.
44pub const FILE_SATURATION_DECAY: f32 = 0.5;
45
46// ---------------------------------------------------------------------------
47// Path-prior regex bank (Python: penalties.py:8-65).
48// ---------------------------------------------------------------------------
49
50/// Regex matching test files across 14 languages plus shared helpers.
51///
52/// Mirrors `_TEST_FILE_RE` from `penalties.py:8`. Lazy-compiled.
53fn test_file_re() -> &'static Regex {
54    static RE: OnceLock<Regex> = OnceLock::new();
55    RE.get_or_init(|| {
56        // Match exactly the Python pattern. The non-capture group after
57        // `(?:^|/)` lists per-language test-file conventions; the
58        // pattern terminates with `$`.
59        let pattern = concat!(
60            r"(?:^|/)",
61            r"(?:",
62            // Python
63            r"test_[^/]*\.py",
64            r"|[^/]*_test\.py",
65            // Go
66            r"|[^/]*_test\.go",
67            // Java
68            r"|[^/]*Tests?\.java",
69            // PHP
70            r"|[^/]*Test\.php",
71            // Ruby
72            r"|[^/]*_spec\.rb",
73            r"|[^/]*_test\.rb",
74            // JavaScript / TypeScript
75            r"|[^/]*\.test\.[jt]sx?",
76            r"|[^/]*\.spec\.[jt]sx?",
77            // Kotlin
78            r"|[^/]*Tests?\.kt",
79            r"|[^/]*Spec\.kt",
80            // Swift
81            r"|[^/]*Tests?\.swift",
82            r"|[^/]*Spec\.swift",
83            // C#
84            r"|[^/]*Tests?\.cs",
85            // C / C++
86            r"|test_[^/]*\.cpp",
87            r"|[^/]*_test\.cpp",
88            r"|test_[^/]*\.c",
89            r"|[^/]*_test\.c",
90            // Scala
91            r"|[^/]*Spec\.scala",
92            r"|[^/]*Suite\.scala",
93            r"|[^/]*Test\.scala",
94            // Dart
95            r"|[^/]*_test\.dart",
96            r"|test_[^/]*\.dart",
97            // Lua
98            r"|[^/]*_spec\.lua",
99            r"|[^/]*_test\.lua",
100            r"|test_[^/]*\.lua",
101            // Shared helper patterns (all languages)
102            r"|test_helpers?[^/]*\.\w+",
103            r")$",
104        );
105        Regex::new(pattern).expect("test-file regex compiles")
106    })
107}
108
109/// Regex matching directories named `test`, `tests`, `__tests__`, `spec`,
110/// `testing`. Mirrors `_TEST_DIR_RE` from `penalties.py:56`.
111fn test_dir_re() -> &'static Regex {
112    static RE: OnceLock<Regex> = OnceLock::new();
113    RE.get_or_init(|| {
114        Regex::new(r"(?:^|/)(?:tests?|__tests__|spec|testing)(?:/|$)")
115            .expect("test-dir regex compiles")
116    })
117}
118
119/// Regex matching compat / legacy directories. Mirrors `_COMPAT_DIR_RE`.
120fn compat_dir_re() -> &'static Regex {
121    static RE: OnceLock<Regex> = OnceLock::new();
122    RE.get_or_init(|| {
123        Regex::new(r"(?:^|/)(?:compat|_compat|legacy)(?:/|$)").expect("compat-dir regex compiles")
124    })
125}
126
127/// Regex matching examples / docs-src directories. Mirrors
128/// `_EXAMPLES_DIR_RE`.
129fn examples_dir_re() -> &'static Regex {
130    static RE: OnceLock<Regex> = OnceLock::new();
131    RE.get_or_init(|| {
132        Regex::new(r"(?:^|/)(?:_?examples?|docs?_src)(?:/|$)").expect("examples-dir regex compiles")
133    })
134}
135
136/// Regex matching TypeScript `.d.ts` declaration stubs.
137fn type_defs_re() -> &'static Regex {
138    static RE: OnceLock<Regex> = OnceLock::new();
139    RE.get_or_init(|| {
140        RegexBuilder::new(r"\.d\.ts$")
141            .build()
142            .expect("dts regex compiles")
143    })
144}
145
146/// Filenames that are re-export barrels or package-level metadata.
147const REEXPORT_FILENAMES: &[&str] = &["__init__.py", "package-info.java"];
148
149// ---------------------------------------------------------------------------
150// Public API.
151// ---------------------------------------------------------------------------
152
153/// Combined multiplicative penalty for all applicable path patterns.
154///
155/// Always in `(0.0, 1.0]`. Returns `1.0` when no patterns match.
156/// Each matched category multiplies into the result independently, so a
157/// file matching both a test-dir pattern and a re-export filename pays
158/// `STRONG_PENALTY * MODERATE_PENALTY`.
159///
160/// Backslashes in `file_path` are normalised to forward slashes before
161/// matching, so Windows-style paths work too.
162#[must_use]
163pub fn file_path_penalty(file_path: &str) -> f32 {
164    let normalised = file_path.replace('\\', "/");
165    let mut penalty = 1.0_f32;
166    if test_file_re().is_match(&normalised) || test_dir_re().is_match(&normalised) {
167        penalty *= STRONG_PENALTY;
168    }
169    if let Some(filename) = Path::new(file_path).file_name().and_then(|f| f.to_str())
170        && REEXPORT_FILENAMES.contains(&filename)
171    {
172        penalty *= MODERATE_PENALTY;
173    }
174    if compat_dir_re().is_match(&normalised) {
175        penalty *= STRONG_PENALTY;
176    }
177    if examples_dir_re().is_match(&normalised) {
178        penalty *= STRONG_PENALTY;
179    }
180    if type_defs_re().is_match(&normalised) {
181        penalty *= MILD_PENALTY;
182    }
183    penalty
184}
185
186/// Select the top-k results with optional path penalties and
187/// file-saturation decay.
188///
189/// Mirrors `rerank_topk` from `penalties.py:81`. The greedy pass uses
190/// the early-exit optimisation: once `selected.len() >= top_k` and the
191/// remaining candidate's penalised score cannot beat the current k-th
192/// best, the loop terminates.
193///
194/// - `penalise_paths == true` — apply [`file_path_penalty`] before
195///   sorting and selection. This matches semble's behaviour for hybrid
196///   and BM25-led queries.
197/// - `penalise_paths == false` — bypass path priors (pure-semantic
198///   queries where the priors don't apply).
199///
200/// Returns `(chunk_index, effective_score)` pairs, highest first, with
201/// length capped at `top_k`.
202#[must_use]
203pub fn rerank_topk(
204    scores: &[(usize, f32)],
205    chunks: &[CodeChunk],
206    top_k: usize,
207    penalise_paths: bool,
208) -> Vec<(usize, f32)> {
209    if scores.is_empty() || top_k == 0 {
210        return Vec::new();
211    }
212
213    // Step 1: apply path penalties (or identity), producing (chunk_index, penalised_score).
214    let mut penalty_cache: HashMap<String, f32> = HashMap::new();
215    let mut penalised: Vec<(usize, f32)> = scores
216        .iter()
217        .map(|&(idx, score)| {
218            if !penalise_paths {
219                return (idx, score);
220            }
221            let path = &chunks[idx].file_path;
222            let factor = *penalty_cache
223                .entry(path.clone())
224                .or_insert_with(|| file_path_penalty(path));
225            (idx, score * factor)
226        })
227        .collect();
228
229    // Step 2: sort descending by penalised score with stable tie-break by index.
230    penalised.sort_unstable_by(|a, b| b.1.total_cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
231
232    // Step 3: greedy pick with file-saturation decay and early-exit.
233    let mut file_selected: HashMap<String, usize> = HashMap::new();
234    let mut selected: Vec<(usize, f32)> = Vec::with_capacity(top_k.min(penalised.len()));
235    let mut min_selected = f32::INFINITY;
236
237    for &(idx, pen_score) in &penalised {
238        if selected.len() >= top_k && pen_score <= min_selected {
239            // Remaining (sorted-descending) cannot beat the current k-th best.
240            break;
241        }
242        let path = chunks[idx].file_path.clone();
243        let already = *file_selected.get(&path).unwrap_or(&0);
244        let eff_score = if already >= FILE_SATURATION_THRESHOLD {
245            let excess = already - FILE_SATURATION_THRESHOLD + 1;
246            // Excess is bounded by the number of chunks in the front; the
247            // saturation here is far below i32::MAX in practice. The clamp
248            // is defence-in-depth.
249            let excess_i32 = i32::try_from(excess).unwrap_or(i32::MAX);
250            pen_score * FILE_SATURATION_DECAY.powi(excess_i32)
251        } else {
252            pen_score
253        };
254        selected.push((idx, eff_score));
255        file_selected.insert(path, already + 1);
256
257        if selected.len() >= top_k {
258            // Refresh the min over the current selection.
259            min_selected = selected
260                .iter()
261                .map(|(_, s)| *s)
262                .fold(f32::INFINITY, f32::min);
263        }
264    }
265
266    // Final descending sort on effective scores.
267    selected.sort_unstable_by(|a, b| b.1.total_cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
268    selected.truncate(top_k);
269    selected
270}
271
272#[cfg(test)]
273mod tests {
274    use super::*;
275
276    fn chunk_at(path: &str) -> CodeChunk {
277        CodeChunk {
278            file_path: path.to_string(),
279            name: String::new(),
280            kind: String::new(),
281            start_line: 1,
282            end_line: 1,
283            content: String::new(),
284            enriched_content: String::new(),
285        }
286    }
287
288    /// `test:penalties-test-file-regex-14-langs` — every per-language
289    /// test-file pattern from `penalties.py:8` matches a representative
290    /// file path.
291    #[test]
292    fn penalties_test_file_regex_14_langs() {
293        let cases: &[&str] = &[
294            // Python
295            "src/test_foo.py",
296            "src/foo_test.py",
297            // Go
298            "pkg/foo_test.go",
299            // Java
300            "src/FooTest.java",
301            "src/FooTests.java",
302            // PHP
303            "src/FooTest.php",
304            // Ruby
305            "spec/foo_spec.rb",
306            "test/foo_test.rb",
307            // JS/TS
308            "src/foo.test.js",
309            "src/foo.spec.ts",
310            "src/foo.test.tsx",
311            // Kotlin
312            "src/FooTest.kt",
313            "src/FooTests.kt",
314            "src/FooSpec.kt",
315            // Swift
316            "src/FooTests.swift",
317            "src/FooSpec.swift",
318            // C#
319            "src/FooTest.cs",
320            "src/FooTests.cs",
321            // C / C++
322            "src/test_foo.cpp",
323            "src/foo_test.cpp",
324            "src/test_foo.c",
325            "src/foo_test.c",
326            // Scala
327            "src/FooSpec.scala",
328            "src/FooSuite.scala",
329            "src/FooTest.scala",
330            // Dart
331            "src/foo_test.dart",
332            "src/test_foo.dart",
333            // Lua
334            "src/foo_spec.lua",
335            "src/foo_test.lua",
336            "src/test_foo.lua",
337            // Shared helpers
338            "test/test_helper.rb",
339            "test/test_helpers.go",
340        ];
341        for path in cases {
342            assert!(
343                test_file_re().is_match(path),
344                "expected test_file_re to match {path:?}"
345            );
346            assert!(
347                (file_path_penalty(path) - STRONG_PENALTY).abs() < 1e-6,
348                "expected STRONG_PENALTY for {path:?}; got {}",
349                file_path_penalty(path)
350            );
351        }
352    }
353
354    /// `test:penalties-compat-dir` — compat / legacy directories trigger
355    /// the strong penalty.
356    #[test]
357    fn penalties_compat_dir() {
358        for path in &["compat/foo.py", "src/_compat/bar.rs", "legacy/baz.go"] {
359            assert!(
360                compat_dir_re().is_match(path),
361                "expected compat match for {path:?}"
362            );
363            assert!((file_path_penalty(path) - STRONG_PENALTY).abs() < 1e-6);
364        }
365    }
366
367    /// `test:penalties-examples-dir` — example / docs-src directories
368    /// trigger the strong penalty.
369    #[test]
370    fn penalties_examples_dir() {
371        for path in &[
372            "examples/foo.py",
373            "_examples/bar.rs",
374            "example/baz.go",
375            "docs_src/x.md",
376        ] {
377            assert!(
378                examples_dir_re().is_match(path),
379                "expected examples match for {path:?}"
380            );
381            assert!((file_path_penalty(path) - STRONG_PENALTY).abs() < 1e-6);
382        }
383    }
384
385    /// `test:penalties-init-py-reexport` — `__init__.py` carries the
386    /// moderate re-export penalty (and Java's `package-info.java`).
387    ///
388    /// Java path deliberately avoids `/example/` since that segment
389    /// triggers `examples_dir_re`'s strong penalty (we want to verify
390    /// re-export in isolation here).
391    #[test]
392    fn penalties_init_py_reexport() {
393        assert!((file_path_penalty("src/__init__.py") - MODERATE_PENALTY).abs() < 1e-6);
394        assert!(
395            (file_path_penalty("src/com/myorg/package-info.java") - MODERATE_PENALTY).abs() < 1e-6
396        );
397    }
398
399    /// `test:penalties-dts-stub` — `.d.ts` files take the mild penalty.
400    #[test]
401    fn penalties_dts_stub() {
402        assert!((file_path_penalty("src/foo.d.ts") - MILD_PENALTY).abs() < 1e-6);
403    }
404
405    /// Non-penalized paths return penalty 1.0.
406    #[test]
407    fn non_penalized_path_is_identity() {
408        assert!((file_path_penalty("src/foo.rs") - 1.0).abs() < 1e-6);
409        assert!((file_path_penalty("lib/bar.py") - 1.0).abs() < 1e-6);
410    }
411
412    /// `test:rerank-topk-saturation-decay` — a third chunk from the
413    /// same file is penalised by 0.5^2.
414    #[test]
415    fn rerank_topk_saturation_decay() {
416        let chunks = vec![
417            chunk_at("src/foo.rs"),
418            chunk_at("src/foo.rs"),
419            chunk_at("src/foo.rs"),
420            chunk_at("src/bar.rs"),
421        ];
422        // All four chunks have identical raw scores. The path penalty is 1.0
423        // for both files. Greedy order picks chunks in their submitted
424        // ordering (stable tie-break by index).
425        let scores = vec![(0, 1.0_f32), (1, 1.0), (2, 1.0), (3, 1.0)];
426        let got = rerank_topk(&scores, &chunks, 4, true);
427        assert_eq!(got.len(), 4);
428
429        let scored: HashMap<usize, f32> = got.iter().copied().collect();
430        // Chunk 0 (first from foo.rs): score 1.0.
431        // Chunk 3 (first from bar.rs): score 1.0.
432        // Chunk 1 (second from foo.rs): already=1, excess=1, score *= 0.5.
433        // Chunk 2 (third from foo.rs):  already=2, excess=2, score *= 0.25.
434        assert!(
435            (scored[&0] - 1.0).abs() < 1e-6,
436            "scored[0] = {}",
437            scored[&0]
438        );
439        assert!(
440            (scored[&3] - 1.0).abs() < 1e-6,
441            "scored[3] = {}",
442            scored[&3]
443        );
444        assert!(
445            (scored[&1] - 0.5).abs() < 1e-6,
446            "scored[1] = {}",
447            scored[&1]
448        );
449        assert!(
450            (scored[&2] - 0.25).abs() < 1e-6,
451            "scored[2] = {}",
452            scored[&2]
453        );
454    }
455
456    /// `test:rerank-topk-greedy-early-exit` — top_k smaller than the
457    /// input cuts the iteration short while preserving ordering.
458    #[test]
459    fn rerank_topk_greedy_early_exit() {
460        let chunks: Vec<CodeChunk> = (0..10).map(|i| chunk_at(&format!("src/f{i}.rs"))).collect();
461        let scores: Vec<(usize, f32)> = (0..10).map(|i| (i, 10.0 - i as f32)).collect();
462        let got = rerank_topk(&scores, &chunks, 3, true);
463        assert_eq!(got.len(), 3);
464        let indices: Vec<usize> = got.iter().map(|(i, _)| *i).collect();
465        assert_eq!(indices, vec![0, 1, 2]);
466    }
467
468    /// `property:penalty-regex-parity-python` — non-test paths in
469    /// production layouts (src/, lib/, crates/, etc.) carry no penalty,
470    /// matching Python's behaviour exactly.
471    #[test]
472    fn property_penalty_regex_parity_python() {
473        let production_paths = &[
474            "src/foo.py",
475            "lib/parser.rs",
476            "crates/ripvec-core/src/encoder/semble/tokens.rs",
477            "pkg/server/handler.go",
478            "app/models/user.rb",
479            "main.c",
480            "include/foo.h",
481        ];
482        for path in production_paths {
483            assert!(
484                (file_path_penalty(path) - 1.0).abs() < 1e-6,
485                "non-test path {path:?} should have penalty 1.0; got {}",
486                file_path_penalty(path)
487            );
488        }
489        // And test-shaped paths always strong-penalize.
490        for path in &["test_foo.py", "src/__init__.py", "src/foo.d.ts"] {
491            assert!(
492                file_path_penalty(path) < 1.0,
493                "test-shaped path {path:?} should be penalised; got 1.0"
494            );
495        }
496    }
497
498    /// `penalise_paths == false` bypasses path priors but still applies
499    /// file-saturation decay.
500    #[test]
501    fn rerank_topk_no_path_penalties_still_decays() {
502        let chunks = vec![chunk_at("test/test_foo.py"), chunk_at("test/test_foo.py")];
503        let scores = vec![(0, 1.0_f32), (1, 1.0)];
504        // With penalise_paths=false, the test_ file's path penalty (0.3)
505        // does NOT apply; but the second chunk from the same file still
506        // gets the saturation decay (×0.5).
507        let got = rerank_topk(&scores, &chunks, 2, false);
508        let scored: HashMap<usize, f32> = got.iter().copied().collect();
509        assert!((scored[&0] - 1.0).abs() < 1e-6);
510        assert!((scored[&1] - 0.5).abs() < 1e-6);
511    }
512}