diskann_quantization/product/tables/
basic.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use crate::traits::CompressInto;
7use crate::views::{ChunkOffsetsBase, ChunkOffsetsView};
8use diskann_utils::views::{DenseData, MatrixBase, MatrixView};
9use diskann_vector::{distance::SquaredL2, PureDistanceFunction};
10use thiserror::Error;
11
12/// A basic PQ table that stores the pivot table in the following dense, row-major form:
13/// ```text
14///           | -- chunk 0 -- | -- chunk 1 -- | -- chunk 2 -- | .... | -- chunk N-1 -- |
15///           +------------------------------------------------------------------------+
16///  pivot 0  | c000 c001 ... | c010 c011 ... | c020 c021 ... | .... |       ...       |
17///  pivot 1  | c100 c101 ... | c110 c111 ... | c120 c121 ... | .... |       ...       |
18///    ...    |      ...      |      ...      |      ...      | .... |       ...       |
19///  pivot K  | cK00 cK01 ... | cK10 cK11 ... | cK20 cK21 ... | .... |       ...       |
20/// ```
21/// The member `offsets` describes the number of dimensions of each chunk.
22///
23/// # Invariants
24///
25/// * `offsets.dim() == pivots.nrows()`: The dimensionality of the two must agree.
26#[derive(Debug, Clone)]
27pub struct BasicTableBase<T, U>
28where
29    T: DenseData<Elem = f32>,
30    U: DenseData<Elem = usize>,
31{
32    pivots: MatrixBase<T>,
33    offsets: ChunkOffsetsBase<U>,
34}
35
36/// A `BasicTableBase` that owns its contents.
37pub type BasicTable = BasicTableBase<Box<[f32]>, Box<[usize]>>;
38
39/// A `BasicTableBase` that references its contents. Construction of such a table will
40/// not result in a memory allocation.
41pub type BasicTableView<'a> = BasicTableBase<&'a [f32], &'a [usize]>;
42
43#[derive(Error, Debug)]
44#[non_exhaustive]
45pub enum BasicTableError {
46    #[error("pivots have {pivot_dim} dimensions while the offsets expect {offsets_dim}")]
47    DimMismatch {
48        pivot_dim: usize,
49        offsets_dim: usize,
50    },
51    #[error("pivots cannot be empty")]
52    PivotsEmpty,
53}
54
55impl<T, U> BasicTableBase<T, U>
56where
57    T: DenseData<Elem = f32>,
58    U: DenseData<Elem = usize>,
59{
60    /// Construct a new `BasicTableBase` over the pivot table and offsets.
61    ///
62    /// # Error
63    ///
64    /// Returns an error if `pivots.ncols() != offsets.dim()` or if `pivots.nrows() == 0`.
65    pub fn new(
66        pivots: MatrixBase<T>,
67        offsets: ChunkOffsetsBase<U>,
68    ) -> Result<Self, BasicTableError> {
69        let pivot_dim = pivots.ncols();
70        let offsets_dim = offsets.dim();
71
72        if pivot_dim != offsets_dim {
73            Err(BasicTableError::DimMismatch {
74                pivot_dim,
75                offsets_dim,
76            })
77        } else if pivots.nrows() == 0 {
78            Err(BasicTableError::PivotsEmpty)
79        } else {
80            Ok(Self { pivots, offsets })
81        }
82    }
83
84    /// Return a view over the pivot table.
85    pub fn view_pivots(&self) -> MatrixView<'_, f32> {
86        self.pivots.as_view()
87    }
88
89    /// Return a view over the schema offsets.
90    pub fn view_offsets(&self) -> ChunkOffsetsView<'_> {
91        self.offsets.as_view()
92    }
93
94    /// Return the number of pivots in each PQ chunk.
95    pub fn ncenters(&self) -> usize {
96        self.pivots.nrows()
97    }
98
99    /// Return the number of PQ chunks.
100    pub fn nchunks(&self) -> usize {
101        self.offsets.len()
102    }
103
104    /// Return the dimensionality of the full-precision vectors associated with this table.
105    pub fn dim(&self) -> usize {
106        self.pivots.ncols()
107    }
108}
109
110#[derive(Error, Debug)]
111#[non_exhaustive]
112pub enum TableCompressionError {
113    #[error("num centers ({0}) must be at most 256 to compress into a byte vector")]
114    CannotCompressToByte(usize),
115    #[error("invalid input len - expected {0}, got {1}")]
116    InvalidInputDim(usize, usize),
117    #[error("invalid PQ buffer len - expected {0}, got {1}")]
118    InvalidOutputDim(usize, usize),
119    #[error("a value of infinity or NaN was observed while compressing chunk {0}")]
120    InfinityOrNaN(usize),
121}
122
123impl<T, U> CompressInto<&[f32], &mut [u8]> for BasicTableBase<T, U>
124where
125    T: DenseData<Elem = f32>,
126    U: DenseData<Elem = usize>,
127{
128    type Error = TableCompressionError;
129    type Output = ();
130
131    /// Compress the full-precision vector `from` into the PQ byte buffer `to`.
132    ///
133    /// Compression is performed by partitioning `from` into chunks according to the offsets
134    /// schema in the table and then finding the closest pivot according to the L2 distance.
135    ///
136    /// The final compressed value is the index of the closest pivot.
137    ///
138    /// # Errors
139    ///
140    /// Returns errors under the following conditions:
141    ///
142    /// * `self.ncenters() > 256`: If the number of centers exceeds 256, then it cannot be
143    ///   guaranteed that the index of the closest pivot for a chunk will fit losslessly in
144    ///   an 8-bit integer.
145    ///
146    /// * `from.len() != self.dim()`: The full precision vector must have the dimensionality
147    ///   expected by the compression.
148    ///
149    /// * `to.len() != self.nchunks()`: The PQ buffer must be sized appropriately.
150    ///
151    /// * If any chunk is sufficiently far from all centers that its distance becomes
152    ///   infinity to all centers.
153    ///
154    /// # Allocates
155    ///
156    /// This function should not allocate when successful.
157    ///
158    /// # Parallelism
159    ///
160    /// This function is single-threaded.
161    fn compress_into(&self, from: &[f32], to: &mut [u8]) -> Result<(), Self::Error> {
162        if self.ncenters() > 256 {
163            return Err(Self::Error::CannotCompressToByte(self.ncenters()));
164        }
165        if from.len() != self.dim() {
166            return Err(Self::Error::InvalidInputDim(self.dim(), from.len()));
167        }
168        if to.len() != self.nchunks() {
169            return Err(Self::Error::InvalidOutputDim(self.nchunks(), to.len()));
170        }
171
172        to.iter_mut().enumerate().try_for_each(|(chunk, to)| {
173            let mut min_distance = f32::INFINITY;
174            let mut min_index = usize::MAX;
175            let range = self.offsets.at(chunk);
176            let slice = &from[range.clone()];
177
178            self.pivots.row_iter().enumerate().for_each(|(index, row)| {
179                let distance: f32 = SquaredL2::evaluate(slice, &row[range.clone()]);
180                if distance < min_distance {
181                    min_distance = distance;
182                    min_index = index;
183                }
184            });
185
186            if min_distance.is_infinite() {
187                Err(Self::Error::InfinityOrNaN(chunk))
188            } else {
189                // This is guaranteed to be lossless because we have at most 256 centers.
190                *to = min_index as u8;
191                Ok(())
192            }
193        })
194    }
195}
196
197///////////
198// Tests //
199///////////
200
201#[cfg(test)]
202mod tests {
203    use diskann_utils::{lazy_format, views};
204    use rand::{
205        distr::{Distribution, StandardUniform},
206        SeedableRng,
207    };
208
209    use super::*;
210    use crate::product::tables::test::{
211        check_pqtable_single_compression_errors, create_dataset, create_pivot_tables,
212    };
213
214    /////////////////////////
215    // Basic Table Methods //
216    /////////////////////////
217
218    // Test that an error is returned when the dimension between the pivots and offsets
219    // disagree.
220    #[test]
221    fn error_on_mismatch_dim() {
222        let pivots = views::Matrix::new(0.0, 3, 5);
223        let offsets = crate::views::ChunkOffsets::new(Box::new([0, 1, 6])).unwrap();
224        let result = BasicTable::new(pivots, offsets);
225        assert!(result.is_err(), "dimensions are not equal");
226        assert_eq!(
227            result.unwrap_err().to_string(),
228            "pivots have 5 dimensions while the offsets expect 6"
229        );
230    }
231
232    // Test that the table constructor errors when there are no pivots.
233    #[test]
234    fn error_on_no_pivots() {
235        let pivots = views::Matrix::new(0.0, 0, 5);
236        let offsets = crate::views::ChunkOffsets::new(Box::new([0, 1, 2, 5])).unwrap();
237        let result = BasicTable::new(pivots, offsets);
238        assert!(result.is_err(), "pivots is empty");
239        assert_eq!(result.unwrap_err().to_string(), "pivots cannot be empty",);
240    }
241
242    #[test]
243    fn basic_table() {
244        let mut rng = rand::rngs::StdRng::seed_from_u64(0xd96bac968083ec29);
245        for dim in [5, 10, 12] {
246            for total in [1, 2, 3] {
247                let pivots = views::Matrix::new(
248                    views::Init(|| -> f32 { StandardUniform {}.sample(&mut rng) }),
249                    total,
250                    dim,
251                );
252                let offsets = crate::views::ChunkOffsets::new(Box::new([0, 1, 3, dim])).unwrap();
253
254                let table = BasicTable::new(pivots.clone(), offsets.clone()).unwrap();
255
256                assert_eq!(table.ncenters(), total);
257                assert_eq!(table.nchunks(), offsets.len());
258                assert_eq!(table.dim(), offsets.dim());
259                assert_eq!(table.view_pivots().as_view(), pivots.as_view());
260                assert_eq!(table.view_offsets().as_view(), offsets.as_view());
261            }
262        }
263    }
264
265    /////////////////
266    // Compression //
267    /////////////////
268
269    #[test]
270    fn test_happy_path() {
271        // Feed in chunks of dimension 1, 2, 3, ... 16.
272        //
273        // If we're using MIRI, max out at 7 dimensions.
274        let offsets: Vec<usize> = if cfg!(miri) {
275            vec![0, 1, 3, 6, 10, 15, 21, 28, 36]
276        } else {
277            vec![
278                0, 1, 3, 6, 10, 15, 21, 28, 36, 45, 55, 66, 78, 91, 105, 120, 136,
279            ]
280        };
281
282        let schema = crate::views::ChunkOffsetsView::new(&offsets).unwrap();
283        let mut rng = rand::rngs::StdRng::seed_from_u64(0xda5b2e661eabacea);
284
285        let num_data = 20;
286        let num_trials = if cfg!(miri) { 1 } else { 10 };
287
288        for &num_centers in [16, 24, 13, 17].iter() {
289            for trial in 0..num_trials {
290                let context = lazy_format!(
291                    "happy path, num centers = {}, num data = {}, trial = {}",
292                    num_centers,
293                    num_data,
294                    trial,
295                );
296
297                println!("Currently = {}", context);
298
299                let (pivots, offsets) = create_pivot_tables(schema.to_owned(), num_centers);
300                let table = BasicTable::new(pivots, offsets).unwrap();
301                let (data, expected) = create_dataset(schema, num_centers, num_data, &mut rng);
302
303                let mut output = vec![0; schema.len()];
304                for (input, expected) in std::iter::zip(data.row_iter(), expected.row_iter()) {
305                    table.compress_into(input, &mut output).unwrap();
306                    for (entry, (e, o)) in
307                        std::iter::zip(expected.iter(), output.iter()).enumerate()
308                    {
309                        let o: usize = (*o).into();
310                        assert_eq!(*e, o, "unexpected assignment at dim {}", entry);
311                    }
312                }
313            }
314        }
315    }
316
317    #[test]
318    fn test_compression_error() {
319        let dim = 10;
320        let num_chunks = 3;
321        let offsets = crate::views::ChunkOffsets::new(Box::new([0, 4, 9, 10])).unwrap();
322
323        // Set up `ncenters > 256`.
324        {
325            let pivots = views::Matrix::new(0.0, 257, dim);
326            let table = BasicTable::new(pivots, offsets.clone()).unwrap();
327
328            let input = vec![f32::default(); dim];
329            let mut output = vec![u8::MAX; num_chunks];
330            let result = table.compress_into(&input, &mut output);
331            assert!(result.is_err());
332            assert_eq!(
333                result.unwrap_err().to_string(),
334                "num centers (257) must be at most 256 to compress into a byte vector"
335            );
336            assert!(
337                output.iter().all(|i| *i == u8::MAX),
338                "output vector should be unmodified"
339            );
340        }
341
342        // Setup input dim not equal to expected.
343        {
344            let pivots = views::Matrix::new(0.0, 10, dim);
345            let table = BasicTable::new(pivots, offsets.clone()).unwrap();
346
347            let input = vec![f32::default(); dim - 1];
348            let mut output = vec![u8::MAX; num_chunks];
349            let result = table.compress_into(&input, &mut output);
350            assert!(result.is_err());
351            assert_eq!(
352                result.unwrap_err().to_string(),
353                format!("invalid input len - expected {}, got {}", dim, dim - 1),
354            );
355            assert!(
356                output.iter().all(|i| *i == u8::MAX),
357                "output vector should be unmodified"
358            );
359        }
360
361        // Setup output dim not equal to expected.
362        {
363            let pivots = views::Matrix::new(0.0, 10, dim);
364            let table = BasicTable::new(pivots, offsets.clone()).unwrap();
365
366            let input = vec![f32::default(); dim];
367            let mut output = vec![u8::MAX; num_chunks - 1];
368            let result = table.compress_into(&input, &mut output);
369            assert!(result.is_err());
370            assert_eq!(
371                result.unwrap_err().to_string(),
372                format!(
373                    "invalid PQ buffer len - expected {}, got {}",
374                    num_chunks,
375                    num_chunks - 1
376                ),
377            );
378            assert!(
379                output.iter().all(|i| *i == u8::MAX),
380                "output vector should be unmodified"
381            );
382        }
383    }
384
385    #[test]
386    fn test_table_single_compression_errors() {
387        check_pqtable_single_compression_errors(
388            &|pivots: views::Matrix<f32>, offsets| BasicTable::new(pivots, offsets).unwrap(),
389            &"BasicTable",
390        )
391    }
392}