Skip to main content

diskann_tools/utils/
ground_truth.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use bit_set::BitSet;
7use diskann_label_filter::{eval_query_expr, read_and_parse_queries, read_baselabels};
8
9use std::{io::Write, mem::size_of, str::FromStr};
10
11use bytemuck::cast_slice;
12use diskann::{
13    neighbor::{Neighbor, NeighborPriorityQueue},
14    utils::VectorRepr,
15};
16use diskann_providers::storage::{StorageReadProvider, StorageWriteProvider};
17use diskann_providers::{
18    common::AlignedBoxWithSlice,
19    model::graph::traits::GraphDataType,
20    utils::{
21        create_thread_pool, file_util, write_metadata, ParallelIteratorInPool, VectorDataIterator,
22    },
23};
24use diskann_utils::views::Matrix;
25use diskann_vector::{distance::Metric, DistanceFunction};
26use itertools::Itertools;
27use rayon::prelude::*;
28
29use crate::utils::{search_index_utils, CMDResult, CMDToolError};
30
31pub fn read_labels_and_compute_bitmap(
32    base_label_filename: &str,
33    query_label_filename: &str,
34) -> CMDResult<Vec<BitSet>> {
35    // Read base labels
36    let base_labels = read_baselabels(base_label_filename)?;
37
38    // Parse queries and evaluate against labels
39    let parsed_queries = read_and_parse_queries(query_label_filename)?;
40
41    // using the global threadpool is fine here
42    #[allow(clippy::disallowed_methods)]
43    let query_bitmaps: Vec<BitSet> = parsed_queries
44        .par_iter()
45        .map(|(_query_id, query_expr)| {
46            let mut bitmap = BitSet::new();
47            for base_label in base_labels.iter() {
48                if eval_query_expr(query_expr, &base_label.label) {
49                    bitmap.insert(base_label.doc_id);
50                }
51            }
52            bitmap
53        })
54        .collect();
55
56    Ok(query_bitmaps)
57}
58
59#[allow(clippy::too_many_arguments)]
60#[allow(clippy::panic)]
61/// Computes the true nearest neighbors for a set of queries and writes them to a file.
62///
63/// # Arguments
64///
65/// * `distance_function` - e.g. L2
66/// * `base_file` - The file containing the base vectors.
67/// * `query_file` - The file containing the query vectors.
68/// * `ground_truth_file` - The file to write the ground truth results to.
69/// * `recall_at` - The number of neighbors to compute for each query.
70/// * `insert_file` - Optional file containing more dataset vectors. This may be useful if you are testing recall for an index that has points dynamically inserted into it.
71/// * `skip_base` - Optional number of base points to skip. This is useful if you want to compute the ground truth for a set where the first skip_base points are deleted from the index.
72pub fn compute_ground_truth_from_datafiles<
73    Data: GraphDataType,
74    StorageProvider: StorageReadProvider + StorageWriteProvider,
75>(
76    storage_provider: &StorageProvider,
77    distance_function: Metric,
78    base_file: &str,
79    query_file: &str,
80    ground_truth_file: &str,
81    vector_filters_file: Option<&str>,
82    recall_at: u32,
83    insert_file: Option<&str>,
84    skip_base: Option<usize>,
85    associated_data_file: Option<String>,
86    base_file_labels: Option<&str>,
87    query_file_labels: Option<&str>,
88) -> CMDResult<()> {
89    let dataset_iterator = VectorDataIterator::<StorageProvider, Data>::new(
90        base_file,
91        associated_data_file.clone(),
92        storage_provider,
93    )?;
94
95    // both base_file_labels and query_file_labels are provided or both are not provided
96    if !((base_file_labels.is_some() && query_file_labels.is_some())
97        || (base_file_labels.is_none() && query_file_labels.is_none()))
98    {
99        return Err(CMDToolError {
100            details: "Both base_file_labels and query_file_labels must be provided or both must be not provided.".to_string(),
101        });
102    }
103
104    if base_file_labels.is_some() && vector_filters_file.is_some() {
105        return Err(CMDToolError {
106            details: "Both base_file_labels and vector_filters_file cannot be provided."
107                .to_string(),
108        });
109    }
110
111    let insert_iterator = match insert_file {
112        Some(insert_file) => {
113            let i = VectorDataIterator::<StorageProvider, Data>::new(
114                insert_file,
115                Option::None,
116                storage_provider,
117            )?;
118            Some(i)
119        }
120        None => None,
121    };
122
123    // Load the query file
124    let (raw_query_data, query_num, query_dim) = file_util::load_bin::<
125        Data::VectorDataType,
126        StorageProvider,
127    >(storage_provider, query_file, 0)?;
128
129    let mut query_bitmaps: Option<Vec<BitSet>> = None;
130    if let (Some(base_file_labels), Some(query_file_labels)) = (base_file_labels, query_file_labels)
131    {
132        query_bitmaps = Some(read_labels_and_compute_bitmap(
133            base_file_labels,
134            query_file_labels,
135        )?);
136    }
137
138    let queries: Vec<_> = raw_query_data.chunks(query_dim).collect();
139
140    // Load the vector filters
141    let vector_filters = match vector_filters_file {
142        Some(vector_filters_file) => {
143            let filters =
144                search_index_utils::load_vector_filters(storage_provider, vector_filters_file)?;
145
146            assert_eq!(
147                filters.len(),
148                queries.len(),
149                "Mismatch in query and vector filter sizes"
150            );
151
152            Some(filters)
153        }
154        None => None,
155    };
156
157    let has_vector_filters = vector_filters.is_some();
158    let has_query_bitmaps = query_bitmaps.is_some();
159
160    if has_vector_filters {
161        // copy vector_filters to query_bitmaps one item at a time
162        if let Some(filters) = vector_filters {
163            let mut bitmaps = vec![BitSet::new(); queries.len()];
164            for (idx_query, filter) in filters.iter().enumerate() {
165                for item in filter.iter() {
166                    if let Ok(idx) = (*item).try_into() {
167                        bitmaps[idx_query].insert(idx);
168                    }
169                }
170            }
171            query_bitmaps = Some(bitmaps)
172        }
173    }
174
175    let query_aligned_dim = query_dim.next_multiple_of(8);
176    let ground_truth_result = compute_ground_truth_from_data::<
177        Data,
178        StorageProvider,
179        VectorDataIterator<StorageProvider, Data>,
180    >(
181        distance_function,
182        dataset_iterator,
183        queries,
184        query_aligned_dim,
185        recall_at,
186        insert_iterator,
187        skip_base,
188        query_bitmaps,
189    );
190    assert!(
191        &ground_truth_result.is_ok(),
192        "Ground-truth computation failed"
193    );
194    let (ground_truth, id_to_associated_data) = ground_truth_result?;
195
196    assert_ne!(ground_truth.len(), 0, "No ground-truth results computed");
197
198    if has_vector_filters || has_query_bitmaps {
199        let ground_truth_collection = ground_truth
200            .into_iter()
201            .map(|npq| npq.into_iter().collect())
202            .collect();
203        write_range_search_ground_truth(
204            storage_provider,
205            ground_truth_file,
206            query_num,
207            ground_truth_collection,
208        )
209    } else {
210        // Write results and return
211        let id_to_associated_data = associated_data_file.map(|_| id_to_associated_data);
212        write_ground_truth::<Data>(
213            storage_provider,
214            ground_truth_file,
215            query_num,
216            recall_at as usize,
217            ground_truth,
218            id_to_associated_data,
219        )
220    }
221}
222
223#[derive(Debug, Clone)]
224pub enum MultivecAggregationMethod {
225    AveragePairwise,
226    MinPairwise,
227    AvgofMins,
228}
229
230#[derive(Debug)]
231pub enum ParseAggrError {
232    InvalidFormat(String),
233}
234
235impl std::fmt::Display for ParseAggrError {
236    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
237        match self {
238            Self::InvalidFormat(str) => write!(f, "Invalid format for Aggregation Method: {}", str),
239        }
240    }
241}
242
243impl std::error::Error for ParseAggrError {}
244
245impl FromStr for MultivecAggregationMethod {
246    type Err = ParseAggrError;
247
248    fn from_str(s: &str) -> Result<Self, Self::Err> {
249        match s.to_lowercase().as_str() {
250            "average_pairwise" => Ok(MultivecAggregationMethod::AveragePairwise),
251            "min_pairwise" => Ok(MultivecAggregationMethod::MinPairwise),
252            "avg_of_mins" => Ok(MultivecAggregationMethod::AvgofMins),
253            _ => Err(ParseAggrError::InvalidFormat(String::from(s))),
254        }
255    }
256}
257
258#[allow(clippy::too_many_arguments)]
259#[allow(clippy::panic)]
260/// Computes the true nearest neighbors for a set of queries and writes them to a file.
261///
262/// # Arguments
263///
264/// * `distance_function` - e.g. L2
265/// * `aggregation_method` - e.g. Average or Min
266/// * `base_file` - The file containing the base vectors.
267/// * `query_file` - The file containing the query vectors.
268/// * `ground_truth_file` - The file to write the ground truth results to.
269/// * `recall_at` - The number of neighbors to compute for each query.
270/// * `base_file_labels` - Optional labels file for the base vectors to filter which base vectors to consider per query.
271/// * `query_file_labels` - Optional labels file for the query vectors to filter which base vectors to consider per query.
272pub fn compute_multivec_ground_truth_from_datafiles<
273    Data: GraphDataType,
274    StorageProvider: StorageReadProvider + StorageWriteProvider,
275>(
276    storage_provider: &StorageProvider,
277    distance_function: Metric,
278    aggregation_method: MultivecAggregationMethod,
279    base_file: &str,
280    query_file: &str,
281    ground_truth_file: &str,
282    recall_at: u32,
283    base_file_labels: Option<&str>,
284    query_file_labels: Option<&str>,
285) -> CMDResult<()> {
286    let (base_vectors, _, _, _) = file_util::load_multivec_bin::<
287        Data::VectorDataType,
288        StorageProvider,
289    >(storage_provider, base_file)?;
290
291    let (query_vectors, query_num, query_dim, _) = file_util::load_multivec_bin::<
292        Data::VectorDataType,
293        StorageProvider,
294    >(storage_provider, query_file)?;
295
296    // both base_file_labels and query_file_labels are provided or both are not provided
297    if !((base_file_labels.is_some() && query_file_labels.is_some())
298        || (base_file_labels.is_none() && query_file_labels.is_none()))
299    {
300        return Err(CMDToolError {
301            details: "Both base_file_labels and query_file_labels must be provided or both must be not provided.".to_string(),
302        });
303    }
304
305    let mut query_bitmaps: Option<Vec<BitSet>> = None;
306    if let (Some(base_file_labels), Some(query_file_labels)) = (base_file_labels, query_file_labels)
307    {
308        query_bitmaps = Some(read_labels_and_compute_bitmap(
309            base_file_labels,
310            query_file_labels,
311        )?);
312    }
313
314    let has_query_bitmaps = query_bitmaps.is_some();
315
316    let ground_truth =
317        compute_multivec_ground_truth_from_data::<Data::VectorDataType, StorageProvider>(
318            distance_function,
319            aggregation_method,
320            base_vectors,
321            query_vectors,
322            query_dim,
323            recall_at,
324            query_bitmaps,
325        )?;
326
327    if has_query_bitmaps {
328        let ground_truth_collection = ground_truth
329            .into_iter()
330            .map(|npq| npq.into_iter().collect())
331            .collect();
332        write_range_search_ground_truth(
333            storage_provider,
334            ground_truth_file,
335            query_num,
336            ground_truth_collection,
337        )
338    } else {
339        // Write results and return
340        write_ground_truth::<Data>(
341            storage_provider,
342            ground_truth_file,
343            query_num,
344            recall_at as usize,
345            ground_truth,
346            Option::None,
347        )
348    }
349}
350
351pub fn compute_range_search_ground_truth_from_datafiles<
352    Data: GraphDataType,
353    StorageProvider: StorageReadProvider + StorageWriteProvider,
354>(
355    storage_provider: &StorageProvider,
356    distance_function: Metric,
357    base_file: &str,
358    query_file: &str,
359    ground_truth_file: &str,
360    range_threshold: f32,
361    tags_file: &str,
362) -> CMDResult<()> {
363    if !tags_file.is_empty() {
364        // We have not implemented tags yet so let the user know!
365        return Err(CMDToolError {
366            details: "Tag files are not implemented for the ground_truth computation yet."
367                .to_string(),
368        });
369    }
370
371    let dataset_iterator = VectorDataIterator::<StorageProvider, Data>::new(
372        base_file,
373        Option::None,
374        storage_provider,
375    )?;
376
377    // Load the query file
378    let (raw_query_data, query_num, query_dim) = file_util::load_bin::<
379        Data::VectorDataType,
380        StorageProvider,
381    >(storage_provider, query_file, 0)?;
382    let queries: Vec<_> = raw_query_data.chunks(query_dim).collect();
383
384    let query_aligned_dim = query_dim.next_multiple_of(8);
385    let ground_truth_result = compute_range_search_ground_truth_from_data::<
386        Data,
387        StorageProvider,
388        VectorDataIterator<StorageProvider, Data>,
389    >(
390        distance_function,
391        dataset_iterator,
392        queries,
393        query_aligned_dim,
394        range_threshold,
395    );
396    assert!(
397        &ground_truth_result.is_ok(),
398        "Ground-truth computation failed"
399    );
400    let ground_truth = ground_truth_result?;
401
402    assert_ne!(ground_truth.len(), 0, "No ground-truth results computed");
403
404    // Write results
405    let _res = write_range_search_ground_truth(
406        storage_provider,
407        ground_truth_file,
408        query_num,
409        ground_truth,
410    );
411
412    Ok(())
413}
414
415fn write_range_search_ground_truth<StorageProvider: StorageReadProvider + StorageWriteProvider>(
416    storage_provider: &StorageProvider,
417    ground_truth_file: &str,
418    number_of_queries: usize,
419    ground_truth: Vec<Vec<Neighbor<u32>>>,
420) -> CMDResult<()> {
421    let mut file = storage_provider.create_for_write(ground_truth_file)?;
422
423    let queue_sizes: Vec<u32> = ground_truth
424        .iter()
425        .map(|queue| queue.len() as u32)
426        .collect();
427    let total_number_of_neighbors: usize = queue_sizes.iter().sum::<u32>() as usize;
428
429    // Metadata
430    write_metadata(&mut file, number_of_queries, total_number_of_neighbors)?;
431
432    // Write queue sizes array.
433    let mut queue_sizes_buffer = vec![0; queue_sizes.len() * size_of::<u32>()];
434    queue_sizes_buffer.clone_from_slice(cast_slice::<u32, u8>(&queue_sizes));
435    file.write_all(&queue_sizes_buffer)?;
436
437    let mut neighbor_ids: Vec<u32> = Vec::with_capacity(total_number_of_neighbors);
438
439    // Write the neighbor IDs array.
440    for query_neighbors in ground_truth {
441        for neighbor in query_neighbors.iter() {
442            neighbor_ids.push(neighbor.id);
443        }
444    }
445
446    // Write neighbor IDs
447    let mut id_buffer = vec![0; total_number_of_neighbors * size_of::<u32>()];
448    id_buffer.clone_from_slice(cast_slice::<u32, u8>(&neighbor_ids));
449    file.write_all(&id_buffer)?;
450
451    // Make sure everything is written to disk
452    file.flush()?;
453
454    Ok(())
455}
456
457/// Writes out a ground truth file.  ground_truth is a vector of NeighborPriorityQueue objects
458/// where the order of queue objects corresponds to the order of queries used to compute this
459/// ground truth.
460fn write_ground_truth<Data: GraphDataType>(
461    storage_provider: &impl StorageWriteProvider,
462    ground_truth_file: &str,
463    number_of_queries: usize,
464    number_of_neighbors: usize,
465    ground_truth: Vec<NeighborPriorityQueue<u32>>,
466    id_to_associated_data: Option<Vec<Data::AssociatedDataType>>,
467) -> CMDResult<()> {
468    let mut file = storage_provider.create_for_write(ground_truth_file)?;
469
470    write_metadata(&mut file, number_of_queries, number_of_neighbors)?;
471
472    let mut gt_ids: Vec<u32> = Vec::with_capacity(number_of_neighbors * number_of_queries);
473    let mut gt_distances: Vec<f32> = Vec::with_capacity(number_of_neighbors * number_of_queries);
474
475    // In the file, we write the neighbor IDs array first, then write the distances array.
476    for mut query_neighbors in ground_truth {
477        while query_neighbors.has_notvisited_node() {
478            let closest_node = query_neighbors.closest_notvisited();
479
480            gt_ids.push(closest_node.id);
481            gt_distances.push(closest_node.distance);
482        }
483    }
484
485    // Write neighbor IDs or Associated Data
486    if let Some(id_to_associated_data) = id_to_associated_data {
487        let mut associated_data_buffer = Vec::<u8>::new();
488        for id in gt_ids {
489            let associated_data = id_to_associated_data[id as usize];
490            let serialized_associated_data =
491                bincode::serialize(&associated_data).map_err(|e| CMDToolError {
492                    details: format!("Failed to serialize associated data: {}", e),
493                })?;
494            associated_data_buffer.extend_from_slice(serialized_associated_data.as_slice());
495        }
496        file.write_all(&associated_data_buffer)?;
497    } else {
498        let mut id_buffer = vec![0; number_of_queries * number_of_neighbors * size_of::<u32>()];
499        id_buffer.clone_from_slice(cast_slice::<u32, u8>(&gt_ids));
500        file.write_all(&id_buffer)?;
501    }
502
503    // Write neighbor distances
504    let mut distance_buffer = vec![0; number_of_queries * number_of_neighbors * size_of::<f32>()];
505    distance_buffer.clone_from_slice(cast_slice::<f32, u8>(&gt_distances));
506    file.write_all(&distance_buffer)?;
507
508    // Make sure everything is written to disk
509    file.flush()?;
510
511    Ok(())
512}
513
514type Npq = Vec<NeighborPriorityQueue<u32>>;
515/// Computes the true nearest neighbors for a set of queries and dataset iterators
516///
517/// # Arguments
518///
519/// * `distance_function` - e.g. L2
520/// * `dataset_iter` - The iterator over the dataset vectors, associated data, and
521/// * `queries` - A vector of query vectors
522/// * `query_aligned_dimmensions` - The number of dimensions to align the query vectors to for optimized distance comparison.
523/// * `recall_at` - The number of neighbors to compute for each query.
524/// * `insert_iterator` - Optional iterator containing more dataset vectors. This may be useful if you are testing recall for an index that has points dynamically inserted into it.
525/// * `skip_base` - Optional number of base points to skip. This is useful if you want to compute the ground truth for a set where the first skip_base points are deleted from the index.
526#[allow(clippy::too_many_arguments)]
527pub fn compute_ground_truth_from_data<Data, VectorReader, VectorIteratorType>(
528    distance_function: Metric,
529    dataset_iter: VectorDataIterator<VectorReader, Data>,
530    queries: Vec<&[Data::VectorDataType]>,
531    query_aligned_dimmensions: usize,
532    recall_at: u32,
533    insert_iter: Option<VectorDataIterator<VectorReader, Data>>,
534    skip_base: Option<usize>,
535    query_bitmaps: Option<Vec<BitSet>>,
536) -> CMDResult<(Npq, Vec<Data::AssociatedDataType>)>
537where
538    Data: GraphDataType,
539    VectorReader: StorageReadProvider,
540{
541    let query_num = queries.len();
542
543    let mut aligned_queries = Vec::with_capacity(query_num);
544    let mut neighbor_queues: Vec<NeighborPriorityQueue<u32>> = Vec::with_capacity(query_num);
545    for query in queries {
546        let mut aligned_query = AlignedBoxWithSlice::new(query_aligned_dimmensions, 32)?;
547        aligned_query[..query.len()].copy_from_slice(query);
548        aligned_queries.push(aligned_query);
549        neighbor_queues.push(NeighborPriorityQueue::new(recall_at as usize));
550    }
551    let mut queries_and_neighbor_queue: Vec<_> = aligned_queries
552        .iter()
553        .zip(neighbor_queues.iter_mut())
554        .collect();
555
556    let distance_comparer =
557        Data::VectorDataType::distance(distance_function, Some(query_aligned_dimmensions));
558
559    let batch_size = 10_000;
560    let mut aligned_data_batch = Vec::with_capacity(batch_size);
561    for _ in 0..batch_size {
562        aligned_data_batch.push(AlignedBoxWithSlice::new(query_aligned_dimmensions, 32)?);
563    }
564
565    let pool = create_thread_pool(0)?;
566
567    let mut num_base_points: usize = 0;
568    let mut id_to_associated_data = Vec::<Data::AssociatedDataType>::new();
569    let skip_base = skip_base.unwrap_or(0);
570    // Loop over all the raw data
571    for chunk in dataset_iter
572        .skip(skip_base)
573        .enumerate()
574        .chunks(batch_size)
575        .into_iter()
576    {
577        let mut points = 0;
578        for (idx, (data_vector, associated_data)) in chunk {
579            aligned_data_batch[idx % batch_size][..data_vector.len()].copy_from_slice(&data_vector);
580            id_to_associated_data.push(associated_data);
581            points += 1;
582        }
583
584        if points == 0 {
585            continue;
586        }
587
588        // For each node in the raw data, calculate the distance to each query vector and store it in the priority queue for that query.  This will find the closest N neighbors for each query.
589        queries_and_neighbor_queue
590            .par_iter_mut()
591            .enumerate()
592            .for_each_in_pool(
593                &pool,
594                |(idx_query, (aligned_query, ref mut neighbor_queue))| {
595                    for (idx_in_batch, aligned_data) in
596                        aligned_data_batch[..points].iter().enumerate()
597                    {
598                        let idx = (num_base_points + idx_in_batch) as u32;
599
600                        let allowed_by_bitmap = if let Some(ref bitmaps) = query_bitmaps {
601                            if let Ok(idx_usize) = idx.try_into() {
602                                bitmaps[idx_query].contains(idx_usize)
603                            } else {
604                                false
605                            }
606                        } else {
607                            true
608                        };
609
610                        if allowed_by_bitmap {
611                            let distance = distance_comparer
612                                .evaluate_similarity(&**aligned_data, aligned_query);
613                            neighbor_queue.insert(Neighbor { id: idx, distance });
614                        }
615                    }
616                },
617            );
618
619        num_base_points += points;
620    }
621
622    let mut aligned_data = AlignedBoxWithSlice::new(query_aligned_dimmensions, 32)?;
623
624    if let Some(insert_iter) = insert_iter {
625        for (insert_idx, (data_vector, _associated_data)) in insert_iter.enumerate() {
626            aligned_data[..data_vector.len()].copy_from_slice(&data_vector);
627            // For each node in the raw data, calculate the distance to each query vector and store it in the priority queue for that query.  This will find the closest N neighbors for each query.
628            for (idx_query, (aligned_query, ref mut neighbor_queue)) in
629                queries_and_neighbor_queue.iter_mut().enumerate()
630            {
631                let idx = (num_base_points + insert_idx) as u32;
632
633                let allowed_by_bitmap = if let Some(ref bitmaps) = query_bitmaps {
634                    if let Ok(idx_usize) = idx.try_into() {
635                        bitmaps[idx_query].contains(idx_usize)
636                    } else {
637                        false
638                    }
639                } else {
640                    true
641                };
642
643                if allowed_by_bitmap {
644                    let distance =
645                        distance_comparer.evaluate_similarity(&*aligned_data, aligned_query);
646                    neighbor_queue.insert(Neighbor { id: idx, distance })
647                }
648            }
649        }
650    }
651
652    Ok((neighbor_queues, id_to_associated_data))
653}
654
655#[allow(clippy::too_many_arguments)]
656pub fn compute_multivec_ground_truth_from_data<T, VectorReader>(
657    distance_function: Metric,
658    aggregation_method: MultivecAggregationMethod,
659    base_vectors: Vec<Matrix<T>>,
660    queries: Vec<Matrix<T>>,
661    query_dim: usize,
662    recall_at: u32,
663    query_bitmaps: Option<Vec<BitSet>>,
664) -> CMDResult<Vec<NeighborPriorityQueue<u32>>>
665where
666    T: VectorRepr,
667    VectorReader: StorageReadProvider,
668{
669    let query_num = queries.len();
670
671    let mut neighbor_queues: Vec<NeighborPriorityQueue<u32>> = Vec::with_capacity(query_num);
672    //
673    for _ in 0..query_num {
674        neighbor_queues.push(NeighborPriorityQueue::new(recall_at as usize));
675    }
676    let mut query_multivecs_and_neighbor_queue: Vec<_> =
677        queries.iter().zip(neighbor_queues.iter_mut()).collect();
678
679    let distance_comparer = T::distance(distance_function, Some(query_dim));
680
681    let pool = create_thread_pool(0)?;
682
683    // for each query multivec, compute chamfer distance in parallel
684
685    query_multivecs_and_neighbor_queue
686        .par_iter_mut()
687        .enumerate()
688        .for_each_in_pool(&pool, |(query_idx, (query_multivec, neighbor_queue))| {
689            for (idx_base, base_multivec) in base_vectors.iter().enumerate() {
690                // check if calculation is allowed by bitmap if present
691                let allowed_by_bitmap = if let Some(ref bitmaps) = query_bitmaps {
692                    bitmaps[query_idx].contains(idx_base)
693                } else {
694                    true
695                };
696
697                if allowed_by_bitmap {
698                    // compute distance between query_multivec and base_multivec
699                    let distance = match aggregation_method {
700                        MultivecAggregationMethod::AveragePairwise => {
701                            let mut total_distance = 0.0;
702                            for query_vec in query_multivec.row_iter() {
703                                for base_vec in base_multivec.row_iter() {
704                                    let dist =
705                                        distance_comparer.evaluate_similarity(query_vec, base_vec);
706                                    total_distance += dist;
707                                }
708                            }
709                            total_distance / (query_multivec.nrows() * base_multivec.nrows()) as f32
710                        }
711                        MultivecAggregationMethod::MinPairwise => {
712                            let mut min_distance = f32::MAX;
713                            for query_vec in query_multivec.row_iter() {
714                                for base_vec in base_multivec.row_iter() {
715                                    let dist =
716                                        distance_comparer.evaluate_similarity(query_vec, base_vec);
717                                    min_distance = min_distance.min(dist);
718                                }
719                            }
720                            min_distance
721                        }
722                        MultivecAggregationMethod::AvgofMins => {
723                            let mut distance = 0_f32;
724                            for query_vec in query_multivec.row_iter() {
725                                let mut local_min = f32::MAX;
726                                for base_vec in base_multivec.row_iter() {
727                                    let dist =
728                                        distance_comparer.evaluate_similarity(query_vec, base_vec);
729                                    local_min = local_min.min(dist);
730                                }
731                                distance += local_min;
732                            }
733                            distance / query_multivec.nrows() as f32
734                        }
735                    };
736                    // insert into neighbor queue
737                    let idx = idx_base as u32;
738                    neighbor_queue.insert(Neighbor { id: idx, distance });
739                }
740            }
741        });
742
743    Ok(neighbor_queues)
744}
745
746pub fn compute_range_search_ground_truth_from_data<Data, VectorReader, VectorIteratorType>(
747    distance_function: Metric,
748    dataset_iter: VectorDataIterator<VectorReader, Data>,
749    queries: Vec<&[Data::VectorDataType]>,
750    query_aligned_dimmensions: usize,
751    range_threshold: f32,
752) -> CMDResult<Vec<Vec<Neighbor<u32>>>>
753where
754    Data: GraphDataType,
755    VectorReader: StorageReadProvider,
756{
757    let query_num = queries.len();
758    let mut neighbor_queues: Vec<Vec<Neighbor<u32>>> = Vec::with_capacity(query_num);
759    for _ in 0..query_num {
760        neighbor_queues.push(Vec::new());
761    }
762
763    let mut queries_and_neighbor_queue: Vec<_> =
764        queries.iter().zip(neighbor_queues.iter_mut()).collect();
765
766    let distance_comparer =
767        Data::VectorDataType::distance(distance_function, Some(query_aligned_dimmensions));
768
769    let mut aligned_data = AlignedBoxWithSlice::new(query_aligned_dimmensions, 32)?;
770    let mut aligned_query = AlignedBoxWithSlice::new(query_aligned_dimmensions, 32)?;
771
772    for (idx, (data_vector, _associated_data)) in dataset_iter.enumerate() {
773        aligned_data[..data_vector.len()].copy_from_slice(&data_vector);
774        for (query, ref mut neighbor_queue) in queries_and_neighbor_queue.iter_mut() {
775            aligned_query[..query.len()].copy_from_slice(query);
776            let distance = distance_comparer.evaluate_similarity(&*aligned_data, &aligned_query);
777            if distance <= range_threshold {
778                neighbor_queue.push(Neighbor {
779                    id: idx as u32,
780                    distance,
781                })
782            }
783        }
784    }
785
786    Ok(neighbor_queues)
787}