1use std::collections::HashMap;
33
34use crate::error::{Result, TextError};
35
36#[derive(Debug, Clone, Copy, PartialEq)]
42pub enum SmoothingMethod {
43 None,
46 AddEpsilon(f64),
50 ExponentialDecay,
54}
55
56#[derive(Debug, Clone)]
58pub struct BleuConfig {
59 pub max_n: usize,
61 pub weights: Option<Vec<f64>>,
63 pub smoothing: SmoothingMethod,
65}
66
67impl Default for BleuConfig {
68 fn default() -> Self {
69 Self {
70 max_n: 4,
71 weights: None,
72 smoothing: SmoothingMethod::None,
73 }
74 }
75}
76
77fn extract_ngrams<'a>(tokens: &'a [&str], n: usize) -> HashMap<Vec<&'a str>, usize> {
79 let mut counts: HashMap<Vec<&'a str>, usize> = HashMap::new();
80 if tokens.len() >= n {
81 for i in 0..=(tokens.len() - n) {
82 let ngram = tokens[i..i + n].to_vec();
83 *counts.entry(ngram).or_insert(0) += 1;
84 }
85 }
86 counts
87}
88
89fn modified_precision(hypothesis: &[&str], references: &[Vec<&str>], n: usize) -> (usize, usize) {
96 let hyp_ngrams = extract_ngrams(hypothesis, n);
97
98 if hyp_ngrams.is_empty() {
99 return (0, 0);
100 }
101
102 let mut max_ref_counts: HashMap<Vec<&str>, usize> = HashMap::new();
104 for reference in references {
105 let ref_ngrams = extract_ngrams(reference, n);
106 for (ngram, count) in &ref_ngrams {
107 let entry = max_ref_counts.entry(ngram.clone()).or_insert(0);
108 if *count > *entry {
109 *entry = *count;
110 }
111 }
112 }
113
114 let mut clipped_count = 0usize;
116 let mut total_count = 0usize;
117
118 for (ngram, hyp_count) in &hyp_ngrams {
119 let max_ref = max_ref_counts.get(ngram).copied().unwrap_or(0);
120 clipped_count += (*hyp_count).min(max_ref);
121 total_count += *hyp_count;
122 }
123
124 (clipped_count, total_count)
125}
126
127fn closest_ref_length(hyp_len: usize, references: &[Vec<&str>]) -> usize {
132 let mut best_len = 0usize;
133 let mut best_diff = usize::MAX;
134
135 for reference in references {
136 let ref_len = reference.len();
137 let diff = ref_len.abs_diff(hyp_len);
138 if diff < best_diff || (diff == best_diff && ref_len < best_len) {
139 best_diff = diff;
140 best_len = ref_len;
141 }
142 }
143
144 best_len
145}
146
147fn brevity_penalty(hyp_len: usize, ref_len: usize) -> f64 {
151 if hyp_len == 0 {
152 return 0.0;
153 }
154 let ratio = ref_len as f64 / hyp_len as f64;
155 if ratio > 1.0 {
156 (1.0 - ratio).exp()
157 } else {
158 1.0
159 }
160}
161
162pub fn corpus_bleu(
182 hypotheses: &[Vec<&str>],
183 references: &[Vec<Vec<&str>>],
184 max_n: usize,
185) -> Result<f64> {
186 if hypotheses.is_empty() {
187 return Err(TextError::InvalidInput(
188 "Hypotheses list must not be empty".to_string(),
189 ));
190 }
191 if hypotheses.len() != references.len() {
192 return Err(TextError::InvalidInput(format!(
193 "Number of hypotheses ({}) must match number of reference sets ({})",
194 hypotheses.len(),
195 references.len()
196 )));
197 }
198 if max_n == 0 {
199 return Err(TextError::InvalidInput(
200 "max_n must be at least 1".to_string(),
201 ));
202 }
203
204 for (i, refs) in references.iter().enumerate() {
206 if refs.is_empty() {
207 return Err(TextError::InvalidInput(format!(
208 "Reference set at index {} must not be empty",
209 i
210 )));
211 }
212 }
213
214 let weights: Vec<f64> = vec![1.0 / max_n as f64; max_n];
215
216 let mut total_clipped = vec![0usize; max_n];
218 let mut total_count = vec![0usize; max_n];
219 let mut total_hyp_len = 0usize;
220 let mut total_ref_len = 0usize;
221
222 for (hyp, refs) in hypotheses.iter().zip(references.iter()) {
223 total_hyp_len += hyp.len();
224 total_ref_len += closest_ref_length(hyp.len(), refs);
225
226 for n in 1..=max_n {
227 let (clipped, count) = modified_precision(hyp, refs, n);
228 total_clipped[n - 1] += clipped;
229 total_count[n - 1] += count;
230 }
231 }
232
233 let mut log_avg = 0.0f64;
235 for n in 0..max_n {
236 if total_count[n] == 0 || total_clipped[n] == 0 {
237 return Ok(0.0);
239 }
240 let precision = total_clipped[n] as f64 / total_count[n] as f64;
241 log_avg += weights[n] * precision.ln();
242 }
243
244 let bp = brevity_penalty(total_hyp_len, total_ref_len);
245 Ok(bp * log_avg.exp())
246}
247
248pub fn sentence_bleu(
265 hypothesis: &[&str],
266 references: &[Vec<&str>],
267 max_n: usize,
268 smoothing: SmoothingMethod,
269) -> Result<f64> {
270 if references.is_empty() {
271 return Err(TextError::InvalidInput(
272 "References must not be empty".to_string(),
273 ));
274 }
275 if max_n == 0 {
276 return Err(TextError::InvalidInput(
277 "max_n must be at least 1".to_string(),
278 ));
279 }
280
281 if hypothesis.is_empty() {
282 return Ok(0.0);
283 }
284
285 let weights: Vec<f64> = vec![1.0 / max_n as f64; max_n];
286 let ref_len = closest_ref_length(hypothesis.len(), references);
287 let bp = brevity_penalty(hypothesis.len(), ref_len);
288
289 let mut log_avg = 0.0f64;
290 let mut consecutive_zeros = 0u32;
291
292 for n in 1..=max_n {
293 let (clipped, count) = modified_precision(hypothesis, references, n);
294
295 let precision = match smoothing {
296 SmoothingMethod::None => {
297 if count == 0 || clipped == 0 {
298 return Ok(0.0);
299 }
300 clipped as f64 / count as f64
301 }
302 SmoothingMethod::AddEpsilon(eps) => {
303 if count == 0 {
304 eps
305 } else {
306 (clipped as f64 + eps) / (count as f64 + eps)
307 }
308 }
309 SmoothingMethod::ExponentialDecay => {
310 if count == 0 || clipped == 0 {
311 consecutive_zeros += 1;
312 1.0 / 2.0f64.powi(consecutive_zeros as i32)
313 } else {
314 consecutive_zeros = 0;
315 clipped as f64 / count as f64
316 }
317 }
318 };
319
320 if precision <= 0.0 {
321 return Ok(0.0);
322 }
323 log_avg += weights[n - 1] * precision.ln();
324 }
325
326 Ok(bp * log_avg.exp())
327}
328
329#[cfg(test)]
330mod tests {
331 use super::*;
332
333 #[test]
334 fn test_perfect_translation() {
335 let hypothesis = vec!["the", "cat", "is", "on", "the", "mat"];
336 let reference = vec![vec!["the", "cat", "is", "on", "the", "mat"]];
337 let score = sentence_bleu(&hypothesis, &reference, 4, SmoothingMethod::None)
338 .expect("should compute");
339 assert!(
340 (score - 1.0).abs() < 1e-9,
341 "Perfect translation should score 1.0, got {}",
342 score
343 );
344 }
345
346 #[test]
347 fn test_no_overlap() {
348 let hypothesis = vec!["a", "b", "c", "d"];
349 let reference = vec![vec!["e", "f", "g", "h"]];
350 let score = sentence_bleu(&hypothesis, &reference, 4, SmoothingMethod::None)
351 .expect("should compute");
352 assert!(
353 score.abs() < 1e-9,
354 "No overlap should score 0.0, got {}",
355 score
356 );
357 }
358
359 #[test]
360 fn test_brevity_penalty_applied() {
361 let hypothesis = vec!["the", "cat"];
363 let reference = vec![vec!["the", "cat", "is", "on", "the", "mat"]];
364 let score = sentence_bleu(&hypothesis, &reference, 1, SmoothingMethod::AddEpsilon(0.1))
365 .expect("should compute");
366 assert!(score < 1.0, "BP should reduce score for short hyp");
368 assert!(score > 0.0, "Score should be positive with partial match");
369 }
370
371 #[test]
372 fn test_multiple_references() {
373 let hypothesis = vec!["the", "cat", "sat", "on", "the", "mat"];
374 let references = vec![
375 vec!["the", "cat", "is", "on", "the", "mat"],
376 vec!["the", "cat", "sat", "on", "the", "mat"],
377 ];
378 let score = sentence_bleu(&hypothesis, &references, 4, SmoothingMethod::None)
379 .expect("should compute");
380 assert!(
381 (score - 1.0).abs() < 1e-9,
382 "Should match second reference perfectly, got {}",
383 score
384 );
385 }
386
387 #[test]
388 fn test_corpus_bleu_basic() {
389 let hypotheses = vec![
390 vec!["the", "cat", "is", "on", "the", "mat"],
391 vec!["there", "is", "a", "cat", "on", "the", "mat"],
392 ];
393 let references = vec![
394 vec![vec!["the", "cat", "is", "on", "the", "mat"]],
395 vec![vec!["there", "is", "a", "cat", "on", "the", "mat"]],
396 ];
397 let score = corpus_bleu(&hypotheses, &references, 4).expect("should compute");
398 assert!(
399 (score - 1.0).abs() < 1e-9,
400 "Perfect corpus should score 1.0, got {}",
401 score
402 );
403 }
404
405 #[test]
406 fn test_corpus_bleu_empty_fails() {
407 let result = corpus_bleu(&[], &[], 4);
408 assert!(result.is_err());
409 }
410
411 #[test]
412 fn test_smoothing_exponential_decay() {
413 let hypothesis = vec!["the", "cat"];
415 let reference = vec![vec!["the", "cat", "sat"]];
416 let score_none = sentence_bleu(&hypothesis, &reference, 4, SmoothingMethod::None)
417 .expect("should compute");
418 let score_smooth = sentence_bleu(
419 &hypothesis,
420 &reference,
421 4,
422 SmoothingMethod::ExponentialDecay,
423 )
424 .expect("should compute");
425 assert!(
427 score_none.abs() < 1e-9,
428 "No smoothing should give 0 with missing n-grams"
429 );
430 assert!(
431 score_smooth > 0.0,
432 "Exponential decay smoothing should give positive score"
433 );
434 }
435
436 #[test]
437 fn test_partial_overlap() {
438 let hypothesis = vec!["the", "cat", "sat", "on", "the", "mat"];
439 let reference = vec![vec!["the", "cat", "is", "on", "the", "mat"]];
440 let score = sentence_bleu(&hypothesis, &reference, 4, SmoothingMethod::AddEpsilon(0.1))
441 .expect("should compute");
442 assert!(score > 0.0 && score < 1.0, "Partial overlap: got {}", score);
444 }
445}