Skip to main content

oxibonsai_eval/
bleu.rs

1//! BLEU (Bilingual Evaluation Understudy) implementation.
2//!
3//! Follows Papineni et al. (2002) — *BLEU: a Method for Automatic Evaluation
4//! of Machine Translation* — with additional Chen & Cherry (2014) smoothing
5//! variants for sentence-level BLEU on sparse n-grams.
6//!
7//! ## Algorithm
8//!
9//! For each n ∈ `[1..=max_n]`:
10//!
11//! ```text
12//! p_n = Σ_{ngram ∈ candidate} min(count_cand(ngram), max_ref_count(ngram))
13//!       / Σ_{ngram ∈ candidate} count_cand(ngram)
14//! ```
15//!
16//! Brevity penalty (BP):
17//!
18//! ```text
19//! BP = 1                 if c > r
20//!    = exp(1 - r/c)      if 0 < c ≤ r
21//!    = 0                 if c == 0
22//! ```
23//!
24//! where `c = sum of candidate lengths` and `r = sum of *closest* reference
25//! lengths` (with shortest-ref tie-break).
26//!
27//! Final score:
28//!
29//! ```text
30//! BLEU = BP · exp( Σ_n (1/N) · log p_n )
31//! ```
32//!
33//! ## Smoothing (sparse sentence-level)
34//!
35//! - [`SmoothingMethod::None`] — unsmoothed (zero if any p_n = 0).
36//! - [`SmoothingMethod::AddOne`] — Laplace smoothing on matches > 0.
37//! - [`SmoothingMethod::ExpDecay`] — Chen & Cherry 2014 method 3.
38//!
39//! ## Empty candidate
40//!
41//! Returns `BleuScore { bleu: 0.0, precisions: [0.0; N], brevity_penalty: 0.0,
42//! length_ratio: 0.0 }`.
43
44use std::collections::HashMap;
45
46use crate::rouge::{tokenize, TokenSeq};
47
48/// Smoothing strategy for sentence-level / sparse BLEU.
49///
50/// Corpus-level BLEU aggregates counts first and should generally use
51/// [`SmoothingMethod::None`].
52#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
53pub enum SmoothingMethod {
54    /// No smoothing (classic Papineni — geometric mean may go to 0).
55    #[default]
56    None,
57    /// Add 1 to both numerator and denominator of each p_n.
58    ///
59    /// This is Chen & Cherry 2014 method 2 (a.k.a. Laplace / Lidstone-1).
60    AddOne,
61    /// Exponentially decaying smoothing (Chen & Cherry 2014 method 3).
62    ///
63    /// When a modified precision is zero, fall back to `1 / (2^k · c)` for
64    /// the k-th consecutive zero, where `c` is candidate length.
65    ExpDecay,
66}
67
68/// One BLEU score for a (candidate, references) pair.
69#[derive(Debug, Clone)]
70pub struct BleuScore {
71    /// Overall BLEU in `[0, 1]`.
72    pub bleu: f32,
73    /// Per-order modified precisions p_1, p_2, …, p_N.
74    pub precisions: Vec<f32>,
75    /// Brevity penalty factor applied.
76    pub brevity_penalty: f32,
77    /// Ratio `c / r` (candidate length over effective reference length).
78    pub length_ratio: f32,
79}
80
81impl BleuScore {
82    fn zero(max_n: usize) -> Self {
83        Self {
84            bleu: 0.0,
85            precisions: vec![0.0; max_n],
86            brevity_penalty: 0.0,
87            length_ratio: 0.0,
88        }
89    }
90}
91
92/// BLEU configuration.
93///
94/// Marked `#[non_exhaustive]` so future fields (e.g. custom n-gram weights)
95/// can be added without breaking downstream code. Construct via
96/// [`BleuConfig::default`] or [`BleuConfig::new`].
97#[derive(Debug, Clone)]
98#[non_exhaustive]
99pub struct BleuConfig {
100    /// Maximum n-gram order (`max_n`); default 4.
101    pub max_n: usize,
102    /// Smoothing method for sparse p_n (default [`SmoothingMethod::None`]).
103    pub smoothing: SmoothingMethod,
104}
105
106impl Default for BleuConfig {
107    fn default() -> Self {
108        Self {
109            max_n: 4,
110            smoothing: SmoothingMethod::None,
111        }
112    }
113}
114
115impl BleuConfig {
116    /// Build a config with explicit parameters.
117    pub fn new(max_n: usize, smoothing: SmoothingMethod) -> Self {
118        Self {
119            max_n: max_n.max(1),
120            smoothing,
121        }
122    }
123}
124
125// ──────────────────────────────────────────────────────────────────────────────
126// Sentence-level BLEU
127// ──────────────────────────────────────────────────────────────────────────────
128
129/// Sentence BLEU for a single candidate against one-or-more references.
130///
131/// Tokenization is word-level (see [`crate::rouge::tokenize`]).
132pub fn sentence_bleu(candidate: &str, references: &[&str], cfg: &BleuConfig) -> BleuScore {
133    let cand = tokenize(candidate);
134    let refs: Vec<TokenSeq> = references.iter().map(|r| tokenize(r)).collect();
135    sentence_bleu_tokens(&cand, &refs, cfg)
136}
137
138/// Sentence BLEU from pre-tokenised inputs.
139pub fn sentence_bleu_tokens(
140    candidate: &TokenSeq,
141    references: &[TokenSeq],
142    cfg: &BleuConfig,
143) -> BleuScore {
144    if candidate.is_empty() {
145        return BleuScore::zero(cfg.max_n);
146    }
147    if references.is_empty() {
148        return BleuScore::zero(cfg.max_n);
149    }
150
151    let c_len = candidate.len();
152    let r_len = closest_ref_length(c_len, references);
153
154    let mut precisions = Vec::with_capacity(cfg.max_n);
155    let mut log_precision_sum = 0.0f64;
156    let mut zero_streak = 0usize;
157    let mut collapsed = false;
158
159    for n in 1..=cfg.max_n {
160        let (matches, total) = match_counts_sentence(candidate, references, n);
161        let (p_n, used_total) =
162            apply_smoothing(matches, total, cfg.smoothing, c_len, &mut zero_streak);
163        precisions.push(p_n);
164
165        if used_total == 0 || p_n <= 0.0 {
166            collapsed = true;
167            log_precision_sum = f64::NEG_INFINITY;
168        } else if !collapsed {
169            log_precision_sum += (p_n as f64).ln();
170        }
171    }
172
173    let bp = brevity_penalty(c_len, r_len);
174    let length_ratio = if r_len == 0 {
175        0.0
176    } else {
177        c_len as f32 / r_len as f32
178    };
179
180    let bleu = if collapsed {
181        0.0
182    } else {
183        let n = cfg.max_n as f64;
184        let geo = (log_precision_sum / n).exp();
185        (bp as f64 * geo) as f32
186    };
187
188    BleuScore {
189        bleu,
190        precisions,
191        brevity_penalty: bp,
192        length_ratio,
193    }
194}
195
196// ──────────────────────────────────────────────────────────────────────────────
197// Corpus-level BLEU
198// ──────────────────────────────────────────────────────────────────────────────
199
200/// Corpus BLEU: aggregate modified precision counts across all sentences
201/// before computing the geometric mean.
202///
203/// `references[i]` is the list of reference translations for the i-th
204/// candidate. At least one reference per candidate is required.
205pub fn corpus_bleu(candidates: &[&str], references: &[Vec<&str>], cfg: &BleuConfig) -> BleuScore {
206    let cands: Vec<TokenSeq> = candidates.iter().map(|c| tokenize(c)).collect();
207    let refs: Vec<Vec<TokenSeq>> = references
208        .iter()
209        .map(|refs_i| refs_i.iter().map(|r| tokenize(r)).collect())
210        .collect();
211    corpus_bleu_tokens(&cands, &refs, cfg)
212}
213
214/// Corpus BLEU from pre-tokenised inputs.
215pub fn corpus_bleu_tokens(
216    candidates: &[TokenSeq],
217    references: &[Vec<TokenSeq>],
218    cfg: &BleuConfig,
219) -> BleuScore {
220    if candidates.is_empty() || candidates.iter().all(|c| c.is_empty()) {
221        return BleuScore::zero(cfg.max_n);
222    }
223    let n_eff = candidates.len().min(references.len());
224    if n_eff == 0 {
225        return BleuScore::zero(cfg.max_n);
226    }
227
228    let mut total_c_len = 0usize;
229    let mut total_r_len = 0usize;
230    let mut match_by_n = vec![0u64; cfg.max_n];
231    let mut total_by_n = vec![0u64; cfg.max_n];
232
233    for i in 0..n_eff {
234        let cand = &candidates[i];
235        let refs = &references[i];
236        if cand.is_empty() || refs.is_empty() {
237            continue;
238        }
239        total_c_len += cand.len();
240        total_r_len += closest_ref_length(cand.len(), refs);
241
242        for n in 1..=cfg.max_n {
243            let (m, t) = match_counts_sentence(cand, refs, n);
244            match_by_n[n - 1] += m as u64;
245            total_by_n[n - 1] += t as u64;
246        }
247    }
248
249    if total_c_len == 0 {
250        return BleuScore::zero(cfg.max_n);
251    }
252
253    // Corpus BLEU uses a single smoothing decision per n across the whole
254    // corpus. For `ExpDecay`, the "zero streak" restarts per corpus evaluation.
255    let mut precisions = Vec::with_capacity(cfg.max_n);
256    let mut log_sum = 0.0f64;
257    let mut collapsed = false;
258    let mut zero_streak = 0usize;
259
260    for n in 0..cfg.max_n {
261        let m = match_by_n[n] as usize;
262        let t = total_by_n[n] as usize;
263        let (p_n, used_total) = apply_smoothing(m, t, cfg.smoothing, total_c_len, &mut zero_streak);
264        precisions.push(p_n);
265        if used_total == 0 || p_n <= 0.0 {
266            collapsed = true;
267            log_sum = f64::NEG_INFINITY;
268        } else if !collapsed {
269            log_sum += (p_n as f64).ln();
270        }
271    }
272
273    let bp = brevity_penalty(total_c_len, total_r_len);
274    let length_ratio = if total_r_len == 0 {
275        0.0
276    } else {
277        total_c_len as f32 / total_r_len as f32
278    };
279
280    let bleu = if collapsed {
281        0.0
282    } else {
283        let nn = cfg.max_n as f64;
284        (bp as f64 * (log_sum / nn).exp()) as f32
285    };
286
287    BleuScore {
288        bleu,
289        precisions,
290        brevity_penalty: bp,
291        length_ratio,
292    }
293}
294
295// ──────────────────────────────────────────────────────────────────────────────
296// Internals
297// ──────────────────────────────────────────────────────────────────────────────
298
299/// Compute clipped match count and total candidate n-gram count for a single sentence.
300fn match_counts_sentence(cand: &TokenSeq, refs: &[TokenSeq], n: usize) -> (usize, usize) {
301    let cand_counts = ngram_counts(cand, n);
302    let total: usize = cand_counts.values().sum();
303    if total == 0 {
304        return (0, 0);
305    }
306
307    // Maximum reference count per n-gram across all references.
308    let mut max_ref: HashMap<Vec<String>, usize> = HashMap::new();
309    for r in refs {
310        let rc = ngram_counts(r, n);
311        for (k, v) in rc {
312            let e = max_ref.entry(k).or_insert(0);
313            if v > *e {
314                *e = v;
315            }
316        }
317    }
318
319    let mut matches = 0usize;
320    for (ngram, &cand_count) in &cand_counts {
321        if let Some(&rc) = max_ref.get(ngram) {
322            matches += cand_count.min(rc);
323        }
324    }
325    (matches, total)
326}
327
328fn ngram_counts(tokens: &TokenSeq, n: usize) -> HashMap<Vec<String>, usize> {
329    let mut counts: HashMap<Vec<String>, usize> = HashMap::new();
330    if n == 0 || tokens.len() < n {
331        return counts;
332    }
333    for w in tokens.windows(n) {
334        *counts.entry(w.to_vec()).or_insert(0) += 1;
335    }
336    counts
337}
338
339/// Find the length of the reference *closest* to the candidate (shortest tie-break).
340fn closest_ref_length(c_len: usize, refs: &[TokenSeq]) -> usize {
341    let mut best: Option<(usize, usize)> = None; // (abs_diff, len)
342    for r in refs {
343        let r_len = r.len();
344        let diff = r_len.max(c_len) - r_len.min(c_len);
345        match best {
346            None => best = Some((diff, r_len)),
347            Some((bd, bl)) => {
348                if diff < bd || (diff == bd && r_len < bl) {
349                    best = Some((diff, r_len));
350                }
351            }
352        }
353    }
354    best.map(|(_, l)| l).unwrap_or(0)
355}
356
357fn brevity_penalty(c_len: usize, r_len: usize) -> f32 {
358    if c_len == 0 {
359        return 0.0;
360    }
361    if c_len > r_len {
362        return 1.0;
363    }
364    (1.0f64 - r_len as f64 / c_len as f64).exp() as f32
365}
366
367/// Returns `(p_n, effective_denominator)`; if the denominator is 0,
368/// caller treats the score as collapsed.
369fn apply_smoothing(
370    matches: usize,
371    total: usize,
372    method: SmoothingMethod,
373    c_len: usize,
374    zero_streak: &mut usize,
375) -> (f32, usize) {
376    match method {
377        SmoothingMethod::None => {
378            if total == 0 {
379                (0.0, 0)
380            } else {
381                (matches as f32 / total as f32, total)
382            }
383        }
384        SmoothingMethod::AddOne => {
385            if total == 0 {
386                (0.0, 0)
387            } else if matches == 0 {
388                // When matches are 0, classic add-one gives 1/(total+1). We
389                // follow Chen & Cherry method 2: add 1 to both numerator and
390                // denominator when there's a zero.
391                (1.0 / (total as f32 + 1.0), total + 1)
392            } else {
393                ((matches as f32 + 1.0) / (total as f32 + 1.0), total + 1)
394            }
395        }
396        SmoothingMethod::ExpDecay => {
397            if total == 0 {
398                return (0.0, 0);
399            }
400            if matches == 0 {
401                *zero_streak += 1;
402                let k = *zero_streak as f32;
403                let denom = (2.0f32).powf(k) * c_len.max(1) as f32;
404                (1.0 / denom, total)
405            } else {
406                *zero_streak = 0;
407                (matches as f32 / total as f32, total)
408            }
409        }
410    }
411}