use std::collections::HashSet;
pub fn mrr(ranked_results: &[Vec<String>], ground_truth: &[Vec<String>]) -> f64 {
if ranked_results.is_empty() {
return 0.0;
}
let sum: f64 = ranked_results
.iter()
.zip(ground_truth.iter())
.map(|(ranked, truth)| {
let truth_set: HashSet<&str> = truth.iter().map(|s| s.as_str()).collect();
ranked
.iter()
.enumerate()
.find(|(_, name)| truth_set.contains(name.as_str()))
.map(|(i, _)| 1.0 / (i as f64 + 1.0))
.unwrap_or(0.0)
})
.sum();
sum / ranked_results.len() as f64
}
fn precision_at_k_single(ranked: &[String], truth: &[String], k: usize) -> f64 {
let truth_set: HashSet<&str> = truth.iter().map(|s| s.as_str()).collect();
let effective_k = k.min(ranked.len());
if effective_k == 0 {
return 0.0;
}
let relevant = ranked[..effective_k]
.iter()
.filter(|name| truth_set.contains(name.as_str()))
.count();
relevant as f64 / effective_k as f64
}
pub fn precision_at_k(
ranked_results: &[Vec<String>],
ground_truth: &[Vec<String>],
k: usize,
) -> f64 {
if ranked_results.is_empty() {
return 0.0;
}
let sum: f64 = ranked_results
.iter()
.zip(ground_truth.iter())
.map(|(ranked, truth)| precision_at_k_single(ranked, truth, k))
.sum();
sum / ranked_results.len() as f64
}
pub fn blast_precision(predicted: &[String], actual: &[String]) -> f64 {
if predicted.is_empty() {
return 0.0;
}
let actual_set: HashSet<&str> = actual.iter().map(|s| s.as_str()).collect();
let intersection = predicted
.iter()
.filter(|p| actual_set.contains(p.as_str()))
.count();
intersection as f64 / predicted.len() as f64
}
pub fn blast_recall(predicted: &[String], actual: &[String]) -> f64 {
if actual.is_empty() {
return 0.0;
}
let actual_set: HashSet<&str> = actual.iter().map(|s| s.as_str()).collect();
let intersection = predicted
.iter()
.filter(|p| actual_set.contains(p.as_str()))
.count();
intersection as f64 / actual.len() as f64
}
pub fn f1(precision: f64, recall: f64) -> f64 {
if precision + recall == 0.0 {
return 0.0;
}
2.0 * precision * recall / (precision + recall)
}
#[cfg(test)]
mod tests {
use super::*;
fn s(val: &str) -> String {
val.to_string()
}
#[test]
fn mrr_perfect_ranking() {
let ranked = vec![vec![s("a"), s("b")], vec![s("c"), s("d")]];
let truth = vec![vec![s("a")], vec![s("c")]];
assert!((mrr(&ranked, &truth) - 1.0).abs() < f64::EPSILON);
}
#[test]
fn mrr_second_position() {
let ranked = vec![vec![s("x"), s("a"), s("b")]];
let truth = vec![vec![s("a")]];
assert!((mrr(&ranked, &truth) - 0.5).abs() < f64::EPSILON);
}
#[test]
fn mrr_no_match() {
let ranked = vec![vec![s("x"), s("y")]];
let truth = vec![vec![s("a")]];
assert!((mrr(&ranked, &truth) - 0.0).abs() < f64::EPSILON);
}
#[test]
fn mrr_mixed() {
let ranked = vec![vec![s("a"), s("b"), s("c")], vec![s("x"), s("y"), s("a")]];
let truth = vec![vec![s("a")], vec![s("a")]];
let expected = (1.0 + 1.0 / 3.0) / 2.0;
assert!((mrr(&ranked, &truth) - expected).abs() < 1e-10);
}
#[test]
fn mrr_empty_queries() {
let ranked: Vec<Vec<String>> = vec![];
let truth: Vec<Vec<String>> = vec![];
assert!((mrr(&ranked, &truth) - 0.0).abs() < f64::EPSILON);
}
#[test]
fn precision_at_k_all_relevant() {
let ranked = vec![vec![s("a"), s("b"), s("c")]];
let truth = vec![vec![s("a"), s("b"), s("c")]];
assert!((precision_at_k(&ranked, &truth, 3) - 1.0).abs() < f64::EPSILON);
}
#[test]
fn precision_at_k_none_relevant() {
let ranked = vec![vec![s("x"), s("y"), s("z")]];
let truth = vec![vec![s("a"), s("b"), s("c")]];
assert!((precision_at_k(&ranked, &truth, 3) - 0.0).abs() < f64::EPSILON);
}
#[test]
fn precision_at_k_partial() {
let ranked = vec![vec![s("a"), s("x"), s("b"), s("y"), s("c")]];
let truth = vec![vec![s("a"), s("b"), s("c")]];
assert!((precision_at_k(&ranked, &truth, 5) - 0.6).abs() < f64::EPSILON);
}
#[test]
fn precision_at_k_fewer_results_than_k() {
let ranked = vec![vec![s("a"), s("b"), s("c")]];
let truth = vec![vec![s("a"), s("b"), s("c"), s("d"), s("e")]];
assert!((precision_at_k(&ranked, &truth, 5) - 1.0).abs() < f64::EPSILON);
}
#[test]
fn blast_precision_perfect() {
let predicted = vec![s("a"), s("b")];
let actual = vec![s("a"), s("b")];
assert!((blast_precision(&predicted, &actual) - 1.0).abs() < f64::EPSILON);
}
#[test]
fn blast_precision_empty_predicted() {
let predicted: Vec<String> = vec![];
let actual = vec![s("a")];
assert!((blast_precision(&predicted, &actual) - 0.0).abs() < f64::EPSILON);
}
#[test]
fn blast_recall_perfect() {
let predicted = vec![s("a"), s("b")];
let actual = vec![s("a"), s("b")];
assert!((blast_recall(&predicted, &actual) - 1.0).abs() < f64::EPSILON);
}
#[test]
fn blast_recall_empty_actual() {
let predicted = vec![s("a")];
let actual: Vec<String> = vec![];
assert!((blast_recall(&predicted, &actual) - 0.0).abs() < f64::EPSILON);
}
#[test]
fn f1_balanced() {
assert!((f1(0.75, 0.75) - 0.75).abs() < f64::EPSILON);
}
#[test]
fn f1_zero_both() {
assert!((f1(0.0, 0.0) - 0.0).abs() < f64::EPSILON);
}
#[test]
fn f1_typical() {
let expected = 2.0 * 0.8 * 0.6 / (0.8 + 0.6);
assert!((f1(0.8, 0.6) - expected).abs() < 1e-10);
}
}