1use std::collections::{HashMap, HashSet};
36
37#[must_use]
49pub fn recall_at_k(retrieved: &[usize], relevant: &HashSet<usize>, k: usize) -> f64 {
50 if relevant.is_empty() {
51 return 1.0;
52 }
53
54 let k = k.min(retrieved.len());
55 let relevant_in_top_k = retrieved[..k]
56 .iter()
57 .filter(|id| relevant.contains(id))
58 .count();
59
60 relevant_in_top_k as f64 / relevant.len() as f64
61}
62
63#[must_use]
75pub fn precision_at_k(retrieved: &[usize], relevant: &HashSet<usize>, k: usize) -> f64 {
76 if retrieved.is_empty() {
77 return 0.0;
78 }
79
80 let k = k.min(retrieved.len());
81 let relevant_in_top_k = retrieved[..k]
82 .iter()
83 .filter(|id| relevant.contains(id))
84 .count();
85
86 relevant_in_top_k as f64 / k as f64
87}
88
89#[must_use]
103pub fn ndcg_at_k(retrieved: &[usize], relevance_scores: &HashMap<usize, f64>, k: usize) -> f64 {
104 let k = k.min(retrieved.len());
105 if k == 0 {
106 return 0.0;
107 }
108
109 let dcg: f64 = retrieved[..k]
111 .iter()
112 .enumerate()
113 .map(|(i, id)| {
114 let rel = relevance_scores.get(id).unwrap_or(&0.0);
115 (2.0f64.powf(*rel) - 1.0) / (2.0 + i as f64).log2()
116 })
117 .sum();
118
119 let mut ideal_rels: Vec<f64> = relevance_scores.values().copied().collect();
121 ideal_rels.sort_by(|a, b| b.partial_cmp(a).unwrap());
122 ideal_rels.truncate(k);
123
124 let idcg: f64 = ideal_rels
125 .iter()
126 .enumerate()
127 .map(|(i, rel)| (2.0f64.powf(*rel) - 1.0) / (2.0 + i as f64).log2())
128 .sum();
129
130 if idcg == 0.0 {
131 return 0.0;
132 }
133
134 dcg / idcg
135}
136
137#[must_use]
150pub fn mrr(retrieved_lists: &[Vec<usize>], relevant_sets: &[HashSet<usize>]) -> f64 {
151 if retrieved_lists.is_empty() || retrieved_lists.len() != relevant_sets.len() {
152 return 0.0;
153 }
154
155 let reciprocal_ranks: f64 = retrieved_lists
156 .iter()
157 .zip(relevant_sets.iter())
158 .map(|(retrieved, relevant)| {
159 if relevant.is_empty() {
160 return 0.0;
161 }
162 retrieved
163 .iter()
164 .position(|id| relevant.contains(id))
165 .map(|pos| 1.0 / (pos + 1) as f64)
166 .unwrap_or(0.0)
167 })
168 .sum();
169
170 reciprocal_ranks / retrieved_lists.len() as f64
171}
172
173#[must_use]
186pub fn map(retrieved_lists: &[Vec<usize>], relevant_sets: &[HashSet<usize>]) -> f64 {
187 if retrieved_lists.is_empty() || retrieved_lists.len() != relevant_sets.len() {
188 return 0.0;
189 }
190
191 let average_precisions: f64 = retrieved_lists
192 .iter()
193 .zip(relevant_sets.iter())
194 .map(|(retrieved, relevant)| {
195 if relevant.is_empty() {
196 return 0.0;
197 }
198
199 let mut sum_precision = 0.0;
200 let mut relevant_count = 0;
201
202 for (i, id) in retrieved.iter().enumerate() {
203 if relevant.contains(id) {
204 relevant_count += 1;
205 #[allow(clippy::cast_precision_loss)]
206 let precision_at_i = f64::from(relevant_count) / (i + 1) as f64;
207 sum_precision += precision_at_i;
208 }
209 }
210
211 sum_precision / relevant.len() as f64
212 })
213 .sum();
214
215 average_precisions / retrieved_lists.len() as f64
216}
217
218#[must_use]
230pub fn hit_rate_at_k(retrieved: &[usize], relevant: &HashSet<usize>, k: usize) -> f64 {
231 let k = k.min(retrieved.len());
232 if k == 0 || relevant.is_empty() {
233 return 0.0;
234 }
235
236 let has_hit = retrieved[..k].iter().any(|id| relevant.contains(id));
237 if has_hit { 1.0 } else { 0.0 }
238}
239
240#[must_use]
253pub fn reciprocal_rank_fusion<T: Clone + Eq + std::hash::Hash + std::cmp::Ord>(
254 result_lists: &[Vec<(T, f32)>],
255 k: u32,
256) -> Vec<(T, f32)> {
257 use std::collections::BTreeMap;
258
259 let mut rrf_scores: BTreeMap<T, f32> = BTreeMap::new();
260
261 for list in result_lists {
262 for (rank, (item, _)) in list.iter().enumerate() {
263 let rrf_contribution = 1.0 / (k as f32 + rank as f32 + 1.0);
264 *rrf_scores.entry(item.clone()).or_insert(0.0) += rrf_contribution;
265 }
266 }
267
268 let mut fused: Vec<(T, f32)> = rrf_scores.into_iter().collect();
269 fused.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
270 fused
271}
272
273#[cfg(test)]
274mod tests {
275 use super::*;
276
277 #[test]
278 fn test_recall_at_k() {
279 let retrieved = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
280 let relevant: HashSet<usize> = [2, 5, 7, 11].into_iter().collect();
281
282 assert!((recall_at_k(&retrieved, &relevant, 10) - 0.75).abs() < 0.001);
284
285 assert!((recall_at_k(&retrieved, &relevant, 3) - 0.25).abs() < 0.001);
287
288 let empty: HashSet<usize> = HashSet::new();
290 assert_eq!(recall_at_k(&retrieved, &empty, 5), 1.0);
291 }
292
293 #[test]
294 fn test_precision_at_k() {
295 let retrieved = vec![1, 2, 3, 4, 5];
296 let relevant: HashSet<usize> = [2, 4].into_iter().collect();
297
298 assert!((precision_at_k(&retrieved, &relevant, 5) - 0.4).abs() < 0.001);
300
301 assert!((precision_at_k(&retrieved, &relevant, 3) - 0.333).abs() < 0.01);
303 }
304
305 #[test]
306 fn test_ndcg_at_k() {
307 let retrieved = vec![1, 2, 3, 4, 5];
308 let mut rel_scores = HashMap::new();
309 rel_scores.insert(1, 3.0); rel_scores.insert(2, 2.0); rel_scores.insert(3, 0.0); let ndcg = ndcg_at_k(&retrieved, &rel_scores, 3);
315 assert!(ndcg > 0.0 && ndcg <= 1.0);
316 }
317
318 #[test]
319 fn test_mrr() {
320 let retrieved_lists = vec![
321 vec![1, 2, 3, 4], vec![5, 6, 7, 8], vec![9, 10, 11, 12], ];
325 let relevant_sets = vec![
326 [2, 4].into_iter().collect(),
327 [7].into_iter().collect(),
328 [13].into_iter().collect(),
329 ];
330
331 let mrr_score = mrr(&retrieved_lists, &relevant_sets);
333 assert!((mrr_score - 0.277).abs() < 0.01);
334 }
335
336 #[test]
337 fn test_map() {
338 let retrieved_lists = vec![vec![1, 2, 3, 4, 5]];
339 let relevant_sets: Vec<HashSet<usize>> = vec![[2, 4].into_iter().collect()];
340
341 let map_score = map(&retrieved_lists, &relevant_sets);
344 assert!((map_score - 0.5).abs() < 0.001);
345 }
346
347 #[test]
348 fn test_reciprocal_rank_fusion() {
349 let list1 = vec![("a", 0.9), ("b", 0.8), ("c", 0.7)];
350 let list2 = vec![("c", 0.95), ("a", 0.85), ("d", 0.75)];
351
352 let fused = reciprocal_rank_fusion(&[list1, list2], 60);
353
354 assert!(!fused.is_empty());
356 assert!(fused.iter().any(|(item, _)| *item == "a"));
357 assert!(fused.iter().any(|(item, _)| *item == "c"));
358 }
359
360 #[test]
361 fn test_hit_rate_at_k() {
362 let retrieved = vec![1, 2, 3, 4, 5];
363 let relevant: HashSet<usize> = [6, 7].into_iter().collect();
364
365 assert_eq!(hit_rate_at_k(&retrieved, &relevant, 5), 0.0);
367
368 let relevant2: HashSet<usize> = [3, 7].into_iter().collect();
369 assert_eq!(hit_rate_at_k(&retrieved, &relevant2, 5), 1.0);
371 }
372}