Skip to main content

entrenar/eval/generative/
text_gen.rs

1//! Text generation evaluation metrics
2//!
3//! Provides BLEU, ROUGE (1, 2, L), and Perplexity for evaluating
4//! text generation, translation, and summarization models.
5
6use std::collections::HashMap;
7
8/// Compute BLEU score with modified n-gram precision and brevity penalty.
9///
10/// Implements the original BLEU algorithm (Papineni et al., 2002).
11/// Returns a value in [0, 1] where 1.0 indicates perfect match.
12///
13/// # Arguments
14/// * `references` - One or more reference translations
15/// * `hypothesis` - The candidate translation
16/// * `max_n` - Maximum n-gram order (typically 4)
17pub fn bleu_score(references: &[&str], hypothesis: &str, max_n: usize) -> f64 {
18    if references.is_empty() || hypothesis.is_empty() {
19        return 0.0;
20    }
21
22    let hyp_tokens: Vec<&str> = hypothesis.split_whitespace().collect();
23    if hyp_tokens.is_empty() {
24        return 0.0;
25    }
26
27    let ref_token_lists: Vec<Vec<&str>> =
28        references.iter().map(|r| r.split_whitespace().collect()).collect();
29
30    // Compute modified precision for each n-gram order
31    let mut log_precisions = Vec::new();
32    for n in 1..=max_n {
33        let (clipped, total) = modified_precision(&ref_token_lists, &hyp_tokens, n);
34        if total == 0 {
35            return 0.0;
36        }
37        let precision = clipped as f64 / total as f64;
38        if precision == 0.0 {
39            return 0.0;
40        }
41        log_precisions.push(precision.max(f64::MIN_POSITIVE).ln());
42    }
43
44    // Geometric mean of precisions (uniform weights)
45    let avg_log_precision: f64 =
46        log_precisions.iter().sum::<f64>() / log_precisions.len().max(1) as f64;
47
48    // Brevity penalty
49    let hyp_len = hyp_tokens.len();
50    let closest_ref_len = ref_token_lists
51        .iter()
52        .map(Vec::len)
53        .min_by_key(|&len| (len as isize - hyp_len as isize).unsigned_abs())
54        .unwrap_or(0);
55
56    let bp = if hyp_len >= closest_ref_len {
57        1.0
58    } else if closest_ref_len == 0 {
59        0.0
60    } else {
61        (1.0 - closest_ref_len as f64 / hyp_len as f64).exp()
62    };
63
64    bp * avg_log_precision.exp()
65}
66
67/// Modified n-gram precision: count clipped matches against all references.
68fn modified_precision(references: &[Vec<&str>], hypothesis: &[&str], n: usize) -> (usize, usize) {
69    let hyp_ngrams = extract_ngrams(hypothesis, n);
70    let total: usize = hyp_ngrams.values().sum();
71
72    let mut clipped = 0usize;
73    for (ngram, &hyp_count) in &hyp_ngrams {
74        let max_ref_count = references
75            .iter()
76            .map(|r| {
77                let ref_ngrams = extract_ngrams(r, n);
78                ref_ngrams.get(ngram).copied().unwrap_or(0)
79            })
80            .max()
81            .unwrap_or(0);
82        clipped += hyp_count.min(max_ref_count);
83    }
84
85    (clipped, total)
86}
87
88/// Extract n-grams from a token sequence and count occurrences.
89fn extract_ngrams<'a>(tokens: &[&'a str], n: usize) -> HashMap<Vec<&'a str>, usize> {
90    let mut counts = HashMap::new();
91    if tokens.len() >= n {
92        for window in tokens.windows(n) {
93            *counts.entry(window.to_vec()).or_insert(0) += 1;
94        }
95    }
96    counts
97}
98
99/// Compute ROUGE-N F1 score (n-gram overlap between reference and hypothesis).
100///
101/// Returns F1 score in [0, 1].
102pub fn rouge_n(reference: &str, hypothesis: &str, n: usize) -> f64 {
103    let ref_tokens: Vec<&str> = reference.split_whitespace().collect();
104    let hyp_tokens: Vec<&str> = hypothesis.split_whitespace().collect();
105
106    if ref_tokens.len() < n || hyp_tokens.len() < n {
107        return 0.0;
108    }
109
110    let ref_ngrams = extract_ngrams(&ref_tokens, n);
111    let hyp_ngrams = extract_ngrams(&hyp_tokens, n);
112
113    let mut overlap = 0usize;
114    for (ngram, &hyp_count) in &hyp_ngrams {
115        let ref_count = ref_ngrams.get(ngram).copied().unwrap_or(0);
116        overlap += hyp_count.min(ref_count);
117    }
118
119    let ref_total: usize = ref_ngrams.values().sum();
120    let hyp_total: usize = hyp_ngrams.values().sum();
121
122    if ref_total == 0 || hyp_total == 0 {
123        return 0.0;
124    }
125
126    let precision = overlap as f64 / hyp_total as f64;
127    let recall = overlap as f64 / ref_total as f64;
128
129    if precision + recall == 0.0 {
130        return 0.0;
131    }
132
133    2.0 * precision * recall / (precision + recall)
134}
135
136/// Compute ROUGE-L F1 score using longest common subsequence.
137///
138/// Returns F1 score in [0, 1].
139pub fn rouge_l(reference: &str, hypothesis: &str) -> f64 {
140    let ref_tokens: Vec<&str> = reference.split_whitespace().collect();
141    let hyp_tokens: Vec<&str> = hypothesis.split_whitespace().collect();
142
143    if ref_tokens.is_empty() || hyp_tokens.is_empty() {
144        return 0.0;
145    }
146
147    let lcs_len = lcs_length(&ref_tokens, &hyp_tokens);
148
149    let precision = lcs_len as f64 / hyp_tokens.len() as f64;
150    let recall = lcs_len as f64 / ref_tokens.len() as f64;
151
152    if precision + recall == 0.0 {
153        return 0.0;
154    }
155
156    2.0 * precision * recall / (precision + recall)
157}
158
159/// Compute length of longest common subsequence.
160fn lcs_length(a: &[&str], b: &[&str]) -> usize {
161    let n = a.len();
162    let m = b.len();
163    let mut dp = vec![vec![0usize; m + 1]; n + 1];
164
165    for i in 1..=n {
166        for j in 1..=m {
167            if a[i - 1] == b[j - 1] {
168                dp[i][j] = dp[i - 1][j - 1] + 1;
169            } else {
170                dp[i][j] = dp[i - 1][j].max(dp[i][j - 1]);
171            }
172        }
173    }
174
175    dp[n][m]
176}
177
178/// Compute perplexity from log-probabilities.
179///
180/// Perplexity = exp(-1/N * sum(log_probs))
181///
182/// Lower is better. Returns >= 1.0 for valid probability distributions.
183/// Returns `f64::INFINITY` for empty input.
184pub fn perplexity(log_probs: &[f64]) -> f64 {
185    if log_probs.is_empty() {
186        return f64::INFINITY;
187    }
188
189    let avg_neg_log_prob = -log_probs.iter().sum::<f64>() / log_probs.len() as f64;
190    avg_neg_log_prob.exp()
191}