1use std::collections::{HashMap, HashSet};
11
12#[derive(Debug, Clone)]
14pub struct QueryResult {
15 pub query_id: String,
17 pub retrieved_ids: Vec<String>,
19 pub relevant_ids: HashSet<String>,
21 pub relevance_grades: Option<HashMap<String, f64>>,
23}
24
25impl QueryResult {
26 pub fn new(
27 query_id: impl Into<String>,
28 retrieved: Vec<String>,
29 relevant: HashSet<String>,
30 ) -> Self {
31 Self {
32 query_id: query_id.into(),
33 retrieved_ids: retrieved,
34 relevant_ids: relevant,
35 relevance_grades: None,
36 }
37 }
38
39 pub fn with_grades(mut self, grades: HashMap<String, f64>) -> Self {
40 self.relevance_grades = Some(grades);
41 self
42 }
43}
44
45#[derive(Debug, Clone, Default)]
47pub struct EvaluationResult {
48 pub query_id: String,
49 pub recall: f64,
50 pub precision: f64,
51 pub ndcg: f64,
52 pub mrr: f64,
53 pub ap: f64,
54 pub k: usize,
55}
56
57#[derive(Debug, Clone)]
59pub struct RetrievalMetrics {
60 pub recall_at_k: HashMap<usize, f64>,
61 pub precision_at_k: HashMap<usize, f64>,
62 pub ndcg_at_k: HashMap<usize, f64>,
63 pub mrr: f64,
64 pub map: f64,
65}
66
67impl RetrievalMetrics {
68 pub fn compute_all(
70 retrieved: &[String],
71 relevant: &HashSet<String>,
72 k_values: &[usize],
73 ) -> Self {
74 let mut recall_at_k = HashMap::new();
75 let mut precision_at_k = HashMap::new();
76 let mut ndcg_at_k = HashMap::new();
77
78 for &k in k_values {
79 recall_at_k.insert(k, recall_at_k_impl(retrieved, relevant, k));
80 precision_at_k.insert(k, precision_at_k_impl(retrieved, relevant, k));
81 ndcg_at_k.insert(k, ndcg_at_k_binary(retrieved, relevant, k));
82 }
83
84 let mrr = mean_reciprocal_rank_single(retrieved, relevant);
85 let map = average_precision_impl(retrieved, relevant);
86
87 Self {
88 recall_at_k,
89 precision_at_k,
90 ndcg_at_k,
91 mrr,
92 map,
93 }
94 }
95
96 pub fn compute(retrieved: &[String], relevant: &HashSet<String>, k: usize) -> EvaluationResult {
98 EvaluationResult {
99 query_id: String::new(),
100 recall: recall_at_k_impl(retrieved, relevant, k),
101 precision: precision_at_k_impl(retrieved, relevant, k),
102 ndcg: ndcg_at_k_binary(retrieved, relevant, k),
103 mrr: mean_reciprocal_rank_single(retrieved, relevant),
104 ap: average_precision_impl(retrieved, relevant),
105 k,
106 }
107 }
108}
109
110pub fn recall_at_k(retrieved: &[impl AsRef<str>], relevant: &HashSet<String>, k: usize) -> f64 {
122 let retrieved_str: Vec<String> = retrieved.iter().map(|s| s.as_ref().to_string()).collect();
123 recall_at_k_impl(&retrieved_str, relevant, k)
124}
125
126fn recall_at_k_impl(retrieved: &[String], relevant: &HashSet<String>, k: usize) -> f64 {
127 if relevant.is_empty() {
128 return 0.0;
129 }
130
131 let top_k: HashSet<_> = retrieved.iter().take(k).cloned().collect();
132 let hits = relevant.intersection(&top_k).count();
133
134 hits as f64 / relevant.len() as f64
135}
136
137pub fn precision_at_k(retrieved: &[impl AsRef<str>], relevant: &HashSet<String>, k: usize) -> f64 {
149 let retrieved_str: Vec<String> = retrieved.iter().map(|s| s.as_ref().to_string()).collect();
150 precision_at_k_impl(&retrieved_str, relevant, k)
151}
152
153fn precision_at_k_impl(retrieved: &[String], relevant: &HashSet<String>, k: usize) -> f64 {
154 if k == 0 {
155 return 0.0;
156 }
157
158 let actual_k = k.min(retrieved.len());
159 if actual_k == 0 {
160 return 0.0;
161 }
162
163 let hits = retrieved
164 .iter()
165 .take(actual_k)
166 .filter(|doc| relevant.contains(*doc))
167 .count();
168
169 hits as f64 / actual_k as f64
170}
171
172pub fn ndcg_at_k(retrieved: &[impl AsRef<str>], relevant: &HashSet<String>, k: usize) -> f64 {
185 let retrieved_str: Vec<String> = retrieved.iter().map(|s| s.as_ref().to_string()).collect();
186 ndcg_at_k_binary(&retrieved_str, relevant, k)
187}
188
189fn ndcg_at_k_binary(retrieved: &[String], relevant: &HashSet<String>, k: usize) -> f64 {
190 if relevant.is_empty() {
191 return 0.0;
192 }
193
194 let dcg: f64 = retrieved
196 .iter()
197 .take(k)
198 .enumerate()
199 .filter(|(_, doc)| relevant.contains(*doc))
200 .map(|(i, _)| 1.0 / (i as f64 + 2.0).log2()) .sum();
202
203 let num_relevant_in_k = k.min(relevant.len());
205 let idcg: f64 = (0..num_relevant_in_k)
206 .map(|i| 1.0 / (i as f64 + 2.0).log2())
207 .sum();
208
209 if idcg == 0.0 {
210 return 0.0;
211 }
212
213 dcg / idcg
214}
215
216pub fn ndcg_at_k_graded(
223 retrieved: &[String],
224 relevance_grades: &HashMap<String, f64>,
225 k: usize,
226) -> f64 {
227 if relevance_grades.is_empty() {
228 return 0.0;
229 }
230
231 let dcg: f64 = retrieved
233 .iter()
234 .take(k)
235 .enumerate()
236 .map(|(i, doc)| {
237 let rel = relevance_grades.get(doc).copied().unwrap_or(0.0);
238 (2_f64.powf(rel) - 1.0) / (i as f64 + 2.0).log2()
239 })
240 .sum();
241
242 let mut sorted_grades: Vec<f64> = relevance_grades.values().copied().collect();
244 sorted_grades.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
245
246 let idcg: f64 = sorted_grades
247 .iter()
248 .take(k)
249 .enumerate()
250 .map(|(i, &rel)| (2_f64.powf(rel) - 1.0) / (i as f64 + 2.0).log2())
251 .sum();
252
253 if idcg == 0.0 {
254 return 0.0;
255 }
256
257 dcg / idcg
258}
259
260pub fn mean_reciprocal_rank(results: &[QueryResult]) -> f64 {
271 if results.is_empty() {
272 return 0.0;
273 }
274
275 let sum: f64 = results
276 .iter()
277 .map(|r| mean_reciprocal_rank_single(&r.retrieved_ids, &r.relevant_ids))
278 .sum();
279
280 sum / results.len() as f64
281}
282
283fn mean_reciprocal_rank_single(retrieved: &[String], relevant: &HashSet<String>) -> f64 {
284 for (i, doc) in retrieved.iter().enumerate() {
285 if relevant.contains(doc) {
286 return 1.0 / (i as f64 + 1.0);
287 }
288 }
289 0.0
290}
291
292pub fn average_precision(retrieved: &[impl AsRef<str>], relevant: &HashSet<String>) -> f64 {
300 let retrieved_str: Vec<String> = retrieved.iter().map(|s| s.as_ref().to_string()).collect();
301 average_precision_impl(&retrieved_str, relevant)
302}
303
304fn average_precision_impl(retrieved: &[String], relevant: &HashSet<String>) -> f64 {
305 if relevant.is_empty() {
306 return 0.0;
307 }
308
309 let mut num_relevant_seen = 0;
310 let mut sum_precision = 0.0;
311
312 for (i, doc) in retrieved.iter().enumerate() {
313 if relevant.contains(doc) {
314 num_relevant_seen += 1;
315 let precision = num_relevant_seen as f64 / (i as f64 + 1.0);
317 sum_precision += precision;
318 }
319 }
320
321 sum_precision / relevant.len() as f64
322}
323
324pub fn mean_average_precision(results: &[QueryResult]) -> f64 {
328 if results.is_empty() {
329 return 0.0;
330 }
331
332 let sum: f64 = results
333 .iter()
334 .map(|r| average_precision_impl(&r.retrieved_ids, &r.relevant_ids))
335 .sum();
336
337 sum / results.len() as f64
338}
339
340#[cfg(test)]
341mod tests {
342 use super::*;
343
344 fn make_relevant(ids: &[&str]) -> HashSet<String> {
345 ids.iter().map(|s| s.to_string()).collect()
346 }
347
348 fn make_retrieved(ids: &[&str]) -> Vec<String> {
349 ids.iter().map(|s| s.to_string()).collect()
350 }
351
352 #[test]
353 fn test_recall_at_k_perfect() {
354 let retrieved = make_retrieved(&["a", "b", "c", "d", "e"]);
355 let relevant = make_relevant(&["a", "b", "c"]);
356
357 assert_eq!(recall_at_k_impl(&retrieved, &relevant, 3), 1.0);
358 assert_eq!(recall_at_k_impl(&retrieved, &relevant, 5), 1.0);
359 }
360
361 #[test]
362 fn test_recall_at_k_partial() {
363 let retrieved = make_retrieved(&["a", "x", "b", "y", "c"]);
364 let relevant = make_relevant(&["a", "b", "c"]);
365
366 assert!((recall_at_k_impl(&retrieved, &relevant, 1) - 1.0 / 3.0).abs() < 0.001);
368
369 assert!((recall_at_k_impl(&retrieved, &relevant, 3) - 2.0 / 3.0).abs() < 0.001);
371
372 assert_eq!(recall_at_k_impl(&retrieved, &relevant, 5), 1.0);
374 }
375
376 #[test]
377 fn test_recall_at_k_none() {
378 let retrieved = make_retrieved(&["x", "y", "z"]);
379 let relevant = make_relevant(&["a", "b", "c"]);
380
381 assert_eq!(recall_at_k_impl(&retrieved, &relevant, 3), 0.0);
382 }
383
384 #[test]
385 fn test_recall_at_k_empty_relevant() {
386 let retrieved = make_retrieved(&["a", "b", "c"]);
387 let relevant = HashSet::new();
388
389 assert_eq!(recall_at_k_impl(&retrieved, &relevant, 3), 0.0);
390 }
391
392 #[test]
393 fn test_precision_at_k_perfect() {
394 let retrieved = make_retrieved(&["a", "b", "c"]);
395 let relevant = make_relevant(&["a", "b", "c", "d", "e"]);
396
397 assert_eq!(precision_at_k_impl(&retrieved, &relevant, 3), 1.0);
398 }
399
400 #[test]
401 fn test_precision_at_k_partial() {
402 let retrieved = make_retrieved(&["a", "x", "b", "y", "c"]);
403 let relevant = make_relevant(&["a", "b", "c"]);
404
405 assert_eq!(precision_at_k_impl(&retrieved, &relevant, 1), 1.0);
407
408 assert_eq!(precision_at_k_impl(&retrieved, &relevant, 2), 0.5);
410
411 assert_eq!(precision_at_k_impl(&retrieved, &relevant, 5), 0.6);
413 }
414
415 #[test]
416 fn test_mrr_first_position() {
417 let retrieved = make_retrieved(&["a", "b", "c"]);
418 let relevant = make_relevant(&["a"]);
419
420 assert_eq!(mean_reciprocal_rank_single(&retrieved, &relevant), 1.0);
421 }
422
423 #[test]
424 fn test_mrr_second_position() {
425 let retrieved = make_retrieved(&["x", "a", "c"]);
426 let relevant = make_relevant(&["a"]);
427
428 assert_eq!(mean_reciprocal_rank_single(&retrieved, &relevant), 0.5);
429 }
430
431 #[test]
432 fn test_mrr_third_position() {
433 let retrieved = make_retrieved(&["x", "y", "a"]);
434 let relevant = make_relevant(&["a"]);
435
436 assert!((mean_reciprocal_rank_single(&retrieved, &relevant) - 1.0 / 3.0).abs() < 0.001);
437 }
438
439 #[test]
440 fn test_mrr_not_found() {
441 let retrieved = make_retrieved(&["x", "y", "z"]);
442 let relevant = make_relevant(&["a"]);
443
444 assert_eq!(mean_reciprocal_rank_single(&retrieved, &relevant), 0.0);
445 }
446
447 #[test]
448 fn test_ndcg_perfect() {
449 let retrieved = make_retrieved(&["a", "b", "c", "x", "y"]);
450 let relevant = make_relevant(&["a", "b", "c"]);
451
452 assert!((ndcg_at_k_binary(&retrieved, &relevant, 5) - 1.0).abs() < 0.001);
454 }
455
456 #[test]
457 fn test_ndcg_partial() {
458 let retrieved = make_retrieved(&["x", "a", "y", "b", "c"]);
459 let relevant = make_relevant(&["a", "b", "c"]);
460
461 let ndcg = ndcg_at_k_binary(&retrieved, &relevant, 5);
463 assert!(ndcg > 0.0 && ndcg < 1.0);
464 }
465
466 #[test]
467 fn test_average_precision() {
468 let retrieved = make_retrieved(&["a", "x", "b", "y", "c"]);
469 let relevant = make_relevant(&["a", "b", "c"]);
470
471 let ap = average_precision_impl(&retrieved, &relevant);
473 assert!(ap > 0.7 && ap < 0.8);
474 }
475
476 #[test]
477 fn test_average_precision_perfect() {
478 let retrieved = make_retrieved(&["a", "b", "c", "x", "y"]);
479 let relevant = make_relevant(&["a", "b", "c"]);
480
481 let ap = average_precision_impl(&retrieved, &relevant);
483 assert_eq!(ap, 1.0);
484 }
485
486 #[test]
487 fn test_retrieval_metrics_compute() {
488 let retrieved = make_retrieved(&["a", "b", "x", "c", "y"]);
489 let relevant = make_relevant(&["a", "b", "c"]);
490
491 let metrics = RetrievalMetrics::compute_all(&retrieved, &relevant, &[5, 10]);
492
493 assert!(metrics.recall_at_k.contains_key(&5));
494 assert!(metrics.precision_at_k.contains_key(&5));
495 assert!(metrics.ndcg_at_k.contains_key(&5));
496 assert!(metrics.mrr > 0.0);
497 assert!(metrics.map > 0.0);
498 }
499}