1use core::marker::PhantomData;
2
3use super::state::FormatOptions;
4use super::{MetricMetadata, NumericEntry, SerializedEntry, format_float};
5use crate::metric::{Metric, MetricAttributes, MetricName, Numeric, NumericAttributes};
6use burn_core::tensor::backend::Backend;
7use burn_core::tensor::{ElementConversion, Int, Tensor};
8
9#[derive(Clone)]
16struct PerplexityState {
17 sum_nll: f64,
19 total_tokens: usize,
21 current: f64,
23}
24
25impl PerplexityState {
26 fn new() -> Self {
27 Self {
28 sum_nll: 0.0,
29 total_tokens: 0,
30 current: f64::NAN,
31 }
32 }
33
34 fn reset(&mut self) {
35 self.sum_nll = 0.0;
36 self.total_tokens = 0;
37 self.current = f64::NAN;
38 }
39
40 fn update(
42 &mut self,
43 sum_log_prob: f64,
44 effective_tokens: usize,
45 format: FormatOptions,
46 ) -> SerializedEntry {
47 let batch_nll = -sum_log_prob;
50
51 self.sum_nll += batch_nll;
53 self.total_tokens += effective_tokens;
54
55 let batch_perplexity = if effective_tokens > 0 {
57 (batch_nll / effective_tokens as f64).exp()
58 } else {
59 f64::INFINITY
60 };
61 self.current = batch_perplexity;
62
63 let epoch_perplexity = if self.total_tokens > 0 {
65 (self.sum_nll / self.total_tokens as f64).exp()
66 } else {
67 f64::INFINITY
68 };
69
70 let (formatted_current, formatted_running) = match format.precision_value() {
72 Some(precision) => (
73 format_float(batch_perplexity, precision),
74 format_float(epoch_perplexity, precision),
75 ),
76 None => (format!("{batch_perplexity}"), format!("{epoch_perplexity}")),
77 };
78
79 let formatted = match format.unit_value() {
80 Some(unit) => {
81 format!("epoch {formatted_running} {unit} - batch {formatted_current} {unit}")
82 }
83 None => format!("epoch {formatted_running} - batch {formatted_current}"),
84 };
85
86 let serialized = NumericEntry::Aggregated {
88 aggregated_value: epoch_perplexity,
89 count: self.total_tokens,
90 }
91 .serialize();
92
93 SerializedEntry::new(formatted, serialized)
94 }
95
96 fn value(&self) -> NumericEntry {
97 let perplexity = if self.total_tokens > 0 {
98 (self.sum_nll / self.total_tokens as f64).exp()
99 } else {
100 f64::INFINITY
101 };
102
103 NumericEntry::Aggregated {
104 aggregated_value: perplexity,
105 count: self.total_tokens,
106 }
107 }
108
109 fn running_value(&self) -> NumericEntry {
110 self.value()
111 }
112}
113
114#[derive(Clone)]
133pub struct PerplexityMetric<B: Backend> {
134 name: MetricName,
135 state: PerplexityState,
136 pad_token: Option<usize>,
137 _b: PhantomData<B>,
138}
139
140#[derive(new)]
142pub struct PerplexityInput<B: Backend> {
143 outputs: Tensor<B, 2>,
145 targets: Tensor<B, 1, Int>,
147}
148
149impl<B: Backend> Default for PerplexityMetric<B> {
150 fn default() -> Self {
151 Self::new()
152 }
153}
154
155impl<B: Backend> PerplexityMetric<B> {
156 pub fn new() -> Self {
158 Self {
159 name: MetricName::new("Perplexity".to_string()),
160 state: PerplexityState::new(),
161 pad_token: Default::default(),
162 _b: PhantomData,
163 }
164 }
165
166 pub fn with_pad_token(mut self, index: usize) -> Self {
172 self.pad_token = Some(index);
173 self
174 }
175}
176
177impl<B: Backend> Metric for PerplexityMetric<B> {
178 type Input = PerplexityInput<B>;
179
180 fn update(
181 &mut self,
182 input: &PerplexityInput<B>,
183 _metadata: &MetricMetadata,
184 ) -> SerializedEntry {
185 let targets = input.targets.clone();
186 let outputs = input.outputs.clone();
187
188 let [total_tokens, _vocab_size] = outputs.dims();
189
190 let log_probs = burn_core::tensor::activation::log_softmax(outputs, 1);
192
193 let target_log_probs = log_probs
195 .gather(1, targets.clone().unsqueeze_dim(1))
196 .squeeze_dim(1);
197
198 let (sum_log_prob, effective_tokens) = match self.pad_token {
199 Some(pad_token) => {
200 let mask = targets.clone().not_equal_elem(pad_token as i64);
202
203 let masked_log_probs = target_log_probs.mask_fill(mask.clone().bool_not(), 0.0);
205
206 let sum_log_prob = masked_log_probs.sum().into_scalar().elem::<f64>();
208 let effective_tokens = mask.int().sum().into_scalar().elem::<i64>() as usize;
209
210 (sum_log_prob, effective_tokens)
211 }
212 None => {
213 let sum_log_prob = target_log_probs.sum().into_scalar().elem::<f64>();
215 (sum_log_prob, total_tokens)
216 }
217 };
218
219 self.state.update(
222 sum_log_prob,
223 effective_tokens,
224 FormatOptions::new(self.name()).precision(2),
225 )
226 }
227
228 fn clear(&mut self) {
229 self.state.reset()
230 }
231
232 fn name(&self) -> MetricName {
233 self.name.clone()
234 }
235
236 fn attributes(&self) -> MetricAttributes {
237 NumericAttributes {
238 unit: None,
239 higher_is_better: false,
240 }
241 .into()
242 }
243}
244
245impl<B: Backend> Numeric for PerplexityMetric<B> {
246 fn value(&self) -> NumericEntry {
247 self.state.value()
248 }
249
250 fn running_value(&self) -> NumericEntry {
251 self.state.running_value()
252 }
253}
254
255#[cfg(test)]
256mod tests {
257 use super::*;
258 use crate::TestBackend;
259
260 #[test]
261 fn test_perplexity_perfect_prediction() {
262 let device = Default::default();
263 let mut metric = PerplexityMetric::<TestBackend>::new();
264
265 let input = PerplexityInput::new(
267 Tensor::from_data(
268 [
269 [10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0], ],
273 &device,
274 ),
275 Tensor::from_data([0, 1, 2], &device),
276 );
277
278 let _entry = metric.update(&input, &MetricMetadata::fake());
279 let perplexity = metric.value().current();
280
281 assert!(
283 perplexity < 1.1,
284 "Perfect predictions should have low perplexity, got {}",
285 perplexity
286 );
287 }
288
289 #[test]
290 fn test_perplexity_uniform_prediction() {
291 let device = Default::default();
292 let mut metric = PerplexityMetric::<TestBackend>::new();
293
294 let input = PerplexityInput::new(
296 Tensor::from_data(
297 [
298 [0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0], ],
302 &device,
303 ),
304 Tensor::from_data([0, 1, 2], &device),
305 );
306
307 let _entry = metric.update(&input, &MetricMetadata::fake());
308 let perplexity = metric.value().current();
309
310 assert!(
312 (perplexity - 3.0).abs() < 0.1,
313 "Uniform distribution perplexity should be ~3.0, got {}",
314 perplexity
315 );
316 }
317
318 #[test]
319 fn test_perplexity_with_padding() {
320 let device = Default::default();
321 let mut metric = PerplexityMetric::<TestBackend>::new().with_pad_token(3);
322
323 let input = PerplexityInput::new(
324 Tensor::from_data(
325 [
326 [10.0, 0.0, 0.0, 0.0], [0.0, 10.0, 0.0, 0.0], [0.0, 0.0, 0.0, 1.0], [0.0, 0.0, 0.0, 1.0], ],
331 &device,
332 ),
333 Tensor::from_data([0, 1, 3, 3], &device), );
335
336 let _entry = metric.update(&input, &MetricMetadata::fake());
337 let perplexity = metric.value().current();
338
339 assert!(
341 perplexity < 1.1,
342 "Good predictions with padding should have low perplexity, got {}",
343 perplexity
344 );
345 }
346
347 #[test]
348 fn test_perplexity_wrong_prediction() {
349 let device = Default::default();
350 let mut metric = PerplexityMetric::<TestBackend>::new();
351
352 let input = PerplexityInput::new(
354 Tensor::from_data(
355 [
356 [0.0, 10.0, 0.0], [10.0, 0.0, 0.0], [0.0, 0.0, 10.0], ],
360 &device,
361 ),
362 Tensor::from_data([0, 1, 0], &device),
363 );
364
365 let _entry = metric.update(&input, &MetricMetadata::fake());
366 let perplexity = metric.value().current();
367
368 assert!(
370 perplexity > 10.0,
371 "Wrong predictions should have high perplexity, got {}",
372 perplexity
373 );
374 }
375
376 #[test]
377 fn test_perplexity_multi_batch_aggregation() {
378 let device = Default::default();
379 let mut metric = PerplexityMetric::<TestBackend>::new();
380
381 let input1 = PerplexityInput::new(
383 Tensor::from_data(
384 [
385 [0.0, 0.0, 0.0], [0.0, 0.0, 0.0], ],
388 &device,
389 ),
390 Tensor::from_data([0, 1], &device),
391 );
392
393 let input2 = PerplexityInput::new(
395 Tensor::from_data(
396 [
397 [0.0, 0.0, 0.0], ],
399 &device,
400 ),
401 Tensor::from_data([2], &device),
402 );
403
404 let _entry1 = metric.update(&input1, &MetricMetadata::fake());
406 let _entry2 = metric.update(&input2, &MetricMetadata::fake());
407
408 let aggregated_perplexity = metric.value().current();
409
410 assert!(
415 (aggregated_perplexity - 3.0).abs() < 0.1,
416 "Multi-batch aggregated perplexity should be ~3.0, got {}",
417 aggregated_perplexity
418 );
419
420 let mut single_batch_metric = PerplexityMetric::<TestBackend>::new();
422 let single_input = PerplexityInput::new(
423 Tensor::from_data([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], &device),
424 Tensor::from_data([0, 1, 2], &device),
425 );
426
427 let _single_entry = single_batch_metric.update(&single_input, &MetricMetadata::fake());
428 let single_batch_perplexity = single_batch_metric.value().current();
429
430 assert!(
432 (aggregated_perplexity - single_batch_perplexity).abs() < 0.01,
433 "Multi-batch ({}) and single-batch ({}) perplexity should match",
434 aggregated_perplexity,
435 single_batch_perplexity
436 );
437 }
438}