pub fn ndcg_at_k(scores: &[f64], relevance: &[usize], k: usize) -> f64 {
assert_eq!(
scores.len(),
relevance.len(),
"scores and relevance must have the same length"
);
if scores.is_empty() {
return 0.0;
}
let n = scores.len();
let cut = if k == 0 || k > n { n } else { k };
let mut ranked: Vec<(f64, usize)> = scores
.iter()
.copied()
.zip(relevance.iter().copied())
.collect();
ranked.sort_by(|a, b| b.0.total_cmp(&a.0));
let dcg: f64 = ranked[..cut]
.iter()
.enumerate()
.map(|(i, (_, rel))| (2_f64.powi(*rel as i32) - 1.0) / (i as f64 + 2.0).log2())
.sum();
let mut ideal: Vec<usize> = relevance.to_vec();
ideal.sort_unstable_by(|a, b| b.cmp(a));
let idcg: f64 = ideal[..cut]
.iter()
.enumerate()
.map(|(i, rel)| (2_f64.powi(*rel as i32) - 1.0) / (i as f64 + 2.0).log2())
.sum();
if idcg == 0.0 {
0.0
} else {
(dcg / idcg).clamp(0.0, 1.0)
}
}
pub fn mrr(scores: &[f64], relevance: &[usize]) -> f64 {
assert_eq!(
scores.len(),
relevance.len(),
"scores and relevance must have the same length"
);
if scores.is_empty() {
return 0.0;
}
let mut ranked: Vec<(f64, usize)> = scores
.iter()
.copied()
.zip(relevance.iter().copied())
.collect();
ranked.sort_by(|a, b| b.0.total_cmp(&a.0));
for (i, (_, rel)) in ranked.iter().enumerate() {
if *rel > 0 {
return 1.0 / (i + 1) as f64;
}
}
0.0
}
pub fn precision_at_k(scores: &[f64], relevance: &[usize], k: usize) -> f64 {
assert_eq!(
scores.len(),
relevance.len(),
"scores and relevance must have the same length"
);
if scores.is_empty() {
return 0.0;
}
let n = scores.len();
let cut = if k == 0 || k > n { n } else { k };
let mut ranked: Vec<(f64, usize)> = scores
.iter()
.copied()
.zip(relevance.iter().copied())
.collect();
ranked.sort_by(|a, b| b.0.total_cmp(&a.0));
let relevant = ranked[..cut].iter().filter(|(_, rel)| *rel > 0).count();
relevant as f64 / cut as f64
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ndcg_perfect_ranking() {
let scores = [0.9, 0.7, 0.4, 0.1];
let rel = [3, 2, 1, 0];
let ndcg = ndcg_at_k(&scores, &rel, 4);
assert!(
(ndcg - 1.0).abs() < 1e-9,
"perfect ranking should yield NDCG=1.0, got {ndcg}"
);
}
#[test]
fn test_ndcg_worst_ranking() {
let scores = [0.1, 0.4, 0.7, 0.9];
let rel = [3, 2, 1, 0];
let ndcg = ndcg_at_k(&scores, &rel, 4);
assert!(
ndcg < 0.65 && ndcg > 0.0,
"reversed graded ranking should give NDCG in (0, 0.65), got {ndcg}"
);
let perfect = ndcg_at_k(&[0.9, 0.7, 0.4, 0.1], &rel, 4);
assert!(ndcg < perfect, "reversed ranking must score below perfect");
}
#[test]
fn test_ndcg_all_zero_relevance() {
let scores = [0.9, 0.5, 0.1];
let rel = [0, 0, 0];
assert_eq!(ndcg_at_k(&scores, &rel, 3), 0.0);
}
#[test]
fn test_ndcg_empty() {
assert_eq!(ndcg_at_k(&[], &[], 0), 0.0);
}
#[test]
fn test_ndcg_k_truncates() {
let scores = [0.9, 0.7, 0.5, 0.3];
let rel = [2, 2, 0, 2];
let ndcg_k2 = ndcg_at_k(&scores, &rel, 2);
let ndcg_k4 = ndcg_at_k(&scores, &rel, 4);
assert!(
(ndcg_k2 - 1.0).abs() < 1e-9,
"top-2 perfectly ranked → NDCG@2 should be 1.0, got {ndcg_k2}"
);
assert!(
ndcg_k4 < 1.0,
"missed relevant item at rank 4 should lower NDCG@4, got {ndcg_k4}"
);
assert!(
ndcg_k2 > ndcg_k4,
"NDCG@2={ndcg_k2} should exceed NDCG@4={ndcg_k4}"
);
}
#[test]
fn test_ndcg_k_zero_means_all() {
let scores = [0.9, 0.5];
let rel = [2, 1];
assert!((ndcg_at_k(&scores, &rel, 0) - ndcg_at_k(&scores, &rel, 2)).abs() < 1e-9);
}
#[test]
fn test_mrr_first_is_relevant() {
let scores = [0.9, 0.5, 0.1];
let rel = [1, 0, 0];
assert!((mrr(&scores, &rel) - 1.0).abs() < 1e-9);
}
#[test]
fn test_mrr_second_is_relevant() {
let scores = [0.9, 0.5, 0.1];
let rel = [0, 1, 0];
assert!((mrr(&scores, &rel) - 0.5).abs() < 1e-9);
}
#[test]
fn test_mrr_no_relevant() {
assert_eq!(mrr(&[0.9, 0.5], &[0, 0]), 0.0);
}
#[test]
fn test_mrr_empty() {
assert_eq!(mrr(&[], &[]), 0.0);
}
#[test]
fn test_precision_at_k_all_relevant() {
let scores = [0.9, 0.7, 0.5];
let rel = [1, 1, 1];
assert!((precision_at_k(&scores, &rel, 3) - 1.0).abs() < 1e-9);
}
#[test]
fn test_precision_at_k_half_relevant() {
let scores = [0.9, 0.8, 0.5, 0.1];
let rel = [1, 1, 0, 0];
assert!((precision_at_k(&scores, &rel, 4) - 0.5).abs() < 1e-9);
}
#[test]
fn test_precision_at_k_truncates() {
let scores = [0.9, 0.8, 0.3, 0.1];
let rel = [1, 1, 0, 0];
assert!((precision_at_k(&scores, &rel, 2) - 1.0).abs() < 1e-9);
}
#[test]
fn test_precision_at_k_zero_k_means_all() {
let scores = [0.9, 0.5];
let rel = [1, 0];
assert!((precision_at_k(&scores, &rel, 0) - precision_at_k(&scores, &rel, 2)).abs() < 1e-9);
}
}