Skip to main content

diskann_benchmark_core/search/graph/
range.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::{num::NonZeroUsize, sync::Arc};
7
8use diskann::{
9    ANNResult,
10    graph::{self, glue},
11    provider,
12};
13use diskann_benchmark_runner::utils::{MicroSeconds, percentiles};
14use diskann_utils::{future::AsyncFriendly, views::Matrix};
15
16use crate::{
17    recall,
18    search::{self, Search, graph::Strategy},
19};
20
21/// A built-in helper for benchmarking the range search method
22/// [`graph::DiskANNIndex::range_search`].
23///
24/// This is intended to be used in conjunction with [`search::search`] or
25/// [`search::search_all`] and provides some basic additional metrics for
26/// the latter. Result aggregation for [`search::search_all`] is provided
27/// by the [`Aggregator`] type.
28///
29/// The provided implementation of [`Search`] accepts
30/// [`graph::search::Range`] and returns [`Metrics`] as additional output.
31#[derive(Debug)]
32pub struct Range<DP, T, S>
33where
34    DP: provider::DataProvider,
35{
36    index: Arc<graph::DiskANNIndex<DP>>,
37    queries: Arc<Matrix<T>>,
38    strategy: Strategy<S>,
39}
40
41impl<DP, T, S> Range<DP, T, S>
42where
43    DP: provider::DataProvider,
44{
45    /// Construct a new [`Range`] searcher.
46    ///
47    /// If `strategy` is one of the container variants of [`Strategy`], its length
48    /// must match the number of rows in `queries`. If this is the case, then the
49    /// strategies will have a querywise correspondence (see [`search::SearchResults`])
50    /// with the query matrix.
51    ///
52    /// # Errors
53    ///
54    /// Returns an error if the number of elements in `strategy` is not compatible with
55    /// the number of rows in `queries`.
56    pub fn new(
57        index: Arc<graph::DiskANNIndex<DP>>,
58        queries: Arc<Matrix<T>>,
59        strategy: Strategy<S>,
60    ) -> anyhow::Result<Arc<Self>> {
61        strategy.length_compatible(queries.nrows())?;
62
63        Ok(Arc::new(Self {
64            index,
65            queries,
66            strategy,
67        }))
68    }
69}
70
71/// Placeholder for range search metrics.
72///
73/// This struct currently does not contain any fields, but may be augmented
74/// in the future. The use of `#[non_exhaustive]` prevents breaking changes.
75#[derive(Debug, Clone, Copy)]
76#[non_exhaustive]
77pub struct Metrics {}
78
79impl<DP, T, S> Search for Range<DP, T, S>
80where
81    DP: provider::DataProvider<Context: Default, ExternalId: search::Id>,
82    S: for<'a> glue::DefaultSearchStrategy<DP, &'a [T], DP::ExternalId> + Clone + AsyncFriendly,
83    T: AsyncFriendly + Clone,
84{
85    type Id = DP::ExternalId;
86    type Parameters = graph::search::Range;
87    type Output = Metrics;
88
89    fn num_queries(&self) -> usize {
90        self.queries.nrows()
91    }
92
93    fn id_count(&self, parameters: &Self::Parameters) -> search::IdCount {
94        search::IdCount::Dynamic(NonZeroUsize::new(parameters.starting_l()))
95    }
96
97    async fn search<O>(
98        &self,
99        parameters: &Self::Parameters,
100        buffer: &mut O,
101        index: usize,
102    ) -> ANNResult<Self::Output>
103    where
104        O: graph::SearchOutputBuffer<DP::ExternalId> + Send,
105    {
106        let context = DP::Context::default();
107        let range_search = *parameters;
108        let _ = self
109            .index
110            .search(
111                range_search,
112                self.strategy.get(index)?,
113                &context,
114                self.queries.row(index),
115                buffer,
116            )
117            .await?;
118
119        Ok(Metrics {})
120    }
121}
122
123/// An [`search::Aggregate`]d summary of multiple [`Range`] search runs
124/// returned by the provided [`Aggregator`].
125///
126/// This struct is marked as non-exhaustive to allow for future additions.
127#[derive(Debug, Clone)]
128#[non_exhaustive]
129pub struct Summary {
130    /// The [`search::Setup`] used for the batch of runs.
131    pub setup: search::Setup,
132
133    /// The [`graph::search::Range`] used for the batch of runs.
134    pub parameters: graph::search::Range,
135
136    /// The end-to-end latency for each repetition in the batch.
137    pub end_to_end_latencies: Vec<MicroSeconds>,
138
139    /// The average latency for individual queries.
140    ///
141    /// This contains one entry per repetition in the batch.
142    pub mean_latencies: Vec<f64>,
143
144    /// The 90th percentile latency for individual queries.
145    ///
146    /// This contains one entry per repetition in the batch.
147    pub p90_latencies: Vec<MicroSeconds>,
148
149    /// The 99th percentile latency for individual queries.
150    ///
151    /// This contains one entry per repetition in the batch.
152    pub p99_latencies: Vec<MicroSeconds>,
153
154    /// The average precision metrics for the batch of runs.
155    ///
156    /// This implementation assumes that search is deterministic and only
157    /// uses the first repetition's results to compute the average precision.
158    pub average_precision: recall::AveragePrecisionMetrics,
159}
160
161/// A [`search::Aggregate`] for collecting the results of multiple [`Range`] search runs.
162///
163/// In addition to collecting latencies and other metrics, this aggregator computes
164/// [`crate::recall::AveragePrecisionMetrics`] using a provided groundtruth.
165///
166/// The aggregated results are available as a [`Summary`].
167pub struct Aggregator<'a, I> {
168    groundtruth: &'a dyn crate::recall::Rows<I>,
169}
170
171impl<'a, I> Aggregator<'a, I> {
172    /// Construct a new [`Aggregator`] using `groundtruth` for average precision computation.
173    pub fn new(groundtruth: &'a dyn crate::recall::Rows<I>) -> Self {
174        Self { groundtruth }
175    }
176}
177
178impl<I> search::Aggregate<graph::search::Range, I, Metrics> for Aggregator<'_, I>
179where
180    I: crate::recall::RecallCompatible,
181{
182    type Output = Summary;
183
184    #[inline(never)]
185    fn aggregate(
186        &mut self,
187        run: search::Run<graph::search::Range>,
188        mut results: Vec<search::SearchResults<I, Metrics>>,
189    ) -> anyhow::Result<Summary> {
190        // Compute the recall using just the first result.
191        let average_precision = match results.first() {
192            Some(first) => {
193                crate::recall::average_precision(first.ids().as_rows(), self.groundtruth)?
194            }
195            None => anyhow::bail!("Results must be non-empty"),
196        };
197
198        let mut mean_latencies = Vec::with_capacity(results.len());
199        let mut p90_latencies = Vec::with_capacity(results.len());
200        let mut p99_latencies = Vec::with_capacity(results.len());
201
202        results.iter_mut().for_each(|r| {
203            match percentiles::compute_percentiles(r.latencies_mut()) {
204                Ok(values) => {
205                    let percentiles::Percentiles { mean, p90, p99, .. } = values;
206                    mean_latencies.push(mean);
207                    p90_latencies.push(p90);
208                    p99_latencies.push(p99);
209                }
210                Err(_) => {
211                    let zero = MicroSeconds::new(0);
212                    mean_latencies.push(0.0);
213                    p90_latencies.push(zero);
214                    p99_latencies.push(zero);
215                }
216            }
217        });
218
219        Ok(Summary {
220            setup: run.setup().clone(),
221            parameters: *run.parameters(),
222            end_to_end_latencies: results.iter().map(|r| r.end_to_end_latency()).collect(),
223            mean_latencies,
224            p90_latencies,
225            p99_latencies,
226            average_precision,
227        })
228    }
229}
230
231///////////
232// Tests //
233///////////
234
235#[cfg(test)]
236mod tests {
237    use super::*;
238
239    use diskann::graph::test::provider;
240
241    #[test]
242    fn test_range() {
243        let index = search::graph::test_grid_provider();
244
245        let mut queries = Matrix::new(0.0f32, 5, index.provider().dim());
246        queries.row_mut(0).copy_from_slice(&[0.0, 0.0, 0.0, 0.0]);
247        queries.row_mut(1).copy_from_slice(&[4.0, 0.0, 0.0, 0.0]);
248        queries.row_mut(2).copy_from_slice(&[0.0, 4.0, 0.0, 0.0]);
249        queries.row_mut(3).copy_from_slice(&[0.0, 0.0, 4.0, 0.0]);
250        queries.row_mut(4).copy_from_slice(&[0.0, 0.0, 0.0, 4.0]);
251
252        let queries = Arc::new(queries);
253
254        let range = Range::new(
255            index,
256            queries.clone(),
257            Strategy::broadcast(provider::Strategy::new()),
258        )
259        .unwrap();
260
261        // Test the standard search interface.
262        let rt = crate::tokio::runtime(2).unwrap();
263        let results = search::search(
264            range.clone(),
265            graph::search::Range::with_options(None, 10, None, 2.0, None, 0.8, 1.2).unwrap(),
266            NonZeroUsize::new(2).unwrap(),
267            &rt,
268        )
269        .unwrap();
270
271        assert_eq!(results.len(), queries.nrows());
272        let rows = results.ids().as_rows();
273        assert_eq!(*rows.row(0).first().unwrap(), 0);
274        const TWO: NonZeroUsize = NonZeroUsize::new(2).unwrap();
275        let setup = search::Setup {
276            threads: TWO,
277            tasks: TWO,
278            reps: TWO,
279        };
280
281        // Try the aggregated strategy.
282        let parameters = [
283            search::Run::new(
284                graph::search::Range::with_options(None, 10, None, 2.0, None, 0.8, 1.2).unwrap(),
285                setup.clone(),
286            ),
287            search::Run::new(
288                graph::search::Range::with_options(None, 15, None, 2.0, None, 0.8, 1.2).unwrap(),
289                setup.clone(),
290            ),
291        ];
292
293        let all = search::search_all(range, parameters, Aggregator::new(rows)).unwrap();
294
295        assert_eq!(all.len(), 2);
296        for summary in all {
297            assert_eq!(summary.setup, setup);
298            assert_eq!(summary.end_to_end_latencies.len(), TWO.get());
299            assert_eq!(summary.mean_latencies.len(), TWO.get());
300            assert_eq!(summary.p90_latencies.len(), TWO.get());
301            assert_eq!(summary.p99_latencies.len(), TWO.get());
302
303            let ap = summary.average_precision;
304            assert_eq!(ap.num_queries, queries.nrows());
305            assert_eq!(
306                ap.average_precision, 1.0,
307                "we used a search as the groundtruth"
308            );
309        }
310    }
311
312    #[test]
313    fn test_range_error() {
314        let index = search::graph::test_grid_provider();
315
316        let queries = Arc::new(Matrix::new(0.0f32, 2, index.provider().dim()));
317        let strategy = provider::Strategy::new();
318
319        let err = Range::new(index, queries.clone(), Strategy::collection([strategy])).unwrap_err();
320        let msg = err.to_string();
321        assert!(
322            msg.contains("1 strategy was provided when 2 were expected"),
323            "failed with {msg}"
324        );
325    }
326}