use core::marker::PhantomData;
use super::state::FormatOptions;
use super::{MetricMetadata, NumericEntry, SerializedEntry, format_float};
use crate::metric::{Metric, MetricAttributes, MetricName, Numeric, NumericAttributes};
use burn_core::tensor::backend::Backend;
use burn_core::tensor::{ElementConversion, Int, Tensor};
#[derive(Clone)]
struct PerplexityState {
sum_nll: f64,
total_tokens: usize,
current: f64,
}
impl PerplexityState {
fn new() -> Self {
Self {
sum_nll: 0.0,
total_tokens: 0,
current: f64::NAN,
}
}
fn reset(&mut self) {
self.sum_nll = 0.0;
self.total_tokens = 0;
self.current = f64::NAN;
}
fn update(
&mut self,
sum_log_prob: f64,
effective_tokens: usize,
format: FormatOptions,
) -> SerializedEntry {
let batch_nll = -sum_log_prob;
self.sum_nll += batch_nll;
self.total_tokens += effective_tokens;
let batch_perplexity = if effective_tokens > 0 {
(batch_nll / effective_tokens as f64).exp()
} else {
f64::INFINITY
};
self.current = batch_perplexity;
let epoch_perplexity = if self.total_tokens > 0 {
(self.sum_nll / self.total_tokens as f64).exp()
} else {
f64::INFINITY
};
let (formatted_current, formatted_running) = match format.precision_value() {
Some(precision) => (
format_float(batch_perplexity, precision),
format_float(epoch_perplexity, precision),
),
None => (format!("{batch_perplexity}"), format!("{epoch_perplexity}")),
};
let formatted = match format.unit_value() {
Some(unit) => {
format!("epoch {formatted_running} {unit} - batch {formatted_current} {unit}")
}
None => format!("epoch {formatted_running} - batch {formatted_current}"),
};
let serialized = NumericEntry::Aggregated {
aggregated_value: epoch_perplexity,
count: self.total_tokens,
}
.serialize();
SerializedEntry::new(formatted, serialized)
}
fn value(&self) -> NumericEntry {
let perplexity = if self.total_tokens > 0 {
(self.sum_nll / self.total_tokens as f64).exp()
} else {
f64::INFINITY
};
NumericEntry::Aggregated {
aggregated_value: perplexity,
count: self.total_tokens,
}
}
fn running_value(&self) -> NumericEntry {
self.value()
}
}
#[derive(Clone)]
pub struct PerplexityMetric<B: Backend> {
name: MetricName,
state: PerplexityState,
pad_token: Option<usize>,
_b: PhantomData<B>,
}
#[derive(new)]
pub struct PerplexityInput<B: Backend> {
outputs: Tensor<B, 2>,
targets: Tensor<B, 1, Int>,
}
impl<B: Backend> Default for PerplexityMetric<B> {
fn default() -> Self {
Self::new()
}
}
impl<B: Backend> PerplexityMetric<B> {
pub fn new() -> Self {
Self {
name: MetricName::new("Perplexity".to_string()),
state: PerplexityState::new(),
pad_token: Default::default(),
_b: PhantomData,
}
}
pub fn with_pad_token(mut self, index: usize) -> Self {
self.pad_token = Some(index);
self
}
}
impl<B: Backend> Metric for PerplexityMetric<B> {
type Input = PerplexityInput<B>;
fn update(
&mut self,
input: &PerplexityInput<B>,
_metadata: &MetricMetadata,
) -> SerializedEntry {
let targets = input.targets.clone();
let outputs = input.outputs.clone();
let [total_tokens, _vocab_size] = outputs.dims();
let log_probs = burn_core::tensor::activation::log_softmax(outputs, 1);
let target_log_probs = log_probs
.gather(1, targets.clone().unsqueeze_dim(1))
.squeeze_dim(1);
let (sum_log_prob, effective_tokens) = match self.pad_token {
Some(pad_token) => {
let mask = targets.clone().not_equal_elem(pad_token as i64);
let masked_log_probs = target_log_probs.mask_fill(mask.clone().bool_not(), 0.0);
let sum_log_prob = masked_log_probs.sum().into_scalar().elem::<f64>();
let effective_tokens = mask.int().sum().into_scalar().elem::<i64>() as usize;
(sum_log_prob, effective_tokens)
}
None => {
let sum_log_prob = target_log_probs.sum().into_scalar().elem::<f64>();
(sum_log_prob, total_tokens)
}
};
self.state.update(
sum_log_prob,
effective_tokens,
FormatOptions::new(self.name()).precision(2),
)
}
fn clear(&mut self) {
self.state.reset()
}
fn name(&self) -> MetricName {
self.name.clone()
}
fn attributes(&self) -> MetricAttributes {
NumericAttributes {
unit: None,
higher_is_better: false,
}
.into()
}
}
impl<B: Backend> Numeric for PerplexityMetric<B> {
fn value(&self) -> NumericEntry {
self.state.value()
}
fn running_value(&self) -> NumericEntry {
self.state.running_value()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;
#[test]
fn test_perplexity_perfect_prediction() {
let device = Default::default();
let mut metric = PerplexityMetric::<TestBackend>::new();
let input = PerplexityInput::new(
Tensor::from_data(
[
[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0], ],
&device,
),
Tensor::from_data([0, 1, 2], &device),
);
let _entry = metric.update(&input, &MetricMetadata::fake());
let perplexity = metric.value().current();
assert!(
perplexity < 1.1,
"Perfect predictions should have low perplexity, got {}",
perplexity
);
}
#[test]
fn test_perplexity_uniform_prediction() {
let device = Default::default();
let mut metric = PerplexityMetric::<TestBackend>::new();
let input = PerplexityInput::new(
Tensor::from_data(
[
[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0], ],
&device,
),
Tensor::from_data([0, 1, 2], &device),
);
let _entry = metric.update(&input, &MetricMetadata::fake());
let perplexity = metric.value().current();
assert!(
(perplexity - 3.0).abs() < 0.1,
"Uniform distribution perplexity should be ~3.0, got {}",
perplexity
);
}
#[test]
fn test_perplexity_with_padding() {
let device = Default::default();
let mut metric = PerplexityMetric::<TestBackend>::new().with_pad_token(3);
let input = PerplexityInput::new(
Tensor::from_data(
[
[10.0, 0.0, 0.0, 0.0], [0.0, 10.0, 0.0, 0.0], [0.0, 0.0, 0.0, 1.0], [0.0, 0.0, 0.0, 1.0], ],
&device,
),
Tensor::from_data([0, 1, 3, 3], &device), );
let _entry = metric.update(&input, &MetricMetadata::fake());
let perplexity = metric.value().current();
assert!(
perplexity < 1.1,
"Good predictions with padding should have low perplexity, got {}",
perplexity
);
}
#[test]
fn test_perplexity_wrong_prediction() {
let device = Default::default();
let mut metric = PerplexityMetric::<TestBackend>::new();
let input = PerplexityInput::new(
Tensor::from_data(
[
[0.0, 10.0, 0.0], [10.0, 0.0, 0.0], [0.0, 0.0, 10.0], ],
&device,
),
Tensor::from_data([0, 1, 0], &device),
);
let _entry = metric.update(&input, &MetricMetadata::fake());
let perplexity = metric.value().current();
assert!(
perplexity > 10.0,
"Wrong predictions should have high perplexity, got {}",
perplexity
);
}
#[test]
fn test_perplexity_multi_batch_aggregation() {
let device = Default::default();
let mut metric = PerplexityMetric::<TestBackend>::new();
let input1 = PerplexityInput::new(
Tensor::from_data(
[
[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], ],
&device,
),
Tensor::from_data([0, 1], &device),
);
let input2 = PerplexityInput::new(
Tensor::from_data(
[
[0.0, 0.0, 0.0], ],
&device,
),
Tensor::from_data([2], &device),
);
let _entry1 = metric.update(&input1, &MetricMetadata::fake());
let _entry2 = metric.update(&input2, &MetricMetadata::fake());
let aggregated_perplexity = metric.value().current();
assert!(
(aggregated_perplexity - 3.0).abs() < 0.1,
"Multi-batch aggregated perplexity should be ~3.0, got {}",
aggregated_perplexity
);
let mut single_batch_metric = PerplexityMetric::<TestBackend>::new();
let single_input = PerplexityInput::new(
Tensor::from_data([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], &device),
Tensor::from_data([0, 1, 2], &device),
);
let _single_entry = single_batch_metric.update(&single_input, &MetricMetadata::fake());
let single_batch_perplexity = single_batch_metric.value().current();
assert!(
(aggregated_perplexity - single_batch_perplexity).abs() < 0.01,
"Multi-batch ({}) and single-batch ({}) perplexity should match",
aggregated_perplexity,
single_batch_perplexity
);
}
}