use crate::dataset::{GoldQuery, RetrievedSet};
pub trait RetrievalMetric: Send + Sync {
fn name(&self) -> String;
fn score(&self, gold: &GoldQuery, retrieved: &RetrievedSet) -> f64;
}
#[derive(Debug, Clone, Copy)]
pub struct RecallAtK {
pub k: usize,
}
impl RecallAtK {
#[must_use]
pub fn new(k: usize) -> Self {
Self { k }
}
}
impl RetrievalMetric for RecallAtK {
fn name(&self) -> String {
format!("recall@{}", self.k)
}
fn score(&self, gold: &GoldQuery, retrieved: &RetrievedSet) -> f64 {
let total = gold.relevant_count();
if total == 0 {
return 1.0;
}
let hits = retrieved
.ranked
.iter()
.take(self.k)
.filter(|d| gold.is_relevant(&d.doc_id))
.count();
hits as f64 / total as f64
}
}
#[derive(Debug, Clone, Copy)]
pub struct PrecisionAtK {
pub k: usize,
}
impl PrecisionAtK {
#[must_use]
pub fn new(k: usize) -> Self {
Self { k }
}
}
impl RetrievalMetric for PrecisionAtK {
fn name(&self) -> String {
format!("precision@{}", self.k)
}
fn score(&self, gold: &GoldQuery, retrieved: &RetrievedSet) -> f64 {
if self.k == 0 {
return 0.0;
}
let hits = retrieved
.ranked
.iter()
.take(self.k)
.filter(|d| gold.is_relevant(&d.doc_id))
.count();
hits as f64 / self.k as f64
}
}
#[derive(Debug, Clone, Copy)]
pub struct HitRateAtK {
pub k: usize,
}
impl HitRateAtK {
#[must_use]
pub fn new(k: usize) -> Self {
Self { k }
}
}
impl RetrievalMetric for HitRateAtK {
fn name(&self) -> String {
format!("hit_rate@{}", self.k)
}
fn score(&self, gold: &GoldQuery, retrieved: &RetrievedSet) -> f64 {
if gold.relevant_count() == 0 {
return 1.0;
}
let any = retrieved
.ranked
.iter()
.take(self.k)
.any(|d| gold.is_relevant(&d.doc_id));
if any { 1.0 } else { 0.0 }
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct Mrr;
impl RetrievalMetric for Mrr {
fn name(&self) -> String {
"mrr".to_string()
}
fn score(&self, gold: &GoldQuery, retrieved: &RetrievedSet) -> f64 {
if gold.relevant_count() == 0 {
return 1.0;
}
for (idx, doc) in retrieved.ranked.iter().enumerate() {
if gold.is_relevant(&doc.doc_id) {
return 1.0 / ((idx + 1) as f64);
}
}
0.0
}
}
#[derive(Debug, Clone, Copy)]
pub struct MapAtK {
pub k: usize,
}
impl MapAtK {
#[must_use]
pub fn new(k: usize) -> Self {
Self { k }
}
}
impl RetrievalMetric for MapAtK {
fn name(&self) -> String {
format!("map@{}", self.k)
}
fn score(&self, gold: &GoldQuery, retrieved: &RetrievedSet) -> f64 {
let total = gold.relevant_count();
if total == 0 {
return 1.0;
}
let mut hits = 0usize;
let mut sum = 0.0_f64;
for (idx, doc) in retrieved.ranked.iter().take(self.k).enumerate() {
if gold.is_relevant(&doc.doc_id) {
hits += 1;
sum += hits as f64 / ((idx + 1) as f64);
}
}
sum / total as f64
}
}
#[derive(Debug, Clone, Copy)]
pub struct NdcgAtK {
pub k: usize,
}
impl NdcgAtK {
#[must_use]
pub fn new(k: usize) -> Self {
Self { k }
}
}
impl RetrievalMetric for NdcgAtK {
fn name(&self) -> String {
format!("ndcg@{}", self.k)
}
fn score(&self, gold: &GoldQuery, retrieved: &RetrievedSet) -> f64 {
if gold.relevant_count() == 0 {
return 1.0;
}
let dcg = retrieved
.ranked
.iter()
.take(self.k)
.enumerate()
.map(|(idx, doc)| {
let grade = gold.grade(&doc.doc_id) as f64;
if grade <= 0.0 {
0.0
} else {
((2.0_f64).powf(grade) - 1.0) / ((idx as f64 + 2.0).log2())
}
})
.sum::<f64>();
let mut grades: Vec<u8> = gold.relevant_docs.values().copied().collect();
grades.sort_unstable_by(|a, b| b.cmp(a));
let idcg = grades
.into_iter()
.take(self.k)
.enumerate()
.map(|(idx, grade)| {
if grade == 0 {
0.0
} else {
((2.0_f64).powf(grade as f64) - 1.0) / ((idx as f64 + 2.0).log2())
}
})
.sum::<f64>();
if idcg == 0.0 { 1.0 } else { dcg / idcg }
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::panic, clippy::indexing_slicing)]
mod tests {
use super::*;
use crate::dataset::RetrievedDoc;
use std::collections::HashMap;
fn gold(relevant: &[(&str, u8)]) -> GoldQuery {
GoldQuery {
query_id: "q".into(),
query: "test".into(),
relevant_docs: relevant
.iter()
.map(|(d, g)| ((*d).to_string(), *g))
.collect::<HashMap<_, _>>(),
reference_answer: None,
}
}
fn retrieved(ids: &[&str]) -> RetrievedSet {
RetrievedSet {
query_id: "q".into(),
ranked: ids
.iter()
.enumerate()
.map(|(i, id)| RetrievedDoc {
doc_id: (*id).to_string(),
score: 1.0 - (i as f64) * 0.01,
})
.collect(),
}
}
#[test]
fn recall_monotonic_in_k() {
let g = gold(&[("d1", 1), ("d2", 1), ("d3", 1)]);
let r = retrieved(&["x", "d1", "y", "d2", "d3"]);
let r1 = RecallAtK::new(1).score(&g, &r);
let r3 = RecallAtK::new(3).score(&g, &r);
let r10 = RecallAtK::new(10).score(&g, &r);
assert!(r1 <= r3);
assert!(r3 <= r10);
assert!((r10 - 1.0).abs() < 1e-9);
}
#[test]
fn precision_correct() {
let g = gold(&[("d1", 1), ("d2", 1)]);
let r = retrieved(&["d1", "x", "d2", "y"]);
let p = PrecisionAtK::new(4).score(&g, &r);
assert!((p - 0.5).abs() < 1e-9);
}
#[test]
fn mrr_perfect_rank_is_one() {
let g = gold(&[("d1", 1)]);
let r = retrieved(&["d1", "x", "y"]);
assert!((Mrr.score(&g, &r) - 1.0).abs() < 1e-9);
}
#[test]
fn mrr_zero_when_no_hit() {
let g = gold(&[("d1", 1)]);
let r = retrieved(&["x", "y", "z"]);
assert_eq!(Mrr.score(&g, &r), 0.0);
}
#[test]
fn ndcg_perfect_when_ordered_by_grade() {
let g = gold(&[("d1", 3), ("d2", 2), ("d3", 1)]);
let r = retrieved(&["d1", "d2", "d3"]);
let s = NdcgAtK::new(3).score(&g, &r);
assert!((s - 1.0).abs() < 1e-9, "got {s}");
}
#[test]
fn ndcg_drops_with_bad_ordering() {
let g = gold(&[("d1", 3), ("d2", 1)]);
let perfect = retrieved(&["d1", "d2"]);
let bad = retrieved(&["d2", "d1"]);
let m = NdcgAtK::new(2);
assert!(m.score(&g, &perfect) > m.score(&g, &bad));
}
#[test]
fn hit_rate_binary() {
let g = gold(&[("d1", 1)]);
assert_eq!(
HitRateAtK::new(3).score(&g, &retrieved(&["a", "d1", "c"])),
1.0
);
assert_eq!(
HitRateAtK::new(3).score(&g, &retrieved(&["a", "b", "c"])),
0.0
);
}
#[test]
fn map_matches_hand_computation() {
let g = gold(&[("d1", 1), ("d2", 1)]);
let r = retrieved(&["d1", "x", "d2", "y"]);
let s = MapAtK::new(4).score(&g, &r);
assert!((s - (1.0 + 2.0 / 3.0) / 2.0).abs() < 1e-9, "got {s}");
}
#[test]
fn empty_relevance_is_vacuously_perfect() {
let g = gold(&[]);
let r = retrieved(&["a", "b", "c"]);
assert_eq!(RecallAtK::new(3).score(&g, &r), 1.0);
assert_eq!(NdcgAtK::new(3).score(&g, &r), 1.0);
assert_eq!(Mrr.score(&g, &r), 1.0);
}
#[test]
fn empty_relevance_is_vacuously_perfect_for_every_metric() {
let g = gold(&[]);
let r = retrieved(&["a", "b", "c"]);
let vacuous: Vec<(String, f64)> = vec![
(RecallAtK::new(3).name(), RecallAtK::new(3).score(&g, &r)),
(HitRateAtK::new(3).name(), HitRateAtK::new(3).score(&g, &r)),
(Mrr.name(), Mrr.score(&g, &r)),
(MapAtK::new(3).name(), MapAtK::new(3).score(&g, &r)),
(NdcgAtK::new(3).name(), NdcgAtK::new(3).score(&g, &r)),
];
for (name, score) in vacuous {
assert_eq!(
score, 1.0,
"{name} broke the vacuous-perfect contract for empty relevance"
);
}
let empty = retrieved(&[]);
assert_eq!(RecallAtK::new(3).score(&g, &empty), 1.0);
assert_eq!(HitRateAtK::new(3).score(&g, &empty), 1.0);
assert_eq!(Mrr.score(&g, &empty), 1.0);
assert_eq!(MapAtK::new(3).score(&g, &empty), 1.0);
assert_eq!(NdcgAtK::new(3).score(&g, &empty), 1.0);
}
#[test]
fn precision_at_k_is_not_vacuously_perfect() {
let g = gold(&[]);
let r = retrieved(&["a", "b", "c"]);
assert_eq!(PrecisionAtK::new(3).score(&g, &r), 0.0);
assert_eq!(PrecisionAtK::new(3).score(&g, &retrieved(&[])), 0.0);
}
}