Skip to main content

aprender_rag/
metrics.rs

1//! Retrieval evaluation metrics
2
3use crate::ChunkId;
4use serde::{Deserialize, Serialize};
5use std::collections::HashSet;
6
7/// Retrieval metrics for evaluation
8#[derive(Debug, Clone, Default, Serialize, Deserialize)]
9pub struct RetrievalMetrics {
10    /// Recall@k for various k values
11    pub recall: std::collections::HashMap<usize, f32>,
12    /// Precision@k for various k values
13    pub precision: std::collections::HashMap<usize, f32>,
14    /// Mean Reciprocal Rank
15    pub mrr: f32,
16    /// Normalized Discounted Cumulative Gain@k
17    pub ndcg: std::collections::HashMap<usize, f32>,
18    /// Mean Average Precision
19    pub map: f32,
20}
21
22impl RetrievalMetrics {
23    /// Compute all metrics for a single query
24    pub fn compute(retrieved: &[ChunkId], relevant: &HashSet<ChunkId>, k_values: &[usize]) -> Self {
25        // Contract: configuration-v1.yaml precondition (pv codegen)
26        contract_pre_configuration!(retrieved);
27        let mut metrics = Self::default();
28
29        for &k in k_values {
30            metrics.recall.insert(k, Self::recall_at_k(retrieved, relevant, k));
31            metrics.precision.insert(k, Self::precision_at_k(retrieved, relevant, k));
32            metrics.ndcg.insert(k, Self::ndcg_at_k(retrieved, relevant, k));
33        }
34
35        metrics.mrr = Self::mean_reciprocal_rank(retrieved, relevant);
36        metrics.map = Self::average_precision(retrieved, relevant);
37
38        metrics
39    }
40
41    /// Compute Recall@k
42    ///
43    /// Recall@k = |relevant ∩ retrieved@k| / |relevant|
44    #[must_use]
45    pub fn recall_at_k(retrieved: &[ChunkId], relevant: &HashSet<ChunkId>, k: usize) -> f32 {
46        if relevant.is_empty() {
47            return 0.0;
48        }
49
50        // Contract: configuration-v1.yaml precondition (pv codegen)
51        contract_pre_configuration!(retrieved);
52        let retrieved_k: HashSet<ChunkId> = retrieved.iter().take(k).copied().collect();
53        let relevant_retrieved = retrieved_k.intersection(relevant).count();
54
55        relevant_retrieved as f32 / relevant.len() as f32
56    }
57
58    /// Compute Precision@k
59    ///
60    /// Precision@k = |relevant ∩ retrieved@k| / k
61    #[must_use]
62    pub fn precision_at_k(retrieved: &[ChunkId], relevant: &HashSet<ChunkId>, k: usize) -> f32 {
63        if k == 0 {
64            return 0.0;
65        }
66
67        // Contract: configuration-v1.yaml precondition (pv codegen)
68        contract_pre_configuration!(retrieved);
69        let retrieved_k: HashSet<ChunkId> = retrieved.iter().take(k).copied().collect();
70        let relevant_retrieved = retrieved_k.intersection(relevant).count();
71
72        relevant_retrieved as f32 / k as f32
73    }
74
75    /// Compute Mean Reciprocal Rank (MRR)
76    ///
77    /// MRR = 1 / rank of first relevant result
78    #[must_use]
79    pub fn mean_reciprocal_rank(retrieved: &[ChunkId], relevant: &HashSet<ChunkId>) -> f32 {
80        // Contract: pagerank-kernel-v1.yaml precondition (pv codegen)
81        contract_pre_pagerank!(retrieved);
82        for (rank, id) in retrieved.iter().enumerate() {
83            if relevant.contains(id) {
84                return 1.0 / (rank + 1) as f32;
85            }
86        }
87        0.0
88    }
89
90    /// Compute Normalized Discounted Cumulative Gain@k
91    ///
92    /// NDCG@k = DCG@k / IDCG@k
93    #[must_use]
94    pub fn ndcg_at_k(retrieved: &[ChunkId], relevant: &HashSet<ChunkId>, k: usize) -> f32 {
95        // Contract: configuration-v1.yaml precondition (pv codegen)
96        contract_pre_configuration!(retrieved);
97        let dcg = Self::dcg_at_k(retrieved, relevant, k);
98        let idcg = Self::ideal_dcg_at_k(relevant.len(), k);
99
100        if idcg == 0.0 {
101            0.0
102        } else {
103            dcg / idcg
104        }
105    }
106
107    /// Compute Discounted Cumulative Gain@k
108    ///
109    /// Note: Each relevant item is counted at most once (at its first occurrence)
110    /// to ensure NDCG remains bounded by 1.0.
111    fn dcg_at_k(retrieved: &[ChunkId], relevant: &HashSet<ChunkId>, k: usize) -> f32 {
112        let mut seen = HashSet::new();
113        retrieved
114            .iter()
115            .take(k)
116            .enumerate()
117            .filter(|(_, id)| relevant.contains(id) && seen.insert(**id))
118            .map(|(rank, _)| 1.0 / (rank as f32 + 2.0).max(f32::EPSILON).log2())
119            .sum()
120    }
121
122    /// Compute Ideal DCG@k (best possible DCG)
123    fn ideal_dcg_at_k(num_relevant: usize, k: usize) -> f32 {
124        (0..num_relevant.min(k))
125            .map(|rank| 1.0 / (rank as f32 + 2.0).max(f32::EPSILON).log2())
126            .sum()
127    }
128
129    /// Compute Average Precision (AP)
130    ///
131    /// AP = (1/|relevant|) * Σ (Precision@k * rel(k))
132    #[must_use]
133    pub fn average_precision(retrieved: &[ChunkId], relevant: &HashSet<ChunkId>) -> f32 {
134        if relevant.is_empty() {
135            return 0.0;
136        }
137
138        // Contract: configuration-v1.yaml precondition (pv codegen)
139        contract_pre_configuration!(retrieved);
140        let mut sum_precision = 0.0;
141        let mut relevant_count = 0;
142
143        for (rank, id) in retrieved.iter().enumerate() {
144            if relevant.contains(id) {
145                relevant_count += 1;
146                sum_precision += relevant_count as f32 / (rank + 1) as f32;
147            }
148        }
149
150        sum_precision / relevant.len().max(1) as f32
151    }
152
153    /// Compute F1 score at k
154    #[must_use]
155    pub fn f1_at_k(retrieved: &[ChunkId], relevant: &HashSet<ChunkId>, k: usize) -> f32 {
156        // Contract: configuration-v1.yaml precondition (pv codegen)
157        contract_pre_configuration!(retrieved);
158        let precision = Self::precision_at_k(retrieved, relevant, k);
159        let recall = Self::recall_at_k(retrieved, relevant, k);
160
161        if precision + recall == 0.0 {
162            0.0
163        } else {
164            2.0 * precision * recall / (precision + recall)
165        }
166    }
167
168    /// Compute Hit Rate (1 if any relevant in top-k, else 0)
169    #[must_use]
170    pub fn hit_rate_at_k(retrieved: &[ChunkId], relevant: &HashSet<ChunkId>, k: usize) -> f32 {
171        // Contract: configuration-v1.yaml precondition (pv codegen)
172        contract_pre_configuration!(retrieved);
173        let retrieved_k: HashSet<ChunkId> = retrieved.iter().take(k).copied().collect();
174        if retrieved_k.intersection(relevant).next().is_some() {
175            1.0
176        } else {
177            0.0
178        }
179    }
180}
181
182/// Aggregated metrics across multiple queries
183#[derive(Debug, Clone, Default, Serialize, Deserialize)]
184pub struct AggregatedMetrics {
185    /// Mean Recall@k
186    pub mean_recall: std::collections::HashMap<usize, f32>,
187    /// Mean Precision@k
188    pub mean_precision: std::collections::HashMap<usize, f32>,
189    /// Mean MRR
190    pub mean_mrr: f32,
191    /// Mean NDCG@k
192    pub mean_ndcg: std::collections::HashMap<usize, f32>,
193    /// Mean Average Precision (MAP)
194    pub map: f32,
195    /// Number of queries
196    pub query_count: usize,
197}
198
199impl AggregatedMetrics {
200    /// Aggregate metrics from multiple queries
201    pub fn aggregate(metrics: &[RetrievalMetrics]) -> Self {
202        if metrics.is_empty() {
203            return Self::default();
204        }
205
206        let n = metrics.len() as f32;
207        let mut agg = Self { query_count: metrics.len(), ..Default::default() };
208
209        // Aggregate MRR and MAP
210        agg.mean_mrr = metrics.iter().map(|m| m.mrr).sum::<f32>() / n;
211        agg.map = metrics.iter().map(|m| m.map).sum::<f32>() / n;
212
213        // Aggregate k-based metrics
214        if let Some(first) = metrics.first() {
215            for &k in first.recall.keys() {
216                let mean_recall = metrics.iter().filter_map(|m| m.recall.get(&k)).sum::<f32>() / n;
217                agg.mean_recall.insert(k, mean_recall);
218
219                let mean_precision =
220                    metrics.iter().filter_map(|m| m.precision.get(&k)).sum::<f32>() / n;
221                agg.mean_precision.insert(k, mean_precision);
222
223                let mean_ndcg = metrics.iter().filter_map(|m| m.ndcg.get(&k)).sum::<f32>() / n;
224                agg.mean_ndcg.insert(k, mean_ndcg);
225            }
226        }
227
228        agg
229    }
230}
231
232#[cfg(test)]
233mod tests {
234    use super::*;
235
236    fn chunk_id(n: u128) -> ChunkId {
237        ChunkId(uuid::Uuid::from_u128(n))
238    }
239
240    // ============ Recall Tests ============
241
242    #[test]
243    fn test_recall_at_k_perfect() {
244        let retrieved = vec![chunk_id(1), chunk_id(2), chunk_id(3)];
245        let relevant: HashSet<_> = [chunk_id(1), chunk_id(2), chunk_id(3)].into();
246
247        let recall = RetrievalMetrics::recall_at_k(&retrieved, &relevant, 3);
248        assert!((recall - 1.0).abs() < 0.001);
249    }
250
251    #[test]
252    fn test_recall_at_k_partial() {
253        let retrieved = vec![chunk_id(1), chunk_id(4), chunk_id(5)];
254        let relevant: HashSet<_> = [chunk_id(1), chunk_id(2), chunk_id(3)].into();
255
256        let recall = RetrievalMetrics::recall_at_k(&retrieved, &relevant, 3);
257        assert!((recall - 1.0 / 3.0).abs() < 0.001);
258    }
259
260    #[test]
261    fn test_recall_at_k_none() {
262        let retrieved = vec![chunk_id(4), chunk_id(5), chunk_id(6)];
263        let relevant: HashSet<_> = [chunk_id(1), chunk_id(2), chunk_id(3)].into();
264
265        let recall = RetrievalMetrics::recall_at_k(&retrieved, &relevant, 3);
266        assert!((recall - 0.0).abs() < 0.001);
267    }
268
269    #[test]
270    fn test_recall_at_k_empty_relevant() {
271        let retrieved = vec![chunk_id(1), chunk_id(2)];
272        let relevant: HashSet<ChunkId> = HashSet::new();
273
274        let recall = RetrievalMetrics::recall_at_k(&retrieved, &relevant, 2);
275        assert!((recall - 0.0).abs() < 0.001);
276    }
277
278    #[test]
279    fn test_recall_at_k_smaller_k() {
280        let retrieved = vec![chunk_id(4), chunk_id(1), chunk_id(2)];
281        let relevant: HashSet<_> = [chunk_id(1), chunk_id(2)].into();
282
283        // At k=1, only chunk_id(4) which is not relevant
284        let recall = RetrievalMetrics::recall_at_k(&retrieved, &relevant, 1);
285        assert!((recall - 0.0).abs() < 0.001);
286
287        // At k=2, chunk_id(1) is relevant
288        let recall = RetrievalMetrics::recall_at_k(&retrieved, &relevant, 2);
289        assert!((recall - 0.5).abs() < 0.001);
290    }
291
292    // ============ Precision Tests ============
293
294    #[test]
295    fn test_precision_at_k_perfect() {
296        let retrieved = vec![chunk_id(1), chunk_id(2)];
297        let relevant: HashSet<_> = [chunk_id(1), chunk_id(2)].into();
298
299        let precision = RetrievalMetrics::precision_at_k(&retrieved, &relevant, 2);
300        assert!((precision - 1.0).abs() < 0.001);
301    }
302
303    #[test]
304    fn test_precision_at_k_half() {
305        let retrieved = vec![chunk_id(1), chunk_id(4)];
306        let relevant: HashSet<_> = [chunk_id(1), chunk_id(2)].into();
307
308        let precision = RetrievalMetrics::precision_at_k(&retrieved, &relevant, 2);
309        assert!((precision - 0.5).abs() < 0.001);
310    }
311
312    #[test]
313    fn test_precision_at_k_zero() {
314        let precision = RetrievalMetrics::precision_at_k(&[], &HashSet::new(), 0);
315        assert!((precision - 0.0).abs() < 0.001);
316    }
317
318    // ============ MRR Tests ============
319
320    #[test]
321    fn test_mrr_first_position() {
322        let retrieved = vec![chunk_id(1), chunk_id(2), chunk_id(3)];
323        let relevant: HashSet<_> = [chunk_id(1)].into();
324
325        let mrr = RetrievalMetrics::mean_reciprocal_rank(&retrieved, &relevant);
326        assert!((mrr - 1.0).abs() < 0.001);
327    }
328
329    #[test]
330    fn test_mrr_second_position() {
331        let retrieved = vec![chunk_id(4), chunk_id(1), chunk_id(3)];
332        let relevant: HashSet<_> = [chunk_id(1)].into();
333
334        let mrr = RetrievalMetrics::mean_reciprocal_rank(&retrieved, &relevant);
335        assert!((mrr - 0.5).abs() < 0.001);
336    }
337
338    #[test]
339    fn test_mrr_third_position() {
340        let retrieved = vec![chunk_id(4), chunk_id(5), chunk_id(1)];
341        let relevant: HashSet<_> = [chunk_id(1)].into();
342
343        let mrr = RetrievalMetrics::mean_reciprocal_rank(&retrieved, &relevant);
344        assert!((mrr - 1.0 / 3.0).abs() < 0.001);
345    }
346
347    #[test]
348    fn test_mrr_not_found() {
349        let retrieved = vec![chunk_id(4), chunk_id(5), chunk_id(6)];
350        let relevant: HashSet<_> = [chunk_id(1)].into();
351
352        let mrr = RetrievalMetrics::mean_reciprocal_rank(&retrieved, &relevant);
353        assert!((mrr - 0.0).abs() < 0.001);
354    }
355
356    // ============ NDCG Tests ============
357
358    #[test]
359    fn test_ndcg_perfect_order() {
360        let retrieved = vec![chunk_id(1), chunk_id(2)];
361        let relevant: HashSet<_> = [chunk_id(1), chunk_id(2)].into();
362
363        let ndcg = RetrievalMetrics::ndcg_at_k(&retrieved, &relevant, 2);
364        assert!((ndcg - 1.0).abs() < 0.001);
365    }
366
367    #[test]
368    fn test_ndcg_no_relevant() {
369        let retrieved = vec![chunk_id(3), chunk_id(4)];
370        let relevant: HashSet<_> = [chunk_id(1), chunk_id(2)].into();
371
372        let ndcg = RetrievalMetrics::ndcg_at_k(&retrieved, &relevant, 2);
373        assert!((ndcg - 0.0).abs() < 0.001);
374    }
375
376    #[test]
377    fn test_ndcg_empty_relevant() {
378        let retrieved = vec![chunk_id(1), chunk_id(2)];
379        let relevant: HashSet<ChunkId> = HashSet::new();
380
381        let ndcg = RetrievalMetrics::ndcg_at_k(&retrieved, &relevant, 2);
382        assert!((ndcg - 0.0).abs() < 0.001);
383    }
384
385    // ============ Average Precision Tests ============
386
387    #[test]
388    fn test_ap_perfect() {
389        let retrieved = vec![chunk_id(1), chunk_id(2), chunk_id(3)];
390        let relevant: HashSet<_> = [chunk_id(1), chunk_id(2), chunk_id(3)].into();
391
392        let ap = RetrievalMetrics::average_precision(&retrieved, &relevant);
393        // AP = (1/3) * (1/1 + 2/2 + 3/3) = (1/3) * 3 = 1.0
394        assert!((ap - 1.0).abs() < 0.001);
395    }
396
397    #[test]
398    fn test_ap_interleaved() {
399        let retrieved = vec![chunk_id(1), chunk_id(4), chunk_id(2)];
400        let relevant: HashSet<_> = [chunk_id(1), chunk_id(2)].into();
401
402        let ap = RetrievalMetrics::average_precision(&retrieved, &relevant);
403        // AP = (1/2) * (1/1 + 2/3) = (1/2) * (1 + 0.667) = 0.833
404        assert!((ap - 5.0 / 6.0).abs() < 0.001);
405    }
406
407    #[test]
408    fn test_ap_empty_relevant() {
409        let retrieved = vec![chunk_id(1), chunk_id(2)];
410        let relevant: HashSet<ChunkId> = HashSet::new();
411
412        let ap = RetrievalMetrics::average_precision(&retrieved, &relevant);
413        assert!((ap - 0.0).abs() < 0.001);
414    }
415
416    // ============ F1 Tests ============
417
418    #[test]
419    fn test_f1_perfect() {
420        let retrieved = vec![chunk_id(1), chunk_id(2)];
421        let relevant: HashSet<_> = [chunk_id(1), chunk_id(2)].into();
422
423        let f1 = RetrievalMetrics::f1_at_k(&retrieved, &relevant, 2);
424        assert!((f1 - 1.0).abs() < 0.001);
425    }
426
427    #[test]
428    fn test_f1_zero() {
429        let retrieved = vec![chunk_id(3), chunk_id(4)];
430        let relevant: HashSet<_> = [chunk_id(1), chunk_id(2)].into();
431
432        let f1 = RetrievalMetrics::f1_at_k(&retrieved, &relevant, 2);
433        assert!((f1 - 0.0).abs() < 0.001);
434    }
435
436    // ============ Hit Rate Tests ============
437
438    #[test]
439    fn test_hit_rate_hit() {
440        let retrieved = vec![chunk_id(3), chunk_id(1), chunk_id(4)];
441        let relevant: HashSet<_> = [chunk_id(1), chunk_id(2)].into();
442
443        let hr = RetrievalMetrics::hit_rate_at_k(&retrieved, &relevant, 3);
444        assert!((hr - 1.0).abs() < 0.001);
445    }
446
447    #[test]
448    fn test_hit_rate_miss() {
449        let retrieved = vec![chunk_id(3), chunk_id(4)];
450        let relevant: HashSet<_> = [chunk_id(1), chunk_id(2)].into();
451
452        let hr = RetrievalMetrics::hit_rate_at_k(&retrieved, &relevant, 2);
453        assert!((hr - 0.0).abs() < 0.001);
454    }
455
456    // ============ Compute Tests ============
457
458    #[test]
459    fn test_compute_all_metrics() {
460        let retrieved = vec![chunk_id(1), chunk_id(4), chunk_id(2), chunk_id(5)];
461        let relevant: HashSet<_> = [chunk_id(1), chunk_id(2), chunk_id(3)].into();
462        let k_values = vec![1, 2, 5, 10];
463
464        let metrics = RetrievalMetrics::compute(&retrieved, &relevant, &k_values);
465
466        assert!(!metrics.recall.is_empty());
467        assert!(!metrics.precision.is_empty());
468        assert!(!metrics.ndcg.is_empty());
469        assert!(metrics.mrr > 0.0);
470    }
471
472    // ============ Aggregation Tests ============
473
474    #[test]
475    fn test_aggregate_empty() {
476        let agg = AggregatedMetrics::aggregate(&[]);
477        assert_eq!(agg.query_count, 0);
478    }
479
480    #[test]
481    fn test_aggregate_single() {
482        let retrieved = vec![chunk_id(1), chunk_id(2)];
483        let relevant: HashSet<_> = [chunk_id(1), chunk_id(2)].into();
484        let metrics = RetrievalMetrics::compute(&retrieved, &relevant, &[1, 2]);
485
486        let agg = AggregatedMetrics::aggregate(&[metrics]);
487        assert_eq!(agg.query_count, 1);
488        assert!((agg.mean_mrr - 1.0).abs() < 0.001);
489    }
490
491    #[test]
492    fn test_aggregate_multiple() {
493        let metrics1 = RetrievalMetrics {
494            mrr: 1.0,
495            map: 1.0,
496            recall: [(1, 1.0), (2, 1.0)].into(),
497            precision: [(1, 1.0), (2, 1.0)].into(),
498            ndcg: [(1, 1.0), (2, 1.0)].into(),
499        };
500        let metrics2 = RetrievalMetrics {
501            mrr: 0.5,
502            map: 0.5,
503            recall: [(1, 0.5), (2, 0.5)].into(),
504            precision: [(1, 0.5), (2, 0.5)].into(),
505            ndcg: [(1, 0.5), (2, 0.5)].into(),
506        };
507
508        let agg = AggregatedMetrics::aggregate(&[metrics1, metrics2]);
509
510        assert_eq!(agg.query_count, 2);
511        assert!((agg.mean_mrr - 0.75).abs() < 0.001);
512        assert!((agg.map - 0.75).abs() < 0.001);
513    }
514
515    // ============ Property-Based Tests ============
516
517    use proptest::prelude::*;
518
519    proptest! {
520        #[test]
521        fn prop_recall_bounded(
522            retrieved_ids in prop::collection::vec(0u128..100, 1..20),
523            relevant_ids in prop::collection::vec(0u128..100, 1..10),
524            k in 1usize..20
525        ) {
526            let retrieved: Vec<_> = retrieved_ids.into_iter().map(chunk_id).collect();
527            let relevant: HashSet<_> = relevant_ids.into_iter().map(chunk_id).collect();
528
529            let recall = RetrievalMetrics::recall_at_k(&retrieved, &relevant, k);
530            prop_assert!(recall >= 0.0);
531            prop_assert!(recall <= 1.0);
532        }
533
534        #[test]
535        fn prop_precision_bounded(
536            retrieved_ids in prop::collection::vec(0u128..100, 1..20),
537            relevant_ids in prop::collection::vec(0u128..100, 1..10),
538            k in 1usize..20
539        ) {
540            let retrieved: Vec<_> = retrieved_ids.into_iter().map(chunk_id).collect();
541            let relevant: HashSet<_> = relevant_ids.into_iter().map(chunk_id).collect();
542
543            let precision = RetrievalMetrics::precision_at_k(&retrieved, &relevant, k);
544            prop_assert!(precision >= 0.0);
545            prop_assert!(precision <= 1.0);
546        }
547
548        #[test]
549        fn prop_mrr_bounded(
550            retrieved_ids in prop::collection::vec(0u128..100, 1..20),
551            relevant_ids in prop::collection::vec(0u128..100, 1..10)
552        ) {
553            let retrieved: Vec<_> = retrieved_ids.into_iter().map(chunk_id).collect();
554            let relevant: HashSet<_> = relevant_ids.into_iter().map(chunk_id).collect();
555
556            let mrr = RetrievalMetrics::mean_reciprocal_rank(&retrieved, &relevant);
557            prop_assert!(mrr >= 0.0);
558            prop_assert!(mrr <= 1.0);
559        }
560
561        #[test]
562        fn prop_ndcg_bounded(
563            retrieved_ids in prop::collection::vec(0u128..100, 1..20),
564            relevant_ids in prop::collection::vec(0u128..100, 1..10),
565            k in 1usize..20
566        ) {
567            let retrieved: Vec<_> = retrieved_ids.into_iter().map(chunk_id).collect();
568            let relevant: HashSet<_> = relevant_ids.into_iter().map(chunk_id).collect();
569
570            let ndcg = RetrievalMetrics::ndcg_at_k(&retrieved, &relevant, k);
571            prop_assert!(ndcg >= 0.0);
572            prop_assert!(ndcg <= 1.0);
573        }
574    }
575}