Skip to main content

diskann_quantization/
views.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::num::NonZeroUsize;
7
8use diskann_utils::views::{DenseData, MutDenseData};
9use std::ops::{Index, IndexMut};
10use thiserror::Error;
11
12//////////////////
13// ChunkOffsets //
14//////////////////
15
16/// A wrapper class for PQ chunk offsets.
17///
18/// Upon construction, this class guarantees that the underlying chunk offset plan is valid.
19/// A valid PQ chunk offset plan records the starting offsets of each chunk such that chunk
20/// `i` of a slice `x` can be accessed using `x[offsets[i]..offsets[i+1]]`.
21///
22/// In particular, a valid PQ chunk offset plan has the following properties:
23///
24/// * It has a length of at least 2.
25/// * Its first entry is 0.
26/// * For `i` in `0..offsets.len()`, `offsets[i] < offsets[i+1]`.
27#[derive(Debug, Clone, Copy, PartialEq)]
28pub struct ChunkOffsetsBase<T>
29where
30    T: DenseData<Elem = usize>,
31{
32    // Pre-compute the associated dimension for better locality.
33    //
34    // We could extract this as `offsets.last().unwrap() - 1`, but hoisting it out into
35    // the struct means it is readily available when needed.
36    dim: NonZeroUsize,
37    // Chunk Offsets
38    offsets: T,
39}
40
41#[derive(Error, Debug)]
42#[non_exhaustive]
43pub enum ChunkOffsetError {
44    #[error("offsets must have a length of at least 2, found {0}")]
45    LengthNotAtLeastTwo(usize),
46    #[error("offsets must begin at 0, not {0}")]
47    DoesNotBeginWithZero(usize),
48    #[error(
49        "offsets must be strictly increasing, \
50         instead entry {start_val} at position {start} is followed by {next_val}"
51    )]
52    NonMonotonic {
53        start_val: usize,
54        start: usize,
55        next_val: usize,
56    },
57}
58
59impl<T> ChunkOffsetsBase<T>
60where
61    T: DenseData<Elem = usize>,
62{
63    /// Construct a new `ChunkOffset` from a raw slice.
64    ///
65    /// Returns an error if:
66    /// * The length of `offsets` is less than 2.
67    /// * The entries in `offsets` are not strictly increasing.
68    pub fn new(offsets: T) -> Result<Self, ChunkOffsetError> {
69        let slice = offsets.as_slice();
70
71        // Validate that the length is correct.
72        let len = slice.len();
73        if len < 2 {
74            return Err(ChunkOffsetError::LengthNotAtLeastTwo(len));
75        }
76
77        // Check that we don't start at zero.
78        let start = slice[0];
79        if start != 0 {
80            return Err(ChunkOffsetError::DoesNotBeginWithZero(start));
81        }
82
83        // What follows is a convoluted dance to safely get the source dimension as a
84        // `NonZeroUsize` while validating the monotonicity of the offsets in the provided
85        // slice.
86        //
87        // We can seed the `NonZeroUsize` with the knowledge that we've already checked
88        // that the length of the `slice` vector is at least two.
89        let mut last: NonZeroUsize = match NonZeroUsize::new(slice[1]) {
90            Some(x) => Ok(x),
91            None => Err(ChunkOffsetError::NonMonotonic {
92                start_val: start,
93                start: 0,
94                next_val: 0,
95            }),
96        }?;
97
98        // Now that we've successfully initialized a `NonZeroUsize` - we can use it going
99        // forward.
100        //
101        // Validate that `offsets` is monotonic.
102        for i in 2..slice.len() {
103            let start_val = slice[i - 1];
104            let next_val = NonZeroUsize::new(slice[i]);
105            last = match next_val {
106                Some(next_val) => {
107                    if start_val >= next_val.get() {
108                        Err(ChunkOffsetError::NonMonotonic {
109                            start_val,
110                            start: i - 1,
111                            next_val: next_val.get(),
112                        })
113                    } else {
114                        Ok(next_val)
115                    }
116                }
117                // If we hit this case, then `slice[i]` was zero.
118                None => Err(ChunkOffsetError::NonMonotonic {
119                    start_val,
120                    start: i - 1,
121                    next_val: 0,
122                }),
123            }?;
124        }
125
126        // The last entry in `offset` is one-past the end.
127        Ok(Self { dim: last, offsets })
128    }
129
130    /// Return the number of chunks associated with this mapping.
131    ///
132    /// This will be one-less than the length of the provided slice.
133    pub fn len(&self) -> usize {
134        // This invariant should hold by construction, and allows us to safely subtract
135        // 1 from the length of the underlying `offsets` span.
136        debug_assert!(self.offsets.as_slice().len() >= 2);
137        self.offsets.as_slice().len() - 1
138    }
139
140    /// Return whether the offsets are empty.
141    pub fn is_empty(&self) -> bool {
142        // by class invariant, there must always be at least one chunk.
143        false
144    }
145
146    /// Return the dimensionality of the vector data associated with this chunking schema.
147    pub fn dim(&self) -> usize {
148        self.dim.get()
149    }
150
151    /// Return the dimensionality of the vector data associated with this chunking schema.
152    ///
153    /// By class invariant, the dimensionality must be nonzero, and this expressed in the
154    /// retuen type.
155    ///
156    /// This method cannot fail and will not panic.
157    pub fn dim_nonzero(&self) -> NonZeroUsize {
158        self.dim
159    }
160
161    /// Return a range containing the start and one-past-the-end indices for chunk `i`.
162    ///
163    /// # Panics
164    ///
165    /// Panics if `i >= self.len()`.
166    pub fn at(&self, i: usize) -> core::ops::Range<usize> {
167        assert!(
168            i < self.len(),
169            "index {i} must be less than len {}",
170            self.len()
171        );
172        let slice = self.offsets.as_slice();
173        slice[i]..slice[i + 1]
174    }
175
176    /// Return `self` as a view.
177    pub fn as_view(&self) -> ChunkOffsetsView<'_> {
178        ChunkOffsetsBase {
179            dim: self.dim,
180            offsets: self.offsets.as_slice(),
181        }
182    }
183
184    /// Return a `'static` copy of `self`.
185    pub fn to_owned(&self) -> ChunkOffsets {
186        ChunkOffsetsBase {
187            dim: self.dim,
188            offsets: self.offsets.as_slice().into(),
189        }
190    }
191
192    /// Return the underlying data as a slice.
193    pub fn as_slice(&self) -> &[usize] {
194        self.offsets.as_slice()
195    }
196}
197
198pub type ChunkOffsetsView<'a> = ChunkOffsetsBase<&'a [usize]>;
199pub type ChunkOffsets = ChunkOffsetsBase<Box<[usize]>>;
200
201/// Allow chunk offsets view to be converted directly to slices.
202impl<'a> From<ChunkOffsetsView<'a>> for &'a [usize] {
203    fn from(view: ChunkOffsetsView<'a>) -> Self {
204        view.offsets
205    }
206}
207
208///////////////
209// ChunkView //
210///////////////
211
212/// A view over a slice that partitions the slice into chunks corresponding to a valid
213/// PQ chunking configuration.
214///
215/// This class maintains the invariant that the provided chunking configuration if valid
216/// and that the data being partitioned has the correct length.
217#[derive(Debug, Clone, Copy)]
218pub struct ChunkViewImpl<'a, T>
219where
220    T: DenseData,
221{
222    data: T,
223    offsets: ChunkOffsetsView<'a>,
224}
225
226#[derive(Error, Debug)]
227#[non_exhaustive]
228#[error(
229    "error in chunk view construction, got a slice of length {got} but \
230         the provided chunking schema expects a length of {should}"
231)]
232pub struct ChunkViewError {
233    got: usize,
234    should: usize,
235}
236
237impl<'a, T> ChunkViewImpl<'a, T>
238where
239    T: DenseData,
240{
241    /// Construct a new `ChunkView`.
242    ///
243    /// Returns an error if `data.len() != offsets.dim()`.
244    pub fn new<U>(data: U, offsets: ChunkOffsetsView<'a>) -> Result<Self, ChunkViewError>
245    where
246        T: From<U>,
247    {
248        let data: T = data.into();
249
250        // Use the `offsets` as the source of truth because it is more likely that a
251        // `ChunkOffsetsView` will be constructed once and reused, so is more likely to be
252        // the desired outcome.
253        let got = data.as_slice().len();
254        let should = offsets.dim();
255        if got != should {
256            Err(ChunkViewError { got, should })
257        } else {
258            Ok(Self { data, offsets })
259        }
260    }
261
262    /// Return the number of partitions in the chunking view.
263    pub fn len(&self) -> usize {
264        self.offsets.len()
265    }
266
267    pub fn is_empty(&self) -> bool {
268        self.offsets.is_empty()
269    }
270}
271
272/// Return the `i`th chunk of the view.
273///
274/// # Panics
275///
276/// Panics if `i >= self.len()`.
277impl<T> Index<usize> for ChunkViewImpl<'_, T>
278where
279    T: DenseData,
280{
281    type Output = [T::Elem];
282
283    fn index(&self, i: usize) -> &Self::Output {
284        &(self.data.as_slice())[self.offsets.at(i)]
285    }
286}
287
288/// Return the `i`th chunk of the view.
289///
290/// # Panics
291///
292/// Panics if `i >= self.len()`.
293impl<T> IndexMut<usize> for ChunkViewImpl<'_, T>
294where
295    T: MutDenseData,
296{
297    fn index_mut(&mut self, i: usize) -> &mut Self::Output {
298        &mut (self.data.as_mut_slice())[self.offsets.at(i)]
299    }
300}
301
302pub type ChunkView<'a, T> = ChunkViewImpl<'a, &'a [T]>;
303pub type MutChunkView<'a, T> = ChunkViewImpl<'a, &'a mut [T]>;
304
305#[cfg(test)]
306mod tests {
307    use super::*;
308    use diskann_utils::lazy_format;
309
310    /// This function is only callable with copyable types.
311    ///
312    /// This lets us test for types we expect to be `Copy`.
313    fn is_copyable<T: Copy>(_x: T) -> bool {
314        true
315    }
316
317    //////////////////
318    // ChunkOffsets //
319    //////////////////
320
321    #[test]
322    fn chunk_offset_happy_path() {
323        let offsets_raw: Vec<usize> = vec![0, 1, 3, 6, 10, 12, 13, 14];
324        let offsets = ChunkOffsetsView::new(offsets_raw.as_slice()).unwrap();
325
326        assert_eq!(offsets.len(), offsets_raw.len() - 1);
327        assert_eq!(offsets.dim(), *offsets_raw.last().unwrap());
328        assert!(!offsets.is_empty());
329
330        assert_eq!(offsets.at(0), 0..1);
331        assert_eq!(offsets.at(1), 1..3);
332        assert_eq!(offsets.at(2), 3..6);
333        assert_eq!(offsets.at(3), 6..10);
334        assert_eq!(offsets.at(4), 10..12);
335        assert_eq!(offsets.at(5), 12..13);
336        assert_eq!(offsets.at(6), 13..14);
337
338        // Finally, make sure the type is copyable.
339        assert!(is_copyable(offsets));
340        // Make sure `as_slice()` properly round-trips.
341        assert_eq!(offsets.as_slice(), offsets_raw.as_slice());
342
343        // `to_owned`
344        let offsets_owned = offsets.to_owned();
345        assert_eq!(offsets_owned.as_slice(), offsets_raw.as_slice());
346        assert_ne!(
347            offsets_owned.as_slice().as_ptr(),
348            offsets_raw.as_slice().as_ptr()
349        );
350        assert_eq!(offsets_owned.dim, offsets.dim);
351
352        // `as_view`.
353        let offsets_view = offsets_owned.as_view();
354        assert_eq!(offsets_view, offsets);
355        // ensure that pointers are preserved.
356        assert_eq!(
357            offsets_view.as_slice().as_ptr(),
358            offsets_owned.as_slice().as_ptr()
359        );
360    }
361
362    #[test]
363    #[should_panic(expected = "index 5 must be less than len 3")]
364    fn chunk_offset_indexing_panic() {
365        let offsets = ChunkOffsets::new(Box::new([0, 1, 2, 3])).unwrap();
366
367        // panics
368        let _ = offsets.at(5);
369    }
370
371    // Construction errors.
372    #[test]
373    fn chunk_offset_construction_errors() {
374        // Pass an empty slice.
375        let offsets = ChunkOffsets::new(Box::new([]));
376        assert_eq!(
377            offsets.unwrap_err().to_string(),
378            "offsets must have a length of at least 2, found 0"
379        );
380
381        // Pass a slice with length 1.
382        let offsets = ChunkOffsets::new(Box::new([0]));
383        assert_eq!(
384            offsets.unwrap_err().to_string(),
385            "offsets must have a length of at least 2, found 1"
386        );
387
388        // Doesn't start with zero.
389        let offsets = ChunkOffsets::new(Box::new([10, 11, 12, 13]));
390        assert_eq!(
391            offsets.unwrap_err().to_string(),
392            "offsets must begin at 0, not 10"
393        );
394
395        // Non-monotonic cases - zero sized chunk
396        let offsets = ChunkOffsets::new(Box::new([0, 10, 20, 30, 30, 40, 41]));
397        assert_eq!(
398            offsets.unwrap_err().to_string(),
399            "offsets must be strictly increasing, instead entry 30 at position 3 \
400            is followed by 30"
401        );
402
403        // Non-monotonic cases - decreasing size
404        let offsets = ChunkOffsets::new(Box::new([0, 10, 9, 10, 20]));
405        assert_eq!(
406            offsets.unwrap_err().to_string(),
407            "offsets must be strictly increasing, instead entry 10 at position 1 \
408            is followed by 9"
409        );
410
411        // Non-monotonic cases - some dimension after the first is zero.
412        let offsets = ChunkOffsets::new(Box::new([0, 10, 11, 12, 0]));
413        assert_eq!(
414            offsets.unwrap_err().to_string(),
415            "offsets must be strictly increasing, instead entry 12 at position 3 \
416            is followed by 0"
417        );
418
419        // Non-monotonic cases - second entry is zero.
420        let offsets = ChunkOffsets::new(Box::new([0, 0, 11, 12, 20]));
421        assert_eq!(
422            offsets.unwrap_err().to_string(),
423            "offsets must be strictly increasing, instead entry 0 at position 0 \
424            is followed by 0"
425        );
426    }
427
428    ///////////////
429    // ChunkView //
430    ///////////////
431
432    fn check_chunk_view<T>(
433        view: &ChunkViewImpl<'_, T>,
434        data: &[i32],
435        offsets: &[usize],
436        context: &dyn std::fmt::Display,
437    ) where
438        T: DenseData<Elem = i32>,
439    {
440        assert_eq!(view.len(), offsets.len() - 1, "{}", context);
441
442        // Ensure that each yielded slice matches that we retrieve manually.
443        for i in 0..view.len() {
444            let context = lazy_format!("start = {}, {}", i, context);
445            let start = offsets[i];
446            let stop = offsets[i + 1];
447
448            let expected = &data[start..stop];
449            let retrieved = &view[i];
450
451            assert_eq!(retrieved, expected, "{}", context);
452        }
453    }
454
455    #[test]
456    fn test_immutable_chunkview() {
457        let data: Vec<i32> = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
458        //                        |-----|  |--|  |--------|  |
459        //                          c0      c1       c2      c3
460        let offsets: Vec<usize> = vec![0, 3, 5, 9, 10];
461
462        let chunks = ChunkOffsetsView::new(offsets.as_slice()).unwrap();
463        let chunk_view = ChunkView::new(data.as_slice(), chunks).unwrap();
464
465        assert_eq!(chunk_view.len(), offsets.len() - 1);
466        assert_eq!(chunk_view.len(), chunks.len());
467
468        assert!(is_copyable(chunk_view));
469        let context = lazy_format!("chunkview happy path");
470        check_chunk_view(&chunk_view, &data, &offsets, &context);
471    }
472
473    #[test]
474    fn test_chunkview_construction_error() {
475        let data: Vec<i32> = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
476        //                        |-----|  |--|  |--------|  |
477        //                          c0      c1       c2      c3
478        let offsets: Vec<usize> = vec![0, 3, 5, 9]; // One too short.
479
480        let chunks = ChunkOffsetsView::new(offsets.as_slice()).unwrap();
481        let chunk_view = ChunkView::new(data.as_slice(), chunks);
482        assert!(chunk_view.is_err());
483        assert_eq!(
484            chunk_view.unwrap_err().to_string(),
485            "error in chunk view construction, got a slice of length 10 but \
486             the provided chunking schema expects a length of 9"
487        );
488    }
489
490    #[test]
491    fn test_mutable_chunkview() {
492        let mut data: Vec<i32> = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
493        //                            |-----|  |--|  |--------|  |
494        //                              c0      c1       c2      c3
495        let offsets: Vec<usize> = vec![0, 3, 5, 9, 10];
496
497        // We need to clone the original data before constructing the mutable chunk view
498        // to avoid both a mutable and immutable borrow in the checking function.
499        let data_clone = data.clone();
500
501        let chunks = ChunkOffsetsView::new(offsets.as_slice()).unwrap();
502        let mut chunk_view = MutChunkView::new(data.as_mut_slice(), chunks).unwrap();
503
504        assert_eq!(chunk_view.len(), offsets.len() - 1);
505        assert_eq!(chunk_view.len(), chunks.len());
506
507        let context = lazy_format!("mutchunkview happy path");
508        check_chunk_view(&chunk_view, &data_clone, &offsets, &context);
509
510        // Make sure that we can assign through the view.
511        for i in 0..chunk_view.len() {
512            let i_i32: i32 = i.try_into().unwrap();
513
514            chunk_view[i].iter_mut().for_each(|d| *d = i_i32);
515        }
516
517        // chunk 0
518        assert_eq!(data[0], 0);
519        assert_eq!(data[1], 0);
520        assert_eq!(data[2], 0);
521
522        // chunk 1
523        assert_eq!(data[3], 1);
524        assert_eq!(data[4], 1);
525
526        // chunk 2
527        assert_eq!(data[5], 2);
528        assert_eq!(data[6], 2);
529        assert_eq!(data[7], 2);
530        assert_eq!(data[8], 2);
531
532        // chunk 3
533        assert_eq!(data[9], 3);
534    }
535}