1use crate::ChunkId;
4use serde::{Deserialize, Serialize};
5use std::collections::HashSet;
6
7#[derive(Debug, Clone, Default, Serialize, Deserialize)]
9pub struct RetrievalMetrics {
10 pub recall: std::collections::HashMap<usize, f32>,
12 pub precision: std::collections::HashMap<usize, f32>,
14 pub mrr: f32,
16 pub ndcg: std::collections::HashMap<usize, f32>,
18 pub map: f32,
20}
21
22impl RetrievalMetrics {
23 pub fn compute(retrieved: &[ChunkId], relevant: &HashSet<ChunkId>, k_values: &[usize]) -> Self {
25 contract_pre_configuration!(retrieved);
27 let mut metrics = Self::default();
28
29 for &k in k_values {
30 metrics.recall.insert(k, Self::recall_at_k(retrieved, relevant, k));
31 metrics.precision.insert(k, Self::precision_at_k(retrieved, relevant, k));
32 metrics.ndcg.insert(k, Self::ndcg_at_k(retrieved, relevant, k));
33 }
34
35 metrics.mrr = Self::mean_reciprocal_rank(retrieved, relevant);
36 metrics.map = Self::average_precision(retrieved, relevant);
37
38 metrics
39 }
40
41 #[must_use]
45 pub fn recall_at_k(retrieved: &[ChunkId], relevant: &HashSet<ChunkId>, k: usize) -> f32 {
46 if relevant.is_empty() {
47 return 0.0;
48 }
49
50 contract_pre_configuration!(retrieved);
52 let retrieved_k: HashSet<ChunkId> = retrieved.iter().take(k).copied().collect();
53 let relevant_retrieved = retrieved_k.intersection(relevant).count();
54
55 relevant_retrieved as f32 / relevant.len() as f32
56 }
57
58 #[must_use]
62 pub fn precision_at_k(retrieved: &[ChunkId], relevant: &HashSet<ChunkId>, k: usize) -> f32 {
63 if k == 0 {
64 return 0.0;
65 }
66
67 contract_pre_configuration!(retrieved);
69 let retrieved_k: HashSet<ChunkId> = retrieved.iter().take(k).copied().collect();
70 let relevant_retrieved = retrieved_k.intersection(relevant).count();
71
72 relevant_retrieved as f32 / k as f32
73 }
74
75 #[must_use]
79 pub fn mean_reciprocal_rank(retrieved: &[ChunkId], relevant: &HashSet<ChunkId>) -> f32 {
80 contract_pre_pagerank!(retrieved);
82 for (rank, id) in retrieved.iter().enumerate() {
83 if relevant.contains(id) {
84 return 1.0 / (rank + 1) as f32;
85 }
86 }
87 0.0
88 }
89
90 #[must_use]
94 pub fn ndcg_at_k(retrieved: &[ChunkId], relevant: &HashSet<ChunkId>, k: usize) -> f32 {
95 contract_pre_configuration!(retrieved);
97 let dcg = Self::dcg_at_k(retrieved, relevant, k);
98 let idcg = Self::ideal_dcg_at_k(relevant.len(), k);
99
100 if idcg == 0.0 {
101 0.0
102 } else {
103 dcg / idcg
104 }
105 }
106
107 fn dcg_at_k(retrieved: &[ChunkId], relevant: &HashSet<ChunkId>, k: usize) -> f32 {
112 let mut seen = HashSet::new();
113 retrieved
114 .iter()
115 .take(k)
116 .enumerate()
117 .filter(|(_, id)| relevant.contains(id) && seen.insert(**id))
118 .map(|(rank, _)| 1.0 / (rank as f32 + 2.0).max(f32::EPSILON).log2())
119 .sum()
120 }
121
122 fn ideal_dcg_at_k(num_relevant: usize, k: usize) -> f32 {
124 (0..num_relevant.min(k))
125 .map(|rank| 1.0 / (rank as f32 + 2.0).max(f32::EPSILON).log2())
126 .sum()
127 }
128
129 #[must_use]
133 pub fn average_precision(retrieved: &[ChunkId], relevant: &HashSet<ChunkId>) -> f32 {
134 if relevant.is_empty() {
135 return 0.0;
136 }
137
138 contract_pre_configuration!(retrieved);
140 let mut sum_precision = 0.0;
141 let mut relevant_count = 0;
142
143 for (rank, id) in retrieved.iter().enumerate() {
144 if relevant.contains(id) {
145 relevant_count += 1;
146 sum_precision += relevant_count as f32 / (rank + 1) as f32;
147 }
148 }
149
150 sum_precision / relevant.len().max(1) as f32
151 }
152
153 #[must_use]
155 pub fn f1_at_k(retrieved: &[ChunkId], relevant: &HashSet<ChunkId>, k: usize) -> f32 {
156 contract_pre_configuration!(retrieved);
158 let precision = Self::precision_at_k(retrieved, relevant, k);
159 let recall = Self::recall_at_k(retrieved, relevant, k);
160
161 if precision + recall == 0.0 {
162 0.0
163 } else {
164 2.0 * precision * recall / (precision + recall)
165 }
166 }
167
168 #[must_use]
170 pub fn hit_rate_at_k(retrieved: &[ChunkId], relevant: &HashSet<ChunkId>, k: usize) -> f32 {
171 contract_pre_configuration!(retrieved);
173 let retrieved_k: HashSet<ChunkId> = retrieved.iter().take(k).copied().collect();
174 if retrieved_k.intersection(relevant).next().is_some() {
175 1.0
176 } else {
177 0.0
178 }
179 }
180}
181
182#[derive(Debug, Clone, Default, Serialize, Deserialize)]
184pub struct AggregatedMetrics {
185 pub mean_recall: std::collections::HashMap<usize, f32>,
187 pub mean_precision: std::collections::HashMap<usize, f32>,
189 pub mean_mrr: f32,
191 pub mean_ndcg: std::collections::HashMap<usize, f32>,
193 pub map: f32,
195 pub query_count: usize,
197}
198
199impl AggregatedMetrics {
200 pub fn aggregate(metrics: &[RetrievalMetrics]) -> Self {
202 if metrics.is_empty() {
203 return Self::default();
204 }
205
206 let n = metrics.len() as f32;
207 let mut agg = Self { query_count: metrics.len(), ..Default::default() };
208
209 agg.mean_mrr = metrics.iter().map(|m| m.mrr).sum::<f32>() / n;
211 agg.map = metrics.iter().map(|m| m.map).sum::<f32>() / n;
212
213 if let Some(first) = metrics.first() {
215 for &k in first.recall.keys() {
216 let mean_recall = metrics.iter().filter_map(|m| m.recall.get(&k)).sum::<f32>() / n;
217 agg.mean_recall.insert(k, mean_recall);
218
219 let mean_precision =
220 metrics.iter().filter_map(|m| m.precision.get(&k)).sum::<f32>() / n;
221 agg.mean_precision.insert(k, mean_precision);
222
223 let mean_ndcg = metrics.iter().filter_map(|m| m.ndcg.get(&k)).sum::<f32>() / n;
224 agg.mean_ndcg.insert(k, mean_ndcg);
225 }
226 }
227
228 agg
229 }
230}
231
232#[cfg(test)]
233mod tests {
234 use super::*;
235
236 fn chunk_id(n: u128) -> ChunkId {
237 ChunkId(uuid::Uuid::from_u128(n))
238 }
239
240 #[test]
243 fn test_recall_at_k_perfect() {
244 let retrieved = vec![chunk_id(1), chunk_id(2), chunk_id(3)];
245 let relevant: HashSet<_> = [chunk_id(1), chunk_id(2), chunk_id(3)].into();
246
247 let recall = RetrievalMetrics::recall_at_k(&retrieved, &relevant, 3);
248 assert!((recall - 1.0).abs() < 0.001);
249 }
250
251 #[test]
252 fn test_recall_at_k_partial() {
253 let retrieved = vec![chunk_id(1), chunk_id(4), chunk_id(5)];
254 let relevant: HashSet<_> = [chunk_id(1), chunk_id(2), chunk_id(3)].into();
255
256 let recall = RetrievalMetrics::recall_at_k(&retrieved, &relevant, 3);
257 assert!((recall - 1.0 / 3.0).abs() < 0.001);
258 }
259
260 #[test]
261 fn test_recall_at_k_none() {
262 let retrieved = vec![chunk_id(4), chunk_id(5), chunk_id(6)];
263 let relevant: HashSet<_> = [chunk_id(1), chunk_id(2), chunk_id(3)].into();
264
265 let recall = RetrievalMetrics::recall_at_k(&retrieved, &relevant, 3);
266 assert!((recall - 0.0).abs() < 0.001);
267 }
268
269 #[test]
270 fn test_recall_at_k_empty_relevant() {
271 let retrieved = vec![chunk_id(1), chunk_id(2)];
272 let relevant: HashSet<ChunkId> = HashSet::new();
273
274 let recall = RetrievalMetrics::recall_at_k(&retrieved, &relevant, 2);
275 assert!((recall - 0.0).abs() < 0.001);
276 }
277
278 #[test]
279 fn test_recall_at_k_smaller_k() {
280 let retrieved = vec![chunk_id(4), chunk_id(1), chunk_id(2)];
281 let relevant: HashSet<_> = [chunk_id(1), chunk_id(2)].into();
282
283 let recall = RetrievalMetrics::recall_at_k(&retrieved, &relevant, 1);
285 assert!((recall - 0.0).abs() < 0.001);
286
287 let recall = RetrievalMetrics::recall_at_k(&retrieved, &relevant, 2);
289 assert!((recall - 0.5).abs() < 0.001);
290 }
291
292 #[test]
295 fn test_precision_at_k_perfect() {
296 let retrieved = vec![chunk_id(1), chunk_id(2)];
297 let relevant: HashSet<_> = [chunk_id(1), chunk_id(2)].into();
298
299 let precision = RetrievalMetrics::precision_at_k(&retrieved, &relevant, 2);
300 assert!((precision - 1.0).abs() < 0.001);
301 }
302
303 #[test]
304 fn test_precision_at_k_half() {
305 let retrieved = vec![chunk_id(1), chunk_id(4)];
306 let relevant: HashSet<_> = [chunk_id(1), chunk_id(2)].into();
307
308 let precision = RetrievalMetrics::precision_at_k(&retrieved, &relevant, 2);
309 assert!((precision - 0.5).abs() < 0.001);
310 }
311
312 #[test]
313 fn test_precision_at_k_zero() {
314 let precision = RetrievalMetrics::precision_at_k(&[], &HashSet::new(), 0);
315 assert!((precision - 0.0).abs() < 0.001);
316 }
317
318 #[test]
321 fn test_mrr_first_position() {
322 let retrieved = vec![chunk_id(1), chunk_id(2), chunk_id(3)];
323 let relevant: HashSet<_> = [chunk_id(1)].into();
324
325 let mrr = RetrievalMetrics::mean_reciprocal_rank(&retrieved, &relevant);
326 assert!((mrr - 1.0).abs() < 0.001);
327 }
328
329 #[test]
330 fn test_mrr_second_position() {
331 let retrieved = vec![chunk_id(4), chunk_id(1), chunk_id(3)];
332 let relevant: HashSet<_> = [chunk_id(1)].into();
333
334 let mrr = RetrievalMetrics::mean_reciprocal_rank(&retrieved, &relevant);
335 assert!((mrr - 0.5).abs() < 0.001);
336 }
337
338 #[test]
339 fn test_mrr_third_position() {
340 let retrieved = vec![chunk_id(4), chunk_id(5), chunk_id(1)];
341 let relevant: HashSet<_> = [chunk_id(1)].into();
342
343 let mrr = RetrievalMetrics::mean_reciprocal_rank(&retrieved, &relevant);
344 assert!((mrr - 1.0 / 3.0).abs() < 0.001);
345 }
346
347 #[test]
348 fn test_mrr_not_found() {
349 let retrieved = vec![chunk_id(4), chunk_id(5), chunk_id(6)];
350 let relevant: HashSet<_> = [chunk_id(1)].into();
351
352 let mrr = RetrievalMetrics::mean_reciprocal_rank(&retrieved, &relevant);
353 assert!((mrr - 0.0).abs() < 0.001);
354 }
355
356 #[test]
359 fn test_ndcg_perfect_order() {
360 let retrieved = vec![chunk_id(1), chunk_id(2)];
361 let relevant: HashSet<_> = [chunk_id(1), chunk_id(2)].into();
362
363 let ndcg = RetrievalMetrics::ndcg_at_k(&retrieved, &relevant, 2);
364 assert!((ndcg - 1.0).abs() < 0.001);
365 }
366
367 #[test]
368 fn test_ndcg_no_relevant() {
369 let retrieved = vec![chunk_id(3), chunk_id(4)];
370 let relevant: HashSet<_> = [chunk_id(1), chunk_id(2)].into();
371
372 let ndcg = RetrievalMetrics::ndcg_at_k(&retrieved, &relevant, 2);
373 assert!((ndcg - 0.0).abs() < 0.001);
374 }
375
376 #[test]
377 fn test_ndcg_empty_relevant() {
378 let retrieved = vec![chunk_id(1), chunk_id(2)];
379 let relevant: HashSet<ChunkId> = HashSet::new();
380
381 let ndcg = RetrievalMetrics::ndcg_at_k(&retrieved, &relevant, 2);
382 assert!((ndcg - 0.0).abs() < 0.001);
383 }
384
385 #[test]
388 fn test_ap_perfect() {
389 let retrieved = vec![chunk_id(1), chunk_id(2), chunk_id(3)];
390 let relevant: HashSet<_> = [chunk_id(1), chunk_id(2), chunk_id(3)].into();
391
392 let ap = RetrievalMetrics::average_precision(&retrieved, &relevant);
393 assert!((ap - 1.0).abs() < 0.001);
395 }
396
397 #[test]
398 fn test_ap_interleaved() {
399 let retrieved = vec![chunk_id(1), chunk_id(4), chunk_id(2)];
400 let relevant: HashSet<_> = [chunk_id(1), chunk_id(2)].into();
401
402 let ap = RetrievalMetrics::average_precision(&retrieved, &relevant);
403 assert!((ap - 5.0 / 6.0).abs() < 0.001);
405 }
406
407 #[test]
408 fn test_ap_empty_relevant() {
409 let retrieved = vec![chunk_id(1), chunk_id(2)];
410 let relevant: HashSet<ChunkId> = HashSet::new();
411
412 let ap = RetrievalMetrics::average_precision(&retrieved, &relevant);
413 assert!((ap - 0.0).abs() < 0.001);
414 }
415
416 #[test]
419 fn test_f1_perfect() {
420 let retrieved = vec![chunk_id(1), chunk_id(2)];
421 let relevant: HashSet<_> = [chunk_id(1), chunk_id(2)].into();
422
423 let f1 = RetrievalMetrics::f1_at_k(&retrieved, &relevant, 2);
424 assert!((f1 - 1.0).abs() < 0.001);
425 }
426
427 #[test]
428 fn test_f1_zero() {
429 let retrieved = vec![chunk_id(3), chunk_id(4)];
430 let relevant: HashSet<_> = [chunk_id(1), chunk_id(2)].into();
431
432 let f1 = RetrievalMetrics::f1_at_k(&retrieved, &relevant, 2);
433 assert!((f1 - 0.0).abs() < 0.001);
434 }
435
436 #[test]
439 fn test_hit_rate_hit() {
440 let retrieved = vec![chunk_id(3), chunk_id(1), chunk_id(4)];
441 let relevant: HashSet<_> = [chunk_id(1), chunk_id(2)].into();
442
443 let hr = RetrievalMetrics::hit_rate_at_k(&retrieved, &relevant, 3);
444 assert!((hr - 1.0).abs() < 0.001);
445 }
446
447 #[test]
448 fn test_hit_rate_miss() {
449 let retrieved = vec![chunk_id(3), chunk_id(4)];
450 let relevant: HashSet<_> = [chunk_id(1), chunk_id(2)].into();
451
452 let hr = RetrievalMetrics::hit_rate_at_k(&retrieved, &relevant, 2);
453 assert!((hr - 0.0).abs() < 0.001);
454 }
455
456 #[test]
459 fn test_compute_all_metrics() {
460 let retrieved = vec![chunk_id(1), chunk_id(4), chunk_id(2), chunk_id(5)];
461 let relevant: HashSet<_> = [chunk_id(1), chunk_id(2), chunk_id(3)].into();
462 let k_values = vec![1, 2, 5, 10];
463
464 let metrics = RetrievalMetrics::compute(&retrieved, &relevant, &k_values);
465
466 assert!(!metrics.recall.is_empty());
467 assert!(!metrics.precision.is_empty());
468 assert!(!metrics.ndcg.is_empty());
469 assert!(metrics.mrr > 0.0);
470 }
471
472 #[test]
475 fn test_aggregate_empty() {
476 let agg = AggregatedMetrics::aggregate(&[]);
477 assert_eq!(agg.query_count, 0);
478 }
479
480 #[test]
481 fn test_aggregate_single() {
482 let retrieved = vec![chunk_id(1), chunk_id(2)];
483 let relevant: HashSet<_> = [chunk_id(1), chunk_id(2)].into();
484 let metrics = RetrievalMetrics::compute(&retrieved, &relevant, &[1, 2]);
485
486 let agg = AggregatedMetrics::aggregate(&[metrics]);
487 assert_eq!(agg.query_count, 1);
488 assert!((agg.mean_mrr - 1.0).abs() < 0.001);
489 }
490
491 #[test]
492 fn test_aggregate_multiple() {
493 let metrics1 = RetrievalMetrics {
494 mrr: 1.0,
495 map: 1.0,
496 recall: [(1, 1.0), (2, 1.0)].into(),
497 precision: [(1, 1.0), (2, 1.0)].into(),
498 ndcg: [(1, 1.0), (2, 1.0)].into(),
499 };
500 let metrics2 = RetrievalMetrics {
501 mrr: 0.5,
502 map: 0.5,
503 recall: [(1, 0.5), (2, 0.5)].into(),
504 precision: [(1, 0.5), (2, 0.5)].into(),
505 ndcg: [(1, 0.5), (2, 0.5)].into(),
506 };
507
508 let agg = AggregatedMetrics::aggregate(&[metrics1, metrics2]);
509
510 assert_eq!(agg.query_count, 2);
511 assert!((agg.mean_mrr - 0.75).abs() < 0.001);
512 assert!((agg.map - 0.75).abs() < 0.001);
513 }
514
515 use proptest::prelude::*;
518
519 proptest! {
520 #[test]
521 fn prop_recall_bounded(
522 retrieved_ids in prop::collection::vec(0u128..100, 1..20),
523 relevant_ids in prop::collection::vec(0u128..100, 1..10),
524 k in 1usize..20
525 ) {
526 let retrieved: Vec<_> = retrieved_ids.into_iter().map(chunk_id).collect();
527 let relevant: HashSet<_> = relevant_ids.into_iter().map(chunk_id).collect();
528
529 let recall = RetrievalMetrics::recall_at_k(&retrieved, &relevant, k);
530 prop_assert!(recall >= 0.0);
531 prop_assert!(recall <= 1.0);
532 }
533
534 #[test]
535 fn prop_precision_bounded(
536 retrieved_ids in prop::collection::vec(0u128..100, 1..20),
537 relevant_ids in prop::collection::vec(0u128..100, 1..10),
538 k in 1usize..20
539 ) {
540 let retrieved: Vec<_> = retrieved_ids.into_iter().map(chunk_id).collect();
541 let relevant: HashSet<_> = relevant_ids.into_iter().map(chunk_id).collect();
542
543 let precision = RetrievalMetrics::precision_at_k(&retrieved, &relevant, k);
544 prop_assert!(precision >= 0.0);
545 prop_assert!(precision <= 1.0);
546 }
547
548 #[test]
549 fn prop_mrr_bounded(
550 retrieved_ids in prop::collection::vec(0u128..100, 1..20),
551 relevant_ids in prop::collection::vec(0u128..100, 1..10)
552 ) {
553 let retrieved: Vec<_> = retrieved_ids.into_iter().map(chunk_id).collect();
554 let relevant: HashSet<_> = relevant_ids.into_iter().map(chunk_id).collect();
555
556 let mrr = RetrievalMetrics::mean_reciprocal_rank(&retrieved, &relevant);
557 prop_assert!(mrr >= 0.0);
558 prop_assert!(mrr <= 1.0);
559 }
560
561 #[test]
562 fn prop_ndcg_bounded(
563 retrieved_ids in prop::collection::vec(0u128..100, 1..20),
564 relevant_ids in prop::collection::vec(0u128..100, 1..10),
565 k in 1usize..20
566 ) {
567 let retrieved: Vec<_> = retrieved_ids.into_iter().map(chunk_id).collect();
568 let relevant: HashSet<_> = relevant_ids.into_iter().map(chunk_id).collect();
569
570 let ndcg = RetrievalMetrics::ndcg_at_k(&retrieved, &relevant, k);
571 prop_assert!(ndcg >= 0.0);
572 prop_assert!(ndcg <= 1.0);
573 }
574 }
575}