use std::collections::{HashMap, HashSet};
#[derive(Debug, Clone)]
pub struct QueryResult {
pub query_id: String,
pub retrieved_ids: Vec<String>,
pub relevant_ids: HashSet<String>,
pub relevance_grades: Option<HashMap<String, f64>>,
}
impl QueryResult {
pub fn new(
query_id: impl Into<String>,
retrieved: Vec<String>,
relevant: HashSet<String>,
) -> Self {
Self {
query_id: query_id.into(),
retrieved_ids: retrieved,
relevant_ids: relevant,
relevance_grades: None,
}
}
pub fn with_grades(mut self, grades: HashMap<String, f64>) -> Self {
self.relevance_grades = Some(grades);
self
}
}
#[derive(Debug, Clone, Default)]
pub struct EvaluationResult {
pub query_id: String,
pub recall: f64,
pub precision: f64,
pub ndcg: f64,
pub mrr: f64,
pub ap: f64,
pub k: usize,
}
#[derive(Debug, Clone)]
pub struct RetrievalMetrics {
pub recall_at_k: HashMap<usize, f64>,
pub precision_at_k: HashMap<usize, f64>,
pub ndcg_at_k: HashMap<usize, f64>,
pub mrr: f64,
pub map: f64,
}
impl RetrievalMetrics {
pub fn compute_all(
retrieved: &[String],
relevant: &HashSet<String>,
k_values: &[usize],
) -> Self {
let mut recall_at_k = HashMap::new();
let mut precision_at_k = HashMap::new();
let mut ndcg_at_k = HashMap::new();
for &k in k_values {
recall_at_k.insert(k, recall_at_k_impl(retrieved, relevant, k));
precision_at_k.insert(k, precision_at_k_impl(retrieved, relevant, k));
ndcg_at_k.insert(k, ndcg_at_k_binary(retrieved, relevant, k));
}
let mrr = mean_reciprocal_rank_single(retrieved, relevant);
let map = average_precision_impl(retrieved, relevant);
Self {
recall_at_k,
precision_at_k,
ndcg_at_k,
mrr,
map,
}
}
pub fn compute(retrieved: &[String], relevant: &HashSet<String>, k: usize) -> EvaluationResult {
EvaluationResult {
query_id: String::new(),
recall: recall_at_k_impl(retrieved, relevant, k),
precision: precision_at_k_impl(retrieved, relevant, k),
ndcg: ndcg_at_k_binary(retrieved, relevant, k),
mrr: mean_reciprocal_rank_single(retrieved, relevant),
ap: average_precision_impl(retrieved, relevant),
k,
}
}
}
pub fn recall_at_k(retrieved: &[impl AsRef<str>], relevant: &HashSet<String>, k: usize) -> f64 {
let retrieved_str: Vec<String> = retrieved.iter().map(|s| s.as_ref().to_string()).collect();
recall_at_k_impl(&retrieved_str, relevant, k)
}
fn recall_at_k_impl(retrieved: &[String], relevant: &HashSet<String>, k: usize) -> f64 {
if relevant.is_empty() {
return 0.0;
}
let top_k: HashSet<_> = retrieved.iter().take(k).cloned().collect();
let hits = relevant.intersection(&top_k).count();
hits as f64 / relevant.len() as f64
}
pub fn precision_at_k(retrieved: &[impl AsRef<str>], relevant: &HashSet<String>, k: usize) -> f64 {
let retrieved_str: Vec<String> = retrieved.iter().map(|s| s.as_ref().to_string()).collect();
precision_at_k_impl(&retrieved_str, relevant, k)
}
fn precision_at_k_impl(retrieved: &[String], relevant: &HashSet<String>, k: usize) -> f64 {
if k == 0 {
return 0.0;
}
let actual_k = k.min(retrieved.len());
if actual_k == 0 {
return 0.0;
}
let hits = retrieved
.iter()
.take(actual_k)
.filter(|doc| relevant.contains(*doc))
.count();
hits as f64 / actual_k as f64
}
pub fn ndcg_at_k(retrieved: &[impl AsRef<str>], relevant: &HashSet<String>, k: usize) -> f64 {
let retrieved_str: Vec<String> = retrieved.iter().map(|s| s.as_ref().to_string()).collect();
ndcg_at_k_binary(&retrieved_str, relevant, k)
}
fn ndcg_at_k_binary(retrieved: &[String], relevant: &HashSet<String>, k: usize) -> f64 {
if relevant.is_empty() {
return 0.0;
}
let dcg: f64 = retrieved
.iter()
.take(k)
.enumerate()
.filter(|(_, doc)| relevant.contains(*doc))
.map(|(i, _)| 1.0 / (i as f64 + 2.0).log2()) .sum();
let num_relevant_in_k = k.min(relevant.len());
let idcg: f64 = (0..num_relevant_in_k)
.map(|i| 1.0 / (i as f64 + 2.0).log2())
.sum();
if idcg == 0.0 {
return 0.0;
}
dcg / idcg
}
pub fn ndcg_at_k_graded(
retrieved: &[String],
relevance_grades: &HashMap<String, f64>,
k: usize,
) -> f64 {
if relevance_grades.is_empty() {
return 0.0;
}
let dcg: f64 = retrieved
.iter()
.take(k)
.enumerate()
.map(|(i, doc)| {
let rel = relevance_grades.get(doc).copied().unwrap_or(0.0);
(2_f64.powf(rel) - 1.0) / (i as f64 + 2.0).log2()
})
.sum();
let mut sorted_grades: Vec<f64> = relevance_grades.values().copied().collect();
sorted_grades.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
let idcg: f64 = sorted_grades
.iter()
.take(k)
.enumerate()
.map(|(i, &rel)| (2_f64.powf(rel) - 1.0) / (i as f64 + 2.0).log2())
.sum();
if idcg == 0.0 {
return 0.0;
}
dcg / idcg
}
pub fn mean_reciprocal_rank(results: &[QueryResult]) -> f64 {
if results.is_empty() {
return 0.0;
}
let sum: f64 = results
.iter()
.map(|r| mean_reciprocal_rank_single(&r.retrieved_ids, &r.relevant_ids))
.sum();
sum / results.len() as f64
}
fn mean_reciprocal_rank_single(retrieved: &[String], relevant: &HashSet<String>) -> f64 {
for (i, doc) in retrieved.iter().enumerate() {
if relevant.contains(doc) {
return 1.0 / (i as f64 + 1.0);
}
}
0.0
}
pub fn average_precision(retrieved: &[impl AsRef<str>], relevant: &HashSet<String>) -> f64 {
let retrieved_str: Vec<String> = retrieved.iter().map(|s| s.as_ref().to_string()).collect();
average_precision_impl(&retrieved_str, relevant)
}
fn average_precision_impl(retrieved: &[String], relevant: &HashSet<String>) -> f64 {
if relevant.is_empty() {
return 0.0;
}
let mut num_relevant_seen = 0;
let mut sum_precision = 0.0;
for (i, doc) in retrieved.iter().enumerate() {
if relevant.contains(doc) {
num_relevant_seen += 1;
let precision = num_relevant_seen as f64 / (i as f64 + 1.0);
sum_precision += precision;
}
}
sum_precision / relevant.len() as f64
}
pub fn mean_average_precision(results: &[QueryResult]) -> f64 {
if results.is_empty() {
return 0.0;
}
let sum: f64 = results
.iter()
.map(|r| average_precision_impl(&r.retrieved_ids, &r.relevant_ids))
.sum();
sum / results.len() as f64
}
#[cfg(test)]
mod tests {
use super::*;
fn make_relevant(ids: &[&str]) -> HashSet<String> {
ids.iter().map(|s| s.to_string()).collect()
}
fn make_retrieved(ids: &[&str]) -> Vec<String> {
ids.iter().map(|s| s.to_string()).collect()
}
#[test]
fn test_recall_at_k_perfect() {
let retrieved = make_retrieved(&["a", "b", "c", "d", "e"]);
let relevant = make_relevant(&["a", "b", "c"]);
assert_eq!(recall_at_k_impl(&retrieved, &relevant, 3), 1.0);
assert_eq!(recall_at_k_impl(&retrieved, &relevant, 5), 1.0);
}
#[test]
fn test_recall_at_k_partial() {
let retrieved = make_retrieved(&["a", "x", "b", "y", "c"]);
let relevant = make_relevant(&["a", "b", "c"]);
assert!((recall_at_k_impl(&retrieved, &relevant, 1) - 1.0 / 3.0).abs() < 0.001);
assert!((recall_at_k_impl(&retrieved, &relevant, 3) - 2.0 / 3.0).abs() < 0.001);
assert_eq!(recall_at_k_impl(&retrieved, &relevant, 5), 1.0);
}
#[test]
fn test_recall_at_k_none() {
let retrieved = make_retrieved(&["x", "y", "z"]);
let relevant = make_relevant(&["a", "b", "c"]);
assert_eq!(recall_at_k_impl(&retrieved, &relevant, 3), 0.0);
}
#[test]
fn test_recall_at_k_empty_relevant() {
let retrieved = make_retrieved(&["a", "b", "c"]);
let relevant = HashSet::new();
assert_eq!(recall_at_k_impl(&retrieved, &relevant, 3), 0.0);
}
#[test]
fn test_precision_at_k_perfect() {
let retrieved = make_retrieved(&["a", "b", "c"]);
let relevant = make_relevant(&["a", "b", "c", "d", "e"]);
assert_eq!(precision_at_k_impl(&retrieved, &relevant, 3), 1.0);
}
#[test]
fn test_precision_at_k_partial() {
let retrieved = make_retrieved(&["a", "x", "b", "y", "c"]);
let relevant = make_relevant(&["a", "b", "c"]);
assert_eq!(precision_at_k_impl(&retrieved, &relevant, 1), 1.0);
assert_eq!(precision_at_k_impl(&retrieved, &relevant, 2), 0.5);
assert_eq!(precision_at_k_impl(&retrieved, &relevant, 5), 0.6);
}
#[test]
fn test_mrr_first_position() {
let retrieved = make_retrieved(&["a", "b", "c"]);
let relevant = make_relevant(&["a"]);
assert_eq!(mean_reciprocal_rank_single(&retrieved, &relevant), 1.0);
}
#[test]
fn test_mrr_second_position() {
let retrieved = make_retrieved(&["x", "a", "c"]);
let relevant = make_relevant(&["a"]);
assert_eq!(mean_reciprocal_rank_single(&retrieved, &relevant), 0.5);
}
#[test]
fn test_mrr_third_position() {
let retrieved = make_retrieved(&["x", "y", "a"]);
let relevant = make_relevant(&["a"]);
assert!((mean_reciprocal_rank_single(&retrieved, &relevant) - 1.0 / 3.0).abs() < 0.001);
}
#[test]
fn test_mrr_not_found() {
let retrieved = make_retrieved(&["x", "y", "z"]);
let relevant = make_relevant(&["a"]);
assert_eq!(mean_reciprocal_rank_single(&retrieved, &relevant), 0.0);
}
#[test]
fn test_ndcg_perfect() {
let retrieved = make_retrieved(&["a", "b", "c", "x", "y"]);
let relevant = make_relevant(&["a", "b", "c"]);
assert!((ndcg_at_k_binary(&retrieved, &relevant, 5) - 1.0).abs() < 0.001);
}
#[test]
fn test_ndcg_partial() {
let retrieved = make_retrieved(&["x", "a", "y", "b", "c"]);
let relevant = make_relevant(&["a", "b", "c"]);
let ndcg = ndcg_at_k_binary(&retrieved, &relevant, 5);
assert!(ndcg > 0.0 && ndcg < 1.0);
}
#[test]
fn test_average_precision() {
let retrieved = make_retrieved(&["a", "x", "b", "y", "c"]);
let relevant = make_relevant(&["a", "b", "c"]);
let ap = average_precision_impl(&retrieved, &relevant);
assert!(ap > 0.7 && ap < 0.8);
}
#[test]
fn test_average_precision_perfect() {
let retrieved = make_retrieved(&["a", "b", "c", "x", "y"]);
let relevant = make_relevant(&["a", "b", "c"]);
let ap = average_precision_impl(&retrieved, &relevant);
assert_eq!(ap, 1.0);
}
#[test]
fn test_retrieval_metrics_compute() {
let retrieved = make_retrieved(&["a", "b", "x", "c", "y"]);
let relevant = make_relevant(&["a", "b", "c"]);
let metrics = RetrievalMetrics::compute_all(&retrieved, &relevant, &[5, 10]);
assert!(metrics.recall_at_k.contains_key(&5));
assert!(metrics.precision_at_k.contains_key(&5));
assert!(metrics.ndcg_at_k.contains_key(&5));
assert!(metrics.mrr > 0.0);
assert!(metrics.map > 0.0);
}
}