Skip to main content

diskann_quantization/scalar/
quantizer.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use diskann_vector::{MathematicalValue, PureDistanceFunction};
7use thiserror::Error;
8
9use super::{
10    bit_scale, inverse_bit_scale,
11    vectors::{
12        CompensatedCosineNormalized, CompensatedIP, CompensatedSquaredL2, Compensation,
13        MutCompensatedVectorRef,
14    },
15};
16use crate::{
17    AsFunctor, CompressInto,
18    bits::{MutBitSlice, PermutationStrategy, Representation, Unsigned},
19};
20
21/// A central parameter collection for a scalar quantization schema.
22///
23/// # Example
24///
25/// An self-contained end-to-end example containing training, compression, and distance
26/// computations is shown below.
27///
28/// ```rust
29/// use diskann_quantization::{
30///     AsFunctor, CompressInto,
31///     distances,
32///     num::Positive, bits::MutBitSlice,
33///     scalar::{
34///         self,
35///         ScalarQuantizer,
36///         train::ScalarQuantizationParameters,
37///         CompensatedVector, MutCompensatedVectorRef,
38///         CompensatedIP, CompensatedSquaredL2,
39///     }
40/// };
41/// use diskann_utils::{views::Matrix, Reborrow, ReborrowMut};
42/// use diskann_vector::DistanceFunction;
43///
44/// // A small training set consisting of two 5-dimensional vectors.
45/// let mut data = Matrix::<f32>::new(0.0, 2, 5);
46/// data.row_mut(0).copy_from_slice(&[-1.0, -1.0, -1.0, -1.0, -1.0]);
47/// data.row_mut(1).copy_from_slice(&[1.0, 1.0, 1.0, 1.0, 1.0]);
48///
49/// let trainer = ScalarQuantizationParameters::new(Positive::new(1.0).unwrap());
50/// let quantizer: ScalarQuantizer = trainer.train(data.as_view());
51///
52/// // The dimension of the quantizer is based on the dimension of the training data.
53/// assert_eq!(quantizer.dim(), data.ncols());
54///
55/// // Compress the two input vectors.
56/// // For one vector, we will use the "boxed" API. The other we will construct "manually".
57///
58/// // Boxed API
59/// let mut c0 = CompensatedVector::<8>::new_boxed(data.ncols());
60///
61/// // Manual construction.
62/// let mut buffer: Vec<u8> = vec![0; c0.vector().bytes()];
63/// let mut compensation = scalar::Compensation(0.0);
64/// let mut c1 = MutCompensatedVectorRef::new(
65///     MutBitSlice::new(buffer.as_mut_slice(), data.ncols()).unwrap(),
66///     &mut compensation
67/// );
68///
69/// quantizer.compress_into(data.row(0), c0.reborrow_mut()).unwrap();
70/// quantizer.compress_into(data.row(1), c1.reborrow_mut()).unwrap();
71///
72/// // Compute inner product.
73/// let ip: CompensatedIP = quantizer.as_functor();
74/// let distance: distances::Result<f32> = ip.evaluate_similarity(c0.reborrow(), c1.reborrow());
75///
76/// // The inner product computation to `f32` is the same as a SimilarityScore and is
77/// // therefore negative of the mathematical value.
78/// assert!((distance.unwrap() - 5.0).abs() < 0.00001);
79///
80/// // Compute squared eudlicean distance.
81/// let l2: CompensatedSquaredL2 = quantizer.as_functor();
82/// let distance: distances::Result<f32> = l2.evaluate_similarity(c0.reborrow(), c1.reborrow());
83/// assert!((distance.unwrap() - 20.0).abs() < 0.00001);
84/// ```
85#[derive(Clone, Debug)]
86pub struct ScalarQuantizer {
87    /// The scaling parameter applied to each vector component.
88    scale: f32,
89
90    /// The amount each data point is shifted.
91    ///
92    /// This is computed as the dataset mean subtracted by the scaling parameter.
93    /// The additional subtraction is needed to ensure we can map encodings into an unsigned
94    /// integer.
95    ///
96    /// For datasets that have components with non-zero mean, this can greatly improve the
97    /// quality of quantization by decreasing the observed dynamic range across all vector
98    /// component, but this shift must be applied regardless of whether or not the mean
99    /// is calculated.
100    shift: Vec<f32>,
101
102    /// The square norm of the shift.
103    /// This quantity is useful when computing dot-products.
104    shift_square_norm: f32,
105
106    /// When processing queries, it may be beneficial to modify the query norm to match the
107    /// dataset norm.
108    ///
109    /// This is only applicable when `InnerProduct` and `Cosine` are used, but serves to
110    /// move the query into the dynamic range of the quantization.
111    mean_norm: Option<f32>,
112}
113
114impl ScalarQuantizer {
115    /// Construct a new scalar quantizer.
116    pub fn new(scale: f32, shift: Vec<f32>, mean_norm: Option<f32>) -> Self {
117        let shift_square_norm: MathematicalValue<f32> =
118            diskann_vector::distance::InnerProduct::evaluate(&*shift, &*shift);
119
120        Self {
121            scale,
122            shift,
123            shift_square_norm: shift_square_norm.into_inner(),
124            mean_norm,
125        }
126    }
127
128    /// Return the number dimensions this ScalarQuantizer has been trained for.
129    pub fn dim(&self) -> usize {
130        self.shift.len()
131    }
132
133    /// Return the scaling coefficient.
134    pub fn scale(&self) -> f32 {
135        self.scale
136    }
137
138    /// Return the square norm of the dataset shift.
139    pub fn shift_square_norm(&self) -> f32 {
140        self.shift_square_norm
141    }
142
143    /// Return the per-dimension shift vector.
144    ///
145    /// This vector is meant to accomplish two goals:
146    ///
147    /// 1. Centers the data around the training dataset mean.
148    /// 2. Offsets each dimension into a range that can be encoded in unsigned values.
149    pub fn shift(&self) -> &[f32] {
150        &self.shift
151    }
152
153    /// Return the average norm of vectors in the training set.
154    pub fn mean_norm(&self) -> Option<f32> {
155        self.mean_norm
156    }
157
158    /// Rescale the argument so it has the average norm of the training set.
159    ///
160    /// This can be used to help with compression queries that come from a different
161    /// distribution when the norm of the query may be safely discarded for purposes of
162    /// distance computations.
163    ///
164    /// This operation can fail is the mean norm was not computed during training.
165    pub fn rescale(&self, x: &mut [f32]) -> Result<(), MeanNormMissing> {
166        match self.mean_norm {
167            Some(mean_norm) => {
168                rescale(x, mean_norm);
169                Ok(())
170            }
171            None => Err(MeanNormMissing),
172        }
173    }
174
175    /// An private compression method used by the implementations of `CompressInto`.
176    ///
177    /// This function works by shifting each dimension by `self.shift`, dividing by
178    /// `self.scale`, and rounding to the nearest integer.
179    ///
180    /// Values that exceed the dynamic range of the quantization are clamped.
181    ///
182    /// To help with computing compensation coefficients, `callback` is included which
183    /// is given the compressed value as a floating point number.
184    ///
185    /// # Notes
186    ///
187    /// This function allows the `ScalarQuantizer` to compress to bit-widths other than the
188    /// one assigned to the quantizer. Though we have to compute a correcting factor for the
189    /// scale, this allows us to mix and match compression bit-widths.
190    fn compress<const NBITS: usize, T, F, Perm>(
191        &self,
192        from: &[T],
193        mut into: MutBitSlice<'_, NBITS, Unsigned, Perm>,
194        mut callback: F,
195    ) -> Result<(), InputContainsNaN>
196    where
197        T: Copy + Into<f32>,
198        F: FnMut(f32, usize),
199        Unsigned: Representation<NBITS>,
200        Perm: PermutationStrategy<NBITS>,
201    {
202        let len = self.shift.len();
203        assert_eq!(from.len(), len);
204        assert_eq!(into.len(), len);
205
206        let domain = Unsigned::domain_const::<NBITS>();
207        let min = *domain.start() as f32;
208        let max = *domain.end() as f32;
209        let inverse_scale = bit_scale::<NBITS>() / (self.scale);
210        let mut nan_check = false;
211
212        std::iter::zip(from.iter(), self.shift.iter())
213            .enumerate()
214            .for_each(|(i, (&f, &s))| {
215                // Center and scale this component.
216                // Then clamp to the unsigned dynamic range representable by the quantizer.
217                let f: f32 = f.into();
218                nan_check |= f.is_nan();
219
220                let code: f32 = ((f - s) * inverse_scale).clamp(min, max).round();
221
222                // Let the callback do some work on the final code if desired.
223                callback(code, i);
224
225                // SAFETY: We've checked that `into` and `from` have the same length.
226                // The iterator will ensure the `i < into.len()`.
227                //
228                // By construction, `code` is in the domain of this `Unsigned` so the conversion
229                // to `u8` is lossless.
230                unsafe { into.set_unchecked(i, code as u8) };
231            });
232
233        if nan_check {
234            Err(InputContainsNaN)
235        } else {
236            Ok(())
237        }
238    }
239
240    /// Compare two `ScalarQuantizer` instances field by field.
241    /// On success, returns `Ok(())`. On failure, returns `Err(SQComparisonError)`
242    /// explaining which field differs.
243    pub fn compare(&self, other: &Self) -> Result<(), SQComparisonError> {
244        if self.scale != other.scale {
245            return Err(SQComparisonError::Scale(self.scale, other.scale));
246        }
247
248        if self.shift.len() != other.shift.len() {
249            return Err(SQComparisonError::ShiftLength(
250                self.shift.len(),
251                other.shift.len(),
252            ));
253        }
254
255        for (i, (a, b)) in self.shift.iter().zip(other.shift.iter()).enumerate() {
256            if a != b {
257                return Err(SQComparisonError::ShiftElement {
258                    index: i,
259                    a: *a,
260                    b: *b,
261                });
262            }
263        }
264
265        if self.shift_square_norm != other.shift_square_norm {
266            return Err(SQComparisonError::ShiftSquareNorm(
267                self.shift_square_norm,
268                other.shift_square_norm,
269            ));
270        }
271
272        match (&self.mean_norm, &other.mean_norm) {
273            (Some(a), Some(b)) => {
274                if a != b {
275                    return Err(SQComparisonError::MeanNorm(*a, *b));
276                }
277            }
278            (None, None) => {
279                // both are None, no issue
280            }
281            _ => {
282                return Err(SQComparisonError::MeanNormPresence);
283            }
284        }
285
286        Ok(())
287    }
288}
289
290#[derive(Debug, Error, Clone, Copy)]
291#[error("mean norm is missing from the quantizer")]
292#[non_exhaustive]
293pub struct MeanNormMissing;
294
295#[derive(Debug, Error, Clone, Copy)]
296#[error("input contains NaN")]
297#[non_exhaustive]
298pub struct InputContainsNaN;
299
300fn rescale(x: &mut [f32], to_norm: f32) {
301    let norm_square: MathematicalValue<f32> =
302        diskann_vector::distance::InnerProduct::evaluate(&*x, &*x);
303    let norm = norm_square.into_inner().sqrt();
304    if norm == 0.0 {
305        return;
306    }
307
308    let scale = to_norm / norm;
309    x.iter_mut().for_each(|i| (*i) *= scale);
310}
311
312///////////////////////
313// Distance Functors //
314///////////////////////
315
316impl AsFunctor<CompensatedSquaredL2> for ScalarQuantizer {
317    fn as_functor(&self) -> CompensatedSquaredL2 {
318        let scale = self.scale();
319        CompensatedSquaredL2::new(scale * scale)
320    }
321}
322
323impl AsFunctor<CompensatedIP> for ScalarQuantizer {
324    fn as_functor(&self) -> CompensatedIP {
325        let scale = self.scale();
326        CompensatedIP::new(scale * scale, self.shift_square_norm())
327    }
328}
329
330impl AsFunctor<CompensatedCosineNormalized> for ScalarQuantizer {
331    fn as_functor(&self) -> CompensatedCosineNormalized {
332        let scale = self.scale();
333        CompensatedCosineNormalized::new(scale * scale)
334    }
335}
336
337/////////////////
338// Compression //
339/////////////////
340
341impl<const NBITS: usize, T, Perm> CompressInto<&[T], MutBitSlice<'_, NBITS, Unsigned, Perm>>
342    for ScalarQuantizer
343where
344    T: Copy + Into<f32>,
345    Unsigned: Representation<NBITS>,
346    Perm: PermutationStrategy<NBITS>,
347{
348    type Error = InputContainsNaN;
349
350    type Output = ();
351
352    /// Compress the input vector `from` into the bitslice `into`.
353    ///
354    /// This method *does not* compute compensation coefficients required for fast
355    /// inner product computations. If only L2 distances is desired, this method can be
356    /// slightly faster.
357    ///
358    /// # Error
359    ///
360    /// Returns an error if the input contains `NaN`.
361    ///
362    /// # Panics
363    ///
364    /// Panics if:
365    /// * `from.len() != self.dim()`: Vector to be compressed must have the same
366    ///   dimensionality as the quantizer.
367    /// * `into.len() != self.dim()`: Compressed vector must have the same dimensionality
368    ///   as the quantizer.
369    fn compress_into(
370        &self,
371        from: &[T],
372        into: MutBitSlice<'_, NBITS, Unsigned, Perm>,
373    ) -> Result<(), Self::Error> {
374        // In this case, we don't need to pass anything special for `callback` because
375        // there is no extra computation needed.
376        ScalarQuantizer::compress(self, from, into, |_, _| {})
377    }
378}
379
380impl<const NBITS: usize, T, Perm> CompressInto<&[T], MutCompensatedVectorRef<'_, NBITS, Perm>>
381    for ScalarQuantizer
382where
383    T: Copy + Into<f32>,
384    Unsigned: Representation<NBITS>,
385    Perm: PermutationStrategy<NBITS>,
386{
387    type Error = InputContainsNaN;
388
389    type Output = ();
390
391    /// Compress the input vector `from` into the bitslice `into`.
392    ///
393    /// This method computes and stores the compensation coefficient required for fast
394    /// inner product computations.
395    ///
396    /// # Error
397    ///
398    /// Returns an error if the input contains `NaN`.
399    ///
400    /// # Panics
401    ///
402    /// Panics if:
403    /// * `from.len() != self.dim()`: Vector to be compressed must have the same
404    ///   dimensionality as the quantizer.
405    /// * `into.len() != self.dim()`: Compressed vector must have the same dimensionality
406    ///   as the quantizer.
407    fn compress_into(
408        &self,
409        from: &[T],
410        mut into: MutCompensatedVectorRef<'_, NBITS, Perm>,
411    ) -> Result<(), Self::Error> {
412        // Compress the raw code.
413        //
414        // While doing so, also compute the dot prodcut between the encoded vector and
415        // the shift.
416        let mut dot: f32 = 0.0;
417        let result = ScalarQuantizer::compress(
418            self,
419            from,
420            into.vector_mut(),
421            // Compute the dot-product between `shift` and the compressed values.
422            |code: f32, index: usize| {
423                dot = code.mul_add(self.shift[index], dot);
424            },
425        );
426        into.set_meta(Compensation(
427            self.scale * inverse_bit_scale::<NBITS>() * dot,
428        ));
429        result
430    }
431}
432
433#[derive(Debug, Error, PartialEq)]
434pub enum SQComparisonError {
435    #[error("Scale mismatch: {0} vs {1}")]
436    Scale(f32, f32),
437
438    #[error("Shift vector length mismatch: {0} vs {1}")]
439    ShiftLength(usize, usize),
440
441    #[error("Shift element at index {index} mismatch: {a} vs {b}")]
442    ShiftElement { index: usize, a: f32, b: f32 },
443
444    #[error("Shift square norm mismatch: {0} vs {1}")]
445    ShiftSquareNorm(f32, f32),
446
447    #[error("Mean norm mismatch: {0} vs {1}")]
448    MeanNorm(f32, f32),
449
450    #[error("Mean norm is missing in one quantizer but present in the other")]
451    MeanNormPresence,
452}
453
454///////////
455// Tests //
456///////////
457
458#[cfg(test)]
459mod tests {
460    use std::collections::HashSet;
461
462    use diskann_utils::{ReborrowMut, views};
463
464    use rand::{
465        SeedableRng,
466        distr::{Distribution, Uniform},
467        rngs::StdRng,
468        seq::SliceRandom,
469    };
470    use rand_distr::Normal;
471
472    use super::*;
473    use crate::{
474        bits::BoxedBitSlice,
475        scalar::{CompensatedVector, inverse_bit_scale},
476    };
477
478    /// Test Rescale.
479    #[test]
480    fn test_rescale() {
481        let dim = 32;
482        let to_norm = 25.0;
483
484        let mut rng = StdRng::seed_from_u64(0x64e956ca2eb726ee);
485        let distribution = Normal::<f32>::new(0.0, 16.0).unwrap();
486
487        let mut v: Vec<f32> = distribution.sample_iter(&mut rng).take(dim).collect();
488        let norm = v.iter().map(|&i| i * i).sum::<f32>().sqrt();
489
490        rescale(&mut v, to_norm);
491        let norm_next = v.iter().map(|&i| i * i).sum::<f32>().sqrt();
492        let relative_error = (norm_next - to_norm).abs() / to_norm;
493
494        assert!(
495            relative_error <= 1.0e-7,
496            "vector was not renormalized, expected {}, got {}, started with {}. Relative error: {}",
497            to_norm,
498            norm_next,
499            norm,
500            relative_error,
501        );
502
503        // Ensure that zero normed vectors are handled properly.
504        let mut v: Vec<f32> = vec![0.0; dim];
505        rescale(&mut v, 10.0);
506        assert!(v.iter().all(|&i| i == 0.0));
507
508        // Test in the context of a quantizer.
509        let mut quantizer = ScalarQuantizer::new(0.0, vec![0.0; dim], Some(to_norm));
510
511        let mut v: Vec<f32> = distribution.sample_iter(&mut rng).take(dim).collect();
512        let norm = v.iter().map(|&i| i * i).sum::<f32>().sqrt();
513
514        quantizer.rescale(&mut v).unwrap();
515        let norm_next = v.iter().map(|&i| i * i).sum::<f32>().sqrt();
516        let relative_error = (norm_next - to_norm).abs() / to_norm;
517
518        assert!(
519            relative_error <= 1.0e-7,
520            "vector was not renormalized, expected {}, got {}, started with {}. Relative error: {}",
521            to_norm,
522            norm_next,
523            norm,
524            relative_error,
525        );
526
527        // Ensure that zero normed vectors are handled properly.
528        let mut v: Vec<f32> = vec![0.0; dim];
529        quantizer.rescale(&mut v).unwrap();
530        assert!(v.iter().all(|&i| i == 0.0));
531
532        // If the `mean_norm` is `None`, ensure we get an error.
533        quantizer.mean_norm = None;
534        let r = quantizer.rescale(&mut v);
535        assert!(matches!(r, Err(MeanNormMissing)));
536    }
537
538    /// Quantizer tests
539    ///
540    /// This test works as follows: we create a simple quantizer with a know shift and scale.
541    ///
542    /// We then provide a range of offsets relative to the shift vector: some below and
543    /// enough above to hit all the codes representable by the quantizer.
544    ///
545    /// These offsets are applied in different orders to each dimensions.
546    ///
547    /// Our checks are this:
548    ///
549    /// * If a value is *lower* than the shift (i.e., a negative offset), its code should be 0.
550    /// * If an offset is *above* `2^NBITS * shift`, its code should be `2^NBITS - 1`.
551    /// * For any offset in between, the reconstructed offsets computed by `shift * code`
552    ///   should have an error no more than `shift / 2.0`.
553    fn test_nbit_quantizer<const NBITS: usize>(dim: usize, rng: &mut StdRng)
554    where
555        Unsigned: Representation<NBITS>,
556        ScalarQuantizer: for<'a, 'b> CompressInto<&'a [f32], MutBitSlice<'b, NBITS, Unsigned>>
557            + for<'a, 'b> CompressInto<&'a [f32], MutCompensatedVectorRef<'b, NBITS>>,
558    {
559        let distribution = Uniform::new_inclusive::<i64, i64>(-10, 10).unwrap();
560        let shift: Vec<f32> = (0..dim).map(|_| distribution.sample(rng) as f32).collect();
561        let scale: f32 = 2.0;
562        let mean_norm: f32 = 1.0;
563
564        let quantizer =
565            ScalarQuantizer::new(scale * bit_scale::<NBITS>(), shift.clone(), Some(mean_norm));
566
567        assert_eq!(quantizer.dim(), dim);
568        assert_eq!(quantizer.scale(), scale * bit_scale::<NBITS>());
569        assert_eq!(quantizer.shift(), shift);
570        assert_eq!(quantizer.mean_norm().unwrap(), mean_norm);
571
572        let expected_shift_norm: f32 = shift.iter().map(|&i| i * i).sum();
573        assert_eq!(quantizer.shift_square_norm(), expected_shift_norm);
574
575        // Check conversion to distance functors.
576        {
577            let l2: CompensatedSquaredL2 = quantizer.as_functor();
578            assert_eq!(l2.scale_squared, quantizer.scale() * quantizer.scale());
579
580            let ip: CompensatedIP = quantizer.as_functor();
581            assert_eq!(ip.scale_squared, quantizer.scale() * quantizer.scale());
582            assert_eq!(ip.shift_square_norm, quantizer.shift_square_norm());
583        }
584
585        // Our strategy here is to generate a range of values for each dimension that should
586        // enable all encodings. The reconstruction error for encoded vectors should be
587        // within `scale / 2.0` (if the values are in range).
588        let sample_points: f32 = 1.25 * (2_usize.pow(NBITS as u32) as f32) + 10.0;
589
590        let min_encodable: f32 = 0.0;
591        let max_encodable: f32 = (*Unsigned::domain_const::<NBITS>().end() as f32) * scale;
592
593        // Create a shuffled matrix of offset values for each dimension. This ensure that
594        // each dimension covers the target dynamic range, but in a different order so
595        // we can rule out cross-coupling of dimensions.
596        let dim_offsets: views::Matrix<f32> = {
597            let range_min = -min_encodable - 3.0 * scale;
598            let range_max = max_encodable + 3.0 * scale;
599            let mut base: Vec<f32> = Vec::new();
600
601            let step_size = (range_max - range_min) / sample_points;
602            let mut i: f32 = range_min;
603            while i < range_max {
604                base.push(i);
605                i += step_size;
606            }
607            // Push one more to have one point above `range_max`.
608            base.push(i);
609
610            let mut output = views::Matrix::new(0.0, base.len(), dim);
611            (0..dim).for_each(|j| {
612                base.shuffle(rng);
613                for (i, b) in base.iter().enumerate() {
614                    output[(i, j)] = *b;
615                }
616            });
617            output
618        };
619        let ntests = dim_offsets.nrows();
620        assert!(ntests as f32 >= sample_points);
621
622        // Post-run checks to ensure coverage.
623        let mut seen_below_min = false;
624        let mut seen_above_max = false;
625        let mut seen: Vec<HashSet<i64>> = (0..dim).map(|_| HashSet::new()).collect();
626
627        // Reuse query space across tests.
628        let mut query: Vec<f32> = vec![0.0; dim];
629        for test_number in 0..ntests {
630            let offsets = dim_offsets.row(test_number);
631            query
632                .iter_mut()
633                .zip(std::iter::zip(shift.iter(), offsets.iter()))
634                .for_each(|(q, (c, o))| {
635                    *q = *c + *o;
636                });
637
638            // Test both `UnsignedBitSlice` and `CompensatedVector`.
639            let mut bitslice = BoxedBitSlice::<NBITS, _>::new_boxed(dim);
640            let mut compensated = CompensatedVector::<NBITS>::new_boxed(dim);
641
642            quantizer
643                .compress_into(&*query, bitslice.reborrow_mut())
644                .unwrap();
645            quantizer
646                .compress_into(&*query, compensated.reborrow_mut())
647                .unwrap();
648
649            // Start checking!.
650            let domain = Unsigned::domain_const::<NBITS>();
651
652            // We compute the expected compensation inline with the checking code.
653            let mut computed_compensation: f32 = 0.0;
654            for d in 0..dim {
655                let code = bitslice.get(d).unwrap();
656                computed_compensation = (code as f32).mul_add(shift[d], computed_compensation);
657
658                // Mark this code as having been observed.
659                seen[d].insert(code);
660
661                let offset = offsets[d];
662                if offset <= min_encodable {
663                    assert_eq!(
664                        code,
665                        *domain.start(),
666                        "expected values below threshold to be set to zero \
667                         test_number = {}, dim = {} of {}, offset = {}, scale = {}",
668                        test_number,
669                        d,
670                        dim,
671                        offset,
672                        scale,
673                    );
674                    seen_below_min = true;
675                } else if offset >= max_encodable {
676                    assert_eq!(
677                        code,
678                        *domain.end(),
679                        "expected values below threshold to be set to max value \
680                         test_number = {}, dim = {} of {}, offset = {}, scale = {}",
681                        test_number,
682                        d,
683                        dim,
684                        offset,
685                        scale,
686                    );
687                    seen_above_max = true;
688                } else {
689                    // This value is encodable - make sure its reconstruction error is with
690                    // our tolerance.
691                    let reconstructed =
692                        quantizer.scale() * (code as f32) * inverse_bit_scale::<NBITS>();
693                    let error = (offset - reconstructed).abs();
694                    assert!(
695                        error <= scale / 2.0,
696                        "failed reconstruction check: \
697                         test_number = {}, dim = {} of {}, offset = {}, scale = {} \
698                         code = {}, reconstructed = {}, error = {}",
699                        test_number,
700                        d,
701                        dim,
702                        offset,
703                        scale,
704                        code,
705                        reconstructed,
706                        error,
707                    );
708                }
709
710                // Now that we have checked the reconstruction, ensure that the
711                // `CompensatedVector` has the same code.
712                assert_eq!(
713                    compensated.vector().get(d).unwrap(),
714                    code,
715                    "compensated disagrees with bitslice"
716                );
717            }
718            assert_eq!(scale * computed_compensation, compensated.meta().0);
719        }
720
721        // Check coverage.
722        assert!(seen_below_min);
723        assert!(seen_above_max);
724        let num_codes = 2usize.pow(NBITS as u32);
725        for (i, s) in seen.iter().enumerate() {
726            assert_eq!(
727                s.len(),
728                num_codes,
729                "dimension {} did not have full coverage",
730                i
731            );
732        }
733
734        // Check NaN detection.
735        {
736            let mut query: Vec<f32> = shift.clone();
737            let mut bitslice = BoxedBitSlice::<NBITS, _>::new_boxed(query.len());
738            let mut compensated = CompensatedVector::<NBITS>::new_boxed(query.len());
739            for i in 0..query.len() {
740                let last = query[i];
741                query[i] = f32::NAN;
742
743                let err = quantizer
744                    .compress_into(&*query, bitslice.reborrow_mut())
745                    .unwrap_err();
746                assert_eq!(err.to_string(), "input contains NaN");
747
748                let err = quantizer
749                    .compress_into(&*query, compensated.reborrow_mut())
750                    .unwrap_err();
751                assert_eq!(err.to_string(), "input contains NaN");
752
753                query[i] = last;
754            }
755        }
756    }
757
758    cfg_if::cfg_if! {
759        if #[cfg(miri)] {
760            const TEST_DIM: usize = 2;
761        } else {
762            const TEST_DIM: usize = 10;
763        }
764    }
765
766    macro_rules! test_quantizer {
767        ($name:ident, $nbits:literal, $seed:literal) => {
768            #[test]
769            fn $name() {
770                let mut rng = StdRng::seed_from_u64($seed);
771                test_nbit_quantizer::<$nbits>(TEST_DIM, &mut rng);
772            }
773        };
774    }
775
776    test_quantizer!(test_8bit_quantizer, 8, 0xb7b4c124102b9fb9);
777    test_quantizer!(test_7bit_quantizer, 7, 0x86d19a821fe934d1);
778    test_quantizer!(test_6bit_quantizer, 6, 0x0de9610f0b9be4f7);
779    test_quantizer!(test_5bit_quantizer, 5, 0x605ed3e7ed775047);
780    test_quantizer!(test_4bit_quantizer, 4, 0x9b66ace7090fa728);
781    test_quantizer!(test_3bit_quantizer, 3, 0x0ce424ddc61ebdb0);
782    test_quantizer!(test_2bit_quantizer, 2, 0x2ba8e5ef6415d4f0);
783    test_quantizer!(test_1bit_quantizer, 1, 0xdcd8c10c4a407956);
784
785    fn base_quantizer() -> ScalarQuantizer {
786        ScalarQuantizer {
787            scale: 2.0,
788            shift: vec![1.0, -1.0, 0.5],
789            shift_square_norm: 1.0_f32 * 1.0 + (-1.0_f32) * (-1.0) + 0.5_f32 * 0.5,
790            mean_norm: Some(4.13),
791        }
792    }
793
794    #[test]
795    fn test_compare_identical_returns_ok() {
796        let q1 = base_quantizer();
797        let q2 = base_quantizer();
798        assert!(q1.compare(&q2).is_ok());
799    }
800
801    #[test]
802    fn test_compare_scale_mismatch() {
803        let q1 = base_quantizer();
804        let mut q2 = base_quantizer();
805        q2.scale = 4.0;
806        let err = q1.compare(&q2).unwrap_err();
807        assert_eq!(err, SQComparisonError::Scale(2.0, 4.0));
808    }
809
810    #[test]
811    fn test_compare_shift_length_mismatch() {
812        let q1 = base_quantizer();
813        let mut q2 = base_quantizer();
814        q2.shift.push(0.0);
815        let err = q1.compare(&q2).unwrap_err();
816        assert_eq!(
817            err,
818            SQComparisonError::ShiftLength(q1.shift.len(), q2.shift.len())
819        );
820    }
821
822    #[test]
823    fn test_compare_shift_element_mismatch() {
824        let q1 = base_quantizer();
825        let mut q2 = base_quantizer();
826        q2.shift[2] = 0.0;
827        let err = q1.compare(&q2).unwrap_err();
828        match err {
829            SQComparisonError::ShiftElement { index, a, b } => {
830                assert_eq!(index, 2);
831                assert_eq!(a, 0.5);
832                assert_eq!(b, 0.0);
833            }
834            _ => panic!("Expected ShiftElementMismatch variant"),
835        }
836    }
837
838    #[test]
839    fn test_compare_shift_square_norm_mismatch() {
840        let q1 = base_quantizer();
841        let mut q2 = base_quantizer();
842        q2.shift_square_norm = 9.0;
843        let err = q1.compare(&q2).unwrap_err();
844        assert_eq!(err, SQComparisonError::ShiftSquareNorm(2.25, 9.0));
845    }
846
847    #[test]
848    fn test_compare_mean_norm_value_mismatch() {
849        let q1 = base_quantizer();
850        let mut q2 = base_quantizer();
851        q2.mean_norm = Some(1.0);
852        let err = q1.compare(&q2).unwrap_err();
853        assert_eq!(err, SQComparisonError::MeanNorm(4.13, 1.0));
854    }
855
856    #[test]
857    fn test_compare_mean_norm_presence_mismatch_left_none() {
858        let mut q1 = base_quantizer();
859        let q2 = base_quantizer();
860        q1.mean_norm = None;
861        let err = q1.compare(&q2).unwrap_err();
862        assert_eq!(err, SQComparisonError::MeanNormPresence);
863    }
864
865    #[test]
866    fn test_compare_mean_norm_presence_mismatch_right_none() {
867        let q1 = base_quantizer();
868        let mut q2 = base_quantizer();
869        q2.mean_norm = None;
870        let err = q1.compare(&q2).unwrap_err();
871        assert_eq!(err, SQComparisonError::MeanNormPresence);
872    }
873}