1use bit_set::BitSet;
7use diskann_label_filter::{eval_query_expr, read_and_parse_queries, read_baselabels};
8
9use std::{io::Write, mem::size_of, str::FromStr};
10
11use bytemuck::cast_slice;
12use diskann::{
13 neighbor::{Neighbor, NeighborPriorityQueue},
14 utils::VectorRepr,
15};
16use diskann_providers::storage::{StorageReadProvider, StorageWriteProvider};
17use diskann_providers::{
18 common::AlignedBoxWithSlice,
19 model::graph::traits::GraphDataType,
20 utils::{create_thread_pool, file_util, ParallelIteratorInPool, VectorDataIterator},
21};
22use diskann_utils::{
23 io::{read_bin, Metadata},
24 views::Matrix,
25};
26use diskann_vector::{distance::Metric, DistanceFunction};
27use itertools::Itertools;
28use rayon::prelude::*;
29
30use crate::utils::{search_index_utils, CMDResult, CMDToolError};
31
32pub fn read_labels_and_compute_bitmap(
33 base_label_filename: &str,
34 query_label_filename: &str,
35) -> CMDResult<Vec<BitSet>> {
36 let base_labels = read_baselabels(base_label_filename)?;
38
39 let parsed_queries = read_and_parse_queries(query_label_filename)?;
41
42 #[allow(clippy::disallowed_methods)]
44 let query_bitmaps: Vec<BitSet> = parsed_queries
45 .par_iter()
46 .map(|(_query_id, query_expr)| {
47 let mut bitmap = BitSet::new();
48 for base_label in base_labels.iter() {
49 if eval_query_expr(query_expr, &base_label.label) {
50 bitmap.insert(base_label.doc_id);
51 }
52 }
53 bitmap
54 })
55 .collect();
56
57 Ok(query_bitmaps)
58}
59
60#[allow(clippy::too_many_arguments)]
61#[allow(clippy::panic)]
62pub fn compute_ground_truth_from_datafiles<
74 Data: GraphDataType,
75 StorageProvider: StorageReadProvider + StorageWriteProvider,
76>(
77 storage_provider: &StorageProvider,
78 distance_function: Metric,
79 base_file: &str,
80 query_file: &str,
81 ground_truth_file: &str,
82 vector_filters_file: Option<&str>,
83 recall_at: u32,
84 insert_file: Option<&str>,
85 skip_base: Option<usize>,
86 associated_data_file: Option<String>,
87 base_file_labels: Option<&str>,
88 query_file_labels: Option<&str>,
89) -> CMDResult<()> {
90 let dataset_iterator = VectorDataIterator::<StorageProvider, Data>::new(
91 base_file,
92 associated_data_file.clone(),
93 storage_provider,
94 )?;
95
96 if !((base_file_labels.is_some() && query_file_labels.is_some())
98 || (base_file_labels.is_none() && query_file_labels.is_none()))
99 {
100 return Err(CMDToolError {
101 details: "Both base_file_labels and query_file_labels must be provided or both must be not provided.".to_string(),
102 });
103 }
104
105 if base_file_labels.is_some() && vector_filters_file.is_some() {
106 return Err(CMDToolError {
107 details: "Both base_file_labels and vector_filters_file cannot be provided."
108 .to_string(),
109 });
110 }
111
112 let insert_iterator = match insert_file {
113 Some(insert_file) => {
114 let i = VectorDataIterator::<StorageProvider, Data>::new(
115 insert_file,
116 Option::None,
117 storage_provider,
118 )?;
119 Some(i)
120 }
121 None => None,
122 };
123
124 let query_data =
126 read_bin::<Data::VectorDataType>(&mut storage_provider.open_reader(query_file)?)?;
127 let query_num = query_data.nrows();
128 let query_dim = query_data.ncols();
129
130 let mut query_bitmaps: Option<Vec<BitSet>> = None;
131 if let (Some(base_file_labels), Some(query_file_labels)) = (base_file_labels, query_file_labels)
132 {
133 query_bitmaps = Some(read_labels_and_compute_bitmap(
134 base_file_labels,
135 query_file_labels,
136 )?);
137 }
138
139 let queries: Vec<_> = query_data.row_iter().collect();
140
141 let vector_filters = match vector_filters_file {
143 Some(vector_filters_file) => {
144 let filters =
145 search_index_utils::load_vector_filters(storage_provider, vector_filters_file)?;
146
147 assert_eq!(
148 filters.len(),
149 queries.len(),
150 "Mismatch in query and vector filter sizes"
151 );
152
153 Some(filters)
154 }
155 None => None,
156 };
157
158 let has_vector_filters = vector_filters.is_some();
159 let has_query_bitmaps = query_bitmaps.is_some();
160
161 if has_vector_filters {
162 if let Some(filters) = vector_filters {
164 let mut bitmaps = vec![BitSet::new(); queries.len()];
165 for (idx_query, filter) in filters.iter().enumerate() {
166 for item in filter.iter() {
167 if let Ok(idx) = (*item).try_into() {
168 bitmaps[idx_query].insert(idx);
169 }
170 }
171 }
172 query_bitmaps = Some(bitmaps)
173 }
174 }
175
176 let query_aligned_dim = query_dim.next_multiple_of(8);
177 let ground_truth_result = compute_ground_truth_from_data::<
178 Data,
179 StorageProvider,
180 VectorDataIterator<StorageProvider, Data>,
181 >(
182 distance_function,
183 dataset_iterator,
184 queries,
185 query_aligned_dim,
186 recall_at,
187 insert_iterator,
188 skip_base,
189 query_bitmaps,
190 );
191 assert!(
192 &ground_truth_result.is_ok(),
193 "Ground-truth computation failed"
194 );
195 let (ground_truth, id_to_associated_data) = ground_truth_result?;
196
197 assert_ne!(ground_truth.len(), 0, "No ground-truth results computed");
198
199 if has_vector_filters || has_query_bitmaps {
200 let ground_truth_collection = ground_truth
201 .into_iter()
202 .map(|npq| npq.into_iter().collect())
203 .collect();
204 write_range_search_ground_truth(
205 storage_provider,
206 ground_truth_file,
207 query_num,
208 ground_truth_collection,
209 )
210 } else {
211 let id_to_associated_data = associated_data_file.map(|_| id_to_associated_data);
213 write_ground_truth::<Data>(
214 storage_provider,
215 ground_truth_file,
216 query_num,
217 recall_at as usize,
218 ground_truth,
219 id_to_associated_data,
220 )
221 }
222}
223
224#[derive(Debug, Clone)]
225pub enum MultivecAggregationMethod {
226 AveragePairwise,
227 MinPairwise,
228 AvgofMins,
229}
230
231#[derive(Debug)]
232pub enum ParseAggrError {
233 InvalidFormat(String),
234}
235
236impl std::fmt::Display for ParseAggrError {
237 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
238 match self {
239 Self::InvalidFormat(str) => write!(f, "Invalid format for Aggregation Method: {}", str),
240 }
241 }
242}
243
244impl std::error::Error for ParseAggrError {}
245
246impl FromStr for MultivecAggregationMethod {
247 type Err = ParseAggrError;
248
249 fn from_str(s: &str) -> Result<Self, Self::Err> {
250 match s.to_lowercase().as_str() {
251 "average_pairwise" => Ok(MultivecAggregationMethod::AveragePairwise),
252 "min_pairwise" => Ok(MultivecAggregationMethod::MinPairwise),
253 "avg_of_mins" => Ok(MultivecAggregationMethod::AvgofMins),
254 _ => Err(ParseAggrError::InvalidFormat(String::from(s))),
255 }
256 }
257}
258
259#[allow(clippy::too_many_arguments)]
260#[allow(clippy::panic)]
261pub fn compute_multivec_ground_truth_from_datafiles<
274 Data: GraphDataType,
275 StorageProvider: StorageReadProvider + StorageWriteProvider,
276>(
277 storage_provider: &StorageProvider,
278 distance_function: Metric,
279 aggregation_method: MultivecAggregationMethod,
280 base_file: &str,
281 query_file: &str,
282 ground_truth_file: &str,
283 recall_at: u32,
284 base_file_labels: Option<&str>,
285 query_file_labels: Option<&str>,
286) -> CMDResult<()> {
287 let (base_vectors, _, _, _) = file_util::load_multivec_bin::<
288 Data::VectorDataType,
289 StorageProvider,
290 >(storage_provider, base_file)?;
291
292 let (query_vectors, query_num, query_dim, _) = file_util::load_multivec_bin::<
293 Data::VectorDataType,
294 StorageProvider,
295 >(storage_provider, query_file)?;
296
297 if !((base_file_labels.is_some() && query_file_labels.is_some())
299 || (base_file_labels.is_none() && query_file_labels.is_none()))
300 {
301 return Err(CMDToolError {
302 details: "Both base_file_labels and query_file_labels must be provided or both must be not provided.".to_string(),
303 });
304 }
305
306 let mut query_bitmaps: Option<Vec<BitSet>> = None;
307 if let (Some(base_file_labels), Some(query_file_labels)) = (base_file_labels, query_file_labels)
308 {
309 query_bitmaps = Some(read_labels_and_compute_bitmap(
310 base_file_labels,
311 query_file_labels,
312 )?);
313 }
314
315 let has_query_bitmaps = query_bitmaps.is_some();
316
317 let ground_truth =
318 compute_multivec_ground_truth_from_data::<Data::VectorDataType, StorageProvider>(
319 distance_function,
320 aggregation_method,
321 base_vectors,
322 query_vectors,
323 query_dim,
324 recall_at,
325 query_bitmaps,
326 )?;
327
328 if has_query_bitmaps {
329 let ground_truth_collection = ground_truth
330 .into_iter()
331 .map(|npq| npq.into_iter().collect())
332 .collect();
333 write_range_search_ground_truth(
334 storage_provider,
335 ground_truth_file,
336 query_num,
337 ground_truth_collection,
338 )
339 } else {
340 write_ground_truth::<Data>(
342 storage_provider,
343 ground_truth_file,
344 query_num,
345 recall_at as usize,
346 ground_truth,
347 Option::None,
348 )
349 }
350}
351
352fn write_range_search_ground_truth<StorageProvider: StorageReadProvider + StorageWriteProvider>(
353 storage_provider: &StorageProvider,
354 ground_truth_file: &str,
355 number_of_queries: usize,
356 ground_truth: Vec<Vec<Neighbor<u32>>>,
357) -> CMDResult<()> {
358 let mut file = storage_provider.create_for_write(ground_truth_file)?;
359
360 let queue_sizes: Vec<u32> = ground_truth
361 .iter()
362 .map(|queue| queue.len() as u32)
363 .collect();
364 let total_number_of_neighbors: usize = queue_sizes.iter().sum::<u32>() as usize;
365
366 Metadata::new(number_of_queries, total_number_of_neighbors)?.write(&mut file)?;
368
369 let mut queue_sizes_buffer = vec![0; queue_sizes.len() * size_of::<u32>()];
371 queue_sizes_buffer.clone_from_slice(cast_slice::<u32, u8>(&queue_sizes));
372 file.write_all(&queue_sizes_buffer)?;
373
374 let mut neighbor_ids: Vec<u32> = Vec::with_capacity(total_number_of_neighbors);
375
376 for query_neighbors in ground_truth {
378 for neighbor in query_neighbors.iter() {
379 neighbor_ids.push(neighbor.id);
380 }
381 }
382
383 let mut id_buffer = vec![0; total_number_of_neighbors * size_of::<u32>()];
385 id_buffer.clone_from_slice(cast_slice::<u32, u8>(&neighbor_ids));
386 file.write_all(&id_buffer)?;
387
388 file.flush()?;
390
391 Ok(())
392}
393
394fn write_ground_truth<Data: GraphDataType>(
398 storage_provider: &impl StorageWriteProvider,
399 ground_truth_file: &str,
400 number_of_queries: usize,
401 number_of_neighbors: usize,
402 ground_truth: Vec<NeighborPriorityQueue<u32>>,
403 id_to_associated_data: Option<Vec<Data::AssociatedDataType>>,
404) -> CMDResult<()> {
405 let mut file = storage_provider.create_for_write(ground_truth_file)?;
406
407 Metadata::new(number_of_queries, number_of_neighbors)?.write(&mut file)?;
408
409 let mut gt_ids: Vec<u32> = Vec::with_capacity(number_of_neighbors * number_of_queries);
410 let mut gt_distances: Vec<f32> = Vec::with_capacity(number_of_neighbors * number_of_queries);
411
412 for mut query_neighbors in ground_truth {
414 while let Some(closest_node) = query_neighbors.closest_notvisited() {
415 gt_ids.push(closest_node.id);
416 gt_distances.push(closest_node.distance);
417 }
418 }
419
420 if let Some(id_to_associated_data) = id_to_associated_data {
422 let mut associated_data_buffer = Vec::<u8>::new();
423 for id in gt_ids {
424 let associated_data = id_to_associated_data[id as usize];
425 let serialized_associated_data =
426 bincode::serialize(&associated_data).map_err(|e| CMDToolError {
427 details: format!("Failed to serialize associated data: {}", e),
428 })?;
429 associated_data_buffer.extend_from_slice(serialized_associated_data.as_slice());
430 }
431 file.write_all(&associated_data_buffer)?;
432 } else {
433 let mut id_buffer = vec![0; number_of_queries * number_of_neighbors * size_of::<u32>()];
434 id_buffer.clone_from_slice(cast_slice::<u32, u8>(>_ids));
435 file.write_all(&id_buffer)?;
436 }
437
438 let mut distance_buffer = vec![0; number_of_queries * number_of_neighbors * size_of::<f32>()];
440 distance_buffer.clone_from_slice(cast_slice::<f32, u8>(>_distances));
441 file.write_all(&distance_buffer)?;
442
443 file.flush()?;
445
446 Ok(())
447}
448
449type Npq = Vec<NeighborPriorityQueue<u32>>;
450#[allow(clippy::too_many_arguments)]
462pub fn compute_ground_truth_from_data<Data, VectorReader, VectorIteratorType>(
463 distance_function: Metric,
464 dataset_iter: VectorDataIterator<VectorReader, Data>,
465 queries: Vec<&[Data::VectorDataType]>,
466 query_aligned_dimmensions: usize,
467 recall_at: u32,
468 insert_iter: Option<VectorDataIterator<VectorReader, Data>>,
469 skip_base: Option<usize>,
470 query_bitmaps: Option<Vec<BitSet>>,
471) -> CMDResult<(Npq, Vec<Data::AssociatedDataType>)>
472where
473 Data: GraphDataType,
474 VectorReader: StorageReadProvider,
475{
476 let query_num = queries.len();
477
478 let mut aligned_queries = Vec::with_capacity(query_num);
479 let mut neighbor_queues: Vec<NeighborPriorityQueue<u32>> = Vec::with_capacity(query_num);
480 for query in queries {
481 let mut aligned_query = AlignedBoxWithSlice::new(query_aligned_dimmensions, 32)?;
482 aligned_query[..query.len()].copy_from_slice(query);
483 aligned_queries.push(aligned_query);
484 neighbor_queues.push(NeighborPriorityQueue::new(recall_at as usize));
485 }
486 let mut queries_and_neighbor_queue: Vec<_> = aligned_queries
487 .iter()
488 .zip(neighbor_queues.iter_mut())
489 .collect();
490
491 let distance_comparer =
492 Data::VectorDataType::distance(distance_function, Some(query_aligned_dimmensions));
493
494 let batch_size = 10_000;
495 let mut aligned_data_batch = Vec::with_capacity(batch_size);
496 for _ in 0..batch_size {
497 aligned_data_batch.push(AlignedBoxWithSlice::new(query_aligned_dimmensions, 32)?);
498 }
499
500 let pool = create_thread_pool(0)?;
501
502 let mut num_base_points: usize = 0;
503 let mut id_to_associated_data = Vec::<Data::AssociatedDataType>::new();
504 let skip_base = skip_base.unwrap_or(0);
505 for chunk in dataset_iter
507 .skip(skip_base)
508 .enumerate()
509 .chunks(batch_size)
510 .into_iter()
511 {
512 let mut points = 0;
513 for (idx, (data_vector, associated_data)) in chunk {
514 aligned_data_batch[idx % batch_size][..data_vector.len()].copy_from_slice(&data_vector);
515 id_to_associated_data.push(associated_data);
516 points += 1;
517 }
518
519 if points == 0 {
520 continue;
521 }
522
523 queries_and_neighbor_queue
525 .par_iter_mut()
526 .enumerate()
527 .for_each_in_pool(
528 &pool,
529 |(idx_query, (aligned_query, ref mut neighbor_queue))| {
530 for (idx_in_batch, aligned_data) in
531 aligned_data_batch[..points].iter().enumerate()
532 {
533 let idx = (num_base_points + idx_in_batch) as u32;
534
535 let allowed_by_bitmap = if let Some(ref bitmaps) = query_bitmaps {
536 if let Ok(idx_usize) = idx.try_into() {
537 bitmaps[idx_query].contains(idx_usize)
538 } else {
539 false
540 }
541 } else {
542 true
543 };
544
545 if allowed_by_bitmap {
546 let distance = distance_comparer
547 .evaluate_similarity(&**aligned_data, aligned_query);
548 neighbor_queue.insert(Neighbor { id: idx, distance });
549 }
550 }
551 },
552 );
553
554 num_base_points += points;
555 }
556
557 let mut aligned_data = AlignedBoxWithSlice::new(query_aligned_dimmensions, 32)?;
558
559 if let Some(insert_iter) = insert_iter {
560 for (insert_idx, (data_vector, _associated_data)) in insert_iter.enumerate() {
561 aligned_data[..data_vector.len()].copy_from_slice(&data_vector);
562 for (idx_query, (aligned_query, ref mut neighbor_queue)) in
564 queries_and_neighbor_queue.iter_mut().enumerate()
565 {
566 let idx = (num_base_points + insert_idx) as u32;
567
568 let allowed_by_bitmap = if let Some(ref bitmaps) = query_bitmaps {
569 if let Ok(idx_usize) = idx.try_into() {
570 bitmaps[idx_query].contains(idx_usize)
571 } else {
572 false
573 }
574 } else {
575 true
576 };
577
578 if allowed_by_bitmap {
579 let distance =
580 distance_comparer.evaluate_similarity(&*aligned_data, aligned_query);
581 neighbor_queue.insert(Neighbor { id: idx, distance })
582 }
583 }
584 }
585 }
586
587 Ok((neighbor_queues, id_to_associated_data))
588}
589
590#[allow(clippy::too_many_arguments)]
591pub fn compute_multivec_ground_truth_from_data<T, VectorReader>(
592 distance_function: Metric,
593 aggregation_method: MultivecAggregationMethod,
594 base_vectors: Vec<Matrix<T>>,
595 queries: Vec<Matrix<T>>,
596 query_dim: usize,
597 recall_at: u32,
598 query_bitmaps: Option<Vec<BitSet>>,
599) -> CMDResult<Vec<NeighborPriorityQueue<u32>>>
600where
601 T: VectorRepr,
602 VectorReader: StorageReadProvider,
603{
604 let query_num = queries.len();
605
606 let mut neighbor_queues: Vec<NeighborPriorityQueue<u32>> = Vec::with_capacity(query_num);
607 for _ in 0..query_num {
609 neighbor_queues.push(NeighborPriorityQueue::new(recall_at as usize));
610 }
611 let mut query_multivecs_and_neighbor_queue: Vec<_> =
612 queries.iter().zip(neighbor_queues.iter_mut()).collect();
613
614 let distance_comparer = T::distance(distance_function, Some(query_dim));
615
616 let pool = create_thread_pool(0)?;
617
618 query_multivecs_and_neighbor_queue
621 .par_iter_mut()
622 .enumerate()
623 .for_each_in_pool(&pool, |(query_idx, (query_multivec, neighbor_queue))| {
624 for (idx_base, base_multivec) in base_vectors.iter().enumerate() {
625 let allowed_by_bitmap = if let Some(ref bitmaps) = query_bitmaps {
627 bitmaps[query_idx].contains(idx_base)
628 } else {
629 true
630 };
631
632 if allowed_by_bitmap {
633 let distance = match aggregation_method {
635 MultivecAggregationMethod::AveragePairwise => {
636 let mut total_distance = 0.0;
637 for query_vec in query_multivec.row_iter() {
638 for base_vec in base_multivec.row_iter() {
639 let dist =
640 distance_comparer.evaluate_similarity(query_vec, base_vec);
641 total_distance += dist;
642 }
643 }
644 total_distance / (query_multivec.nrows() * base_multivec.nrows()) as f32
645 }
646 MultivecAggregationMethod::MinPairwise => {
647 let mut min_distance = f32::MAX;
648 for query_vec in query_multivec.row_iter() {
649 for base_vec in base_multivec.row_iter() {
650 let dist =
651 distance_comparer.evaluate_similarity(query_vec, base_vec);
652 min_distance = min_distance.min(dist);
653 }
654 }
655 min_distance
656 }
657 MultivecAggregationMethod::AvgofMins => {
658 let mut distance = 0_f32;
659 for query_vec in query_multivec.row_iter() {
660 let mut local_min = f32::MAX;
661 for base_vec in base_multivec.row_iter() {
662 let dist =
663 distance_comparer.evaluate_similarity(query_vec, base_vec);
664 local_min = local_min.min(dist);
665 }
666 distance += local_min;
667 }
668 distance / query_multivec.nrows() as f32
669 }
670 };
671 let idx = idx_base as u32;
673 neighbor_queue.insert(Neighbor { id: idx, distance });
674 }
675 }
676 });
677
678 Ok(neighbor_queues)
679}