use std::collections::HashSet;
use uuid::Uuid;
pub fn recall_at_k(ranked: &[Uuid], relevant: &HashSet<Uuid>, k: usize) -> f32 {
let hit = ranked.iter().take(k).any(|id| relevant.contains(id));
if hit { 1.0 } else { 0.0 }
}
pub fn mrr(ranked: &[Uuid], relevant: &HashSet<Uuid>) -> f32 {
ranked
.iter()
.position(|id| relevant.contains(id))
.map(|idx| 1.0 / (idx as f32 + 1.0))
.unwrap_or(0.0)
}
pub fn precision_at_k(ranked: &[Uuid], relevant: &HashSet<Uuid>, k: usize) -> f32 {
if k == 0 {
return 0.0;
}
let hits = ranked
.iter()
.take(k)
.filter(|id| relevant.contains(id))
.count();
hits as f32 / k as f32
}
pub fn gated_out_rate(flags: &[bool]) -> f32 {
if flags.is_empty() {
return 0.0;
}
let gated = flags.iter().filter(|f| **f).count();
gated as f32 / flags.len() as f32
}
#[cfg(test)]
mod tests {
use super::*;
fn ids(n: usize) -> Vec<Uuid> {
(0..n).map(|_| Uuid::new_v4()).collect()
}
#[test]
fn gold_at_rank_three_scores_as_expected() {
let ranked = ids(5);
let relevant: HashSet<Uuid> = [ranked[2]].into_iter().collect();
assert_eq!(recall_at_k(&ranked, &relevant, 1), 0.0);
assert_eq!(recall_at_k(&ranked, &relevant, 3), 1.0);
assert_eq!(recall_at_k(&ranked, &relevant, 5), 1.0);
assert!((mrr(&ranked, &relevant) - 1.0 / 3.0).abs() < 1e-6);
assert!((precision_at_k(&ranked, &relevant, 5) - 1.0 / 5.0).abs() < 1e-6);
}
#[test]
fn no_relevant_in_ranking_scores_zero() {
let ranked = ids(5);
let relevant: HashSet<Uuid> = [Uuid::new_v4()].into_iter().collect();
assert_eq!(recall_at_k(&ranked, &relevant, 5), 0.0);
assert_eq!(mrr(&ranked, &relevant), 0.0);
assert_eq!(precision_at_k(&ranked, &relevant, 5), 0.0);
}
#[test]
fn empty_ranking_scores_zero() {
let ranked: Vec<Uuid> = Vec::new();
let relevant: HashSet<Uuid> = [Uuid::new_v4()].into_iter().collect();
assert_eq!(recall_at_k(&ranked, &relevant, 5), 0.0);
assert_eq!(mrr(&ranked, &relevant), 0.0);
assert_eq!(precision_at_k(&ranked, &relevant, 5), 0.0);
}
#[test]
fn precision_counts_intersection_in_window() {
let ranked = ids(5);
let relevant: HashSet<Uuid> = [ranked[0], ranked[2]].into_iter().collect();
assert!((precision_at_k(&ranked, &relevant, 5) - 2.0 / 5.0).abs() < 1e-6);
assert!((precision_at_k(&ranked, &relevant, 1) - 1.0).abs() < 1e-6);
}
#[test]
fn mrr_uses_first_relevant_rank() {
let ranked = ids(4);
let relevant: HashSet<Uuid> = [ranked[1], ranked[3]].into_iter().collect();
assert!((mrr(&ranked, &relevant) - 1.0 / 2.0).abs() < 1e-6);
}
#[test]
fn gated_out_rate_is_fraction_true() {
assert!((gated_out_rate(&[true, false, false, false]) - 0.25).abs() < 1e-6);
assert_eq!(gated_out_rate(&[false, false]), 0.0);
assert_eq!(gated_out_rate(&[true, true]), 1.0);
assert_eq!(gated_out_rate(&[]), 0.0);
}
}