1use std::sync::Arc;
9
10use diskann::{
11 ANNResult,
12 graph::{self, glue},
13 provider,
14};
15use diskann_benchmark_runner::utils::{MicroSeconds, percentiles};
16use diskann_utils::{future::AsyncFriendly, views::Matrix};
17
18use crate::{
19 recall,
20 search::{self, Search, graph::Strategy},
21 utils,
22};
23
24#[derive(Debug)]
35pub struct KNN<DP, T, S>
36where
37 DP: provider::DataProvider,
38{
39 index: Arc<graph::DiskANNIndex<DP>>,
40 queries: Arc<Matrix<T>>,
41 strategy: Strategy<S>,
42}
43
44impl<DP, T, S> KNN<DP, T, S>
45where
46 DP: provider::DataProvider,
47{
48 pub fn new(
60 index: Arc<graph::DiskANNIndex<DP>>,
61 queries: Arc<Matrix<T>>,
62 strategy: Strategy<S>,
63 ) -> anyhow::Result<Arc<Self>> {
64 strategy.length_compatible(queries.nrows())?;
65
66 Ok(Arc::new(Self {
67 index,
68 queries,
69 strategy,
70 }))
71 }
72}
73
74#[derive(Debug, Clone, Copy)]
80#[non_exhaustive]
81pub struct Metrics {
82 pub comparisons: u32,
84 pub hops: u32,
86}
87
88impl<DP, T, S> Search for KNN<DP, T, S>
89where
90 DP: provider::DataProvider<Context: Default, ExternalId: search::Id>,
91 S: for<'a> glue::DefaultSearchStrategy<DP, &'a [T], DP::ExternalId> + Clone + AsyncFriendly,
92 T: AsyncFriendly + Clone,
93{
94 type Id = DP::ExternalId;
95 type Parameters = graph::search::Knn;
96 type Output = Metrics;
97
98 fn num_queries(&self) -> usize {
99 self.queries.nrows()
100 }
101
102 fn id_count(&self, parameters: &Self::Parameters) -> search::IdCount {
103 search::IdCount::Fixed(parameters.k_value())
104 }
105
106 async fn search<O>(
107 &self,
108 parameters: &Self::Parameters,
109 buffer: &mut O,
110 index: usize,
111 ) -> ANNResult<Self::Output>
112 where
113 O: graph::SearchOutputBuffer<DP::ExternalId> + Send,
114 {
115 let context = DP::Context::default();
116 let knn_search = *parameters;
117 let stats = self
118 .index
119 .search(
120 knn_search,
121 self.strategy.get(index)?,
122 &context,
123 self.queries.row(index),
124 buffer,
125 )
126 .await?;
127
128 Ok(Metrics {
129 comparisons: stats.cmps,
130 hops: stats.hops,
131 })
132 }
133}
134
135#[derive(Debug, Clone)]
140#[non_exhaustive]
141pub struct Summary {
142 pub setup: search::Setup,
144
145 pub parameters: graph::search::Knn,
147
148 pub end_to_end_latencies: Vec<MicroSeconds>,
150
151 pub mean_latencies: Vec<f64>,
155
156 pub p90_latencies: Vec<MicroSeconds>,
160
161 pub p99_latencies: Vec<MicroSeconds>,
165
166 pub recall: recall::RecallMetrics,
171
172 pub mean_cmps: f64,
174
175 pub mean_hops: f64,
177}
178
179pub struct Aggregator<'a, I> {
186 groundtruth: &'a dyn crate::recall::Rows<I>,
187 recall_k: usize,
188 recall_n: usize,
189}
190
191impl<'a, I> Aggregator<'a, I> {
192 pub fn new(
199 groundtruth: &'a dyn crate::recall::Rows<I>,
200 recall_k: usize,
201 recall_n: usize,
202 ) -> Self {
203 Self {
204 groundtruth,
205 recall_k,
206 recall_n,
207 }
208 }
209}
210
211impl<I> search::Aggregate<graph::search::Knn, I, Metrics> for Aggregator<'_, I>
212where
213 I: crate::recall::RecallCompatible,
214{
215 type Output = Summary;
216
217 fn aggregate(
218 &mut self,
219 run: search::Run<graph::search::Knn>,
220 mut results: Vec<search::SearchResults<I, Metrics>>,
221 ) -> anyhow::Result<Summary> {
222 let recall = match results.first() {
224 Some(first) => crate::recall::knn(
225 self.groundtruth,
226 None,
227 first.ids().as_rows(),
228 self.recall_k,
229 self.recall_n,
230 true,
231 )?,
232 None => anyhow::bail!("Results must be non-empty"),
233 };
234
235 let mut mean_latencies = Vec::with_capacity(results.len());
236 let mut p90_latencies = Vec::with_capacity(results.len());
237 let mut p99_latencies = Vec::with_capacity(results.len());
238
239 results.iter_mut().for_each(|r| {
240 match percentiles::compute_percentiles(r.latencies_mut()) {
241 Ok(values) => {
242 let percentiles::Percentiles { mean, p90, p99, .. } = values;
243 mean_latencies.push(mean);
244 p90_latencies.push(p90);
245 p99_latencies.push(p99);
246 }
247 Err(_) => {
248 let zero = MicroSeconds::new(0);
249 mean_latencies.push(0.0);
250 p90_latencies.push(zero);
251 p99_latencies.push(zero);
252 }
253 }
254 });
255
256 Ok(Summary {
257 setup: run.setup().clone(),
258 parameters: *run.parameters(),
259 end_to_end_latencies: results.iter().map(|r| r.end_to_end_latency()).collect(),
260 recall,
261 mean_latencies,
262 p90_latencies,
263 p99_latencies,
264 mean_cmps: utils::average_all(
265 results
266 .iter()
267 .flat_map(|r| r.output().iter().map(|o| o.comparisons)),
268 ),
269 mean_hops: utils::average_all(
270 results
271 .iter()
272 .flat_map(|r| r.output().iter().map(|o| o.hops)),
273 ),
274 })
275 }
276}
277
278#[cfg(test)]
283mod tests {
284 use std::num::NonZeroUsize;
285
286 use super::*;
287
288 use diskann::graph::test::provider;
289
290 #[test]
291 fn test_knn() {
292 let nearest_neighbors = 5;
293
294 let index = search::graph::test_grid_provider();
295
296 let mut queries = Matrix::new(0.0f32, 5, index.provider().dim());
297 queries.row_mut(0).copy_from_slice(&[0.0, 0.0, 0.0, 0.0]);
298 queries.row_mut(1).copy_from_slice(&[4.0, 0.0, 0.0, 0.0]);
299 queries.row_mut(2).copy_from_slice(&[0.0, 4.0, 0.0, 0.0]);
300 queries.row_mut(3).copy_from_slice(&[0.0, 0.0, 4.0, 0.0]);
301 queries.row_mut(4).copy_from_slice(&[0.0, 0.0, 0.0, 4.0]);
302
303 let queries = Arc::new(queries);
304
305 let knn = KNN::new(
306 index,
307 queries.clone(),
308 Strategy::broadcast(provider::Strategy::new()),
309 )
310 .unwrap();
311
312 let rt = crate::tokio::runtime(2).unwrap();
314 let results = search::search(
315 knn.clone(),
316 graph::search::Knn::new(nearest_neighbors, 10, None).unwrap(),
317 NonZeroUsize::new(2).unwrap(),
318 &rt,
319 )
320 .unwrap();
321
322 assert_eq!(results.len(), queries.nrows());
323 let rows = results.ids().as_rows();
324 assert_eq!(*rows.row(0).first().unwrap(), 0);
325
326 for r in 0..rows.nrows() {
327 assert_eq!(rows.row(r).len(), nearest_neighbors);
328 }
329
330 const TWO: NonZeroUsize = NonZeroUsize::new(2).unwrap();
331 let setup = search::Setup {
332 threads: TWO,
333 tasks: TWO,
334 reps: TWO,
335 };
336
337 let parameters = [
339 search::Run::new(
340 graph::search::Knn::new(nearest_neighbors, 10, None).unwrap(),
341 setup.clone(),
342 ),
343 search::Run::new(
344 graph::search::Knn::new(nearest_neighbors, 15, None).unwrap(),
345 setup.clone(),
346 ),
347 ];
348
349 let recall_k = nearest_neighbors;
350 let recall_n = nearest_neighbors;
351
352 let all =
353 search::search_all(knn, parameters, Aggregator::new(rows, recall_k, recall_n)).unwrap();
354
355 assert_eq!(all.len(), 2);
356 for summary in all {
357 assert_eq!(summary.setup, setup);
358 assert_eq!(summary.end_to_end_latencies.len(), TWO.get());
359 assert_eq!(summary.mean_latencies.len(), TWO.get());
360 assert_eq!(summary.p90_latencies.len(), TWO.get());
361 assert_eq!(summary.p99_latencies.len(), TWO.get());
362
363 assert_ne!(summary.mean_cmps, 0.0);
364 assert_ne!(summary.mean_hops, 0.0);
365
366 let recall = summary.recall;
367 assert_eq!(recall.recall_k, recall_k);
368 assert_eq!(recall.recall_n, recall_n);
369 assert_eq!(recall.num_queries, queries.nrows());
370 assert_eq!(recall.average, 1.0, "we used a search as the groundtruth");
371 }
372 }
373
374 #[test]
375 fn test_knn_error() {
376 let index = search::graph::test_grid_provider();
377
378 let queries = Arc::new(Matrix::new(0.0f32, 1, index.provider().dim()));
379 let strategy = provider::Strategy::new();
380
381 let err = KNN::new(
382 index,
383 queries.clone(),
384 Strategy::collection([strategy.clone(), strategy.clone()]),
385 )
386 .unwrap_err();
387 let msg = err.to_string();
388 assert!(
389 msg.contains("2 strategies were provided when 1 was expected"),
390 "failed with {msg}"
391 );
392 }
393}