1use crate::utils::compute_bitmap::compute_query_bitmaps;
7use bit_set::BitSet;
8use diskann_label_filter::{read_and_parse_queries, read_baselabels};
9
10use std::{io::Write, mem::size_of, str::FromStr};
11
12use bytemuck::cast_slice;
13use diskann::{
14 neighbor::{Neighbor, NeighborPriorityQueue},
15 utils::VectorRepr,
16};
17use diskann_disk::data_model::GraphDataType;
18use diskann_providers::storage::{StorageReadProvider, StorageWriteProvider};
19use diskann_providers::utils::{
20 create_thread_pool, file_util, ParallelIteratorInPool, VectorDataIterator,
21};
22use diskann_utils::{
23 io::{read_bin, Metadata},
24 views::Matrix,
25};
26use diskann_vector::{distance::Metric, DistanceFunction};
27use itertools::Itertools;
28use rayon::prelude::*;
29
30use crate::utils::{search_index_utils, CMDResult, CMDToolError};
31
32pub fn read_labels_and_compute_bitmap(
33 base_label_filename: &str,
34 query_label_filename: &str,
35) -> CMDResult<Vec<BitSet>> {
36 let base_labels = read_baselabels(base_label_filename)?;
38
39 let parsed_queries = read_and_parse_queries(query_label_filename)?;
41
42 let query_bitmaps = compute_query_bitmaps(base_labels, parsed_queries);
44
45 match query_bitmaps {
46 Ok(bitmaps) => Ok(bitmaps),
47 Err(e) => Err(CMDToolError {
48 details: format!("Error computing query bitmaps: {}", e),
49 }),
50 }
51}
52
53#[allow(clippy::too_many_arguments)]
54#[allow(clippy::panic)]
55pub fn compute_ground_truth_from_datafiles<
67 Data: GraphDataType,
68 StorageProvider: StorageReadProvider + StorageWriteProvider,
69>(
70 storage_provider: &StorageProvider,
71 distance_function: Metric,
72 base_file: &str,
73 query_file: &str,
74 ground_truth_file: &str,
75 vector_filters_file: Option<&str>,
76 recall_at: u32,
77 insert_file: Option<&str>,
78 skip_base: Option<usize>,
79 associated_data_file: Option<String>,
80 base_file_labels: Option<&str>,
81 query_file_labels: Option<&str>,
82) -> CMDResult<()> {
83 let dataset_iterator = VectorDataIterator::<
84 StorageProvider,
85 Data::VectorDataType,
86 Data::AssociatedDataType,
87 >::new(base_file, associated_data_file.clone(), storage_provider)?;
88
89 if !((base_file_labels.is_some() && query_file_labels.is_some())
91 || (base_file_labels.is_none() && query_file_labels.is_none()))
92 {
93 return Err(CMDToolError {
94 details: "Both base_file_labels and query_file_labels must be provided or both must be not provided.".to_string(),
95 });
96 }
97
98 if base_file_labels.is_some() && vector_filters_file.is_some() {
99 return Err(CMDToolError {
100 details: "Both base_file_labels and vector_filters_file cannot be provided."
101 .to_string(),
102 });
103 }
104
105 let insert_iterator = match insert_file {
106 Some(insert_file) => {
107 let i = VectorDataIterator::<
108 StorageProvider,
109 Data::VectorDataType,
110 Data::AssociatedDataType,
111 >::new(insert_file, Option::None, storage_provider)?;
112 Some(i)
113 }
114 None => None,
115 };
116
117 let query_data =
119 read_bin::<Data::VectorDataType>(&mut storage_provider.open_reader(query_file)?)?;
120 let query_num = query_data.nrows();
121
122 let mut query_bitmaps: Option<Vec<BitSet>> = None;
123 if let (Some(base_file_labels), Some(query_file_labels)) = (base_file_labels, query_file_labels)
124 {
125 query_bitmaps = Some(read_labels_and_compute_bitmap(
126 base_file_labels,
127 query_file_labels,
128 )?);
129 }
130
131 let vector_filters = match vector_filters_file {
133 Some(vector_filters_file) => {
134 let filters =
135 search_index_utils::load_vector_filters(storage_provider, vector_filters_file)?;
136
137 assert_eq!(
138 filters.len(),
139 query_num,
140 "Mismatch in query and vector filter sizes"
141 );
142
143 Some(filters)
144 }
145 None => None,
146 };
147
148 let has_vector_filters = vector_filters.is_some();
149 let has_query_bitmaps = query_bitmaps.is_some();
150
151 if has_vector_filters {
152 if let Some(filters) = vector_filters {
154 let mut bitmaps = vec![BitSet::new(); query_num];
155 for (idx_query, filter) in filters.iter().enumerate() {
156 for item in filter.iter() {
157 if let Ok(idx) = (*item).try_into() {
158 bitmaps[idx_query].insert(idx);
159 }
160 }
161 }
162 query_bitmaps = Some(bitmaps)
163 }
164 }
165
166 let ground_truth_result = compute_ground_truth_from_data::<Data, StorageProvider>(
167 distance_function,
168 dataset_iterator,
169 &query_data,
170 recall_at,
171 insert_iterator,
172 skip_base,
173 query_bitmaps,
174 );
175 assert!(
176 &ground_truth_result.is_ok(),
177 "Ground-truth computation failed"
178 );
179 let (ground_truth, id_to_associated_data) = ground_truth_result?;
180
181 assert_ne!(ground_truth.len(), 0, "No ground-truth results computed");
182
183 if has_vector_filters || has_query_bitmaps {
184 let ground_truth_collection = ground_truth
185 .into_iter()
186 .map(|npq| npq.into_iter().collect())
187 .collect();
188 write_range_search_ground_truth(
189 storage_provider,
190 ground_truth_file,
191 query_num,
192 ground_truth_collection,
193 )
194 } else {
195 let id_to_associated_data = associated_data_file.map(|_| id_to_associated_data);
197 write_ground_truth::<Data>(
198 storage_provider,
199 ground_truth_file,
200 query_num,
201 recall_at as usize,
202 ground_truth,
203 id_to_associated_data,
204 )
205 }
206}
207
208#[derive(Debug, Clone)]
209pub enum MultivecAggregationMethod {
210 AveragePairwise,
211 MinPairwise,
212 AvgofMins,
213}
214
215#[derive(Debug)]
216pub enum ParseAggrError {
217 InvalidFormat(String),
218}
219
220impl std::fmt::Display for ParseAggrError {
221 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
222 match self {
223 Self::InvalidFormat(str) => write!(f, "Invalid format for Aggregation Method: {}", str),
224 }
225 }
226}
227
228impl std::error::Error for ParseAggrError {}
229
230impl FromStr for MultivecAggregationMethod {
231 type Err = ParseAggrError;
232
233 fn from_str(s: &str) -> Result<Self, Self::Err> {
234 match s.to_lowercase().as_str() {
235 "average_pairwise" => Ok(MultivecAggregationMethod::AveragePairwise),
236 "min_pairwise" => Ok(MultivecAggregationMethod::MinPairwise),
237 "avg_of_mins" => Ok(MultivecAggregationMethod::AvgofMins),
238 _ => Err(ParseAggrError::InvalidFormat(String::from(s))),
239 }
240 }
241}
242
243#[allow(clippy::too_many_arguments)]
244#[allow(clippy::panic)]
245pub fn compute_multivec_ground_truth_from_datafiles<
258 Data: GraphDataType,
259 StorageProvider: StorageReadProvider + StorageWriteProvider,
260>(
261 storage_provider: &StorageProvider,
262 distance_function: Metric,
263 aggregation_method: MultivecAggregationMethod,
264 base_file: &str,
265 query_file: &str,
266 ground_truth_file: &str,
267 recall_at: u32,
268 base_file_labels: Option<&str>,
269 query_file_labels: Option<&str>,
270) -> CMDResult<()> {
271 let (base_vectors, _, _, _) = file_util::load_multivec_bin::<
272 Data::VectorDataType,
273 StorageProvider,
274 >(storage_provider, base_file)?;
275
276 let (query_vectors, query_num, query_dim, _) = file_util::load_multivec_bin::<
277 Data::VectorDataType,
278 StorageProvider,
279 >(storage_provider, query_file)?;
280
281 if !((base_file_labels.is_some() && query_file_labels.is_some())
283 || (base_file_labels.is_none() && query_file_labels.is_none()))
284 {
285 return Err(CMDToolError {
286 details: "Both base_file_labels and query_file_labels must be provided or both must be not provided.".to_string(),
287 });
288 }
289
290 let mut query_bitmaps: Option<Vec<BitSet>> = None;
291 if let (Some(base_file_labels), Some(query_file_labels)) = (base_file_labels, query_file_labels)
292 {
293 query_bitmaps = Some(read_labels_and_compute_bitmap(
294 base_file_labels,
295 query_file_labels,
296 )?);
297 }
298
299 let has_query_bitmaps = query_bitmaps.is_some();
300
301 let ground_truth = compute_multivec_ground_truth_from_data::<Data::VectorDataType>(
302 distance_function,
303 aggregation_method,
304 base_vectors,
305 query_vectors,
306 query_dim,
307 recall_at,
308 query_bitmaps,
309 )?;
310
311 if has_query_bitmaps {
312 let ground_truth_collection = ground_truth
313 .into_iter()
314 .map(|npq| npq.into_iter().collect())
315 .collect();
316 write_range_search_ground_truth(
317 storage_provider,
318 ground_truth_file,
319 query_num,
320 ground_truth_collection,
321 )
322 } else {
323 write_ground_truth::<Data>(
325 storage_provider,
326 ground_truth_file,
327 query_num,
328 recall_at as usize,
329 ground_truth,
330 Option::None,
331 )
332 }
333}
334
335fn write_range_search_ground_truth<StorageProvider: StorageReadProvider + StorageWriteProvider>(
336 storage_provider: &StorageProvider,
337 ground_truth_file: &str,
338 number_of_queries: usize,
339 ground_truth: Vec<Vec<Neighbor<u32>>>,
340) -> CMDResult<()> {
341 let mut file = storage_provider.create_for_write(ground_truth_file)?;
342
343 let queue_sizes: Vec<u32> = ground_truth
344 .iter()
345 .map(|queue| queue.len() as u32)
346 .collect();
347 let total_number_of_neighbors: usize = queue_sizes.iter().sum::<u32>() as usize;
348
349 Metadata::new(number_of_queries, total_number_of_neighbors)?.write(&mut file)?;
351
352 let mut queue_sizes_buffer = vec![0; queue_sizes.len() * size_of::<u32>()];
354 queue_sizes_buffer.clone_from_slice(cast_slice::<u32, u8>(&queue_sizes));
355 file.write_all(&queue_sizes_buffer)?;
356
357 let mut neighbor_ids: Vec<u32> = Vec::with_capacity(total_number_of_neighbors);
358
359 for query_neighbors in ground_truth {
361 for neighbor in query_neighbors.iter() {
362 neighbor_ids.push(neighbor.id);
363 }
364 }
365
366 let mut id_buffer = vec![0; total_number_of_neighbors * size_of::<u32>()];
368 id_buffer.clone_from_slice(cast_slice::<u32, u8>(&neighbor_ids));
369 file.write_all(&id_buffer)?;
370
371 file.flush()?;
373
374 Ok(())
375}
376
377fn write_ground_truth<Data: GraphDataType>(
381 storage_provider: &impl StorageWriteProvider,
382 ground_truth_file: &str,
383 number_of_queries: usize,
384 number_of_neighbors: usize,
385 ground_truth: Vec<NeighborPriorityQueue<u32>>,
386 id_to_associated_data: Option<Vec<Data::AssociatedDataType>>,
387) -> CMDResult<()> {
388 let mut file = storage_provider.create_for_write(ground_truth_file)?;
389
390 Metadata::new(number_of_queries, number_of_neighbors)?.write(&mut file)?;
391
392 let mut gt_ids: Vec<u32> = Vec::with_capacity(number_of_neighbors * number_of_queries);
393 let mut gt_distances: Vec<f32> = Vec::with_capacity(number_of_neighbors * number_of_queries);
394
395 for mut query_neighbors in ground_truth {
397 while let Some(closest_node) = query_neighbors.closest_notvisited() {
398 gt_ids.push(closest_node.id);
399 gt_distances.push(closest_node.distance);
400 }
401 }
402
403 if let Some(id_to_associated_data) = id_to_associated_data {
405 let mut associated_data_buffer = Vec::<u8>::new();
406 for id in gt_ids {
407 let associated_data = id_to_associated_data[id as usize];
408 let serialized_associated_data =
409 bincode::serialize(&associated_data).map_err(|e| CMDToolError {
410 details: format!("Failed to serialize associated data: {}", e),
411 })?;
412 associated_data_buffer.extend_from_slice(serialized_associated_data.as_slice());
413 }
414 file.write_all(&associated_data_buffer)?;
415 } else {
416 let mut id_buffer = vec![0; number_of_queries * number_of_neighbors * size_of::<u32>()];
417 id_buffer.clone_from_slice(cast_slice::<u32, u8>(>_ids));
418 file.write_all(&id_buffer)?;
419 }
420
421 let mut distance_buffer = vec![0; number_of_queries * number_of_neighbors * size_of::<f32>()];
423 distance_buffer.clone_from_slice(cast_slice::<f32, u8>(>_distances));
424 file.write_all(&distance_buffer)?;
425
426 file.flush()?;
428
429 Ok(())
430}
431
432type Npq = Vec<NeighborPriorityQueue<u32>>;
433#[allow(clippy::too_many_arguments)]
446pub fn compute_ground_truth_from_data<Data, VectorReader>(
447 distance_function: Metric,
448 dataset_iter: VectorDataIterator<VectorReader, Data::VectorDataType, Data::AssociatedDataType>,
449 queries: &Matrix<Data::VectorDataType>,
450 recall_at: u32,
451 insert_iter: Option<
452 VectorDataIterator<VectorReader, Data::VectorDataType, Data::AssociatedDataType>,
453 >,
454 skip_base: Option<usize>,
455 query_bitmaps: Option<Vec<BitSet>>,
456) -> CMDResult<(Npq, Vec<Data::AssociatedDataType>)>
457where
458 Data: GraphDataType,
459 VectorReader: StorageReadProvider,
460{
461 let query_num = queries.nrows();
462 let query_dim = queries.ncols();
463
464 let mut neighbor_queues: Vec<NeighborPriorityQueue<u32>> = (0..query_num)
465 .map(|_| NeighborPriorityQueue::new(recall_at as usize))
466 .collect();
467 let mut queries_and_neighbor_queue: Vec<_> =
468 queries.row_iter().zip(neighbor_queues.iter_mut()).collect();
469
470 let distance_comparer = Data::VectorDataType::distance(distance_function, Some(query_dim));
471
472 let batch_size = 10_000;
473 let mut data_batch: Vec<Box<[Data::VectorDataType]>> = Vec::with_capacity(batch_size);
474
475 let pool = create_thread_pool(0)?;
476
477 let mut num_base_points: usize = 0;
478 let mut id_to_associated_data = Vec::<Data::AssociatedDataType>::new();
479 let skip_base = skip_base.unwrap_or(0);
480 for chunk in dataset_iter.skip(skip_base).chunks(batch_size).into_iter() {
482 data_batch.clear();
483 for (data_vector, associated_data) in chunk {
484 data_batch.push(data_vector);
485 id_to_associated_data.push(associated_data);
486 }
487 let points = data_batch.len();
488
489 if points == 0 {
490 continue;
491 }
492
493 queries_and_neighbor_queue
495 .par_iter_mut()
496 .enumerate()
497 .for_each_in_pool(
498 pool.as_ref(),
499 |(idx_query, (query, ref mut neighbor_queue))| {
500 for (idx_in_batch, data) in data_batch.iter().enumerate() {
501 let idx = (num_base_points + idx_in_batch) as u32;
502
503 let allowed_by_bitmap = if let Some(ref bitmaps) = query_bitmaps {
504 if let Ok(idx_usize) = idx.try_into() {
505 bitmaps[idx_query].contains(idx_usize)
506 } else {
507 false
508 }
509 } else {
510 true
511 };
512
513 if allowed_by_bitmap {
514 let distance = distance_comparer.evaluate_similarity(data, query);
515 neighbor_queue.insert(Neighbor { id: idx, distance });
516 }
517 }
518 },
519 );
520
521 num_base_points += points;
522 }
523
524 if let Some(insert_iter) = insert_iter {
525 for (insert_idx, (data_vector, _associated_data)) in insert_iter.enumerate() {
526 for (idx_query, (query, ref mut neighbor_queue)) in
528 queries_and_neighbor_queue.iter_mut().enumerate()
529 {
530 let idx = (num_base_points + insert_idx) as u32;
531
532 let allowed_by_bitmap = if let Some(ref bitmaps) = query_bitmaps {
533 if let Ok(idx_usize) = idx.try_into() {
534 bitmaps[idx_query].contains(idx_usize)
535 } else {
536 false
537 }
538 } else {
539 true
540 };
541
542 if allowed_by_bitmap {
543 let distance = distance_comparer.evaluate_similarity(&data_vector, query);
544 neighbor_queue.insert(Neighbor { id: idx, distance })
545 }
546 }
547 }
548 }
549
550 Ok((neighbor_queues, id_to_associated_data))
551}
552
553#[allow(clippy::too_many_arguments)]
554pub fn compute_multivec_ground_truth_from_data<T>(
555 distance_function: Metric,
556 aggregation_method: MultivecAggregationMethod,
557 base_vectors: Vec<Matrix<T>>,
558 queries: Vec<Matrix<T>>,
559 query_dim: usize,
560 recall_at: u32,
561 query_bitmaps: Option<Vec<BitSet>>,
562) -> CMDResult<Vec<NeighborPriorityQueue<u32>>>
563where
564 T: VectorRepr,
565{
566 let query_num = queries.len();
567
568 let mut neighbor_queues: Vec<NeighborPriorityQueue<u32>> = Vec::with_capacity(query_num);
569 for _ in 0..query_num {
571 neighbor_queues.push(NeighborPriorityQueue::new(recall_at as usize));
572 }
573 let mut query_multivecs_and_neighbor_queue: Vec<_> =
574 queries.iter().zip(neighbor_queues.iter_mut()).collect();
575
576 let distance_comparer = T::distance(distance_function, Some(query_dim));
577
578 let pool = create_thread_pool(0)?;
579
580 query_multivecs_and_neighbor_queue
583 .par_iter_mut()
584 .enumerate()
585 .for_each_in_pool(
586 pool.as_ref(),
587 |(query_idx, (query_multivec, neighbor_queue))| {
588 for (idx_base, base_multivec) in base_vectors.iter().enumerate() {
589 let allowed_by_bitmap = if let Some(ref bitmaps) = query_bitmaps {
591 bitmaps[query_idx].contains(idx_base)
592 } else {
593 true
594 };
595
596 if allowed_by_bitmap {
597 let distance = match aggregation_method {
599 MultivecAggregationMethod::AveragePairwise => {
600 let mut total_distance = 0.0;
601 for query_vec in query_multivec.row_iter() {
602 for base_vec in base_multivec.row_iter() {
603 let dist = distance_comparer
604 .evaluate_similarity(query_vec, base_vec);
605 total_distance += dist;
606 }
607 }
608 total_distance
609 / (query_multivec.nrows() * base_multivec.nrows()) as f32
610 }
611 MultivecAggregationMethod::MinPairwise => {
612 let mut min_distance = f32::MAX;
613 for query_vec in query_multivec.row_iter() {
614 for base_vec in base_multivec.row_iter() {
615 let dist = distance_comparer
616 .evaluate_similarity(query_vec, base_vec);
617 min_distance = min_distance.min(dist);
618 }
619 }
620 min_distance
621 }
622 MultivecAggregationMethod::AvgofMins => {
623 let mut distance = 0_f32;
624 for query_vec in query_multivec.row_iter() {
625 let mut local_min = f32::MAX;
626 for base_vec in base_multivec.row_iter() {
627 let dist = distance_comparer
628 .evaluate_similarity(query_vec, base_vec);
629 local_min = local_min.min(dist);
630 }
631 distance += local_min;
632 }
633 distance / query_multivec.nrows() as f32
634 }
635 };
636 let idx = idx_base as u32;
638 neighbor_queue.insert(Neighbor { id: idx, distance });
639 }
640 }
641 },
642 );
643
644 Ok(neighbor_queues)
645}