diskann_quantization/spherical/
quantizer.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::num::NonZeroUsize;
7
8use diskann_utils::{views::MatrixView, ReborrowMut};
9use diskann_vector::{
10    distance::InnerProduct, norm::FastL2Norm, MathematicalValue, Norm, PureDistanceFunction,
11};
12#[cfg(feature = "flatbuffers")]
13use flatbuffers::{FlatBufferBuilder, WIPOffset};
14use rand::{Rng, RngCore};
15use thiserror::Error;
16
17use super::{
18    CompensatedCosine, CompensatedIP, CompensatedSquaredL2, DataMeta, DataMetaError, DataMut,
19    FullQueryMeta, FullQueryMut, QueryMeta, QueryMut, SupportedMetric,
20};
21#[cfg(feature = "flatbuffers")]
22use crate::{
23    algorithms::transforms::TransformError, flatbuffers::spherical, spherical::InvalidMetric,
24};
25use crate::{
26    algorithms::{
27        heap::SliceHeap,
28        transforms::{NewTransformError, Transform, TransformFailed, TransformKind},
29    },
30    alloc::{Allocator, AllocatorError, GlobalAllocator, Poly, ScopedAllocator, TryClone},
31    bits::{PermutationStrategy, Representation, Unsigned},
32    num::Positive,
33    utils::{compute_means_and_average_norm, compute_normalized_means, CannotBeEmpty},
34    AsFunctor, CompressIntoWith,
35};
36
37///////////////
38// Quantizer //
39///////////////
40
41#[derive(Debug)]
42#[cfg_attr(test, derive(PartialEq))]
43pub struct SphericalQuantizer<A = GlobalAllocator>
44where
45    A: Allocator,
46{
47    /// The offset to apply to each vector.
48    shift: Poly<[f32], A>,
49
50    /// The [`SphericalQuantizer`] supports several different strategies for performing the
51    /// distance-preserving transformation on dataset vectors, which may be applicable in
52    /// different scenarios.
53    ///
54    /// The different transformations may have restrictions on the number of supported dimensions.
55    /// While we will accept all non-zero input dimensions, the output dimension of a transform
56    /// may be higher or lower, depending on the configuration.
57    transform: Transform<A>,
58
59    /// The metric meant to be used by the quantizer.
60    metric: SupportedMetric,
61
62    /// When processing queries, it may be beneficial to modify the query norm to match the
63    /// dataset norm.
64    ///
65    /// This is only applicable when `InnerProduct` and `Cosine` are used, but serves to
66    /// move the query into the dynamic range of the quantization.
67    ///
68    /// You would think that the normalization step in RabitQ would mitigate this, but
69    /// that is not always right since range-adjustment happens before centering.
70    mean_norm: Positive<f32>,
71
72    /// To support 16-bit constants which have a limited dynamic range, we allow a
73    /// pre-scaling parameter that is multiplied to each value in compressed vectors.
74    ///
75    /// This allows to transparent handling of compressing integral data, which can
76    /// otherwise easily overflow `f16`.
77    pre_scale: Positive<f32>,
78}
79
80impl<A> TryClone for SphericalQuantizer<A>
81where
82    A: Allocator,
83{
84    fn try_clone(&self) -> Result<Self, AllocatorError> {
85        Ok(Self {
86            shift: self.shift.try_clone()?,
87            transform: self.transform.try_clone()?,
88            metric: self.metric,
89            mean_norm: self.mean_norm,
90            pre_scale: self.pre_scale,
91        })
92    }
93}
94
95#[derive(Debug, Clone, Copy, Error)]
96#[non_exhaustive]
97pub enum TrainError {
98    #[error("data dim cannot be zero")]
99    DimCannotBeZero,
100    #[error("data cannot be empty")]
101    DataCannotBeEmpty,
102    #[error("pre-scale must be positive")]
103    PrescaleNotPositive,
104    #[error("norm must be positive")]
105    NormNotPositive,
106    #[error("computed norm contains infinity or NaN")]
107    NormNotFinite,
108    #[error("reciprocal norm contains infinity or NaN")]
109    ReciprocalNormNotFinite,
110    #[error(transparent)]
111    AllocatorError(#[from] AllocatorError),
112}
113
114impl<A> SphericalQuantizer<A>
115where
116    A: Allocator,
117{
118    /// Return the number dimensions this quantizer has been trained for.
119    pub fn input_dim(&self) -> usize {
120        self.shift.len()
121    }
122
123    /// Return the dimension of the post-transformed vector.
124    ///
125    /// Output storage vectors should use this dimension instead of `self.dim()` because
126    /// in general, the output dim **may** be different from the input dimension.
127    pub fn output_dim(&self) -> usize {
128        self.transform.output_dim()
129    }
130
131    /// Return the per-dimension shift vector.
132    ///
133    /// This vector is meant to accomplish two goals:
134    ///
135    /// 1. Centers the data around the training dataset mean.
136    /// 2. Offsets each dimension into a range that can be encoded in unsigned values.
137    pub fn shift(&self) -> &[f32] {
138        &self.shift
139    }
140
141    /// Return the approximate mean norm of the training data.
142    pub fn mean_norm(&self) -> Positive<f32> {
143        self.mean_norm
144    }
145
146    /// Return the pre-scaling parameter for data. This value is multiplied to every
147    /// compressed vector to adjust its dynamic range.
148    ///
149    /// A value of 1.0 means that no scaling is occurring.
150    pub fn pre_scale(&self) -> Positive<f32> {
151        self.pre_scale
152    }
153
154    /// Return a reference to the allocator used by this data structure.
155    pub fn allocator(&self) -> &A {
156        self.shift.allocator()
157    }
158
159    /// A lower-level constructor that accepts a centroid, mean norm, and pre-scale directly.
160    pub fn generate(
161        mut centroid: Poly<[f32], A>,
162        mean_norm: f32,
163        transform: TransformKind,
164        metric: SupportedMetric,
165        pre_scale: Option<f32>,
166        rng: &mut dyn RngCore,
167        allocator: A,
168    ) -> Result<Self, TrainError> {
169        let pre_scale = match pre_scale {
170            Some(v) => Positive::new(v).map_err(|_| TrainError::PrescaleNotPositive)?,
171            None => crate::num::POSITIVE_ONE_F32,
172        };
173
174        let dim = match NonZeroUsize::new(centroid.len()) {
175            Some(dim) => dim,
176            None => {
177                return Err(TrainError::DimCannotBeZero);
178            }
179        };
180
181        let mean_norm = Positive::new(mean_norm).map_err(|_| TrainError::NormNotPositive)?;
182
183        // We passed in a 'rng' so `Transform::new` will not fail.
184        let transform = match Transform::new(transform, dim, Some(rng), allocator.clone()) {
185            Ok(v) => v,
186            Err(NewTransformError::RngMissing(_)) => unreachable!("An Rng was provided"),
187            Err(NewTransformError::AllocatorError(err)) => {
188                return Err(TrainError::AllocatorError(err));
189            }
190        };
191
192        // Transform the centroid by the pre-scale.
193        centroid
194            .iter_mut()
195            .for_each(|v| *v *= pre_scale.into_inner());
196
197        Ok(SphericalQuantizer {
198            shift: centroid,
199            transform,
200            metric,
201            mean_norm,
202            pre_scale,
203        })
204    }
205
206    /// Return the metric used by this quantizer.
207    pub fn metric(&self) -> SupportedMetric {
208        self.metric
209    }
210
211    /// Construct a quantizer for vectors in the distribution of `data`.
212    ///
213    /// The type of distance-preserving transform to use is selected by the [`TransformKind`].
214    ///
215    /// Vectors compressed with this quantizer will be **metric specific** and optimized for
216    /// distance computations rather than reconstruction. This means that vectors compressed
217    /// targeting the inner-product distance will not return meaningful results if used for
218    /// L2 distance computations.
219    ///
220    /// Additionally, vectors compressed when using the [`SupportedMetric::Cosine`] distance
221    /// will be implicitly normalized before being compressed to enable better compression.
222    ///
223    /// If argument `pre_scale` is given, then all vectors compressed by this quantizer will
224    /// first be scaled by this value. Note that if given, `pre_scale` **must** be positive.
225    pub fn train<T, R>(
226        data: MatrixView<T>,
227        transform: TransformKind,
228        metric: SupportedMetric,
229        pre_scale: PreScale,
230        rng: &mut R,
231        allocator: A,
232    ) -> Result<Self, TrainError>
233    where
234        T: Copy + Into<f64> + Into<f32>,
235        R: Rng,
236    {
237        // An inner implementation erasing the type of the random number generator to
238        // cut down on excess monomorphization.
239        #[inline(never)]
240        fn train<T, A>(
241            data: MatrixView<T>,
242            transform: TransformKind,
243            metric: SupportedMetric,
244            pre_scale: PreScale,
245            rng: &mut dyn RngCore,
246            allocator: A,
247        ) -> Result<SphericalQuantizer<A>, TrainError>
248        where
249            T: Copy + Into<f64> + Into<f32>,
250            A: Allocator,
251        {
252            // This check is repeated in `Self::generate`, but we prefer to bail as early
253            // as possible if we detect an error.
254            if data.ncols() == 0 {
255                return Err(TrainError::DimCannotBeZero);
256            }
257
258            let (centroid, mean_norm) = match metric {
259                SupportedMetric::SquaredL2 | SupportedMetric::InnerProduct => {
260                    compute_means_and_average_norm(data)
261                }
262                SupportedMetric::Cosine => (
263                    compute_normalized_means(data)
264                        .map_err(|_: CannotBeEmpty| TrainError::DataCannotBeEmpty)?,
265                    1.0,
266                ),
267            };
268
269            let mean_norm = mean_norm as f32;
270
271            if mean_norm <= 0.0 {
272                return Err(TrainError::NormNotPositive);
273            }
274
275            if !mean_norm.is_finite() {
276                return Err(TrainError::NormNotFinite);
277            }
278
279            // Determining if (and how) the pre-scaling term will be calculated.
280            let pre_scale: Positive<f32> = match pre_scale {
281                PreScale::None => crate::num::POSITIVE_ONE_F32,
282                PreScale::Some(v) => v,
283                PreScale::ReciprocalMeanNorm => {
284                    // We've checked that `mean_norm` is both positive and finite.
285                    //
286                    // It's possible that when converted to `f32`,
287                    //
288                    // Taking the reciprocal is well defined. However, since the norms
289                    // and scales in `compute_means_and_average_norm` are done using `f64`,
290                    // it's possible that the computed `mean_norm` is subnormal, leading
291                    // to the reciprocal being infinity.
292                    let pre_scale = Positive::new(1.0 / mean_norm)
293                        .map_err(|_| TrainError::ReciprocalNormNotFinite)?;
294
295                    if !pre_scale.into_inner().is_finite() {
296                        return Err(TrainError::ReciprocalNormNotFinite);
297                    }
298
299                    pre_scale
300                }
301            };
302
303            // Allow the pre-scaling to take place inside `Self::generate`.
304            let centroid =
305                Poly::from_iter(centroid.into_iter().map(|i| i as f32), allocator.clone())?;
306
307            SphericalQuantizer::generate(
308                centroid,
309                mean_norm,
310                transform,
311                metric,
312                Some(pre_scale.into_inner()),
313                rng,
314                allocator,
315            )
316        }
317
318        train(data, transform, metric, pre_scale, rng, allocator)
319    }
320
321    /// Rescale the argument `v` to be in the rough dynamic range of the training dataset.
322    pub fn rescale(&self, v: &mut [f32]) {
323        let norm = FastL2Norm.evaluate(&*v);
324        let m = self.mean_norm.into_inner() / norm;
325        v.iter_mut().for_each(|i| *i *= m);
326    }
327
328    /// Private helper function to do common data pre-processing.
329    ///
330    /// # Panics
331    ///
332    /// Panics if `data.len() != self.dim()`.
333    fn preprocess<'a>(
334        &self,
335        data: &[f32],
336        allocator: ScopedAllocator<'a>,
337    ) -> Result<Preprocessed<'a>, CompressionError> {
338        assert_eq!(data.len(), self.input_dim(), "Data dimension is incorrect.");
339
340        // Fold in pre-scaling with the potential norm corretion for cosine.
341        //
342        // NOTE: When we're computing Cosine Similarity, we normalize the vector. As such,
343        // the `pre_scale` parameter become irrelvant since it just gets normalized away.
344        let scale = self.pre_scale.into_inner();
345        let mul: f32 = match self.metric {
346            SupportedMetric::Cosine => {
347                let norm: f32 = (FastL2Norm).evaluate(data);
348                if norm == 0.0 {
349                    1.0
350                } else {
351                    1.0 / norm
352                }
353            }
354            SupportedMetric::SquaredL2 | SupportedMetric::InnerProduct => scale,
355        };
356
357        // Center the vector and compute the squared norm of the shifted vector.
358        let shifted = Poly::from_iter(
359            std::iter::zip(data.iter(), self.shift.iter()).map(|(&f, &s)| mul * f - s),
360            allocator,
361        )?;
362
363        let shifted_norm = FastL2Norm.evaluate(&*shifted);
364        if !shifted_norm.is_finite() {
365            return Err(CompressionError::InputContainsNaN);
366        }
367        let inner_product_with_centroid = match self.metric {
368            SupportedMetric::SquaredL2 => None,
369            SupportedMetric::InnerProduct | SupportedMetric::Cosine => {
370                let ip: MathematicalValue<f32> = InnerProduct::evaluate(&*shifted, &*self.shift);
371                Some(ip.into_inner())
372            }
373        };
374
375        Ok(Preprocessed {
376            shifted,
377            shifted_norm,
378            inner_product_with_centroid,
379        })
380    }
381}
382
383/// Pre-scaling selector for spherical quantization training. Pre-scaling adjusts the
384/// dynamic range of the data (usually decreasing it uniformly) to keep the correction terms
385/// within the range expressible by 16-bit floating point numbers.
386#[derive(Debug, Clone, Copy)]
387pub enum PreScale {
388    /// Do not use any pre-scaling.
389    None,
390    /// Pre-scale all data by the specified amount.
391    Some(Positive<f32>),
392    /// Heuristically estimate a pre-scaling parameter by using the inverse approximate
393    /// mean norm. This will nearly normalize in-distribution vectors.
394    ReciprocalMeanNorm,
395}
396
397#[cfg(feature = "flatbuffers")]
398#[cfg_attr(docsrs, doc(cfg(feature = "flatbuffers")))]
399#[derive(Debug, Clone, Error, PartialEq)]
400#[non_exhaustive]
401pub enum DeserializationError {
402    #[error(transparent)]
403    TransformError(#[from] TransformError),
404    #[error("unrecognized flatbuffer identifier")]
405    UnrecognizedIdentifier,
406    #[error("transform length not equal to centroid")]
407    DimMismatch,
408    #[error("norm is missing or is not positive")]
409    MissingNorm,
410    #[error("pre-scale is missing or is not positive")]
411    PreScaleNotPositive,
412
413    #[error(transparent)]
414    InvalidFlatBuffer(#[from] flatbuffers::InvalidFlatbuffer),
415
416    #[error(transparent)]
417    InvalidMetric(#[from] InvalidMetric),
418
419    #[error(transparent)]
420    AllocatorError(#[from] AllocatorError),
421}
422
423#[cfg(feature = "flatbuffers")]
424#[cfg_attr(docsrs, doc(cfg(feature = "flatbuffers")))]
425impl<A> SphericalQuantizer<A>
426where
427    A: Allocator + Clone,
428{
429    /// Pack `self` into `buf` using the [`spherical::SphericalQuantizer`] serialized
430    /// representation.
431    pub(crate) fn pack<'a, FA>(
432        &self,
433        buf: &mut FlatBufferBuilder<'a, FA>,
434    ) -> WIPOffset<spherical::SphericalQuantizer<'a>>
435    where
436        FA: flatbuffers::Allocator + 'a,
437    {
438        // Save the centroid vector.
439        let centroid = buf.create_vector(&self.shift);
440
441        // Save the transform.
442        let transform = self.transform.pack(buf);
443
444        // Finish up.
445        spherical::SphericalQuantizer::create(
446            buf,
447            &spherical::SphericalQuantizerArgs {
448                centroid: Some(centroid),
449                transform: Some(transform),
450                metric: self.metric.into(),
451                mean_norm: self.mean_norm.into_inner(),
452                pre_scale: self.pre_scale.into_inner(),
453            },
454        )
455    }
456
457    /// Attempt to unpack `self` from a serialized [`spherical::SphericalQuantizer`]
458    /// serialized representation, returning any encountered error.
459    pub(crate) fn try_unpack(
460        alloc: A,
461        proto: spherical::SphericalQuantizer<'_>,
462    ) -> Result<Self, DeserializationError> {
463        let metric: SupportedMetric = proto.metric().try_into()?;
464
465        // Unpack the centroid.
466        let shift = Poly::from_iter(proto.centroid().into_iter(), alloc.clone())?;
467
468        // Unpack the transform.
469        let transform = Transform::try_unpack(alloc, proto.transform())?;
470
471        // Ensure consistency between the shift dimensions and the transform.
472        if shift.len() != transform.input_dim() {
473            return Err(DeserializationError::DimMismatch);
474        }
475
476        // Make sure we get a sane value for the mean norm.
477        let mean_norm =
478            Positive::new(proto.mean_norm()).map_err(|_| DeserializationError::MissingNorm)?;
479
480        let pre_scale = Positive::new(proto.pre_scale())
481            .map_err(|_| DeserializationError::PreScaleNotPositive)?;
482
483        Ok(Self {
484            shift,
485            transform,
486            metric,
487            mean_norm,
488            pre_scale,
489        })
490    }
491}
492
493struct Preprocessed<'a> {
494    shifted: Poly<[f32], ScopedAllocator<'a>>,
495    shifted_norm: f32,
496    inner_product_with_centroid: Option<f32>,
497}
498
499impl Preprocessed<'_> {
500    /// Return the metric specific correction term as sumamrized below:
501    ///
502    /// * Inner Product: The inner product between the shifted vector and the centroid.
503    /// * Squared L2: The squared norm of the shifted vector.
504    fn metric_specific(&self) -> f32 {
505        match self.inner_product_with_centroid {
506            Some(ip) => ip,
507            None => self.shifted_norm * self.shifted_norm,
508        }
509    }
510}
511
512///////////////////////
513// Distance Functors //
514///////////////////////
515
516impl<A> AsFunctor<CompensatedSquaredL2> for SphericalQuantizer<A>
517where
518    A: Allocator,
519{
520    fn as_functor(&self) -> CompensatedSquaredL2 {
521        CompensatedSquaredL2::new(self.output_dim())
522    }
523}
524
525impl<A> AsFunctor<CompensatedIP> for SphericalQuantizer<A>
526where
527    A: Allocator,
528{
529    fn as_functor(&self) -> CompensatedIP {
530        CompensatedIP::new(&self.shift, self.output_dim())
531    }
532}
533
534impl<A> AsFunctor<CompensatedCosine> for SphericalQuantizer<A>
535where
536    A: Allocator,
537{
538    fn as_functor(&self) -> CompensatedCosine {
539        CompensatedCosine::new(self.as_functor())
540    }
541}
542
543/////////////////
544// Compression //
545/////////////////
546
547#[derive(Debug, Error, Clone, Copy, PartialEq)]
548#[non_exhaustive]
549pub enum CompressionError {
550    #[error("input contains NaN")]
551    InputContainsNaN,
552
553    #[error("expected source vector to have length {expected}")]
554    SourceDimensionMismatch { expected: usize },
555
556    #[error("expected destination vector to have length {expected}")]
557    DestinationDimensionMismatch { expected: usize },
558
559    #[error(
560        "encoding error - you may need to scale the entire dataset to reduce its dynamic range"
561    )]
562    EncodingError(#[from] DataMetaError),
563
564    #[error(transparent)]
565    AllocatorError(#[from] AllocatorError),
566}
567
568fn check_dims(
569    input: usize,
570    output: usize,
571    from: usize,
572    into: usize,
573) -> Result<(), CompressionError> {
574    if from != input {
575        return Err(CompressionError::SourceDimensionMismatch { expected: input });
576    }
577    if into != output {
578        return Err(CompressionError::DestinationDimensionMismatch { expected: output });
579    }
580    Ok(())
581}
582
583/// Helper trait to dispatch to a faster 1-bit implementation and use the slower
584/// maximum-cosine algorithm when more than 1 bit is used.
585trait FinishCompressing {
586    fn finish_compressing(
587        &mut self,
588        preprocessed: &Preprocessed<'_>,
589        transformed: &[f32],
590        transformed_norm: f32,
591        allocator: ScopedAllocator<'_>,
592    ) -> Result<(), CompressionError>;
593}
594
595impl FinishCompressing for DataMut<'_, 1> {
596    fn finish_compressing(
597        &mut self,
598        preprocessed: &Preprocessed<'_>,
599        transformed: &[f32],
600        transformed_norm: f32,
601        _: ScopedAllocator<'_>,
602    ) -> Result<(), CompressionError> {
603        // Compute signed quantized vector (-1 or 1)
604        // and also populate the unsigned bit representation in `into` output vector.
605        let mut quant_raw_inner_product = 0.0f32;
606        let mut bit_sum = 0u32;
607        transformed.iter().enumerate().for_each(|(i, &r)| {
608            let bit: u8 = if r > 0.0 { 1 } else { 0 };
609
610            quant_raw_inner_product += r.abs();
611            bit_sum += <u8 as Into<u32>>::into(bit);
612
613            // SAFETY: From check 1, we know that `i < into.len()`.
614            unsafe { self.vector_mut().set_unchecked(i, bit) };
615        });
616
617        // The value we just computed for `quant_raw_inner_product` is:
618        // ```
619        // Y = <x', x> * sqrt(D)                        [1]
620        // ```
621        // The inner product correction term is
622        // ```
623        //       2 |X|
624        // -----------------                            [2]
625        // <x', x> * sqrt(D)
626        // ```
627        // [1] substitutes directly into [2] and we get
628        // ```
629        // 2 |X|
630        // -----
631        //   Y
632        // ```
633        // Therefore, the inner product correction term is
634        // ```
635        // 2.0 * shifted_norm / quant_raw_inner_product
636        // ```
637        let inner_product_correction =
638            2.0 * transformed_norm * preprocessed.shifted_norm / quant_raw_inner_product;
639        self.set_meta(DataMeta::new(
640            inner_product_correction,
641            preprocessed.metric_specific(),
642            bit_sum,
643        )?);
644        Ok(())
645    }
646}
647
648impl FinishCompressing for DataMut<'_, 2> {
649    fn finish_compressing(
650        &mut self,
651        preprocessed: &Preprocessed<'_>,
652        transformed: &[f32],
653        transformed_norm: f32,
654        allocator: ScopedAllocator<'_>,
655    ) -> Result<(), CompressionError> {
656        compress_via_maximum_cosine(
657            self.reborrow_mut(),
658            preprocessed,
659            transformed,
660            transformed_norm,
661            allocator,
662        )
663    }
664}
665
666impl FinishCompressing for DataMut<'_, 4> {
667    fn finish_compressing(
668        &mut self,
669        preprocessed: &Preprocessed<'_>,
670        transformed: &[f32],
671        transformed_norm: f32,
672        allocator: ScopedAllocator<'_>,
673    ) -> Result<(), CompressionError> {
674        compress_via_maximum_cosine(
675            self.reborrow_mut(),
676            preprocessed,
677            transformed,
678            transformed_norm,
679            allocator,
680        )
681    }
682}
683
684impl FinishCompressing for DataMut<'_, 8> {
685    fn finish_compressing(
686        &mut self,
687        preprocessed: &Preprocessed<'_>,
688        transformed: &[f32],
689        transformed_norm: f32,
690        allocator: ScopedAllocator<'_>,
691    ) -> Result<(), CompressionError> {
692        compress_via_maximum_cosine(
693            self.reborrow_mut(),
694            preprocessed,
695            transformed,
696            transformed_norm,
697            allocator,
698        )
699    }
700}
701
702//////////////////////
703// Data Compression //
704//////////////////////
705
706impl<A> CompressIntoWith<&[f32], FullQueryMut<'_>, ScopedAllocator<'_>> for SphericalQuantizer<A>
707where
708    A: Allocator,
709{
710    type Error = CompressionError;
711
712    /// Compress the input vector `from` into the bitslice `into`.
713    ///
714    /// # Error
715    ///
716    /// Returns an error if
717    /// * The input contains `NaN`.
718    /// * `from.len() != self.dim()`: Vector to be compressed must have the same
719    ///   dimensionality as the quantizer.
720    /// * `into.len() != self.output_dim()`: Compressed vector must have the same
721    ///   dimensionality as the output of the distance-preserving transform. Importantely,
722    ///   this **may** be different than `self.dim()` and should be retrieved from
723    ///   `self.output_dim()`.
724    fn compress_into_with(
725        &self,
726        from: &[f32],
727        mut into: FullQueryMut<'_>,
728        allocator: ScopedAllocator<'_>,
729    ) -> Result<(), Self::Error> {
730        let input_dim = self.shift.len();
731        let output_dim = self.output_dim();
732        check_dims(input_dim, output_dim, from.len(), into.len())?;
733
734        let mut preprocessed = self.preprocess(from, allocator)?;
735
736        // If the preprocessed norm is zero, then we tried to compress the center directly.
737        // In this case, we can get the correct behavior by setting `into` to all zeros.
738        if preprocessed.shifted_norm == 0.0 {
739            into.vector_mut().fill(0.0);
740            *into.meta_mut() = Default::default();
741            return Ok(());
742        }
743
744        preprocessed
745            .shifted
746            .iter_mut()
747            .for_each(|v| *v /= preprocessed.shifted_norm);
748
749        // Transformation can fail due to OOM - we want to handle that gracefully.
750        //
751        // If the transformation fails because we provided the wrong sizes, that is a hard
752        // program bug.
753        #[expect(clippy::panic, reason = "the dimensions should already be as expected")]
754        match self
755            .transform
756            .transform_into(into.vector_mut(), &preprocessed.shifted, allocator)
757        {
758            Ok(()) => {}
759            Err(TransformFailed::AllocatorError(err)) => {
760                return Err(CompressionError::AllocatorError(err))
761            }
762            Err(TransformFailed::SourceMismatch { .. })
763            | Err(TransformFailed::DestinationMismatch { .. }) => {
764                panic!(
765                    "The sizes of these arrays should already be checked - this is a logic error"
766                );
767            }
768        }
769
770        *into.meta_mut() = FullQueryMeta {
771            sum: into.vector().iter().sum::<f32>(),
772            shifted_norm: preprocessed.shifted_norm,
773            metric_specific: preprocessed.metric_specific(),
774        };
775        Ok(())
776    }
777}
778
779impl<const NBITS: usize, A> CompressIntoWith<&[f32], DataMut<'_, NBITS>, ScopedAllocator<'_>>
780    for SphericalQuantizer<A>
781where
782    A: Allocator,
783    Unsigned: Representation<NBITS>,
784    for<'a> DataMut<'a, NBITS>: FinishCompressing,
785{
786    type Error = CompressionError;
787
788    /// Compress the input vector `from` into the bitslice `into`.
789    ///
790    /// # Error
791    ///
792    /// Returns an error if
793    /// * The input contains `NaN`.
794    /// * `from.len() != self.dim()`: Vector to be compressed must have the same
795    ///   dimensionality as the quantizer.
796    /// * `into.len() != self.output_dim()`: Compressed vector must have the same
797    ///   dimensionality as the output of the distance-preserving transform. Importantely,
798    ///   this **may** be different than `self.dim()` and should be retrieved from
799    ///   `self.output_dim()`.
800    fn compress_into_with(
801        &self,
802        from: &[f32],
803        mut into: DataMut<'_, NBITS>,
804        allocator: ScopedAllocator<'_>,
805    ) -> Result<(), Self::Error> {
806        let input_dim = self.shift.len();
807        let output_dim = self.output_dim();
808        check_dims(input_dim, output_dim, from.len(), into.len())?;
809
810        let mut preprocessed = self.preprocess(from, allocator)?;
811
812        if preprocessed.shifted_norm == 0.0 {
813            into.set_meta(DataMeta::default());
814            return Ok(());
815        }
816
817        let mut transformed = Poly::broadcast(0.0f32, output_dim, allocator)?;
818        preprocessed
819            .shifted
820            .iter_mut()
821            .for_each(|v| *v /= preprocessed.shifted_norm);
822
823        // Transformation can fail due to OOM - we want to handle that gracefully.
824        //
825        // If the transformation fails because we provided the wrong sizes, that is a hard
826        // program bug.
827        #[expect(clippy::panic, reason = "the dimensions should already be as expected")]
828        match self
829            .transform
830            .transform_into(&mut transformed, &preprocessed.shifted, allocator)
831        {
832            Ok(()) => {}
833            Err(TransformFailed::AllocatorError(err)) => {
834                return Err(CompressionError::AllocatorError(err))
835            }
836            Err(TransformFailed::SourceMismatch { .. })
837            | Err(TransformFailed::DestinationMismatch { .. }) => {
838                panic!(
839                    "The sizes of these arrays should already be checked - this is a logic error"
840                );
841            }
842        }
843
844        let transformed_norm = if self.transform.preserves_norms() {
845            1.0
846        } else {
847            (FastL2Norm).evaluate(&*transformed)
848        };
849
850        into.finish_compressing(&preprocessed, &transformed, transformed_norm, allocator)?;
851        Ok(())
852    }
853}
854
855struct AsNonZero<const NBITS: usize>;
856impl<const NBITS: usize> AsNonZero<NBITS> {
857    // Lint: Unwrap is being used in a const-context.
858    #[allow(clippy::unwrap_used)]
859    const NON_ZERO: NonZeroUsize = NonZeroUsize::new(NBITS).unwrap();
860}
861
862fn compress_via_maximum_cosine<const NBITS: usize>(
863    mut data: DataMut<'_, NBITS>,
864    preprocessed: &Preprocessed<'_>,
865    transformed: &[f32],
866    transformed_norm: f32,
867    allocator: ScopedAllocator<'_>,
868) -> Result<(), CompressionError>
869where
870    Unsigned: Representation<NBITS>,
871{
872    assert_eq!(data.len(), transformed.len());
873
874    // Find the value we will use to multiply `transformed` to round it to the lattice
875    // element that has the maximum cosine-similarity.
876    let optimal_scale =
877        maximize_cosine_similarity(transformed, AsNonZero::<NBITS>::NON_ZERO, allocator)?;
878
879    let domain = Unsigned::domain_const::<NBITS>();
880    let min = *domain.start() as f32;
881    let max = *domain.end() as f32;
882    let offset = max / 2.0;
883
884    let mut self_inner_product = 0.0f32;
885    let mut bit_sum = 0u32;
886    for (i, t) in transformed.iter().enumerate() {
887        let v = (*t * optimal_scale + offset).clamp(min, max).round();
888        let dv = v - offset;
889        self_inner_product = dv.mul_add(*t, self_inner_product);
890
891        let v = v as u8;
892        bit_sum += <u8 as Into<u32>>::into(v);
893
894        // SAFETY: We have checked that `data.len() == transformed.len()`, so this access
895        // is in-bounds.
896        //
897        // Further, by construction, `v` is encodable by the `Unsigned`.
898        unsafe { data.vector_mut().set_unchecked(i, v) };
899    }
900
901    let shifted_norm = preprocessed.shifted_norm;
902    let inner_product_correction = (transformed_norm * shifted_norm) / self_inner_product;
903    data.set_meta(DataMeta::new(
904        inner_product_correction,
905        preprocessed.metric_specific(),
906        bit_sum,
907    )?);
908    Ok(())
909}
910
911// This struct does 2 things:
912//
913// 1. Records the index in `v` and `rounded` and the scaling parameter so that
914//   `value * v[position]` gets rounded to `rounded[position] + 1` while
915//   `(value - epsilon) * v[position]` is rounded to `rounded[position]` for some small
916//   epsilon.
917//
918//   Informally, "what's the smallest scaling factor so `v[position]` gets rounded to
919//   the next value.
920//
921// 2. Imposes a total ordering on `f32` values so it can be used in a `BinaryHeap`.
922//   Additionally, ordering is reverse so that `BinaryHeap` models a min-heap instead
923//   of a max-heap.
924#[derive(Debug, Clone, Copy)]
925struct Pair {
926    value: f32,
927    position: u32,
928}
929
930impl PartialEq for Pair {
931    fn eq(&self, other: &Self) -> bool {
932        self.value.eq(&other.value)
933    }
934}
935
936impl Eq for Pair {}
937impl PartialOrd for Pair {
938    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
939        Some(self.cmp(other))
940    }
941}
942impl Ord for Pair {
943    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
944        other
945            .value
946            .partial_cmp(&self.value)
947            .unwrap_or(std::cmp::Ordering::Equal)
948    }
949}
950
951/// This is a tricky function - please read carefully.
952///
953/// Given a vector `v` compute the scaling factor `s` such that the cosine similarity
954/// betwen `v` and `r` is maximized where `r` is defined as
955/// ```math
956/// let offset = (2^(num_bits) - 1) / 2;
957/// let r = (s * v + offset).round().clamp(0, 2^num_bits - 1) - offset
958/// ```
959///
960/// More informally, maximize the inner product between `v` and the points in a square
961/// lattice with 2^num_bits values in each dimension, centered around zero. This latice
962/// takes the values (+0.5, -0.5, +1.5, -1.5 ...) to give equal weight above and below zero.
963///
964/// It works by slowly increasing the factor `s` such that the rounding of only one
965/// dimension in `v` is changed at a time. A running tally of the cosine similarity is
966/// computed for each scaling factor until we've processed `D * 2^(num_bits - 1)` possible
967/// scaling factors, where `D` is the length of `v`.
968///
969/// The best scaling factor is returned.
970///
971/// Refer to algorithm 1 in <https://arxiv.org/pdf/2409.09913>.
972///
973/// # Panics
974///
975/// Panics is `v.is_empty()`.
976///
977/// # Implementation details
978///
979/// We work with the absolute value of the elements in the vector `v`.
980/// This does not affect the final result as the scaling works the same in both the
981/// positive and negative directions but simplifies the book keeping.
982fn maximize_cosine_similarity(
983    v: &[f32],
984    num_bits: NonZeroUsize,
985    allocator: ScopedAllocator<'_>,
986) -> Result<f32, AllocatorError> {
987    // Initially, the lattice element has the value `0.5` for all dimensions.
988    // This means the initial inner product between `v` and the rounded term is simply
989    // `0.5 * sum(abs.(v))`. The absolute value is used because the latice element is
990    // always in the direction of the components in `v`.
991    let mut current_ip = 0.5 * v.iter().map(|i| i.abs() as f64).sum::<f64>();
992    let mut current_square_norm = 0.25 * (v.len() as f64);
993
994    // Book keeping for the current value of the rounded vector.
995    // The true numeric value is 0.5 less than this (in the direction of `v`), but we use
996    // integers for a smaller memory footprint.
997    let mut rounded = Poly::broadcast(1u16, v.len(), allocator)?;
998
999    // Compute the critical values and store them on a heap.
1000    //
1001    // The binary heap will keep track of the minimum critical value. Multiplying `v` by the
1002    // minimum critical value `s` means that `s * v` will only change `rounded` from its
1003    // current value at a single index (the position associated with `s`).
1004    let eps = 0.0001f32;
1005    let one_and_change = 1.0 + eps;
1006    let mut base = Poly::from_iter(
1007        v.iter().enumerate().map(|(position, value)| {
1008            let value = one_and_change / value.abs();
1009            Pair {
1010                value,
1011                position: position as u32,
1012            }
1013        }),
1014        allocator,
1015    )?;
1016
1017    // Lint: This is a private method and all the callers have an invariant that they check
1018    // for non-empty inputs.
1019    #[allow(clippy::expect_used)]
1020    let mut critical_values =
1021        SliceHeap::new(&mut base).expect("calling code should not allow the slice to be empty");
1022
1023    let mut max_similarity = f64::NEG_INFINITY;
1024    let mut optimal_scale = f32::default();
1025    let stop = (2usize).pow(num_bits.get() as u32 - 1) as u16;
1026
1027    loop {
1028        let mut should_break = false;
1029        critical_values.update_root(|pair| {
1030            let Pair { value, position } = *pair;
1031            if value == f32::MAX {
1032                should_break = true;
1033                return;
1034            }
1035
1036            let r = &mut rounded[position as usize];
1037            let vp = &v[position as usize];
1038
1039            let old_r = *r;
1040            // By the nature of cricital values, only `r` will change in `rounded` when
1041            // multiplying by `value`. And that change will be to increase by 1.
1042            *r += 1;
1043
1044            // The inner product estimate simply increases by `vp.abs()` because:
1045            //
1046            // * `r` is the only value in `rounded` that changes.
1047            // * `r` is increased by 1.
1048            current_ip += vp.abs() as f64;
1049
1050            // This uses the formula
1051            // ```math
1052            // (x + 1)^2 - x^2 = x^2 + 2x + 1 - x^2
1053            //                 = 2x + 1
1054            // ```
1055            // substitute `x = y - 1/2` to obtain the true value associated with rounded and
1056            // we get
1057            // ```math
1058            // 2 ( y - 1/2 ) + 1 = 2y - 1 + 1
1059            //                   = 2y
1060            // ```
1061            // Therefore, the change in the estimate for the square norm of `rounded` is
1062            // `2 * old_r`.
1063            current_square_norm += (2 * old_r) as f64;
1064
1065            // Compute the current cosine similarity and update max if needed.
1066            let similarity = current_ip / current_square_norm.sqrt();
1067            if similarity > max_similarity {
1068                max_similarity = similarity;
1069                optimal_scale = value;
1070            }
1071
1072            // Compute the scaling factor that will change this dimension to the next value.
1073            if *r < stop {
1074                *pair = Pair {
1075                    value: (*r as f32 + eps) / vp.abs(),
1076                    position,
1077                };
1078            } else {
1079                *pair = Pair {
1080                    value: f32::MAX,
1081                    position,
1082                };
1083            }
1084        });
1085        if should_break {
1086            break;
1087        }
1088    }
1089
1090    Ok(optimal_scale)
1091}
1092
1093///////////////////////
1094// Query Compression //
1095///////////////////////
1096
1097impl<const NBITS: usize, Perm, A>
1098    CompressIntoWith<&[f32], QueryMut<'_, NBITS, Perm>, ScopedAllocator<'_>>
1099    for SphericalQuantizer<A>
1100where
1101    Unsigned: Representation<NBITS>,
1102    Perm: PermutationStrategy<NBITS>,
1103    A: Allocator,
1104{
1105    type Error = CompressionError;
1106
1107    /// Compress the input vector `from` into the bitslice `into`.
1108    ///
1109    /// # Error
1110    ///
1111    /// Returns an error if
1112    /// * The input contains `NaN`.
1113    /// * `from.len() != self.dim()`: Vector to be compressed must have the same
1114    ///   dimensionality as the quantizer.
1115    /// * `into.len() != self.output_dim()`: Compressed vector must have the same
1116    ///   dimensionality as the output of the distance-preserving transform. Importantely,
1117    ///   this **may** be different than `self.dim()` and should be retrieved from
1118    ///   `self.output_dim()`.
1119    fn compress_into_with(
1120        &self,
1121        from: &[f32],
1122        mut into: QueryMut<'_, NBITS, Perm>,
1123        allocator: ScopedAllocator<'_>,
1124    ) -> Result<(), Self::Error> {
1125        let input_dim = self.shift.len();
1126        let output_dim = self.output_dim();
1127        check_dims(input_dim, output_dim, from.len(), into.len())?;
1128
1129        let mut preprocessed = self.preprocess(from, allocator)?;
1130
1131        if preprocessed.shifted_norm == 0.0 {
1132            into.set_meta(QueryMeta::default());
1133            return Ok(());
1134        }
1135
1136        preprocessed
1137            .shifted
1138            .iter_mut()
1139            .for_each(|v| *v /= preprocessed.shifted_norm);
1140
1141        let mut transformed = Poly::broadcast(0.0f32, output_dim, allocator)?;
1142
1143        // Transformation can fail due to OOM - we want to handle that gracefully.
1144        //
1145        // If the transformation fails because we provided the wrong sizes, that is a hard
1146        // program bug.
1147        #[expect(clippy::panic, reason = "the dimensions should already be as expected")]
1148        match self
1149            .transform
1150            .transform_into(&mut transformed, &preprocessed.shifted, allocator)
1151        {
1152            Ok(()) => {}
1153            Err(TransformFailed::AllocatorError(err)) => {
1154                return Err(CompressionError::AllocatorError(err))
1155            }
1156            Err(TransformFailed::SourceMismatch { .. })
1157            | Err(TransformFailed::DestinationMismatch { .. }) => {
1158                panic!(
1159                    "The sizes of these arrays should already be checked - this is a logic error"
1160                );
1161            }
1162        }
1163
1164        // Compute the minimum and maximum values of the transformed vector.
1165        let (min, max) = transformed
1166            .iter()
1167            .fold((f32::MAX, f32::MIN), |(min, max), i| {
1168                (i.min(min), i.max(max))
1169            });
1170
1171        let domain = Unsigned::domain_const::<NBITS>();
1172        let lo = (*domain.start()) as f32;
1173        let hi = (*domain.end()) as f32;
1174
1175        let scale = (max - min) / hi;
1176        let mut bit_sum: f32 = 0.0;
1177        transformed.iter().enumerate().for_each(|(i, v)| {
1178            let c = ((v - min) / scale).round().clamp(lo, hi);
1179            bit_sum += c;
1180
1181            // Lint: We have verified that `into.len() == transformed.len()`, so the index
1182            // `i` is in bounds.
1183            //
1184            // Further, `c` has beem clamped to `[0, 2^NBITS - 1]` and is thus encodable
1185            // with the NBITS-bit unsigned representation.
1186            #[allow(clippy::unwrap_used)]
1187            into.vector_mut().set(i, c as i64).unwrap();
1188        });
1189
1190        // Finish up the compensation terms.
1191        into.set_meta(QueryMeta {
1192            inner_product_correction: preprocessed.shifted_norm * scale,
1193            bit_sum,
1194            offset: min / scale,
1195            metric_specific: preprocessed.metric_specific(),
1196        });
1197
1198        Ok(())
1199    }
1200}
1201
1202///////////
1203// Tests //
1204///////////
1205
1206#[cfg(not(miri))]
1207#[cfg(test)]
1208mod tests {
1209    use super::*;
1210
1211    use std::fmt::Display;
1212
1213    use diskann_utils::{
1214        lazy_format,
1215        views::{self, Matrix},
1216        ReborrowMut,
1217    };
1218    use diskann_vector::{norm::FastL2NormSquared, PureDistanceFunction};
1219    use diskann_wide::ARCH;
1220    use rand::{
1221        distr::{Distribution, Uniform},
1222        rngs::StdRng,
1223        SeedableRng,
1224    };
1225    use rand_distr::StandardNormal;
1226
1227    use crate::{
1228        algorithms::transforms::TargetDim,
1229        alloc::GlobalAllocator,
1230        bits::{BitTranspose, Dense},
1231        spherical::{Data, DataMetaF32, FullQuery, Query},
1232        test_util,
1233    };
1234
1235    // Test cosine-similarity maximizer
1236    #[test]
1237    fn test_cosine_similarity_maximizer() {
1238        let mut rng = StdRng::seed_from_u64(0x070d9ff8cf5e0f8c);
1239        let num_trials = 10000;
1240        let num_bits = NonZeroUsize::new(3).unwrap();
1241
1242        let scale_distribution = Uniform::new(0.5f32, 10.0f32).unwrap();
1243
1244        let run_test = |target: [f32; 4]| {
1245            let scale =
1246                maximize_cosine_similarity(&target, num_bits, ScopedAllocator::global()).unwrap();
1247
1248            let mut best: [f32; 4] = [0.0, 0.0, 0.0, 0.0];
1249            let mut best_similarity: f32 = f32::NEG_INFINITY;
1250
1251            // This crazy series of nested loops performs an exhaustive search over the
1252            // encoding space.
1253            let min = -3.5;
1254            for i0 in 0..8 {
1255                for i1 in 0..8 {
1256                    for i2 in 0..8 {
1257                        for i3 in 0..8 {
1258                            let p: [f32; 4] = [
1259                                min + (i0 as f32),
1260                                min + (i1 as f32),
1261                                min + (i2 as f32),
1262                                min + (i3 as f32),
1263                            ];
1264
1265                            let sim: MathematicalValue<f32> =
1266                                diskann_vector::distance::Cosine::evaluate(&p, &target);
1267                            let sim = sim.into_inner();
1268                            if sim > best_similarity {
1269                                best_similarity = sim;
1270                                // Transform into an integer starting at zero.
1271                                best = p.map(|i| i - min);
1272                            }
1273                        }
1274                    }
1275                }
1276            }
1277
1278            // Now, rescale the input vector, clamp, and round.
1279            // Check if they agree.
1280            let clamped = target.map(|i| (i * scale - min).round().clamp(0.0, 7.0));
1281            let clamped_cosine: MathematicalValue<f32> =
1282                diskann_vector::distance::Cosine::evaluate(&clamped.map(|i| i + min), &target);
1283
1284            // We expect to either get the best value found via exhaustive search, or some
1285            // scalar multiple of it (since that will have the same cosine similarity).
1286            let passed = if best == clamped {
1287                true
1288            } else {
1289                let ratio: Vec<f32> = std::iter::zip(best, clamped)
1290                    .map(|(b, c)| {
1291                        let ratio = (b + min) / (c + min);
1292                        assert_ne!(
1293                            ratio, 0.0,
1294                            "ratio should never be zero because `b` is an integer and \
1295                             `min` is not"
1296                        );
1297                        ratio
1298                    })
1299                    .collect();
1300
1301                ratio.iter().all(|i| *i == ratio[0])
1302            };
1303
1304            if !passed {
1305                panic!(
1306                    "failed for input {:?}.\
1307                     Best = {:?}, Found = {:?}\
1308                     Best similarity = {}, similarity with clamped = {}",
1309                    target,
1310                    best,
1311                    clamped,
1312                    best_similarity,
1313                    clamped_cosine.into_inner()
1314                );
1315            }
1316        };
1317
1318        // Run targeted tests.
1319        let min = -3.5;
1320        for i0 in (0..8).step_by(2) {
1321            for i1 in (1..9).step_by(2) {
1322                for i2 in (0..8).step_by(2) {
1323                    for i3 in (1..9).step_by(2) {
1324                        let p: [f32; 4] = [
1325                            min + (i0 as f32),
1326                            min + (i1 as f32),
1327                            min + (i2 as f32),
1328                            min + (i3 as f32),
1329                        ];
1330                        run_test(p)
1331                    }
1332                }
1333            }
1334        }
1335
1336        for _ in 0..num_trials {
1337            let this_scale: f32 = scale_distribution.sample(&mut rng);
1338            let v: [f32; 4] = [(); 4].map(|_| {
1339                let v: f32 = StandardNormal {}.sample(&mut rng);
1340                this_scale * v
1341            });
1342            run_test(v);
1343        }
1344    }
1345
1346    #[test]
1347    #[should_panic(expected = "calling code should not allow the slice to be empty")]
1348    fn empty_slice_panics() {
1349        maximize_cosine_similarity(
1350            &[],
1351            NonZeroUsize::new(4).unwrap(),
1352            ScopedAllocator::global(),
1353        )
1354        .unwrap();
1355    }
1356
1357    struct Setup {
1358        transform: TransformKind,
1359        nrows: usize,
1360        ncols: usize,
1361        num_trials: usize,
1362    }
1363
1364    fn get_scale(scale: PreScale, quantizer: &SphericalQuantizer) -> f32 {
1365        match scale {
1366            PreScale::None => 1.0,
1367            PreScale::Some(v) => v.into_inner(),
1368            PreScale::ReciprocalMeanNorm => 1.0 / quantizer.mean_norm().into_inner(),
1369        }
1370    }
1371
1372    fn test_l2<const Q: usize, const D: usize, Perm>(
1373        setup: &Setup,
1374        problem: &test_util::TestProblem,
1375        computed_means: &[f32],
1376        pre_scale: PreScale,
1377        rng: &mut StdRng,
1378    ) where
1379        Unsigned: Representation<Q>,
1380        Unsigned: Representation<D>,
1381        Perm: PermutationStrategy<Q>,
1382        for<'a> SphericalQuantizer:
1383            CompressIntoWith<&'a [f32], DataMut<'a, D>, ScopedAllocator<'a>>,
1384        for<'a> SphericalQuantizer:
1385            CompressIntoWith<&'a [f32], QueryMut<'a, Q, Perm>, ScopedAllocator<'a>>,
1386    {
1387        assert_eq!(setup.nrows, problem.data.nrows());
1388        assert_eq!(setup.ncols, problem.data.ncols());
1389
1390        let scoped_global = ScopedAllocator::global();
1391        let distribution = Uniform::new(0, setup.nrows).unwrap();
1392        let quantizer = SphericalQuantizer::train(
1393            problem.data.as_view(),
1394            setup.transform,
1395            SupportedMetric::SquaredL2,
1396            pre_scale,
1397            rng,
1398            GlobalAllocator,
1399        )
1400        .unwrap();
1401
1402        let scale = get_scale(pre_scale, &quantizer);
1403
1404        let mut b = Data::<D, _>::new_boxed(quantizer.output_dim());
1405        let mut q = Query::<Q, Perm, _>::new_boxed(quantizer.output_dim());
1406        let mut f = FullQuery::empty(quantizer.output_dim(), GlobalAllocator).unwrap();
1407
1408        assert_eq!(
1409            quantizer.mean_norm.into_inner(),
1410            problem.mean_norm as f32,
1411            "computed mean norm should not apply scale"
1412        );
1413        let scaled_means: Vec<_> = computed_means.iter().map(|i| scale * i).collect();
1414        assert_eq!(&*scaled_means, quantizer.shift());
1415
1416        let l2: CompensatedSquaredL2 = quantizer.as_functor();
1417        assert_eq!(l2.dim, quantizer.output_dim() as f32);
1418
1419        for _ in 0..setup.num_trials {
1420            let i = distribution.sample(rng);
1421            let v = problem.data.row(i);
1422
1423            quantizer
1424                .compress_into_with(v, b.reborrow_mut(), scoped_global)
1425                .unwrap();
1426            quantizer
1427                .compress_into_with(v, q.reborrow_mut(), scoped_global)
1428                .unwrap();
1429            quantizer
1430                .compress_into_with(v, f.reborrow_mut(), scoped_global)
1431                .unwrap();
1432
1433            let shifted: Vec<f32> = std::iter::zip(v.iter(), quantizer.shift().iter())
1434                .map(|(a, b)| scale * a - b)
1435                .collect();
1436
1437            // Check that the compensation coefficient were chosen correctly.
1438            {
1439                let DataMetaF32 {
1440                    inner_product_correction,
1441                    bit_sum,
1442                    metric_specific,
1443                } = b.meta().to_full(ARCH);
1444
1445                let shifted_square_norm = metric_specific;
1446
1447                // Check that the bit-count is correct. let bv = b.vector();
1448                let bv = b.vector();
1449                let s: usize = (0..bv.len()).map(|i| bv.get(i).unwrap() as usize).sum();
1450                assert_eq!(s, bit_sum as usize);
1451
1452                // Check that the shifted norm is correct.
1453                {
1454                    let expected = FastL2NormSquared.evaluate(&*shifted);
1455                    let err = (shifted_square_norm - expected).abs() / expected.abs();
1456                    assert!(
1457                        err < 5.0e-4, // twice the minimum normal f16 value.
1458                        "failed diff check, got {}, expected {} - relative error = {}",
1459                        shifted_square_norm,
1460                        expected,
1461                        err
1462                    );
1463                }
1464
1465                // Finaly, verify that the self-inner-product is clustered around 0.8 as
1466                // the RaBitQ paper suggests.
1467                if const { D == 1 } {
1468                    let self_inner_product = 2.0 * shifted_square_norm.sqrt()
1469                        / (inner_product_correction * (bv.len() as f32).sqrt());
1470                    assert!(
1471                        (self_inner_product - 0.8).abs() < 0.13,
1472                        "self inner-product should be close to 0.8. Instead, it's {}",
1473                        self_inner_product
1474                    );
1475                }
1476            }
1477
1478            {
1479                let QueryMeta {
1480                    inner_product_correction,
1481                    bit_sum,
1482                    offset,
1483                    metric_specific,
1484                } = q.meta();
1485
1486                let shifted_square_norm = metric_specific;
1487                let mut preprocessed = quantizer.preprocess(v, scoped_global).unwrap();
1488                preprocessed
1489                    .shifted
1490                    .iter_mut()
1491                    .for_each(|i| *i /= preprocessed.shifted_norm);
1492
1493                let mut transformed = vec![0.0f32; quantizer.output_dim()];
1494                quantizer
1495                    .transform
1496                    .transform_into(&mut transformed, &preprocessed.shifted, scoped_global)
1497                    .unwrap();
1498
1499                let min = transformed.iter().fold(f32::MAX, |min, &i| min.min(i));
1500                let max = transformed.iter().fold(f32::MIN, |max, &i| max.max(i));
1501
1502                let scale = (max - min) / ((2usize.pow(Q as u32) - 1) as f32);
1503
1504                // Shifted Norm
1505                {
1506                    let expected = FastL2NormSquared.evaluate(&*shifted);
1507                    let err = (shifted_square_norm - expected).abs() / expected.abs();
1508                    assert!(
1509                        err < 2e-7,
1510                        "failed diff check, got {}, expected {} - relative error = {}",
1511                        shifted_square_norm,
1512                        expected,
1513                        err
1514                    );
1515                }
1516
1517                // Inner product correction
1518                {
1519                    let expected = shifted_square_norm.sqrt() * scale;
1520                    let got = inner_product_correction;
1521
1522                    let err = (expected - got).abs();
1523                    assert!(
1524                        err < 1.0e-7,
1525                        "\"innerproduct_scale\": expected {}, got {}, error = {}",
1526                        expected,
1527                        got,
1528                        err
1529                    );
1530                }
1531
1532                // Offset
1533                {
1534                    let expected = min / scale;
1535                    let got = offset;
1536
1537                    let err = (expected - got).abs();
1538                    assert!(
1539                        err < 1.0e-7,
1540                        "\"sum_scale\": expected {}, got {}, error = {}",
1541                        expected,
1542                        got,
1543                        err
1544                    );
1545                }
1546
1547                // Bit Sum
1548                {
1549                    let expected = (0..q.len())
1550                        .map(|i| q.vector().get(i).unwrap())
1551                        .sum::<i64>() as f32;
1552
1553                    let got = bit_sum;
1554
1555                    let err = (expected - got).abs();
1556                    assert!(
1557                        err < 1.0e-7,
1558                        "\"offset\": expected {}, got {}, error = {}",
1559                        expected,
1560                        got,
1561                        err
1562                    );
1563                }
1564            }
1565
1566            // Check that the compensation coefficient were chosen correctly.
1567            {
1568                // Check that the bit-count is correct.
1569                let s: f32 = f.data.iter().sum::<f32>();
1570                assert_eq!(s, f.meta.sum);
1571
1572                // Check that the shifted norm is correct.
1573                {
1574                    let expected = FastL2Norm.evaluate(&*shifted);
1575                    let err = (f.meta.shifted_norm - expected).abs() / expected.abs();
1576                    assert!(
1577                        err < 2e-7,
1578                        "failed diff check, got {}, expected {} - relative error = {}",
1579                        f.meta.shifted_norm,
1580                        expected,
1581                        err
1582                    );
1583                }
1584
1585                assert_eq!(
1586                    f.meta.metric_specific,
1587                    f.meta.shifted_norm * f.meta.shifted_norm,
1588                    "metric specific data for squared l2 is the square shifted norm",
1589                );
1590            }
1591        }
1592
1593        // Finally - test that if we compress the centroid, the metadata coefficients get
1594        // zeroed correctly.
1595        quantizer
1596            .compress_into_with(computed_means, b.reborrow_mut(), scoped_global)
1597            .unwrap();
1598        assert_eq!(b.meta(), DataMeta::default());
1599
1600        quantizer
1601            .compress_into_with(computed_means, q.reborrow_mut(), scoped_global)
1602            .unwrap();
1603        assert_eq!(q.meta(), QueryMeta::default());
1604
1605        f.data.fill(f32::INFINITY);
1606        quantizer
1607            .compress_into_with(computed_means, f.reborrow_mut(), scoped_global)
1608            .unwrap();
1609        assert!(f.data.iter().all(|&i| i == 0.0));
1610        assert_eq!(f.meta.sum, 0.0);
1611        assert_eq!(f.meta.metric_specific, 0.0);
1612    }
1613
1614    fn test_ip<const Q: usize, const D: usize, Perm>(
1615        setup: &Setup,
1616        problem: &test_util::TestProblem,
1617        computed_means: &[f32],
1618        pre_scale: PreScale,
1619        rng: &mut StdRng,
1620        ctx: &dyn Display,
1621    ) where
1622        Unsigned: Representation<Q>,
1623        Unsigned: Representation<D>,
1624        Perm: PermutationStrategy<Q>,
1625        for<'a> SphericalQuantizer:
1626            CompressIntoWith<&'a [f32], DataMut<'a, D>, ScopedAllocator<'a>>,
1627        for<'a> SphericalQuantizer:
1628            CompressIntoWith<&'a [f32], QueryMut<'a, Q, Perm>, ScopedAllocator<'a>>,
1629    {
1630        assert_eq!(setup.nrows, problem.data.nrows());
1631        assert_eq!(setup.ncols, problem.data.ncols());
1632
1633        let scoped_global = ScopedAllocator::global();
1634        let distribution = Uniform::new(0, setup.nrows).unwrap();
1635        let quantizer = SphericalQuantizer::train(
1636            problem.data.as_view(),
1637            setup.transform,
1638            SupportedMetric::InnerProduct,
1639            pre_scale,
1640            rng,
1641            GlobalAllocator,
1642        )
1643        .unwrap();
1644
1645        let scale = get_scale(pre_scale, &quantizer);
1646
1647        let mut b = Data::<D, _>::new_boxed(quantizer.output_dim());
1648        let mut q = Query::<Q, Perm, _>::new_boxed(quantizer.output_dim());
1649        let mut f = FullQuery::empty(quantizer.output_dim(), GlobalAllocator).unwrap();
1650
1651        assert_eq!(
1652            quantizer.mean_norm.into_inner(),
1653            problem.mean_norm as f32,
1654            "computed mean norm should not apply scale"
1655        );
1656        let scaled_means: Vec<_> = computed_means.iter().map(|i| scale * i).collect();
1657        assert_eq!(&*scaled_means, quantizer.shift());
1658
1659        let ip: CompensatedIP = quantizer.as_functor();
1660
1661        assert_eq!(ip.dim, quantizer.output_dim() as f32);
1662        assert_eq!(
1663            ip.squared_shift_norm,
1664            FastL2NormSquared.evaluate(quantizer.shift())
1665        );
1666
1667        for _ in 0..setup.num_trials {
1668            let i = distribution.sample(rng);
1669            let v = problem.data.row(i);
1670
1671            quantizer
1672                .compress_into_with(v, b.reborrow_mut(), scoped_global)
1673                .unwrap();
1674            quantizer
1675                .compress_into_with(v, q.reborrow_mut(), scoped_global)
1676                .unwrap();
1677            quantizer
1678                .compress_into_with(v, f.reborrow_mut(), scoped_global)
1679                .unwrap();
1680
1681            let shifted: Vec<f32> = std::iter::zip(v.iter(), quantizer.shift().iter())
1682                .map(|(a, b)| scale * a - b)
1683                .collect();
1684
1685            // Check that the compensation coefficient were chosen correctly.
1686            {
1687                let DataMetaF32 {
1688                    inner_product_correction,
1689                    bit_sum,
1690                    metric_specific,
1691                } = b.meta().to_full(ARCH);
1692
1693                let inner_product_with_centroid = metric_specific;
1694
1695                // Check that the bit-count is correct.
1696                let bv = b.vector();
1697                let s: usize = (0..bv.len()).map(|i| bv.get(i).unwrap() as usize).sum();
1698                assert_eq!(s, bit_sum as usize);
1699
1700                // Check that the shifted norm is correct.
1701                let inner_product: MathematicalValue<f32> =
1702                    InnerProduct::evaluate(&*shifted, quantizer.shift());
1703
1704                let diff = (inner_product.into_inner() - inner_product_with_centroid).abs();
1705                assert!(
1706                    diff < 1.53e-5,
1707                    "got a diff of {}. Expected = {}, got = {} -- context: {}",
1708                    diff,
1709                    inner_product.into_inner(),
1710                    inner_product_with_centroid,
1711                    ctx,
1712                );
1713
1714                // Finaly, verify that the self-inner-product is clustered around 0.8 as
1715                // the RaBitQ paper suggests.
1716                if const { D == 1 } {
1717                    let self_inner_product = 2.0 * (FastL2Norm).evaluate(&*shifted)
1718                        / (inner_product_correction * (bv.len() as f32).sqrt());
1719                    assert!(
1720                        (self_inner_product - 0.8).abs() < 0.12,
1721                        "self inner-product should be close to 0.8. Instead, it's {}",
1722                        self_inner_product
1723                    );
1724                }
1725            }
1726
1727            {
1728                let QueryMeta {
1729                    inner_product_correction,
1730                    bit_sum,
1731                    offset,
1732                    metric_specific,
1733                } = q.meta();
1734
1735                let inner_product_with_centroid = metric_specific;
1736                let mut preprocessed = quantizer.preprocess(v, scoped_global).unwrap();
1737                preprocessed
1738                    .shifted
1739                    .iter_mut()
1740                    .for_each(|i| *i /= preprocessed.shifted_norm);
1741
1742                let mut transformed = vec![0.0f32; quantizer.output_dim()];
1743                quantizer
1744                    .transform
1745                    .transform_into(&mut transformed, &preprocessed.shifted, scoped_global)
1746                    .unwrap();
1747
1748                let min = transformed.iter().fold(f32::MAX, |min, &i| min.min(i));
1749                let max = transformed.iter().fold(f32::MIN, |max, &i| max.max(i));
1750
1751                let scale = (max - min) / ((2usize.pow(Q as u32) - 1) as f32);
1752
1753                // Inner product correction
1754                {
1755                    let expected = (FastL2Norm).evaluate(&*shifted) * scale;
1756                    let got = inner_product_correction;
1757
1758                    let err = (expected - got).abs();
1759                    assert!(
1760                        err < 1.0e-7,
1761                        "\"innerproduct_scale\": expected {}, got {}, error = {}",
1762                        expected,
1763                        got,
1764                        err
1765                    );
1766                }
1767
1768                // Offset
1769                {
1770                    let expected = min / scale;
1771                    let got = offset;
1772
1773                    let err = (expected - got).abs();
1774                    assert!(
1775                        err < 1.0e-7,
1776                        "\"sum_scale\": expected {}, got {}, error = {}",
1777                        expected,
1778                        got,
1779                        err
1780                    );
1781                }
1782
1783                // Bit Sum
1784                {
1785                    let expected = (0..q.len())
1786                        .map(|i| q.vector().get(i).unwrap())
1787                        .sum::<i64>() as f32;
1788
1789                    let got = bit_sum;
1790
1791                    let err = (expected - got).abs();
1792                    assert!(
1793                        err < 1.0e-7,
1794                        "\"offset\": expected {}, got {}, error = {}",
1795                        expected,
1796                        got,
1797                        err
1798                    );
1799                }
1800
1801                // Inner Product with Centroid
1802                {
1803                    // Check that the shifted norm is correct.
1804                    let inner_product: MathematicalValue<f32> =
1805                        InnerProduct::evaluate(&*shifted, quantizer.shift());
1806                    assert_eq!(inner_product.into_inner(), inner_product_with_centroid);
1807                }
1808            }
1809
1810            // Check that the compensation coefficient were chosen correctly.
1811            {
1812                // Check that the bit-count is correct.
1813                let s: f32 = f.data.iter().sum::<f32>();
1814                assert_eq!(s, f.meta.sum);
1815
1816                // Check that the shifted norm is correct.
1817                {
1818                    let expected = FastL2Norm.evaluate(&*shifted);
1819                    let err = (f.meta.shifted_norm - expected).abs() / expected.abs();
1820                    assert!(
1821                        err < 2e-7,
1822                        "failed diff check, got {}, expected {} - relative error = {}",
1823                        f.meta.shifted_norm,
1824                        expected,
1825                        err
1826                    );
1827                }
1828
1829                // Check that the shifted norm is correct. s
1830                let inner_product: MathematicalValue<f32> =
1831                    InnerProduct::evaluate(&*shifted, quantizer.shift());
1832                assert_eq!(inner_product.into_inner(), f.meta.metric_specific,);
1833            }
1834        }
1835
1836        // Finally - test that if we compress the centroid, the metadata coefficients get
1837        // zeroed correctly.
1838        quantizer
1839            .compress_into_with(computed_means, b.reborrow_mut(), scoped_global)
1840            .unwrap();
1841        assert_eq!(b.meta(), DataMeta::default());
1842
1843        quantizer
1844            .compress_into_with(computed_means, q.reborrow_mut(), scoped_global)
1845            .unwrap();
1846        assert_eq!(q.meta(), QueryMeta::default());
1847
1848        f.data.fill(f32::INFINITY);
1849        quantizer
1850            .compress_into_with(computed_means, f.reborrow_mut(), scoped_global)
1851            .unwrap();
1852        assert!(f.data.iter().all(|&i| i == 0.0));
1853        assert_eq!(f.meta.sum, 0.0);
1854        assert_eq!(f.meta.metric_specific, 0.0);
1855    }
1856
1857    fn test_cosine<const Q: usize, const D: usize, Perm>(
1858        setup: &Setup,
1859        problem: &test_util::TestProblem,
1860        pre_scale: PreScale,
1861        rng: &mut StdRng,
1862    ) where
1863        Unsigned: Representation<Q>,
1864        Unsigned: Representation<D>,
1865        Perm: PermutationStrategy<Q>,
1866        for<'a> SphericalQuantizer:
1867            CompressIntoWith<&'a [f32], DataMut<'a, D>, ScopedAllocator<'a>>,
1868        for<'a> SphericalQuantizer:
1869            CompressIntoWith<&'a [f32], QueryMut<'a, Q, Perm>, ScopedAllocator<'a>>,
1870    {
1871        assert_eq!(setup.nrows, problem.data.nrows());
1872        assert_eq!(setup.ncols, problem.data.ncols());
1873
1874        let scoped_global = ScopedAllocator::global();
1875        let distribution = Uniform::new(0, setup.nrows).unwrap();
1876        let quantizer = SphericalQuantizer::train(
1877            problem.data.as_view(),
1878            setup.transform,
1879            SupportedMetric::Cosine,
1880            pre_scale,
1881            rng,
1882            GlobalAllocator,
1883        )
1884        .unwrap();
1885
1886        let mut b = Data::<D, _>::new_boxed(quantizer.output_dim());
1887        let mut q = Query::<Q, Perm, _>::new_boxed(quantizer.output_dim());
1888        let mut f = FullQuery::empty(quantizer.output_dim(), GlobalAllocator).unwrap();
1889
1890        let cosine: CompensatedCosine = quantizer.as_functor();
1891
1892        assert_eq!(cosine.inner.dim, quantizer.output_dim() as f32);
1893        assert_eq!(
1894            cosine.inner.squared_shift_norm,
1895            FastL2NormSquared.evaluate(quantizer.shift())
1896        );
1897
1898        const IP_BOUND: f32 = 2.6e-3;
1899
1900        let mut test_row = |v: &[f32]| {
1901            let vnorm = (FastL2Norm).evaluate(v);
1902            let v_normalized: Vec<f32> = v
1903                .iter()
1904                .map(|i| if vnorm == 0.0 { 0.0 } else { *i / vnorm })
1905                .collect();
1906
1907            quantizer
1908                .compress_into_with(v, b.reborrow_mut(), scoped_global)
1909                .unwrap();
1910
1911            quantizer
1912                .compress_into_with(v, q.reborrow_mut(), scoped_global)
1913                .unwrap();
1914
1915            quantizer
1916                .compress_into_with(v, f.reborrow_mut(), scoped_global)
1917                .unwrap();
1918
1919            let shifted: Vec<f32> = std::iter::zip(v_normalized.iter(), quantizer.shift().iter())
1920                .map(|(a, b)| a - b)
1921                .collect();
1922
1923            // Check that the compensation coefficient were chosen correctly.
1924            {
1925                let DataMetaF32 {
1926                    inner_product_correction,
1927                    bit_sum,
1928                    metric_specific,
1929                } = b.meta().to_full(ARCH);
1930
1931                let inner_product_with_centroid = metric_specific;
1932
1933                // Check that the bit-count is correct.
1934                let bv = b.vector();
1935                let s: usize = (0..bv.len()).map(|i| bv.get(i).unwrap() as usize).sum();
1936                assert_eq!(s, bit_sum as usize);
1937
1938                // Check that the shifted norm is correct. Since they are computed slightly
1939                // differnetly, allow a small amount of error.
1940                let inner_product: MathematicalValue<f32> =
1941                    InnerProduct::evaluate(&*shifted, quantizer.shift());
1942
1943                let abs = (inner_product.into_inner() - inner_product_with_centroid).abs();
1944                let relative = abs / inner_product.into_inner().abs();
1945
1946                assert!(
1947                    abs < 1e-7 || relative < IP_BOUND,
1948                    "got an abs/rel of {}/{} with a bound of {}/{}",
1949                    abs,
1950                    relative,
1951                    1e-7,
1952                    IP_BOUND
1953                );
1954
1955                // Finaly, verify that the self-inner-product is clustered around 0.8 as
1956                // the RaBitQ paper suggests.
1957                if const { D == 1 } {
1958                    let self_inner_product = 2.0 * (FastL2Norm).evaluate(&*shifted)
1959                        / (inner_product_correction * (bv.len() as f32).sqrt());
1960                    assert!(
1961                        (self_inner_product - 0.8).abs() < 0.11,
1962                        "self inner-product should be close to 0.8. Instead, it's {}",
1963                        self_inner_product
1964                    );
1965                }
1966            }
1967
1968            {
1969                let QueryMeta {
1970                    inner_product_correction,
1971                    bit_sum,
1972                    offset,
1973                    metric_specific,
1974                } = q.meta();
1975
1976                let inner_product_with_centroid = metric_specific;
1977                let mut preprocessed = quantizer.preprocess(v, scoped_global).unwrap();
1978                preprocessed
1979                    .shifted
1980                    .iter_mut()
1981                    .for_each(|i| *i /= preprocessed.shifted_norm);
1982
1983                let mut transformed = vec![0.0f32; quantizer.output_dim()];
1984                quantizer
1985                    .transform
1986                    .transform_into(&mut transformed, &preprocessed.shifted, scoped_global)
1987                    .unwrap();
1988
1989                let min = transformed.iter().fold(f32::MAX, |min, &i| min.min(i));
1990                let max = transformed.iter().fold(f32::MIN, |max, &i| max.max(i));
1991
1992                let scale = (max - min) / ((2usize.pow(Q as u32) - 1) as f32);
1993
1994                // Inner product correction
1995                {
1996                    let expected = (FastL2Norm).evaluate(&*shifted) * scale;
1997                    let got = inner_product_correction;
1998
1999                    let err = (expected - got).abs();
2000                    assert!(
2001                        err < 1.0e-7,
2002                        "\"innerproduct_scale\": expected {}, got {}, error = {}",
2003                        expected,
2004                        got,
2005                        err
2006                    );
2007                }
2008
2009                // Offset
2010                {
2011                    let expected = min / scale;
2012                    let got = offset;
2013
2014                    let err = (expected - got).abs();
2015                    assert!(
2016                        err < 1.0e-7,
2017                        "\"sum_scale\": expected {}, got {}, error = {}",
2018                        expected,
2019                        got,
2020                        err
2021                    );
2022                }
2023
2024                // Bit Sum
2025                {
2026                    let expected = (0..q.len())
2027                        .map(|i| q.vector().get(i).unwrap())
2028                        .sum::<i64>() as f32;
2029
2030                    let got = bit_sum;
2031
2032                    let err = (expected - got).abs();
2033                    assert!(
2034                        err < 1.0e-7,
2035                        "\"offset\": expected {}, got {}, error = {}",
2036                        expected,
2037                        got,
2038                        err
2039                    );
2040                }
2041
2042                // Inner Product with Centroid
2043                {
2044                    // Check that the shifted norm is correct.
2045                    let inner_product: MathematicalValue<f32> =
2046                        InnerProduct::evaluate(&*shifted, quantizer.shift());
2047
2048                    let err = (inner_product.into_inner() - inner_product_with_centroid).abs()
2049                        / inner_product.into_inner().abs();
2050                    assert!(
2051                        err < IP_BOUND,
2052                        "\"offset\": expected {}, got {}, error = {}",
2053                        inner_product.into_inner(),
2054                        inner_product_with_centroid,
2055                        err
2056                    );
2057                }
2058            }
2059
2060            // Check that the compensation coefficient were chosen correctly.
2061            {
2062                // Check that the bit-count is correct.
2063                let s: f32 = f.data.iter().sum::<f32>();
2064                assert_eq!(s, f.meta.sum);
2065
2066                // Check that the shifted norm is correct.
2067                {
2068                    let expected = FastL2Norm.evaluate(&*shifted);
2069                    let err = (f.meta.shifted_norm - expected).abs() / expected.abs();
2070                    assert!(
2071                        err < 2e-7,
2072                        "failed diff check, got {}, expected {} - relative error = {}",
2073                        f.meta.shifted_norm,
2074                        expected,
2075                        err
2076                    );
2077                }
2078
2079                // Check that the shifted norm is correct. s
2080                let inner_product: MathematicalValue<f32> =
2081                    InnerProduct::evaluate(&*shifted, quantizer.shift());
2082                let err = (inner_product.into_inner() - f.meta.metric_specific).abs()
2083                    / inner_product.into_inner().abs();
2084                assert!(
2085                    err < IP_BOUND,
2086                    "\"offset\": expected {}, got {}, error = {}",
2087                    inner_product.into_inner(),
2088                    f.meta.metric_specific,
2089                    err
2090                );
2091            }
2092        };
2093
2094        for _ in 0..setup.num_trials {
2095            let i = distribution.sample(rng);
2096            let v = problem.data.row(i);
2097            test_row(v);
2098        }
2099
2100        // Ensure that if a zero vector is provided that we do not divide by zero.
2101        let zero = vec![0.0f32; quantizer.input_dim()];
2102        test_row(&zero);
2103    }
2104
2105    fn _test_oom_resiliance<T>(quantizer: &SphericalQuantizer, data: &[f32], dst: &mut T)
2106    where
2107        for<'a> T: ReborrowMut<'a>,
2108        for<'a> SphericalQuantizer: CompressIntoWith<
2109            &'a [f32],
2110            <T as ReborrowMut<'a>>::Target,
2111            ScopedAllocator<'a>,
2112            Error = CompressionError,
2113        >,
2114    {
2115        let mut succeeded = false;
2116        let mut failed = false;
2117        for max_allocations in 0..10 {
2118            match quantizer.compress_into_with(
2119                data,
2120                dst.reborrow_mut(),
2121                ScopedAllocator::new(&test_util::LimitedAllocator::new(max_allocations)),
2122            ) {
2123                Ok(()) => {
2124                    succeeded = true;
2125                }
2126                Err(CompressionError::AllocatorError(_)) => {
2127                    failed = true;
2128                }
2129                Err(other) => {
2130                    panic!("received an unexpected error: {:?}", other);
2131                }
2132            }
2133        }
2134        assert!(succeeded);
2135        assert!(failed);
2136    }
2137
2138    fn test_oom_resiliance<const Q: usize, const D: usize, Perm>(
2139        setup: &Setup,
2140        problem: &test_util::TestProblem,
2141        pre_scale: PreScale,
2142        rng: &mut StdRng,
2143    ) where
2144        Unsigned: Representation<Q>,
2145        Unsigned: Representation<D>,
2146        Perm: PermutationStrategy<Q>,
2147        for<'a> SphericalQuantizer: CompressIntoWith<
2148            &'a [f32],
2149            DataMut<'a, D>,
2150            ScopedAllocator<'a>,
2151            Error = CompressionError,
2152        >,
2153        for<'a> SphericalQuantizer: CompressIntoWith<
2154            &'a [f32],
2155            QueryMut<'a, Q, Perm>,
2156            ScopedAllocator<'a>,
2157            Error = CompressionError,
2158        >,
2159    {
2160        assert_eq!(setup.nrows, problem.data.nrows());
2161        assert_eq!(setup.ncols, problem.data.ncols());
2162
2163        let quantizer = SphericalQuantizer::train(
2164            problem.data.as_view(),
2165            setup.transform,
2166            SupportedMetric::SquaredL2,
2167            pre_scale,
2168            rng,
2169            GlobalAllocator,
2170        )
2171        .unwrap();
2172
2173        // Data.
2174        let data = problem.data.row(0);
2175        _test_oom_resiliance::<Data<D, _>>(
2176            &quantizer,
2177            data,
2178            &mut Data::new_boxed(quantizer.output_dim()),
2179        );
2180        _test_oom_resiliance::<Query<Q, Perm, _>>(
2181            &quantizer,
2182            data,
2183            &mut Query::new_boxed(quantizer.output_dim()),
2184        );
2185        _test_oom_resiliance::<FullQuery<_>>(
2186            &quantizer,
2187            data,
2188            &mut FullQuery::empty(quantizer.output_dim(), GlobalAllocator).unwrap(),
2189        );
2190    }
2191
2192    fn test_quantizer<const Q: usize, const D: usize, Perm>(setup: &Setup, rng: &mut StdRng)
2193    where
2194        Unsigned: Representation<Q>,
2195        Unsigned: Representation<D>,
2196        Perm: PermutationStrategy<Q>,
2197        for<'a> SphericalQuantizer: CompressIntoWith<
2198            &'a [f32],
2199            DataMut<'a, D>,
2200            ScopedAllocator<'a>,
2201            Error = CompressionError,
2202        >,
2203        for<'a> SphericalQuantizer: CompressIntoWith<
2204            &'a [f32],
2205            QueryMut<'a, Q, Perm>,
2206            ScopedAllocator<'a>,
2207            Error = CompressionError,
2208        >,
2209    {
2210        let problem = test_util::create_test_problem(setup.nrows, setup.ncols, rng);
2211        let computed_means_f32: Vec<_> = problem.means.iter().map(|i| *i as f32).collect();
2212
2213        let scales = [
2214            PreScale::Some(Positive::new(1.0 / 1024.0).unwrap()),
2215            PreScale::Some(Positive::new(1.0 / 1024.0).unwrap()),
2216            PreScale::ReciprocalMeanNorm,
2217        ];
2218
2219        for scale in scales {
2220            let ctx = &lazy_format!("dim = {}, scale = {:?}", setup.ncols, scale);
2221
2222            test_l2::<Q, D, Perm>(setup, &problem, &computed_means_f32, scale, rng);
2223            test_ip::<Q, D, Perm>(setup, &problem, &computed_means_f32, scale, rng, ctx);
2224            test_cosine::<Q, D, Perm>(setup, &problem, scale, rng);
2225        }
2226
2227        test_oom_resiliance::<Q, D, Perm>(setup, &problem, PreScale::ReciprocalMeanNorm, rng);
2228    }
2229
2230    #[test]
2231    fn test_spherical_quantizer() {
2232        let mut rng = StdRng::seed_from_u64(0xab516aef1ce61640);
2233        for dim in [56, 72, 128, 255] {
2234            let setup = Setup {
2235                transform: TransformKind::PaddingHadamard {
2236                    target_dim: TargetDim::Same,
2237                },
2238                nrows: 64,
2239                ncols: dim,
2240                num_trials: 10,
2241            };
2242
2243            test_quantizer::<4, 1, BitTranspose>(&setup, &mut rng);
2244            test_quantizer::<2, 2, Dense>(&setup, &mut rng);
2245            test_quantizer::<4, 4, Dense>(&setup, &mut rng);
2246            test_quantizer::<8, 8, Dense>(&setup, &mut rng);
2247
2248            let setup = Setup {
2249                transform: TransformKind::DoubleHadamard {
2250                    target_dim: TargetDim::Same,
2251                },
2252                nrows: 64,
2253                ncols: dim,
2254                num_trials: 10,
2255            };
2256            test_quantizer::<4, 1, BitTranspose>(&setup, &mut rng);
2257            test_quantizer::<2, 2, Dense>(&setup, &mut rng);
2258            test_quantizer::<4, 4, Dense>(&setup, &mut rng);
2259            test_quantizer::<8, 8, Dense>(&setup, &mut rng);
2260        }
2261    }
2262
2263    ////////////
2264    // Errors //
2265    ////////////
2266
2267    #[test]
2268    fn err_dim_cannot_be_zero() {
2269        let data = Matrix::new(0.0f32, 10, 0);
2270        let mut rng = StdRng::seed_from_u64(0xe3e9f42ed9f15883);
2271        let err = SphericalQuantizer::train(
2272            data.as_view(),
2273            TransformKind::DoubleHadamard {
2274                target_dim: TargetDim::Same,
2275            },
2276            SupportedMetric::SquaredL2,
2277            PreScale::None,
2278            &mut rng,
2279            GlobalAllocator,
2280        )
2281        .unwrap_err();
2282        assert_eq!(err.to_string(), "data dim cannot be zero");
2283    }
2284
2285    #[test]
2286    fn err_norm_must_be_positive() {
2287        let data = Matrix::new(0.0f32, 10, 10);
2288        let mut rng = StdRng::seed_from_u64(0xe3e9f42ed9f15883);
2289        let err = SphericalQuantizer::train(
2290            data.as_view(),
2291            TransformKind::DoubleHadamard {
2292                target_dim: TargetDim::Same,
2293            },
2294            SupportedMetric::SquaredL2,
2295            PreScale::None,
2296            &mut rng,
2297            GlobalAllocator,
2298        )
2299        .unwrap_err();
2300        assert_eq!(err.to_string(), "norm must be positive");
2301    }
2302
2303    #[test]
2304    fn err_norm_cannot_be_infinity() {
2305        let mut data = Matrix::new(0.0f32, 10, 10);
2306        data[(2, 5)] = f32::INFINITY;
2307
2308        let mut rng = StdRng::seed_from_u64(0xe3e9f42ed9f15883);
2309        let err = SphericalQuantizer::train(
2310            data.as_view(),
2311            TransformKind::DoubleHadamard {
2312                target_dim: TargetDim::Same,
2313            },
2314            SupportedMetric::SquaredL2,
2315            PreScale::None,
2316            &mut rng,
2317            GlobalAllocator,
2318        )
2319        .unwrap_err();
2320        assert_eq!(err.to_string(), "computed norm contains infinity or NaN");
2321    }
2322
2323    #[test]
2324    fn err_reciprocal_norm_cannot_be_infinity() {
2325        let mut data = Matrix::new(0.0f32, 10, 10);
2326        data[(2, 5)] = 2.93863e-39;
2327
2328        let mut rng = StdRng::seed_from_u64(0xe3e9f42ed9f15883);
2329        let err = SphericalQuantizer::train(
2330            data.as_view(),
2331            TransformKind::DoubleHadamard {
2332                target_dim: TargetDim::Same,
2333            },
2334            SupportedMetric::SquaredL2,
2335            PreScale::ReciprocalMeanNorm,
2336            &mut rng,
2337            GlobalAllocator,
2338        )
2339        .unwrap_err();
2340        assert_eq!(err.to_string(), "reciprocal norm contains infinity or NaN");
2341    }
2342
2343    #[test]
2344    fn err_mean_norm_cannot_be_zero_generate() {
2345        let centroid = Poly::broadcast(0.0f32, 10, GlobalAllocator).unwrap();
2346        let mut rng = StdRng::seed_from_u64(0xe3e9f42ed9f15883);
2347        let err = SphericalQuantizer::generate(
2348            centroid,
2349            0.0,
2350            TransformKind::DoubleHadamard {
2351                target_dim: TargetDim::Same,
2352            },
2353            SupportedMetric::SquaredL2,
2354            None,
2355            &mut rng,
2356            GlobalAllocator,
2357        )
2358        .unwrap_err();
2359        assert_eq!(err.to_string(), "norm must be positive");
2360    }
2361
2362    #[test]
2363    fn err_scale_cannot_be_zero_generate() {
2364        let centroid = Poly::broadcast(0.0f32, 10, GlobalAllocator).unwrap();
2365        let mut rng = StdRng::seed_from_u64(0xe3e9f42ed9f15883);
2366        let err = SphericalQuantizer::generate(
2367            centroid,
2368            1.0,
2369            TransformKind::DoubleHadamard {
2370                target_dim: TargetDim::Same,
2371            },
2372            SupportedMetric::SquaredL2,
2373            Some(0.0),
2374            &mut rng,
2375            GlobalAllocator,
2376        )
2377        .unwrap_err();
2378        assert_eq!(err.to_string(), "pre-scale must be positive");
2379    }
2380
2381    #[test]
2382    fn compression_errors_data() {
2383        let mut rng = StdRng::seed_from_u64(0xe3e9f42ed9f15883);
2384        let data = Matrix::<f32>::new(views::Init(|| StandardNormal {}.sample(&mut rng)), 16, 12);
2385
2386        let quantizer = SphericalQuantizer::train(
2387            data.as_view(),
2388            TransformKind::PaddingHadamard {
2389                target_dim: TargetDim::Same,
2390            },
2391            SupportedMetric::SquaredL2,
2392            PreScale::None,
2393            &mut rng,
2394            GlobalAllocator,
2395        )
2396        .unwrap();
2397
2398        let scoped_global = ScopedAllocator::global();
2399
2400        // Input contains NaN.
2401        {
2402            let mut query: Vec<f32> = quantizer.shift().to_vec();
2403            let mut d = Data::<1, _>::new_boxed(quantizer.output_dim());
2404            let mut q = Query::<4, BitTranspose, _>::new_boxed(quantizer.output_dim());
2405
2406            for i in 0..query.len() {
2407                let last = query[i];
2408                for v in [f32::NAN, f32::INFINITY, f32::NEG_INFINITY] {
2409                    query[i] = v;
2410
2411                    let err = quantizer
2412                        .compress_into_with(&*query, d.reborrow_mut(), scoped_global)
2413                        .unwrap_err();
2414
2415                    assert_eq!(err.to_string(), "input contains NaN", "failed for {}", v);
2416
2417                    let err = quantizer
2418                        .compress_into_with(&*query, q.reborrow_mut(), scoped_global)
2419                        .unwrap_err();
2420
2421                    assert_eq!(err.to_string(), "input contains NaN", "failed for {}", v);
2422                }
2423                query[i] = last;
2424            }
2425        }
2426
2427        // Input has a large value.
2428        {
2429            let query: Vec<f32> = vec![1000000.0; quantizer.input_dim()];
2430            let mut d = Data::<1, _>::new_boxed(quantizer.output_dim());
2431
2432            let err = quantizer
2433                .compress_into_with(&*query, d.reborrow_mut(), scoped_global)
2434                .unwrap_err();
2435
2436            let expected = "encoding error - you may need to scale the entire dataset to reduce its dynamic range";
2437
2438            assert_eq!(err.to_string(), expected, "failed for {:?}", query);
2439        }
2440
2441        // Input length
2442        for len in [quantizer.input_dim() - 1, quantizer.input_dim() + 1] {
2443            let query = vec![0.0f32; len];
2444            let mut d = Data::<1, _>::new_boxed(quantizer.output_dim());
2445            let mut q = Query::<4, BitTranspose, _>::new_boxed(quantizer.output_dim());
2446
2447            let err = quantizer
2448                .compress_into_with(&*query, d.reborrow_mut(), scoped_global)
2449                .unwrap_err();
2450            assert_eq!(
2451                err,
2452                CompressionError::SourceDimensionMismatch {
2453                    expected: quantizer.input_dim(),
2454                }
2455            );
2456
2457            let err = quantizer
2458                .compress_into_with(&*query, q.reborrow_mut(), scoped_global)
2459                .unwrap_err();
2460            assert_eq!(
2461                err,
2462                CompressionError::SourceDimensionMismatch {
2463                    expected: quantizer.input_dim(),
2464                }
2465            );
2466        }
2467
2468        for len in [quantizer.output_dim() - 1, quantizer.output_dim() + 1] {
2469            let query = vec![0.0f32; quantizer.input_dim()];
2470            let mut d = Data::<1, _>::new_boxed(len);
2471            let mut q = Query::<4, BitTranspose, _>::new_boxed(len);
2472
2473            let err = quantizer
2474                .compress_into_with(&*query, d.reborrow_mut(), scoped_global)
2475                .unwrap_err();
2476            assert_eq!(
2477                err,
2478                CompressionError::DestinationDimensionMismatch {
2479                    expected: quantizer.output_dim(),
2480                }
2481            );
2482
2483            let err = quantizer
2484                .compress_into_with(&*query, q.reborrow_mut(), scoped_global)
2485                .unwrap_err();
2486            assert_eq!(
2487                err,
2488                CompressionError::DestinationDimensionMismatch {
2489                    expected: quantizer.output_dim(),
2490                }
2491            );
2492        }
2493    }
2494
2495    #[test]
2496    fn centroid_scaling_happens_in_generate() {
2497        let centroid = Poly::from_iter(
2498            [1088.6732f32, 1393.32, 1547.877].into_iter(),
2499            GlobalAllocator,
2500        )
2501        .unwrap();
2502        let mean_norm = 2359.27;
2503        let pre_scale = 1.0 / mean_norm;
2504
2505        let quantizer = SphericalQuantizer::generate(
2506            centroid,
2507            mean_norm,
2508            TransformKind::Null,
2509            SupportedMetric::InnerProduct,
2510            Some(pre_scale),
2511            &mut StdRng::seed_from_u64(10),
2512            GlobalAllocator,
2513        )
2514        .unwrap();
2515
2516        let mut v = Data::<4, _>::new_boxed(quantizer.input_dim());
2517        let data: &[f32] = &[1000.34, 1456.32, 1234.5446];
2518        assert!(quantizer
2519            .compress_into_with(data, v.reborrow_mut(), ScopedAllocator::global())
2520            .is_ok(),
2521            "if this failed, the likely culprit is exceeding the value of the 16-bit correction terms"
2522            );
2523    }
2524}
2525
2526#[cfg(feature = "flatbuffers")]
2527#[cfg(test)]
2528mod test_serialization {
2529    use rand::{rngs::StdRng, SeedableRng};
2530
2531    use super::*;
2532    use crate::{
2533        algorithms::transforms::TargetDim,
2534        flatbuffers::{self as fb, to_flatbuffer},
2535        poly, test_util,
2536    };
2537
2538    #[test]
2539    fn test_serialization_happy_path() {
2540        let mut rng = StdRng::seed_from_u64(0x070d9ff8cf5e0f8c);
2541        let problem = test_util::create_test_problem(10, 128, &mut rng);
2542
2543        let low = NonZeroUsize::new(100).unwrap();
2544        let high = NonZeroUsize::new(150).unwrap();
2545
2546        let kinds = [
2547            // Null
2548            TransformKind::Null,
2549            // Double Hadamard
2550            TransformKind::DoubleHadamard {
2551                target_dim: TargetDim::Same,
2552            },
2553            TransformKind::DoubleHadamard {
2554                target_dim: TargetDim::Natural,
2555            },
2556            TransformKind::DoubleHadamard {
2557                target_dim: TargetDim::Override(low),
2558            },
2559            TransformKind::DoubleHadamard {
2560                target_dim: TargetDim::Override(high),
2561            },
2562            // Padding Hadamard
2563            TransformKind::PaddingHadamard {
2564                target_dim: TargetDim::Same,
2565            },
2566            TransformKind::PaddingHadamard {
2567                target_dim: TargetDim::Natural,
2568            },
2569            TransformKind::PaddingHadamard {
2570                target_dim: TargetDim::Override(low),
2571            },
2572            TransformKind::PaddingHadamard {
2573                target_dim: TargetDim::Override(high),
2574            },
2575            // Random Rotation
2576            #[cfg(all(not(miri), feature = "linalg"))]
2577            TransformKind::RandomRotation {
2578                target_dim: TargetDim::Same,
2579            },
2580            #[cfg(all(not(miri), feature = "linalg"))]
2581            TransformKind::RandomRotation {
2582                target_dim: TargetDim::Natural,
2583            },
2584            #[cfg(all(not(miri), feature = "linalg"))]
2585            TransformKind::RandomRotation {
2586                target_dim: TargetDim::Override(low),
2587            },
2588            #[cfg(all(not(miri), feature = "linalg"))]
2589            TransformKind::RandomRotation {
2590                target_dim: TargetDim::Override(high),
2591            },
2592        ];
2593
2594        let pre_scales = [
2595            PreScale::None,
2596            PreScale::Some(Positive::new(0.5).unwrap()),
2597            PreScale::Some(Positive::new(1.0).unwrap()),
2598            PreScale::Some(Positive::new(1.5).unwrap()),
2599            PreScale::ReciprocalMeanNorm,
2600        ];
2601
2602        let alloc = GlobalAllocator;
2603        for kind in kinds.into_iter() {
2604            for metric in SupportedMetric::all() {
2605                for pre_scale in pre_scales {
2606                    let quantizer = SphericalQuantizer::train(
2607                        problem.data.as_view(),
2608                        kind,
2609                        metric,
2610                        pre_scale,
2611                        &mut rng,
2612                        alloc,
2613                    )
2614                    .unwrap();
2615
2616                    let data = to_flatbuffer(|buf| quantizer.pack(buf));
2617                    let proto =
2618                        flatbuffers::root::<fb::spherical::SphericalQuantizer>(&data).unwrap();
2619                    let reloaded = SphericalQuantizer::try_unpack(alloc, proto).unwrap();
2620                    assert_eq!(quantizer, reloaded, "failed on transform {:?}", kind);
2621                }
2622            }
2623        }
2624    }
2625
2626    #[test]
2627    fn test_error_checking() {
2628        let mut rng = StdRng::seed_from_u64(0x070d9ff8cf5e0f8c);
2629        let problem = test_util::create_test_problem(10, 128, &mut rng);
2630
2631        let transform = TransformKind::DoubleHadamard {
2632            target_dim: TargetDim::Same,
2633        };
2634
2635        let alloc = GlobalAllocator;
2636        let mut make_quantizer = || {
2637            SphericalQuantizer::train(
2638                problem.data.as_view(),
2639                transform,
2640                SupportedMetric::SquaredL2,
2641                PreScale::None,
2642                &mut rng,
2643                alloc,
2644            )
2645            .unwrap()
2646        };
2647
2648        type E = DeserializationError;
2649
2650        // Missing norm: 0.0
2651        {
2652            let mut quantizer = make_quantizer();
2653            // SAFETY: We do not do anything with the created value and the compiler
2654            // does not know about the layout of `Positive`, so we don't need to worry
2655            // about violating layout restrictions.
2656            quantizer.mean_norm = unsafe { Positive::new_unchecked(0.0) };
2657
2658            let data = to_flatbuffer(|buf| quantizer.pack(buf));
2659            let proto = flatbuffers::root::<fb::spherical::SphericalQuantizer>(&data).unwrap();
2660            let err = SphericalQuantizer::try_unpack(alloc, proto).unwrap_err();
2661            assert_eq!(err, E::MissingNorm);
2662        }
2663
2664        // Missing norm: negative
2665        {
2666            let mut quantizer = make_quantizer();
2667
2668            // SAFETY: We do not do anything with the created value and the compiler
2669            // does not know about the layout of `Positive`, so we don't need to worry
2670            // about violating layout restrictions.
2671            quantizer.mean_norm = unsafe { Positive::new_unchecked(-1.0) };
2672
2673            let data = to_flatbuffer(|buf| quantizer.pack(buf));
2674            let proto = flatbuffers::root::<fb::spherical::SphericalQuantizer>(&data).unwrap();
2675            let err = SphericalQuantizer::try_unpack(alloc, proto).unwrap_err();
2676            assert_eq!(err, E::MissingNorm);
2677        }
2678
2679        // PreScaleNotPositive
2680        {
2681            let mut quantizer = make_quantizer();
2682
2683            // SAFETY: This really isn't safe, but we are not using the improper value in a
2684            // way that will trigger undefined behavior.
2685            quantizer.pre_scale = unsafe { Positive::new_unchecked(0.0) };
2686
2687            let data = to_flatbuffer(|buf| quantizer.pack(buf));
2688            let proto = flatbuffers::root::<fb::spherical::SphericalQuantizer>(&data).unwrap();
2689            let err = SphericalQuantizer::try_unpack(alloc, proto).unwrap_err();
2690            assert_eq!(err, E::PreScaleNotPositive);
2691        }
2692
2693        // Dim Mismatch.
2694        {
2695            let mut quantizer = make_quantizer();
2696            quantizer.shift = poly!([1.0, 2.0, 3.0], alloc).unwrap();
2697
2698            let data = to_flatbuffer(|buf| quantizer.pack(buf));
2699            let proto = flatbuffers::root::<fb::spherical::SphericalQuantizer>(&data).unwrap();
2700            let err = SphericalQuantizer::try_unpack(alloc, proto).unwrap_err();
2701            assert_eq!(err, E::DimMismatch);
2702        }
2703    }
2704}