use super::{Metric, MetricInput, MetricResult};
use crate::error::{Result, TrustformersError};
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct LanguageModelingMetric {
log_likelihood: f64,
num_tokens: usize,
}
impl LanguageModelingMetric {
pub fn new() -> Self {
Self {
log_likelihood: 0.0,
num_tokens: 0,
}
}
}
impl Metric for LanguageModelingMetric {
fn add_batch(&mut self, predictions: &MetricInput, references: &MetricInput) -> Result<()> {
match (predictions, references) {
(MetricInput::Probabilities(probs), MetricInput::Tokens(tokens)) => {
for (prob_seq, token_seq) in probs.iter().zip(tokens.iter()) {
for (prob_dist, &token) in prob_seq.iter().zip(token_seq.iter()) {
if (token as usize) < prob_dist.len() {
let token_prob = prob_dist[token as usize];
if token_prob > 0.0 {
self.log_likelihood += token_prob.ln() as f64;
} else {
self.log_likelihood += f64::NEG_INFINITY.max(-50.0); }
self.num_tokens += 1;
}
}
}
Ok(())
},
_ => Err(TrustformersError::invalid_input_simple("Invalid input types for language modeling metric: expected Probabilities for predictions and Tokens for references".to_string()
)),
}
}
fn compute(&self) -> Result<MetricResult> {
let perplexity = if self.num_tokens > 0 {
(-self.log_likelihood / self.num_tokens as f64).exp()
} else {
f64::INFINITY
};
let mut details = HashMap::new();
details.insert("perplexity".to_string(), perplexity);
details.insert("log_likelihood".to_string(), self.log_likelihood);
details.insert("num_tokens".to_string(), self.num_tokens as f64);
Ok(MetricResult {
name: "language_modeling".to_string(),
value: perplexity, details,
metadata: HashMap::new(),
})
}
fn reset(&mut self) {
self.log_likelihood = 0.0;
self.num_tokens = 0;
}
fn name(&self) -> &str {
"language_modeling"
}
}
impl Default for LanguageModelingMetric {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_language_modeling_metric_basic() {
let mut metric = LanguageModelingMetric::new();
let probabilities = MetricInput::Probabilities(vec![vec![
vec![0.1, 0.9], vec![0.2, 0.8], ]]);
let tokens = MetricInput::Tokens(vec![vec![1, 1]]);
metric.add_batch(&probabilities, &tokens).expect("add operation failed");
let result = metric.compute().expect("operation failed in test");
assert_eq!(result.name, "language_modeling");
assert!(result.value >= 1.0); assert!(result.details.contains_key("perplexity"));
assert!(result.details.contains_key("log_likelihood"));
assert_eq!(result.details.get("num_tokens"), Some(&2.0));
}
#[test]
fn test_language_modeling_metric_perfect_prediction() {
let mut metric = LanguageModelingMetric::new();
let probabilities = MetricInput::Probabilities(vec![vec![vec![0.0, 1.0]]]);
let tokens = MetricInput::Tokens(vec![vec![1]]);
metric.add_batch(&probabilities, &tokens).expect("add operation failed");
let result = metric.compute().expect("operation failed in test");
assert_eq!(result.value, 1.0); }
#[test]
fn test_language_modeling_metric_random_prediction() {
let mut metric = LanguageModelingMetric::new();
let probabilities = MetricInput::Probabilities(vec![vec![vec![0.5, 0.5]]]);
let tokens = MetricInput::Tokens(vec![vec![0]]);
metric.add_batch(&probabilities, &tokens).expect("add operation failed");
let result = metric.compute().expect("operation failed in test");
assert!((result.value - 2.0).abs() < 1e-6);
}
#[test]
fn test_language_modeling_metric_zero_probability() {
let mut metric = LanguageModelingMetric::new();
let probabilities = MetricInput::Probabilities(vec![vec![vec![1.0, 0.0]]]);
let tokens = MetricInput::Tokens(vec![vec![1]]);
metric.add_batch(&probabilities, &tokens).expect("add operation failed");
let result = metric.compute().expect("operation failed in test");
assert!(result.value > 1.0);
}
#[test]
fn test_language_modeling_metric_invalid_token() {
let mut metric = LanguageModelingMetric::new();
let probabilities = MetricInput::Probabilities(vec![vec![vec![0.5, 0.5]]]);
let tokens = MetricInput::Tokens(vec![vec![5]]);
metric.add_batch(&probabilities, &tokens).expect("add operation failed");
let result = metric.compute().expect("operation failed in test");
assert_eq!(result.value, f64::INFINITY); assert_eq!(result.details.get("num_tokens"), Some(&0.0));
}
#[test]
fn test_language_modeling_metric_multiple_batches() {
let mut metric = LanguageModelingMetric::new();
metric
.add_batch(
&MetricInput::Probabilities(vec![vec![vec![0.0, 1.0]]]),
&MetricInput::Tokens(vec![vec![1]]),
)
.expect("operation failed in test");
metric
.add_batch(
&MetricInput::Probabilities(vec![vec![vec![0.5, 0.5]]]),
&MetricInput::Tokens(vec![vec![0]]),
)
.expect("operation failed in test");
let result = metric.compute().expect("operation failed in test");
let expected = (-0.5_f64.ln() / 2.0).exp();
assert!((result.value - expected).abs() < 1e-6);
assert_eq!(result.details.get("num_tokens"), Some(&2.0));
}
#[test]
fn test_language_modeling_metric_reset() {
let mut metric = LanguageModelingMetric::new();
metric
.add_batch(
&MetricInput::Probabilities(vec![vec![vec![0.5, 0.5]]]),
&MetricInput::Tokens(vec![vec![0]]),
)
.expect("operation failed in test");
metric.reset();
let result = metric.compute().expect("operation failed in test");
assert_eq!(result.value, f64::INFINITY); assert_eq!(result.details.get("num_tokens"), Some(&0.0));
}
#[test]
fn test_language_modeling_metric_invalid_input() {
let mut metric = LanguageModelingMetric::new();
let predictions = MetricInput::Text(vec!["hello".to_string()]);
let references = MetricInput::Tokens(vec![vec![0]]);
let result = metric.add_batch(&predictions, &references);
assert!(result.is_err());
}
#[test]
fn test_language_modeling_metric_empty_sequences() {
let mut metric = LanguageModelingMetric::new();
let probabilities = MetricInput::Probabilities(vec![vec![]]);
let tokens = MetricInput::Tokens(vec![vec![]]);
metric.add_batch(&probabilities, &tokens).expect("add operation failed");
let result = metric.compute().expect("operation failed in test");
assert_eq!(result.value, f64::INFINITY); }
}