use std::collections::{HashMap, HashSet};
use serde::{Deserialize, Serialize};
pub struct QueryEval {
pub query_id: String,
pub retrieved: Vec<String>,
pub judgments: HashMap<String, u32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchMetrics {
pub precision_at_k: f64,
pub recall_at_k: f64,
pub mrr: f64,
pub ndcg_at_k: f64,
pub k: usize,
pub n_queries: usize,
}
pub fn precision_at_k(retrieved: &[String], relevant: &HashSet<String>, k: usize) -> f64 {
if k == 0 {
return 0.0;
}
let top_k = retrieved.iter().take(k);
let hits = top_k.filter(|id| relevant.contains(id.as_str())).count();
hits as f64 / k as f64
}
pub fn recall_at_k(retrieved: &[String], relevant: &HashSet<String>, k: usize) -> f64 {
if relevant.is_empty() {
return 0.0;
}
let top_k: HashSet<&str> = retrieved.iter().take(k).map(|s| s.as_str()).collect();
let hits = relevant
.iter()
.filter(|id| top_k.contains(id.as_str()))
.count();
hits as f64 / relevant.len() as f64
}
pub fn mrr(retrieved: &[String], relevant: &HashSet<String>) -> f64 {
for (i, id) in retrieved.iter().enumerate() {
if relevant.contains(id.as_str()) {
return 1.0 / (i as f64 + 1.0);
}
}
0.0
}
fn dcg_at_k(retrieved: &[String], judgments: &HashMap<String, u32>, k: usize) -> f64 {
retrieved
.iter()
.take(k)
.enumerate()
.map(|(i, id)| {
let grade = *judgments.get(id.as_str()).unwrap_or(&0) as f64;
(grade.exp2() - 1.0) / (i as f64 + 2.0).log2()
})
.sum()
}
pub fn ndcg_at_k(retrieved: &[String], judgments: &HashMap<String, u32>, k: usize) -> f64 {
let actual = dcg_at_k(retrieved, judgments, k);
let mut ideal_grades: Vec<u32> = judgments.values().copied().collect();
ideal_grades.sort_unstable_by(|a, b| b.cmp(a));
let ideal: f64 = ideal_grades
.iter()
.take(k)
.enumerate()
.map(|(i, &grade)| {
let g = grade as f64;
(g.exp2() - 1.0) / (i as f64 + 2.0).log2()
})
.sum();
if ideal == 0.0 { 0.0 } else { actual / ideal }
}
pub fn evaluate_search(queries: &[QueryEval], k: usize) -> SearchMetrics {
if queries.is_empty() {
return SearchMetrics {
precision_at_k: 0.0,
recall_at_k: 0.0,
mrr: 0.0,
ndcg_at_k: 0.0,
k,
n_queries: 0,
};
}
let n = queries.len() as f64;
let mut sum_p = 0.0;
let mut sum_r = 0.0;
let mut sum_mrr = 0.0;
let mut sum_ndcg = 0.0;
for q in queries {
let relevant: HashSet<String> = q
.judgments
.iter()
.filter(|(_, g)| **g > 0)
.map(|(id, _)| id.clone())
.collect();
sum_p += precision_at_k(&q.retrieved, &relevant, k);
sum_r += recall_at_k(&q.retrieved, &relevant, k);
sum_mrr += mrr(&q.retrieved, &relevant);
sum_ndcg += ndcg_at_k(&q.retrieved, &q.judgments, k);
}
SearchMetrics {
precision_at_k: sum_p / n,
recall_at_k: sum_r / n,
mrr: sum_mrr / n,
ndcg_at_k: sum_ndcg / n,
k,
n_queries: queries.len(),
}
}
#[cfg(test)]
mod tests {
use super::*;
fn relevant_set(ids: &[&str]) -> HashSet<String> {
ids.iter().map(|s| s.to_string()).collect()
}
fn retrieved(ids: &[&str]) -> Vec<String> {
ids.iter().map(|s| s.to_string()).collect()
}
#[test]
fn precision_perfect() {
let r = retrieved(&["a", "b", "c"]);
let rel = relevant_set(&["a", "b", "c"]);
assert!((precision_at_k(&r, &rel, 3) - 1.0).abs() < 1e-10);
}
#[test]
fn precision_half() {
let r = retrieved(&["a", "x", "b", "y"]);
let rel = relevant_set(&["a", "b"]);
assert!((precision_at_k(&r, &rel, 4) - 0.5).abs() < 1e-10);
}
#[test]
fn recall_partial() {
let r = retrieved(&["a", "x"]);
let rel = relevant_set(&["a", "b", "c"]);
let rc = recall_at_k(&r, &rel, 2);
assert!((rc - 1.0 / 3.0).abs() < 1e-10);
}
#[test]
fn mrr_first() {
let r = retrieved(&["a", "b"]);
let rel = relevant_set(&["a"]);
assert!((mrr(&r, &rel) - 1.0).abs() < 1e-10);
}
#[test]
fn mrr_second() {
let r = retrieved(&["x", "a"]);
let rel = relevant_set(&["a"]);
assert!((mrr(&r, &rel) - 0.5).abs() < 1e-10);
}
#[test]
fn mrr_none() {
let r = retrieved(&["x", "y"]);
let rel = relevant_set(&["a"]);
assert_eq!(mrr(&r, &rel), 0.0);
}
#[test]
fn ndcg_perfect_binary() {
let r = retrieved(&["a", "b"]);
let mut j = HashMap::new();
j.insert("a".to_string(), 1);
j.insert("b".to_string(), 1);
let score = ndcg_at_k(&r, &j, 2);
assert!((score - 1.0).abs() < 1e-10);
}
#[test]
fn ndcg_reversed_graded() {
let r = retrieved(&["a", "b"]);
let mut j = HashMap::new();
j.insert("a".to_string(), 1);
j.insert("b".to_string(), 3);
let score = ndcg_at_k(&r, &j, 2);
assert!(score < 1.0);
assert!(score > 0.0);
}
#[test]
fn evaluate_search_aggregates() {
let queries = vec![
QueryEval {
query_id: "q1".to_string(),
retrieved: retrieved(&["a", "b"]),
judgments: HashMap::from([("a".to_string(), 1), ("b".to_string(), 1)]),
},
QueryEval {
query_id: "q2".to_string(),
retrieved: retrieved(&["x", "y"]),
judgments: HashMap::from([("a".to_string(), 1)]),
},
];
let m = evaluate_search(&queries, 2);
assert_eq!(m.n_queries, 2);
assert_eq!(m.k, 2);
assert!((m.precision_at_k - 0.5).abs() < 1e-10);
}
#[test]
fn empty_queries() {
let m = evaluate_search(&[], 5);
assert_eq!(m.n_queries, 0);
assert_eq!(m.precision_at_k, 0.0);
}
#[test]
fn recall_empty_relevant() {
let r = retrieved(&["a"]);
let rel = relevant_set(&[]);
assert_eq!(recall_at_k(&r, &rel, 1), 0.0);
}
}