Skip to main content

diskann_tools/utils/
search_index_utils.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5use std::{collections::HashSet, fmt, hash::Hash, io::Read, mem::size_of};
6
7use bytemuck::cast_slice;
8use diskann::{ANNError, ANNResult};
9use diskann_providers::storage::StorageReadProvider;
10use diskann_utils::io::Metadata;
11use tracing::info;
12
13use crate::utils::CMDToolError;
14
15pub struct TruthSet {
16    pub index_nodes: Vec<u32>,
17    pub distances: Option<Vec<f32>>,
18    pub index_num_points: usize,
19    pub index_dimension: usize,
20}
21
22pub struct RangeSearchTruthSet {
23    pub index_nodes: Vec<Vec<u32>>,
24    pub distances: Option<Vec<Vec<f32>>>,
25    pub index_num_points: usize,
26    pub index_dimensions: Vec<u32>,
27}
28
29/// A struct used to indicate the bounds `k` and `n` for recall computation where:
30///
31/// * k: Is the number of ground truth neighbors to use.
32/// * n: Is the number of retrieved neighbors.
33///
34/// Recall call is measured as the fraction of the `k` ground truth neighbors that are
35/// present in the `n` retrieved neighbors.
36///
37/// We make a deliberate choice that the invariant `k <= n` must hold (don't search for
38/// more ground truth neighbors than those actually retrieved.
39///
40/// Furthermore, both `n` and `k` must be non-zero.
41///
42/// The constructor `new` should be used instead of direct construction because it will
43/// enforce this invariant.
44#[derive(Debug, Clone, Copy)]
45pub struct KRecallAtN {
46    k: u32,
47    n: u32,
48}
49
50#[derive(Debug, Clone, Copy)]
51pub enum RecallBoundsError {
52    // Error when `k` is assigned a higher value than `n`.
53    KGreaterThanN { k: u32, n: u32 },
54    // Both arguments must be non-zero.
55    // This alternative captures both values to provide a better error message in the case
56    // that both arguments are zero.
57    ArgumentIsZero { k: u32, n: u32 },
58}
59impl fmt::Display for RecallBoundsError {
60    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
61        match self {
62            RecallBoundsError::KGreaterThanN { k, n } => {
63                write!(
64                    f,
65                    "recall value k ({}) must be less than or equal to n ({})",
66                    k, n
67                )
68            }
69            // Match the various argument-is-zero cases.
70            RecallBoundsError::ArgumentIsZero { k, n } => {
71                if *k == 0 && *n == 0 {
72                    write!(f, "recall values k and n must both be non-zero")
73                } else if *k == 0 {
74                    write!(f, "recall values k must be non-zero")
75                } else {
76                    write!(f, "recall values n must be non-zero")
77                }
78            }
79        }
80    }
81}
82
83// opt-in to error reporting.
84impl std::error::Error for RecallBoundsError {}
85
86// Allow conversion to `ANNError` for error propagation.
87impl From<RecallBoundsError> for CMDToolError {
88    fn from(err: RecallBoundsError) -> Self {
89        CMDToolError {
90            details: err.to_string(),
91        }
92    }
93}
94
95impl KRecallAtN {
96    /// Construct a new instance of this class.
97    ///
98    /// If the invariant `k <= n` does not hold, than return the error type `KGreaterThanNError`.
99    pub fn new(k: u32, n: u32) -> Result<Self, RecallBoundsError> {
100        if k == 0 || n == 0 {
101            Err(RecallBoundsError::ArgumentIsZero { k, n })
102        } else if k > n {
103            Err(RecallBoundsError::KGreaterThanN { k, n })
104        } else {
105            Ok(KRecallAtN { k, n })
106        }
107    }
108
109    pub fn get_k(self) -> usize {
110        self.k as usize
111    }
112
113    pub fn get_n(self) -> usize {
114        self.n as usize
115    }
116}
117
118/// Calculate the intersection between the top `k` ground truth elements and the top `n`
119/// obtained results.
120#[allow(clippy::too_many_arguments)]
121pub fn calculate_recall<T: Eq + Hash + Copy>(
122    num_queries: usize,
123    ground_truth: &[T],
124    gt_dist: Option<&Vec<f32>>,
125    dim_gt: usize,
126    our_results: &[T],
127    dim_or: u32,
128    recall_bounds: KRecallAtN,
129) -> ANNResult<f64> {
130    let mut total_recall: f64 = 0.0;
131    let (mut gt, mut res): (HashSet<T>, HashSet<T>) = (HashSet::new(), HashSet::new());
132
133    for i in 0..num_queries {
134        gt.clear();
135        res.clear();
136
137        let gt_slice = &ground_truth[dim_gt * i..];
138        let res_slice = &our_results[dim_or as usize * i..];
139        let mut tie_breaker = recall_bounds.get_k();
140
141        if let Some(gt_dist) = gt_dist {
142            let gt_dist_vec = &gt_dist[dim_gt * i..];
143            while tie_breaker < dim_gt
144                && gt_dist_vec[tie_breaker] == gt_dist_vec[recall_bounds.get_k() - 1]
145            {
146                tie_breaker += 1;
147            }
148        }
149
150        (0..tie_breaker).for_each(|idx| {
151            gt.insert(gt_slice[idx]);
152        });
153
154        (0..recall_bounds.get_n()).for_each(|idx| {
155            res.insert(res_slice[idx]);
156        });
157
158        let mut cur_recall: u32 = 0;
159        for v in gt.iter() {
160            if res.contains(v) && cur_recall < recall_bounds.get_k() as u32 {
161                cur_recall += 1;
162            }
163        }
164
165        total_recall += cur_recall as f64;
166    }
167
168    Ok(total_recall / num_queries as f64 * (100.0 / recall_bounds.get_k() as f64))
169}
170
171/// Calculates the filtered search recall for a set of queries.
172///
173/// This function computes the recall percentage for a filtered search scenario, where each
174/// query has an associated set of ground truth IDs and a set of retrieved IDs produced by a
175/// filtered search (for example, a search constrained by metadata or predicates).
176///
177/// For each query, the recall is defined as the number of ground truth elements that appear
178/// in the retrieved results divided by:
179///
180/// * `min(k_recall, ground truth size)` if `k_recall` is less than or equal to the size of the
181///   ground truth set, or
182/// * the length of the retrieved results, whichever is larger.
183///
184/// The function returns the average of this per-query recall value, expressed as a percentage.
185///
186/// # Arguments
187///
188/// * `num_queries` - The number of queries for which recall is being calculated.
189/// * `gt_dist` - An optional vector of distances for the ground truth elements.
190/// * `groundtruth` - A slice of vectors, where each vector contains the ground truth IDs for a query.
191/// * `our_results` - A slice of vectors, where each vector contains the retrieved IDs for a query.
192/// * `k_recall` - Number of top results to consider from `groundtruth` for recall calculation. Must be greater than 0.
193///
194/// # Returns
195///
196/// Returns an `ANNResult<f64>` containing the average filtered-search recall percentage across all queries.
197///
198/// # Assumptions
199///
200/// * The `groundtruth` and `our_results` slices must have the same length as `num_queries`.
201/// * Each vector in `groundtruth` must be the same length as the corresponding vector in `gt_dist`,
202///   if `gt_dist` is provided.
203///
204/// # Behavior
205///
206/// - If the ground truth for a query is empty, the recall for that query is considered to be 100.0.
207/// - For non-empty ground truth, the recall is calculated as:
208///   `(100.0 * number_of_matches) / max(min(k_recall value, groundtruth size), our_results size)`
209/// - When the vector in gt_dist is provided and k_recall value is less than the size of groundtruth, ties are broken.
210///
211/// # Differences from Range Search Recall
212///
213/// Unlike `calculate_range_search_recall`, this function normalizes the recall by the maximum
214/// of the sizes of the k_recall/groundtruth value and retrieved results, which can lead to different recall
215/// values in cases where the retrieved results contain more elements than the ground truth.
216pub fn calculate_filtered_search_recall(
217    num_queries: usize,
218    gt_dist: Option<&[Vec<f32>]>,
219    groundtruth: &[Vec<u32>],
220    our_results: &[Vec<u32>],
221    k_recall: u32,
222) -> ANNResult<f64> {
223    if k_recall == 0 {
224        return Err(ANNError::log_index_error(format_args!(
225            "k_recall value must be greater than 0, but got {}",
226            k_recall
227        )));
228    }
229
230    if groundtruth.len() != num_queries || our_results.len() != num_queries {
231        return Err(ANNError::log_index_error(format_args!(
232            "groundtruth length ({}) or our_results length ({}) does not match num_queries ({})",
233            groundtruth.len(),
234            our_results.len(),
235            num_queries
236        )));
237    }
238
239    let mut total_recall = 0.0;
240    for i in 0..num_queries {
241        let mut gt: HashSet<u32> = HashSet::new();
242        let mut res: HashSet<u32> = HashSet::new();
243        let gt_cutoff = (k_recall as usize).min(groundtruth[i].len());
244
245        for &item in &groundtruth[i][..gt_cutoff] {
246            //only insert k items from groundtruth
247            gt.insert(item);
248        }
249
250        for &item in &our_results[i] {
251            res.insert(item);
252        }
253
254        if gt_cutoff > 0 {
255            //only break ties when groundtruth is not empty
256            if let Some(gt_dist) = gt_dist {
257                let gt_dist_vec = gt_dist[i].as_slice();
258
259                if gt_dist_vec.len() != groundtruth[i].len() {
260                    return Err(ANNError::log_index_error(format_args!(
261                        "Ground truth distance for query ({}) vector length ({}) is not equal to groundtruth len ({})",
262                        i,
263                        gt_dist_vec.len(),
264                        groundtruth[i].len(),
265                    )));
266                }
267
268                let mut tie_breaker = gt_cutoff;
269
270                while tie_breaker < gt_dist_vec.len() //while there are still ties, add them to gt
271                        && gt_dist_vec[tie_breaker] == gt_dist_vec[gt_cutoff - 1]
272                {
273                    gt.insert(groundtruth[i][tie_breaker]);
274                    tie_breaker += 1;
275                }
276            }
277        }
278
279        let mut cur_recall = 0;
280
281        for &v in &gt {
282            if res.contains(&v) {
283                cur_recall += 1;
284            }
285        }
286
287        if gt_cutoff > 0 {
288            total_recall += (100.0 * cur_recall as f64) / gt_cutoff.max(res.len()) as f64;
289        } else {
290            total_recall += 100.0;
291        }
292    }
293
294    Ok(total_recall / num_queries as f64)
295}
296
297pub fn load_truthset(
298    storage_provider: &impl StorageReadProvider,
299    bin_file: &str,
300) -> ANNResult<TruthSet> {
301    let actual_file_size = storage_provider.get_length(bin_file)? as usize;
302    let mut file = storage_provider.open_reader(bin_file)?;
303
304    let metadata = Metadata::read(&mut file)?;
305    let (npts, dim) = metadata.into_dims();
306
307    info!("Metadata: #pts = {npts}, #dims = {dim}... ");
308
309    let expected_file_size_with_dists: usize =
310        2 * npts * dim * size_of::<u32>() + 2 * size_of::<u32>();
311    let expected_file_size_just_ids: usize = npts * dim * size_of::<u32>() + 2 * size_of::<u32>();
312
313    // truthset_type: 1 = ids + distances, 2 = ids only
314    let truthset_type : i32 = match actual_file_size {
315        x if x == expected_file_size_with_dists => 1,
316        x if x == expected_file_size_just_ids => 2,
317        _ => return Err(ANNError::log_index_error(format_args!(
318            "Error. File size mismatch. File should have bin format, with npts followed by ngt followed by npts*ngt ids and optionally followed by npts*ngt distance values; actual size: {}, expected: {} or {}",
319            actual_file_size,
320            expected_file_size_with_dists,
321            expected_file_size_just_ids
322        )))
323    };
324
325    let mut ids: Vec<u32> = vec![0; npts * dim];
326    let mut buffer = vec![0; npts * dim * size_of::<u32>()];
327    file.read_exact(&mut buffer)?;
328    ids.clone_from_slice(cast_slice::<u8, u32>(&buffer));
329
330    if truthset_type == 1 {
331        let mut dists: Vec<f32> = vec![0.0; npts * dim];
332        let mut buffer = vec![0; npts * dim * size_of::<f32>()];
333        file.read_exact(&mut buffer)?;
334        dists.clone_from_slice(cast_slice::<u8, f32>(&buffer));
335
336        return Ok(TruthSet {
337            index_nodes: ids,
338            distances: Some(dists),
339            index_num_points: npts,
340            index_dimension: dim,
341        });
342    }
343
344    Ok(TruthSet {
345        index_nodes: ids,
346        distances: None,
347        index_num_points: npts,
348        index_dimension: dim,
349    })
350}
351
352// Load the range truthset from the file.
353// The format of the file is as follows:
354// 1. The first 4 bytes are the number of queries.
355// 2. The next 4 bytes are the number of total vector ids.
356// 3. The next (queries * 4) bytes are the numbers of vector ids for each query.
357// 4. The next (total vector ids * 4) bytes are vector ids for each query.
358pub fn load_range_truthset(
359    storage_provider: &impl StorageReadProvider,
360    bin_file: &str,
361) -> ANNResult<RangeSearchTruthSet> {
362    let mut file = storage_provider.open_reader(bin_file)?;
363
364    let metadata = Metadata::read(&mut file)?;
365    let (npts, total_ids) = metadata.into_dims();
366    let mut buffer = [0; size_of::<i32>()];
367
368    info!("Metadata: #pts = {}, #totalIds = {}", npts, total_ids);
369
370    let mut ids: Vec<Vec<u32>> = Vec::new();
371    let mut counts: Vec<u32> = vec![0; npts];
372
373    for count in counts.iter_mut() {
374        file.read_exact(&mut buffer)?;
375        *count = i32::from_le_bytes(buffer) as u32;
376    }
377
378    for &count in &counts {
379        let mut point_ids: Vec<u32> = vec![0; count as usize];
380        let mut buffer = vec![0; count as usize * size_of::<u32>()];
381        file.read_exact(&mut buffer)?;
382        point_ids.clone_from_slice(cast_slice::<u8, u32>(&buffer));
383        ids.push(point_ids);
384    }
385
386    Ok(RangeSearchTruthSet {
387        index_nodes: ids,
388        distances: None,
389        index_num_points: npts,
390        index_dimensions: counts,
391    })
392}
393
394// Load the vector filters from the file in the range truthset format.
395pub fn load_vector_filters(
396    storage_provider: &impl StorageReadProvider,
397    bin_file: &str,
398) -> ANNResult<Vec<HashSet<u32>>> {
399    let range_truthset = load_range_truthset(storage_provider, bin_file)?;
400
401    let query_filters: Vec<HashSet<u32>> = range_truthset
402        .index_nodes
403        .into_iter()
404        .map(|filter| filter.into_iter().collect())
405        .collect();
406
407    Ok(query_filters)
408}
409
410#[cfg(test)]
411mod test_search_index_utils {
412    use super::*;
413
414    struct ExpectedRecall {
415        pub recall_k: usize,
416        pub recall_n: usize,
417        // Recall for each component.
418        pub components: Vec<usize>,
419    }
420
421    impl ExpectedRecall {
422        fn new(recall_k: usize, recall_n: usize, components: Vec<usize>) -> Self {
423            assert!(recall_k <= recall_n);
424            components.iter().for_each(|x| {
425                assert!(*x <= recall_k);
426            });
427            Self {
428                recall_k,
429                recall_n,
430                components,
431            }
432        }
433
434        fn compute(&self) -> f64 {
435            100.0 * (self.components.iter().sum::<usize>() as f64)
436                / ((self.components.len() * self.recall_k) as f64)
437        }
438    }
439
440    #[test]
441    fn test_k_recall_at_n_struct() {
442        // Happy paths should succeed.
443        for k in 1..=10 {
444            for n in k..=10 {
445                let v = KRecallAtN::new(k, n).unwrap();
446                assert_eq!(v.get_k(), k as usize);
447                assert_eq!(v.get_n(), n as usize);
448            }
449        }
450
451        // Error paths.
452        // N.B.: Note the inversion of `k` and `n` in the loop bounds!
453        for n in 1..=10 {
454            for k in (n + 1)..=11 {
455                let v = KRecallAtN::new(k, n).unwrap_err();
456                match v {
457                    RecallBoundsError::KGreaterThanN { k: k_err, n: n_err } => {
458                        assert_eq!(k_err, k);
459                        assert_eq!(n_err, n);
460                    }
461                    RecallBoundsError::ArgumentIsZero { .. } => {
462                        panic!("unreachable reached");
463                    }
464                }
465                let message = format!("{}", v);
466                assert!(message.contains("recall value k"));
467                assert!(message.contains("must be less than or equal to n"));
468                assert!(message.contains(&format!("{}", k)));
469                assert!(message.contains(&format!("{}", n)));
470            }
471        }
472
473        // both zero
474        let v = KRecallAtN::new(0, 0).unwrap_err();
475        let message = format!("{}", v);
476        assert!(message == "recall values k and n must both be non-zero");
477
478        // k is zero
479        let v = KRecallAtN::new(0, 10).unwrap_err();
480        let message = format!("{}", v);
481        assert!(message == "recall values k must be non-zero");
482
483        // n is zero
484        let v = KRecallAtN::new(10, 0).unwrap_err();
485        let message = format!("{}", v);
486        assert!(message == "recall values n must be non-zero");
487    }
488
489    #[test]
490    fn test_compute_recall() {
491        // Set up examples for ground truth and retrieved IDs.
492        let groundtruth_dim = 10;
493        let num_queries = 4;
494
495        let groundtruth: Vec<u32> = vec![
496            0, 1, 2, 3, 4, 5, 6, 7, 8, 9, // row 0
497            5, 6, 7, 8, 9, 10, 11, 12, 13, 14, // row 1
498            0, 1, 2, 3, 4, 5, 6, 7, 8, 9, // row 2
499            0, 1, 2, 3, 4, 5, 6, 7, 8, 9, // row 3
500        ];
501
502        assert_eq!(groundtruth.len(), num_queries * groundtruth_dim);
503
504        let distances: Vec<f32> = vec![
505            0.0, 1.0, 2.0, 3.0, 3.0, 3.0, 3.0, 4.0, 5.0, 6.0, // row 0
506            2.0, 3.0, 3.0, 3.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, // row 1
507            0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, // row 2
508            0.0, 1.0, 2.0, 3.0, 3.0, 3.0, 3.0, 4.0, 5.0, 6.0, // row 3
509        ];
510
511        assert_eq!(distances.len(), groundtruth.len());
512
513        // Shift row 0 by one and row 1 by two.
514        let results_dim = 6;
515        let our_results: Vec<u32> = vec![
516            100, 0, 1, 2, 5, 6, // row 0
517            100, 101, 7, 8, 9, 10, // row 1
518            0, 1, 2, 3, 4, 5, // row 2
519            0, 1, 2, 3, 4, 5, // row 3
520        ];
521        assert_eq!(our_results.len(), num_queries * results_dim);
522
523        // No ties
524        let expected_no_ties = vec![
525            // Equal `k` and `n`
526            ExpectedRecall::new(1, 1, vec![0, 0, 1, 1]),
527            ExpectedRecall::new(2, 2, vec![1, 0, 2, 2]),
528            ExpectedRecall::new(3, 3, vec![2, 1, 3, 3]),
529            ExpectedRecall::new(4, 4, vec![3, 2, 4, 4]),
530            ExpectedRecall::new(5, 5, vec![3, 3, 5, 5]),
531            ExpectedRecall::new(6, 6, vec![4, 4, 6, 6]),
532            // Unequal `k` and `n`.
533            ExpectedRecall::new(1, 2, vec![1, 0, 1, 1]),
534            ExpectedRecall::new(1, 3, vec![1, 0, 1, 1]),
535            ExpectedRecall::new(2, 3, vec![2, 0, 2, 2]),
536            ExpectedRecall::new(3, 5, vec![3, 1, 3, 3]),
537        ];
538        let epsilon = 1e-6; // Define a small tolerance
539        for (i, expected) in expected_no_ties.iter().enumerate() {
540            println!("No Ties: i = {i}");
541            assert_eq!(expected.components.len(), num_queries);
542            let recall = calculate_recall(
543                num_queries,
544                &groundtruth,
545                None,
546                groundtruth_dim,
547                &our_results,
548                results_dim as u32,
549                KRecallAtN::new(expected.recall_k as u32, expected.recall_n as u32).unwrap(),
550            );
551            let left = recall.unwrap();
552            let right = expected.compute();
553            assert!(
554                (left - right).abs() < epsilon,
555                "left = {}, right = {}",
556                left,
557                right
558            );
559        }
560
561        // With Ties
562        let expected_with_ties = vec![
563            // Equal `k` and `n`
564            ExpectedRecall::new(1, 1, vec![0, 0, 1, 1]),
565            ExpectedRecall::new(2, 2, vec![1, 0, 2, 2]),
566            ExpectedRecall::new(3, 3, vec![2, 1, 3, 3]),
567            ExpectedRecall::new(4, 4, vec![3, 2, 4, 4]),
568            ExpectedRecall::new(5, 5, vec![4, 3, 5, 5]), // tie-breaker kicks in
569            ExpectedRecall::new(6, 6, vec![5, 4, 6, 6]), // tie-breaker kicks in
570            // Unequal `k` and `n`.
571            ExpectedRecall::new(1, 2, vec![1, 0, 1, 1]),
572            ExpectedRecall::new(1, 3, vec![1, 0, 1, 1]),
573            ExpectedRecall::new(2, 3, vec![2, 1, 2, 2]),
574            ExpectedRecall::new(4, 5, vec![4, 3, 4, 4]),
575        ];
576
577        for (i, expected) in expected_with_ties.iter().enumerate() {
578            println!("With Ties: i = {i}");
579            assert_eq!(expected.components.len(), num_queries);
580            let recall = calculate_recall(
581                num_queries,
582                &groundtruth,
583                Some(&distances),
584                groundtruth_dim,
585                &our_results,
586                results_dim as u32,
587                KRecallAtN::new(expected.recall_k as u32, expected.recall_n as u32).unwrap(),
588            );
589            let left = recall.unwrap();
590            let right = expected.compute();
591            assert!(
592                (left - right).abs() < epsilon,
593                "left = {}, right = {}",
594                left,
595                right
596            );
597        }
598    }
599
600    #[test]
601    fn test_calculate_filtered_search_recall() {
602        let filtered_search_recall =
603            calculate_filtered_search_recall(1, None, &[vec![5, 6]], &[vec![5, 6, 7, 8, 9]], 1000)
604                .unwrap();
605        assert_eq!(
606            filtered_search_recall, 40.0,
607            "Returned more results than ground truth"
608        );
609
610        assert_eq!(
611            calculate_filtered_search_recall(
612                1,
613                None,
614                &[vec![0, 1, 2, 3, 4],],
615                &[vec![0, 1],],
616                1000
617            )
618            .unwrap(),
619            40.0,
620            "Returned less results than ground truth"
621        );
622
623        let groundtruth: Vec<Vec<u32>> = vec![vec![0, 1, 2, 3, 4], vec![5, 6]];
624
625        let our_results: Vec<Vec<u32>> = vec![vec![0, 1], vec![5, 6, 7, 8, 9]];
626
627        assert_eq!(
628            calculate_filtered_search_recall(2, None, &groundtruth, &our_results, 1000).unwrap(),
629            40.0,
630            "Combination of both cases"
631        );
632
633        assert_eq!(
634            calculate_filtered_search_recall(
635                1,
636                None,
637                &[vec![0, 1, 2, 3, 4],],
638                &[vec![0, 1, 2, 3, 4],],
639                1000
640            )
641            .unwrap(),
642            100.0,
643            "The result matched the ground truth"
644        );
645
646        assert_eq!(
647            calculate_filtered_search_recall(
648                1,
649                None,
650                &[vec![0, 1, 2, 3, 4],],
651                &[vec![0, 1, 12, 13, 14],],
652                1000
653            )
654            .unwrap(),
655            40.0,
656            "The result partially matched the ground truth"
657        );
658
659        assert_eq!(
660            calculate_filtered_search_recall(
661                1,
662                None,
663                &[vec![0; 0],],
664                &[vec![0, 1, 2, 3, 4],],
665                1000
666            )
667            .unwrap(),
668            100.0,
669            "The empty ground truth"
670        );
671    }
672
673    #[test]
674    fn test_calculate_filtered_search_recall_with_tie_breaking() {
675        // Ground truth with distances
676        let gt_distances: Vec<Vec<f32>> = vec![
677            vec![0.1, 0.2, 0.3, 0.3, 0.3], // Ties at index 2, 3, 4
678            vec![0.1, 0.2, 0.3, 0.4, 0.5], // No ties
679        ];
680
681        let groundtruth: Vec<Vec<u32>> = vec![
682            vec![0, 1, 2, 3, 4], // Ground truth IDs
683            vec![5, 6, 7, 8, 9],
684        ];
685
686        let our_results: Vec<Vec<u32>> = vec![
687            vec![0, 1, 3, 2, 4], // Matches all ground truth
688            vec![5, 6, 7, 8, 9], // Matches all ground truth
689        ];
690
691        // Test with tie-breaking
692        assert_eq!(
693            calculate_filtered_search_recall(
694                2,
695                Some(&gt_distances),
696                &groundtruth,
697                &our_results,
698                3 // k_recall
699            )
700            .unwrap(),
701            80.0, //query 0 has 5 matches including ties and 5 returned results => 100% recall
702            // query 1 has 3 matches and 5 returned results => 60% recall
703            "Tie-breaking should include all tied elements"
704        );
705
706        // Test without tie-breaking
707        assert_eq!(
708            calculate_filtered_search_recall(2, None, &groundtruth, &our_results, 3).unwrap(),
709            60.0, // both queries have 3 matches and 5 returned results => 60% recall
710            "Without tie-breaking, both queries should match on 3 of 5 elements"
711        );
712
713        // Test without tie-breaking and large k
714        assert_eq!(
715            calculate_filtered_search_recall(2, None, &groundtruth, &our_results, 10).unwrap(),
716            100.0,
717            "Without tie-breaking and with large k, both queries should match on all elements"
718        );
719    }
720
721    #[test]
722    fn test_calculate_filtered_search_recall_empty_ground_truth() {
723        assert_eq!(
724            calculate_filtered_search_recall(
725                2,
726                Some(&[vec![], vec![]]),
727                &[vec![], vec![]],
728                &[vec![0, 1, 2], vec![5, 6, 7],],
729                1
730            )
731            .unwrap(),
732            100.0,
733            "Empty ground truth should result in 100% recall"
734        );
735    }
736
737    #[test]
738    fn test_recall_bounds_error_display() {
739        let error = RecallBoundsError::KGreaterThanN { k: 10, n: 5 };
740        let message = format!("{}", error);
741        assert!(message.contains("recall value k"));
742        assert!(message.contains("must be less than or equal to n"));
743
744        let error = RecallBoundsError::ArgumentIsZero { k: 0, n: 0 };
745        let message = format!("{}", error);
746        assert_eq!(message, "recall values k and n must both be non-zero");
747
748        let error = RecallBoundsError::ArgumentIsZero { k: 0, n: 5 };
749        let message = format!("{}", error);
750        assert_eq!(message, "recall values k must be non-zero");
751
752        let error = RecallBoundsError::ArgumentIsZero { k: 5, n: 0 };
753        let message = format!("{}", error);
754        assert_eq!(message, "recall values n must be non-zero");
755    }
756
757    #[test]
758    fn test_recall_bounds_error_conversion() {
759        let error = RecallBoundsError::KGreaterThanN { k: 10, n: 5 };
760        let cmd_error: CMDToolError = error.into();
761        assert!(!cmd_error.details.is_empty());
762    }
763
764    #[test]
765    fn test_k_recall_at_n_getters() {
766        let recall = KRecallAtN::new(5, 10).unwrap();
767        assert_eq!(recall.get_k(), 5);
768        assert_eq!(recall.get_n(), 10);
769    }
770
771    #[test]
772    fn test_k_recall_at_n_equal_values() {
773        let recall = KRecallAtN::new(5, 5).unwrap();
774        assert_eq!(recall.get_k(), 5);
775        assert_eq!(recall.get_n(), 5);
776    }
777}