Skip to main content

diskann_quantization/spherical/
quantizer.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::num::NonZeroUsize;
7
8use diskann_utils::{ReborrowMut, views::MatrixView};
9use diskann_vector::{
10    MathematicalValue, Norm, PureDistanceFunction, distance::InnerProduct, norm::FastL2Norm,
11};
12#[cfg(feature = "flatbuffers")]
13use flatbuffers::{FlatBufferBuilder, WIPOffset};
14use rand::{Rng, RngCore};
15use thiserror::Error;
16
17use super::{
18    CompensatedCosine, CompensatedIP, CompensatedSquaredL2, DataMeta, DataMetaError, DataMut,
19    FullQueryMeta, FullQueryMut, QueryMeta, QueryMut, SupportedMetric,
20};
21use crate::{
22    AsFunctor, CompressIntoWith,
23    algorithms::{
24        heap::SliceHeap,
25        transforms::{NewTransformError, Transform, TransformFailed, TransformKind},
26    },
27    alloc::{Allocator, AllocatorError, GlobalAllocator, Poly, ScopedAllocator, TryClone},
28    bits::{PermutationStrategy, Representation, Unsigned},
29    num::Positive,
30    utils::{CannotBeEmpty, compute_means_and_average_norm, compute_normalized_means},
31};
32#[cfg(feature = "flatbuffers")]
33use crate::{
34    algorithms::transforms::TransformError, flatbuffers::spherical, spherical::InvalidMetric,
35};
36
37///////////////
38// Quantizer //
39///////////////
40
41#[derive(Debug)]
42#[cfg_attr(test, derive(PartialEq))]
43pub struct SphericalQuantizer<A = GlobalAllocator>
44where
45    A: Allocator,
46{
47    /// The offset to apply to each vector.
48    shift: Poly<[f32], A>,
49
50    /// The [`SphericalQuantizer`] supports several different strategies for performing the
51    /// distance-preserving transformation on dataset vectors, which may be applicable in
52    /// different scenarios.
53    ///
54    /// The different transformations may have restrictions on the number of supported dimensions.
55    /// While we will accept all non-zero input dimensions, the output dimension of a transform
56    /// may be higher or lower, depending on the configuration.
57    transform: Transform<A>,
58
59    /// The metric meant to be used by the quantizer.
60    metric: SupportedMetric,
61
62    /// When processing queries, it may be beneficial to modify the query norm to match the
63    /// dataset norm.
64    ///
65    /// This is only applicable when `InnerProduct` and `Cosine` are used, but serves to
66    /// move the query into the dynamic range of the quantization.
67    ///
68    /// You would think that the normalization step in RabitQ would mitigate this, but
69    /// that is not always right since range-adjustment happens before centering.
70    mean_norm: Positive<f32>,
71
72    /// To support 16-bit constants which have a limited dynamic range, we allow a
73    /// pre-scaling parameter that is multiplied to each value in compressed vectors.
74    ///
75    /// This allows to transparent handling of compressing integral data, which can
76    /// otherwise easily overflow `f16`.
77    pre_scale: Positive<f32>,
78}
79
80impl<A> TryClone for SphericalQuantizer<A>
81where
82    A: Allocator,
83{
84    fn try_clone(&self) -> Result<Self, AllocatorError> {
85        Ok(Self {
86            shift: self.shift.try_clone()?,
87            transform: self.transform.try_clone()?,
88            metric: self.metric,
89            mean_norm: self.mean_norm,
90            pre_scale: self.pre_scale,
91        })
92    }
93}
94
95#[derive(Debug, Clone, Copy, Error)]
96#[non_exhaustive]
97pub enum TrainError {
98    #[error("data dim cannot be zero")]
99    DimCannotBeZero,
100    #[error("data cannot be empty")]
101    DataCannotBeEmpty,
102    #[error("pre-scale must be positive")]
103    PrescaleNotPositive,
104    #[error("norm must be positive")]
105    NormNotPositive,
106    #[error("computed norm contains infinity or NaN")]
107    NormNotFinite,
108    #[error("reciprocal norm contains infinity or NaN")]
109    ReciprocalNormNotFinite,
110    #[error(transparent)]
111    AllocatorError(#[from] AllocatorError),
112}
113
114impl<A> SphericalQuantizer<A>
115where
116    A: Allocator,
117{
118    /// Return the number dimensions this quantizer has been trained for.
119    pub fn input_dim(&self) -> usize {
120        self.shift.len()
121    }
122
123    /// Return the dimension of the post-transformed vector.
124    ///
125    /// Output storage vectors should use this dimension instead of `self.dim()` because
126    /// in general, the output dim **may** be different from the input dimension.
127    pub fn output_dim(&self) -> usize {
128        self.transform.output_dim()
129    }
130
131    /// Return the per-dimension shift vector.
132    ///
133    /// This vector is meant to accomplish two goals:
134    ///
135    /// 1. Centers the data around the training dataset mean.
136    /// 2. Offsets each dimension into a range that can be encoded in unsigned values.
137    pub fn shift(&self) -> &[f32] {
138        &self.shift
139    }
140
141    /// Return the approximate mean norm of the training data.
142    pub fn mean_norm(&self) -> Positive<f32> {
143        self.mean_norm
144    }
145
146    /// Return the pre-scaling parameter for data. This value is multiplied to every
147    /// compressed vector to adjust its dynamic range.
148    ///
149    /// A value of 1.0 means that no scaling is occurring.
150    pub fn pre_scale(&self) -> Positive<f32> {
151        self.pre_scale
152    }
153
154    /// Return a reference to the allocator used by this data structure.
155    pub fn allocator(&self) -> &A {
156        self.shift.allocator()
157    }
158
159    /// A lower-level constructor that accepts a centroid, mean norm, and pre-scale directly.
160    pub fn generate(
161        mut centroid: Poly<[f32], A>,
162        mean_norm: f32,
163        transform: TransformKind,
164        metric: SupportedMetric,
165        pre_scale: Option<f32>,
166        rng: &mut dyn RngCore,
167        allocator: A,
168    ) -> Result<Self, TrainError> {
169        let pre_scale = match pre_scale {
170            Some(v) => Positive::new(v).map_err(|_| TrainError::PrescaleNotPositive)?,
171            None => crate::num::POSITIVE_ONE_F32,
172        };
173
174        let dim = match NonZeroUsize::new(centroid.len()) {
175            Some(dim) => dim,
176            None => {
177                return Err(TrainError::DimCannotBeZero);
178            }
179        };
180
181        let mean_norm = Positive::new(mean_norm).map_err(|_| TrainError::NormNotPositive)?;
182
183        // We passed in a 'rng' so `Transform::new` will not fail.
184        let transform = match Transform::new(transform, dim, Some(rng), allocator.clone()) {
185            Ok(v) => v,
186            Err(NewTransformError::RngMissing(_)) => unreachable!("An Rng was provided"),
187            Err(NewTransformError::AllocatorError(err)) => {
188                return Err(TrainError::AllocatorError(err));
189            }
190        };
191
192        // Transform the centroid by the pre-scale.
193        centroid
194            .iter_mut()
195            .for_each(|v| *v *= pre_scale.into_inner());
196
197        Ok(SphericalQuantizer {
198            shift: centroid,
199            transform,
200            metric,
201            mean_norm,
202            pre_scale,
203        })
204    }
205
206    /// Return the metric used by this quantizer.
207    pub fn metric(&self) -> SupportedMetric {
208        self.metric
209    }
210
211    /// Construct a quantizer for vectors in the distribution of `data`.
212    ///
213    /// The type of distance-preserving transform to use is selected by the [`TransformKind`].
214    ///
215    /// Vectors compressed with this quantizer will be **metric specific** and optimized for
216    /// distance computations rather than reconstruction. This means that vectors compressed
217    /// targeting the inner-product distance will not return meaningful results if used for
218    /// L2 distance computations.
219    ///
220    /// Additionally, vectors compressed when using the [`SupportedMetric::Cosine`] distance
221    /// will be implicitly normalized before being compressed to enable better compression.
222    ///
223    /// If argument `pre_scale` is given, then all vectors compressed by this quantizer will
224    /// first be scaled by this value. Note that if given, `pre_scale` **must** be positive.
225    pub fn train<T, R>(
226        data: MatrixView<T>,
227        transform: TransformKind,
228        metric: SupportedMetric,
229        pre_scale: PreScale,
230        rng: &mut R,
231        allocator: A,
232    ) -> Result<Self, TrainError>
233    where
234        T: Copy + Into<f64> + Into<f32>,
235        R: Rng,
236    {
237        // An inner implementation erasing the type of the random number generator to
238        // cut down on excess monomorphization.
239        #[inline(never)]
240        fn train<T, A>(
241            data: MatrixView<T>,
242            transform: TransformKind,
243            metric: SupportedMetric,
244            pre_scale: PreScale,
245            rng: &mut dyn RngCore,
246            allocator: A,
247        ) -> Result<SphericalQuantizer<A>, TrainError>
248        where
249            T: Copy + Into<f64> + Into<f32>,
250            A: Allocator,
251        {
252            // This check is repeated in `Self::generate`, but we prefer to bail as early
253            // as possible if we detect an error.
254            if data.ncols() == 0 {
255                return Err(TrainError::DimCannotBeZero);
256            }
257
258            let (centroid, mean_norm) = match metric {
259                SupportedMetric::SquaredL2 | SupportedMetric::InnerProduct => {
260                    compute_means_and_average_norm(data)
261                }
262                SupportedMetric::Cosine => (
263                    compute_normalized_means(data)
264                        .map_err(|_: CannotBeEmpty| TrainError::DataCannotBeEmpty)?,
265                    1.0,
266                ),
267            };
268
269            let mean_norm = mean_norm as f32;
270
271            if mean_norm <= 0.0 {
272                return Err(TrainError::NormNotPositive);
273            }
274
275            if !mean_norm.is_finite() {
276                return Err(TrainError::NormNotFinite);
277            }
278
279            // Determining if (and how) the pre-scaling term will be calculated.
280            let pre_scale: Positive<f32> = match pre_scale {
281                PreScale::None => crate::num::POSITIVE_ONE_F32,
282                PreScale::Some(v) => v,
283                PreScale::ReciprocalMeanNorm => {
284                    // We've checked that `mean_norm` is both positive and finite.
285                    //
286                    // It's possible that when converted to `f32`,
287                    //
288                    // Taking the reciprocal is well defined. However, since the norms
289                    // and scales in `compute_means_and_average_norm` are done using `f64`,
290                    // it's possible that the computed `mean_norm` is subnormal, leading
291                    // to the reciprocal being infinity.
292                    let pre_scale = Positive::new(1.0 / mean_norm)
293                        .map_err(|_| TrainError::ReciprocalNormNotFinite)?;
294
295                    if !pre_scale.into_inner().is_finite() {
296                        return Err(TrainError::ReciprocalNormNotFinite);
297                    }
298
299                    pre_scale
300                }
301            };
302
303            // Allow the pre-scaling to take place inside `Self::generate`.
304            let centroid =
305                Poly::from_iter(centroid.into_iter().map(|i| i as f32), allocator.clone())?;
306
307            SphericalQuantizer::generate(
308                centroid,
309                mean_norm,
310                transform,
311                metric,
312                Some(pre_scale.into_inner()),
313                rng,
314                allocator,
315            )
316        }
317
318        train(data, transform, metric, pre_scale, rng, allocator)
319    }
320
321    /// Rescale the argument `v` to be in the rough dynamic range of the training dataset.
322    pub fn rescale(&self, v: &mut [f32]) {
323        let norm = FastL2Norm.evaluate(&*v);
324        let m = self.mean_norm.into_inner() / norm;
325        v.iter_mut().for_each(|i| *i *= m);
326    }
327
328    /// Private helper function to do common data pre-processing.
329    ///
330    /// # Panics
331    ///
332    /// Panics if `data.len() != self.dim()`.
333    fn preprocess<'a>(
334        &self,
335        data: &[f32],
336        allocator: ScopedAllocator<'a>,
337    ) -> Result<Preprocessed<'a>, CompressionError> {
338        assert_eq!(data.len(), self.input_dim(), "Data dimension is incorrect.");
339
340        // Fold in pre-scaling with the potential norm corretion for cosine.
341        //
342        // NOTE: When we're computing Cosine Similarity, we normalize the vector. As such,
343        // the `pre_scale` parameter become irrelvant since it just gets normalized away.
344        let scale = self.pre_scale.into_inner();
345        let mul: f32 = match self.metric {
346            SupportedMetric::Cosine => {
347                let norm: f32 = (FastL2Norm).evaluate(data);
348                if norm == 0.0 { 1.0 } else { 1.0 / norm }
349            }
350            SupportedMetric::SquaredL2 | SupportedMetric::InnerProduct => scale,
351        };
352
353        // Center the vector and compute the squared norm of the shifted vector.
354        let shifted = Poly::from_iter(
355            std::iter::zip(data.iter(), self.shift.iter()).map(|(&f, &s)| mul * f - s),
356            allocator,
357        )?;
358
359        let shifted_norm = FastL2Norm.evaluate(&*shifted);
360        if !shifted_norm.is_finite() {
361            return Err(CompressionError::InputContainsNaN);
362        }
363        let inner_product_with_centroid = match self.metric {
364            SupportedMetric::SquaredL2 => None,
365            SupportedMetric::InnerProduct | SupportedMetric::Cosine => {
366                let ip: MathematicalValue<f32> = InnerProduct::evaluate(&*shifted, &*self.shift);
367                Some(ip.into_inner())
368            }
369        };
370
371        Ok(Preprocessed {
372            shifted,
373            shifted_norm,
374            inner_product_with_centroid,
375        })
376    }
377}
378
379/// Pre-scaling selector for spherical quantization training. Pre-scaling adjusts the
380/// dynamic range of the data (usually decreasing it uniformly) to keep the correction terms
381/// within the range expressible by 16-bit floating point numbers.
382#[derive(Debug, Clone, Copy)]
383pub enum PreScale {
384    /// Do not use any pre-scaling.
385    None,
386    /// Pre-scale all data by the specified amount.
387    Some(Positive<f32>),
388    /// Heuristically estimate a pre-scaling parameter by using the inverse approximate
389    /// mean norm. This will nearly normalize in-distribution vectors.
390    ReciprocalMeanNorm,
391}
392
393#[cfg(feature = "flatbuffers")]
394#[cfg_attr(docsrs, doc(cfg(feature = "flatbuffers")))]
395#[derive(Debug, Clone, Error, PartialEq)]
396#[non_exhaustive]
397pub enum DeserializationError {
398    #[error(transparent)]
399    TransformError(#[from] TransformError),
400    #[error("unrecognized flatbuffer identifier")]
401    UnrecognizedIdentifier,
402    #[error("transform length not equal to centroid")]
403    DimMismatch,
404    #[error("norm is missing or is not positive")]
405    MissingNorm,
406    #[error("pre-scale is missing or is not positive")]
407    PreScaleNotPositive,
408
409    #[error(transparent)]
410    InvalidFlatBuffer(#[from] flatbuffers::InvalidFlatbuffer),
411
412    #[error(transparent)]
413    InvalidMetric(#[from] InvalidMetric),
414
415    #[error(transparent)]
416    AllocatorError(#[from] AllocatorError),
417}
418
419#[cfg(feature = "flatbuffers")]
420#[cfg_attr(docsrs, doc(cfg(feature = "flatbuffers")))]
421impl<A> SphericalQuantizer<A>
422where
423    A: Allocator + Clone,
424{
425    /// Pack `self` into `buf` using the [`spherical::SphericalQuantizer`] serialized
426    /// representation.
427    pub(crate) fn pack<'a, FA>(
428        &self,
429        buf: &mut FlatBufferBuilder<'a, FA>,
430    ) -> WIPOffset<spherical::SphericalQuantizer<'a>>
431    where
432        FA: flatbuffers::Allocator + 'a,
433    {
434        // Save the centroid vector.
435        let centroid = buf.create_vector(&self.shift);
436
437        // Save the transform.
438        let transform = self.transform.pack(buf);
439
440        // Finish up.
441        spherical::SphericalQuantizer::create(
442            buf,
443            &spherical::SphericalQuantizerArgs {
444                centroid: Some(centroid),
445                transform: Some(transform),
446                metric: self.metric.into(),
447                mean_norm: self.mean_norm.into_inner(),
448                pre_scale: self.pre_scale.into_inner(),
449            },
450        )
451    }
452
453    /// Attempt to unpack `self` from a serialized [`spherical::SphericalQuantizer`]
454    /// serialized representation, returning any encountered error.
455    pub(crate) fn try_unpack(
456        alloc: A,
457        proto: spherical::SphericalQuantizer<'_>,
458    ) -> Result<Self, DeserializationError> {
459        let metric: SupportedMetric = proto.metric().try_into()?;
460
461        // Unpack the centroid.
462        let shift = Poly::from_iter(proto.centroid().into_iter(), alloc.clone())?;
463
464        // Unpack the transform.
465        let transform = Transform::try_unpack(alloc, proto.transform())?;
466
467        // Ensure consistency between the shift dimensions and the transform.
468        if shift.len() != transform.input_dim() {
469            return Err(DeserializationError::DimMismatch);
470        }
471
472        // Make sure we get a sane value for the mean norm.
473        let mean_norm =
474            Positive::new(proto.mean_norm()).map_err(|_| DeserializationError::MissingNorm)?;
475
476        let pre_scale = Positive::new(proto.pre_scale())
477            .map_err(|_| DeserializationError::PreScaleNotPositive)?;
478
479        Ok(Self {
480            shift,
481            transform,
482            metric,
483            mean_norm,
484            pre_scale,
485        })
486    }
487}
488
489struct Preprocessed<'a> {
490    shifted: Poly<[f32], ScopedAllocator<'a>>,
491    shifted_norm: f32,
492    inner_product_with_centroid: Option<f32>,
493}
494
495impl Preprocessed<'_> {
496    /// Return the metric specific correction term as sumamrized below:
497    ///
498    /// * Inner Product: The inner product between the shifted vector and the centroid.
499    /// * Squared L2: The squared norm of the shifted vector.
500    fn metric_specific(&self) -> f32 {
501        match self.inner_product_with_centroid {
502            Some(ip) => ip,
503            None => self.shifted_norm * self.shifted_norm,
504        }
505    }
506}
507
508///////////////////////
509// Distance Functors //
510///////////////////////
511
512impl<A> AsFunctor<CompensatedSquaredL2> for SphericalQuantizer<A>
513where
514    A: Allocator,
515{
516    fn as_functor(&self) -> CompensatedSquaredL2 {
517        CompensatedSquaredL2::new(self.output_dim())
518    }
519}
520
521impl<A> AsFunctor<CompensatedIP> for SphericalQuantizer<A>
522where
523    A: Allocator,
524{
525    fn as_functor(&self) -> CompensatedIP {
526        CompensatedIP::new(&self.shift, self.output_dim())
527    }
528}
529
530impl<A> AsFunctor<CompensatedCosine> for SphericalQuantizer<A>
531where
532    A: Allocator,
533{
534    fn as_functor(&self) -> CompensatedCosine {
535        CompensatedCosine::new(self.as_functor())
536    }
537}
538
539/////////////////
540// Compression //
541/////////////////
542
543#[derive(Debug, Error, Clone, Copy, PartialEq)]
544#[non_exhaustive]
545pub enum CompressionError {
546    #[error("input contains NaN")]
547    InputContainsNaN,
548
549    #[error("expected source vector to have length {expected}")]
550    SourceDimensionMismatch { expected: usize },
551
552    #[error("expected destination vector to have length {expected}")]
553    DestinationDimensionMismatch { expected: usize },
554
555    #[error(
556        "encoding error - you may need to scale the entire dataset to reduce its dynamic range"
557    )]
558    EncodingError(#[from] DataMetaError),
559
560    #[error(transparent)]
561    AllocatorError(#[from] AllocatorError),
562}
563
564fn check_dims(
565    input: usize,
566    output: usize,
567    from: usize,
568    into: usize,
569) -> Result<(), CompressionError> {
570    if from != input {
571        return Err(CompressionError::SourceDimensionMismatch { expected: input });
572    }
573    if into != output {
574        return Err(CompressionError::DestinationDimensionMismatch { expected: output });
575    }
576    Ok(())
577}
578
579/// Helper trait to dispatch to a faster 1-bit implementation and use the slower
580/// maximum-cosine algorithm when more than 1 bit is used.
581trait FinishCompressing {
582    fn finish_compressing(
583        &mut self,
584        preprocessed: &Preprocessed<'_>,
585        transformed: &[f32],
586        transformed_norm: f32,
587        allocator: ScopedAllocator<'_>,
588    ) -> Result<(), CompressionError>;
589}
590
591impl FinishCompressing for DataMut<'_, 1> {
592    fn finish_compressing(
593        &mut self,
594        preprocessed: &Preprocessed<'_>,
595        transformed: &[f32],
596        transformed_norm: f32,
597        _: ScopedAllocator<'_>,
598    ) -> Result<(), CompressionError> {
599        // Compute signed quantized vector (-1 or 1)
600        // and also populate the unsigned bit representation in `into` output vector.
601        let mut quant_raw_inner_product = 0.0f32;
602        let mut bit_sum = 0u32;
603        transformed.iter().enumerate().for_each(|(i, &r)| {
604            let bit: u8 = if r > 0.0 { 1 } else { 0 };
605
606            quant_raw_inner_product += r.abs();
607            bit_sum += <u8 as Into<u32>>::into(bit);
608
609            // SAFETY: From check 1, we know that `i < into.len()`.
610            unsafe { self.vector_mut().set_unchecked(i, bit) };
611        });
612
613        // The value we just computed for `quant_raw_inner_product` is:
614        // ```
615        // Y = <x', x> * sqrt(D)                        [1]
616        // ```
617        // The inner product correction term is
618        // ```
619        //       2 |X|
620        // -----------------                            [2]
621        // <x', x> * sqrt(D)
622        // ```
623        // [1] substitutes directly into [2] and we get
624        // ```
625        // 2 |X|
626        // -----
627        //   Y
628        // ```
629        // Therefore, the inner product correction term is
630        // ```
631        // 2.0 * shifted_norm / quant_raw_inner_product
632        // ```
633        let inner_product_correction =
634            2.0 * transformed_norm * preprocessed.shifted_norm / quant_raw_inner_product;
635        self.set_meta(DataMeta::new(
636            inner_product_correction,
637            preprocessed.metric_specific(),
638            bit_sum,
639        )?);
640        Ok(())
641    }
642}
643
644impl FinishCompressing for DataMut<'_, 2> {
645    fn finish_compressing(
646        &mut self,
647        preprocessed: &Preprocessed<'_>,
648        transformed: &[f32],
649        transformed_norm: f32,
650        allocator: ScopedAllocator<'_>,
651    ) -> Result<(), CompressionError> {
652        compress_via_maximum_cosine(
653            self.reborrow_mut(),
654            preprocessed,
655            transformed,
656            transformed_norm,
657            allocator,
658        )
659    }
660}
661
662impl FinishCompressing for DataMut<'_, 4> {
663    fn finish_compressing(
664        &mut self,
665        preprocessed: &Preprocessed<'_>,
666        transformed: &[f32],
667        transformed_norm: f32,
668        allocator: ScopedAllocator<'_>,
669    ) -> Result<(), CompressionError> {
670        compress_via_maximum_cosine(
671            self.reborrow_mut(),
672            preprocessed,
673            transformed,
674            transformed_norm,
675            allocator,
676        )
677    }
678}
679
680impl FinishCompressing for DataMut<'_, 8> {
681    fn finish_compressing(
682        &mut self,
683        preprocessed: &Preprocessed<'_>,
684        transformed: &[f32],
685        transformed_norm: f32,
686        allocator: ScopedAllocator<'_>,
687    ) -> Result<(), CompressionError> {
688        compress_via_maximum_cosine(
689            self.reborrow_mut(),
690            preprocessed,
691            transformed,
692            transformed_norm,
693            allocator,
694        )
695    }
696}
697
698//////////////////////
699// Data Compression //
700//////////////////////
701
702impl<A> CompressIntoWith<&[f32], FullQueryMut<'_>, ScopedAllocator<'_>> for SphericalQuantizer<A>
703where
704    A: Allocator,
705{
706    type Error = CompressionError;
707
708    /// Compress the input vector `from` into the bitslice `into`.
709    ///
710    /// # Error
711    ///
712    /// Returns an error if
713    /// * The input contains `NaN`.
714    /// * `from.len() != self.dim()`: Vector to be compressed must have the same
715    ///   dimensionality as the quantizer.
716    /// * `into.len() != self.output_dim()`: Compressed vector must have the same
717    ///   dimensionality as the output of the distance-preserving transform. Importantely,
718    ///   this **may** be different than `self.dim()` and should be retrieved from
719    ///   `self.output_dim()`.
720    fn compress_into_with(
721        &self,
722        from: &[f32],
723        mut into: FullQueryMut<'_>,
724        allocator: ScopedAllocator<'_>,
725    ) -> Result<(), Self::Error> {
726        let input_dim = self.shift.len();
727        let output_dim = self.output_dim();
728        check_dims(input_dim, output_dim, from.len(), into.len())?;
729
730        let mut preprocessed = self.preprocess(from, allocator)?;
731
732        // If the preprocessed norm is zero, then we tried to compress the center directly.
733        // In this case, we can get the correct behavior by setting `into` to all zeros.
734        if preprocessed.shifted_norm == 0.0 {
735            into.vector_mut().fill(0.0);
736            *into.meta_mut() = Default::default();
737            return Ok(());
738        }
739
740        preprocessed
741            .shifted
742            .iter_mut()
743            .for_each(|v| *v /= preprocessed.shifted_norm);
744
745        // Transformation can fail due to OOM - we want to handle that gracefully.
746        //
747        // If the transformation fails because we provided the wrong sizes, that is a hard
748        // program bug.
749        #[expect(clippy::panic, reason = "the dimensions should already be as expected")]
750        match self
751            .transform
752            .transform_into(into.vector_mut(), &preprocessed.shifted, allocator)
753        {
754            Ok(()) => {}
755            Err(TransformFailed::AllocatorError(err)) => {
756                return Err(CompressionError::AllocatorError(err));
757            }
758            Err(TransformFailed::SourceMismatch { .. })
759            | Err(TransformFailed::DestinationMismatch { .. }) => {
760                panic!(
761                    "The sizes of these arrays should already be checked - this is a logic error"
762                );
763            }
764            #[cfg(feature = "linalg")]
765            Err(TransformFailed::SgemmError(_)) => {
766                panic!("SGEMM should not fail with valid dimensions - this is a logic error");
767            }
768        }
769
770        *into.meta_mut() = FullQueryMeta {
771            sum: into.vector().iter().sum::<f32>(),
772            shifted_norm: preprocessed.shifted_norm,
773            metric_specific: preprocessed.metric_specific(),
774        };
775        Ok(())
776    }
777}
778
779impl<const NBITS: usize, A> CompressIntoWith<&[f32], DataMut<'_, NBITS>, ScopedAllocator<'_>>
780    for SphericalQuantizer<A>
781where
782    A: Allocator,
783    Unsigned: Representation<NBITS>,
784    for<'a> DataMut<'a, NBITS>: FinishCompressing,
785{
786    type Error = CompressionError;
787
788    /// Compress the input vector `from` into the bitslice `into`.
789    ///
790    /// # Error
791    ///
792    /// Returns an error if
793    /// * The input contains `NaN`.
794    /// * `from.len() != self.dim()`: Vector to be compressed must have the same
795    ///   dimensionality as the quantizer.
796    /// * `into.len() != self.output_dim()`: Compressed vector must have the same
797    ///   dimensionality as the output of the distance-preserving transform. Importantely,
798    ///   this **may** be different than `self.dim()` and should be retrieved from
799    ///   `self.output_dim()`.
800    fn compress_into_with(
801        &self,
802        from: &[f32],
803        mut into: DataMut<'_, NBITS>,
804        allocator: ScopedAllocator<'_>,
805    ) -> Result<(), Self::Error> {
806        let input_dim = self.shift.len();
807        let output_dim = self.output_dim();
808        check_dims(input_dim, output_dim, from.len(), into.len())?;
809
810        let mut preprocessed = self.preprocess(from, allocator)?;
811
812        if preprocessed.shifted_norm == 0.0 {
813            into.set_meta(DataMeta::default());
814            return Ok(());
815        }
816
817        let mut transformed = Poly::broadcast(0.0f32, output_dim, allocator)?;
818        preprocessed
819            .shifted
820            .iter_mut()
821            .for_each(|v| *v /= preprocessed.shifted_norm);
822
823        // Transformation can fail due to OOM - we want to handle that gracefully.
824        //
825        // If the transformation fails because we provided the wrong sizes, that is a hard
826        // program bug.
827        #[expect(clippy::panic, reason = "the dimensions should already be as expected")]
828        match self
829            .transform
830            .transform_into(&mut transformed, &preprocessed.shifted, allocator)
831        {
832            Ok(()) => {}
833            Err(TransformFailed::AllocatorError(err)) => {
834                return Err(CompressionError::AllocatorError(err));
835            }
836            Err(TransformFailed::SourceMismatch { .. })
837            | Err(TransformFailed::DestinationMismatch { .. }) => {
838                panic!(
839                    "The sizes of these arrays should already be checked - this is a logic error"
840                );
841            }
842            #[cfg(feature = "linalg")]
843            Err(TransformFailed::SgemmError(_)) => {
844                panic!("SGEMM should not fail with valid dimensions - this is a logic error");
845            }
846        }
847
848        let transformed_norm = if self.transform.preserves_norms() {
849            1.0
850        } else {
851            (FastL2Norm).evaluate(&*transformed)
852        };
853
854        into.finish_compressing(&preprocessed, &transformed, transformed_norm, allocator)?;
855        Ok(())
856    }
857}
858
859struct AsNonZero<const NBITS: usize>;
860impl<const NBITS: usize> AsNonZero<NBITS> {
861    // Lint: Unwrap is being used in a const-context.
862    #[allow(clippy::unwrap_used)]
863    const NON_ZERO: NonZeroUsize = NonZeroUsize::new(NBITS).unwrap();
864}
865
866fn compress_via_maximum_cosine<const NBITS: usize>(
867    mut data: DataMut<'_, NBITS>,
868    preprocessed: &Preprocessed<'_>,
869    transformed: &[f32],
870    transformed_norm: f32,
871    allocator: ScopedAllocator<'_>,
872) -> Result<(), CompressionError>
873where
874    Unsigned: Representation<NBITS>,
875{
876    assert_eq!(data.len(), transformed.len());
877
878    // Find the value we will use to multiply `transformed` to round it to the lattice
879    // element that has the maximum cosine-similarity.
880    let optimal_scale =
881        maximize_cosine_similarity(transformed, AsNonZero::<NBITS>::NON_ZERO, allocator)?;
882
883    let domain = Unsigned::domain_const::<NBITS>();
884    let min = *domain.start() as f32;
885    let max = *domain.end() as f32;
886    let offset = max / 2.0;
887
888    let mut self_inner_product = 0.0f32;
889    let mut bit_sum = 0u32;
890    for (i, t) in transformed.iter().enumerate() {
891        let v = (*t * optimal_scale + offset).clamp(min, max).round();
892        let dv = v - offset;
893        self_inner_product = dv.mul_add(*t, self_inner_product);
894
895        let v = v as u8;
896        bit_sum += <u8 as Into<u32>>::into(v);
897
898        // SAFETY: We have checked that `data.len() == transformed.len()`, so this access
899        // is in-bounds.
900        //
901        // Further, by construction, `v` is encodable by the `Unsigned`.
902        unsafe { data.vector_mut().set_unchecked(i, v) };
903    }
904
905    let shifted_norm = preprocessed.shifted_norm;
906    let inner_product_correction = (transformed_norm * shifted_norm) / self_inner_product;
907    data.set_meta(DataMeta::new(
908        inner_product_correction,
909        preprocessed.metric_specific(),
910        bit_sum,
911    )?);
912    Ok(())
913}
914
915// This struct does 2 things:
916//
917// 1. Records the index in `v` and `rounded` and the scaling parameter so that
918//   `value * v[position]` gets rounded to `rounded[position] + 1` while
919//   `(value - epsilon) * v[position]` is rounded to `rounded[position]` for some small
920//   epsilon.
921//
922//   Informally, "what's the smallest scaling factor so `v[position]` gets rounded to
923//   the next value.
924//
925// 2. Imposes a total ordering on `f32` values so it can be used in a `BinaryHeap`.
926//   Additionally, ordering is reverse so that `BinaryHeap` models a min-heap instead
927//   of a max-heap.
928#[derive(Debug, Clone, Copy)]
929struct Pair {
930    value: f32,
931    position: u32,
932}
933
934impl PartialEq for Pair {
935    fn eq(&self, other: &Self) -> bool {
936        self.value.eq(&other.value)
937    }
938}
939
940impl Eq for Pair {}
941impl PartialOrd for Pair {
942    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
943        Some(self.cmp(other))
944    }
945}
946impl Ord for Pair {
947    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
948        other
949            .value
950            .partial_cmp(&self.value)
951            .unwrap_or(std::cmp::Ordering::Equal)
952    }
953}
954
955/// This is a tricky function - please read carefully.
956///
957/// Given a vector `v` compute the scaling factor `s` such that the cosine similarity
958/// betwen `v` and `r` is maximized where `r` is defined as
959/// ```math
960/// let offset = (2^(num_bits) - 1) / 2;
961/// let r = (s * v + offset).round().clamp(0, 2^num_bits - 1) - offset
962/// ```
963///
964/// More informally, maximize the inner product between `v` and the points in a square
965/// lattice with 2^num_bits values in each dimension, centered around zero. This latice
966/// takes the values (+0.5, -0.5, +1.5, -1.5 ...) to give equal weight above and below zero.
967///
968/// It works by slowly increasing the factor `s` such that the rounding of only one
969/// dimension in `v` is changed at a time. A running tally of the cosine similarity is
970/// computed for each scaling factor until we've processed `D * 2^(num_bits - 1)` possible
971/// scaling factors, where `D` is the length of `v`.
972///
973/// The best scaling factor is returned.
974///
975/// Refer to algorithm 1 in <https://arxiv.org/pdf/2409.09913>.
976///
977/// # Panics
978///
979/// Panics is `v.is_empty()`.
980///
981/// # Implementation details
982///
983/// We work with the absolute value of the elements in the vector `v`.
984/// This does not affect the final result as the scaling works the same in both the
985/// positive and negative directions but simplifies the book keeping.
986fn maximize_cosine_similarity(
987    v: &[f32],
988    num_bits: NonZeroUsize,
989    allocator: ScopedAllocator<'_>,
990) -> Result<f32, AllocatorError> {
991    // Initially, the lattice element has the value `0.5` for all dimensions.
992    // This means the initial inner product between `v` and the rounded term is simply
993    // `0.5 * sum(abs.(v))`. The absolute value is used because the latice element is
994    // always in the direction of the components in `v`.
995    let mut current_ip = 0.5 * v.iter().map(|i| i.abs() as f64).sum::<f64>();
996    let mut current_square_norm = 0.25 * (v.len() as f64);
997
998    // Book keeping for the current value of the rounded vector.
999    // The true numeric value is 0.5 less than this (in the direction of `v`), but we use
1000    // integers for a smaller memory footprint.
1001    let mut rounded = Poly::broadcast(1u16, v.len(), allocator)?;
1002
1003    // Compute the critical values and store them on a heap.
1004    //
1005    // The binary heap will keep track of the minimum critical value. Multiplying `v` by the
1006    // minimum critical value `s` means that `s * v` will only change `rounded` from its
1007    // current value at a single index (the position associated with `s`).
1008    let eps = 0.0001f32;
1009    let one_and_change = 1.0 + eps;
1010    let mut base = Poly::from_iter(
1011        v.iter().enumerate().map(|(position, value)| {
1012            let value = one_and_change / value.abs();
1013            Pair {
1014                value,
1015                position: position as u32,
1016            }
1017        }),
1018        allocator,
1019    )?;
1020
1021    // Lint: This is a private method and all the callers have an invariant that they check
1022    // for non-empty inputs.
1023    #[allow(clippy::expect_used)]
1024    let mut critical_values =
1025        SliceHeap::new(&mut base).expect("calling code should not allow the slice to be empty");
1026
1027    let mut max_similarity = f64::NEG_INFINITY;
1028    let mut optimal_scale = f32::default();
1029    let stop = (2usize).pow(num_bits.get() as u32 - 1) as u16;
1030
1031    loop {
1032        let mut should_break = false;
1033        critical_values.update_root(|pair| {
1034            let Pair { value, position } = *pair;
1035            if value == f32::MAX {
1036                should_break = true;
1037                return;
1038            }
1039
1040            let r = &mut rounded[position as usize];
1041            let vp = &v[position as usize];
1042
1043            let old_r = *r;
1044            // By the nature of cricital values, only `r` will change in `rounded` when
1045            // multiplying by `value`. And that change will be to increase by 1.
1046            *r += 1;
1047
1048            // The inner product estimate simply increases by `vp.abs()` because:
1049            //
1050            // * `r` is the only value in `rounded` that changes.
1051            // * `r` is increased by 1.
1052            current_ip += vp.abs() as f64;
1053
1054            // This uses the formula
1055            // ```math
1056            // (x + 1)^2 - x^2 = x^2 + 2x + 1 - x^2
1057            //                 = 2x + 1
1058            // ```
1059            // substitute `x = y - 1/2` to obtain the true value associated with rounded and
1060            // we get
1061            // ```math
1062            // 2 ( y - 1/2 ) + 1 = 2y - 1 + 1
1063            //                   = 2y
1064            // ```
1065            // Therefore, the change in the estimate for the square norm of `rounded` is
1066            // `2 * old_r`.
1067            current_square_norm += (2 * old_r) as f64;
1068
1069            // Compute the current cosine similarity and update max if needed.
1070            let similarity = current_ip / current_square_norm.sqrt();
1071            if similarity > max_similarity {
1072                max_similarity = similarity;
1073                optimal_scale = value;
1074            }
1075
1076            // Compute the scaling factor that will change this dimension to the next value.
1077            if *r < stop {
1078                *pair = Pair {
1079                    value: (*r as f32 + eps) / vp.abs(),
1080                    position,
1081                };
1082            } else {
1083                *pair = Pair {
1084                    value: f32::MAX,
1085                    position,
1086                };
1087            }
1088        });
1089        if should_break {
1090            break;
1091        }
1092    }
1093
1094    Ok(optimal_scale)
1095}
1096
1097///////////////////////
1098// Query Compression //
1099///////////////////////
1100
1101impl<const NBITS: usize, Perm, A>
1102    CompressIntoWith<&[f32], QueryMut<'_, NBITS, Perm>, ScopedAllocator<'_>>
1103    for SphericalQuantizer<A>
1104where
1105    Unsigned: Representation<NBITS>,
1106    Perm: PermutationStrategy<NBITS>,
1107    A: Allocator,
1108{
1109    type Error = CompressionError;
1110
1111    /// Compress the input vector `from` into the bitslice `into`.
1112    ///
1113    /// # Error
1114    ///
1115    /// Returns an error if
1116    /// * The input contains `NaN`.
1117    /// * `from.len() != self.dim()`: Vector to be compressed must have the same
1118    ///   dimensionality as the quantizer.
1119    /// * `into.len() != self.output_dim()`: Compressed vector must have the same
1120    ///   dimensionality as the output of the distance-preserving transform. Importantely,
1121    ///   this **may** be different than `self.dim()` and should be retrieved from
1122    ///   `self.output_dim()`.
1123    fn compress_into_with(
1124        &self,
1125        from: &[f32],
1126        mut into: QueryMut<'_, NBITS, Perm>,
1127        allocator: ScopedAllocator<'_>,
1128    ) -> Result<(), Self::Error> {
1129        let input_dim = self.shift.len();
1130        let output_dim = self.output_dim();
1131        check_dims(input_dim, output_dim, from.len(), into.len())?;
1132
1133        let mut preprocessed = self.preprocess(from, allocator)?;
1134
1135        if preprocessed.shifted_norm == 0.0 {
1136            into.set_meta(QueryMeta::default());
1137            return Ok(());
1138        }
1139
1140        preprocessed
1141            .shifted
1142            .iter_mut()
1143            .for_each(|v| *v /= preprocessed.shifted_norm);
1144
1145        let mut transformed = Poly::broadcast(0.0f32, output_dim, allocator)?;
1146
1147        // Transformation can fail due to OOM - we want to handle that gracefully.
1148        //
1149        // If the transformation fails because we provided the wrong sizes, that is a hard
1150        // program bug.
1151        #[expect(clippy::panic, reason = "the dimensions should already be as expected")]
1152        match self
1153            .transform
1154            .transform_into(&mut transformed, &preprocessed.shifted, allocator)
1155        {
1156            Ok(()) => {}
1157            Err(TransformFailed::AllocatorError(err)) => {
1158                return Err(CompressionError::AllocatorError(err));
1159            }
1160            Err(TransformFailed::SourceMismatch { .. })
1161            | Err(TransformFailed::DestinationMismatch { .. }) => {
1162                panic!(
1163                    "The sizes of these arrays should already be checked - this is a logic error"
1164                );
1165            }
1166            #[cfg(feature = "linalg")]
1167            Err(TransformFailed::SgemmError(_)) => {
1168                panic!("SGEMM should not fail with valid dimensions - this is a logic error");
1169            }
1170        }
1171
1172        // Compute the minimum and maximum values of the transformed vector.
1173        let (min, max) = transformed
1174            .iter()
1175            .fold((f32::MAX, f32::MIN), |(min, max), i| {
1176                (i.min(min), i.max(max))
1177            });
1178
1179        let domain = Unsigned::domain_const::<NBITS>();
1180        let lo = (*domain.start()) as f32;
1181        let hi = (*domain.end()) as f32;
1182
1183        let scale = (max - min) / hi;
1184        let mut bit_sum: f32 = 0.0;
1185        transformed.iter().enumerate().for_each(|(i, v)| {
1186            let c = ((v - min) / scale).round().clamp(lo, hi);
1187            bit_sum += c;
1188
1189            // Lint: We have verified that `into.len() == transformed.len()`, so the index
1190            // `i` is in bounds.
1191            //
1192            // Further, `c` has beem clamped to `[0, 2^NBITS - 1]` and is thus encodable
1193            // with the NBITS-bit unsigned representation.
1194            #[allow(clippy::unwrap_used)]
1195            into.vector_mut().set(i, c as i64).unwrap();
1196        });
1197
1198        // Finish up the compensation terms.
1199        into.set_meta(QueryMeta {
1200            inner_product_correction: preprocessed.shifted_norm * scale,
1201            bit_sum,
1202            offset: min / scale,
1203            metric_specific: preprocessed.metric_specific(),
1204        });
1205
1206        Ok(())
1207    }
1208}
1209
1210///////////
1211// Tests //
1212///////////
1213
1214#[cfg(not(miri))]
1215#[cfg(test)]
1216mod tests {
1217    use super::*;
1218
1219    use std::fmt::Display;
1220
1221    use diskann_utils::{
1222        ReborrowMut, lazy_format,
1223        views::{self, Matrix},
1224    };
1225    use diskann_vector::{PureDistanceFunction, norm::FastL2NormSquared};
1226    use diskann_wide::ARCH;
1227    use rand::{
1228        SeedableRng,
1229        distr::{Distribution, Uniform},
1230        rngs::StdRng,
1231    };
1232    use rand_distr::StandardNormal;
1233
1234    use crate::{
1235        algorithms::transforms::TargetDim,
1236        alloc::GlobalAllocator,
1237        bits::{BitTranspose, Dense},
1238        spherical::{Data, DataMetaF32, FullQuery, Query},
1239        test_util,
1240    };
1241
1242    // Test cosine-similarity maximizer
1243    #[test]
1244    fn test_cosine_similarity_maximizer() {
1245        let mut rng = StdRng::seed_from_u64(0x070d9ff8cf5e0f8c);
1246        let num_trials = 10000;
1247        let num_bits = NonZeroUsize::new(3).unwrap();
1248
1249        let scale_distribution = Uniform::new(0.5f32, 10.0f32).unwrap();
1250
1251        let run_test = |target: [f32; 4]| {
1252            let scale =
1253                maximize_cosine_similarity(&target, num_bits, ScopedAllocator::global()).unwrap();
1254
1255            let mut best: [f32; 4] = [0.0, 0.0, 0.0, 0.0];
1256            let mut best_similarity: f32 = f32::NEG_INFINITY;
1257
1258            // This crazy series of nested loops performs an exhaustive search over the
1259            // encoding space.
1260            let min = -3.5;
1261            for i0 in 0..8 {
1262                for i1 in 0..8 {
1263                    for i2 in 0..8 {
1264                        for i3 in 0..8 {
1265                            let p: [f32; 4] = [
1266                                min + (i0 as f32),
1267                                min + (i1 as f32),
1268                                min + (i2 as f32),
1269                                min + (i3 as f32),
1270                            ];
1271
1272                            let sim: MathematicalValue<f32> =
1273                                diskann_vector::distance::Cosine::evaluate(&p, &target);
1274                            let sim = sim.into_inner();
1275                            if sim > best_similarity {
1276                                best_similarity = sim;
1277                                // Transform into an integer starting at zero.
1278                                best = p.map(|i| i - min);
1279                            }
1280                        }
1281                    }
1282                }
1283            }
1284
1285            // Now, rescale the input vector, clamp, and round.
1286            // Check if they agree.
1287            let clamped = target.map(|i| (i * scale - min).round().clamp(0.0, 7.0));
1288            let clamped_cosine: MathematicalValue<f32> =
1289                diskann_vector::distance::Cosine::evaluate(&clamped.map(|i| i + min), &target);
1290
1291            // We expect to either get the best value found via exhaustive search, or some
1292            // scalar multiple of it (since that will have the same cosine similarity).
1293            let passed = if best == clamped {
1294                true
1295            } else {
1296                let ratio: Vec<f32> = std::iter::zip(best, clamped)
1297                    .map(|(b, c)| {
1298                        let ratio = (b + min) / (c + min);
1299                        assert_ne!(
1300                            ratio, 0.0,
1301                            "ratio should never be zero because `b` is an integer and \
1302                             `min` is not"
1303                        );
1304                        ratio
1305                    })
1306                    .collect();
1307
1308                ratio.iter().all(|i| *i == ratio[0])
1309            };
1310
1311            if !passed {
1312                panic!(
1313                    "failed for input {:?}.\
1314                     Best = {:?}, Found = {:?}\
1315                     Best similarity = {}, similarity with clamped = {}",
1316                    target,
1317                    best,
1318                    clamped,
1319                    best_similarity,
1320                    clamped_cosine.into_inner()
1321                );
1322            }
1323        };
1324
1325        // Run targeted tests.
1326        let min = -3.5;
1327        for i0 in (0..8).step_by(2) {
1328            for i1 in (1..9).step_by(2) {
1329                for i2 in (0..8).step_by(2) {
1330                    for i3 in (1..9).step_by(2) {
1331                        let p: [f32; 4] = [
1332                            min + (i0 as f32),
1333                            min + (i1 as f32),
1334                            min + (i2 as f32),
1335                            min + (i3 as f32),
1336                        ];
1337                        run_test(p)
1338                    }
1339                }
1340            }
1341        }
1342
1343        for _ in 0..num_trials {
1344            let this_scale: f32 = scale_distribution.sample(&mut rng);
1345            let v: [f32; 4] = [(); 4].map(|_| {
1346                let v: f32 = StandardNormal {}.sample(&mut rng);
1347                this_scale * v
1348            });
1349            run_test(v);
1350        }
1351    }
1352
1353    #[test]
1354    #[should_panic(expected = "calling code should not allow the slice to be empty")]
1355    fn empty_slice_panics() {
1356        maximize_cosine_similarity(
1357            &[],
1358            NonZeroUsize::new(4).unwrap(),
1359            ScopedAllocator::global(),
1360        )
1361        .unwrap();
1362    }
1363
1364    struct Setup {
1365        transform: TransformKind,
1366        nrows: usize,
1367        ncols: usize,
1368        num_trials: usize,
1369    }
1370
1371    fn get_scale(scale: PreScale, quantizer: &SphericalQuantizer) -> f32 {
1372        match scale {
1373            PreScale::None => 1.0,
1374            PreScale::Some(v) => v.into_inner(),
1375            PreScale::ReciprocalMeanNorm => 1.0 / quantizer.mean_norm().into_inner(),
1376        }
1377    }
1378
1379    fn test_l2<const Q: usize, const D: usize, Perm>(
1380        setup: &Setup,
1381        problem: &test_util::TestProblem,
1382        computed_means: &[f32],
1383        pre_scale: PreScale,
1384        rng: &mut StdRng,
1385    ) where
1386        Unsigned: Representation<Q>,
1387        Unsigned: Representation<D>,
1388        Perm: PermutationStrategy<Q>,
1389        for<'a> SphericalQuantizer:
1390            CompressIntoWith<&'a [f32], DataMut<'a, D>, ScopedAllocator<'a>>,
1391        for<'a> SphericalQuantizer:
1392            CompressIntoWith<&'a [f32], QueryMut<'a, Q, Perm>, ScopedAllocator<'a>>,
1393    {
1394        assert_eq!(setup.nrows, problem.data.nrows());
1395        assert_eq!(setup.ncols, problem.data.ncols());
1396
1397        let scoped_global = ScopedAllocator::global();
1398        let distribution = Uniform::new(0, setup.nrows).unwrap();
1399        let quantizer = SphericalQuantizer::train(
1400            problem.data.as_view(),
1401            setup.transform,
1402            SupportedMetric::SquaredL2,
1403            pre_scale,
1404            rng,
1405            GlobalAllocator,
1406        )
1407        .unwrap();
1408
1409        let scale = get_scale(pre_scale, &quantizer);
1410
1411        let mut b = Data::<D, _>::new_boxed(quantizer.output_dim());
1412        let mut q = Query::<Q, Perm, _>::new_boxed(quantizer.output_dim());
1413        let mut f = FullQuery::empty(quantizer.output_dim(), GlobalAllocator).unwrap();
1414
1415        assert_eq!(
1416            quantizer.mean_norm.into_inner(),
1417            problem.mean_norm as f32,
1418            "computed mean norm should not apply scale"
1419        );
1420        let scaled_means: Vec<_> = computed_means.iter().map(|i| scale * i).collect();
1421        assert_eq!(&*scaled_means, quantizer.shift());
1422
1423        let l2: CompensatedSquaredL2 = quantizer.as_functor();
1424        assert_eq!(l2.dim, quantizer.output_dim() as f32);
1425
1426        for _ in 0..setup.num_trials {
1427            let i = distribution.sample(rng);
1428            let v = problem.data.row(i);
1429
1430            quantizer
1431                .compress_into_with(v, b.reborrow_mut(), scoped_global)
1432                .unwrap();
1433            quantizer
1434                .compress_into_with(v, q.reborrow_mut(), scoped_global)
1435                .unwrap();
1436            quantizer
1437                .compress_into_with(v, f.reborrow_mut(), scoped_global)
1438                .unwrap();
1439
1440            let shifted: Vec<f32> = std::iter::zip(v.iter(), quantizer.shift().iter())
1441                .map(|(a, b)| scale * a - b)
1442                .collect();
1443
1444            // Check that the compensation coefficient were chosen correctly.
1445            {
1446                let DataMetaF32 {
1447                    inner_product_correction,
1448                    bit_sum,
1449                    metric_specific,
1450                } = b.meta().to_full(ARCH);
1451
1452                let shifted_square_norm = metric_specific;
1453
1454                // Check that the bit-count is correct. let bv = b.vector();
1455                let bv = b.vector();
1456                let s: usize = (0..bv.len()).map(|i| bv.get(i).unwrap() as usize).sum();
1457                assert_eq!(s, bit_sum as usize);
1458
1459                // Check that the shifted norm is correct.
1460                {
1461                    let expected = FastL2NormSquared.evaluate(&*shifted);
1462                    let err = (shifted_square_norm - expected).abs() / expected.abs();
1463                    assert!(
1464                        err < 5.0e-4, // twice the minimum normal f16 value.
1465                        "failed diff check, got {}, expected {} - relative error = {}",
1466                        shifted_square_norm,
1467                        expected,
1468                        err
1469                    );
1470                }
1471
1472                // Finaly, verify that the self-inner-product is clustered around 0.8 as
1473                // the RaBitQ paper suggests.
1474                if const { D == 1 } {
1475                    let self_inner_product = 2.0 * shifted_square_norm.sqrt()
1476                        / (inner_product_correction * (bv.len() as f32).sqrt());
1477                    assert!(
1478                        (self_inner_product - 0.8).abs() < 0.13,
1479                        "self inner-product should be close to 0.8. Instead, it's {}",
1480                        self_inner_product
1481                    );
1482                }
1483            }
1484
1485            {
1486                let QueryMeta {
1487                    inner_product_correction,
1488                    bit_sum,
1489                    offset,
1490                    metric_specific,
1491                } = q.meta();
1492
1493                let shifted_square_norm = metric_specific;
1494                let mut preprocessed = quantizer.preprocess(v, scoped_global).unwrap();
1495                preprocessed
1496                    .shifted
1497                    .iter_mut()
1498                    .for_each(|i| *i /= preprocessed.shifted_norm);
1499
1500                let mut transformed = vec![0.0f32; quantizer.output_dim()];
1501                quantizer
1502                    .transform
1503                    .transform_into(&mut transformed, &preprocessed.shifted, scoped_global)
1504                    .unwrap();
1505
1506                let min = transformed.iter().fold(f32::MAX, |min, &i| min.min(i));
1507                let max = transformed.iter().fold(f32::MIN, |max, &i| max.max(i));
1508
1509                let scale = (max - min) / ((2usize.pow(Q as u32) - 1) as f32);
1510
1511                // Shifted Norm
1512                {
1513                    let expected = FastL2NormSquared.evaluate(&*shifted);
1514                    let err = (shifted_square_norm - expected).abs() / expected.abs();
1515                    assert!(
1516                        err < 2e-7,
1517                        "failed diff check, got {}, expected {} - relative error = {}",
1518                        shifted_square_norm,
1519                        expected,
1520                        err
1521                    );
1522                }
1523
1524                // Inner product correction
1525                {
1526                    let expected = shifted_square_norm.sqrt() * scale;
1527                    let got = inner_product_correction;
1528
1529                    let err = (expected - got).abs();
1530                    assert!(
1531                        err < 1.0e-7,
1532                        "\"innerproduct_scale\": expected {}, got {}, error = {}",
1533                        expected,
1534                        got,
1535                        err
1536                    );
1537                }
1538
1539                // Offset
1540                {
1541                    let expected = min / scale;
1542                    let got = offset;
1543
1544                    let err = (expected - got).abs();
1545                    assert!(
1546                        err < 1.0e-7,
1547                        "\"sum_scale\": expected {}, got {}, error = {}",
1548                        expected,
1549                        got,
1550                        err
1551                    );
1552                }
1553
1554                // Bit Sum
1555                {
1556                    let expected = (0..q.len())
1557                        .map(|i| q.vector().get(i).unwrap())
1558                        .sum::<i64>() as f32;
1559
1560                    let got = bit_sum;
1561
1562                    let err = (expected - got).abs();
1563                    assert!(
1564                        err < 1.0e-7,
1565                        "\"offset\": expected {}, got {}, error = {}",
1566                        expected,
1567                        got,
1568                        err
1569                    );
1570                }
1571            }
1572
1573            // Check that the compensation coefficient were chosen correctly.
1574            {
1575                // Check that the bit-count is correct.
1576                let s: f32 = f.data.iter().sum::<f32>();
1577                assert_eq!(s, f.meta.sum);
1578
1579                // Check that the shifted norm is correct.
1580                {
1581                    let expected = FastL2Norm.evaluate(&*shifted);
1582                    let err = (f.meta.shifted_norm - expected).abs() / expected.abs();
1583                    assert!(
1584                        err < 2e-7,
1585                        "failed diff check, got {}, expected {} - relative error = {}",
1586                        f.meta.shifted_norm,
1587                        expected,
1588                        err
1589                    );
1590                }
1591
1592                assert_eq!(
1593                    f.meta.metric_specific,
1594                    f.meta.shifted_norm * f.meta.shifted_norm,
1595                    "metric specific data for squared l2 is the square shifted norm",
1596                );
1597            }
1598        }
1599
1600        // Finally - test that if we compress the centroid, the metadata coefficients get
1601        // zeroed correctly.
1602        quantizer
1603            .compress_into_with(computed_means, b.reborrow_mut(), scoped_global)
1604            .unwrap();
1605        assert_eq!(b.meta(), DataMeta::default());
1606
1607        quantizer
1608            .compress_into_with(computed_means, q.reborrow_mut(), scoped_global)
1609            .unwrap();
1610        assert_eq!(q.meta(), QueryMeta::default());
1611
1612        f.data.fill(f32::INFINITY);
1613        quantizer
1614            .compress_into_with(computed_means, f.reborrow_mut(), scoped_global)
1615            .unwrap();
1616        assert!(f.data.iter().all(|&i| i == 0.0));
1617        assert_eq!(f.meta.sum, 0.0);
1618        assert_eq!(f.meta.metric_specific, 0.0);
1619    }
1620
1621    fn test_ip<const Q: usize, const D: usize, Perm>(
1622        setup: &Setup,
1623        problem: &test_util::TestProblem,
1624        computed_means: &[f32],
1625        pre_scale: PreScale,
1626        rng: &mut StdRng,
1627        ctx: &dyn Display,
1628    ) where
1629        Unsigned: Representation<Q>,
1630        Unsigned: Representation<D>,
1631        Perm: PermutationStrategy<Q>,
1632        for<'a> SphericalQuantizer:
1633            CompressIntoWith<&'a [f32], DataMut<'a, D>, ScopedAllocator<'a>>,
1634        for<'a> SphericalQuantizer:
1635            CompressIntoWith<&'a [f32], QueryMut<'a, Q, Perm>, ScopedAllocator<'a>>,
1636    {
1637        assert_eq!(setup.nrows, problem.data.nrows());
1638        assert_eq!(setup.ncols, problem.data.ncols());
1639
1640        let scoped_global = ScopedAllocator::global();
1641        let distribution = Uniform::new(0, setup.nrows).unwrap();
1642        let quantizer = SphericalQuantizer::train(
1643            problem.data.as_view(),
1644            setup.transform,
1645            SupportedMetric::InnerProduct,
1646            pre_scale,
1647            rng,
1648            GlobalAllocator,
1649        )
1650        .unwrap();
1651
1652        let scale = get_scale(pre_scale, &quantizer);
1653
1654        let mut b = Data::<D, _>::new_boxed(quantizer.output_dim());
1655        let mut q = Query::<Q, Perm, _>::new_boxed(quantizer.output_dim());
1656        let mut f = FullQuery::empty(quantizer.output_dim(), GlobalAllocator).unwrap();
1657
1658        assert_eq!(
1659            quantizer.mean_norm.into_inner(),
1660            problem.mean_norm as f32,
1661            "computed mean norm should not apply scale"
1662        );
1663        let scaled_means: Vec<_> = computed_means.iter().map(|i| scale * i).collect();
1664        assert_eq!(&*scaled_means, quantizer.shift());
1665
1666        let ip: CompensatedIP = quantizer.as_functor();
1667
1668        assert_eq!(ip.dim, quantizer.output_dim() as f32);
1669        assert_eq!(
1670            ip.squared_shift_norm,
1671            FastL2NormSquared.evaluate(quantizer.shift())
1672        );
1673
1674        for _ in 0..setup.num_trials {
1675            let i = distribution.sample(rng);
1676            let v = problem.data.row(i);
1677
1678            quantizer
1679                .compress_into_with(v, b.reborrow_mut(), scoped_global)
1680                .unwrap();
1681            quantizer
1682                .compress_into_with(v, q.reborrow_mut(), scoped_global)
1683                .unwrap();
1684            quantizer
1685                .compress_into_with(v, f.reborrow_mut(), scoped_global)
1686                .unwrap();
1687
1688            let shifted: Vec<f32> = std::iter::zip(v.iter(), quantizer.shift().iter())
1689                .map(|(a, b)| scale * a - b)
1690                .collect();
1691
1692            // Check that the compensation coefficient were chosen correctly.
1693            {
1694                let DataMetaF32 {
1695                    inner_product_correction,
1696                    bit_sum,
1697                    metric_specific,
1698                } = b.meta().to_full(ARCH);
1699
1700                let inner_product_with_centroid = metric_specific;
1701
1702                // Check that the bit-count is correct.
1703                let bv = b.vector();
1704                let s: usize = (0..bv.len()).map(|i| bv.get(i).unwrap() as usize).sum();
1705                assert_eq!(s, bit_sum as usize);
1706
1707                // Check that the shifted norm is correct.
1708                let inner_product: MathematicalValue<f32> =
1709                    InnerProduct::evaluate(&*shifted, quantizer.shift());
1710
1711                let diff = (inner_product.into_inner() - inner_product_with_centroid).abs();
1712                assert!(
1713                    diff < 1.53e-5,
1714                    "got a diff of {}. Expected = {}, got = {} -- context: {}",
1715                    diff,
1716                    inner_product.into_inner(),
1717                    inner_product_with_centroid,
1718                    ctx,
1719                );
1720
1721                // Finaly, verify that the self-inner-product is clustered around 0.8 as
1722                // the RaBitQ paper suggests.
1723                if const { D == 1 } {
1724                    let self_inner_product = 2.0 * (FastL2Norm).evaluate(&*shifted)
1725                        / (inner_product_correction * (bv.len() as f32).sqrt());
1726                    assert!(
1727                        (self_inner_product - 0.8).abs() < 0.12,
1728                        "self inner-product should be close to 0.8. Instead, it's {}",
1729                        self_inner_product
1730                    );
1731                }
1732            }
1733
1734            {
1735                let QueryMeta {
1736                    inner_product_correction,
1737                    bit_sum,
1738                    offset,
1739                    metric_specific,
1740                } = q.meta();
1741
1742                let inner_product_with_centroid = metric_specific;
1743                let mut preprocessed = quantizer.preprocess(v, scoped_global).unwrap();
1744                preprocessed
1745                    .shifted
1746                    .iter_mut()
1747                    .for_each(|i| *i /= preprocessed.shifted_norm);
1748
1749                let mut transformed = vec![0.0f32; quantizer.output_dim()];
1750                quantizer
1751                    .transform
1752                    .transform_into(&mut transformed, &preprocessed.shifted, scoped_global)
1753                    .unwrap();
1754
1755                let min = transformed.iter().fold(f32::MAX, |min, &i| min.min(i));
1756                let max = transformed.iter().fold(f32::MIN, |max, &i| max.max(i));
1757
1758                let scale = (max - min) / ((2usize.pow(Q as u32) - 1) as f32);
1759
1760                // Inner product correction
1761                {
1762                    let expected = (FastL2Norm).evaluate(&*shifted) * scale;
1763                    let got = inner_product_correction;
1764
1765                    let err = (expected - got).abs();
1766                    assert!(
1767                        err < 1.0e-7,
1768                        "\"innerproduct_scale\": expected {}, got {}, error = {}",
1769                        expected,
1770                        got,
1771                        err
1772                    );
1773                }
1774
1775                // Offset
1776                {
1777                    let expected = min / scale;
1778                    let got = offset;
1779
1780                    let err = (expected - got).abs();
1781                    assert!(
1782                        err < 1.0e-7,
1783                        "\"sum_scale\": expected {}, got {}, error = {}",
1784                        expected,
1785                        got,
1786                        err
1787                    );
1788                }
1789
1790                // Bit Sum
1791                {
1792                    let expected = (0..q.len())
1793                        .map(|i| q.vector().get(i).unwrap())
1794                        .sum::<i64>() as f32;
1795
1796                    let got = bit_sum;
1797
1798                    let err = (expected - got).abs();
1799                    assert!(
1800                        err < 1.0e-7,
1801                        "\"offset\": expected {}, got {}, error = {}",
1802                        expected,
1803                        got,
1804                        err
1805                    );
1806                }
1807
1808                // Inner Product with Centroid
1809                {
1810                    // Check that the shifted norm is correct.
1811                    let inner_product: MathematicalValue<f32> =
1812                        InnerProduct::evaluate(&*shifted, quantizer.shift());
1813                    assert_eq!(inner_product.into_inner(), inner_product_with_centroid);
1814                }
1815            }
1816
1817            // Check that the compensation coefficient were chosen correctly.
1818            {
1819                // Check that the bit-count is correct.
1820                let s: f32 = f.data.iter().sum::<f32>();
1821                assert_eq!(s, f.meta.sum);
1822
1823                // Check that the shifted norm is correct.
1824                {
1825                    let expected = FastL2Norm.evaluate(&*shifted);
1826                    let err = (f.meta.shifted_norm - expected).abs() / expected.abs();
1827                    assert!(
1828                        err < 2e-7,
1829                        "failed diff check, got {}, expected {} - relative error = {}",
1830                        f.meta.shifted_norm,
1831                        expected,
1832                        err
1833                    );
1834                }
1835
1836                // Check that the shifted norm is correct. s
1837                let inner_product: MathematicalValue<f32> =
1838                    InnerProduct::evaluate(&*shifted, quantizer.shift());
1839                assert_eq!(inner_product.into_inner(), f.meta.metric_specific,);
1840            }
1841        }
1842
1843        // Finally - test that if we compress the centroid, the metadata coefficients get
1844        // zeroed correctly.
1845        quantizer
1846            .compress_into_with(computed_means, b.reborrow_mut(), scoped_global)
1847            .unwrap();
1848        assert_eq!(b.meta(), DataMeta::default());
1849
1850        quantizer
1851            .compress_into_with(computed_means, q.reborrow_mut(), scoped_global)
1852            .unwrap();
1853        assert_eq!(q.meta(), QueryMeta::default());
1854
1855        f.data.fill(f32::INFINITY);
1856        quantizer
1857            .compress_into_with(computed_means, f.reborrow_mut(), scoped_global)
1858            .unwrap();
1859        assert!(f.data.iter().all(|&i| i == 0.0));
1860        assert_eq!(f.meta.sum, 0.0);
1861        assert_eq!(f.meta.metric_specific, 0.0);
1862    }
1863
1864    fn test_cosine<const Q: usize, const D: usize, Perm>(
1865        setup: &Setup,
1866        problem: &test_util::TestProblem,
1867        pre_scale: PreScale,
1868        rng: &mut StdRng,
1869    ) where
1870        Unsigned: Representation<Q>,
1871        Unsigned: Representation<D>,
1872        Perm: PermutationStrategy<Q>,
1873        for<'a> SphericalQuantizer:
1874            CompressIntoWith<&'a [f32], DataMut<'a, D>, ScopedAllocator<'a>>,
1875        for<'a> SphericalQuantizer:
1876            CompressIntoWith<&'a [f32], QueryMut<'a, Q, Perm>, ScopedAllocator<'a>>,
1877    {
1878        assert_eq!(setup.nrows, problem.data.nrows());
1879        assert_eq!(setup.ncols, problem.data.ncols());
1880
1881        let scoped_global = ScopedAllocator::global();
1882        let distribution = Uniform::new(0, setup.nrows).unwrap();
1883        let quantizer = SphericalQuantizer::train(
1884            problem.data.as_view(),
1885            setup.transform,
1886            SupportedMetric::Cosine,
1887            pre_scale,
1888            rng,
1889            GlobalAllocator,
1890        )
1891        .unwrap();
1892
1893        let mut b = Data::<D, _>::new_boxed(quantizer.output_dim());
1894        let mut q = Query::<Q, Perm, _>::new_boxed(quantizer.output_dim());
1895        let mut f = FullQuery::empty(quantizer.output_dim(), GlobalAllocator).unwrap();
1896
1897        let cosine: CompensatedCosine = quantizer.as_functor();
1898
1899        assert_eq!(cosine.inner.dim, quantizer.output_dim() as f32);
1900        assert_eq!(
1901            cosine.inner.squared_shift_norm,
1902            FastL2NormSquared.evaluate(quantizer.shift())
1903        );
1904
1905        const IP_BOUND: f32 = 2.6e-3;
1906
1907        let mut test_row = |v: &[f32]| {
1908            let vnorm = (FastL2Norm).evaluate(v);
1909            let v_normalized: Vec<f32> = v
1910                .iter()
1911                .map(|i| if vnorm == 0.0 { 0.0 } else { *i / vnorm })
1912                .collect();
1913
1914            quantizer
1915                .compress_into_with(v, b.reborrow_mut(), scoped_global)
1916                .unwrap();
1917
1918            quantizer
1919                .compress_into_with(v, q.reborrow_mut(), scoped_global)
1920                .unwrap();
1921
1922            quantizer
1923                .compress_into_with(v, f.reborrow_mut(), scoped_global)
1924                .unwrap();
1925
1926            let shifted: Vec<f32> = std::iter::zip(v_normalized.iter(), quantizer.shift().iter())
1927                .map(|(a, b)| a - b)
1928                .collect();
1929
1930            // Check that the compensation coefficient were chosen correctly.
1931            {
1932                let DataMetaF32 {
1933                    inner_product_correction,
1934                    bit_sum,
1935                    metric_specific,
1936                } = b.meta().to_full(ARCH);
1937
1938                let inner_product_with_centroid = metric_specific;
1939
1940                // Check that the bit-count is correct.
1941                let bv = b.vector();
1942                let s: usize = (0..bv.len()).map(|i| bv.get(i).unwrap() as usize).sum();
1943                assert_eq!(s, bit_sum as usize);
1944
1945                // Check that the shifted norm is correct. Since they are computed slightly
1946                // differnetly, allow a small amount of error.
1947                let inner_product: MathematicalValue<f32> =
1948                    InnerProduct::evaluate(&*shifted, quantizer.shift());
1949
1950                let abs = (inner_product.into_inner() - inner_product_with_centroid).abs();
1951                let relative = abs / inner_product.into_inner().abs();
1952
1953                assert!(
1954                    abs < 1e-7 || relative < IP_BOUND,
1955                    "got an abs/rel of {}/{} with a bound of {}/{}",
1956                    abs,
1957                    relative,
1958                    1e-7,
1959                    IP_BOUND
1960                );
1961
1962                // Finaly, verify that the self-inner-product is clustered around 0.8 as
1963                // the RaBitQ paper suggests.
1964                if const { D == 1 } {
1965                    let self_inner_product = 2.0 * (FastL2Norm).evaluate(&*shifted)
1966                        / (inner_product_correction * (bv.len() as f32).sqrt());
1967                    assert!(
1968                        (self_inner_product - 0.8).abs() < 0.11,
1969                        "self inner-product should be close to 0.8. Instead, it's {}",
1970                        self_inner_product
1971                    );
1972                }
1973            }
1974
1975            {
1976                let QueryMeta {
1977                    inner_product_correction,
1978                    bit_sum,
1979                    offset,
1980                    metric_specific,
1981                } = q.meta();
1982
1983                let inner_product_with_centroid = metric_specific;
1984                let mut preprocessed = quantizer.preprocess(v, scoped_global).unwrap();
1985                preprocessed
1986                    .shifted
1987                    .iter_mut()
1988                    .for_each(|i| *i /= preprocessed.shifted_norm);
1989
1990                let mut transformed = vec![0.0f32; quantizer.output_dim()];
1991                quantizer
1992                    .transform
1993                    .transform_into(&mut transformed, &preprocessed.shifted, scoped_global)
1994                    .unwrap();
1995
1996                let min = transformed.iter().fold(f32::MAX, |min, &i| min.min(i));
1997                let max = transformed.iter().fold(f32::MIN, |max, &i| max.max(i));
1998
1999                let scale = (max - min) / ((2usize.pow(Q as u32) - 1) as f32);
2000
2001                // Inner product correction
2002                {
2003                    let expected = (FastL2Norm).evaluate(&*shifted) * scale;
2004                    let got = inner_product_correction;
2005
2006                    let err = (expected - got).abs();
2007                    assert!(
2008                        err < 1.0e-7,
2009                        "\"innerproduct_scale\": expected {}, got {}, error = {}",
2010                        expected,
2011                        got,
2012                        err
2013                    );
2014                }
2015
2016                // Offset
2017                {
2018                    let expected = min / scale;
2019                    let got = offset;
2020
2021                    let err = (expected - got).abs();
2022                    assert!(
2023                        err < 1.0e-7,
2024                        "\"sum_scale\": expected {}, got {}, error = {}",
2025                        expected,
2026                        got,
2027                        err
2028                    );
2029                }
2030
2031                // Bit Sum
2032                {
2033                    let expected = (0..q.len())
2034                        .map(|i| q.vector().get(i).unwrap())
2035                        .sum::<i64>() as f32;
2036
2037                    let got = bit_sum;
2038
2039                    let err = (expected - got).abs();
2040                    assert!(
2041                        err < 1.0e-7,
2042                        "\"offset\": expected {}, got {}, error = {}",
2043                        expected,
2044                        got,
2045                        err
2046                    );
2047                }
2048
2049                // Inner Product with Centroid
2050                {
2051                    // Check that the shifted norm is correct.
2052                    let inner_product: MathematicalValue<f32> =
2053                        InnerProduct::evaluate(&*shifted, quantizer.shift());
2054
2055                    let err = (inner_product.into_inner() - inner_product_with_centroid).abs()
2056                        / inner_product.into_inner().abs();
2057                    assert!(
2058                        err < IP_BOUND,
2059                        "\"offset\": expected {}, got {}, error = {}",
2060                        inner_product.into_inner(),
2061                        inner_product_with_centroid,
2062                        err
2063                    );
2064                }
2065            }
2066
2067            // Check that the compensation coefficient were chosen correctly.
2068            {
2069                // Check that the bit-count is correct.
2070                let s: f32 = f.data.iter().sum::<f32>();
2071                assert_eq!(s, f.meta.sum);
2072
2073                // Check that the shifted norm is correct.
2074                {
2075                    let expected = FastL2Norm.evaluate(&*shifted);
2076                    let err = (f.meta.shifted_norm - expected).abs() / expected.abs();
2077                    assert!(
2078                        err < 2e-7,
2079                        "failed diff check, got {}, expected {} - relative error = {}",
2080                        f.meta.shifted_norm,
2081                        expected,
2082                        err
2083                    );
2084                }
2085
2086                // Check that the shifted norm is correct. s
2087                let inner_product: MathematicalValue<f32> =
2088                    InnerProduct::evaluate(&*shifted, quantizer.shift());
2089                let err = (inner_product.into_inner() - f.meta.metric_specific).abs()
2090                    / inner_product.into_inner().abs();
2091                assert!(
2092                    err < IP_BOUND,
2093                    "\"offset\": expected {}, got {}, error = {}",
2094                    inner_product.into_inner(),
2095                    f.meta.metric_specific,
2096                    err
2097                );
2098            }
2099        };
2100
2101        for _ in 0..setup.num_trials {
2102            let i = distribution.sample(rng);
2103            let v = problem.data.row(i);
2104            test_row(v);
2105        }
2106
2107        // Ensure that if a zero vector is provided that we do not divide by zero.
2108        let zero = vec![0.0f32; quantizer.input_dim()];
2109        test_row(&zero);
2110    }
2111
2112    fn _test_oom_resiliance<T>(quantizer: &SphericalQuantizer, data: &[f32], dst: &mut T)
2113    where
2114        for<'a> T: ReborrowMut<'a>,
2115        for<'a> SphericalQuantizer: CompressIntoWith<
2116                &'a [f32],
2117                <T as ReborrowMut<'a>>::Target,
2118                ScopedAllocator<'a>,
2119                Error = CompressionError,
2120            >,
2121    {
2122        let mut succeeded = false;
2123        let mut failed = false;
2124        for max_allocations in 0..10 {
2125            match quantizer.compress_into_with(
2126                data,
2127                dst.reborrow_mut(),
2128                ScopedAllocator::new(&test_util::LimitedAllocator::new(max_allocations)),
2129            ) {
2130                Ok(()) => {
2131                    succeeded = true;
2132                }
2133                Err(CompressionError::AllocatorError(_)) => {
2134                    failed = true;
2135                }
2136                Err(other) => {
2137                    panic!("received an unexpected error: {:?}", other);
2138                }
2139            }
2140        }
2141        assert!(succeeded);
2142        assert!(failed);
2143    }
2144
2145    fn test_oom_resiliance<const Q: usize, const D: usize, Perm>(
2146        setup: &Setup,
2147        problem: &test_util::TestProblem,
2148        pre_scale: PreScale,
2149        rng: &mut StdRng,
2150    ) where
2151        Unsigned: Representation<Q>,
2152        Unsigned: Representation<D>,
2153        Perm: PermutationStrategy<Q>,
2154        for<'a> SphericalQuantizer: CompressIntoWith<
2155                &'a [f32],
2156                DataMut<'a, D>,
2157                ScopedAllocator<'a>,
2158                Error = CompressionError,
2159            >,
2160        for<'a> SphericalQuantizer: CompressIntoWith<
2161                &'a [f32],
2162                QueryMut<'a, Q, Perm>,
2163                ScopedAllocator<'a>,
2164                Error = CompressionError,
2165            >,
2166    {
2167        assert_eq!(setup.nrows, problem.data.nrows());
2168        assert_eq!(setup.ncols, problem.data.ncols());
2169
2170        let quantizer = SphericalQuantizer::train(
2171            problem.data.as_view(),
2172            setup.transform,
2173            SupportedMetric::SquaredL2,
2174            pre_scale,
2175            rng,
2176            GlobalAllocator,
2177        )
2178        .unwrap();
2179
2180        // Data.
2181        let data = problem.data.row(0);
2182        _test_oom_resiliance::<Data<D, _>>(
2183            &quantizer,
2184            data,
2185            &mut Data::new_boxed(quantizer.output_dim()),
2186        );
2187        _test_oom_resiliance::<Query<Q, Perm, _>>(
2188            &quantizer,
2189            data,
2190            &mut Query::new_boxed(quantizer.output_dim()),
2191        );
2192        _test_oom_resiliance::<FullQuery<_>>(
2193            &quantizer,
2194            data,
2195            &mut FullQuery::empty(quantizer.output_dim(), GlobalAllocator).unwrap(),
2196        );
2197    }
2198
2199    fn test_quantizer<const Q: usize, const D: usize, Perm>(setup: &Setup, rng: &mut StdRng)
2200    where
2201        Unsigned: Representation<Q>,
2202        Unsigned: Representation<D>,
2203        Perm: PermutationStrategy<Q>,
2204        for<'a> SphericalQuantizer: CompressIntoWith<
2205                &'a [f32],
2206                DataMut<'a, D>,
2207                ScopedAllocator<'a>,
2208                Error = CompressionError,
2209            >,
2210        for<'a> SphericalQuantizer: CompressIntoWith<
2211                &'a [f32],
2212                QueryMut<'a, Q, Perm>,
2213                ScopedAllocator<'a>,
2214                Error = CompressionError,
2215            >,
2216    {
2217        let problem = test_util::create_test_problem(setup.nrows, setup.ncols, rng);
2218        let computed_means_f32: Vec<_> = problem.means.iter().map(|i| *i as f32).collect();
2219
2220        let scales = [
2221            PreScale::Some(Positive::new(1.0 / 1024.0).unwrap()),
2222            PreScale::Some(Positive::new(1.0 / 1024.0).unwrap()),
2223            PreScale::ReciprocalMeanNorm,
2224        ];
2225
2226        for scale in scales {
2227            let ctx = &lazy_format!("dim = {}, scale = {:?}", setup.ncols, scale);
2228
2229            test_l2::<Q, D, Perm>(setup, &problem, &computed_means_f32, scale, rng);
2230            test_ip::<Q, D, Perm>(setup, &problem, &computed_means_f32, scale, rng, ctx);
2231            test_cosine::<Q, D, Perm>(setup, &problem, scale, rng);
2232        }
2233
2234        test_oom_resiliance::<Q, D, Perm>(setup, &problem, PreScale::ReciprocalMeanNorm, rng);
2235    }
2236
2237    #[test]
2238    fn test_spherical_quantizer() {
2239        let mut rng = StdRng::seed_from_u64(0xab516aef1ce61640);
2240        for dim in [56, 72, 128, 255] {
2241            let setup = Setup {
2242                transform: TransformKind::PaddingHadamard {
2243                    target_dim: TargetDim::Same,
2244                },
2245                nrows: 64,
2246                ncols: dim,
2247                num_trials: 10,
2248            };
2249
2250            test_quantizer::<4, 1, BitTranspose>(&setup, &mut rng);
2251            test_quantizer::<2, 2, Dense>(&setup, &mut rng);
2252            test_quantizer::<4, 4, Dense>(&setup, &mut rng);
2253            test_quantizer::<8, 8, Dense>(&setup, &mut rng);
2254
2255            let setup = Setup {
2256                transform: TransformKind::DoubleHadamard {
2257                    target_dim: TargetDim::Same,
2258                },
2259                nrows: 64,
2260                ncols: dim,
2261                num_trials: 10,
2262            };
2263            test_quantizer::<4, 1, BitTranspose>(&setup, &mut rng);
2264            test_quantizer::<2, 2, Dense>(&setup, &mut rng);
2265            test_quantizer::<4, 4, Dense>(&setup, &mut rng);
2266            test_quantizer::<8, 8, Dense>(&setup, &mut rng);
2267        }
2268    }
2269
2270    ////////////
2271    // Errors //
2272    ////////////
2273
2274    #[test]
2275    fn err_dim_cannot_be_zero() {
2276        let data = Matrix::new(0.0f32, 10, 0);
2277        let mut rng = StdRng::seed_from_u64(0xe3e9f42ed9f15883);
2278        let err = SphericalQuantizer::train(
2279            data.as_view(),
2280            TransformKind::DoubleHadamard {
2281                target_dim: TargetDim::Same,
2282            },
2283            SupportedMetric::SquaredL2,
2284            PreScale::None,
2285            &mut rng,
2286            GlobalAllocator,
2287        )
2288        .unwrap_err();
2289        assert_eq!(err.to_string(), "data dim cannot be zero");
2290    }
2291
2292    #[test]
2293    fn err_norm_must_be_positive() {
2294        let data = Matrix::new(0.0f32, 10, 10);
2295        let mut rng = StdRng::seed_from_u64(0xe3e9f42ed9f15883);
2296        let err = SphericalQuantizer::train(
2297            data.as_view(),
2298            TransformKind::DoubleHadamard {
2299                target_dim: TargetDim::Same,
2300            },
2301            SupportedMetric::SquaredL2,
2302            PreScale::None,
2303            &mut rng,
2304            GlobalAllocator,
2305        )
2306        .unwrap_err();
2307        assert_eq!(err.to_string(), "norm must be positive");
2308    }
2309
2310    #[test]
2311    fn err_norm_cannot_be_infinity() {
2312        let mut data = Matrix::new(0.0f32, 10, 10);
2313        data[(2, 5)] = f32::INFINITY;
2314
2315        let mut rng = StdRng::seed_from_u64(0xe3e9f42ed9f15883);
2316        let err = SphericalQuantizer::train(
2317            data.as_view(),
2318            TransformKind::DoubleHadamard {
2319                target_dim: TargetDim::Same,
2320            },
2321            SupportedMetric::SquaredL2,
2322            PreScale::None,
2323            &mut rng,
2324            GlobalAllocator,
2325        )
2326        .unwrap_err();
2327        assert_eq!(err.to_string(), "computed norm contains infinity or NaN");
2328    }
2329
2330    #[test]
2331    fn err_reciprocal_norm_cannot_be_infinity() {
2332        let mut data = Matrix::new(0.0f32, 10, 10);
2333        data[(2, 5)] = 2.93863e-39;
2334
2335        let mut rng = StdRng::seed_from_u64(0xe3e9f42ed9f15883);
2336        let err = SphericalQuantizer::train(
2337            data.as_view(),
2338            TransformKind::DoubleHadamard {
2339                target_dim: TargetDim::Same,
2340            },
2341            SupportedMetric::SquaredL2,
2342            PreScale::ReciprocalMeanNorm,
2343            &mut rng,
2344            GlobalAllocator,
2345        )
2346        .unwrap_err();
2347        assert_eq!(err.to_string(), "reciprocal norm contains infinity or NaN");
2348    }
2349
2350    #[test]
2351    fn err_mean_norm_cannot_be_zero_generate() {
2352        let centroid = Poly::broadcast(0.0f32, 10, GlobalAllocator).unwrap();
2353        let mut rng = StdRng::seed_from_u64(0xe3e9f42ed9f15883);
2354        let err = SphericalQuantizer::generate(
2355            centroid,
2356            0.0,
2357            TransformKind::DoubleHadamard {
2358                target_dim: TargetDim::Same,
2359            },
2360            SupportedMetric::SquaredL2,
2361            None,
2362            &mut rng,
2363            GlobalAllocator,
2364        )
2365        .unwrap_err();
2366        assert_eq!(err.to_string(), "norm must be positive");
2367    }
2368
2369    #[test]
2370    fn err_scale_cannot_be_zero_generate() {
2371        let centroid = Poly::broadcast(0.0f32, 10, GlobalAllocator).unwrap();
2372        let mut rng = StdRng::seed_from_u64(0xe3e9f42ed9f15883);
2373        let err = SphericalQuantizer::generate(
2374            centroid,
2375            1.0,
2376            TransformKind::DoubleHadamard {
2377                target_dim: TargetDim::Same,
2378            },
2379            SupportedMetric::SquaredL2,
2380            Some(0.0),
2381            &mut rng,
2382            GlobalAllocator,
2383        )
2384        .unwrap_err();
2385        assert_eq!(err.to_string(), "pre-scale must be positive");
2386    }
2387
2388    #[test]
2389    fn compression_errors_data() {
2390        let mut rng = StdRng::seed_from_u64(0xe3e9f42ed9f15883);
2391        let data = Matrix::<f32>::new(views::Init(|| StandardNormal {}.sample(&mut rng)), 16, 12);
2392
2393        let quantizer = SphericalQuantizer::train(
2394            data.as_view(),
2395            TransformKind::PaddingHadamard {
2396                target_dim: TargetDim::Same,
2397            },
2398            SupportedMetric::SquaredL2,
2399            PreScale::None,
2400            &mut rng,
2401            GlobalAllocator,
2402        )
2403        .unwrap();
2404
2405        let scoped_global = ScopedAllocator::global();
2406
2407        // Input contains NaN.
2408        {
2409            let mut query: Vec<f32> = quantizer.shift().to_vec();
2410            let mut d = Data::<1, _>::new_boxed(quantizer.output_dim());
2411            let mut q = Query::<4, BitTranspose, _>::new_boxed(quantizer.output_dim());
2412
2413            for i in 0..query.len() {
2414                let last = query[i];
2415                for v in [f32::NAN, f32::INFINITY, f32::NEG_INFINITY] {
2416                    query[i] = v;
2417
2418                    let err = quantizer
2419                        .compress_into_with(&*query, d.reborrow_mut(), scoped_global)
2420                        .unwrap_err();
2421
2422                    assert_eq!(err.to_string(), "input contains NaN", "failed for {}", v);
2423
2424                    let err = quantizer
2425                        .compress_into_with(&*query, q.reborrow_mut(), scoped_global)
2426                        .unwrap_err();
2427
2428                    assert_eq!(err.to_string(), "input contains NaN", "failed for {}", v);
2429                }
2430                query[i] = last;
2431            }
2432        }
2433
2434        // Input has a large value.
2435        {
2436            let query: Vec<f32> = vec![1000000.0; quantizer.input_dim()];
2437            let mut d = Data::<1, _>::new_boxed(quantizer.output_dim());
2438
2439            let err = quantizer
2440                .compress_into_with(&*query, d.reborrow_mut(), scoped_global)
2441                .unwrap_err();
2442
2443            let expected = "encoding error - you may need to scale the entire dataset to reduce its dynamic range";
2444
2445            assert_eq!(err.to_string(), expected, "failed for {:?}", query);
2446        }
2447
2448        // Input length
2449        for len in [quantizer.input_dim() - 1, quantizer.input_dim() + 1] {
2450            let query = vec![0.0f32; len];
2451            let mut d = Data::<1, _>::new_boxed(quantizer.output_dim());
2452            let mut q = Query::<4, BitTranspose, _>::new_boxed(quantizer.output_dim());
2453
2454            let err = quantizer
2455                .compress_into_with(&*query, d.reborrow_mut(), scoped_global)
2456                .unwrap_err();
2457            assert_eq!(
2458                err,
2459                CompressionError::SourceDimensionMismatch {
2460                    expected: quantizer.input_dim(),
2461                }
2462            );
2463
2464            let err = quantizer
2465                .compress_into_with(&*query, q.reborrow_mut(), scoped_global)
2466                .unwrap_err();
2467            assert_eq!(
2468                err,
2469                CompressionError::SourceDimensionMismatch {
2470                    expected: quantizer.input_dim(),
2471                }
2472            );
2473        }
2474
2475        for len in [quantizer.output_dim() - 1, quantizer.output_dim() + 1] {
2476            let query = vec![0.0f32; quantizer.input_dim()];
2477            let mut d = Data::<1, _>::new_boxed(len);
2478            let mut q = Query::<4, BitTranspose, _>::new_boxed(len);
2479
2480            let err = quantizer
2481                .compress_into_with(&*query, d.reborrow_mut(), scoped_global)
2482                .unwrap_err();
2483            assert_eq!(
2484                err,
2485                CompressionError::DestinationDimensionMismatch {
2486                    expected: quantizer.output_dim(),
2487                }
2488            );
2489
2490            let err = quantizer
2491                .compress_into_with(&*query, q.reborrow_mut(), scoped_global)
2492                .unwrap_err();
2493            assert_eq!(
2494                err,
2495                CompressionError::DestinationDimensionMismatch {
2496                    expected: quantizer.output_dim(),
2497                }
2498            );
2499        }
2500    }
2501
2502    #[test]
2503    fn centroid_scaling_happens_in_generate() {
2504        let centroid = Poly::from_iter(
2505            [1088.6732f32, 1393.32, 1547.877].into_iter(),
2506            GlobalAllocator,
2507        )
2508        .unwrap();
2509        let mean_norm = 2359.27;
2510        let pre_scale = 1.0 / mean_norm;
2511
2512        let quantizer = SphericalQuantizer::generate(
2513            centroid,
2514            mean_norm,
2515            TransformKind::Null,
2516            SupportedMetric::InnerProduct,
2517            Some(pre_scale),
2518            &mut StdRng::seed_from_u64(10),
2519            GlobalAllocator,
2520        )
2521        .unwrap();
2522
2523        let mut v = Data::<4, _>::new_boxed(quantizer.input_dim());
2524        let data: &[f32] = &[1000.34, 1456.32, 1234.5446];
2525        assert!(
2526            quantizer
2527                .compress_into_with(data, v.reborrow_mut(), ScopedAllocator::global())
2528                .is_ok(),
2529            "if this failed, the likely culprit is exceeding the value of the 16-bit correction terms"
2530        );
2531    }
2532}
2533
2534#[cfg(feature = "flatbuffers")]
2535#[cfg(test)]
2536mod test_serialization {
2537    use rand::{SeedableRng, rngs::StdRng};
2538
2539    use super::*;
2540    use crate::{
2541        algorithms::transforms::TargetDim,
2542        flatbuffers::{self as fb, to_flatbuffer},
2543        poly, test_util,
2544    };
2545
2546    #[test]
2547    fn test_serialization_happy_path() {
2548        let mut rng = StdRng::seed_from_u64(0x070d9ff8cf5e0f8c);
2549        let problem = test_util::create_test_problem(10, 128, &mut rng);
2550
2551        let low = NonZeroUsize::new(100).unwrap();
2552        let high = NonZeroUsize::new(150).unwrap();
2553
2554        let kinds = [
2555            // Null
2556            TransformKind::Null,
2557            // Double Hadamard
2558            TransformKind::DoubleHadamard {
2559                target_dim: TargetDim::Same,
2560            },
2561            TransformKind::DoubleHadamard {
2562                target_dim: TargetDim::Natural,
2563            },
2564            TransformKind::DoubleHadamard {
2565                target_dim: TargetDim::Override(low),
2566            },
2567            TransformKind::DoubleHadamard {
2568                target_dim: TargetDim::Override(high),
2569            },
2570            // Padding Hadamard
2571            TransformKind::PaddingHadamard {
2572                target_dim: TargetDim::Same,
2573            },
2574            TransformKind::PaddingHadamard {
2575                target_dim: TargetDim::Natural,
2576            },
2577            TransformKind::PaddingHadamard {
2578                target_dim: TargetDim::Override(low),
2579            },
2580            TransformKind::PaddingHadamard {
2581                target_dim: TargetDim::Override(high),
2582            },
2583            // Random Rotation
2584            #[cfg(all(not(miri), feature = "linalg"))]
2585            TransformKind::RandomRotation {
2586                target_dim: TargetDim::Same,
2587            },
2588            #[cfg(all(not(miri), feature = "linalg"))]
2589            TransformKind::RandomRotation {
2590                target_dim: TargetDim::Natural,
2591            },
2592            #[cfg(all(not(miri), feature = "linalg"))]
2593            TransformKind::RandomRotation {
2594                target_dim: TargetDim::Override(low),
2595            },
2596            #[cfg(all(not(miri), feature = "linalg"))]
2597            TransformKind::RandomRotation {
2598                target_dim: TargetDim::Override(high),
2599            },
2600        ];
2601
2602        let pre_scales = [
2603            PreScale::None,
2604            PreScale::Some(Positive::new(0.5).unwrap()),
2605            PreScale::Some(Positive::new(1.0).unwrap()),
2606            PreScale::Some(Positive::new(1.5).unwrap()),
2607            PreScale::ReciprocalMeanNorm,
2608        ];
2609
2610        let alloc = GlobalAllocator;
2611        for kind in kinds.into_iter() {
2612            for metric in SupportedMetric::all() {
2613                for pre_scale in pre_scales {
2614                    let quantizer = SphericalQuantizer::train(
2615                        problem.data.as_view(),
2616                        kind,
2617                        metric,
2618                        pre_scale,
2619                        &mut rng,
2620                        alloc,
2621                    )
2622                    .unwrap();
2623
2624                    let data = to_flatbuffer(|buf| quantizer.pack(buf));
2625                    let proto =
2626                        flatbuffers::root::<fb::spherical::SphericalQuantizer>(&data).unwrap();
2627                    let reloaded = SphericalQuantizer::try_unpack(alloc, proto).unwrap();
2628                    assert_eq!(quantizer, reloaded, "failed on transform {:?}", kind);
2629                }
2630            }
2631        }
2632    }
2633
2634    #[test]
2635    fn test_error_checking() {
2636        let mut rng = StdRng::seed_from_u64(0x070d9ff8cf5e0f8c);
2637        let problem = test_util::create_test_problem(10, 128, &mut rng);
2638
2639        let transform = TransformKind::DoubleHadamard {
2640            target_dim: TargetDim::Same,
2641        };
2642
2643        let alloc = GlobalAllocator;
2644        let mut make_quantizer = || {
2645            SphericalQuantizer::train(
2646                problem.data.as_view(),
2647                transform,
2648                SupportedMetric::SquaredL2,
2649                PreScale::None,
2650                &mut rng,
2651                alloc,
2652            )
2653            .unwrap()
2654        };
2655
2656        type E = DeserializationError;
2657
2658        // Missing norm: 0.0
2659        {
2660            let mut quantizer = make_quantizer();
2661            // SAFETY: We do not do anything with the created value and the compiler
2662            // does not know about the layout of `Positive`, so we don't need to worry
2663            // about violating layout restrictions.
2664            quantizer.mean_norm = unsafe { Positive::new_unchecked(0.0) };
2665
2666            let data = to_flatbuffer(|buf| quantizer.pack(buf));
2667            let proto = flatbuffers::root::<fb::spherical::SphericalQuantizer>(&data).unwrap();
2668            let err = SphericalQuantizer::try_unpack(alloc, proto).unwrap_err();
2669            assert_eq!(err, E::MissingNorm);
2670        }
2671
2672        // Missing norm: negative
2673        {
2674            let mut quantizer = make_quantizer();
2675
2676            // SAFETY: We do not do anything with the created value and the compiler
2677            // does not know about the layout of `Positive`, so we don't need to worry
2678            // about violating layout restrictions.
2679            quantizer.mean_norm = unsafe { Positive::new_unchecked(-1.0) };
2680
2681            let data = to_flatbuffer(|buf| quantizer.pack(buf));
2682            let proto = flatbuffers::root::<fb::spherical::SphericalQuantizer>(&data).unwrap();
2683            let err = SphericalQuantizer::try_unpack(alloc, proto).unwrap_err();
2684            assert_eq!(err, E::MissingNorm);
2685        }
2686
2687        // PreScaleNotPositive
2688        {
2689            let mut quantizer = make_quantizer();
2690
2691            // SAFETY: This really isn't safe, but we are not using the improper value in a
2692            // way that will trigger undefined behavior.
2693            quantizer.pre_scale = unsafe { Positive::new_unchecked(0.0) };
2694
2695            let data = to_flatbuffer(|buf| quantizer.pack(buf));
2696            let proto = flatbuffers::root::<fb::spherical::SphericalQuantizer>(&data).unwrap();
2697            let err = SphericalQuantizer::try_unpack(alloc, proto).unwrap_err();
2698            assert_eq!(err, E::PreScaleNotPositive);
2699        }
2700
2701        // Dim Mismatch.
2702        {
2703            let mut quantizer = make_quantizer();
2704            quantizer.shift = poly!([1.0, 2.0, 3.0], alloc).unwrap();
2705
2706            let data = to_flatbuffer(|buf| quantizer.pack(buf));
2707            let proto = flatbuffers::root::<fb::spherical::SphericalQuantizer>(&data).unwrap();
2708            let err = SphericalQuantizer::try_unpack(alloc, proto).unwrap_err();
2709            assert_eq!(err, E::DimMismatch);
2710        }
2711    }
2712}