use crate::error::{SeqError, SeqResult};
#[derive(Debug, Clone)]
pub struct LengthPenaltyConfig {
pub alpha: f64,
pub beta: f64,
pub min_length: usize,
pub max_length: usize,
}
#[derive(Debug, Clone)]
pub struct LengthPenalty {
config: LengthPenaltyConfig,
}
impl LengthPenalty {
pub fn new(config: LengthPenaltyConfig) -> SeqResult<Self> {
if config.alpha < 0.0 {
return Err(SeqError::InvalidParameter {
name: "alpha".into(),
value: config.alpha,
});
}
if config.beta < 0.0 {
return Err(SeqError::InvalidParameter {
name: "beta".into(),
value: config.beta,
});
}
if config.max_length == 0 {
return Err(SeqError::InvalidConfiguration(
"max_length must be > 0".into(),
));
}
Ok(Self { config })
}
#[inline]
pub fn lp(&self, length: usize) -> f64 {
let ratio = (5.0 + length as f64) / 6.0;
ratio.powf(self.config.alpha)
}
pub fn cp(&self, coverage_probs: &[f64], n_source: usize, seq_len: usize) -> f64 {
if n_source == 0 || seq_len == 0 || coverage_probs.is_empty() {
return 0.0;
}
let mut coverage = vec![0.0f64; n_source];
for t in 0..seq_len {
for i in 0..n_source {
let idx = t * n_source + i;
if idx < coverage_probs.len() {
coverage[i] += coverage_probs[idx];
}
}
}
let mut penalty = 0.0;
for i in 0..n_source {
penalty += coverage[i].min(1.0).ln();
}
penalty
}
pub fn score(
&self,
log_prob: f64,
length: usize,
coverage_probs: &[f64],
n_source: usize,
) -> SeqResult<f64> {
if !log_prob.is_finite() {
return Err(SeqError::NumericalInstability(
"log_prob is not finite".into(),
));
}
let lp = self.lp(length);
let cp_val = self.cp(coverage_probs, n_source, length);
Ok(log_prob / lp - self.config.beta * cp_val.abs())
}
pub fn rank(&self, log_probs: &[f64], lengths: &[usize]) -> Vec<usize> {
if log_probs.is_empty() {
return Vec::new();
}
let n = log_probs.len().min(lengths.len());
let scores: Vec<f64> = (0..n)
.map(|i| {
let lp = self.lp(lengths[i]);
log_probs[i] / lp
})
.collect();
let mut indices: Vec<usize> = (0..n).collect();
indices.sort_by(|&a, &b| {
scores[b]
.partial_cmp(&scores[a])
.unwrap_or(std::cmp::Ordering::Equal)
});
indices
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_lp(alpha: f64, beta: f64) -> LengthPenalty {
LengthPenalty::new(LengthPenaltyConfig {
alpha,
beta,
min_length: 1,
max_length: 200,
})
.expect("LengthPenalty::new failed")
}
#[test]
fn lp_at_length_1() {
for &alpha in &[0.0, 0.5, 1.0, 2.0] {
let lp = make_lp(alpha, 0.0);
let val = lp.lp(1);
assert!(
(val - 1.0).abs() < 1e-12,
"lp(1) should be 1.0 for alpha={alpha}, got {val}"
);
}
}
#[test]
fn lp_increases_with_length() {
let lp = make_lp(0.8, 0.0);
assert!(
lp.lp(10) > lp.lp(5),
"lp(10)={} should be > lp(5)={} for alpha=0.8",
lp.lp(10),
lp.lp(5)
);
}
#[test]
fn alpha_zero_lp_one() {
let lp = make_lp(0.0, 0.0);
for length in [1, 5, 10, 100] {
let val = lp.lp(length);
assert!(
(val - 1.0).abs() < 1e-12,
"alpha=0: lp({length}) should be 1.0, got {val}"
);
}
}
#[test]
fn cp_zero_when_full_coverage() {
let lp = make_lp(0.6, 0.1);
let n_source = 3;
let seq_len = 3;
let coverage_probs = vec![1.0 / 3.0; n_source * seq_len];
let cp = lp.cp(&coverage_probs, n_source, seq_len);
assert!(
cp.abs() < 1e-10,
"cp should be ~0 for full coverage, got {cp}"
);
}
#[test]
fn cp_negative_for_under_coverage() {
let lp = make_lp(0.6, 0.1);
let n_source = 4;
let seq_len = 2;
let mut coverage_probs = vec![0.0f64; n_source * seq_len];
for t in 0..seq_len {
coverage_probs[t * n_source] = 0.3; }
let cp = lp.cp(&coverage_probs, n_source, seq_len);
assert!(cp < 0.0, "under-coverage should give negative cp, got {cp}");
}
#[test]
fn score_penalizes_short() {
let lp = make_lp(1.0, 0.0);
let empty_cov: &[f64] = &[];
let _short = lp.score(-10.0, 5, empty_cov, 0).expect("score short");
let _long = lp.score(-20.0, 15, empty_cov, 0).expect("score long");
let better_long = lp.score(-6.0, 20, empty_cov, 0).expect("score better_long");
let worse_short = lp.score(-10.0, 3, empty_cov, 0).expect("score worse_short");
assert!(
better_long > worse_short,
"better_long_score={better_long:.4} should > worse_short_score={worse_short:.4}"
);
}
#[test]
fn rank_returns_correct_order() {
let lp = make_lp(0.6, 0.0);
let log_probs = [-5.0, -2.0, -15.0];
let lengths = [5, 3, 20];
let order = lp.rank(&log_probs, &lengths);
assert_eq!(order[0], 1, "best candidate should be index 1");
assert_eq!(order[2], 2, "worst candidate should be index 2");
}
#[test]
fn max_length_exceeded_score_no_panic() {
let lp = LengthPenalty::new(LengthPenaltyConfig {
alpha: 0.6,
beta: 0.0,
min_length: 1,
max_length: 10,
})
.expect("new");
let result = lp.score(-5.0, 50, &[], 0);
assert!(
result.is_ok(),
"score should not fail for length > max_length"
);
}
#[test]
fn beta_zero_no_coverage_penalty() {
let lp = make_lp(0.6, 0.0); let n_source = 3;
let coverage_probs = vec![0.1f64; n_source * 5]; let s = lp.score(-8.0, 5, &coverage_probs, n_source).expect("score");
let expected = -8.0 / lp.lp(5);
assert!(
(s - expected).abs() < 1e-12,
"beta=0: score should be log_prob/lp, expected={expected}, got={s}"
);
}
}