Skip to main content

oxicuda_seq/beam/
length_penalty.rs

1//! GNMT-style length penalty and coverage penalty for beam search scoring.
2//!
3//! Reference: Wu et al. 2016, "Google's Neural Machine Translation System:
4//! Bridging the Gap between Human and Machine Translation", arXiv:1609.08144.
5//!
6//! Length penalty:     lp(y) = ((5 + |y|) / 6)^α
7//! Coverage penalty:   cp    = Σ_i log(min(Σ_t p_{t,i}, 1.0))
8//! Combined score:     score = log_prob / lp(|y|) - β * |cp|
9
10use crate::error::{SeqError, SeqResult};
11
12// ─── Configuration ────────────────────────────────────────────────────────────
13
14/// Configuration for GNMT-style length and coverage penalties.
15#[derive(Debug, Clone)]
16pub struct LengthPenaltyConfig {
17    /// Length-penalty exponent α.  0 = disabled; typical values 0.6–1.0.
18    pub alpha: f64,
19    /// Coverage-penalty weight β.  0 = disabled.
20    pub beta: f64,
21    /// Minimum output length (informational; not enforced by score()).
22    pub min_length: usize,
23    /// Maximum output length (informational; not enforced by score()).
24    pub max_length: usize,
25}
26
27// ─── LengthPenalty ────────────────────────────────────────────────────────────
28
29/// Computes GNMT-style length and coverage penalties for beam-search hypothesis scoring.
30#[derive(Debug, Clone)]
31pub struct LengthPenalty {
32    config: LengthPenaltyConfig,
33}
34
35impl LengthPenalty {
36    /// Create a new `LengthPenalty`.  Returns `Err` if `alpha < 0` or `beta < 0`.
37    pub fn new(config: LengthPenaltyConfig) -> SeqResult<Self> {
38        if config.alpha < 0.0 {
39            return Err(SeqError::InvalidParameter {
40                name: "alpha".into(),
41                value: config.alpha,
42            });
43        }
44        if config.beta < 0.0 {
45            return Err(SeqError::InvalidParameter {
46                name: "beta".into(),
47                value: config.beta,
48            });
49        }
50        if config.max_length == 0 {
51            return Err(SeqError::InvalidConfiguration(
52                "max_length must be > 0".into(),
53            ));
54        }
55        Ok(Self { config })
56    }
57
58    // ── Core penalty functions ────────────────────────────────────────────────
59
60    /// GNMT length penalty: `((5 + length) / (5 + 1))^alpha`.
61    ///
62    /// At `length=1`: returns 1.0.
63    /// Monotonically increasing for `alpha > 0`.
64    #[inline]
65    pub fn lp(&self, length: usize) -> f64 {
66        let ratio = (5.0 + length as f64) / 6.0;
67        ratio.powf(self.config.alpha)
68    }
69
70    /// Coverage penalty: `Σ_i log(min(Σ_t p_{t,i}, 1.0))`.
71    ///
72    /// `coverage_probs` has layout `[seq_len × n_source]`
73    /// (for each target step `t`, attention over `n_source` source tokens).
74    ///
75    /// Returns 0.0 when all source tokens are fully covered (sum ≥ 1.0).
76    /// Returns a negative value when coverage is partial.
77    pub fn cp(&self, coverage_probs: &[f64], n_source: usize, seq_len: usize) -> f64 {
78        if n_source == 0 || seq_len == 0 || coverage_probs.is_empty() {
79            return 0.0;
80        }
81        // Accumulate Σ_t p_{t,i} for each source position i.
82        let mut coverage = vec![0.0f64; n_source];
83        for t in 0..seq_len {
84            for i in 0..n_source {
85                let idx = t * n_source + i;
86                if idx < coverage_probs.len() {
87                    coverage[i] += coverage_probs[idx];
88                }
89            }
90        }
91        // cp = Σ_i log(min(coverage_i, 1.0))
92        let mut penalty = 0.0;
93        for i in 0..n_source {
94            penalty += coverage[i].min(1.0).ln();
95        }
96        penalty
97    }
98
99    /// Combined beam-search score.
100    ///
101    /// `score = log_prob / lp(length) - beta * |cp|`
102    ///
103    /// The magnitude of `cp` is used so that beta ≥ 0 always penalises
104    /// under-coverage (cp is ≤ 0 when coverage < 1).
105    pub fn score(
106        &self,
107        log_prob: f64,
108        length: usize,
109        coverage_probs: &[f64],
110        n_source: usize,
111    ) -> SeqResult<f64> {
112        if !log_prob.is_finite() {
113            return Err(SeqError::NumericalInstability(
114                "log_prob is not finite".into(),
115            ));
116        }
117        let lp = self.lp(length);
118        let cp_val = self.cp(coverage_probs, n_source, length);
119        Ok(log_prob / lp - self.config.beta * cp_val.abs())
120    }
121
122    /// Rank hypotheses by descending combined score (no coverage penalty applied
123    /// in this simplified batch form — coverage is assumed uniform).
124    ///
125    /// Returns indices sorted best-first.
126    pub fn rank(&self, log_probs: &[f64], lengths: &[usize]) -> Vec<usize> {
127        if log_probs.is_empty() {
128            return Vec::new();
129        }
130        let n = log_probs.len().min(lengths.len());
131        let scores: Vec<f64> = (0..n)
132            .map(|i| {
133                let lp = self.lp(lengths[i]);
134                log_probs[i] / lp
135            })
136            .collect();
137        let mut indices: Vec<usize> = (0..n).collect();
138        indices.sort_by(|&a, &b| {
139            scores[b]
140                .partial_cmp(&scores[a])
141                .unwrap_or(std::cmp::Ordering::Equal)
142        });
143        indices
144    }
145}
146
147// ─── Tests ────────────────────────────────────────────────────────────────────
148
149#[cfg(test)]
150mod tests {
151    use super::*;
152
153    fn make_lp(alpha: f64, beta: f64) -> LengthPenalty {
154        LengthPenalty::new(LengthPenaltyConfig {
155            alpha,
156            beta,
157            min_length: 1,
158            max_length: 200,
159        })
160        .expect("LengthPenalty::new failed")
161    }
162
163    #[test]
164    fn lp_at_length_1() {
165        // lp(1) = ((5+1)/(5+1))^alpha = 1.0 for any alpha
166        for &alpha in &[0.0, 0.5, 1.0, 2.0] {
167            let lp = make_lp(alpha, 0.0);
168            let val = lp.lp(1);
169            assert!(
170                (val - 1.0).abs() < 1e-12,
171                "lp(1) should be 1.0 for alpha={alpha}, got {val}"
172            );
173        }
174    }
175
176    #[test]
177    fn lp_increases_with_length() {
178        let lp = make_lp(0.8, 0.0);
179        assert!(
180            lp.lp(10) > lp.lp(5),
181            "lp(10)={} should be > lp(5)={} for alpha=0.8",
182            lp.lp(10),
183            lp.lp(5)
184        );
185    }
186
187    #[test]
188    fn alpha_zero_lp_one() {
189        let lp = make_lp(0.0, 0.0);
190        for length in [1, 5, 10, 100] {
191            let val = lp.lp(length);
192            assert!(
193                (val - 1.0).abs() < 1e-12,
194                "alpha=0: lp({length}) should be 1.0, got {val}"
195            );
196        }
197    }
198
199    #[test]
200    fn cp_zero_when_full_coverage() {
201        let lp = make_lp(0.6, 0.1);
202        let n_source = 3;
203        let seq_len = 3;
204        // Each target step attends uniformly: row sums to 1/3 each → column sum = 1.0
205        let coverage_probs = vec![1.0 / 3.0; n_source * seq_len];
206        let cp = lp.cp(&coverage_probs, n_source, seq_len);
207        assert!(
208            cp.abs() < 1e-10,
209            "cp should be ~0 for full coverage, got {cp}"
210        );
211    }
212
213    #[test]
214    fn cp_negative_for_under_coverage() {
215        let lp = make_lp(0.6, 0.1);
216        let n_source = 4;
217        let seq_len = 2;
218        // Each step attends only to first source position → positions 1-3 get 0
219        let mut coverage_probs = vec![0.0f64; n_source * seq_len];
220        for t in 0..seq_len {
221            coverage_probs[t * n_source] = 0.3; // only position 0
222        }
223        let cp = lp.cp(&coverage_probs, n_source, seq_len);
224        assert!(cp < 0.0, "under-coverage should give negative cp, got {cp}");
225    }
226
227    #[test]
228    fn score_penalizes_short() {
229        // For high alpha, a longer sequence with the same total log-prob per token
230        // should get a higher score (penalty reduces for longer sequences).
231        let lp = make_lp(1.0, 0.0);
232        let empty_cov: &[f64] = &[];
233        let _short = lp.score(-10.0, 5, empty_cov, 0).expect("score short");
234        let _long = lp.score(-20.0, 15, empty_cov, 0).expect("score long");
235        // short log_prob per token = -2.0/tok, long = -20/15 ≈ -1.33
236        // After lp division: short = -10 / lp(5), long = -20 / lp(15)
237        // With alpha=1: lp(5)=(10/6)=1.667, lp(15)=(20/6)=3.333
238        // short_score = -10/1.667 ≈ -6.0, long_score = -20/3.333 ≈ -6.0 → both ~equal
239        // Use a cleaner example: long has much better per-token log_prob
240        let better_long = lp.score(-6.0, 20, empty_cov, 0).expect("score better_long");
241        let worse_short = lp.score(-10.0, 3, empty_cov, 0).expect("score worse_short");
242        // better_long: -6.0 / lp(20) = -6.0 / (25/6) = -6 * 6/25 = -1.44
243        // worse_short: -10.0 / lp(3) = -10.0 / (8/6) = -10 * 6/8 = -7.5
244        assert!(
245            better_long > worse_short,
246            "better_long_score={better_long:.4} should > worse_short_score={worse_short:.4}"
247        );
248    }
249
250    #[test]
251    fn rank_returns_correct_order() {
252        let lp = make_lp(0.6, 0.0);
253        // Candidate 0: log_prob=-5, len=5  → score = -5/lp(5)
254        // Candidate 1: log_prob=-2, len=3  → score = -2/lp(3)  (best)
255        // Candidate 2: log_prob=-15, len=20 → score = -15/lp(20) (worst)
256        let log_probs = [-5.0, -2.0, -15.0];
257        let lengths = [5, 3, 20];
258        let order = lp.rank(&log_probs, &lengths);
259        assert_eq!(order[0], 1, "best candidate should be index 1");
260        assert_eq!(order[2], 2, "worst candidate should be index 2");
261    }
262
263    #[test]
264    fn max_length_exceeded_score_no_panic() {
265        // score() should work even for length > max_length
266        let lp = LengthPenalty::new(LengthPenaltyConfig {
267            alpha: 0.6,
268            beta: 0.0,
269            min_length: 1,
270            max_length: 10,
271        })
272        .expect("new");
273        let result = lp.score(-5.0, 50, &[], 0);
274        assert!(
275            result.is_ok(),
276            "score should not fail for length > max_length"
277        );
278    }
279
280    #[test]
281    fn beta_zero_no_coverage_penalty() {
282        let lp = make_lp(0.6, 0.0); // beta=0
283        // With beta=0, coverage term is 0 → score = log_prob / lp(len)
284        let n_source = 3;
285        let coverage_probs = vec![0.1f64; n_source * 5]; // partial coverage
286        let s = lp.score(-8.0, 5, &coverage_probs, n_source).expect("score");
287        let expected = -8.0 / lp.lp(5);
288        assert!(
289            (s - expected).abs() < 1e-12,
290            "beta=0: score should be log_prob/lp, expected={expected}, got={s}"
291        );
292    }
293}