Skip to main content

diskann_benchmark_core/search/graph/
knn.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6//! A built-in helper for benchmarking K-nearest neighbors.
7
8use std::sync::Arc;
9
10use diskann::{
11    ANNResult,
12    graph::{self, glue},
13    provider,
14};
15use diskann_benchmark_runner::utils::{MicroSeconds, percentiles};
16use diskann_utils::{future::AsyncFriendly, views::Matrix};
17
18use crate::{
19    recall,
20    search::{self, Search, graph::Strategy},
21    utils,
22};
23
24/// A built-in helper for benchmarking the K-nearest neighbors method
25/// [`graph::DiskANNIndex::search`].
26///
27/// This is intended to be used in conjunction with [`search::search`] or
28/// [`search::search_all`] and provides some basic additional metrics for
29/// the latter. Result aggregation for [`search::search_all`] is provided
30/// by the [`Aggregator`] type.
31///
32/// The provided implementation of [`Search`] accepts [`graph::search::Knn`]
33/// and returns [`Metrics`] as additional output.
34#[derive(Debug)]
35pub struct KNN<DP, T, S>
36where
37    DP: provider::DataProvider,
38{
39    index: Arc<graph::DiskANNIndex<DP>>,
40    queries: Arc<Matrix<T>>,
41    strategy: Strategy<S>,
42}
43
44impl<DP, T, S> KNN<DP, T, S>
45where
46    DP: provider::DataProvider,
47{
48    /// Construct a new [`KNN`] searcher.
49    ///
50    /// If `strategy` is one of the container variants of [`Strategy`], its length
51    /// must match the number of rows in `queries`. If this is the case, then the
52    /// strategies will have a querywise correspondence (see [`search::SearchResults`])
53    /// with the query matrix.
54    ///
55    /// # Errors
56    ///
57    /// Returns an error if the number of elements in `strategy` is not compatible with
58    /// the number of rows in `queries`.
59    pub fn new(
60        index: Arc<graph::DiskANNIndex<DP>>,
61        queries: Arc<Matrix<T>>,
62        strategy: Strategy<S>,
63    ) -> anyhow::Result<Arc<Self>> {
64        strategy.length_compatible(queries.nrows())?;
65
66        Ok(Arc::new(Self {
67            index,
68            queries,
69            strategy,
70        }))
71    }
72}
73
74/// Additional metrics collected during [`KNN`] search.
75///
76/// # Note
77///
78/// This struct is marked as non-exhaustive to allow for future additions.
79#[derive(Debug, Clone, Copy)]
80#[non_exhaustive]
81pub struct Metrics {
82    /// The number of distance comparisons performed during search.
83    pub comparisons: u32,
84    /// The number of candidates expanded during search.
85    pub hops: u32,
86}
87
88impl<DP, T, S> Search for KNN<DP, T, S>
89where
90    DP: provider::DataProvider<Context: Default, ExternalId: search::Id>,
91    S: for<'a> glue::DefaultSearchStrategy<DP, &'a [T], DP::ExternalId> + Clone + AsyncFriendly,
92    T: AsyncFriendly + Clone,
93{
94    type Id = DP::ExternalId;
95    type Parameters = graph::search::Knn;
96    type Output = Metrics;
97
98    fn num_queries(&self) -> usize {
99        self.queries.nrows()
100    }
101
102    fn id_count(&self, parameters: &Self::Parameters) -> search::IdCount {
103        search::IdCount::Fixed(parameters.k_value())
104    }
105
106    async fn search<O>(
107        &self,
108        parameters: &Self::Parameters,
109        buffer: &mut O,
110        index: usize,
111    ) -> ANNResult<Self::Output>
112    where
113        O: graph::SearchOutputBuffer<DP::ExternalId> + Send,
114    {
115        let context = DP::Context::default();
116        let knn_search = *parameters;
117        let stats = self
118            .index
119            .search(
120                knn_search,
121                self.strategy.get(index)?,
122                &context,
123                self.queries.row(index),
124                buffer,
125            )
126            .await?;
127
128        Ok(Metrics {
129            comparisons: stats.cmps,
130            hops: stats.hops,
131        })
132    }
133}
134
135/// An [`search::Aggregate`]d summary of multiple [`KNN`] search runs
136/// returned by the provided [`Aggregator`].
137///
138/// This struct is marked as non-exhaustive to allow for future additions.
139#[derive(Debug, Clone)]
140#[non_exhaustive]
141pub struct Summary {
142    /// The [`search::Setup`] used for the batch of runs.
143    pub setup: search::Setup,
144
145    /// The [`Search::Parameters`] used for the batch of runs.
146    pub parameters: graph::search::Knn,
147
148    /// The end-to-end latency for each repetition in the batch.
149    pub end_to_end_latencies: Vec<MicroSeconds>,
150
151    /// The average latency for individual queries.
152    ///
153    /// This contains one entry per repetition in the batch.
154    pub mean_latencies: Vec<f64>,
155
156    /// The 90th percentile latency for individual queries.
157    ///
158    /// This contains one entry per repetition in the batch.
159    pub p90_latencies: Vec<MicroSeconds>,
160
161    /// The 99th percentile latency for individual queries.
162    ///
163    /// This contains one entry per repetition in the batch.
164    pub p99_latencies: Vec<MicroSeconds>,
165
166    /// The recall metrics for search.
167    ///
168    /// This implementation assumes that search is deterministic and only
169    /// uses the first repetition's results to compute recall.
170    pub recall: recall::RecallMetrics,
171
172    /// The average number of distance comparisons per query.
173    pub mean_cmps: f64,
174
175    /// The average number of neighbor hops per query.
176    pub mean_hops: f64,
177}
178
179/// A [`search::Aggregate`] for collecting the results of multiple [`KNN`] search runs.
180///
181/// In addition to collecting latencies and other metrics, this aggregator computes
182/// recall using a provided groundtruth.
183///
184/// The aggregated results are available as a [`Summary`].
185pub struct Aggregator<'a, I> {
186    groundtruth: &'a dyn crate::recall::Rows<I>,
187    recall_k: usize,
188    recall_n: usize,
189}
190
191impl<'a, I> Aggregator<'a, I> {
192    /// Construct a new [`Aggregator`] using `groundtruth` for recall computation.
193    ///
194    /// Recall will be computed as `recall_k`-NN recall over the top `recall_n` neighbors.
195    ///
196    /// This implementation allows fewer than `recall_n` neighbors to be returned
197    /// per query without error.
198    pub fn new(
199        groundtruth: &'a dyn crate::recall::Rows<I>,
200        recall_k: usize,
201        recall_n: usize,
202    ) -> Self {
203        Self {
204            groundtruth,
205            recall_k,
206            recall_n,
207        }
208    }
209}
210
211impl<I> search::Aggregate<graph::search::Knn, I, Metrics> for Aggregator<'_, I>
212where
213    I: crate::recall::RecallCompatible,
214{
215    type Output = Summary;
216
217    fn aggregate(
218        &mut self,
219        run: search::Run<graph::search::Knn>,
220        mut results: Vec<search::SearchResults<I, Metrics>>,
221    ) -> anyhow::Result<Summary> {
222        // Compute the recall using just the first result.
223        let recall = match results.first() {
224            Some(first) => crate::recall::knn(
225                self.groundtruth,
226                None,
227                first.ids().as_rows(),
228                self.recall_k,
229                self.recall_n,
230                true,
231            )?,
232            None => anyhow::bail!("Results must be non-empty"),
233        };
234
235        let mut mean_latencies = Vec::with_capacity(results.len());
236        let mut p90_latencies = Vec::with_capacity(results.len());
237        let mut p99_latencies = Vec::with_capacity(results.len());
238
239        results.iter_mut().for_each(|r| {
240            match percentiles::compute_percentiles(r.latencies_mut()) {
241                Ok(values) => {
242                    let percentiles::Percentiles { mean, p90, p99, .. } = values;
243                    mean_latencies.push(mean);
244                    p90_latencies.push(p90);
245                    p99_latencies.push(p99);
246                }
247                Err(_) => {
248                    let zero = MicroSeconds::new(0);
249                    mean_latencies.push(0.0);
250                    p90_latencies.push(zero);
251                    p99_latencies.push(zero);
252                }
253            }
254        });
255
256        Ok(Summary {
257            setup: run.setup().clone(),
258            parameters: *run.parameters(),
259            end_to_end_latencies: results.iter().map(|r| r.end_to_end_latency()).collect(),
260            recall,
261            mean_latencies,
262            p90_latencies,
263            p99_latencies,
264            mean_cmps: utils::average_all(
265                results
266                    .iter()
267                    .flat_map(|r| r.output().iter().map(|o| o.comparisons)),
268            ),
269            mean_hops: utils::average_all(
270                results
271                    .iter()
272                    .flat_map(|r| r.output().iter().map(|o| o.hops)),
273            ),
274        })
275    }
276}
277
278///////////
279// Tests //
280///////////
281
282#[cfg(test)]
283mod tests {
284    use std::num::NonZeroUsize;
285
286    use super::*;
287
288    use diskann::graph::test::provider;
289
290    #[test]
291    fn test_knn() {
292        let nearest_neighbors = 5;
293
294        let index = search::graph::test_grid_provider();
295
296        let mut queries = Matrix::new(0.0f32, 5, index.provider().dim());
297        queries.row_mut(0).copy_from_slice(&[0.0, 0.0, 0.0, 0.0]);
298        queries.row_mut(1).copy_from_slice(&[4.0, 0.0, 0.0, 0.0]);
299        queries.row_mut(2).copy_from_slice(&[0.0, 4.0, 0.0, 0.0]);
300        queries.row_mut(3).copy_from_slice(&[0.0, 0.0, 4.0, 0.0]);
301        queries.row_mut(4).copy_from_slice(&[0.0, 0.0, 0.0, 4.0]);
302
303        let queries = Arc::new(queries);
304
305        let knn = KNN::new(
306            index,
307            queries.clone(),
308            Strategy::broadcast(provider::Strategy::new()),
309        )
310        .unwrap();
311
312        // Test the standard search interface.
313        let rt = crate::tokio::runtime(2).unwrap();
314        let results = search::search(
315            knn.clone(),
316            graph::search::Knn::new(nearest_neighbors, 10, None).unwrap(),
317            NonZeroUsize::new(2).unwrap(),
318            &rt,
319        )
320        .unwrap();
321
322        assert_eq!(results.len(), queries.nrows());
323        let rows = results.ids().as_rows();
324        assert_eq!(*rows.row(0).first().unwrap(), 0);
325
326        for r in 0..rows.nrows() {
327            assert_eq!(rows.row(r).len(), nearest_neighbors);
328        }
329
330        const TWO: NonZeroUsize = NonZeroUsize::new(2).unwrap();
331        let setup = search::Setup {
332            threads: TWO,
333            tasks: TWO,
334            reps: TWO,
335        };
336
337        // Try the aggregated strategy.
338        let parameters = [
339            search::Run::new(
340                graph::search::Knn::new(nearest_neighbors, 10, None).unwrap(),
341                setup.clone(),
342            ),
343            search::Run::new(
344                graph::search::Knn::new(nearest_neighbors, 15, None).unwrap(),
345                setup.clone(),
346            ),
347        ];
348
349        let recall_k = nearest_neighbors;
350        let recall_n = nearest_neighbors;
351
352        let all =
353            search::search_all(knn, parameters, Aggregator::new(rows, recall_k, recall_n)).unwrap();
354
355        assert_eq!(all.len(), 2);
356        for summary in all {
357            assert_eq!(summary.setup, setup);
358            assert_eq!(summary.end_to_end_latencies.len(), TWO.get());
359            assert_eq!(summary.mean_latencies.len(), TWO.get());
360            assert_eq!(summary.p90_latencies.len(), TWO.get());
361            assert_eq!(summary.p99_latencies.len(), TWO.get());
362
363            assert_ne!(summary.mean_cmps, 0.0);
364            assert_ne!(summary.mean_hops, 0.0);
365
366            let recall = summary.recall;
367            assert_eq!(recall.recall_k, recall_k);
368            assert_eq!(recall.recall_n, recall_n);
369            assert_eq!(recall.num_queries, queries.nrows());
370            assert_eq!(recall.average, 1.0, "we used a search as the groundtruth");
371        }
372    }
373
374    #[test]
375    fn test_knn_error() {
376        let index = search::graph::test_grid_provider();
377
378        let queries = Arc::new(Matrix::new(0.0f32, 1, index.provider().dim()));
379        let strategy = provider::Strategy::new();
380
381        let err = KNN::new(
382            index,
383            queries.clone(),
384            Strategy::collection([strategy.clone(), strategy.clone()]),
385        )
386        .unwrap_err();
387        let msg = err.to_string();
388        assert!(
389            msg.contains("2 strategies were provided when 1 was expected"),
390            "failed with {msg}"
391        );
392    }
393}