Skip to main content

do_memory_core/search/
metrics.rs

1//! Retrieval evaluation metrics for benchmarking.
2//!
3//! Implements standard metrics from MTEB/BEIR methodology for retrieval-first evaluation
4//! without LLM calls. These metrics allow objective comparison of retrieval quality.
5//!
6//! # Metrics Overview
7//!
8//! | Metric | Formula | Best For |
9//! |--------|---------|----------|
10//! | Recall@k | (# relevant in top k) / (total relevant) | When missing items is costly |
11//! | Precision@k | (# relevant in top k) / k | When user only views top results |
12//! | NDCG@k | DCG@k / IDCG@k | Graded relevance (0-3 scale) |
13//! | MRR | 1/N Σ (1 / rank_first_relevant) | Question answering |
14//! | MAP | 1/Q Σ (precision at each relevant item) | Ranked retrieval |
15//!
16//! # Example
17//!
18//! ```
19//! use do_memory_core::search::metrics::{recall_at_k, ndcg_at_k, mrr};
20//! use std::collections::{HashMap, HashSet};
21//!
22//! let retrieved = vec![1, 2, 3, 4, 5];
23//! let relevant: HashSet<usize> = [2, 4, 6].into_iter().collect();
24//!
25//! let recall = recall_at_k(&retrieved, &relevant, 5);
26//! assert!((recall - 0.666).abs() < 0.01); // 2 of 3 relevant items found
27//!
28//! let mut rel_scores = HashMap::new();
29//! rel_scores.insert(2, 3.0); // highly relevant
30//! rel_scores.insert(4, 1.0); // marginally relevant
31//! let ndcg = ndcg_at_k(&retrieved, &rel_scores, 5);
32//! assert!(ndcg > 0.0 && ndcg <= 1.0);
33//! ```
34
35use std::collections::{HashMap, HashSet};
36
37/// Calculate Recall@k: fraction of relevant items retrieved in top k.
38///
39/// # Arguments
40///
41/// * `retrieved` - Ordered list of retrieved item IDs (ranked by relevance score)
42/// * `relevant` - Set of relevant item IDs (ground truth)
43/// * `k` - Number of top results to consider
44///
45/// # Returns
46///
47/// Recall@k in range [0.0, 1.0]. Returns 1.0 if there are no relevant items.
48#[must_use]
49pub fn recall_at_k(retrieved: &[usize], relevant: &HashSet<usize>, k: usize) -> f64 {
50    if relevant.is_empty() {
51        return 1.0;
52    }
53
54    let k = k.min(retrieved.len());
55    let relevant_in_top_k = retrieved[..k]
56        .iter()
57        .filter(|id| relevant.contains(id))
58        .count();
59
60    relevant_in_top_k as f64 / relevant.len() as f64
61}
62
63/// Calculate Precision@k: fraction of top-k results that are relevant.
64///
65/// # Arguments
66///
67/// * `retrieved` - Ordered list of retrieved item IDs
68/// * `relevant` - Set of relevant item IDs
69/// * `k` - Number of top results to consider
70///
71/// # Returns
72///
73/// Precision@k in range [0.0, 1.0].
74#[must_use]
75pub fn precision_at_k(retrieved: &[usize], relevant: &HashSet<usize>, k: usize) -> f64 {
76    if retrieved.is_empty() {
77        return 0.0;
78    }
79
80    let k = k.min(retrieved.len());
81    let relevant_in_top_k = retrieved[..k]
82        .iter()
83        .filter(|id| relevant.contains(id))
84        .count();
85
86    relevant_in_top_k as f64 / k as f64
87}
88
89/// Calculate NDCG@k (Normalized Discounted Cumulative Gain).
90///
91/// NDCG accounts for graded relevance (not just binary) and position in ranking.
92///
93/// # Arguments
94///
95/// * `retrieved` - Ordered list of retrieved item IDs
96/// * `relevance_scores` - Map of item ID to relevance score (typically 0-3 scale)
97/// * `k` - Number of top results to consider
98///
99/// # Returns
100///
101/// NDCG@k in range [0.0, 1.0]. Returns 1.0 if no relevant items exist.
102#[must_use]
103pub fn ndcg_at_k(retrieved: &[usize], relevance_scores: &HashMap<usize, f64>, k: usize) -> f64 {
104    let k = k.min(retrieved.len());
105    if k == 0 {
106        return 0.0;
107    }
108
109    // Calculate DCG@k
110    let dcg: f64 = retrieved[..k]
111        .iter()
112        .enumerate()
113        .map(|(i, id)| {
114            let rel = relevance_scores.get(id).unwrap_or(&0.0);
115            (2.0f64.powf(*rel) - 1.0) / (2.0 + i as f64).log2()
116        })
117        .sum();
118
119    // Calculate IDCG@k (ideal DCG - items sorted by relevance)
120    let mut ideal_rels: Vec<f64> = relevance_scores.values().copied().collect();
121    ideal_rels.sort_by(|a, b| b.partial_cmp(a).unwrap());
122    ideal_rels.truncate(k);
123
124    let idcg: f64 = ideal_rels
125        .iter()
126        .enumerate()
127        .map(|(i, rel)| (2.0f64.powf(*rel) - 1.0) / (2.0 + i as f64).log2())
128        .sum();
129
130    if idcg == 0.0 {
131        return 0.0;
132    }
133
134    dcg / idcg
135}
136
137/// Calculate Mean Reciprocal Rank (MRR).
138///
139/// MRR measures where the first relevant item appears in the ranking.
140///
141/// # Arguments
142///
143/// * `retrieved_lists` - List of ranked result lists (one per query)
144/// * `relevant_sets` - List of relevant item sets (one per query)
145///
146/// # Returns
147///
148/// MRR in range [0.0, 1.0]. Higher is better.
149#[must_use]
150pub fn mrr(retrieved_lists: &[Vec<usize>], relevant_sets: &[HashSet<usize>]) -> f64 {
151    if retrieved_lists.is_empty() || retrieved_lists.len() != relevant_sets.len() {
152        return 0.0;
153    }
154
155    let reciprocal_ranks: f64 = retrieved_lists
156        .iter()
157        .zip(relevant_sets.iter())
158        .map(|(retrieved, relevant)| {
159            if relevant.is_empty() {
160                return 0.0;
161            }
162            retrieved
163                .iter()
164                .position(|id| relevant.contains(id))
165                .map(|pos| 1.0 / (pos + 1) as f64)
166                .unwrap_or(0.0)
167        })
168        .sum();
169
170    reciprocal_ranks / retrieved_lists.len() as f64
171}
172
173/// Calculate Mean Average Precision (MAP).
174///
175/// MAP considers precision at each relevant item position.
176///
177/// # Arguments
178///
179/// * `retrieved_lists` - List of ranked result lists (one per query)
180/// * `relevant_sets` - List of relevant item sets (one per query)
181///
182/// # Returns
183///
184/// MAP in range [0.0, 1.0]. Higher is better.
185#[must_use]
186pub fn map(retrieved_lists: &[Vec<usize>], relevant_sets: &[HashSet<usize>]) -> f64 {
187    if retrieved_lists.is_empty() || retrieved_lists.len() != relevant_sets.len() {
188        return 0.0;
189    }
190
191    let average_precisions: f64 = retrieved_lists
192        .iter()
193        .zip(relevant_sets.iter())
194        .map(|(retrieved, relevant)| {
195            if relevant.is_empty() {
196                return 0.0;
197            }
198
199            let mut sum_precision = 0.0;
200            let mut relevant_count = 0;
201
202            for (i, id) in retrieved.iter().enumerate() {
203                if relevant.contains(id) {
204                    relevant_count += 1;
205                    #[allow(clippy::cast_precision_loss)]
206                    let precision_at_i = f64::from(relevant_count) / (i + 1) as f64;
207                    sum_precision += precision_at_i;
208                }
209            }
210
211            sum_precision / relevant.len() as f64
212        })
213        .sum();
214
215    average_precisions / retrieved_lists.len() as f64
216}
217
218/// Calculate Hit Rate@k: binary metric checking if any relevant item is in top k.
219///
220/// # Arguments
221///
222/// * `retrieved` - Ordered list of retrieved item IDs
223/// * `relevant` - Set of relevant item IDs
224/// * `k` - Number of top results to consider
225///
226/// # Returns
227///
228/// 1.0 if any relevant item is in top k, 0.0 otherwise.
229#[must_use]
230pub fn hit_rate_at_k(retrieved: &[usize], relevant: &HashSet<usize>, k: usize) -> f64 {
231    let k = k.min(retrieved.len());
232    if k == 0 || relevant.is_empty() {
233        return 0.0;
234    }
235
236    let has_hit = retrieved[..k].iter().any(|id| relevant.contains(id));
237    if has_hit { 1.0 } else { 0.0 }
238}
239
240/// Reciprocal Rank Fusion (RRF) for combining multiple ranked lists.
241///
242/// RRF is robust to score scale differences and doesn't require normalization.
243///
244/// # Arguments
245///
246/// * `result_lists` - Multiple ranked lists with (item_id, score) tuples
247/// * `k` - RRF constant (typically 60)
248///
249/// # Returns
250///
251/// Fused ranked list with RRF scores.
252#[must_use]
253pub fn reciprocal_rank_fusion<T: Clone + Eq + std::hash::Hash + std::cmp::Ord>(
254    result_lists: &[Vec<(T, f32)>],
255    k: u32,
256) -> Vec<(T, f32)> {
257    use std::collections::BTreeMap;
258
259    let mut rrf_scores: BTreeMap<T, f32> = BTreeMap::new();
260
261    for list in result_lists {
262        for (rank, (item, _)) in list.iter().enumerate() {
263            let rrf_contribution = 1.0 / (k as f32 + rank as f32 + 1.0);
264            *rrf_scores.entry(item.clone()).or_insert(0.0) += rrf_contribution;
265        }
266    }
267
268    let mut fused: Vec<(T, f32)> = rrf_scores.into_iter().collect();
269    fused.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
270    fused
271}
272
273#[cfg(test)]
274mod tests {
275    use super::*;
276
277    #[test]
278    fn test_recall_at_k() {
279        let retrieved = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
280        let relevant: HashSet<usize> = [2, 5, 7, 11].into_iter().collect();
281
282        // 4 relevant total, 3 found in top 10
283        assert!((recall_at_k(&retrieved, &relevant, 10) - 0.75).abs() < 0.001);
284
285        // 4 relevant total, 1 found in top 3
286        assert!((recall_at_k(&retrieved, &relevant, 3) - 0.25).abs() < 0.001);
287
288        // Empty relevant set
289        let empty: HashSet<usize> = HashSet::new();
290        assert_eq!(recall_at_k(&retrieved, &empty, 5), 1.0);
291    }
292
293    #[test]
294    fn test_precision_at_k() {
295        let retrieved = vec![1, 2, 3, 4, 5];
296        let relevant: HashSet<usize> = [2, 4].into_iter().collect();
297
298        // 2 relevant in top 5
299        assert!((precision_at_k(&retrieved, &relevant, 5) - 0.4).abs() < 0.001);
300
301        // 1 relevant in top 3
302        assert!((precision_at_k(&retrieved, &relevant, 3) - 0.333).abs() < 0.01);
303    }
304
305    #[test]
306    fn test_ndcg_at_k() {
307        let retrieved = vec![1, 2, 3, 4, 5];
308        let mut rel_scores = HashMap::new();
309        rel_scores.insert(1, 3.0); // highly relevant
310        rel_scores.insert(2, 2.0); // relevant
311        rel_scores.insert(3, 0.0); // not relevant
312
313        // Perfect ranking would have item 1 first, item 2 second
314        let ndcg = ndcg_at_k(&retrieved, &rel_scores, 3);
315        assert!(ndcg > 0.0 && ndcg <= 1.0);
316    }
317
318    #[test]
319    fn test_mrr() {
320        let retrieved_lists = vec![
321            vec![1, 2, 3, 4],    // First relevant at position 2 (rank 2)
322            vec![5, 6, 7, 8],    // First relevant at position 1 (rank 1)
323            vec![9, 10, 11, 12], // No relevant items
324        ];
325        let relevant_sets = vec![
326            [2, 4].into_iter().collect(),
327            [7].into_iter().collect(),
328            [13].into_iter().collect(),
329        ];
330
331        // MRR = (1/2 + 1/3 + 0) / 3 = 0.277...
332        let mrr_score = mrr(&retrieved_lists, &relevant_sets);
333        assert!((mrr_score - 0.277).abs() < 0.01);
334    }
335
336    #[test]
337    fn test_map() {
338        let retrieved_lists = vec![vec![1, 2, 3, 4, 5]];
339        let relevant_sets: Vec<HashSet<usize>> = vec![[2, 4].into_iter().collect()];
340
341        // Precision at rank 2 = 1/2, precision at rank 4 = 2/4
342        // AP = (0.5 + 0.5) / 2 = 0.5
343        let map_score = map(&retrieved_lists, &relevant_sets);
344        assert!((map_score - 0.5).abs() < 0.001);
345    }
346
347    #[test]
348    fn test_reciprocal_rank_fusion() {
349        let list1 = vec![("a", 0.9), ("b", 0.8), ("c", 0.7)];
350        let list2 = vec![("c", 0.95), ("a", 0.85), ("d", 0.75)];
351
352        let fused = reciprocal_rank_fusion(&[list1, list2], 60);
353
354        // Both lists have "a" and "c" high, they should rank well
355        assert!(!fused.is_empty());
356        assert!(fused.iter().any(|(item, _)| *item == "a"));
357        assert!(fused.iter().any(|(item, _)| *item == "c"));
358    }
359
360    #[test]
361    fn test_hit_rate_at_k() {
362        let retrieved = vec![1, 2, 3, 4, 5];
363        let relevant: HashSet<usize> = [6, 7].into_iter().collect();
364
365        // No relevant items in top 5
366        assert_eq!(hit_rate_at_k(&retrieved, &relevant, 5), 0.0);
367
368        let relevant2: HashSet<usize> = [3, 7].into_iter().collect();
369        // Has relevant item in top 5
370        assert_eq!(hit_rate_at_k(&retrieved, &relevant2, 5), 1.0);
371    }
372}