pub fn mean_reciprocal_rank(ranks: &[usize]) -> f64 {
if ranks.is_empty() {
return 0.0;
}
ranks.iter().map(|&r| 1.0 / r as f64).sum::<f64>() / ranks.len() as f64
}
pub fn hits_at_k(ranks: &[usize], k: usize) -> f64 {
if ranks.is_empty() {
return 0.0;
}
ranks.iter().filter(|&&r| r <= k).count() as f64 / ranks.len() as f64
}
pub fn mean_rank(ranks: &[usize]) -> f64 {
if ranks.is_empty() {
return 0.0;
}
ranks.iter().sum::<usize>() as f64 / ranks.len() as f64
}
pub fn realistic_rank(all_scores: &[f32], true_score: f32) -> f64 {
let mut strictly_better = 0usize;
let mut at_least_as_good = 0usize;
for &s in all_scores {
if s < true_score {
strictly_better += 1;
}
if s <= true_score {
at_least_as_good += 1;
}
}
let optimistic = strictly_better + 1;
let pessimistic = at_least_as_good;
(optimistic as f64 + pessimistic as f64) / 2.0
}
pub fn adjusted_mean_rank(ranks: &[usize], num_entities: usize) -> f64 {
if ranks.is_empty() || num_entities == 0 {
return 0.0;
}
let mr = mean_rank(ranks);
let expected = (num_entities as f64 + 1.0) / 2.0;
mr / expected
}
pub fn per_relation_mrr(rel_ranks: &[(usize, usize)]) -> std::collections::HashMap<usize, f64> {
let mut grouped: std::collections::HashMap<usize, Vec<usize>> =
std::collections::HashMap::new();
for &(rel, rank) in rel_ranks {
grouped.entry(rel).or_default().push(rank);
}
grouped
.into_iter()
.map(|(rel, ranks)| (rel, mean_reciprocal_rank(&ranks)))
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn mrr_basic() {
let ranks = vec![1, 2, 4];
let mrr = mean_reciprocal_rank(&ranks);
assert!((mrr - 0.5833).abs() < 0.001);
}
#[test]
fn hits_at_k_basic() {
let ranks = vec![1, 2, 5, 10, 20];
assert!((hits_at_k(&ranks, 10) - 0.8).abs() < 1e-9);
assert!((hits_at_k(&ranks, 1) - 0.2).abs() < 1e-9);
}
#[test]
fn mean_rank_basic() {
let ranks = vec![1, 3, 5];
assert!((mean_rank(&ranks) - 3.0).abs() < 1e-9);
}
#[test]
fn adjusted_mean_rank_basic() {
let ranks = vec![1, 1, 1]; let amr = adjusted_mean_rank(&ranks, 100);
assert!((amr - 1.0 / 50.5).abs() < 1e-9);
}
#[test]
fn realistic_rank_no_ties() {
let scores = vec![0.1, 0.5, 0.3, 0.9];
assert!((realistic_rank(&scores, 0.3) - 2.0).abs() < 1e-9);
}
#[test]
fn realistic_rank_with_ties() {
let scores = vec![0.3, 0.3, 0.3, 0.9];
assert!((realistic_rank(&scores, 0.3) - 2.0).abs() < 1e-9);
}
#[test]
fn realistic_rank_best() {
let scores = vec![0.1, 0.5, 0.9];
assert!((realistic_rank(&scores, 0.1) - 1.0).abs() < 1e-9);
}
#[test]
fn realistic_rank_worst() {
let scores = vec![0.1, 0.5, 0.9];
assert!((realistic_rank(&scores, 0.9) - 3.0).abs() < 1e-9);
}
#[test]
fn empty_ranks() {
assert_eq!(mean_reciprocal_rank(&[]), 0.0);
assert_eq!(hits_at_k(&[], 10), 0.0);
assert_eq!(mean_rank(&[]), 0.0);
assert_eq!(adjusted_mean_rank(&[], 100), 0.0);
}
}