Skip to main content

diskann_benchmark_core/search/
ids.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use diskann_utils::views::{self, Matrix};
7
8use crate::recall;
9
10/// A generic wrapper for containing the results IDs for multiple query searches.
11///
12/// Users should interact with this type via the [`as_rows`](Self::as_rows) methods.
13///
14/// Note that the slices returned by [`as_rows`](Self::as_rows) may have different lengths
15/// depending on how many IDs were actually returned for each query.
16#[derive(Debug)]
17pub struct ResultIds<I> {
18    inner: ResultIdsInner<I>,
19}
20
21impl<I> ResultIds<I> {
22    /// Return the contained IDs as a [`recall::Rows<I>`].
23    pub fn as_rows(&self) -> &dyn recall::Rows<I> {
24        self.inner.as_rows()
25    }
26
27    pub(crate) fn new(inner: ResultIdsInner<I>) -> Self {
28        Self { inner }
29    }
30}
31
32/// A [`recall::Rows<I>`] implementation that is more efficient when the number of returned
33/// IDs is known to be bounded by a fixed size and thus can be stored in a matrix.
34///
35/// The number of valid IDs per row is allowed to be less than this upper bound and is tracked
36/// separately.
37#[derive(Debug)]
38pub(crate) struct Bounded<I> {
39    ids: Matrix<I>,
40    // Must have the same length as `matrix.nrows()`.
41    lengths: Vec<usize>,
42}
43
44impl<I> Bounded<I> {
45    /// Create a new `Bounded` instance with the given `ids` matrix and `lengths` vector.
46    ///
47    /// Argument `lengths` must have the same length as the number of rows in `ids` and the
48    /// value of each entry must be less than or equal to the number of columns in `ids`.
49    ///
50    /// Length values that exceed the number of columns will silently be clamped when accessing rows.
51    ///
52    /// # Panics
53    ///
54    /// Panics if the number of rows in `ids` does not match the length of `lengths`.
55    pub(crate) fn new(ids: Matrix<I>, lengths: Vec<usize>) -> Self {
56        assert_eq!(
57            ids.nrows(),
58            lengths.len(),
59            "an internal invariant was not upheld",
60        );
61        Self { ids, lengths }
62    }
63
64    /// Return the number of rows stored in `self`.
65    pub(crate) fn len(&self) -> usize {
66        self.lengths.len()
67    }
68
69    /// Return an iterator over the valid ID slices. The length of the iterator will be equal to
70    /// [`Bounded::len`].
71    ///
72    /// Note that the yielded slices are not guaranteed to have the same length.
73    pub(crate) fn iter(&self) -> impl ExactSizeIterator<Item = &[I]> {
74        std::iter::zip(self.ids.row_iter(), self.lengths.iter()).map(|(row, len)| {
75            match row.get(..*len) {
76                Some(v) => v,
77                None => row,
78            }
79        })
80    }
81}
82
83impl<I> recall::Rows<I> for Bounded<I> {
84    fn nrows(&self) -> usize {
85        self.len()
86    }
87    fn row(&self, index: usize) -> &[I] {
88        let length = self.lengths[index];
89        let row = self.ids.row(index);
90        match row.get(..length) {
91            Some(v) => v,
92            None => row,
93        }
94    }
95    fn ncols(&self) -> Option<usize> {
96        None
97    }
98}
99
100///////////
101// Inner //
102///////////
103
104/// We internally have two representations for result IDs: either a bounded size
105/// container (to reduce the number of heap allocations) or a dynamic vector of vectors.
106///
107/// The former is used when the number of IDs is known to be bounded.
108#[derive(Debug)]
109pub(crate) enum ResultIdsInner<I> {
110    Fixed(Bounded<I>),
111    Dynamic(Vec<Vec<I>>),
112}
113
114impl<I> ResultIdsInner<I> {
115    pub(crate) fn as_rows(&self) -> &dyn recall::Rows<I> {
116        match self {
117            Self::Fixed(bounded) => bounded,
118            Self::Dynamic(ids) => ids,
119        }
120    }
121}
122
123/// A utility tool for aggregating multiple [`ResultIdsInner<I>`] instances into a single
124/// [`ResultIds<I>`]. When aggregating, if all instances are [`ResultIdsInner::Fixed`]
125/// with the same upper bound, then the aggregation will also be [`ResultIdsInner::Fixed`].
126///
127/// Otherwise, the aggregation will be [`ResultIdsInner::Dynamic`].
128#[derive(Debug, Default)]
129pub(crate) enum IdAggregator<I> {
130    /// No ids have been aggregated.
131    #[default]
132    Empty,
133    /// IDs have been aggregated and all of them are bounded with the same size
134    /// stored in `num_ids`. The field `len` stores the total number of rows aggregated
135    /// to help with the final allocation in [`IdAggergator::finish`].
136    Fixed {
137        matrices: Vec<Bounded<I>>,
138        len: usize,
139        num_ids: usize,
140    },
141    /// At least one aggregated IDs instance was dynamic.
142    Dynamic(Vec<ResultIdsInner<I>>),
143}
144
145impl<I> IdAggregator<I>
146where
147    I: Clone + Default,
148{
149    /// Construct a new empty [`IdAggregator`].
150    pub(crate) fn new() -> Self {
151        Self::Empty
152    }
153
154    /// Push `ids` into the aggregator.
155    pub(crate) fn push(&mut self, ids: ResultIdsInner<I>) {
156        // The general logic is as follows:
157        // - If we are empty, we just take the incoming IDs. If they are bounded, we optimistically assume
158        //   that future pushes will also be bounded with the same size. Otherwise, we'll always be `Dynamic`.
159        //
160        // - If we are `Fixed`, we check if the incoming IDs are also bounded with the same size and if so,
161        //   simply append them to the internal list. If not, we convert all previously stored bounded IDs
162        //   into dynamic IDs and switch to `Dynamic` mode.
163        //
164        // - If we are `Dynamic`, we simply append the incoming IDs.
165        //
166        // Possible transitions:
167        // * Empty -> Fixed
168        // * Empty -> Dynamic
169        // * Fixed -> Dynamic
170
171        *self = match std::mem::take(self) {
172            Self::Empty => match ids {
173                ResultIdsInner::Fixed(bounded) => {
174                    let len = bounded.ids.nrows();
175                    let num_ids = bounded.ids.ncols();
176                    Self::Fixed {
177                        matrices: vec![bounded],
178                        len,
179                        num_ids,
180                    }
181                }
182                ResultIdsInner::Dynamic(ids) => Self::Dynamic(vec![ResultIdsInner::Dynamic(ids)]),
183            },
184            Self::Fixed {
185                mut matrices,
186                len,
187                num_ids,
188            } => match ids {
189                ResultIdsInner::Fixed(bounded) => {
190                    if bounded.ids.ncols() == num_ids {
191                        let len = len + bounded.len();
192                        matrices.push(bounded);
193                        Self::Fixed {
194                            matrices,
195                            len,
196                            num_ids,
197                        }
198                    } else {
199                        let mut dynamic: Vec<_> =
200                            matrices.into_iter().map(ResultIdsInner::Fixed).collect();
201                        dynamic.push(ResultIdsInner::Fixed(bounded));
202                        Self::Dynamic(dynamic)
203                    }
204                }
205                ResultIdsInner::Dynamic(ids) => {
206                    let mut dynamic: Vec<_> =
207                        matrices.into_iter().map(ResultIdsInner::Fixed).collect();
208                    dynamic.push(ResultIdsInner::Dynamic(ids));
209                    Self::Dynamic(dynamic)
210                }
211            },
212            Self::Dynamic(mut dynamic) => {
213                dynamic.push(ids);
214                Self::Dynamic(dynamic)
215            }
216        };
217    }
218
219    /// Consume `self`, producing a single [`ResultIds<I>`] containing the concatenation
220    /// of all pushed IDs.
221    pub(crate) fn finish(self) -> ResultIds<I> {
222        // The internal logic is as follows:
223        // * If we are empty, we return an empty dynamic IDs container.
224        // * If we are fixed, we allocate a new **single** matrix and copy all IDs into it.
225        // * If we are dynamic, we concatenate all dynamic ID vectors into a single dynamic container.
226
227        match self {
228            Self::Empty => ResultIds::new(ResultIdsInner::Dynamic(Vec::new())),
229            Self::Fixed {
230                matrices,
231                len,
232                num_ids,
233            } => {
234                let mut dst = Matrix::new(views::Init(|| I::default()), len, num_ids);
235                let mut lengths = Vec::with_capacity(len);
236
237                let mut output_row = 0;
238                for bounded in matrices {
239                    for row in bounded.ids.row_iter() {
240                        dst.row_mut(output_row).clone_from_slice(row);
241                        output_row += 1;
242                    }
243                    lengths.extend_from_slice(&bounded.lengths);
244                }
245
246                ResultIds::new(ResultIdsInner::Fixed(Bounded::new(dst, lengths)))
247            }
248            Self::Dynamic(all) => {
249                let mut dst = Vec::<Vec<I>>::new();
250                for ids in all {
251                    match ids {
252                        ResultIdsInner::Fixed(bounded) => {
253                            bounded.iter().for_each(|row| dst.push(row.into()));
254                        }
255                        ResultIdsInner::Dynamic(dynamic) => {
256                            dynamic.into_iter().for_each(|i| dst.push(i));
257                        }
258                    }
259                }
260
261                ResultIds::new(ResultIdsInner::Dynamic(dst))
262            }
263        }
264    }
265}
266
267///////////
268// Tests //
269///////////
270
271#[cfg(test)]
272mod tests {
273    use super::*;
274
275    use crate::recall::Rows;
276
277    // Helper to create a Bounded instance for testing
278    fn make_bounded(data: Vec<Vec<u32>>) -> Bounded<u32> {
279        let nrows = data.len();
280        let ncols = data.iter().map(|v| v.len()).max().unwrap_or(0);
281
282        let mut matrix = Matrix::new(0u32, nrows, ncols);
283        let mut lengths = Vec::with_capacity(nrows);
284
285        for (row, row_data) in std::iter::zip(matrix.row_iter_mut(), data.iter()) {
286            let len = std::iter::zip(row.iter_mut(), row_data.iter())
287                .map(|(dst, src)| {
288                    *dst = *src;
289                })
290                .count();
291            lengths.push(len);
292        }
293
294        Bounded::new(matrix, lengths)
295    }
296
297    #[test]
298    fn test_bounded_new_valid() {
299        let matrix = Matrix::new(0u32, 3, 5);
300        let lengths = vec![2, 3, 1];
301        let bounded = Bounded::new(matrix, lengths);
302
303        assert_eq!(bounded.len(), 3);
304    }
305
306    #[test]
307    fn test_bounded_length_clamping() {
308        let matrix = Matrix::new(0u32, 3, 3);
309        let lengths = vec![2, 3, 5]; // Last length exceeds number of columns
310        let bounded = Bounded::new(matrix, lengths);
311
312        assert_eq!(bounded.len(), 3);
313        assert_eq!(bounded.row(0), &[0, 0]);
314        assert_eq!(bounded.row(1), &[0, 0, 0]);
315        assert_eq!(bounded.row(2), &[0, 0, 0]); // Clamped to 3 columns
316
317        let rows: Vec<&[u32]> = bounded.iter().collect();
318        assert_eq!(rows[0], &[0, 0]);
319        assert_eq!(rows[1], &[0, 0, 0]);
320        assert_eq!(rows[2], &[0, 0, 0]); // Clamped to 3 columns
321    }
322
323    #[test]
324    #[should_panic(expected = "an internal invariant was not upheld")]
325    fn test_bounded_new_mismatched_lengths() {
326        let matrix = Matrix::new(0u32, 3, 5);
327        let lengths = vec![2, 3]; // Only 2 lengths for 3 rows
328        Bounded::new(matrix, lengths);
329    }
330
331    #[test]
332    fn test_bounded() {
333        let bounded = make_bounded(vec![vec![1, 2], vec![3, 4, 5], vec![6]]);
334        assert_eq!(bounded.len(), 3);
335
336        // `Rows`
337        assert_eq!(bounded.nrows(), 3);
338        assert_eq!(bounded.row(0), &[1, 2]);
339        assert_eq!(bounded.row(1), &[3, 4, 5]);
340        assert_eq!(bounded.row(2), &[6]);
341        assert_eq!(bounded.ncols(), None);
342
343        // Iterator
344        let rows: Vec<&[u32]> = bounded.iter().collect();
345        assert_eq!(rows.len(), 3);
346        assert_eq!(rows[0], &[1, 2]);
347        assert_eq!(rows[1], &[3, 4, 5]);
348        assert_eq!(rows[2], &[6]);
349    }
350
351    #[test]
352    fn test_result_ids_inner_fixed() {
353        let bounded = make_bounded(vec![vec![1, 2], vec![3, 4, 5]]);
354        let inner = ResultIdsInner::Fixed(bounded);
355
356        let rows = inner.as_rows();
357        assert_eq!(rows.nrows(), 2);
358        assert_eq!(rows.row(0), &[1, 2]);
359        assert_eq!(rows.row(1), &[3, 4, 5]);
360    }
361
362    #[test]
363    fn test_result_ids_inner_dynamic() {
364        let vecs = vec![vec![1, 2, 3], vec![4], vec![5, 6]];
365        let inner = ResultIdsInner::Dynamic(vecs);
366
367        let rows = inner.as_rows();
368        assert_eq!(rows.nrows(), 3);
369        assert_eq!(rows.row(0), &[1, 2, 3]);
370        assert_eq!(rows.row(1), &[4]);
371        assert_eq!(rows.row(2), &[5, 6]);
372    }
373
374    #[test]
375    fn test_result_ids_wrapper() {
376        let bounded = make_bounded(vec![vec![10], vec![20, 30]]);
377        let result = ResultIds::new(ResultIdsInner::Fixed(bounded));
378
379        let rows = result.as_rows();
380        assert_eq!(rows.nrows(), 2);
381        assert_eq!(rows.row(0), &[10]);
382        assert_eq!(rows.row(1), &[20, 30]);
383    }
384
385    // IdAggregator Tests
386
387    #[test]
388    fn test_aggregator_empty_finish() {
389        let aggregator = IdAggregator::<u32>::new();
390        let result = aggregator.finish();
391
392        let rows = result.as_rows();
393        assert_eq!(rows.nrows(), 0);
394        assert_eq!(rows.ncols(), None);
395    }
396
397    #[test]
398    fn test_aggregator_empty_to_fixed() {
399        let mut aggregator = IdAggregator::new();
400
401        let bounded = make_bounded(vec![vec![1, 2], vec![3], vec![4, 5]]);
402        aggregator.push(ResultIdsInner::Fixed(bounded));
403
404        // Should be in Fixed state
405        match aggregator {
406            IdAggregator::Fixed { len, num_ids, .. } => {
407                assert_eq!(len, 3);
408                assert_eq!(num_ids, 2);
409            }
410            _ => panic!("Expected Fixed state"),
411        }
412
413        let finished = aggregator.finish();
414        let rows = finished.as_rows();
415        assert_eq!(rows.nrows(), 3);
416        assert_eq!(rows.row(0), &[1, 2]);
417        assert_eq!(rows.row(1), &[3]);
418        assert_eq!(rows.row(2), &[4, 5]);
419    }
420
421    #[test]
422    fn test_aggregator_empty_to_dynamic() {
423        let mut aggregator = IdAggregator::new();
424
425        let vecs = vec![vec![1, 2, 3], vec![4]];
426        aggregator.push(ResultIdsInner::Dynamic(vecs));
427
428        // Should be in Dynamic state
429        match aggregator {
430            IdAggregator::Dynamic(ref inner) => {
431                assert_eq!(inner.len(), 1);
432            }
433            _ => panic!("Expected Dynamic state"),
434        }
435
436        let finished = aggregator.finish();
437        let rows = finished.as_rows();
438        assert_eq!(rows.nrows(), 2);
439        assert_eq!(rows.row(0), &[1, 2, 3]);
440        assert_eq!(rows.row(1), &[4]);
441    }
442
443    #[test]
444    fn test_aggregator_fixed_stays_fixed_same_size() {
445        let mut aggregator = IdAggregator::new();
446
447        // Push first bounded with 3 columns
448        let bounded1 = make_bounded(vec![vec![1, 2, 3], vec![4, 5]]);
449        aggregator.push(ResultIdsInner::Fixed(bounded1));
450
451        // Push second bounded with 3 columns
452        let bounded2 = make_bounded(vec![vec![6, 7, 8]]);
453        aggregator.push(ResultIdsInner::Fixed(bounded2));
454
455        // Should still be in Fixed state with accumulated length
456        match &aggregator {
457            IdAggregator::Fixed {
458                len,
459                num_ids,
460                matrices,
461            } => {
462                assert_eq!(*len, 3); // 2 + 1 rows
463                assert_eq!(*num_ids, 3);
464                assert_eq!(matrices.len(), 2);
465            }
466            _ => panic!("Expected Fixed state"),
467        }
468
469        let finished = aggregator.finish();
470        let rows = finished.as_rows();
471        assert_eq!(rows.nrows(), 3);
472        assert_eq!(rows.row(0), &[1, 2, 3]);
473        assert_eq!(rows.row(1), &[4, 5]);
474        assert_eq!(rows.row(2), &[6, 7, 8]);
475    }
476
477    #[test]
478    fn test_aggregator_fixed_to_dynamic_different_sizes() {
479        let mut aggregator = IdAggregator::new();
480
481        // Push first bounded with 2 columns
482        let bounded1 = make_bounded(vec![vec![1, 2], vec![3, 4]]);
483        aggregator.push(ResultIdsInner::Fixed(bounded1));
484
485        // Push second bounded with 3 columns (different size)
486        let bounded2 = make_bounded(vec![vec![5, 6, 7]]);
487        aggregator.push(ResultIdsInner::Fixed(bounded2));
488
489        // Should transition to Dynamic
490        match aggregator {
491            IdAggregator::Dynamic(ref inner) => {
492                assert_eq!(inner.len(), 2);
493            }
494            _ => panic!("Expected Dynamic state after size mismatch"),
495        }
496
497        let finished = aggregator.finish();
498        let rows = finished.as_rows();
499        assert_eq!(rows.nrows(), 3);
500        assert_eq!(rows.row(0), &[1, 2]);
501        assert_eq!(rows.row(1), &[3, 4]);
502        assert_eq!(rows.row(2), &[5, 6, 7]);
503    }
504
505    #[test]
506    fn test_aggregator_fixed_to_dynamic_incoming_dynamic() {
507        let mut aggregator = IdAggregator::new();
508
509        // Start with Fixed
510        let bounded = make_bounded(vec![vec![1, 2], vec![3, 4]]);
511        aggregator.push(ResultIdsInner::Fixed(bounded));
512
513        // Push dynamic
514        let vecs = vec![vec![5, 6, 7]];
515        aggregator.push(ResultIdsInner::Dynamic(vecs));
516
517        // Should transition to Dynamic
518        match aggregator {
519            IdAggregator::Dynamic(ref inner) => {
520                assert_eq!(inner.len(), 2);
521            }
522            _ => panic!("Expected Dynamic state"),
523        }
524
525        let finished = aggregator.finish();
526        let rows = finished.as_rows();
527        assert_eq!(rows.nrows(), 3);
528        assert_eq!(rows.row(0), &[1, 2]);
529        assert_eq!(rows.row(1), &[3, 4]);
530        assert_eq!(rows.row(2), &[5, 6, 7]);
531    }
532
533    #[test]
534    fn test_aggregator_dynamic_stays_dynamic() {
535        let mut aggregator = IdAggregator::new();
536
537        // Start with Dynamic
538        let vecs1 = vec![vec![1, 2]];
539        aggregator.push(ResultIdsInner::Dynamic(vecs1));
540
541        // Push more dynamic
542        let vecs2 = vec![vec![3, 4, 5]];
543        aggregator.push(ResultIdsInner::Dynamic(vecs2));
544
545        // Push bounded
546        let bounded = make_bounded(vec![vec![6, 7]]);
547        aggregator.push(ResultIdsInner::Fixed(bounded));
548
549        // Should remain Dynamic
550        match aggregator {
551            IdAggregator::Dynamic(ref inner) => {
552                assert_eq!(inner.len(), 3);
553            }
554            _ => panic!("Expected Dynamic state"),
555        }
556
557        let finished = aggregator.finish();
558        let rows = finished.as_rows();
559        assert_eq!(rows.nrows(), 3);
560        assert_eq!(rows.row(0), &[1, 2]);
561        assert_eq!(rows.row(1), &[3, 4, 5]);
562        assert_eq!(rows.row(2), &[6, 7]);
563    }
564
565    // #[test]
566    // fn test_aggregator_finish_fixed_single_matrix() {
567    //     let mut aggregator = IdAggregator::new();
568
569    //     let bounded1 = make_bounded(vec![vec![1, 2], vec![3, 4]]);
570    //     aggregator.push(ResultIdsInner::Fixed(bounded1));
571
572    //     let bounded2 = make_bounded(vec![vec![5, 6], vec![7, 8]]);
573    //     aggregator.push(ResultIdsInner::Fixed(bounded2));
574
575    //     let result = aggregator.finish();
576    //     let rows = result.as_rows();
577
578    //     assert_eq!(rows.nrows(), 4);
579    //     assert_eq!(rows.row(0), &[1, 2]);
580    //     assert_eq!(rows.row(1), &[3, 4]);
581    //     assert_eq!(rows.row(2), &[5, 6]);
582    //     assert_eq!(rows.row(3), &[7, 8]);
583    // }
584
585    // #[test]
586    // fn test_aggregator_finish_dynamic() {
587    //     let mut aggregator = IdAggregator::new();
588
589    //     let vecs1 = vec![vec![1, 2, 3], vec![4]];
590    //     aggregator.push(ResultIdsInner::Dynamic(vecs1));
591
592    //     let bounded = make_bounded(vec![vec![5, 6]]);
593    //     aggregator.push(ResultIdsInner::Fixed(bounded));
594
595    //     let vecs2 = vec![vec![7, 8, 9, 10]];
596    //     aggregator.push(ResultIdsInner::Dynamic(vecs2));
597
598    //     let result = aggregator.finish();
599    //     let rows = result.as_rows();
600
601    //     assert_eq!(rows.nrows(), 4);
602    //     assert_eq!(rows.row(0), &[1, 2, 3]);
603    //     assert_eq!(rows.row(1), &[4]);
604    //     assert_eq!(rows.row(2), &[5, 6]);
605    //     assert_eq!(rows.row(3), &[7, 8, 9, 10]);
606    // }
607
608    // #[test]
609    // fn test_aggregator_finish_preserves_variable_lengths() {
610    //     let mut aggregator = IdAggregator::new();
611
612    //     // Different row lengths with same max columns
613    //     let bounded = make_bounded(vec![vec![1, 2, 3], vec![4], vec![5, 6]]);
614    //     aggregator.push(ResultIdsInner::Fixed(bounded));
615
616    //     let result = aggregator.finish();
617    //     let rows = result.as_rows();
618
619    //     assert_eq!(rows.nrows(), 3);
620    //     assert_eq!(rows.row(0), &[1, 2, 3]);
621    //     assert_eq!(rows.row(1), &[4]);
622    //     assert_eq!(rows.row(2), &[5, 6]);
623    // }
624}