1use crate::error::{SeqError, SeqResult};
11
12#[derive(Debug, Clone)]
16pub struct LengthPenaltyConfig {
17 pub alpha: f64,
19 pub beta: f64,
21 pub min_length: usize,
23 pub max_length: usize,
25}
26
27#[derive(Debug, Clone)]
31pub struct LengthPenalty {
32 config: LengthPenaltyConfig,
33}
34
35impl LengthPenalty {
36 pub fn new(config: LengthPenaltyConfig) -> SeqResult<Self> {
38 if config.alpha < 0.0 {
39 return Err(SeqError::InvalidParameter {
40 name: "alpha".into(),
41 value: config.alpha,
42 });
43 }
44 if config.beta < 0.0 {
45 return Err(SeqError::InvalidParameter {
46 name: "beta".into(),
47 value: config.beta,
48 });
49 }
50 if config.max_length == 0 {
51 return Err(SeqError::InvalidConfiguration(
52 "max_length must be > 0".into(),
53 ));
54 }
55 Ok(Self { config })
56 }
57
58 #[inline]
65 pub fn lp(&self, length: usize) -> f64 {
66 let ratio = (5.0 + length as f64) / 6.0;
67 ratio.powf(self.config.alpha)
68 }
69
70 pub fn cp(&self, coverage_probs: &[f64], n_source: usize, seq_len: usize) -> f64 {
78 if n_source == 0 || seq_len == 0 || coverage_probs.is_empty() {
79 return 0.0;
80 }
81 let mut coverage = vec![0.0f64; n_source];
83 for t in 0..seq_len {
84 for i in 0..n_source {
85 let idx = t * n_source + i;
86 if idx < coverage_probs.len() {
87 coverage[i] += coverage_probs[idx];
88 }
89 }
90 }
91 let mut penalty = 0.0;
93 for i in 0..n_source {
94 penalty += coverage[i].min(1.0).ln();
95 }
96 penalty
97 }
98
99 pub fn score(
106 &self,
107 log_prob: f64,
108 length: usize,
109 coverage_probs: &[f64],
110 n_source: usize,
111 ) -> SeqResult<f64> {
112 if !log_prob.is_finite() {
113 return Err(SeqError::NumericalInstability(
114 "log_prob is not finite".into(),
115 ));
116 }
117 let lp = self.lp(length);
118 let cp_val = self.cp(coverage_probs, n_source, length);
119 Ok(log_prob / lp - self.config.beta * cp_val.abs())
120 }
121
122 pub fn rank(&self, log_probs: &[f64], lengths: &[usize]) -> Vec<usize> {
127 if log_probs.is_empty() {
128 return Vec::new();
129 }
130 let n = log_probs.len().min(lengths.len());
131 let scores: Vec<f64> = (0..n)
132 .map(|i| {
133 let lp = self.lp(lengths[i]);
134 log_probs[i] / lp
135 })
136 .collect();
137 let mut indices: Vec<usize> = (0..n).collect();
138 indices.sort_by(|&a, &b| {
139 scores[b]
140 .partial_cmp(&scores[a])
141 .unwrap_or(std::cmp::Ordering::Equal)
142 });
143 indices
144 }
145}
146
147#[cfg(test)]
150mod tests {
151 use super::*;
152
153 fn make_lp(alpha: f64, beta: f64) -> LengthPenalty {
154 LengthPenalty::new(LengthPenaltyConfig {
155 alpha,
156 beta,
157 min_length: 1,
158 max_length: 200,
159 })
160 .expect("LengthPenalty::new failed")
161 }
162
163 #[test]
164 fn lp_at_length_1() {
165 for &alpha in &[0.0, 0.5, 1.0, 2.0] {
167 let lp = make_lp(alpha, 0.0);
168 let val = lp.lp(1);
169 assert!(
170 (val - 1.0).abs() < 1e-12,
171 "lp(1) should be 1.0 for alpha={alpha}, got {val}"
172 );
173 }
174 }
175
176 #[test]
177 fn lp_increases_with_length() {
178 let lp = make_lp(0.8, 0.0);
179 assert!(
180 lp.lp(10) > lp.lp(5),
181 "lp(10)={} should be > lp(5)={} for alpha=0.8",
182 lp.lp(10),
183 lp.lp(5)
184 );
185 }
186
187 #[test]
188 fn alpha_zero_lp_one() {
189 let lp = make_lp(0.0, 0.0);
190 for length in [1, 5, 10, 100] {
191 let val = lp.lp(length);
192 assert!(
193 (val - 1.0).abs() < 1e-12,
194 "alpha=0: lp({length}) should be 1.0, got {val}"
195 );
196 }
197 }
198
199 #[test]
200 fn cp_zero_when_full_coverage() {
201 let lp = make_lp(0.6, 0.1);
202 let n_source = 3;
203 let seq_len = 3;
204 let coverage_probs = vec![1.0 / 3.0; n_source * seq_len];
206 let cp = lp.cp(&coverage_probs, n_source, seq_len);
207 assert!(
208 cp.abs() < 1e-10,
209 "cp should be ~0 for full coverage, got {cp}"
210 );
211 }
212
213 #[test]
214 fn cp_negative_for_under_coverage() {
215 let lp = make_lp(0.6, 0.1);
216 let n_source = 4;
217 let seq_len = 2;
218 let mut coverage_probs = vec![0.0f64; n_source * seq_len];
220 for t in 0..seq_len {
221 coverage_probs[t * n_source] = 0.3; }
223 let cp = lp.cp(&coverage_probs, n_source, seq_len);
224 assert!(cp < 0.0, "under-coverage should give negative cp, got {cp}");
225 }
226
227 #[test]
228 fn score_penalizes_short() {
229 let lp = make_lp(1.0, 0.0);
232 let empty_cov: &[f64] = &[];
233 let _short = lp.score(-10.0, 5, empty_cov, 0).expect("score short");
234 let _long = lp.score(-20.0, 15, empty_cov, 0).expect("score long");
235 let better_long = lp.score(-6.0, 20, empty_cov, 0).expect("score better_long");
241 let worse_short = lp.score(-10.0, 3, empty_cov, 0).expect("score worse_short");
242 assert!(
245 better_long > worse_short,
246 "better_long_score={better_long:.4} should > worse_short_score={worse_short:.4}"
247 );
248 }
249
250 #[test]
251 fn rank_returns_correct_order() {
252 let lp = make_lp(0.6, 0.0);
253 let log_probs = [-5.0, -2.0, -15.0];
257 let lengths = [5, 3, 20];
258 let order = lp.rank(&log_probs, &lengths);
259 assert_eq!(order[0], 1, "best candidate should be index 1");
260 assert_eq!(order[2], 2, "worst candidate should be index 2");
261 }
262
263 #[test]
264 fn max_length_exceeded_score_no_panic() {
265 let lp = LengthPenalty::new(LengthPenaltyConfig {
267 alpha: 0.6,
268 beta: 0.0,
269 min_length: 1,
270 max_length: 10,
271 })
272 .expect("new");
273 let result = lp.score(-5.0, 50, &[], 0);
274 assert!(
275 result.is_ok(),
276 "score should not fail for length > max_length"
277 );
278 }
279
280 #[test]
281 fn beta_zero_no_coverage_penalty() {
282 let lp = make_lp(0.6, 0.0); let n_source = 3;
285 let coverage_probs = vec![0.1f64; n_source * 5]; let s = lp.score(-8.0, 5, &coverage_probs, n_source).expect("score");
287 let expected = -8.0 / lp.lp(5);
288 assert!(
289 (s - expected).abs() < 1e-12,
290 "beta=0: score should be log_prob/lp, expected={expected}, got={s}"
291 );
292 }
293}