1use std::{collections::HashSet, sync::atomic::AtomicBool, time::Instant};
7
8use diskann::utils::IntoUsize;
9use diskann_disk::{
10 data_model::{CachingStrategy, GraphDataType},
11 search::provider::{
12 disk_provider::DiskIndexSearcher, disk_vertex_provider_factory::DiskVertexProviderFactory,
13 },
14 storage::disk_index_reader::DiskIndexReader,
15 utils::{
16 aligned_file_reader::traits::AlignedReaderFactory, instrumentation::PerfLogger, statistics,
17 QueryStatistics,
18 },
19};
20use diskann_providers::storage::{StorageReadProvider, StorageWriteProvider};
21use diskann_providers::{
22 storage::{get_compressed_pq_file, get_pq_pivot_file},
23 utils::{create_thread_pool, ParallelIteratorInPool},
24};
25use diskann_utils::{
26 io::{read_bin, write_bin},
27 views::MatrixView,
28};
29use diskann_vector::distance::Metric;
30use opentelemetry::global::BoxedSpan;
31#[cfg(feature = "perf_test")]
32use opentelemetry::{
33 trace::{Span, Tracer},
34 KeyValue,
35};
36use ordered_float::OrderedFloat;
37use rayon::prelude::*;
38use tracing::{error, info};
39
40use crate::utils::{search_index_utils, CMDResult, CMDToolError, KRecallAtN};
41
42pub struct SearchDiskIndexParameters<'a> {
43 pub metric: Metric,
44 pub index_path_prefix: &'a str,
45 pub result_output_prefix: &'a str,
46 pub query_file: &'a str,
47 pub truthset_file: &'a str,
48 pub vector_filters_file: Option<&'a str>,
49 pub num_threads: usize,
50 pub recall_at: u32,
51 pub beam_width: u32,
52 pub search_io_limit: u32,
53 pub l_vec: &'a [u32],
54 pub fail_if_recall_below: f32,
55 pub num_nodes_to_cache: usize,
56 pub is_flat_search: bool,
57}
58
59pub fn search_disk_index<Data, StorageType, ReaderFactory>(
60 storage_provider: &StorageType,
61 parameters: SearchDiskIndexParameters,
62 aligned_reader_factory: ReaderFactory,
63) -> CMDResult<i32>
64where
65 Data: GraphDataType<VectorIdType = u32>,
66 StorageType: StorageReadProvider + StorageWriteProvider,
67 ReaderFactory: AlignedReaderFactory,
68{
69 let mut logger = PerfLogger::new("search_disk_index".to_string(), true);
70
71 info!(
72 "Search parameters: #threads: {}, recall_at {}, search_list_size: {:?}, search_io_limit: {}, fail_if_recall_below: {}, beam_width: {}",
73 parameters.num_threads, parameters.recall_at, parameters.l_vec, parameters.search_io_limit, parameters.fail_if_recall_below,parameters.beam_width
74 );
75
76 let queries = read_bin::<Data::VectorDataType>(
78 &mut storage_provider.open_reader(parameters.query_file)?,
79 )?;
80 let query_num = queries.nrows();
81 let vector_filters = match parameters.vector_filters_file {
83 Some(vector_filters_file) => {
84 search_index_utils::load_vector_filters(storage_provider, vector_filters_file)?
85 }
86 None => vec![HashSet::<u32>::new(); query_num],
87 };
88
89 assert_eq!(
90 vector_filters.len(),
91 query_num,
92 "Mismatch in query and vector filter sizes"
93 );
94
95 let mut gt_dim: usize = 0;
96 let mut gt_ids: Option<Vec<u32>> = None;
97
98 let mut gt_ids_variable_length: Option<Vec<Vec<u32>>> = None;
99 let mut gt_dists: Option<Vec<f32>> = None;
100
101 let mut calc_recall_flag = false;
103 if !parameters.truthset_file.is_empty() && storage_provider.exists(parameters.truthset_file) {
104 if parameters.vector_filters_file.is_none() {
105 let ret =
106 search_index_utils::load_truthset(storage_provider, parameters.truthset_file)?;
107 gt_ids = Some(ret.index_nodes);
108 gt_dists = ret.distances;
109 let gt_num = ret.index_num_points;
110 gt_dim = ret.index_dimension;
111
112 if gt_num != query_num {
113 error!("Error. Mismatch in number of queries and ground truth data");
114 }
115 } else {
116 let range_truthset = search_index_utils::load_range_truthset(
117 storage_provider,
118 parameters.truthset_file,
119 )?;
120 gt_ids_variable_length = Some(range_truthset.index_nodes);
121 let gt_num = range_truthset.index_num_points;
122
123 if gt_num != query_num {
124 error!("Error. Mismatch in number of queries and ground truth data");
125 }
126 }
127
128 calc_recall_flag = true;
129 } else {
130 error!(
131 "Truthset file {} not found. Not computing recall",
132 parameters.truthset_file
133 );
134 }
135
136 let index_reader = DiskIndexReader::<<Data as GraphDataType>::VectorDataType>::new(
137 get_pq_pivot_file(parameters.index_path_prefix),
138 get_compressed_pq_file(parameters.index_path_prefix),
139 storage_provider,
140 )?;
141
142 let caching_strategy = if parameters.num_nodes_to_cache > 0 {
143 CachingStrategy::StaticCacheWithBfsNodes(parameters.num_nodes_to_cache)
144 } else {
145 CachingStrategy::None
146 };
147 let vertex_provider_factory =
149 DiskVertexProviderFactory::new(aligned_reader_factory, caching_strategy)?;
150
151 let searcher = DiskIndexSearcher::<Data, DiskVertexProviderFactory<Data, ReaderFactory>>::new(
152 parameters.num_threads.into_usize(),
153 parameters.search_io_limit.into_usize(),
154 &index_reader,
155 vertex_provider_factory,
156 parameters.metric,
157 None,
158 )?;
159
160 logger.log_checkpoint("index_loaded");
161
162 let recall_string = format!("Recall@{}", parameters.recall_at);
163 if calc_recall_flag {
164 println!(
165 "{:<6}{:<12}{:<15}{:<20}{:<20}{:<12}{:<16}{:<10}{:<20}{:<12}{:<12}{:<14}{:<16}",
166 "L",
167 "Beamwidth",
168 "QPS",
169 "Mean Latency (us)",
170 "99.9 Latency (us)",
171 "Mean IOs",
172 "Mean IO (us)",
173 "CPU (us)",
174 "PQ Preprocess (us)",
175 "Mean Comps",
176 "Mean Hops",
177 "Cache Hit %",
178 recall_string
179 );
180 } else {
181 println!(
182 "{:<6}{:<12}{:<15}{:<20}{:<20}{:<12}{:<16}{:<10}{:<20}{:<12}{:<12}{:<14}",
183 "L",
184 "Beamwidth",
185 "QPS",
186 "Mean Latency (us)",
187 "99.9 Latency (us)",
188 "Mean IOs",
189 "Mean IO (us)",
190 "CPU (us)",
191 "PQ Preprocess (us)",
192 "Mean Comparisons",
193 "Mean hops",
194 "Cache Hit %",
195 );
196 }
197 println!("{:=<178}", "");
198
199 let mut query_result_ids: Vec<Vec<u32>> = vec![vec![]; parameters.l_vec.len()];
200 let mut query_result_dists: Vec<Vec<f32>> = vec![vec![]; parameters.l_vec.len()];
201 let mut cmp_stats: Vec<u32> = vec![0; query_num];
202 let has_any_search_failed = AtomicBool::new(false);
203
204 let mut best_recall = 0.0;
205
206 let pool = create_thread_pool(parameters.num_threads)?;
207
208 for (test_id, &l) in parameters.l_vec.iter().enumerate() {
209 if l < parameters.recall_at {
210 println!(
211 "Ignoring search with L: {} since it's smaller than K: {}",
212 l, parameters.recall_at
213 );
214 continue;
215 }
216
217 query_result_ids[test_id].resize(parameters.recall_at as usize * query_num, 0);
218 query_result_dists[test_id].resize(parameters.recall_at as usize * query_num, 0.0);
219
220 let mut statistics: Vec<QueryStatistics> = vec![QueryStatistics::default(); query_num];
222 let mut result_counts: Vec<u32> = vec![0; query_num];
223
224 let zipped = cmp_stats
225 .par_iter_mut()
226 .zip(queries.par_row_iter())
227 .zip(vector_filters.par_iter())
228 .zip(query_result_ids[test_id].par_chunks_mut(parameters.recall_at as usize))
229 .zip(query_result_dists[test_id].par_chunks_mut(parameters.recall_at as usize))
230 .zip(statistics.par_iter_mut())
231 .zip(result_counts.par_iter_mut());
232
233 let mut _span: BoxedSpan;
234 #[cfg(feature = "perf_test")]
235 {
236 let tracer = opentelemetry::global::tracer("");
237
238 _span = tracer.start(format!("search-with-L={}-bw={}", l, parameters.beam_width));
240 }
241
242 let test_start = Instant::now();
243 zipped.for_each_in_pool(
244 pool.as_ref(),
245 |(
246 (((((_cmp, query), vector_filter), query_result_id), query_result_dist), stats),
247 result_count,
248 )| {
249 let vector_filter_function: Box<dyn Fn(&u32) -> bool + Send + Sync> =
250 if parameters.vector_filters_file.is_none() {
251 Box::new(|_: &u32| true)
252 } else {
253 Box::new(move |vector_id: &u32| vector_filter.contains(vector_id))
254 };
255
256 let result = searcher.search(
257 query,
258 parameters.recall_at,
259 l,
260 Some(parameters.beam_width as usize),
261 Some(vector_filter_function),
262 parameters.is_flat_search,
263 );
264
265 match result {
266 Ok(search_result) => {
267 *result_count = search_result.stats.result_count;
268 *stats = search_result.stats.query_statistics;
269 search_result
270 .results
271 .iter()
272 .take(parameters.recall_at as usize)
273 .enumerate()
274 .for_each(|(i, item)| {
275 query_result_id[i] = item.vertex_id;
276 query_result_dist[i] = item.distance;
277 });
278 }
279 Err(e) => {
280 error!("Error during search: {}", e);
281 has_any_search_failed.store(true, std::sync::atomic::Ordering::Release);
282 }
283 }
284 },
285 );
286
287 let diff = test_start.elapsed();
288 let qps = query_num as f32 / diff.as_secs_f32();
289
290 let mean_latency =
291 statistics::get_mean_stats(&statistics, |stats| stats.total_execution_time_us as f64);
292
293 let latency_999 = statistics::get_percentile_stats(&statistics, 0.999, |stats| {
294 stats.total_execution_time_us
295 });
296
297 let mean_ios = statistics::get_mean_stats(&statistics, |stats| stats.total_io_operations);
298 let mean_io_time = statistics::get_mean_stats(&statistics, |stats| stats.io_time_us as f64);
299 let mean_cpus = statistics::get_mean_stats(&statistics, |stats| stats.cpu_time_us as f64);
300 let mean_pq_preprocess_time = statistics::get_mean_stats(&statistics, |stats| {
301 stats.query_pq_preprocess_time_us as f64
302 });
303 let mean_comps =
304 statistics::get_mean_stats(&statistics, |stats| stats.total_comparisons as f64);
305 let mean_hops = statistics::get_mean_stats(&statistics, |stats| stats.search_hops as f64);
306 let total_ios = statistics::get_sum_stats(&statistics, |stats| stats.total_io_operations);
307 let total_vertices_loaded =
308 statistics::get_sum_stats(&statistics, |stats| stats.total_vertices_loaded);
309 let cache_hit_percentage = if total_vertices_loaded > 0.0 {
310 100.0 * (1.0 - (total_ios / total_vertices_loaded))
311 } else {
312 100.0
313 };
314
315 let mut recall = 0.0;
316 if calc_recall_flag {
317 recall = if let Some(gt_ids_variable_length) = >_ids_variable_length {
318 let our_results_variable_length = query_result_ids[test_id]
319 .chunks_exact(parameters.recall_at as usize)
320 .enumerate()
321 .map(|(i, chunk)| chunk[..result_counts[i] as usize].to_vec())
322 .collect::<Vec<_>>();
323 search_index_utils::calculate_filtered_search_recall(
324 query_num,
325 None,
326 gt_ids_variable_length,
327 &our_results_variable_length,
328 parameters.recall_at,
329 )? as f32
330 } else {
331 search_index_utils::calculate_recall(
332 query_num,
333 gt_ids.as_ref().ok_or_else(|| CMDToolError {
334 details: "GroundTruth IDs not initialized".to_string(),
335 })?,
336 gt_dists.as_ref(),
337 gt_dim,
338 &query_result_ids[test_id],
339 parameters.recall_at,
340 KRecallAtN::new(parameters.recall_at, parameters.recall_at)?,
341 )? as f32
342 };
343
344 best_recall = f32::from(std::cmp::max(
345 OrderedFloat::<f32>(best_recall),
346 OrderedFloat::<f32>(recall),
347 ));
348 }
349
350 if calc_recall_flag {
351 println!(
352 "{:<6}{:<12.2}{:<15.2}{:<20.2}{:<20.2}{:<12.2}{:<16.2}{:<10.2}{:<20.2}{:<12.2}{:<12.2}{:<14.2}{:<16.2}",
353 l,
354 parameters.beam_width,
355 qps,
356 mean_latency,
357 latency_999,
358 mean_ios,
359 mean_io_time,
360 mean_cpus,
361 mean_pq_preprocess_time,
362 mean_comps,
363 mean_hops,
364 cache_hit_percentage,
365 recall,
366 );
367 } else {
368 println!(
369 "{:<6}{:<12.2}{:<15.2}{:<20.2}{:<20.2}{:<12.2}{:<16.2}{:<10.2}{:<20.2}{:<12.2}{:<12.2}{:<14.2}",
370 l,
371 parameters.beam_width,
372 qps,
373 mean_latency,
374 latency_999,
375 mean_ios,
376 mean_io_time,
377 mean_cpus,
378 mean_pq_preprocess_time,
379 mean_comps,
380 mean_hops,
381 cache_hit_percentage,
382 );
383 }
384
385 #[cfg(feature = "perf_test")]
386 {
387 let latency_95 = statistics::get_percentile_stats(&statistics, 0.95, |stats| {
388 stats.total_execution_time_us
389 });
390
391 _span.set_attribute(KeyValue::new("qps", qps as f64));
392 _span.set_attribute(KeyValue::new("mean_latency", mean_latency));
393 _span.set_attribute(KeyValue::new("latency_999", latency_999 as f64));
394 _span.set_attribute(KeyValue::new("latency_95", latency_95 as f64));
395 _span.set_attribute(KeyValue::new("mean_cpus", mean_cpus));
396 _span.set_attribute(KeyValue::new("mean_io_time", mean_io_time));
397 _span.set_attribute(KeyValue::new("mean_ios", mean_ios));
398 _span.set_attribute(KeyValue::new("mean_comps", mean_comps));
399 _span.set_attribute(KeyValue::new("mean_hops", mean_hops));
400 _span.set_attribute(KeyValue::new("recall", recall as f64));
401 _span.end();
402 }
403 }
404
405 logger.log_checkpoint("search_completed");
406
407 info!("Done searching. Now saving results");
408 for (test_id, l_value) in parameters.l_vec.iter().enumerate() {
409 if *l_value < parameters.recall_at {
410 println!(
411 "Ignoring all search with L: {} since it's smaller than K: {}",
412 l_value, parameters.recall_at
413 );
414 }
415
416 let cur_result_path = format!(
417 "{}_{}_idx_uint32.bin",
418 parameters.result_output_prefix, l_value
419 );
420 let view = MatrixView::try_from(
421 query_result_ids[test_id].as_slice(),
422 query_num,
423 parameters.recall_at as usize,
424 )
425 .map_err(|e| CMDToolError {
426 details: e.to_string(),
427 })?;
428 write_bin(
429 view,
430 &mut storage_provider.create_for_write(&cur_result_path)?,
431 )?;
432 }
433
434 if has_any_search_failed.load(std::sync::atomic::Ordering::Acquire) {
435 return Err(CMDToolError {
437 details: "At least one search failed with error. See log for details. Exiting."
438 .to_string(),
439 });
440 }
441
442 if best_recall >= parameters.fail_if_recall_below {
443 Ok(0)
444 } else {
445 println!(
446 "Search failed. Best recall {} is below the threshold {}",
447 best_recall, parameters.fail_if_recall_below
448 );
449 Ok(-1)
450 }
451}