Skip to main content

oxicuda_seq/metrics/
bertscore.rs

1//! BERTScore: token-embedding similarity metric via greedy cosine matching.
2//!
3//! Reference: Zhang, T., Kishore, V., Wu, F., Weinberger, K. Q. & Artzi, Y.
4//! (2020). *BERTScore: Evaluating Text Generation with BERT*. ICLR 2020.
5//!
6//! # What this module computes
7//!
8//! BERTScore compares a *candidate* token sequence against a *reference* token
9//! sequence using **contextual embeddings** for each token. Given a candidate
10//! with embeddings `x̂_1 … x̂_n` and a reference with embeddings `x_1 … x_m`,
11//! every pairwise cosine similarity `cos(x̂_i, x_j)` is formed and then matched
12//! **greedily** (each token aligned to its single most similar counterpart):
13//!
14//! ```text
15//! Recall    R = ( Σ_j  idf(x_j)  · max_i cos(x̂_i, x_j) ) / Σ_j idf(x_j)
16//! Precision P = ( Σ_i  idf(x̂_i) · max_j cos(x̂_i, x_j) ) / Σ_i idf(x̂_i)
17//! F1        = 2 · P · R / (P + R)
18//! ```
19//!
20//! With **uniform IDF weights** these reduce to plain averages of the row /
21//! column maxima of the cosine-similarity matrix. Optional inverse-document-
22//! frequency weights (precomputed from a corpus) down-weight frequent tokens
23//! exactly as in the paper.
24//!
25//! ## Honesty note — this is the real metric, not a stub
26//!
27//! The "BERT" in BERTScore is only the *source of the embeddings*. This crate
28//! does not (and cannot, in pure-CPU form) ship a transformer; instead the
29//! **embedding vectors are an input** supplied by the caller (from any encoder:
30//! a `trustformers` model, word2vec, a learned table, …). Everything BERTScore
31//! actually specifies — the cosine-similarity matrix, the greedy precision /
32//! recall / F1 matching, IDF weighting, and the optional baseline rescaling — is
33//! computed here in full and is exact. Feeding genuine contextual embeddings
34//! reproduces the published metric; feeding any other embeddings yields the same
35//! algorithm over those vectors.
36//!
37//! Production code never panics: every fallible path validates its inputs and
38//! returns [`SeqError`].
39
40use crate::error::{SeqError, SeqResult};
41
42/// Precision / recall / F1 triple produced by BERTScore.
43#[derive(Debug, Clone, Copy, PartialEq)]
44pub struct BertScore {
45    /// Precision: how well each candidate token is covered by the reference.
46    pub precision: f64,
47    /// Recall: how well each reference token is covered by the candidate.
48    pub recall: f64,
49    /// Harmonic mean of precision and recall.
50    pub f1: f64,
51}
52
53/// Configuration for BERTScore.
54#[derive(Debug, Clone, Default)]
55pub struct BertScoreConfig {
56    /// Optional baseline value `b ∈ (−1, 1)` used for *rescaling* the raw
57    /// scores: `score ← (score − b) / (1 − b)`. The paper rescales against an
58    /// empirical baseline (the average score of random sentence pairs for the
59    /// chosen model/layer) so that scores spread across a more interpretable
60    /// range. `None` disables rescaling (raw cosine scores in `[−1, 1]`).
61    pub baseline: Option<f64>,
62}
63
64impl BertScoreConfig {
65    /// Validate the configuration.
66    ///
67    /// # Errors
68    /// * [`SeqError::InvalidParameter`] if `baseline` is set to a non-finite
69    ///   value or to `±1` (the rescaling denominator `1 − b` must be non-zero,
70    ///   and a baseline outside `(−1, 1)` is meaningless for cosine scores).
71    pub fn validate(&self) -> SeqResult<()> {
72        if let Some(b) = self.baseline {
73            if !b.is_finite() || b <= -1.0 || b >= 1.0 {
74                return Err(SeqError::InvalidParameter {
75                    name: "baseline".into(),
76                    value: b,
77                });
78            }
79        }
80        Ok(())
81    }
82
83    /// Apply the optional baseline rescaling to a raw score.
84    fn rescale(&self, score: f64) -> f64 {
85        match self.baseline {
86            Some(b) => (score - b) / (1.0 - b),
87            None => score,
88        }
89    }
90}
91
92/// L2 norm of a slice.
93fn l2_norm(v: &[f64]) -> f64 {
94    v.iter().map(|&x| x * x).sum::<f64>().sqrt()
95}
96
97/// Cosine similarity of two equal-length, non-zero vectors. The norms are
98/// passed in to avoid recomputing them inside the `n × m` loop.
99fn cosine(a: &[f64], na: f64, b: &[f64], nb: f64) -> f64 {
100    if na == 0.0 || nb == 0.0 {
101        return 0.0;
102    }
103    let dot: f64 = a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum();
104    (dot / (na * nb)).clamp(-1.0, 1.0)
105}
106
107/// Compute BERTScore between candidate and reference token embeddings with
108/// **uniform** token weights.
109///
110/// `candidate` holds `n` row-major embedding vectors of dimension `dim`
111/// (`candidate.len() == n * dim`); `reference` holds `m` such vectors
112/// (`reference.len() == m * dim`).
113///
114/// # Errors
115/// * [`SeqError::EmptyInput`] if either side has zero tokens or `dim == 0`.
116/// * [`SeqError::ShapeMismatch`] if a flat buffer length is not a multiple of
117///   `dim` consistent with the stated token count.
118/// * Propagates [`BertScoreConfig::validate`].
119pub fn bert_score(
120    candidate: &[f64],
121    n: usize,
122    reference: &[f64],
123    m: usize,
124    dim: usize,
125    config: &BertScoreConfig,
126) -> SeqResult<BertScore> {
127    let cand_idf = vec![1.0; n];
128    let ref_idf = vec![1.0; m];
129    bert_score_idf(candidate, n, reference, m, dim, &cand_idf, &ref_idf, config)
130}
131
132/// Compute BERTScore with explicit **IDF weights** for candidate and reference
133/// tokens (e.g. precomputed inverse-document-frequencies over a corpus). Weights
134/// must be non-negative and finite; at least one weight on each side must be
135/// strictly positive (so the normalising denominators are non-zero).
136///
137/// # Errors
138/// In addition to the cases of [`bert_score`]:
139/// * [`SeqError::LengthMismatch`] if `cand_idf.len() != n` or
140///   `ref_idf.len() != m`.
141/// * [`SeqError::InvalidParameter`] if any weight is negative / non-finite.
142/// * [`SeqError::NumericalInstability`] if the candidate or reference weights
143///   sum to zero.
144#[allow(clippy::too_many_arguments)]
145pub fn bert_score_idf(
146    candidate: &[f64],
147    n: usize,
148    reference: &[f64],
149    m: usize,
150    dim: usize,
151    cand_idf: &[f64],
152    ref_idf: &[f64],
153    config: &BertScoreConfig,
154) -> SeqResult<BertScore> {
155    config.validate()?;
156    if n == 0 || m == 0 || dim == 0 {
157        return Err(SeqError::EmptyInput);
158    }
159    if candidate.len() != n * dim {
160        return Err(SeqError::ShapeMismatch {
161            expected: n * dim,
162            got: candidate.len(),
163        });
164    }
165    if reference.len() != m * dim {
166        return Err(SeqError::ShapeMismatch {
167            expected: m * dim,
168            got: reference.len(),
169        });
170    }
171    if cand_idf.len() != n {
172        return Err(SeqError::LengthMismatch {
173            a: cand_idf.len(),
174            b: n,
175        });
176    }
177    if ref_idf.len() != m {
178        return Err(SeqError::LengthMismatch {
179            a: ref_idf.len(),
180            b: m,
181        });
182    }
183    let mut sum_cand_idf = 0.0;
184    for (idx, &w) in cand_idf.iter().enumerate() {
185        if !(w.is_finite() && w >= 0.0) {
186            return Err(SeqError::InvalidParameter {
187                name: format!("cand_idf[{idx}]"),
188                value: w,
189            });
190        }
191        sum_cand_idf += w;
192    }
193    let mut sum_ref_idf = 0.0;
194    for (idx, &w) in ref_idf.iter().enumerate() {
195        if !(w.is_finite() && w >= 0.0) {
196            return Err(SeqError::InvalidParameter {
197                name: format!("ref_idf[{idx}]"),
198                value: w,
199            });
200        }
201        sum_ref_idf += w;
202    }
203    if sum_cand_idf <= 0.0 || sum_ref_idf <= 0.0 {
204        return Err(SeqError::NumericalInstability(
205            "IDF weights sum to zero on one side".into(),
206        ));
207    }
208
209    // Precompute norms.
210    let cand_norms: Vec<f64> = (0..n)
211        .map(|i| l2_norm(&candidate[i * dim..(i + 1) * dim]))
212        .collect();
213    let ref_norms: Vec<f64> = (0..m)
214        .map(|j| l2_norm(&reference[j * dim..(j + 1) * dim]))
215        .collect();
216
217    // Row maxima (over reference) give precision; column maxima (over
218    // candidate) give recall. Compute the full similarity once, tracking both.
219    let mut row_max = vec![f64::NEG_INFINITY; n]; // best ref for each cand token
220    let mut col_max = vec![f64::NEG_INFINITY; m]; // best cand for each ref token
221    for i in 0..n {
222        let ci = &candidate[i * dim..(i + 1) * dim];
223        let ni = cand_norms[i];
224        for j in 0..m {
225            let rj = &reference[j * dim..(j + 1) * dim];
226            let sim = cosine(ci, ni, rj, ref_norms[j]);
227            if sim > row_max[i] {
228                row_max[i] = sim;
229            }
230            if sim > col_max[j] {
231                col_max[j] = sim;
232            }
233        }
234    }
235
236    // Weighted precision / recall.
237    let mut precision = 0.0;
238    for i in 0..n {
239        precision += cand_idf[i] * row_max[i];
240    }
241    precision /= sum_cand_idf;
242
243    let mut recall = 0.0;
244    for j in 0..m {
245        recall += ref_idf[j] * col_max[j];
246    }
247    recall /= sum_ref_idf;
248
249    // Optional baseline rescaling, then F1 from the (possibly rescaled) P, R.
250    precision = config.rescale(precision);
251    recall = config.rescale(recall);
252
253    let f1 = if precision + recall <= 0.0 {
254        0.0
255    } else {
256        2.0 * precision * recall / (precision + recall)
257    };
258
259    Ok(BertScore {
260        precision,
261        recall,
262        f1,
263    })
264}
265
266/// Convenience IDF estimator: compute smoothed inverse-document-frequencies for
267/// a vocabulary from a corpus of tokenised documents.
268///
269/// `idf(t) = ln( (1 + N) / (1 + df(t)) ) + 1` where `N` is the number of
270/// documents and `df(t)` the number of documents containing token `t` (each
271/// token counted at most once per document). The `+1`/smoothing matches the
272/// common scikit-learn convention and keeps every weight strictly positive.
273/// Tokens are identified by `usize` ids in `0..vocab_size`.
274///
275/// # Errors
276/// * [`SeqError::EmptyInput`] if `vocab_size == 0` or `documents` is empty.
277/// * [`SeqError::IndexOutOfBounds`] if any token id is `>= vocab_size`.
278pub fn corpus_idf(documents: &[Vec<usize>], vocab_size: usize) -> SeqResult<Vec<f64>> {
279    if vocab_size == 0 || documents.is_empty() {
280        return Err(SeqError::EmptyInput);
281    }
282    let n_docs = documents.len() as f64;
283    let mut df = vec![0.0f64; vocab_size];
284    let mut seen = vec![false; vocab_size];
285    for doc in documents {
286        for &t in doc {
287            if t >= vocab_size {
288                return Err(SeqError::IndexOutOfBounds {
289                    index: t,
290                    len: vocab_size,
291                });
292            }
293        }
294        // Reset only the entries we touched (cheaper than clearing the whole
295        // vector for short documents).
296        for &t in doc {
297            seen[t] = false;
298        }
299        for &t in doc {
300            if !seen[t] {
301                seen[t] = true;
302                df[t] += 1.0;
303            }
304        }
305    }
306    let idf: Vec<f64> = df
307        .iter()
308        .map(|&d| ((1.0 + n_docs) / (1.0 + d)).ln() + 1.0)
309        .collect();
310    Ok(idf)
311}
312
313#[cfg(test)]
314mod tests {
315    use super::*;
316
317    /// Identical candidate and reference embeddings ⇒ P = R = F1 = 1
318    /// (every token matches itself with cosine 1).
319    #[test]
320    fn identical_scores_one() {
321        let dim = 3;
322        let emb = vec![
323            1.0, 0.0, 0.0, // tok 0
324            0.0, 1.0, 0.0, // tok 1
325            0.0, 0.0, 1.0, // tok 2
326        ];
327        let cfg = BertScoreConfig::default();
328        let s = bert_score(&emb, 3, &emb, 3, dim, &cfg).expect("score");
329        assert!((s.precision - 1.0).abs() < 1e-12, "P = {}", s.precision);
330        assert!((s.recall - 1.0).abs() < 1e-12, "R = {}", s.recall);
331        assert!((s.f1 - 1.0).abs() < 1e-12, "F1 = {}", s.f1);
332    }
333
334    /// Orthogonal embeddings ⇒ cosine 0 everywhere ⇒ all scores 0.
335    #[test]
336    fn orthogonal_scores_zero() {
337        let dim = 2;
338        let cand = vec![1.0, 0.0]; // 1 token along x
339        let reference = vec![0.0, 1.0]; // 1 token along y
340        let cfg = BertScoreConfig::default();
341        let s = bert_score(&cand, 1, &reference, 1, dim, &cfg).expect("score");
342        assert!(s.precision.abs() < 1e-12);
343        assert!(s.recall.abs() < 1e-12);
344        assert!(s.f1.abs() < 1e-12);
345    }
346
347    /// Greedy matching: a candidate token aligns to its single most-similar
348    /// reference token. Here candidate {x} against reference {x, y}: precision
349    /// (one cand token, best match = x ⇒ 1) but recall averages best-of-x=1 and
350    /// best-of-y=0 ⇒ 0.5.
351    #[test]
352    fn greedy_matching_asymmetric() {
353        let dim = 2;
354        let cand = vec![1.0, 0.0]; // {x}
355        let reference = vec![1.0, 0.0, 0.0, 1.0]; // {x, y}
356        let cfg = BertScoreConfig::default();
357        let s = bert_score(&cand, 1, &reference, 2, dim, &cfg).expect("score");
358        assert!((s.precision - 1.0).abs() < 1e-12, "P = {}", s.precision);
359        assert!((s.recall - 0.5).abs() < 1e-12, "R = {}", s.recall);
360        // F1 = 2 * 1 * 0.5 / 1.5
361        assert!((s.f1 - (2.0 * 0.5 / 1.5)).abs() < 1e-12, "F1 = {}", s.f1);
362    }
363
364    /// Cosine ignores magnitude: scaling a vector does not change the score.
365    #[test]
366    fn scale_invariance() {
367        let dim = 3;
368        let cand = vec![2.0, 0.0, 0.0];
369        let reference = vec![5.0, 0.0, 0.0];
370        let cfg = BertScoreConfig::default();
371        let s = bert_score(&cand, 1, &reference, 1, dim, &cfg).expect("score");
372        assert!((s.f1 - 1.0).abs() < 1e-12, "F1 = {}", s.f1);
373    }
374
375    /// Baseline rescaling maps a raw score `r` to `(r − b)/(1 − b)`.
376    #[test]
377    fn baseline_rescaling() {
378        let dim = 2;
379        // Make raw precision/recall exactly 0.5 via a 60° angle: cos 60° = 0.5.
380        let cand = vec![1.0, 0.0];
381        let reference = vec![0.5, 3.0f64.sqrt() / 2.0]; // unit vector at 60°
382        let cfg_raw = BertScoreConfig::default();
383        let raw = bert_score(&cand, 1, &reference, 1, dim, &cfg_raw).expect("raw");
384        assert!((raw.f1 - 0.5).abs() < 1e-9, "raw f1 = {}", raw.f1);
385
386        let b = 0.25;
387        let cfg = BertScoreConfig { baseline: Some(b) };
388        let rescaled = bert_score(&cand, 1, &reference, 1, dim, &cfg).expect("rescaled");
389        let expected = (0.5 - b) / (1.0 - b);
390        assert!(
391            (rescaled.precision - expected).abs() < 1e-9,
392            "P = {}",
393            rescaled.precision
394        );
395        assert!(
396            (rescaled.f1 - expected).abs() < 1e-9,
397            "F1 = {}",
398            rescaled.f1
399        );
400    }
401
402    /// IDF weighting changes the average toward heavily-weighted tokens.
403    #[test]
404    fn idf_weighting() {
405        let dim = 2;
406        // Reference {x, y}; candidate {x} matches x perfectly (1) and y not at
407        // all (0). Up-weighting the y token lowers recall; up-weighting x raises
408        // it.
409        let cand = vec![1.0, 0.0];
410        let reference = vec![1.0, 0.0, 0.0, 1.0];
411        let cfg = BertScoreConfig::default();
412        let cand_idf = vec![1.0];
413
414        // Weight x token (matched) heavily ⇒ recall → 1.
415        let ref_idf_high_x = vec![10.0, 1.0];
416        let s_high = bert_score_idf(
417            &cand,
418            1,
419            &reference,
420            2,
421            dim,
422            &cand_idf,
423            &ref_idf_high_x,
424            &cfg,
425        )
426        .expect("score");
427        // recall = (10*1 + 1*0) / 11
428        assert!(
429            (s_high.recall - 10.0 / 11.0).abs() < 1e-12,
430            "R = {}",
431            s_high.recall
432        );
433
434        // Weight y token (unmatched) heavily ⇒ recall → 0.
435        let ref_idf_high_y = vec![1.0, 10.0];
436        let s_low = bert_score_idf(
437            &cand,
438            1,
439            &reference,
440            2,
441            dim,
442            &cand_idf,
443            &ref_idf_high_y,
444            &cfg,
445        )
446        .expect("score");
447        assert!(
448            (s_low.recall - 1.0 / 11.0).abs() < 1e-12,
449            "R = {}",
450            s_low.recall
451        );
452        assert!(s_high.recall > s_low.recall);
453    }
454
455    /// `corpus_idf`: rarer tokens get higher IDF than frequent ones, and the
456    /// smoothed formula keeps everything positive.
457    #[test]
458    fn corpus_idf_orders_by_rarity() {
459        // token 0 appears in all 3 docs, token 1 in 1 doc, token 2 in 0 docs.
460        let docs = vec![vec![0usize, 0, 1], vec![0usize], vec![0usize]];
461        let idf = corpus_idf(&docs, 3).expect("idf");
462        assert_eq!(idf.len(), 3);
463        // df(0)=3, df(1)=1, df(2)=0  ⇒ idf strictly increasing in rarity.
464        assert!(idf[0] < idf[1], "{} !< {}", idf[0], idf[1]);
465        assert!(idf[1] < idf[2], "{} !< {}", idf[1], idf[2]);
466        for &w in &idf {
467            assert!(w > 0.0, "idf {w} not positive");
468        }
469        // df=3, N=3 ⇒ ln((1+3)/(1+3)) + 1 = 1.
470        assert!((idf[0] - 1.0).abs() < 1e-12);
471    }
472
473    /// Validation paths.
474    #[test]
475    fn validation_errors() {
476        let cfg = BertScoreConfig::default();
477        // empty
478        assert!(bert_score(&[], 0, &[1.0], 1, 1, &cfg).is_err());
479        // shape mismatch (n*dim != len)
480        assert!(bert_score(&[1.0, 2.0, 3.0], 2, &[1.0, 2.0], 1, 2, &cfg).is_err());
481        // bad baseline = 1.0
482        let bad = BertScoreConfig {
483            baseline: Some(1.0),
484        };
485        assert!(bad.validate().is_err());
486        // idf length mismatch
487        assert!(
488            bert_score_idf(&[1.0, 0.0], 1, &[1.0, 0.0], 1, 2, &[1.0, 1.0], &[1.0], &cfg).is_err()
489        );
490        // negative idf
491        assert!(bert_score_idf(&[1.0, 0.0], 1, &[1.0, 0.0], 1, 2, &[-1.0], &[1.0], &cfg).is_err());
492        // corpus_idf out-of-range id
493        assert!(corpus_idf(&[vec![5usize]], 3).is_err());
494    }
495}