use serde::Serialize;
#[derive(Debug, Serialize)]
pub struct PerplexityResult {
pub mean_ppl: f32,
pub min_ppl: f32,
pub max_ppl: f32,
pub std_ppl: f32,
pub n_samples: usize,
pub total_tokens: usize,
}
pub struct PerplexityEvaluator {
pub stride: usize,
pub max_length: Option<usize>,
}
impl Default for PerplexityEvaluator {
fn default() -> Self {
Self::new()
}
}
impl PerplexityEvaluator {
pub fn new() -> Self {
Self {
stride: 512,
max_length: None,
}
}
pub fn with_stride(stride: usize) -> Self {
Self {
stride,
max_length: None,
}
}
pub fn compute(&self, log_probs: &[f32]) -> f32 {
let probs = match self.max_length {
Some(max) => &log_probs[..log_probs.len().min(max)],
None => log_probs,
};
if probs.is_empty() {
return f32::INFINITY;
}
let n = probs.len() as f32;
let avg_neg_log_prob = -probs.iter().copied().sum::<f32>() / n;
avg_neg_log_prob.exp()
}
pub fn compute_batch(&self, log_probs_batch: &[Vec<f32>]) -> PerplexityResult {
let ppls: Vec<f32> = log_probs_batch
.iter()
.filter(|lp| !lp.is_empty())
.map(|lp| self.compute(lp))
.collect();
let total_tokens: usize = log_probs_batch.iter().map(Vec::len).sum();
if ppls.is_empty() {
return PerplexityResult {
mean_ppl: f32::INFINITY,
min_ppl: f32::INFINITY,
max_ppl: f32::INFINITY,
std_ppl: 0.0,
n_samples: 0,
total_tokens,
};
}
let n = ppls.len() as f32;
let mean_ppl = ppls.iter().copied().sum::<f32>() / n;
let min_ppl = ppls.iter().cloned().fold(f32::INFINITY, f32::min);
let max_ppl = ppls.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let variance = ppls.iter().map(|p| (p - mean_ppl).powi(2)).sum::<f32>() / n;
let std_ppl = variance.sqrt();
PerplexityResult {
mean_ppl,
min_ppl,
max_ppl,
std_ppl,
n_samples: ppls.len(),
total_tokens,
}
}
pub fn from_logits(&self, logits: &[Vec<f32>], token_ids: &[u32]) -> f32 {
let len = logits.len().min(token_ids.len());
if len == 0 {
return f32::INFINITY;
}
let log_probs: Vec<f32> = logits[..len]
.iter()
.zip(token_ids[..len].iter())
.map(|(logit_vec, &token_id)| {
let max_logit = logit_vec.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exp_sum: f32 = logit_vec.iter().map(|&l| (l - max_logit).exp()).sum();
let log_sum_exp = max_logit + exp_sum.ln();
let tid = token_id as usize;
logit_vec[tid] - log_sum_exp
})
.collect();
self.compute(&log_probs)
}
pub fn bits_per_byte(&self, log_probs: &[f32], n_bytes: usize) -> f32 {
let probs = match self.max_length {
Some(max) => &log_probs[..log_probs.len().min(max)],
None => log_probs,
};
if probs.is_empty() || n_bytes == 0 {
return f32::INFINITY;
}
let log2_e: f32 = std::f32::consts::E.log2();
let neg_sum_log2_prob: f32 = probs.iter().map(|&lp| -lp * log2_e).sum();
neg_sum_log2_prob / n_bytes as f32
}
}