Skip to main content

scirs2_text/evaluation/
bleu.rs

1//! # BLEU Score (Bilingual Evaluation Understudy)
2//!
3//! Implementation of BLEU score for machine translation evaluation
4//! (Papineni et al. 2002). Supports both corpus-level and sentence-level
5//! BLEU with multiple smoothing methods.
6//!
7//! ## Overview
8//!
9//! BLEU measures how many n-grams in the hypothesis (candidate translation)
10//! appear in the reference(s). It combines modified n-gram precision for
11//! n=1..4 with a brevity penalty to discourage overly short translations.
12//!
13//! Formula: BLEU = BP * exp(sum(w_n * log(p_n)))
14//!
15//! where:
16//! - BP = exp(min(0, 1 - ref_len/hyp_len)) is the brevity penalty
17//! - p_n is the modified n-gram precision for order n
18//! - w_n is the weight for order n (default: uniform 1/max_n)
19//!
20//! ## Examples
21//!
22//! ```rust
23//! use scirs2_text::evaluation::bleu::{corpus_bleu, sentence_bleu, SmoothingMethod};
24//!
25//! let hypothesis = vec!["the", "cat", "sat", "on", "the", "mat"];
26//! let reference = vec![vec!["the", "cat", "is", "on", "the", "mat"]];
27//! let score = sentence_bleu(&hypothesis, &reference, 4, SmoothingMethod::AddEpsilon(0.1))
28//!     .expect("Operation failed");
29//! assert!(score > 0.0 && score < 1.0);
30//! ```
31
32use std::collections::HashMap;
33
34use crate::error::{Result, TextError};
35
36/// Smoothing method for sentence-level BLEU.
37///
38/// At corpus level, n-gram counts are aggregated and smoothing is typically
39/// unnecessary. At sentence level, zero n-gram counts for higher orders are
40/// common and smoothing prevents the score from collapsing to zero.
41#[derive(Debug, Clone, Copy, PartialEq)]
42pub enum SmoothingMethod {
43    /// No smoothing (standard BLEU). If any n-gram precision is zero,
44    /// the overall score is zero.
45    None,
46    /// Add-epsilon smoothing (Chen & Cherry method 1).
47    /// Adds epsilon to both numerator and denominator for n-gram orders
48    /// where the count is zero.
49    AddEpsilon(f64),
50    /// Exponential decay smoothing (Chen & Cherry method 2).
51    /// For n-gram orders with zero matches, use 1/(2^k) where k is
52    /// the number of consecutive zero-count orders.
53    ExponentialDecay,
54}
55
56/// Configuration for BLEU score computation.
57#[derive(Debug, Clone)]
58pub struct BleuConfig {
59    /// Maximum n-gram order (default: 4).
60    pub max_n: usize,
61    /// Weights for each n-gram order. If None, uniform weights (1/max_n) are used.
62    pub weights: Option<Vec<f64>>,
63    /// Smoothing method for sentence-level BLEU.
64    pub smoothing: SmoothingMethod,
65}
66
67impl Default for BleuConfig {
68    fn default() -> Self {
69        Self {
70            max_n: 4,
71            weights: None,
72            smoothing: SmoothingMethod::None,
73        }
74    }
75}
76
77/// Extract n-grams of a given order from a token sequence.
78fn extract_ngrams<'a>(tokens: &'a [&str], n: usize) -> HashMap<Vec<&'a str>, usize> {
79    let mut counts: HashMap<Vec<&'a str>, usize> = HashMap::new();
80    if tokens.len() >= n {
81        for i in 0..=(tokens.len() - n) {
82            let ngram = tokens[i..i + n].to_vec();
83            *counts.entry(ngram).or_insert(0) += 1;
84        }
85    }
86    counts
87}
88
89/// Compute modified n-gram precision for a single hypothesis against
90/// multiple references.
91///
92/// For each n-gram in the hypothesis, its clipped count is
93/// min(hyp_count, max_ref_count). The modified precision is
94/// sum(clipped_counts) / sum(hyp_counts).
95fn modified_precision(hypothesis: &[&str], references: &[Vec<&str>], n: usize) -> (usize, usize) {
96    let hyp_ngrams = extract_ngrams(hypothesis, n);
97
98    if hyp_ngrams.is_empty() {
99        return (0, 0);
100    }
101
102    // For each n-gram, find the maximum count across all references
103    let mut max_ref_counts: HashMap<Vec<&str>, usize> = HashMap::new();
104    for reference in references {
105        let ref_ngrams = extract_ngrams(reference, n);
106        for (ngram, count) in &ref_ngrams {
107            let entry = max_ref_counts.entry(ngram.clone()).or_insert(0);
108            if *count > *entry {
109                *entry = *count;
110            }
111        }
112    }
113
114    // Compute clipped counts
115    let mut clipped_count = 0usize;
116    let mut total_count = 0usize;
117
118    for (ngram, hyp_count) in &hyp_ngrams {
119        let max_ref = max_ref_counts.get(ngram).copied().unwrap_or(0);
120        clipped_count += (*hyp_count).min(max_ref);
121        total_count += *hyp_count;
122    }
123
124    (clipped_count, total_count)
125}
126
127/// Compute the closest reference length for brevity penalty.
128///
129/// Among all references, select the one whose length is closest to
130/// the hypothesis length. In case of a tie, use the shorter reference.
131fn closest_ref_length(hyp_len: usize, references: &[Vec<&str>]) -> usize {
132    let mut best_len = 0usize;
133    let mut best_diff = usize::MAX;
134
135    for reference in references {
136        let ref_len = reference.len();
137        let diff = ref_len.abs_diff(hyp_len);
138        if diff < best_diff || (diff == best_diff && ref_len < best_len) {
139            best_diff = diff;
140            best_len = ref_len;
141        }
142    }
143
144    best_len
145}
146
147/// Compute brevity penalty.
148///
149/// BP = exp(min(0, 1 - ref_len/hyp_len))
150fn brevity_penalty(hyp_len: usize, ref_len: usize) -> f64 {
151    if hyp_len == 0 {
152        return 0.0;
153    }
154    let ratio = ref_len as f64 / hyp_len as f64;
155    if ratio > 1.0 {
156        (1.0 - ratio).exp()
157    } else {
158        1.0
159    }
160}
161
162/// Compute corpus-level BLEU score.
163///
164/// Aggregates n-gram counts across all sentence pairs before computing
165/// precision. This is the standard way to compute BLEU for evaluation.
166///
167/// # Arguments
168///
169/// * `hypotheses` - List of hypothesis sentences, each as a slice of tokens.
170/// * `references` - List of reference sets. Each entry is a Vec of reference
171///   sentences for the corresponding hypothesis. Each reference is a Vec of tokens.
172/// * `max_n` - Maximum n-gram order (typically 4).
173///
174/// # Returns
175///
176/// The corpus-level BLEU score in `[0.0, 1.0]`.
177///
178/// # Errors
179///
180/// Returns `TextError::InvalidInput` if inputs are empty or mismatched.
181pub fn corpus_bleu(
182    hypotheses: &[Vec<&str>],
183    references: &[Vec<Vec<&str>>],
184    max_n: usize,
185) -> Result<f64> {
186    if hypotheses.is_empty() {
187        return Err(TextError::InvalidInput(
188            "Hypotheses list must not be empty".to_string(),
189        ));
190    }
191    if hypotheses.len() != references.len() {
192        return Err(TextError::InvalidInput(format!(
193            "Number of hypotheses ({}) must match number of reference sets ({})",
194            hypotheses.len(),
195            references.len()
196        )));
197    }
198    if max_n == 0 {
199        return Err(TextError::InvalidInput(
200            "max_n must be at least 1".to_string(),
201        ));
202    }
203
204    // Validate that each reference set is non-empty
205    for (i, refs) in references.iter().enumerate() {
206        if refs.is_empty() {
207            return Err(TextError::InvalidInput(format!(
208                "Reference set at index {} must not be empty",
209                i
210            )));
211        }
212    }
213
214    let weights: Vec<f64> = vec![1.0 / max_n as f64; max_n];
215
216    // Aggregate counts across the corpus
217    let mut total_clipped = vec![0usize; max_n];
218    let mut total_count = vec![0usize; max_n];
219    let mut total_hyp_len = 0usize;
220    let mut total_ref_len = 0usize;
221
222    for (hyp, refs) in hypotheses.iter().zip(references.iter()) {
223        total_hyp_len += hyp.len();
224        total_ref_len += closest_ref_length(hyp.len(), refs);
225
226        for n in 1..=max_n {
227            let (clipped, count) = modified_precision(hyp, refs, n);
228            total_clipped[n - 1] += clipped;
229            total_count[n - 1] += count;
230        }
231    }
232
233    // Compute log-averaged precision
234    let mut log_avg = 0.0f64;
235    for n in 0..max_n {
236        if total_count[n] == 0 || total_clipped[n] == 0 {
237            // If any n-gram precision is zero, corpus BLEU is zero
238            return Ok(0.0);
239        }
240        let precision = total_clipped[n] as f64 / total_count[n] as f64;
241        log_avg += weights[n] * precision.ln();
242    }
243
244    let bp = brevity_penalty(total_hyp_len, total_ref_len);
245    Ok(bp * log_avg.exp())
246}
247
248/// Compute sentence-level BLEU score with optional smoothing.
249///
250/// # Arguments
251///
252/// * `hypothesis` - The hypothesis sentence as a slice of tokens.
253/// * `references` - One or more reference sentences, each as a Vec of tokens.
254/// * `max_n` - Maximum n-gram order (typically 4).
255/// * `smoothing` - Smoothing method to handle zero n-gram counts.
256///
257/// # Returns
258///
259/// The sentence-level BLEU score in `[0.0, 1.0]`.
260///
261/// # Errors
262///
263/// Returns `TextError::InvalidInput` if inputs are invalid.
264pub fn sentence_bleu(
265    hypothesis: &[&str],
266    references: &[Vec<&str>],
267    max_n: usize,
268    smoothing: SmoothingMethod,
269) -> Result<f64> {
270    if references.is_empty() {
271        return Err(TextError::InvalidInput(
272            "References must not be empty".to_string(),
273        ));
274    }
275    if max_n == 0 {
276        return Err(TextError::InvalidInput(
277            "max_n must be at least 1".to_string(),
278        ));
279    }
280
281    if hypothesis.is_empty() {
282        return Ok(0.0);
283    }
284
285    let weights: Vec<f64> = vec![1.0 / max_n as f64; max_n];
286    let ref_len = closest_ref_length(hypothesis.len(), references);
287    let bp = brevity_penalty(hypothesis.len(), ref_len);
288
289    let mut log_avg = 0.0f64;
290    let mut consecutive_zeros = 0u32;
291
292    for n in 1..=max_n {
293        let (clipped, count) = modified_precision(hypothesis, references, n);
294
295        let precision = match smoothing {
296            SmoothingMethod::None => {
297                if count == 0 || clipped == 0 {
298                    return Ok(0.0);
299                }
300                clipped as f64 / count as f64
301            }
302            SmoothingMethod::AddEpsilon(eps) => {
303                if count == 0 {
304                    eps
305                } else {
306                    (clipped as f64 + eps) / (count as f64 + eps)
307                }
308            }
309            SmoothingMethod::ExponentialDecay => {
310                if count == 0 || clipped == 0 {
311                    consecutive_zeros += 1;
312                    1.0 / 2.0f64.powi(consecutive_zeros as i32)
313                } else {
314                    consecutive_zeros = 0;
315                    clipped as f64 / count as f64
316                }
317            }
318        };
319
320        if precision <= 0.0 {
321            return Ok(0.0);
322        }
323        log_avg += weights[n - 1] * precision.ln();
324    }
325
326    Ok(bp * log_avg.exp())
327}
328
329#[cfg(test)]
330mod tests {
331    use super::*;
332
333    #[test]
334    fn test_perfect_translation() {
335        let hypothesis = vec!["the", "cat", "is", "on", "the", "mat"];
336        let reference = vec![vec!["the", "cat", "is", "on", "the", "mat"]];
337        let score = sentence_bleu(&hypothesis, &reference, 4, SmoothingMethod::None)
338            .expect("should compute");
339        assert!(
340            (score - 1.0).abs() < 1e-9,
341            "Perfect translation should score 1.0, got {}",
342            score
343        );
344    }
345
346    #[test]
347    fn test_no_overlap() {
348        let hypothesis = vec!["a", "b", "c", "d"];
349        let reference = vec![vec!["e", "f", "g", "h"]];
350        let score = sentence_bleu(&hypothesis, &reference, 4, SmoothingMethod::None)
351            .expect("should compute");
352        assert!(
353            score.abs() < 1e-9,
354            "No overlap should score 0.0, got {}",
355            score
356        );
357    }
358
359    #[test]
360    fn test_brevity_penalty_applied() {
361        // Short hypothesis vs longer reference
362        let hypothesis = vec!["the", "cat"];
363        let reference = vec![vec!["the", "cat", "is", "on", "the", "mat"]];
364        let score = sentence_bleu(&hypothesis, &reference, 1, SmoothingMethod::AddEpsilon(0.1))
365            .expect("should compute");
366        // Unigram precision is high but BP should penalize
367        assert!(score < 1.0, "BP should reduce score for short hyp");
368        assert!(score > 0.0, "Score should be positive with partial match");
369    }
370
371    #[test]
372    fn test_multiple_references() {
373        let hypothesis = vec!["the", "cat", "sat", "on", "the", "mat"];
374        let references = vec![
375            vec!["the", "cat", "is", "on", "the", "mat"],
376            vec!["the", "cat", "sat", "on", "the", "mat"],
377        ];
378        let score = sentence_bleu(&hypothesis, &references, 4, SmoothingMethod::None)
379            .expect("should compute");
380        assert!(
381            (score - 1.0).abs() < 1e-9,
382            "Should match second reference perfectly, got {}",
383            score
384        );
385    }
386
387    #[test]
388    fn test_corpus_bleu_basic() {
389        let hypotheses = vec![
390            vec!["the", "cat", "is", "on", "the", "mat"],
391            vec!["there", "is", "a", "cat", "on", "the", "mat"],
392        ];
393        let references = vec![
394            vec![vec!["the", "cat", "is", "on", "the", "mat"]],
395            vec![vec!["there", "is", "a", "cat", "on", "the", "mat"]],
396        ];
397        let score = corpus_bleu(&hypotheses, &references, 4).expect("should compute");
398        assert!(
399            (score - 1.0).abs() < 1e-9,
400            "Perfect corpus should score 1.0, got {}",
401            score
402        );
403    }
404
405    #[test]
406    fn test_corpus_bleu_empty_fails() {
407        let result = corpus_bleu(&[], &[], 4);
408        assert!(result.is_err());
409    }
410
411    #[test]
412    fn test_smoothing_exponential_decay() {
413        // Short hypothesis where higher-order n-grams may be zero
414        let hypothesis = vec!["the", "cat"];
415        let reference = vec![vec!["the", "cat", "sat"]];
416        let score_none = sentence_bleu(&hypothesis, &reference, 4, SmoothingMethod::None)
417            .expect("should compute");
418        let score_smooth = sentence_bleu(
419            &hypothesis,
420            &reference,
421            4,
422            SmoothingMethod::ExponentialDecay,
423        )
424        .expect("should compute");
425        // With no smoothing, zero 3-gram/4-gram precision kills the score
426        assert!(
427            score_none.abs() < 1e-9,
428            "No smoothing should give 0 with missing n-grams"
429        );
430        assert!(
431            score_smooth > 0.0,
432            "Exponential decay smoothing should give positive score"
433        );
434    }
435
436    #[test]
437    fn test_partial_overlap() {
438        let hypothesis = vec!["the", "cat", "sat", "on", "the", "mat"];
439        let reference = vec![vec!["the", "cat", "is", "on", "the", "mat"]];
440        let score = sentence_bleu(&hypothesis, &reference, 4, SmoothingMethod::AddEpsilon(0.1))
441            .expect("should compute");
442        // Should be between 0 and 1
443        assert!(score > 0.0 && score < 1.0, "Partial overlap: got {}", score);
444    }
445}