oxify_vector/
recall_eval.rs

1//! Recall Evaluation for ANN Indexes
2//!
3//! Tools for evaluating the quality of Approximate Nearest Neighbor (ANN) indexes
4//! by comparing their results against ground truth exact search.
5//!
6//! ## Features
7//!
8//! - **Ground Truth Generation**: Generate exact search results for comparison
9//! - **Recall@k Calculation**: Measure how many true nearest neighbors are found
10//! - **Precision@k**: Measure accuracy of retrieved results
11//! - **nDCG@k**: Normalized Discounted Cumulative Gain for ranking quality
12//! - **Configuration Comparison**: Compare different index configurations
13//!
14//! ## Example
15//!
16//! ```rust
17//! use oxify_vector::recall_eval::{RecallEvaluator, EvaluationConfig};
18//! use oxify_vector::{HnswIndex, HnswConfig, SearchConfig, VectorSearchIndex};
19//! use std::collections::HashMap;
20//!
21//! # fn example() -> anyhow::Result<()> {
22//! // Create test dataset
23//! let mut embeddings = HashMap::new();
24//! for i in 0..1000 {
25//!     let vec = vec![i as f32 * 0.01, (i * 2) as f32 * 0.01, (i * 3) as f32 * 0.01];
26//!     embeddings.insert(format!("doc{}", i), vec);
27//! }
28//!
29//! // Build exact and approximate indexes
30//! let mut exact_index = VectorSearchIndex::new(SearchConfig::default());
31//! exact_index.build(&embeddings)?;
32//!
33//! let mut hnsw_index = HnswIndex::new(HnswConfig::default());
34//! hnsw_index.build(&embeddings)?;
35//!
36//! // Evaluate recall
37//! let config = EvaluationConfig::default();
38//! let evaluator = RecallEvaluator::new(config);
39//!
40//! let query = vec![0.5, 1.0, 1.5];
41//! let metrics = evaluator.evaluate_single_query(
42//!     &query,
43//!     |q, k| exact_index.search(q, k),
44//!     |q, k| hnsw_index.search(q, k),
45//! )?;
46//!
47//! // metrics is a Vec<QueryMetrics>, one for each k value
48//! for m in &metrics {
49//!     println!("Recall@{}: {:.2}%", m.k, m.recall_at_k * 100.0);
50//!     println!("Precision@{}: {:.2}%", m.k, m.precision_at_k * 100.0);
51//! }
52//! # Ok(())
53//! # }
54//! ```
55
56use anyhow::Result;
57use serde::{Deserialize, Serialize};
58use std::collections::HashSet;
59
60use crate::types::SearchResult;
61
62/// Evaluation configuration
63#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct EvaluationConfig {
65    /// Number of top results to consider for recall calculation
66    pub k_values: Vec<usize>,
67    /// Whether to calculate nDCG in addition to recall/precision
68    pub calculate_ndcg: bool,
69    /// Number of test queries to use (for batch evaluation)
70    pub num_test_queries: usize,
71}
72
73impl Default for EvaluationConfig {
74    fn default() -> Self {
75        Self {
76            k_values: vec![1, 5, 10, 20, 50, 100],
77            calculate_ndcg: true,
78            num_test_queries: 100,
79        }
80    }
81}
82
83impl EvaluationConfig {
84    /// Create config for quick evaluation
85    pub fn quick() -> Self {
86        Self {
87            k_values: vec![10, 20],
88            calculate_ndcg: false,
89            num_test_queries: 10,
90        }
91    }
92
93    /// Create config for comprehensive evaluation
94    pub fn comprehensive() -> Self {
95        Self {
96            k_values: vec![1, 5, 10, 20, 50, 100, 200],
97            calculate_ndcg: true,
98            num_test_queries: 1000,
99        }
100    }
101}
102
103/// Evaluation metrics for a single query
104#[derive(Debug, Clone, Serialize, Deserialize)]
105pub struct QueryMetrics {
106    /// k value (number of results)
107    pub k: usize,
108    /// Recall@k: Fraction of true top-k results found
109    pub recall_at_k: f32,
110    /// Precision@k: Fraction of returned results that are relevant
111    pub precision_at_k: f32,
112    /// nDCG@k: Normalized Discounted Cumulative Gain
113    pub ndcg_at_k: Option<f32>,
114    /// Number of true positives found
115    pub true_positives: usize,
116    /// Number of false positives found
117    pub false_positives: usize,
118}
119
120impl QueryMetrics {
121    /// Calculate F1 score (harmonic mean of precision and recall)
122    pub fn f1_score(&self) -> f32 {
123        if self.precision_at_k + self.recall_at_k == 0.0 {
124            0.0
125        } else {
126            2.0 * (self.precision_at_k * self.recall_at_k)
127                / (self.precision_at_k + self.recall_at_k)
128        }
129    }
130}
131
132/// Aggregated evaluation metrics across multiple queries
133#[derive(Debug, Clone, Serialize, Deserialize)]
134pub struct AggregatedMetrics {
135    /// k value
136    pub k: usize,
137    /// Average recall@k
138    pub avg_recall: f32,
139    /// Average precision@k
140    pub avg_precision: f32,
141    /// Average nDCG@k
142    pub avg_ndcg: Option<f32>,
143    /// Standard deviation of recall
144    pub std_recall: f32,
145    /// Standard deviation of precision
146    pub std_precision: f32,
147    /// Number of queries evaluated
148    pub num_queries: usize,
149}
150
151/// Recall evaluator for comparing ANN and exact search
152#[derive(Debug, Clone)]
153pub struct RecallEvaluator {
154    config: EvaluationConfig,
155}
156
157impl RecallEvaluator {
158    /// Create a new recall evaluator
159    pub fn new(config: EvaluationConfig) -> Self {
160        Self { config }
161    }
162
163    /// Evaluate a single query against ground truth
164    ///
165    /// # Arguments
166    /// * `query` - Query vector
167    /// * `exact_search` - Function that performs exact search (ground truth)
168    /// * `ann_search` - Function that performs approximate search
169    ///
170    /// # Returns
171    /// Metrics for each k value in the configuration
172    pub fn evaluate_single_query<F, G>(
173        &self,
174        query: &[f32],
175        exact_search: F,
176        ann_search: G,
177    ) -> Result<Vec<QueryMetrics>>
178    where
179        F: Fn(&[f32], usize) -> Result<Vec<SearchResult>>,
180        G: Fn(&[f32], usize) -> Result<Vec<SearchResult>>,
181    {
182        let mut metrics = Vec::new();
183
184        for &k in &self.config.k_values {
185            // Get ground truth
186            let ground_truth = exact_search(query, k)?;
187            let ground_truth_ids: HashSet<&str> =
188                ground_truth.iter().map(|r| r.entity_id.as_str()).collect();
189
190            // Get ANN results
191            let ann_results = ann_search(query, k)?;
192            let ann_ids: HashSet<&str> = ann_results.iter().map(|r| r.entity_id.as_str()).collect();
193
194            // Calculate recall and precision
195            let true_positives = ground_truth_ids.intersection(&ann_ids).count();
196            let false_positives = ann_results.len().saturating_sub(true_positives);
197
198            let recall_at_k = if !ground_truth_ids.is_empty() {
199                true_positives as f32 / ground_truth_ids.len() as f32
200            } else {
201                0.0
202            };
203
204            let precision_at_k = if !ann_results.is_empty() {
205                true_positives as f32 / ann_results.len() as f32
206            } else {
207                0.0
208            };
209
210            // Calculate nDCG if enabled
211            let ndcg_at_k = if self.config.calculate_ndcg {
212                Some(self.calculate_ndcg(&ground_truth, &ann_results, k))
213            } else {
214                None
215            };
216
217            metrics.push(QueryMetrics {
218                k,
219                recall_at_k,
220                precision_at_k,
221                ndcg_at_k,
222                true_positives,
223                false_positives,
224            });
225        }
226
227        Ok(metrics)
228    }
229
230    /// Evaluate multiple queries and aggregate results
231    ///
232    /// # Arguments
233    /// * `queries` - List of query vectors
234    /// * `exact_search` - Function that performs exact search (ground truth)
235    /// * `ann_search` - Function that performs approximate search
236    ///
237    /// # Returns
238    /// Aggregated metrics for each k value
239    pub fn evaluate_batch<F, G>(
240        &self,
241        queries: &[Vec<f32>],
242        exact_search: F,
243        ann_search: G,
244    ) -> Result<Vec<AggregatedMetrics>>
245    where
246        F: Fn(&[f32], usize) -> Result<Vec<SearchResult>>,
247        G: Fn(&[f32], usize) -> Result<Vec<SearchResult>>,
248    {
249        let mut all_metrics: Vec<Vec<QueryMetrics>> = Vec::new();
250
251        // Evaluate each query
252        for query in queries.iter().take(self.config.num_test_queries) {
253            let query_metrics = self.evaluate_single_query(query, &exact_search, &ann_search)?;
254            all_metrics.push(query_metrics);
255        }
256
257        // Aggregate results for each k
258        let mut aggregated = Vec::new();
259        for &k in &self.config.k_values {
260            let metrics_for_k: Vec<&QueryMetrics> = all_metrics
261                .iter()
262                .filter_map(|qm| qm.iter().find(|m| m.k == k))
263                .collect();
264
265            if metrics_for_k.is_empty() {
266                continue;
267            }
268
269            let recalls: Vec<f32> = metrics_for_k.iter().map(|m| m.recall_at_k).collect();
270            let precisions: Vec<f32> = metrics_for_k.iter().map(|m| m.precision_at_k).collect();
271
272            let avg_recall = recalls.iter().sum::<f32>() / recalls.len() as f32;
273            let avg_precision = precisions.iter().sum::<f32>() / precisions.len() as f32;
274
275            // Calculate standard deviations
276            let variance_recall = recalls
277                .iter()
278                .map(|r| (r - avg_recall).powi(2))
279                .sum::<f32>()
280                / recalls.len() as f32;
281            let std_recall = variance_recall.sqrt();
282
283            let variance_precision = precisions
284                .iter()
285                .map(|p| (p - avg_precision).powi(2))
286                .sum::<f32>()
287                / precisions.len() as f32;
288            let std_precision = variance_precision.sqrt();
289
290            let avg_ndcg = if self.config.calculate_ndcg {
291                let ndcgs: Vec<f32> = metrics_for_k.iter().filter_map(|m| m.ndcg_at_k).collect();
292                if !ndcgs.is_empty() {
293                    Some(ndcgs.iter().sum::<f32>() / ndcgs.len() as f32)
294                } else {
295                    None
296                }
297            } else {
298                None
299            };
300
301            aggregated.push(AggregatedMetrics {
302                k,
303                avg_recall,
304                avg_precision,
305                avg_ndcg,
306                std_recall,
307                std_precision,
308                num_queries: metrics_for_k.len(),
309            });
310        }
311
312        Ok(aggregated)
313    }
314
315    /// Calculate Normalized Discounted Cumulative Gain (nDCG@k)
316    fn calculate_ndcg(
317        &self,
318        ground_truth: &[SearchResult],
319        ann_results: &[SearchResult],
320        k: usize,
321    ) -> f32 {
322        if ground_truth.is_empty() || ann_results.is_empty() {
323            return 0.0;
324        }
325
326        // Create relevance map from ground truth (position-based relevance)
327        let relevance_map: std::collections::HashMap<&str, f32> = ground_truth
328            .iter()
329            .enumerate()
330            .map(|(i, r)| {
331                let relevance = (k - i) as f32; // Higher relevance for higher-ranked items
332                (r.entity_id.as_str(), relevance)
333            })
334            .collect();
335
336        // Calculate DCG for ANN results
337        let dcg: f32 = ann_results
338            .iter()
339            .take(k)
340            .enumerate()
341            .map(|(i, result)| {
342                let relevance = relevance_map.get(result.entity_id.as_str()).unwrap_or(&0.0);
343                let discount = ((i + 2) as f32).log2(); // +2 because we start from position 1, not 0
344                relevance / discount
345            })
346            .sum();
347
348        // Calculate IDCG (ideal DCG) - what we'd get with perfect ranking
349        let idcg: f32 = ground_truth
350            .iter()
351            .take(k)
352            .enumerate()
353            .map(|(i, _)| {
354                let relevance = (k - i) as f32;
355                let discount = ((i + 2) as f32).log2();
356                relevance / discount
357            })
358            .sum();
359
360        if idcg == 0.0 {
361            0.0
362        } else {
363            dcg / idcg
364        }
365    }
366}
367
368#[cfg(test)]
369mod tests {
370    use super::*;
371    use crate::types::SearchResult;
372
373    fn create_search_results(ids: &[&str], scores: &[f32]) -> Vec<SearchResult> {
374        ids.iter()
375            .zip(scores.iter())
376            .enumerate()
377            .map(|(rank, (id, score))| SearchResult {
378                entity_id: id.to_string(),
379                score: *score,
380                distance: 1.0 - score, // Approximate distance from score
381                rank: rank + 1,
382            })
383            .collect()
384    }
385
386    #[test]
387    fn test_perfect_recall() {
388        let config = EvaluationConfig {
389            k_values: vec![3],
390            calculate_ndcg: false,
391            num_test_queries: 10,
392        };
393        let evaluator = RecallEvaluator::new(config);
394
395        let query = vec![1.0, 2.0, 3.0];
396
397        let exact_fn = |_q: &[f32], _k: usize| {
398            Ok(create_search_results(
399                &["doc1", "doc2", "doc3"],
400                &[0.9, 0.8, 0.7],
401            ))
402        };
403
404        let ann_fn = |_q: &[f32], _k: usize| {
405            Ok(create_search_results(
406                &["doc1", "doc2", "doc3"],
407                &[0.9, 0.8, 0.7],
408            ))
409        };
410
411        let metrics = evaluator
412            .evaluate_single_query(&query, exact_fn, ann_fn)
413            .unwrap();
414
415        assert_eq!(metrics.len(), 1);
416        assert_eq!(metrics[0].k, 3);
417        assert!((metrics[0].recall_at_k - 1.0).abs() < 1e-6);
418        assert!((metrics[0].precision_at_k - 1.0).abs() < 1e-6);
419        assert_eq!(metrics[0].true_positives, 3);
420        assert_eq!(metrics[0].false_positives, 0);
421    }
422
423    #[test]
424    fn test_partial_recall() {
425        let config = EvaluationConfig {
426            k_values: vec![3],
427            calculate_ndcg: false,
428            num_test_queries: 10,
429        };
430        let evaluator = RecallEvaluator::new(config);
431
432        let query = vec![1.0, 2.0, 3.0];
433
434        let exact_fn = |_q: &[f32], _k: usize| {
435            Ok(create_search_results(
436                &["doc1", "doc2", "doc3"],
437                &[0.9, 0.8, 0.7],
438            ))
439        };
440
441        // ANN returns only 2 out of 3 correct results
442        let ann_fn = |_q: &[f32], _k: usize| {
443            Ok(create_search_results(
444                &["doc1", "doc2", "doc4"],
445                &[0.9, 0.8, 0.6],
446            ))
447        };
448
449        let metrics = evaluator
450            .evaluate_single_query(&query, exact_fn, ann_fn)
451            .unwrap();
452
453        assert_eq!(metrics.len(), 1);
454        assert!((metrics[0].recall_at_k - 2.0 / 3.0).abs() < 1e-6); // 2 out of 3
455        assert!((metrics[0].precision_at_k - 2.0 / 3.0).abs() < 1e-6);
456        assert_eq!(metrics[0].true_positives, 2);
457        assert_eq!(metrics[0].false_positives, 1);
458    }
459
460    #[test]
461    fn test_zero_recall() {
462        let config = EvaluationConfig {
463            k_values: vec![3],
464            calculate_ndcg: false,
465            num_test_queries: 10,
466        };
467        let evaluator = RecallEvaluator::new(config);
468
469        let query = vec![1.0, 2.0, 3.0];
470
471        let exact_fn = |_q: &[f32], _k: usize| {
472            Ok(create_search_results(
473                &["doc1", "doc2", "doc3"],
474                &[0.9, 0.8, 0.7],
475            ))
476        };
477
478        // ANN returns completely different results
479        let ann_fn = |_q: &[f32], _k: usize| {
480            Ok(create_search_results(
481                &["doc4", "doc5", "doc6"],
482                &[0.6, 0.5, 0.4],
483            ))
484        };
485
486        let metrics = evaluator
487            .evaluate_single_query(&query, exact_fn, ann_fn)
488            .unwrap();
489
490        assert_eq!(metrics.len(), 1);
491        assert!((metrics[0].recall_at_k - 0.0).abs() < 1e-6);
492        assert!((metrics[0].precision_at_k - 0.0).abs() < 1e-6);
493        assert_eq!(metrics[0].true_positives, 0);
494        assert_eq!(metrics[0].false_positives, 3);
495    }
496
497    #[test]
498    fn test_f1_score() {
499        let metrics = QueryMetrics {
500            k: 10,
501            recall_at_k: 0.8,
502            precision_at_k: 0.6,
503            ndcg_at_k: None,
504            true_positives: 8,
505            false_positives: 2,
506        };
507
508        let f1 = metrics.f1_score();
509        let expected_f1 = 2.0 * (0.8 * 0.6) / (0.8 + 0.6);
510        assert!((f1 - expected_f1).abs() < 1e-6);
511    }
512
513    #[test]
514    fn test_f1_score_zero() {
515        let metrics = QueryMetrics {
516            k: 10,
517            recall_at_k: 0.0,
518            precision_at_k: 0.0,
519            ndcg_at_k: None,
520            true_positives: 0,
521            false_positives: 10,
522        };
523
524        let f1 = metrics.f1_score();
525        assert_eq!(f1, 0.0);
526    }
527
528    #[test]
529    fn test_ndcg_perfect() {
530        let config = EvaluationConfig {
531            k_values: vec![3],
532            calculate_ndcg: true,
533            num_test_queries: 10,
534        };
535        let evaluator = RecallEvaluator::new(config);
536
537        let ground_truth = create_search_results(&["doc1", "doc2", "doc3"], &[1.0, 0.9, 0.8]);
538        let ann_results = create_search_results(&["doc1", "doc2", "doc3"], &[1.0, 0.9, 0.8]);
539
540        let ndcg = evaluator.calculate_ndcg(&ground_truth, &ann_results, 3);
541        assert!((ndcg - 1.0).abs() < 1e-6); // Perfect ranking should have nDCG = 1.0
542    }
543
544    #[test]
545    fn test_ndcg_reversed() {
546        let config = EvaluationConfig {
547            k_values: vec![3],
548            calculate_ndcg: true,
549            num_test_queries: 10,
550        };
551        let evaluator = RecallEvaluator::new(config);
552
553        let ground_truth = create_search_results(&["doc1", "doc2", "doc3"], &[1.0, 0.9, 0.8]);
554        let ann_results = create_search_results(&["doc3", "doc2", "doc1"], &[0.8, 0.9, 1.0]); // Reversed
555
556        let ndcg = evaluator.calculate_ndcg(&ground_truth, &ann_results, 3);
557        assert!(ndcg > 0.0 && ndcg < 1.0); // Not perfect but not zero
558    }
559
560    #[test]
561    fn test_batch_evaluation() {
562        let config = EvaluationConfig {
563            k_values: vec![3],
564            calculate_ndcg: false,
565            num_test_queries: 2,
566        };
567        let evaluator = RecallEvaluator::new(config);
568
569        let queries = vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]];
570
571        let exact_fn = |_q: &[f32], _k: usize| {
572            Ok(create_search_results(
573                &["doc1", "doc2", "doc3"],
574                &[0.9, 0.8, 0.7],
575            ))
576        };
577
578        let ann_fn = |_q: &[f32], _k: usize| {
579            Ok(create_search_results(
580                &["doc1", "doc2", "doc4"],
581                &[0.9, 0.8, 0.6],
582            ))
583        };
584
585        let aggregated = evaluator
586            .evaluate_batch(&queries, exact_fn, ann_fn)
587            .unwrap();
588
589        assert_eq!(aggregated.len(), 1);
590        assert_eq!(aggregated[0].k, 3);
591        assert_eq!(aggregated[0].num_queries, 2);
592        assert!((aggregated[0].avg_recall - 2.0 / 3.0).abs() < 1e-6);
593        assert!((aggregated[0].avg_precision - 2.0 / 3.0).abs() < 1e-6);
594    }
595
596    #[test]
597    fn test_multiple_k_values() {
598        let config = EvaluationConfig {
599            k_values: vec![1, 2, 3],
600            calculate_ndcg: false,
601            num_test_queries: 10,
602        };
603        let evaluator = RecallEvaluator::new(config);
604
605        let query = vec![1.0, 2.0, 3.0];
606
607        let exact_fn = |_q: &[f32], k: usize| {
608            let all_results = create_search_results(&["doc1", "doc2", "doc3"], &[0.9, 0.8, 0.7]);
609            Ok(all_results.into_iter().take(k).collect())
610        };
611
612        let ann_fn = |_q: &[f32], k: usize| {
613            let all_results = create_search_results(&["doc1", "doc4", "doc5"], &[0.9, 0.7, 0.6]);
614            Ok(all_results.into_iter().take(k).collect())
615        };
616
617        let metrics = evaluator
618            .evaluate_single_query(&query, exact_fn, ann_fn)
619            .unwrap();
620
621        assert_eq!(metrics.len(), 3);
622        assert_eq!(metrics[0].k, 1);
623        assert_eq!(metrics[1].k, 2);
624        assert_eq!(metrics[2].k, 3);
625
626        // At k=1, we found 1/1 = 100% recall
627        assert!((metrics[0].recall_at_k - 1.0).abs() < 1e-6);
628
629        // At k=2, we found 1/2 = 50% recall
630        assert!((metrics[1].recall_at_k - 0.5).abs() < 1e-6);
631
632        // At k=3, we found 1/3 = 33.3% recall
633        assert!((metrics[2].recall_at_k - 1.0 / 3.0).abs() < 1e-6);
634    }
635
636    #[test]
637    fn test_evaluation_config_presets() {
638        let quick = EvaluationConfig::quick();
639        assert_eq!(quick.k_values.len(), 2);
640        assert!(!quick.calculate_ndcg);
641        assert_eq!(quick.num_test_queries, 10);
642
643        let comprehensive = EvaluationConfig::comprehensive();
644        assert_eq!(comprehensive.k_values.len(), 7);
645        assert!(comprehensive.calculate_ndcg);
646        assert_eq!(comprehensive.num_test_queries, 1000);
647    }
648}