use serde::{Deserialize, Serialize};
pub fn hits_at_k(ranks: &[usize], k: usize) -> f64 {
if ranks.is_empty() {
return 0.0;
}
let hits = ranks.iter().filter(|&&r| r <= k).count();
hits as f64 / ranks.len() as f64
}
pub fn mean_rank(ranks: &[usize]) -> f64 {
if ranks.is_empty() {
return 0.0;
}
let total: usize = ranks.iter().sum();
total as f64 / ranks.len() as f64
}
pub fn mean_reciprocal_rank(ranks: &[usize]) -> f64 {
if ranks.is_empty() {
return 0.0;
}
let total: f64 = ranks.iter().map(|&r| 1.0 / r as f64).sum();
total / ranks.len() as f64
}
pub fn compute_filtered_rank(
all_scores: &[(usize, f64)],
target_entity: usize,
known_true_indices: &std::collections::HashSet<usize>,
) -> usize {
let mut filtered_rank = 0usize;
for &(entity_idx, _score) in all_scores {
let is_other_positive =
known_true_indices.contains(&entity_idx) && entity_idx != target_entity;
if is_other_positive {
continue;
}
filtered_rank += 1;
if entity_idx == target_entity {
return filtered_rank;
}
}
all_scores.len() + 1
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct EvaluationMetrics {
pub mean_rank: f64,
pub mean_reciprocal_rank: f64,
pub hits_at_1: f64,
pub hits_at_3: f64,
pub hits_at_10: f64,
pub filtered_mean_rank: f64,
pub filtered_mrr: f64,
pub filtered_hits_at_1: f64,
pub filtered_hits_at_3: f64,
pub filtered_hits_at_10: f64,
pub num_test_triples: usize,
}
impl EvaluationMetrics {
pub fn compute(ranks: &[usize], filtered_ranks: &[usize]) -> Self {
assert_eq!(
ranks.len(),
filtered_ranks.len(),
"ranks and filtered_ranks must have the same length"
);
let num = ranks.len();
Self {
mean_rank: mean_rank(ranks),
mean_reciprocal_rank: mean_reciprocal_rank(ranks),
hits_at_1: hits_at_k(ranks, 1),
hits_at_3: hits_at_k(ranks, 3),
hits_at_10: hits_at_k(ranks, 10),
filtered_mean_rank: mean_rank(filtered_ranks),
filtered_mrr: mean_reciprocal_rank(filtered_ranks),
filtered_hits_at_1: hits_at_k(filtered_ranks, 1),
filtered_hits_at_3: hits_at_k(filtered_ranks, 3),
filtered_hits_at_10: hits_at_k(filtered_ranks, 10),
num_test_triples: num,
}
}
pub fn zero() -> Self {
Self {
mean_rank: 0.0,
mean_reciprocal_rank: 0.0,
hits_at_1: 0.0,
hits_at_3: 0.0,
hits_at_10: 0.0,
filtered_mean_rank: 0.0,
filtered_mrr: 0.0,
filtered_hits_at_1: 0.0,
filtered_hits_at_3: 0.0,
filtered_hits_at_10: 0.0,
num_test_triples: 0,
}
}
pub fn display(&self) -> String {
format!(
"KGC Evaluation Metrics ({} queries)\n\
─────────────────────────────────────────\n\
Metric Raw Filtered\n\
Mean Rank {:>10.2} {:>10.2}\n\
MRR {:>10.4} {:>10.4}\n\
Hits@1 {:>10.4} {:>10.4}\n\
Hits@3 {:>10.4} {:>10.4}\n\
Hits@10 {:>10.4} {:>10.4}",
self.num_test_triples,
self.mean_rank,
self.filtered_mean_rank,
self.mean_reciprocal_rank,
self.filtered_mrr,
self.hits_at_1,
self.filtered_hits_at_1,
self.hits_at_3,
self.filtered_hits_at_3,
self.hits_at_10,
self.filtered_hits_at_10,
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_compute_known_ranks_correct_mrr() {
let ranks = vec![1usize, 2, 4];
let filtered = vec![1usize, 2, 4]; let m = EvaluationMetrics::compute(&ranks, &filtered);
let expected_mrr = (1.0 + 0.5 + 0.25) / 3.0;
assert!(
(m.mean_reciprocal_rank - expected_mrr).abs() < 1e-12,
"MRR expected {expected_mrr:.6}, got {:.6}",
m.mean_reciprocal_rank
);
}
#[test]
fn test_mrr_all_rank_one() {
let ranks = vec![1usize; 10];
let m = EvaluationMetrics::compute(&ranks, &ranks.clone());
assert!(
(m.mean_reciprocal_rank - 1.0).abs() < 1e-12,
"expected MRR = 1.0, got {}",
m.mean_reciprocal_rank
);
}
#[test]
fn test_mrr_all_rank_two() {
let ranks = vec![2usize; 6];
let m = EvaluationMetrics::compute(&ranks, &ranks.clone());
assert!(
(m.mean_reciprocal_rank - 0.5).abs() < 1e-12,
"expected MRR = 0.5, got {}",
m.mean_reciprocal_rank
);
}
#[test]
fn test_hits_at_k_correctness() {
let ranks = vec![1usize, 2, 3, 5, 11, 12];
assert!((hits_at_k(&ranks, 1) - 1.0 / 6.0).abs() < 1e-12);
assert!((hits_at_k(&ranks, 3) - 3.0 / 6.0).abs() < 1e-12);
assert!((hits_at_k(&ranks, 10) - 4.0 / 6.0).abs() < 1e-12);
}
#[test]
fn test_mean_rank_helper() {
let ranks = vec![2usize, 4, 6];
assert!((mean_rank(&ranks) - 4.0).abs() < 1e-12);
}
#[test]
fn test_empty_slices_return_zero() {
assert_eq!(hits_at_k(&[], 1), 0.0);
assert_eq!(mean_rank(&[]), 0.0);
assert_eq!(mean_reciprocal_rank(&[]), 0.0);
}
#[test]
fn test_num_test_triples_field() {
let ranks = vec![1usize, 3, 5, 7];
let m = EvaluationMetrics::compute(&ranks, &ranks.clone());
assert_eq!(m.num_test_triples, 4);
}
#[test]
fn test_compute_filtered_rank_skips_positives() {
let scores = vec![(0usize, 10.0_f64), (1, 8.0), (2, 6.0)];
let mut known: std::collections::HashSet<usize> = std::collections::HashSet::new();
known.insert(1);
let raw = scores.iter().position(|&(e, _)| e == 2).unwrap() + 1;
let filtered = compute_filtered_rank(&scores, 2, &known);
assert_eq!(raw, 3, "raw rank should be 3");
assert_eq!(filtered, 2, "filtered rank should be 2 (entity 1 removed)");
}
#[test]
fn test_filtered_rank_equals_raw_when_no_other_positives() {
let scores = vec![(0usize, 5.0_f64), (1, 3.0), (2, 1.0)];
let known: std::collections::HashSet<usize> = std::collections::HashSet::new();
let raw = scores.iter().position(|&(e, _)| e == 2).unwrap() + 1;
let filtered = compute_filtered_rank(&scores, 2, &known);
assert_eq!(raw, filtered, "no other positives → filtered equals raw");
}
#[test]
fn test_display_non_empty() {
let ranks = vec![1usize, 2, 3];
let m = EvaluationMetrics::compute(&ranks, &ranks.clone());
let s = m.display();
assert!(!s.is_empty(), "display() should produce a non-empty string");
assert!(s.contains("MRR"), "display string should mention MRR");
}
}