Skip to main content

diskann_tools/utils/
ground_truth.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use crate::utils::compute_bitmap::compute_query_bitmaps;
7use bit_set::BitSet;
8use diskann_label_filter::{read_and_parse_queries, read_baselabels};
9
10use std::{io::Write, mem::size_of, str::FromStr};
11
12use bytemuck::cast_slice;
13use diskann::{
14    neighbor::{Neighbor, NeighborPriorityQueue},
15    utils::VectorRepr,
16};
17use diskann_disk::data_model::GraphDataType;
18use diskann_providers::storage::{StorageReadProvider, StorageWriteProvider};
19use diskann_providers::utils::{
20    create_thread_pool, file_util, ParallelIteratorInPool, VectorDataIterator,
21};
22use diskann_utils::{
23    io::{read_bin, Metadata},
24    views::Matrix,
25};
26use diskann_vector::{distance::Metric, DistanceFunction};
27use itertools::Itertools;
28use rayon::prelude::*;
29
30use crate::utils::{search_index_utils, CMDResult, CMDToolError};
31
32pub fn read_labels_and_compute_bitmap(
33    base_label_filename: &str,
34    query_label_filename: &str,
35) -> CMDResult<Vec<BitSet>> {
36    // Read base labels
37    let base_labels = read_baselabels(base_label_filename)?;
38
39    // Read and parse queries
40    let parsed_queries = read_and_parse_queries(query_label_filename)?;
41
42    // Compute the query bitmaps
43    let query_bitmaps = compute_query_bitmaps(base_labels, parsed_queries);
44
45    match query_bitmaps {
46        Ok(bitmaps) => Ok(bitmaps),
47        Err(e) => Err(CMDToolError {
48            details: format!("Error computing query bitmaps: {}", e),
49        }),
50    }
51}
52
53#[allow(clippy::too_many_arguments)]
54#[allow(clippy::panic)]
55/// Computes the true nearest neighbors for a set of queries and writes them to a file.
56///
57/// # Arguments
58///
59/// * `distance_function` - e.g. L2
60/// * `base_file` - The file containing the base vectors.
61/// * `query_file` - The file containing the query vectors.
62/// * `ground_truth_file` - The file to write the ground truth results to.
63/// * `recall_at` - The number of neighbors to compute for each query.
64/// * `insert_file` - Optional file containing more dataset vectors. This may be useful if you are testing recall for an index that has points dynamically inserted into it.
65/// * `skip_base` - Optional number of base points to skip. This is useful if you want to compute the ground truth for a set where the first skip_base points are deleted from the index.
66pub fn compute_ground_truth_from_datafiles<
67    Data: GraphDataType,
68    StorageProvider: StorageReadProvider + StorageWriteProvider,
69>(
70    storage_provider: &StorageProvider,
71    distance_function: Metric,
72    base_file: &str,
73    query_file: &str,
74    ground_truth_file: &str,
75    vector_filters_file: Option<&str>,
76    recall_at: u32,
77    insert_file: Option<&str>,
78    skip_base: Option<usize>,
79    associated_data_file: Option<String>,
80    base_file_labels: Option<&str>,
81    query_file_labels: Option<&str>,
82) -> CMDResult<()> {
83    let dataset_iterator = VectorDataIterator::<
84        StorageProvider,
85        Data::VectorDataType,
86        Data::AssociatedDataType,
87    >::new(base_file, associated_data_file.clone(), storage_provider)?;
88
89    // both base_file_labels and query_file_labels are provided or both are not provided
90    if !((base_file_labels.is_some() && query_file_labels.is_some())
91        || (base_file_labels.is_none() && query_file_labels.is_none()))
92    {
93        return Err(CMDToolError {
94            details: "Both base_file_labels and query_file_labels must be provided or both must be not provided.".to_string(),
95        });
96    }
97
98    if base_file_labels.is_some() && vector_filters_file.is_some() {
99        return Err(CMDToolError {
100            details: "Both base_file_labels and vector_filters_file cannot be provided."
101                .to_string(),
102        });
103    }
104
105    let insert_iterator = match insert_file {
106        Some(insert_file) => {
107            let i = VectorDataIterator::<
108                StorageProvider,
109                Data::VectorDataType,
110                Data::AssociatedDataType,
111            >::new(insert_file, Option::None, storage_provider)?;
112            Some(i)
113        }
114        None => None,
115    };
116
117    // Load the query file
118    let query_data =
119        read_bin::<Data::VectorDataType>(&mut storage_provider.open_reader(query_file)?)?;
120    let query_num = query_data.nrows();
121
122    let mut query_bitmaps: Option<Vec<BitSet>> = None;
123    if let (Some(base_file_labels), Some(query_file_labels)) = (base_file_labels, query_file_labels)
124    {
125        query_bitmaps = Some(read_labels_and_compute_bitmap(
126            base_file_labels,
127            query_file_labels,
128        )?);
129    }
130
131    // Load the vector filters
132    let vector_filters = match vector_filters_file {
133        Some(vector_filters_file) => {
134            let filters =
135                search_index_utils::load_vector_filters(storage_provider, vector_filters_file)?;
136
137            assert_eq!(
138                filters.len(),
139                query_num,
140                "Mismatch in query and vector filter sizes"
141            );
142
143            Some(filters)
144        }
145        None => None,
146    };
147
148    let has_vector_filters = vector_filters.is_some();
149    let has_query_bitmaps = query_bitmaps.is_some();
150
151    if has_vector_filters {
152        // copy vector_filters to query_bitmaps one item at a time
153        if let Some(filters) = vector_filters {
154            let mut bitmaps = vec![BitSet::new(); query_num];
155            for (idx_query, filter) in filters.iter().enumerate() {
156                for item in filter.iter() {
157                    if let Ok(idx) = (*item).try_into() {
158                        bitmaps[idx_query].insert(idx);
159                    }
160                }
161            }
162            query_bitmaps = Some(bitmaps)
163        }
164    }
165
166    let ground_truth_result = compute_ground_truth_from_data::<Data, StorageProvider>(
167        distance_function,
168        dataset_iterator,
169        &query_data,
170        recall_at,
171        insert_iterator,
172        skip_base,
173        query_bitmaps,
174    );
175    assert!(
176        &ground_truth_result.is_ok(),
177        "Ground-truth computation failed"
178    );
179    let (ground_truth, id_to_associated_data) = ground_truth_result?;
180
181    assert_ne!(ground_truth.len(), 0, "No ground-truth results computed");
182
183    if has_vector_filters || has_query_bitmaps {
184        let ground_truth_collection = ground_truth
185            .into_iter()
186            .map(|npq| npq.into_iter().collect())
187            .collect();
188        write_range_search_ground_truth(
189            storage_provider,
190            ground_truth_file,
191            query_num,
192            ground_truth_collection,
193        )
194    } else {
195        // Write results and return
196        let id_to_associated_data = associated_data_file.map(|_| id_to_associated_data);
197        write_ground_truth::<Data>(
198            storage_provider,
199            ground_truth_file,
200            query_num,
201            recall_at as usize,
202            ground_truth,
203            id_to_associated_data,
204        )
205    }
206}
207
208#[derive(Debug, Clone)]
209pub enum MultivecAggregationMethod {
210    AveragePairwise,
211    MinPairwise,
212    AvgofMins,
213}
214
215#[derive(Debug)]
216pub enum ParseAggrError {
217    InvalidFormat(String),
218}
219
220impl std::fmt::Display for ParseAggrError {
221    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
222        match self {
223            Self::InvalidFormat(str) => write!(f, "Invalid format for Aggregation Method: {}", str),
224        }
225    }
226}
227
228impl std::error::Error for ParseAggrError {}
229
230impl FromStr for MultivecAggregationMethod {
231    type Err = ParseAggrError;
232
233    fn from_str(s: &str) -> Result<Self, Self::Err> {
234        match s.to_lowercase().as_str() {
235            "average_pairwise" => Ok(MultivecAggregationMethod::AveragePairwise),
236            "min_pairwise" => Ok(MultivecAggregationMethod::MinPairwise),
237            "avg_of_mins" => Ok(MultivecAggregationMethod::AvgofMins),
238            _ => Err(ParseAggrError::InvalidFormat(String::from(s))),
239        }
240    }
241}
242
243#[allow(clippy::too_many_arguments)]
244#[allow(clippy::panic)]
245/// Computes the true nearest neighbors for a set of queries and writes them to a file.
246///
247/// # Arguments
248///
249/// * `distance_function` - e.g. L2
250/// * `aggregation_method` - e.g. Average or Min
251/// * `base_file` - The file containing the base vectors.
252/// * `query_file` - The file containing the query vectors.
253/// * `ground_truth_file` - The file to write the ground truth results to.
254/// * `recall_at` - The number of neighbors to compute for each query.
255/// * `base_file_labels` - Optional labels file for the base vectors to filter which base vectors to consider per query.
256/// * `query_file_labels` - Optional labels file for the query vectors to filter which base vectors to consider per query.
257pub fn compute_multivec_ground_truth_from_datafiles<
258    Data: GraphDataType,
259    StorageProvider: StorageReadProvider + StorageWriteProvider,
260>(
261    storage_provider: &StorageProvider,
262    distance_function: Metric,
263    aggregation_method: MultivecAggregationMethod,
264    base_file: &str,
265    query_file: &str,
266    ground_truth_file: &str,
267    recall_at: u32,
268    base_file_labels: Option<&str>,
269    query_file_labels: Option<&str>,
270) -> CMDResult<()> {
271    let (base_vectors, _, _, _) = file_util::load_multivec_bin::<
272        Data::VectorDataType,
273        StorageProvider,
274    >(storage_provider, base_file)?;
275
276    let (query_vectors, query_num, query_dim, _) = file_util::load_multivec_bin::<
277        Data::VectorDataType,
278        StorageProvider,
279    >(storage_provider, query_file)?;
280
281    // both base_file_labels and query_file_labels are provided or both are not provided
282    if !((base_file_labels.is_some() && query_file_labels.is_some())
283        || (base_file_labels.is_none() && query_file_labels.is_none()))
284    {
285        return Err(CMDToolError {
286            details: "Both base_file_labels and query_file_labels must be provided or both must be not provided.".to_string(),
287        });
288    }
289
290    let mut query_bitmaps: Option<Vec<BitSet>> = None;
291    if let (Some(base_file_labels), Some(query_file_labels)) = (base_file_labels, query_file_labels)
292    {
293        query_bitmaps = Some(read_labels_and_compute_bitmap(
294            base_file_labels,
295            query_file_labels,
296        )?);
297    }
298
299    let has_query_bitmaps = query_bitmaps.is_some();
300
301    let ground_truth = compute_multivec_ground_truth_from_data::<Data::VectorDataType>(
302        distance_function,
303        aggregation_method,
304        base_vectors,
305        query_vectors,
306        query_dim,
307        recall_at,
308        query_bitmaps,
309    )?;
310
311    if has_query_bitmaps {
312        let ground_truth_collection = ground_truth
313            .into_iter()
314            .map(|npq| npq.into_iter().collect())
315            .collect();
316        write_range_search_ground_truth(
317            storage_provider,
318            ground_truth_file,
319            query_num,
320            ground_truth_collection,
321        )
322    } else {
323        // Write results and return
324        write_ground_truth::<Data>(
325            storage_provider,
326            ground_truth_file,
327            query_num,
328            recall_at as usize,
329            ground_truth,
330            Option::None,
331        )
332    }
333}
334
335fn write_range_search_ground_truth<StorageProvider: StorageReadProvider + StorageWriteProvider>(
336    storage_provider: &StorageProvider,
337    ground_truth_file: &str,
338    number_of_queries: usize,
339    ground_truth: Vec<Vec<Neighbor<u32>>>,
340) -> CMDResult<()> {
341    let mut file = storage_provider.create_for_write(ground_truth_file)?;
342
343    let queue_sizes: Vec<u32> = ground_truth
344        .iter()
345        .map(|queue| queue.len() as u32)
346        .collect();
347    let total_number_of_neighbors: usize = queue_sizes.iter().sum::<u32>() as usize;
348
349    // Metadata
350    Metadata::new(number_of_queries, total_number_of_neighbors)?.write(&mut file)?;
351
352    // Write queue sizes array.
353    let mut queue_sizes_buffer = vec![0; queue_sizes.len() * size_of::<u32>()];
354    queue_sizes_buffer.clone_from_slice(cast_slice::<u32, u8>(&queue_sizes));
355    file.write_all(&queue_sizes_buffer)?;
356
357    let mut neighbor_ids: Vec<u32> = Vec::with_capacity(total_number_of_neighbors);
358
359    // Write the neighbor IDs array.
360    for query_neighbors in ground_truth {
361        for neighbor in query_neighbors.iter() {
362            neighbor_ids.push(neighbor.id);
363        }
364    }
365
366    // Write neighbor IDs
367    let mut id_buffer = vec![0; total_number_of_neighbors * size_of::<u32>()];
368    id_buffer.clone_from_slice(cast_slice::<u32, u8>(&neighbor_ids));
369    file.write_all(&id_buffer)?;
370
371    // Make sure everything is written to disk
372    file.flush()?;
373
374    Ok(())
375}
376
377/// Writes out a ground truth file.  ground_truth is a vector of NeighborPriorityQueue objects
378/// where the order of queue objects corresponds to the order of queries used to compute this
379/// ground truth.
380fn write_ground_truth<Data: GraphDataType>(
381    storage_provider: &impl StorageWriteProvider,
382    ground_truth_file: &str,
383    number_of_queries: usize,
384    number_of_neighbors: usize,
385    ground_truth: Vec<NeighborPriorityQueue<u32>>,
386    id_to_associated_data: Option<Vec<Data::AssociatedDataType>>,
387) -> CMDResult<()> {
388    let mut file = storage_provider.create_for_write(ground_truth_file)?;
389
390    Metadata::new(number_of_queries, number_of_neighbors)?.write(&mut file)?;
391
392    let mut gt_ids: Vec<u32> = Vec::with_capacity(number_of_neighbors * number_of_queries);
393    let mut gt_distances: Vec<f32> = Vec::with_capacity(number_of_neighbors * number_of_queries);
394
395    // In the file, we write the neighbor IDs array first, then write the distances array.
396    for mut query_neighbors in ground_truth {
397        while let Some(closest_node) = query_neighbors.closest_notvisited() {
398            gt_ids.push(closest_node.id);
399            gt_distances.push(closest_node.distance);
400        }
401    }
402
403    // Write neighbor IDs or Associated Data
404    if let Some(id_to_associated_data) = id_to_associated_data {
405        let mut associated_data_buffer = Vec::<u8>::new();
406        for id in gt_ids {
407            let associated_data = id_to_associated_data[id as usize];
408            let serialized_associated_data =
409                bincode::serialize(&associated_data).map_err(|e| CMDToolError {
410                    details: format!("Failed to serialize associated data: {}", e),
411                })?;
412            associated_data_buffer.extend_from_slice(serialized_associated_data.as_slice());
413        }
414        file.write_all(&associated_data_buffer)?;
415    } else {
416        let mut id_buffer = vec![0; number_of_queries * number_of_neighbors * size_of::<u32>()];
417        id_buffer.clone_from_slice(cast_slice::<u32, u8>(&gt_ids));
418        file.write_all(&id_buffer)?;
419    }
420
421    // Write neighbor distances
422    let mut distance_buffer = vec![0; number_of_queries * number_of_neighbors * size_of::<f32>()];
423    distance_buffer.clone_from_slice(cast_slice::<f32, u8>(&gt_distances));
424    file.write_all(&distance_buffer)?;
425
426    // Make sure everything is written to disk
427    file.flush()?;
428
429    Ok(())
430}
431
432type Npq = Vec<NeighborPriorityQueue<u32>>;
433/// Computes the true nearest neighbors for a set of queries and dataset iterators
434///
435/// # Arguments
436///
437/// * `distance_function` - e.g. L2
438/// * `dataset_iter` - The iterator over the dataset vectors and associated data.
439/// * `queries` - Query vectors as a row-major `Matrix` of shape `num_queries × query_dim`.
440///   `query_dim` is inferred from `queries.ncols()`.
441/// * `recall_at` - The number of neighbors to compute for each query.
442/// * `insert_iter` - Optional iterator containing more dataset vectors. This may be useful if you are testing recall for an index that has points dynamically inserted into it.
443/// * `skip_base` - Optional number of base points to skip. This is useful if you want to compute the ground truth for a set where the first skip_base points are deleted from the index.
444/// * `query_bitmaps` - Optional per-query bitmaps restricting which base point ids contribute to that query's neighbors.
445#[allow(clippy::too_many_arguments)]
446pub fn compute_ground_truth_from_data<Data, VectorReader>(
447    distance_function: Metric,
448    dataset_iter: VectorDataIterator<VectorReader, Data::VectorDataType, Data::AssociatedDataType>,
449    queries: &Matrix<Data::VectorDataType>,
450    recall_at: u32,
451    insert_iter: Option<
452        VectorDataIterator<VectorReader, Data::VectorDataType, Data::AssociatedDataType>,
453    >,
454    skip_base: Option<usize>,
455    query_bitmaps: Option<Vec<BitSet>>,
456) -> CMDResult<(Npq, Vec<Data::AssociatedDataType>)>
457where
458    Data: GraphDataType,
459    VectorReader: StorageReadProvider,
460{
461    let query_num = queries.nrows();
462    let query_dim = queries.ncols();
463
464    let mut neighbor_queues: Vec<NeighborPriorityQueue<u32>> = (0..query_num)
465        .map(|_| NeighborPriorityQueue::new(recall_at as usize))
466        .collect();
467    let mut queries_and_neighbor_queue: Vec<_> =
468        queries.row_iter().zip(neighbor_queues.iter_mut()).collect();
469
470    let distance_comparer = Data::VectorDataType::distance(distance_function, Some(query_dim));
471
472    let batch_size = 10_000;
473    let mut data_batch: Vec<Box<[Data::VectorDataType]>> = Vec::with_capacity(batch_size);
474
475    let pool = create_thread_pool(0)?;
476
477    let mut num_base_points: usize = 0;
478    let mut id_to_associated_data = Vec::<Data::AssociatedDataType>::new();
479    let skip_base = skip_base.unwrap_or(0);
480    // Loop over all the raw data
481    for chunk in dataset_iter.skip(skip_base).chunks(batch_size).into_iter() {
482        data_batch.clear();
483        for (data_vector, associated_data) in chunk {
484            data_batch.push(data_vector);
485            id_to_associated_data.push(associated_data);
486        }
487        let points = data_batch.len();
488
489        if points == 0 {
490            continue;
491        }
492
493        // For each node in the raw data, calculate the distance to each query vector and store it in the priority queue for that query.  This will find the closest N neighbors for each query.
494        queries_and_neighbor_queue
495            .par_iter_mut()
496            .enumerate()
497            .for_each_in_pool(
498                pool.as_ref(),
499                |(idx_query, (query, ref mut neighbor_queue))| {
500                    for (idx_in_batch, data) in data_batch.iter().enumerate() {
501                        let idx = (num_base_points + idx_in_batch) as u32;
502
503                        let allowed_by_bitmap = if let Some(ref bitmaps) = query_bitmaps {
504                            if let Ok(idx_usize) = idx.try_into() {
505                                bitmaps[idx_query].contains(idx_usize)
506                            } else {
507                                false
508                            }
509                        } else {
510                            true
511                        };
512
513                        if allowed_by_bitmap {
514                            let distance = distance_comparer.evaluate_similarity(data, query);
515                            neighbor_queue.insert(Neighbor { id: idx, distance });
516                        }
517                    }
518                },
519            );
520
521        num_base_points += points;
522    }
523
524    if let Some(insert_iter) = insert_iter {
525        for (insert_idx, (data_vector, _associated_data)) in insert_iter.enumerate() {
526            // For each node in the raw data, calculate the distance to each query vector and store it in the priority queue for that query.  This will find the closest N neighbors for each query.
527            for (idx_query, (query, ref mut neighbor_queue)) in
528                queries_and_neighbor_queue.iter_mut().enumerate()
529            {
530                let idx = (num_base_points + insert_idx) as u32;
531
532                let allowed_by_bitmap = if let Some(ref bitmaps) = query_bitmaps {
533                    if let Ok(idx_usize) = idx.try_into() {
534                        bitmaps[idx_query].contains(idx_usize)
535                    } else {
536                        false
537                    }
538                } else {
539                    true
540                };
541
542                if allowed_by_bitmap {
543                    let distance = distance_comparer.evaluate_similarity(&data_vector, query);
544                    neighbor_queue.insert(Neighbor { id: idx, distance })
545                }
546            }
547        }
548    }
549
550    Ok((neighbor_queues, id_to_associated_data))
551}
552
553#[allow(clippy::too_many_arguments)]
554pub fn compute_multivec_ground_truth_from_data<T>(
555    distance_function: Metric,
556    aggregation_method: MultivecAggregationMethod,
557    base_vectors: Vec<Matrix<T>>,
558    queries: Vec<Matrix<T>>,
559    query_dim: usize,
560    recall_at: u32,
561    query_bitmaps: Option<Vec<BitSet>>,
562) -> CMDResult<Vec<NeighborPriorityQueue<u32>>>
563where
564    T: VectorRepr,
565{
566    let query_num = queries.len();
567
568    let mut neighbor_queues: Vec<NeighborPriorityQueue<u32>> = Vec::with_capacity(query_num);
569    //
570    for _ in 0..query_num {
571        neighbor_queues.push(NeighborPriorityQueue::new(recall_at as usize));
572    }
573    let mut query_multivecs_and_neighbor_queue: Vec<_> =
574        queries.iter().zip(neighbor_queues.iter_mut()).collect();
575
576    let distance_comparer = T::distance(distance_function, Some(query_dim));
577
578    let pool = create_thread_pool(0)?;
579
580    // for each query multivec, compute chamfer distance in parallel
581
582    query_multivecs_and_neighbor_queue
583        .par_iter_mut()
584        .enumerate()
585        .for_each_in_pool(
586            pool.as_ref(),
587            |(query_idx, (query_multivec, neighbor_queue))| {
588                for (idx_base, base_multivec) in base_vectors.iter().enumerate() {
589                    // check if calculation is allowed by bitmap if present
590                    let allowed_by_bitmap = if let Some(ref bitmaps) = query_bitmaps {
591                        bitmaps[query_idx].contains(idx_base)
592                    } else {
593                        true
594                    };
595
596                    if allowed_by_bitmap {
597                        // compute distance between query_multivec and base_multivec
598                        let distance = match aggregation_method {
599                            MultivecAggregationMethod::AveragePairwise => {
600                                let mut total_distance = 0.0;
601                                for query_vec in query_multivec.row_iter() {
602                                    for base_vec in base_multivec.row_iter() {
603                                        let dist = distance_comparer
604                                            .evaluate_similarity(query_vec, base_vec);
605                                        total_distance += dist;
606                                    }
607                                }
608                                total_distance
609                                    / (query_multivec.nrows() * base_multivec.nrows()) as f32
610                            }
611                            MultivecAggregationMethod::MinPairwise => {
612                                let mut min_distance = f32::MAX;
613                                for query_vec in query_multivec.row_iter() {
614                                    for base_vec in base_multivec.row_iter() {
615                                        let dist = distance_comparer
616                                            .evaluate_similarity(query_vec, base_vec);
617                                        min_distance = min_distance.min(dist);
618                                    }
619                                }
620                                min_distance
621                            }
622                            MultivecAggregationMethod::AvgofMins => {
623                                let mut distance = 0_f32;
624                                for query_vec in query_multivec.row_iter() {
625                                    let mut local_min = f32::MAX;
626                                    for base_vec in base_multivec.row_iter() {
627                                        let dist = distance_comparer
628                                            .evaluate_similarity(query_vec, base_vec);
629                                        local_min = local_min.min(dist);
630                                    }
631                                    distance += local_min;
632                                }
633                                distance / query_multivec.nrows() as f32
634                            }
635                        };
636                        // insert into neighbor queue
637                        let idx = idx_base as u32;
638                        neighbor_queue.insert(Neighbor { id: idx, distance });
639                    }
640                }
641            },
642        );
643
644    Ok(neighbor_queues)
645}