diskann_benchmark_core/search/graph/
multihop.rs1use std::sync::Arc;
7
8use diskann::{
9 ANNResult,
10 graph::{self, glue},
11 provider,
12};
13use diskann_utils::{future::AsyncFriendly, views::Matrix};
14
15use crate::search::{self, Search, graph::Strategy};
16
17#[derive(Debug)]
28pub struct MultiHop<DP, T, S>
29where
30 DP: provider::DataProvider,
31{
32 index: Arc<graph::DiskANNIndex<DP>>,
33 queries: Arc<Matrix<T>>,
34 strategy: Strategy<S>,
35 labels: Arc<[Arc<dyn graph::index::QueryLabelProvider<DP::InternalId>>]>,
36}
37
38impl<DP, T, S> MultiHop<DP, T, S>
39where
40 DP: provider::DataProvider,
41{
42 pub fn new(
62 index: Arc<graph::DiskANNIndex<DP>>,
63 queries: Arc<Matrix<T>>,
64 strategy: Strategy<S>,
65 labels: Arc<[Arc<dyn graph::index::QueryLabelProvider<DP::InternalId>>]>,
66 ) -> anyhow::Result<Arc<Self>> {
67 strategy.length_compatible(queries.nrows())?;
68
69 if labels.len() != queries.nrows() {
70 Err(anyhow::anyhow!(
71 "Number of label providers ({}) must be equal to the number of queries ({})",
72 labels.len(),
73 queries.nrows()
74 ))
75 } else {
76 Ok(Arc::new(Self {
77 index,
78 queries,
79 strategy,
80 labels,
81 }))
82 }
83 }
84}
85
86impl<DP, T, S> Search for MultiHop<DP, T, S>
87where
88 DP: provider::DataProvider<Context: Default, ExternalId: search::Id>,
89 S: for<'a> glue::DefaultSearchStrategy<DP, &'a [T], DP::ExternalId> + Clone + AsyncFriendly,
90 T: AsyncFriendly + Clone,
91{
92 type Id = DP::ExternalId;
93 type Parameters = graph::search::Knn;
94 type Output = super::knn::Metrics;
95
96 fn num_queries(&self) -> usize {
97 self.queries.nrows()
98 }
99
100 fn id_count(&self, parameters: &Self::Parameters) -> search::IdCount {
101 search::IdCount::Fixed(parameters.k_value())
102 }
103
104 async fn search<O>(
105 &self,
106 parameters: &Self::Parameters,
107 buffer: &mut O,
108 index: usize,
109 ) -> ANNResult<Self::Output>
110 where
111 O: graph::SearchOutputBuffer<DP::ExternalId> + Send,
112 {
113 let context = DP::Context::default();
114 let multihop_search = graph::search::MultihopSearch::new(*parameters, &*self.labels[index]);
115 let stats = self
116 .index
117 .search(
118 multihop_search,
119 self.strategy.get(index)?,
120 &context,
121 self.queries.row(index),
122 buffer,
123 )
124 .await?;
125
126 Ok(super::knn::Metrics {
127 comparisons: stats.cmps,
128 hops: stats.hops,
129 })
130 }
131}
132
133#[cfg(test)]
138mod tests {
139 use std::num::NonZeroUsize;
140
141 use super::*;
142
143 use diskann::graph::{index::QueryLabelProvider, test::provider};
144
145 #[derive(Debug)]
147 struct NoOdds;
148
149 impl graph::index::QueryLabelProvider<u32> for NoOdds {
150 fn is_match(&self, id: u32) -> bool {
151 id.is_multiple_of(2)
152 }
153 }
154
155 #[test]
156 fn test_multihop() {
157 let nearest_neighbors = 5;
158
159 let index = search::graph::test_grid_provider();
160
161 let mut queries = Matrix::new(0.0f32, 5, index.provider().dim());
162 queries.row_mut(0).copy_from_slice(&[0.0, 0.0, 0.0, 0.0]);
163 queries.row_mut(1).copy_from_slice(&[4.0, 0.0, 0.0, 0.0]);
164 queries.row_mut(2).copy_from_slice(&[0.0, 4.0, 0.0, 0.0]);
165 queries.row_mut(3).copy_from_slice(&[0.0, 0.0, 4.0, 0.0]);
166 queries.row_mut(4).copy_from_slice(&[0.0, 0.0, 0.0, 4.0]);
167
168 let queries = Arc::new(queries);
169
170 let multihop = MultiHop::new(
171 index,
172 queries.clone(),
173 Strategy::broadcast(provider::Strategy::new()),
174 (0..queries.nrows())
175 .map(|_| -> Arc<dyn QueryLabelProvider<_>> { Arc::new(NoOdds {}) })
176 .collect(),
177 )
178 .unwrap();
179
180 let rt = crate::tokio::runtime(2).unwrap();
182 let results = search::search(
183 multihop.clone(),
184 graph::search::Knn::new(nearest_neighbors, 10, None).unwrap(),
185 NonZeroUsize::new(2).unwrap(),
186 &rt,
187 )
188 .unwrap();
189
190 assert_eq!(results.len(), queries.nrows());
191 let rows = results.ids().as_rows();
192 assert_eq!(*rows.row(0).first().unwrap(), 0);
193
194 for r in 0..rows.nrows() {
196 assert_eq!(rows.row(r).len(), nearest_neighbors);
197 for &id in rows.row(r) {
198 assert_eq!(id % 2, 0, "Found odd ID {} in row {}", id, r);
199 }
200 }
201
202 const TWO: NonZeroUsize = NonZeroUsize::new(2).unwrap();
203 let setup = search::Setup {
204 threads: TWO,
205 tasks: TWO,
206 reps: TWO,
207 };
208
209 let parameters = [
211 search::Run::new(
212 graph::search::Knn::new(nearest_neighbors, 10, None).unwrap(),
213 setup.clone(),
214 ),
215 search::Run::new(
216 graph::search::Knn::new(nearest_neighbors, 15, None).unwrap(),
217 setup.clone(),
218 ),
219 ];
220
221 let recall_k = nearest_neighbors;
222 let recall_n = nearest_neighbors;
223
224 let all = search::search_all(
225 multihop,
226 parameters,
227 search::graph::knn::Aggregator::new(rows, recall_k, recall_n),
228 )
229 .unwrap();
230
231 assert_eq!(all.len(), 2);
232 for summary in all {
233 assert_eq!(summary.setup, setup);
234 assert_eq!(summary.end_to_end_latencies.len(), TWO.get());
235 assert_eq!(summary.mean_latencies.len(), TWO.get());
236 assert_eq!(summary.p90_latencies.len(), TWO.get());
237 assert_eq!(summary.p99_latencies.len(), TWO.get());
238
239 assert_ne!(summary.mean_cmps, 0.0);
240 assert_ne!(summary.mean_hops, 0.0);
241
242 let recall = summary.recall;
243 assert_eq!(recall.recall_k, recall_k);
244 assert_eq!(recall.recall_n, recall_n);
245 assert_eq!(recall.num_queries, queries.nrows());
246 assert_eq!(recall.average, 1.0, "we used a search as the groundtruth");
247 }
248 }
249
250 #[test]
251 fn test_multihop_error() {
252 let index = search::graph::test_grid_provider();
253 let queries = Arc::new(Matrix::new(0.0f32, 2, index.provider().dim()));
254
255 let labels: Arc<[_]> = (0..queries.nrows() + 1)
256 .map(|_| -> Arc<dyn QueryLabelProvider<_>> { Arc::new(NoOdds {}) })
257 .collect();
258
259 let strategy = provider::Strategy::new();
260
261 let err = MultiHop::new(
263 index.clone(),
264 queries.clone(),
265 Strategy::collection([strategy.clone()]),
266 labels.clone(),
267 )
268 .unwrap_err();
269 let msg = err.to_string();
270 assert!(
271 msg.contains("1 strategy was provided when 2 were expected"),
272 "failed with {msg}"
273 );
274
275 let err = MultiHop::new(
277 index,
278 queries.clone(),
279 Strategy::broadcast(strategy.clone()),
280 labels.clone(),
281 )
282 .unwrap_err();
283 let msg = err.to_string();
284 assert!(
285 msg.contains(
286 "Number of label providers (3) must be equal to the number of queries (2)"
287 ),
288 "failed with {msg}"
289 );
290 }
291}