1use std::collections::HashSet;
14use std::time::{Duration, Instant};
15
16#[derive(Debug, Clone)]
18pub struct QueryEvaluation {
19 pub query_id: String,
21 pub recall: f64,
23 pub precision: f64,
25 pub reciprocal_rank: f64,
27 pub ndcg: f64,
29 pub hit: bool,
31 pub latency: Duration,
33 pub num_results: usize,
35}
36
37#[derive(Debug, Clone, Default)]
39pub struct EvaluationSummary {
40 pub num_queries: usize,
42 pub mean_recall: f64,
44 pub mean_precision: f64,
46 pub mrr: f64,
48 pub mean_ndcg: f64,
50 pub hit_rate: f64,
52 pub mean_latency: Duration,
54 pub p50_latency: Duration,
56 pub p95_latency: Duration,
58 pub p99_latency: Duration,
60}
61
62impl EvaluationSummary {
63 pub fn from_evaluations(evals: &[QueryEvaluation]) -> Self {
65 if evals.is_empty() {
66 return Self::default();
67 }
68
69 let n = evals.len() as f64;
70
71 let mean_recall = evals.iter().map(|e| e.recall).sum::<f64>() / n;
72 let mean_precision = evals.iter().map(|e| e.precision).sum::<f64>() / n;
73 let mrr = evals.iter().map(|e| e.reciprocal_rank).sum::<f64>() / n;
74 let mean_ndcg = evals.iter().map(|e| e.ndcg).sum::<f64>() / n;
75 let hit_rate = evals.iter().filter(|e| e.hit).count() as f64 / n;
76
77 let mean_latency_nanos = evals.iter().map(|e| e.latency.as_nanos()).sum::<u128>() / evals.len() as u128;
78 let mean_latency = Duration::from_nanos(mean_latency_nanos as u64);
79
80 let mut latencies: Vec<Duration> = evals.iter().map(|e| e.latency).collect();
82 latencies.sort();
83
84 let p50_idx = (evals.len() as f64 * 0.50) as usize;
85 let p95_idx = (evals.len() as f64 * 0.95) as usize;
86 let p99_idx = (evals.len() as f64 * 0.99) as usize;
87
88 let p50_latency = latencies.get(p50_idx.min(latencies.len() - 1)).copied().unwrap_or_default();
89 let p95_latency = latencies.get(p95_idx.min(latencies.len() - 1)).copied().unwrap_or_default();
90 let p99_latency = latencies.get(p99_idx.min(latencies.len() - 1)).copied().unwrap_or_default();
91
92 Self {
93 num_queries: evals.len(),
94 mean_recall,
95 mean_precision,
96 mrr,
97 mean_ndcg,
98 hit_rate,
99 mean_latency,
100 p50_latency,
101 p95_latency,
102 p99_latency,
103 }
104 }
105
106 pub fn report(&self) -> String {
108 format!(
109 r#"
110=== RAG++ Evaluation Summary ===
111Queries evaluated: {}
112
113Retrieval Quality:
114 Mean Recall@K: {:.4}
115 Mean Precision@K: {:.4}
116 MRR: {:.4}
117 Mean NDCG@K: {:.4}
118 Hit Rate: {:.2}%
119
120Latency:
121 Mean: {:?}
122 P50: {:?}
123 P95: {:?}
124 P99: {:?}
125================================
126"#,
127 self.num_queries,
128 self.mean_recall,
129 self.mean_precision,
130 self.mrr,
131 self.mean_ndcg,
132 self.hit_rate * 100.0,
133 self.mean_latency,
134 self.p50_latency,
135 self.p95_latency,
136 self.p99_latency,
137 )
138 }
139}
140
141pub struct Evaluator {
143 k: usize,
145}
146
147impl Evaluator {
148 #[must_use]
150 pub fn new(k: usize) -> Self {
151 Self { k }
152 }
153
154 pub fn evaluate_query(
162 &self,
163 query_id: impl Into<String>,
164 retrieved_ids: &[String],
165 relevant_ids: &HashSet<String>,
166 latency: Duration,
167 ) -> QueryEvaluation {
168 let k = self.k.min(retrieved_ids.len());
169 let top_k: Vec<_> = retrieved_ids.iter().take(k).collect();
170
171 let relevant_found = top_k.iter().filter(|id| relevant_ids.contains(id.as_str())).count();
173 let recall = if relevant_ids.is_empty() {
174 1.0 } else {
176 relevant_found as f64 / relevant_ids.len() as f64
177 };
178
179 let precision = if k == 0 {
181 0.0
182 } else {
183 relevant_found as f64 / k as f64
184 };
185
186 let reciprocal_rank = top_k
188 .iter()
189 .position(|id| relevant_ids.contains(id.as_str()))
190 .map(|pos| 1.0 / (pos + 1) as f64)
191 .unwrap_or(0.0);
192
193 let hit = relevant_found > 0;
195
196 let ndcg = self.compute_ndcg(&top_k, relevant_ids);
198
199 QueryEvaluation {
200 query_id: query_id.into(),
201 recall,
202 precision,
203 reciprocal_rank,
204 ndcg,
205 hit,
206 latency,
207 num_results: retrieved_ids.len(),
208 }
209 }
210
211 fn compute_ndcg(&self, retrieved: &[&String], relevant: &HashSet<String>) -> f64 {
213 if relevant.is_empty() {
214 return 1.0;
215 }
216
217 let dcg: f64 = retrieved
219 .iter()
220 .enumerate()
221 .map(|(i, id)| {
222 let rel = if relevant.contains(id.as_str()) { 1.0 } else { 0.0 };
223 rel / (i as f64 + 2.0).log2()
224 })
225 .sum();
226
227 let ideal_k = self.k.min(relevant.len());
229 let idcg: f64 = (0..ideal_k)
230 .map(|i| 1.0 / (i as f64 + 2.0).log2())
231 .sum();
232
233 if idcg == 0.0 {
234 0.0
235 } else {
236 dcg / idcg
237 }
238 }
239}
240
241pub struct Benchmarker {
243 warmup_iters: usize,
245 measure_iters: usize,
247}
248
249impl Benchmarker {
250 #[must_use]
252 pub fn new(warmup_iters: usize, measure_iters: usize) -> Self {
253 Self {
254 warmup_iters,
255 measure_iters,
256 }
257 }
258
259 pub fn run<F>(&self, mut f: F) -> BenchmarkResult
261 where
262 F: FnMut(),
263 {
264 for _ in 0..self.warmup_iters {
266 f();
267 }
268
269 let mut durations = Vec::with_capacity(self.measure_iters);
271 for _ in 0..self.measure_iters {
272 let start = Instant::now();
273 f();
274 durations.push(start.elapsed());
275 }
276
277 durations.sort();
279 let total: Duration = durations.iter().sum();
280 let mean = total / self.measure_iters as u32;
281
282 let p50 = durations[durations.len() / 2];
283 let p95 = durations[(durations.len() as f64 * 0.95) as usize];
284 let p99 = durations[(durations.len() as f64 * 0.99) as usize];
285 let min = durations[0];
286 let max = durations[durations.len() - 1];
287
288 BenchmarkResult {
289 iterations: self.measure_iters,
290 mean,
291 p50,
292 p95,
293 p99,
294 min,
295 max,
296 }
297 }
298}
299
300#[derive(Debug, Clone)]
302pub struct BenchmarkResult {
303 pub iterations: usize,
305 pub mean: Duration,
307 pub p50: Duration,
309 pub p95: Duration,
311 pub p99: Duration,
313 pub min: Duration,
315 pub max: Duration,
317}
318
319impl BenchmarkResult {
320 #[must_use]
322 pub fn throughput(&self) -> f64 {
323 1.0 / self.mean.as_secs_f64()
324 }
325
326 pub fn report(&self, name: &str) -> String {
328 format!(
329 r#"
330=== Benchmark: {} ===
331Iterations: {}
332Mean: {:?}
333P50: {:?}
334P95: {:?}
335P99: {:?}
336Min: {:?}
337Max: {:?}
338Throughput: {:.2} ops/sec
339======================
340"#,
341 name,
342 self.iterations,
343 self.mean,
344 self.p50,
345 self.p95,
346 self.p99,
347 self.min,
348 self.max,
349 self.throughput(),
350 )
351 }
352}
353
354#[cfg(test)]
355mod tests {
356 use super::*;
357
358 #[test]
359 fn test_perfect_recall() {
360 let evaluator = Evaluator::new(10);
361 let retrieved: Vec<String> = (0..10).map(|i| format!("doc-{i}")).collect();
362 let relevant: HashSet<String> = (0..5).map(|i| format!("doc-{i}")).collect();
363
364 let eval = evaluator.evaluate_query("q1", &retrieved, &relevant, Duration::from_millis(10));
365
366 assert_eq!(eval.recall, 1.0); assert_eq!(eval.precision, 0.5); assert!(eval.hit);
369 }
370
371 #[test]
372 fn test_no_relevant_items() {
373 let evaluator = Evaluator::new(10);
374 let retrieved: Vec<String> = (0..10).map(|i| format!("doc-{i}")).collect();
375 let relevant: HashSet<String> = HashSet::new();
376
377 let eval = evaluator.evaluate_query("q1", &retrieved, &relevant, Duration::from_millis(10));
378
379 assert_eq!(eval.recall, 1.0); assert_eq!(eval.precision, 0.0); assert!(!eval.hit);
382 }
383
384 #[test]
385 fn test_mrr() {
386 let evaluator = Evaluator::new(10);
387
388 let retrieved1 = vec!["a".to_string(), "b".to_string(), "c".to_string()];
390 let relevant1: HashSet<_> = ["a".to_string()].into();
391 let eval1 = evaluator.evaluate_query("q1", &retrieved1, &relevant1, Duration::ZERO);
392 assert!((eval1.reciprocal_rank - 1.0).abs() < 1e-6);
393
394 let retrieved2 = vec!["x".to_string(), "y".to_string(), "a".to_string()];
396 let eval2 = evaluator.evaluate_query("q2", &retrieved2, &relevant1, Duration::ZERO);
397 assert!((eval2.reciprocal_rank - 1.0 / 3.0).abs() < 1e-6);
398 }
399
400 #[test]
401 fn test_evaluation_summary() {
402 let evals = vec![
403 QueryEvaluation {
404 query_id: "q1".into(),
405 recall: 1.0,
406 precision: 0.5,
407 reciprocal_rank: 1.0,
408 ndcg: 0.8,
409 hit: true,
410 latency: Duration::from_millis(10),
411 num_results: 10,
412 },
413 QueryEvaluation {
414 query_id: "q2".into(),
415 recall: 0.5,
416 precision: 0.25,
417 reciprocal_rank: 0.5,
418 ndcg: 0.6,
419 hit: true,
420 latency: Duration::from_millis(20),
421 num_results: 10,
422 },
423 ];
424
425 let summary = EvaluationSummary::from_evaluations(&evals);
426
427 assert_eq!(summary.num_queries, 2);
428 assert!((summary.mean_recall - 0.75).abs() < 1e-6);
429 assert!((summary.mrr - 0.75).abs() < 1e-6);
430 assert_eq!(summary.hit_rate, 1.0);
431 }
432
433 #[test]
434 fn test_benchmarker() {
435 let benchmarker = Benchmarker::new(2, 10);
436 let mut counter = 0;
437
438 let result = benchmarker.run(|| {
439 counter += 1;
440 std::thread::sleep(Duration::from_micros(100));
441 });
442
443 assert_eq!(counter, 12); assert!(result.mean >= Duration::from_micros(100));
445 assert!(result.throughput() > 0.0);
446 }
447}