Skip to main content

oxirs_vec/
ann_benchmark.rs

1//! ANN (Approximate Nearest Neighbour) recall and latency benchmarking.
2//!
3//! Provides utilities for evaluating the quality and performance of ANN indices:
4//! recall@k, QPS, build time, memory usage, ground-truth generation,
5//! precision-recall tradeoff, latency percentiles, and report generation.
6
7use std::collections::HashMap;
8use std::time::{Duration, Instant};
9
10// ── Ground truth ────────────────────────────────────────────────────────────
11
12/// A single query result: (vector_id, distance).
13pub type Neighbour = (usize, f32);
14
15/// Brute-force computation of exact k nearest neighbours for a set of queries.
16///
17/// `dataset` is the indexed corpus; `queries` are the query vectors;
18/// `k` is the number of neighbours to return for each query.
19/// Distance is squared Euclidean.
20pub fn brute_force_knn(
21    dataset: &[Vec<f32>],
22    queries: &[Vec<f32>],
23    k: usize,
24) -> Vec<Vec<Neighbour>> {
25    queries
26        .iter()
27        .map(|q| {
28            let mut dists: Vec<(usize, f32)> = dataset
29                .iter()
30                .enumerate()
31                .map(|(i, v)| {
32                    let d: f32 = q.iter().zip(v.iter()).map(|(a, b)| (a - b) * (a - b)).sum();
33                    (i, d)
34                })
35                .collect();
36            dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
37            dists.truncate(k);
38            dists
39        })
40        .collect()
41}
42
43// ── Recall@k ────────────────────────────────────────────────────────────────
44
45/// Compute recall@k: the fraction of true k-NN that appear in the
46/// approximate result set.
47///
48/// `ground_truth` and `approximate` must be parallel slices (one entry per query).
49pub fn recall_at_k(ground_truth: &[Vec<Neighbour>], approximate: &[Vec<Neighbour>]) -> f64 {
50    if ground_truth.is_empty() {
51        return 0.0;
52    }
53    let mut total_recall = 0.0;
54    for (gt, ap) in ground_truth.iter().zip(approximate.iter()) {
55        let gt_ids: std::collections::HashSet<usize> = gt.iter().map(|n| n.0).collect();
56        let ap_ids: std::collections::HashSet<usize> = ap.iter().map(|n| n.0).collect();
57        let found = gt_ids.intersection(&ap_ids).count();
58        if gt_ids.is_empty() {
59            continue;
60        }
61        total_recall += found as f64 / gt_ids.len() as f64;
62    }
63    total_recall / ground_truth.len() as f64
64}
65
66/// Compute recall@k for individual queries (not averaged).
67pub fn per_query_recall(
68    ground_truth: &[Vec<Neighbour>],
69    approximate: &[Vec<Neighbour>],
70) -> Vec<f64> {
71    ground_truth
72        .iter()
73        .zip(approximate.iter())
74        .map(|(gt, ap)| {
75            let gt_ids: std::collections::HashSet<usize> = gt.iter().map(|n| n.0).collect();
76            let ap_ids: std::collections::HashSet<usize> = ap.iter().map(|n| n.0).collect();
77            let found = gt_ids.intersection(&ap_ids).count();
78            if gt_ids.is_empty() {
79                0.0
80            } else {
81                found as f64 / gt_ids.len() as f64
82            }
83        })
84        .collect()
85}
86
87// ── Precision ───────────────────────────────────────────────────────────────
88
89/// Compute precision: the fraction of returned results that are true neighbours.
90pub fn precision(ground_truth: &[Vec<Neighbour>], approximate: &[Vec<Neighbour>]) -> f64 {
91    if ground_truth.is_empty() {
92        return 0.0;
93    }
94    let mut total = 0.0;
95    for (gt, ap) in ground_truth.iter().zip(approximate.iter()) {
96        let gt_ids: std::collections::HashSet<usize> = gt.iter().map(|n| n.0).collect();
97        let ap_ids: std::collections::HashSet<usize> = ap.iter().map(|n| n.0).collect();
98        let found = gt_ids.intersection(&ap_ids).count();
99        if ap_ids.is_empty() {
100            continue;
101        }
102        total += found as f64 / ap_ids.len() as f64;
103    }
104    total / ground_truth.len() as f64
105}
106
107// ── QPS measurement ─────────────────────────────────────────────────────────
108
109/// Measure queries per second for a given search function.
110///
111/// `search_fn` takes a query vector and returns the approximate result.
112pub fn measure_qps<F>(queries: &[Vec<f32>], mut search_fn: F) -> QpsResult
113where
114    F: FnMut(&[f32]) -> Vec<Neighbour>,
115{
116    let mut latencies = Vec::with_capacity(queries.len());
117    let overall_start = Instant::now();
118
119    for q in queries {
120        let start = Instant::now();
121        let _ = search_fn(q);
122        latencies.push(start.elapsed());
123    }
124
125    let total_time = overall_start.elapsed();
126    let qps = if total_time.as_secs_f64() > 0.0 {
127        queries.len() as f64 / total_time.as_secs_f64()
128    } else {
129        0.0
130    };
131
132    latencies.sort();
133
134    QpsResult {
135        qps,
136        total_queries: queries.len(),
137        total_time,
138        latencies,
139    }
140}
141
142/// Result of a QPS measurement.
143#[derive(Debug, Clone)]
144pub struct QpsResult {
145    /// Queries per second.
146    pub qps: f64,
147    /// Total number of queries executed.
148    pub total_queries: usize,
149    /// Wall-clock time for all queries.
150    pub total_time: Duration,
151    /// Per-query latencies (sorted ascending).
152    pub latencies: Vec<Duration>,
153}
154
155impl QpsResult {
156    /// Median latency.
157    pub fn p50(&self) -> Duration {
158        percentile_duration(&self.latencies, 50.0)
159    }
160
161    /// 95th-percentile latency.
162    pub fn p95(&self) -> Duration {
163        percentile_duration(&self.latencies, 95.0)
164    }
165
166    /// 99th-percentile latency.
167    pub fn p99(&self) -> Duration {
168        percentile_duration(&self.latencies, 99.0)
169    }
170
171    /// Mean latency.
172    pub fn mean_latency(&self) -> Duration {
173        if self.latencies.is_empty() {
174            return Duration::ZERO;
175        }
176        let total: Duration = self.latencies.iter().sum();
177        total / self.latencies.len() as u32
178    }
179
180    /// Minimum latency.
181    pub fn min_latency(&self) -> Duration {
182        self.latencies.first().copied().unwrap_or(Duration::ZERO)
183    }
184
185    /// Maximum latency.
186    pub fn max_latency(&self) -> Duration {
187        self.latencies.last().copied().unwrap_or(Duration::ZERO)
188    }
189}
190
191/// Helper: compute the p-th percentile from a sorted duration list.
192fn percentile_duration(sorted: &[Duration], pct: f64) -> Duration {
193    if sorted.is_empty() {
194        return Duration::ZERO;
195    }
196    let idx = ((pct / 100.0) * (sorted.len() as f64 - 1.0))
197        .round()
198        .max(0.0) as usize;
199    let idx = idx.min(sorted.len() - 1);
200    sorted[idx]
201}
202
203// ── Build time tracking ─────────────────────────────────────────────────────
204
205/// Track how long an index build takes.
206pub struct BuildTimer {
207    label: String,
208    start: Instant,
209}
210
211impl BuildTimer {
212    /// Start a new build timer with a label.
213    pub fn start(label: impl Into<String>) -> Self {
214        Self {
215            label: label.into(),
216            start: Instant::now(),
217        }
218    }
219
220    /// Stop the timer and return the result.
221    pub fn stop(self) -> BuildTimeResult {
222        BuildTimeResult {
223            label: self.label,
224            duration: self.start.elapsed(),
225        }
226    }
227}
228
229/// Result of a build-time measurement.
230#[derive(Debug, Clone)]
231pub struct BuildTimeResult {
232    /// Label for the build operation.
233    pub label: String,
234    /// Elapsed time.
235    pub duration: Duration,
236}
237
238// ── Memory estimation ───────────────────────────────────────────────────────
239
240/// Estimate the memory footprint of a flat vector index (vectors only).
241///
242/// Returns bytes.
243pub fn estimate_flat_memory(n_vectors: usize, dimension: usize) -> usize {
244    n_vectors * dimension * std::mem::size_of::<f32>()
245}
246
247/// Estimate memory for an HNSW-like graph index.
248///
249/// Accounts for vectors + adjacency lists.
250pub fn estimate_hnsw_memory(
251    n_vectors: usize,
252    dimension: usize,
253    m: usize,        // max connections per layer
254    n_levels: usize, // number of HNSW levels
255) -> usize {
256    let vector_bytes = n_vectors * dimension * std::mem::size_of::<f32>();
257    // Each node has up to m neighbours per level; store as usize ids
258    let graph_bytes = n_vectors * m * n_levels * std::mem::size_of::<usize>();
259    vector_bytes + graph_bytes
260}
261
262/// Estimate memory for a product-quantised (PQ) index.
263pub fn estimate_pq_memory(n_vectors: usize, n_subspaces: usize) -> usize {
264    // Each vector → n_subspaces bytes (codes)
265    n_vectors * n_subspaces
266}
267
268// ── Precision-recall tradeoff ───────────────────────────────────────────────
269
270/// A data point on the precision-recall curve.
271#[derive(Debug, Clone)]
272pub struct PrecisionRecallPoint {
273    /// Recall value [0, 1].
274    pub recall: f64,
275    /// Precision value [0, 1].
276    pub precision: f64,
277    /// The parameter setting that produced this point.
278    pub parameter: String,
279}
280
281/// Run a sweep over a set of parameter values and collect recall/precision
282/// at each setting.
283///
284/// `search_with_param` receives a parameter value and returns the
285/// approximate results for all queries.
286pub fn precision_recall_sweep<F>(
287    ground_truth: &[Vec<Neighbour>],
288    queries: &[Vec<f32>],
289    param_values: &[String],
290    mut search_with_param: F,
291) -> Vec<PrecisionRecallPoint>
292where
293    F: FnMut(&str, &[Vec<f32>]) -> Vec<Vec<Neighbour>>,
294{
295    let mut curve = Vec::with_capacity(param_values.len());
296    for param in param_values {
297        let approx = search_with_param(param, queries);
298        let r = recall_at_k(ground_truth, &approx);
299        let p = precision(ground_truth, &approx);
300        curve.push(PrecisionRecallPoint {
301            recall: r,
302            precision: p,
303            parameter: param.clone(),
304        });
305    }
306    curve
307}
308
309// ── Benchmark report ────────────────────────────────────────────────────────
310
311/// A complete benchmark report.
312#[derive(Debug, Clone)]
313pub struct BenchmarkReport {
314    /// Name of the index / algorithm.
315    pub index_name: String,
316    /// Number of vectors in the dataset.
317    pub dataset_size: usize,
318    /// Dimensionality.
319    pub dimension: usize,
320    /// Number of queries.
321    pub n_queries: usize,
322    /// k used for recall measurement.
323    pub k: usize,
324    /// Overall recall@k.
325    pub recall: f64,
326    /// Overall precision.
327    pub precision: f64,
328    /// QPS result.
329    pub qps: f64,
330    /// P50 latency.
331    pub p50_us: u64,
332    /// P95 latency.
333    pub p95_us: u64,
334    /// P99 latency.
335    pub p99_us: u64,
336    /// Estimated memory in bytes.
337    pub memory_bytes: usize,
338    /// Build time in milliseconds.
339    pub build_time_ms: u64,
340    /// Extra key-value metadata.
341    pub metadata: HashMap<String, String>,
342}
343
344impl BenchmarkReport {
345    /// Format the report as a human-readable string.
346    pub fn to_text(&self) -> String {
347        let mut out = String::new();
348        out.push_str(&format!(
349            "=== ANN Benchmark Report: {} ===\n",
350            self.index_name
351        ));
352        out.push_str(&format!(
353            "Dataset: {} vectors × {} dims\n",
354            self.dataset_size, self.dimension
355        ));
356        out.push_str(&format!("Queries: {}, k={}\n", self.n_queries, self.k));
357        out.push_str(&format!("Recall@{}: {:.4}\n", self.k, self.recall));
358        out.push_str(&format!("Precision: {:.4}\n", self.precision));
359        out.push_str(&format!("QPS: {:.1}\n", self.qps));
360        out.push_str(&format!(
361            "Latency p50: {} µs, p95: {} µs, p99: {} µs\n",
362            self.p50_us, self.p95_us, self.p99_us
363        ));
364        out.push_str(&format!(
365            "Memory: {:.2} MB\n",
366            self.memory_bytes as f64 / (1024.0 * 1024.0)
367        ));
368        out.push_str(&format!("Build time: {} ms\n", self.build_time_ms));
369        if !self.metadata.is_empty() {
370            out.push_str("Metadata:\n");
371            for (k, v) in &self.metadata {
372                out.push_str(&format!("  {k}: {v}\n"));
373            }
374        }
375        out
376    }
377
378    /// Format the report as a JSON string.
379    pub fn to_json(&self) -> String {
380        let mut out = String::from("{\n");
381        out.push_str(&format!("  \"index_name\": \"{}\",\n", self.index_name));
382        out.push_str(&format!("  \"dataset_size\": {},\n", self.dataset_size));
383        out.push_str(&format!("  \"dimension\": {},\n", self.dimension));
384        out.push_str(&format!("  \"n_queries\": {},\n", self.n_queries));
385        out.push_str(&format!("  \"k\": {},\n", self.k));
386        out.push_str(&format!("  \"recall\": {:.6},\n", self.recall));
387        out.push_str(&format!("  \"precision\": {:.6},\n", self.precision));
388        out.push_str(&format!("  \"qps\": {:.1},\n", self.qps));
389        out.push_str(&format!("  \"p50_us\": {},\n", self.p50_us));
390        out.push_str(&format!("  \"p95_us\": {},\n", self.p95_us));
391        out.push_str(&format!("  \"p99_us\": {},\n", self.p99_us));
392        out.push_str(&format!("  \"memory_bytes\": {},\n", self.memory_bytes));
393        out.push_str(&format!("  \"build_time_ms\": {}\n", self.build_time_ms));
394        out.push('}');
395        out
396    }
397}
398
399// ── Distance ratio (approximation quality) ──────────────────────────────────
400
401/// Compute the average distance ratio: sum of approx_dist / true_dist for
402/// each query's k-th neighbour.  A perfect index has ratio = 1.0.
403pub fn average_distance_ratio(
404    ground_truth: &[Vec<Neighbour>],
405    approximate: &[Vec<Neighbour>],
406) -> f64 {
407    if ground_truth.is_empty() {
408        return 1.0;
409    }
410    let mut total = 0.0;
411    let mut count = 0usize;
412    for (gt, ap) in ground_truth.iter().zip(approximate.iter()) {
413        for (g, a) in gt.iter().zip(ap.iter()) {
414            if g.1 > 1e-12 {
415                total += a.1 as f64 / g.1 as f64;
416                count += 1;
417            }
418        }
419    }
420    if count == 0 {
421        1.0
422    } else {
423        total / count as f64
424    }
425}
426
427// ═══════════════════════════════════════════════════════════════════════════════
428// Tests
429// ═══════════════════════════════════════════════════════════════════════════════
430
431#[cfg(test)]
432mod tests {
433    use super::*;
434
435    fn simple_dataset() -> Vec<Vec<f32>> {
436        vec![
437            vec![0.0, 0.0],
438            vec![1.0, 0.0],
439            vec![0.0, 1.0],
440            vec![1.0, 1.0],
441            vec![2.0, 2.0],
442            vec![3.0, 3.0],
443            vec![5.0, 5.0],
444            vec![10.0, 10.0],
445        ]
446    }
447
448    fn simple_queries() -> Vec<Vec<f32>> {
449        vec![vec![0.0, 0.0], vec![1.0, 1.0], vec![5.0, 5.0]]
450    }
451
452    // ── Brute-force ground truth ────────────────────────────────────────────
453
454    #[test]
455    fn test_brute_force_knn_basic() {
456        let data = simple_dataset();
457        let queries = vec![vec![0.0, 0.0]];
458        let gt = brute_force_knn(&data, &queries, 3);
459        assert_eq!(gt.len(), 1);
460        assert_eq!(gt[0].len(), 3);
461        // Nearest to (0,0) should be index 0 (itself) with distance 0
462        assert_eq!(gt[0][0].0, 0);
463        assert!((gt[0][0].1).abs() < 1e-6);
464    }
465
466    #[test]
467    fn test_brute_force_knn_k_larger_than_dataset() {
468        let data = vec![vec![1.0], vec![2.0]];
469        let queries = vec![vec![0.0]];
470        let gt = brute_force_knn(&data, &queries, 10);
471        // Should return at most dataset.len() neighbours
472        assert_eq!(gt[0].len(), 2);
473    }
474
475    #[test]
476    fn test_brute_force_ordering() {
477        let data = simple_dataset();
478        let queries = vec![vec![0.0, 0.0]];
479        let gt = brute_force_knn(&data, &queries, 4);
480        // Distances should be non-decreasing
481        for i in 1..gt[0].len() {
482            assert!(gt[0][i].1 >= gt[0][i - 1].1);
483        }
484    }
485
486    // ── Recall@k ────────────────────────────────────────────────────────────
487
488    #[test]
489    fn test_recall_perfect() {
490        let gt = vec![vec![(0, 0.0), (1, 1.0), (2, 1.0)]];
491        let ap = vec![vec![(0, 0.0), (1, 1.0), (2, 1.0)]];
492        let r = recall_at_k(&gt, &ap);
493        assert!(
494            (r - 1.0).abs() < 1e-10,
495            "Perfect recall should be 1.0, got {r}"
496        );
497    }
498
499    #[test]
500    fn test_recall_zero() {
501        let gt = vec![vec![(0, 0.0), (1, 1.0)]];
502        let ap = vec![vec![(5, 10.0), (6, 11.0)]];
503        let r = recall_at_k(&gt, &ap);
504        assert!(r.abs() < 1e-10, "No overlap → recall = 0, got {r}");
505    }
506
507    #[test]
508    fn test_recall_partial() {
509        let gt = vec![vec![(0, 0.0), (1, 1.0), (2, 1.0), (3, 2.0)]];
510        let ap = vec![vec![(0, 0.0), (1, 1.0), (5, 5.0), (6, 6.0)]];
511        let r = recall_at_k(&gt, &ap);
512        // 2 out of 4 true neighbours found
513        assert!((r - 0.5).abs() < 1e-10, "Recall = 0.5, got {r}");
514    }
515
516    #[test]
517    fn test_recall_empty() {
518        let r = recall_at_k(&[], &[]);
519        assert!(r.abs() < 1e-10);
520    }
521
522    #[test]
523    fn test_per_query_recall() {
524        let gt = vec![vec![(0, 0.0), (1, 1.0)], vec![(2, 0.0), (3, 1.0)]];
525        let ap = vec![
526            vec![(0, 0.0), (1, 1.0)], // perfect
527            vec![(2, 0.0), (5, 5.0)], // 1 of 2
528        ];
529        let pq = per_query_recall(&gt, &ap);
530        assert!((pq[0] - 1.0).abs() < 1e-10);
531        assert!((pq[1] - 0.5).abs() < 1e-10);
532    }
533
534    // ── Precision ───────────────────────────────────────────────────────────
535
536    #[test]
537    fn test_precision_perfect() {
538        let gt = vec![vec![(0, 0.0), (1, 1.0)]];
539        let ap = vec![vec![(0, 0.0), (1, 1.0)]];
540        let p = precision(&gt, &ap);
541        assert!((p - 1.0).abs() < 1e-10);
542    }
543
544    #[test]
545    fn test_precision_half() {
546        let gt = vec![vec![(0, 0.0), (1, 1.0)]];
547        let ap = vec![vec![(0, 0.0), (5, 10.0)]]; // 1 true of 2 returned
548        let p = precision(&gt, &ap);
549        assert!((p - 0.5).abs() < 1e-10);
550    }
551
552    // ── QPS measurement ─────────────────────────────────────────────────────
553
554    #[test]
555    fn test_measure_qps() {
556        let queries = vec![vec![0.0, 0.0], vec![1.0, 1.0]];
557        let data = simple_dataset();
558        let result = measure_qps(&queries, |q| {
559            // Trivial linear scan
560            let mut dists: Vec<(usize, f32)> = data
561                .iter()
562                .enumerate()
563                .map(|(i, v)| {
564                    let d: f32 = q.iter().zip(v.iter()).map(|(a, b)| (a - b) * (a - b)).sum();
565                    (i, d)
566                })
567                .collect();
568            dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
569            dists.truncate(3);
570            dists
571        });
572        assert!(result.qps > 0.0, "QPS should be positive");
573        assert_eq!(result.total_queries, 2);
574        assert_eq!(result.latencies.len(), 2);
575    }
576
577    #[test]
578    fn test_qps_latency_percentiles() {
579        let queries: Vec<Vec<f32>> = (0..100).map(|i| vec![i as f32, 0.0]).collect();
580        let result = measure_qps(&queries, |_q| vec![(0, 0.0)]);
581        // p50 <= p95 <= p99 (monotonicity)
582        assert!(result.p50() <= result.p95());
583        assert!(result.p95() <= result.p99());
584    }
585
586    #[test]
587    fn test_qps_mean_latency() {
588        let queries = vec![vec![0.0], vec![1.0]];
589        let result = measure_qps(&queries, |_q| vec![(0, 0.0)]);
590        assert!(result.mean_latency() >= result.min_latency());
591        assert!(result.mean_latency() <= result.max_latency());
592    }
593
594    // ── Build timer ─────────────────────────────────────────────────────────
595
596    #[test]
597    fn test_build_timer() {
598        let timer = BuildTimer::start("test_build");
599        // Do a tiny amount of work
600        let _sum: u64 = (0..1000).sum();
601        let result = timer.stop();
602        assert_eq!(result.label, "test_build");
603        assert!(result.duration >= Duration::ZERO);
604    }
605
606    // ── Memory estimation ───────────────────────────────────────────────────
607
608    #[test]
609    fn test_estimate_flat_memory() {
610        let mem = estimate_flat_memory(1000, 128);
611        // 1000 * 128 * 4 = 512,000 bytes
612        assert_eq!(mem, 512_000);
613    }
614
615    #[test]
616    fn test_estimate_hnsw_memory() {
617        let mem = estimate_hnsw_memory(1000, 128, 16, 4);
618        let vector_bytes = 1000 * 128 * 4;
619        let graph_bytes = 1000 * 16 * 4 * 8; // usize = 8 bytes on 64-bit
620        assert_eq!(mem, vector_bytes + graph_bytes);
621    }
622
623    #[test]
624    fn test_estimate_pq_memory() {
625        let mem = estimate_pq_memory(10_000, 8);
626        assert_eq!(mem, 80_000);
627    }
628
629    // ── Distance ratio ──────────────────────────────────────────────────────
630
631    #[test]
632    fn test_distance_ratio_perfect() {
633        let gt = vec![vec![(0, 1.0), (1, 2.0)]];
634        let ap = vec![vec![(0, 1.0), (1, 2.0)]];
635        let ratio = average_distance_ratio(&gt, &ap);
636        assert!((ratio - 1.0).abs() < 1e-6, "Perfect match → ratio = 1.0");
637    }
638
639    #[test]
640    fn test_distance_ratio_worse() {
641        let gt = vec![vec![(0, 1.0), (1, 2.0)]];
642        let ap = vec![vec![(0, 2.0), (1, 4.0)]]; // double the true distances
643        let ratio = average_distance_ratio(&gt, &ap);
644        assert!(
645            (ratio - 2.0).abs() < 1e-6,
646            "Double distances → ratio = 2.0, got {ratio}"
647        );
648    }
649
650    #[test]
651    fn test_distance_ratio_empty() {
652        let ratio = average_distance_ratio(&[], &[]);
653        assert!((ratio - 1.0).abs() < 1e-6);
654    }
655
656    // ── Precision-recall sweep ──────────────────────────────────────────────
657
658    #[test]
659    fn test_precision_recall_sweep() {
660        let data = simple_dataset();
661        let queries = simple_queries();
662        let gt = brute_force_knn(&data, &queries, 3);
663        let params = vec!["exact".to_string()];
664        let curve = precision_recall_sweep(&gt, &queries, &params, |_param, qs| {
665            brute_force_knn(&data, qs, 3) // exact → perfect recall
666        });
667        assert_eq!(curve.len(), 1);
668        assert!((curve[0].recall - 1.0).abs() < 1e-10);
669        assert!((curve[0].precision - 1.0).abs() < 1e-10);
670    }
671
672    // ── Benchmark report ────────────────────────────────────────────────────
673
674    #[test]
675    fn test_report_text() {
676        let report = BenchmarkReport {
677            index_name: "HNSW".to_string(),
678            dataset_size: 10_000,
679            dimension: 128,
680            n_queries: 1000,
681            k: 10,
682            recall: 0.95,
683            precision: 0.93,
684            qps: 5000.0,
685            p50_us: 100,
686            p95_us: 250,
687            p99_us: 500,
688            memory_bytes: 10_000_000,
689            build_time_ms: 1500,
690            metadata: HashMap::new(),
691        };
692        let text = report.to_text();
693        assert!(text.contains("HNSW"));
694        assert!(text.contains("10000"));
695        assert!(text.contains("0.95"));
696    }
697
698    #[test]
699    fn test_report_json() {
700        let report = BenchmarkReport {
701            index_name: "Flat".to_string(),
702            dataset_size: 5000,
703            dimension: 64,
704            n_queries: 500,
705            k: 5,
706            recall: 1.0,
707            precision: 1.0,
708            qps: 2000.0,
709            p50_us: 200,
710            p95_us: 400,
711            p99_us: 800,
712            memory_bytes: 1_280_000,
713            build_time_ms: 0,
714            metadata: HashMap::new(),
715        };
716        let json = report.to_json();
717        assert!(json.contains("\"index_name\": \"Flat\""));
718        assert!(json.contains("\"recall\": 1.0"));
719    }
720
721    #[test]
722    fn test_report_with_metadata() {
723        let mut meta = HashMap::new();
724        meta.insert("ef_search".to_string(), "64".to_string());
725        let report = BenchmarkReport {
726            index_name: "HNSW".to_string(),
727            dataset_size: 100,
728            dimension: 16,
729            n_queries: 10,
730            k: 5,
731            recall: 0.8,
732            precision: 0.8,
733            qps: 100.0,
734            p50_us: 500,
735            p95_us: 1000,
736            p99_us: 2000,
737            memory_bytes: 10_000,
738            build_time_ms: 100,
739            metadata: meta,
740        };
741        let text = report.to_text();
742        assert!(text.contains("ef_search"));
743        assert!(text.contains("64"));
744    }
745
746    // ── Percentile helper ───────────────────────────────────────────────────
747
748    #[test]
749    fn test_percentile_empty() {
750        let p = percentile_duration(&[], 50.0);
751        assert_eq!(p, Duration::ZERO);
752    }
753
754    #[test]
755    fn test_percentile_single() {
756        let durs = vec![Duration::from_micros(100)];
757        let p = percentile_duration(&durs, 50.0);
758        assert_eq!(p, Duration::from_micros(100));
759    }
760
761    #[test]
762    fn test_percentile_sorted() {
763        let durs: Vec<Duration> = (1..=100).map(Duration::from_micros).collect();
764        let p50 = percentile_duration(&durs, 50.0);
765        let p99 = percentile_duration(&durs, 99.0);
766        assert!(p50 < p99);
767        // p50 should be around 50 µs
768        assert!(p50.as_micros() >= 49 && p50.as_micros() <= 51);
769    }
770
771    // ── Integration: end-to-end benchmark ───────────────────────────────────
772
773    #[test]
774    fn test_end_to_end_benchmark() {
775        let data = simple_dataset();
776        let queries = simple_queries();
777        let k = 3;
778
779        // Ground truth
780        let gt = brute_force_knn(&data, &queries, k);
781        assert_eq!(gt.len(), queries.len());
782
783        // "Approximate" search (same as exact for this test)
784        let qps_result = measure_qps(&queries, |q| {
785            let mut dists: Vec<(usize, f32)> = data
786                .iter()
787                .enumerate()
788                .map(|(i, v)| {
789                    let d: f32 = q.iter().zip(v.iter()).map(|(a, b)| (a - b) * (a - b)).sum();
790                    (i, d)
791                })
792                .collect();
793            dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
794            dists.truncate(k);
795            dists
796        });
797
798        // Collect approximate results
799        let approx: Vec<Vec<Neighbour>> = queries
800            .iter()
801            .map(|q| {
802                let mut dists: Vec<(usize, f32)> = data
803                    .iter()
804                    .enumerate()
805                    .map(|(i, v)| {
806                        let d: f32 = q.iter().zip(v.iter()).map(|(a, b)| (a - b) * (a - b)).sum();
807                        (i, d)
808                    })
809                    .collect();
810                dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
811                dists.truncate(k);
812                dists
813            })
814            .collect();
815
816        let recall = recall_at_k(&gt, &approx);
817        assert!(
818            (recall - 1.0).abs() < 1e-10,
819            "Exact search should give recall = 1.0"
820        );
821
822        let prec = precision(&gt, &approx);
823        assert!((prec - 1.0).abs() < 1e-10);
824
825        assert!(qps_result.qps > 0.0);
826    }
827}