Skip to main content

diskann_tools/utils/
search_index_utils.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5use std::{collections::HashSet, fmt, hash::Hash, io::Read, mem::size_of};
6
7use bytemuck::cast_slice;
8use diskann::{ANNError, ANNResult};
9use diskann_providers::model::graph::traits::GraphDataType;
10use diskann_providers::storage::StorageReadProvider;
11use diskann_utils::io::Metadata;
12use tracing::{error, info};
13
14use crate::utils::CMDToolError;
15
16pub struct TruthSet {
17    pub index_nodes: Vec<u32>,
18    pub distances: Option<Vec<f32>>,
19    pub index_num_points: usize,
20    pub index_dimension: usize,
21}
22
23pub struct TruthSetWithAssociatedData<Data: GraphDataType> {
24    pub index_nodes: Vec<<Data as GraphDataType>::AssociatedDataType>,
25    pub distances: Option<Vec<f32>>,
26    pub index_num_points: usize,
27    pub index_dimension: usize,
28}
29
30pub struct RangeSearchTruthSet {
31    pub index_nodes: Vec<Vec<u32>>,
32    pub distances: Option<Vec<Vec<f32>>>,
33    pub index_num_points: usize,
34    pub index_dimensions: Vec<u32>,
35}
36
37/// A struct used to indicate the bounds `k` and `n` for recall computation where:
38///
39/// * k: Is the number of ground truth neighbors to use.
40/// * n: Is the number of retrieved neighbors.
41///
42/// Recall call is measured as the fraction of the `k` ground truth neighbors that are
43/// present in the `n` retrieved neighbors.
44///
45/// We make a deliberate choice that the invariant `k <= n` must hold (don't search for
46/// more ground truth neighbors than those actually retrieved.
47///
48/// Furthermore, both `n` and `k` must be non-zero.
49///
50/// The constructor `new` should be used instead of direct construction because it will
51/// enforce this invariant.
52#[derive(Debug, Clone, Copy)]
53pub struct KRecallAtN {
54    k: u32,
55    n: u32,
56}
57
58#[derive(Debug, Clone, Copy)]
59pub enum RecallBoundsError {
60    // Error when `k` is assigned a higher value than `n`.
61    KGreaterThanN { k: u32, n: u32 },
62    // Both arguments must be non-zero.
63    // This alternative captures both values to provide a better error message in the case
64    // that both arguments are zero.
65    ArgumentIsZero { k: u32, n: u32 },
66}
67impl fmt::Display for RecallBoundsError {
68    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
69        match self {
70            RecallBoundsError::KGreaterThanN { k, n } => {
71                write!(
72                    f,
73                    "recall value k ({}) must be less than or equal to n ({})",
74                    k, n
75                )
76            }
77            // Match the various argument-is-zero cases.
78            RecallBoundsError::ArgumentIsZero { k, n } => {
79                if *k == 0 && *n == 0 {
80                    write!(f, "recall values k and n must both be non-zero")
81                } else if *k == 0 {
82                    write!(f, "recall values k must be non-zero")
83                } else {
84                    write!(f, "recall values n must be non-zero")
85                }
86            }
87        }
88    }
89}
90
91// opt-in to error reporting.
92impl std::error::Error for RecallBoundsError {}
93
94// Allow conversion to `ANNError` for error propagation.
95impl From<RecallBoundsError> for CMDToolError {
96    fn from(err: RecallBoundsError) -> Self {
97        CMDToolError {
98            details: err.to_string(),
99        }
100    }
101}
102
103impl KRecallAtN {
104    /// Construct a new instance of this class.
105    ///
106    /// If the invariant `k <= n` does not hold, than return the error type `KGreaterThanNError`.
107    pub fn new(k: u32, n: u32) -> Result<Self, RecallBoundsError> {
108        if k == 0 || n == 0 {
109            Err(RecallBoundsError::ArgumentIsZero { k, n })
110        } else if k > n {
111            Err(RecallBoundsError::KGreaterThanN { k, n })
112        } else {
113            Ok(KRecallAtN { k, n })
114        }
115    }
116
117    pub fn get_k(self) -> usize {
118        self.k as usize
119    }
120
121    pub fn get_n(self) -> usize {
122        self.n as usize
123    }
124}
125
126/// Calculate the intersection between the top `k` ground truth elements and the top `n`
127/// obtained results.
128#[allow(clippy::too_many_arguments)]
129pub fn calculate_recall<T: Eq + Hash + Copy>(
130    num_queries: usize,
131    ground_truth: &[T],
132    gt_dist: Option<&Vec<f32>>,
133    dim_gt: usize,
134    our_results: &[T],
135    dim_or: u32,
136    recall_bounds: KRecallAtN,
137) -> ANNResult<f64> {
138    let mut total_recall: f64 = 0.0;
139    let (mut gt, mut res): (HashSet<T>, HashSet<T>) = (HashSet::new(), HashSet::new());
140
141    for i in 0..num_queries {
142        gt.clear();
143        res.clear();
144
145        let gt_slice = &ground_truth[dim_gt * i..];
146        let res_slice = &our_results[dim_or as usize * i..];
147        let mut tie_breaker = recall_bounds.get_k();
148
149        if let Some(gt_dist) = gt_dist {
150            let gt_dist_vec = &gt_dist[dim_gt * i..];
151            while tie_breaker < dim_gt
152                && gt_dist_vec[tie_breaker] == gt_dist_vec[recall_bounds.get_k() - 1]
153            {
154                tie_breaker += 1;
155            }
156        }
157
158        (0..tie_breaker).for_each(|idx| {
159            gt.insert(gt_slice[idx]);
160        });
161
162        (0..recall_bounds.get_n()).for_each(|idx| {
163            res.insert(res_slice[idx]);
164        });
165
166        let mut cur_recall: u32 = 0;
167        for v in gt.iter() {
168            if res.contains(v) && cur_recall < recall_bounds.get_k() as u32 {
169                cur_recall += 1;
170            }
171        }
172
173        total_recall += cur_recall as f64;
174    }
175
176    Ok(total_recall / num_queries as f64 * (100.0 / recall_bounds.get_k() as f64))
177}
178
179pub fn calculate_range_search_recall(
180    num_queries: u32,
181    groundtruth: &[Vec<u32>],
182    our_results: &[Vec<u32>],
183) -> ANNResult<f64> {
184    let mut total_recall = 0.0;
185    for i in 0..num_queries as usize {
186        let mut gt: HashSet<u32> = HashSet::new();
187        let mut res: HashSet<u32> = HashSet::new();
188
189        for &item in &groundtruth[i] {
190            gt.insert(item);
191        }
192
193        for &item in &our_results[i] {
194            res.insert(item);
195        }
196
197        let mut cur_recall = 0;
198        for &v in &gt {
199            if res.contains(&v) {
200                cur_recall += 1;
201            }
202        }
203
204        if !gt.is_empty() {
205            total_recall += (100.0 * cur_recall as f64) / gt.len() as f64;
206        } else {
207            total_recall += 100.0;
208        }
209    }
210
211    Ok(total_recall / num_queries as f64)
212}
213
214/// Calculates the filtered search recall for a set of queries.
215///
216/// This function computes the recall percentage for a filtered search scenario, where the recall
217/// is calculated as the percentage of ground truth elements that are present in the retrieved
218/// results, normalized by the maximum of min(k_recall, ground truth size) and retrieved results length.
219///
220/// # Arguments
221///
222/// * `num_queries` - The number of queries for which recall is being calculated.
223/// * `gt_dist` - An optional vector of distances for the ground truth elements.
224/// * `groundtruth` - A slice of vectors, where each vector contains the ground truth IDs for a query.
225/// * `our_results` - A slice of vectors, where each vector contains the retrieved IDs for a query.
226/// * `k_recall` - number of top results to consider from ground_truth for recall calculation. Must be greater than 0.
227///
228/// # Returns
229///
230/// Returns an `ANNResult<f64>` containing the average recall percentage across all queries.
231///
232/// # Assumptions
233///
234/// * The `groundtruth` and `our_results` slices must have the same length as `num_queries`.
235/// * Each vector in 'groundtruth' must be the same length of the corresponding vector in 'gt_dist', if 'gt_dist' is provided.
236///
237/// # Behavior
238///
239/// - If the ground truth for a query is empty, the recall for that query is considered to be 100.0.
240/// - For non-empty ground truth, the recall is calculated as:
241///   `(100.0 * number_of_matches) / max(min(k_recall value, groundtruth size), our_results size)`
242/// - When the vector in gt_dist is provided and k_recall value is less than the size of groundtruth, ties are broken.
243///
244/// # Differences from Range Search Recall
245///
246/// Unlike `calculate_range_search_recall`, this function normalizes the recall by the maximum
247/// of the sizes of the k_recall/groundtruth value and retrieved results, which can lead to different recall
248/// values in cases where the retrieved results contain more elements than the ground truth.
249pub fn calculate_filtered_search_recall(
250    num_queries: usize,
251    gt_dist: Option<&[Vec<f32>]>,
252    groundtruth: &[Vec<u32>],
253    our_results: &[Vec<u32>],
254    k_recall: u32,
255) -> ANNResult<f64> {
256    if k_recall == 0 {
257        return Err(ANNError::log_index_error(format_args!(
258            "k_recall value must be greater than 0, but got {}",
259            k_recall
260        )));
261    }
262
263    if groundtruth.len() != num_queries || our_results.len() != num_queries {
264        return Err(ANNError::log_index_error(format_args!(
265            "groundtruth length ({}) or our_results length ({}) does not match num_queries ({})",
266            groundtruth.len(),
267            our_results.len(),
268            num_queries
269        )));
270    }
271
272    let mut total_recall = 0.0;
273    for i in 0..num_queries {
274        let mut gt: HashSet<u32> = HashSet::new();
275        let mut res: HashSet<u32> = HashSet::new();
276        let gt_cutoff = (k_recall as usize).min(groundtruth[i].len());
277
278        for &item in &groundtruth[i][..gt_cutoff] {
279            //only insert k items from groundtruth
280            gt.insert(item);
281        }
282
283        for &item in &our_results[i] {
284            res.insert(item);
285        }
286
287        if gt_cutoff > 0 {
288            //only break ties when groundtruth is not empty
289            if let Some(gt_dist) = gt_dist {
290                let gt_dist_vec = gt_dist[i].as_slice();
291
292                if gt_dist_vec.len() != groundtruth[i].len() {
293                    return Err(ANNError::log_index_error(format_args!(
294                        "Ground truth distance for query ({}) vector length ({}) is not equal to groundtruth len ({})",
295                        i,
296                        gt_dist_vec.len(),
297                        groundtruth[i].len(),
298                    )));
299                }
300
301                let mut tie_breaker = gt_cutoff;
302
303                while tie_breaker < gt_dist_vec.len() //while there are still ties, add them to gt
304                        && gt_dist_vec[tie_breaker] == gt_dist_vec[gt_cutoff - 1]
305                {
306                    gt.insert(groundtruth[i][tie_breaker]);
307                    tie_breaker += 1;
308                }
309            }
310        }
311
312        let mut cur_recall = 0;
313
314        for &v in &gt {
315            if res.contains(&v) {
316                cur_recall += 1;
317            }
318        }
319
320        if gt_cutoff > 0 {
321            total_recall += (100.0 * cur_recall as f64) / gt_cutoff.max(res.len()) as f64;
322        } else {
323            total_recall += 100.0;
324        }
325    }
326
327    Ok(total_recall / num_queries as f64)
328}
329
330pub fn get_graph_num_frozen_points(
331    storage_provider: &impl StorageReadProvider,
332    graph_file: &str,
333) -> ANNResult<usize> {
334    let mut file = storage_provider.open_reader(graph_file)?;
335    let mut usize_buffer = [0; size_of::<usize>()];
336    let mut u32_buffer = [0; size_of::<u32>()];
337
338    file.read_exact(&mut usize_buffer)?;
339    file.read_exact(&mut u32_buffer)?;
340    file.read_exact(&mut u32_buffer)?;
341    file.read_exact(&mut usize_buffer)?;
342    let file_frozen_pts = usize::from_le_bytes(usize_buffer);
343
344    Ok(file_frozen_pts)
345}
346
347pub fn get_graph_max_observed_degree(
348    storage_provider: &impl StorageReadProvider,
349    graph_file: &str,
350) -> ANNResult<u32> {
351    let mut file = storage_provider.open_reader(graph_file)?;
352    let mut usize_buffer = [0; size_of::<usize>()];
353    let mut u32_buffer = [0; size_of::<u32>()];
354
355    file.read_exact(&mut usize_buffer)?;
356    file.read_exact(&mut u32_buffer)?;
357    let max_observed_degree = u32::from_le_bytes(u32_buffer);
358
359    Ok(max_observed_degree)
360}
361
362pub fn load_truthset(
363    storage_provider: &impl StorageReadProvider,
364    bin_file: &str,
365) -> ANNResult<TruthSet> {
366    let actual_file_size = storage_provider.get_length(bin_file)? as usize;
367    let mut file = storage_provider.open_reader(bin_file)?;
368
369    let metadata = Metadata::read(&mut file)?;
370    let (npts, dim) = metadata.into_dims();
371
372    info!("Metadata: #pts = {npts}, #dims = {dim}... ");
373
374    let expected_file_size_with_dists: usize =
375        2 * npts * dim * size_of::<u32>() + 2 * size_of::<u32>();
376    let expected_file_size_just_ids: usize = npts * dim * size_of::<u32>() + 2 * size_of::<u32>();
377
378    // truthset_type: 1 = ids + distances, 2 = ids only
379    let truthset_type : i32 = match actual_file_size {
380        x if x == expected_file_size_with_dists => 1,
381        x if x == expected_file_size_just_ids => 2,
382        _ => return Err(ANNError::log_index_error(format_args!(
383            "Error. File size mismatch. File should have bin format, with npts followed by ngt followed by npts*ngt ids and optionally followed by npts*ngt distance values; actual size: {}, expected: {} or {}",
384            actual_file_size,
385            expected_file_size_with_dists,
386            expected_file_size_just_ids
387        )))
388    };
389
390    let mut ids: Vec<u32> = vec![0; npts * dim];
391    let mut buffer = vec![0; npts * dim * size_of::<u32>()];
392    file.read_exact(&mut buffer)?;
393    ids.clone_from_slice(cast_slice::<u8, u32>(&buffer));
394
395    if truthset_type == 1 {
396        let mut dists: Vec<f32> = vec![0.0; npts * dim];
397        let mut buffer = vec![0; npts * dim * size_of::<f32>()];
398        file.read_exact(&mut buffer)?;
399        dists.clone_from_slice(cast_slice::<u8, f32>(&buffer));
400
401        return Ok(TruthSet {
402            index_nodes: ids,
403            distances: Some(dists),
404            index_num_points: npts,
405            index_dimension: dim,
406        });
407    }
408
409    Ok(TruthSet {
410        index_nodes: ids,
411        distances: None,
412        index_num_points: npts,
413        index_dimension: dim,
414    })
415}
416
417pub fn load_truthset_with_associated_data<Data: GraphDataType>(
418    storage_provider: &impl StorageReadProvider,
419    bin_file: &str,
420) -> ANNResult<TruthSetWithAssociatedData<Data>> {
421    let mut file = storage_provider.open_reader(bin_file)?;
422
423    let metadata = Metadata::read(&mut file)?;
424    let (npts, dim) = metadata.into_dims();
425
426    info!("Metadata: #pts = {}, #dims = {}...", npts, dim);
427
428    let mut associated_data: Vec<Data::AssociatedDataType> =
429        vec![Data::AssociatedDataType::default(); npts * dim];
430
431    for associated_datum in associated_data.iter_mut().take(npts * dim) {
432        let mut associated_data_buf = vec![0u8; size_of::<Data::AssociatedDataType>()];
433        file.read_exact(&mut associated_data_buf)
434            .map_err(ANNError::log_io_error)?;
435
436        match bincode::deserialize::<Data::AssociatedDataType>(&associated_data_buf) {
437            Ok(datum) => {
438                *associated_datum = datum;
439            }
440            Err(_) => {
441                error!("Error deserializing associated data");
442                return Err(ANNError::log_index_error("Error reading associated data"));
443            }
444        }
445    }
446
447    let mut dists: Vec<f32> = vec![0.0; npts * dim];
448    let mut buffer = vec![0; npts * dim * size_of::<f32>()];
449    file.read_exact(&mut buffer)?;
450    dists.clone_from_slice(cast_slice::<u8, f32>(&buffer));
451
452    Ok(TruthSetWithAssociatedData {
453        index_nodes: associated_data,
454        distances: Some(dists),
455        index_num_points: npts,
456        index_dimension: dim,
457    })
458}
459
460// Load the range truthset from the file.
461// The format of the file is as follows:
462// 1. The first 4 bytes are the number of queries.
463// 2. The next 4 bytes are the number of total vector ids.
464// 3. The next (queries * 4) bytes are the numbers of vector ids for each query.
465// 4. The next (total vector ids * 4) bytes are vector ids for each query.
466pub fn load_range_truthset(
467    storage_provider: &impl StorageReadProvider,
468    bin_file: &str,
469) -> ANNResult<RangeSearchTruthSet> {
470    let mut file = storage_provider.open_reader(bin_file)?;
471
472    let metadata = Metadata::read(&mut file)?;
473    let (npts, total_ids) = metadata.into_dims();
474    let mut buffer = [0; size_of::<i32>()];
475
476    info!("Metadata: #pts = {}, #totalIds = {}", npts, total_ids);
477
478    let mut ids: Vec<Vec<u32>> = Vec::new();
479    let mut counts: Vec<u32> = vec![0; npts];
480
481    for count in counts.iter_mut() {
482        file.read_exact(&mut buffer)?;
483        *count = i32::from_le_bytes(buffer) as u32;
484    }
485
486    for &count in &counts {
487        let mut point_ids: Vec<u32> = vec![0; count as usize];
488        let mut buffer = vec![0; count as usize * size_of::<u32>()];
489        file.read_exact(&mut buffer)?;
490        point_ids.clone_from_slice(cast_slice::<u8, u32>(&buffer));
491        ids.push(point_ids);
492    }
493
494    Ok(RangeSearchTruthSet {
495        index_nodes: ids,
496        distances: None,
497        index_num_points: npts,
498        index_dimensions: counts,
499    })
500}
501
502// Load the vector filters from the file in the range truthset format.
503pub fn load_vector_filters(
504    storage_provider: &impl StorageReadProvider,
505    bin_file: &str,
506) -> ANNResult<Vec<HashSet<u32>>> {
507    let range_truthset = load_range_truthset(storage_provider, bin_file)?;
508
509    let query_filters: Vec<HashSet<u32>> = range_truthset
510        .index_nodes
511        .into_iter()
512        .map(|filter| filter.into_iter().collect())
513        .collect();
514
515    Ok(query_filters)
516}
517
518#[cfg(test)]
519mod test_search_index_utils {
520    use super::*;
521
522    struct ExpectedRecall {
523        pub recall_k: usize,
524        pub recall_n: usize,
525        // Recall for each component.
526        pub components: Vec<usize>,
527    }
528
529    impl ExpectedRecall {
530        fn new(recall_k: usize, recall_n: usize, components: Vec<usize>) -> Self {
531            assert!(recall_k <= recall_n);
532            components.iter().for_each(|x| {
533                assert!(*x <= recall_k);
534            });
535            Self {
536                recall_k,
537                recall_n,
538                components,
539            }
540        }
541
542        fn compute(&self) -> f64 {
543            100.0 * (self.components.iter().sum::<usize>() as f64)
544                / ((self.components.len() * self.recall_k) as f64)
545        }
546    }
547
548    #[test]
549    fn test_k_recall_at_n_struct() {
550        // Happy paths should succeed.
551        for k in 1..=10 {
552            for n in k..=10 {
553                let v = KRecallAtN::new(k, n).unwrap();
554                assert_eq!(v.get_k(), k as usize);
555                assert_eq!(v.get_n(), n as usize);
556            }
557        }
558
559        // Error paths.
560        // N.B.: Note the inversion of `k` and `n` in the loop bounds!
561        for n in 1..=10 {
562            for k in (n + 1)..=11 {
563                let v = KRecallAtN::new(k, n).unwrap_err();
564                match v {
565                    RecallBoundsError::KGreaterThanN { k: k_err, n: n_err } => {
566                        assert_eq!(k_err, k);
567                        assert_eq!(n_err, n);
568                    }
569                    RecallBoundsError::ArgumentIsZero { .. } => {
570                        panic!("unreachable reached");
571                    }
572                }
573                let message = format!("{}", v);
574                assert!(message.contains("recall value k"));
575                assert!(message.contains("must be less than or equal to n"));
576                assert!(message.contains(&format!("{}", k)));
577                assert!(message.contains(&format!("{}", n)));
578            }
579        }
580
581        // both zero
582        let v = KRecallAtN::new(0, 0).unwrap_err();
583        let message = format!("{}", v);
584        assert!(message == "recall values k and n must both be non-zero");
585
586        // k is zero
587        let v = KRecallAtN::new(0, 10).unwrap_err();
588        let message = format!("{}", v);
589        assert!(message == "recall values k must be non-zero");
590
591        // n is zero
592        let v = KRecallAtN::new(10, 0).unwrap_err();
593        let message = format!("{}", v);
594        assert!(message == "recall values n must be non-zero");
595    }
596
597    #[test]
598    fn test_compute_recall() {
599        // Set up examples for ground truth and retrieved IDs.
600        let groundtruth_dim = 10;
601        let num_queries = 4;
602
603        let groundtruth: Vec<u32> = vec![
604            0, 1, 2, 3, 4, 5, 6, 7, 8, 9, // row 0
605            5, 6, 7, 8, 9, 10, 11, 12, 13, 14, // row 1
606            0, 1, 2, 3, 4, 5, 6, 7, 8, 9, // row 2
607            0, 1, 2, 3, 4, 5, 6, 7, 8, 9, // row 3
608        ];
609
610        assert_eq!(groundtruth.len(), num_queries * groundtruth_dim);
611
612        let distances: Vec<f32> = vec![
613            0.0, 1.0, 2.0, 3.0, 3.0, 3.0, 3.0, 4.0, 5.0, 6.0, // row 0
614            2.0, 3.0, 3.0, 3.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, // row 1
615            0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, // row 2
616            0.0, 1.0, 2.0, 3.0, 3.0, 3.0, 3.0, 4.0, 5.0, 6.0, // row 3
617        ];
618
619        assert_eq!(distances.len(), groundtruth.len());
620
621        // Shift row 0 by one and row 1 by two.
622        let results_dim = 6;
623        let our_results: Vec<u32> = vec![
624            100, 0, 1, 2, 5, 6, // row 0
625            100, 101, 7, 8, 9, 10, // row 1
626            0, 1, 2, 3, 4, 5, // row 2
627            0, 1, 2, 3, 4, 5, // row 3
628        ];
629        assert_eq!(our_results.len(), num_queries * results_dim);
630
631        // No ties
632        let expected_no_ties = vec![
633            // Equal `k` and `n`
634            ExpectedRecall::new(1, 1, vec![0, 0, 1, 1]),
635            ExpectedRecall::new(2, 2, vec![1, 0, 2, 2]),
636            ExpectedRecall::new(3, 3, vec![2, 1, 3, 3]),
637            ExpectedRecall::new(4, 4, vec![3, 2, 4, 4]),
638            ExpectedRecall::new(5, 5, vec![3, 3, 5, 5]),
639            ExpectedRecall::new(6, 6, vec![4, 4, 6, 6]),
640            // Unequal `k` and `n`.
641            ExpectedRecall::new(1, 2, vec![1, 0, 1, 1]),
642            ExpectedRecall::new(1, 3, vec![1, 0, 1, 1]),
643            ExpectedRecall::new(2, 3, vec![2, 0, 2, 2]),
644            ExpectedRecall::new(3, 5, vec![3, 1, 3, 3]),
645        ];
646        let epsilon = 1e-6; // Define a small tolerance
647        for (i, expected) in expected_no_ties.iter().enumerate() {
648            println!("No Ties: i = {i}");
649            assert_eq!(expected.components.len(), num_queries);
650            let recall = calculate_recall(
651                num_queries,
652                &groundtruth,
653                None,
654                groundtruth_dim,
655                &our_results,
656                results_dim as u32,
657                KRecallAtN::new(expected.recall_k as u32, expected.recall_n as u32).unwrap(),
658            );
659            let left = recall.unwrap();
660            let right = expected.compute();
661            assert!(
662                (left - right).abs() < epsilon,
663                "left = {}, right = {}",
664                left,
665                right
666            );
667        }
668
669        // With Ties
670        let expected_with_ties = vec![
671            // Equal `k` and `n`
672            ExpectedRecall::new(1, 1, vec![0, 0, 1, 1]),
673            ExpectedRecall::new(2, 2, vec![1, 0, 2, 2]),
674            ExpectedRecall::new(3, 3, vec![2, 1, 3, 3]),
675            ExpectedRecall::new(4, 4, vec![3, 2, 4, 4]),
676            ExpectedRecall::new(5, 5, vec![4, 3, 5, 5]), // tie-breaker kicks in
677            ExpectedRecall::new(6, 6, vec![5, 4, 6, 6]), // tie-breaker kicks in
678            // Unequal `k` and `n`.
679            ExpectedRecall::new(1, 2, vec![1, 0, 1, 1]),
680            ExpectedRecall::new(1, 3, vec![1, 0, 1, 1]),
681            ExpectedRecall::new(2, 3, vec![2, 1, 2, 2]),
682            ExpectedRecall::new(4, 5, vec![4, 3, 4, 4]),
683        ];
684
685        for (i, expected) in expected_with_ties.iter().enumerate() {
686            println!("With Ties: i = {i}");
687            assert_eq!(expected.components.len(), num_queries);
688            let recall = calculate_recall(
689                num_queries,
690                &groundtruth,
691                Some(&distances),
692                groundtruth_dim,
693                &our_results,
694                results_dim as u32,
695                KRecallAtN::new(expected.recall_k as u32, expected.recall_n as u32).unwrap(),
696            );
697            let left = recall.unwrap();
698            let right = expected.compute();
699            assert!(
700                (left - right).abs() < epsilon,
701                "left = {}, right = {}",
702                left,
703                right
704            );
705        }
706    }
707
708    #[test]
709    fn test_calculate_range_search_recall() {
710        assert_eq!(
711            calculate_range_search_recall(1, &[vec![5, 6],], &[vec![5, 6, 7, 8, 9],]).unwrap(),
712            100.0,
713            "Returned more results than ground truth"
714        );
715
716        assert_eq!(
717            calculate_range_search_recall(1, &[vec![0, 1, 2, 3, 4],], &[vec![0, 1],]).unwrap(),
718            40.0,
719            "Returned less results than ground truth"
720        );
721
722        let groundtruth: Vec<Vec<u32>> = vec![vec![0, 1, 2, 3, 4], vec![5, 6]];
723
724        let our_results: Vec<Vec<u32>> = vec![vec![0, 1], vec![5, 6, 7, 8, 9]];
725
726        assert_eq!(
727            calculate_range_search_recall(2, &groundtruth, &our_results).unwrap(),
728            70.0,
729            "Combination of both cases"
730        );
731
732        assert_eq!(
733            calculate_range_search_recall(1, &[vec![0, 1, 2, 3, 4],], &[vec![0, 1, 2, 3, 4],])
734                .unwrap(),
735            100.0,
736            "The result matched the ground truth"
737        );
738
739        assert_eq!(
740            calculate_range_search_recall(1, &[vec![0, 1, 2, 3, 4],], &[vec![0, 1, 12, 13, 14],])
741                .unwrap(),
742            40.0,
743            "The result partially matched the ground truth"
744        );
745
746        assert_eq!(
747            calculate_range_search_recall(1, &[vec![0; 0],], &[vec![0, 1, 2, 3, 4],]).unwrap(),
748            100.0,
749            "The empty ground truth"
750        );
751    }
752
753    #[test]
754    fn test_calculate_filtered_search_recall() {
755        let filtered_search_recall =
756            calculate_filtered_search_recall(1, None, &[vec![5, 6]], &[vec![5, 6, 7, 8, 9]], 1000)
757                .unwrap();
758        assert_eq!(
759            filtered_search_recall, 40.0,
760            "Returned more results than ground truth"
761        );
762
763        let range_search_recall =
764            calculate_range_search_recall(1, &[vec![5, 6]], &[vec![5, 6, 7, 8, 9]]).unwrap();
765        assert_eq!(
766            range_search_recall, 100.0,
767            "Returned more results than ground truth"
768        );
769
770        assert_ne!(
771            filtered_search_recall, range_search_recall,
772            "This test case showcases the difference between range and filtered search"
773        );
774
775        assert_eq!(
776            calculate_filtered_search_recall(
777                1,
778                None,
779                &[vec![0, 1, 2, 3, 4],],
780                &[vec![0, 1],],
781                1000
782            )
783            .unwrap(),
784            40.0,
785            "Returned less results than ground truth"
786        );
787
788        let groundtruth: Vec<Vec<u32>> = vec![vec![0, 1, 2, 3, 4], vec![5, 6]];
789
790        let our_results: Vec<Vec<u32>> = vec![vec![0, 1], vec![5, 6, 7, 8, 9]];
791
792        assert_eq!(
793            calculate_filtered_search_recall(2, None, &groundtruth, &our_results, 1000).unwrap(),
794            40.0,
795            "Combination of both cases"
796        );
797
798        assert_eq!(
799            calculate_filtered_search_recall(
800                1,
801                None,
802                &[vec![0, 1, 2, 3, 4],],
803                &[vec![0, 1, 2, 3, 4],],
804                1000
805            )
806            .unwrap(),
807            100.0,
808            "The result matched the ground truth"
809        );
810
811        assert_eq!(
812            calculate_filtered_search_recall(
813                1,
814                None,
815                &[vec![0, 1, 2, 3, 4],],
816                &[vec![0, 1, 12, 13, 14],],
817                1000
818            )
819            .unwrap(),
820            40.0,
821            "The result partially matched the ground truth"
822        );
823
824        assert_eq!(
825            calculate_filtered_search_recall(
826                1,
827                None,
828                &[vec![0; 0],],
829                &[vec![0, 1, 2, 3, 4],],
830                1000
831            )
832            .unwrap(),
833            100.0,
834            "The empty ground truth"
835        );
836    }
837
838    #[test]
839    fn test_calculate_filtered_search_recall_with_tie_breaking() {
840        // Ground truth with distances
841        let gt_distances: Vec<Vec<f32>> = vec![
842            vec![0.1, 0.2, 0.3, 0.3, 0.3], // Ties at index 2, 3, 4
843            vec![0.1, 0.2, 0.3, 0.4, 0.5], // No ties
844        ];
845
846        let groundtruth: Vec<Vec<u32>> = vec![
847            vec![0, 1, 2, 3, 4], // Ground truth IDs
848            vec![5, 6, 7, 8, 9],
849        ];
850
851        let our_results: Vec<Vec<u32>> = vec![
852            vec![0, 1, 3, 2, 4], // Matches all ground truth
853            vec![5, 6, 7, 8, 9], // Matches all ground truth
854        ];
855
856        // Test with tie-breaking
857        assert_eq!(
858            calculate_filtered_search_recall(
859                2,
860                Some(&gt_distances),
861                &groundtruth,
862                &our_results,
863                3 // k_recall
864            )
865            .unwrap(),
866            80.0, //query 0 has 5 matches including ties and 5 returned results => 100% recall
867            // query 1 has 3 matches and 5 returned results => 60% recall
868            "Tie-breaking should include all tied elements"
869        );
870
871        // Test without tie-breaking
872        assert_eq!(
873            calculate_filtered_search_recall(2, None, &groundtruth, &our_results, 3).unwrap(),
874            60.0, // both queries have 3 matches and 5 returned results => 60% recall
875            "Without tie-breaking, both queries should match on 3 of 5 elements"
876        );
877
878        // Test without tie-breaking and large k
879        assert_eq!(
880            calculate_filtered_search_recall(2, None, &groundtruth, &our_results, 10).unwrap(),
881            100.0,
882            "Without tie-breaking and with large k, both queries should match on all elements"
883        );
884    }
885
886    #[test]
887    fn test_calculate_filtered_search_recall_empty_ground_truth() {
888        assert_eq!(
889            calculate_filtered_search_recall(
890                2,
891                Some(&[vec![], vec![]]),
892                &[vec![], vec![]],
893                &[vec![0, 1, 2], vec![5, 6, 7],],
894                1
895            )
896            .unwrap(),
897            100.0,
898            "Empty ground truth should result in 100% recall"
899        );
900    }
901
902    #[test]
903    fn test_recall_bounds_error_display() {
904        let error = RecallBoundsError::KGreaterThanN { k: 10, n: 5 };
905        let message = format!("{}", error);
906        assert!(message.contains("recall value k"));
907        assert!(message.contains("must be less than or equal to n"));
908
909        let error = RecallBoundsError::ArgumentIsZero { k: 0, n: 0 };
910        let message = format!("{}", error);
911        assert_eq!(message, "recall values k and n must both be non-zero");
912
913        let error = RecallBoundsError::ArgumentIsZero { k: 0, n: 5 };
914        let message = format!("{}", error);
915        assert_eq!(message, "recall values k must be non-zero");
916
917        let error = RecallBoundsError::ArgumentIsZero { k: 5, n: 0 };
918        let message = format!("{}", error);
919        assert_eq!(message, "recall values n must be non-zero");
920    }
921
922    #[test]
923    fn test_recall_bounds_error_conversion() {
924        let error = RecallBoundsError::KGreaterThanN { k: 10, n: 5 };
925        let cmd_error: CMDToolError = error.into();
926        assert!(!cmd_error.details.is_empty());
927    }
928
929    #[test]
930    fn test_k_recall_at_n_getters() {
931        let recall = KRecallAtN::new(5, 10).unwrap();
932        assert_eq!(recall.get_k(), 5);
933        assert_eq!(recall.get_n(), 10);
934    }
935
936    #[test]
937    fn test_k_recall_at_n_equal_values() {
938        let recall = KRecallAtN::new(5, 5).unwrap();
939        assert_eq!(recall.get_k(), 5);
940        assert_eq!(recall.get_n(), 5);
941    }
942}