diskann_benchmark_core/search/graph/
range.rs1use std::{num::NonZeroUsize, sync::Arc};
7
8use diskann::{
9 ANNResult,
10 graph::{self, glue},
11 provider,
12};
13use diskann_benchmark_runner::utils::{MicroSeconds, percentiles};
14use diskann_utils::{future::AsyncFriendly, views::Matrix};
15
16use crate::{
17 recall,
18 search::{self, Search, graph::Strategy},
19};
20
21#[derive(Debug)]
32pub struct Range<DP, T, S>
33where
34 DP: provider::DataProvider,
35{
36 index: Arc<graph::DiskANNIndex<DP>>,
37 queries: Arc<Matrix<T>>,
38 strategy: Strategy<S>,
39}
40
41impl<DP, T, S> Range<DP, T, S>
42where
43 DP: provider::DataProvider,
44{
45 pub fn new(
57 index: Arc<graph::DiskANNIndex<DP>>,
58 queries: Arc<Matrix<T>>,
59 strategy: Strategy<S>,
60 ) -> anyhow::Result<Arc<Self>> {
61 strategy.length_compatible(queries.nrows())?;
62
63 Ok(Arc::new(Self {
64 index,
65 queries,
66 strategy,
67 }))
68 }
69}
70
71#[derive(Debug, Clone, Copy)]
76#[non_exhaustive]
77pub struct Metrics {}
78
79impl<DP, T, S> Search for Range<DP, T, S>
80where
81 DP: provider::DataProvider<Context: Default, ExternalId: search::Id>,
82 S: for<'a> glue::DefaultSearchStrategy<DP, &'a [T], DP::ExternalId> + Clone + AsyncFriendly,
83 T: AsyncFriendly + Clone,
84{
85 type Id = DP::ExternalId;
86 type Parameters = graph::search::Range;
87 type Output = Metrics;
88
89 fn num_queries(&self) -> usize {
90 self.queries.nrows()
91 }
92
93 fn id_count(&self, parameters: &Self::Parameters) -> search::IdCount {
94 search::IdCount::Dynamic(NonZeroUsize::new(parameters.starting_l()))
95 }
96
97 async fn search<O>(
98 &self,
99 parameters: &Self::Parameters,
100 buffer: &mut O,
101 index: usize,
102 ) -> ANNResult<Self::Output>
103 where
104 O: graph::SearchOutputBuffer<DP::ExternalId> + Send,
105 {
106 let context = DP::Context::default();
107 let range_search = *parameters;
108 let _ = self
109 .index
110 .search(
111 range_search,
112 self.strategy.get(index)?,
113 &context,
114 self.queries.row(index),
115 buffer,
116 )
117 .await?;
118
119 Ok(Metrics {})
120 }
121}
122
123#[derive(Debug, Clone)]
128#[non_exhaustive]
129pub struct Summary {
130 pub setup: search::Setup,
132
133 pub parameters: graph::search::Range,
135
136 pub end_to_end_latencies: Vec<MicroSeconds>,
138
139 pub mean_latencies: Vec<f64>,
143
144 pub p90_latencies: Vec<MicroSeconds>,
148
149 pub p99_latencies: Vec<MicroSeconds>,
153
154 pub average_precision: recall::AveragePrecisionMetrics,
159}
160
161pub struct Aggregator<'a, I> {
168 groundtruth: &'a dyn crate::recall::Rows<I>,
169}
170
171impl<'a, I> Aggregator<'a, I> {
172 pub fn new(groundtruth: &'a dyn crate::recall::Rows<I>) -> Self {
174 Self { groundtruth }
175 }
176}
177
178impl<I> search::Aggregate<graph::search::Range, I, Metrics> for Aggregator<'_, I>
179where
180 I: crate::recall::RecallCompatible,
181{
182 type Output = Summary;
183
184 #[inline(never)]
185 fn aggregate(
186 &mut self,
187 run: search::Run<graph::search::Range>,
188 mut results: Vec<search::SearchResults<I, Metrics>>,
189 ) -> anyhow::Result<Summary> {
190 let average_precision = match results.first() {
192 Some(first) => {
193 crate::recall::average_precision(first.ids().as_rows(), self.groundtruth)?
194 }
195 None => anyhow::bail!("Results must be non-empty"),
196 };
197
198 let mut mean_latencies = Vec::with_capacity(results.len());
199 let mut p90_latencies = Vec::with_capacity(results.len());
200 let mut p99_latencies = Vec::with_capacity(results.len());
201
202 results.iter_mut().for_each(|r| {
203 match percentiles::compute_percentiles(r.latencies_mut()) {
204 Ok(values) => {
205 let percentiles::Percentiles { mean, p90, p99, .. } = values;
206 mean_latencies.push(mean);
207 p90_latencies.push(p90);
208 p99_latencies.push(p99);
209 }
210 Err(_) => {
211 let zero = MicroSeconds::new(0);
212 mean_latencies.push(0.0);
213 p90_latencies.push(zero);
214 p99_latencies.push(zero);
215 }
216 }
217 });
218
219 Ok(Summary {
220 setup: run.setup().clone(),
221 parameters: *run.parameters(),
222 end_to_end_latencies: results.iter().map(|r| r.end_to_end_latency()).collect(),
223 mean_latencies,
224 p90_latencies,
225 p99_latencies,
226 average_precision,
227 })
228 }
229}
230
231#[cfg(test)]
236mod tests {
237 use super::*;
238
239 use diskann::graph::test::provider;
240
241 #[test]
242 fn test_range() {
243 let index = search::graph::test_grid_provider();
244
245 let mut queries = Matrix::new(0.0f32, 5, index.provider().dim());
246 queries.row_mut(0).copy_from_slice(&[0.0, 0.0, 0.0, 0.0]);
247 queries.row_mut(1).copy_from_slice(&[4.0, 0.0, 0.0, 0.0]);
248 queries.row_mut(2).copy_from_slice(&[0.0, 4.0, 0.0, 0.0]);
249 queries.row_mut(3).copy_from_slice(&[0.0, 0.0, 4.0, 0.0]);
250 queries.row_mut(4).copy_from_slice(&[0.0, 0.0, 0.0, 4.0]);
251
252 let queries = Arc::new(queries);
253
254 let range = Range::new(
255 index,
256 queries.clone(),
257 Strategy::broadcast(provider::Strategy::new()),
258 )
259 .unwrap();
260
261 let rt = crate::tokio::runtime(2).unwrap();
263 let results = search::search(
264 range.clone(),
265 graph::search::Range::with_options(None, 10, None, 2.0, None, 0.8, 1.2).unwrap(),
266 NonZeroUsize::new(2).unwrap(),
267 &rt,
268 )
269 .unwrap();
270
271 assert_eq!(results.len(), queries.nrows());
272 let rows = results.ids().as_rows();
273 assert_eq!(*rows.row(0).first().unwrap(), 0);
274 const TWO: NonZeroUsize = NonZeroUsize::new(2).unwrap();
275 let setup = search::Setup {
276 threads: TWO,
277 tasks: TWO,
278 reps: TWO,
279 };
280
281 let parameters = [
283 search::Run::new(
284 graph::search::Range::with_options(None, 10, None, 2.0, None, 0.8, 1.2).unwrap(),
285 setup.clone(),
286 ),
287 search::Run::new(
288 graph::search::Range::with_options(None, 15, None, 2.0, None, 0.8, 1.2).unwrap(),
289 setup.clone(),
290 ),
291 ];
292
293 let all = search::search_all(range, parameters, Aggregator::new(rows)).unwrap();
294
295 assert_eq!(all.len(), 2);
296 for summary in all {
297 assert_eq!(summary.setup, setup);
298 assert_eq!(summary.end_to_end_latencies.len(), TWO.get());
299 assert_eq!(summary.mean_latencies.len(), TWO.get());
300 assert_eq!(summary.p90_latencies.len(), TWO.get());
301 assert_eq!(summary.p99_latencies.len(), TWO.get());
302
303 let ap = summary.average_precision;
304 assert_eq!(ap.num_queries, queries.nrows());
305 assert_eq!(
306 ap.average_precision, 1.0,
307 "we used a search as the groundtruth"
308 );
309 }
310 }
311
312 #[test]
313 fn test_range_error() {
314 let index = search::graph::test_grid_provider();
315
316 let queries = Arc::new(Matrix::new(0.0f32, 2, index.provider().dim()));
317 let strategy = provider::Strategy::new();
318
319 let err = Range::new(index, queries.clone(), Strategy::collection([strategy])).unwrap_err();
320 let msg = err.to_string();
321 assert!(
322 msg.contains("1 strategy was provided when 2 were expected"),
323 "failed with {msg}"
324 );
325 }
326}