1use std::{collections::HashSet, fmt, hash::Hash, io::Read, mem::size_of};
6
7use bytemuck::cast_slice;
8use diskann::{ANNError, ANNResult};
9use diskann_providers::storage::StorageReadProvider;
10use diskann_utils::io::Metadata;
11use tracing::info;
12
13use crate::utils::CMDToolError;
14
15pub struct TruthSet {
16 pub index_nodes: Vec<u32>,
17 pub distances: Option<Vec<f32>>,
18 pub index_num_points: usize,
19 pub index_dimension: usize,
20}
21
22pub struct RangeSearchTruthSet {
23 pub index_nodes: Vec<Vec<u32>>,
24 pub distances: Option<Vec<Vec<f32>>>,
25 pub index_num_points: usize,
26 pub index_dimensions: Vec<u32>,
27}
28
29#[derive(Debug, Clone, Copy)]
45pub struct KRecallAtN {
46 k: u32,
47 n: u32,
48}
49
50#[derive(Debug, Clone, Copy)]
51pub enum RecallBoundsError {
52 KGreaterThanN { k: u32, n: u32 },
54 ArgumentIsZero { k: u32, n: u32 },
58}
59impl fmt::Display for RecallBoundsError {
60 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
61 match self {
62 RecallBoundsError::KGreaterThanN { k, n } => {
63 write!(
64 f,
65 "recall value k ({}) must be less than or equal to n ({})",
66 k, n
67 )
68 }
69 RecallBoundsError::ArgumentIsZero { k, n } => {
71 if *k == 0 && *n == 0 {
72 write!(f, "recall values k and n must both be non-zero")
73 } else if *k == 0 {
74 write!(f, "recall values k must be non-zero")
75 } else {
76 write!(f, "recall values n must be non-zero")
77 }
78 }
79 }
80 }
81}
82
83impl std::error::Error for RecallBoundsError {}
85
86impl From<RecallBoundsError> for CMDToolError {
88 fn from(err: RecallBoundsError) -> Self {
89 CMDToolError {
90 details: err.to_string(),
91 }
92 }
93}
94
95impl KRecallAtN {
96 pub fn new(k: u32, n: u32) -> Result<Self, RecallBoundsError> {
100 if k == 0 || n == 0 {
101 Err(RecallBoundsError::ArgumentIsZero { k, n })
102 } else if k > n {
103 Err(RecallBoundsError::KGreaterThanN { k, n })
104 } else {
105 Ok(KRecallAtN { k, n })
106 }
107 }
108
109 pub fn get_k(self) -> usize {
110 self.k as usize
111 }
112
113 pub fn get_n(self) -> usize {
114 self.n as usize
115 }
116}
117
118#[allow(clippy::too_many_arguments)]
121pub fn calculate_recall<T: Eq + Hash + Copy>(
122 num_queries: usize,
123 ground_truth: &[T],
124 gt_dist: Option<&Vec<f32>>,
125 dim_gt: usize,
126 our_results: &[T],
127 dim_or: u32,
128 recall_bounds: KRecallAtN,
129) -> ANNResult<f64> {
130 let mut total_recall: f64 = 0.0;
131 let (mut gt, mut res): (HashSet<T>, HashSet<T>) = (HashSet::new(), HashSet::new());
132
133 for i in 0..num_queries {
134 gt.clear();
135 res.clear();
136
137 let gt_slice = &ground_truth[dim_gt * i..];
138 let res_slice = &our_results[dim_or as usize * i..];
139 let mut tie_breaker = recall_bounds.get_k();
140
141 if let Some(gt_dist) = gt_dist {
142 let gt_dist_vec = >_dist[dim_gt * i..];
143 while tie_breaker < dim_gt
144 && gt_dist_vec[tie_breaker] == gt_dist_vec[recall_bounds.get_k() - 1]
145 {
146 tie_breaker += 1;
147 }
148 }
149
150 (0..tie_breaker).for_each(|idx| {
151 gt.insert(gt_slice[idx]);
152 });
153
154 (0..recall_bounds.get_n()).for_each(|idx| {
155 res.insert(res_slice[idx]);
156 });
157
158 let mut cur_recall: u32 = 0;
159 for v in gt.iter() {
160 if res.contains(v) && cur_recall < recall_bounds.get_k() as u32 {
161 cur_recall += 1;
162 }
163 }
164
165 total_recall += cur_recall as f64;
166 }
167
168 Ok(total_recall / num_queries as f64 * (100.0 / recall_bounds.get_k() as f64))
169}
170
171pub fn calculate_filtered_search_recall(
217 num_queries: usize,
218 gt_dist: Option<&[Vec<f32>]>,
219 groundtruth: &[Vec<u32>],
220 our_results: &[Vec<u32>],
221 k_recall: u32,
222) -> ANNResult<f64> {
223 if k_recall == 0 {
224 return Err(ANNError::log_index_error(format_args!(
225 "k_recall value must be greater than 0, but got {}",
226 k_recall
227 )));
228 }
229
230 if groundtruth.len() != num_queries || our_results.len() != num_queries {
231 return Err(ANNError::log_index_error(format_args!(
232 "groundtruth length ({}) or our_results length ({}) does not match num_queries ({})",
233 groundtruth.len(),
234 our_results.len(),
235 num_queries
236 )));
237 }
238
239 let mut total_recall = 0.0;
240 for i in 0..num_queries {
241 let mut gt: HashSet<u32> = HashSet::new();
242 let mut res: HashSet<u32> = HashSet::new();
243 let gt_cutoff = (k_recall as usize).min(groundtruth[i].len());
244
245 for &item in &groundtruth[i][..gt_cutoff] {
246 gt.insert(item);
248 }
249
250 for &item in &our_results[i] {
251 res.insert(item);
252 }
253
254 if gt_cutoff > 0 {
255 if let Some(gt_dist) = gt_dist {
257 let gt_dist_vec = gt_dist[i].as_slice();
258
259 if gt_dist_vec.len() != groundtruth[i].len() {
260 return Err(ANNError::log_index_error(format_args!(
261 "Ground truth distance for query ({}) vector length ({}) is not equal to groundtruth len ({})",
262 i,
263 gt_dist_vec.len(),
264 groundtruth[i].len(),
265 )));
266 }
267
268 let mut tie_breaker = gt_cutoff;
269
270 while tie_breaker < gt_dist_vec.len() && gt_dist_vec[tie_breaker] == gt_dist_vec[gt_cutoff - 1]
272 {
273 gt.insert(groundtruth[i][tie_breaker]);
274 tie_breaker += 1;
275 }
276 }
277 }
278
279 let mut cur_recall = 0;
280
281 for &v in > {
282 if res.contains(&v) {
283 cur_recall += 1;
284 }
285 }
286
287 if gt_cutoff > 0 {
288 total_recall += (100.0 * cur_recall as f64) / gt_cutoff.max(res.len()) as f64;
289 } else {
290 total_recall += 100.0;
291 }
292 }
293
294 Ok(total_recall / num_queries as f64)
295}
296
297pub fn load_truthset(
298 storage_provider: &impl StorageReadProvider,
299 bin_file: &str,
300) -> ANNResult<TruthSet> {
301 let actual_file_size = storage_provider.get_length(bin_file)? as usize;
302 let mut file = storage_provider.open_reader(bin_file)?;
303
304 let metadata = Metadata::read(&mut file)?;
305 let (npts, dim) = metadata.into_dims();
306
307 info!("Metadata: #pts = {npts}, #dims = {dim}... ");
308
309 let expected_file_size_with_dists: usize =
310 2 * npts * dim * size_of::<u32>() + 2 * size_of::<u32>();
311 let expected_file_size_just_ids: usize = npts * dim * size_of::<u32>() + 2 * size_of::<u32>();
312
313 let truthset_type : i32 = match actual_file_size {
315 x if x == expected_file_size_with_dists => 1,
316 x if x == expected_file_size_just_ids => 2,
317 _ => return Err(ANNError::log_index_error(format_args!(
318 "Error. File size mismatch. File should have bin format, with npts followed by ngt followed by npts*ngt ids and optionally followed by npts*ngt distance values; actual size: {}, expected: {} or {}",
319 actual_file_size,
320 expected_file_size_with_dists,
321 expected_file_size_just_ids
322 )))
323 };
324
325 let mut ids: Vec<u32> = vec![0; npts * dim];
326 let mut buffer = vec![0; npts * dim * size_of::<u32>()];
327 file.read_exact(&mut buffer)?;
328 ids.clone_from_slice(cast_slice::<u8, u32>(&buffer));
329
330 if truthset_type == 1 {
331 let mut dists: Vec<f32> = vec![0.0; npts * dim];
332 let mut buffer = vec![0; npts * dim * size_of::<f32>()];
333 file.read_exact(&mut buffer)?;
334 dists.clone_from_slice(cast_slice::<u8, f32>(&buffer));
335
336 return Ok(TruthSet {
337 index_nodes: ids,
338 distances: Some(dists),
339 index_num_points: npts,
340 index_dimension: dim,
341 });
342 }
343
344 Ok(TruthSet {
345 index_nodes: ids,
346 distances: None,
347 index_num_points: npts,
348 index_dimension: dim,
349 })
350}
351
352pub fn load_range_truthset(
359 storage_provider: &impl StorageReadProvider,
360 bin_file: &str,
361) -> ANNResult<RangeSearchTruthSet> {
362 let mut file = storage_provider.open_reader(bin_file)?;
363
364 let metadata = Metadata::read(&mut file)?;
365 let (npts, total_ids) = metadata.into_dims();
366 let mut buffer = [0; size_of::<i32>()];
367
368 info!("Metadata: #pts = {}, #totalIds = {}", npts, total_ids);
369
370 let mut ids: Vec<Vec<u32>> = Vec::new();
371 let mut counts: Vec<u32> = vec![0; npts];
372
373 for count in counts.iter_mut() {
374 file.read_exact(&mut buffer)?;
375 *count = i32::from_le_bytes(buffer) as u32;
376 }
377
378 for &count in &counts {
379 let mut point_ids: Vec<u32> = vec![0; count as usize];
380 let mut buffer = vec![0; count as usize * size_of::<u32>()];
381 file.read_exact(&mut buffer)?;
382 point_ids.clone_from_slice(cast_slice::<u8, u32>(&buffer));
383 ids.push(point_ids);
384 }
385
386 Ok(RangeSearchTruthSet {
387 index_nodes: ids,
388 distances: None,
389 index_num_points: npts,
390 index_dimensions: counts,
391 })
392}
393
394pub fn load_vector_filters(
396 storage_provider: &impl StorageReadProvider,
397 bin_file: &str,
398) -> ANNResult<Vec<HashSet<u32>>> {
399 let range_truthset = load_range_truthset(storage_provider, bin_file)?;
400
401 let query_filters: Vec<HashSet<u32>> = range_truthset
402 .index_nodes
403 .into_iter()
404 .map(|filter| filter.into_iter().collect())
405 .collect();
406
407 Ok(query_filters)
408}
409
410#[cfg(test)]
411mod test_search_index_utils {
412 use super::*;
413
414 struct ExpectedRecall {
415 pub recall_k: usize,
416 pub recall_n: usize,
417 pub components: Vec<usize>,
419 }
420
421 impl ExpectedRecall {
422 fn new(recall_k: usize, recall_n: usize, components: Vec<usize>) -> Self {
423 assert!(recall_k <= recall_n);
424 components.iter().for_each(|x| {
425 assert!(*x <= recall_k);
426 });
427 Self {
428 recall_k,
429 recall_n,
430 components,
431 }
432 }
433
434 fn compute(&self) -> f64 {
435 100.0 * (self.components.iter().sum::<usize>() as f64)
436 / ((self.components.len() * self.recall_k) as f64)
437 }
438 }
439
440 #[test]
441 fn test_k_recall_at_n_struct() {
442 for k in 1..=10 {
444 for n in k..=10 {
445 let v = KRecallAtN::new(k, n).unwrap();
446 assert_eq!(v.get_k(), k as usize);
447 assert_eq!(v.get_n(), n as usize);
448 }
449 }
450
451 for n in 1..=10 {
454 for k in (n + 1)..=11 {
455 let v = KRecallAtN::new(k, n).unwrap_err();
456 match v {
457 RecallBoundsError::KGreaterThanN { k: k_err, n: n_err } => {
458 assert_eq!(k_err, k);
459 assert_eq!(n_err, n);
460 }
461 RecallBoundsError::ArgumentIsZero { .. } => {
462 panic!("unreachable reached");
463 }
464 }
465 let message = format!("{}", v);
466 assert!(message.contains("recall value k"));
467 assert!(message.contains("must be less than or equal to n"));
468 assert!(message.contains(&format!("{}", k)));
469 assert!(message.contains(&format!("{}", n)));
470 }
471 }
472
473 let v = KRecallAtN::new(0, 0).unwrap_err();
475 let message = format!("{}", v);
476 assert!(message == "recall values k and n must both be non-zero");
477
478 let v = KRecallAtN::new(0, 10).unwrap_err();
480 let message = format!("{}", v);
481 assert!(message == "recall values k must be non-zero");
482
483 let v = KRecallAtN::new(10, 0).unwrap_err();
485 let message = format!("{}", v);
486 assert!(message == "recall values n must be non-zero");
487 }
488
489 #[test]
490 fn test_compute_recall() {
491 let groundtruth_dim = 10;
493 let num_queries = 4;
494
495 let groundtruth: Vec<u32> = vec![
496 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ];
501
502 assert_eq!(groundtruth.len(), num_queries * groundtruth_dim);
503
504 let distances: Vec<f32> = vec![
505 0.0, 1.0, 2.0, 3.0, 3.0, 3.0, 3.0, 4.0, 5.0, 6.0, 2.0, 3.0, 3.0, 3.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 0.0, 1.0, 2.0, 3.0, 3.0, 3.0, 3.0, 4.0, 5.0, 6.0, ];
510
511 assert_eq!(distances.len(), groundtruth.len());
512
513 let results_dim = 6;
515 let our_results: Vec<u32> = vec![
516 100, 0, 1, 2, 5, 6, 100, 101, 7, 8, 9, 10, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, ];
521 assert_eq!(our_results.len(), num_queries * results_dim);
522
523 let expected_no_ties = vec![
525 ExpectedRecall::new(1, 1, vec![0, 0, 1, 1]),
527 ExpectedRecall::new(2, 2, vec![1, 0, 2, 2]),
528 ExpectedRecall::new(3, 3, vec![2, 1, 3, 3]),
529 ExpectedRecall::new(4, 4, vec![3, 2, 4, 4]),
530 ExpectedRecall::new(5, 5, vec![3, 3, 5, 5]),
531 ExpectedRecall::new(6, 6, vec![4, 4, 6, 6]),
532 ExpectedRecall::new(1, 2, vec![1, 0, 1, 1]),
534 ExpectedRecall::new(1, 3, vec![1, 0, 1, 1]),
535 ExpectedRecall::new(2, 3, vec![2, 0, 2, 2]),
536 ExpectedRecall::new(3, 5, vec![3, 1, 3, 3]),
537 ];
538 let epsilon = 1e-6; for (i, expected) in expected_no_ties.iter().enumerate() {
540 println!("No Ties: i = {i}");
541 assert_eq!(expected.components.len(), num_queries);
542 let recall = calculate_recall(
543 num_queries,
544 &groundtruth,
545 None,
546 groundtruth_dim,
547 &our_results,
548 results_dim as u32,
549 KRecallAtN::new(expected.recall_k as u32, expected.recall_n as u32).unwrap(),
550 );
551 let left = recall.unwrap();
552 let right = expected.compute();
553 assert!(
554 (left - right).abs() < epsilon,
555 "left = {}, right = {}",
556 left,
557 right
558 );
559 }
560
561 let expected_with_ties = vec![
563 ExpectedRecall::new(1, 1, vec![0, 0, 1, 1]),
565 ExpectedRecall::new(2, 2, vec![1, 0, 2, 2]),
566 ExpectedRecall::new(3, 3, vec![2, 1, 3, 3]),
567 ExpectedRecall::new(4, 4, vec![3, 2, 4, 4]),
568 ExpectedRecall::new(5, 5, vec![4, 3, 5, 5]), ExpectedRecall::new(6, 6, vec![5, 4, 6, 6]), ExpectedRecall::new(1, 2, vec![1, 0, 1, 1]),
572 ExpectedRecall::new(1, 3, vec![1, 0, 1, 1]),
573 ExpectedRecall::new(2, 3, vec![2, 1, 2, 2]),
574 ExpectedRecall::new(4, 5, vec![4, 3, 4, 4]),
575 ];
576
577 for (i, expected) in expected_with_ties.iter().enumerate() {
578 println!("With Ties: i = {i}");
579 assert_eq!(expected.components.len(), num_queries);
580 let recall = calculate_recall(
581 num_queries,
582 &groundtruth,
583 Some(&distances),
584 groundtruth_dim,
585 &our_results,
586 results_dim as u32,
587 KRecallAtN::new(expected.recall_k as u32, expected.recall_n as u32).unwrap(),
588 );
589 let left = recall.unwrap();
590 let right = expected.compute();
591 assert!(
592 (left - right).abs() < epsilon,
593 "left = {}, right = {}",
594 left,
595 right
596 );
597 }
598 }
599
600 #[test]
601 fn test_calculate_filtered_search_recall() {
602 let filtered_search_recall =
603 calculate_filtered_search_recall(1, None, &[vec![5, 6]], &[vec![5, 6, 7, 8, 9]], 1000)
604 .unwrap();
605 assert_eq!(
606 filtered_search_recall, 40.0,
607 "Returned more results than ground truth"
608 );
609
610 assert_eq!(
611 calculate_filtered_search_recall(
612 1,
613 None,
614 &[vec![0, 1, 2, 3, 4],],
615 &[vec![0, 1],],
616 1000
617 )
618 .unwrap(),
619 40.0,
620 "Returned less results than ground truth"
621 );
622
623 let groundtruth: Vec<Vec<u32>> = vec![vec![0, 1, 2, 3, 4], vec![5, 6]];
624
625 let our_results: Vec<Vec<u32>> = vec![vec![0, 1], vec![5, 6, 7, 8, 9]];
626
627 assert_eq!(
628 calculate_filtered_search_recall(2, None, &groundtruth, &our_results, 1000).unwrap(),
629 40.0,
630 "Combination of both cases"
631 );
632
633 assert_eq!(
634 calculate_filtered_search_recall(
635 1,
636 None,
637 &[vec![0, 1, 2, 3, 4],],
638 &[vec![0, 1, 2, 3, 4],],
639 1000
640 )
641 .unwrap(),
642 100.0,
643 "The result matched the ground truth"
644 );
645
646 assert_eq!(
647 calculate_filtered_search_recall(
648 1,
649 None,
650 &[vec![0, 1, 2, 3, 4],],
651 &[vec![0, 1, 12, 13, 14],],
652 1000
653 )
654 .unwrap(),
655 40.0,
656 "The result partially matched the ground truth"
657 );
658
659 assert_eq!(
660 calculate_filtered_search_recall(
661 1,
662 None,
663 &[vec![0; 0],],
664 &[vec![0, 1, 2, 3, 4],],
665 1000
666 )
667 .unwrap(),
668 100.0,
669 "The empty ground truth"
670 );
671 }
672
673 #[test]
674 fn test_calculate_filtered_search_recall_with_tie_breaking() {
675 let gt_distances: Vec<Vec<f32>> = vec![
677 vec![0.1, 0.2, 0.3, 0.3, 0.3], vec![0.1, 0.2, 0.3, 0.4, 0.5], ];
680
681 let groundtruth: Vec<Vec<u32>> = vec![
682 vec![0, 1, 2, 3, 4], vec![5, 6, 7, 8, 9],
684 ];
685
686 let our_results: Vec<Vec<u32>> = vec![
687 vec![0, 1, 3, 2, 4], vec![5, 6, 7, 8, 9], ];
690
691 assert_eq!(
693 calculate_filtered_search_recall(
694 2,
695 Some(>_distances),
696 &groundtruth,
697 &our_results,
698 3 )
700 .unwrap(),
701 80.0, "Tie-breaking should include all tied elements"
704 );
705
706 assert_eq!(
708 calculate_filtered_search_recall(2, None, &groundtruth, &our_results, 3).unwrap(),
709 60.0, "Without tie-breaking, both queries should match on 3 of 5 elements"
711 );
712
713 assert_eq!(
715 calculate_filtered_search_recall(2, None, &groundtruth, &our_results, 10).unwrap(),
716 100.0,
717 "Without tie-breaking and with large k, both queries should match on all elements"
718 );
719 }
720
721 #[test]
722 fn test_calculate_filtered_search_recall_empty_ground_truth() {
723 assert_eq!(
724 calculate_filtered_search_recall(
725 2,
726 Some(&[vec![], vec![]]),
727 &[vec![], vec![]],
728 &[vec![0, 1, 2], vec![5, 6, 7],],
729 1
730 )
731 .unwrap(),
732 100.0,
733 "Empty ground truth should result in 100% recall"
734 );
735 }
736
737 #[test]
738 fn test_recall_bounds_error_display() {
739 let error = RecallBoundsError::KGreaterThanN { k: 10, n: 5 };
740 let message = format!("{}", error);
741 assert!(message.contains("recall value k"));
742 assert!(message.contains("must be less than or equal to n"));
743
744 let error = RecallBoundsError::ArgumentIsZero { k: 0, n: 0 };
745 let message = format!("{}", error);
746 assert_eq!(message, "recall values k and n must both be non-zero");
747
748 let error = RecallBoundsError::ArgumentIsZero { k: 0, n: 5 };
749 let message = format!("{}", error);
750 assert_eq!(message, "recall values k must be non-zero");
751
752 let error = RecallBoundsError::ArgumentIsZero { k: 5, n: 0 };
753 let message = format!("{}", error);
754 assert_eq!(message, "recall values n must be non-zero");
755 }
756
757 #[test]
758 fn test_recall_bounds_error_conversion() {
759 let error = RecallBoundsError::KGreaterThanN { k: 10, n: 5 };
760 let cmd_error: CMDToolError = error.into();
761 assert!(!cmd_error.details.is_empty());
762 }
763
764 #[test]
765 fn test_k_recall_at_n_getters() {
766 let recall = KRecallAtN::new(5, 10).unwrap();
767 assert_eq!(recall.get_k(), 5);
768 assert_eq!(recall.get_n(), 10);
769 }
770
771 #[test]
772 fn test_k_recall_at_n_equal_values() {
773 let recall = KRecallAtN::new(5, 5).unwrap();
774 assert_eq!(recall.get_k(), 5);
775 assert_eq!(recall.get_n(), 5);
776 }
777}