1use bit_set::BitSet;
7use diskann_label_filter::{eval_query_expr, read_and_parse_queries, read_baselabels};
8
9use std::{io::Write, mem::size_of, str::FromStr};
10
11use bytemuck::cast_slice;
12use diskann::{
13 neighbor::{Neighbor, NeighborPriorityQueue},
14 utils::VectorRepr,
15};
16use diskann_disk::data_model::GraphDataType;
17use diskann_providers::storage::{StorageReadProvider, StorageWriteProvider};
18use diskann_providers::utils::{
19 create_thread_pool, file_util, ParallelIteratorInPool, VectorDataIterator,
20};
21use diskann_utils::{
22 io::{read_bin, Metadata},
23 views::Matrix,
24};
25use diskann_vector::{distance::Metric, DistanceFunction};
26use itertools::Itertools;
27use rayon::prelude::*;
28
29use crate::utils::{search_index_utils, CMDResult, CMDToolError};
30
31pub fn read_labels_and_compute_bitmap(
32 base_label_filename: &str,
33 query_label_filename: &str,
34) -> CMDResult<Vec<BitSet>> {
35 let base_labels = read_baselabels(base_label_filename)?;
37
38 let parsed_queries = read_and_parse_queries(query_label_filename)?;
40
41 #[allow(clippy::disallowed_methods)]
43 let query_bitmaps: Vec<BitSet> = parsed_queries
44 .par_iter()
45 .map(|(_query_id, query_expr)| {
46 let mut bitmap = BitSet::new();
47 for base_label in base_labels.iter() {
48 if eval_query_expr(query_expr, &base_label.label) {
49 bitmap.insert(base_label.doc_id);
50 }
51 }
52 bitmap
53 })
54 .collect();
55
56 Ok(query_bitmaps)
57}
58
59#[allow(clippy::too_many_arguments)]
60#[allow(clippy::panic)]
61pub fn compute_ground_truth_from_datafiles<
73 Data: GraphDataType,
74 StorageProvider: StorageReadProvider + StorageWriteProvider,
75>(
76 storage_provider: &StorageProvider,
77 distance_function: Metric,
78 base_file: &str,
79 query_file: &str,
80 ground_truth_file: &str,
81 vector_filters_file: Option<&str>,
82 recall_at: u32,
83 insert_file: Option<&str>,
84 skip_base: Option<usize>,
85 associated_data_file: Option<String>,
86 base_file_labels: Option<&str>,
87 query_file_labels: Option<&str>,
88) -> CMDResult<()> {
89 let dataset_iterator = VectorDataIterator::<
90 StorageProvider,
91 Data::VectorDataType,
92 Data::AssociatedDataType,
93 >::new(base_file, associated_data_file.clone(), storage_provider)?;
94
95 if !((base_file_labels.is_some() && query_file_labels.is_some())
97 || (base_file_labels.is_none() && query_file_labels.is_none()))
98 {
99 return Err(CMDToolError {
100 details: "Both base_file_labels and query_file_labels must be provided or both must be not provided.".to_string(),
101 });
102 }
103
104 if base_file_labels.is_some() && vector_filters_file.is_some() {
105 return Err(CMDToolError {
106 details: "Both base_file_labels and vector_filters_file cannot be provided."
107 .to_string(),
108 });
109 }
110
111 let insert_iterator = match insert_file {
112 Some(insert_file) => {
113 let i = VectorDataIterator::<
114 StorageProvider,
115 Data::VectorDataType,
116 Data::AssociatedDataType,
117 >::new(insert_file, Option::None, storage_provider)?;
118 Some(i)
119 }
120 None => None,
121 };
122
123 let query_data =
125 read_bin::<Data::VectorDataType>(&mut storage_provider.open_reader(query_file)?)?;
126 let query_num = query_data.nrows();
127
128 let mut query_bitmaps: Option<Vec<BitSet>> = None;
129 if let (Some(base_file_labels), Some(query_file_labels)) = (base_file_labels, query_file_labels)
130 {
131 query_bitmaps = Some(read_labels_and_compute_bitmap(
132 base_file_labels,
133 query_file_labels,
134 )?);
135 }
136
137 let vector_filters = match vector_filters_file {
139 Some(vector_filters_file) => {
140 let filters =
141 search_index_utils::load_vector_filters(storage_provider, vector_filters_file)?;
142
143 assert_eq!(
144 filters.len(),
145 query_num,
146 "Mismatch in query and vector filter sizes"
147 );
148
149 Some(filters)
150 }
151 None => None,
152 };
153
154 let has_vector_filters = vector_filters.is_some();
155 let has_query_bitmaps = query_bitmaps.is_some();
156
157 if has_vector_filters {
158 if let Some(filters) = vector_filters {
160 let mut bitmaps = vec![BitSet::new(); query_num];
161 for (idx_query, filter) in filters.iter().enumerate() {
162 for item in filter.iter() {
163 if let Ok(idx) = (*item).try_into() {
164 bitmaps[idx_query].insert(idx);
165 }
166 }
167 }
168 query_bitmaps = Some(bitmaps)
169 }
170 }
171
172 let ground_truth_result = compute_ground_truth_from_data::<Data, StorageProvider>(
173 distance_function,
174 dataset_iterator,
175 &query_data,
176 recall_at,
177 insert_iterator,
178 skip_base,
179 query_bitmaps,
180 );
181 assert!(
182 &ground_truth_result.is_ok(),
183 "Ground-truth computation failed"
184 );
185 let (ground_truth, id_to_associated_data) = ground_truth_result?;
186
187 assert_ne!(ground_truth.len(), 0, "No ground-truth results computed");
188
189 if has_vector_filters || has_query_bitmaps {
190 let ground_truth_collection = ground_truth
191 .into_iter()
192 .map(|npq| npq.into_iter().collect())
193 .collect();
194 write_range_search_ground_truth(
195 storage_provider,
196 ground_truth_file,
197 query_num,
198 ground_truth_collection,
199 )
200 } else {
201 let id_to_associated_data = associated_data_file.map(|_| id_to_associated_data);
203 write_ground_truth::<Data>(
204 storage_provider,
205 ground_truth_file,
206 query_num,
207 recall_at as usize,
208 ground_truth,
209 id_to_associated_data,
210 )
211 }
212}
213
214#[derive(Debug, Clone)]
215pub enum MultivecAggregationMethod {
216 AveragePairwise,
217 MinPairwise,
218 AvgofMins,
219}
220
221#[derive(Debug)]
222pub enum ParseAggrError {
223 InvalidFormat(String),
224}
225
226impl std::fmt::Display for ParseAggrError {
227 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
228 match self {
229 Self::InvalidFormat(str) => write!(f, "Invalid format for Aggregation Method: {}", str),
230 }
231 }
232}
233
234impl std::error::Error for ParseAggrError {}
235
236impl FromStr for MultivecAggregationMethod {
237 type Err = ParseAggrError;
238
239 fn from_str(s: &str) -> Result<Self, Self::Err> {
240 match s.to_lowercase().as_str() {
241 "average_pairwise" => Ok(MultivecAggregationMethod::AveragePairwise),
242 "min_pairwise" => Ok(MultivecAggregationMethod::MinPairwise),
243 "avg_of_mins" => Ok(MultivecAggregationMethod::AvgofMins),
244 _ => Err(ParseAggrError::InvalidFormat(String::from(s))),
245 }
246 }
247}
248
249#[allow(clippy::too_many_arguments)]
250#[allow(clippy::panic)]
251pub fn compute_multivec_ground_truth_from_datafiles<
264 Data: GraphDataType,
265 StorageProvider: StorageReadProvider + StorageWriteProvider,
266>(
267 storage_provider: &StorageProvider,
268 distance_function: Metric,
269 aggregation_method: MultivecAggregationMethod,
270 base_file: &str,
271 query_file: &str,
272 ground_truth_file: &str,
273 recall_at: u32,
274 base_file_labels: Option<&str>,
275 query_file_labels: Option<&str>,
276) -> CMDResult<()> {
277 let (base_vectors, _, _, _) = file_util::load_multivec_bin::<
278 Data::VectorDataType,
279 StorageProvider,
280 >(storage_provider, base_file)?;
281
282 let (query_vectors, query_num, query_dim, _) = file_util::load_multivec_bin::<
283 Data::VectorDataType,
284 StorageProvider,
285 >(storage_provider, query_file)?;
286
287 if !((base_file_labels.is_some() && query_file_labels.is_some())
289 || (base_file_labels.is_none() && query_file_labels.is_none()))
290 {
291 return Err(CMDToolError {
292 details: "Both base_file_labels and query_file_labels must be provided or both must be not provided.".to_string(),
293 });
294 }
295
296 let mut query_bitmaps: Option<Vec<BitSet>> = None;
297 if let (Some(base_file_labels), Some(query_file_labels)) = (base_file_labels, query_file_labels)
298 {
299 query_bitmaps = Some(read_labels_and_compute_bitmap(
300 base_file_labels,
301 query_file_labels,
302 )?);
303 }
304
305 let has_query_bitmaps = query_bitmaps.is_some();
306
307 let ground_truth = compute_multivec_ground_truth_from_data::<Data::VectorDataType>(
308 distance_function,
309 aggregation_method,
310 base_vectors,
311 query_vectors,
312 query_dim,
313 recall_at,
314 query_bitmaps,
315 )?;
316
317 if has_query_bitmaps {
318 let ground_truth_collection = ground_truth
319 .into_iter()
320 .map(|npq| npq.into_iter().collect())
321 .collect();
322 write_range_search_ground_truth(
323 storage_provider,
324 ground_truth_file,
325 query_num,
326 ground_truth_collection,
327 )
328 } else {
329 write_ground_truth::<Data>(
331 storage_provider,
332 ground_truth_file,
333 query_num,
334 recall_at as usize,
335 ground_truth,
336 Option::None,
337 )
338 }
339}
340
341fn write_range_search_ground_truth<StorageProvider: StorageReadProvider + StorageWriteProvider>(
342 storage_provider: &StorageProvider,
343 ground_truth_file: &str,
344 number_of_queries: usize,
345 ground_truth: Vec<Vec<Neighbor<u32>>>,
346) -> CMDResult<()> {
347 let mut file = storage_provider.create_for_write(ground_truth_file)?;
348
349 let queue_sizes: Vec<u32> = ground_truth
350 .iter()
351 .map(|queue| queue.len() as u32)
352 .collect();
353 let total_number_of_neighbors: usize = queue_sizes.iter().sum::<u32>() as usize;
354
355 Metadata::new(number_of_queries, total_number_of_neighbors)?.write(&mut file)?;
357
358 let mut queue_sizes_buffer = vec![0; queue_sizes.len() * size_of::<u32>()];
360 queue_sizes_buffer.clone_from_slice(cast_slice::<u32, u8>(&queue_sizes));
361 file.write_all(&queue_sizes_buffer)?;
362
363 let mut neighbor_ids: Vec<u32> = Vec::with_capacity(total_number_of_neighbors);
364
365 for query_neighbors in ground_truth {
367 for neighbor in query_neighbors.iter() {
368 neighbor_ids.push(neighbor.id);
369 }
370 }
371
372 let mut id_buffer = vec![0; total_number_of_neighbors * size_of::<u32>()];
374 id_buffer.clone_from_slice(cast_slice::<u32, u8>(&neighbor_ids));
375 file.write_all(&id_buffer)?;
376
377 file.flush()?;
379
380 Ok(())
381}
382
383fn write_ground_truth<Data: GraphDataType>(
387 storage_provider: &impl StorageWriteProvider,
388 ground_truth_file: &str,
389 number_of_queries: usize,
390 number_of_neighbors: usize,
391 ground_truth: Vec<NeighborPriorityQueue<u32>>,
392 id_to_associated_data: Option<Vec<Data::AssociatedDataType>>,
393) -> CMDResult<()> {
394 let mut file = storage_provider.create_for_write(ground_truth_file)?;
395
396 Metadata::new(number_of_queries, number_of_neighbors)?.write(&mut file)?;
397
398 let mut gt_ids: Vec<u32> = Vec::with_capacity(number_of_neighbors * number_of_queries);
399 let mut gt_distances: Vec<f32> = Vec::with_capacity(number_of_neighbors * number_of_queries);
400
401 for mut query_neighbors in ground_truth {
403 while let Some(closest_node) = query_neighbors.closest_notvisited() {
404 gt_ids.push(closest_node.id);
405 gt_distances.push(closest_node.distance);
406 }
407 }
408
409 if let Some(id_to_associated_data) = id_to_associated_data {
411 let mut associated_data_buffer = Vec::<u8>::new();
412 for id in gt_ids {
413 let associated_data = id_to_associated_data[id as usize];
414 let serialized_associated_data =
415 bincode::serialize(&associated_data).map_err(|e| CMDToolError {
416 details: format!("Failed to serialize associated data: {}", e),
417 })?;
418 associated_data_buffer.extend_from_slice(serialized_associated_data.as_slice());
419 }
420 file.write_all(&associated_data_buffer)?;
421 } else {
422 let mut id_buffer = vec![0; number_of_queries * number_of_neighbors * size_of::<u32>()];
423 id_buffer.clone_from_slice(cast_slice::<u32, u8>(>_ids));
424 file.write_all(&id_buffer)?;
425 }
426
427 let mut distance_buffer = vec![0; number_of_queries * number_of_neighbors * size_of::<f32>()];
429 distance_buffer.clone_from_slice(cast_slice::<f32, u8>(>_distances));
430 file.write_all(&distance_buffer)?;
431
432 file.flush()?;
434
435 Ok(())
436}
437
438type Npq = Vec<NeighborPriorityQueue<u32>>;
439#[allow(clippy::too_many_arguments)]
452pub fn compute_ground_truth_from_data<Data, VectorReader>(
453 distance_function: Metric,
454 dataset_iter: VectorDataIterator<VectorReader, Data::VectorDataType, Data::AssociatedDataType>,
455 queries: &Matrix<Data::VectorDataType>,
456 recall_at: u32,
457 insert_iter: Option<
458 VectorDataIterator<VectorReader, Data::VectorDataType, Data::AssociatedDataType>,
459 >,
460 skip_base: Option<usize>,
461 query_bitmaps: Option<Vec<BitSet>>,
462) -> CMDResult<(Npq, Vec<Data::AssociatedDataType>)>
463where
464 Data: GraphDataType,
465 VectorReader: StorageReadProvider,
466{
467 let query_num = queries.nrows();
468 let query_dim = queries.ncols();
469
470 let mut neighbor_queues: Vec<NeighborPriorityQueue<u32>> = (0..query_num)
471 .map(|_| NeighborPriorityQueue::new(recall_at as usize))
472 .collect();
473 let mut queries_and_neighbor_queue: Vec<_> =
474 queries.row_iter().zip(neighbor_queues.iter_mut()).collect();
475
476 let distance_comparer = Data::VectorDataType::distance(distance_function, Some(query_dim));
477
478 let batch_size = 10_000;
479 let mut data_batch: Vec<Box<[Data::VectorDataType]>> = Vec::with_capacity(batch_size);
480
481 let pool = create_thread_pool(0)?;
482
483 let mut num_base_points: usize = 0;
484 let mut id_to_associated_data = Vec::<Data::AssociatedDataType>::new();
485 let skip_base = skip_base.unwrap_or(0);
486 for chunk in dataset_iter.skip(skip_base).chunks(batch_size).into_iter() {
488 data_batch.clear();
489 for (data_vector, associated_data) in chunk {
490 data_batch.push(data_vector);
491 id_to_associated_data.push(associated_data);
492 }
493 let points = data_batch.len();
494
495 if points == 0 {
496 continue;
497 }
498
499 queries_and_neighbor_queue
501 .par_iter_mut()
502 .enumerate()
503 .for_each_in_pool(&pool, |(idx_query, (query, ref mut neighbor_queue))| {
504 for (idx_in_batch, data) in data_batch.iter().enumerate() {
505 let idx = (num_base_points + idx_in_batch) as u32;
506
507 let allowed_by_bitmap = if let Some(ref bitmaps) = query_bitmaps {
508 if let Ok(idx_usize) = idx.try_into() {
509 bitmaps[idx_query].contains(idx_usize)
510 } else {
511 false
512 }
513 } else {
514 true
515 };
516
517 if allowed_by_bitmap {
518 let distance = distance_comparer.evaluate_similarity(data, query);
519 neighbor_queue.insert(Neighbor { id: idx, distance });
520 }
521 }
522 });
523
524 num_base_points += points;
525 }
526
527 if let Some(insert_iter) = insert_iter {
528 for (insert_idx, (data_vector, _associated_data)) in insert_iter.enumerate() {
529 for (idx_query, (query, ref mut neighbor_queue)) in
531 queries_and_neighbor_queue.iter_mut().enumerate()
532 {
533 let idx = (num_base_points + insert_idx) as u32;
534
535 let allowed_by_bitmap = if let Some(ref bitmaps) = query_bitmaps {
536 if let Ok(idx_usize) = idx.try_into() {
537 bitmaps[idx_query].contains(idx_usize)
538 } else {
539 false
540 }
541 } else {
542 true
543 };
544
545 if allowed_by_bitmap {
546 let distance = distance_comparer.evaluate_similarity(&data_vector, query);
547 neighbor_queue.insert(Neighbor { id: idx, distance })
548 }
549 }
550 }
551 }
552
553 Ok((neighbor_queues, id_to_associated_data))
554}
555
556#[allow(clippy::too_many_arguments)]
557pub fn compute_multivec_ground_truth_from_data<T>(
558 distance_function: Metric,
559 aggregation_method: MultivecAggregationMethod,
560 base_vectors: Vec<Matrix<T>>,
561 queries: Vec<Matrix<T>>,
562 query_dim: usize,
563 recall_at: u32,
564 query_bitmaps: Option<Vec<BitSet>>,
565) -> CMDResult<Vec<NeighborPriorityQueue<u32>>>
566where
567 T: VectorRepr,
568{
569 let query_num = queries.len();
570
571 let mut neighbor_queues: Vec<NeighborPriorityQueue<u32>> = Vec::with_capacity(query_num);
572 for _ in 0..query_num {
574 neighbor_queues.push(NeighborPriorityQueue::new(recall_at as usize));
575 }
576 let mut query_multivecs_and_neighbor_queue: Vec<_> =
577 queries.iter().zip(neighbor_queues.iter_mut()).collect();
578
579 let distance_comparer = T::distance(distance_function, Some(query_dim));
580
581 let pool = create_thread_pool(0)?;
582
583 query_multivecs_and_neighbor_queue
586 .par_iter_mut()
587 .enumerate()
588 .for_each_in_pool(&pool, |(query_idx, (query_multivec, neighbor_queue))| {
589 for (idx_base, base_multivec) in base_vectors.iter().enumerate() {
590 let allowed_by_bitmap = if let Some(ref bitmaps) = query_bitmaps {
592 bitmaps[query_idx].contains(idx_base)
593 } else {
594 true
595 };
596
597 if allowed_by_bitmap {
598 let distance = match aggregation_method {
600 MultivecAggregationMethod::AveragePairwise => {
601 let mut total_distance = 0.0;
602 for query_vec in query_multivec.row_iter() {
603 for base_vec in base_multivec.row_iter() {
604 let dist =
605 distance_comparer.evaluate_similarity(query_vec, base_vec);
606 total_distance += dist;
607 }
608 }
609 total_distance / (query_multivec.nrows() * base_multivec.nrows()) as f32
610 }
611 MultivecAggregationMethod::MinPairwise => {
612 let mut min_distance = f32::MAX;
613 for query_vec in query_multivec.row_iter() {
614 for base_vec in base_multivec.row_iter() {
615 let dist =
616 distance_comparer.evaluate_similarity(query_vec, base_vec);
617 min_distance = min_distance.min(dist);
618 }
619 }
620 min_distance
621 }
622 MultivecAggregationMethod::AvgofMins => {
623 let mut distance = 0_f32;
624 for query_vec in query_multivec.row_iter() {
625 let mut local_min = f32::MAX;
626 for base_vec in base_multivec.row_iter() {
627 let dist =
628 distance_comparer.evaluate_similarity(query_vec, base_vec);
629 local_min = local_min.min(dist);
630 }
631 distance += local_min;
632 }
633 distance / query_multivec.nrows() as f32
634 }
635 };
636 let idx = idx_base as u32;
638 neighbor_queue.insert(Neighbor { id: idx, distance });
639 }
640 }
641 });
642
643 Ok(neighbor_queues)
644}