Skip to main content

diskann_tools/utils/
ground_truth.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use bit_set::BitSet;
7use diskann_label_filter::{eval_query_expr, read_and_parse_queries, read_baselabels};
8
9use std::{io::Write, mem::size_of, str::FromStr};
10
11use bytemuck::cast_slice;
12use diskann::{
13    neighbor::{Neighbor, NeighborPriorityQueue},
14    utils::VectorRepr,
15};
16use diskann_disk::data_model::GraphDataType;
17use diskann_providers::storage::{StorageReadProvider, StorageWriteProvider};
18use diskann_providers::utils::{
19    create_thread_pool, file_util, ParallelIteratorInPool, VectorDataIterator,
20};
21use diskann_utils::{
22    io::{read_bin, Metadata},
23    views::Matrix,
24};
25use diskann_vector::{distance::Metric, DistanceFunction};
26use itertools::Itertools;
27use rayon::prelude::*;
28
29use crate::utils::{search_index_utils, CMDResult, CMDToolError};
30
31pub fn read_labels_and_compute_bitmap(
32    base_label_filename: &str,
33    query_label_filename: &str,
34) -> CMDResult<Vec<BitSet>> {
35    // Read base labels
36    let base_labels = read_baselabels(base_label_filename)?;
37
38    // Parse queries and evaluate against labels
39    let parsed_queries = read_and_parse_queries(query_label_filename)?;
40
41    // using the global threadpool is fine here
42    #[allow(clippy::disallowed_methods)]
43    let query_bitmaps: Vec<BitSet> = parsed_queries
44        .par_iter()
45        .map(|(_query_id, query_expr)| {
46            let mut bitmap = BitSet::new();
47            for base_label in base_labels.iter() {
48                if eval_query_expr(query_expr, &base_label.label) {
49                    bitmap.insert(base_label.doc_id);
50                }
51            }
52            bitmap
53        })
54        .collect();
55
56    Ok(query_bitmaps)
57}
58
59#[allow(clippy::too_many_arguments)]
60#[allow(clippy::panic)]
61/// Computes the true nearest neighbors for a set of queries and writes them to a file.
62///
63/// # Arguments
64///
65/// * `distance_function` - e.g. L2
66/// * `base_file` - The file containing the base vectors.
67/// * `query_file` - The file containing the query vectors.
68/// * `ground_truth_file` - The file to write the ground truth results to.
69/// * `recall_at` - The number of neighbors to compute for each query.
70/// * `insert_file` - Optional file containing more dataset vectors. This may be useful if you are testing recall for an index that has points dynamically inserted into it.
71/// * `skip_base` - Optional number of base points to skip. This is useful if you want to compute the ground truth for a set where the first skip_base points are deleted from the index.
72pub fn compute_ground_truth_from_datafiles<
73    Data: GraphDataType,
74    StorageProvider: StorageReadProvider + StorageWriteProvider,
75>(
76    storage_provider: &StorageProvider,
77    distance_function: Metric,
78    base_file: &str,
79    query_file: &str,
80    ground_truth_file: &str,
81    vector_filters_file: Option<&str>,
82    recall_at: u32,
83    insert_file: Option<&str>,
84    skip_base: Option<usize>,
85    associated_data_file: Option<String>,
86    base_file_labels: Option<&str>,
87    query_file_labels: Option<&str>,
88) -> CMDResult<()> {
89    let dataset_iterator = VectorDataIterator::<
90        StorageProvider,
91        Data::VectorDataType,
92        Data::AssociatedDataType,
93    >::new(base_file, associated_data_file.clone(), storage_provider)?;
94
95    // both base_file_labels and query_file_labels are provided or both are not provided
96    if !((base_file_labels.is_some() && query_file_labels.is_some())
97        || (base_file_labels.is_none() && query_file_labels.is_none()))
98    {
99        return Err(CMDToolError {
100            details: "Both base_file_labels and query_file_labels must be provided or both must be not provided.".to_string(),
101        });
102    }
103
104    if base_file_labels.is_some() && vector_filters_file.is_some() {
105        return Err(CMDToolError {
106            details: "Both base_file_labels and vector_filters_file cannot be provided."
107                .to_string(),
108        });
109    }
110
111    let insert_iterator = match insert_file {
112        Some(insert_file) => {
113            let i = VectorDataIterator::<
114                StorageProvider,
115                Data::VectorDataType,
116                Data::AssociatedDataType,
117            >::new(insert_file, Option::None, storage_provider)?;
118            Some(i)
119        }
120        None => None,
121    };
122
123    // Load the query file
124    let query_data =
125        read_bin::<Data::VectorDataType>(&mut storage_provider.open_reader(query_file)?)?;
126    let query_num = query_data.nrows();
127
128    let mut query_bitmaps: Option<Vec<BitSet>> = None;
129    if let (Some(base_file_labels), Some(query_file_labels)) = (base_file_labels, query_file_labels)
130    {
131        query_bitmaps = Some(read_labels_and_compute_bitmap(
132            base_file_labels,
133            query_file_labels,
134        )?);
135    }
136
137    // Load the vector filters
138    let vector_filters = match vector_filters_file {
139        Some(vector_filters_file) => {
140            let filters =
141                search_index_utils::load_vector_filters(storage_provider, vector_filters_file)?;
142
143            assert_eq!(
144                filters.len(),
145                query_num,
146                "Mismatch in query and vector filter sizes"
147            );
148
149            Some(filters)
150        }
151        None => None,
152    };
153
154    let has_vector_filters = vector_filters.is_some();
155    let has_query_bitmaps = query_bitmaps.is_some();
156
157    if has_vector_filters {
158        // copy vector_filters to query_bitmaps one item at a time
159        if let Some(filters) = vector_filters {
160            let mut bitmaps = vec![BitSet::new(); query_num];
161            for (idx_query, filter) in filters.iter().enumerate() {
162                for item in filter.iter() {
163                    if let Ok(idx) = (*item).try_into() {
164                        bitmaps[idx_query].insert(idx);
165                    }
166                }
167            }
168            query_bitmaps = Some(bitmaps)
169        }
170    }
171
172    let ground_truth_result = compute_ground_truth_from_data::<Data, StorageProvider>(
173        distance_function,
174        dataset_iterator,
175        &query_data,
176        recall_at,
177        insert_iterator,
178        skip_base,
179        query_bitmaps,
180    );
181    assert!(
182        &ground_truth_result.is_ok(),
183        "Ground-truth computation failed"
184    );
185    let (ground_truth, id_to_associated_data) = ground_truth_result?;
186
187    assert_ne!(ground_truth.len(), 0, "No ground-truth results computed");
188
189    if has_vector_filters || has_query_bitmaps {
190        let ground_truth_collection = ground_truth
191            .into_iter()
192            .map(|npq| npq.into_iter().collect())
193            .collect();
194        write_range_search_ground_truth(
195            storage_provider,
196            ground_truth_file,
197            query_num,
198            ground_truth_collection,
199        )
200    } else {
201        // Write results and return
202        let id_to_associated_data = associated_data_file.map(|_| id_to_associated_data);
203        write_ground_truth::<Data>(
204            storage_provider,
205            ground_truth_file,
206            query_num,
207            recall_at as usize,
208            ground_truth,
209            id_to_associated_data,
210        )
211    }
212}
213
214#[derive(Debug, Clone)]
215pub enum MultivecAggregationMethod {
216    AveragePairwise,
217    MinPairwise,
218    AvgofMins,
219}
220
221#[derive(Debug)]
222pub enum ParseAggrError {
223    InvalidFormat(String),
224}
225
226impl std::fmt::Display for ParseAggrError {
227    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
228        match self {
229            Self::InvalidFormat(str) => write!(f, "Invalid format for Aggregation Method: {}", str),
230        }
231    }
232}
233
234impl std::error::Error for ParseAggrError {}
235
236impl FromStr for MultivecAggregationMethod {
237    type Err = ParseAggrError;
238
239    fn from_str(s: &str) -> Result<Self, Self::Err> {
240        match s.to_lowercase().as_str() {
241            "average_pairwise" => Ok(MultivecAggregationMethod::AveragePairwise),
242            "min_pairwise" => Ok(MultivecAggregationMethod::MinPairwise),
243            "avg_of_mins" => Ok(MultivecAggregationMethod::AvgofMins),
244            _ => Err(ParseAggrError::InvalidFormat(String::from(s))),
245        }
246    }
247}
248
249#[allow(clippy::too_many_arguments)]
250#[allow(clippy::panic)]
251/// Computes the true nearest neighbors for a set of queries and writes them to a file.
252///
253/// # Arguments
254///
255/// * `distance_function` - e.g. L2
256/// * `aggregation_method` - e.g. Average or Min
257/// * `base_file` - The file containing the base vectors.
258/// * `query_file` - The file containing the query vectors.
259/// * `ground_truth_file` - The file to write the ground truth results to.
260/// * `recall_at` - The number of neighbors to compute for each query.
261/// * `base_file_labels` - Optional labels file for the base vectors to filter which base vectors to consider per query.
262/// * `query_file_labels` - Optional labels file for the query vectors to filter which base vectors to consider per query.
263pub fn compute_multivec_ground_truth_from_datafiles<
264    Data: GraphDataType,
265    StorageProvider: StorageReadProvider + StorageWriteProvider,
266>(
267    storage_provider: &StorageProvider,
268    distance_function: Metric,
269    aggregation_method: MultivecAggregationMethod,
270    base_file: &str,
271    query_file: &str,
272    ground_truth_file: &str,
273    recall_at: u32,
274    base_file_labels: Option<&str>,
275    query_file_labels: Option<&str>,
276) -> CMDResult<()> {
277    let (base_vectors, _, _, _) = file_util::load_multivec_bin::<
278        Data::VectorDataType,
279        StorageProvider,
280    >(storage_provider, base_file)?;
281
282    let (query_vectors, query_num, query_dim, _) = file_util::load_multivec_bin::<
283        Data::VectorDataType,
284        StorageProvider,
285    >(storage_provider, query_file)?;
286
287    // both base_file_labels and query_file_labels are provided or both are not provided
288    if !((base_file_labels.is_some() && query_file_labels.is_some())
289        || (base_file_labels.is_none() && query_file_labels.is_none()))
290    {
291        return Err(CMDToolError {
292            details: "Both base_file_labels and query_file_labels must be provided or both must be not provided.".to_string(),
293        });
294    }
295
296    let mut query_bitmaps: Option<Vec<BitSet>> = None;
297    if let (Some(base_file_labels), Some(query_file_labels)) = (base_file_labels, query_file_labels)
298    {
299        query_bitmaps = Some(read_labels_and_compute_bitmap(
300            base_file_labels,
301            query_file_labels,
302        )?);
303    }
304
305    let has_query_bitmaps = query_bitmaps.is_some();
306
307    let ground_truth = compute_multivec_ground_truth_from_data::<Data::VectorDataType>(
308        distance_function,
309        aggregation_method,
310        base_vectors,
311        query_vectors,
312        query_dim,
313        recall_at,
314        query_bitmaps,
315    )?;
316
317    if has_query_bitmaps {
318        let ground_truth_collection = ground_truth
319            .into_iter()
320            .map(|npq| npq.into_iter().collect())
321            .collect();
322        write_range_search_ground_truth(
323            storage_provider,
324            ground_truth_file,
325            query_num,
326            ground_truth_collection,
327        )
328    } else {
329        // Write results and return
330        write_ground_truth::<Data>(
331            storage_provider,
332            ground_truth_file,
333            query_num,
334            recall_at as usize,
335            ground_truth,
336            Option::None,
337        )
338    }
339}
340
341fn write_range_search_ground_truth<StorageProvider: StorageReadProvider + StorageWriteProvider>(
342    storage_provider: &StorageProvider,
343    ground_truth_file: &str,
344    number_of_queries: usize,
345    ground_truth: Vec<Vec<Neighbor<u32>>>,
346) -> CMDResult<()> {
347    let mut file = storage_provider.create_for_write(ground_truth_file)?;
348
349    let queue_sizes: Vec<u32> = ground_truth
350        .iter()
351        .map(|queue| queue.len() as u32)
352        .collect();
353    let total_number_of_neighbors: usize = queue_sizes.iter().sum::<u32>() as usize;
354
355    // Metadata
356    Metadata::new(number_of_queries, total_number_of_neighbors)?.write(&mut file)?;
357
358    // Write queue sizes array.
359    let mut queue_sizes_buffer = vec![0; queue_sizes.len() * size_of::<u32>()];
360    queue_sizes_buffer.clone_from_slice(cast_slice::<u32, u8>(&queue_sizes));
361    file.write_all(&queue_sizes_buffer)?;
362
363    let mut neighbor_ids: Vec<u32> = Vec::with_capacity(total_number_of_neighbors);
364
365    // Write the neighbor IDs array.
366    for query_neighbors in ground_truth {
367        for neighbor in query_neighbors.iter() {
368            neighbor_ids.push(neighbor.id);
369        }
370    }
371
372    // Write neighbor IDs
373    let mut id_buffer = vec![0; total_number_of_neighbors * size_of::<u32>()];
374    id_buffer.clone_from_slice(cast_slice::<u32, u8>(&neighbor_ids));
375    file.write_all(&id_buffer)?;
376
377    // Make sure everything is written to disk
378    file.flush()?;
379
380    Ok(())
381}
382
383/// Writes out a ground truth file.  ground_truth is a vector of NeighborPriorityQueue objects
384/// where the order of queue objects corresponds to the order of queries used to compute this
385/// ground truth.
386fn write_ground_truth<Data: GraphDataType>(
387    storage_provider: &impl StorageWriteProvider,
388    ground_truth_file: &str,
389    number_of_queries: usize,
390    number_of_neighbors: usize,
391    ground_truth: Vec<NeighborPriorityQueue<u32>>,
392    id_to_associated_data: Option<Vec<Data::AssociatedDataType>>,
393) -> CMDResult<()> {
394    let mut file = storage_provider.create_for_write(ground_truth_file)?;
395
396    Metadata::new(number_of_queries, number_of_neighbors)?.write(&mut file)?;
397
398    let mut gt_ids: Vec<u32> = Vec::with_capacity(number_of_neighbors * number_of_queries);
399    let mut gt_distances: Vec<f32> = Vec::with_capacity(number_of_neighbors * number_of_queries);
400
401    // In the file, we write the neighbor IDs array first, then write the distances array.
402    for mut query_neighbors in ground_truth {
403        while let Some(closest_node) = query_neighbors.closest_notvisited() {
404            gt_ids.push(closest_node.id);
405            gt_distances.push(closest_node.distance);
406        }
407    }
408
409    // Write neighbor IDs or Associated Data
410    if let Some(id_to_associated_data) = id_to_associated_data {
411        let mut associated_data_buffer = Vec::<u8>::new();
412        for id in gt_ids {
413            let associated_data = id_to_associated_data[id as usize];
414            let serialized_associated_data =
415                bincode::serialize(&associated_data).map_err(|e| CMDToolError {
416                    details: format!("Failed to serialize associated data: {}", e),
417                })?;
418            associated_data_buffer.extend_from_slice(serialized_associated_data.as_slice());
419        }
420        file.write_all(&associated_data_buffer)?;
421    } else {
422        let mut id_buffer = vec![0; number_of_queries * number_of_neighbors * size_of::<u32>()];
423        id_buffer.clone_from_slice(cast_slice::<u32, u8>(&gt_ids));
424        file.write_all(&id_buffer)?;
425    }
426
427    // Write neighbor distances
428    let mut distance_buffer = vec![0; number_of_queries * number_of_neighbors * size_of::<f32>()];
429    distance_buffer.clone_from_slice(cast_slice::<f32, u8>(&gt_distances));
430    file.write_all(&distance_buffer)?;
431
432    // Make sure everything is written to disk
433    file.flush()?;
434
435    Ok(())
436}
437
438type Npq = Vec<NeighborPriorityQueue<u32>>;
439/// Computes the true nearest neighbors for a set of queries and dataset iterators
440///
441/// # Arguments
442///
443/// * `distance_function` - e.g. L2
444/// * `dataset_iter` - The iterator over the dataset vectors and associated data.
445/// * `queries` - Query vectors as a row-major `Matrix` of shape `num_queries × query_dim`.
446///   `query_dim` is inferred from `queries.ncols()`.
447/// * `recall_at` - The number of neighbors to compute for each query.
448/// * `insert_iter` - Optional iterator containing more dataset vectors. This may be useful if you are testing recall for an index that has points dynamically inserted into it.
449/// * `skip_base` - Optional number of base points to skip. This is useful if you want to compute the ground truth for a set where the first skip_base points are deleted from the index.
450/// * `query_bitmaps` - Optional per-query bitmaps restricting which base point ids contribute to that query's neighbors.
451#[allow(clippy::too_many_arguments)]
452pub fn compute_ground_truth_from_data<Data, VectorReader>(
453    distance_function: Metric,
454    dataset_iter: VectorDataIterator<VectorReader, Data::VectorDataType, Data::AssociatedDataType>,
455    queries: &Matrix<Data::VectorDataType>,
456    recall_at: u32,
457    insert_iter: Option<
458        VectorDataIterator<VectorReader, Data::VectorDataType, Data::AssociatedDataType>,
459    >,
460    skip_base: Option<usize>,
461    query_bitmaps: Option<Vec<BitSet>>,
462) -> CMDResult<(Npq, Vec<Data::AssociatedDataType>)>
463where
464    Data: GraphDataType,
465    VectorReader: StorageReadProvider,
466{
467    let query_num = queries.nrows();
468    let query_dim = queries.ncols();
469
470    let mut neighbor_queues: Vec<NeighborPriorityQueue<u32>> = (0..query_num)
471        .map(|_| NeighborPriorityQueue::new(recall_at as usize))
472        .collect();
473    let mut queries_and_neighbor_queue: Vec<_> =
474        queries.row_iter().zip(neighbor_queues.iter_mut()).collect();
475
476    let distance_comparer = Data::VectorDataType::distance(distance_function, Some(query_dim));
477
478    let batch_size = 10_000;
479    let mut data_batch: Vec<Box<[Data::VectorDataType]>> = Vec::with_capacity(batch_size);
480
481    let pool = create_thread_pool(0)?;
482
483    let mut num_base_points: usize = 0;
484    let mut id_to_associated_data = Vec::<Data::AssociatedDataType>::new();
485    let skip_base = skip_base.unwrap_or(0);
486    // Loop over all the raw data
487    for chunk in dataset_iter.skip(skip_base).chunks(batch_size).into_iter() {
488        data_batch.clear();
489        for (data_vector, associated_data) in chunk {
490            data_batch.push(data_vector);
491            id_to_associated_data.push(associated_data);
492        }
493        let points = data_batch.len();
494
495        if points == 0 {
496            continue;
497        }
498
499        // For each node in the raw data, calculate the distance to each query vector and store it in the priority queue for that query.  This will find the closest N neighbors for each query.
500        queries_and_neighbor_queue
501            .par_iter_mut()
502            .enumerate()
503            .for_each_in_pool(&pool, |(idx_query, (query, ref mut neighbor_queue))| {
504                for (idx_in_batch, data) in data_batch.iter().enumerate() {
505                    let idx = (num_base_points + idx_in_batch) as u32;
506
507                    let allowed_by_bitmap = if let Some(ref bitmaps) = query_bitmaps {
508                        if let Ok(idx_usize) = idx.try_into() {
509                            bitmaps[idx_query].contains(idx_usize)
510                        } else {
511                            false
512                        }
513                    } else {
514                        true
515                    };
516
517                    if allowed_by_bitmap {
518                        let distance = distance_comparer.evaluate_similarity(data, query);
519                        neighbor_queue.insert(Neighbor { id: idx, distance });
520                    }
521                }
522            });
523
524        num_base_points += points;
525    }
526
527    if let Some(insert_iter) = insert_iter {
528        for (insert_idx, (data_vector, _associated_data)) in insert_iter.enumerate() {
529            // For each node in the raw data, calculate the distance to each query vector and store it in the priority queue for that query.  This will find the closest N neighbors for each query.
530            for (idx_query, (query, ref mut neighbor_queue)) in
531                queries_and_neighbor_queue.iter_mut().enumerate()
532            {
533                let idx = (num_base_points + insert_idx) as u32;
534
535                let allowed_by_bitmap = if let Some(ref bitmaps) = query_bitmaps {
536                    if let Ok(idx_usize) = idx.try_into() {
537                        bitmaps[idx_query].contains(idx_usize)
538                    } else {
539                        false
540                    }
541                } else {
542                    true
543                };
544
545                if allowed_by_bitmap {
546                    let distance = distance_comparer.evaluate_similarity(&data_vector, query);
547                    neighbor_queue.insert(Neighbor { id: idx, distance })
548                }
549            }
550        }
551    }
552
553    Ok((neighbor_queues, id_to_associated_data))
554}
555
556#[allow(clippy::too_many_arguments)]
557pub fn compute_multivec_ground_truth_from_data<T>(
558    distance_function: Metric,
559    aggregation_method: MultivecAggregationMethod,
560    base_vectors: Vec<Matrix<T>>,
561    queries: Vec<Matrix<T>>,
562    query_dim: usize,
563    recall_at: u32,
564    query_bitmaps: Option<Vec<BitSet>>,
565) -> CMDResult<Vec<NeighborPriorityQueue<u32>>>
566where
567    T: VectorRepr,
568{
569    let query_num = queries.len();
570
571    let mut neighbor_queues: Vec<NeighborPriorityQueue<u32>> = Vec::with_capacity(query_num);
572    //
573    for _ in 0..query_num {
574        neighbor_queues.push(NeighborPriorityQueue::new(recall_at as usize));
575    }
576    let mut query_multivecs_and_neighbor_queue: Vec<_> =
577        queries.iter().zip(neighbor_queues.iter_mut()).collect();
578
579    let distance_comparer = T::distance(distance_function, Some(query_dim));
580
581    let pool = create_thread_pool(0)?;
582
583    // for each query multivec, compute chamfer distance in parallel
584
585    query_multivecs_and_neighbor_queue
586        .par_iter_mut()
587        .enumerate()
588        .for_each_in_pool(&pool, |(query_idx, (query_multivec, neighbor_queue))| {
589            for (idx_base, base_multivec) in base_vectors.iter().enumerate() {
590                // check if calculation is allowed by bitmap if present
591                let allowed_by_bitmap = if let Some(ref bitmaps) = query_bitmaps {
592                    bitmaps[query_idx].contains(idx_base)
593                } else {
594                    true
595                };
596
597                if allowed_by_bitmap {
598                    // compute distance between query_multivec and base_multivec
599                    let distance = match aggregation_method {
600                        MultivecAggregationMethod::AveragePairwise => {
601                            let mut total_distance = 0.0;
602                            for query_vec in query_multivec.row_iter() {
603                                for base_vec in base_multivec.row_iter() {
604                                    let dist =
605                                        distance_comparer.evaluate_similarity(query_vec, base_vec);
606                                    total_distance += dist;
607                                }
608                            }
609                            total_distance / (query_multivec.nrows() * base_multivec.nrows()) as f32
610                        }
611                        MultivecAggregationMethod::MinPairwise => {
612                            let mut min_distance = f32::MAX;
613                            for query_vec in query_multivec.row_iter() {
614                                for base_vec in base_multivec.row_iter() {
615                                    let dist =
616                                        distance_comparer.evaluate_similarity(query_vec, base_vec);
617                                    min_distance = min_distance.min(dist);
618                                }
619                            }
620                            min_distance
621                        }
622                        MultivecAggregationMethod::AvgofMins => {
623                            let mut distance = 0_f32;
624                            for query_vec in query_multivec.row_iter() {
625                                let mut local_min = f32::MAX;
626                                for base_vec in base_multivec.row_iter() {
627                                    let dist =
628                                        distance_comparer.evaluate_similarity(query_vec, base_vec);
629                                    local_min = local_min.min(dist);
630                                }
631                                distance += local_min;
632                            }
633                            distance / query_multivec.nrows() as f32
634                        }
635                    };
636                    // insert into neighbor queue
637                    let idx = idx_base as u32;
638                    neighbor_queue.insert(Neighbor { id: idx, distance });
639                }
640            }
641        });
642
643    Ok(neighbor_queues)
644}