oxibonsai_eval/perplexity.rs
1//! Perplexity evaluator.
2//!
3//! Perplexity (PPL) measures how well a probability distribution predicted a
4//! text sample. Lower is better.
5//!
6//! PPL = exp(−(1/N) · Σ log p(xᵢ | x<ᵢ))
7//!
8//! This module also provides bits-per-byte (BPB), an alternative metric
9//! normalised by the number of UTF-8 bytes in the corpus.
10
11use serde::Serialize;
12
13// ──────────────────────────────────────────────────────────────────────────────
14// PerplexityResult
15// ──────────────────────────────────────────────────────────────────────────────
16
17/// Aggregate statistics from a batch perplexity evaluation.
18#[derive(Debug, Serialize)]
19pub struct PerplexityResult {
20 /// Mean perplexity across all samples.
21 pub mean_ppl: f32,
22 /// Minimum perplexity across all samples.
23 pub min_ppl: f32,
24 /// Maximum perplexity across all samples.
25 pub max_ppl: f32,
26 /// Population standard deviation of perplexity values.
27 pub std_ppl: f32,
28 /// Number of samples evaluated.
29 pub n_samples: usize,
30 /// Total number of tokens processed.
31 pub total_tokens: usize,
32}
33
34// ──────────────────────────────────────────────────────────────────────────────
35// PerplexityEvaluator
36// ──────────────────────────────────────────────────────────────────────────────
37
38/// Evaluator that computes perplexity from model log-probabilities.
39pub struct PerplexityEvaluator {
40 /// Sliding-window stride used when chunking long sequences (default: 512).
41 pub stride: usize,
42 /// Optional maximum sequence length to consider.
43 pub max_length: Option<usize>,
44}
45
46impl Default for PerplexityEvaluator {
47 fn default() -> Self {
48 Self::new()
49 }
50}
51
52impl PerplexityEvaluator {
53 /// Create a new evaluator with sensible defaults (stride = 512, no max length).
54 pub fn new() -> Self {
55 Self {
56 stride: 512,
57 max_length: None,
58 }
59 }
60
61 /// Create an evaluator with the specified sliding-window stride.
62 pub fn with_stride(stride: usize) -> Self {
63 Self {
64 stride,
65 max_length: None,
66 }
67 }
68
69 /// Compute perplexity for a single sequence of log-probabilities.
70 ///
71 /// Each element of `log_probs` is the natural log-probability of the token
72 /// at that position given all preceding tokens.
73 ///
74 /// Returns `f32::INFINITY` when `log_probs` is empty (undefined PPL).
75 pub fn compute(&self, log_probs: &[f32]) -> f32 {
76 let probs = match self.max_length {
77 Some(max) => &log_probs[..log_probs.len().min(max)],
78 None => log_probs,
79 };
80
81 if probs.is_empty() {
82 return f32::INFINITY;
83 }
84
85 let n = probs.len() as f32;
86 let avg_neg_log_prob = -probs.iter().copied().sum::<f32>() / n;
87 avg_neg_log_prob.exp()
88 }
89
90 /// Compute perplexity statistics for a batch of log-probability sequences.
91 ///
92 /// Each inner `Vec<f32>` corresponds to one sample.
93 /// Empty sequences are silently skipped.
94 pub fn compute_batch(&self, log_probs_batch: &[Vec<f32>]) -> PerplexityResult {
95 let ppls: Vec<f32> = log_probs_batch
96 .iter()
97 .filter(|lp| !lp.is_empty())
98 .map(|lp| self.compute(lp))
99 .collect();
100
101 let total_tokens: usize = log_probs_batch.iter().map(Vec::len).sum();
102
103 if ppls.is_empty() {
104 return PerplexityResult {
105 mean_ppl: f32::INFINITY,
106 min_ppl: f32::INFINITY,
107 max_ppl: f32::INFINITY,
108 std_ppl: 0.0,
109 n_samples: 0,
110 total_tokens,
111 };
112 }
113
114 let n = ppls.len() as f32;
115 let mean_ppl = ppls.iter().copied().sum::<f32>() / n;
116 let min_ppl = ppls.iter().cloned().fold(f32::INFINITY, f32::min);
117 let max_ppl = ppls.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
118 let variance = ppls.iter().map(|p| (p - mean_ppl).powi(2)).sum::<f32>() / n;
119 let std_ppl = variance.sqrt();
120
121 PerplexityResult {
122 mean_ppl,
123 min_ppl,
124 max_ppl,
125 std_ppl,
126 n_samples: ppls.len(),
127 total_tokens,
128 }
129 }
130
131 /// Compute perplexity from raw logits and the ground-truth token IDs.
132 ///
133 /// `logits[i]` is the vocabulary-wide logit vector at position `i`.
134 /// `token_ids[i]` is the token that was actually observed at position `i`.
135 ///
136 /// The function applies the log-softmax over each logit vector and selects
137 /// the log-prob corresponding to the ground-truth token.
138 ///
139 /// Panics (via bounds check) if `token_ids[i]` is out of range for `logits[i]`.
140 pub fn from_logits(&self, logits: &[Vec<f32>], token_ids: &[u32]) -> f32 {
141 let len = logits.len().min(token_ids.len());
142 if len == 0 {
143 return f32::INFINITY;
144 }
145
146 let log_probs: Vec<f32> = logits[..len]
147 .iter()
148 .zip(token_ids[..len].iter())
149 .map(|(logit_vec, &token_id)| {
150 let max_logit = logit_vec.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
151 let exp_sum: f32 = logit_vec.iter().map(|&l| (l - max_logit).exp()).sum();
152 let log_sum_exp = max_logit + exp_sum.ln();
153 let tid = token_id as usize;
154 logit_vec[tid] - log_sum_exp
155 })
156 .collect();
157
158 self.compute(&log_probs)
159 }
160
161 /// Compute bits-per-byte (BPB).
162 ///
163 /// BPB normalises perplexity by the number of bytes in the corpus:
164 ///
165 /// BPB = (−Σ log₂ p(xᵢ | x<ᵢ)) / n_bytes
166 ///
167 /// `log_probs` must be natural-log probabilities. Returns `f32::INFINITY`
168 /// when `n_bytes == 0` or `log_probs` is empty.
169 pub fn bits_per_byte(&self, log_probs: &[f32], n_bytes: usize) -> f32 {
170 let probs = match self.max_length {
171 Some(max) => &log_probs[..log_probs.len().min(max)],
172 None => log_probs,
173 };
174
175 if probs.is_empty() || n_bytes == 0 {
176 return f32::INFINITY;
177 }
178
179 // Convert nats to bits: log₂(x) = ln(x) / ln(2)
180 let log2_e: f32 = std::f32::consts::E.log2();
181 let neg_sum_log2_prob: f32 = probs.iter().map(|&lp| -lp * log2_e).sum();
182 neg_sum_log2_prob / n_bytes as f32
183 }
184}