use std::path::Path;
use std::time::Instant;
use comfy_table::{presets::UTF8_FULL, Table};
use console::style;
use crate::cli::args::{CorpusFormat, EvalCommands, EvalCompareArgs, EvalPerplexityArgs};
use crate::cli::error::{print_success, CliError, CliResult};
use crate::corpus::{CorpusReader, GutenbergReader, PlaintextReader, Tokenizer, WikipediaReader};
pub fn run(cmd: EvalCommands, verbose: bool, quiet: bool) -> CliResult<()> {
match cmd {
EvalCommands::Perplexity(args) => eval_perplexity(args, verbose, quiet),
EvalCommands::Compare(args) => eval_compare(args, verbose, quiet),
}
}
#[derive(Debug, Clone)]
pub struct PerplexityResult {
pub perplexity: f64,
pub log_probability: f64,
pub sentences: u64,
pub tokens: u64,
pub oov_tokens: u64,
pub per_sentence: Option<Vec<f64>>,
pub elapsed_secs: f64,
}
trait PerplexityModel {
fn sentence_log_prob(&self, tokens: &[&str]) -> f64;
fn in_vocabulary(&self, word: &str) -> bool;
fn description(&self) -> String;
}
struct NgramPerplexityModel {
model: crate::ngram::NgramModel<
liblevenshtein::dictionary::dynamic_dawg_char::DynamicDawgChar<crate::ngram::NgramEntry>,
>,
}
impl PerplexityModel for NgramPerplexityModel {
fn sentence_log_prob(&self, tokens: &[&str]) -> f64 {
self.model.sentence_log_prob(tokens)
}
fn in_vocabulary(&self, word: &str) -> bool {
self.model.in_vocabulary(word)
}
fn description(&self) -> String {
format!(
"N-gram (order={}, vocab={})",
self.model.order(),
self.model.vocab_size()
)
}
}
struct HybridPerplexityModel {
model: crate::hybrid::HybridLanguageModel<
liblevenshtein::dictionary::dynamic_dawg_char::DynamicDawgChar<crate::ngram::NgramEntry>,
>,
}
impl PerplexityModel for HybridPerplexityModel {
fn sentence_log_prob(&self, tokens: &[&str]) -> f64 {
self.model.sentence_log_prob(tokens)
}
fn in_vocabulary(&self, word: &str) -> bool {
self.model.ngram_model().in_vocabulary(word)
}
fn description(&self) -> String {
format!(
"Hybrid (order={}, ngram_vocab={}, emb_vocab={})",
self.model.ngram_model().order(),
self.model.ngram_model().vocab_size(),
self.model.embedding_model().vocab_size()
)
}
}
fn load_model_for_perplexity(path: &Path) -> CliResult<Box<dyn PerplexityModel>> {
use crate::hybrid::HybridLanguageModel;
use crate::ngram::NgramModel;
use liblevenshtein::dictionary::dynamic_dawg_char::DynamicDawgChar;
if let Ok(model) = HybridLanguageModel::load_portable(path, DynamicDawgChar::new) {
return Ok(Box::new(HybridPerplexityModel { model }));
}
if let Ok(model) = NgramModel::load_portable(path, DynamicDawgChar::new) {
return Ok(Box::new(NgramPerplexityModel { model }));
}
Err(CliError::model_load(
path.to_path_buf(),
"Failed to load model (unknown format or corrupted file)".to_string(),
))
}
fn compute_perplexity(
model: &dyn PerplexityModel,
reader: &dyn CorpusReader,
per_sentence: bool,
) -> PerplexityResult {
let start = Instant::now();
let tokenizer = Tokenizer::new();
let mut total_log_prob = 0.0f64;
let mut total_tokens = 0u64;
let mut total_sentences = 0u64;
let mut oov_tokens = 0u64;
let mut sentence_perplexities = if per_sentence { Some(Vec::new()) } else { None };
for sentence in reader.sentences() {
let tokens_owned: Vec<String> = tokenizer.words(&sentence).collect();
if tokens_owned.is_empty() {
continue;
}
let tokens: Vec<&str> = tokens_owned.iter().map(|s| s.as_str()).collect();
for token in &tokens {
if !model.in_vocabulary(token) {
oov_tokens += 1;
}
}
let sent_log_prob = model.sentence_log_prob(&tokens);
total_log_prob += sent_log_prob;
total_tokens += tokens.len() as u64;
total_sentences += 1;
if let Some(ref mut perps) = sentence_perplexities {
let avg_log_prob = sent_log_prob / tokens.len() as f64;
let sent_ppl = (-avg_log_prob).exp();
perps.push(sent_ppl);
}
}
let perplexity = if total_tokens > 0 {
let avg_log_prob = total_log_prob / total_tokens as f64;
(-avg_log_prob).exp()
} else {
f64::INFINITY
};
PerplexityResult {
perplexity,
log_probability: total_log_prob,
sentences: total_sentences,
tokens: total_tokens,
oov_tokens,
per_sentence: sentence_perplexities,
elapsed_secs: start.elapsed().as_secs_f64(),
}
}
fn eval_perplexity(args: EvalPerplexityArgs, verbose: bool, quiet: bool) -> CliResult<()> {
if verbose {
eprintln!("Evaluating perplexity");
eprintln!(" Model: {}", args.model.display());
eprintln!(" Corpus: {}", args.test_corpus);
}
if !args.model.exists() {
return Err(CliError::file_not_found(&args.model));
}
if !quiet {
eprintln!("Loading model...");
}
let model = load_model_for_perplexity(&args.model)?;
if verbose {
eprintln!(" Model type: {}", model.description());
}
if !quiet {
eprintln!("Loading test corpus...");
}
let reader = create_corpus_reader(&args.test_corpus, args.format)?;
if !quiet {
eprintln!("Computing perplexity...");
}
let result = compute_perplexity(model.as_ref(), reader.as_ref(), args.per_sentence);
if let Some(ref output_path) = args.output {
let json_output = serde_json::json!({
"model": args.model.display().to_string(),
"test_corpus": args.test_corpus,
"perplexity": result.perplexity,
"log_probability": result.log_probability,
"sentences": result.sentences,
"tokens": result.tokens,
"oov_tokens": result.oov_tokens,
"oov_rate": if result.tokens > 0 { result.oov_tokens as f64 / result.tokens as f64 * 100.0 } else { 0.0 },
"elapsed_secs": result.elapsed_secs,
"per_sentence": result.per_sentence,
});
let serialized = serde_json::to_string_pretty(&json_output)
.expect("serde_json::Value built via json!() macro must serialize");
std::fs::write(output_path, serialized)
.map_err(|e| CliError::io(format!("Failed to write output: {}", e)))?;
if !quiet {
eprintln!("Results written to: {}", output_path.display());
}
}
if !quiet {
println!();
println!("Model: {}", style(args.model.display()).cyan());
println!(
"Test corpus: {} ({} sentences, {} tokens)",
args.test_corpus, result.sentences, result.tokens
);
println!();
println!(
"Perplexity: {}",
style(format!("{:.2}", result.perplexity)).green().bold()
);
println!("Log probability: {:.2}", result.log_probability);
println!(
"OOV rate: {:.2}% ({} tokens)",
if result.tokens > 0 {
result.oov_tokens as f64 / result.tokens as f64 * 100.0
} else {
0.0
},
result.oov_tokens
);
println!(
"Avg tokens/sent: {:.2}",
result.tokens as f64 / result.sentences.max(1) as f64
);
println!("Evaluation time: {:.2}s", result.elapsed_secs);
if let Some(ref perps) = result.per_sentence {
if !perps.is_empty() {
let mut sorted = perps.clone();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
println!();
println!("Per-sentence breakdown:");
println!(" Min perplexity: {:.2}", sorted.first().unwrap_or(&0.0));
println!(" Max perplexity: {:.2}", sorted.last().unwrap_or(&0.0));
println!(
" Median: {:.2}",
sorted.get(sorted.len() / 2).unwrap_or(&0.0)
);
}
}
}
Ok(())
}
fn eval_compare(args: EvalCompareArgs, verbose: bool, quiet: bool) -> CliResult<()> {
if verbose {
eprintln!("Comparing {} models", args.models.len());
eprintln!(" Corpus: {}", args.test_corpus);
for model in &args.models {
eprintln!(" Model: {}", model.display());
}
}
for model in &args.models {
if !model.exists() {
return Err(CliError::file_not_found(model));
}
}
if !quiet {
eprintln!("Loading test corpus...");
}
let reader = create_corpus_reader(&args.test_corpus, args.format)?;
let sentences: Vec<String> = reader.sentences().collect();
let total_sentences = sentences.len();
if !quiet {
eprintln!("Test corpus: {} sentences", total_sentences);
}
struct MemoryReader {
sentences: Vec<String>,
}
impl CorpusReader for MemoryReader {
fn documents(&self) -> Box<dyn Iterator<Item = crate::corpus::Document> + Send + '_> {
Box::new(
self.sentences
.iter()
.map(|s| crate::corpus::Document::new(s.clone())),
)
}
fn sentences(&self) -> Box<dyn Iterator<Item = String> + Send + '_> {
Box::new(self.sentences.iter().cloned())
}
fn estimated_tokens(&self) -> Option<usize> {
Some(
self.sentences
.iter()
.map(|s| s.split_whitespace().count())
.sum(),
)
}
}
let memory_reader = MemoryReader { sentences };
let mut results = Vec::new();
for (i, model_path) in args.models.iter().enumerate() {
if !quiet {
eprintln!(
"Evaluating model {}/{}: {}",
i + 1,
args.models.len(),
model_path.display()
);
}
let model = load_model_for_perplexity(model_path)?;
let result = compute_perplexity(model.as_ref(), &memory_reader, false);
results.push((model_path.clone(), model.description(), result));
}
if let Some(ref output_path) = args.output {
let json_results: Vec<_> = results
.iter()
.map(|(path, desc, result)| {
serde_json::json!({
"model": path.display().to_string(),
"description": desc,
"perplexity": result.perplexity,
"log_probability": result.log_probability,
"oov_rate": if result.tokens > 0 { result.oov_tokens as f64 / result.tokens as f64 * 100.0 } else { 0.0 },
"elapsed_secs": result.elapsed_secs,
})
})
.collect();
let json_output = serde_json::json!({
"test_corpus": args.test_corpus,
"sentences": total_sentences,
"models": json_results,
});
let serialized = serde_json::to_string_pretty(&json_output)
.expect("serde_json::Value built via json!() macro must serialize");
std::fs::write(output_path, serialized)
.map_err(|e| CliError::io(format!("Failed to write output: {}", e)))?;
if !quiet {
eprintln!("Results written to: {}", output_path.display());
}
}
if !quiet {
println!();
let mut table = Table::new();
table.load_preset(UTF8_FULL);
table.set_header(vec!["Model", "Perplexity", "OOV Rate", "Time (s)"]);
for (path, _desc, result) in &results {
let oov_rate = if result.tokens > 0 {
format!(
"{:.2}%",
result.oov_tokens as f64 / result.tokens as f64 * 100.0
)
} else {
"N/A".to_string()
};
table.add_row(vec![
path.file_name()
.map(|n| n.to_string_lossy().to_string())
.unwrap_or_else(|| path.display().to_string()),
format!("{:.2}", result.perplexity),
oov_rate,
format!("{:.2}", result.elapsed_secs),
]);
}
println!("{}", table);
}
if !quiet {
if let Some(best) = results.iter().min_by(|a, b| {
a.2.perplexity
.partial_cmp(&b.2.perplexity)
.unwrap_or(std::cmp::Ordering::Equal)
}) {
println!();
print_success(&format!(
"Best model: {} (perplexity: {:.2})",
best.0.display(),
best.2.perplexity
));
}
}
Ok(())
}
fn create_corpus_reader(path: &str, format: CorpusFormat) -> CliResult<Box<dyn CorpusReader>> {
let path = Path::new(path);
match format {
CorpusFormat::Plaintext => {
if path.is_dir() {
Ok(Box::new(
PlaintextReader::from_directory(path)
.map_err(|e| CliError::corpus(e.to_string()))?,
))
} else if path.exists() {
Ok(Box::new(
PlaintextReader::from_file(path)
.map_err(|e| CliError::corpus(e.to_string()))?,
))
} else {
Err(CliError::file_not_found(path))
}
}
CorpusFormat::Wikipedia => {
#[cfg(feature = "http-corpus")]
if path
.to_str()
.is_some_and(|p| p.starts_with("http://") || p.starts_with("https://"))
{
return Ok(Box::new(
WikipediaReader::from_url(path.to_string_lossy().as_ref(), Default::default())
.map_err(|e| CliError::corpus(e.to_string()))?,
));
}
if path.exists() {
Ok(Box::new(
WikipediaReader::new(path).map_err(|e| CliError::corpus(e.to_string()))?,
))
} else {
Err(CliError::file_not_found(path))
}
}
CorpusFormat::Gutenberg => {
if path.is_dir() {
Ok(Box::new(
GutenbergReader::from_directory(path)
.map_err(|e| CliError::corpus(e.to_string()))?,
))
} else if path.exists() {
Ok(Box::new(
GutenbergReader::from_file(path)
.map_err(|e| CliError::corpus(e.to_string()))?,
))
} else {
Err(CliError::file_not_found(path))
}
}
}
}