Skip to main content

diskann_tools/utils/
ground_truth.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use bit_set::BitSet;
7use diskann_label_filter::{eval_query_expr, read_and_parse_queries, read_baselabels};
8
9use std::{io::Write, mem::size_of, str::FromStr};
10
11use bytemuck::cast_slice;
12use diskann::{
13    neighbor::{Neighbor, NeighborPriorityQueue},
14    utils::VectorRepr,
15};
16use diskann_providers::storage::{StorageReadProvider, StorageWriteProvider};
17use diskann_providers::{
18    common::AlignedBoxWithSlice,
19    model::graph::traits::GraphDataType,
20    utils::{create_thread_pool, file_util, ParallelIteratorInPool, VectorDataIterator},
21};
22use diskann_utils::{
23    io::{read_bin, Metadata},
24    views::Matrix,
25};
26use diskann_vector::{distance::Metric, DistanceFunction};
27use itertools::Itertools;
28use rayon::prelude::*;
29
30use crate::utils::{search_index_utils, CMDResult, CMDToolError};
31
32pub fn read_labels_and_compute_bitmap(
33    base_label_filename: &str,
34    query_label_filename: &str,
35) -> CMDResult<Vec<BitSet>> {
36    // Read base labels
37    let base_labels = read_baselabels(base_label_filename)?;
38
39    // Parse queries and evaluate against labels
40    let parsed_queries = read_and_parse_queries(query_label_filename)?;
41
42    // using the global threadpool is fine here
43    #[allow(clippy::disallowed_methods)]
44    let query_bitmaps: Vec<BitSet> = parsed_queries
45        .par_iter()
46        .map(|(_query_id, query_expr)| {
47            let mut bitmap = BitSet::new();
48            for base_label in base_labels.iter() {
49                if eval_query_expr(query_expr, &base_label.label) {
50                    bitmap.insert(base_label.doc_id);
51                }
52            }
53            bitmap
54        })
55        .collect();
56
57    Ok(query_bitmaps)
58}
59
60#[allow(clippy::too_many_arguments)]
61#[allow(clippy::panic)]
62/// Computes the true nearest neighbors for a set of queries and writes them to a file.
63///
64/// # Arguments
65///
66/// * `distance_function` - e.g. L2
67/// * `base_file` - The file containing the base vectors.
68/// * `query_file` - The file containing the query vectors.
69/// * `ground_truth_file` - The file to write the ground truth results to.
70/// * `recall_at` - The number of neighbors to compute for each query.
71/// * `insert_file` - Optional file containing more dataset vectors. This may be useful if you are testing recall for an index that has points dynamically inserted into it.
72/// * `skip_base` - Optional number of base points to skip. This is useful if you want to compute the ground truth for a set where the first skip_base points are deleted from the index.
73pub fn compute_ground_truth_from_datafiles<
74    Data: GraphDataType,
75    StorageProvider: StorageReadProvider + StorageWriteProvider,
76>(
77    storage_provider: &StorageProvider,
78    distance_function: Metric,
79    base_file: &str,
80    query_file: &str,
81    ground_truth_file: &str,
82    vector_filters_file: Option<&str>,
83    recall_at: u32,
84    insert_file: Option<&str>,
85    skip_base: Option<usize>,
86    associated_data_file: Option<String>,
87    base_file_labels: Option<&str>,
88    query_file_labels: Option<&str>,
89) -> CMDResult<()> {
90    let dataset_iterator = VectorDataIterator::<StorageProvider, Data>::new(
91        base_file,
92        associated_data_file.clone(),
93        storage_provider,
94    )?;
95
96    // both base_file_labels and query_file_labels are provided or both are not provided
97    if !((base_file_labels.is_some() && query_file_labels.is_some())
98        || (base_file_labels.is_none() && query_file_labels.is_none()))
99    {
100        return Err(CMDToolError {
101            details: "Both base_file_labels and query_file_labels must be provided or both must be not provided.".to_string(),
102        });
103    }
104
105    if base_file_labels.is_some() && vector_filters_file.is_some() {
106        return Err(CMDToolError {
107            details: "Both base_file_labels and vector_filters_file cannot be provided."
108                .to_string(),
109        });
110    }
111
112    let insert_iterator = match insert_file {
113        Some(insert_file) => {
114            let i = VectorDataIterator::<StorageProvider, Data>::new(
115                insert_file,
116                Option::None,
117                storage_provider,
118            )?;
119            Some(i)
120        }
121        None => None,
122    };
123
124    // Load the query file
125    let query_data =
126        read_bin::<Data::VectorDataType>(&mut storage_provider.open_reader(query_file)?)?;
127    let query_num = query_data.nrows();
128    let query_dim = query_data.ncols();
129
130    let mut query_bitmaps: Option<Vec<BitSet>> = None;
131    if let (Some(base_file_labels), Some(query_file_labels)) = (base_file_labels, query_file_labels)
132    {
133        query_bitmaps = Some(read_labels_and_compute_bitmap(
134            base_file_labels,
135            query_file_labels,
136        )?);
137    }
138
139    let queries: Vec<_> = query_data.row_iter().collect();
140
141    // Load the vector filters
142    let vector_filters = match vector_filters_file {
143        Some(vector_filters_file) => {
144            let filters =
145                search_index_utils::load_vector_filters(storage_provider, vector_filters_file)?;
146
147            assert_eq!(
148                filters.len(),
149                queries.len(),
150                "Mismatch in query and vector filter sizes"
151            );
152
153            Some(filters)
154        }
155        None => None,
156    };
157
158    let has_vector_filters = vector_filters.is_some();
159    let has_query_bitmaps = query_bitmaps.is_some();
160
161    if has_vector_filters {
162        // copy vector_filters to query_bitmaps one item at a time
163        if let Some(filters) = vector_filters {
164            let mut bitmaps = vec![BitSet::new(); queries.len()];
165            for (idx_query, filter) in filters.iter().enumerate() {
166                for item in filter.iter() {
167                    if let Ok(idx) = (*item).try_into() {
168                        bitmaps[idx_query].insert(idx);
169                    }
170                }
171            }
172            query_bitmaps = Some(bitmaps)
173        }
174    }
175
176    let query_aligned_dim = query_dim.next_multiple_of(8);
177    let ground_truth_result = compute_ground_truth_from_data::<
178        Data,
179        StorageProvider,
180        VectorDataIterator<StorageProvider, Data>,
181    >(
182        distance_function,
183        dataset_iterator,
184        queries,
185        query_aligned_dim,
186        recall_at,
187        insert_iterator,
188        skip_base,
189        query_bitmaps,
190    );
191    assert!(
192        &ground_truth_result.is_ok(),
193        "Ground-truth computation failed"
194    );
195    let (ground_truth, id_to_associated_data) = ground_truth_result?;
196
197    assert_ne!(ground_truth.len(), 0, "No ground-truth results computed");
198
199    if has_vector_filters || has_query_bitmaps {
200        let ground_truth_collection = ground_truth
201            .into_iter()
202            .map(|npq| npq.into_iter().collect())
203            .collect();
204        write_range_search_ground_truth(
205            storage_provider,
206            ground_truth_file,
207            query_num,
208            ground_truth_collection,
209        )
210    } else {
211        // Write results and return
212        let id_to_associated_data = associated_data_file.map(|_| id_to_associated_data);
213        write_ground_truth::<Data>(
214            storage_provider,
215            ground_truth_file,
216            query_num,
217            recall_at as usize,
218            ground_truth,
219            id_to_associated_data,
220        )
221    }
222}
223
224#[derive(Debug, Clone)]
225pub enum MultivecAggregationMethod {
226    AveragePairwise,
227    MinPairwise,
228    AvgofMins,
229}
230
231#[derive(Debug)]
232pub enum ParseAggrError {
233    InvalidFormat(String),
234}
235
236impl std::fmt::Display for ParseAggrError {
237    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
238        match self {
239            Self::InvalidFormat(str) => write!(f, "Invalid format for Aggregation Method: {}", str),
240        }
241    }
242}
243
244impl std::error::Error for ParseAggrError {}
245
246impl FromStr for MultivecAggregationMethod {
247    type Err = ParseAggrError;
248
249    fn from_str(s: &str) -> Result<Self, Self::Err> {
250        match s.to_lowercase().as_str() {
251            "average_pairwise" => Ok(MultivecAggregationMethod::AveragePairwise),
252            "min_pairwise" => Ok(MultivecAggregationMethod::MinPairwise),
253            "avg_of_mins" => Ok(MultivecAggregationMethod::AvgofMins),
254            _ => Err(ParseAggrError::InvalidFormat(String::from(s))),
255        }
256    }
257}
258
259#[allow(clippy::too_many_arguments)]
260#[allow(clippy::panic)]
261/// Computes the true nearest neighbors for a set of queries and writes them to a file.
262///
263/// # Arguments
264///
265/// * `distance_function` - e.g. L2
266/// * `aggregation_method` - e.g. Average or Min
267/// * `base_file` - The file containing the base vectors.
268/// * `query_file` - The file containing the query vectors.
269/// * `ground_truth_file` - The file to write the ground truth results to.
270/// * `recall_at` - The number of neighbors to compute for each query.
271/// * `base_file_labels` - Optional labels file for the base vectors to filter which base vectors to consider per query.
272/// * `query_file_labels` - Optional labels file for the query vectors to filter which base vectors to consider per query.
273pub fn compute_multivec_ground_truth_from_datafiles<
274    Data: GraphDataType,
275    StorageProvider: StorageReadProvider + StorageWriteProvider,
276>(
277    storage_provider: &StorageProvider,
278    distance_function: Metric,
279    aggregation_method: MultivecAggregationMethod,
280    base_file: &str,
281    query_file: &str,
282    ground_truth_file: &str,
283    recall_at: u32,
284    base_file_labels: Option<&str>,
285    query_file_labels: Option<&str>,
286) -> CMDResult<()> {
287    let (base_vectors, _, _, _) = file_util::load_multivec_bin::<
288        Data::VectorDataType,
289        StorageProvider,
290    >(storage_provider, base_file)?;
291
292    let (query_vectors, query_num, query_dim, _) = file_util::load_multivec_bin::<
293        Data::VectorDataType,
294        StorageProvider,
295    >(storage_provider, query_file)?;
296
297    // both base_file_labels and query_file_labels are provided or both are not provided
298    if !((base_file_labels.is_some() && query_file_labels.is_some())
299        || (base_file_labels.is_none() && query_file_labels.is_none()))
300    {
301        return Err(CMDToolError {
302            details: "Both base_file_labels and query_file_labels must be provided or both must be not provided.".to_string(),
303        });
304    }
305
306    let mut query_bitmaps: Option<Vec<BitSet>> = None;
307    if let (Some(base_file_labels), Some(query_file_labels)) = (base_file_labels, query_file_labels)
308    {
309        query_bitmaps = Some(read_labels_and_compute_bitmap(
310            base_file_labels,
311            query_file_labels,
312        )?);
313    }
314
315    let has_query_bitmaps = query_bitmaps.is_some();
316
317    let ground_truth =
318        compute_multivec_ground_truth_from_data::<Data::VectorDataType, StorageProvider>(
319            distance_function,
320            aggregation_method,
321            base_vectors,
322            query_vectors,
323            query_dim,
324            recall_at,
325            query_bitmaps,
326        )?;
327
328    if has_query_bitmaps {
329        let ground_truth_collection = ground_truth
330            .into_iter()
331            .map(|npq| npq.into_iter().collect())
332            .collect();
333        write_range_search_ground_truth(
334            storage_provider,
335            ground_truth_file,
336            query_num,
337            ground_truth_collection,
338        )
339    } else {
340        // Write results and return
341        write_ground_truth::<Data>(
342            storage_provider,
343            ground_truth_file,
344            query_num,
345            recall_at as usize,
346            ground_truth,
347            Option::None,
348        )
349    }
350}
351
352pub fn compute_range_search_ground_truth_from_datafiles<
353    Data: GraphDataType,
354    StorageProvider: StorageReadProvider + StorageWriteProvider,
355>(
356    storage_provider: &StorageProvider,
357    distance_function: Metric,
358    base_file: &str,
359    query_file: &str,
360    ground_truth_file: &str,
361    range_threshold: f32,
362    tags_file: &str,
363) -> CMDResult<()> {
364    if !tags_file.is_empty() {
365        // We have not implemented tags yet so let the user know!
366        return Err(CMDToolError {
367            details: "Tag files are not implemented for the ground_truth computation yet."
368                .to_string(),
369        });
370    }
371
372    let dataset_iterator = VectorDataIterator::<StorageProvider, Data>::new(
373        base_file,
374        Option::None,
375        storage_provider,
376    )?;
377
378    // Load the query file
379    let query_data =
380        read_bin::<Data::VectorDataType>(&mut storage_provider.open_reader(query_file)?)?;
381    let query_num = query_data.nrows();
382    let query_dim = query_data.ncols();
383    let queries: Vec<_> = query_data.row_iter().collect();
384
385    let query_aligned_dim = query_dim.next_multiple_of(8);
386    let ground_truth_result = compute_range_search_ground_truth_from_data::<
387        Data,
388        StorageProvider,
389        VectorDataIterator<StorageProvider, Data>,
390    >(
391        distance_function,
392        dataset_iterator,
393        queries,
394        query_aligned_dim,
395        range_threshold,
396    );
397    assert!(
398        &ground_truth_result.is_ok(),
399        "Ground-truth computation failed"
400    );
401    let ground_truth = ground_truth_result?;
402
403    assert_ne!(ground_truth.len(), 0, "No ground-truth results computed");
404
405    // Write results
406    let _res = write_range_search_ground_truth(
407        storage_provider,
408        ground_truth_file,
409        query_num,
410        ground_truth,
411    );
412
413    Ok(())
414}
415
416fn write_range_search_ground_truth<StorageProvider: StorageReadProvider + StorageWriteProvider>(
417    storage_provider: &StorageProvider,
418    ground_truth_file: &str,
419    number_of_queries: usize,
420    ground_truth: Vec<Vec<Neighbor<u32>>>,
421) -> CMDResult<()> {
422    let mut file = storage_provider.create_for_write(ground_truth_file)?;
423
424    let queue_sizes: Vec<u32> = ground_truth
425        .iter()
426        .map(|queue| queue.len() as u32)
427        .collect();
428    let total_number_of_neighbors: usize = queue_sizes.iter().sum::<u32>() as usize;
429
430    // Metadata
431    Metadata::new(number_of_queries, total_number_of_neighbors)?.write(&mut file)?;
432
433    // Write queue sizes array.
434    let mut queue_sizes_buffer = vec![0; queue_sizes.len() * size_of::<u32>()];
435    queue_sizes_buffer.clone_from_slice(cast_slice::<u32, u8>(&queue_sizes));
436    file.write_all(&queue_sizes_buffer)?;
437
438    let mut neighbor_ids: Vec<u32> = Vec::with_capacity(total_number_of_neighbors);
439
440    // Write the neighbor IDs array.
441    for query_neighbors in ground_truth {
442        for neighbor in query_neighbors.iter() {
443            neighbor_ids.push(neighbor.id);
444        }
445    }
446
447    // Write neighbor IDs
448    let mut id_buffer = vec![0; total_number_of_neighbors * size_of::<u32>()];
449    id_buffer.clone_from_slice(cast_slice::<u32, u8>(&neighbor_ids));
450    file.write_all(&id_buffer)?;
451
452    // Make sure everything is written to disk
453    file.flush()?;
454
455    Ok(())
456}
457
458/// Writes out a ground truth file.  ground_truth is a vector of NeighborPriorityQueue objects
459/// where the order of queue objects corresponds to the order of queries used to compute this
460/// ground truth.
461fn write_ground_truth<Data: GraphDataType>(
462    storage_provider: &impl StorageWriteProvider,
463    ground_truth_file: &str,
464    number_of_queries: usize,
465    number_of_neighbors: usize,
466    ground_truth: Vec<NeighborPriorityQueue<u32>>,
467    id_to_associated_data: Option<Vec<Data::AssociatedDataType>>,
468) -> CMDResult<()> {
469    let mut file = storage_provider.create_for_write(ground_truth_file)?;
470
471    Metadata::new(number_of_queries, number_of_neighbors)?.write(&mut file)?;
472
473    let mut gt_ids: Vec<u32> = Vec::with_capacity(number_of_neighbors * number_of_queries);
474    let mut gt_distances: Vec<f32> = Vec::with_capacity(number_of_neighbors * number_of_queries);
475
476    // In the file, we write the neighbor IDs array first, then write the distances array.
477    for mut query_neighbors in ground_truth {
478        while let Some(closest_node) = query_neighbors.closest_notvisited() {
479            gt_ids.push(closest_node.id);
480            gt_distances.push(closest_node.distance);
481        }
482    }
483
484    // Write neighbor IDs or Associated Data
485    if let Some(id_to_associated_data) = id_to_associated_data {
486        let mut associated_data_buffer = Vec::<u8>::new();
487        for id in gt_ids {
488            let associated_data = id_to_associated_data[id as usize];
489            let serialized_associated_data =
490                bincode::serialize(&associated_data).map_err(|e| CMDToolError {
491                    details: format!("Failed to serialize associated data: {}", e),
492                })?;
493            associated_data_buffer.extend_from_slice(serialized_associated_data.as_slice());
494        }
495        file.write_all(&associated_data_buffer)?;
496    } else {
497        let mut id_buffer = vec![0; number_of_queries * number_of_neighbors * size_of::<u32>()];
498        id_buffer.clone_from_slice(cast_slice::<u32, u8>(&gt_ids));
499        file.write_all(&id_buffer)?;
500    }
501
502    // Write neighbor distances
503    let mut distance_buffer = vec![0; number_of_queries * number_of_neighbors * size_of::<f32>()];
504    distance_buffer.clone_from_slice(cast_slice::<f32, u8>(&gt_distances));
505    file.write_all(&distance_buffer)?;
506
507    // Make sure everything is written to disk
508    file.flush()?;
509
510    Ok(())
511}
512
513type Npq = Vec<NeighborPriorityQueue<u32>>;
514/// Computes the true nearest neighbors for a set of queries and dataset iterators
515///
516/// # Arguments
517///
518/// * `distance_function` - e.g. L2
519/// * `dataset_iter` - The iterator over the dataset vectors, associated data, and
520/// * `queries` - A vector of query vectors
521/// * `query_aligned_dimmensions` - The number of dimensions to align the query vectors to for optimized distance comparison.
522/// * `recall_at` - The number of neighbors to compute for each query.
523/// * `insert_iterator` - Optional iterator containing more dataset vectors. This may be useful if you are testing recall for an index that has points dynamically inserted into it.
524/// * `skip_base` - Optional number of base points to skip. This is useful if you want to compute the ground truth for a set where the first skip_base points are deleted from the index.
525#[allow(clippy::too_many_arguments)]
526pub fn compute_ground_truth_from_data<Data, VectorReader, VectorIteratorType>(
527    distance_function: Metric,
528    dataset_iter: VectorDataIterator<VectorReader, Data>,
529    queries: Vec<&[Data::VectorDataType]>,
530    query_aligned_dimmensions: usize,
531    recall_at: u32,
532    insert_iter: Option<VectorDataIterator<VectorReader, Data>>,
533    skip_base: Option<usize>,
534    query_bitmaps: Option<Vec<BitSet>>,
535) -> CMDResult<(Npq, Vec<Data::AssociatedDataType>)>
536where
537    Data: GraphDataType,
538    VectorReader: StorageReadProvider,
539{
540    let query_num = queries.len();
541
542    let mut aligned_queries = Vec::with_capacity(query_num);
543    let mut neighbor_queues: Vec<NeighborPriorityQueue<u32>> = Vec::with_capacity(query_num);
544    for query in queries {
545        let mut aligned_query = AlignedBoxWithSlice::new(query_aligned_dimmensions, 32)?;
546        aligned_query[..query.len()].copy_from_slice(query);
547        aligned_queries.push(aligned_query);
548        neighbor_queues.push(NeighborPriorityQueue::new(recall_at as usize));
549    }
550    let mut queries_and_neighbor_queue: Vec<_> = aligned_queries
551        .iter()
552        .zip(neighbor_queues.iter_mut())
553        .collect();
554
555    let distance_comparer =
556        Data::VectorDataType::distance(distance_function, Some(query_aligned_dimmensions));
557
558    let batch_size = 10_000;
559    let mut aligned_data_batch = Vec::with_capacity(batch_size);
560    for _ in 0..batch_size {
561        aligned_data_batch.push(AlignedBoxWithSlice::new(query_aligned_dimmensions, 32)?);
562    }
563
564    let pool = create_thread_pool(0)?;
565
566    let mut num_base_points: usize = 0;
567    let mut id_to_associated_data = Vec::<Data::AssociatedDataType>::new();
568    let skip_base = skip_base.unwrap_or(0);
569    // Loop over all the raw data
570    for chunk in dataset_iter
571        .skip(skip_base)
572        .enumerate()
573        .chunks(batch_size)
574        .into_iter()
575    {
576        let mut points = 0;
577        for (idx, (data_vector, associated_data)) in chunk {
578            aligned_data_batch[idx % batch_size][..data_vector.len()].copy_from_slice(&data_vector);
579            id_to_associated_data.push(associated_data);
580            points += 1;
581        }
582
583        if points == 0 {
584            continue;
585        }
586
587        // For each node in the raw data, calculate the distance to each query vector and store it in the priority queue for that query.  This will find the closest N neighbors for each query.
588        queries_and_neighbor_queue
589            .par_iter_mut()
590            .enumerate()
591            .for_each_in_pool(
592                &pool,
593                |(idx_query, (aligned_query, ref mut neighbor_queue))| {
594                    for (idx_in_batch, aligned_data) in
595                        aligned_data_batch[..points].iter().enumerate()
596                    {
597                        let idx = (num_base_points + idx_in_batch) as u32;
598
599                        let allowed_by_bitmap = if let Some(ref bitmaps) = query_bitmaps {
600                            if let Ok(idx_usize) = idx.try_into() {
601                                bitmaps[idx_query].contains(idx_usize)
602                            } else {
603                                false
604                            }
605                        } else {
606                            true
607                        };
608
609                        if allowed_by_bitmap {
610                            let distance = distance_comparer
611                                .evaluate_similarity(&**aligned_data, aligned_query);
612                            neighbor_queue.insert(Neighbor { id: idx, distance });
613                        }
614                    }
615                },
616            );
617
618        num_base_points += points;
619    }
620
621    let mut aligned_data = AlignedBoxWithSlice::new(query_aligned_dimmensions, 32)?;
622
623    if let Some(insert_iter) = insert_iter {
624        for (insert_idx, (data_vector, _associated_data)) in insert_iter.enumerate() {
625            aligned_data[..data_vector.len()].copy_from_slice(&data_vector);
626            // For each node in the raw data, calculate the distance to each query vector and store it in the priority queue for that query.  This will find the closest N neighbors for each query.
627            for (idx_query, (aligned_query, ref mut neighbor_queue)) in
628                queries_and_neighbor_queue.iter_mut().enumerate()
629            {
630                let idx = (num_base_points + insert_idx) as u32;
631
632                let allowed_by_bitmap = if let Some(ref bitmaps) = query_bitmaps {
633                    if let Ok(idx_usize) = idx.try_into() {
634                        bitmaps[idx_query].contains(idx_usize)
635                    } else {
636                        false
637                    }
638                } else {
639                    true
640                };
641
642                if allowed_by_bitmap {
643                    let distance =
644                        distance_comparer.evaluate_similarity(&*aligned_data, aligned_query);
645                    neighbor_queue.insert(Neighbor { id: idx, distance })
646                }
647            }
648        }
649    }
650
651    Ok((neighbor_queues, id_to_associated_data))
652}
653
654#[allow(clippy::too_many_arguments)]
655pub fn compute_multivec_ground_truth_from_data<T, VectorReader>(
656    distance_function: Metric,
657    aggregation_method: MultivecAggregationMethod,
658    base_vectors: Vec<Matrix<T>>,
659    queries: Vec<Matrix<T>>,
660    query_dim: usize,
661    recall_at: u32,
662    query_bitmaps: Option<Vec<BitSet>>,
663) -> CMDResult<Vec<NeighborPriorityQueue<u32>>>
664where
665    T: VectorRepr,
666    VectorReader: StorageReadProvider,
667{
668    let query_num = queries.len();
669
670    let mut neighbor_queues: Vec<NeighborPriorityQueue<u32>> = Vec::with_capacity(query_num);
671    //
672    for _ in 0..query_num {
673        neighbor_queues.push(NeighborPriorityQueue::new(recall_at as usize));
674    }
675    let mut query_multivecs_and_neighbor_queue: Vec<_> =
676        queries.iter().zip(neighbor_queues.iter_mut()).collect();
677
678    let distance_comparer = T::distance(distance_function, Some(query_dim));
679
680    let pool = create_thread_pool(0)?;
681
682    // for each query multivec, compute chamfer distance in parallel
683
684    query_multivecs_and_neighbor_queue
685        .par_iter_mut()
686        .enumerate()
687        .for_each_in_pool(&pool, |(query_idx, (query_multivec, neighbor_queue))| {
688            for (idx_base, base_multivec) in base_vectors.iter().enumerate() {
689                // check if calculation is allowed by bitmap if present
690                let allowed_by_bitmap = if let Some(ref bitmaps) = query_bitmaps {
691                    bitmaps[query_idx].contains(idx_base)
692                } else {
693                    true
694                };
695
696                if allowed_by_bitmap {
697                    // compute distance between query_multivec and base_multivec
698                    let distance = match aggregation_method {
699                        MultivecAggregationMethod::AveragePairwise => {
700                            let mut total_distance = 0.0;
701                            for query_vec in query_multivec.row_iter() {
702                                for base_vec in base_multivec.row_iter() {
703                                    let dist =
704                                        distance_comparer.evaluate_similarity(query_vec, base_vec);
705                                    total_distance += dist;
706                                }
707                            }
708                            total_distance / (query_multivec.nrows() * base_multivec.nrows()) as f32
709                        }
710                        MultivecAggregationMethod::MinPairwise => {
711                            let mut min_distance = f32::MAX;
712                            for query_vec in query_multivec.row_iter() {
713                                for base_vec in base_multivec.row_iter() {
714                                    let dist =
715                                        distance_comparer.evaluate_similarity(query_vec, base_vec);
716                                    min_distance = min_distance.min(dist);
717                                }
718                            }
719                            min_distance
720                        }
721                        MultivecAggregationMethod::AvgofMins => {
722                            let mut distance = 0_f32;
723                            for query_vec in query_multivec.row_iter() {
724                                let mut local_min = f32::MAX;
725                                for base_vec in base_multivec.row_iter() {
726                                    let dist =
727                                        distance_comparer.evaluate_similarity(query_vec, base_vec);
728                                    local_min = local_min.min(dist);
729                                }
730                                distance += local_min;
731                            }
732                            distance / query_multivec.nrows() as f32
733                        }
734                    };
735                    // insert into neighbor queue
736                    let idx = idx_base as u32;
737                    neighbor_queue.insert(Neighbor { id: idx, distance });
738                }
739            }
740        });
741
742    Ok(neighbor_queues)
743}
744
745pub fn compute_range_search_ground_truth_from_data<Data, VectorReader, VectorIteratorType>(
746    distance_function: Metric,
747    dataset_iter: VectorDataIterator<VectorReader, Data>,
748    queries: Vec<&[Data::VectorDataType]>,
749    query_aligned_dimmensions: usize,
750    range_threshold: f32,
751) -> CMDResult<Vec<Vec<Neighbor<u32>>>>
752where
753    Data: GraphDataType,
754    VectorReader: StorageReadProvider,
755{
756    let query_num = queries.len();
757    let mut neighbor_queues: Vec<Vec<Neighbor<u32>>> = Vec::with_capacity(query_num);
758    for _ in 0..query_num {
759        neighbor_queues.push(Vec::new());
760    }
761
762    let mut queries_and_neighbor_queue: Vec<_> =
763        queries.iter().zip(neighbor_queues.iter_mut()).collect();
764
765    let distance_comparer =
766        Data::VectorDataType::distance(distance_function, Some(query_aligned_dimmensions));
767
768    let mut aligned_data = AlignedBoxWithSlice::new(query_aligned_dimmensions, 32)?;
769    let mut aligned_query = AlignedBoxWithSlice::new(query_aligned_dimmensions, 32)?;
770
771    for (idx, (data_vector, _associated_data)) in dataset_iter.enumerate() {
772        aligned_data[..data_vector.len()].copy_from_slice(&data_vector);
773        for (query, ref mut neighbor_queue) in queries_and_neighbor_queue.iter_mut() {
774            aligned_query[..query.len()].copy_from_slice(query);
775            let distance = distance_comparer.evaluate_similarity(&*aligned_data, &aligned_query);
776            if distance <= range_threshold {
777                neighbor_queue.push(Neighbor {
778                    id: idx as u32,
779                    distance,
780                })
781            }
782        }
783    }
784
785    Ok(neighbor_queues)
786}