Skip to main content

diskann_benchmark_core/
recall.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::{collections::HashSet, hash::Hash};
7
8use diskann_utils::{
9    strided::StridedView,
10    views::{Matrix, MatrixView},
11};
12use thiserror::Error;
13
14#[derive(Debug, Clone)]
15#[non_exhaustive]
16pub struct RecallMetrics {
17    /// The `k` value for `k-recall-at-n`.
18    pub recall_k: usize,
19    /// The `n` value for `k-recall-at-n`.
20    pub recall_n: usize,
21    /// The number of queries.
22    pub num_queries: usize,
23    /// The average recall across all queries.
24    pub average: f64,
25    /// The minimum observed recall (max possible value: `recall_n`).
26    pub minimum: usize,
27    /// The maximum observed recall (max possible value: `recall_k`).
28    pub maximum: usize,
29}
30
31#[derive(Debug, Error)]
32pub enum ComputeRecallError {
33    #[error("results matrix has {0} rows but ground truth has {1}")]
34    RowsMismatch(usize, usize),
35    #[error("distances matrix has {0} rows but ground truth has {1}")]
36    DistanceRowsMismatch(usize, usize),
37    #[error("recall k value {0} must be less than or equal to recall n {1}")]
38    RecallKAndNError(usize, usize),
39    #[error("number of results per query {0} must be at least the specified recall k {1}")]
40    NotEnoughResults(usize, usize),
41    #[error(
42        "number of groundtruth values per query {0} must be at least the specified recall n {1}"
43    )]
44    NotEnoughGroundTruth(usize, usize),
45    #[error("number of groundtruth distances {0} does not match groundtruth entries {1}")]
46    NotEnoughGroundTruthDistances(usize, usize),
47}
48
49/// An abstraction over data-structures such as vector-of-vectors.
50///
51/// This is used in recall calculations such as [`knn`] and [`average_precision`] and
52/// is purposely `dyn` compatible to reduce compilation overhead.
53///
54/// Implementations should ensure that if [`Self::nrows`] returns a value `N` that `row(i)`
55/// returns a slice for all `i` in `0..N`. Accesses outside of that range are allowed to
56/// panic.
57///
58/// The implementation [`Self::ncols`] is optional and can be implemented if the length of
59/// all inner vectors is known and identical for all inner vectors, enabling faster error
60/// paths during recall calculation.
61///
62/// If [`Self::ncols`] returns `Some(K)`, then the length of each slice returned from
63/// [`Self::row`] should have a length equal to `K`. Note that unsafe code may **not** rely
64/// on this behavior.
65///
66/// If [`Self::ncols`] returns `None` then no assumption can be made about the length of
67/// the slices yielded from [`Self::row`].
68pub trait Rows<T> {
69    /// Return the number of subslices in `Self`.
70    fn nrows(&self) -> usize;
71
72    /// Return the `i`th subslice contained in self.
73    fn row(&self, i: usize) -> &[T];
74
75    /// Return `Some(K)` if all subslices are known to have length `K`. Otherwise, return
76    /// `None`.
77    ///
78    /// # Provided Implementation
79    ///
80    /// The provided implementation returns `None`.
81    fn ncols(&self) -> Option<usize> {
82        None
83    }
84}
85
86impl<T> Rows<T> for Matrix<T> {
87    fn nrows(&self) -> usize {
88        Matrix::<T>::nrows(self)
89    }
90    fn row(&self, i: usize) -> &[T] {
91        Matrix::<T>::row(self, i)
92    }
93    fn ncols(&self) -> Option<usize> {
94        Some(Matrix::<T>::ncols(self))
95    }
96}
97
98impl<T> Rows<T> for MatrixView<'_, T> {
99    fn nrows(&self) -> usize {
100        MatrixView::<'_, T>::nrows(self)
101    }
102    fn row(&self, i: usize) -> &[T] {
103        MatrixView::<'_, T>::row(self, i)
104    }
105    fn ncols(&self) -> Option<usize> {
106        Some(MatrixView::<'_, T>::ncols(self))
107    }
108}
109
110impl<T> Rows<T> for Vec<Vec<T>> {
111    fn nrows(&self) -> usize {
112        self.len()
113    }
114    fn row(&self, i: usize) -> &[T] {
115        &self[i]
116    }
117}
118
119/// Aggregate trait for required behavior when computing recall and average precision.
120pub trait RecallCompatible: Eq + Hash + Clone + std::fmt::Debug {}
121
122impl<T> RecallCompatible for T where T: Eq + Hash + Clone + std::fmt::Debug {}
123
124/// Compute the K-nearest-neighbors recall value "K-recall-at-N".
125///
126/// For each entry in `groundtruth` and `results`, this computes the `recall_k` number of
127/// elements of `groundtruth` that are present in the first `recall_n` entries of `results`.
128///
129/// If `groundtruth_distances` is provided, then it will be used to allow ties when matching
130/// the last values of each entry of `results`. Values will be counted towards the recall if
131/// they have the same distance as the last ordered candidate.
132///
133/// If `allow_insufficient_results`, an error will not be given if an entry in `results`
134/// has fewer than `recall_n` candidates.
135pub fn knn<T>(
136    groundtruth: &dyn Rows<T>,
137    groundtruth_distances: Option<StridedView<'_, f32>>,
138    results: &dyn Rows<T>,
139    recall_k: usize,
140    recall_n: usize,
141    allow_insufficient_results: bool,
142) -> Result<RecallMetrics, ComputeRecallError>
143where
144    T: RecallCompatible,
145{
146    if recall_k > recall_n {
147        return Err(ComputeRecallError::RecallKAndNError(recall_k, recall_n));
148    }
149
150    let nrows = results.nrows();
151    if nrows != groundtruth.nrows() {
152        return Err(ComputeRecallError::RowsMismatch(nrows, groundtruth.nrows()));
153    }
154
155    if let Some(cols) = results.ncols()
156        && cols < recall_n
157        && !allow_insufficient_results
158    {
159        return Err(ComputeRecallError::NotEnoughResults(cols, recall_n));
160    }
161
162    // Validate groundtruth size for fixed-size sources
163    match groundtruth.ncols() {
164        Some(ncols) if ncols < recall_k => {
165            return Err(ComputeRecallError::NotEnoughGroundTruth(ncols, recall_k));
166        }
167        _ => {}
168    }
169
170    if let Some(distances) = groundtruth_distances {
171        if nrows != distances.nrows() {
172            return Err(ComputeRecallError::DistanceRowsMismatch(
173                distances.nrows(),
174                nrows,
175            ));
176        }
177
178        match groundtruth.ncols() {
179            Some(ncols) if distances.ncols() != ncols => {
180                return Err(ComputeRecallError::NotEnoughGroundTruthDistances(
181                    distances.ncols(),
182                    ncols,
183                ));
184            }
185            _ => {}
186        }
187    }
188
189    // The actual recall computation for fixed-size groundtruth
190    let mut recall_values: Vec<usize> = Vec::new();
191    let mut this_groundtruth = HashSet::new();
192    let mut this_results = HashSet::new();
193
194    for i in 0..results.nrows() {
195        let result = results.row(i);
196        if !allow_insufficient_results && result.len() < recall_n {
197            return Err(ComputeRecallError::NotEnoughResults(result.len(), recall_n));
198        }
199
200        let gt_row = groundtruth.row(i);
201        if gt_row.len() < recall_k {
202            return Err(ComputeRecallError::NotEnoughGroundTruth(
203                gt_row.len(),
204                recall_k,
205            ));
206        }
207
208        // Populate the groundtruth using the top-k
209        this_groundtruth.clear();
210        this_groundtruth.extend(gt_row.iter().take(recall_k).cloned());
211
212        // If we have distances, then continue to append distances as long as the distance
213        // value is constant
214        if let Some(distances) = groundtruth_distances
215            && recall_k > 0
216        {
217            let distances_row = distances.row(i);
218            if distances_row.len() > recall_k - 1 && gt_row.len() > recall_k - 1 {
219                let last_distance = distances_row[recall_k - 1];
220                for (d, g) in distances_row.iter().zip(gt_row.iter()).skip(recall_k) {
221                    if *d == last_distance {
222                        this_groundtruth.insert(g.clone());
223                    } else {
224                        break;
225                    }
226                }
227            }
228        }
229
230        this_results.clear();
231        this_results.extend(result.iter().take(recall_n).cloned());
232
233        // Count the overlap
234        let r = this_groundtruth
235            .iter()
236            .filter(|i| this_results.contains(i))
237            .count()
238            .min(recall_k);
239
240        recall_values.push(r);
241    }
242
243    // Perform post-processing
244    let total: usize = recall_values.iter().sum();
245    let minimum = recall_values.iter().min().unwrap_or(&0);
246    let maximum = recall_values.iter().max().unwrap_or(&0);
247
248    // We explicitly check that each groundtruth row has at least `recall_k` elements.
249    let div = recall_k * nrows;
250    let average = (total as f64) / (div as f64);
251
252    Ok(RecallMetrics {
253        recall_k,
254        recall_n,
255        num_queries: nrows,
256        average,
257        minimum: *minimum,
258        maximum: *maximum,
259    })
260}
261
262#[derive(Debug, Clone)]
263#[non_exhaustive]
264pub struct AveragePrecisionMetrics {
265    /// The number of queries.
266    pub num_queries: usize,
267    /// The average precision
268    pub average_precision: f64,
269}
270
271#[derive(Debug, Error)]
272pub enum AveragePrecisionError {
273    #[error("results has {0} elements but ground truth has {1}")]
274    EntriesMismatch(usize, usize),
275}
276
277/// Compute average precision of a range search result
278pub fn average_precision<T>(
279    results: &dyn Rows<T>,
280    groundtruth: &dyn Rows<T>,
281) -> Result<AveragePrecisionMetrics, AveragePrecisionError>
282where
283    T: RecallCompatible,
284{
285    let nrows = results.nrows();
286    let groundtruth_nrows = groundtruth.nrows();
287    if nrows != groundtruth_nrows {
288        return Err(AveragePrecisionError::EntriesMismatch(
289            nrows,
290            groundtruth_nrows,
291        ));
292    }
293
294    // The actual recall computation.
295    let mut num_gt_results = 0;
296    let mut num_reported_results = 0;
297
298    let mut scratch = HashSet::new();
299    let nrows = results.nrows();
300
301    for i in 0..nrows {
302        let result = results.row(i);
303        let gt = groundtruth.row(i);
304
305        scratch.clear();
306        scratch.extend(result.iter().cloned());
307        num_reported_results += gt.iter().filter(|i| scratch.contains(i)).count();
308        num_gt_results += gt.len();
309    }
310
311    // Perform post-processing.
312    let average_precision = (num_reported_results as f64) / (num_gt_results as f64);
313
314    Ok(AveragePrecisionMetrics {
315        average_precision,
316        num_queries: nrows,
317    })
318}
319
320///////////
321// Tests //
322///////////
323
324#[cfg(test)]
325mod tests {
326    use diskann_utils::views::{self, Matrix};
327
328    use super::*;
329
330    fn test_rows_inner(rows: &dyn Rows<usize>, ncols: Option<usize>) {
331        assert_eq!(rows.ncols(), ncols);
332        assert_eq!(rows.nrows(), 3);
333        assert_eq!(rows.row(0), &[0, 1, 2, 3]);
334        assert_eq!(rows.row(1), &[4, 5, 6, 7]);
335        assert_eq!(rows.row(2), &[8, 9, 10, 11]);
336    }
337
338    #[test]
339    fn test_rows() {
340        let mut i = 0usize;
341        let mat = Matrix::new(
342            views::Init(|| {
343                let v = i;
344                i += 1;
345                v
346            }),
347            3,
348            4,
349        );
350
351        test_rows_inner(&mat, Some(4));
352        test_rows_inner(&(mat.as_view()), Some(4));
353
354        let vecs = vec![vec![0, 1, 2, 3], vec![4, 5, 6, 7], vec![8, 9, 10, 11]];
355        test_rows_inner(&vecs, None);
356    }
357
358    struct ExpectedRecall {
359        recall_k: usize,
360        recall_n: usize,
361        // Recall for each component.
362        components: Vec<usize>,
363    }
364
365    impl ExpectedRecall {
366        fn new(recall_k: usize, recall_n: usize, components: Vec<usize>) -> Self {
367            assert!(recall_k <= recall_n);
368            components.iter().for_each(|x| {
369                assert!(*x <= recall_k);
370            });
371            Self {
372                recall_k,
373                recall_n,
374                components,
375            }
376        }
377
378        fn compute_recall(&self) -> f64 {
379            (self.components.iter().sum::<usize>() as f64)
380                / ((self.components.len() * self.recall_k) as f64)
381        }
382    }
383
384    #[test]
385    fn test_happy_path() {
386        let groundtruth = Matrix::try_from(
387            vec![
388                0, 1, 2, 3, 4, 5, 6, 7, 8, 9, // row 0
389                5, 6, 7, 8, 9, 10, 11, 12, 13, 14, // row 1
390                0, 1, 2, 3, 4, 5, 6, 7, 8, 9, // row 2
391                0, 1, 2, 3, 4, 5, 6, 7, 8, 9, // row 3
392            ]
393            .into(),
394            4,
395            10,
396        )
397        .unwrap();
398
399        let distances = Matrix::try_from(
400            vec![
401                0.0, 1.0, 2.0, 3.0, 3.0, 3.0, 3.0, 4.0, 5.0, 6.0, // row 0
402                2.0, 3.0, 3.0, 3.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, // row 1
403                0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, // row 2
404                0.0, 1.0, 2.0, 3.0, 3.0, 3.0, 3.0, 4.0, 5.0, 6.0, // row 3
405            ]
406            .into(),
407            4,
408            10,
409        )
410        .unwrap();
411
412        // Shift row 0 by one and row 1 by two.
413        let our_results = Matrix::try_from(
414            vec![
415                100, 0, 1, 2, 5, 6, // row 0
416                100, 101, 7, 8, 9, 10, // row 1
417                0, 1, 2, 3, 4, 5, // row 2
418                0, 1, 2, 3, 4, 5, // row 3
419            ]
420            .into(),
421            4,
422            6,
423        )
424        .unwrap();
425
426        //---------//
427        // No Ties //
428        //---------//
429        let expected_no_ties = vec![
430            // Equal `k` and `n`
431            ExpectedRecall::new(1, 1, vec![0, 0, 1, 1]),
432            ExpectedRecall::new(2, 2, vec![1, 0, 2, 2]),
433            ExpectedRecall::new(3, 3, vec![2, 1, 3, 3]),
434            ExpectedRecall::new(4, 4, vec![3, 2, 4, 4]),
435            ExpectedRecall::new(5, 5, vec![3, 3, 5, 5]),
436            ExpectedRecall::new(6, 6, vec![4, 4, 6, 6]),
437            // Unequal `k` and `n`.
438            ExpectedRecall::new(1, 2, vec![1, 0, 1, 1]),
439            ExpectedRecall::new(1, 3, vec![1, 0, 1, 1]),
440            ExpectedRecall::new(2, 3, vec![2, 0, 2, 2]),
441            ExpectedRecall::new(3, 5, vec![3, 1, 3, 3]),
442        ];
443        let epsilon = 1e-6; // Define a small tolerance
444
445        for (i, expected) in expected_no_ties.iter().enumerate() {
446            assert_eq!(expected.components.len(), our_results.nrows());
447            let recall = knn(
448                &groundtruth,
449                None,
450                &our_results,
451                expected.recall_k,
452                expected.recall_n,
453                false,
454            )
455            .unwrap();
456
457            let left = recall.average;
458            let right = expected.compute_recall();
459            assert!(
460                (left - right).abs() < epsilon,
461                "left = {}, right = {} on input {}",
462                left,
463                right,
464                i
465            );
466
467            assert_eq!(recall.num_queries, our_results.nrows());
468            assert_eq!(recall.recall_k, expected.recall_k);
469            assert_eq!(recall.recall_n, expected.recall_n);
470            assert_eq!(recall.minimum, *expected.components.iter().min().unwrap());
471            assert_eq!(recall.maximum, *expected.components.iter().max().unwrap());
472        }
473
474        //-----------//
475        // With Ties //
476        //-----------//
477        let expected_with_ties = vec![
478            // Equal `k` and `n`
479            ExpectedRecall::new(1, 1, vec![0, 0, 1, 1]),
480            ExpectedRecall::new(2, 2, vec![1, 0, 2, 2]),
481            ExpectedRecall::new(3, 3, vec![2, 1, 3, 3]),
482            ExpectedRecall::new(4, 4, vec![3, 2, 4, 4]),
483            ExpectedRecall::new(5, 5, vec![4, 3, 5, 5]), // tie-breaker kicks in
484            ExpectedRecall::new(6, 6, vec![5, 4, 6, 6]), // tie-breaker kicks in
485            // Unequal `k` and `n`.
486            ExpectedRecall::new(1, 2, vec![1, 0, 1, 1]),
487            ExpectedRecall::new(1, 3, vec![1, 0, 1, 1]),
488            ExpectedRecall::new(2, 3, vec![2, 1, 2, 2]),
489            ExpectedRecall::new(4, 5, vec![4, 3, 4, 4]),
490        ];
491
492        for (i, expected) in expected_with_ties.iter().enumerate() {
493            assert_eq!(expected.components.len(), our_results.nrows());
494            let recall = knn(
495                &groundtruth,
496                Some(distances.as_view().into()),
497                &our_results,
498                expected.recall_k,
499                expected.recall_n,
500                false,
501            )
502            .unwrap();
503
504            let left = recall.average;
505            let right = expected.compute_recall();
506            assert!(
507                (left - right).abs() < epsilon,
508                "left = {}, right = {} on input {}",
509                left,
510                right,
511                i
512            );
513
514            assert_eq!(recall.num_queries, our_results.nrows());
515            assert_eq!(recall.recall_k, expected.recall_k);
516            assert_eq!(recall.recall_n, expected.recall_n);
517            assert_eq!(recall.minimum, *expected.components.iter().min().unwrap());
518            assert_eq!(recall.maximum, *expected.components.iter().max().unwrap());
519        }
520    }
521
522    #[test]
523    fn test_errors() {
524        // k greater than n
525        {
526            let groundtruth = Matrix::<u32>::new(0, 10, 10);
527            let results = Matrix::<u32>::new(0, 10, 10);
528            let err = knn(&groundtruth, None, &results, 11, 10, false).unwrap_err();
529            assert!(matches!(err, ComputeRecallError::RecallKAndNError(..)));
530        }
531
532        // Unequal rows
533        {
534            let groundtruth = Matrix::<u32>::new(0, 11, 10);
535            let results = Matrix::<u32>::new(0, 10, 10);
536            let err = knn(&groundtruth, None, &results, 10, 10, false).unwrap_err();
537            assert!(matches!(err, ComputeRecallError::RowsMismatch(..)));
538            let err_allow_insufficient_results =
539                knn(&groundtruth, None, &results, 10, 10, true).unwrap_err();
540            assert!(matches!(
541                err_allow_insufficient_results,
542                ComputeRecallError::RowsMismatch(..)
543            ));
544        }
545
546        // Not enough results
547        {
548            let groundtruth = Matrix::<u32>::new(0, 10, 10);
549            let results = Matrix::<u32>::new(0, 10, 5);
550            let err = knn(&groundtruth, None, &results, 5, 10, false).unwrap_err();
551            assert!(matches!(err, ComputeRecallError::NotEnoughResults(..)));
552            let _ = knn(&groundtruth, None, &results, 5, 10, true);
553        }
554
555        // Not enough results - dynamic
556        {
557            let groundtruth = Matrix::<u32>::new(0, 10, 10);
558            let results: Vec<_> = (0..10).map(|_| vec![0; 5]).collect();
559            let err = knn(&groundtruth, None, &results, 5, 10, false).unwrap_err();
560            assert!(matches!(err, ComputeRecallError::NotEnoughResults(..)));
561            let _ = knn(&groundtruth, None, &results, 5, 10, true);
562        }
563
564        // Not enough groundtruth
565        {
566            let groundtruth = Matrix::<u32>::new(0, 10, 5);
567            let results = Matrix::<u32>::new(0, 10, 10);
568            let err = knn(&groundtruth, None, &results, 10, 10, false).unwrap_err();
569            assert!(matches!(err, ComputeRecallError::NotEnoughGroundTruth(..)));
570            let err_allow_insufficient_results =
571                knn(&groundtruth, None, &results, 10, 10, true).unwrap_err();
572            assert!(matches!(
573                err_allow_insufficient_results,
574                ComputeRecallError::NotEnoughGroundTruth(..)
575            ));
576        }
577
578        // Not enough groundtruth - dynamic
579        {
580            let groundtruth: Vec<_> = (0..10).map(|_| vec![0; 5]).collect();
581            let results = Matrix::<u32>::new(0, 10, 10);
582            let err = knn(&groundtruth, None, &results, 10, 10, false).unwrap_err();
583            assert!(matches!(err, ComputeRecallError::NotEnoughGroundTruth(..)));
584            let err_allow_insufficient_results =
585                knn(&groundtruth, None, &results, 10, 10, true).unwrap_err();
586            assert!(matches!(
587                err_allow_insufficient_results,
588                ComputeRecallError::NotEnoughGroundTruth(..)
589            ));
590        }
591
592        // Distance Row Mismatch
593        {
594            let groundtruth = Matrix::<u32>::new(0, 10, 10);
595            let distances = Matrix::<f32>::new(0.0, 9, 10);
596            let results = Matrix::<u32>::new(0, 10, 10);
597            let err = knn(
598                &groundtruth,
599                Some(distances.as_view().into()),
600                &results,
601                10,
602                10,
603                false,
604            )
605            .unwrap_err();
606            assert!(matches!(err, ComputeRecallError::DistanceRowsMismatch(..)));
607        }
608
609        // Distance Cols Mismatch
610        {
611            let groundtruth = Matrix::<u32>::new(0, 10, 10);
612            let distances = Matrix::<f32>::new(0.0, 10, 9);
613            let results = Matrix::<u32>::new(0, 10, 10);
614            let err = knn(
615                &groundtruth,
616                Some(distances.as_view().into()),
617                &results,
618                10,
619                10,
620                false,
621            )
622            .unwrap_err();
623            assert!(matches!(
624                err,
625                ComputeRecallError::NotEnoughGroundTruthDistances(..)
626            ));
627        }
628    }
629}