use candle_core::{Result, Tensor, D};
#[derive(Debug, Clone)]
pub struct PerplexityResult {
pub perplexity: f64,
pub avg_loss: f64,
pub num_tokens: usize,
pub total_loss: f64,
}
impl std::fmt::Display for PerplexityResult {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"PPL: {:.2} | Loss: {:.4} | Tokens: {}",
self.perplexity, self.avg_loss, self.num_tokens
)
}
}
pub fn compute_perplexity(
logits: &Tensor,
targets: &Tensor,
ignore_index: Option<i64>,
) -> Result<PerplexityResult> {
let ignore_idx = ignore_index.unwrap_or(-100);
let (batch_size, seq_len, vocab_size) = logits.dims3()?;
let logits_flat = logits.reshape((batch_size * seq_len, vocab_size))?;
let targets_flat = targets.reshape(batch_size * seq_len)?;
let log_probs = candle_nn::ops::log_softmax(&logits_flat, D::Minus1)?;
let targets_i64 = targets_flat.to_vec1::<i64>()?;
let mut total_loss = 0.0f64;
let mut num_tokens = 0usize;
for (i, &target) in targets_i64.iter().enumerate() {
if target == ignore_idx || target < 0 || target >= vocab_size as i64 {
continue;
}
let log_prob = log_probs
.get(i)?
.get(target as usize)?
.to_scalar::<f32>()?;
total_loss += -log_prob as f64;
num_tokens += 1;
}
if num_tokens == 0 {
return Ok(PerplexityResult {
perplexity: f64::INFINITY,
avg_loss: f64::INFINITY,
num_tokens: 0,
total_loss: 0.0,
});
}
let avg_loss = total_loss / num_tokens as f64;
let perplexity = avg_loss.exp();
Ok(PerplexityResult {
perplexity,
avg_loss,
num_tokens,
total_loss,
})
}
pub fn compute_perplexity_chunked<F>(
token_ids: &[u32],
chunk_size: usize,
forward_fn: F,
) -> Result<PerplexityResult>
where
F: Fn(&[u32]) -> Result<Tensor>,
{
let mut total_loss = 0.0f64;
let mut total_tokens = 0usize;
let stride = chunk_size / 2; let mut pos = 0;
while pos + chunk_size <= token_ids.len() {
let chunk = &token_ids[pos..pos + chunk_size];
let input = &chunk[..chunk_size - 1];
let targets: Vec<i64> = chunk[1..].iter().map(|&x| x as i64).collect();
let logits = forward_fn(input)?;
let targets_tensor = Tensor::from_vec(
targets.clone(),
(1, targets.len()),
logits.device(),
)?;
let result = compute_perplexity(&logits, &targets_tensor, Some(-100))?;
total_loss += result.total_loss;
total_tokens += result.num_tokens;
pos += stride;
}
if total_tokens == 0 {
return Ok(PerplexityResult {
perplexity: f64::INFINITY,
avg_loss: f64::INFINITY,
num_tokens: 0,
total_loss: 0.0,
});
}
let avg_loss = total_loss / total_tokens as f64;
let perplexity = avg_loss.exp();
Ok(PerplexityResult {
perplexity,
avg_loss,
num_tokens: total_tokens,
total_loss,
})
}
#[cfg(test)]
mod tests {
use super::*;
use candle_core::Device;
#[test]
fn test_perplexity_calculation() -> Result<()> {
let device = Device::Cpu;
let logits_data: Vec<f32> = vec![
-10.0, 10.0, -10.0, -10.0, -10.0,
-10.0, -10.0, 10.0, -10.0, -10.0,
-10.0, -10.0, -10.0, 10.0, -10.0,
];
let logits = Tensor::from_vec(logits_data, (1, 3, 5), &device)?;
let targets = Tensor::from_vec(vec![1i64, 2, 3], (1, 3), &device)?;
let result = compute_perplexity(&logits, &targets, None)?;
assert!(result.perplexity < 2.0, "PPL should be low: {}", result.perplexity);
assert_eq!(result.num_tokens, 3);
Ok(())
}
#[test]
fn test_perplexity_with_ignore_index() -> Result<()> {
let device = Device::Cpu;
let logits = Tensor::from_vec(
vec![
-10.0f32, 10.0, -10.0, -10.0, -10.0,
-10.0, -10.0, 10.0, -10.0, -10.0,
-10.0, -10.0, -10.0, 10.0, -10.0,
],
(1, 3, 5),
&device,
)?;
let targets = Tensor::from_vec(vec![1i64, -100, 3], (1, 3), &device)?;
let result = compute_perplexity(&logits, &targets, Some(-100))?;
assert_eq!(result.num_tokens, 2);
Ok(())
}
#[test]
fn test_perplexity_uniform_distribution() -> Result<()> {
let device = Device::Cpu;
let vocab_size = 100;
let logits_data: Vec<f32> = vec![0.0; vocab_size];
let logits = Tensor::from_vec(logits_data, (1, 1, vocab_size), &device)?;
let targets = Tensor::from_vec(vec![50i64], (1, 1), &device)?;
let result = compute_perplexity(&logits, &targets, None)?;
assert!(
(result.perplexity - vocab_size as f64).abs() < 1.0,
"PPL should be ~{}: {}",
vocab_size,
result.perplexity
);
Ok(())
}
}