#[derive(Debug, Clone, Copy)]
pub struct QaScore {
pub exact_match: f32,
pub f1: f32,
}
pub fn normalize_answer(s: &str) -> String {
let lower = s.to_lowercase();
let no_punct: String = lower
.chars()
.map(|c| {
if c.is_alphanumeric() || c.is_whitespace() {
c
} else {
' '
}
})
.collect();
let no_articles: Vec<&str> = no_punct
.split_whitespace()
.filter(|tok| *tok != "a" && *tok != "an" && *tok != "the")
.collect();
no_articles.join(" ")
}
pub fn normalize_tokens(s: &str) -> Vec<String> {
normalize_answer(s)
.split_whitespace()
.map(str::to_string)
.collect()
}
pub fn exact_match(prediction: &str, reference: &str) -> f32 {
if normalize_answer(prediction) == normalize_answer(reference) {
1.0
} else {
0.0
}
}
pub fn f1_score(prediction: &str, reference: &str) -> f32 {
let pred_tokens = normalize_tokens(prediction);
let ref_tokens = normalize_tokens(reference);
if pred_tokens.is_empty() && ref_tokens.is_empty() {
return 1.0;
}
if pred_tokens.is_empty() || ref_tokens.is_empty() {
return 0.0;
}
let common = common_multiset(&pred_tokens, &ref_tokens);
if common == 0 {
return 0.0;
}
let precision = common as f32 / pred_tokens.len() as f32;
let recall = common as f32 / ref_tokens.len() as f32;
(2.0 * precision * recall) / (precision + recall)
}
pub fn score_multi(prediction: &str, references: &[&str]) -> QaScore {
if references.is_empty() {
return QaScore {
exact_match: 0.0,
f1: 0.0,
};
}
let mut em = 0.0f32;
let mut f1 = 0.0f32;
for r in references {
em = em.max(exact_match(prediction, r));
f1 = f1.max(f1_score(prediction, r));
}
QaScore {
exact_match: em,
f1,
}
}
pub fn corpus_em_f1(examples: &[(String, Vec<String>)]) -> (f32, f32) {
if examples.is_empty() {
return (0.0, 0.0);
}
let mut em_sum = 0.0f32;
let mut f1_sum = 0.0f32;
for (pred, refs) in examples {
let refs_slice: Vec<&str> = refs.iter().map(String::as_str).collect();
let s = score_multi(pred, &refs_slice);
em_sum += s.exact_match;
f1_sum += s.f1;
}
let n = examples.len() as f32;
(em_sum / n, f1_sum / n)
}
fn common_multiset(a: &[String], b: &[String]) -> usize {
use std::collections::HashMap;
let mut map: HashMap<&str, i64> = HashMap::new();
for t in a {
*map.entry(t.as_str()).or_insert(0) += 1;
}
let mut common = 0i64;
for t in b {
let entry = map.entry(t.as_str()).or_insert(0);
if *entry > 0 {
common += 1;
*entry -= 1;
}
}
common.max(0) as usize
}