burn_train/metric/
perplexity.rs

1use core::marker::PhantomData;
2
3use super::state::FormatOptions;
4use super::{MetricEntry, MetricMetadata, NumericEntry, format_float};
5use crate::metric::{Metric, MetricName, Numeric};
6use burn_core::tensor::backend::Backend;
7use burn_core::tensor::{ElementConversion, Int, Tensor};
8
9/// Custom state for perplexity metric that correctly accumulates negative log-likelihood.
10///
11/// Unlike other metrics that can be averaged, perplexity requires special handling:
12/// - Accumulate total negative log-likelihood across all tokens
13/// - Accumulate total number of effective tokens
14/// - Compute perplexity as exp(total_nll / total_tokens) only at the end
15#[derive(Clone)]
16struct PerplexityState {
17    /// Sum of negative log-likelihood across all tokens
18    sum_nll: f64,
19    /// Total number of effective tokens (excluding padding)
20    total_tokens: usize,
21    /// Current batch perplexity (for display purposes)
22    current: f64,
23}
24
25impl PerplexityState {
26    fn new() -> Self {
27        Self {
28            sum_nll: 0.0,
29            total_tokens: 0,
30            current: f64::NAN,
31        }
32    }
33
34    fn reset(&mut self) {
35        self.sum_nll = 0.0;
36        self.total_tokens = 0;
37        self.current = f64::NAN;
38    }
39
40    /// Update state with negative log-likelihood and token count from current batch
41    fn update(
42        &mut self,
43        sum_log_prob: f64,
44        effective_tokens: usize,
45        format: FormatOptions,
46    ) -> MetricEntry {
47        // sum_log_prob is already the sum of log probabilities (negative values)
48        // We need to negate it to get negative log-likelihood
49        let batch_nll = -sum_log_prob;
50
51        // Accumulate across batches
52        self.sum_nll += batch_nll;
53        self.total_tokens += effective_tokens;
54
55        // Compute current batch perplexity for display
56        let batch_perplexity = if effective_tokens > 0 {
57            (batch_nll / effective_tokens as f64).exp()
58        } else {
59            f64::INFINITY
60        };
61        self.current = batch_perplexity;
62
63        // Compute running epoch perplexity
64        let epoch_perplexity = if self.total_tokens > 0 {
65            (self.sum_nll / self.total_tokens as f64).exp()
66        } else {
67            f64::INFINITY
68        };
69
70        // Format for display
71        let (formatted_current, formatted_running) = match format.precision_value() {
72            Some(precision) => (
73                format_float(batch_perplexity, precision),
74                format_float(epoch_perplexity, precision),
75            ),
76            None => (format!("{batch_perplexity}"), format!("{epoch_perplexity}")),
77        };
78
79        let formatted = match format.unit_value() {
80            Some(unit) => {
81                format!("epoch {formatted_running} {unit} - batch {formatted_current} {unit}")
82            }
83            None => format!("epoch {formatted_running} - batch {formatted_current}"),
84        };
85
86        // Serialize the state for aggregation
87        let serialized = NumericEntry::Aggregated {
88            sum: self.sum_nll,
89            count: self.total_tokens,
90            current: epoch_perplexity,
91        }
92        .serialize();
93
94        MetricEntry::new(format.name().clone(), formatted, serialized)
95    }
96
97    fn value(&self) -> NumericEntry {
98        let perplexity = if self.total_tokens > 0 {
99            (self.sum_nll / self.total_tokens as f64).exp()
100        } else {
101            f64::INFINITY
102        };
103
104        NumericEntry::Aggregated {
105            sum: self.sum_nll,
106            count: self.total_tokens,
107            current: perplexity,
108        }
109    }
110}
111
112/// The perplexity metric.
113///
114/// Perplexity is a measure of how well a probability distribution or probability model
115/// predicts a sample. It's commonly used to evaluate language models. A lower perplexity
116/// indicates that the model is more confident in its predictions.
117///
118/// Mathematically, perplexity is defined as the exponentiation of the cross-entropy loss:
119/// PPL = exp(H(p, q)) = exp(-1/N * Σ log(p(x_i)))
120///
121/// where:
122/// - H(p, q) is the cross-entropy between the true distribution p and predicted distribution q
123/// - N is the number of tokens
124/// - p(x_i) is the predicted probability of the i-th token
125///
126/// # Aggregation
127/// Unlike other metrics, perplexity cannot be simply averaged across batches.
128/// This implementation correctly accumulates the total negative log-likelihood and
129/// total token count across batches, then computes perplexity as exp(total_nll / total_tokens).
130#[derive(Clone)]
131pub struct PerplexityMetric<B: Backend> {
132    name: MetricName,
133    state: PerplexityState,
134    pad_token: Option<usize>,
135    _b: PhantomData<B>,
136}
137
138/// The [perplexity metric](PerplexityMetric) input type.
139#[derive(new)]
140pub struct PerplexityInput<B: Backend> {
141    /// Logits tensor of shape [batch_size * sequence_length, vocab_size]
142    outputs: Tensor<B, 2>,
143    /// Target tokens tensor of shape [batch_size * sequence_length]
144    targets: Tensor<B, 1, Int>,
145}
146
147impl<B: Backend> Default for PerplexityMetric<B> {
148    fn default() -> Self {
149        Self::new()
150    }
151}
152
153impl<B: Backend> PerplexityMetric<B> {
154    /// Creates the metric.
155    pub fn new() -> Self {
156        Self {
157            name: MetricName::new("Perplexity".to_string()),
158            state: PerplexityState::new(),
159            pad_token: Default::default(),
160            _b: PhantomData,
161        }
162    }
163
164    /// Sets the pad token to exclude from perplexity calculation.
165    ///
166    /// When a pad token is set, predictions for padding tokens are masked out
167    /// and do not contribute to the perplexity calculation. This is important
168    /// for variable-length sequences where padding is used.
169    pub fn with_pad_token(mut self, index: usize) -> Self {
170        self.pad_token = Some(index);
171        self
172    }
173}
174
175impl<B: Backend> Metric for PerplexityMetric<B> {
176    type Input = PerplexityInput<B>;
177
178    fn update(&mut self, input: &PerplexityInput<B>, _metadata: &MetricMetadata) -> MetricEntry {
179        let targets = input.targets.clone();
180        let outputs = input.outputs.clone();
181
182        let [total_tokens, _vocab_size] = outputs.dims();
183
184        // Convert logits to log probabilities using log_softmax for numerical stability
185        let log_probs = burn_core::tensor::activation::log_softmax(outputs, 1);
186
187        // Gather the log probabilities for the target tokens
188        let target_log_probs = log_probs
189            .gather(1, targets.clone().unsqueeze_dim(1))
190            .squeeze_dim(1);
191
192        let (sum_log_prob, effective_tokens) = match self.pad_token {
193            Some(pad_token) => {
194                // Create a mask for non-padding tokens
195                let mask = targets.clone().not_equal_elem(pad_token as i64);
196
197                // Apply mask to log probabilities (set padding log probs to 0)
198                let masked_log_probs = target_log_probs.mask_fill(mask.clone().bool_not(), 0.0);
199
200                // Sum the log probabilities and count effective tokens
201                let sum_log_prob = masked_log_probs.sum().into_scalar().elem::<f64>();
202                let effective_tokens = mask.int().sum().into_scalar().elem::<i64>() as usize;
203
204                (sum_log_prob, effective_tokens)
205            }
206            None => {
207                // No padding, use all tokens
208                let sum_log_prob = target_log_probs.sum().into_scalar().elem::<f64>();
209                (sum_log_prob, total_tokens)
210            }
211        };
212
213        // Pass the sum_log_prob and effective_tokens to the state
214        // The state will handle the correct accumulation and perplexity calculation
215        self.state.update(
216            sum_log_prob,
217            effective_tokens,
218            FormatOptions::new(self.name()).precision(2),
219        )
220    }
221
222    fn clear(&mut self) {
223        self.state.reset()
224    }
225
226    fn name(&self) -> MetricName {
227        self.name.clone()
228    }
229}
230
231impl<B: Backend> Numeric for PerplexityMetric<B> {
232    fn value(&self) -> super::NumericEntry {
233        self.state.value()
234    }
235}
236
237#[cfg(test)]
238mod tests {
239    use super::*;
240    use crate::TestBackend;
241
242    #[test]
243    fn test_perplexity_perfect_prediction() {
244        let device = Default::default();
245        let mut metric = PerplexityMetric::<TestBackend>::new();
246
247        // Perfect prediction: target is always the highest probability class
248        let input = PerplexityInput::new(
249            Tensor::from_data(
250                [
251                    [10.0, 0.0, 0.0], // Very confident prediction for class 0
252                    [0.0, 10.0, 0.0], // Very confident prediction for class 1
253                    [0.0, 0.0, 10.0], // Very confident prediction for class 2
254                ],
255                &device,
256            ),
257            Tensor::from_data([0, 1, 2], &device),
258        );
259
260        let _entry = metric.update(&input, &MetricMetadata::fake());
261        let perplexity = metric.value().current();
262
263        // Perfect predictions should result in very low perplexity (close to 1.0)
264        assert!(
265            perplexity < 1.1,
266            "Perfect predictions should have low perplexity, got {}",
267            perplexity
268        );
269    }
270
271    #[test]
272    fn test_perplexity_uniform_prediction() {
273        let device = Default::default();
274        let mut metric = PerplexityMetric::<TestBackend>::new();
275
276        // Uniform prediction: all classes have equal probability
277        let input = PerplexityInput::new(
278            Tensor::from_data(
279                [
280                    [0.0, 0.0, 0.0], // Uniform distribution (after softmax)
281                    [0.0, 0.0, 0.0], // Uniform distribution (after softmax)
282                    [0.0, 0.0, 0.0], // Uniform distribution (after softmax)
283                ],
284                &device,
285            ),
286            Tensor::from_data([0, 1, 2], &device),
287        );
288
289        let _entry = metric.update(&input, &MetricMetadata::fake());
290        let perplexity = metric.value().current();
291
292        // Uniform distribution over 3 classes should have perplexity ≈ 3.0
293        assert!(
294            (perplexity - 3.0).abs() < 0.1,
295            "Uniform distribution perplexity should be ~3.0, got {}",
296            perplexity
297        );
298    }
299
300    #[test]
301    fn test_perplexity_with_padding() {
302        let device = Default::default();
303        let mut metric = PerplexityMetric::<TestBackend>::new().with_pad_token(3);
304
305        let input = PerplexityInput::new(
306            Tensor::from_data(
307                [
308                    [10.0, 0.0, 0.0, 0.0], // Good prediction for class 0
309                    [0.0, 10.0, 0.0, 0.0], // Good prediction for class 1
310                    [0.0, 0.0, 0.0, 1.0],  // This is padding - should be ignored
311                    [0.0, 0.0, 0.0, 1.0],  // This is padding - should be ignored
312                ],
313                &device,
314            ),
315            Tensor::from_data([0, 1, 3, 3], &device), // 3 is pad token
316        );
317
318        let _entry = metric.update(&input, &MetricMetadata::fake());
319        let perplexity = metric.value().current();
320
321        // Should only consider the first two predictions, both of which are confident
322        assert!(
323            perplexity < 1.1,
324            "Good predictions with padding should have low perplexity, got {}",
325            perplexity
326        );
327    }
328
329    #[test]
330    fn test_perplexity_wrong_prediction() {
331        let device = Default::default();
332        let mut metric = PerplexityMetric::<TestBackend>::new();
333
334        // Wrong predictions: target class has very low probability
335        let input = PerplexityInput::new(
336            Tensor::from_data(
337                [
338                    [0.0, 10.0, 0.0], // Predicts class 1, but target is 0
339                    [10.0, 0.0, 0.0], // Predicts class 0, but target is 1
340                    [0.0, 0.0, 10.0], // Predicts class 2, but target is 0
341                ],
342                &device,
343            ),
344            Tensor::from_data([0, 1, 0], &device),
345        );
346
347        let _entry = metric.update(&input, &MetricMetadata::fake());
348        let perplexity = metric.value().current();
349
350        // Wrong predictions should result in high perplexity
351        assert!(
352            perplexity > 10.0,
353            "Wrong predictions should have high perplexity, got {}",
354            perplexity
355        );
356    }
357
358    #[test]
359    fn test_perplexity_multi_batch_aggregation() {
360        let device = Default::default();
361        let mut metric = PerplexityMetric::<TestBackend>::new();
362
363        // First batch: 2 tokens with uniform distribution (log_prob ≈ -1.0986 each)
364        let input1 = PerplexityInput::new(
365            Tensor::from_data(
366                [
367                    [0.0, 0.0, 0.0], // Uniform distribution (log_prob ≈ -1.0986)
368                    [0.0, 0.0, 0.0], // Uniform distribution (log_prob ≈ -1.0986)
369                ],
370                &device,
371            ),
372            Tensor::from_data([0, 1], &device),
373        );
374
375        // Second batch: 1 token with uniform distribution
376        let input2 = PerplexityInput::new(
377            Tensor::from_data(
378                [
379                    [0.0, 0.0, 0.0], // Uniform distribution (log_prob ≈ -1.0986)
380                ],
381                &device,
382            ),
383            Tensor::from_data([2], &device),
384        );
385
386        // Update with both batches
387        let _entry1 = metric.update(&input1, &MetricMetadata::fake());
388        let _entry2 = metric.update(&input2, &MetricMetadata::fake());
389
390        let aggregated_perplexity = metric.value().current();
391
392        // For uniform distribution over 3 classes: log_prob ≈ -log(3) ≈ -1.0986
393        // Total negative log-likelihood: 3 * 1.0986 ≈ 3.2958
394        // Total tokens: 3
395        // Expected perplexity: exp(3.2958 / 3) = exp(1.0986) ≈ 3.0
396        assert!(
397            (aggregated_perplexity - 3.0).abs() < 0.1,
398            "Multi-batch aggregated perplexity should be ~3.0, got {}",
399            aggregated_perplexity
400        );
401
402        // Compare with single batch containing all data
403        let mut single_batch_metric = PerplexityMetric::<TestBackend>::new();
404        let single_input = PerplexityInput::new(
405            Tensor::from_data([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], &device),
406            Tensor::from_data([0, 1, 2], &device),
407        );
408
409        let _single_entry = single_batch_metric.update(&single_input, &MetricMetadata::fake());
410        let single_batch_perplexity = single_batch_metric.value().current();
411
412        // Multi-batch and single-batch should give the same result
413        assert!(
414            (aggregated_perplexity - single_batch_perplexity).abs() < 0.01,
415            "Multi-batch ({}) and single-batch ({}) perplexity should match",
416            aggregated_perplexity,
417            single_batch_perplexity
418        );
419    }
420}