pub fn mean_reciprocal_rank<I>(ranks: I) -> f32
where
I: Iterator<Item = usize>,
{
let mut sum = 0.0;
let mut count = 0;
for rank in ranks {
if rank > 0 {
sum += 1.0 / (rank as f32);
}
count += 1;
}
if count == 0 {
0.0
} else {
sum / (count as f32)
}
}
pub fn hits_at_k<I>(ranks: I, k: usize) -> f32
where
I: Iterator<Item = usize>,
{
let mut hits = 0;
let mut count = 0;
for rank in ranks {
if rank > 0 && rank <= k {
hits += 1;
}
count += 1;
}
if count == 0 {
0.0
} else {
hits as f32 / count as f32
}
}
pub fn mean_rank<I>(ranks: I) -> f32
where
I: Iterator<Item = usize>,
{
let mut sum = 0.0;
let mut count = 0;
for rank in ranks {
if rank > 0 {
sum += rank as f32;
}
count += 1;
}
if count == 0 {
0.0
} else {
sum / (count as f32)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mean_reciprocal_rank() {
let ranks = [1, 3, 2, 5];
let mrr = mean_reciprocal_rank(ranks.iter().copied());
assert!((mrr - 0.5083).abs() < 1e-3);
}
#[test]
fn test_hits_at_k() {
let ranks = [1, 3, 2, 5, 10];
let hits_3 = hits_at_k(ranks.iter().copied(), 3);
assert_eq!(hits_3, 0.6);
}
#[test]
fn test_mean_rank() {
let ranks = [1, 3, 2, 5];
let mr = mean_rank(ranks.iter().copied());
assert_eq!(mr, 2.75);
}
#[test]
fn test_empty_metrics() {
let empty_mrr = mean_reciprocal_rank(std::iter::empty());
assert_eq!(empty_mrr, 0.0);
let empty_hits = hits_at_k(std::iter::empty(), 10);
assert_eq!(empty_hits, 0.0);
let empty_mr = mean_rank(std::iter::empty());
assert_eq!(empty_mr, 0.0);
}
#[test]
fn test_edge_cases_zero_rank() {
let ranks = [0, 1, 2, 0, 3];
let mrr = mean_reciprocal_rank(ranks.iter().copied());
assert!(mrr > 0.0 && mrr < 1.0);
}
#[test]
fn test_mrr_single_element() {
let mrr = mean_reciprocal_rank([1].iter().copied());
assert_eq!(mrr, 1.0);
let mrr5 = mean_reciprocal_rank([5].iter().copied());
assert!((mrr5 - 0.2).abs() < 1e-6);
}
#[test]
fn test_mrr_all_perfect() {
let mrr = mean_reciprocal_rank([1, 1, 1, 1].iter().copied());
assert_eq!(mrr, 1.0);
}
#[test]
fn test_mrr_all_zero_rank() {
let mrr = mean_reciprocal_rank([0, 0, 0, 0].iter().copied());
assert_eq!(mrr, 0.0);
}
#[test]
fn test_mean_rank_with_zero_ranks() {
let mr = mean_rank([0, 2, 4, 0].iter().copied());
assert!((mr - 1.5).abs() < 1e-6);
}
#[test]
fn test_mean_rank_single_element() {
let mr = mean_rank([7].iter().copied());
assert_eq!(mr, 7.0);
}
#[test]
fn test_hits_at_k_all_hit() {
let h = hits_at_k([1, 2, 3].iter().copied(), 3);
assert_eq!(h, 1.0);
}
#[test]
fn test_hits_at_k_none_hit() {
let h = hits_at_k([4, 5, 6].iter().copied(), 3);
assert_eq!(h, 0.0);
}
#[test]
fn test_hits_at_k_with_zero_rank() {
let h = hits_at_k([0, 1, 0, 2].iter().copied(), 10);
assert_eq!(h, 0.5);
}
#[test]
fn test_hits_at_1() {
let h = hits_at_k([1, 2, 1, 3].iter().copied(), 1);
assert_eq!(h, 0.5);
}
}