use std::collections::HashMap;
use crate::rouge::{tokenize, TokenSeq};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum SmoothingMethod {
#[default]
None,
AddOne,
ExpDecay,
}
#[derive(Debug, Clone)]
pub struct BleuScore {
pub bleu: f32,
pub precisions: Vec<f32>,
pub brevity_penalty: f32,
pub length_ratio: f32,
}
impl BleuScore {
fn zero(max_n: usize) -> Self {
Self {
bleu: 0.0,
precisions: vec![0.0; max_n],
brevity_penalty: 0.0,
length_ratio: 0.0,
}
}
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct BleuConfig {
pub max_n: usize,
pub smoothing: SmoothingMethod,
}
impl Default for BleuConfig {
fn default() -> Self {
Self {
max_n: 4,
smoothing: SmoothingMethod::None,
}
}
}
impl BleuConfig {
pub fn new(max_n: usize, smoothing: SmoothingMethod) -> Self {
Self {
max_n: max_n.max(1),
smoothing,
}
}
}
pub fn sentence_bleu(candidate: &str, references: &[&str], cfg: &BleuConfig) -> BleuScore {
let cand = tokenize(candidate);
let refs: Vec<TokenSeq> = references.iter().map(|r| tokenize(r)).collect();
sentence_bleu_tokens(&cand, &refs, cfg)
}
pub fn sentence_bleu_tokens(
candidate: &TokenSeq,
references: &[TokenSeq],
cfg: &BleuConfig,
) -> BleuScore {
if candidate.is_empty() {
return BleuScore::zero(cfg.max_n);
}
if references.is_empty() {
return BleuScore::zero(cfg.max_n);
}
let c_len = candidate.len();
let r_len = closest_ref_length(c_len, references);
let mut precisions = Vec::with_capacity(cfg.max_n);
let mut log_precision_sum = 0.0f64;
let mut zero_streak = 0usize;
let mut collapsed = false;
for n in 1..=cfg.max_n {
let (matches, total) = match_counts_sentence(candidate, references, n);
let (p_n, used_total) =
apply_smoothing(matches, total, cfg.smoothing, c_len, &mut zero_streak);
precisions.push(p_n);
if used_total == 0 || p_n <= 0.0 {
collapsed = true;
log_precision_sum = f64::NEG_INFINITY;
} else if !collapsed {
log_precision_sum += (p_n as f64).ln();
}
}
let bp = brevity_penalty(c_len, r_len);
let length_ratio = if r_len == 0 {
0.0
} else {
c_len as f32 / r_len as f32
};
let bleu = if collapsed {
0.0
} else {
let n = cfg.max_n as f64;
let geo = (log_precision_sum / n).exp();
(bp as f64 * geo) as f32
};
BleuScore {
bleu,
precisions,
brevity_penalty: bp,
length_ratio,
}
}
pub fn corpus_bleu(candidates: &[&str], references: &[Vec<&str>], cfg: &BleuConfig) -> BleuScore {
let cands: Vec<TokenSeq> = candidates.iter().map(|c| tokenize(c)).collect();
let refs: Vec<Vec<TokenSeq>> = references
.iter()
.map(|refs_i| refs_i.iter().map(|r| tokenize(r)).collect())
.collect();
corpus_bleu_tokens(&cands, &refs, cfg)
}
pub fn corpus_bleu_tokens(
candidates: &[TokenSeq],
references: &[Vec<TokenSeq>],
cfg: &BleuConfig,
) -> BleuScore {
if candidates.is_empty() || candidates.iter().all(|c| c.is_empty()) {
return BleuScore::zero(cfg.max_n);
}
let n_eff = candidates.len().min(references.len());
if n_eff == 0 {
return BleuScore::zero(cfg.max_n);
}
let mut total_c_len = 0usize;
let mut total_r_len = 0usize;
let mut match_by_n = vec![0u64; cfg.max_n];
let mut total_by_n = vec![0u64; cfg.max_n];
for i in 0..n_eff {
let cand = &candidates[i];
let refs = &references[i];
if cand.is_empty() || refs.is_empty() {
continue;
}
total_c_len += cand.len();
total_r_len += closest_ref_length(cand.len(), refs);
for n in 1..=cfg.max_n {
let (m, t) = match_counts_sentence(cand, refs, n);
match_by_n[n - 1] += m as u64;
total_by_n[n - 1] += t as u64;
}
}
if total_c_len == 0 {
return BleuScore::zero(cfg.max_n);
}
let mut precisions = Vec::with_capacity(cfg.max_n);
let mut log_sum = 0.0f64;
let mut collapsed = false;
let mut zero_streak = 0usize;
for n in 0..cfg.max_n {
let m = match_by_n[n] as usize;
let t = total_by_n[n] as usize;
let (p_n, used_total) = apply_smoothing(m, t, cfg.smoothing, total_c_len, &mut zero_streak);
precisions.push(p_n);
if used_total == 0 || p_n <= 0.0 {
collapsed = true;
log_sum = f64::NEG_INFINITY;
} else if !collapsed {
log_sum += (p_n as f64).ln();
}
}
let bp = brevity_penalty(total_c_len, total_r_len);
let length_ratio = if total_r_len == 0 {
0.0
} else {
total_c_len as f32 / total_r_len as f32
};
let bleu = if collapsed {
0.0
} else {
let nn = cfg.max_n as f64;
(bp as f64 * (log_sum / nn).exp()) as f32
};
BleuScore {
bleu,
precisions,
brevity_penalty: bp,
length_ratio,
}
}
fn match_counts_sentence(cand: &TokenSeq, refs: &[TokenSeq], n: usize) -> (usize, usize) {
let cand_counts = ngram_counts(cand, n);
let total: usize = cand_counts.values().sum();
if total == 0 {
return (0, 0);
}
let mut max_ref: HashMap<Vec<String>, usize> = HashMap::new();
for r in refs {
let rc = ngram_counts(r, n);
for (k, v) in rc {
let e = max_ref.entry(k).or_insert(0);
if v > *e {
*e = v;
}
}
}
let mut matches = 0usize;
for (ngram, &cand_count) in &cand_counts {
if let Some(&rc) = max_ref.get(ngram) {
matches += cand_count.min(rc);
}
}
(matches, total)
}
fn ngram_counts(tokens: &TokenSeq, n: usize) -> HashMap<Vec<String>, usize> {
let mut counts: HashMap<Vec<String>, usize> = HashMap::new();
if n == 0 || tokens.len() < n {
return counts;
}
for w in tokens.windows(n) {
*counts.entry(w.to_vec()).or_insert(0) += 1;
}
counts
}
fn closest_ref_length(c_len: usize, refs: &[TokenSeq]) -> usize {
let mut best: Option<(usize, usize)> = None; for r in refs {
let r_len = r.len();
let diff = r_len.max(c_len) - r_len.min(c_len);
match best {
None => best = Some((diff, r_len)),
Some((bd, bl)) => {
if diff < bd || (diff == bd && r_len < bl) {
best = Some((diff, r_len));
}
}
}
}
best.map(|(_, l)| l).unwrap_or(0)
}
fn brevity_penalty(c_len: usize, r_len: usize) -> f32 {
if c_len == 0 {
return 0.0;
}
if c_len > r_len {
return 1.0;
}
(1.0f64 - r_len as f64 / c_len as f64).exp() as f32
}
fn apply_smoothing(
matches: usize,
total: usize,
method: SmoothingMethod,
c_len: usize,
zero_streak: &mut usize,
) -> (f32, usize) {
match method {
SmoothingMethod::None => {
if total == 0 {
(0.0, 0)
} else {
(matches as f32 / total as f32, total)
}
}
SmoothingMethod::AddOne => {
if total == 0 {
(0.0, 0)
} else if matches == 0 {
(1.0 / (total as f32 + 1.0), total + 1)
} else {
((matches as f32 + 1.0) / (total as f32 + 1.0), total + 1)
}
}
SmoothingMethod::ExpDecay => {
if total == 0 {
return (0.0, 0);
}
if matches == 0 {
*zero_streak += 1;
let k = *zero_streak as f32;
let denom = (2.0f32).powf(k) * c_len.max(1) as f32;
(1.0 / denom, total)
} else {
*zero_streak = 0;
(matches as f32 / total as f32, total)
}
}
}
}