Skip to main content

diskann_quantization/product/tables/transposed/
table.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::num::NonZeroUsize;
7
8use super::pivots;
9use crate::{
10    product::tables::basic::TableCompressionError,
11    traits::CompressInto,
12    views::{ChunkOffsets, ChunkOffsetsView},
13};
14use diskann_utils::{
15    strided,
16    views::{self, MatrixView, MutMatrixView},
17};
18use thiserror::Error;
19
20/// A PQ table that stores the pivots for each chunk in a miniture block-transpose to
21/// facilitate faster compression. The exact layout is not documented (as it is for the
22/// `BasicTable`) because it may be subject to change.
23///
24/// The advantage of this table over the `BasicTable` is that this table uses a more
25/// hardware friendly layout for the pivots, meaning that compression (particularly batch
26/// compression) is much faster than the basic table, at the cost of a slightly higher
27/// memory footprint.
28///
29/// # Invariants (Dev Docs)
30///
31/// * `pivots.len() == schema.len()`: The number of PQ chunks must agree.
32/// * `pivots[i].dimension() == schema.at(i).len()` for all inbounds `i`.
33/// * `pivots[i].total() == ncenters` for all inbounds `i`.
34/// * `largest = max(schema.at(i).len() for i)`.
35#[derive(Debug)]
36pub struct TransposedTable {
37    pivots: Box<[pivots::Chunk]>,
38    offsets: crate::views::ChunkOffsets,
39    /// The largest dimension in offsets.
40    largest: usize,
41    /// The number of centers in each block.
42    ncenters: usize,
43}
44
45#[derive(Debug, Error)]
46#[non_exhaustive]
47pub enum TransposedTableError {
48    #[error("pivots have {pivot_dim} dimensions while the offsets expect {offsets_dim}")]
49    DimMismatch {
50        pivot_dim: usize,
51        offsets_dim: usize,
52    },
53    #[error("error constructing pivot {problem} of {total}")]
54    PivotError {
55        problem: usize,
56        total: usize,
57        source: pivots::ChunkConstructionError,
58    },
59}
60
61impl TransposedTable {
62    /// Construct a new `TransposedTable` from raw parts.
63    ///
64    /// # Error
65    ///
66    /// Returns an error if
67    ///
68    /// * `pivots.ncols() != offsets.dim()`: Pivots must have the dimensionality expected
69    ///   by the offsets.
70    ///
71    /// * `pivots.nrows() == 0`: The pivot table cannot be empty.
72    #[allow(clippy::expect_used)]
73    pub fn from_parts(
74        pivots: views::MatrixView<f32>,
75        offsets: ChunkOffsets,
76    ) -> Result<Self, TransposedTableError> {
77        let pivot_dim = pivots.ncols();
78        let offsets_dim = offsets.dim();
79        if pivot_dim != offsets_dim {
80            return Err(TransposedTableError::DimMismatch {
81                pivot_dim,
82                offsets_dim,
83            });
84        }
85
86        let ncenters = pivots.nrows();
87        let mut largest = 0;
88        let pivots: Box<[_]> = (0..offsets.len())
89            .map(|i| {
90                let range = offsets.at(i);
91                largest = largest.max(range.len());
92                let view = strided::StridedView::try_shrink_from(
93                    &(pivots.as_slice()[range.start..]),
94                    pivots.nrows(),
95                    range.len(),
96                    offsets.dim(),
97                )
98                .expect(
99                    "the check on `pivot_dim` and `offsets_dim` should cause this to never error",
100                );
101                pivots::Chunk::new(view).map_err(|source| TransposedTableError::PivotError {
102                    problem: i,
103                    total: offsets.len(),
104                    source,
105                })
106            })
107            .collect::<Result<Box<[_]>, TransposedTableError>>()?;
108
109        debug_assert_eq!(pivots.len(), offsets.len());
110        Ok(Self {
111            pivots,
112            offsets,
113            largest,
114            ncenters,
115        })
116    }
117
118    /// Return the number of pivots in each PQ chunk.
119    pub fn ncenters(&self) -> usize {
120        self.ncenters
121    }
122
123    /// Return the number of PQ chunks.
124    pub fn nchunks(&self) -> usize {
125        self.offsets.len()
126    }
127
128    /// Return the dimensionality of the full-precision vectors associated with this table.
129    pub fn dim(&self) -> usize {
130        self.offsets.dim()
131    }
132
133    /// Return a view over the schema offsets.
134    pub fn view_offsets(&self) -> ChunkOffsetsView<'_> {
135        self.offsets.as_view()
136    }
137
138    /// Perform PQ compression on the dataset by mapping each chunk in `data` to its nearest
139    /// neighbor in the corresponding entry in `chunks`.
140    ///
141    /// The index of the nearest neighbor is provided to `compression_delegate` along with its
142    /// corresponding row in data and chunk index.
143    ///
144    /// Calls to `compression_delegate` may occur in any order.
145    ///
146    /// Visitor will be invoked for all rows in `0..data.nrows()` and all chunks in
147    /// `0..schema.len()`.
148    ///
149    /// # Panics
150    ///
151    /// Panics under the following conditions:
152    /// * `data.cols() != self.dim()`: The number of columns in the source dataset must match
153    ///   the number of dimensions expected by the schema.
154    #[allow(clippy::expect_used)]
155    pub fn compress_batch<T, F, DelegateError>(
156        &self,
157        data: views::MatrixView<'_, T>,
158        mut compression_delegate: F,
159    ) -> Result<(), CompressError<DelegateError>>
160    where
161        T: Copy + Into<f32>,
162        F: FnMut(RowChunk, pivots::CompressionResult) -> Result<(), DelegateError>,
163        DelegateError: std::error::Error,
164    {
165        assert_eq!(
166            data.ncols(),
167            self.dim(),
168            "schema expects {} dimensions but data has {}",
169            self.dim(),
170            data.ncols()
171        );
172
173        // The batch size expected by the compression micro-kernel.
174        const SUB_BATCH_SIZE: usize = pivots::Chunk::batchsize();
175
176        let dim_nonzero = self.offsets.dim_nonzero();
177        let mut packing_buffer: Box<[f32]> =
178            (0..self.largest * SUB_BATCH_SIZE).map(|_| 0.0).collect();
179
180        // Stride along the chunks in the source matrix to keep the associated `Chunk` in the
181        // cache.
182        let nrows = data.nrows();
183        let ncols = data.ncols();
184        let slice = data.as_slice();
185
186        for (i, chunk) in self.pivots.iter().enumerate() {
187            let range = self.offsets.at(i);
188            if let Some(chunk_dim) = NonZeroUsize::new(range.len()) {
189                // Construct a view for the packing buffer for this chunk.
190                let mut packing_view = views::MutMatrixView::try_from(
191                    &mut packing_buffer[..SUB_BATCH_SIZE * chunk_dim.get()],
192                    SUB_BATCH_SIZE,
193                    chunk_dim.get(),
194                )
195                .expect("the packing buffer should have been sized correctly");
196
197                for row_start in (0..nrows).step_by(SUB_BATCH_SIZE) {
198                    let row_end = nrows.min(row_start + SUB_BATCH_SIZE);
199
200                    // If this is a full batch, the use the batched micro-kernel.
201                    if row_end - row_start == SUB_BATCH_SIZE {
202                        // When computing the stop offset, don't try to adjust for the length
203                        // of the underlying span.
204                        //
205                        // The control on our loop bounds mean we should never be in a situation
206                        // where this would occur, so we'd rather hit the indexing panic early.
207                        let mut linear_start = row_start * ncols + range.start;
208                        packing_view.row_iter_mut().for_each(|row| {
209                            pack(row, &slice[linear_start..linear_start + chunk_dim.get()]);
210                            linear_start += dim_nonzero.get();
211                        });
212
213                        let result = chunk.find_closest_batch(packing_view.as_view().into());
214
215                        // Invoke the delegate with the results.
216                        // If the delegate returns an error, wrap that error inside a
217                        // `CompressError` to add context and forward the error upward.
218                        for (j, &r) in result.iter().enumerate() {
219                            compression_delegate(
220                                RowChunk {
221                                    row: row_start + j,
222                                    chunk: i,
223                                },
224                                r,
225                            )
226                            .map_err(|inner| CompressError {
227                                inner,
228                                row: row_start + j,
229                                chunk: i,
230                                nearest: r.into_inner(),
231                            })?;
232                        }
233                    } else {
234                        // Handle remainders one at a time.
235                        for row in row_start..row_end {
236                            let linear_start = row * ncols + range.start;
237                            let linear_stop = linear_start + range.len();
238
239                            // Pre-convert to f32.
240                            let packed = &mut packing_view.row_mut(0);
241                            pack(packed, &slice[linear_start..linear_stop]);
242                            let result = chunk.find_closest(packed);
243
244                            compression_delegate(RowChunk { row, chunk: i }, result).map_err(
245                                |inner| CompressError {
246                                    inner,
247                                    row,
248                                    chunk: i,
249                                    nearest: result.into_inner(),
250                                },
251                            )?;
252                        }
253                    }
254                }
255            }
256        }
257        Ok(())
258    }
259
260    /// Compute the operation defined by `T` for each chunk in the query on all corresponding
261    /// pivots for that chunk, storing the result in the output matrix.
262    ///
263    /// For example, this can be used to compute squared L2 distances between chunks of the
264    /// query and the pivot table to create a fast run-time lookup table for these distances.
265    ///
266    /// This is currently implemented for the following operation types `T`:
267    ///
268    /// * `quantization::distances::SquaredL2`
269    /// * `quantization::distances::InnerProduct`
270    ///
271    /// # Arguments
272    ///
273    /// * `query`: The query slice to process. Must have length `self.dim()`.
274    /// * `partials`: Output matrix for the partial results. The result of the computation
275    ///   of chunk `i` against pivot `j` will be stored into `pivots[(i, j)]`.
276    ///
277    ///   Must have `nrows = self.nchunks()` and `ncols = self.ncenters()`.
278    ///
279    /// # Panics
280    ///
281    /// Panics if:
282    /// * `query.len() != self.dim()`.
283    /// * `partisl.nrows() != self.nchunks()`.
284    /// * `partisl.ncols() != self.ncenters()`.
285    pub fn process_into<T>(&self, query: &[f32], mut partials: MutMatrixView<'_, f32>)
286    where
287        T: pivots::ProcessInto,
288    {
289        // Check Requirements
290        assert_eq!(
291            query.len(),
292            self.dim(),
293            "query has the wrong number of dimensions"
294        );
295        assert_eq!(
296            partials.ncols(),
297            self.ncenters(),
298            "output has the wrong number of columns"
299        );
300        assert_eq!(
301            partials.nrows(),
302            self.nchunks(),
303            "output has the wrong number of rows"
304        );
305
306        // Loop over each chunk.
307        std::iter::zip(self.pivots.iter(), partials.row_iter_mut())
308            .enumerate()
309            .for_each(|(i, (pivot, out))| {
310                let range = self.offsets.at(i);
311                T::process_into(pivot, &query[range], out);
312            });
313    }
314}
315
316/// Row and chunk indexes provided to the `compression_delegate` argument of `compress` to
317/// describe the position of the nearest neighbor being provided.
318pub struct RowChunk {
319    row: usize,
320    chunk: usize,
321}
322
323#[derive(Error, Debug)]
324#[error(
325    "compression delegate returned \"{inner}\" when processing row {row} and chunk \
326    {chunk} with nearest center {nearest}"
327)]
328pub struct CompressError<DelegateError: std::error::Error> {
329    inner: DelegateError,
330    row: usize,
331    chunk: usize,
332    nearest: u32,
333}
334
335#[inline(always)]
336fn pack<T>(dst: &mut [f32], src: &[T])
337where
338    T: Copy + Into<f32>,
339{
340    debug_assert_eq!(dst.len(), src.len());
341    std::iter::zip(dst.iter_mut(), src.iter()).for_each(|(d, &s)| *d = s.into());
342}
343
344///////////////
345// Coompress //
346///////////////
347
348impl CompressInto<&[f32], &mut [u8]> for TransposedTable {
349    type Error = TableCompressionError;
350    type Output = ();
351
352    /// Compress the full-precision vector `from` into the PQ byte buffer `to`.
353    ///
354    /// Compression is performed by partitioning `from` into chunks according to the offsets
355    /// schema in the table and then finding the closest pivot according to the L2 distance.
356    ///
357    /// The final compressed value is the index of the closest pivot.
358    ///
359    /// # Errors
360    ///
361    /// Returns errors under the following conditions:
362    ///
363    /// * `self.ncenters() > 256`: If the number of centers exceeds 256, then it cannot be
364    ///   guaranteed that the index of the closest pivot for a chunk will fit losslessly in
365    ///   an 8-bit integer.
366    ///
367    /// * `from.len() != self.dim()`: The full precision vector must have the dimensionality
368    ///   expected by the compression.
369    ///
370    /// * `to.len() != self.nchunks()`: The PQ buffer must be sized appropriately.
371    ///
372    /// * If any chunk is sufficiently far from all centers that its distance becomes
373    ///   infinity to all centers.
374    ///
375    /// # Allocates
376    ///
377    /// This function should not allocate when successful.
378    ///
379    /// # Parallelism
380    ///
381    /// This function is single-threaded.
382    fn compress_into(&self, from: &[f32], to: &mut [u8]) -> Result<(), Self::Error> {
383        if self.ncenters() > 256 {
384            return Err(Self::Error::CannotCompressToByte(self.ncenters()));
385        }
386        if from.len() != self.dim() {
387            return Err(Self::Error::InvalidInputDim(self.dim(), from.len()));
388        }
389        if to.len() != self.nchunks() {
390            return Err(Self::Error::InvalidOutputDim(self.nchunks(), to.len()));
391        }
392
393        std::iter::zip(self.pivots.iter(), to.iter_mut())
394            .enumerate()
395            .try_for_each(|(i, (pivot, to))| {
396                let range = self.offsets.at(i);
397                let result = pivot.find_closest(&from[range]);
398                result.map(
399                    |v| *to = v as u8, // conversion guaranteed to be lossless
400                    || Self::Error::InfinityOrNaN(i),
401                )
402            })
403    }
404}
405
406#[derive(Error, Debug)]
407pub enum TableBatchCompressionError {
408    #[error("num centers ({0}) must be at most 256 to compress into a byte vector")]
409    CannotCompressToByte(usize),
410    #[error("invalid input len - expected {0}, got {1}")]
411    InvalidInputDim(usize, usize),
412    #[error("invalid PQ buffer len - expected {0}, got {1}")]
413    InvalidOutputDim(usize, usize),
414    #[error(
415        "input and output must have the same number of rows - instead, got {0} and {1} \
416         (respectively)"
417    )]
418    UnequalRows(usize, usize),
419    #[error(
420        "a value of infinity or NaN was observed while compressing chunk {0} of batch input {1}"
421    )]
422    InfinityOrNaN(usize, usize),
423}
424
425impl<T> CompressInto<MatrixView<'_, T>, MutMatrixView<'_, u8>> for TransposedTable
426where
427    T: Copy + Into<f32>,
428{
429    type Error = TableBatchCompressionError;
430    type Output = ();
431
432    /// Compress each full-precision row in `from` into the corresponding row in `to`.
433    ///
434    /// Compression is performed by partitioning `from` into chunks according to the offsets
435    /// schema in the table and then finding the closest pivot according to the L2 distance.
436    ///
437    /// The final compressed value is the index of the closest pivot.
438    ///
439    /// # Errors
440    ///
441    /// Returns errors under the following conditions:
442    ///
443    /// * `self.ncenters() > 256`: If the number of centers exceeds 256, then it cannot be
444    ///   guaranteed that the index of the closest pivot for a chunk will fit losslessly in
445    ///   an 8-bit integer.
446    ///
447    /// * `from.ncols() != self.dim()`: The full precision data must have the dimensionality
448    ///   expected by the compression.
449    ///
450    /// * `to.ncols() != self.nchunks()`: The PQ buffer must be sized appropriately.
451    ///
452    /// * `from.nrows() == to.nrows()`: The input and output buffers must have the same
453    ///   number of elements.
454    ///
455    /// * If any chunk is sufficiently far from all centers that its distance becomes
456    ///   infinity to all centers.
457    ///
458    /// # Allocates
459    ///
460    /// Allocates scratch memory proportional to the length of the largest chunk.
461    ///
462    /// # Parallelism
463    ///
464    /// This function is single-threaded.
465    fn compress_into(
466        &self,
467        from: MatrixView<'_, T>,
468        mut to: MutMatrixView<'_, u8>,
469    ) -> Result<(), Self::Error> {
470        if self.ncenters() > 256 {
471            return Err(Self::Error::CannotCompressToByte(self.ncenters()));
472        }
473        if from.ncols() != self.dim() {
474            return Err(Self::Error::InvalidInputDim(self.dim(), from.ncols()));
475        }
476        if to.ncols() != self.nchunks() {
477            return Err(Self::Error::InvalidOutputDim(self.nchunks(), to.ncols()));
478        }
479        if from.nrows() != to.nrows() {
480            return Err(Self::Error::UnequalRows(from.nrows(), to.nrows()));
481        }
482
483        // The `CompressionError` already has all the information we need, so we make the
484        // `Delegate` error light-weight.
485        #[derive(Debug, Error)]
486        #[error("unreachable")]
487        struct PassThrough;
488
489        // Do the compression. We will reformat the NaN warning from `CompressionError`
490        // if needed.
491        let result = self.compress_batch(
492            from,
493            |RowChunk { row, chunk }, result| -> Result<(), PassThrough> {
494                result.map(|v| to[(row, chunk)] = v as u8, || PassThrough)
495            },
496        );
497
498        result.map_err(|err| Self::Error::InfinityOrNaN(err.chunk, err.row))
499    }
500}
501
502#[cfg(test)]
503mod test_compression {
504    use std::collections::HashSet;
505
506    use diskann_vector::{PureDistanceFunction, distance};
507    use rand::{
508        Rng, SeedableRng,
509        distr::{Distribution, StandardUniform, Uniform},
510        rngs::StdRng,
511    };
512
513    use super::*;
514    #[cfg(not(miri))]
515    use crate::product::tables::test::{
516        check_pqtable_batch_compression_errors, check_pqtable_single_compression_errors,
517    };
518    use crate::{
519        distances::{InnerProduct, SquaredL2},
520        error::format,
521        product::tables::test::{create_dataset, create_pivot_tables},
522    };
523    use diskann_utils::lazy_format;
524
525    //////////////////////////////
526    // Transposed Table Methods //
527    //////////////////////////////
528
529    // Test that an error is returned when the dimension between the pivots and offsets
530    // disagree.
531    #[test]
532    fn error_on_mismatch_dim() {
533        let pivots = views::Matrix::new(0.0, 3, 5);
534        let offsets = ChunkOffsets::new(Box::new([0, 1, 6])).unwrap();
535        let result = TransposedTable::from_parts(pivots.as_view(), offsets);
536        assert!(result.is_err(), "dimensions are not equal");
537        assert_eq!(
538            result.unwrap_err().to_string(),
539            "pivots have 5 dimensions while the offsets expect 6"
540        );
541    }
542
543    // Test that an error is returned when the dimension between the pivots and offsets
544    // disagree.
545    #[test]
546    fn error_on_empty() {
547        let pivots = views::Matrix::new(0.0, 0, 5);
548        let offsets = ChunkOffsets::new(Box::new([0, 1, 5])).unwrap();
549        let result = TransposedTable::from_parts(pivots.as_view(), offsets);
550        assert!(result.is_err(), "dimensions are not equal");
551
552        let expected = [
553            "error constructing pivot 0 of 2",
554            "    caused by: cannot construct a Chunk from a source with zero length",
555        ]
556        .join("\n");
557
558        assert_eq!(format(&result.unwrap_err()), expected,);
559    }
560
561    #[test]
562    fn basic_table() {
563        let mut rng = StdRng::seed_from_u64(0xd96bac968083ec29);
564        for dim in [5, 10, 12] {
565            // Sweep over enough totals to ensure the inner chunks have a non-trivial layout.
566            for total in [1, 2, 3, 7, 8, 9, 10] {
567                let pivots = views::Matrix::new(
568                    views::Init(|| -> f32 { StandardUniform {}.sample(&mut rng) }),
569                    total,
570                    dim,
571                );
572                let offsets = ChunkOffsets::new(Box::new([0, 1, 3, dim])).unwrap();
573                let table = TransposedTable::from_parts(pivots.as_view(), offsets.clone()).unwrap();
574
575                assert_eq!(table.ncenters(), total);
576                assert_eq!(table.nchunks(), offsets.len());
577                assert_eq!(table.dim(), offsets.dim());
578
579                // This kind of looks into the guts of this data structure, but is an extra
580                // check that the plumbing was performed properly.
581                for chunk in 0..offsets.len() {
582                    let range = offsets.at(chunk);
583                    let pivot = &table.pivots[chunk];
584                    for row in 0..total {
585                        let r = &pivots.row(row)[range.clone()];
586                        for (col, expected) in r.iter().enumerate() {
587                            assert_eq!(pivot.get(row, col), *expected);
588                        }
589                    }
590                }
591
592                assert_eq!(table.view_offsets(), offsets.as_view());
593            }
594        }
595    }
596
597    /////////////////
598    // Compression //
599    /////////////////
600
601    #[derive(Error, Debug)]
602    #[error("unreachable reached")]
603    struct Infallible;
604
605    #[test]
606    fn test_happy_path() {
607        // Feed in chunks of dimension 1, 2, 3, ... 16.
608        //
609        // If we're using MIRI, max out at 7 dimensions.
610        let offsets: Vec<usize> = if cfg!(miri) {
611            vec![0, 1, 3, 6, 10, 15, 21, 28, 36]
612        } else {
613            vec![
614                0, 1, 3, 6, 10, 15, 21, 28, 36, 45, 55, 66, 78, 91, 105, 120, 136,
615            ]
616        };
617
618        let schema = ChunkOffsetsView::new(&offsets).unwrap();
619        let mut rng = StdRng::seed_from_u64(0x88e3d3366501ad6c);
620
621        let num_data = if cfg!(miri) {
622            vec![0, 7, 8, 10]
623        } else {
624            vec![0, 1, 2, 3, 4, 16, 17, 18, 19]
625        };
626
627        let num_trials = if cfg!(miri) { 1 } else { 10 };
628
629        // Strategically pick `num_centers` so we cover the corner cases in lower-level
630        // handling.
631        //
632        // This includes:
633        // * Full blocks with an even number of total blocks (num centers = 16);
634        // * Full blocks with an odd number of total blocks (num centers = 24);
635        // * Partially full blocks with the last block even (num centers = 13);
636        // * Partially full blocks with the last block odd (num centers = 17);
637        for &num_centers in [16, 24, 13, 17].iter() {
638            for &num_data in num_data.iter() {
639                for trial in 0..num_trials {
640                    let context = lazy_format!(
641                        "happy path, num centers = {}, num data = {}, trial = {}",
642                        num_centers,
643                        num_data,
644                        trial,
645                    );
646
647                    println!("Currently = {}", context);
648
649                    let (pivots, offsets) = create_pivot_tables(schema.to_owned(), num_centers);
650                    let table = TransposedTable::from_parts(pivots.as_view(), offsets).unwrap();
651                    let (data, expected) = create_dataset(schema, num_centers, num_data, &mut rng);
652
653                    let mut called = HashSet::<(usize, usize)>::new();
654
655                    // Direct method call.
656                    table
657                        .compress_batch(
658                            data.as_view(),
659                            |RowChunk { row, chunk }, value| -> Result<(), Infallible> {
660                                assert!(value.is_okay());
661                                // Ensure that this is the expected value.
662                                assert_eq!(
663                                value.unwrap() as usize,
664                                expected[(row, chunk)],
665                                "failed at (row = {row}, chunk = {chunk}). data = {:?}, context: {}",
666                                &(data.row(row)[schema.at(chunk)]),
667                                context,
668                            );
669
670                                // (A) Ensure that this combination of row and chunk hasn't
671                                //     been called before.
672                                // (B) Record that this combination has been called.
673                                assert!(
674                                    called.insert((row, chunk)),
675                                    "row {row} and chunk {chunk}, called multiple times. Context = {}",
676                                    context,
677                                );
678
679                                Ok(())
680                            },
681                        )
682                        .unwrap();
683
684                    assert_eq!(called.len(), num_data * schema.len());
685
686                    // Trait Interface.
687                    let mut output = views::Matrix::new(0, num_data, schema.len());
688                    table
689                        .compress_into(data.as_view(), output.as_mut_view())
690                        .unwrap();
691
692                    assert_eq!(output.nrows(), expected.nrows());
693                    assert_eq!(output.ncols(), expected.ncols());
694                    for row in 0..output.nrows() {
695                        for col in 0..output.ncols() {
696                            assert_eq!(
697                                output[(row, col)] as usize,
698                                expected[(row, col)],
699                                "failed on row {}, col {}. Context = {}",
700                                row,
701                                col,
702                                context,
703                            );
704                        }
705                    }
706
707                    // Trait inteface - single step.
708                    let mut output = vec![0; schema.len()];
709                    for (i, (row, expected)) in
710                        std::iter::zip(data.row_iter(), expected.row_iter()).enumerate()
711                    {
712                        table.compress_into(row, output.as_mut_slice()).unwrap();
713                        for (d, (o, e)) in
714                            std::iter::zip(output.iter(), expected.iter()).enumerate()
715                        {
716                            assert_eq!(
717                                *o as usize, *e,
718                                "failed on row {}, col {}. Context = {}",
719                                i, d, context
720                            );
721                        }
722                    }
723                }
724            }
725        }
726    }
727
728    ////////////////////
729    // Error Handling //
730    ////////////////////
731
732    #[test]
733    #[should_panic(expected = "schema expects 4 dimensions but data has 5")]
734    fn panic_on_dim_mismatch() {
735        let offsets = [0, 4];
736        let data: Vec<f32> = vec![0.0; 5];
737
738        let schema = ChunkOffsetsView::new(&offsets).unwrap();
739        let (pivots, offsets) = create_pivot_tables(schema.to_owned(), 3);
740        let table = TransposedTable::from_parts(pivots.as_view(), offsets).unwrap();
741
742        // should panic
743        let _ = table.compress_batch(
744            views::MatrixView::try_from(data.as_slice(), 1, 5).unwrap(),
745            |_, _| -> Result<(), Infallible> { panic!("this shouldn't be called") },
746        );
747    }
748
749    #[derive(Error, Debug)]
750    #[error("compression delegate error with {0}")]
751    struct DelegateError(u64);
752
753    // The strategy here is to construct a delegate that returns an error on a specific
754    // row and chunk. We then make sure that the error is propagated successfully along
755    // with the recorded row and chunk.
756    #[test]
757    fn test_delegate_error_propagation() {
758        let offsets: Vec<usize> = vec![0, 1, 7];
759        let schema = ChunkOffsetsView::new(&offsets).unwrap();
760        let mut rng = StdRng::seed_from_u64(0xc35a90da17fafa2a);
761
762        let num_centers = 3;
763        let num_data = 7;
764
765        let (pivots, offsets) = create_pivot_tables(schema.to_owned(), num_centers);
766        let table = TransposedTable::from_parts(pivots.as_view(), offsets).unwrap();
767        let (data, _) = create_dataset(schema, num_centers, num_data, &mut rng);
768
769        let data_view =
770            views::MatrixView::try_from(data.as_slice(), num_data, schema.dim()).unwrap();
771        let distribution = rand_distr::StandardUniform {};
772
773        for row in 0..data_view.nrows() {
774            for chunk in 0..schema.len() {
775                let context = lazy_format!("row = {row}, chunk = {chunk}");
776
777                // Generate a random number for the delegate error.
778                let value: u64 = rng.sample(distribution);
779
780                let result = table.compress_batch(
781                    data_view,
782                    |RowChunk {
783                         row: this_row,
784                         chunk: this_chunk,
785                     },
786                     _| {
787                        if this_row == row && this_chunk == chunk {
788                            Err(DelegateError(value))
789                        } else {
790                            Ok(())
791                        }
792                    },
793                );
794                assert!(result.is_err(), "{}", context);
795
796                let message = result.unwrap_err().to_string();
797                assert!(
798                    message.contains(&format!("{}", DelegateError(value))),
799                    "{}",
800                    context
801                );
802                assert!(message.contains("delegate returned"));
803                assert!(message.contains(&format!("when processing row {row} and chunk {chunk}")));
804            }
805        }
806    }
807
808    #[test]
809    #[cfg(not(miri))]
810    fn test_table_single_compression_errors() {
811        check_pqtable_single_compression_errors(
812            &|pivots: views::Matrix<f32>, offsets| {
813                TransposedTable::from_parts(pivots.as_view(), offsets).unwrap()
814            },
815            &"TranposedTable",
816        )
817    }
818
819    #[test]
820    #[cfg(not(miri))]
821    fn test_table_batch_compression_errors() {
822        check_pqtable_batch_compression_errors(
823            &|pivots: views::Matrix<f32>, offsets| {
824                TransposedTable::from_parts(pivots.as_view(), offsets).unwrap()
825            },
826            &"TranposedTable",
827        )
828    }
829
830    /////////////////////////
831    // Test `process_into` //
832    /////////////////////////
833
834    fn test_process_into_impl(
835        num_chunks: usize,
836        num_centers: usize,
837        num_trials: usize,
838        rng: &mut StdRng,
839    ) {
840        // Choose the chunk size randomly from this distribution. Keep choosing chunks
841        // sizes until the desired `num_chunks` is reached.
842        //
843        // The sum of chunk sizes gives the dimensionality.
844        let chunk_size_distribution = Uniform::<usize>::new(1, 6).unwrap();
845
846        // Use integer values for the value distribution to avoid dealing with floating
847        // point rounding.
848        let value_distribution = Uniform::<i32>::new(-10, 10).unwrap();
849
850        for trial in 0..num_trials {
851            let mut offsets: Vec<usize> = vec![0];
852            for _ in 0..num_chunks {
853                let chunk_size = chunk_size_distribution.sample(rng);
854                offsets.push(offsets.last().unwrap() + chunk_size);
855            }
856
857            let offsets = ChunkOffsets::new(offsets.into()).unwrap();
858            let dim = offsets.dim();
859            let pivots = views::Matrix::<f32>::new(
860                views::Init(|| value_distribution.sample(rng) as f32),
861                num_centers,
862                dim,
863            );
864
865            let table = TransposedTable::from_parts(pivots.as_view(), offsets.clone()).unwrap();
866
867            let mut output = views::Matrix::<f32>::new(0.0, num_chunks, num_centers);
868            let query: Vec<_> = (0..dim)
869                .map(|_| value_distribution.sample(rng) as f32)
870                .collect();
871
872            // Inner Product
873            table.process_into::<InnerProduct>(&query, output.as_mut_view());
874
875            for chunk in 0..num_chunks {
876                let range = offsets.at(chunk);
877                let query_chunk = &query[range.clone()];
878                for center in 0..num_centers {
879                    let data_chunk = &pivots.row(center)[range.clone()];
880                    let expected: f32 = distance::InnerProduct::evaluate(query_chunk, data_chunk);
881                    assert_eq!(
882                        output[(chunk, center)],
883                        expected,
884                        "failed on (chunk, center) = ({}, {}) - offsets = {:?} - trial = {}",
885                        chunk,
886                        center,
887                        offsets,
888                        trial,
889                    );
890                }
891            }
892
893            // Squared L2
894            table.process_into::<SquaredL2>(&query, output.as_mut_view());
895
896            for chunk in 0..num_chunks {
897                let range = offsets.at(chunk);
898                let query_chunk = &query[range.clone()];
899                for center in 0..num_centers {
900                    let data_chunk = &pivots.row(center)[range.clone()];
901                    let expected: f32 = distance::SquaredL2::evaluate(query_chunk, data_chunk);
902                    assert_eq!(
903                        output[(chunk, center)],
904                        expected,
905                        "failed on (chunk, center) = ({}, {}) - offsets = {:?} - trial = {}",
906                        chunk,
907                        center,
908                        offsets,
909                        trial,
910                    );
911                }
912            }
913        }
914    }
915
916    #[test]
917    fn test_process_into() {
918        let mut rng = StdRng::seed_from_u64(0x0e3cf3ba4b27e7f8);
919
920        let num_chunks_range = if cfg!(miri) { 4..5 } else { 1..5 };
921
922        let num_centers: Vec<usize> = if cfg!(miri) {
923            vec![1, 7, 16, 33, 47]
924        } else {
925            (1..48).collect()
926        };
927
928        for num_chunks in num_chunks_range {
929            for num_centers in num_centers.clone() {
930                test_process_into_impl(num_chunks, num_centers, 2, &mut rng);
931            }
932        }
933    }
934
935    #[test]
936    #[should_panic(expected = "query has the wrong number of dimensions")]
937    fn test_process_into_panics_query() {
938        let offsets = ChunkOffsets::new(Box::new([0, 1, 5])).unwrap();
939        let data = views::Matrix::<f32>::new(0.0, 3, 5);
940        let table = TransposedTable::from_parts(data.as_view(), offsets).unwrap();
941        assert_eq!(table.dim(), 5);
942
943        // query has the wrong length.
944        let query = vec![0.0; table.dim() - 1];
945        let mut partials = views::Matrix::new(0.0, table.nchunks(), table.ncenters());
946        table.process_into::<InnerProduct>(&query, partials.as_mut_view());
947    }
948
949    #[test]
950    #[should_panic(expected = "output has the wrong number of rows")]
951    fn test_process_into_panics_partials_rows() {
952        let offsets = ChunkOffsets::new(Box::new([0, 1, 5])).unwrap();
953        let data = views::Matrix::<f32>::new(0.0, 3, 5);
954        let table = TransposedTable::from_parts(data.as_view(), offsets).unwrap();
955        assert_eq!(table.dim(), 5);
956
957        let query = vec![0.0; table.dim()];
958        // partials has the wrong numbers of rows.
959        let mut partials = views::Matrix::new(0.0, table.nchunks() - 1, table.ncenters());
960        table.process_into::<InnerProduct>(&query, partials.as_mut_view());
961    }
962
963    #[test]
964    #[should_panic(expected = "output has the wrong number of columns")]
965    fn test_process_into_panics_partials_cols() {
966        let offsets = ChunkOffsets::new(Box::new([0, 1, 5])).unwrap();
967        let data = views::Matrix::<f32>::new(0.0, 3, 5);
968        let table = TransposedTable::from_parts(data.as_view(), offsets).unwrap();
969        assert_eq!(table.dim(), 5);
970
971        let query = vec![0.0; table.dim()];
972        // partials has the wrong numbers of rows.
973        let mut partials = views::Matrix::new(0.0, table.nchunks(), table.ncenters() - 1);
974        table.process_into::<InnerProduct>(&query, partials.as_mut_view());
975    }
976}