Skip to main content

diskann_quantization/product/tables/transposed/
table.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::num::NonZeroUsize;
7
8use super::pivots;
9use crate::{
10    product::tables::basic::TableCompressionError,
11    traits::CompressInto,
12    views::{ChunkOffsets, ChunkOffsetsView},
13};
14use diskann_utils::{
15    strided,
16    views::{self, MatrixView, MutMatrixView},
17};
18use thiserror::Error;
19
20/// A PQ table that stores the pivots for each chunk in a miniture block-transpose to
21/// facilitate faster compression. The exact layout is not documented (as it is for the
22/// `BasicTable`) because it may be subject to change.
23///
24/// The advantage of this table over the `BasicTable` is that this table uses a more
25/// hardware friendly layout for the pivots, meaning that compression (particularly batch
26/// compression) is much faster than the basic table, at the cost of a slightly higher
27/// memory footprint.
28///
29/// # Invariants (Dev Docs)
30///
31/// * `pivots.len() == schema.len()`: The number of PQ chunks must agree.
32/// * `pivots[i].dimension() == schema.at(i).len()` for all inbounds `i`.
33/// * `pivots[i].total() == ncenters` for all inbounds `i`.
34/// * `largest = max(schema.at(i).len() for i)`.
35#[derive(Debug)]
36pub struct TransposedTable {
37    pivots: Box<[pivots::Chunk]>,
38    offsets: crate::views::ChunkOffsets,
39    /// The largest dimension in offsets.
40    largest: usize,
41    /// The number of centers in each block.
42    ncenters: usize,
43}
44
45#[derive(Debug, Error)]
46#[non_exhaustive]
47pub enum TransposedTableError {
48    #[error("pivots have {pivot_dim} dimensions while the offsets expect {offsets_dim}")]
49    DimMismatch {
50        pivot_dim: usize,
51        offsets_dim: usize,
52    },
53    #[error("error constructing pivot {problem} of {total}")]
54    PivotError {
55        problem: usize,
56        total: usize,
57        source: pivots::ChunkConstructionError,
58    },
59}
60
61impl TransposedTable {
62    /// Construct a new `TransposedTable` from raw parts.
63    ///
64    /// # Error
65    ///
66    /// Returns an error if
67    ///
68    /// * `pivots.ncols() != offsets.dim()`: Pivots must have the dimensionality expected
69    ///   by the offsets.
70    ///
71    /// * `pivots.nrows() == 0`: The pivot table cannot be empty.
72    #[allow(clippy::expect_used)]
73    pub fn from_parts(
74        pivots: views::MatrixView<f32>,
75        offsets: ChunkOffsets,
76    ) -> Result<Self, TransposedTableError> {
77        let pivot_dim = pivots.ncols();
78        let offsets_dim = offsets.dim();
79        if pivot_dim != offsets_dim {
80            return Err(TransposedTableError::DimMismatch {
81                pivot_dim,
82                offsets_dim,
83            });
84        }
85
86        let ncenters = pivots.nrows();
87        let mut largest = 0;
88        let pivots: Box<[_]> = (0..offsets.len())
89            .map(|i| {
90                let range = offsets.at(i);
91                largest = largest.max(range.len());
92                let view = strided::StridedView::try_shrink_from(
93                    &(pivots.as_slice()[range.start..]),
94                    pivots.nrows(),
95                    range.len(),
96                    offsets.dim(),
97                )
98                .expect(
99                    "the check on `pivot_dim` and `offsets_dim` should cause this to never error",
100                );
101                pivots::Chunk::new(view).map_err(|source| TransposedTableError::PivotError {
102                    problem: i,
103                    total: offsets.len(),
104                    source,
105                })
106            })
107            .collect::<Result<Box<[_]>, TransposedTableError>>()?;
108
109        debug_assert_eq!(pivots.len(), offsets.len());
110        Ok(Self {
111            pivots,
112            offsets,
113            largest,
114            ncenters,
115        })
116    }
117
118    /// Return the number of pivots in each PQ chunk.
119    pub fn ncenters(&self) -> usize {
120        self.ncenters
121    }
122
123    /// Return the number of PQ chunks.
124    pub fn nchunks(&self) -> usize {
125        self.offsets.len()
126    }
127
128    /// Return the dimensionality of the full-precision vectors associated with this table.
129    pub fn dim(&self) -> usize {
130        self.offsets.dim()
131    }
132
133    /// Return a view over the schema offsets.
134    pub fn view_offsets(&self) -> ChunkOffsetsView<'_> {
135        self.offsets.as_view()
136    }
137
138    /// Perform PQ compression on the dataset by mapping each chunk in `data` to its nearest
139    /// neighbor in the corresponding entry in `chunks`.
140    ///
141    /// The index of the nearest neighbor is provided to `compression_delegate` along with its
142    /// corresponding row in data and chunk index.
143    ///
144    /// Calls to `compression_delegate` may occur in any order.
145    ///
146    /// Visitor will be invoked for all rows in `0..data.nrows()` and all chunks in
147    /// `0..schema.len()`.
148    ///
149    /// # Panics
150    ///
151    /// Panics under the following conditions:
152    /// * `data.cols() != self.dim()`: The number of columns in the source dataset must match
153    ///   the number of dimensions expected by the schema.
154    #[allow(clippy::expect_used)]
155    pub fn compress_batch<T, F, DelegateError>(
156        &self,
157        data: views::MatrixView<'_, T>,
158        mut compression_delegate: F,
159    ) -> Result<(), CompressError<DelegateError>>
160    where
161        T: Copy + Into<f32>,
162        F: FnMut(RowChunk, pivots::CompressionResult) -> Result<(), DelegateError>,
163        DelegateError: std::error::Error,
164    {
165        assert_eq!(
166            data.ncols(),
167            self.dim(),
168            "schema expects {} dimensions but data has {}",
169            self.dim(),
170            data.ncols()
171        );
172
173        // The batch size expected by the compression micro-kernel.
174        const SUB_BATCH_SIZE: usize = pivots::Chunk::batchsize();
175
176        let dim_nonzero = self.offsets.dim_nonzero();
177        let mut packing_buffer: Box<[f32]> =
178            (0..self.largest * SUB_BATCH_SIZE).map(|_| 0.0).collect();
179
180        // Stride along the chunks in the source matrix to keep the associated `Chunk` in the
181        // cache.
182        let nrows = data.nrows();
183        let ncols = data.ncols();
184        let slice = data.as_slice();
185
186        for (i, chunk) in self.pivots.iter().enumerate() {
187            let range = self.offsets.at(i);
188            if let Some(chunk_dim) = NonZeroUsize::new(range.len()) {
189                // Construct a view for the packing buffer for this chunk.
190                let mut packing_view = views::MutMatrixView::try_from(
191                    &mut packing_buffer[..SUB_BATCH_SIZE * chunk_dim.get()],
192                    SUB_BATCH_SIZE,
193                    chunk_dim.get(),
194                )
195                .expect("the packing buffer should have been sized correctly");
196
197                for row_start in (0..nrows).step_by(SUB_BATCH_SIZE) {
198                    let row_end = nrows.min(row_start + SUB_BATCH_SIZE);
199
200                    // If this is a full batch, the use the batched micro-kernel.
201                    if row_end - row_start == SUB_BATCH_SIZE {
202                        // When computing the stop offset, don't try to adjust for the length
203                        // of the underlying span.
204                        //
205                        // The control on our loop bounds mean we should never be in a situation
206                        // where this would occur, so we'd rather hit the indexing panic early.
207                        let mut linear_start = row_start * ncols + range.start;
208                        packing_view.row_iter_mut().for_each(|row| {
209                            pack(row, &slice[linear_start..linear_start + chunk_dim.get()]);
210                            linear_start += dim_nonzero.get();
211                        });
212
213                        let result = chunk.find_closest_batch(packing_view.as_view().into());
214
215                        // Invoke the delegate with the results.
216                        // If the delegate returns an error, wrap that error inside a
217                        // `CompressError` to add context and forward the error upward.
218                        for (j, &r) in result.iter().enumerate() {
219                            compression_delegate(
220                                RowChunk {
221                                    row: row_start + j,
222                                    chunk: i,
223                                },
224                                r,
225                            )
226                            .map_err(|inner| CompressError {
227                                inner,
228                                row: row_start + j,
229                                chunk: i,
230                                nearest: r.into_inner(),
231                            })?;
232                        }
233                    } else {
234                        // Handle remainders one at a time.
235                        for row in row_start..row_end {
236                            let linear_start = row * ncols + range.start;
237                            let linear_stop = linear_start + range.len();
238
239                            // Pre-convert to f32.
240                            let packed = &mut packing_view.row_mut(0);
241                            pack(packed, &slice[linear_start..linear_stop]);
242                            let result = chunk.find_closest(packed);
243
244                            compression_delegate(RowChunk { row, chunk: i }, result).map_err(
245                                |inner| CompressError {
246                                    inner,
247                                    row,
248                                    chunk: i,
249                                    nearest: result.into_inner(),
250                                },
251                            )?;
252                        }
253                    }
254                }
255            }
256        }
257        Ok(())
258    }
259
260    /// Compute the operation defined by `T` for each chunk in the query on all corresponding
261    /// pivots for that chunk, storing the result in the output matrix.
262    ///
263    /// For example, this can be used to compute squared L2 distances between chunks of the
264    /// query and the pivot table to create a fast run-time lookup table for these distances.
265    ///
266    /// This is currently implemented for the following operation types `T`:
267    ///
268    /// * `quantization::distances::SquaredL2`
269    /// * `quantization::distances::InnerProduct`
270    ///
271    /// # Arguments
272    ///
273    /// * `query`: The query slice to process. Must have length `self.dim()`.
274    /// * `partials`: Output matrix for the partial results. The result of the computation
275    ///   of chunk `i` against pivot `j` will be stored into `pivots[(i, j)]`.
276    ///
277    ///   Must have `nrows = self.nchunks()` and `ncols = self.ncenters()`.
278    ///
279    /// # Panics
280    ///
281    /// Panics if:
282    /// * `query.len() != self.dim()`.
283    /// * `partisl.nrows() != self.nchunks()`.
284    /// * `partisl.ncols() != self.ncenters()`.
285    pub fn process_into<T>(&self, query: &[f32], mut partials: MutMatrixView<'_, f32>)
286    where
287        T: pivots::ProcessInto,
288    {
289        // Check Requirements
290        assert_eq!(
291            query.len(),
292            self.dim(),
293            "query has the wrong number of dimensions"
294        );
295        assert_eq!(
296            partials.ncols(),
297            self.ncenters(),
298            "output has the wrong number of columns"
299        );
300        assert_eq!(
301            partials.nrows(),
302            self.nchunks(),
303            "output has the wrong number of rows"
304        );
305
306        // Loop over each chunk.
307        std::iter::zip(self.pivots.iter(), partials.row_iter_mut())
308            .enumerate()
309            .for_each(|(i, (pivot, out))| {
310                let range = self.offsets.at(i);
311                T::process_into(pivot, &query[range], out);
312            });
313    }
314}
315
316/// Row and chunk indexes provided to the `compression_delegate` argument of `compress` to
317/// describe the position of the nearest neighbor being provided.
318pub struct RowChunk {
319    row: usize,
320    chunk: usize,
321}
322
323#[derive(Error, Debug)]
324#[error(
325    "compression delegate returned \"{inner}\" when processing row {row} and chunk \
326    {chunk} with nearest center {nearest}"
327)]
328pub struct CompressError<DelegateError: std::error::Error> {
329    inner: DelegateError,
330    row: usize,
331    chunk: usize,
332    nearest: u32,
333}
334
335#[inline(always)]
336fn pack<T>(dst: &mut [f32], src: &[T])
337where
338    T: Copy + Into<f32>,
339{
340    debug_assert_eq!(dst.len(), src.len());
341    std::iter::zip(dst.iter_mut(), src.iter()).for_each(|(d, &s)| *d = s.into());
342}
343
344///////////////
345// Coompress //
346///////////////
347
348impl CompressInto<&[f32], &mut [u8]> for TransposedTable {
349    type Error = TableCompressionError;
350    type Output = ();
351
352    /// Compress the full-precision vector `from` into the PQ byte buffer `to`.
353    ///
354    /// Compression is performed by partitioning `from` into chunks according to the offsets
355    /// schema in the table and then finding the closest pivot according to the L2 distance.
356    ///
357    /// The final compressed value is the index of the closest pivot.
358    ///
359    /// # Errors
360    ///
361    /// Returns errors under the following conditions:
362    ///
363    /// * `self.ncenters() > 256`: If the number of centers exceeds 256, then it cannot be
364    ///   guaranteed that the index of the closest pivot for a chunk will fit losslessly in
365    ///   an 8-bit integer.
366    ///
367    /// * `from.len() != self.dim()`: The full precision vector must have the dimensionality
368    ///   expected by the compression.
369    ///
370    /// * `to.len() != self.nchunks()`: The PQ buffer must be sized appropriately.
371    ///
372    /// * If any chunk is sufficiently far from all centers that its distance becomes
373    ///   infinity to all centers.
374    ///
375    /// # Allocates
376    ///
377    /// This function should not allocate when successful.
378    ///
379    /// # Parallelism
380    ///
381    /// This function is single-threaded.
382    fn compress_into(&self, from: &[f32], to: &mut [u8]) -> Result<(), Self::Error> {
383        if self.ncenters() > 256 {
384            return Err(Self::Error::CannotCompressToByte(self.ncenters()));
385        }
386        if from.len() != self.dim() {
387            return Err(Self::Error::InvalidInputDim(self.dim(), from.len()));
388        }
389        if to.len() != self.nchunks() {
390            return Err(Self::Error::InvalidOutputDim(self.nchunks(), to.len()));
391        }
392
393        std::iter::zip(self.pivots.iter(), to.iter_mut())
394            .enumerate()
395            .try_for_each(|(i, (pivot, to))| {
396                let range = self.offsets.at(i);
397                let result = pivot.find_closest(&from[range]);
398                result.map(
399                    |v| *to = v as u8, // conversion guaranteed to be lossless
400                    || Self::Error::InfinityOrNaN(i),
401                )
402            })
403    }
404}
405
406#[derive(Error, Debug)]
407pub enum TableBatchCompressionError {
408    #[error("num centers ({0}) must be at most 256 to compress into a byte vector")]
409    CannotCompressToByte(usize),
410    #[error("invalid input len - expected {0}, got {1}")]
411    InvalidInputDim(usize, usize),
412    #[error("invalid PQ buffer len - expected {0}, got {1}")]
413    InvalidOutputDim(usize, usize),
414    #[error(
415        "input and output must have the same number of rows - instead, got {0} and {1} \
416         (respectively)"
417    )]
418    UnequalRows(usize, usize),
419    #[error(
420        "a value of infinity or NaN was observed while compressing chunk {0} of batch input {1}"
421    )]
422    InfinityOrNaN(usize, usize),
423}
424
425impl<T> CompressInto<MatrixView<'_, T>, MutMatrixView<'_, u8>> for TransposedTable
426where
427    T: Copy + Into<f32>,
428{
429    type Error = TableBatchCompressionError;
430    type Output = ();
431
432    /// Compress each full-precision row in `from` into the corresponding row in `to`.
433    ///
434    /// Compression is performed by partitioning `from` into chunks according to the offsets
435    /// schema in the table and then finding the closest pivot according to the L2 distance.
436    ///
437    /// The final compressed value is the index of the closest pivot.
438    ///
439    /// # Errors
440    ///
441    /// Returns errors under the following conditions:
442    ///
443    /// * `self.ncenters() > 256`: If the number of centers exceeds 256, then it cannot be
444    ///   guaranteed that the index of the closest pivot for a chunk will fit losslessly in
445    ///   an 8-bit integer.
446    ///
447    /// * `from.ncols() != self.dim()`: The full precision data must have the dimensionality
448    ///   expected by the compression.
449    ///
450    /// * `to.ncols() != self.nchunks()`: The PQ buffer must be sized appropriately.
451    ///
452    /// * `from.nrows() == to.nrows()`: The input and output buffers must have the same
453    ///   number of elements.
454    ///
455    /// * If any chunk is sufficiently far from all centers that its distance becomes
456    ///   infinity to all centers.
457    ///
458    /// # Allocates
459    ///
460    /// Allocates scratch memory proportional to the length of the largest chunk.
461    ///
462    /// # Parallelism
463    ///
464    /// This function is single-threaded.
465    fn compress_into(
466        &self,
467        from: MatrixView<'_, T>,
468        mut to: MutMatrixView<'_, u8>,
469    ) -> Result<(), Self::Error> {
470        if self.ncenters() > 256 {
471            return Err(Self::Error::CannotCompressToByte(self.ncenters()));
472        }
473        if from.ncols() != self.dim() {
474            return Err(Self::Error::InvalidInputDim(self.dim(), from.ncols()));
475        }
476        if to.ncols() != self.nchunks() {
477            return Err(Self::Error::InvalidOutputDim(self.nchunks(), to.ncols()));
478        }
479        if from.nrows() != to.nrows() {
480            return Err(Self::Error::UnequalRows(from.nrows(), to.nrows()));
481        }
482
483        // The `CompressionError` already has all the information we need, so we make the
484        // `Delegate` error light-weight.
485        #[derive(Debug, Error)]
486        #[error("unreachable")]
487        struct PassThrough;
488
489        // Do the compression. We will reformat the NaN warning from `CompressionError`
490        // if needed.
491        let result = self.compress_batch(
492            from,
493            |RowChunk { row, chunk }, result| -> Result<(), PassThrough> {
494                result.map(|v| to[(row, chunk)] = v as u8, || PassThrough)
495            },
496        );
497
498        result.map_err(|err| Self::Error::InfinityOrNaN(err.chunk, err.row))
499    }
500}
501
502#[cfg(test)]
503mod test_compression {
504    use std::collections::HashSet;
505
506    use diskann_vector::{PureDistanceFunction, distance};
507    use rand::{
508        Rng, SeedableRng,
509        distr::{Distribution, StandardUniform, Uniform},
510        rngs::StdRng,
511    };
512
513    use super::*;
514    use crate::{
515        distances::{InnerProduct, SquaredL2},
516        error::format,
517        product::tables::test::{
518            check_pqtable_batch_compression_errors, check_pqtable_single_compression_errors,
519            create_dataset, create_pivot_tables,
520        },
521    };
522    use diskann_utils::lazy_format;
523
524    //////////////////////////////
525    // Transposed Table Methods //
526    //////////////////////////////
527
528    // Test that an error is returned when the dimension between the pivots and offsets
529    // disagree.
530    #[test]
531    fn error_on_mismatch_dim() {
532        let pivots = views::Matrix::new(0.0, 3, 5);
533        let offsets = ChunkOffsets::new(Box::new([0, 1, 6])).unwrap();
534        let result = TransposedTable::from_parts(pivots.as_view(), offsets);
535        assert!(result.is_err(), "dimensions are not equal");
536        assert_eq!(
537            result.unwrap_err().to_string(),
538            "pivots have 5 dimensions while the offsets expect 6"
539        );
540    }
541
542    // Test that an error is returned when the dimension between the pivots and offsets
543    // disagree.
544    #[test]
545    fn error_on_empty() {
546        let pivots = views::Matrix::new(0.0, 0, 5);
547        let offsets = ChunkOffsets::new(Box::new([0, 1, 5])).unwrap();
548        let result = TransposedTable::from_parts(pivots.as_view(), offsets);
549        assert!(result.is_err(), "dimensions are not equal");
550
551        let expected = [
552            "error constructing pivot 0 of 2",
553            "    caused by: cannot construct a Chunk from a source with zero length",
554        ]
555        .join("\n");
556
557        assert_eq!(format(&result.unwrap_err()), expected,);
558    }
559
560    #[test]
561    fn basic_table() {
562        let mut rng = StdRng::seed_from_u64(0xd96bac968083ec29);
563        for dim in [5, 10, 12] {
564            // Sweep over enough totals to ensure the inner chunks have a non-trivial layout.
565            for total in [1, 2, 3, 7, 8, 9, 10] {
566                let pivots = views::Matrix::new(
567                    views::Init(|| -> f32 { StandardUniform {}.sample(&mut rng) }),
568                    total,
569                    dim,
570                );
571                let offsets = ChunkOffsets::new(Box::new([0, 1, 3, dim])).unwrap();
572                let table = TransposedTable::from_parts(pivots.as_view(), offsets.clone()).unwrap();
573
574                assert_eq!(table.ncenters(), total);
575                assert_eq!(table.nchunks(), offsets.len());
576                assert_eq!(table.dim(), offsets.dim());
577
578                // This kind of looks into the guts of this data structure, but is an extra
579                // check that the plumbing was performed properly.
580                for chunk in 0..offsets.len() {
581                    let range = offsets.at(chunk);
582                    let pivot = &table.pivots[chunk];
583                    for row in 0..total {
584                        let r = &pivots.row(row)[range.clone()];
585                        for (col, expected) in r.iter().enumerate() {
586                            assert_eq!(pivot.get(row, col), *expected);
587                        }
588                    }
589                }
590
591                assert_eq!(table.view_offsets(), offsets.as_view());
592            }
593        }
594    }
595
596    /////////////////
597    // Compression //
598    /////////////////
599
600    #[derive(Error, Debug)]
601    #[error("unreachable reached")]
602    struct Infallible;
603
604    #[test]
605    fn test_happy_path() {
606        // Feed in chunks of dimension 1, 2, 3, ... 16.
607        //
608        // If we're using MIRI, max out at 7 dimensions.
609        let offsets: Vec<usize> = if cfg!(miri) {
610            vec![0, 1, 3, 6, 10, 15, 21, 28, 36]
611        } else {
612            vec![
613                0, 1, 3, 6, 10, 15, 21, 28, 36, 45, 55, 66, 78, 91, 105, 120, 136,
614            ]
615        };
616
617        let schema = ChunkOffsetsView::new(&offsets).unwrap();
618        let mut rng = StdRng::seed_from_u64(0x88e3d3366501ad6c);
619
620        let num_data = if cfg!(miri) {
621            vec![0, 8, 9, 10, 11]
622        } else {
623            vec![0, 1, 2, 3, 4, 16, 17, 18, 19]
624        };
625
626        let num_trials = if cfg!(miri) { 1 } else { 10 };
627
628        // Strategically pick `num_centers` so we cover the corner cases in lower-level
629        // handling.
630        //
631        // This includes:
632        // * Full blocks with an even number of total blocks (num centers = 16);
633        // * Full blocks with an odd number of total blocks (num centers = 24);
634        // * Partially full blocks with the last block even (num centers = 13);
635        // * Partially full blocks with the last block odd (num centers = 17);
636        for &num_centers in [16, 24, 13, 17].iter() {
637            for &num_data in num_data.iter() {
638                for trial in 0..num_trials {
639                    let context = lazy_format!(
640                        "happy path, num centers = {}, num data = {}, trial = {}",
641                        num_centers,
642                        num_data,
643                        trial,
644                    );
645
646                    println!("Currently = {}", context);
647
648                    let (pivots, offsets) = create_pivot_tables(schema.to_owned(), num_centers);
649                    let table = TransposedTable::from_parts(pivots.as_view(), offsets).unwrap();
650                    let (data, expected) = create_dataset(schema, num_centers, num_data, &mut rng);
651
652                    let mut called = HashSet::<(usize, usize)>::new();
653
654                    // Direct method call.
655                    table
656                        .compress_batch(
657                            data.as_view(),
658                            |RowChunk { row, chunk }, value| -> Result<(), Infallible> {
659                                assert!(value.is_okay());
660                                // Ensure that this is the expected value.
661                                assert_eq!(
662                                value.unwrap() as usize,
663                                expected[(row, chunk)],
664                                "failed at (row = {row}, chunk = {chunk}). data = {:?}, context: {}",
665                                &(data.row(row)[schema.at(chunk)]),
666                                context,
667                            );
668
669                                // (A) Ensure that this combination of row and chunk hasn't
670                                //     been called before.
671                                // (B) Record that this combination has been called.
672                                assert!(
673                                    called.insert((row, chunk)),
674                                    "row {row} and chunk {chunk}, called multiple times. Context = {}",
675                                    context,
676                                );
677
678                                Ok(())
679                            },
680                        )
681                        .unwrap();
682
683                    assert_eq!(called.len(), num_data * schema.len());
684
685                    // Trait Interface.
686                    let mut output = views::Matrix::new(0, num_data, schema.len());
687                    table
688                        .compress_into(data.as_view(), output.as_mut_view())
689                        .unwrap();
690
691                    assert_eq!(output.nrows(), expected.nrows());
692                    assert_eq!(output.ncols(), expected.ncols());
693                    for row in 0..output.nrows() {
694                        for col in 0..output.ncols() {
695                            assert_eq!(
696                                output[(row, col)] as usize,
697                                expected[(row, col)],
698                                "failed on row {}, col {}. Context = {}",
699                                row,
700                                col,
701                                context,
702                            );
703                        }
704                    }
705
706                    // Trait inteface - single step.
707                    let mut output = vec![0; schema.len()];
708                    for (i, (row, expected)) in
709                        std::iter::zip(data.row_iter(), expected.row_iter()).enumerate()
710                    {
711                        table.compress_into(row, output.as_mut_slice()).unwrap();
712                        for (d, (o, e)) in
713                            std::iter::zip(output.iter(), expected.iter()).enumerate()
714                        {
715                            assert_eq!(
716                                *o as usize, *e,
717                                "failed on row {}, col {}. Context = {}",
718                                i, d, context
719                            );
720                        }
721                    }
722                }
723            }
724        }
725    }
726
727    ////////////////////
728    // Error Handling //
729    ////////////////////
730
731    #[test]
732    #[should_panic(expected = "schema expects 4 dimensions but data has 5")]
733    fn panic_on_dim_mismatch() {
734        let offsets = [0, 4];
735        let data: Vec<f32> = vec![0.0; 5];
736
737        let schema = ChunkOffsetsView::new(&offsets).unwrap();
738        let (pivots, offsets) = create_pivot_tables(schema.to_owned(), 3);
739        let table = TransposedTable::from_parts(pivots.as_view(), offsets).unwrap();
740
741        // should panic
742        let _ = table.compress_batch(
743            views::MatrixView::try_from(data.as_slice(), 1, 5).unwrap(),
744            |_, _| -> Result<(), Infallible> { panic!("this shouldn't be called") },
745        );
746    }
747
748    #[derive(Error, Debug)]
749    #[error("compression delegate error with {0}")]
750    struct DelegateError(u64);
751
752    // The strategy here is to construct a delegate that returns an error on a specific
753    // row and chunk. We then make sure that the error is propagated successfully along
754    // with the recorded row and chunk.
755    #[test]
756    fn test_delegate_error_propagation() {
757        let offsets: Vec<usize> = vec![0, 1, 7];
758        let schema = ChunkOffsetsView::new(&offsets).unwrap();
759        let mut rng = StdRng::seed_from_u64(0xc35a90da17fafa2a);
760
761        let num_centers = 3;
762        let num_data = 7;
763
764        let (pivots, offsets) = create_pivot_tables(schema.to_owned(), num_centers);
765        let table = TransposedTable::from_parts(pivots.as_view(), offsets).unwrap();
766        let (data, _) = create_dataset(schema, num_centers, num_data, &mut rng);
767
768        let data_view =
769            views::MatrixView::try_from(data.as_slice(), num_data, schema.dim()).unwrap();
770        let distribution = rand_distr::StandardUniform {};
771
772        for row in 0..data_view.nrows() {
773            for chunk in 0..schema.len() {
774                let context = lazy_format!("row = {row}, chunk = {chunk}");
775
776                // Generate a random number for the delegate error.
777                let value: u64 = rng.sample(distribution);
778
779                let result = table.compress_batch(
780                    data_view,
781                    |RowChunk {
782                         row: this_row,
783                         chunk: this_chunk,
784                     },
785                     _| {
786                        if this_row == row && this_chunk == chunk {
787                            Err(DelegateError(value))
788                        } else {
789                            Ok(())
790                        }
791                    },
792                );
793                assert!(result.is_err(), "{}", context);
794
795                let message = result.unwrap_err().to_string();
796                assert!(
797                    message.contains(&format!("{}", DelegateError(value))),
798                    "{}",
799                    context
800                );
801                assert!(message.contains("delegate returned"));
802                assert!(message.contains(&format!("when processing row {row} and chunk {chunk}")));
803            }
804        }
805    }
806
807    #[test]
808    #[cfg(not(miri))]
809    fn test_table_single_compression_errors() {
810        check_pqtable_single_compression_errors(
811            &|pivots: views::Matrix<f32>, offsets| {
812                TransposedTable::from_parts(pivots.as_view(), offsets).unwrap()
813            },
814            &"TranposedTable",
815        )
816    }
817
818    #[test]
819    #[cfg(not(miri))]
820    fn test_table_batch_compression_errors() {
821        check_pqtable_batch_compression_errors(
822            &|pivots: views::Matrix<f32>, offsets| {
823                TransposedTable::from_parts(pivots.as_view(), offsets).unwrap()
824            },
825            &"TranposedTable",
826        )
827    }
828
829    /////////////////////////
830    // Test `process_into` //
831    /////////////////////////
832
833    fn test_process_into_impl(
834        num_chunks: usize,
835        num_centers: usize,
836        num_trials: usize,
837        rng: &mut StdRng,
838    ) {
839        // Choose the chunk size randomly from this distribution. Keep choosing chunks
840        // sizes until the desired `num_chunks` is reached.
841        //
842        // The sum of chunk sizes gives the dimensionality.
843        let chunk_size_distribution = Uniform::<usize>::new(1, 6).unwrap();
844
845        // Use integer values for the value distribution to avoid dealing with floating
846        // point rounding.
847        let value_distribution = Uniform::<i32>::new(-10, 10).unwrap();
848
849        for trial in 0..num_trials {
850            let mut offsets: Vec<usize> = vec![0];
851            for _ in 0..num_chunks {
852                let chunk_size = chunk_size_distribution.sample(rng);
853                offsets.push(offsets.last().unwrap() + chunk_size);
854            }
855
856            let offsets = ChunkOffsets::new(offsets.into()).unwrap();
857            let dim = offsets.dim();
858            let pivots = views::Matrix::<f32>::new(
859                views::Init(|| value_distribution.sample(rng) as f32),
860                num_centers,
861                dim,
862            );
863
864            let table = TransposedTable::from_parts(pivots.as_view(), offsets.clone()).unwrap();
865
866            let mut output = views::Matrix::<f32>::new(0.0, num_chunks, num_centers);
867            let query: Vec<_> = (0..dim)
868                .map(|_| value_distribution.sample(rng) as f32)
869                .collect();
870
871            // Inner Product
872            table.process_into::<InnerProduct>(&query, output.as_mut_view());
873
874            for chunk in 0..num_chunks {
875                let range = offsets.at(chunk);
876                let query_chunk = &query[range.clone()];
877                for center in 0..num_centers {
878                    let data_chunk = &pivots.row(center)[range.clone()];
879                    let expected: f32 = distance::InnerProduct::evaluate(query_chunk, data_chunk);
880                    assert_eq!(
881                        output[(chunk, center)],
882                        expected,
883                        "failed on (chunk, center) = ({}, {}) - offsets = {:?} - trial = {}",
884                        chunk,
885                        center,
886                        offsets,
887                        trial,
888                    );
889                }
890            }
891
892            // Squared L2
893            table.process_into::<SquaredL2>(&query, output.as_mut_view());
894
895            for chunk in 0..num_chunks {
896                let range = offsets.at(chunk);
897                let query_chunk = &query[range.clone()];
898                for center in 0..num_centers {
899                    let data_chunk = &pivots.row(center)[range.clone()];
900                    let expected: f32 = distance::SquaredL2::evaluate(query_chunk, data_chunk);
901                    assert_eq!(
902                        output[(chunk, center)],
903                        expected,
904                        "failed on (chunk, center) = ({}, {}) - offsets = {:?} - trial = {}",
905                        chunk,
906                        center,
907                        offsets,
908                        trial,
909                    );
910                }
911            }
912        }
913    }
914
915    #[test]
916    fn test_process_into() {
917        let mut rng = StdRng::seed_from_u64(0x0e3cf3ba4b27e7f8);
918        for num_chunks in 1..5 {
919            for num_centers in 1..48 {
920                test_process_into_impl(num_chunks, num_centers, 2, &mut rng);
921            }
922        }
923    }
924
925    #[test]
926    #[should_panic(expected = "query has the wrong number of dimensions")]
927    fn test_process_into_panics_query() {
928        let offsets = ChunkOffsets::new(Box::new([0, 1, 5])).unwrap();
929        let data = views::Matrix::<f32>::new(0.0, 3, 5);
930        let table = TransposedTable::from_parts(data.as_view(), offsets).unwrap();
931        assert_eq!(table.dim(), 5);
932
933        // query has the wrong length.
934        let query = vec![0.0; table.dim() - 1];
935        let mut partials = views::Matrix::new(0.0, table.nchunks(), table.ncenters());
936        table.process_into::<InnerProduct>(&query, partials.as_mut_view());
937    }
938
939    #[test]
940    #[should_panic(expected = "output has the wrong number of rows")]
941    fn test_process_into_panics_partials_rows() {
942        let offsets = ChunkOffsets::new(Box::new([0, 1, 5])).unwrap();
943        let data = views::Matrix::<f32>::new(0.0, 3, 5);
944        let table = TransposedTable::from_parts(data.as_view(), offsets).unwrap();
945        assert_eq!(table.dim(), 5);
946
947        let query = vec![0.0; table.dim()];
948        // partials has the wrong numbers of rows.
949        let mut partials = views::Matrix::new(0.0, table.nchunks() - 1, table.ncenters());
950        table.process_into::<InnerProduct>(&query, partials.as_mut_view());
951    }
952
953    #[test]
954    #[should_panic(expected = "output has the wrong number of columns")]
955    fn test_process_into_panics_partials_cols() {
956        let offsets = ChunkOffsets::new(Box::new([0, 1, 5])).unwrap();
957        let data = views::Matrix::<f32>::new(0.0, 3, 5);
958        let table = TransposedTable::from_parts(data.as_view(), offsets).unwrap();
959        assert_eq!(table.dim(), 5);
960
961        let query = vec![0.0; table.dim()];
962        // partials has the wrong numbers of rows.
963        let mut partials = views::Matrix::new(0.0, table.nchunks(), table.ncenters() - 1);
964        table.process_into::<InnerProduct>(&query, partials.as_mut_view());
965    }
966}