mnemara_core/
evaluation.rs1use serde::{Deserialize, Serialize};
2use std::collections::BTreeSet;
3
4use crate::query::RecallQuery;
5
6#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
7pub struct JudgedRecallCase {
8 pub name: String,
9 pub query: RecallQuery,
10 pub relevant_record_ids: Vec<String>,
11}
12
13#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
14pub struct RankingMetrics {
15 pub cases: usize,
16 pub hit_rate_at_k: f32,
17 pub recall_at_k: f32,
18 pub mrr: f32,
19 pub ndcg_at_k: f32,
20}
21
22pub fn evaluate_rankings_at_k(rankings: &[(&[String], &[String])], k: usize) -> RankingMetrics {
23 if rankings.is_empty() || k == 0 {
24 return RankingMetrics {
25 cases: rankings.len(),
26 hit_rate_at_k: 0.0,
27 recall_at_k: 0.0,
28 mrr: 0.0,
29 ndcg_at_k: 0.0,
30 };
31 }
32
33 let mut hits = 0.0f32;
34 let mut recall = 0.0f32;
35 let mut reciprocal_rank = 0.0f32;
36 let mut ndcg = 0.0f32;
37
38 for (ranked_ids, relevant_ids) in rankings {
39 let relevant = relevant_ids.iter().cloned().collect::<BTreeSet<_>>();
40 if relevant.is_empty() {
41 continue;
42 }
43 let top_k = ranked_ids.iter().take(k).collect::<Vec<_>>();
44 let matches = top_k
45 .iter()
46 .filter(|record_id| relevant.contains(record_id.as_str()))
47 .count() as f32;
48 if matches > 0.0 {
49 hits += 1.0;
50 }
51 recall += matches / relevant.len() as f32;
52 if let Some(rank) = ranked_ids
53 .iter()
54 .position(|record_id| relevant.contains(record_id.as_str()))
55 {
56 reciprocal_rank += 1.0 / (rank as f32 + 1.0);
57 }
58
59 let dcg = top_k
60 .iter()
61 .enumerate()
62 .filter(|(_, record_id)| relevant.contains(record_id.as_str()))
63 .map(|(index, _)| 1.0 / ((index as f32 + 2.0).log2()))
64 .sum::<f32>();
65 let ideal_hits = relevant.len().min(k);
66 let ideal_dcg = (0..ideal_hits)
67 .map(|index| 1.0 / ((index as f32 + 2.0).log2()))
68 .sum::<f32>();
69 if ideal_dcg > 0.0 {
70 ndcg += dcg / ideal_dcg;
71 }
72 }
73
74 let cases = rankings.len() as f32;
75 RankingMetrics {
76 cases: rankings.len(),
77 hit_rate_at_k: hits / cases,
78 recall_at_k: recall / cases,
79 mrr: reciprocal_rank / cases,
80 ndcg_at_k: ndcg / cases,
81 }
82}
83
84#[cfg(test)]
85mod tests {
86 use super::evaluate_rankings_at_k;
87
88 #[test]
89 fn computes_standard_ranking_metrics() {
90 let ranked_a = vec!["a".to_string(), "b".to_string(), "c".to_string()];
91 let relevant_a = vec!["b".to_string(), "d".to_string()];
92 let ranked_b = vec!["x".to_string(), "y".to_string(), "z".to_string()];
93 let relevant_b = vec!["x".to_string()];
94
95 let metrics =
96 evaluate_rankings_at_k(&[(&ranked_a, &relevant_a), (&ranked_b, &relevant_b)], 3);
97
98 assert_eq!(metrics.cases, 2);
99 assert!(metrics.hit_rate_at_k > 0.9);
100 assert!(metrics.recall_at_k > 0.7);
101 assert!(metrics.mrr > 0.6);
102 assert!(metrics.ndcg_at_k > 0.6);
103 }
104}