reasonkit/evaluation/
metrics.rs

1//! Core retrieval evaluation metrics.
2//!
3//! Implements standard IR metrics:
4//! - Recall@K
5//! - Precision@K
6//! - NDCG@K
7//! - MRR
8//! - MAP
9
10use std::collections::{HashMap, HashSet};
11
12/// Result for a single query
13#[derive(Debug, Clone)]
14pub struct QueryResult {
15    /// Query ID
16    pub query_id: String,
17    /// Retrieved document IDs in ranked order
18    pub retrieved_ids: Vec<String>,
19    /// Ground truth: set of relevant document IDs
20    pub relevant_ids: HashSet<String>,
21    /// Optional: graded relevance scores (for NDCG)
22    pub relevance_grades: Option<HashMap<String, f64>>,
23}
24
25impl QueryResult {
26    pub fn new(
27        query_id: impl Into<String>,
28        retrieved: Vec<String>,
29        relevant: HashSet<String>,
30    ) -> Self {
31        Self {
32            query_id: query_id.into(),
33            retrieved_ids: retrieved,
34            relevant_ids: relevant,
35            relevance_grades: None,
36        }
37    }
38
39    pub fn with_grades(mut self, grades: HashMap<String, f64>) -> Self {
40        self.relevance_grades = Some(grades);
41        self
42    }
43}
44
45/// Evaluation result for a single query
46#[derive(Debug, Clone, Default)]
47pub struct EvaluationResult {
48    pub query_id: String,
49    pub recall: f64,
50    pub precision: f64,
51    pub ndcg: f64,
52    pub mrr: f64,
53    pub ap: f64,
54    pub k: usize,
55}
56
57/// Retrieval metrics for a single query
58#[derive(Debug, Clone)]
59pub struct RetrievalMetrics {
60    pub recall_at_k: HashMap<usize, f64>,
61    pub precision_at_k: HashMap<usize, f64>,
62    pub ndcg_at_k: HashMap<usize, f64>,
63    pub mrr: f64,
64    pub map: f64,
65}
66
67impl RetrievalMetrics {
68    /// Compute all metrics for given K values
69    pub fn compute_all(
70        retrieved: &[String],
71        relevant: &HashSet<String>,
72        k_values: &[usize],
73    ) -> Self {
74        let mut recall_at_k = HashMap::new();
75        let mut precision_at_k = HashMap::new();
76        let mut ndcg_at_k = HashMap::new();
77
78        for &k in k_values {
79            recall_at_k.insert(k, recall_at_k_impl(retrieved, relevant, k));
80            precision_at_k.insert(k, precision_at_k_impl(retrieved, relevant, k));
81            ndcg_at_k.insert(k, ndcg_at_k_binary(retrieved, relevant, k));
82        }
83
84        let mrr = mean_reciprocal_rank_single(retrieved, relevant);
85        let map = average_precision_impl(retrieved, relevant);
86
87        Self {
88            recall_at_k,
89            precision_at_k,
90            ndcg_at_k,
91            mrr,
92            map,
93        }
94    }
95
96    /// Compute metrics for a single K value
97    pub fn compute(retrieved: &[String], relevant: &HashSet<String>, k: usize) -> EvaluationResult {
98        EvaluationResult {
99            query_id: String::new(),
100            recall: recall_at_k_impl(retrieved, relevant, k),
101            precision: precision_at_k_impl(retrieved, relevant, k),
102            ndcg: ndcg_at_k_binary(retrieved, relevant, k),
103            mrr: mean_reciprocal_rank_single(retrieved, relevant),
104            ap: average_precision_impl(retrieved, relevant),
105            k,
106        }
107    }
108}
109
110/// Recall@K: Proportion of relevant documents retrieved in top-K
111///
112/// Recall@K = |Relevant ∩ Retrieved@K| / |Relevant|
113///
114/// # Arguments
115/// * `retrieved` - Document IDs in ranked order
116/// * `relevant` - Set of relevant document IDs
117/// * `k` - Number of top results to consider
118///
119/// # Returns
120/// Recall value between 0.0 and 1.0
121pub fn recall_at_k(retrieved: &[impl AsRef<str>], relevant: &HashSet<String>, k: usize) -> f64 {
122    let retrieved_str: Vec<String> = retrieved.iter().map(|s| s.as_ref().to_string()).collect();
123    recall_at_k_impl(&retrieved_str, relevant, k)
124}
125
126fn recall_at_k_impl(retrieved: &[String], relevant: &HashSet<String>, k: usize) -> f64 {
127    if relevant.is_empty() {
128        return 0.0;
129    }
130
131    let top_k: HashSet<_> = retrieved.iter().take(k).cloned().collect();
132    let hits = relevant.intersection(&top_k).count();
133
134    hits as f64 / relevant.len() as f64
135}
136
137/// Precision@K: Proportion of top-K documents that are relevant
138///
139/// Precision@K = |Relevant ∩ Retrieved@K| / K
140///
141/// # Arguments
142/// * `retrieved` - Document IDs in ranked order
143/// * `relevant` - Set of relevant document IDs
144/// * `k` - Number of top results to consider
145///
146/// # Returns
147/// Precision value between 0.0 and 1.0
148pub fn precision_at_k(retrieved: &[impl AsRef<str>], relevant: &HashSet<String>, k: usize) -> f64 {
149    let retrieved_str: Vec<String> = retrieved.iter().map(|s| s.as_ref().to_string()).collect();
150    precision_at_k_impl(&retrieved_str, relevant, k)
151}
152
153fn precision_at_k_impl(retrieved: &[String], relevant: &HashSet<String>, k: usize) -> f64 {
154    if k == 0 {
155        return 0.0;
156    }
157
158    let actual_k = k.min(retrieved.len());
159    if actual_k == 0 {
160        return 0.0;
161    }
162
163    let hits = retrieved
164        .iter()
165        .take(actual_k)
166        .filter(|doc| relevant.contains(*doc))
167        .count();
168
169    hits as f64 / actual_k as f64
170}
171
172/// NDCG@K: Normalized Discounted Cumulative Gain
173///
174/// DCG@K = Σ(i=1 to K) rel_i / log2(i+1)
175/// NDCG@K = DCG@K / IDCG@K
176///
177/// # Arguments
178/// * `retrieved` - Document IDs in ranked order
179/// * `relevant` - Set of relevant document IDs (binary relevance)
180/// * `k` - Number of top results to consider
181///
182/// # Returns
183/// NDCG value between 0.0 and 1.0
184pub fn ndcg_at_k(retrieved: &[impl AsRef<str>], relevant: &HashSet<String>, k: usize) -> f64 {
185    let retrieved_str: Vec<String> = retrieved.iter().map(|s| s.as_ref().to_string()).collect();
186    ndcg_at_k_binary(&retrieved_str, relevant, k)
187}
188
189fn ndcg_at_k_binary(retrieved: &[String], relevant: &HashSet<String>, k: usize) -> f64 {
190    if relevant.is_empty() {
191        return 0.0;
192    }
193
194    // DCG: sum of relevance / log2(rank + 1)
195    let dcg: f64 = retrieved
196        .iter()
197        .take(k)
198        .enumerate()
199        .filter(|(_, doc)| relevant.contains(*doc))
200        .map(|(i, _)| 1.0 / (i as f64 + 2.0).log2()) // log2(i+2) = log2(rank+1) where rank is 1-indexed
201        .sum();
202
203    // IDCG: ideal DCG (all relevant docs at top)
204    let num_relevant_in_k = k.min(relevant.len());
205    let idcg: f64 = (0..num_relevant_in_k)
206        .map(|i| 1.0 / (i as f64 + 2.0).log2())
207        .sum();
208
209    if idcg == 0.0 {
210        return 0.0;
211    }
212
213    dcg / idcg
214}
215
216/// NDCG@K with graded relevance
217///
218/// # Arguments
219/// * `retrieved` - Document IDs in ranked order
220/// * `relevance_grades` - Map of document ID to relevance grade (0.0 to 1.0 or higher)
221/// * `k` - Number of top results to consider
222pub fn ndcg_at_k_graded(
223    retrieved: &[String],
224    relevance_grades: &HashMap<String, f64>,
225    k: usize,
226) -> f64 {
227    if relevance_grades.is_empty() {
228        return 0.0;
229    }
230
231    // DCG with graded relevance
232    let dcg: f64 = retrieved
233        .iter()
234        .take(k)
235        .enumerate()
236        .map(|(i, doc)| {
237            let rel = relevance_grades.get(doc).copied().unwrap_or(0.0);
238            (2_f64.powf(rel) - 1.0) / (i as f64 + 2.0).log2()
239        })
240        .sum();
241
242    // IDCG: sort by relevance grade descending
243    let mut sorted_grades: Vec<f64> = relevance_grades.values().copied().collect();
244    sorted_grades.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
245
246    let idcg: f64 = sorted_grades
247        .iter()
248        .take(k)
249        .enumerate()
250        .map(|(i, &rel)| (2_f64.powf(rel) - 1.0) / (i as f64 + 2.0).log2())
251        .sum();
252
253    if idcg == 0.0 {
254        return 0.0;
255    }
256
257    dcg / idcg
258}
259
260/// MRR: Mean Reciprocal Rank (for single query)
261///
262/// RR = 1 / rank of first relevant document
263///
264/// # Arguments
265/// * `retrieved` - Document IDs in ranked order
266/// * `relevant` - Set of relevant document IDs
267///
268/// # Returns
269/// Reciprocal rank (1.0 if first result is relevant, 0.5 if second, etc.)
270pub fn mean_reciprocal_rank(results: &[QueryResult]) -> f64 {
271    if results.is_empty() {
272        return 0.0;
273    }
274
275    let sum: f64 = results
276        .iter()
277        .map(|r| mean_reciprocal_rank_single(&r.retrieved_ids, &r.relevant_ids))
278        .sum();
279
280    sum / results.len() as f64
281}
282
283fn mean_reciprocal_rank_single(retrieved: &[String], relevant: &HashSet<String>) -> f64 {
284    for (i, doc) in retrieved.iter().enumerate() {
285        if relevant.contains(doc) {
286            return 1.0 / (i as f64 + 1.0);
287        }
288    }
289    0.0
290}
291
292/// Average Precision for a single query
293///
294/// AP = (1/|Relevant|) × Σ(k=1 to n) Precision@k × rel(k)
295///
296/// # Arguments
297/// * `retrieved` - Document IDs in ranked order
298/// * `relevant` - Set of relevant document IDs
299pub fn average_precision(retrieved: &[impl AsRef<str>], relevant: &HashSet<String>) -> f64 {
300    let retrieved_str: Vec<String> = retrieved.iter().map(|s| s.as_ref().to_string()).collect();
301    average_precision_impl(&retrieved_str, relevant)
302}
303
304fn average_precision_impl(retrieved: &[String], relevant: &HashSet<String>) -> f64 {
305    if relevant.is_empty() {
306        return 0.0;
307    }
308
309    let mut num_relevant_seen = 0;
310    let mut sum_precision = 0.0;
311
312    for (i, doc) in retrieved.iter().enumerate() {
313        if relevant.contains(doc) {
314            num_relevant_seen += 1;
315            // Precision at this position
316            let precision = num_relevant_seen as f64 / (i as f64 + 1.0);
317            sum_precision += precision;
318        }
319    }
320
321    sum_precision / relevant.len() as f64
322}
323
324/// Mean Average Precision across multiple queries
325///
326/// MAP = (1/|Q|) × Σ(q=1 to |Q|) AP(q)
327pub fn mean_average_precision(results: &[QueryResult]) -> f64 {
328    if results.is_empty() {
329        return 0.0;
330    }
331
332    let sum: f64 = results
333        .iter()
334        .map(|r| average_precision_impl(&r.retrieved_ids, &r.relevant_ids))
335        .sum();
336
337    sum / results.len() as f64
338}
339
340#[cfg(test)]
341mod tests {
342    use super::*;
343
344    fn make_relevant(ids: &[&str]) -> HashSet<String> {
345        ids.iter().map(|s| s.to_string()).collect()
346    }
347
348    fn make_retrieved(ids: &[&str]) -> Vec<String> {
349        ids.iter().map(|s| s.to_string()).collect()
350    }
351
352    #[test]
353    fn test_recall_at_k_perfect() {
354        let retrieved = make_retrieved(&["a", "b", "c", "d", "e"]);
355        let relevant = make_relevant(&["a", "b", "c"]);
356
357        assert_eq!(recall_at_k_impl(&retrieved, &relevant, 3), 1.0);
358        assert_eq!(recall_at_k_impl(&retrieved, &relevant, 5), 1.0);
359    }
360
361    #[test]
362    fn test_recall_at_k_partial() {
363        let retrieved = make_retrieved(&["a", "x", "b", "y", "c"]);
364        let relevant = make_relevant(&["a", "b", "c"]);
365
366        // At k=1, only "a" is retrieved (1/3 relevant)
367        assert!((recall_at_k_impl(&retrieved, &relevant, 1) - 1.0 / 3.0).abs() < 0.001);
368
369        // At k=3, "a" and "b" are retrieved (2/3 relevant)
370        assert!((recall_at_k_impl(&retrieved, &relevant, 3) - 2.0 / 3.0).abs() < 0.001);
371
372        // At k=5, all are retrieved
373        assert_eq!(recall_at_k_impl(&retrieved, &relevant, 5), 1.0);
374    }
375
376    #[test]
377    fn test_recall_at_k_none() {
378        let retrieved = make_retrieved(&["x", "y", "z"]);
379        let relevant = make_relevant(&["a", "b", "c"]);
380
381        assert_eq!(recall_at_k_impl(&retrieved, &relevant, 3), 0.0);
382    }
383
384    #[test]
385    fn test_recall_at_k_empty_relevant() {
386        let retrieved = make_retrieved(&["a", "b", "c"]);
387        let relevant = HashSet::new();
388
389        assert_eq!(recall_at_k_impl(&retrieved, &relevant, 3), 0.0);
390    }
391
392    #[test]
393    fn test_precision_at_k_perfect() {
394        let retrieved = make_retrieved(&["a", "b", "c"]);
395        let relevant = make_relevant(&["a", "b", "c", "d", "e"]);
396
397        assert_eq!(precision_at_k_impl(&retrieved, &relevant, 3), 1.0);
398    }
399
400    #[test]
401    fn test_precision_at_k_partial() {
402        let retrieved = make_retrieved(&["a", "x", "b", "y", "c"]);
403        let relevant = make_relevant(&["a", "b", "c"]);
404
405        // At k=1: 1 relevant out of 1
406        assert_eq!(precision_at_k_impl(&retrieved, &relevant, 1), 1.0);
407
408        // At k=2: 1 relevant out of 2
409        assert_eq!(precision_at_k_impl(&retrieved, &relevant, 2), 0.5);
410
411        // At k=5: 3 relevant out of 5
412        assert_eq!(precision_at_k_impl(&retrieved, &relevant, 5), 0.6);
413    }
414
415    #[test]
416    fn test_mrr_first_position() {
417        let retrieved = make_retrieved(&["a", "b", "c"]);
418        let relevant = make_relevant(&["a"]);
419
420        assert_eq!(mean_reciprocal_rank_single(&retrieved, &relevant), 1.0);
421    }
422
423    #[test]
424    fn test_mrr_second_position() {
425        let retrieved = make_retrieved(&["x", "a", "c"]);
426        let relevant = make_relevant(&["a"]);
427
428        assert_eq!(mean_reciprocal_rank_single(&retrieved, &relevant), 0.5);
429    }
430
431    #[test]
432    fn test_mrr_third_position() {
433        let retrieved = make_retrieved(&["x", "y", "a"]);
434        let relevant = make_relevant(&["a"]);
435
436        assert!((mean_reciprocal_rank_single(&retrieved, &relevant) - 1.0 / 3.0).abs() < 0.001);
437    }
438
439    #[test]
440    fn test_mrr_not_found() {
441        let retrieved = make_retrieved(&["x", "y", "z"]);
442        let relevant = make_relevant(&["a"]);
443
444        assert_eq!(mean_reciprocal_rank_single(&retrieved, &relevant), 0.0);
445    }
446
447    #[test]
448    fn test_ndcg_perfect() {
449        let retrieved = make_retrieved(&["a", "b", "c", "x", "y"]);
450        let relevant = make_relevant(&["a", "b", "c"]);
451
452        // Perfect ranking: all relevant docs at top
453        assert!((ndcg_at_k_binary(&retrieved, &relevant, 5) - 1.0).abs() < 0.001);
454    }
455
456    #[test]
457    fn test_ndcg_partial() {
458        let retrieved = make_retrieved(&["x", "a", "y", "b", "c"]);
459        let relevant = make_relevant(&["a", "b", "c"]);
460
461        // Not perfect: relevant docs are at positions 2, 4, 5
462        let ndcg = ndcg_at_k_binary(&retrieved, &relevant, 5);
463        assert!(ndcg > 0.0 && ndcg < 1.0);
464    }
465
466    #[test]
467    fn test_average_precision() {
468        let retrieved = make_retrieved(&["a", "x", "b", "y", "c"]);
469        let relevant = make_relevant(&["a", "b", "c"]);
470
471        // AP = (1/3) * (1/1 + 2/3 + 3/5) = (1/3) * (1 + 0.667 + 0.6) ≈ 0.756
472        let ap = average_precision_impl(&retrieved, &relevant);
473        assert!(ap > 0.7 && ap < 0.8);
474    }
475
476    #[test]
477    fn test_average_precision_perfect() {
478        let retrieved = make_retrieved(&["a", "b", "c", "x", "y"]);
479        let relevant = make_relevant(&["a", "b", "c"]);
480
481        // AP = (1/3) * (1/1 + 2/2 + 3/3) = (1/3) * 3 = 1.0
482        let ap = average_precision_impl(&retrieved, &relevant);
483        assert_eq!(ap, 1.0);
484    }
485
486    #[test]
487    fn test_retrieval_metrics_compute() {
488        let retrieved = make_retrieved(&["a", "b", "x", "c", "y"]);
489        let relevant = make_relevant(&["a", "b", "c"]);
490
491        let metrics = RetrievalMetrics::compute_all(&retrieved, &relevant, &[5, 10]);
492
493        assert!(metrics.recall_at_k.contains_key(&5));
494        assert!(metrics.precision_at_k.contains_key(&5));
495        assert!(metrics.ndcg_at_k.contains_key(&5));
496        assert!(metrics.mrr > 0.0);
497        assert!(metrics.map > 0.0);
498    }
499}