use smol_str::format_smolstr;
use crate::{
array::Array,
dtype::Dtype,
error::{
Error, LengthMismatchPayload, OutOfRangePayload, RankMismatchPayload, Result,
ShapePairMismatchPayload, try_with_capacity,
},
lm::{
cache::{CacheConfig, make_prompt_cache},
model::Model,
},
ops,
};
pub const MIN_WINDOW: usize = 2;
pub const DEFAULT_BATCH_SIZE: usize = 8;
pub struct PerplexityResult {
pub perplexity: f32,
pub std_error: f32,
pub mean_loss: f32,
pub num_tokens: usize,
pub losses: Array,
}
pub fn make_windows(tokens: &[i32], sequence_length: usize) -> Result<Array> {
if sequence_length < MIN_WINDOW {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"perplexity::make_windows: sequence_length",
"must be >= MIN_WINDOW (one input + one target)",
format_smolstr!("{sequence_length} (MIN_WINDOW={MIN_WINDOW})"),
)));
}
let num_rows = tokens.len() / sequence_length;
if num_rows == 0 {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"perplexity::make_windows: tokens.len()",
"must be >= sequence_length to fill one window",
format_smolstr!("{} (sequence_length={sequence_length})", tokens.len()),
)));
}
let kept = num_rows * sequence_length;
Array::from_slice::<i32>(&tokens[..kept], &(num_rows, sequence_length))
}
pub fn cross_entropy_none(logits: &Array, targets: &Array) -> Result<Array> {
let logits_ndim = logits.ndim();
if logits_ndim == 0 {
return Err(Error::RankMismatch(RankMismatchPayload::new(
"perplexity::cross_entropy_none: logits must have a vocab axis (ndim >= 1)",
0,
Vec::new(),
)));
}
if targets.ndim() != logits_ndim - 1 {
return Err(Error::LengthMismatch(LengthMismatchPayload::new(
"perplexity::cross_entropy_none: targets ndim (must equal logits ndim - 1 for class-index targets)",
logits_ndim - 1,
targets.ndim(),
)));
}
let axis = (logits_ndim - 1) as i32;
let logits_shape = logits.shape();
let expected = &logits_shape[..logits_ndim - 1];
let targets_shape = targets.shape();
if targets_shape.as_slice() != expected {
return Err(Error::ShapePairMismatch(ShapePairMismatchPayload::new(
"perplexity::cross_entropy_none: targets must equal logits with the class axis removed",
expected.to_vec(),
targets_shape.to_vec(),
)));
}
let idx = targets.expand_dims_axes(&[axis])?;
let score = ops::indexing::take_along_axis(logits, &idx, axis)?;
let score = score.squeeze_axes(&[axis])?;
let lse = ops::reduction::logsumexp_axes(logits, &[axis], false)?;
ops::arithmetic::subtract(&lse, &score)
}
pub fn perplexity<M: Model>(
model: &M,
data: &Array,
batch_size: usize,
cache_config: &CacheConfig,
) -> Result<PerplexityResult> {
let shape = data.shape();
let (num_rows, seq_len) = match shape.as_slice() {
[n, l] => (*n, *l),
other => {
return Err(Error::RankMismatch(RankMismatchPayload::new(
"perplexity: data must be a rank-2 [N, L] token matrix",
other.len() as u32,
other.to_vec(),
)));
}
};
if seq_len < MIN_WINDOW {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"perplexity: window length (must hold one input + one target)",
"must be >= MIN_WINDOW",
format_smolstr!("{seq_len} (MIN_WINDOW={MIN_WINDOW})"),
)));
}
let batch_size = batch_size.max(1);
let num_batches = num_rows.div_ceil(batch_size);
let mut all_losses: Vec<Array> = try_with_capacity(num_batches)?;
let mut start = 0usize;
while start < num_rows {
let stop = (start + batch_size).min(num_rows);
let batch = ops::indexing::slice(
data,
&[start as i32, 0],
&[stop as i32, seq_len as i32],
&[1, 1],
)?;
let rows = (stop - start) as i32;
let inputs = ops::indexing::slice(&batch, &[0, 0], &[rows, (seq_len - 1) as i32], &[1, 1])?;
let targets = ops::indexing::slice(&batch, &[0, 1], &[rows, seq_len as i32], &[1, 1])?;
let mut cache = make_prompt_cache(cache_config);
let logits = model.forward(&inputs, &mut cache)?;
let logits = logits.astype(Dtype::F32)?;
let losses = cross_entropy_none(&logits, &targets)?;
let mut losses = losses.flatten(0, -1)?;
losses.eval()?;
all_losses.push(losses);
start = stop;
}
let losses = if all_losses.len() == 1 {
all_losses.into_iter().next().expect("len checked == 1")
} else {
let refs: Vec<&Array> = all_losses.iter().collect();
ops::shape::concatenate(&refs, 0)?
};
let mut mean_loss_arr = ops::reduction::mean(&losses, false)?;
let mean_loss: f32 = mean_loss_arr.item::<f32>()?;
let perplexity = mean_loss.exp();
let num_tokens = losses.size();
let mut var_arr = ops::reduction::var(&losses, false, 1)?;
let std_dev: f32 = var_arr.item::<f32>()?.sqrt();
let standard_error = std_dev / (num_tokens as f32).sqrt();
let std_error = perplexity * standard_error;
Ok(PerplexityResult {
perplexity,
std_error,
mean_loss,
num_tokens,
losses,
})
}
#[cfg(test)]
mod tests;