Skip to main content

phago_rag/
scoring.rs

1//! Information retrieval scoring metrics.
2//!
3//! Implements standard IR metrics: Precision@k, Recall@k, MRR, NDCG@k.
4//! Used to benchmark bio-rag query performance against baselines.
5
6use serde::Serialize;
7use std::collections::HashSet;
8
9/// Scores for a single query.
10#[derive(Debug, Clone, Serialize)]
11pub struct QueryScores {
12    pub query: String,
13    pub precision_at_5: f64,
14    pub precision_at_10: f64,
15    pub recall_at_5: f64,
16    pub recall_at_10: f64,
17    pub mrr: f64,
18    pub ndcg_at_10: f64,
19}
20
21/// Aggregate scores across multiple queries.
22#[derive(Debug, Clone, Serialize)]
23pub struct AggregateScores {
24    pub mean_precision_at_5: f64,
25    pub mean_precision_at_10: f64,
26    pub mean_recall_at_5: f64,
27    pub mean_recall_at_10: f64,
28    pub mean_mrr: f64,
29    pub mean_ndcg_at_10: f64,
30    pub query_count: usize,
31}
32
33/// Compute precision@k: fraction of top-k results that are relevant.
34pub fn precision_at_k(retrieved: &[String], relevant: &HashSet<String>, k: usize) -> f64 {
35    let top_k: Vec<&String> = retrieved.iter().take(k).collect();
36    if top_k.is_empty() {
37        return 0.0;
38    }
39    let hits = top_k.iter().filter(|r| relevant.contains(r.as_str())).count();
40    hits as f64 / top_k.len() as f64
41}
42
43/// Compute recall@k: fraction of relevant items found in top-k.
44pub fn recall_at_k(retrieved: &[String], relevant: &HashSet<String>, k: usize) -> f64 {
45    if relevant.is_empty() {
46        return 0.0;
47    }
48    let top_k: Vec<&String> = retrieved.iter().take(k).collect();
49    let hits = top_k.iter().filter(|r| relevant.contains(r.as_str())).count();
50    hits as f64 / relevant.len() as f64
51}
52
53/// Compute Mean Reciprocal Rank: 1/rank of first relevant result.
54pub fn mrr(retrieved: &[String], relevant: &HashSet<String>) -> f64 {
55    for (i, item) in retrieved.iter().enumerate() {
56        if relevant.contains(item.as_str()) {
57            return 1.0 / (i as f64 + 1.0);
58        }
59    }
60    0.0
61}
62
63/// Compute NDCG@k: Normalized Discounted Cumulative Gain.
64pub fn ndcg_at_k(retrieved: &[String], relevant: &HashSet<String>, k: usize) -> f64 {
65    let dcg = dcg(retrieved, relevant, k);
66    let ideal_dcg = ideal_dcg(relevant.len(), k);
67    if ideal_dcg == 0.0 {
68        return 0.0;
69    }
70    dcg / ideal_dcg
71}
72
73fn dcg(retrieved: &[String], relevant: &HashSet<String>, k: usize) -> f64 {
74    retrieved
75        .iter()
76        .take(k)
77        .enumerate()
78        .map(|(i, item)| {
79            let rel = if relevant.contains(item.as_str()) { 1.0 } else { 0.0 };
80            rel / (i as f64 + 2.0).log2()
81        })
82        .sum()
83}
84
85fn ideal_dcg(num_relevant: usize, k: usize) -> f64 {
86    (0..num_relevant.min(k))
87        .map(|i| 1.0 / (i as f64 + 2.0).log2())
88        .sum()
89}
90
91/// Compute all scores for a single query.
92pub fn score_query(
93    query_text: &str,
94    retrieved: &[String],
95    relevant: &HashSet<String>,
96) -> QueryScores {
97    QueryScores {
98        query: query_text.to_string(),
99        precision_at_5: precision_at_k(retrieved, relevant, 5),
100        precision_at_10: precision_at_k(retrieved, relevant, 10),
101        recall_at_5: recall_at_k(retrieved, relevant, 5),
102        recall_at_10: recall_at_k(retrieved, relevant, 10),
103        mrr: mrr(retrieved, relevant),
104        ndcg_at_10: ndcg_at_k(retrieved, relevant, 10),
105    }
106}
107
108/// Aggregate scores across multiple queries.
109pub fn aggregate(scores: &[QueryScores]) -> AggregateScores {
110    let n = scores.len();
111    if n == 0 {
112        return AggregateScores {
113            mean_precision_at_5: 0.0,
114            mean_precision_at_10: 0.0,
115            mean_recall_at_5: 0.0,
116            mean_recall_at_10: 0.0,
117            mean_mrr: 0.0,
118            mean_ndcg_at_10: 0.0,
119            query_count: 0,
120        };
121    }
122
123    let nf = n as f64;
124    AggregateScores {
125        mean_precision_at_5: scores.iter().map(|s| s.precision_at_5).sum::<f64>() / nf,
126        mean_precision_at_10: scores.iter().map(|s| s.precision_at_10).sum::<f64>() / nf,
127        mean_recall_at_5: scores.iter().map(|s| s.recall_at_5).sum::<f64>() / nf,
128        mean_recall_at_10: scores.iter().map(|s| s.recall_at_10).sum::<f64>() / nf,
129        mean_mrr: scores.iter().map(|s| s.mrr).sum::<f64>() / nf,
130        mean_ndcg_at_10: scores.iter().map(|s| s.ndcg_at_10).sum::<f64>() / nf,
131        query_count: n,
132    }
133}
134
135#[cfg(test)]
136mod tests {
137    use super::*;
138
139    #[test]
140    fn perfect_precision() {
141        let relevant: HashSet<String> = ["a", "b", "c"].iter().map(|s| s.to_string()).collect();
142        let retrieved: Vec<String> = ["a", "b", "c", "d", "e"].iter().map(|s| s.to_string()).collect();
143        assert!((precision_at_k(&retrieved, &relevant, 3) - 1.0).abs() < 1e-10);
144    }
145
146    #[test]
147    fn zero_precision_when_no_relevant() {
148        let relevant: HashSet<String> = ["x", "y"].iter().map(|s| s.to_string()).collect();
149        let retrieved: Vec<String> = ["a", "b", "c"].iter().map(|s| s.to_string()).collect();
150        assert!((precision_at_k(&retrieved, &relevant, 3) - 0.0).abs() < 1e-10);
151    }
152
153    #[test]
154    fn mrr_first_position() {
155        let relevant: HashSet<String> = ["a"].iter().map(|s| s.to_string()).collect();
156        let retrieved: Vec<String> = ["a", "b", "c"].iter().map(|s| s.to_string()).collect();
157        assert!((mrr(&retrieved, &relevant) - 1.0).abs() < 1e-10);
158    }
159
160    #[test]
161    fn mrr_second_position() {
162        let relevant: HashSet<String> = ["b"].iter().map(|s| s.to_string()).collect();
163        let retrieved: Vec<String> = ["a", "b", "c"].iter().map(|s| s.to_string()).collect();
164        assert!((mrr(&retrieved, &relevant) - 0.5).abs() < 1e-10);
165    }
166
167    #[test]
168    fn ndcg_perfect() {
169        let relevant: HashSet<String> = ["a", "b"].iter().map(|s| s.to_string()).collect();
170        let retrieved: Vec<String> = ["a", "b", "c"].iter().map(|s| s.to_string()).collect();
171        assert!((ndcg_at_k(&retrieved, &relevant, 3) - 1.0).abs() < 1e-10);
172    }
173}