rag_plusplus_core/eval/
mod.rs

1//! Evaluation Module for RAG++
2//!
3//! Provides metrics and utilities for evaluating retrieval quality.
4//!
5//! # Metrics
6//!
7//! - **Recall@K**: Fraction of relevant items retrieved in top-K
8//! - **Precision@K**: Fraction of retrieved items that are relevant
9//! - **MRR (Mean Reciprocal Rank)**: Average of reciprocal ranks of first relevant item
10//! - **NDCG@K**: Normalized Discounted Cumulative Gain
11//! - **Hit Rate@K**: Whether at least one relevant item is in top-K
12
13use std::collections::HashSet;
14use std::time::{Duration, Instant};
15
16/// Evaluation result for a single query.
17#[derive(Debug, Clone)]
18pub struct QueryEvaluation {
19    /// Query identifier
20    pub query_id: String,
21    /// Recall at K
22    pub recall: f64,
23    /// Precision at K
24    pub precision: f64,
25    /// Reciprocal rank (1/rank of first relevant item, 0 if none)
26    pub reciprocal_rank: f64,
27    /// NDCG at K
28    pub ndcg: f64,
29    /// Whether any relevant item was in top-K
30    pub hit: bool,
31    /// Query latency
32    pub latency: Duration,
33    /// Number of results returned
34    pub num_results: usize,
35}
36
37/// Aggregated evaluation metrics across multiple queries.
38#[derive(Debug, Clone, Default)]
39pub struct EvaluationSummary {
40    /// Number of queries evaluated
41    pub num_queries: usize,
42    /// Mean Recall@K
43    pub mean_recall: f64,
44    /// Mean Precision@K
45    pub mean_precision: f64,
46    /// Mean Reciprocal Rank
47    pub mrr: f64,
48    /// Mean NDCG@K
49    pub mean_ndcg: f64,
50    /// Hit rate (fraction of queries with at least one hit)
51    pub hit_rate: f64,
52    /// Mean latency
53    pub mean_latency: Duration,
54    /// P50 latency
55    pub p50_latency: Duration,
56    /// P95 latency
57    pub p95_latency: Duration,
58    /// P99 latency
59    pub p99_latency: Duration,
60}
61
62impl EvaluationSummary {
63    /// Create summary from individual evaluations.
64    pub fn from_evaluations(evals: &[QueryEvaluation]) -> Self {
65        if evals.is_empty() {
66            return Self::default();
67        }
68
69        let n = evals.len() as f64;
70
71        let mean_recall = evals.iter().map(|e| e.recall).sum::<f64>() / n;
72        let mean_precision = evals.iter().map(|e| e.precision).sum::<f64>() / n;
73        let mrr = evals.iter().map(|e| e.reciprocal_rank).sum::<f64>() / n;
74        let mean_ndcg = evals.iter().map(|e| e.ndcg).sum::<f64>() / n;
75        let hit_rate = evals.iter().filter(|e| e.hit).count() as f64 / n;
76
77        let mean_latency_nanos = evals.iter().map(|e| e.latency.as_nanos()).sum::<u128>() / evals.len() as u128;
78        let mean_latency = Duration::from_nanos(mean_latency_nanos as u64);
79
80        // Compute latency percentiles
81        let mut latencies: Vec<Duration> = evals.iter().map(|e| e.latency).collect();
82        latencies.sort();
83
84        let p50_idx = (evals.len() as f64 * 0.50) as usize;
85        let p95_idx = (evals.len() as f64 * 0.95) as usize;
86        let p99_idx = (evals.len() as f64 * 0.99) as usize;
87
88        let p50_latency = latencies.get(p50_idx.min(latencies.len() - 1)).copied().unwrap_or_default();
89        let p95_latency = latencies.get(p95_idx.min(latencies.len() - 1)).copied().unwrap_or_default();
90        let p99_latency = latencies.get(p99_idx.min(latencies.len() - 1)).copied().unwrap_or_default();
91
92        Self {
93            num_queries: evals.len(),
94            mean_recall,
95            mean_precision,
96            mrr,
97            mean_ndcg,
98            hit_rate,
99            mean_latency,
100            p50_latency,
101            p95_latency,
102            p99_latency,
103        }
104    }
105
106    /// Print a formatted summary report.
107    pub fn report(&self) -> String {
108        format!(
109            r#"
110=== RAG++ Evaluation Summary ===
111Queries evaluated: {}
112
113Retrieval Quality:
114  Mean Recall@K:    {:.4}
115  Mean Precision@K: {:.4}
116  MRR:              {:.4}
117  Mean NDCG@K:      {:.4}
118  Hit Rate:         {:.2}%
119
120Latency:
121  Mean:  {:?}
122  P50:   {:?}
123  P95:   {:?}
124  P99:   {:?}
125================================
126"#,
127            self.num_queries,
128            self.mean_recall,
129            self.mean_precision,
130            self.mrr,
131            self.mean_ndcg,
132            self.hit_rate * 100.0,
133            self.mean_latency,
134            self.p50_latency,
135            self.p95_latency,
136            self.p99_latency,
137        )
138    }
139}
140
141/// Evaluator for computing retrieval metrics.
142pub struct Evaluator {
143    /// K for recall@K, precision@K, etc.
144    k: usize,
145}
146
147impl Evaluator {
148    /// Create new evaluator with given K.
149    #[must_use]
150    pub fn new(k: usize) -> Self {
151        Self { k }
152    }
153
154    /// Evaluate a single query.
155    ///
156    /// # Arguments
157    /// * `query_id` - Identifier for this query
158    /// * `retrieved_ids` - IDs returned by the retrieval system (in ranked order)
159    /// * `relevant_ids` - Ground truth relevant IDs
160    /// * `latency` - Query execution time
161    pub fn evaluate_query(
162        &self,
163        query_id: impl Into<String>,
164        retrieved_ids: &[String],
165        relevant_ids: &HashSet<String>,
166        latency: Duration,
167    ) -> QueryEvaluation {
168        let k = self.k.min(retrieved_ids.len());
169        let top_k: Vec<_> = retrieved_ids.iter().take(k).collect();
170
171        // Recall@K: How many relevant items did we find?
172        let relevant_found = top_k.iter().filter(|id| relevant_ids.contains(id.as_str())).count();
173        let recall = if relevant_ids.is_empty() {
174            1.0 // No relevant items means perfect recall
175        } else {
176            relevant_found as f64 / relevant_ids.len() as f64
177        };
178
179        // Precision@K: How many retrieved items are relevant?
180        let precision = if k == 0 {
181            0.0
182        } else {
183            relevant_found as f64 / k as f64
184        };
185
186        // Reciprocal Rank: 1/rank of first relevant item
187        let reciprocal_rank = top_k
188            .iter()
189            .position(|id| relevant_ids.contains(id.as_str()))
190            .map(|pos| 1.0 / (pos + 1) as f64)
191            .unwrap_or(0.0);
192
193        // Hit: Did we find at least one relevant item?
194        let hit = relevant_found > 0;
195
196        // NDCG@K
197        let ndcg = self.compute_ndcg(&top_k, relevant_ids);
198
199        QueryEvaluation {
200            query_id: query_id.into(),
201            recall,
202            precision,
203            reciprocal_rank,
204            ndcg,
205            hit,
206            latency,
207            num_results: retrieved_ids.len(),
208        }
209    }
210
211    /// Compute NDCG (Normalized Discounted Cumulative Gain).
212    fn compute_ndcg(&self, retrieved: &[&String], relevant: &HashSet<String>) -> f64 {
213        if relevant.is_empty() {
214            return 1.0;
215        }
216
217        // DCG: sum of relevance / log2(rank + 1)
218        let dcg: f64 = retrieved
219            .iter()
220            .enumerate()
221            .map(|(i, id)| {
222                let rel = if relevant.contains(id.as_str()) { 1.0 } else { 0.0 };
223                rel / (i as f64 + 2.0).log2()
224            })
225            .sum();
226
227        // Ideal DCG: all relevant items at top
228        let ideal_k = self.k.min(relevant.len());
229        let idcg: f64 = (0..ideal_k)
230            .map(|i| 1.0 / (i as f64 + 2.0).log2())
231            .sum();
232
233        if idcg == 0.0 {
234            0.0
235        } else {
236            dcg / idcg
237        }
238    }
239}
240
241/// Benchmark runner for performance evaluation.
242pub struct Benchmarker {
243    /// Warm-up iterations
244    warmup_iters: usize,
245    /// Measurement iterations
246    measure_iters: usize,
247}
248
249impl Benchmarker {
250    /// Create new benchmarker.
251    #[must_use]
252    pub fn new(warmup_iters: usize, measure_iters: usize) -> Self {
253        Self {
254            warmup_iters,
255            measure_iters,
256        }
257    }
258
259    /// Run benchmark on a function.
260    pub fn run<F>(&self, mut f: F) -> BenchmarkResult
261    where
262        F: FnMut(),
263    {
264        // Warm-up
265        for _ in 0..self.warmup_iters {
266            f();
267        }
268
269        // Measure
270        let mut durations = Vec::with_capacity(self.measure_iters);
271        for _ in 0..self.measure_iters {
272            let start = Instant::now();
273            f();
274            durations.push(start.elapsed());
275        }
276
277        // Compute stats
278        durations.sort();
279        let total: Duration = durations.iter().sum();
280        let mean = total / self.measure_iters as u32;
281
282        let p50 = durations[durations.len() / 2];
283        let p95 = durations[(durations.len() as f64 * 0.95) as usize];
284        let p99 = durations[(durations.len() as f64 * 0.99) as usize];
285        let min = durations[0];
286        let max = durations[durations.len() - 1];
287
288        BenchmarkResult {
289            iterations: self.measure_iters,
290            mean,
291            p50,
292            p95,
293            p99,
294            min,
295            max,
296        }
297    }
298}
299
300/// Result of a benchmark run.
301#[derive(Debug, Clone)]
302pub struct BenchmarkResult {
303    /// Number of measured iterations
304    pub iterations: usize,
305    /// Mean duration
306    pub mean: Duration,
307    /// P50 (median) duration
308    pub p50: Duration,
309    /// P95 duration
310    pub p95: Duration,
311    /// P99 duration
312    pub p99: Duration,
313    /// Minimum duration
314    pub min: Duration,
315    /// Maximum duration
316    pub max: Duration,
317}
318
319impl BenchmarkResult {
320    /// Throughput in operations per second.
321    #[must_use]
322    pub fn throughput(&self) -> f64 {
323        1.0 / self.mean.as_secs_f64()
324    }
325
326    /// Print formatted benchmark report.
327    pub fn report(&self, name: &str) -> String {
328        format!(
329            r#"
330=== Benchmark: {} ===
331Iterations: {}
332Mean:       {:?}
333P50:        {:?}
334P95:        {:?}
335P99:        {:?}
336Min:        {:?}
337Max:        {:?}
338Throughput: {:.2} ops/sec
339======================
340"#,
341            name,
342            self.iterations,
343            self.mean,
344            self.p50,
345            self.p95,
346            self.p99,
347            self.min,
348            self.max,
349            self.throughput(),
350        )
351    }
352}
353
354#[cfg(test)]
355mod tests {
356    use super::*;
357
358    #[test]
359    fn test_perfect_recall() {
360        let evaluator = Evaluator::new(10);
361        let retrieved: Vec<String> = (0..10).map(|i| format!("doc-{i}")).collect();
362        let relevant: HashSet<String> = (0..5).map(|i| format!("doc-{i}")).collect();
363
364        let eval = evaluator.evaluate_query("q1", &retrieved, &relevant, Duration::from_millis(10));
365
366        assert_eq!(eval.recall, 1.0); // All 5 relevant items in top 10
367        assert_eq!(eval.precision, 0.5); // 5 of 10 retrieved are relevant
368        assert!(eval.hit);
369    }
370
371    #[test]
372    fn test_no_relevant_items() {
373        let evaluator = Evaluator::new(10);
374        let retrieved: Vec<String> = (0..10).map(|i| format!("doc-{i}")).collect();
375        let relevant: HashSet<String> = HashSet::new();
376
377        let eval = evaluator.evaluate_query("q1", &retrieved, &relevant, Duration::from_millis(10));
378
379        assert_eq!(eval.recall, 1.0); // No relevant = perfect recall
380        assert_eq!(eval.precision, 0.0); // No relevant = 0 precision
381        assert!(!eval.hit);
382    }
383
384    #[test]
385    fn test_mrr() {
386        let evaluator = Evaluator::new(10);
387
388        // Relevant item at position 0
389        let retrieved1 = vec!["a".to_string(), "b".to_string(), "c".to_string()];
390        let relevant1: HashSet<_> = ["a".to_string()].into();
391        let eval1 = evaluator.evaluate_query("q1", &retrieved1, &relevant1, Duration::ZERO);
392        assert!((eval1.reciprocal_rank - 1.0).abs() < 1e-6);
393
394        // Relevant item at position 2
395        let retrieved2 = vec!["x".to_string(), "y".to_string(), "a".to_string()];
396        let eval2 = evaluator.evaluate_query("q2", &retrieved2, &relevant1, Duration::ZERO);
397        assert!((eval2.reciprocal_rank - 1.0 / 3.0).abs() < 1e-6);
398    }
399
400    #[test]
401    fn test_evaluation_summary() {
402        let evals = vec![
403            QueryEvaluation {
404                query_id: "q1".into(),
405                recall: 1.0,
406                precision: 0.5,
407                reciprocal_rank: 1.0,
408                ndcg: 0.8,
409                hit: true,
410                latency: Duration::from_millis(10),
411                num_results: 10,
412            },
413            QueryEvaluation {
414                query_id: "q2".into(),
415                recall: 0.5,
416                precision: 0.25,
417                reciprocal_rank: 0.5,
418                ndcg: 0.6,
419                hit: true,
420                latency: Duration::from_millis(20),
421                num_results: 10,
422            },
423        ];
424
425        let summary = EvaluationSummary::from_evaluations(&evals);
426
427        assert_eq!(summary.num_queries, 2);
428        assert!((summary.mean_recall - 0.75).abs() < 1e-6);
429        assert!((summary.mrr - 0.75).abs() < 1e-6);
430        assert_eq!(summary.hit_rate, 1.0);
431    }
432
433    #[test]
434    fn test_benchmarker() {
435        let benchmarker = Benchmarker::new(2, 10);
436        let mut counter = 0;
437
438        let result = benchmarker.run(|| {
439            counter += 1;
440            std::thread::sleep(Duration::from_micros(100));
441        });
442
443        assert_eq!(counter, 12); // 2 warmup + 10 measured
444        assert!(result.mean >= Duration::from_micros(100));
445        assert!(result.throughput() > 0.0);
446    }
447}