use crate::error::{SeqError, SeqResult};
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct BertScore {
pub precision: f64,
pub recall: f64,
pub f1: f64,
}
#[derive(Debug, Clone, Default)]
pub struct BertScoreConfig {
pub baseline: Option<f64>,
}
impl BertScoreConfig {
pub fn validate(&self) -> SeqResult<()> {
if let Some(b) = self.baseline {
if !b.is_finite() || b <= -1.0 || b >= 1.0 {
return Err(SeqError::InvalidParameter {
name: "baseline".into(),
value: b,
});
}
}
Ok(())
}
fn rescale(&self, score: f64) -> f64 {
match self.baseline {
Some(b) => (score - b) / (1.0 - b),
None => score,
}
}
}
fn l2_norm(v: &[f64]) -> f64 {
v.iter().map(|&x| x * x).sum::<f64>().sqrt()
}
fn cosine(a: &[f64], na: f64, b: &[f64], nb: f64) -> f64 {
if na == 0.0 || nb == 0.0 {
return 0.0;
}
let dot: f64 = a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum();
(dot / (na * nb)).clamp(-1.0, 1.0)
}
pub fn bert_score(
candidate: &[f64],
n: usize,
reference: &[f64],
m: usize,
dim: usize,
config: &BertScoreConfig,
) -> SeqResult<BertScore> {
let cand_idf = vec![1.0; n];
let ref_idf = vec![1.0; m];
bert_score_idf(candidate, n, reference, m, dim, &cand_idf, &ref_idf, config)
}
#[allow(clippy::too_many_arguments)]
pub fn bert_score_idf(
candidate: &[f64],
n: usize,
reference: &[f64],
m: usize,
dim: usize,
cand_idf: &[f64],
ref_idf: &[f64],
config: &BertScoreConfig,
) -> SeqResult<BertScore> {
config.validate()?;
if n == 0 || m == 0 || dim == 0 {
return Err(SeqError::EmptyInput);
}
if candidate.len() != n * dim {
return Err(SeqError::ShapeMismatch {
expected: n * dim,
got: candidate.len(),
});
}
if reference.len() != m * dim {
return Err(SeqError::ShapeMismatch {
expected: m * dim,
got: reference.len(),
});
}
if cand_idf.len() != n {
return Err(SeqError::LengthMismatch {
a: cand_idf.len(),
b: n,
});
}
if ref_idf.len() != m {
return Err(SeqError::LengthMismatch {
a: ref_idf.len(),
b: m,
});
}
let mut sum_cand_idf = 0.0;
for (idx, &w) in cand_idf.iter().enumerate() {
if !(w.is_finite() && w >= 0.0) {
return Err(SeqError::InvalidParameter {
name: format!("cand_idf[{idx}]"),
value: w,
});
}
sum_cand_idf += w;
}
let mut sum_ref_idf = 0.0;
for (idx, &w) in ref_idf.iter().enumerate() {
if !(w.is_finite() && w >= 0.0) {
return Err(SeqError::InvalidParameter {
name: format!("ref_idf[{idx}]"),
value: w,
});
}
sum_ref_idf += w;
}
if sum_cand_idf <= 0.0 || sum_ref_idf <= 0.0 {
return Err(SeqError::NumericalInstability(
"IDF weights sum to zero on one side".into(),
));
}
let cand_norms: Vec<f64> = (0..n)
.map(|i| l2_norm(&candidate[i * dim..(i + 1) * dim]))
.collect();
let ref_norms: Vec<f64> = (0..m)
.map(|j| l2_norm(&reference[j * dim..(j + 1) * dim]))
.collect();
let mut row_max = vec![f64::NEG_INFINITY; n]; let mut col_max = vec![f64::NEG_INFINITY; m]; for i in 0..n {
let ci = &candidate[i * dim..(i + 1) * dim];
let ni = cand_norms[i];
for j in 0..m {
let rj = &reference[j * dim..(j + 1) * dim];
let sim = cosine(ci, ni, rj, ref_norms[j]);
if sim > row_max[i] {
row_max[i] = sim;
}
if sim > col_max[j] {
col_max[j] = sim;
}
}
}
let mut precision = 0.0;
for i in 0..n {
precision += cand_idf[i] * row_max[i];
}
precision /= sum_cand_idf;
let mut recall = 0.0;
for j in 0..m {
recall += ref_idf[j] * col_max[j];
}
recall /= sum_ref_idf;
precision = config.rescale(precision);
recall = config.rescale(recall);
let f1 = if precision + recall <= 0.0 {
0.0
} else {
2.0 * precision * recall / (precision + recall)
};
Ok(BertScore {
precision,
recall,
f1,
})
}
pub fn corpus_idf(documents: &[Vec<usize>], vocab_size: usize) -> SeqResult<Vec<f64>> {
if vocab_size == 0 || documents.is_empty() {
return Err(SeqError::EmptyInput);
}
let n_docs = documents.len() as f64;
let mut df = vec![0.0f64; vocab_size];
let mut seen = vec![false; vocab_size];
for doc in documents {
for &t in doc {
if t >= vocab_size {
return Err(SeqError::IndexOutOfBounds {
index: t,
len: vocab_size,
});
}
}
for &t in doc {
seen[t] = false;
}
for &t in doc {
if !seen[t] {
seen[t] = true;
df[t] += 1.0;
}
}
}
let idf: Vec<f64> = df
.iter()
.map(|&d| ((1.0 + n_docs) / (1.0 + d)).ln() + 1.0)
.collect();
Ok(idf)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn identical_scores_one() {
let dim = 3;
let emb = vec![
1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, ];
let cfg = BertScoreConfig::default();
let s = bert_score(&emb, 3, &emb, 3, dim, &cfg).expect("score");
assert!((s.precision - 1.0).abs() < 1e-12, "P = {}", s.precision);
assert!((s.recall - 1.0).abs() < 1e-12, "R = {}", s.recall);
assert!((s.f1 - 1.0).abs() < 1e-12, "F1 = {}", s.f1);
}
#[test]
fn orthogonal_scores_zero() {
let dim = 2;
let cand = vec![1.0, 0.0]; let reference = vec![0.0, 1.0]; let cfg = BertScoreConfig::default();
let s = bert_score(&cand, 1, &reference, 1, dim, &cfg).expect("score");
assert!(s.precision.abs() < 1e-12);
assert!(s.recall.abs() < 1e-12);
assert!(s.f1.abs() < 1e-12);
}
#[test]
fn greedy_matching_asymmetric() {
let dim = 2;
let cand = vec![1.0, 0.0]; let reference = vec![1.0, 0.0, 0.0, 1.0]; let cfg = BertScoreConfig::default();
let s = bert_score(&cand, 1, &reference, 2, dim, &cfg).expect("score");
assert!((s.precision - 1.0).abs() < 1e-12, "P = {}", s.precision);
assert!((s.recall - 0.5).abs() < 1e-12, "R = {}", s.recall);
assert!((s.f1 - (2.0 * 0.5 / 1.5)).abs() < 1e-12, "F1 = {}", s.f1);
}
#[test]
fn scale_invariance() {
let dim = 3;
let cand = vec![2.0, 0.0, 0.0];
let reference = vec![5.0, 0.0, 0.0];
let cfg = BertScoreConfig::default();
let s = bert_score(&cand, 1, &reference, 1, dim, &cfg).expect("score");
assert!((s.f1 - 1.0).abs() < 1e-12, "F1 = {}", s.f1);
}
#[test]
fn baseline_rescaling() {
let dim = 2;
let cand = vec![1.0, 0.0];
let reference = vec![0.5, 3.0f64.sqrt() / 2.0]; let cfg_raw = BertScoreConfig::default();
let raw = bert_score(&cand, 1, &reference, 1, dim, &cfg_raw).expect("raw");
assert!((raw.f1 - 0.5).abs() < 1e-9, "raw f1 = {}", raw.f1);
let b = 0.25;
let cfg = BertScoreConfig { baseline: Some(b) };
let rescaled = bert_score(&cand, 1, &reference, 1, dim, &cfg).expect("rescaled");
let expected = (0.5 - b) / (1.0 - b);
assert!(
(rescaled.precision - expected).abs() < 1e-9,
"P = {}",
rescaled.precision
);
assert!(
(rescaled.f1 - expected).abs() < 1e-9,
"F1 = {}",
rescaled.f1
);
}
#[test]
fn idf_weighting() {
let dim = 2;
let cand = vec![1.0, 0.0];
let reference = vec![1.0, 0.0, 0.0, 1.0];
let cfg = BertScoreConfig::default();
let cand_idf = vec![1.0];
let ref_idf_high_x = vec![10.0, 1.0];
let s_high = bert_score_idf(
&cand,
1,
&reference,
2,
dim,
&cand_idf,
&ref_idf_high_x,
&cfg,
)
.expect("score");
assert!(
(s_high.recall - 10.0 / 11.0).abs() < 1e-12,
"R = {}",
s_high.recall
);
let ref_idf_high_y = vec![1.0, 10.0];
let s_low = bert_score_idf(
&cand,
1,
&reference,
2,
dim,
&cand_idf,
&ref_idf_high_y,
&cfg,
)
.expect("score");
assert!(
(s_low.recall - 1.0 / 11.0).abs() < 1e-12,
"R = {}",
s_low.recall
);
assert!(s_high.recall > s_low.recall);
}
#[test]
fn corpus_idf_orders_by_rarity() {
let docs = vec![vec![0usize, 0, 1], vec![0usize], vec![0usize]];
let idf = corpus_idf(&docs, 3).expect("idf");
assert_eq!(idf.len(), 3);
assert!(idf[0] < idf[1], "{} !< {}", idf[0], idf[1]);
assert!(idf[1] < idf[2], "{} !< {}", idf[1], idf[2]);
for &w in &idf {
assert!(w > 0.0, "idf {w} not positive");
}
assert!((idf[0] - 1.0).abs() < 1e-12);
}
#[test]
fn validation_errors() {
let cfg = BertScoreConfig::default();
assert!(bert_score(&[], 0, &[1.0], 1, 1, &cfg).is_err());
assert!(bert_score(&[1.0, 2.0, 3.0], 2, &[1.0, 2.0], 1, 2, &cfg).is_err());
let bad = BertScoreConfig {
baseline: Some(1.0),
};
assert!(bad.validate().is_err());
assert!(
bert_score_idf(&[1.0, 0.0], 1, &[1.0, 0.0], 1, 2, &[1.0, 1.0], &[1.0], &cfg).is_err()
);
assert!(bert_score_idf(&[1.0, 0.0], 1, &[1.0, 0.0], 1, 2, &[-1.0], &[1.0], &cfg).is_err());
assert!(corpus_idf(&[vec![5usize]], 3).is_err());
}
}