1use core::marker::PhantomData;
2
3use super::state::FormatOptions;
4use super::{MetricEntry, MetricMetadata, NumericEntry, format_float};
5use crate::metric::{Metric, MetricName, Numeric};
6use burn_core::tensor::backend::Backend;
7use burn_core::tensor::{ElementConversion, Int, Tensor};
8
9#[derive(Clone)]
16struct PerplexityState {
17 sum_nll: f64,
19 total_tokens: usize,
21 current: f64,
23}
24
25impl PerplexityState {
26 fn new() -> Self {
27 Self {
28 sum_nll: 0.0,
29 total_tokens: 0,
30 current: f64::NAN,
31 }
32 }
33
34 fn reset(&mut self) {
35 self.sum_nll = 0.0;
36 self.total_tokens = 0;
37 self.current = f64::NAN;
38 }
39
40 fn update(
42 &mut self,
43 sum_log_prob: f64,
44 effective_tokens: usize,
45 format: FormatOptions,
46 ) -> MetricEntry {
47 let batch_nll = -sum_log_prob;
50
51 self.sum_nll += batch_nll;
53 self.total_tokens += effective_tokens;
54
55 let batch_perplexity = if effective_tokens > 0 {
57 (batch_nll / effective_tokens as f64).exp()
58 } else {
59 f64::INFINITY
60 };
61 self.current = batch_perplexity;
62
63 let epoch_perplexity = if self.total_tokens > 0 {
65 (self.sum_nll / self.total_tokens as f64).exp()
66 } else {
67 f64::INFINITY
68 };
69
70 let (formatted_current, formatted_running) = match format.precision_value() {
72 Some(precision) => (
73 format_float(batch_perplexity, precision),
74 format_float(epoch_perplexity, precision),
75 ),
76 None => (format!("{batch_perplexity}"), format!("{epoch_perplexity}")),
77 };
78
79 let formatted = match format.unit_value() {
80 Some(unit) => {
81 format!("epoch {formatted_running} {unit} - batch {formatted_current} {unit}")
82 }
83 None => format!("epoch {formatted_running} - batch {formatted_current}"),
84 };
85
86 let serialized = NumericEntry::Aggregated {
88 sum: self.sum_nll,
89 count: self.total_tokens,
90 current: epoch_perplexity,
91 }
92 .serialize();
93
94 MetricEntry::new(format.name().clone(), formatted, serialized)
95 }
96
97 fn value(&self) -> NumericEntry {
98 let perplexity = if self.total_tokens > 0 {
99 (self.sum_nll / self.total_tokens as f64).exp()
100 } else {
101 f64::INFINITY
102 };
103
104 NumericEntry::Aggregated {
105 sum: self.sum_nll,
106 count: self.total_tokens,
107 current: perplexity,
108 }
109 }
110}
111
112#[derive(Clone)]
131pub struct PerplexityMetric<B: Backend> {
132 name: MetricName,
133 state: PerplexityState,
134 pad_token: Option<usize>,
135 _b: PhantomData<B>,
136}
137
138#[derive(new)]
140pub struct PerplexityInput<B: Backend> {
141 outputs: Tensor<B, 2>,
143 targets: Tensor<B, 1, Int>,
145}
146
147impl<B: Backend> Default for PerplexityMetric<B> {
148 fn default() -> Self {
149 Self::new()
150 }
151}
152
153impl<B: Backend> PerplexityMetric<B> {
154 pub fn new() -> Self {
156 Self {
157 name: MetricName::new("Perplexity".to_string()),
158 state: PerplexityState::new(),
159 pad_token: Default::default(),
160 _b: PhantomData,
161 }
162 }
163
164 pub fn with_pad_token(mut self, index: usize) -> Self {
170 self.pad_token = Some(index);
171 self
172 }
173}
174
175impl<B: Backend> Metric for PerplexityMetric<B> {
176 type Input = PerplexityInput<B>;
177
178 fn update(&mut self, input: &PerplexityInput<B>, _metadata: &MetricMetadata) -> MetricEntry {
179 let targets = input.targets.clone();
180 let outputs = input.outputs.clone();
181
182 let [total_tokens, _vocab_size] = outputs.dims();
183
184 let log_probs = burn_core::tensor::activation::log_softmax(outputs, 1);
186
187 let target_log_probs = log_probs
189 .gather(1, targets.clone().unsqueeze_dim(1))
190 .squeeze_dim(1);
191
192 let (sum_log_prob, effective_tokens) = match self.pad_token {
193 Some(pad_token) => {
194 let mask = targets.clone().not_equal_elem(pad_token as i64);
196
197 let masked_log_probs = target_log_probs.mask_fill(mask.clone().bool_not(), 0.0);
199
200 let sum_log_prob = masked_log_probs.sum().into_scalar().elem::<f64>();
202 let effective_tokens = mask.int().sum().into_scalar().elem::<i64>() as usize;
203
204 (sum_log_prob, effective_tokens)
205 }
206 None => {
207 let sum_log_prob = target_log_probs.sum().into_scalar().elem::<f64>();
209 (sum_log_prob, total_tokens)
210 }
211 };
212
213 self.state.update(
216 sum_log_prob,
217 effective_tokens,
218 FormatOptions::new(self.name()).precision(2),
219 )
220 }
221
222 fn clear(&mut self) {
223 self.state.reset()
224 }
225
226 fn name(&self) -> MetricName {
227 self.name.clone()
228 }
229}
230
231impl<B: Backend> Numeric for PerplexityMetric<B> {
232 fn value(&self) -> super::NumericEntry {
233 self.state.value()
234 }
235}
236
237#[cfg(test)]
238mod tests {
239 use super::*;
240 use crate::TestBackend;
241
242 #[test]
243 fn test_perplexity_perfect_prediction() {
244 let device = Default::default();
245 let mut metric = PerplexityMetric::<TestBackend>::new();
246
247 let input = PerplexityInput::new(
249 Tensor::from_data(
250 [
251 [10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0], ],
255 &device,
256 ),
257 Tensor::from_data([0, 1, 2], &device),
258 );
259
260 let _entry = metric.update(&input, &MetricMetadata::fake());
261 let perplexity = metric.value().current();
262
263 assert!(
265 perplexity < 1.1,
266 "Perfect predictions should have low perplexity, got {}",
267 perplexity
268 );
269 }
270
271 #[test]
272 fn test_perplexity_uniform_prediction() {
273 let device = Default::default();
274 let mut metric = PerplexityMetric::<TestBackend>::new();
275
276 let input = PerplexityInput::new(
278 Tensor::from_data(
279 [
280 [0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0], ],
284 &device,
285 ),
286 Tensor::from_data([0, 1, 2], &device),
287 );
288
289 let _entry = metric.update(&input, &MetricMetadata::fake());
290 let perplexity = metric.value().current();
291
292 assert!(
294 (perplexity - 3.0).abs() < 0.1,
295 "Uniform distribution perplexity should be ~3.0, got {}",
296 perplexity
297 );
298 }
299
300 #[test]
301 fn test_perplexity_with_padding() {
302 let device = Default::default();
303 let mut metric = PerplexityMetric::<TestBackend>::new().with_pad_token(3);
304
305 let input = PerplexityInput::new(
306 Tensor::from_data(
307 [
308 [10.0, 0.0, 0.0, 0.0], [0.0, 10.0, 0.0, 0.0], [0.0, 0.0, 0.0, 1.0], [0.0, 0.0, 0.0, 1.0], ],
313 &device,
314 ),
315 Tensor::from_data([0, 1, 3, 3], &device), );
317
318 let _entry = metric.update(&input, &MetricMetadata::fake());
319 let perplexity = metric.value().current();
320
321 assert!(
323 perplexity < 1.1,
324 "Good predictions with padding should have low perplexity, got {}",
325 perplexity
326 );
327 }
328
329 #[test]
330 fn test_perplexity_wrong_prediction() {
331 let device = Default::default();
332 let mut metric = PerplexityMetric::<TestBackend>::new();
333
334 let input = PerplexityInput::new(
336 Tensor::from_data(
337 [
338 [0.0, 10.0, 0.0], [10.0, 0.0, 0.0], [0.0, 0.0, 10.0], ],
342 &device,
343 ),
344 Tensor::from_data([0, 1, 0], &device),
345 );
346
347 let _entry = metric.update(&input, &MetricMetadata::fake());
348 let perplexity = metric.value().current();
349
350 assert!(
352 perplexity > 10.0,
353 "Wrong predictions should have high perplexity, got {}",
354 perplexity
355 );
356 }
357
358 #[test]
359 fn test_perplexity_multi_batch_aggregation() {
360 let device = Default::default();
361 let mut metric = PerplexityMetric::<TestBackend>::new();
362
363 let input1 = PerplexityInput::new(
365 Tensor::from_data(
366 [
367 [0.0, 0.0, 0.0], [0.0, 0.0, 0.0], ],
370 &device,
371 ),
372 Tensor::from_data([0, 1], &device),
373 );
374
375 let input2 = PerplexityInput::new(
377 Tensor::from_data(
378 [
379 [0.0, 0.0, 0.0], ],
381 &device,
382 ),
383 Tensor::from_data([2], &device),
384 );
385
386 let _entry1 = metric.update(&input1, &MetricMetadata::fake());
388 let _entry2 = metric.update(&input2, &MetricMetadata::fake());
389
390 let aggregated_perplexity = metric.value().current();
391
392 assert!(
397 (aggregated_perplexity - 3.0).abs() < 0.1,
398 "Multi-batch aggregated perplexity should be ~3.0, got {}",
399 aggregated_perplexity
400 );
401
402 let mut single_batch_metric = PerplexityMetric::<TestBackend>::new();
404 let single_input = PerplexityInput::new(
405 Tensor::from_data([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], &device),
406 Tensor::from_data([0, 1, 2], &device),
407 );
408
409 let _single_entry = single_batch_metric.update(&single_input, &MetricMetadata::fake());
410 let single_batch_perplexity = single_batch_metric.value().current();
411
412 assert!(
414 (aggregated_perplexity - single_batch_perplexity).abs() < 0.01,
415 "Multi-batch ({}) and single-batch ({}) perplexity should match",
416 aggregated_perplexity,
417 single_batch_perplexity
418 );
419 }
420}