1use std::{collections::HashSet, fmt, hash::Hash, io::Read, mem::size_of};
6
7use bytemuck::cast_slice;
8use diskann::{ANNError, ANNResult};
9use diskann_providers::model::graph::traits::GraphDataType;
10use diskann_providers::storage::StorageReadProvider;
11use diskann_utils::io::Metadata;
12use tracing::{error, info};
13
14use crate::utils::CMDToolError;
15
16pub struct TruthSet {
17 pub index_nodes: Vec<u32>,
18 pub distances: Option<Vec<f32>>,
19 pub index_num_points: usize,
20 pub index_dimension: usize,
21}
22
23pub struct TruthSetWithAssociatedData<Data: GraphDataType> {
24 pub index_nodes: Vec<<Data as GraphDataType>::AssociatedDataType>,
25 pub distances: Option<Vec<f32>>,
26 pub index_num_points: usize,
27 pub index_dimension: usize,
28}
29
30pub struct RangeSearchTruthSet {
31 pub index_nodes: Vec<Vec<u32>>,
32 pub distances: Option<Vec<Vec<f32>>>,
33 pub index_num_points: usize,
34 pub index_dimensions: Vec<u32>,
35}
36
37#[derive(Debug, Clone, Copy)]
53pub struct KRecallAtN {
54 k: u32,
55 n: u32,
56}
57
58#[derive(Debug, Clone, Copy)]
59pub enum RecallBoundsError {
60 KGreaterThanN { k: u32, n: u32 },
62 ArgumentIsZero { k: u32, n: u32 },
66}
67impl fmt::Display for RecallBoundsError {
68 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
69 match self {
70 RecallBoundsError::KGreaterThanN { k, n } => {
71 write!(
72 f,
73 "recall value k ({}) must be less than or equal to n ({})",
74 k, n
75 )
76 }
77 RecallBoundsError::ArgumentIsZero { k, n } => {
79 if *k == 0 && *n == 0 {
80 write!(f, "recall values k and n must both be non-zero")
81 } else if *k == 0 {
82 write!(f, "recall values k must be non-zero")
83 } else {
84 write!(f, "recall values n must be non-zero")
85 }
86 }
87 }
88 }
89}
90
91impl std::error::Error for RecallBoundsError {}
93
94impl From<RecallBoundsError> for CMDToolError {
96 fn from(err: RecallBoundsError) -> Self {
97 CMDToolError {
98 details: err.to_string(),
99 }
100 }
101}
102
103impl KRecallAtN {
104 pub fn new(k: u32, n: u32) -> Result<Self, RecallBoundsError> {
108 if k == 0 || n == 0 {
109 Err(RecallBoundsError::ArgumentIsZero { k, n })
110 } else if k > n {
111 Err(RecallBoundsError::KGreaterThanN { k, n })
112 } else {
113 Ok(KRecallAtN { k, n })
114 }
115 }
116
117 pub fn get_k(self) -> usize {
118 self.k as usize
119 }
120
121 pub fn get_n(self) -> usize {
122 self.n as usize
123 }
124}
125
126#[allow(clippy::too_many_arguments)]
129pub fn calculate_recall<T: Eq + Hash + Copy>(
130 num_queries: usize,
131 ground_truth: &[T],
132 gt_dist: Option<&Vec<f32>>,
133 dim_gt: usize,
134 our_results: &[T],
135 dim_or: u32,
136 recall_bounds: KRecallAtN,
137) -> ANNResult<f64> {
138 let mut total_recall: f64 = 0.0;
139 let (mut gt, mut res): (HashSet<T>, HashSet<T>) = (HashSet::new(), HashSet::new());
140
141 for i in 0..num_queries {
142 gt.clear();
143 res.clear();
144
145 let gt_slice = &ground_truth[dim_gt * i..];
146 let res_slice = &our_results[dim_or as usize * i..];
147 let mut tie_breaker = recall_bounds.get_k();
148
149 if let Some(gt_dist) = gt_dist {
150 let gt_dist_vec = >_dist[dim_gt * i..];
151 while tie_breaker < dim_gt
152 && gt_dist_vec[tie_breaker] == gt_dist_vec[recall_bounds.get_k() - 1]
153 {
154 tie_breaker += 1;
155 }
156 }
157
158 (0..tie_breaker).for_each(|idx| {
159 gt.insert(gt_slice[idx]);
160 });
161
162 (0..recall_bounds.get_n()).for_each(|idx| {
163 res.insert(res_slice[idx]);
164 });
165
166 let mut cur_recall: u32 = 0;
167 for v in gt.iter() {
168 if res.contains(v) && cur_recall < recall_bounds.get_k() as u32 {
169 cur_recall += 1;
170 }
171 }
172
173 total_recall += cur_recall as f64;
174 }
175
176 Ok(total_recall / num_queries as f64 * (100.0 / recall_bounds.get_k() as f64))
177}
178
179pub fn calculate_range_search_recall(
180 num_queries: u32,
181 groundtruth: &[Vec<u32>],
182 our_results: &[Vec<u32>],
183) -> ANNResult<f64> {
184 let mut total_recall = 0.0;
185 for i in 0..num_queries as usize {
186 let mut gt: HashSet<u32> = HashSet::new();
187 let mut res: HashSet<u32> = HashSet::new();
188
189 for &item in &groundtruth[i] {
190 gt.insert(item);
191 }
192
193 for &item in &our_results[i] {
194 res.insert(item);
195 }
196
197 let mut cur_recall = 0;
198 for &v in > {
199 if res.contains(&v) {
200 cur_recall += 1;
201 }
202 }
203
204 if !gt.is_empty() {
205 total_recall += (100.0 * cur_recall as f64) / gt.len() as f64;
206 } else {
207 total_recall += 100.0;
208 }
209 }
210
211 Ok(total_recall / num_queries as f64)
212}
213
214pub fn calculate_filtered_search_recall(
250 num_queries: usize,
251 gt_dist: Option<&[Vec<f32>]>,
252 groundtruth: &[Vec<u32>],
253 our_results: &[Vec<u32>],
254 k_recall: u32,
255) -> ANNResult<f64> {
256 if k_recall == 0 {
257 return Err(ANNError::log_index_error(format_args!(
258 "k_recall value must be greater than 0, but got {}",
259 k_recall
260 )));
261 }
262
263 if groundtruth.len() != num_queries || our_results.len() != num_queries {
264 return Err(ANNError::log_index_error(format_args!(
265 "groundtruth length ({}) or our_results length ({}) does not match num_queries ({})",
266 groundtruth.len(),
267 our_results.len(),
268 num_queries
269 )));
270 }
271
272 let mut total_recall = 0.0;
273 for i in 0..num_queries {
274 let mut gt: HashSet<u32> = HashSet::new();
275 let mut res: HashSet<u32> = HashSet::new();
276 let gt_cutoff = (k_recall as usize).min(groundtruth[i].len());
277
278 for &item in &groundtruth[i][..gt_cutoff] {
279 gt.insert(item);
281 }
282
283 for &item in &our_results[i] {
284 res.insert(item);
285 }
286
287 if gt_cutoff > 0 {
288 if let Some(gt_dist) = gt_dist {
290 let gt_dist_vec = gt_dist[i].as_slice();
291
292 if gt_dist_vec.len() != groundtruth[i].len() {
293 return Err(ANNError::log_index_error(format_args!(
294 "Ground truth distance for query ({}) vector length ({}) is not equal to groundtruth len ({})",
295 i,
296 gt_dist_vec.len(),
297 groundtruth[i].len(),
298 )));
299 }
300
301 let mut tie_breaker = gt_cutoff;
302
303 while tie_breaker < gt_dist_vec.len() && gt_dist_vec[tie_breaker] == gt_dist_vec[gt_cutoff - 1]
305 {
306 gt.insert(groundtruth[i][tie_breaker]);
307 tie_breaker += 1;
308 }
309 }
310 }
311
312 let mut cur_recall = 0;
313
314 for &v in > {
315 if res.contains(&v) {
316 cur_recall += 1;
317 }
318 }
319
320 if gt_cutoff > 0 {
321 total_recall += (100.0 * cur_recall as f64) / gt_cutoff.max(res.len()) as f64;
322 } else {
323 total_recall += 100.0;
324 }
325 }
326
327 Ok(total_recall / num_queries as f64)
328}
329
330pub fn get_graph_num_frozen_points(
331 storage_provider: &impl StorageReadProvider,
332 graph_file: &str,
333) -> ANNResult<usize> {
334 let mut file = storage_provider.open_reader(graph_file)?;
335 let mut usize_buffer = [0; size_of::<usize>()];
336 let mut u32_buffer = [0; size_of::<u32>()];
337
338 file.read_exact(&mut usize_buffer)?;
339 file.read_exact(&mut u32_buffer)?;
340 file.read_exact(&mut u32_buffer)?;
341 file.read_exact(&mut usize_buffer)?;
342 let file_frozen_pts = usize::from_le_bytes(usize_buffer);
343
344 Ok(file_frozen_pts)
345}
346
347pub fn get_graph_max_observed_degree(
348 storage_provider: &impl StorageReadProvider,
349 graph_file: &str,
350) -> ANNResult<u32> {
351 let mut file = storage_provider.open_reader(graph_file)?;
352 let mut usize_buffer = [0; size_of::<usize>()];
353 let mut u32_buffer = [0; size_of::<u32>()];
354
355 file.read_exact(&mut usize_buffer)?;
356 file.read_exact(&mut u32_buffer)?;
357 let max_observed_degree = u32::from_le_bytes(u32_buffer);
358
359 Ok(max_observed_degree)
360}
361
362pub fn load_truthset(
363 storage_provider: &impl StorageReadProvider,
364 bin_file: &str,
365) -> ANNResult<TruthSet> {
366 let actual_file_size = storage_provider.get_length(bin_file)? as usize;
367 let mut file = storage_provider.open_reader(bin_file)?;
368
369 let metadata = Metadata::read(&mut file)?;
370 let (npts, dim) = metadata.into_dims();
371
372 info!("Metadata: #pts = {npts}, #dims = {dim}... ");
373
374 let expected_file_size_with_dists: usize =
375 2 * npts * dim * size_of::<u32>() + 2 * size_of::<u32>();
376 let expected_file_size_just_ids: usize = npts * dim * size_of::<u32>() + 2 * size_of::<u32>();
377
378 let truthset_type : i32 = match actual_file_size {
380 x if x == expected_file_size_with_dists => 1,
381 x if x == expected_file_size_just_ids => 2,
382 _ => return Err(ANNError::log_index_error(format_args!(
383 "Error. File size mismatch. File should have bin format, with npts followed by ngt followed by npts*ngt ids and optionally followed by npts*ngt distance values; actual size: {}, expected: {} or {}",
384 actual_file_size,
385 expected_file_size_with_dists,
386 expected_file_size_just_ids
387 )))
388 };
389
390 let mut ids: Vec<u32> = vec![0; npts * dim];
391 let mut buffer = vec![0; npts * dim * size_of::<u32>()];
392 file.read_exact(&mut buffer)?;
393 ids.clone_from_slice(cast_slice::<u8, u32>(&buffer));
394
395 if truthset_type == 1 {
396 let mut dists: Vec<f32> = vec![0.0; npts * dim];
397 let mut buffer = vec![0; npts * dim * size_of::<f32>()];
398 file.read_exact(&mut buffer)?;
399 dists.clone_from_slice(cast_slice::<u8, f32>(&buffer));
400
401 return Ok(TruthSet {
402 index_nodes: ids,
403 distances: Some(dists),
404 index_num_points: npts,
405 index_dimension: dim,
406 });
407 }
408
409 Ok(TruthSet {
410 index_nodes: ids,
411 distances: None,
412 index_num_points: npts,
413 index_dimension: dim,
414 })
415}
416
417pub fn load_truthset_with_associated_data<Data: GraphDataType>(
418 storage_provider: &impl StorageReadProvider,
419 bin_file: &str,
420) -> ANNResult<TruthSetWithAssociatedData<Data>> {
421 let mut file = storage_provider.open_reader(bin_file)?;
422
423 let metadata = Metadata::read(&mut file)?;
424 let (npts, dim) = metadata.into_dims();
425
426 info!("Metadata: #pts = {}, #dims = {}...", npts, dim);
427
428 let mut associated_data: Vec<Data::AssociatedDataType> =
429 vec![Data::AssociatedDataType::default(); npts * dim];
430
431 for associated_datum in associated_data.iter_mut().take(npts * dim) {
432 let mut associated_data_buf = vec![0u8; size_of::<Data::AssociatedDataType>()];
433 file.read_exact(&mut associated_data_buf)
434 .map_err(ANNError::log_io_error)?;
435
436 match bincode::deserialize::<Data::AssociatedDataType>(&associated_data_buf) {
437 Ok(datum) => {
438 *associated_datum = datum;
439 }
440 Err(_) => {
441 error!("Error deserializing associated data");
442 return Err(ANNError::log_index_error("Error reading associated data"));
443 }
444 }
445 }
446
447 let mut dists: Vec<f32> = vec![0.0; npts * dim];
448 let mut buffer = vec![0; npts * dim * size_of::<f32>()];
449 file.read_exact(&mut buffer)?;
450 dists.clone_from_slice(cast_slice::<u8, f32>(&buffer));
451
452 Ok(TruthSetWithAssociatedData {
453 index_nodes: associated_data,
454 distances: Some(dists),
455 index_num_points: npts,
456 index_dimension: dim,
457 })
458}
459
460pub fn load_range_truthset(
467 storage_provider: &impl StorageReadProvider,
468 bin_file: &str,
469) -> ANNResult<RangeSearchTruthSet> {
470 let mut file = storage_provider.open_reader(bin_file)?;
471
472 let metadata = Metadata::read(&mut file)?;
473 let (npts, total_ids) = metadata.into_dims();
474 let mut buffer = [0; size_of::<i32>()];
475
476 info!("Metadata: #pts = {}, #totalIds = {}", npts, total_ids);
477
478 let mut ids: Vec<Vec<u32>> = Vec::new();
479 let mut counts: Vec<u32> = vec![0; npts];
480
481 for count in counts.iter_mut() {
482 file.read_exact(&mut buffer)?;
483 *count = i32::from_le_bytes(buffer) as u32;
484 }
485
486 for &count in &counts {
487 let mut point_ids: Vec<u32> = vec![0; count as usize];
488 let mut buffer = vec![0; count as usize * size_of::<u32>()];
489 file.read_exact(&mut buffer)?;
490 point_ids.clone_from_slice(cast_slice::<u8, u32>(&buffer));
491 ids.push(point_ids);
492 }
493
494 Ok(RangeSearchTruthSet {
495 index_nodes: ids,
496 distances: None,
497 index_num_points: npts,
498 index_dimensions: counts,
499 })
500}
501
502pub fn load_vector_filters(
504 storage_provider: &impl StorageReadProvider,
505 bin_file: &str,
506) -> ANNResult<Vec<HashSet<u32>>> {
507 let range_truthset = load_range_truthset(storage_provider, bin_file)?;
508
509 let query_filters: Vec<HashSet<u32>> = range_truthset
510 .index_nodes
511 .into_iter()
512 .map(|filter| filter.into_iter().collect())
513 .collect();
514
515 Ok(query_filters)
516}
517
518#[cfg(test)]
519mod test_search_index_utils {
520 use super::*;
521
522 struct ExpectedRecall {
523 pub recall_k: usize,
524 pub recall_n: usize,
525 pub components: Vec<usize>,
527 }
528
529 impl ExpectedRecall {
530 fn new(recall_k: usize, recall_n: usize, components: Vec<usize>) -> Self {
531 assert!(recall_k <= recall_n);
532 components.iter().for_each(|x| {
533 assert!(*x <= recall_k);
534 });
535 Self {
536 recall_k,
537 recall_n,
538 components,
539 }
540 }
541
542 fn compute(&self) -> f64 {
543 100.0 * (self.components.iter().sum::<usize>() as f64)
544 / ((self.components.len() * self.recall_k) as f64)
545 }
546 }
547
548 #[test]
549 fn test_k_recall_at_n_struct() {
550 for k in 1..=10 {
552 for n in k..=10 {
553 let v = KRecallAtN::new(k, n).unwrap();
554 assert_eq!(v.get_k(), k as usize);
555 assert_eq!(v.get_n(), n as usize);
556 }
557 }
558
559 for n in 1..=10 {
562 for k in (n + 1)..=11 {
563 let v = KRecallAtN::new(k, n).unwrap_err();
564 match v {
565 RecallBoundsError::KGreaterThanN { k: k_err, n: n_err } => {
566 assert_eq!(k_err, k);
567 assert_eq!(n_err, n);
568 }
569 RecallBoundsError::ArgumentIsZero { .. } => {
570 panic!("unreachable reached");
571 }
572 }
573 let message = format!("{}", v);
574 assert!(message.contains("recall value k"));
575 assert!(message.contains("must be less than or equal to n"));
576 assert!(message.contains(&format!("{}", k)));
577 assert!(message.contains(&format!("{}", n)));
578 }
579 }
580
581 let v = KRecallAtN::new(0, 0).unwrap_err();
583 let message = format!("{}", v);
584 assert!(message == "recall values k and n must both be non-zero");
585
586 let v = KRecallAtN::new(0, 10).unwrap_err();
588 let message = format!("{}", v);
589 assert!(message == "recall values k must be non-zero");
590
591 let v = KRecallAtN::new(10, 0).unwrap_err();
593 let message = format!("{}", v);
594 assert!(message == "recall values n must be non-zero");
595 }
596
597 #[test]
598 fn test_compute_recall() {
599 let groundtruth_dim = 10;
601 let num_queries = 4;
602
603 let groundtruth: Vec<u32> = vec![
604 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ];
609
610 assert_eq!(groundtruth.len(), num_queries * groundtruth_dim);
611
612 let distances: Vec<f32> = vec![
613 0.0, 1.0, 2.0, 3.0, 3.0, 3.0, 3.0, 4.0, 5.0, 6.0, 2.0, 3.0, 3.0, 3.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 0.0, 1.0, 2.0, 3.0, 3.0, 3.0, 3.0, 4.0, 5.0, 6.0, ];
618
619 assert_eq!(distances.len(), groundtruth.len());
620
621 let results_dim = 6;
623 let our_results: Vec<u32> = vec![
624 100, 0, 1, 2, 5, 6, 100, 101, 7, 8, 9, 10, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, ];
629 assert_eq!(our_results.len(), num_queries * results_dim);
630
631 let expected_no_ties = vec![
633 ExpectedRecall::new(1, 1, vec![0, 0, 1, 1]),
635 ExpectedRecall::new(2, 2, vec![1, 0, 2, 2]),
636 ExpectedRecall::new(3, 3, vec![2, 1, 3, 3]),
637 ExpectedRecall::new(4, 4, vec![3, 2, 4, 4]),
638 ExpectedRecall::new(5, 5, vec![3, 3, 5, 5]),
639 ExpectedRecall::new(6, 6, vec![4, 4, 6, 6]),
640 ExpectedRecall::new(1, 2, vec![1, 0, 1, 1]),
642 ExpectedRecall::new(1, 3, vec![1, 0, 1, 1]),
643 ExpectedRecall::new(2, 3, vec![2, 0, 2, 2]),
644 ExpectedRecall::new(3, 5, vec![3, 1, 3, 3]),
645 ];
646 let epsilon = 1e-6; for (i, expected) in expected_no_ties.iter().enumerate() {
648 println!("No Ties: i = {i}");
649 assert_eq!(expected.components.len(), num_queries);
650 let recall = calculate_recall(
651 num_queries,
652 &groundtruth,
653 None,
654 groundtruth_dim,
655 &our_results,
656 results_dim as u32,
657 KRecallAtN::new(expected.recall_k as u32, expected.recall_n as u32).unwrap(),
658 );
659 let left = recall.unwrap();
660 let right = expected.compute();
661 assert!(
662 (left - right).abs() < epsilon,
663 "left = {}, right = {}",
664 left,
665 right
666 );
667 }
668
669 let expected_with_ties = vec![
671 ExpectedRecall::new(1, 1, vec![0, 0, 1, 1]),
673 ExpectedRecall::new(2, 2, vec![1, 0, 2, 2]),
674 ExpectedRecall::new(3, 3, vec![2, 1, 3, 3]),
675 ExpectedRecall::new(4, 4, vec![3, 2, 4, 4]),
676 ExpectedRecall::new(5, 5, vec![4, 3, 5, 5]), ExpectedRecall::new(6, 6, vec![5, 4, 6, 6]), ExpectedRecall::new(1, 2, vec![1, 0, 1, 1]),
680 ExpectedRecall::new(1, 3, vec![1, 0, 1, 1]),
681 ExpectedRecall::new(2, 3, vec![2, 1, 2, 2]),
682 ExpectedRecall::new(4, 5, vec![4, 3, 4, 4]),
683 ];
684
685 for (i, expected) in expected_with_ties.iter().enumerate() {
686 println!("With Ties: i = {i}");
687 assert_eq!(expected.components.len(), num_queries);
688 let recall = calculate_recall(
689 num_queries,
690 &groundtruth,
691 Some(&distances),
692 groundtruth_dim,
693 &our_results,
694 results_dim as u32,
695 KRecallAtN::new(expected.recall_k as u32, expected.recall_n as u32).unwrap(),
696 );
697 let left = recall.unwrap();
698 let right = expected.compute();
699 assert!(
700 (left - right).abs() < epsilon,
701 "left = {}, right = {}",
702 left,
703 right
704 );
705 }
706 }
707
708 #[test]
709 fn test_calculate_range_search_recall() {
710 assert_eq!(
711 calculate_range_search_recall(1, &[vec![5, 6],], &[vec![5, 6, 7, 8, 9],]).unwrap(),
712 100.0,
713 "Returned more results than ground truth"
714 );
715
716 assert_eq!(
717 calculate_range_search_recall(1, &[vec![0, 1, 2, 3, 4],], &[vec![0, 1],]).unwrap(),
718 40.0,
719 "Returned less results than ground truth"
720 );
721
722 let groundtruth: Vec<Vec<u32>> = vec![vec![0, 1, 2, 3, 4], vec![5, 6]];
723
724 let our_results: Vec<Vec<u32>> = vec![vec![0, 1], vec![5, 6, 7, 8, 9]];
725
726 assert_eq!(
727 calculate_range_search_recall(2, &groundtruth, &our_results).unwrap(),
728 70.0,
729 "Combination of both cases"
730 );
731
732 assert_eq!(
733 calculate_range_search_recall(1, &[vec![0, 1, 2, 3, 4],], &[vec![0, 1, 2, 3, 4],])
734 .unwrap(),
735 100.0,
736 "The result matched the ground truth"
737 );
738
739 assert_eq!(
740 calculate_range_search_recall(1, &[vec![0, 1, 2, 3, 4],], &[vec![0, 1, 12, 13, 14],])
741 .unwrap(),
742 40.0,
743 "The result partially matched the ground truth"
744 );
745
746 assert_eq!(
747 calculate_range_search_recall(1, &[vec![0; 0],], &[vec![0, 1, 2, 3, 4],]).unwrap(),
748 100.0,
749 "The empty ground truth"
750 );
751 }
752
753 #[test]
754 fn test_calculate_filtered_search_recall() {
755 let filtered_search_recall =
756 calculate_filtered_search_recall(1, None, &[vec![5, 6]], &[vec![5, 6, 7, 8, 9]], 1000)
757 .unwrap();
758 assert_eq!(
759 filtered_search_recall, 40.0,
760 "Returned more results than ground truth"
761 );
762
763 let range_search_recall =
764 calculate_range_search_recall(1, &[vec![5, 6]], &[vec![5, 6, 7, 8, 9]]).unwrap();
765 assert_eq!(
766 range_search_recall, 100.0,
767 "Returned more results than ground truth"
768 );
769
770 assert_ne!(
771 filtered_search_recall, range_search_recall,
772 "This test case showcases the difference between range and filtered search"
773 );
774
775 assert_eq!(
776 calculate_filtered_search_recall(
777 1,
778 None,
779 &[vec![0, 1, 2, 3, 4],],
780 &[vec![0, 1],],
781 1000
782 )
783 .unwrap(),
784 40.0,
785 "Returned less results than ground truth"
786 );
787
788 let groundtruth: Vec<Vec<u32>> = vec![vec![0, 1, 2, 3, 4], vec![5, 6]];
789
790 let our_results: Vec<Vec<u32>> = vec![vec![0, 1], vec![5, 6, 7, 8, 9]];
791
792 assert_eq!(
793 calculate_filtered_search_recall(2, None, &groundtruth, &our_results, 1000).unwrap(),
794 40.0,
795 "Combination of both cases"
796 );
797
798 assert_eq!(
799 calculate_filtered_search_recall(
800 1,
801 None,
802 &[vec![0, 1, 2, 3, 4],],
803 &[vec![0, 1, 2, 3, 4],],
804 1000
805 )
806 .unwrap(),
807 100.0,
808 "The result matched the ground truth"
809 );
810
811 assert_eq!(
812 calculate_filtered_search_recall(
813 1,
814 None,
815 &[vec![0, 1, 2, 3, 4],],
816 &[vec![0, 1, 12, 13, 14],],
817 1000
818 )
819 .unwrap(),
820 40.0,
821 "The result partially matched the ground truth"
822 );
823
824 assert_eq!(
825 calculate_filtered_search_recall(
826 1,
827 None,
828 &[vec![0; 0],],
829 &[vec![0, 1, 2, 3, 4],],
830 1000
831 )
832 .unwrap(),
833 100.0,
834 "The empty ground truth"
835 );
836 }
837
838 #[test]
839 fn test_calculate_filtered_search_recall_with_tie_breaking() {
840 let gt_distances: Vec<Vec<f32>> = vec![
842 vec![0.1, 0.2, 0.3, 0.3, 0.3], vec![0.1, 0.2, 0.3, 0.4, 0.5], ];
845
846 let groundtruth: Vec<Vec<u32>> = vec![
847 vec![0, 1, 2, 3, 4], vec![5, 6, 7, 8, 9],
849 ];
850
851 let our_results: Vec<Vec<u32>> = vec![
852 vec![0, 1, 3, 2, 4], vec![5, 6, 7, 8, 9], ];
855
856 assert_eq!(
858 calculate_filtered_search_recall(
859 2,
860 Some(>_distances),
861 &groundtruth,
862 &our_results,
863 3 )
865 .unwrap(),
866 80.0, "Tie-breaking should include all tied elements"
869 );
870
871 assert_eq!(
873 calculate_filtered_search_recall(2, None, &groundtruth, &our_results, 3).unwrap(),
874 60.0, "Without tie-breaking, both queries should match on 3 of 5 elements"
876 );
877
878 assert_eq!(
880 calculate_filtered_search_recall(2, None, &groundtruth, &our_results, 10).unwrap(),
881 100.0,
882 "Without tie-breaking and with large k, both queries should match on all elements"
883 );
884 }
885
886 #[test]
887 fn test_calculate_filtered_search_recall_empty_ground_truth() {
888 assert_eq!(
889 calculate_filtered_search_recall(
890 2,
891 Some(&[vec![], vec![]]),
892 &[vec![], vec![]],
893 &[vec![0, 1, 2], vec![5, 6, 7],],
894 1
895 )
896 .unwrap(),
897 100.0,
898 "Empty ground truth should result in 100% recall"
899 );
900 }
901
902 #[test]
903 fn test_recall_bounds_error_display() {
904 let error = RecallBoundsError::KGreaterThanN { k: 10, n: 5 };
905 let message = format!("{}", error);
906 assert!(message.contains("recall value k"));
907 assert!(message.contains("must be less than or equal to n"));
908
909 let error = RecallBoundsError::ArgumentIsZero { k: 0, n: 0 };
910 let message = format!("{}", error);
911 assert_eq!(message, "recall values k and n must both be non-zero");
912
913 let error = RecallBoundsError::ArgumentIsZero { k: 0, n: 5 };
914 let message = format!("{}", error);
915 assert_eq!(message, "recall values k must be non-zero");
916
917 let error = RecallBoundsError::ArgumentIsZero { k: 5, n: 0 };
918 let message = format!("{}", error);
919 assert_eq!(message, "recall values n must be non-zero");
920 }
921
922 #[test]
923 fn test_recall_bounds_error_conversion() {
924 let error = RecallBoundsError::KGreaterThanN { k: 10, n: 5 };
925 let cmd_error: CMDToolError = error.into();
926 assert!(!cmd_error.details.is_empty());
927 }
928
929 #[test]
930 fn test_k_recall_at_n_getters() {
931 let recall = KRecallAtN::new(5, 10).unwrap();
932 assert_eq!(recall.get_k(), 5);
933 assert_eq!(recall.get_n(), 10);
934 }
935
936 #[test]
937 fn test_k_recall_at_n_equal_values() {
938 let recall = KRecallAtN::new(5, 5).unwrap();
939 assert_eq!(recall.get_k(), 5);
940 assert_eq!(recall.get_n(), 5);
941 }
942}