use crate::error::{Result, TextError};
#[derive(Debug, Clone)]
pub struct MeteorScore {
pub score: f64,
pub precision: f64,
pub recall: f64,
pub f_mean: f64,
pub penalty: f64,
pub chunks: usize,
pub matches: usize,
}
#[derive(Debug, Clone)]
pub struct MeteorConfig {
pub alpha: f64,
pub beta: f64,
pub gamma: f64,
pub use_stemming: bool,
pub use_approximate: bool,
pub approximate_threshold: f64,
}
impl Default for MeteorConfig {
fn default() -> Self {
Self {
alpha: 0.9,
beta: 3.0,
gamma: 0.5,
use_stemming: true,
use_approximate: true,
approximate_threshold: 0.4,
}
}
}
fn simple_stem(word: &str) -> String {
let w = word.to_lowercase();
let len = w.len();
if len <= 3 {
return w;
}
let suffixes = [
"ational", "tional", "ences", "ances", "ments", "ously", "ively", "ation", "ness", "ment",
"able", "ible", "ting", "ally", "ence", "ance", "ings", "ized", "ling", "ful", "ous",
"ive", "ize", "ing", "ies", "ied", "ion", "ers", "est", "ess", "ism", "ist", "ity", "ble",
"ful", "ous", "ent", "ant", "ary", "ery", "ory", "al", "ly", "er", "ed", "en", "es", "ty",
];
for suffix in &suffixes {
if w.ends_with(suffix) && len - suffix.len() >= 3 {
return w[..len - suffix.len()].to_string();
}
}
if w.ends_with('s') && !w.ends_with("ss") && len >= 4 {
return w[..len - 1].to_string();
}
w
}
fn edit_distance(a: &str, b: &str) -> usize {
let a_chars: Vec<char> = a.chars().collect();
let b_chars: Vec<char> = b.chars().collect();
let m = a_chars.len();
let n = b_chars.len();
if m == 0 {
return n;
}
if n == 0 {
return m;
}
let mut prev = vec![0usize; n + 1];
let mut curr = vec![0usize; n + 1];
for j in 0..=n {
prev[j] = j;
}
for i in 1..=m {
curr[0] = i;
for j in 1..=n {
let cost = if a_chars[i - 1] == b_chars[j - 1] {
0
} else {
1
};
curr[j] = (prev[j] + 1).min(curr[j - 1] + 1).min(prev[j - 1] + cost);
}
std::mem::swap(&mut prev, &mut curr);
}
prev[n]
}
#[derive(Debug, Clone)]
struct Alignment {
hyp_idx: usize,
ref_idx: usize,
}
fn build_alignment(
hypothesis: &[&str],
reference: &[&str],
config: &MeteorConfig,
) -> Vec<Alignment> {
let hyp_lower: Vec<String> = hypothesis.iter().map(|w| w.to_lowercase()).collect();
let ref_lower: Vec<String> = reference.iter().map(|w| w.to_lowercase()).collect();
let mut hyp_matched = vec![false; hypothesis.len()];
let mut ref_matched = vec![false; reference.len()];
let mut alignments: Vec<Alignment> = Vec::new();
stage_match(
&hyp_lower,
&ref_lower,
&mut hyp_matched,
&mut ref_matched,
&mut alignments,
|h, r| h == r,
);
if config.use_stemming {
let hyp_stems: Vec<String> = hyp_lower.iter().map(|w| simple_stem(w)).collect();
let ref_stems: Vec<String> = ref_lower.iter().map(|w| simple_stem(w)).collect();
stage_match(
&hyp_stems,
&ref_stems,
&mut hyp_matched,
&mut ref_matched,
&mut alignments,
|h, r| h == r,
);
}
if config.use_approximate {
let threshold = config.approximate_threshold;
stage_match(
&hyp_lower,
&ref_lower,
&mut hyp_matched,
&mut ref_matched,
&mut alignments,
|h, r| {
let max_len = h.len().max(r.len());
if max_len == 0 {
return true;
}
let dist = edit_distance(h, r);
(dist as f64 / max_len as f64) <= threshold
},
);
}
alignments
}
fn stage_match<F>(
hyp_forms: &[String],
ref_forms: &[String],
hyp_matched: &mut [bool],
ref_matched: &mut [bool],
alignments: &mut Vec<Alignment>,
matches: F,
) where
F: Fn(&str, &str) -> bool,
{
for (h_idx, h_form) in hyp_forms.iter().enumerate() {
if hyp_matched[h_idx] {
continue;
}
let mut best_r_idx: Option<usize> = None;
let mut best_dist = usize::MAX;
for (r_idx, r_form) in ref_forms.iter().enumerate() {
if ref_matched[r_idx] {
continue;
}
if matches(h_form, r_form) {
let dist = h_idx.abs_diff(r_idx);
if dist < best_dist {
best_dist = dist;
best_r_idx = Some(r_idx);
}
}
}
if let Some(r_idx) = best_r_idx {
hyp_matched[h_idx] = true;
ref_matched[r_idx] = true;
alignments.push(Alignment {
hyp_idx: h_idx,
ref_idx: r_idx,
});
}
}
}
fn count_chunks(alignments: &[Alignment]) -> usize {
if alignments.is_empty() {
return 0;
}
let mut sorted = alignments.to_vec();
sorted.sort_by_key(|a| a.hyp_idx);
let mut chunks = 1usize;
for i in 1..sorted.len() {
let hyp_contiguous = sorted[i].hyp_idx == sorted[i - 1].hyp_idx + 1;
let ref_contiguous = sorted[i].ref_idx == sorted[i - 1].ref_idx + 1;
if !hyp_contiguous || !ref_contiguous {
chunks += 1;
}
}
chunks
}
pub fn meteor_score(
hypothesis: &[&str],
reference: &[&str],
config: &MeteorConfig,
) -> Result<MeteorScore> {
if config.alpha <= 0.0 || config.alpha >= 1.0 {
return Err(TextError::InvalidInput(format!(
"Alpha must be in (0, 1), got {}",
config.alpha
)));
}
let hyp_len = hypothesis.len();
let ref_len = reference.len();
if hyp_len == 0 && ref_len == 0 {
return Ok(MeteorScore {
score: 1.0,
precision: 1.0,
recall: 1.0,
f_mean: 1.0,
penalty: 0.0,
chunks: 0,
matches: 0,
});
}
if hyp_len == 0 || ref_len == 0 {
return Ok(MeteorScore {
score: 0.0,
precision: 0.0,
recall: 0.0,
f_mean: 0.0,
penalty: 0.0,
chunks: 0,
matches: 0,
});
}
let alignments = build_alignment(hypothesis, reference, config);
let matches = alignments.len();
if matches == 0 {
return Ok(MeteorScore {
score: 0.0,
precision: 0.0,
recall: 0.0,
f_mean: 0.0,
penalty: 0.0,
chunks: 0,
matches: 0,
});
}
let precision = matches as f64 / hyp_len as f64;
let recall = matches as f64 / ref_len as f64;
let alpha = config.alpha;
let f_mean = (precision * recall) / (alpha * precision + (1.0 - alpha) * recall);
let chunks = count_chunks(&alignments);
let frag = chunks as f64 / matches as f64;
let penalty = config.gamma * frag.powf(config.beta);
let penalty = penalty.clamp(0.0, 1.0);
let score = f_mean * (1.0 - penalty);
Ok(MeteorScore {
score,
precision,
recall,
f_mean,
penalty,
chunks,
matches,
})
}
pub fn meteor_score_multi(
hypothesis: &[&str],
references: &[Vec<&str>],
config: &MeteorConfig,
) -> Result<MeteorScore> {
if references.is_empty() {
return Err(TextError::InvalidInput(
"References must not be empty".to_string(),
));
}
let mut best: Option<MeteorScore> = None;
for reference in references {
let score = meteor_score(hypothesis, reference, config)?;
if best.is_none() || score.score > best.as_ref().map_or(0.0, |b| b.score) {
best = Some(score);
}
}
best.ok_or_else(|| TextError::InvalidInput("No references provided".to_string()))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_exact_match_score() {
let hypothesis = vec!["the", "cat", "is", "on", "the", "mat"];
let reference = vec!["the", "cat", "is", "on", "the", "mat"];
let config = MeteorConfig::default();
let result = meteor_score(&hypothesis, &reference, &config).expect("should compute");
assert!(
(result.precision - 1.0).abs() < 1e-9,
"Precision should be 1.0"
);
assert!((result.recall - 1.0).abs() < 1e-9, "Recall should be 1.0");
assert!(result.score > 0.9, "Perfect match should score high");
}
#[test]
fn test_no_match_score() {
let hypothesis = vec!["a", "b", "c"];
let reference = vec!["x", "y", "z"];
let config = MeteorConfig {
use_approximate: false,
..Default::default()
};
let result = meteor_score(&hypothesis, &reference, &config).expect("should compute");
assert!(result.score.abs() < 1e-9, "No match should score 0.0");
}
#[test]
fn test_partial_match_with_stemming() {
let hypothesis = vec!["the", "cats", "sitting", "on", "the", "mats"];
let reference = vec!["the", "cat", "sat", "on", "the", "mat"];
let config = MeteorConfig {
use_stemming: true,
use_approximate: false,
..Default::default()
};
let result = meteor_score(&hypothesis, &reference, &config).expect("should compute");
assert!(
result.matches >= 3,
"Should have at least exact matches: got {}",
result.matches
);
assert!(
result.score > 0.0,
"Partial match should give positive score"
);
}
#[test]
fn test_fragmentation_penalty() {
let hypothesis = vec!["mat", "the", "on", "sat", "cat", "the"];
let reference = vec!["the", "cat", "sat", "on", "the", "mat"];
let config = MeteorConfig {
use_stemming: false,
use_approximate: false,
..Default::default()
};
let result = meteor_score(&hypothesis, &reference, &config).expect("should compute");
assert!(result.chunks > 1, "Scrambled order should produce chunks");
assert!(
result.penalty > 0.0,
"Should have fragmentation penalty: {}",
result.penalty
);
}
#[test]
fn test_approximate_matching() {
let hypothesis = vec!["colour", "neighbours"];
let reference = vec!["color", "neighbors"];
let config = MeteorConfig {
use_stemming: false,
use_approximate: true,
approximate_threshold: 0.4,
..Default::default()
};
let result = meteor_score(&hypothesis, &reference, &config).expect("should compute");
assert_eq!(result.matches, 2, "Both should match approximately");
assert!(result.score > 0.0);
}
#[test]
fn test_invalid_alpha() {
let result = meteor_score(
&["a"],
&["a"],
&MeteorConfig {
alpha: 0.0,
..Default::default()
},
);
assert!(result.is_err());
}
#[test]
fn test_multi_reference() {
let hypothesis = vec!["the", "cat", "sat"];
let references = vec![vec!["a", "dog", "ran"], vec!["the", "cat", "sat"]];
let config = MeteorConfig::default();
let result = meteor_score_multi(&hypothesis, &references, &config).expect("should compute");
assert!(
result.score > 0.8,
"Should match second reference well: {}",
result.score
);
}
#[test]
fn test_simple_stem() {
assert_eq!(simple_stem("running"), "runn");
assert_eq!(simple_stem("cats"), "cat");
assert_eq!(simple_stem("happiness"), "happi");
assert_eq!(simple_stem("the"), "the");
}
}