#[derive(Debug, Clone, PartialEq)]
pub enum PerplexityOutcome {
Ok {
ppl: f64,
mean_nll: f64,
num_tokens: usize,
},
EmptyLogProbs,
NonFiniteLogProb,
PositiveLogProb(f64),
}
pub fn compute_perplexity(log_probs: &[f64]) -> PerplexityOutcome {
if log_probs.is_empty() {
return PerplexityOutcome::EmptyLogProbs;
}
for &lp in log_probs {
if !lp.is_finite() {
return PerplexityOutcome::NonFiniteLogProb;
}
if lp > 0.0 {
return PerplexityOutcome::PositiveLogProb(lp);
}
}
let n = log_probs.len();
let sum: f64 = log_probs.iter().sum();
let mean_nll = -sum / (n as f64);
let ppl = mean_nll.exp();
PerplexityOutcome::Ok {
ppl,
mean_nll,
num_tokens: n,
}
}
pub fn classify_ppl_at_least_one(log_probs: &[f64]) -> bool {
match compute_perplexity(log_probs) {
PerplexityOutcome::Ok { ppl, .. } => ppl >= 1.0 && ppl.is_finite(),
_ => false,
}
}
pub fn classify_empty_distinct() -> bool {
matches!(compute_perplexity(&[]), PerplexityOutcome::EmptyLogProbs)
}
pub fn classify_nan_rejected() -> bool {
matches!(
compute_perplexity(&[-1.0, f64::NAN, -2.0]),
PerplexityOutcome::NonFiniteLogProb
)
}
pub fn classify_inf_rejected() -> bool {
matches!(
compute_perplexity(&[-1.0, f64::INFINITY, -2.0]),
PerplexityOutcome::NonFiniteLogProb
)
}
pub fn classify_positive_log_prob_rejected() -> bool {
matches!(
compute_perplexity(&[-1.0, 0.5, -2.0]),
PerplexityOutcome::PositiveLogProb(_)
)
}
pub fn classify_perfect_prediction_is_one() -> bool {
matches!(
compute_perplexity(&[0.0, 0.0, 0.0, 0.0]),
PerplexityOutcome::Ok { ppl, .. } if (ppl - 1.0).abs() < 1e-12
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty_log_probs_distinct_outcome() {
assert!(matches!(
compute_perplexity(&[]),
PerplexityOutcome::EmptyLogProbs
));
}
#[test]
fn perfect_prediction_ppl_is_one() {
match compute_perplexity(&[0.0, 0.0, 0.0]) {
PerplexityOutcome::Ok {
ppl,
mean_nll,
num_tokens,
} => {
assert!((ppl - 1.0).abs() < 1e-12, "ppl={ppl} expected 1.0");
assert!(mean_nll.abs() < 1e-12);
assert_eq!(num_tokens, 3);
}
other => panic!("expected Ok, got {other:?}"),
}
}
#[test]
fn uniform_log_prob_gives_correct_ppl() {
let ln_half = -std::f64::consts::LN_2;
let samples = vec![ln_half; 16];
match compute_perplexity(&samples) {
PerplexityOutcome::Ok {
ppl,
mean_nll,
num_tokens,
} => {
assert!((ppl - 2.0).abs() < 1e-12, "ppl={ppl}");
assert!((mean_nll - std::f64::consts::LN_2).abs() < 1e-12);
assert_eq!(num_tokens, 16);
}
other => panic!("expected Ok, got {other:?}"),
}
}
#[test]
fn ppl_at_least_one_invariant_holds() {
let samples = [-1.0_f64, -0.5, -2.3, -0.01];
assert!(classify_ppl_at_least_one(&samples));
if let PerplexityOutcome::Ok { ppl, .. } = compute_perplexity(&samples) {
assert!(ppl >= 1.0, "ppl={ppl} must be >= 1.0");
} else {
panic!("expected Ok");
}
}
#[test]
fn nan_log_prob_rejected() {
assert!(matches!(
compute_perplexity(&[-1.0, f64::NAN]),
PerplexityOutcome::NonFiniteLogProb
));
}
#[test]
fn positive_infinity_rejected() {
assert!(matches!(
compute_perplexity(&[-1.0, f64::INFINITY]),
PerplexityOutcome::NonFiniteLogProb
));
}
#[test]
fn negative_infinity_rejected() {
assert!(matches!(
compute_perplexity(&[-1.0, f64::NEG_INFINITY]),
PerplexityOutcome::NonFiniteLogProb
));
}
#[test]
fn positive_log_prob_rejected() {
match compute_perplexity(&[-1.0, 0.25, -2.0]) {
PerplexityOutcome::PositiveLogProb(v) => assert!((v - 0.25).abs() < 1e-12),
other => panic!("expected PositiveLogProb, got {other:?}"),
}
}
#[test]
fn single_log_prob_works() {
match compute_perplexity(&[-1.0]) {
PerplexityOutcome::Ok {
ppl,
mean_nll,
num_tokens,
} => {
assert!((ppl - std::f64::consts::E).abs() < 1e-12);
assert!((mean_nll - 1.0).abs() < 1e-12);
assert_eq!(num_tokens, 1);
}
other => panic!("expected Ok, got {other:?}"),
}
}
#[test]
fn ppl_monotone_in_mean_nll() {
let a = [-0.5_f64, -0.5, -0.5];
let b = [-1.0_f64, -1.0, -1.0];
let ppl_a = match compute_perplexity(&a) {
PerplexityOutcome::Ok { ppl, .. } => ppl,
_ => panic!(),
};
let ppl_b = match compute_perplexity(&b) {
PerplexityOutcome::Ok { ppl, .. } => ppl,
_ => panic!(),
};
assert!(ppl_a < ppl_b, "ppl({ppl_a}) should be < ppl({ppl_b})");
}
#[test]
fn classifier_functions_all_pass() {
assert!(classify_empty_distinct());
assert!(classify_nan_rejected());
assert!(classify_inf_rejected());
assert!(classify_positive_log_prob_rejected());
assert!(classify_perfect_prediction_is_one());
}
#[test]
fn num_tokens_matches_input_length() {
for n in [1usize, 5, 100, 1000] {
let samples = vec![-0.7_f64; n];
match compute_perplexity(&samples) {
PerplexityOutcome::Ok { num_tokens, .. } => assert_eq!(num_tokens, n),
other => panic!("n={n}: expected Ok, got {other:?}"),
}
}
}
#[test]
fn known_wikitext_ballpark_ppl() {
let mean_nll = 1.8_f64;
let log_probs = vec![-mean_nll; 256];
match compute_perplexity(&log_probs) {
PerplexityOutcome::Ok { ppl, .. } => {
assert!((ppl - mean_nll.exp()).abs() < 1e-9);
assert!((5.5..=7.5).contains(&ppl), "ppl={ppl}");
}
other => panic!("expected Ok, got {other:?}"),
}
}
}