1use std::collections::HashSet;
11
12#[derive(Debug, Clone)]
14pub struct EvalQuery {
15 pub query: String,
16 pub relevant_files: Vec<String>,
17 pub expected_top: Option<String>,
18}
19
20#[derive(Debug, Clone)]
22pub struct EvalReport {
23 pub query_count: usize,
24 pub precision_at_5: f64,
25 pub precision_at_10: f64,
26 pub recall_at_10: f64,
27 pub mrr: f64,
28 pub ndcg_at_10: f64,
29 pub per_query: Vec<QueryScore>,
30}
31
32#[derive(Debug, Clone)]
33pub struct QueryScore {
34 pub query: String,
35 pub precision_at_k: f64,
36 pub recall: f64,
37 pub reciprocal_rank: f64,
38 pub ndcg: f64,
39}
40
41#[derive(Debug, Clone)]
43pub struct RetrievedItem {
44 pub file_path: String,
45 pub score: f64,
46}
47
48pub fn evaluate(queries: &[EvalQuery], retrieve_fn: &dyn Fn(&str) -> Vec<RetrievedItem>) -> EvalReport {
50 let mut per_query = Vec::with_capacity(queries.len());
51 let mut sum_p5 = 0.0;
52 let mut sum_p10 = 0.0;
53 let mut sum_r10 = 0.0;
54 let mut sum_mrr = 0.0;
55 let mut sum_ndcg = 0.0;
56
57 for q in queries {
58 let results = retrieve_fn(&q.query);
59 let relevant: HashSet<&str> = q.relevant_files.iter().map(|s| s.as_str()).collect();
60
61 let p5 = precision_at_k(&results, &relevant, 5);
62 let p10 = precision_at_k(&results, &relevant, 10);
63 let r10 = recall_at_k(&results, &relevant, 10);
64 let rr = reciprocal_rank(&results, &relevant);
65 let ndcg = ndcg_at_k(&results, &relevant, 10);
66
67 sum_p5 += p5;
68 sum_p10 += p10;
69 sum_r10 += r10;
70 sum_mrr += rr;
71 sum_ndcg += ndcg;
72
73 per_query.push(QueryScore {
74 query: q.query.clone(),
75 precision_at_k: p10,
76 recall: r10,
77 reciprocal_rank: rr,
78 ndcg,
79 });
80 }
81
82 let n = queries.len().max(1) as f64;
83 EvalReport {
84 query_count: queries.len(),
85 precision_at_5: sum_p5 / n,
86 precision_at_10: sum_p10 / n,
87 recall_at_10: sum_r10 / n,
88 mrr: sum_mrr / n,
89 ndcg_at_10: sum_ndcg / n,
90 per_query,
91 }
92}
93
94fn precision_at_k(results: &[RetrievedItem], relevant: &HashSet<&str>, k: usize) -> f64 {
95 let top_k: Vec<&RetrievedItem> = results.iter().take(k).collect();
96 if top_k.is_empty() {
97 return 0.0;
98 }
99 let hits = top_k
100 .iter()
101 .filter(|r| relevant.contains(r.file_path.as_str()))
102 .count();
103 hits as f64 / top_k.len() as f64
104}
105
106fn recall_at_k(results: &[RetrievedItem], relevant: &HashSet<&str>, k: usize) -> f64 {
107 if relevant.is_empty() {
108 return 0.0;
109 }
110 let hits = results
111 .iter()
112 .take(k)
113 .filter(|r| relevant.contains(r.file_path.as_str()))
114 .count();
115 hits as f64 / relevant.len() as f64
116}
117
118fn reciprocal_rank(results: &[RetrievedItem], relevant: &HashSet<&str>) -> f64 {
119 for (i, r) in results.iter().enumerate() {
120 if relevant.contains(r.file_path.as_str()) {
121 return 1.0 / (i + 1) as f64;
122 }
123 }
124 0.0
125}
126
127fn ndcg_at_k(results: &[RetrievedItem], relevant: &HashSet<&str>, k: usize) -> f64 {
128 let dcg = results
129 .iter()
130 .take(k)
131 .enumerate()
132 .map(|(i, r)| {
133 let gain = if relevant.contains(r.file_path.as_str()) { 1.0 } else { 0.0 };
134 gain / (2.0f64 + i as f64).log2()
135 })
136 .sum::<f64>();
137
138 let ideal_count = relevant.len().min(k);
139 let ideal_dcg: f64 = (0..ideal_count)
140 .map(|i| 1.0 / (2.0f64 + i as f64).log2())
141 .sum();
142
143 if ideal_dcg == 0.0 {
144 return 0.0;
145 }
146 dcg / ideal_dcg
147}
148
149impl EvalReport {
150 pub fn to_compact_string(&self) -> String {
151 format!(
152 "P@5={:.3} P@10={:.3} R@10={:.3} MRR={:.3} nDCG@10={:.3} (n={})",
153 self.precision_at_5,
154 self.precision_at_10,
155 self.recall_at_10,
156 self.mrr,
157 self.ndcg_at_10,
158 self.query_count,
159 )
160 }
161
162 pub fn passed_threshold(&self, min_mrr: f64, min_ndcg: f64) -> bool {
163 self.mrr >= min_mrr && self.ndcg_at_10 >= min_ndcg
164 }
165}
166
167#[cfg(test)]
168mod tests {
169 use super::*;
170
171 fn items(files: &[&str]) -> Vec<RetrievedItem> {
172 files
173 .iter()
174 .enumerate()
175 .map(|(i, f)| RetrievedItem {
176 file_path: f.to_string(),
177 score: 10.0 - i as f64,
178 })
179 .collect()
180 }
181
182 #[test]
183 fn precision_at_k_perfect() {
184 let relevant: HashSet<&str> = ["a.rs", "b.rs"].into_iter().collect();
185 let results = items(&["a.rs", "b.rs", "c.rs"]);
186 assert!((precision_at_k(&results, &relevant, 2) - 1.0).abs() < 1e-6);
187 }
188
189 #[test]
190 fn precision_at_k_half() {
191 let relevant: HashSet<&str> = ["a.rs"].into_iter().collect();
192 let results = items(&["a.rs", "b.rs"]);
193 assert!((precision_at_k(&results, &relevant, 2) - 0.5).abs() < 1e-6);
194 }
195
196 #[test]
197 fn precision_at_k_none() {
198 let relevant: HashSet<&str> = ["x.rs"].into_iter().collect();
199 let results = items(&["a.rs", "b.rs"]);
200 assert!((precision_at_k(&results, &relevant, 2) - 0.0).abs() < 1e-6);
201 }
202
203 #[test]
204 fn recall_at_k_full() {
205 let relevant: HashSet<&str> = ["a.rs"].into_iter().collect();
206 let results = items(&["a.rs", "b.rs", "c.rs"]);
207 assert!((recall_at_k(&results, &relevant, 3) - 1.0).abs() < 1e-6);
208 }
209
210 #[test]
211 fn recall_at_k_partial() {
212 let relevant: HashSet<&str> = ["a.rs", "d.rs"].into_iter().collect();
213 let results = items(&["a.rs", "b.rs", "c.rs"]);
214 assert!((recall_at_k(&results, &relevant, 3) - 0.5).abs() < 1e-6);
215 }
216
217 #[test]
218 fn mrr_first_position() {
219 let relevant: HashSet<&str> = ["a.rs"].into_iter().collect();
220 let results = items(&["a.rs", "b.rs"]);
221 assert!((reciprocal_rank(&results, &relevant) - 1.0).abs() < 1e-6);
222 }
223
224 #[test]
225 fn mrr_second_position() {
226 let relevant: HashSet<&str> = ["b.rs"].into_iter().collect();
227 let results = items(&["a.rs", "b.rs"]);
228 assert!((reciprocal_rank(&results, &relevant) - 0.5).abs() < 1e-6);
229 }
230
231 #[test]
232 fn mrr_not_found() {
233 let relevant: HashSet<&str> = ["x.rs"].into_iter().collect();
234 let results = items(&["a.rs", "b.rs"]);
235 assert!((reciprocal_rank(&results, &relevant) - 0.0).abs() < 1e-6);
236 }
237
238 #[test]
239 fn ndcg_perfect() {
240 let relevant: HashSet<&str> = ["a.rs", "b.rs"].into_iter().collect();
241 let results = items(&["a.rs", "b.rs", "c.rs"]);
242 let score = ndcg_at_k(&results, &relevant, 3);
243 assert!((score - 1.0).abs() < 1e-6, "perfect ranking should give nDCG=1.0, got {score}");
244 }
245
246 #[test]
247 fn ndcg_imperfect() {
248 let relevant: HashSet<&str> = ["b.rs"].into_iter().collect();
249 let results = items(&["a.rs", "b.rs", "c.rs"]);
250 let score = ndcg_at_k(&results, &relevant, 3);
251 assert!(score > 0.0 && score < 1.0, "imperfect ranking: {score}");
252 }
253
254 #[test]
255 fn evaluate_pipeline() {
256 let queries = vec![
257 EvalQuery {
258 query: "authentication".to_string(),
259 relevant_files: vec!["auth.rs".to_string()],
260 expected_top: Some("auth.rs".to_string()),
261 },
262 EvalQuery {
263 query: "database connection".to_string(),
264 relevant_files: vec!["db.rs".to_string(), "pool.rs".to_string()],
265 expected_top: Some("db.rs".to_string()),
266 },
267 ];
268
269 let report = evaluate(&queries, &|q| {
270 if q.contains("auth") {
271 items(&["auth.rs", "user.rs", "session.rs"])
272 } else {
273 items(&["db.rs", "pool.rs", "config.rs"])
274 }
275 });
276
277 assert_eq!(report.query_count, 2);
278 assert!(report.mrr > 0.5, "MRR should be high: {}", report.mrr);
279 assert!(report.ndcg_at_10 > 0.5, "nDCG should be high: {}", report.ndcg_at_10);
280 }
281
282 #[test]
283 fn report_threshold() {
284 let queries = vec![EvalQuery {
285 query: "test".to_string(),
286 relevant_files: vec!["test.rs".to_string()],
287 expected_top: None,
288 }];
289 let report = evaluate(&queries, &|_| items(&["test.rs", "other.rs"]));
290 assert!(report.passed_threshold(0.5, 0.5));
291 assert!(!report.passed_threshold(2.0, 2.0));
292 }
293}