1use serde::Serialize;
7use std::collections::HashSet;
8
9#[derive(Debug, Clone, Serialize)]
11pub struct QueryScores {
12 pub query: String,
13 pub precision_at_5: f64,
14 pub precision_at_10: f64,
15 pub recall_at_5: f64,
16 pub recall_at_10: f64,
17 pub mrr: f64,
18 pub ndcg_at_10: f64,
19}
20
21#[derive(Debug, Clone, Serialize)]
23pub struct AggregateScores {
24 pub mean_precision_at_5: f64,
25 pub mean_precision_at_10: f64,
26 pub mean_recall_at_5: f64,
27 pub mean_recall_at_10: f64,
28 pub mean_mrr: f64,
29 pub mean_ndcg_at_10: f64,
30 pub query_count: usize,
31}
32
33pub fn precision_at_k(retrieved: &[String], relevant: &HashSet<String>, k: usize) -> f64 {
35 let top_k: Vec<&String> = retrieved.iter().take(k).collect();
36 if top_k.is_empty() {
37 return 0.0;
38 }
39 let hits = top_k.iter().filter(|r| relevant.contains(r.as_str())).count();
40 hits as f64 / top_k.len() as f64
41}
42
43pub fn recall_at_k(retrieved: &[String], relevant: &HashSet<String>, k: usize) -> f64 {
45 if relevant.is_empty() {
46 return 0.0;
47 }
48 let top_k: Vec<&String> = retrieved.iter().take(k).collect();
49 let hits = top_k.iter().filter(|r| relevant.contains(r.as_str())).count();
50 hits as f64 / relevant.len() as f64
51}
52
53pub fn mrr(retrieved: &[String], relevant: &HashSet<String>) -> f64 {
55 for (i, item) in retrieved.iter().enumerate() {
56 if relevant.contains(item.as_str()) {
57 return 1.0 / (i as f64 + 1.0);
58 }
59 }
60 0.0
61}
62
63pub fn ndcg_at_k(retrieved: &[String], relevant: &HashSet<String>, k: usize) -> f64 {
65 let dcg = dcg(retrieved, relevant, k);
66 let ideal_dcg = ideal_dcg(relevant.len(), k);
67 if ideal_dcg == 0.0 {
68 return 0.0;
69 }
70 dcg / ideal_dcg
71}
72
73fn dcg(retrieved: &[String], relevant: &HashSet<String>, k: usize) -> f64 {
74 retrieved
75 .iter()
76 .take(k)
77 .enumerate()
78 .map(|(i, item)| {
79 let rel = if relevant.contains(item.as_str()) { 1.0 } else { 0.0 };
80 rel / (i as f64 + 2.0).log2()
81 })
82 .sum()
83}
84
85fn ideal_dcg(num_relevant: usize, k: usize) -> f64 {
86 (0..num_relevant.min(k))
87 .map(|i| 1.0 / (i as f64 + 2.0).log2())
88 .sum()
89}
90
91pub fn score_query(
93 query_text: &str,
94 retrieved: &[String],
95 relevant: &HashSet<String>,
96) -> QueryScores {
97 QueryScores {
98 query: query_text.to_string(),
99 precision_at_5: precision_at_k(retrieved, relevant, 5),
100 precision_at_10: precision_at_k(retrieved, relevant, 10),
101 recall_at_5: recall_at_k(retrieved, relevant, 5),
102 recall_at_10: recall_at_k(retrieved, relevant, 10),
103 mrr: mrr(retrieved, relevant),
104 ndcg_at_10: ndcg_at_k(retrieved, relevant, 10),
105 }
106}
107
108pub fn aggregate(scores: &[QueryScores]) -> AggregateScores {
110 let n = scores.len();
111 if n == 0 {
112 return AggregateScores {
113 mean_precision_at_5: 0.0,
114 mean_precision_at_10: 0.0,
115 mean_recall_at_5: 0.0,
116 mean_recall_at_10: 0.0,
117 mean_mrr: 0.0,
118 mean_ndcg_at_10: 0.0,
119 query_count: 0,
120 };
121 }
122
123 let nf = n as f64;
124 AggregateScores {
125 mean_precision_at_5: scores.iter().map(|s| s.precision_at_5).sum::<f64>() / nf,
126 mean_precision_at_10: scores.iter().map(|s| s.precision_at_10).sum::<f64>() / nf,
127 mean_recall_at_5: scores.iter().map(|s| s.recall_at_5).sum::<f64>() / nf,
128 mean_recall_at_10: scores.iter().map(|s| s.recall_at_10).sum::<f64>() / nf,
129 mean_mrr: scores.iter().map(|s| s.mrr).sum::<f64>() / nf,
130 mean_ndcg_at_10: scores.iter().map(|s| s.ndcg_at_10).sum::<f64>() / nf,
131 query_count: n,
132 }
133}
134
135#[cfg(test)]
136mod tests {
137 use super::*;
138
139 #[test]
140 fn perfect_precision() {
141 let relevant: HashSet<String> = ["a", "b", "c"].iter().map(|s| s.to_string()).collect();
142 let retrieved: Vec<String> = ["a", "b", "c", "d", "e"].iter().map(|s| s.to_string()).collect();
143 assert!((precision_at_k(&retrieved, &relevant, 3) - 1.0).abs() < 1e-10);
144 }
145
146 #[test]
147 fn zero_precision_when_no_relevant() {
148 let relevant: HashSet<String> = ["x", "y"].iter().map(|s| s.to_string()).collect();
149 let retrieved: Vec<String> = ["a", "b", "c"].iter().map(|s| s.to_string()).collect();
150 assert!((precision_at_k(&retrieved, &relevant, 3) - 0.0).abs() < 1e-10);
151 }
152
153 #[test]
154 fn mrr_first_position() {
155 let relevant: HashSet<String> = ["a"].iter().map(|s| s.to_string()).collect();
156 let retrieved: Vec<String> = ["a", "b", "c"].iter().map(|s| s.to_string()).collect();
157 assert!((mrr(&retrieved, &relevant) - 1.0).abs() < 1e-10);
158 }
159
160 #[test]
161 fn mrr_second_position() {
162 let relevant: HashSet<String> = ["b"].iter().map(|s| s.to_string()).collect();
163 let retrieved: Vec<String> = ["a", "b", "c"].iter().map(|s| s.to_string()).collect();
164 assert!((mrr(&retrieved, &relevant) - 0.5).abs() < 1e-10);
165 }
166
167 #[test]
168 fn ndcg_perfect() {
169 let relevant: HashSet<String> = ["a", "b"].iter().map(|s| s.to_string()).collect();
170 let retrieved: Vec<String> = ["a", "b", "c"].iter().map(|s| s.to_string()).collect();
171 assert!((ndcg_at_k(&retrieved, &relevant, 3) - 1.0).abs() < 1e-10);
172 }
173}