use crate::error::{Result, TextError};
use std::path::Path;
pub trait LanguageModelLike {
fn log_prob_sequence(&self, tokens: &[&str]) -> Option<f64>;
fn vocabulary_size(&self) -> usize;
}
#[derive(Debug, Clone)]
pub struct PerplexityReport {
pub corpus_perplexity: f64,
pub per_sentence_perplexity: Vec<f64>,
pub total_tokens: usize,
pub total_log_prob: f64,
}
pub fn perplexity_evaluate(
model: &dyn LanguageModelLike,
corpus: &[Vec<&str>],
) -> Result<PerplexityReport> {
if corpus.is_empty() {
return Err(TextError::InvalidInput("corpus is empty".into()));
}
let mut total_log_prob = 0.0f64;
let mut total_tokens = 0usize;
let mut per_sentence = Vec::with_capacity(corpus.len());
for sentence in corpus {
if sentence.is_empty() {
per_sentence.push(f64::NAN);
continue;
}
match model.log_prob_sequence(sentence) {
Some(lp) => {
let n = sentence.len();
let ppl = (-lp / n as f64).exp();
per_sentence.push(ppl);
total_log_prob += lp;
total_tokens += n;
}
None => {
per_sentence.push(f64::NAN);
}
}
}
if total_tokens == 0 {
return Err(TextError::InvalidInput(
"no tokens found in corpus (all sentences are empty)".into(),
));
}
let corpus_ppl = (-total_log_prob / total_tokens as f64).exp();
Ok(PerplexityReport {
corpus_perplexity: corpus_ppl,
per_sentence_perplexity: per_sentence,
total_tokens,
total_log_prob,
})
}
pub fn load_token_corpus(path: impl AsRef<Path>) -> Result<Vec<Vec<String>>> {
use std::fs::File;
use std::io::{BufRead, BufReader};
let file = File::open(path.as_ref()).map_err(|e| TextError::IoError(e.to_string()))?;
let reader = BufReader::new(file);
let mut result = Vec::new();
for line in reader.lines() {
let line = line.map_err(|e| TextError::IoError(e.to_string()))?;
let tokens: Vec<String> = line.split_whitespace().map(str::to_owned).collect();
if !tokens.is_empty() {
result.push(tokens);
}
}
Ok(result)
}
impl LanguageModelLike for crate::language_models::NgramLM {
fn log_prob_sequence(&self, tokens: &[&str]) -> Option<f64> {
if tokens.is_empty() {
return None;
}
let n = self.n;
let mut log_prob = 0.0f64;
for i in 0..tokens.len() {
let ctx_start = if i >= n - 1 { i + 1 - n } else { 0 };
let context: Vec<&str> = tokens[ctx_start..i].to_vec();
let word = tokens[i];
let p = self.probability(word, &context);
log_prob += if p <= 0.0 { 1e-10_f64.ln() } else { p.ln() };
}
Some(log_prob)
}
fn vocabulary_size(&self) -> usize {
let mut vocab = std::collections::HashSet::new();
for key in self.counts.keys() {
if let Some(word) = key.last() {
vocab.insert(word.as_str());
}
}
vocab.len().max(1)
}
}
impl LanguageModelLike for crate::language_model::NgramModel {
fn log_prob_sequence(&self, tokens: &[&str]) -> Option<f64> {
if tokens.is_empty() {
return None;
}
let n = self.order();
let mut log_prob = 0.0f64;
for i in 0..tokens.len() {
let ctx_start = if i >= n - 1 { i + 1 - n } else { 0 };
let context: Vec<&str> = tokens[ctx_start..i].to_vec();
let word = tokens[i];
let padded_ctx: Vec<&str> = if context.len() < n - 1 {
let needed = n - 1 - context.len();
let mut v: Vec<&str> = vec!["<START>"; needed];
v.extend_from_slice(&context);
v
} else {
context
};
let p = self.probability(&padded_ctx, word).unwrap_or(1e-10);
log_prob += if p <= 0.0 { 1e-10_f64.ln() } else { p.ln() };
}
Some(log_prob)
}
fn vocabulary_size(&self) -> usize {
self.vocabulary_size()
}
}
#[cfg(test)]
mod tests {
use super::*;
struct UniformModel {
vocab: usize,
}
impl LanguageModelLike for UniformModel {
fn log_prob_sequence(&self, tokens: &[&str]) -> Option<f64> {
if tokens.is_empty() {
return None;
}
Some(tokens.len() as f64 * -(self.vocab as f64).ln())
}
fn vocabulary_size(&self) -> usize {
self.vocab
}
}
struct PerfectModel;
impl LanguageModelLike for PerfectModel {
fn log_prob_sequence(&self, tokens: &[&str]) -> Option<f64> {
if tokens.is_empty() {
return None;
}
Some(0.0_f64) }
fn vocabulary_size(&self) -> usize {
1
}
}
#[test]
fn perplexity_uniform_model_equals_vocab_size() {
let model = UniformModel { vocab: 100 };
let corpus = vec![vec!["a", "b", "c", "d", "e"]];
let report = perplexity_evaluate(&model, &corpus).expect("evaluate");
assert!(
(report.corpus_perplexity - 100.0).abs() < 1e-6,
"expected 100.0, got {}",
report.corpus_perplexity
);
}
#[test]
fn perplexity_of_perfect_predictor_is_one() {
let model = PerfectModel;
let corpus = vec![vec!["a", "b", "c"]];
let report = perplexity_evaluate(&model, &corpus).expect("evaluate");
assert!(
(report.corpus_perplexity - 1.0).abs() < 1e-9,
"expected 1.0, got {}",
report.corpus_perplexity
);
}
#[test]
fn perplexity_corpus_aggregates_token_log_probs() {
let model = UniformModel { vocab: 10 };
let corpus = vec![vec!["a", "b", "c"], vec!["d", "e"]];
let report = perplexity_evaluate(&model, &corpus).expect("evaluate");
assert_eq!(report.total_tokens, 5);
let expected_lp = 5.0 * -(10.0f64).ln();
assert!(
(report.total_log_prob - expected_lp).abs() < 1e-9,
"expected total_log_prob {expected_lp}, got {}",
report.total_log_prob
);
}
#[test]
fn perplexity_empty_corpus_returns_error() {
let model = UniformModel { vocab: 10 };
let result = perplexity_evaluate(&model, &[]);
assert!(result.is_err());
}
#[test]
fn perplexity_per_sentence_are_positive() {
let model = UniformModel { vocab: 5 };
let corpus = vec![vec!["a"], vec!["b", "c"]];
let report = perplexity_evaluate(&model, &corpus).expect("evaluate");
for &ppl in &report.per_sentence_perplexity {
assert!(ppl > 0.0 && ppl.is_finite(), "per-sentence ppl = {ppl}");
}
}
#[test]
fn perplexity_all_empty_sentences_returns_error() {
let model = UniformModel { vocab: 5 };
let corpus: Vec<Vec<&str>> = vec![vec![], vec![]];
let result = perplexity_evaluate(&model, &corpus);
assert!(result.is_err());
}
}