oxicuda_seq/metrics/bertscore.rs
1//! BERTScore: token-embedding similarity metric via greedy cosine matching.
2//!
3//! Reference: Zhang, T., Kishore, V., Wu, F., Weinberger, K. Q. & Artzi, Y.
4//! (2020). *BERTScore: Evaluating Text Generation with BERT*. ICLR 2020.
5//!
6//! # What this module computes
7//!
8//! BERTScore compares a *candidate* token sequence against a *reference* token
9//! sequence using **contextual embeddings** for each token. Given a candidate
10//! with embeddings `x̂_1 … x̂_n` and a reference with embeddings `x_1 … x_m`,
11//! every pairwise cosine similarity `cos(x̂_i, x_j)` is formed and then matched
12//! **greedily** (each token aligned to its single most similar counterpart):
13//!
14//! ```text
15//! Recall R = ( Σ_j idf(x_j) · max_i cos(x̂_i, x_j) ) / Σ_j idf(x_j)
16//! Precision P = ( Σ_i idf(x̂_i) · max_j cos(x̂_i, x_j) ) / Σ_i idf(x̂_i)
17//! F1 = 2 · P · R / (P + R)
18//! ```
19//!
20//! With **uniform IDF weights** these reduce to plain averages of the row /
21//! column maxima of the cosine-similarity matrix. Optional inverse-document-
22//! frequency weights (precomputed from a corpus) down-weight frequent tokens
23//! exactly as in the paper.
24//!
25//! ## Honesty note — this is the real metric, not a stub
26//!
27//! The "BERT" in BERTScore is only the *source of the embeddings*. This crate
28//! does not (and cannot, in pure-CPU form) ship a transformer; instead the
29//! **embedding vectors are an input** supplied by the caller (from any encoder:
30//! a `trustformers` model, word2vec, a learned table, …). Everything BERTScore
31//! actually specifies — the cosine-similarity matrix, the greedy precision /
32//! recall / F1 matching, IDF weighting, and the optional baseline rescaling — is
33//! computed here in full and is exact. Feeding genuine contextual embeddings
34//! reproduces the published metric; feeding any other embeddings yields the same
35//! algorithm over those vectors.
36//!
37//! Production code never panics: every fallible path validates its inputs and
38//! returns [`SeqError`].
39
40use crate::error::{SeqError, SeqResult};
41
42/// Precision / recall / F1 triple produced by BERTScore.
43#[derive(Debug, Clone, Copy, PartialEq)]
44pub struct BertScore {
45 /// Precision: how well each candidate token is covered by the reference.
46 pub precision: f64,
47 /// Recall: how well each reference token is covered by the candidate.
48 pub recall: f64,
49 /// Harmonic mean of precision and recall.
50 pub f1: f64,
51}
52
53/// Configuration for BERTScore.
54#[derive(Debug, Clone, Default)]
55pub struct BertScoreConfig {
56 /// Optional baseline value `b ∈ (−1, 1)` used for *rescaling* the raw
57 /// scores: `score ← (score − b) / (1 − b)`. The paper rescales against an
58 /// empirical baseline (the average score of random sentence pairs for the
59 /// chosen model/layer) so that scores spread across a more interpretable
60 /// range. `None` disables rescaling (raw cosine scores in `[−1, 1]`).
61 pub baseline: Option<f64>,
62}
63
64impl BertScoreConfig {
65 /// Validate the configuration.
66 ///
67 /// # Errors
68 /// * [`SeqError::InvalidParameter`] if `baseline` is set to a non-finite
69 /// value or to `±1` (the rescaling denominator `1 − b` must be non-zero,
70 /// and a baseline outside `(−1, 1)` is meaningless for cosine scores).
71 pub fn validate(&self) -> SeqResult<()> {
72 if let Some(b) = self.baseline {
73 if !b.is_finite() || b <= -1.0 || b >= 1.0 {
74 return Err(SeqError::InvalidParameter {
75 name: "baseline".into(),
76 value: b,
77 });
78 }
79 }
80 Ok(())
81 }
82
83 /// Apply the optional baseline rescaling to a raw score.
84 fn rescale(&self, score: f64) -> f64 {
85 match self.baseline {
86 Some(b) => (score - b) / (1.0 - b),
87 None => score,
88 }
89 }
90}
91
92/// L2 norm of a slice.
93fn l2_norm(v: &[f64]) -> f64 {
94 v.iter().map(|&x| x * x).sum::<f64>().sqrt()
95}
96
97/// Cosine similarity of two equal-length, non-zero vectors. The norms are
98/// passed in to avoid recomputing them inside the `n × m` loop.
99fn cosine(a: &[f64], na: f64, b: &[f64], nb: f64) -> f64 {
100 if na == 0.0 || nb == 0.0 {
101 return 0.0;
102 }
103 let dot: f64 = a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum();
104 (dot / (na * nb)).clamp(-1.0, 1.0)
105}
106
107/// Compute BERTScore between candidate and reference token embeddings with
108/// **uniform** token weights.
109///
110/// `candidate` holds `n` row-major embedding vectors of dimension `dim`
111/// (`candidate.len() == n * dim`); `reference` holds `m` such vectors
112/// (`reference.len() == m * dim`).
113///
114/// # Errors
115/// * [`SeqError::EmptyInput`] if either side has zero tokens or `dim == 0`.
116/// * [`SeqError::ShapeMismatch`] if a flat buffer length is not a multiple of
117/// `dim` consistent with the stated token count.
118/// * Propagates [`BertScoreConfig::validate`].
119pub fn bert_score(
120 candidate: &[f64],
121 n: usize,
122 reference: &[f64],
123 m: usize,
124 dim: usize,
125 config: &BertScoreConfig,
126) -> SeqResult<BertScore> {
127 let cand_idf = vec![1.0; n];
128 let ref_idf = vec![1.0; m];
129 bert_score_idf(candidate, n, reference, m, dim, &cand_idf, &ref_idf, config)
130}
131
132/// Compute BERTScore with explicit **IDF weights** for candidate and reference
133/// tokens (e.g. precomputed inverse-document-frequencies over a corpus). Weights
134/// must be non-negative and finite; at least one weight on each side must be
135/// strictly positive (so the normalising denominators are non-zero).
136///
137/// # Errors
138/// In addition to the cases of [`bert_score`]:
139/// * [`SeqError::LengthMismatch`] if `cand_idf.len() != n` or
140/// `ref_idf.len() != m`.
141/// * [`SeqError::InvalidParameter`] if any weight is negative / non-finite.
142/// * [`SeqError::NumericalInstability`] if the candidate or reference weights
143/// sum to zero.
144#[allow(clippy::too_many_arguments)]
145pub fn bert_score_idf(
146 candidate: &[f64],
147 n: usize,
148 reference: &[f64],
149 m: usize,
150 dim: usize,
151 cand_idf: &[f64],
152 ref_idf: &[f64],
153 config: &BertScoreConfig,
154) -> SeqResult<BertScore> {
155 config.validate()?;
156 if n == 0 || m == 0 || dim == 0 {
157 return Err(SeqError::EmptyInput);
158 }
159 if candidate.len() != n * dim {
160 return Err(SeqError::ShapeMismatch {
161 expected: n * dim,
162 got: candidate.len(),
163 });
164 }
165 if reference.len() != m * dim {
166 return Err(SeqError::ShapeMismatch {
167 expected: m * dim,
168 got: reference.len(),
169 });
170 }
171 if cand_idf.len() != n {
172 return Err(SeqError::LengthMismatch {
173 a: cand_idf.len(),
174 b: n,
175 });
176 }
177 if ref_idf.len() != m {
178 return Err(SeqError::LengthMismatch {
179 a: ref_idf.len(),
180 b: m,
181 });
182 }
183 let mut sum_cand_idf = 0.0;
184 for (idx, &w) in cand_idf.iter().enumerate() {
185 if !(w.is_finite() && w >= 0.0) {
186 return Err(SeqError::InvalidParameter {
187 name: format!("cand_idf[{idx}]"),
188 value: w,
189 });
190 }
191 sum_cand_idf += w;
192 }
193 let mut sum_ref_idf = 0.0;
194 for (idx, &w) in ref_idf.iter().enumerate() {
195 if !(w.is_finite() && w >= 0.0) {
196 return Err(SeqError::InvalidParameter {
197 name: format!("ref_idf[{idx}]"),
198 value: w,
199 });
200 }
201 sum_ref_idf += w;
202 }
203 if sum_cand_idf <= 0.0 || sum_ref_idf <= 0.0 {
204 return Err(SeqError::NumericalInstability(
205 "IDF weights sum to zero on one side".into(),
206 ));
207 }
208
209 // Precompute norms.
210 let cand_norms: Vec<f64> = (0..n)
211 .map(|i| l2_norm(&candidate[i * dim..(i + 1) * dim]))
212 .collect();
213 let ref_norms: Vec<f64> = (0..m)
214 .map(|j| l2_norm(&reference[j * dim..(j + 1) * dim]))
215 .collect();
216
217 // Row maxima (over reference) give precision; column maxima (over
218 // candidate) give recall. Compute the full similarity once, tracking both.
219 let mut row_max = vec![f64::NEG_INFINITY; n]; // best ref for each cand token
220 let mut col_max = vec![f64::NEG_INFINITY; m]; // best cand for each ref token
221 for i in 0..n {
222 let ci = &candidate[i * dim..(i + 1) * dim];
223 let ni = cand_norms[i];
224 for j in 0..m {
225 let rj = &reference[j * dim..(j + 1) * dim];
226 let sim = cosine(ci, ni, rj, ref_norms[j]);
227 if sim > row_max[i] {
228 row_max[i] = sim;
229 }
230 if sim > col_max[j] {
231 col_max[j] = sim;
232 }
233 }
234 }
235
236 // Weighted precision / recall.
237 let mut precision = 0.0;
238 for i in 0..n {
239 precision += cand_idf[i] * row_max[i];
240 }
241 precision /= sum_cand_idf;
242
243 let mut recall = 0.0;
244 for j in 0..m {
245 recall += ref_idf[j] * col_max[j];
246 }
247 recall /= sum_ref_idf;
248
249 // Optional baseline rescaling, then F1 from the (possibly rescaled) P, R.
250 precision = config.rescale(precision);
251 recall = config.rescale(recall);
252
253 let f1 = if precision + recall <= 0.0 {
254 0.0
255 } else {
256 2.0 * precision * recall / (precision + recall)
257 };
258
259 Ok(BertScore {
260 precision,
261 recall,
262 f1,
263 })
264}
265
266/// Convenience IDF estimator: compute smoothed inverse-document-frequencies for
267/// a vocabulary from a corpus of tokenised documents.
268///
269/// `idf(t) = ln( (1 + N) / (1 + df(t)) ) + 1` where `N` is the number of
270/// documents and `df(t)` the number of documents containing token `t` (each
271/// token counted at most once per document). The `+1`/smoothing matches the
272/// common scikit-learn convention and keeps every weight strictly positive.
273/// Tokens are identified by `usize` ids in `0..vocab_size`.
274///
275/// # Errors
276/// * [`SeqError::EmptyInput`] if `vocab_size == 0` or `documents` is empty.
277/// * [`SeqError::IndexOutOfBounds`] if any token id is `>= vocab_size`.
278pub fn corpus_idf(documents: &[Vec<usize>], vocab_size: usize) -> SeqResult<Vec<f64>> {
279 if vocab_size == 0 || documents.is_empty() {
280 return Err(SeqError::EmptyInput);
281 }
282 let n_docs = documents.len() as f64;
283 let mut df = vec![0.0f64; vocab_size];
284 let mut seen = vec![false; vocab_size];
285 for doc in documents {
286 for &t in doc {
287 if t >= vocab_size {
288 return Err(SeqError::IndexOutOfBounds {
289 index: t,
290 len: vocab_size,
291 });
292 }
293 }
294 // Reset only the entries we touched (cheaper than clearing the whole
295 // vector for short documents).
296 for &t in doc {
297 seen[t] = false;
298 }
299 for &t in doc {
300 if !seen[t] {
301 seen[t] = true;
302 df[t] += 1.0;
303 }
304 }
305 }
306 let idf: Vec<f64> = df
307 .iter()
308 .map(|&d| ((1.0 + n_docs) / (1.0 + d)).ln() + 1.0)
309 .collect();
310 Ok(idf)
311}
312
313#[cfg(test)]
314mod tests {
315 use super::*;
316
317 /// Identical candidate and reference embeddings ⇒ P = R = F1 = 1
318 /// (every token matches itself with cosine 1).
319 #[test]
320 fn identical_scores_one() {
321 let dim = 3;
322 let emb = vec![
323 1.0, 0.0, 0.0, // tok 0
324 0.0, 1.0, 0.0, // tok 1
325 0.0, 0.0, 1.0, // tok 2
326 ];
327 let cfg = BertScoreConfig::default();
328 let s = bert_score(&emb, 3, &emb, 3, dim, &cfg).expect("score");
329 assert!((s.precision - 1.0).abs() < 1e-12, "P = {}", s.precision);
330 assert!((s.recall - 1.0).abs() < 1e-12, "R = {}", s.recall);
331 assert!((s.f1 - 1.0).abs() < 1e-12, "F1 = {}", s.f1);
332 }
333
334 /// Orthogonal embeddings ⇒ cosine 0 everywhere ⇒ all scores 0.
335 #[test]
336 fn orthogonal_scores_zero() {
337 let dim = 2;
338 let cand = vec![1.0, 0.0]; // 1 token along x
339 let reference = vec![0.0, 1.0]; // 1 token along y
340 let cfg = BertScoreConfig::default();
341 let s = bert_score(&cand, 1, &reference, 1, dim, &cfg).expect("score");
342 assert!(s.precision.abs() < 1e-12);
343 assert!(s.recall.abs() < 1e-12);
344 assert!(s.f1.abs() < 1e-12);
345 }
346
347 /// Greedy matching: a candidate token aligns to its single most-similar
348 /// reference token. Here candidate {x} against reference {x, y}: precision
349 /// (one cand token, best match = x ⇒ 1) but recall averages best-of-x=1 and
350 /// best-of-y=0 ⇒ 0.5.
351 #[test]
352 fn greedy_matching_asymmetric() {
353 let dim = 2;
354 let cand = vec![1.0, 0.0]; // {x}
355 let reference = vec![1.0, 0.0, 0.0, 1.0]; // {x, y}
356 let cfg = BertScoreConfig::default();
357 let s = bert_score(&cand, 1, &reference, 2, dim, &cfg).expect("score");
358 assert!((s.precision - 1.0).abs() < 1e-12, "P = {}", s.precision);
359 assert!((s.recall - 0.5).abs() < 1e-12, "R = {}", s.recall);
360 // F1 = 2 * 1 * 0.5 / 1.5
361 assert!((s.f1 - (2.0 * 0.5 / 1.5)).abs() < 1e-12, "F1 = {}", s.f1);
362 }
363
364 /// Cosine ignores magnitude: scaling a vector does not change the score.
365 #[test]
366 fn scale_invariance() {
367 let dim = 3;
368 let cand = vec![2.0, 0.0, 0.0];
369 let reference = vec![5.0, 0.0, 0.0];
370 let cfg = BertScoreConfig::default();
371 let s = bert_score(&cand, 1, &reference, 1, dim, &cfg).expect("score");
372 assert!((s.f1 - 1.0).abs() < 1e-12, "F1 = {}", s.f1);
373 }
374
375 /// Baseline rescaling maps a raw score `r` to `(r − b)/(1 − b)`.
376 #[test]
377 fn baseline_rescaling() {
378 let dim = 2;
379 // Make raw precision/recall exactly 0.5 via a 60° angle: cos 60° = 0.5.
380 let cand = vec![1.0, 0.0];
381 let reference = vec![0.5, 3.0f64.sqrt() / 2.0]; // unit vector at 60°
382 let cfg_raw = BertScoreConfig::default();
383 let raw = bert_score(&cand, 1, &reference, 1, dim, &cfg_raw).expect("raw");
384 assert!((raw.f1 - 0.5).abs() < 1e-9, "raw f1 = {}", raw.f1);
385
386 let b = 0.25;
387 let cfg = BertScoreConfig { baseline: Some(b) };
388 let rescaled = bert_score(&cand, 1, &reference, 1, dim, &cfg).expect("rescaled");
389 let expected = (0.5 - b) / (1.0 - b);
390 assert!(
391 (rescaled.precision - expected).abs() < 1e-9,
392 "P = {}",
393 rescaled.precision
394 );
395 assert!(
396 (rescaled.f1 - expected).abs() < 1e-9,
397 "F1 = {}",
398 rescaled.f1
399 );
400 }
401
402 /// IDF weighting changes the average toward heavily-weighted tokens.
403 #[test]
404 fn idf_weighting() {
405 let dim = 2;
406 // Reference {x, y}; candidate {x} matches x perfectly (1) and y not at
407 // all (0). Up-weighting the y token lowers recall; up-weighting x raises
408 // it.
409 let cand = vec![1.0, 0.0];
410 let reference = vec![1.0, 0.0, 0.0, 1.0];
411 let cfg = BertScoreConfig::default();
412 let cand_idf = vec![1.0];
413
414 // Weight x token (matched) heavily ⇒ recall → 1.
415 let ref_idf_high_x = vec![10.0, 1.0];
416 let s_high = bert_score_idf(
417 &cand,
418 1,
419 &reference,
420 2,
421 dim,
422 &cand_idf,
423 &ref_idf_high_x,
424 &cfg,
425 )
426 .expect("score");
427 // recall = (10*1 + 1*0) / 11
428 assert!(
429 (s_high.recall - 10.0 / 11.0).abs() < 1e-12,
430 "R = {}",
431 s_high.recall
432 );
433
434 // Weight y token (unmatched) heavily ⇒ recall → 0.
435 let ref_idf_high_y = vec![1.0, 10.0];
436 let s_low = bert_score_idf(
437 &cand,
438 1,
439 &reference,
440 2,
441 dim,
442 &cand_idf,
443 &ref_idf_high_y,
444 &cfg,
445 )
446 .expect("score");
447 assert!(
448 (s_low.recall - 1.0 / 11.0).abs() < 1e-12,
449 "R = {}",
450 s_low.recall
451 );
452 assert!(s_high.recall > s_low.recall);
453 }
454
455 /// `corpus_idf`: rarer tokens get higher IDF than frequent ones, and the
456 /// smoothed formula keeps everything positive.
457 #[test]
458 fn corpus_idf_orders_by_rarity() {
459 // token 0 appears in all 3 docs, token 1 in 1 doc, token 2 in 0 docs.
460 let docs = vec![vec![0usize, 0, 1], vec![0usize], vec![0usize]];
461 let idf = corpus_idf(&docs, 3).expect("idf");
462 assert_eq!(idf.len(), 3);
463 // df(0)=3, df(1)=1, df(2)=0 ⇒ idf strictly increasing in rarity.
464 assert!(idf[0] < idf[1], "{} !< {}", idf[0], idf[1]);
465 assert!(idf[1] < idf[2], "{} !< {}", idf[1], idf[2]);
466 for &w in &idf {
467 assert!(w > 0.0, "idf {w} not positive");
468 }
469 // df=3, N=3 ⇒ ln((1+3)/(1+3)) + 1 = 1.
470 assert!((idf[0] - 1.0).abs() < 1e-12);
471 }
472
473 /// Validation paths.
474 #[test]
475 fn validation_errors() {
476 let cfg = BertScoreConfig::default();
477 // empty
478 assert!(bert_score(&[], 0, &[1.0], 1, 1, &cfg).is_err());
479 // shape mismatch (n*dim != len)
480 assert!(bert_score(&[1.0, 2.0, 3.0], 2, &[1.0, 2.0], 1, 2, &cfg).is_err());
481 // bad baseline = 1.0
482 let bad = BertScoreConfig {
483 baseline: Some(1.0),
484 };
485 assert!(bad.validate().is_err());
486 // idf length mismatch
487 assert!(
488 bert_score_idf(&[1.0, 0.0], 1, &[1.0, 0.0], 1, 2, &[1.0, 1.0], &[1.0], &cfg).is_err()
489 );
490 // negative idf
491 assert!(bert_score_idf(&[1.0, 0.0], 1, &[1.0, 0.0], 1, 2, &[-1.0], &[1.0], &cfg).is_err());
492 // corpus_idf out-of-range id
493 assert!(corpus_idf(&[vec![5usize]], 3).is_err());
494 }
495}