1use std::collections::HashMap;
8use std::time::{Duration, Instant};
9
10pub type Neighbour = (usize, f32);
14
15pub fn brute_force_knn(
21 dataset: &[Vec<f32>],
22 queries: &[Vec<f32>],
23 k: usize,
24) -> Vec<Vec<Neighbour>> {
25 queries
26 .iter()
27 .map(|q| {
28 let mut dists: Vec<(usize, f32)> = dataset
29 .iter()
30 .enumerate()
31 .map(|(i, v)| {
32 let d: f32 = q.iter().zip(v.iter()).map(|(a, b)| (a - b) * (a - b)).sum();
33 (i, d)
34 })
35 .collect();
36 dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
37 dists.truncate(k);
38 dists
39 })
40 .collect()
41}
42
43pub fn recall_at_k(ground_truth: &[Vec<Neighbour>], approximate: &[Vec<Neighbour>]) -> f64 {
50 if ground_truth.is_empty() {
51 return 0.0;
52 }
53 let mut total_recall = 0.0;
54 for (gt, ap) in ground_truth.iter().zip(approximate.iter()) {
55 let gt_ids: std::collections::HashSet<usize> = gt.iter().map(|n| n.0).collect();
56 let ap_ids: std::collections::HashSet<usize> = ap.iter().map(|n| n.0).collect();
57 let found = gt_ids.intersection(&ap_ids).count();
58 if gt_ids.is_empty() {
59 continue;
60 }
61 total_recall += found as f64 / gt_ids.len() as f64;
62 }
63 total_recall / ground_truth.len() as f64
64}
65
66pub fn per_query_recall(
68 ground_truth: &[Vec<Neighbour>],
69 approximate: &[Vec<Neighbour>],
70) -> Vec<f64> {
71 ground_truth
72 .iter()
73 .zip(approximate.iter())
74 .map(|(gt, ap)| {
75 let gt_ids: std::collections::HashSet<usize> = gt.iter().map(|n| n.0).collect();
76 let ap_ids: std::collections::HashSet<usize> = ap.iter().map(|n| n.0).collect();
77 let found = gt_ids.intersection(&ap_ids).count();
78 if gt_ids.is_empty() {
79 0.0
80 } else {
81 found as f64 / gt_ids.len() as f64
82 }
83 })
84 .collect()
85}
86
87pub fn precision(ground_truth: &[Vec<Neighbour>], approximate: &[Vec<Neighbour>]) -> f64 {
91 if ground_truth.is_empty() {
92 return 0.0;
93 }
94 let mut total = 0.0;
95 for (gt, ap) in ground_truth.iter().zip(approximate.iter()) {
96 let gt_ids: std::collections::HashSet<usize> = gt.iter().map(|n| n.0).collect();
97 let ap_ids: std::collections::HashSet<usize> = ap.iter().map(|n| n.0).collect();
98 let found = gt_ids.intersection(&ap_ids).count();
99 if ap_ids.is_empty() {
100 continue;
101 }
102 total += found as f64 / ap_ids.len() as f64;
103 }
104 total / ground_truth.len() as f64
105}
106
107pub fn measure_qps<F>(queries: &[Vec<f32>], mut search_fn: F) -> QpsResult
113where
114 F: FnMut(&[f32]) -> Vec<Neighbour>,
115{
116 let mut latencies = Vec::with_capacity(queries.len());
117 let overall_start = Instant::now();
118
119 for q in queries {
120 let start = Instant::now();
121 let _ = search_fn(q);
122 latencies.push(start.elapsed());
123 }
124
125 let total_time = overall_start.elapsed();
126 let qps = if total_time.as_secs_f64() > 0.0 {
127 queries.len() as f64 / total_time.as_secs_f64()
128 } else {
129 0.0
130 };
131
132 latencies.sort();
133
134 QpsResult {
135 qps,
136 total_queries: queries.len(),
137 total_time,
138 latencies,
139 }
140}
141
142#[derive(Debug, Clone)]
144pub struct QpsResult {
145 pub qps: f64,
147 pub total_queries: usize,
149 pub total_time: Duration,
151 pub latencies: Vec<Duration>,
153}
154
155impl QpsResult {
156 pub fn p50(&self) -> Duration {
158 percentile_duration(&self.latencies, 50.0)
159 }
160
161 pub fn p95(&self) -> Duration {
163 percentile_duration(&self.latencies, 95.0)
164 }
165
166 pub fn p99(&self) -> Duration {
168 percentile_duration(&self.latencies, 99.0)
169 }
170
171 pub fn mean_latency(&self) -> Duration {
173 if self.latencies.is_empty() {
174 return Duration::ZERO;
175 }
176 let total: Duration = self.latencies.iter().sum();
177 total / self.latencies.len() as u32
178 }
179
180 pub fn min_latency(&self) -> Duration {
182 self.latencies.first().copied().unwrap_or(Duration::ZERO)
183 }
184
185 pub fn max_latency(&self) -> Duration {
187 self.latencies.last().copied().unwrap_or(Duration::ZERO)
188 }
189}
190
191fn percentile_duration(sorted: &[Duration], pct: f64) -> Duration {
193 if sorted.is_empty() {
194 return Duration::ZERO;
195 }
196 let idx = ((pct / 100.0) * (sorted.len() as f64 - 1.0))
197 .round()
198 .max(0.0) as usize;
199 let idx = idx.min(sorted.len() - 1);
200 sorted[idx]
201}
202
203pub struct BuildTimer {
207 label: String,
208 start: Instant,
209}
210
211impl BuildTimer {
212 pub fn start(label: impl Into<String>) -> Self {
214 Self {
215 label: label.into(),
216 start: Instant::now(),
217 }
218 }
219
220 pub fn stop(self) -> BuildTimeResult {
222 BuildTimeResult {
223 label: self.label,
224 duration: self.start.elapsed(),
225 }
226 }
227}
228
229#[derive(Debug, Clone)]
231pub struct BuildTimeResult {
232 pub label: String,
234 pub duration: Duration,
236}
237
238pub fn estimate_flat_memory(n_vectors: usize, dimension: usize) -> usize {
244 n_vectors * dimension * std::mem::size_of::<f32>()
245}
246
247pub fn estimate_hnsw_memory(
251 n_vectors: usize,
252 dimension: usize,
253 m: usize, n_levels: usize, ) -> usize {
256 let vector_bytes = n_vectors * dimension * std::mem::size_of::<f32>();
257 let graph_bytes = n_vectors * m * n_levels * std::mem::size_of::<usize>();
259 vector_bytes + graph_bytes
260}
261
262pub fn estimate_pq_memory(n_vectors: usize, n_subspaces: usize) -> usize {
264 n_vectors * n_subspaces
266}
267
268#[derive(Debug, Clone)]
272pub struct PrecisionRecallPoint {
273 pub recall: f64,
275 pub precision: f64,
277 pub parameter: String,
279}
280
281pub fn precision_recall_sweep<F>(
287 ground_truth: &[Vec<Neighbour>],
288 queries: &[Vec<f32>],
289 param_values: &[String],
290 mut search_with_param: F,
291) -> Vec<PrecisionRecallPoint>
292where
293 F: FnMut(&str, &[Vec<f32>]) -> Vec<Vec<Neighbour>>,
294{
295 let mut curve = Vec::with_capacity(param_values.len());
296 for param in param_values {
297 let approx = search_with_param(param, queries);
298 let r = recall_at_k(ground_truth, &approx);
299 let p = precision(ground_truth, &approx);
300 curve.push(PrecisionRecallPoint {
301 recall: r,
302 precision: p,
303 parameter: param.clone(),
304 });
305 }
306 curve
307}
308
309#[derive(Debug, Clone)]
313pub struct BenchmarkReport {
314 pub index_name: String,
316 pub dataset_size: usize,
318 pub dimension: usize,
320 pub n_queries: usize,
322 pub k: usize,
324 pub recall: f64,
326 pub precision: f64,
328 pub qps: f64,
330 pub p50_us: u64,
332 pub p95_us: u64,
334 pub p99_us: u64,
336 pub memory_bytes: usize,
338 pub build_time_ms: u64,
340 pub metadata: HashMap<String, String>,
342}
343
344impl BenchmarkReport {
345 pub fn to_text(&self) -> String {
347 let mut out = String::new();
348 out.push_str(&format!(
349 "=== ANN Benchmark Report: {} ===\n",
350 self.index_name
351 ));
352 out.push_str(&format!(
353 "Dataset: {} vectors × {} dims\n",
354 self.dataset_size, self.dimension
355 ));
356 out.push_str(&format!("Queries: {}, k={}\n", self.n_queries, self.k));
357 out.push_str(&format!("Recall@{}: {:.4}\n", self.k, self.recall));
358 out.push_str(&format!("Precision: {:.4}\n", self.precision));
359 out.push_str(&format!("QPS: {:.1}\n", self.qps));
360 out.push_str(&format!(
361 "Latency p50: {} µs, p95: {} µs, p99: {} µs\n",
362 self.p50_us, self.p95_us, self.p99_us
363 ));
364 out.push_str(&format!(
365 "Memory: {:.2} MB\n",
366 self.memory_bytes as f64 / (1024.0 * 1024.0)
367 ));
368 out.push_str(&format!("Build time: {} ms\n", self.build_time_ms));
369 if !self.metadata.is_empty() {
370 out.push_str("Metadata:\n");
371 for (k, v) in &self.metadata {
372 out.push_str(&format!(" {k}: {v}\n"));
373 }
374 }
375 out
376 }
377
378 pub fn to_json(&self) -> String {
380 let mut out = String::from("{\n");
381 out.push_str(&format!(" \"index_name\": \"{}\",\n", self.index_name));
382 out.push_str(&format!(" \"dataset_size\": {},\n", self.dataset_size));
383 out.push_str(&format!(" \"dimension\": {},\n", self.dimension));
384 out.push_str(&format!(" \"n_queries\": {},\n", self.n_queries));
385 out.push_str(&format!(" \"k\": {},\n", self.k));
386 out.push_str(&format!(" \"recall\": {:.6},\n", self.recall));
387 out.push_str(&format!(" \"precision\": {:.6},\n", self.precision));
388 out.push_str(&format!(" \"qps\": {:.1},\n", self.qps));
389 out.push_str(&format!(" \"p50_us\": {},\n", self.p50_us));
390 out.push_str(&format!(" \"p95_us\": {},\n", self.p95_us));
391 out.push_str(&format!(" \"p99_us\": {},\n", self.p99_us));
392 out.push_str(&format!(" \"memory_bytes\": {},\n", self.memory_bytes));
393 out.push_str(&format!(" \"build_time_ms\": {}\n", self.build_time_ms));
394 out.push('}');
395 out
396 }
397}
398
399pub fn average_distance_ratio(
404 ground_truth: &[Vec<Neighbour>],
405 approximate: &[Vec<Neighbour>],
406) -> f64 {
407 if ground_truth.is_empty() {
408 return 1.0;
409 }
410 let mut total = 0.0;
411 let mut count = 0usize;
412 for (gt, ap) in ground_truth.iter().zip(approximate.iter()) {
413 for (g, a) in gt.iter().zip(ap.iter()) {
414 if g.1 > 1e-12 {
415 total += a.1 as f64 / g.1 as f64;
416 count += 1;
417 }
418 }
419 }
420 if count == 0 {
421 1.0
422 } else {
423 total / count as f64
424 }
425}
426
427#[cfg(test)]
432mod tests {
433 use super::*;
434
435 fn simple_dataset() -> Vec<Vec<f32>> {
436 vec![
437 vec![0.0, 0.0],
438 vec![1.0, 0.0],
439 vec![0.0, 1.0],
440 vec![1.0, 1.0],
441 vec![2.0, 2.0],
442 vec![3.0, 3.0],
443 vec![5.0, 5.0],
444 vec![10.0, 10.0],
445 ]
446 }
447
448 fn simple_queries() -> Vec<Vec<f32>> {
449 vec![vec![0.0, 0.0], vec![1.0, 1.0], vec![5.0, 5.0]]
450 }
451
452 #[test]
455 fn test_brute_force_knn_basic() {
456 let data = simple_dataset();
457 let queries = vec![vec![0.0, 0.0]];
458 let gt = brute_force_knn(&data, &queries, 3);
459 assert_eq!(gt.len(), 1);
460 assert_eq!(gt[0].len(), 3);
461 assert_eq!(gt[0][0].0, 0);
463 assert!((gt[0][0].1).abs() < 1e-6);
464 }
465
466 #[test]
467 fn test_brute_force_knn_k_larger_than_dataset() {
468 let data = vec![vec![1.0], vec![2.0]];
469 let queries = vec![vec![0.0]];
470 let gt = brute_force_knn(&data, &queries, 10);
471 assert_eq!(gt[0].len(), 2);
473 }
474
475 #[test]
476 fn test_brute_force_ordering() {
477 let data = simple_dataset();
478 let queries = vec![vec![0.0, 0.0]];
479 let gt = brute_force_knn(&data, &queries, 4);
480 for i in 1..gt[0].len() {
482 assert!(gt[0][i].1 >= gt[0][i - 1].1);
483 }
484 }
485
486 #[test]
489 fn test_recall_perfect() {
490 let gt = vec![vec![(0, 0.0), (1, 1.0), (2, 1.0)]];
491 let ap = vec![vec![(0, 0.0), (1, 1.0), (2, 1.0)]];
492 let r = recall_at_k(>, &ap);
493 assert!(
494 (r - 1.0).abs() < 1e-10,
495 "Perfect recall should be 1.0, got {r}"
496 );
497 }
498
499 #[test]
500 fn test_recall_zero() {
501 let gt = vec![vec![(0, 0.0), (1, 1.0)]];
502 let ap = vec![vec![(5, 10.0), (6, 11.0)]];
503 let r = recall_at_k(>, &ap);
504 assert!(r.abs() < 1e-10, "No overlap → recall = 0, got {r}");
505 }
506
507 #[test]
508 fn test_recall_partial() {
509 let gt = vec![vec![(0, 0.0), (1, 1.0), (2, 1.0), (3, 2.0)]];
510 let ap = vec![vec![(0, 0.0), (1, 1.0), (5, 5.0), (6, 6.0)]];
511 let r = recall_at_k(>, &ap);
512 assert!((r - 0.5).abs() < 1e-10, "Recall = 0.5, got {r}");
514 }
515
516 #[test]
517 fn test_recall_empty() {
518 let r = recall_at_k(&[], &[]);
519 assert!(r.abs() < 1e-10);
520 }
521
522 #[test]
523 fn test_per_query_recall() {
524 let gt = vec![vec![(0, 0.0), (1, 1.0)], vec![(2, 0.0), (3, 1.0)]];
525 let ap = vec![
526 vec![(0, 0.0), (1, 1.0)], vec![(2, 0.0), (5, 5.0)], ];
529 let pq = per_query_recall(>, &ap);
530 assert!((pq[0] - 1.0).abs() < 1e-10);
531 assert!((pq[1] - 0.5).abs() < 1e-10);
532 }
533
534 #[test]
537 fn test_precision_perfect() {
538 let gt = vec![vec![(0, 0.0), (1, 1.0)]];
539 let ap = vec![vec![(0, 0.0), (1, 1.0)]];
540 let p = precision(>, &ap);
541 assert!((p - 1.0).abs() < 1e-10);
542 }
543
544 #[test]
545 fn test_precision_half() {
546 let gt = vec![vec![(0, 0.0), (1, 1.0)]];
547 let ap = vec![vec![(0, 0.0), (5, 10.0)]]; let p = precision(>, &ap);
549 assert!((p - 0.5).abs() < 1e-10);
550 }
551
552 #[test]
555 fn test_measure_qps() {
556 let queries = vec![vec![0.0, 0.0], vec![1.0, 1.0]];
557 let data = simple_dataset();
558 let result = measure_qps(&queries, |q| {
559 let mut dists: Vec<(usize, f32)> = data
561 .iter()
562 .enumerate()
563 .map(|(i, v)| {
564 let d: f32 = q.iter().zip(v.iter()).map(|(a, b)| (a - b) * (a - b)).sum();
565 (i, d)
566 })
567 .collect();
568 dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
569 dists.truncate(3);
570 dists
571 });
572 assert!(result.qps > 0.0, "QPS should be positive");
573 assert_eq!(result.total_queries, 2);
574 assert_eq!(result.latencies.len(), 2);
575 }
576
577 #[test]
578 fn test_qps_latency_percentiles() {
579 let queries: Vec<Vec<f32>> = (0..100).map(|i| vec![i as f32, 0.0]).collect();
580 let result = measure_qps(&queries, |_q| vec![(0, 0.0)]);
581 assert!(result.p50() <= result.p95());
583 assert!(result.p95() <= result.p99());
584 }
585
586 #[test]
587 fn test_qps_mean_latency() {
588 let queries = vec![vec![0.0], vec![1.0]];
589 let result = measure_qps(&queries, |_q| vec![(0, 0.0)]);
590 assert!(result.mean_latency() >= result.min_latency());
591 assert!(result.mean_latency() <= result.max_latency());
592 }
593
594 #[test]
597 fn test_build_timer() {
598 let timer = BuildTimer::start("test_build");
599 let _sum: u64 = (0..1000).sum();
601 let result = timer.stop();
602 assert_eq!(result.label, "test_build");
603 assert!(result.duration >= Duration::ZERO);
604 }
605
606 #[test]
609 fn test_estimate_flat_memory() {
610 let mem = estimate_flat_memory(1000, 128);
611 assert_eq!(mem, 512_000);
613 }
614
615 #[test]
616 fn test_estimate_hnsw_memory() {
617 let mem = estimate_hnsw_memory(1000, 128, 16, 4);
618 let vector_bytes = 1000 * 128 * 4;
619 let graph_bytes = 1000 * 16 * 4 * 8; assert_eq!(mem, vector_bytes + graph_bytes);
621 }
622
623 #[test]
624 fn test_estimate_pq_memory() {
625 let mem = estimate_pq_memory(10_000, 8);
626 assert_eq!(mem, 80_000);
627 }
628
629 #[test]
632 fn test_distance_ratio_perfect() {
633 let gt = vec![vec![(0, 1.0), (1, 2.0)]];
634 let ap = vec![vec![(0, 1.0), (1, 2.0)]];
635 let ratio = average_distance_ratio(>, &ap);
636 assert!((ratio - 1.0).abs() < 1e-6, "Perfect match → ratio = 1.0");
637 }
638
639 #[test]
640 fn test_distance_ratio_worse() {
641 let gt = vec![vec![(0, 1.0), (1, 2.0)]];
642 let ap = vec![vec![(0, 2.0), (1, 4.0)]]; let ratio = average_distance_ratio(>, &ap);
644 assert!(
645 (ratio - 2.0).abs() < 1e-6,
646 "Double distances → ratio = 2.0, got {ratio}"
647 );
648 }
649
650 #[test]
651 fn test_distance_ratio_empty() {
652 let ratio = average_distance_ratio(&[], &[]);
653 assert!((ratio - 1.0).abs() < 1e-6);
654 }
655
656 #[test]
659 fn test_precision_recall_sweep() {
660 let data = simple_dataset();
661 let queries = simple_queries();
662 let gt = brute_force_knn(&data, &queries, 3);
663 let params = vec!["exact".to_string()];
664 let curve = precision_recall_sweep(>, &queries, ¶ms, |_param, qs| {
665 brute_force_knn(&data, qs, 3) });
667 assert_eq!(curve.len(), 1);
668 assert!((curve[0].recall - 1.0).abs() < 1e-10);
669 assert!((curve[0].precision - 1.0).abs() < 1e-10);
670 }
671
672 #[test]
675 fn test_report_text() {
676 let report = BenchmarkReport {
677 index_name: "HNSW".to_string(),
678 dataset_size: 10_000,
679 dimension: 128,
680 n_queries: 1000,
681 k: 10,
682 recall: 0.95,
683 precision: 0.93,
684 qps: 5000.0,
685 p50_us: 100,
686 p95_us: 250,
687 p99_us: 500,
688 memory_bytes: 10_000_000,
689 build_time_ms: 1500,
690 metadata: HashMap::new(),
691 };
692 let text = report.to_text();
693 assert!(text.contains("HNSW"));
694 assert!(text.contains("10000"));
695 assert!(text.contains("0.95"));
696 }
697
698 #[test]
699 fn test_report_json() {
700 let report = BenchmarkReport {
701 index_name: "Flat".to_string(),
702 dataset_size: 5000,
703 dimension: 64,
704 n_queries: 500,
705 k: 5,
706 recall: 1.0,
707 precision: 1.0,
708 qps: 2000.0,
709 p50_us: 200,
710 p95_us: 400,
711 p99_us: 800,
712 memory_bytes: 1_280_000,
713 build_time_ms: 0,
714 metadata: HashMap::new(),
715 };
716 let json = report.to_json();
717 assert!(json.contains("\"index_name\": \"Flat\""));
718 assert!(json.contains("\"recall\": 1.0"));
719 }
720
721 #[test]
722 fn test_report_with_metadata() {
723 let mut meta = HashMap::new();
724 meta.insert("ef_search".to_string(), "64".to_string());
725 let report = BenchmarkReport {
726 index_name: "HNSW".to_string(),
727 dataset_size: 100,
728 dimension: 16,
729 n_queries: 10,
730 k: 5,
731 recall: 0.8,
732 precision: 0.8,
733 qps: 100.0,
734 p50_us: 500,
735 p95_us: 1000,
736 p99_us: 2000,
737 memory_bytes: 10_000,
738 build_time_ms: 100,
739 metadata: meta,
740 };
741 let text = report.to_text();
742 assert!(text.contains("ef_search"));
743 assert!(text.contains("64"));
744 }
745
746 #[test]
749 fn test_percentile_empty() {
750 let p = percentile_duration(&[], 50.0);
751 assert_eq!(p, Duration::ZERO);
752 }
753
754 #[test]
755 fn test_percentile_single() {
756 let durs = vec![Duration::from_micros(100)];
757 let p = percentile_duration(&durs, 50.0);
758 assert_eq!(p, Duration::from_micros(100));
759 }
760
761 #[test]
762 fn test_percentile_sorted() {
763 let durs: Vec<Duration> = (1..=100).map(Duration::from_micros).collect();
764 let p50 = percentile_duration(&durs, 50.0);
765 let p99 = percentile_duration(&durs, 99.0);
766 assert!(p50 < p99);
767 assert!(p50.as_micros() >= 49 && p50.as_micros() <= 51);
769 }
770
771 #[test]
774 fn test_end_to_end_benchmark() {
775 let data = simple_dataset();
776 let queries = simple_queries();
777 let k = 3;
778
779 let gt = brute_force_knn(&data, &queries, k);
781 assert_eq!(gt.len(), queries.len());
782
783 let qps_result = measure_qps(&queries, |q| {
785 let mut dists: Vec<(usize, f32)> = data
786 .iter()
787 .enumerate()
788 .map(|(i, v)| {
789 let d: f32 = q.iter().zip(v.iter()).map(|(a, b)| (a - b) * (a - b)).sum();
790 (i, d)
791 })
792 .collect();
793 dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
794 dists.truncate(k);
795 dists
796 });
797
798 let approx: Vec<Vec<Neighbour>> = queries
800 .iter()
801 .map(|q| {
802 let mut dists: Vec<(usize, f32)> = data
803 .iter()
804 .enumerate()
805 .map(|(i, v)| {
806 let d: f32 = q.iter().zip(v.iter()).map(|(a, b)| (a - b) * (a - b)).sum();
807 (i, d)
808 })
809 .collect();
810 dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
811 dists.truncate(k);
812 dists
813 })
814 .collect();
815
816 let recall = recall_at_k(>, &approx);
817 assert!(
818 (recall - 1.0).abs() < 1e-10,
819 "Exact search should give recall = 1.0"
820 );
821
822 let prec = precision(>, &approx);
823 assert!((prec - 1.0).abs() < 1e-10);
824
825 assert!(qps_result.qps > 0.0);
826 }
827}