Skip to main content

oxibonsai_eval/
perplexity.rs

1//! Perplexity evaluator.
2//!
3//! Perplexity (PPL) measures how well a probability distribution predicted a
4//! text sample. Lower is better.
5//!
6//! PPL = exp(−(1/N) · Σ log p(xᵢ | x<ᵢ))
7//!
8//! This module also provides bits-per-byte (BPB), an alternative metric
9//! normalised by the number of UTF-8 bytes in the corpus.
10
11use serde::Serialize;
12
13// ──────────────────────────────────────────────────────────────────────────────
14// PerplexityResult
15// ──────────────────────────────────────────────────────────────────────────────
16
17/// Aggregate statistics from a batch perplexity evaluation.
18#[derive(Debug, Serialize)]
19pub struct PerplexityResult {
20    /// Mean perplexity across all samples.
21    pub mean_ppl: f32,
22    /// Minimum perplexity across all samples.
23    pub min_ppl: f32,
24    /// Maximum perplexity across all samples.
25    pub max_ppl: f32,
26    /// Population standard deviation of perplexity values.
27    pub std_ppl: f32,
28    /// Number of samples evaluated.
29    pub n_samples: usize,
30    /// Total number of tokens processed.
31    pub total_tokens: usize,
32}
33
34// ──────────────────────────────────────────────────────────────────────────────
35// PerplexityEvaluator
36// ──────────────────────────────────────────────────────────────────────────────
37
38/// Evaluator that computes perplexity from model log-probabilities.
39pub struct PerplexityEvaluator {
40    /// Sliding-window stride used when chunking long sequences (default: 512).
41    pub stride: usize,
42    /// Optional maximum sequence length to consider.
43    pub max_length: Option<usize>,
44}
45
46impl Default for PerplexityEvaluator {
47    fn default() -> Self {
48        Self::new()
49    }
50}
51
52impl PerplexityEvaluator {
53    /// Create a new evaluator with sensible defaults (stride = 512, no max length).
54    pub fn new() -> Self {
55        Self {
56            stride: 512,
57            max_length: None,
58        }
59    }
60
61    /// Create an evaluator with the specified sliding-window stride.
62    pub fn with_stride(stride: usize) -> Self {
63        Self {
64            stride,
65            max_length: None,
66        }
67    }
68
69    /// Compute perplexity for a single sequence of log-probabilities.
70    ///
71    /// Each element of `log_probs` is the natural log-probability of the token
72    /// at that position given all preceding tokens.
73    ///
74    /// Returns `f32::INFINITY` when `log_probs` is empty (undefined PPL).
75    pub fn compute(&self, log_probs: &[f32]) -> f32 {
76        let probs = match self.max_length {
77            Some(max) => &log_probs[..log_probs.len().min(max)],
78            None => log_probs,
79        };
80
81        if probs.is_empty() {
82            return f32::INFINITY;
83        }
84
85        let n = probs.len() as f32;
86        let avg_neg_log_prob = -probs.iter().copied().sum::<f32>() / n;
87        avg_neg_log_prob.exp()
88    }
89
90    /// Compute perplexity statistics for a batch of log-probability sequences.
91    ///
92    /// Each inner `Vec<f32>` corresponds to one sample.
93    /// Empty sequences are silently skipped.
94    pub fn compute_batch(&self, log_probs_batch: &[Vec<f32>]) -> PerplexityResult {
95        let ppls: Vec<f32> = log_probs_batch
96            .iter()
97            .filter(|lp| !lp.is_empty())
98            .map(|lp| self.compute(lp))
99            .collect();
100
101        let total_tokens: usize = log_probs_batch.iter().map(Vec::len).sum();
102
103        if ppls.is_empty() {
104            return PerplexityResult {
105                mean_ppl: f32::INFINITY,
106                min_ppl: f32::INFINITY,
107                max_ppl: f32::INFINITY,
108                std_ppl: 0.0,
109                n_samples: 0,
110                total_tokens,
111            };
112        }
113
114        let n = ppls.len() as f32;
115        let mean_ppl = ppls.iter().copied().sum::<f32>() / n;
116        let min_ppl = ppls.iter().cloned().fold(f32::INFINITY, f32::min);
117        let max_ppl = ppls.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
118        let variance = ppls.iter().map(|p| (p - mean_ppl).powi(2)).sum::<f32>() / n;
119        let std_ppl = variance.sqrt();
120
121        PerplexityResult {
122            mean_ppl,
123            min_ppl,
124            max_ppl,
125            std_ppl,
126            n_samples: ppls.len(),
127            total_tokens,
128        }
129    }
130
131    /// Compute perplexity from raw logits and the ground-truth token IDs.
132    ///
133    /// `logits[i]` is the vocabulary-wide logit vector at position `i`.
134    /// `token_ids[i]` is the token that was actually observed at position `i`.
135    ///
136    /// The function applies the log-softmax over each logit vector and selects
137    /// the log-prob corresponding to the ground-truth token.
138    ///
139    /// Panics (via bounds check) if `token_ids[i]` is out of range for `logits[i]`.
140    pub fn from_logits(&self, logits: &[Vec<f32>], token_ids: &[u32]) -> f32 {
141        let len = logits.len().min(token_ids.len());
142        if len == 0 {
143            return f32::INFINITY;
144        }
145
146        let log_probs: Vec<f32> = logits[..len]
147            .iter()
148            .zip(token_ids[..len].iter())
149            .map(|(logit_vec, &token_id)| {
150                let max_logit = logit_vec.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
151                let exp_sum: f32 = logit_vec.iter().map(|&l| (l - max_logit).exp()).sum();
152                let log_sum_exp = max_logit + exp_sum.ln();
153                let tid = token_id as usize;
154                logit_vec[tid] - log_sum_exp
155            })
156            .collect();
157
158        self.compute(&log_probs)
159    }
160
161    /// Compute bits-per-byte (BPB).
162    ///
163    /// BPB normalises perplexity by the number of bytes in the corpus:
164    ///
165    /// BPB = (−Σ log₂ p(xᵢ | x<ᵢ)) / n_bytes
166    ///
167    /// `log_probs` must be natural-log probabilities. Returns `f32::INFINITY`
168    /// when `n_bytes == 0` or `log_probs` is empty.
169    pub fn bits_per_byte(&self, log_probs: &[f32], n_bytes: usize) -> f32 {
170        let probs = match self.max_length {
171            Some(max) => &log_probs[..log_probs.len().min(max)],
172            None => log_probs,
173        };
174
175        if probs.is_empty() || n_bytes == 0 {
176            return f32::INFINITY;
177        }
178
179        // Convert nats to bits: log₂(x) = ln(x) / ln(2)
180        let log2_e: f32 = std::f32::consts::E.log2();
181        let neg_sum_log2_prob: f32 = probs.iter().map(|&lp| -lp * log2_e).sum();
182        neg_sum_log2_prob / n_bytes as f32
183    }
184}