1use std::collections::HashMap;
45
46use crate::rouge::{tokenize, TokenSeq};
47
48#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
53pub enum SmoothingMethod {
54 #[default]
56 None,
57 AddOne,
61 ExpDecay,
66}
67
68#[derive(Debug, Clone)]
70pub struct BleuScore {
71 pub bleu: f32,
73 pub precisions: Vec<f32>,
75 pub brevity_penalty: f32,
77 pub length_ratio: f32,
79}
80
81impl BleuScore {
82 fn zero(max_n: usize) -> Self {
83 Self {
84 bleu: 0.0,
85 precisions: vec![0.0; max_n],
86 brevity_penalty: 0.0,
87 length_ratio: 0.0,
88 }
89 }
90}
91
92#[derive(Debug, Clone)]
98#[non_exhaustive]
99pub struct BleuConfig {
100 pub max_n: usize,
102 pub smoothing: SmoothingMethod,
104}
105
106impl Default for BleuConfig {
107 fn default() -> Self {
108 Self {
109 max_n: 4,
110 smoothing: SmoothingMethod::None,
111 }
112 }
113}
114
115impl BleuConfig {
116 pub fn new(max_n: usize, smoothing: SmoothingMethod) -> Self {
118 Self {
119 max_n: max_n.max(1),
120 smoothing,
121 }
122 }
123}
124
125pub fn sentence_bleu(candidate: &str, references: &[&str], cfg: &BleuConfig) -> BleuScore {
133 let cand = tokenize(candidate);
134 let refs: Vec<TokenSeq> = references.iter().map(|r| tokenize(r)).collect();
135 sentence_bleu_tokens(&cand, &refs, cfg)
136}
137
138pub fn sentence_bleu_tokens(
140 candidate: &TokenSeq,
141 references: &[TokenSeq],
142 cfg: &BleuConfig,
143) -> BleuScore {
144 if candidate.is_empty() {
145 return BleuScore::zero(cfg.max_n);
146 }
147 if references.is_empty() {
148 return BleuScore::zero(cfg.max_n);
149 }
150
151 let c_len = candidate.len();
152 let r_len = closest_ref_length(c_len, references);
153
154 let mut precisions = Vec::with_capacity(cfg.max_n);
155 let mut log_precision_sum = 0.0f64;
156 let mut zero_streak = 0usize;
157 let mut collapsed = false;
158
159 for n in 1..=cfg.max_n {
160 let (matches, total) = match_counts_sentence(candidate, references, n);
161 let (p_n, used_total) =
162 apply_smoothing(matches, total, cfg.smoothing, c_len, &mut zero_streak);
163 precisions.push(p_n);
164
165 if used_total == 0 || p_n <= 0.0 {
166 collapsed = true;
167 log_precision_sum = f64::NEG_INFINITY;
168 } else if !collapsed {
169 log_precision_sum += (p_n as f64).ln();
170 }
171 }
172
173 let bp = brevity_penalty(c_len, r_len);
174 let length_ratio = if r_len == 0 {
175 0.0
176 } else {
177 c_len as f32 / r_len as f32
178 };
179
180 let bleu = if collapsed {
181 0.0
182 } else {
183 let n = cfg.max_n as f64;
184 let geo = (log_precision_sum / n).exp();
185 (bp as f64 * geo) as f32
186 };
187
188 BleuScore {
189 bleu,
190 precisions,
191 brevity_penalty: bp,
192 length_ratio,
193 }
194}
195
196pub fn corpus_bleu(candidates: &[&str], references: &[Vec<&str>], cfg: &BleuConfig) -> BleuScore {
206 let cands: Vec<TokenSeq> = candidates.iter().map(|c| tokenize(c)).collect();
207 let refs: Vec<Vec<TokenSeq>> = references
208 .iter()
209 .map(|refs_i| refs_i.iter().map(|r| tokenize(r)).collect())
210 .collect();
211 corpus_bleu_tokens(&cands, &refs, cfg)
212}
213
214pub fn corpus_bleu_tokens(
216 candidates: &[TokenSeq],
217 references: &[Vec<TokenSeq>],
218 cfg: &BleuConfig,
219) -> BleuScore {
220 if candidates.is_empty() || candidates.iter().all(|c| c.is_empty()) {
221 return BleuScore::zero(cfg.max_n);
222 }
223 let n_eff = candidates.len().min(references.len());
224 if n_eff == 0 {
225 return BleuScore::zero(cfg.max_n);
226 }
227
228 let mut total_c_len = 0usize;
229 let mut total_r_len = 0usize;
230 let mut match_by_n = vec![0u64; cfg.max_n];
231 let mut total_by_n = vec![0u64; cfg.max_n];
232
233 for i in 0..n_eff {
234 let cand = &candidates[i];
235 let refs = &references[i];
236 if cand.is_empty() || refs.is_empty() {
237 continue;
238 }
239 total_c_len += cand.len();
240 total_r_len += closest_ref_length(cand.len(), refs);
241
242 for n in 1..=cfg.max_n {
243 let (m, t) = match_counts_sentence(cand, refs, n);
244 match_by_n[n - 1] += m as u64;
245 total_by_n[n - 1] += t as u64;
246 }
247 }
248
249 if total_c_len == 0 {
250 return BleuScore::zero(cfg.max_n);
251 }
252
253 let mut precisions = Vec::with_capacity(cfg.max_n);
256 let mut log_sum = 0.0f64;
257 let mut collapsed = false;
258 let mut zero_streak = 0usize;
259
260 for n in 0..cfg.max_n {
261 let m = match_by_n[n] as usize;
262 let t = total_by_n[n] as usize;
263 let (p_n, used_total) = apply_smoothing(m, t, cfg.smoothing, total_c_len, &mut zero_streak);
264 precisions.push(p_n);
265 if used_total == 0 || p_n <= 0.0 {
266 collapsed = true;
267 log_sum = f64::NEG_INFINITY;
268 } else if !collapsed {
269 log_sum += (p_n as f64).ln();
270 }
271 }
272
273 let bp = brevity_penalty(total_c_len, total_r_len);
274 let length_ratio = if total_r_len == 0 {
275 0.0
276 } else {
277 total_c_len as f32 / total_r_len as f32
278 };
279
280 let bleu = if collapsed {
281 0.0
282 } else {
283 let nn = cfg.max_n as f64;
284 (bp as f64 * (log_sum / nn).exp()) as f32
285 };
286
287 BleuScore {
288 bleu,
289 precisions,
290 brevity_penalty: bp,
291 length_ratio,
292 }
293}
294
295fn match_counts_sentence(cand: &TokenSeq, refs: &[TokenSeq], n: usize) -> (usize, usize) {
301 let cand_counts = ngram_counts(cand, n);
302 let total: usize = cand_counts.values().sum();
303 if total == 0 {
304 return (0, 0);
305 }
306
307 let mut max_ref: HashMap<Vec<String>, usize> = HashMap::new();
309 for r in refs {
310 let rc = ngram_counts(r, n);
311 for (k, v) in rc {
312 let e = max_ref.entry(k).or_insert(0);
313 if v > *e {
314 *e = v;
315 }
316 }
317 }
318
319 let mut matches = 0usize;
320 for (ngram, &cand_count) in &cand_counts {
321 if let Some(&rc) = max_ref.get(ngram) {
322 matches += cand_count.min(rc);
323 }
324 }
325 (matches, total)
326}
327
328fn ngram_counts(tokens: &TokenSeq, n: usize) -> HashMap<Vec<String>, usize> {
329 let mut counts: HashMap<Vec<String>, usize> = HashMap::new();
330 if n == 0 || tokens.len() < n {
331 return counts;
332 }
333 for w in tokens.windows(n) {
334 *counts.entry(w.to_vec()).or_insert(0) += 1;
335 }
336 counts
337}
338
339fn closest_ref_length(c_len: usize, refs: &[TokenSeq]) -> usize {
341 let mut best: Option<(usize, usize)> = None; for r in refs {
343 let r_len = r.len();
344 let diff = r_len.max(c_len) - r_len.min(c_len);
345 match best {
346 None => best = Some((diff, r_len)),
347 Some((bd, bl)) => {
348 if diff < bd || (diff == bd && r_len < bl) {
349 best = Some((diff, r_len));
350 }
351 }
352 }
353 }
354 best.map(|(_, l)| l).unwrap_or(0)
355}
356
357fn brevity_penalty(c_len: usize, r_len: usize) -> f32 {
358 if c_len == 0 {
359 return 0.0;
360 }
361 if c_len > r_len {
362 return 1.0;
363 }
364 (1.0f64 - r_len as f64 / c_len as f64).exp() as f32
365}
366
367fn apply_smoothing(
370 matches: usize,
371 total: usize,
372 method: SmoothingMethod,
373 c_len: usize,
374 zero_streak: &mut usize,
375) -> (f32, usize) {
376 match method {
377 SmoothingMethod::None => {
378 if total == 0 {
379 (0.0, 0)
380 } else {
381 (matches as f32 / total as f32, total)
382 }
383 }
384 SmoothingMethod::AddOne => {
385 if total == 0 {
386 (0.0, 0)
387 } else if matches == 0 {
388 (1.0 / (total as f32 + 1.0), total + 1)
392 } else {
393 ((matches as f32 + 1.0) / (total as f32 + 1.0), total + 1)
394 }
395 }
396 SmoothingMethod::ExpDecay => {
397 if total == 0 {
398 return (0.0, 0);
399 }
400 if matches == 0 {
401 *zero_streak += 1;
402 let k = *zero_streak as f32;
403 let denom = (2.0f32).powf(k) * c_len.max(1) as f32;
404 (1.0 / denom, total)
405 } else {
406 *zero_streak = 0;
407 (matches as f32 / total as f32, total)
408 }
409 }
410 }
411}