Skip to main content

mnemara_core/
evaluation.rs

1use serde::{Deserialize, Serialize};
2use std::collections::BTreeSet;
3
4use crate::query::RecallQuery;
5
6#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
7pub struct JudgedRecallCase {
8    pub name: String,
9    pub query: RecallQuery,
10    pub relevant_record_ids: Vec<String>,
11}
12
13#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
14pub struct RankingMetrics {
15    pub cases: usize,
16    pub hit_rate_at_k: f32,
17    pub recall_at_k: f32,
18    pub mrr: f32,
19    pub ndcg_at_k: f32,
20}
21
22pub fn evaluate_rankings_at_k(rankings: &[(&[String], &[String])], k: usize) -> RankingMetrics {
23    if rankings.is_empty() || k == 0 {
24        return RankingMetrics {
25            cases: rankings.len(),
26            hit_rate_at_k: 0.0,
27            recall_at_k: 0.0,
28            mrr: 0.0,
29            ndcg_at_k: 0.0,
30        };
31    }
32
33    let mut hits = 0.0f32;
34    let mut recall = 0.0f32;
35    let mut reciprocal_rank = 0.0f32;
36    let mut ndcg = 0.0f32;
37
38    for (ranked_ids, relevant_ids) in rankings {
39        let relevant = relevant_ids.iter().cloned().collect::<BTreeSet<_>>();
40        if relevant.is_empty() {
41            continue;
42        }
43        let top_k = ranked_ids.iter().take(k).collect::<Vec<_>>();
44        let matches = top_k
45            .iter()
46            .filter(|record_id| relevant.contains(record_id.as_str()))
47            .count() as f32;
48        if matches > 0.0 {
49            hits += 1.0;
50        }
51        recall += matches / relevant.len() as f32;
52        if let Some(rank) = ranked_ids
53            .iter()
54            .position(|record_id| relevant.contains(record_id.as_str()))
55        {
56            reciprocal_rank += 1.0 / (rank as f32 + 1.0);
57        }
58
59        let dcg = top_k
60            .iter()
61            .enumerate()
62            .filter(|(_, record_id)| relevant.contains(record_id.as_str()))
63            .map(|(index, _)| 1.0 / ((index as f32 + 2.0).log2()))
64            .sum::<f32>();
65        let ideal_hits = relevant.len().min(k);
66        let ideal_dcg = (0..ideal_hits)
67            .map(|index| 1.0 / ((index as f32 + 2.0).log2()))
68            .sum::<f32>();
69        if ideal_dcg > 0.0 {
70            ndcg += dcg / ideal_dcg;
71        }
72    }
73
74    let cases = rankings.len() as f32;
75    RankingMetrics {
76        cases: rankings.len(),
77        hit_rate_at_k: hits / cases,
78        recall_at_k: recall / cases,
79        mrr: reciprocal_rank / cases,
80        ndcg_at_k: ndcg / cases,
81    }
82}
83
84#[cfg(test)]
85mod tests {
86    use super::evaluate_rankings_at_k;
87
88    #[test]
89    fn computes_standard_ranking_metrics() {
90        let ranked_a = vec!["a".to_string(), "b".to_string(), "c".to_string()];
91        let relevant_a = vec!["b".to_string(), "d".to_string()];
92        let ranked_b = vec!["x".to_string(), "y".to_string(), "z".to_string()];
93        let relevant_b = vec!["x".to_string()];
94
95        let metrics =
96            evaluate_rankings_at_k(&[(&ranked_a, &relevant_a), (&ranked_b, &relevant_b)], 3);
97
98        assert_eq!(metrics.cases, 2);
99        assert!(metrics.hit_rate_at_k > 0.9);
100        assert!(metrics.recall_at_k > 0.7);
101        assert!(metrics.mrr > 0.6);
102        assert!(metrics.ndcg_at_k > 0.6);
103    }
104}