Skip to main content

diskann_benchmark_core/search/graph/
multihop.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::sync::Arc;
7
8use diskann::{
9    ANNResult,
10    graph::{self, glue},
11    provider,
12};
13use diskann_utils::{future::AsyncFriendly, views::Matrix};
14
15use crate::search::{self, Search, graph::Strategy};
16
17/// A built-in helper for benchmarking filtered K-nearest neighbors search
18/// using the multi-hop search method.
19///
20/// This is intended to be used in conjunction with [`search::search`] or [`search::search_all`]
21/// and provides some basic additional metrics for the latter. Result aggregation for
22/// [`search::search_all`] is provided by the [`search::graph::knn::Aggregator`] type (same
23/// aggregator as [`search::graph::knn::KNN`]).
24///
25/// The provided implementation of [`Search`] accepts [`graph::search::Knn`]
26/// and returns [`search::graph::knn::Metrics`] as additional output.
27#[derive(Debug)]
28pub struct MultiHop<DP, T, S>
29where
30    DP: provider::DataProvider,
31{
32    index: Arc<graph::DiskANNIndex<DP>>,
33    queries: Arc<Matrix<T>>,
34    strategy: Strategy<S>,
35    labels: Arc<[Arc<dyn graph::index::QueryLabelProvider<DP::InternalId>>]>,
36}
37
38impl<DP, T, S> MultiHop<DP, T, S>
39where
40    DP: provider::DataProvider,
41{
42    /// Construct a new [`MultiHop`] searcher.
43    ///
44    /// If `strategy` is one of the container variants of [`Strategy`], its length
45    /// must match the number of rows in `queries`. If this is the case, then the
46    /// strategies will have a querywise correspondence (see [`search::SearchResults`])
47    /// with the query matrix.
48    ///
49    /// Additionally, the length of `labels` must match the number of rows in `queries`
50    /// and will be used in querywise correspondence with `queries`.
51    ///
52    /// # Errors
53    ///
54    /// Returns an error under the following conditions.
55    ///
56    /// 1. The number of elements in `strategy` is not compatible with the number of rows in
57    ///    `queries`.
58    ///
59    /// 2. The number of label providers in `labels` is not equal to the number of rows in
60    ///    `queries`.
61    pub fn new(
62        index: Arc<graph::DiskANNIndex<DP>>,
63        queries: Arc<Matrix<T>>,
64        strategy: Strategy<S>,
65        labels: Arc<[Arc<dyn graph::index::QueryLabelProvider<DP::InternalId>>]>,
66    ) -> anyhow::Result<Arc<Self>> {
67        strategy.length_compatible(queries.nrows())?;
68
69        if labels.len() != queries.nrows() {
70            Err(anyhow::anyhow!(
71                "Number of label providers ({}) must be equal to the number of queries ({})",
72                labels.len(),
73                queries.nrows()
74            ))
75        } else {
76            Ok(Arc::new(Self {
77                index,
78                queries,
79                strategy,
80                labels,
81            }))
82        }
83    }
84}
85
86impl<DP, T, S> Search for MultiHop<DP, T, S>
87where
88    DP: provider::DataProvider<Context: Default, ExternalId: search::Id>,
89    S: for<'a> glue::DefaultSearchStrategy<DP, &'a [T], DP::ExternalId> + Clone + AsyncFriendly,
90    T: AsyncFriendly + Clone,
91{
92    type Id = DP::ExternalId;
93    type Parameters = graph::search::Knn;
94    type Output = super::knn::Metrics;
95
96    fn num_queries(&self) -> usize {
97        self.queries.nrows()
98    }
99
100    fn id_count(&self, parameters: &Self::Parameters) -> search::IdCount {
101        search::IdCount::Fixed(parameters.k_value())
102    }
103
104    async fn search<O>(
105        &self,
106        parameters: &Self::Parameters,
107        buffer: &mut O,
108        index: usize,
109    ) -> ANNResult<Self::Output>
110    where
111        O: graph::SearchOutputBuffer<DP::ExternalId> + Send,
112    {
113        let context = DP::Context::default();
114        let multihop_search = graph::search::MultihopSearch::new(*parameters, &*self.labels[index]);
115        let stats = self
116            .index
117            .search(
118                multihop_search,
119                self.strategy.get(index)?,
120                &context,
121                self.queries.row(index),
122                buffer,
123            )
124            .await?;
125
126        Ok(super::knn::Metrics {
127            comparisons: stats.cmps,
128            hops: stats.hops,
129        })
130    }
131}
132
133///////////
134// Tests //
135///////////
136
137#[cfg(test)]
138mod tests {
139    use std::num::NonZeroUsize;
140
141    use super::*;
142
143    use diskann::graph::{index::QueryLabelProvider, test::provider};
144
145    // A simple [`QueryLabelProvider`] that rejects odd indices.
146    #[derive(Debug)]
147    struct NoOdds;
148
149    impl graph::index::QueryLabelProvider<u32> for NoOdds {
150        fn is_match(&self, id: u32) -> bool {
151            id.is_multiple_of(2)
152        }
153    }
154
155    #[test]
156    fn test_multihop() {
157        let nearest_neighbors = 5;
158
159        let index = search::graph::test_grid_provider();
160
161        let mut queries = Matrix::new(0.0f32, 5, index.provider().dim());
162        queries.row_mut(0).copy_from_slice(&[0.0, 0.0, 0.0, 0.0]);
163        queries.row_mut(1).copy_from_slice(&[4.0, 0.0, 0.0, 0.0]);
164        queries.row_mut(2).copy_from_slice(&[0.0, 4.0, 0.0, 0.0]);
165        queries.row_mut(3).copy_from_slice(&[0.0, 0.0, 4.0, 0.0]);
166        queries.row_mut(4).copy_from_slice(&[0.0, 0.0, 0.0, 4.0]);
167
168        let queries = Arc::new(queries);
169
170        let multihop = MultiHop::new(
171            index,
172            queries.clone(),
173            Strategy::broadcast(provider::Strategy::new()),
174            (0..queries.nrows())
175                .map(|_| -> Arc<dyn QueryLabelProvider<_>> { Arc::new(NoOdds {}) })
176                .collect(),
177        )
178        .unwrap();
179
180        // Test the standard search interface.
181        let rt = crate::tokio::runtime(2).unwrap();
182        let results = search::search(
183            multihop.clone(),
184            graph::search::Knn::new(nearest_neighbors, 10, None).unwrap(),
185            NonZeroUsize::new(2).unwrap(),
186            &rt,
187        )
188        .unwrap();
189
190        assert_eq!(results.len(), queries.nrows());
191        let rows = results.ids().as_rows();
192        assert_eq!(*rows.row(0).first().unwrap(), 0);
193
194        // Check that only even IDs are returned.
195        for r in 0..rows.nrows() {
196            assert_eq!(rows.row(r).len(), nearest_neighbors);
197            for &id in rows.row(r) {
198                assert_eq!(id % 2, 0, "Found odd ID {} in row {}", id, r);
199            }
200        }
201
202        const TWO: NonZeroUsize = NonZeroUsize::new(2).unwrap();
203        let setup = search::Setup {
204            threads: TWO,
205            tasks: TWO,
206            reps: TWO,
207        };
208
209        // Try the aggregated strategy.
210        let parameters = [
211            search::Run::new(
212                graph::search::Knn::new(nearest_neighbors, 10, None).unwrap(),
213                setup.clone(),
214            ),
215            search::Run::new(
216                graph::search::Knn::new(nearest_neighbors, 15, None).unwrap(),
217                setup.clone(),
218            ),
219        ];
220
221        let recall_k = nearest_neighbors;
222        let recall_n = nearest_neighbors;
223
224        let all = search::search_all(
225            multihop,
226            parameters,
227            search::graph::knn::Aggregator::new(rows, recall_k, recall_n),
228        )
229        .unwrap();
230
231        assert_eq!(all.len(), 2);
232        for summary in all {
233            assert_eq!(summary.setup, setup);
234            assert_eq!(summary.end_to_end_latencies.len(), TWO.get());
235            assert_eq!(summary.mean_latencies.len(), TWO.get());
236            assert_eq!(summary.p90_latencies.len(), TWO.get());
237            assert_eq!(summary.p99_latencies.len(), TWO.get());
238
239            assert_ne!(summary.mean_cmps, 0.0);
240            assert_ne!(summary.mean_hops, 0.0);
241
242            let recall = summary.recall;
243            assert_eq!(recall.recall_k, recall_k);
244            assert_eq!(recall.recall_n, recall_n);
245            assert_eq!(recall.num_queries, queries.nrows());
246            assert_eq!(recall.average, 1.0, "we used a search as the groundtruth");
247        }
248    }
249
250    #[test]
251    fn test_multihop_error() {
252        let index = search::graph::test_grid_provider();
253        let queries = Arc::new(Matrix::new(0.0f32, 2, index.provider().dim()));
254
255        let labels: Arc<[_]> = (0..queries.nrows() + 1)
256            .map(|_| -> Arc<dyn QueryLabelProvider<_>> { Arc::new(NoOdds {}) })
257            .collect();
258
259        let strategy = provider::Strategy::new();
260
261        // Error for a mismatch between strategies and queries.
262        let err = MultiHop::new(
263            index.clone(),
264            queries.clone(),
265            Strategy::collection([strategy.clone()]),
266            labels.clone(),
267        )
268        .unwrap_err();
269        let msg = err.to_string();
270        assert!(
271            msg.contains("1 strategy was provided when 2 were expected"),
272            "failed with {msg}"
273        );
274
275        // Error for a mismatch between label providers and queries.
276        let err = MultiHop::new(
277            index,
278            queries.clone(),
279            Strategy::broadcast(strategy.clone()),
280            labels.clone(),
281        )
282        .unwrap_err();
283        let msg = err.to_string();
284        assert!(
285            msg.contains(
286                "Number of label providers (3) must be equal to the number of queries (2)"
287            ),
288            "failed with {msg}"
289        );
290    }
291}