Skip to main content

diskann_tools/utils/
search_disk_index.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::{collections::HashSet, sync::atomic::AtomicBool, time::Instant};
7
8use diskann::utils::IntoUsize;
9use diskann_disk::{
10    data_model::{CachingStrategy, GraphDataType},
11    search::provider::{
12        disk_provider::DiskIndexSearcher, disk_vertex_provider_factory::DiskVertexProviderFactory,
13    },
14    storage::disk_index_reader::DiskIndexReader,
15    utils::{
16        aligned_file_reader::traits::AlignedReaderFactory, instrumentation::PerfLogger, statistics,
17        QueryStatistics,
18    },
19};
20use diskann_providers::storage::{StorageReadProvider, StorageWriteProvider};
21use diskann_providers::{
22    storage::{get_compressed_pq_file, get_pq_pivot_file},
23    utils::{create_thread_pool, ParallelIteratorInPool},
24};
25use diskann_utils::{
26    io::{read_bin, write_bin},
27    views::MatrixView,
28};
29use diskann_vector::distance::Metric;
30use opentelemetry::global::BoxedSpan;
31#[cfg(feature = "perf_test")]
32use opentelemetry::{
33    trace::{Span, Tracer},
34    KeyValue,
35};
36use ordered_float::OrderedFloat;
37use rayon::prelude::*;
38use tracing::{error, info};
39
40use crate::utils::{search_index_utils, CMDResult, CMDToolError, KRecallAtN};
41
42pub struct SearchDiskIndexParameters<'a> {
43    pub metric: Metric,
44    pub index_path_prefix: &'a str,
45    pub result_output_prefix: &'a str,
46    pub query_file: &'a str,
47    pub truthset_file: &'a str,
48    pub vector_filters_file: Option<&'a str>,
49    pub num_threads: usize,
50    pub recall_at: u32,
51    pub beam_width: u32,
52    pub search_io_limit: u32,
53    pub l_vec: &'a [u32],
54    pub fail_if_recall_below: f32,
55    pub num_nodes_to_cache: usize,
56    pub is_flat_search: bool,
57}
58
59pub fn search_disk_index<Data, StorageType, ReaderFactory>(
60    storage_provider: &StorageType,
61    parameters: SearchDiskIndexParameters,
62    aligned_reader_factory: ReaderFactory,
63) -> CMDResult<i32>
64where
65    Data: GraphDataType<VectorIdType = u32>,
66    StorageType: StorageReadProvider + StorageWriteProvider,
67    ReaderFactory: AlignedReaderFactory,
68{
69    let mut logger = PerfLogger::new("search_disk_index".to_string(), true);
70
71    info!(
72        "Search parameters: #threads: {}, recall_at {}, search_list_size: {:?}, search_io_limit: {}, fail_if_recall_below: {}, beam_width: {}",
73        parameters.num_threads, parameters.recall_at, parameters.l_vec, parameters.search_io_limit, parameters.fail_if_recall_below,parameters.beam_width
74    );
75
76    // Load the query file
77    let queries = read_bin::<Data::VectorDataType>(
78        &mut storage_provider.open_reader(parameters.query_file)?,
79    )?;
80    let query_num = queries.nrows();
81    // Load the vector filters
82    let vector_filters = match parameters.vector_filters_file {
83        Some(vector_filters_file) => {
84            search_index_utils::load_vector_filters(storage_provider, vector_filters_file)?
85        }
86        None => vec![HashSet::<u32>::new(); query_num],
87    };
88
89    assert_eq!(
90        vector_filters.len(),
91        query_num,
92        "Mismatch in query and vector filter sizes"
93    );
94
95    let mut gt_dim: usize = 0;
96    let mut gt_ids: Option<Vec<u32>> = None;
97
98    let mut gt_ids_variable_length: Option<Vec<Vec<u32>>> = None;
99    let mut gt_dists: Option<Vec<f32>> = None;
100
101    // Check for ground truth
102    let mut calc_recall_flag = false;
103    if !parameters.truthset_file.is_empty() && storage_provider.exists(parameters.truthset_file) {
104        if parameters.vector_filters_file.is_none() {
105            let ret =
106                search_index_utils::load_truthset(storage_provider, parameters.truthset_file)?;
107            gt_ids = Some(ret.index_nodes);
108            gt_dists = ret.distances;
109            let gt_num = ret.index_num_points;
110            gt_dim = ret.index_dimension;
111
112            if gt_num != query_num {
113                error!("Error. Mismatch in number of queries and ground truth data");
114            }
115        } else {
116            let range_truthset = search_index_utils::load_range_truthset(
117                storage_provider,
118                parameters.truthset_file,
119            )?;
120            gt_ids_variable_length = Some(range_truthset.index_nodes);
121            let gt_num = range_truthset.index_num_points;
122
123            if gt_num != query_num {
124                error!("Error. Mismatch in number of queries and ground truth data");
125            }
126        }
127
128        calc_recall_flag = true;
129    } else {
130        error!(
131            "Truthset file {} not found. Not computing recall",
132            parameters.truthset_file
133        );
134    }
135
136    let index_reader = DiskIndexReader::<<Data as GraphDataType>::VectorDataType>::new(
137        get_pq_pivot_file(parameters.index_path_prefix),
138        get_compressed_pq_file(parameters.index_path_prefix),
139        storage_provider,
140    )?;
141
142    let caching_strategy = if parameters.num_nodes_to_cache > 0 {
143        CachingStrategy::StaticCacheWithBfsNodes(parameters.num_nodes_to_cache)
144    } else {
145        CachingStrategy::None
146    };
147    // Create the vertex provider factory
148    let vertex_provider_factory =
149        DiskVertexProviderFactory::new(aligned_reader_factory, caching_strategy)?;
150
151    let searcher = DiskIndexSearcher::<Data, DiskVertexProviderFactory<Data, ReaderFactory>>::new(
152        parameters.num_threads.into_usize(),
153        parameters.search_io_limit.into_usize(),
154        &index_reader,
155        vertex_provider_factory,
156        parameters.metric,
157        None,
158    )?;
159
160    logger.log_checkpoint("index_loaded");
161
162    let recall_string = format!("Recall@{}", parameters.recall_at);
163    if calc_recall_flag {
164        println!(
165            "{:<6}{:<12}{:<15}{:<20}{:<20}{:<12}{:<16}{:<10}{:<20}{:<12}{:<12}{:<14}{:<16}",
166            "L",
167            "Beamwidth",
168            "QPS",
169            "Mean Latency (us)",
170            "99.9 Latency (us)",
171            "Mean IOs",
172            "Mean IO (us)",
173            "CPU (us)",
174            "PQ Preprocess (us)",
175            "Mean Comps",
176            "Mean Hops",
177            "Cache Hit %",
178            recall_string
179        );
180    } else {
181        println!(
182            "{:<6}{:<12}{:<15}{:<20}{:<20}{:<12}{:<16}{:<10}{:<20}{:<12}{:<12}{:<14}",
183            "L",
184            "Beamwidth",
185            "QPS",
186            "Mean Latency (us)",
187            "99.9 Latency (us)",
188            "Mean IOs",
189            "Mean IO (us)",
190            "CPU (us)",
191            "PQ Preprocess (us)",
192            "Mean Comparisons",
193            "Mean hops",
194            "Cache Hit %",
195        );
196    }
197    println!("{:=<178}", "");
198
199    let mut query_result_ids: Vec<Vec<u32>> = vec![vec![]; parameters.l_vec.len()];
200    let mut query_result_dists: Vec<Vec<f32>> = vec![vec![]; parameters.l_vec.len()];
201    let mut cmp_stats: Vec<u32> = vec![0; query_num];
202    let has_any_search_failed = AtomicBool::new(false);
203
204    let mut best_recall = 0.0;
205
206    let pool = create_thread_pool(parameters.num_threads)?;
207
208    for (test_id, &l) in parameters.l_vec.iter().enumerate() {
209        if l < parameters.recall_at {
210            println!(
211                "Ignoring search with L: {} since it's smaller than K: {}",
212                l, parameters.recall_at
213            );
214            continue;
215        }
216
217        query_result_ids[test_id].resize(parameters.recall_at as usize * query_num, 0);
218        query_result_dists[test_id].resize(parameters.recall_at as usize * query_num, 0.0);
219
220        // Assuming `QueryStats` is a struct that you have defined elsewhere
221        let mut statistics: Vec<QueryStatistics> = vec![QueryStatistics::default(); query_num];
222        let mut result_counts: Vec<u32> = vec![0; query_num];
223
224        let zipped = cmp_stats
225            .par_iter_mut()
226            .zip(queries.par_row_iter())
227            .zip(vector_filters.par_iter())
228            .zip(query_result_ids[test_id].par_chunks_mut(parameters.recall_at as usize))
229            .zip(query_result_dists[test_id].par_chunks_mut(parameters.recall_at as usize))
230            .zip(statistics.par_iter_mut())
231            .zip(result_counts.par_iter_mut());
232
233        let mut _span: BoxedSpan;
234        #[cfg(feature = "perf_test")]
235        {
236            let tracer = opentelemetry::global::tracer("");
237
238            // Start a span for the search iteration.
239            _span = tracer.start(format!("search-with-L={}-bw={}", l, parameters.beam_width));
240        }
241
242        let test_start = Instant::now();
243        zipped.for_each_in_pool(
244            pool.as_ref(),
245            |(
246                (((((_cmp, query), vector_filter), query_result_id), query_result_dist), stats),
247                result_count,
248            )| {
249                let vector_filter_function: Box<dyn Fn(&u32) -> bool + Send + Sync> =
250                    if parameters.vector_filters_file.is_none() {
251                        Box::new(|_: &u32| true)
252                    } else {
253                        Box::new(move |vector_id: &u32| vector_filter.contains(vector_id))
254                    };
255
256                let result = searcher.search(
257                    query,
258                    parameters.recall_at,
259                    l,
260                    Some(parameters.beam_width as usize),
261                    Some(vector_filter_function),
262                    parameters.is_flat_search,
263                );
264
265                match result {
266                    Ok(search_result) => {
267                        *result_count = search_result.stats.result_count;
268                        *stats = search_result.stats.query_statistics;
269                        search_result
270                            .results
271                            .iter()
272                            .take(parameters.recall_at as usize)
273                            .enumerate()
274                            .for_each(|(i, item)| {
275                                query_result_id[i] = item.vertex_id;
276                                query_result_dist[i] = item.distance;
277                            });
278                    }
279                    Err(e) => {
280                        error!("Error during search: {}", e);
281                        has_any_search_failed.store(true, std::sync::atomic::Ordering::Release);
282                    }
283                }
284            },
285        );
286
287        let diff = test_start.elapsed();
288        let qps = query_num as f32 / diff.as_secs_f32();
289
290        let mean_latency =
291            statistics::get_mean_stats(&statistics, |stats| stats.total_execution_time_us as f64);
292
293        let latency_999 = statistics::get_percentile_stats(&statistics, 0.999, |stats| {
294            stats.total_execution_time_us
295        });
296
297        let mean_ios = statistics::get_mean_stats(&statistics, |stats| stats.total_io_operations);
298        let mean_io_time = statistics::get_mean_stats(&statistics, |stats| stats.io_time_us as f64);
299        let mean_cpus = statistics::get_mean_stats(&statistics, |stats| stats.cpu_time_us as f64);
300        let mean_pq_preprocess_time = statistics::get_mean_stats(&statistics, |stats| {
301            stats.query_pq_preprocess_time_us as f64
302        });
303        let mean_comps =
304            statistics::get_mean_stats(&statistics, |stats| stats.total_comparisons as f64);
305        let mean_hops = statistics::get_mean_stats(&statistics, |stats| stats.search_hops as f64);
306        let total_ios = statistics::get_sum_stats(&statistics, |stats| stats.total_io_operations);
307        let total_vertices_loaded =
308            statistics::get_sum_stats(&statistics, |stats| stats.total_vertices_loaded);
309        let cache_hit_percentage = if total_vertices_loaded > 0.0 {
310            100.0 * (1.0 - (total_ios / total_vertices_loaded))
311        } else {
312            100.0
313        };
314
315        let mut recall = 0.0;
316        if calc_recall_flag {
317            recall = if let Some(gt_ids_variable_length) = &gt_ids_variable_length {
318                let our_results_variable_length = query_result_ids[test_id]
319                    .chunks_exact(parameters.recall_at as usize)
320                    .enumerate()
321                    .map(|(i, chunk)| chunk[..result_counts[i] as usize].to_vec())
322                    .collect::<Vec<_>>();
323                search_index_utils::calculate_filtered_search_recall(
324                    query_num,
325                    None,
326                    gt_ids_variable_length,
327                    &our_results_variable_length,
328                    parameters.recall_at,
329                )? as f32
330            } else {
331                search_index_utils::calculate_recall(
332                    query_num,
333                    gt_ids.as_ref().ok_or_else(|| CMDToolError {
334                        details: "GroundTruth IDs not initialized".to_string(),
335                    })?,
336                    gt_dists.as_ref(),
337                    gt_dim,
338                    &query_result_ids[test_id],
339                    parameters.recall_at,
340                    KRecallAtN::new(parameters.recall_at, parameters.recall_at)?,
341                )? as f32
342            };
343
344            best_recall = f32::from(std::cmp::max(
345                OrderedFloat::<f32>(best_recall),
346                OrderedFloat::<f32>(recall),
347            ));
348        }
349
350        if calc_recall_flag {
351            println!(
352                "{:<6}{:<12.2}{:<15.2}{:<20.2}{:<20.2}{:<12.2}{:<16.2}{:<10.2}{:<20.2}{:<12.2}{:<12.2}{:<14.2}{:<16.2}",
353                l,
354                parameters.beam_width,
355                qps,
356                mean_latency,
357                latency_999,
358                mean_ios,
359                mean_io_time,
360                mean_cpus,
361                mean_pq_preprocess_time,
362                mean_comps,
363                mean_hops,
364                cache_hit_percentage,
365                recall,
366            );
367        } else {
368            println!(
369                "{:<6}{:<12.2}{:<15.2}{:<20.2}{:<20.2}{:<12.2}{:<16.2}{:<10.2}{:<20.2}{:<12.2}{:<12.2}{:<14.2}",
370                l,
371                parameters.beam_width,
372                qps,
373                mean_latency,
374                latency_999,
375                mean_ios,
376                mean_io_time,
377                mean_cpus,
378                mean_pq_preprocess_time,
379                mean_comps,
380                mean_hops,
381                cache_hit_percentage,
382            );
383        }
384
385        #[cfg(feature = "perf_test")]
386        {
387            let latency_95 = statistics::get_percentile_stats(&statistics, 0.95, |stats| {
388                stats.total_execution_time_us
389            });
390
391            _span.set_attribute(KeyValue::new("qps", qps as f64));
392            _span.set_attribute(KeyValue::new("mean_latency", mean_latency));
393            _span.set_attribute(KeyValue::new("latency_999", latency_999 as f64));
394            _span.set_attribute(KeyValue::new("latency_95", latency_95 as f64));
395            _span.set_attribute(KeyValue::new("mean_cpus", mean_cpus));
396            _span.set_attribute(KeyValue::new("mean_io_time", mean_io_time));
397            _span.set_attribute(KeyValue::new("mean_ios", mean_ios));
398            _span.set_attribute(KeyValue::new("mean_comps", mean_comps));
399            _span.set_attribute(KeyValue::new("mean_hops", mean_hops));
400            _span.set_attribute(KeyValue::new("recall", recall as f64));
401            _span.end();
402        }
403    }
404
405    logger.log_checkpoint("search_completed");
406
407    info!("Done searching. Now saving results");
408    for (test_id, l_value) in parameters.l_vec.iter().enumerate() {
409        if *l_value < parameters.recall_at {
410            println!(
411                "Ignoring all search with L: {} since it's smaller than K: {}",
412                l_value, parameters.recall_at
413            );
414        }
415
416        let cur_result_path = format!(
417            "{}_{}_idx_uint32.bin",
418            parameters.result_output_prefix, l_value
419        );
420        let view = MatrixView::try_from(
421            query_result_ids[test_id].as_slice(),
422            query_num,
423            parameters.recall_at as usize,
424        )
425        .map_err(|e| CMDToolError {
426            details: e.to_string(),
427        })?;
428        write_bin(
429            view,
430            &mut storage_provider.create_for_write(&cur_result_path)?,
431        )?;
432    }
433
434    if has_any_search_failed.load(std::sync::atomic::Ordering::Acquire) {
435        // Exit with error. The above stats might still be useful to the user if only a few searched failed, so allowed printing them.
436        return Err(CMDToolError {
437            details: "At least one search failed with error. See log for details. Exiting."
438                .to_string(),
439        });
440    }
441
442    if best_recall >= parameters.fail_if_recall_below {
443        Ok(0)
444    } else {
445        println!(
446            "Search failed. Best recall {} is below the threshold {}",
447            best_recall, parameters.fail_if_recall_below
448        );
449        Ok(-1)
450    }
451}