Skip to main content

ripvec_core/encoder/ripvec/
penalties.rs

1//! File-path penalties + `rerank_topk` with file-saturation decay.
2//!
3//! Port of `~/src/semble/src/semble/ranking/penalties.py`. Two surfaces:
4//!
5//! 1. [`file_path_penalty`] — multiplicative penalty per file path.
6//!    Combines penalties for test files, compat/legacy/examples
7//!    directories, re-export barrels (`__init__.py`, `package-info.java`),
8//!    and `.d.ts` declaration stubs.
9//! 2. [`rerank_topk`] — greedy top-k selection that applies path
10//!    penalties (when `penalise_paths == true`) then decays by 0.5 per
11//!    extra chunk from the same file beyond a threshold of 1.
12//!
13//! ## Indexing convention
14//!
15//! Where Python uses `dict[Chunk, float]` keyed by hashable Chunk,
16//! Rust uses `(chunk_index, score)` pairs — the same convention ripvec
17//! already uses in [`crate::hybrid`]. This avoids adding `Hash`/`Eq`
18//! impls to [`CodeChunk`] just to satisfy a HashMap key.
19
20use std::collections::HashMap;
21use std::path::Path;
22use std::sync::OnceLock;
23
24use regex::{Regex, RegexBuilder};
25
26use crate::chunk::CodeChunk;
27
28// ---------------------------------------------------------------------------
29// Penalty constants (Python: `_STRONG_PENALTY` etc. in penalties.py:67).
30// ---------------------------------------------------------------------------
31
32/// Test files, compat shims, example/doc code.
33pub const STRONG_PENALTY: f32 = 0.3;
34/// Re-export / metadata files (`__init__.py`, `package-info.java`).
35pub const MODERATE_PENALTY: f32 = 0.5;
36/// `.d.ts` declaration stubs (still carry useful type info).
37pub const MILD_PENALTY: f32 = 0.7;
38
39/// Maximum chunks from the same file before saturation penalty applies.
40pub const FILE_SATURATION_THRESHOLD: usize = 1;
41
42/// Multiplicative penalty per extra chunk from the same file beyond
43/// [`FILE_SATURATION_THRESHOLD`]. Excess chunks pay `decay^excess`.
44pub const FILE_SATURATION_DECAY: f32 = 0.5;
45
46// ---------------------------------------------------------------------------
47// Path-prior regex bank (Python: penalties.py:8-65).
48// ---------------------------------------------------------------------------
49
50/// Regex matching test files across 14 languages plus shared helpers.
51///
52/// Mirrors `_TEST_FILE_RE` from `penalties.py:8`. Lazy-compiled.
53fn test_file_re() -> &'static Regex {
54    static RE: OnceLock<Regex> = OnceLock::new();
55    RE.get_or_init(|| {
56        // Match exactly the Python pattern. The non-capture group after
57        // `(?:^|/)` lists per-language test-file conventions; the
58        // pattern terminates with `$`.
59        let pattern = concat!(
60            r"(?:^|/)",
61            r"(?:",
62            // Python
63            r"test_[^/]*\.py",
64            r"|[^/]*_test\.py",
65            // Go
66            r"|[^/]*_test\.go",
67            // Java
68            r"|[^/]*Tests?\.java",
69            // PHP
70            r"|[^/]*Test\.php",
71            // Ruby
72            r"|[^/]*_spec\.rb",
73            r"|[^/]*_test\.rb",
74            // JavaScript / TypeScript
75            r"|[^/]*\.test\.[jt]sx?",
76            r"|[^/]*\.spec\.[jt]sx?",
77            // Kotlin
78            r"|[^/]*Tests?\.kt",
79            r"|[^/]*Spec\.kt",
80            // Swift
81            r"|[^/]*Tests?\.swift",
82            r"|[^/]*Spec\.swift",
83            // C#
84            r"|[^/]*Tests?\.cs",
85            // C / C++
86            r"|test_[^/]*\.cpp",
87            r"|[^/]*_test\.cpp",
88            r"|test_[^/]*\.c",
89            r"|[^/]*_test\.c",
90            // Scala
91            r"|[^/]*Spec\.scala",
92            r"|[^/]*Suite\.scala",
93            r"|[^/]*Test\.scala",
94            // Dart
95            r"|[^/]*_test\.dart",
96            r"|test_[^/]*\.dart",
97            // Lua
98            r"|[^/]*_spec\.lua",
99            r"|[^/]*_test\.lua",
100            r"|test_[^/]*\.lua",
101            // Shared helper patterns (all languages)
102            r"|test_helpers?[^/]*\.\w+",
103            r")$",
104        );
105        Regex::new(pattern).expect("test-file regex compiles")
106    })
107}
108
109/// Regex matching directories named `test`, `tests`, `__tests__`, `spec`,
110/// `testing`. Mirrors `_TEST_DIR_RE` from `penalties.py:56`.
111fn test_dir_re() -> &'static Regex {
112    static RE: OnceLock<Regex> = OnceLock::new();
113    RE.get_or_init(|| {
114        Regex::new(r"(?:^|/)(?:tests?|__tests__|spec|testing)(?:/|$)")
115            .expect("test-dir regex compiles")
116    })
117}
118
119/// Regex matching compat / legacy directories. Mirrors `_COMPAT_DIR_RE`.
120fn compat_dir_re() -> &'static Regex {
121    static RE: OnceLock<Regex> = OnceLock::new();
122    RE.get_or_init(|| {
123        Regex::new(r"(?:^|/)(?:compat|_compat|legacy)(?:/|$)").expect("compat-dir regex compiles")
124    })
125}
126
127/// Regex matching examples / docs-src directories. Mirrors
128/// `_EXAMPLES_DIR_RE`.
129fn examples_dir_re() -> &'static Regex {
130    static RE: OnceLock<Regex> = OnceLock::new();
131    RE.get_or_init(|| {
132        Regex::new(r"(?:^|/)(?:_?examples?|docs?_src)(?:/|$)").expect("examples-dir regex compiles")
133    })
134}
135
136/// Regex matching TypeScript `.d.ts` declaration stubs.
137fn type_defs_re() -> &'static Regex {
138    static RE: OnceLock<Regex> = OnceLock::new();
139    RE.get_or_init(|| {
140        RegexBuilder::new(r"\.d\.ts$")
141            .build()
142            .expect("dts regex compiles")
143    })
144}
145
146/// Filenames that are re-export barrels or package-level metadata.
147const REEXPORT_FILENAMES: &[&str] = &["__init__.py", "package-info.java"];
148
149// ---------------------------------------------------------------------------
150// Public API.
151// ---------------------------------------------------------------------------
152
153/// Combined multiplicative penalty for all applicable path patterns.
154///
155/// Always in `(0.0, 1.0]`. Returns `1.0` when no patterns match.
156/// Each matched category multiplies into the result independently, so a
157/// file matching both a test-dir pattern and a re-export filename pays
158/// `STRONG_PENALTY * MODERATE_PENALTY`.
159///
160/// Backslashes in `file_path` are normalised to forward slashes before
161/// matching, so Windows-style paths work too.
162#[must_use]
163pub fn file_path_penalty(file_path: &str) -> f32 {
164    let normalised = file_path.replace('\\', "/");
165    let mut penalty = 1.0_f32;
166    if test_file_re().is_match(&normalised) || test_dir_re().is_match(&normalised) {
167        penalty *= STRONG_PENALTY;
168    }
169    if let Some(filename) = Path::new(file_path).file_name().and_then(|f| f.to_str())
170        && REEXPORT_FILENAMES.contains(&filename)
171    {
172        penalty *= MODERATE_PENALTY;
173    }
174    if compat_dir_re().is_match(&normalised) {
175        penalty *= STRONG_PENALTY;
176    }
177    if examples_dir_re().is_match(&normalised) {
178        penalty *= STRONG_PENALTY;
179    }
180    if type_defs_re().is_match(&normalised) {
181        penalty *= MILD_PENALTY;
182    }
183    penalty
184}
185
186/// Select the top-k results with optional path penalties and
187/// file-saturation decay.
188///
189/// Mirrors `rerank_topk` from `penalties.py:81`. The greedy pass uses
190/// the early-exit optimisation: once `selected.len() >= top_k` and the
191/// remaining candidate's penalised score cannot beat the current k-th
192/// best, the loop terminates.
193///
194/// - `penalise_paths == true` — apply [`file_path_penalty`] before
195///   sorting and selection. This matches semble's behaviour for hybrid
196///   and BM25-led queries.
197/// - `penalise_paths == false` — bypass path priors (pure-semantic
198///   queries where the priors don't apply).
199///
200/// Returns `(chunk_index, effective_score)` pairs, highest first, with
201/// length capped at `top_k`.
202#[must_use]
203pub fn rerank_topk(
204    scores: &[(usize, f32)],
205    chunks: &[CodeChunk],
206    top_k: usize,
207    penalise_paths: bool,
208) -> Vec<(usize, f32)> {
209    if scores.is_empty() || top_k == 0 {
210        return Vec::new();
211    }
212
213    // Step 1: apply path penalties (or identity), producing (chunk_index, penalised_score).
214    let mut penalty_cache: HashMap<String, f32> = HashMap::new();
215    let mut penalised: Vec<(usize, f32)> = scores
216        .iter()
217        .map(|&(idx, score)| {
218            if !penalise_paths {
219                return (idx, score);
220            }
221            let path = &chunks[idx].file_path;
222            let factor = *penalty_cache
223                .entry(path.clone())
224                .or_insert_with(|| file_path_penalty(path));
225            (idx, score * factor)
226        })
227        .collect();
228
229    // Step 2: sort descending by penalised score with stable tie-break by index.
230    penalised.sort_unstable_by(|a, b| b.1.total_cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
231
232    // Step 3: greedy pick with file-saturation decay and early-exit.
233    let mut file_selected: HashMap<String, usize> = HashMap::new();
234    let mut selected: Vec<(usize, f32)> = Vec::with_capacity(top_k.min(penalised.len()));
235    let mut min_selected = f32::INFINITY;
236
237    for &(idx, pen_score) in &penalised {
238        if selected.len() >= top_k && pen_score <= min_selected {
239            // Remaining (sorted-descending) cannot beat the current k-th best.
240            break;
241        }
242        let path = chunks[idx].file_path.clone();
243        let already = *file_selected.get(&path).unwrap_or(&0);
244        let eff_score = if already >= FILE_SATURATION_THRESHOLD {
245            let excess = already - FILE_SATURATION_THRESHOLD + 1;
246            // Excess is bounded by the number of chunks in the front; the
247            // saturation here is far below i32::MAX in practice. The clamp
248            // is defence-in-depth.
249            let excess_i32 = i32::try_from(excess).unwrap_or(i32::MAX);
250            pen_score * FILE_SATURATION_DECAY.powi(excess_i32)
251        } else {
252            pen_score
253        };
254        selected.push((idx, eff_score));
255        file_selected.insert(path, already + 1);
256
257        if selected.len() >= top_k {
258            // Refresh the min over the current selection.
259            min_selected = selected
260                .iter()
261                .map(|(_, s)| *s)
262                .fold(f32::INFINITY, f32::min);
263        }
264    }
265
266    // Final descending sort on effective scores.
267    selected.sort_unstable_by(|a, b| b.1.total_cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
268    selected.truncate(top_k);
269    selected
270}
271
272#[cfg(test)]
273mod tests {
274    use super::*;
275
276    fn chunk_at(path: &str) -> CodeChunk {
277        CodeChunk {
278            file_path: path.to_string(),
279            name: String::new(),
280            kind: String::new(),
281            content_kind: crate::chunk::ContentKind::Code,
282            start_line: 1,
283            symbol_line: 1,
284            end_line: 1,
285            content: String::new(),
286            enriched_content: String::new(),
287            qualified_name: None,
288        }
289    }
290
291    /// `test:penalties-test-file-regex-14-langs` — every per-language
292    /// test-file pattern from `penalties.py:8` matches a representative
293    /// file path.
294    #[test]
295    fn penalties_test_file_regex_14_langs() {
296        let cases: &[&str] = &[
297            // Python
298            "src/test_foo.py",
299            "src/foo_test.py",
300            // Go
301            "pkg/foo_test.go",
302            // Java
303            "src/FooTest.java",
304            "src/FooTests.java",
305            // PHP
306            "src/FooTest.php",
307            // Ruby
308            "spec/foo_spec.rb",
309            "test/foo_test.rb",
310            // JS/TS
311            "src/foo.test.js",
312            "src/foo.spec.ts",
313            "src/foo.test.tsx",
314            // Kotlin
315            "src/FooTest.kt",
316            "src/FooTests.kt",
317            "src/FooSpec.kt",
318            // Swift
319            "src/FooTests.swift",
320            "src/FooSpec.swift",
321            // C#
322            "src/FooTest.cs",
323            "src/FooTests.cs",
324            // C / C++
325            "src/test_foo.cpp",
326            "src/foo_test.cpp",
327            "src/test_foo.c",
328            "src/foo_test.c",
329            // Scala
330            "src/FooSpec.scala",
331            "src/FooSuite.scala",
332            "src/FooTest.scala",
333            // Dart
334            "src/foo_test.dart",
335            "src/test_foo.dart",
336            // Lua
337            "src/foo_spec.lua",
338            "src/foo_test.lua",
339            "src/test_foo.lua",
340            // Shared helpers
341            "test/test_helper.rb",
342            "test/test_helpers.go",
343        ];
344        for path in cases {
345            assert!(
346                test_file_re().is_match(path),
347                "expected test_file_re to match {path:?}"
348            );
349            assert!(
350                (file_path_penalty(path) - STRONG_PENALTY).abs() < 1e-6,
351                "expected STRONG_PENALTY for {path:?}; got {}",
352                file_path_penalty(path)
353            );
354        }
355    }
356
357    /// `test:penalties-compat-dir` — compat / legacy directories trigger
358    /// the strong penalty.
359    #[test]
360    fn penalties_compat_dir() {
361        for path in &["compat/foo.py", "src/_compat/bar.rs", "legacy/baz.go"] {
362            assert!(
363                compat_dir_re().is_match(path),
364                "expected compat match for {path:?}"
365            );
366            assert!((file_path_penalty(path) - STRONG_PENALTY).abs() < 1e-6);
367        }
368    }
369
370    /// `test:penalties-examples-dir` — example / docs-src directories
371    /// trigger the strong penalty.
372    #[test]
373    fn penalties_examples_dir() {
374        for path in &[
375            "examples/foo.py",
376            "_examples/bar.rs",
377            "example/baz.go",
378            "docs_src/x.md",
379        ] {
380            assert!(
381                examples_dir_re().is_match(path),
382                "expected examples match for {path:?}"
383            );
384            assert!((file_path_penalty(path) - STRONG_PENALTY).abs() < 1e-6);
385        }
386    }
387
388    /// `test:penalties-init-py-reexport` — `__init__.py` carries the
389    /// moderate re-export penalty (and Java's `package-info.java`).
390    ///
391    /// Java path deliberately avoids `/example/` since that segment
392    /// triggers `examples_dir_re`'s strong penalty (we want to verify
393    /// re-export in isolation here).
394    #[test]
395    fn penalties_init_py_reexport() {
396        assert!((file_path_penalty("src/__init__.py") - MODERATE_PENALTY).abs() < 1e-6);
397        assert!(
398            (file_path_penalty("src/com/myorg/package-info.java") - MODERATE_PENALTY).abs() < 1e-6
399        );
400    }
401
402    /// `test:penalties-dts-stub` — `.d.ts` files take the mild penalty.
403    #[test]
404    fn penalties_dts_stub() {
405        assert!((file_path_penalty("src/foo.d.ts") - MILD_PENALTY).abs() < 1e-6);
406    }
407
408    /// Non-penalized paths return penalty 1.0.
409    #[test]
410    fn non_penalized_path_is_identity() {
411        assert!((file_path_penalty("src/foo.rs") - 1.0).abs() < 1e-6);
412        assert!((file_path_penalty("lib/bar.py") - 1.0).abs() < 1e-6);
413    }
414
415    /// `test:rerank-topk-saturation-decay` — a third chunk from the
416    /// same file is penalised by 0.5^2.
417    #[test]
418    fn rerank_topk_saturation_decay() {
419        let chunks = vec![
420            chunk_at("src/foo.rs"),
421            chunk_at("src/foo.rs"),
422            chunk_at("src/foo.rs"),
423            chunk_at("src/bar.rs"),
424        ];
425        // All four chunks have identical raw scores. The path penalty is 1.0
426        // for both files. Greedy order picks chunks in their submitted
427        // ordering (stable tie-break by index).
428        let scores = vec![(0, 1.0_f32), (1, 1.0), (2, 1.0), (3, 1.0)];
429        let got = rerank_topk(&scores, &chunks, 4, true);
430        assert_eq!(got.len(), 4);
431
432        let scored: HashMap<usize, f32> = got.iter().copied().collect();
433        // Chunk 0 (first from foo.rs): score 1.0.
434        // Chunk 3 (first from bar.rs): score 1.0.
435        // Chunk 1 (second from foo.rs): already=1, excess=1, score *= 0.5.
436        // Chunk 2 (third from foo.rs):  already=2, excess=2, score *= 0.25.
437        assert!(
438            (scored[&0] - 1.0).abs() < 1e-6,
439            "scored[0] = {}",
440            scored[&0]
441        );
442        assert!(
443            (scored[&3] - 1.0).abs() < 1e-6,
444            "scored[3] = {}",
445            scored[&3]
446        );
447        assert!(
448            (scored[&1] - 0.5).abs() < 1e-6,
449            "scored[1] = {}",
450            scored[&1]
451        );
452        assert!(
453            (scored[&2] - 0.25).abs() < 1e-6,
454            "scored[2] = {}",
455            scored[&2]
456        );
457    }
458
459    /// `test:rerank-topk-greedy-early-exit` — top_k smaller than the
460    /// input cuts the iteration short while preserving ordering.
461    #[test]
462    fn rerank_topk_greedy_early_exit() {
463        let chunks: Vec<CodeChunk> = (0..10).map(|i| chunk_at(&format!("src/f{i}.rs"))).collect();
464        let scores: Vec<(usize, f32)> = (0..10).map(|i| (i, 10.0 - i as f32)).collect();
465        let got = rerank_topk(&scores, &chunks, 3, true);
466        assert_eq!(got.len(), 3);
467        let indices: Vec<usize> = got.iter().map(|(i, _)| *i).collect();
468        assert_eq!(indices, vec![0, 1, 2]);
469    }
470
471    /// `property:penalty-regex-parity-python` — non-test paths in
472    /// production layouts (src/, lib/, crates/, etc.) carry no penalty,
473    /// matching Python's behaviour exactly.
474    #[test]
475    fn property_penalty_regex_parity_python() {
476        let production_paths = &[
477            "src/foo.py",
478            "lib/parser.rs",
479            "crates/ripvec-core/src/encoder/semble/tokens.rs",
480            "pkg/server/handler.go",
481            "app/models/user.rb",
482            "main.c",
483            "include/foo.h",
484        ];
485        for path in production_paths {
486            assert!(
487                (file_path_penalty(path) - 1.0).abs() < 1e-6,
488                "non-test path {path:?} should have penalty 1.0; got {}",
489                file_path_penalty(path)
490            );
491        }
492        // And test-shaped paths always strong-penalize.
493        for path in &["test_foo.py", "src/__init__.py", "src/foo.d.ts"] {
494            assert!(
495                file_path_penalty(path) < 1.0,
496                "test-shaped path {path:?} should be penalised; got 1.0"
497            );
498        }
499    }
500
501    /// `penalise_paths == false` bypasses path priors but still applies
502    /// file-saturation decay.
503    #[test]
504    fn rerank_topk_no_path_penalties_still_decays() {
505        let chunks = vec![chunk_at("test/test_foo.py"), chunk_at("test/test_foo.py")];
506        let scores = vec![(0, 1.0_f32), (1, 1.0)];
507        // With penalise_paths=false, the test_ file's path penalty (0.3)
508        // does NOT apply; but the second chunk from the same file still
509        // gets the saturation decay (×0.5).
510        let got = rerank_topk(&scores, &chunks, 2, false);
511        let scored: HashMap<usize, f32> = got.iter().copied().collect();
512        assert!((scored[&0] - 1.0).abs() < 1e-6);
513        assert!((scored[&1] - 0.5).abs() < 1e-6);
514    }
515}