burn_train/metric/
perplexity.rs

1use core::marker::PhantomData;
2
3use super::state::FormatOptions;
4use super::{MetricMetadata, NumericEntry, SerializedEntry, format_float};
5use crate::metric::{Metric, MetricAttributes, MetricName, Numeric, NumericAttributes};
6use burn_core::tensor::backend::Backend;
7use burn_core::tensor::{ElementConversion, Int, Tensor};
8
9/// Custom state for perplexity metric that correctly accumulates negative log-likelihood.
10///
11/// Unlike other metrics that can be averaged, perplexity requires special handling:
12/// - Accumulate total negative log-likelihood across all tokens
13/// - Accumulate total number of effective tokens
14/// - Compute perplexity as exp(total_nll / total_tokens) only at the end
15#[derive(Clone)]
16struct PerplexityState {
17    /// Sum of negative log-likelihood across all tokens
18    sum_nll: f64,
19    /// Total number of effective tokens (excluding padding)
20    total_tokens: usize,
21    /// Current batch perplexity (for display purposes)
22    current: f64,
23}
24
25impl PerplexityState {
26    fn new() -> Self {
27        Self {
28            sum_nll: 0.0,
29            total_tokens: 0,
30            current: f64::NAN,
31        }
32    }
33
34    fn reset(&mut self) {
35        self.sum_nll = 0.0;
36        self.total_tokens = 0;
37        self.current = f64::NAN;
38    }
39
40    /// Update state with negative log-likelihood and token count from current batch
41    fn update(
42        &mut self,
43        sum_log_prob: f64,
44        effective_tokens: usize,
45        format: FormatOptions,
46    ) -> SerializedEntry {
47        // sum_log_prob is already the sum of log probabilities (negative values)
48        // We need to negate it to get negative log-likelihood
49        let batch_nll = -sum_log_prob;
50
51        // Accumulate across batches
52        self.sum_nll += batch_nll;
53        self.total_tokens += effective_tokens;
54
55        // Compute current batch perplexity for display
56        let batch_perplexity = if effective_tokens > 0 {
57            (batch_nll / effective_tokens as f64).exp()
58        } else {
59            f64::INFINITY
60        };
61        self.current = batch_perplexity;
62
63        // Compute running epoch perplexity
64        let epoch_perplexity = if self.total_tokens > 0 {
65            (self.sum_nll / self.total_tokens as f64).exp()
66        } else {
67            f64::INFINITY
68        };
69
70        // Format for display
71        let (formatted_current, formatted_running) = match format.precision_value() {
72            Some(precision) => (
73                format_float(batch_perplexity, precision),
74                format_float(epoch_perplexity, precision),
75            ),
76            None => (format!("{batch_perplexity}"), format!("{epoch_perplexity}")),
77        };
78
79        let formatted = match format.unit_value() {
80            Some(unit) => {
81                format!("epoch {formatted_running} {unit} - batch {formatted_current} {unit}")
82            }
83            None => format!("epoch {formatted_running} - batch {formatted_current}"),
84        };
85
86        // Serialize the state for aggregation
87        let serialized = NumericEntry::Aggregated {
88            aggregated_value: epoch_perplexity,
89            count: self.total_tokens,
90        }
91        .serialize();
92
93        SerializedEntry::new(formatted, serialized)
94    }
95
96    fn value(&self) -> NumericEntry {
97        let perplexity = if self.total_tokens > 0 {
98            (self.sum_nll / self.total_tokens as f64).exp()
99        } else {
100            f64::INFINITY
101        };
102
103        NumericEntry::Aggregated {
104            aggregated_value: perplexity,
105            count: self.total_tokens,
106        }
107    }
108
109    fn running_value(&self) -> NumericEntry {
110        self.value()
111    }
112}
113
114/// The perplexity metric.
115///
116/// Perplexity is a measure of how well a probability distribution or probability model
117/// predicts a sample. It's commonly used to evaluate language models. A lower perplexity
118/// indicates that the model is more confident in its predictions.
119///
120/// Mathematically, perplexity is defined as the exponentiation of the cross-entropy loss:
121/// PPL = exp(H(p, q)) = exp(-1/N * Σ log(p(x_i)))
122///
123/// where:
124/// - H(p, q) is the cross-entropy between the true distribution p and predicted distribution q
125/// - N is the number of tokens
126/// - p(x_i) is the predicted probability of the i-th token
127///
128/// # Aggregation
129/// Unlike other metrics, perplexity cannot be simply averaged across batches.
130/// This implementation correctly accumulates the total negative log-likelihood and
131/// total token count across batches, then computes perplexity as exp(total_nll / total_tokens).
132#[derive(Clone)]
133pub struct PerplexityMetric<B: Backend> {
134    name: MetricName,
135    state: PerplexityState,
136    pad_token: Option<usize>,
137    _b: PhantomData<B>,
138}
139
140/// The [perplexity metric](PerplexityMetric) input type.
141#[derive(new)]
142pub struct PerplexityInput<B: Backend> {
143    /// Logits tensor of shape [batch_size * sequence_length, vocab_size]
144    outputs: Tensor<B, 2>,
145    /// Target tokens tensor of shape [batch_size * sequence_length]
146    targets: Tensor<B, 1, Int>,
147}
148
149impl<B: Backend> Default for PerplexityMetric<B> {
150    fn default() -> Self {
151        Self::new()
152    }
153}
154
155impl<B: Backend> PerplexityMetric<B> {
156    /// Creates the metric.
157    pub fn new() -> Self {
158        Self {
159            name: MetricName::new("Perplexity".to_string()),
160            state: PerplexityState::new(),
161            pad_token: Default::default(),
162            _b: PhantomData,
163        }
164    }
165
166    /// Sets the pad token to exclude from perplexity calculation.
167    ///
168    /// When a pad token is set, predictions for padding tokens are masked out
169    /// and do not contribute to the perplexity calculation. This is important
170    /// for variable-length sequences where padding is used.
171    pub fn with_pad_token(mut self, index: usize) -> Self {
172        self.pad_token = Some(index);
173        self
174    }
175}
176
177impl<B: Backend> Metric for PerplexityMetric<B> {
178    type Input = PerplexityInput<B>;
179
180    fn update(
181        &mut self,
182        input: &PerplexityInput<B>,
183        _metadata: &MetricMetadata,
184    ) -> SerializedEntry {
185        let targets = input.targets.clone();
186        let outputs = input.outputs.clone();
187
188        let [total_tokens, _vocab_size] = outputs.dims();
189
190        // Convert logits to log probabilities using log_softmax for numerical stability
191        let log_probs = burn_core::tensor::activation::log_softmax(outputs, 1);
192
193        // Gather the log probabilities for the target tokens
194        let target_log_probs = log_probs
195            .gather(1, targets.clone().unsqueeze_dim(1))
196            .squeeze_dim(1);
197
198        let (sum_log_prob, effective_tokens) = match self.pad_token {
199            Some(pad_token) => {
200                // Create a mask for non-padding tokens
201                let mask = targets.clone().not_equal_elem(pad_token as i64);
202
203                // Apply mask to log probabilities (set padding log probs to 0)
204                let masked_log_probs = target_log_probs.mask_fill(mask.clone().bool_not(), 0.0);
205
206                // Sum the log probabilities and count effective tokens
207                let sum_log_prob = masked_log_probs.sum().into_scalar().elem::<f64>();
208                let effective_tokens = mask.int().sum().into_scalar().elem::<i64>() as usize;
209
210                (sum_log_prob, effective_tokens)
211            }
212            None => {
213                // No padding, use all tokens
214                let sum_log_prob = target_log_probs.sum().into_scalar().elem::<f64>();
215                (sum_log_prob, total_tokens)
216            }
217        };
218
219        // Pass the sum_log_prob and effective_tokens to the state
220        // The state will handle the correct accumulation and perplexity calculation
221        self.state.update(
222            sum_log_prob,
223            effective_tokens,
224            FormatOptions::new(self.name()).precision(2),
225        )
226    }
227
228    fn clear(&mut self) {
229        self.state.reset()
230    }
231
232    fn name(&self) -> MetricName {
233        self.name.clone()
234    }
235
236    fn attributes(&self) -> MetricAttributes {
237        NumericAttributes {
238            unit: None,
239            higher_is_better: false,
240        }
241        .into()
242    }
243}
244
245impl<B: Backend> Numeric for PerplexityMetric<B> {
246    fn value(&self) -> NumericEntry {
247        self.state.value()
248    }
249
250    fn running_value(&self) -> NumericEntry {
251        self.state.running_value()
252    }
253}
254
255#[cfg(test)]
256mod tests {
257    use super::*;
258    use crate::TestBackend;
259
260    #[test]
261    fn test_perplexity_perfect_prediction() {
262        let device = Default::default();
263        let mut metric = PerplexityMetric::<TestBackend>::new();
264
265        // Perfect prediction: target is always the highest probability class
266        let input = PerplexityInput::new(
267            Tensor::from_data(
268                [
269                    [10.0, 0.0, 0.0], // Very confident prediction for class 0
270                    [0.0, 10.0, 0.0], // Very confident prediction for class 1
271                    [0.0, 0.0, 10.0], // Very confident prediction for class 2
272                ],
273                &device,
274            ),
275            Tensor::from_data([0, 1, 2], &device),
276        );
277
278        let _entry = metric.update(&input, &MetricMetadata::fake());
279        let perplexity = metric.value().current();
280
281        // Perfect predictions should result in very low perplexity (close to 1.0)
282        assert!(
283            perplexity < 1.1,
284            "Perfect predictions should have low perplexity, got {}",
285            perplexity
286        );
287    }
288
289    #[test]
290    fn test_perplexity_uniform_prediction() {
291        let device = Default::default();
292        let mut metric = PerplexityMetric::<TestBackend>::new();
293
294        // Uniform prediction: all classes have equal probability
295        let input = PerplexityInput::new(
296            Tensor::from_data(
297                [
298                    [0.0, 0.0, 0.0], // Uniform distribution (after softmax)
299                    [0.0, 0.0, 0.0], // Uniform distribution (after softmax)
300                    [0.0, 0.0, 0.0], // Uniform distribution (after softmax)
301                ],
302                &device,
303            ),
304            Tensor::from_data([0, 1, 2], &device),
305        );
306
307        let _entry = metric.update(&input, &MetricMetadata::fake());
308        let perplexity = metric.value().current();
309
310        // Uniform distribution over 3 classes should have perplexity ≈ 3.0
311        assert!(
312            (perplexity - 3.0).abs() < 0.1,
313            "Uniform distribution perplexity should be ~3.0, got {}",
314            perplexity
315        );
316    }
317
318    #[test]
319    fn test_perplexity_with_padding() {
320        let device = Default::default();
321        let mut metric = PerplexityMetric::<TestBackend>::new().with_pad_token(3);
322
323        let input = PerplexityInput::new(
324            Tensor::from_data(
325                [
326                    [10.0, 0.0, 0.0, 0.0], // Good prediction for class 0
327                    [0.0, 10.0, 0.0, 0.0], // Good prediction for class 1
328                    [0.0, 0.0, 0.0, 1.0],  // This is padding - should be ignored
329                    [0.0, 0.0, 0.0, 1.0],  // This is padding - should be ignored
330                ],
331                &device,
332            ),
333            Tensor::from_data([0, 1, 3, 3], &device), // 3 is pad token
334        );
335
336        let _entry = metric.update(&input, &MetricMetadata::fake());
337        let perplexity = metric.value().current();
338
339        // Should only consider the first two predictions, both of which are confident
340        assert!(
341            perplexity < 1.1,
342            "Good predictions with padding should have low perplexity, got {}",
343            perplexity
344        );
345    }
346
347    #[test]
348    fn test_perplexity_wrong_prediction() {
349        let device = Default::default();
350        let mut metric = PerplexityMetric::<TestBackend>::new();
351
352        // Wrong predictions: target class has very low probability
353        let input = PerplexityInput::new(
354            Tensor::from_data(
355                [
356                    [0.0, 10.0, 0.0], // Predicts class 1, but target is 0
357                    [10.0, 0.0, 0.0], // Predicts class 0, but target is 1
358                    [0.0, 0.0, 10.0], // Predicts class 2, but target is 0
359                ],
360                &device,
361            ),
362            Tensor::from_data([0, 1, 0], &device),
363        );
364
365        let _entry = metric.update(&input, &MetricMetadata::fake());
366        let perplexity = metric.value().current();
367
368        // Wrong predictions should result in high perplexity
369        assert!(
370            perplexity > 10.0,
371            "Wrong predictions should have high perplexity, got {}",
372            perplexity
373        );
374    }
375
376    #[test]
377    fn test_perplexity_multi_batch_aggregation() {
378        let device = Default::default();
379        let mut metric = PerplexityMetric::<TestBackend>::new();
380
381        // First batch: 2 tokens with uniform distribution (log_prob ≈ -1.0986 each)
382        let input1 = PerplexityInput::new(
383            Tensor::from_data(
384                [
385                    [0.0, 0.0, 0.0], // Uniform distribution (log_prob ≈ -1.0986)
386                    [0.0, 0.0, 0.0], // Uniform distribution (log_prob ≈ -1.0986)
387                ],
388                &device,
389            ),
390            Tensor::from_data([0, 1], &device),
391        );
392
393        // Second batch: 1 token with uniform distribution
394        let input2 = PerplexityInput::new(
395            Tensor::from_data(
396                [
397                    [0.0, 0.0, 0.0], // Uniform distribution (log_prob ≈ -1.0986)
398                ],
399                &device,
400            ),
401            Tensor::from_data([2], &device),
402        );
403
404        // Update with both batches
405        let _entry1 = metric.update(&input1, &MetricMetadata::fake());
406        let _entry2 = metric.update(&input2, &MetricMetadata::fake());
407
408        let aggregated_perplexity = metric.value().current();
409
410        // For uniform distribution over 3 classes: log_prob ≈ -log(3) ≈ -1.0986
411        // Total negative log-likelihood: 3 * 1.0986 ≈ 3.2958
412        // Total tokens: 3
413        // Expected perplexity: exp(3.2958 / 3) = exp(1.0986) ≈ 3.0
414        assert!(
415            (aggregated_perplexity - 3.0).abs() < 0.1,
416            "Multi-batch aggregated perplexity should be ~3.0, got {}",
417            aggregated_perplexity
418        );
419
420        // Compare with single batch containing all data
421        let mut single_batch_metric = PerplexityMetric::<TestBackend>::new();
422        let single_input = PerplexityInput::new(
423            Tensor::from_data([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], &device),
424            Tensor::from_data([0, 1, 2], &device),
425        );
426
427        let _single_entry = single_batch_metric.update(&single_input, &MetricMetadata::fake());
428        let single_batch_perplexity = single_batch_metric.value().current();
429
430        // Multi-batch and single-batch should give the same result
431        assert!(
432            (aggregated_perplexity - single_batch_perplexity).abs() < 0.01,
433            "Multi-batch ({}) and single-batch ({}) perplexity should match",
434            aggregated_perplexity,
435            single_batch_perplexity
436        );
437    }
438}