1use anyhow::Result;
57use serde::{Deserialize, Serialize};
58use std::collections::HashSet;
59
60use crate::types::SearchResult;
61
62#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct EvaluationConfig {
65 pub k_values: Vec<usize>,
67 pub calculate_ndcg: bool,
69 pub num_test_queries: usize,
71}
72
73impl Default for EvaluationConfig {
74 fn default() -> Self {
75 Self {
76 k_values: vec![1, 5, 10, 20, 50, 100],
77 calculate_ndcg: true,
78 num_test_queries: 100,
79 }
80 }
81}
82
83impl EvaluationConfig {
84 pub fn quick() -> Self {
86 Self {
87 k_values: vec![10, 20],
88 calculate_ndcg: false,
89 num_test_queries: 10,
90 }
91 }
92
93 pub fn comprehensive() -> Self {
95 Self {
96 k_values: vec![1, 5, 10, 20, 50, 100, 200],
97 calculate_ndcg: true,
98 num_test_queries: 1000,
99 }
100 }
101}
102
103#[derive(Debug, Clone, Serialize, Deserialize)]
105pub struct QueryMetrics {
106 pub k: usize,
108 pub recall_at_k: f32,
110 pub precision_at_k: f32,
112 pub ndcg_at_k: Option<f32>,
114 pub true_positives: usize,
116 pub false_positives: usize,
118}
119
120impl QueryMetrics {
121 pub fn f1_score(&self) -> f32 {
123 if self.precision_at_k + self.recall_at_k == 0.0 {
124 0.0
125 } else {
126 2.0 * (self.precision_at_k * self.recall_at_k)
127 / (self.precision_at_k + self.recall_at_k)
128 }
129 }
130}
131
132#[derive(Debug, Clone, Serialize, Deserialize)]
134pub struct AggregatedMetrics {
135 pub k: usize,
137 pub avg_recall: f32,
139 pub avg_precision: f32,
141 pub avg_ndcg: Option<f32>,
143 pub std_recall: f32,
145 pub std_precision: f32,
147 pub num_queries: usize,
149}
150
151#[derive(Debug, Clone)]
153pub struct RecallEvaluator {
154 config: EvaluationConfig,
155}
156
157impl RecallEvaluator {
158 pub fn new(config: EvaluationConfig) -> Self {
160 Self { config }
161 }
162
163 pub fn evaluate_single_query<F, G>(
173 &self,
174 query: &[f32],
175 exact_search: F,
176 ann_search: G,
177 ) -> Result<Vec<QueryMetrics>>
178 where
179 F: Fn(&[f32], usize) -> Result<Vec<SearchResult>>,
180 G: Fn(&[f32], usize) -> Result<Vec<SearchResult>>,
181 {
182 let mut metrics = Vec::new();
183
184 for &k in &self.config.k_values {
185 let ground_truth = exact_search(query, k)?;
187 let ground_truth_ids: HashSet<&str> =
188 ground_truth.iter().map(|r| r.entity_id.as_str()).collect();
189
190 let ann_results = ann_search(query, k)?;
192 let ann_ids: HashSet<&str> = ann_results.iter().map(|r| r.entity_id.as_str()).collect();
193
194 let true_positives = ground_truth_ids.intersection(&ann_ids).count();
196 let false_positives = ann_results.len().saturating_sub(true_positives);
197
198 let recall_at_k = if !ground_truth_ids.is_empty() {
199 true_positives as f32 / ground_truth_ids.len() as f32
200 } else {
201 0.0
202 };
203
204 let precision_at_k = if !ann_results.is_empty() {
205 true_positives as f32 / ann_results.len() as f32
206 } else {
207 0.0
208 };
209
210 let ndcg_at_k = if self.config.calculate_ndcg {
212 Some(self.calculate_ndcg(&ground_truth, &ann_results, k))
213 } else {
214 None
215 };
216
217 metrics.push(QueryMetrics {
218 k,
219 recall_at_k,
220 precision_at_k,
221 ndcg_at_k,
222 true_positives,
223 false_positives,
224 });
225 }
226
227 Ok(metrics)
228 }
229
230 pub fn evaluate_batch<F, G>(
240 &self,
241 queries: &[Vec<f32>],
242 exact_search: F,
243 ann_search: G,
244 ) -> Result<Vec<AggregatedMetrics>>
245 where
246 F: Fn(&[f32], usize) -> Result<Vec<SearchResult>>,
247 G: Fn(&[f32], usize) -> Result<Vec<SearchResult>>,
248 {
249 let mut all_metrics: Vec<Vec<QueryMetrics>> = Vec::new();
250
251 for query in queries.iter().take(self.config.num_test_queries) {
253 let query_metrics = self.evaluate_single_query(query, &exact_search, &ann_search)?;
254 all_metrics.push(query_metrics);
255 }
256
257 let mut aggregated = Vec::new();
259 for &k in &self.config.k_values {
260 let metrics_for_k: Vec<&QueryMetrics> = all_metrics
261 .iter()
262 .filter_map(|qm| qm.iter().find(|m| m.k == k))
263 .collect();
264
265 if metrics_for_k.is_empty() {
266 continue;
267 }
268
269 let recalls: Vec<f32> = metrics_for_k.iter().map(|m| m.recall_at_k).collect();
270 let precisions: Vec<f32> = metrics_for_k.iter().map(|m| m.precision_at_k).collect();
271
272 let avg_recall = recalls.iter().sum::<f32>() / recalls.len() as f32;
273 let avg_precision = precisions.iter().sum::<f32>() / precisions.len() as f32;
274
275 let variance_recall = recalls
277 .iter()
278 .map(|r| (r - avg_recall).powi(2))
279 .sum::<f32>()
280 / recalls.len() as f32;
281 let std_recall = variance_recall.sqrt();
282
283 let variance_precision = precisions
284 .iter()
285 .map(|p| (p - avg_precision).powi(2))
286 .sum::<f32>()
287 / precisions.len() as f32;
288 let std_precision = variance_precision.sqrt();
289
290 let avg_ndcg = if self.config.calculate_ndcg {
291 let ndcgs: Vec<f32> = metrics_for_k.iter().filter_map(|m| m.ndcg_at_k).collect();
292 if !ndcgs.is_empty() {
293 Some(ndcgs.iter().sum::<f32>() / ndcgs.len() as f32)
294 } else {
295 None
296 }
297 } else {
298 None
299 };
300
301 aggregated.push(AggregatedMetrics {
302 k,
303 avg_recall,
304 avg_precision,
305 avg_ndcg,
306 std_recall,
307 std_precision,
308 num_queries: metrics_for_k.len(),
309 });
310 }
311
312 Ok(aggregated)
313 }
314
315 fn calculate_ndcg(
317 &self,
318 ground_truth: &[SearchResult],
319 ann_results: &[SearchResult],
320 k: usize,
321 ) -> f32 {
322 if ground_truth.is_empty() || ann_results.is_empty() {
323 return 0.0;
324 }
325
326 let relevance_map: std::collections::HashMap<&str, f32> = ground_truth
328 .iter()
329 .enumerate()
330 .map(|(i, r)| {
331 let relevance = (k - i) as f32; (r.entity_id.as_str(), relevance)
333 })
334 .collect();
335
336 let dcg: f32 = ann_results
338 .iter()
339 .take(k)
340 .enumerate()
341 .map(|(i, result)| {
342 let relevance = relevance_map.get(result.entity_id.as_str()).unwrap_or(&0.0);
343 let discount = ((i + 2) as f32).log2(); relevance / discount
345 })
346 .sum();
347
348 let idcg: f32 = ground_truth
350 .iter()
351 .take(k)
352 .enumerate()
353 .map(|(i, _)| {
354 let relevance = (k - i) as f32;
355 let discount = ((i + 2) as f32).log2();
356 relevance / discount
357 })
358 .sum();
359
360 if idcg == 0.0 {
361 0.0
362 } else {
363 dcg / idcg
364 }
365 }
366}
367
368#[cfg(test)]
369mod tests {
370 use super::*;
371 use crate::types::SearchResult;
372
373 fn create_search_results(ids: &[&str], scores: &[f32]) -> Vec<SearchResult> {
374 ids.iter()
375 .zip(scores.iter())
376 .enumerate()
377 .map(|(rank, (id, score))| SearchResult {
378 entity_id: id.to_string(),
379 score: *score,
380 distance: 1.0 - score, rank: rank + 1,
382 })
383 .collect()
384 }
385
386 #[test]
387 fn test_perfect_recall() {
388 let config = EvaluationConfig {
389 k_values: vec![3],
390 calculate_ndcg: false,
391 num_test_queries: 10,
392 };
393 let evaluator = RecallEvaluator::new(config);
394
395 let query = vec![1.0, 2.0, 3.0];
396
397 let exact_fn = |_q: &[f32], _k: usize| {
398 Ok(create_search_results(
399 &["doc1", "doc2", "doc3"],
400 &[0.9, 0.8, 0.7],
401 ))
402 };
403
404 let ann_fn = |_q: &[f32], _k: usize| {
405 Ok(create_search_results(
406 &["doc1", "doc2", "doc3"],
407 &[0.9, 0.8, 0.7],
408 ))
409 };
410
411 let metrics = evaluator
412 .evaluate_single_query(&query, exact_fn, ann_fn)
413 .unwrap();
414
415 assert_eq!(metrics.len(), 1);
416 assert_eq!(metrics[0].k, 3);
417 assert!((metrics[0].recall_at_k - 1.0).abs() < 1e-6);
418 assert!((metrics[0].precision_at_k - 1.0).abs() < 1e-6);
419 assert_eq!(metrics[0].true_positives, 3);
420 assert_eq!(metrics[0].false_positives, 0);
421 }
422
423 #[test]
424 fn test_partial_recall() {
425 let config = EvaluationConfig {
426 k_values: vec![3],
427 calculate_ndcg: false,
428 num_test_queries: 10,
429 };
430 let evaluator = RecallEvaluator::new(config);
431
432 let query = vec![1.0, 2.0, 3.0];
433
434 let exact_fn = |_q: &[f32], _k: usize| {
435 Ok(create_search_results(
436 &["doc1", "doc2", "doc3"],
437 &[0.9, 0.8, 0.7],
438 ))
439 };
440
441 let ann_fn = |_q: &[f32], _k: usize| {
443 Ok(create_search_results(
444 &["doc1", "doc2", "doc4"],
445 &[0.9, 0.8, 0.6],
446 ))
447 };
448
449 let metrics = evaluator
450 .evaluate_single_query(&query, exact_fn, ann_fn)
451 .unwrap();
452
453 assert_eq!(metrics.len(), 1);
454 assert!((metrics[0].recall_at_k - 2.0 / 3.0).abs() < 1e-6); assert!((metrics[0].precision_at_k - 2.0 / 3.0).abs() < 1e-6);
456 assert_eq!(metrics[0].true_positives, 2);
457 assert_eq!(metrics[0].false_positives, 1);
458 }
459
460 #[test]
461 fn test_zero_recall() {
462 let config = EvaluationConfig {
463 k_values: vec![3],
464 calculate_ndcg: false,
465 num_test_queries: 10,
466 };
467 let evaluator = RecallEvaluator::new(config);
468
469 let query = vec![1.0, 2.0, 3.0];
470
471 let exact_fn = |_q: &[f32], _k: usize| {
472 Ok(create_search_results(
473 &["doc1", "doc2", "doc3"],
474 &[0.9, 0.8, 0.7],
475 ))
476 };
477
478 let ann_fn = |_q: &[f32], _k: usize| {
480 Ok(create_search_results(
481 &["doc4", "doc5", "doc6"],
482 &[0.6, 0.5, 0.4],
483 ))
484 };
485
486 let metrics = evaluator
487 .evaluate_single_query(&query, exact_fn, ann_fn)
488 .unwrap();
489
490 assert_eq!(metrics.len(), 1);
491 assert!((metrics[0].recall_at_k - 0.0).abs() < 1e-6);
492 assert!((metrics[0].precision_at_k - 0.0).abs() < 1e-6);
493 assert_eq!(metrics[0].true_positives, 0);
494 assert_eq!(metrics[0].false_positives, 3);
495 }
496
497 #[test]
498 fn test_f1_score() {
499 let metrics = QueryMetrics {
500 k: 10,
501 recall_at_k: 0.8,
502 precision_at_k: 0.6,
503 ndcg_at_k: None,
504 true_positives: 8,
505 false_positives: 2,
506 };
507
508 let f1 = metrics.f1_score();
509 let expected_f1 = 2.0 * (0.8 * 0.6) / (0.8 + 0.6);
510 assert!((f1 - expected_f1).abs() < 1e-6);
511 }
512
513 #[test]
514 fn test_f1_score_zero() {
515 let metrics = QueryMetrics {
516 k: 10,
517 recall_at_k: 0.0,
518 precision_at_k: 0.0,
519 ndcg_at_k: None,
520 true_positives: 0,
521 false_positives: 10,
522 };
523
524 let f1 = metrics.f1_score();
525 assert_eq!(f1, 0.0);
526 }
527
528 #[test]
529 fn test_ndcg_perfect() {
530 let config = EvaluationConfig {
531 k_values: vec![3],
532 calculate_ndcg: true,
533 num_test_queries: 10,
534 };
535 let evaluator = RecallEvaluator::new(config);
536
537 let ground_truth = create_search_results(&["doc1", "doc2", "doc3"], &[1.0, 0.9, 0.8]);
538 let ann_results = create_search_results(&["doc1", "doc2", "doc3"], &[1.0, 0.9, 0.8]);
539
540 let ndcg = evaluator.calculate_ndcg(&ground_truth, &ann_results, 3);
541 assert!((ndcg - 1.0).abs() < 1e-6); }
543
544 #[test]
545 fn test_ndcg_reversed() {
546 let config = EvaluationConfig {
547 k_values: vec![3],
548 calculate_ndcg: true,
549 num_test_queries: 10,
550 };
551 let evaluator = RecallEvaluator::new(config);
552
553 let ground_truth = create_search_results(&["doc1", "doc2", "doc3"], &[1.0, 0.9, 0.8]);
554 let ann_results = create_search_results(&["doc3", "doc2", "doc1"], &[0.8, 0.9, 1.0]); let ndcg = evaluator.calculate_ndcg(&ground_truth, &ann_results, 3);
557 assert!(ndcg > 0.0 && ndcg < 1.0); }
559
560 #[test]
561 fn test_batch_evaluation() {
562 let config = EvaluationConfig {
563 k_values: vec![3],
564 calculate_ndcg: false,
565 num_test_queries: 2,
566 };
567 let evaluator = RecallEvaluator::new(config);
568
569 let queries = vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]];
570
571 let exact_fn = |_q: &[f32], _k: usize| {
572 Ok(create_search_results(
573 &["doc1", "doc2", "doc3"],
574 &[0.9, 0.8, 0.7],
575 ))
576 };
577
578 let ann_fn = |_q: &[f32], _k: usize| {
579 Ok(create_search_results(
580 &["doc1", "doc2", "doc4"],
581 &[0.9, 0.8, 0.6],
582 ))
583 };
584
585 let aggregated = evaluator
586 .evaluate_batch(&queries, exact_fn, ann_fn)
587 .unwrap();
588
589 assert_eq!(aggregated.len(), 1);
590 assert_eq!(aggregated[0].k, 3);
591 assert_eq!(aggregated[0].num_queries, 2);
592 assert!((aggregated[0].avg_recall - 2.0 / 3.0).abs() < 1e-6);
593 assert!((aggregated[0].avg_precision - 2.0 / 3.0).abs() < 1e-6);
594 }
595
596 #[test]
597 fn test_multiple_k_values() {
598 let config = EvaluationConfig {
599 k_values: vec![1, 2, 3],
600 calculate_ndcg: false,
601 num_test_queries: 10,
602 };
603 let evaluator = RecallEvaluator::new(config);
604
605 let query = vec![1.0, 2.0, 3.0];
606
607 let exact_fn = |_q: &[f32], k: usize| {
608 let all_results = create_search_results(&["doc1", "doc2", "doc3"], &[0.9, 0.8, 0.7]);
609 Ok(all_results.into_iter().take(k).collect())
610 };
611
612 let ann_fn = |_q: &[f32], k: usize| {
613 let all_results = create_search_results(&["doc1", "doc4", "doc5"], &[0.9, 0.7, 0.6]);
614 Ok(all_results.into_iter().take(k).collect())
615 };
616
617 let metrics = evaluator
618 .evaluate_single_query(&query, exact_fn, ann_fn)
619 .unwrap();
620
621 assert_eq!(metrics.len(), 3);
622 assert_eq!(metrics[0].k, 1);
623 assert_eq!(metrics[1].k, 2);
624 assert_eq!(metrics[2].k, 3);
625
626 assert!((metrics[0].recall_at_k - 1.0).abs() < 1e-6);
628
629 assert!((metrics[1].recall_at_k - 0.5).abs() < 1e-6);
631
632 assert!((metrics[2].recall_at_k - 1.0 / 3.0).abs() < 1e-6);
634 }
635
636 #[test]
637 fn test_evaluation_config_presets() {
638 let quick = EvaluationConfig::quick();
639 assert_eq!(quick.k_values.len(), 2);
640 assert!(!quick.calculate_ndcg);
641 assert_eq!(quick.num_test_queries, 10);
642
643 let comprehensive = EvaluationConfig::comprehensive();
644 assert_eq!(comprehensive.k_values.len(), 7);
645 assert!(comprehensive.calculate_ndcg);
646 assert_eq!(comprehensive.num_test_queries, 1000);
647 }
648}