Skip to main content

diskann_quantization/spherical/
quantizer.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::num::NonZeroUsize;
7
8use diskann_utils::{ReborrowMut, views::MatrixView};
9use diskann_vector::{
10    MathematicalValue, Norm, PureDistanceFunction, distance::InnerProduct, norm::FastL2Norm,
11};
12#[cfg(feature = "flatbuffers")]
13use flatbuffers::{FlatBufferBuilder, WIPOffset};
14use rand::{Rng, RngCore};
15use thiserror::Error;
16
17use super::{
18    CompensatedCosine, CompensatedIP, CompensatedSquaredL2, DataMeta, DataMetaError, DataMut,
19    FullQueryMeta, FullQueryMut, QueryMeta, QueryMut, SupportedMetric,
20};
21use crate::{
22    AsFunctor, CompressIntoWith,
23    algorithms::{
24        heap::SliceHeap,
25        transforms::{NewTransformError, Transform, TransformFailed, TransformKind},
26    },
27    alloc::{Allocator, AllocatorError, GlobalAllocator, Poly, ScopedAllocator, TryClone},
28    bits::{PermutationStrategy, Representation, Unsigned},
29    num::Positive,
30    utils::{CannotBeEmpty, compute_means_and_average_norm, compute_normalized_means},
31};
32#[cfg(feature = "flatbuffers")]
33use crate::{
34    algorithms::transforms::TransformError, flatbuffers::spherical, spherical::InvalidMetric,
35};
36
37///////////////
38// Quantizer //
39///////////////
40
41#[derive(Debug)]
42#[cfg_attr(test, derive(PartialEq))]
43pub struct SphericalQuantizer<A = GlobalAllocator>
44where
45    A: Allocator,
46{
47    /// The offset to apply to each vector.
48    shift: Poly<[f32], A>,
49
50    /// The [`SphericalQuantizer`] supports several different strategies for performing the
51    /// distance-preserving transformation on dataset vectors, which may be applicable in
52    /// different scenarios.
53    ///
54    /// The different transformations may have restrictions on the number of supported dimensions.
55    /// While we will accept all non-zero input dimensions, the output dimension of a transform
56    /// may be higher or lower, depending on the configuration.
57    transform: Transform<A>,
58
59    /// The metric meant to be used by the quantizer.
60    metric: SupportedMetric,
61
62    /// When processing queries, it may be beneficial to modify the query norm to match the
63    /// dataset norm.
64    ///
65    /// This is only applicable when `InnerProduct` and `Cosine` are used, but serves to
66    /// move the query into the dynamic range of the quantization.
67    ///
68    /// You would think that the normalization step in RabitQ would mitigate this, but
69    /// that is not always right since range-adjustment happens before centering.
70    mean_norm: Positive<f32>,
71
72    /// To support 16-bit constants which have a limited dynamic range, we allow a
73    /// pre-scaling parameter that is multiplied to each value in compressed vectors.
74    ///
75    /// This allows to transparent handling of compressing integral data, which can
76    /// otherwise easily overflow `f16`.
77    pre_scale: Positive<f32>,
78}
79
80impl<A> TryClone for SphericalQuantizer<A>
81where
82    A: Allocator,
83{
84    fn try_clone(&self) -> Result<Self, AllocatorError> {
85        Ok(Self {
86            shift: self.shift.try_clone()?,
87            transform: self.transform.try_clone()?,
88            metric: self.metric,
89            mean_norm: self.mean_norm,
90            pre_scale: self.pre_scale,
91        })
92    }
93}
94
95#[derive(Debug, Clone, Copy, Error)]
96#[non_exhaustive]
97pub enum TrainError {
98    #[error("data dim cannot be zero")]
99    DimCannotBeZero,
100    #[error("data cannot be empty")]
101    DataCannotBeEmpty,
102    #[error("pre-scale must be positive")]
103    PrescaleNotPositive,
104    #[error("norm must be positive")]
105    NormNotPositive,
106    #[error("computed norm contains infinity or NaN")]
107    NormNotFinite,
108    #[error("reciprocal norm contains infinity or NaN")]
109    ReciprocalNormNotFinite,
110    #[error(transparent)]
111    AllocatorError(#[from] AllocatorError),
112}
113
114impl<A> SphericalQuantizer<A>
115where
116    A: Allocator,
117{
118    /// Return the number dimensions this quantizer has been trained for.
119    pub fn input_dim(&self) -> usize {
120        self.shift.len()
121    }
122
123    /// Return the dimension of the post-transformed vector.
124    ///
125    /// Output storage vectors should use this dimension instead of `self.dim()` because
126    /// in general, the output dim **may** be different from the input dimension.
127    pub fn output_dim(&self) -> usize {
128        self.transform.output_dim()
129    }
130
131    /// Return the per-dimension shift vector.
132    ///
133    /// This vector is meant to accomplish two goals:
134    ///
135    /// 1. Centers the data around the training dataset mean.
136    /// 2. Offsets each dimension into a range that can be encoded in unsigned values.
137    pub fn shift(&self) -> &[f32] {
138        &self.shift
139    }
140
141    /// Return the approximate mean norm of the training data.
142    pub fn mean_norm(&self) -> Positive<f32> {
143        self.mean_norm
144    }
145
146    /// Return the pre-scaling parameter for data. This value is multiplied to every
147    /// compressed vector to adjust its dynamic range.
148    ///
149    /// A value of 1.0 means that no scaling is occurring.
150    pub fn pre_scale(&self) -> Positive<f32> {
151        self.pre_scale
152    }
153
154    /// Return a reference to the allocator used by this data structure.
155    pub fn allocator(&self) -> &A {
156        self.shift.allocator()
157    }
158
159    /// A lower-level constructor that accepts a centroid, mean norm, and pre-scale directly.
160    pub fn generate(
161        mut centroid: Poly<[f32], A>,
162        mean_norm: f32,
163        transform: TransformKind,
164        metric: SupportedMetric,
165        pre_scale: Option<f32>,
166        rng: &mut dyn RngCore,
167        allocator: A,
168    ) -> Result<Self, TrainError> {
169        let pre_scale = match pre_scale {
170            Some(v) => Positive::new(v).map_err(|_| TrainError::PrescaleNotPositive)?,
171            None => crate::num::POSITIVE_ONE_F32,
172        };
173
174        let dim = match NonZeroUsize::new(centroid.len()) {
175            Some(dim) => dim,
176            None => {
177                return Err(TrainError::DimCannotBeZero);
178            }
179        };
180
181        let mean_norm = Positive::new(mean_norm).map_err(|_| TrainError::NormNotPositive)?;
182
183        // We passed in a 'rng' so `Transform::new` will not fail.
184        let transform = match Transform::new(transform, dim, Some(rng), allocator.clone()) {
185            Ok(v) => v,
186            Err(NewTransformError::RngMissing(_)) => unreachable!("An Rng was provided"),
187            Err(NewTransformError::AllocatorError(err)) => {
188                return Err(TrainError::AllocatorError(err));
189            }
190        };
191
192        // Transform the centroid by the pre-scale.
193        centroid
194            .iter_mut()
195            .for_each(|v| *v *= pre_scale.into_inner());
196
197        Ok(SphericalQuantizer {
198            shift: centroid,
199            transform,
200            metric,
201            mean_norm,
202            pre_scale,
203        })
204    }
205
206    /// Return the metric used by this quantizer.
207    pub fn metric(&self) -> SupportedMetric {
208        self.metric
209    }
210
211    /// Construct a quantizer for vectors in the distribution of `data`.
212    ///
213    /// The type of distance-preserving transform to use is selected by the [`TransformKind`].
214    ///
215    /// Vectors compressed with this quantizer will be **metric specific** and optimized for
216    /// distance computations rather than reconstruction. This means that vectors compressed
217    /// targeting the inner-product distance will not return meaningful results if used for
218    /// L2 distance computations.
219    ///
220    /// Additionally, vectors compressed when using the [`SupportedMetric::Cosine`] distance
221    /// will be implicitly normalized before being compressed to enable better compression.
222    ///
223    /// If argument `pre_scale` is given, then all vectors compressed by this quantizer will
224    /// first be scaled by this value. Note that if given, `pre_scale` **must** be positive.
225    pub fn train<T, R>(
226        data: MatrixView<T>,
227        transform: TransformKind,
228        metric: SupportedMetric,
229        pre_scale: PreScale,
230        rng: &mut R,
231        allocator: A,
232    ) -> Result<Self, TrainError>
233    where
234        T: Copy + Into<f64> + Into<f32>,
235        R: Rng,
236    {
237        // An inner implementation erasing the type of the random number generator to
238        // cut down on excess monomorphization.
239        #[inline(never)]
240        fn train<T, A>(
241            data: MatrixView<T>,
242            transform: TransformKind,
243            metric: SupportedMetric,
244            pre_scale: PreScale,
245            rng: &mut dyn RngCore,
246            allocator: A,
247        ) -> Result<SphericalQuantizer<A>, TrainError>
248        where
249            T: Copy + Into<f64> + Into<f32>,
250            A: Allocator,
251        {
252            // This check is repeated in `Self::generate`, but we prefer to bail as early
253            // as possible if we detect an error.
254            if data.ncols() == 0 {
255                return Err(TrainError::DimCannotBeZero);
256            }
257
258            let (centroid, mean_norm) = match metric {
259                SupportedMetric::SquaredL2 | SupportedMetric::InnerProduct => {
260                    compute_means_and_average_norm(data)
261                }
262                SupportedMetric::Cosine => (
263                    compute_normalized_means(data)
264                        .map_err(|_: CannotBeEmpty| TrainError::DataCannotBeEmpty)?,
265                    1.0,
266                ),
267            };
268
269            let mean_norm = mean_norm as f32;
270
271            if mean_norm <= 0.0 {
272                return Err(TrainError::NormNotPositive);
273            }
274
275            if !mean_norm.is_finite() {
276                return Err(TrainError::NormNotFinite);
277            }
278
279            // Determining if (and how) the pre-scaling term will be calculated.
280            let pre_scale: Positive<f32> = match pre_scale {
281                PreScale::None => crate::num::POSITIVE_ONE_F32,
282                PreScale::Some(v) => v,
283                PreScale::ReciprocalMeanNorm => {
284                    // We've checked that `mean_norm` is both positive and finite.
285                    //
286                    // It's possible that when converted to `f32`,
287                    //
288                    // Taking the reciprocal is well defined. However, since the norms
289                    // and scales in `compute_means_and_average_norm` are done using `f64`,
290                    // it's possible that the computed `mean_norm` is subnormal, leading
291                    // to the reciprocal being infinity.
292                    let pre_scale = Positive::new(1.0 / mean_norm)
293                        .map_err(|_| TrainError::ReciprocalNormNotFinite)?;
294
295                    if !pre_scale.into_inner().is_finite() {
296                        return Err(TrainError::ReciprocalNormNotFinite);
297                    }
298
299                    pre_scale
300                }
301            };
302
303            // Allow the pre-scaling to take place inside `Self::generate`.
304            let centroid =
305                Poly::from_iter(centroid.into_iter().map(|i| i as f32), allocator.clone())?;
306
307            SphericalQuantizer::generate(
308                centroid,
309                mean_norm,
310                transform,
311                metric,
312                Some(pre_scale.into_inner()),
313                rng,
314                allocator,
315            )
316        }
317
318        train(data, transform, metric, pre_scale, rng, allocator)
319    }
320
321    /// Rescale the argument `v` to be in the rough dynamic range of the training dataset.
322    pub fn rescale(&self, v: &mut [f32]) {
323        let norm = FastL2Norm.evaluate(&*v);
324        let m = self.mean_norm.into_inner() / norm;
325        v.iter_mut().for_each(|i| *i *= m);
326    }
327
328    /// Private helper function to do common data pre-processing.
329    ///
330    /// # Panics
331    ///
332    /// Panics if `data.len() != self.dim()`.
333    fn preprocess<'a>(
334        &self,
335        data: &[f32],
336        allocator: ScopedAllocator<'a>,
337    ) -> Result<Preprocessed<'a>, CompressionError> {
338        assert_eq!(data.len(), self.input_dim(), "Data dimension is incorrect.");
339
340        // Fold in pre-scaling with the potential norm corretion for cosine.
341        //
342        // NOTE: When we're computing Cosine Similarity, we normalize the vector. As such,
343        // the `pre_scale` parameter become irrelvant since it just gets normalized away.
344        let scale = self.pre_scale.into_inner();
345        let mul: f32 = match self.metric {
346            SupportedMetric::Cosine => {
347                let norm: f32 = (FastL2Norm).evaluate(data);
348                if norm == 0.0 { 1.0 } else { 1.0 / norm }
349            }
350            SupportedMetric::SquaredL2 | SupportedMetric::InnerProduct => scale,
351        };
352
353        // Center the vector and compute the squared norm of the shifted vector.
354        let shifted = Poly::from_iter(
355            std::iter::zip(data.iter(), self.shift.iter()).map(|(&f, &s)| mul * f - s),
356            allocator,
357        )?;
358
359        let shifted_norm = FastL2Norm.evaluate(&*shifted);
360        if !shifted_norm.is_finite() {
361            return Err(CompressionError::InputContainsNaN);
362        }
363        let inner_product_with_centroid = match self.metric {
364            SupportedMetric::SquaredL2 => None,
365            SupportedMetric::InnerProduct | SupportedMetric::Cosine => {
366                let ip: MathematicalValue<f32> = InnerProduct::evaluate(&*shifted, &*self.shift);
367                Some(ip.into_inner())
368            }
369        };
370
371        Ok(Preprocessed {
372            shifted,
373            shifted_norm,
374            inner_product_with_centroid,
375        })
376    }
377}
378
379/// Pre-scaling selector for spherical quantization training. Pre-scaling adjusts the
380/// dynamic range of the data (usually decreasing it uniformly) to keep the correction terms
381/// within the range expressible by 16-bit floating point numbers.
382#[derive(Debug, Clone, Copy)]
383pub enum PreScale {
384    /// Do not use any pre-scaling.
385    None,
386    /// Pre-scale all data by the specified amount.
387    Some(Positive<f32>),
388    /// Heuristically estimate a pre-scaling parameter by using the inverse approximate
389    /// mean norm. This will nearly normalize in-distribution vectors.
390    ReciprocalMeanNorm,
391}
392
393#[cfg(feature = "flatbuffers")]
394#[cfg_attr(docsrs, doc(cfg(feature = "flatbuffers")))]
395#[derive(Debug, Clone, Error, PartialEq)]
396#[non_exhaustive]
397pub enum DeserializationError {
398    #[error(transparent)]
399    TransformError(#[from] TransformError),
400    #[error("unrecognized flatbuffer identifier")]
401    UnrecognizedIdentifier,
402    #[error("transform length not equal to centroid")]
403    DimMismatch,
404    #[error("norm is missing or is not positive")]
405    MissingNorm,
406    #[error("pre-scale is missing or is not positive")]
407    PreScaleNotPositive,
408
409    #[error(transparent)]
410    InvalidFlatBuffer(#[from] flatbuffers::InvalidFlatbuffer),
411
412    #[error(transparent)]
413    InvalidMetric(#[from] InvalidMetric),
414
415    #[error(transparent)]
416    AllocatorError(#[from] AllocatorError),
417}
418
419#[cfg(feature = "flatbuffers")]
420#[cfg_attr(docsrs, doc(cfg(feature = "flatbuffers")))]
421impl<A> SphericalQuantizer<A>
422where
423    A: Allocator + Clone,
424{
425    /// Pack `self` into `buf` using the [`spherical::SphericalQuantizer`] serialized
426    /// representation.
427    pub(crate) fn pack<'a, FA>(
428        &self,
429        buf: &mut FlatBufferBuilder<'a, FA>,
430    ) -> WIPOffset<spherical::SphericalQuantizer<'a>>
431    where
432        FA: flatbuffers::Allocator + 'a,
433    {
434        // Save the centroid vector.
435        let centroid = buf.create_vector(&self.shift);
436
437        // Save the transform.
438        let transform = self.transform.pack(buf);
439
440        // Finish up.
441        spherical::SphericalQuantizer::create(
442            buf,
443            &spherical::SphericalQuantizerArgs {
444                centroid: Some(centroid),
445                transform: Some(transform),
446                metric: self.metric.into(),
447                mean_norm: self.mean_norm.into_inner(),
448                pre_scale: self.pre_scale.into_inner(),
449            },
450        )
451    }
452
453    /// Attempt to unpack `self` from a serialized [`spherical::SphericalQuantizer`]
454    /// serialized representation, returning any encountered error.
455    pub(crate) fn try_unpack(
456        alloc: A,
457        proto: spherical::SphericalQuantizer<'_>,
458    ) -> Result<Self, DeserializationError> {
459        let metric: SupportedMetric = proto.metric().try_into()?;
460
461        // Unpack the centroid.
462        let shift = Poly::from_iter(proto.centroid().into_iter(), alloc.clone())?;
463
464        // Unpack the transform.
465        let transform = Transform::try_unpack(alloc, proto.transform())?;
466
467        // Ensure consistency between the shift dimensions and the transform.
468        if shift.len() != transform.input_dim() {
469            return Err(DeserializationError::DimMismatch);
470        }
471
472        // Make sure we get a sane value for the mean norm.
473        let mean_norm =
474            Positive::new(proto.mean_norm()).map_err(|_| DeserializationError::MissingNorm)?;
475
476        let pre_scale = Positive::new(proto.pre_scale())
477            .map_err(|_| DeserializationError::PreScaleNotPositive)?;
478
479        Ok(Self {
480            shift,
481            transform,
482            metric,
483            mean_norm,
484            pre_scale,
485        })
486    }
487}
488
489struct Preprocessed<'a> {
490    shifted: Poly<[f32], ScopedAllocator<'a>>,
491    shifted_norm: f32,
492    inner_product_with_centroid: Option<f32>,
493}
494
495impl Preprocessed<'_> {
496    /// Return the metric specific correction term as sumamrized below:
497    ///
498    /// * Inner Product: The inner product between the shifted vector and the centroid.
499    /// * Squared L2: The squared norm of the shifted vector.
500    fn metric_specific(&self) -> f32 {
501        match self.inner_product_with_centroid {
502            Some(ip) => ip,
503            None => self.shifted_norm * self.shifted_norm,
504        }
505    }
506}
507
508///////////////////////
509// Distance Functors //
510///////////////////////
511
512impl<A> AsFunctor<CompensatedSquaredL2> for SphericalQuantizer<A>
513where
514    A: Allocator,
515{
516    fn as_functor(&self) -> CompensatedSquaredL2 {
517        CompensatedSquaredL2::new(self.output_dim())
518    }
519}
520
521impl<A> AsFunctor<CompensatedIP> for SphericalQuantizer<A>
522where
523    A: Allocator,
524{
525    fn as_functor(&self) -> CompensatedIP {
526        CompensatedIP::new(&self.shift, self.output_dim())
527    }
528}
529
530impl<A> AsFunctor<CompensatedCosine> for SphericalQuantizer<A>
531where
532    A: Allocator,
533{
534    fn as_functor(&self) -> CompensatedCosine {
535        CompensatedCosine::new(self.as_functor())
536    }
537}
538
539/////////////////
540// Compression //
541/////////////////
542
543#[derive(Debug, Error, Clone, Copy, PartialEq)]
544#[non_exhaustive]
545pub enum CompressionError {
546    #[error("input contains NaN")]
547    InputContainsNaN,
548
549    #[error("expected source vector to have length {expected}")]
550    SourceDimensionMismatch { expected: usize },
551
552    #[error("expected destination vector to have length {expected}")]
553    DestinationDimensionMismatch { expected: usize },
554
555    #[error(
556        "encoding error - you may need to scale the entire dataset to reduce its dynamic range"
557    )]
558    EncodingError(#[from] DataMetaError),
559
560    #[error(transparent)]
561    AllocatorError(#[from] AllocatorError),
562}
563
564fn check_dims(
565    input: usize,
566    output: usize,
567    from: usize,
568    into: usize,
569) -> Result<(), CompressionError> {
570    if from != input {
571        return Err(CompressionError::SourceDimensionMismatch { expected: input });
572    }
573    if into != output {
574        return Err(CompressionError::DestinationDimensionMismatch { expected: output });
575    }
576    Ok(())
577}
578
579/// Helper trait to dispatch to a faster 1-bit implementation and use the slower
580/// maximum-cosine algorithm when more than 1 bit is used.
581trait FinishCompressing {
582    fn finish_compressing(
583        &mut self,
584        preprocessed: &Preprocessed<'_>,
585        transformed: &[f32],
586        transformed_norm: f32,
587        allocator: ScopedAllocator<'_>,
588    ) -> Result<(), CompressionError>;
589}
590
591impl FinishCompressing for DataMut<'_, 1> {
592    fn finish_compressing(
593        &mut self,
594        preprocessed: &Preprocessed<'_>,
595        transformed: &[f32],
596        transformed_norm: f32,
597        _: ScopedAllocator<'_>,
598    ) -> Result<(), CompressionError> {
599        // Compute signed quantized vector (-1 or 1)
600        // and also populate the unsigned bit representation in `into` output vector.
601        let mut quant_raw_inner_product = 0.0f32;
602        let mut bit_sum = 0u32;
603        transformed.iter().enumerate().for_each(|(i, &r)| {
604            let bit: u8 = if r > 0.0 { 1 } else { 0 };
605
606            quant_raw_inner_product += r.abs();
607            bit_sum += <u8 as Into<u32>>::into(bit);
608
609            // SAFETY: From check 1, we know that `i < into.len()`.
610            unsafe { self.vector_mut().set_unchecked(i, bit) };
611        });
612
613        // The value we just computed for `quant_raw_inner_product` is:
614        // ```
615        // Y = <x', x> * sqrt(D)                        [1]
616        // ```
617        // The inner product correction term is
618        // ```
619        //       2 |X|
620        // -----------------                            [2]
621        // <x', x> * sqrt(D)
622        // ```
623        // [1] substitutes directly into [2] and we get
624        // ```
625        // 2 |X|
626        // -----
627        //   Y
628        // ```
629        // Therefore, the inner product correction term is
630        // ```
631        // 2.0 * shifted_norm / quant_raw_inner_product
632        // ```
633        let inner_product_correction =
634            2.0 * transformed_norm * preprocessed.shifted_norm / quant_raw_inner_product;
635        self.set_meta(DataMeta::new(
636            inner_product_correction,
637            preprocessed.metric_specific(),
638            bit_sum,
639        )?);
640        Ok(())
641    }
642}
643
644impl FinishCompressing for DataMut<'_, 2> {
645    fn finish_compressing(
646        &mut self,
647        preprocessed: &Preprocessed<'_>,
648        transformed: &[f32],
649        transformed_norm: f32,
650        allocator: ScopedAllocator<'_>,
651    ) -> Result<(), CompressionError> {
652        compress_via_maximum_cosine(
653            self.reborrow_mut(),
654            preprocessed,
655            transformed,
656            transformed_norm,
657            allocator,
658        )
659    }
660}
661
662impl FinishCompressing for DataMut<'_, 4> {
663    fn finish_compressing(
664        &mut self,
665        preprocessed: &Preprocessed<'_>,
666        transformed: &[f32],
667        transformed_norm: f32,
668        allocator: ScopedAllocator<'_>,
669    ) -> Result<(), CompressionError> {
670        compress_via_maximum_cosine(
671            self.reborrow_mut(),
672            preprocessed,
673            transformed,
674            transformed_norm,
675            allocator,
676        )
677    }
678}
679
680impl FinishCompressing for DataMut<'_, 8> {
681    fn finish_compressing(
682        &mut self,
683        preprocessed: &Preprocessed<'_>,
684        transformed: &[f32],
685        transformed_norm: f32,
686        allocator: ScopedAllocator<'_>,
687    ) -> Result<(), CompressionError> {
688        compress_via_maximum_cosine(
689            self.reborrow_mut(),
690            preprocessed,
691            transformed,
692            transformed_norm,
693            allocator,
694        )
695    }
696}
697
698//////////////////////
699// Data Compression //
700//////////////////////
701
702impl<A> CompressIntoWith<&[f32], FullQueryMut<'_>, ScopedAllocator<'_>> for SphericalQuantizer<A>
703where
704    A: Allocator,
705{
706    type Error = CompressionError;
707
708    /// Compress the input vector `from` into the bitslice `into`.
709    ///
710    /// # Error
711    ///
712    /// Returns an error if
713    /// * The input contains `NaN`.
714    /// * `from.len() != self.dim()`: Vector to be compressed must have the same
715    ///   dimensionality as the quantizer.
716    /// * `into.len() != self.output_dim()`: Compressed vector must have the same
717    ///   dimensionality as the output of the distance-preserving transform. Importantely,
718    ///   this **may** be different than `self.dim()` and should be retrieved from
719    ///   `self.output_dim()`.
720    fn compress_into_with(
721        &self,
722        from: &[f32],
723        mut into: FullQueryMut<'_>,
724        allocator: ScopedAllocator<'_>,
725    ) -> Result<(), Self::Error> {
726        let input_dim = self.shift.len();
727        let output_dim = self.output_dim();
728        check_dims(input_dim, output_dim, from.len(), into.len())?;
729
730        let mut preprocessed = self.preprocess(from, allocator)?;
731
732        // If the preprocessed norm is zero, then we tried to compress the center directly.
733        // In this case, we can get the correct behavior by setting `into` to all zeros.
734        if preprocessed.shifted_norm == 0.0 {
735            into.vector_mut().fill(0.0);
736            *into.meta_mut() = Default::default();
737            return Ok(());
738        }
739
740        preprocessed
741            .shifted
742            .iter_mut()
743            .for_each(|v| *v /= preprocessed.shifted_norm);
744
745        // Transformation can fail due to OOM - we want to handle that gracefully.
746        //
747        // If the transformation fails because we provided the wrong sizes, that is a hard
748        // program bug.
749        #[expect(clippy::panic, reason = "the dimensions should already be as expected")]
750        match self
751            .transform
752            .transform_into(into.vector_mut(), &preprocessed.shifted, allocator)
753        {
754            Ok(()) => {}
755            Err(TransformFailed::AllocatorError(err)) => {
756                return Err(CompressionError::AllocatorError(err));
757            }
758            Err(TransformFailed::SourceMismatch { .. })
759            | Err(TransformFailed::DestinationMismatch { .. }) => {
760                panic!(
761                    "The sizes of these arrays should already be checked - this is a logic error"
762                );
763            }
764        }
765
766        *into.meta_mut() = FullQueryMeta {
767            sum: into.vector().iter().sum::<f32>(),
768            shifted_norm: preprocessed.shifted_norm,
769            metric_specific: preprocessed.metric_specific(),
770        };
771        Ok(())
772    }
773}
774
775impl<const NBITS: usize, A> CompressIntoWith<&[f32], DataMut<'_, NBITS>, ScopedAllocator<'_>>
776    for SphericalQuantizer<A>
777where
778    A: Allocator,
779    Unsigned: Representation<NBITS>,
780    for<'a> DataMut<'a, NBITS>: FinishCompressing,
781{
782    type Error = CompressionError;
783
784    /// Compress the input vector `from` into the bitslice `into`.
785    ///
786    /// # Error
787    ///
788    /// Returns an error if
789    /// * The input contains `NaN`.
790    /// * `from.len() != self.dim()`: Vector to be compressed must have the same
791    ///   dimensionality as the quantizer.
792    /// * `into.len() != self.output_dim()`: Compressed vector must have the same
793    ///   dimensionality as the output of the distance-preserving transform. Importantely,
794    ///   this **may** be different than `self.dim()` and should be retrieved from
795    ///   `self.output_dim()`.
796    fn compress_into_with(
797        &self,
798        from: &[f32],
799        mut into: DataMut<'_, NBITS>,
800        allocator: ScopedAllocator<'_>,
801    ) -> Result<(), Self::Error> {
802        let input_dim = self.shift.len();
803        let output_dim = self.output_dim();
804        check_dims(input_dim, output_dim, from.len(), into.len())?;
805
806        let mut preprocessed = self.preprocess(from, allocator)?;
807
808        if preprocessed.shifted_norm == 0.0 {
809            into.set_meta(DataMeta::default());
810            return Ok(());
811        }
812
813        let mut transformed = Poly::broadcast(0.0f32, output_dim, allocator)?;
814        preprocessed
815            .shifted
816            .iter_mut()
817            .for_each(|v| *v /= preprocessed.shifted_norm);
818
819        // Transformation can fail due to OOM - we want to handle that gracefully.
820        //
821        // If the transformation fails because we provided the wrong sizes, that is a hard
822        // program bug.
823        #[expect(clippy::panic, reason = "the dimensions should already be as expected")]
824        match self
825            .transform
826            .transform_into(&mut transformed, &preprocessed.shifted, allocator)
827        {
828            Ok(()) => {}
829            Err(TransformFailed::AllocatorError(err)) => {
830                return Err(CompressionError::AllocatorError(err));
831            }
832            Err(TransformFailed::SourceMismatch { .. })
833            | Err(TransformFailed::DestinationMismatch { .. }) => {
834                panic!(
835                    "The sizes of these arrays should already be checked - this is a logic error"
836                );
837            }
838        }
839
840        let transformed_norm = if self.transform.preserves_norms() {
841            1.0
842        } else {
843            (FastL2Norm).evaluate(&*transformed)
844        };
845
846        into.finish_compressing(&preprocessed, &transformed, transformed_norm, allocator)?;
847        Ok(())
848    }
849}
850
851struct AsNonZero<const NBITS: usize>;
852impl<const NBITS: usize> AsNonZero<NBITS> {
853    // Lint: Unwrap is being used in a const-context.
854    #[allow(clippy::unwrap_used)]
855    const NON_ZERO: NonZeroUsize = NonZeroUsize::new(NBITS).unwrap();
856}
857
858fn compress_via_maximum_cosine<const NBITS: usize>(
859    mut data: DataMut<'_, NBITS>,
860    preprocessed: &Preprocessed<'_>,
861    transformed: &[f32],
862    transformed_norm: f32,
863    allocator: ScopedAllocator<'_>,
864) -> Result<(), CompressionError>
865where
866    Unsigned: Representation<NBITS>,
867{
868    assert_eq!(data.len(), transformed.len());
869
870    // Find the value we will use to multiply `transformed` to round it to the lattice
871    // element that has the maximum cosine-similarity.
872    let optimal_scale =
873        maximize_cosine_similarity(transformed, AsNonZero::<NBITS>::NON_ZERO, allocator)?;
874
875    let domain = Unsigned::domain_const::<NBITS>();
876    let min = *domain.start() as f32;
877    let max = *domain.end() as f32;
878    let offset = max / 2.0;
879
880    let mut self_inner_product = 0.0f32;
881    let mut bit_sum = 0u32;
882    for (i, t) in transformed.iter().enumerate() {
883        let v = (*t * optimal_scale + offset).clamp(min, max).round();
884        let dv = v - offset;
885        self_inner_product = dv.mul_add(*t, self_inner_product);
886
887        let v = v as u8;
888        bit_sum += <u8 as Into<u32>>::into(v);
889
890        // SAFETY: We have checked that `data.len() == transformed.len()`, so this access
891        // is in-bounds.
892        //
893        // Further, by construction, `v` is encodable by the `Unsigned`.
894        unsafe { data.vector_mut().set_unchecked(i, v) };
895    }
896
897    let shifted_norm = preprocessed.shifted_norm;
898    let inner_product_correction = (transformed_norm * shifted_norm) / self_inner_product;
899    data.set_meta(DataMeta::new(
900        inner_product_correction,
901        preprocessed.metric_specific(),
902        bit_sum,
903    )?);
904    Ok(())
905}
906
907// This struct does 2 things:
908//
909// 1. Records the index in `v` and `rounded` and the scaling parameter so that
910//   `value * v[position]` gets rounded to `rounded[position] + 1` while
911//   `(value - epsilon) * v[position]` is rounded to `rounded[position]` for some small
912//   epsilon.
913//
914//   Informally, "what's the smallest scaling factor so `v[position]` gets rounded to
915//   the next value.
916//
917// 2. Imposes a total ordering on `f32` values so it can be used in a `BinaryHeap`.
918//   Additionally, ordering is reverse so that `BinaryHeap` models a min-heap instead
919//   of a max-heap.
920#[derive(Debug, Clone, Copy)]
921struct Pair {
922    value: f32,
923    position: u32,
924}
925
926impl PartialEq for Pair {
927    fn eq(&self, other: &Self) -> bool {
928        self.value.eq(&other.value)
929    }
930}
931
932impl Eq for Pair {}
933impl PartialOrd for Pair {
934    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
935        Some(self.cmp(other))
936    }
937}
938impl Ord for Pair {
939    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
940        other
941            .value
942            .partial_cmp(&self.value)
943            .unwrap_or(std::cmp::Ordering::Equal)
944    }
945}
946
947/// This is a tricky function - please read carefully.
948///
949/// Given a vector `v` compute the scaling factor `s` such that the cosine similarity
950/// betwen `v` and `r` is maximized where `r` is defined as
951/// ```math
952/// let offset = (2^(num_bits) - 1) / 2;
953/// let r = (s * v + offset).round().clamp(0, 2^num_bits - 1) - offset
954/// ```
955///
956/// More informally, maximize the inner product between `v` and the points in a square
957/// lattice with 2^num_bits values in each dimension, centered around zero. This latice
958/// takes the values (+0.5, -0.5, +1.5, -1.5 ...) to give equal weight above and below zero.
959///
960/// It works by slowly increasing the factor `s` such that the rounding of only one
961/// dimension in `v` is changed at a time. A running tally of the cosine similarity is
962/// computed for each scaling factor until we've processed `D * 2^(num_bits - 1)` possible
963/// scaling factors, where `D` is the length of `v`.
964///
965/// The best scaling factor is returned.
966///
967/// Refer to algorithm 1 in <https://arxiv.org/pdf/2409.09913>.
968///
969/// # Panics
970///
971/// Panics is `v.is_empty()`.
972///
973/// # Implementation details
974///
975/// We work with the absolute value of the elements in the vector `v`.
976/// This does not affect the final result as the scaling works the same in both the
977/// positive and negative directions but simplifies the book keeping.
978fn maximize_cosine_similarity(
979    v: &[f32],
980    num_bits: NonZeroUsize,
981    allocator: ScopedAllocator<'_>,
982) -> Result<f32, AllocatorError> {
983    // Initially, the lattice element has the value `0.5` for all dimensions.
984    // This means the initial inner product between `v` and the rounded term is simply
985    // `0.5 * sum(abs.(v))`. The absolute value is used because the latice element is
986    // always in the direction of the components in `v`.
987    let mut current_ip = 0.5 * v.iter().map(|i| i.abs() as f64).sum::<f64>();
988    let mut current_square_norm = 0.25 * (v.len() as f64);
989
990    // Book keeping for the current value of the rounded vector.
991    // The true numeric value is 0.5 less than this (in the direction of `v`), but we use
992    // integers for a smaller memory footprint.
993    let mut rounded = Poly::broadcast(1u16, v.len(), allocator)?;
994
995    // Compute the critical values and store them on a heap.
996    //
997    // The binary heap will keep track of the minimum critical value. Multiplying `v` by the
998    // minimum critical value `s` means that `s * v` will only change `rounded` from its
999    // current value at a single index (the position associated with `s`).
1000    let eps = 0.0001f32;
1001    let one_and_change = 1.0 + eps;
1002    let mut base = Poly::from_iter(
1003        v.iter().enumerate().map(|(position, value)| {
1004            let value = one_and_change / value.abs();
1005            Pair {
1006                value,
1007                position: position as u32,
1008            }
1009        }),
1010        allocator,
1011    )?;
1012
1013    // Lint: This is a private method and all the callers have an invariant that they check
1014    // for non-empty inputs.
1015    #[allow(clippy::expect_used)]
1016    let mut critical_values =
1017        SliceHeap::new(&mut base).expect("calling code should not allow the slice to be empty");
1018
1019    let mut max_similarity = f64::NEG_INFINITY;
1020    let mut optimal_scale = f32::default();
1021    let stop = (2usize).pow(num_bits.get() as u32 - 1) as u16;
1022
1023    loop {
1024        let mut should_break = false;
1025        critical_values.update_root(|pair| {
1026            let Pair { value, position } = *pair;
1027            if value == f32::MAX {
1028                should_break = true;
1029                return;
1030            }
1031
1032            let r = &mut rounded[position as usize];
1033            let vp = &v[position as usize];
1034
1035            let old_r = *r;
1036            // By the nature of cricital values, only `r` will change in `rounded` when
1037            // multiplying by `value`. And that change will be to increase by 1.
1038            *r += 1;
1039
1040            // The inner product estimate simply increases by `vp.abs()` because:
1041            //
1042            // * `r` is the only value in `rounded` that changes.
1043            // * `r` is increased by 1.
1044            current_ip += vp.abs() as f64;
1045
1046            // This uses the formula
1047            // ```math
1048            // (x + 1)^2 - x^2 = x^2 + 2x + 1 - x^2
1049            //                 = 2x + 1
1050            // ```
1051            // substitute `x = y - 1/2` to obtain the true value associated with rounded and
1052            // we get
1053            // ```math
1054            // 2 ( y - 1/2 ) + 1 = 2y - 1 + 1
1055            //                   = 2y
1056            // ```
1057            // Therefore, the change in the estimate for the square norm of `rounded` is
1058            // `2 * old_r`.
1059            current_square_norm += (2 * old_r) as f64;
1060
1061            // Compute the current cosine similarity and update max if needed.
1062            let similarity = current_ip / current_square_norm.sqrt();
1063            if similarity > max_similarity {
1064                max_similarity = similarity;
1065                optimal_scale = value;
1066            }
1067
1068            // Compute the scaling factor that will change this dimension to the next value.
1069            if *r < stop {
1070                *pair = Pair {
1071                    value: (*r as f32 + eps) / vp.abs(),
1072                    position,
1073                };
1074            } else {
1075                *pair = Pair {
1076                    value: f32::MAX,
1077                    position,
1078                };
1079            }
1080        });
1081        if should_break {
1082            break;
1083        }
1084    }
1085
1086    Ok(optimal_scale)
1087}
1088
1089///////////////////////
1090// Query Compression //
1091///////////////////////
1092
1093impl<const NBITS: usize, Perm, A>
1094    CompressIntoWith<&[f32], QueryMut<'_, NBITS, Perm>, ScopedAllocator<'_>>
1095    for SphericalQuantizer<A>
1096where
1097    Unsigned: Representation<NBITS>,
1098    Perm: PermutationStrategy<NBITS>,
1099    A: Allocator,
1100{
1101    type Error = CompressionError;
1102
1103    /// Compress the input vector `from` into the bitslice `into`.
1104    ///
1105    /// # Error
1106    ///
1107    /// Returns an error if
1108    /// * The input contains `NaN`.
1109    /// * `from.len() != self.dim()`: Vector to be compressed must have the same
1110    ///   dimensionality as the quantizer.
1111    /// * `into.len() != self.output_dim()`: Compressed vector must have the same
1112    ///   dimensionality as the output of the distance-preserving transform. Importantely,
1113    ///   this **may** be different than `self.dim()` and should be retrieved from
1114    ///   `self.output_dim()`.
1115    fn compress_into_with(
1116        &self,
1117        from: &[f32],
1118        mut into: QueryMut<'_, NBITS, Perm>,
1119        allocator: ScopedAllocator<'_>,
1120    ) -> Result<(), Self::Error> {
1121        let input_dim = self.shift.len();
1122        let output_dim = self.output_dim();
1123        check_dims(input_dim, output_dim, from.len(), into.len())?;
1124
1125        let mut preprocessed = self.preprocess(from, allocator)?;
1126
1127        if preprocessed.shifted_norm == 0.0 {
1128            into.set_meta(QueryMeta::default());
1129            return Ok(());
1130        }
1131
1132        preprocessed
1133            .shifted
1134            .iter_mut()
1135            .for_each(|v| *v /= preprocessed.shifted_norm);
1136
1137        let mut transformed = Poly::broadcast(0.0f32, output_dim, allocator)?;
1138
1139        // Transformation can fail due to OOM - we want to handle that gracefully.
1140        //
1141        // If the transformation fails because we provided the wrong sizes, that is a hard
1142        // program bug.
1143        #[expect(clippy::panic, reason = "the dimensions should already be as expected")]
1144        match self
1145            .transform
1146            .transform_into(&mut transformed, &preprocessed.shifted, allocator)
1147        {
1148            Ok(()) => {}
1149            Err(TransformFailed::AllocatorError(err)) => {
1150                return Err(CompressionError::AllocatorError(err));
1151            }
1152            Err(TransformFailed::SourceMismatch { .. })
1153            | Err(TransformFailed::DestinationMismatch { .. }) => {
1154                panic!(
1155                    "The sizes of these arrays should already be checked - this is a logic error"
1156                );
1157            }
1158        }
1159
1160        // Compute the minimum and maximum values of the transformed vector.
1161        let (min, max) = transformed
1162            .iter()
1163            .fold((f32::MAX, f32::MIN), |(min, max), i| {
1164                (i.min(min), i.max(max))
1165            });
1166
1167        let domain = Unsigned::domain_const::<NBITS>();
1168        let lo = (*domain.start()) as f32;
1169        let hi = (*domain.end()) as f32;
1170
1171        let scale = (max - min) / hi;
1172        let mut bit_sum: f32 = 0.0;
1173        transformed.iter().enumerate().for_each(|(i, v)| {
1174            let c = ((v - min) / scale).round().clamp(lo, hi);
1175            bit_sum += c;
1176
1177            // Lint: We have verified that `into.len() == transformed.len()`, so the index
1178            // `i` is in bounds.
1179            //
1180            // Further, `c` has beem clamped to `[0, 2^NBITS - 1]` and is thus encodable
1181            // with the NBITS-bit unsigned representation.
1182            #[allow(clippy::unwrap_used)]
1183            into.vector_mut().set(i, c as i64).unwrap();
1184        });
1185
1186        // Finish up the compensation terms.
1187        into.set_meta(QueryMeta {
1188            inner_product_correction: preprocessed.shifted_norm * scale,
1189            bit_sum,
1190            offset: min / scale,
1191            metric_specific: preprocessed.metric_specific(),
1192        });
1193
1194        Ok(())
1195    }
1196}
1197
1198///////////
1199// Tests //
1200///////////
1201
1202#[cfg(not(miri))]
1203#[cfg(test)]
1204mod tests {
1205    use super::*;
1206
1207    use std::fmt::Display;
1208
1209    use diskann_utils::{
1210        ReborrowMut, lazy_format,
1211        views::{self, Matrix},
1212    };
1213    use diskann_vector::{PureDistanceFunction, norm::FastL2NormSquared};
1214    use diskann_wide::ARCH;
1215    use rand::{
1216        SeedableRng,
1217        distr::{Distribution, Uniform},
1218        rngs::StdRng,
1219    };
1220    use rand_distr::StandardNormal;
1221
1222    use crate::{
1223        algorithms::transforms::TargetDim,
1224        alloc::GlobalAllocator,
1225        bits::{BitTranspose, Dense},
1226        spherical::{Data, DataMetaF32, FullQuery, Query},
1227        test_util,
1228    };
1229
1230    // Test cosine-similarity maximizer
1231    #[test]
1232    fn test_cosine_similarity_maximizer() {
1233        let mut rng = StdRng::seed_from_u64(0x070d9ff8cf5e0f8c);
1234        let num_trials = 10000;
1235        let num_bits = NonZeroUsize::new(3).unwrap();
1236
1237        let scale_distribution = Uniform::new(0.5f32, 10.0f32).unwrap();
1238
1239        let run_test = |target: [f32; 4]| {
1240            let scale =
1241                maximize_cosine_similarity(&target, num_bits, ScopedAllocator::global()).unwrap();
1242
1243            let mut best: [f32; 4] = [0.0, 0.0, 0.0, 0.0];
1244            let mut best_similarity: f32 = f32::NEG_INFINITY;
1245
1246            // This crazy series of nested loops performs an exhaustive search over the
1247            // encoding space.
1248            let min = -3.5;
1249            for i0 in 0..8 {
1250                for i1 in 0..8 {
1251                    for i2 in 0..8 {
1252                        for i3 in 0..8 {
1253                            let p: [f32; 4] = [
1254                                min + (i0 as f32),
1255                                min + (i1 as f32),
1256                                min + (i2 as f32),
1257                                min + (i3 as f32),
1258                            ];
1259
1260                            let sim: MathematicalValue<f32> =
1261                                diskann_vector::distance::Cosine::evaluate(&p, &target);
1262                            let sim = sim.into_inner();
1263                            if sim > best_similarity {
1264                                best_similarity = sim;
1265                                // Transform into an integer starting at zero.
1266                                best = p.map(|i| i - min);
1267                            }
1268                        }
1269                    }
1270                }
1271            }
1272
1273            // Now, rescale the input vector, clamp, and round.
1274            // Check if they agree.
1275            let clamped = target.map(|i| (i * scale - min).round().clamp(0.0, 7.0));
1276            let clamped_cosine: MathematicalValue<f32> =
1277                diskann_vector::distance::Cosine::evaluate(&clamped.map(|i| i + min), &target);
1278
1279            // We expect to either get the best value found via exhaustive search, or some
1280            // scalar multiple of it (since that will have the same cosine similarity).
1281            let passed = if best == clamped {
1282                true
1283            } else {
1284                let ratio: Vec<f32> = std::iter::zip(best, clamped)
1285                    .map(|(b, c)| {
1286                        let ratio = (b + min) / (c + min);
1287                        assert_ne!(
1288                            ratio, 0.0,
1289                            "ratio should never be zero because `b` is an integer and \
1290                             `min` is not"
1291                        );
1292                        ratio
1293                    })
1294                    .collect();
1295
1296                ratio.iter().all(|i| *i == ratio[0])
1297            };
1298
1299            if !passed {
1300                panic!(
1301                    "failed for input {:?}.\
1302                     Best = {:?}, Found = {:?}\
1303                     Best similarity = {}, similarity with clamped = {}",
1304                    target,
1305                    best,
1306                    clamped,
1307                    best_similarity,
1308                    clamped_cosine.into_inner()
1309                );
1310            }
1311        };
1312
1313        // Run targeted tests.
1314        let min = -3.5;
1315        for i0 in (0..8).step_by(2) {
1316            for i1 in (1..9).step_by(2) {
1317                for i2 in (0..8).step_by(2) {
1318                    for i3 in (1..9).step_by(2) {
1319                        let p: [f32; 4] = [
1320                            min + (i0 as f32),
1321                            min + (i1 as f32),
1322                            min + (i2 as f32),
1323                            min + (i3 as f32),
1324                        ];
1325                        run_test(p)
1326                    }
1327                }
1328            }
1329        }
1330
1331        for _ in 0..num_trials {
1332            let this_scale: f32 = scale_distribution.sample(&mut rng);
1333            let v: [f32; 4] = [(); 4].map(|_| {
1334                let v: f32 = StandardNormal {}.sample(&mut rng);
1335                this_scale * v
1336            });
1337            run_test(v);
1338        }
1339    }
1340
1341    #[test]
1342    #[should_panic(expected = "calling code should not allow the slice to be empty")]
1343    fn empty_slice_panics() {
1344        maximize_cosine_similarity(
1345            &[],
1346            NonZeroUsize::new(4).unwrap(),
1347            ScopedAllocator::global(),
1348        )
1349        .unwrap();
1350    }
1351
1352    struct Setup {
1353        transform: TransformKind,
1354        nrows: usize,
1355        ncols: usize,
1356        num_trials: usize,
1357    }
1358
1359    fn get_scale(scale: PreScale, quantizer: &SphericalQuantizer) -> f32 {
1360        match scale {
1361            PreScale::None => 1.0,
1362            PreScale::Some(v) => v.into_inner(),
1363            PreScale::ReciprocalMeanNorm => 1.0 / quantizer.mean_norm().into_inner(),
1364        }
1365    }
1366
1367    fn test_l2<const Q: usize, const D: usize, Perm>(
1368        setup: &Setup,
1369        problem: &test_util::TestProblem,
1370        computed_means: &[f32],
1371        pre_scale: PreScale,
1372        rng: &mut StdRng,
1373    ) where
1374        Unsigned: Representation<Q>,
1375        Unsigned: Representation<D>,
1376        Perm: PermutationStrategy<Q>,
1377        for<'a> SphericalQuantizer:
1378            CompressIntoWith<&'a [f32], DataMut<'a, D>, ScopedAllocator<'a>>,
1379        for<'a> SphericalQuantizer:
1380            CompressIntoWith<&'a [f32], QueryMut<'a, Q, Perm>, ScopedAllocator<'a>>,
1381    {
1382        assert_eq!(setup.nrows, problem.data.nrows());
1383        assert_eq!(setup.ncols, problem.data.ncols());
1384
1385        let scoped_global = ScopedAllocator::global();
1386        let distribution = Uniform::new(0, setup.nrows).unwrap();
1387        let quantizer = SphericalQuantizer::train(
1388            problem.data.as_view(),
1389            setup.transform,
1390            SupportedMetric::SquaredL2,
1391            pre_scale,
1392            rng,
1393            GlobalAllocator,
1394        )
1395        .unwrap();
1396
1397        let scale = get_scale(pre_scale, &quantizer);
1398
1399        let mut b = Data::<D, _>::new_boxed(quantizer.output_dim());
1400        let mut q = Query::<Q, Perm, _>::new_boxed(quantizer.output_dim());
1401        let mut f = FullQuery::empty(quantizer.output_dim(), GlobalAllocator).unwrap();
1402
1403        assert_eq!(
1404            quantizer.mean_norm.into_inner(),
1405            problem.mean_norm as f32,
1406            "computed mean norm should not apply scale"
1407        );
1408        let scaled_means: Vec<_> = computed_means.iter().map(|i| scale * i).collect();
1409        assert_eq!(&*scaled_means, quantizer.shift());
1410
1411        let l2: CompensatedSquaredL2 = quantizer.as_functor();
1412        assert_eq!(l2.dim, quantizer.output_dim() as f32);
1413
1414        for _ in 0..setup.num_trials {
1415            let i = distribution.sample(rng);
1416            let v = problem.data.row(i);
1417
1418            quantizer
1419                .compress_into_with(v, b.reborrow_mut(), scoped_global)
1420                .unwrap();
1421            quantizer
1422                .compress_into_with(v, q.reborrow_mut(), scoped_global)
1423                .unwrap();
1424            quantizer
1425                .compress_into_with(v, f.reborrow_mut(), scoped_global)
1426                .unwrap();
1427
1428            let shifted: Vec<f32> = std::iter::zip(v.iter(), quantizer.shift().iter())
1429                .map(|(a, b)| scale * a - b)
1430                .collect();
1431
1432            // Check that the compensation coefficient were chosen correctly.
1433            {
1434                let DataMetaF32 {
1435                    inner_product_correction,
1436                    bit_sum,
1437                    metric_specific,
1438                } = b.meta().to_full(ARCH);
1439
1440                let shifted_square_norm = metric_specific;
1441
1442                // Check that the bit-count is correct. let bv = b.vector();
1443                let bv = b.vector();
1444                let s: usize = (0..bv.len()).map(|i| bv.get(i).unwrap() as usize).sum();
1445                assert_eq!(s, bit_sum as usize);
1446
1447                // Check that the shifted norm is correct.
1448                {
1449                    let expected = FastL2NormSquared.evaluate(&*shifted);
1450                    let err = (shifted_square_norm - expected).abs() / expected.abs();
1451                    assert!(
1452                        err < 5.0e-4, // twice the minimum normal f16 value.
1453                        "failed diff check, got {}, expected {} - relative error = {}",
1454                        shifted_square_norm,
1455                        expected,
1456                        err
1457                    );
1458                }
1459
1460                // Finaly, verify that the self-inner-product is clustered around 0.8 as
1461                // the RaBitQ paper suggests.
1462                if const { D == 1 } {
1463                    let self_inner_product = 2.0 * shifted_square_norm.sqrt()
1464                        / (inner_product_correction * (bv.len() as f32).sqrt());
1465                    assert!(
1466                        (self_inner_product - 0.8).abs() < 0.13,
1467                        "self inner-product should be close to 0.8. Instead, it's {}",
1468                        self_inner_product
1469                    );
1470                }
1471            }
1472
1473            {
1474                let QueryMeta {
1475                    inner_product_correction,
1476                    bit_sum,
1477                    offset,
1478                    metric_specific,
1479                } = q.meta();
1480
1481                let shifted_square_norm = metric_specific;
1482                let mut preprocessed = quantizer.preprocess(v, scoped_global).unwrap();
1483                preprocessed
1484                    .shifted
1485                    .iter_mut()
1486                    .for_each(|i| *i /= preprocessed.shifted_norm);
1487
1488                let mut transformed = vec![0.0f32; quantizer.output_dim()];
1489                quantizer
1490                    .transform
1491                    .transform_into(&mut transformed, &preprocessed.shifted, scoped_global)
1492                    .unwrap();
1493
1494                let min = transformed.iter().fold(f32::MAX, |min, &i| min.min(i));
1495                let max = transformed.iter().fold(f32::MIN, |max, &i| max.max(i));
1496
1497                let scale = (max - min) / ((2usize.pow(Q as u32) - 1) as f32);
1498
1499                // Shifted Norm
1500                {
1501                    let expected = FastL2NormSquared.evaluate(&*shifted);
1502                    let err = (shifted_square_norm - expected).abs() / expected.abs();
1503                    assert!(
1504                        err < 2e-7,
1505                        "failed diff check, got {}, expected {} - relative error = {}",
1506                        shifted_square_norm,
1507                        expected,
1508                        err
1509                    );
1510                }
1511
1512                // Inner product correction
1513                {
1514                    let expected = shifted_square_norm.sqrt() * scale;
1515                    let got = inner_product_correction;
1516
1517                    let err = (expected - got).abs();
1518                    assert!(
1519                        err < 1.0e-7,
1520                        "\"innerproduct_scale\": expected {}, got {}, error = {}",
1521                        expected,
1522                        got,
1523                        err
1524                    );
1525                }
1526
1527                // Offset
1528                {
1529                    let expected = min / scale;
1530                    let got = offset;
1531
1532                    let err = (expected - got).abs();
1533                    assert!(
1534                        err < 1.0e-7,
1535                        "\"sum_scale\": expected {}, got {}, error = {}",
1536                        expected,
1537                        got,
1538                        err
1539                    );
1540                }
1541
1542                // Bit Sum
1543                {
1544                    let expected = (0..q.len())
1545                        .map(|i| q.vector().get(i).unwrap())
1546                        .sum::<i64>() as f32;
1547
1548                    let got = bit_sum;
1549
1550                    let err = (expected - got).abs();
1551                    assert!(
1552                        err < 1.0e-7,
1553                        "\"offset\": expected {}, got {}, error = {}",
1554                        expected,
1555                        got,
1556                        err
1557                    );
1558                }
1559            }
1560
1561            // Check that the compensation coefficient were chosen correctly.
1562            {
1563                // Check that the bit-count is correct.
1564                let s: f32 = f.data.iter().sum::<f32>();
1565                assert_eq!(s, f.meta.sum);
1566
1567                // Check that the shifted norm is correct.
1568                {
1569                    let expected = FastL2Norm.evaluate(&*shifted);
1570                    let err = (f.meta.shifted_norm - expected).abs() / expected.abs();
1571                    assert!(
1572                        err < 2e-7,
1573                        "failed diff check, got {}, expected {} - relative error = {}",
1574                        f.meta.shifted_norm,
1575                        expected,
1576                        err
1577                    );
1578                }
1579
1580                assert_eq!(
1581                    f.meta.metric_specific,
1582                    f.meta.shifted_norm * f.meta.shifted_norm,
1583                    "metric specific data for squared l2 is the square shifted norm",
1584                );
1585            }
1586        }
1587
1588        // Finally - test that if we compress the centroid, the metadata coefficients get
1589        // zeroed correctly.
1590        quantizer
1591            .compress_into_with(computed_means, b.reborrow_mut(), scoped_global)
1592            .unwrap();
1593        assert_eq!(b.meta(), DataMeta::default());
1594
1595        quantizer
1596            .compress_into_with(computed_means, q.reborrow_mut(), scoped_global)
1597            .unwrap();
1598        assert_eq!(q.meta(), QueryMeta::default());
1599
1600        f.data.fill(f32::INFINITY);
1601        quantizer
1602            .compress_into_with(computed_means, f.reborrow_mut(), scoped_global)
1603            .unwrap();
1604        assert!(f.data.iter().all(|&i| i == 0.0));
1605        assert_eq!(f.meta.sum, 0.0);
1606        assert_eq!(f.meta.metric_specific, 0.0);
1607    }
1608
1609    fn test_ip<const Q: usize, const D: usize, Perm>(
1610        setup: &Setup,
1611        problem: &test_util::TestProblem,
1612        computed_means: &[f32],
1613        pre_scale: PreScale,
1614        rng: &mut StdRng,
1615        ctx: &dyn Display,
1616    ) where
1617        Unsigned: Representation<Q>,
1618        Unsigned: Representation<D>,
1619        Perm: PermutationStrategy<Q>,
1620        for<'a> SphericalQuantizer:
1621            CompressIntoWith<&'a [f32], DataMut<'a, D>, ScopedAllocator<'a>>,
1622        for<'a> SphericalQuantizer:
1623            CompressIntoWith<&'a [f32], QueryMut<'a, Q, Perm>, ScopedAllocator<'a>>,
1624    {
1625        assert_eq!(setup.nrows, problem.data.nrows());
1626        assert_eq!(setup.ncols, problem.data.ncols());
1627
1628        let scoped_global = ScopedAllocator::global();
1629        let distribution = Uniform::new(0, setup.nrows).unwrap();
1630        let quantizer = SphericalQuantizer::train(
1631            problem.data.as_view(),
1632            setup.transform,
1633            SupportedMetric::InnerProduct,
1634            pre_scale,
1635            rng,
1636            GlobalAllocator,
1637        )
1638        .unwrap();
1639
1640        let scale = get_scale(pre_scale, &quantizer);
1641
1642        let mut b = Data::<D, _>::new_boxed(quantizer.output_dim());
1643        let mut q = Query::<Q, Perm, _>::new_boxed(quantizer.output_dim());
1644        let mut f = FullQuery::empty(quantizer.output_dim(), GlobalAllocator).unwrap();
1645
1646        assert_eq!(
1647            quantizer.mean_norm.into_inner(),
1648            problem.mean_norm as f32,
1649            "computed mean norm should not apply scale"
1650        );
1651        let scaled_means: Vec<_> = computed_means.iter().map(|i| scale * i).collect();
1652        assert_eq!(&*scaled_means, quantizer.shift());
1653
1654        let ip: CompensatedIP = quantizer.as_functor();
1655
1656        assert_eq!(ip.dim, quantizer.output_dim() as f32);
1657        assert_eq!(
1658            ip.squared_shift_norm,
1659            FastL2NormSquared.evaluate(quantizer.shift())
1660        );
1661
1662        for _ in 0..setup.num_trials {
1663            let i = distribution.sample(rng);
1664            let v = problem.data.row(i);
1665
1666            quantizer
1667                .compress_into_with(v, b.reborrow_mut(), scoped_global)
1668                .unwrap();
1669            quantizer
1670                .compress_into_with(v, q.reborrow_mut(), scoped_global)
1671                .unwrap();
1672            quantizer
1673                .compress_into_with(v, f.reborrow_mut(), scoped_global)
1674                .unwrap();
1675
1676            let shifted: Vec<f32> = std::iter::zip(v.iter(), quantizer.shift().iter())
1677                .map(|(a, b)| scale * a - b)
1678                .collect();
1679
1680            // Check that the compensation coefficient were chosen correctly.
1681            {
1682                let DataMetaF32 {
1683                    inner_product_correction,
1684                    bit_sum,
1685                    metric_specific,
1686                } = b.meta().to_full(ARCH);
1687
1688                let inner_product_with_centroid = metric_specific;
1689
1690                // Check that the bit-count is correct.
1691                let bv = b.vector();
1692                let s: usize = (0..bv.len()).map(|i| bv.get(i).unwrap() as usize).sum();
1693                assert_eq!(s, bit_sum as usize);
1694
1695                // Check that the shifted norm is correct.
1696                let inner_product: MathematicalValue<f32> =
1697                    InnerProduct::evaluate(&*shifted, quantizer.shift());
1698
1699                let diff = (inner_product.into_inner() - inner_product_with_centroid).abs();
1700                assert!(
1701                    diff < 1.53e-5,
1702                    "got a diff of {}. Expected = {}, got = {} -- context: {}",
1703                    diff,
1704                    inner_product.into_inner(),
1705                    inner_product_with_centroid,
1706                    ctx,
1707                );
1708
1709                // Finaly, verify that the self-inner-product is clustered around 0.8 as
1710                // the RaBitQ paper suggests.
1711                if const { D == 1 } {
1712                    let self_inner_product = 2.0 * (FastL2Norm).evaluate(&*shifted)
1713                        / (inner_product_correction * (bv.len() as f32).sqrt());
1714                    assert!(
1715                        (self_inner_product - 0.8).abs() < 0.12,
1716                        "self inner-product should be close to 0.8. Instead, it's {}",
1717                        self_inner_product
1718                    );
1719                }
1720            }
1721
1722            {
1723                let QueryMeta {
1724                    inner_product_correction,
1725                    bit_sum,
1726                    offset,
1727                    metric_specific,
1728                } = q.meta();
1729
1730                let inner_product_with_centroid = metric_specific;
1731                let mut preprocessed = quantizer.preprocess(v, scoped_global).unwrap();
1732                preprocessed
1733                    .shifted
1734                    .iter_mut()
1735                    .for_each(|i| *i /= preprocessed.shifted_norm);
1736
1737                let mut transformed = vec![0.0f32; quantizer.output_dim()];
1738                quantizer
1739                    .transform
1740                    .transform_into(&mut transformed, &preprocessed.shifted, scoped_global)
1741                    .unwrap();
1742
1743                let min = transformed.iter().fold(f32::MAX, |min, &i| min.min(i));
1744                let max = transformed.iter().fold(f32::MIN, |max, &i| max.max(i));
1745
1746                let scale = (max - min) / ((2usize.pow(Q as u32) - 1) as f32);
1747
1748                // Inner product correction
1749                {
1750                    let expected = (FastL2Norm).evaluate(&*shifted) * scale;
1751                    let got = inner_product_correction;
1752
1753                    let err = (expected - got).abs();
1754                    assert!(
1755                        err < 1.0e-7,
1756                        "\"innerproduct_scale\": expected {}, got {}, error = {}",
1757                        expected,
1758                        got,
1759                        err
1760                    );
1761                }
1762
1763                // Offset
1764                {
1765                    let expected = min / scale;
1766                    let got = offset;
1767
1768                    let err = (expected - got).abs();
1769                    assert!(
1770                        err < 1.0e-7,
1771                        "\"sum_scale\": expected {}, got {}, error = {}",
1772                        expected,
1773                        got,
1774                        err
1775                    );
1776                }
1777
1778                // Bit Sum
1779                {
1780                    let expected = (0..q.len())
1781                        .map(|i| q.vector().get(i).unwrap())
1782                        .sum::<i64>() as f32;
1783
1784                    let got = bit_sum;
1785
1786                    let err = (expected - got).abs();
1787                    assert!(
1788                        err < 1.0e-7,
1789                        "\"offset\": expected {}, got {}, error = {}",
1790                        expected,
1791                        got,
1792                        err
1793                    );
1794                }
1795
1796                // Inner Product with Centroid
1797                {
1798                    // Check that the shifted norm is correct.
1799                    let inner_product: MathematicalValue<f32> =
1800                        InnerProduct::evaluate(&*shifted, quantizer.shift());
1801                    assert_eq!(inner_product.into_inner(), inner_product_with_centroid);
1802                }
1803            }
1804
1805            // Check that the compensation coefficient were chosen correctly.
1806            {
1807                // Check that the bit-count is correct.
1808                let s: f32 = f.data.iter().sum::<f32>();
1809                assert_eq!(s, f.meta.sum);
1810
1811                // Check that the shifted norm is correct.
1812                {
1813                    let expected = FastL2Norm.evaluate(&*shifted);
1814                    let err = (f.meta.shifted_norm - expected).abs() / expected.abs();
1815                    assert!(
1816                        err < 2e-7,
1817                        "failed diff check, got {}, expected {} - relative error = {}",
1818                        f.meta.shifted_norm,
1819                        expected,
1820                        err
1821                    );
1822                }
1823
1824                // Check that the shifted norm is correct. s
1825                let inner_product: MathematicalValue<f32> =
1826                    InnerProduct::evaluate(&*shifted, quantizer.shift());
1827                assert_eq!(inner_product.into_inner(), f.meta.metric_specific,);
1828            }
1829        }
1830
1831        // Finally - test that if we compress the centroid, the metadata coefficients get
1832        // zeroed correctly.
1833        quantizer
1834            .compress_into_with(computed_means, b.reborrow_mut(), scoped_global)
1835            .unwrap();
1836        assert_eq!(b.meta(), DataMeta::default());
1837
1838        quantizer
1839            .compress_into_with(computed_means, q.reborrow_mut(), scoped_global)
1840            .unwrap();
1841        assert_eq!(q.meta(), QueryMeta::default());
1842
1843        f.data.fill(f32::INFINITY);
1844        quantizer
1845            .compress_into_with(computed_means, f.reborrow_mut(), scoped_global)
1846            .unwrap();
1847        assert!(f.data.iter().all(|&i| i == 0.0));
1848        assert_eq!(f.meta.sum, 0.0);
1849        assert_eq!(f.meta.metric_specific, 0.0);
1850    }
1851
1852    fn test_cosine<const Q: usize, const D: usize, Perm>(
1853        setup: &Setup,
1854        problem: &test_util::TestProblem,
1855        pre_scale: PreScale,
1856        rng: &mut StdRng,
1857    ) where
1858        Unsigned: Representation<Q>,
1859        Unsigned: Representation<D>,
1860        Perm: PermutationStrategy<Q>,
1861        for<'a> SphericalQuantizer:
1862            CompressIntoWith<&'a [f32], DataMut<'a, D>, ScopedAllocator<'a>>,
1863        for<'a> SphericalQuantizer:
1864            CompressIntoWith<&'a [f32], QueryMut<'a, Q, Perm>, ScopedAllocator<'a>>,
1865    {
1866        assert_eq!(setup.nrows, problem.data.nrows());
1867        assert_eq!(setup.ncols, problem.data.ncols());
1868
1869        let scoped_global = ScopedAllocator::global();
1870        let distribution = Uniform::new(0, setup.nrows).unwrap();
1871        let quantizer = SphericalQuantizer::train(
1872            problem.data.as_view(),
1873            setup.transform,
1874            SupportedMetric::Cosine,
1875            pre_scale,
1876            rng,
1877            GlobalAllocator,
1878        )
1879        .unwrap();
1880
1881        let mut b = Data::<D, _>::new_boxed(quantizer.output_dim());
1882        let mut q = Query::<Q, Perm, _>::new_boxed(quantizer.output_dim());
1883        let mut f = FullQuery::empty(quantizer.output_dim(), GlobalAllocator).unwrap();
1884
1885        let cosine: CompensatedCosine = quantizer.as_functor();
1886
1887        assert_eq!(cosine.inner.dim, quantizer.output_dim() as f32);
1888        assert_eq!(
1889            cosine.inner.squared_shift_norm,
1890            FastL2NormSquared.evaluate(quantizer.shift())
1891        );
1892
1893        const IP_BOUND: f32 = 2.6e-3;
1894
1895        let mut test_row = |v: &[f32]| {
1896            let vnorm = (FastL2Norm).evaluate(v);
1897            let v_normalized: Vec<f32> = v
1898                .iter()
1899                .map(|i| if vnorm == 0.0 { 0.0 } else { *i / vnorm })
1900                .collect();
1901
1902            quantizer
1903                .compress_into_with(v, b.reborrow_mut(), scoped_global)
1904                .unwrap();
1905
1906            quantizer
1907                .compress_into_with(v, q.reborrow_mut(), scoped_global)
1908                .unwrap();
1909
1910            quantizer
1911                .compress_into_with(v, f.reborrow_mut(), scoped_global)
1912                .unwrap();
1913
1914            let shifted: Vec<f32> = std::iter::zip(v_normalized.iter(), quantizer.shift().iter())
1915                .map(|(a, b)| a - b)
1916                .collect();
1917
1918            // Check that the compensation coefficient were chosen correctly.
1919            {
1920                let DataMetaF32 {
1921                    inner_product_correction,
1922                    bit_sum,
1923                    metric_specific,
1924                } = b.meta().to_full(ARCH);
1925
1926                let inner_product_with_centroid = metric_specific;
1927
1928                // Check that the bit-count is correct.
1929                let bv = b.vector();
1930                let s: usize = (0..bv.len()).map(|i| bv.get(i).unwrap() as usize).sum();
1931                assert_eq!(s, bit_sum as usize);
1932
1933                // Check that the shifted norm is correct. Since they are computed slightly
1934                // differnetly, allow a small amount of error.
1935                let inner_product: MathematicalValue<f32> =
1936                    InnerProduct::evaluate(&*shifted, quantizer.shift());
1937
1938                let abs = (inner_product.into_inner() - inner_product_with_centroid).abs();
1939                let relative = abs / inner_product.into_inner().abs();
1940
1941                assert!(
1942                    abs < 1e-7 || relative < IP_BOUND,
1943                    "got an abs/rel of {}/{} with a bound of {}/{}",
1944                    abs,
1945                    relative,
1946                    1e-7,
1947                    IP_BOUND
1948                );
1949
1950                // Finaly, verify that the self-inner-product is clustered around 0.8 as
1951                // the RaBitQ paper suggests.
1952                if const { D == 1 } {
1953                    let self_inner_product = 2.0 * (FastL2Norm).evaluate(&*shifted)
1954                        / (inner_product_correction * (bv.len() as f32).sqrt());
1955                    assert!(
1956                        (self_inner_product - 0.8).abs() < 0.11,
1957                        "self inner-product should be close to 0.8. Instead, it's {}",
1958                        self_inner_product
1959                    );
1960                }
1961            }
1962
1963            {
1964                let QueryMeta {
1965                    inner_product_correction,
1966                    bit_sum,
1967                    offset,
1968                    metric_specific,
1969                } = q.meta();
1970
1971                let inner_product_with_centroid = metric_specific;
1972                let mut preprocessed = quantizer.preprocess(v, scoped_global).unwrap();
1973                preprocessed
1974                    .shifted
1975                    .iter_mut()
1976                    .for_each(|i| *i /= preprocessed.shifted_norm);
1977
1978                let mut transformed = vec![0.0f32; quantizer.output_dim()];
1979                quantizer
1980                    .transform
1981                    .transform_into(&mut transformed, &preprocessed.shifted, scoped_global)
1982                    .unwrap();
1983
1984                let min = transformed.iter().fold(f32::MAX, |min, &i| min.min(i));
1985                let max = transformed.iter().fold(f32::MIN, |max, &i| max.max(i));
1986
1987                let scale = (max - min) / ((2usize.pow(Q as u32) - 1) as f32);
1988
1989                // Inner product correction
1990                {
1991                    let expected = (FastL2Norm).evaluate(&*shifted) * scale;
1992                    let got = inner_product_correction;
1993
1994                    let err = (expected - got).abs();
1995                    assert!(
1996                        err < 1.0e-7,
1997                        "\"innerproduct_scale\": expected {}, got {}, error = {}",
1998                        expected,
1999                        got,
2000                        err
2001                    );
2002                }
2003
2004                // Offset
2005                {
2006                    let expected = min / scale;
2007                    let got = offset;
2008
2009                    let err = (expected - got).abs();
2010                    assert!(
2011                        err < 1.0e-7,
2012                        "\"sum_scale\": expected {}, got {}, error = {}",
2013                        expected,
2014                        got,
2015                        err
2016                    );
2017                }
2018
2019                // Bit Sum
2020                {
2021                    let expected = (0..q.len())
2022                        .map(|i| q.vector().get(i).unwrap())
2023                        .sum::<i64>() as f32;
2024
2025                    let got = bit_sum;
2026
2027                    let err = (expected - got).abs();
2028                    assert!(
2029                        err < 1.0e-7,
2030                        "\"offset\": expected {}, got {}, error = {}",
2031                        expected,
2032                        got,
2033                        err
2034                    );
2035                }
2036
2037                // Inner Product with Centroid
2038                {
2039                    // Check that the shifted norm is correct.
2040                    let inner_product: MathematicalValue<f32> =
2041                        InnerProduct::evaluate(&*shifted, quantizer.shift());
2042
2043                    let err = (inner_product.into_inner() - inner_product_with_centroid).abs()
2044                        / inner_product.into_inner().abs();
2045                    assert!(
2046                        err < IP_BOUND,
2047                        "\"offset\": expected {}, got {}, error = {}",
2048                        inner_product.into_inner(),
2049                        inner_product_with_centroid,
2050                        err
2051                    );
2052                }
2053            }
2054
2055            // Check that the compensation coefficient were chosen correctly.
2056            {
2057                // Check that the bit-count is correct.
2058                let s: f32 = f.data.iter().sum::<f32>();
2059                assert_eq!(s, f.meta.sum);
2060
2061                // Check that the shifted norm is correct.
2062                {
2063                    let expected = FastL2Norm.evaluate(&*shifted);
2064                    let err = (f.meta.shifted_norm - expected).abs() / expected.abs();
2065                    assert!(
2066                        err < 2e-7,
2067                        "failed diff check, got {}, expected {} - relative error = {}",
2068                        f.meta.shifted_norm,
2069                        expected,
2070                        err
2071                    );
2072                }
2073
2074                // Check that the shifted norm is correct. s
2075                let inner_product: MathematicalValue<f32> =
2076                    InnerProduct::evaluate(&*shifted, quantizer.shift());
2077                let err = (inner_product.into_inner() - f.meta.metric_specific).abs()
2078                    / inner_product.into_inner().abs();
2079                assert!(
2080                    err < IP_BOUND,
2081                    "\"offset\": expected {}, got {}, error = {}",
2082                    inner_product.into_inner(),
2083                    f.meta.metric_specific,
2084                    err
2085                );
2086            }
2087        };
2088
2089        for _ in 0..setup.num_trials {
2090            let i = distribution.sample(rng);
2091            let v = problem.data.row(i);
2092            test_row(v);
2093        }
2094
2095        // Ensure that if a zero vector is provided that we do not divide by zero.
2096        let zero = vec![0.0f32; quantizer.input_dim()];
2097        test_row(&zero);
2098    }
2099
2100    fn _test_oom_resiliance<T>(quantizer: &SphericalQuantizer, data: &[f32], dst: &mut T)
2101    where
2102        for<'a> T: ReborrowMut<'a>,
2103        for<'a> SphericalQuantizer: CompressIntoWith<
2104                &'a [f32],
2105                <T as ReborrowMut<'a>>::Target,
2106                ScopedAllocator<'a>,
2107                Error = CompressionError,
2108            >,
2109    {
2110        let mut succeeded = false;
2111        let mut failed = false;
2112        for max_allocations in 0..10 {
2113            match quantizer.compress_into_with(
2114                data,
2115                dst.reborrow_mut(),
2116                ScopedAllocator::new(&test_util::LimitedAllocator::new(max_allocations)),
2117            ) {
2118                Ok(()) => {
2119                    succeeded = true;
2120                }
2121                Err(CompressionError::AllocatorError(_)) => {
2122                    failed = true;
2123                }
2124                Err(other) => {
2125                    panic!("received an unexpected error: {:?}", other);
2126                }
2127            }
2128        }
2129        assert!(succeeded);
2130        assert!(failed);
2131    }
2132
2133    fn test_oom_resiliance<const Q: usize, const D: usize, Perm>(
2134        setup: &Setup,
2135        problem: &test_util::TestProblem,
2136        pre_scale: PreScale,
2137        rng: &mut StdRng,
2138    ) where
2139        Unsigned: Representation<Q>,
2140        Unsigned: Representation<D>,
2141        Perm: PermutationStrategy<Q>,
2142        for<'a> SphericalQuantizer: CompressIntoWith<
2143                &'a [f32],
2144                DataMut<'a, D>,
2145                ScopedAllocator<'a>,
2146                Error = CompressionError,
2147            >,
2148        for<'a> SphericalQuantizer: CompressIntoWith<
2149                &'a [f32],
2150                QueryMut<'a, Q, Perm>,
2151                ScopedAllocator<'a>,
2152                Error = CompressionError,
2153            >,
2154    {
2155        assert_eq!(setup.nrows, problem.data.nrows());
2156        assert_eq!(setup.ncols, problem.data.ncols());
2157
2158        let quantizer = SphericalQuantizer::train(
2159            problem.data.as_view(),
2160            setup.transform,
2161            SupportedMetric::SquaredL2,
2162            pre_scale,
2163            rng,
2164            GlobalAllocator,
2165        )
2166        .unwrap();
2167
2168        // Data.
2169        let data = problem.data.row(0);
2170        _test_oom_resiliance::<Data<D, _>>(
2171            &quantizer,
2172            data,
2173            &mut Data::new_boxed(quantizer.output_dim()),
2174        );
2175        _test_oom_resiliance::<Query<Q, Perm, _>>(
2176            &quantizer,
2177            data,
2178            &mut Query::new_boxed(quantizer.output_dim()),
2179        );
2180        _test_oom_resiliance::<FullQuery<_>>(
2181            &quantizer,
2182            data,
2183            &mut FullQuery::empty(quantizer.output_dim(), GlobalAllocator).unwrap(),
2184        );
2185    }
2186
2187    fn test_quantizer<const Q: usize, const D: usize, Perm>(setup: &Setup, rng: &mut StdRng)
2188    where
2189        Unsigned: Representation<Q>,
2190        Unsigned: Representation<D>,
2191        Perm: PermutationStrategy<Q>,
2192        for<'a> SphericalQuantizer: CompressIntoWith<
2193                &'a [f32],
2194                DataMut<'a, D>,
2195                ScopedAllocator<'a>,
2196                Error = CompressionError,
2197            >,
2198        for<'a> SphericalQuantizer: CompressIntoWith<
2199                &'a [f32],
2200                QueryMut<'a, Q, Perm>,
2201                ScopedAllocator<'a>,
2202                Error = CompressionError,
2203            >,
2204    {
2205        let problem = test_util::create_test_problem(setup.nrows, setup.ncols, rng);
2206        let computed_means_f32: Vec<_> = problem.means.iter().map(|i| *i as f32).collect();
2207
2208        let scales = [
2209            PreScale::Some(Positive::new(1.0 / 1024.0).unwrap()),
2210            PreScale::Some(Positive::new(1.0 / 1024.0).unwrap()),
2211            PreScale::ReciprocalMeanNorm,
2212        ];
2213
2214        for scale in scales {
2215            let ctx = &lazy_format!("dim = {}, scale = {:?}", setup.ncols, scale);
2216
2217            test_l2::<Q, D, Perm>(setup, &problem, &computed_means_f32, scale, rng);
2218            test_ip::<Q, D, Perm>(setup, &problem, &computed_means_f32, scale, rng, ctx);
2219            test_cosine::<Q, D, Perm>(setup, &problem, scale, rng);
2220        }
2221
2222        test_oom_resiliance::<Q, D, Perm>(setup, &problem, PreScale::ReciprocalMeanNorm, rng);
2223    }
2224
2225    #[test]
2226    fn test_spherical_quantizer() {
2227        let mut rng = StdRng::seed_from_u64(0xab516aef1ce61640);
2228        for dim in [56, 72, 128, 255] {
2229            let setup = Setup {
2230                transform: TransformKind::PaddingHadamard {
2231                    target_dim: TargetDim::Same,
2232                },
2233                nrows: 64,
2234                ncols: dim,
2235                num_trials: 10,
2236            };
2237
2238            test_quantizer::<4, 1, BitTranspose>(&setup, &mut rng);
2239            test_quantizer::<2, 2, Dense>(&setup, &mut rng);
2240            test_quantizer::<4, 4, Dense>(&setup, &mut rng);
2241            test_quantizer::<8, 8, Dense>(&setup, &mut rng);
2242
2243            let setup = Setup {
2244                transform: TransformKind::DoubleHadamard {
2245                    target_dim: TargetDim::Same,
2246                },
2247                nrows: 64,
2248                ncols: dim,
2249                num_trials: 10,
2250            };
2251            test_quantizer::<4, 1, BitTranspose>(&setup, &mut rng);
2252            test_quantizer::<2, 2, Dense>(&setup, &mut rng);
2253            test_quantizer::<4, 4, Dense>(&setup, &mut rng);
2254            test_quantizer::<8, 8, Dense>(&setup, &mut rng);
2255        }
2256    }
2257
2258    ////////////
2259    // Errors //
2260    ////////////
2261
2262    #[test]
2263    fn err_dim_cannot_be_zero() {
2264        let data = Matrix::new(0.0f32, 10, 0);
2265        let mut rng = StdRng::seed_from_u64(0xe3e9f42ed9f15883);
2266        let err = SphericalQuantizer::train(
2267            data.as_view(),
2268            TransformKind::DoubleHadamard {
2269                target_dim: TargetDim::Same,
2270            },
2271            SupportedMetric::SquaredL2,
2272            PreScale::None,
2273            &mut rng,
2274            GlobalAllocator,
2275        )
2276        .unwrap_err();
2277        assert_eq!(err.to_string(), "data dim cannot be zero");
2278    }
2279
2280    #[test]
2281    fn err_norm_must_be_positive() {
2282        let data = Matrix::new(0.0f32, 10, 10);
2283        let mut rng = StdRng::seed_from_u64(0xe3e9f42ed9f15883);
2284        let err = SphericalQuantizer::train(
2285            data.as_view(),
2286            TransformKind::DoubleHadamard {
2287                target_dim: TargetDim::Same,
2288            },
2289            SupportedMetric::SquaredL2,
2290            PreScale::None,
2291            &mut rng,
2292            GlobalAllocator,
2293        )
2294        .unwrap_err();
2295        assert_eq!(err.to_string(), "norm must be positive");
2296    }
2297
2298    #[test]
2299    fn err_norm_cannot_be_infinity() {
2300        let mut data = Matrix::new(0.0f32, 10, 10);
2301        data[(2, 5)] = f32::INFINITY;
2302
2303        let mut rng = StdRng::seed_from_u64(0xe3e9f42ed9f15883);
2304        let err = SphericalQuantizer::train(
2305            data.as_view(),
2306            TransformKind::DoubleHadamard {
2307                target_dim: TargetDim::Same,
2308            },
2309            SupportedMetric::SquaredL2,
2310            PreScale::None,
2311            &mut rng,
2312            GlobalAllocator,
2313        )
2314        .unwrap_err();
2315        assert_eq!(err.to_string(), "computed norm contains infinity or NaN");
2316    }
2317
2318    #[test]
2319    fn err_reciprocal_norm_cannot_be_infinity() {
2320        let mut data = Matrix::new(0.0f32, 10, 10);
2321        data[(2, 5)] = 2.93863e-39;
2322
2323        let mut rng = StdRng::seed_from_u64(0xe3e9f42ed9f15883);
2324        let err = SphericalQuantizer::train(
2325            data.as_view(),
2326            TransformKind::DoubleHadamard {
2327                target_dim: TargetDim::Same,
2328            },
2329            SupportedMetric::SquaredL2,
2330            PreScale::ReciprocalMeanNorm,
2331            &mut rng,
2332            GlobalAllocator,
2333        )
2334        .unwrap_err();
2335        assert_eq!(err.to_string(), "reciprocal norm contains infinity or NaN");
2336    }
2337
2338    #[test]
2339    fn err_mean_norm_cannot_be_zero_generate() {
2340        let centroid = Poly::broadcast(0.0f32, 10, GlobalAllocator).unwrap();
2341        let mut rng = StdRng::seed_from_u64(0xe3e9f42ed9f15883);
2342        let err = SphericalQuantizer::generate(
2343            centroid,
2344            0.0,
2345            TransformKind::DoubleHadamard {
2346                target_dim: TargetDim::Same,
2347            },
2348            SupportedMetric::SquaredL2,
2349            None,
2350            &mut rng,
2351            GlobalAllocator,
2352        )
2353        .unwrap_err();
2354        assert_eq!(err.to_string(), "norm must be positive");
2355    }
2356
2357    #[test]
2358    fn err_scale_cannot_be_zero_generate() {
2359        let centroid = Poly::broadcast(0.0f32, 10, GlobalAllocator).unwrap();
2360        let mut rng = StdRng::seed_from_u64(0xe3e9f42ed9f15883);
2361        let err = SphericalQuantizer::generate(
2362            centroid,
2363            1.0,
2364            TransformKind::DoubleHadamard {
2365                target_dim: TargetDim::Same,
2366            },
2367            SupportedMetric::SquaredL2,
2368            Some(0.0),
2369            &mut rng,
2370            GlobalAllocator,
2371        )
2372        .unwrap_err();
2373        assert_eq!(err.to_string(), "pre-scale must be positive");
2374    }
2375
2376    #[test]
2377    fn compression_errors_data() {
2378        let mut rng = StdRng::seed_from_u64(0xe3e9f42ed9f15883);
2379        let data = Matrix::<f32>::new(views::Init(|| StandardNormal {}.sample(&mut rng)), 16, 12);
2380
2381        let quantizer = SphericalQuantizer::train(
2382            data.as_view(),
2383            TransformKind::PaddingHadamard {
2384                target_dim: TargetDim::Same,
2385            },
2386            SupportedMetric::SquaredL2,
2387            PreScale::None,
2388            &mut rng,
2389            GlobalAllocator,
2390        )
2391        .unwrap();
2392
2393        let scoped_global = ScopedAllocator::global();
2394
2395        // Input contains NaN.
2396        {
2397            let mut query: Vec<f32> = quantizer.shift().to_vec();
2398            let mut d = Data::<1, _>::new_boxed(quantizer.output_dim());
2399            let mut q = Query::<4, BitTranspose, _>::new_boxed(quantizer.output_dim());
2400
2401            for i in 0..query.len() {
2402                let last = query[i];
2403                for v in [f32::NAN, f32::INFINITY, f32::NEG_INFINITY] {
2404                    query[i] = v;
2405
2406                    let err = quantizer
2407                        .compress_into_with(&*query, d.reborrow_mut(), scoped_global)
2408                        .unwrap_err();
2409
2410                    assert_eq!(err.to_string(), "input contains NaN", "failed for {}", v);
2411
2412                    let err = quantizer
2413                        .compress_into_with(&*query, q.reborrow_mut(), scoped_global)
2414                        .unwrap_err();
2415
2416                    assert_eq!(err.to_string(), "input contains NaN", "failed for {}", v);
2417                }
2418                query[i] = last;
2419            }
2420        }
2421
2422        // Input has a large value.
2423        {
2424            let query: Vec<f32> = vec![1000000.0; quantizer.input_dim()];
2425            let mut d = Data::<1, _>::new_boxed(quantizer.output_dim());
2426
2427            let err = quantizer
2428                .compress_into_with(&*query, d.reborrow_mut(), scoped_global)
2429                .unwrap_err();
2430
2431            let expected = "encoding error - you may need to scale the entire dataset to reduce its dynamic range";
2432
2433            assert_eq!(err.to_string(), expected, "failed for {:?}", query);
2434        }
2435
2436        // Input length
2437        for len in [quantizer.input_dim() - 1, quantizer.input_dim() + 1] {
2438            let query = vec![0.0f32; len];
2439            let mut d = Data::<1, _>::new_boxed(quantizer.output_dim());
2440            let mut q = Query::<4, BitTranspose, _>::new_boxed(quantizer.output_dim());
2441
2442            let err = quantizer
2443                .compress_into_with(&*query, d.reborrow_mut(), scoped_global)
2444                .unwrap_err();
2445            assert_eq!(
2446                err,
2447                CompressionError::SourceDimensionMismatch {
2448                    expected: quantizer.input_dim(),
2449                }
2450            );
2451
2452            let err = quantizer
2453                .compress_into_with(&*query, q.reborrow_mut(), scoped_global)
2454                .unwrap_err();
2455            assert_eq!(
2456                err,
2457                CompressionError::SourceDimensionMismatch {
2458                    expected: quantizer.input_dim(),
2459                }
2460            );
2461        }
2462
2463        for len in [quantizer.output_dim() - 1, quantizer.output_dim() + 1] {
2464            let query = vec![0.0f32; quantizer.input_dim()];
2465            let mut d = Data::<1, _>::new_boxed(len);
2466            let mut q = Query::<4, BitTranspose, _>::new_boxed(len);
2467
2468            let err = quantizer
2469                .compress_into_with(&*query, d.reborrow_mut(), scoped_global)
2470                .unwrap_err();
2471            assert_eq!(
2472                err,
2473                CompressionError::DestinationDimensionMismatch {
2474                    expected: quantizer.output_dim(),
2475                }
2476            );
2477
2478            let err = quantizer
2479                .compress_into_with(&*query, q.reborrow_mut(), scoped_global)
2480                .unwrap_err();
2481            assert_eq!(
2482                err,
2483                CompressionError::DestinationDimensionMismatch {
2484                    expected: quantizer.output_dim(),
2485                }
2486            );
2487        }
2488    }
2489
2490    #[test]
2491    fn centroid_scaling_happens_in_generate() {
2492        let centroid = Poly::from_iter(
2493            [1088.6732f32, 1393.32, 1547.877].into_iter(),
2494            GlobalAllocator,
2495        )
2496        .unwrap();
2497        let mean_norm = 2359.27;
2498        let pre_scale = 1.0 / mean_norm;
2499
2500        let quantizer = SphericalQuantizer::generate(
2501            centroid,
2502            mean_norm,
2503            TransformKind::Null,
2504            SupportedMetric::InnerProduct,
2505            Some(pre_scale),
2506            &mut StdRng::seed_from_u64(10),
2507            GlobalAllocator,
2508        )
2509        .unwrap();
2510
2511        let mut v = Data::<4, _>::new_boxed(quantizer.input_dim());
2512        let data: &[f32] = &[1000.34, 1456.32, 1234.5446];
2513        assert!(
2514            quantizer
2515                .compress_into_with(data, v.reborrow_mut(), ScopedAllocator::global())
2516                .is_ok(),
2517            "if this failed, the likely culprit is exceeding the value of the 16-bit correction terms"
2518        );
2519    }
2520}
2521
2522#[cfg(feature = "flatbuffers")]
2523#[cfg(test)]
2524mod test_serialization {
2525    use rand::{SeedableRng, rngs::StdRng};
2526
2527    use super::*;
2528    use crate::{
2529        algorithms::transforms::TargetDim,
2530        flatbuffers::{self as fb, to_flatbuffer},
2531        poly, test_util,
2532    };
2533
2534    #[test]
2535    fn test_serialization_happy_path() {
2536        let mut rng = StdRng::seed_from_u64(0x070d9ff8cf5e0f8c);
2537        let problem = test_util::create_test_problem(10, 128, &mut rng);
2538
2539        let low = NonZeroUsize::new(100).unwrap();
2540        let high = NonZeroUsize::new(150).unwrap();
2541
2542        let kinds = [
2543            // Null
2544            TransformKind::Null,
2545            // Double Hadamard
2546            TransformKind::DoubleHadamard {
2547                target_dim: TargetDim::Same,
2548            },
2549            TransformKind::DoubleHadamard {
2550                target_dim: TargetDim::Natural,
2551            },
2552            TransformKind::DoubleHadamard {
2553                target_dim: TargetDim::Override(low),
2554            },
2555            TransformKind::DoubleHadamard {
2556                target_dim: TargetDim::Override(high),
2557            },
2558            // Padding Hadamard
2559            TransformKind::PaddingHadamard {
2560                target_dim: TargetDim::Same,
2561            },
2562            TransformKind::PaddingHadamard {
2563                target_dim: TargetDim::Natural,
2564            },
2565            TransformKind::PaddingHadamard {
2566                target_dim: TargetDim::Override(low),
2567            },
2568            TransformKind::PaddingHadamard {
2569                target_dim: TargetDim::Override(high),
2570            },
2571            // Random Rotation
2572            #[cfg(all(not(miri), feature = "linalg"))]
2573            TransformKind::RandomRotation {
2574                target_dim: TargetDim::Same,
2575            },
2576            #[cfg(all(not(miri), feature = "linalg"))]
2577            TransformKind::RandomRotation {
2578                target_dim: TargetDim::Natural,
2579            },
2580            #[cfg(all(not(miri), feature = "linalg"))]
2581            TransformKind::RandomRotation {
2582                target_dim: TargetDim::Override(low),
2583            },
2584            #[cfg(all(not(miri), feature = "linalg"))]
2585            TransformKind::RandomRotation {
2586                target_dim: TargetDim::Override(high),
2587            },
2588        ];
2589
2590        let pre_scales = [
2591            PreScale::None,
2592            PreScale::Some(Positive::new(0.5).unwrap()),
2593            PreScale::Some(Positive::new(1.0).unwrap()),
2594            PreScale::Some(Positive::new(1.5).unwrap()),
2595            PreScale::ReciprocalMeanNorm,
2596        ];
2597
2598        let alloc = GlobalAllocator;
2599        for kind in kinds.into_iter() {
2600            for metric in SupportedMetric::all() {
2601                for pre_scale in pre_scales {
2602                    let quantizer = SphericalQuantizer::train(
2603                        problem.data.as_view(),
2604                        kind,
2605                        metric,
2606                        pre_scale,
2607                        &mut rng,
2608                        alloc,
2609                    )
2610                    .unwrap();
2611
2612                    let data = to_flatbuffer(|buf| quantizer.pack(buf));
2613                    let proto =
2614                        flatbuffers::root::<fb::spherical::SphericalQuantizer>(&data).unwrap();
2615                    let reloaded = SphericalQuantizer::try_unpack(alloc, proto).unwrap();
2616                    assert_eq!(quantizer, reloaded, "failed on transform {:?}", kind);
2617                }
2618            }
2619        }
2620    }
2621
2622    #[test]
2623    fn test_error_checking() {
2624        let mut rng = StdRng::seed_from_u64(0x070d9ff8cf5e0f8c);
2625        let problem = test_util::create_test_problem(10, 128, &mut rng);
2626
2627        let transform = TransformKind::DoubleHadamard {
2628            target_dim: TargetDim::Same,
2629        };
2630
2631        let alloc = GlobalAllocator;
2632        let mut make_quantizer = || {
2633            SphericalQuantizer::train(
2634                problem.data.as_view(),
2635                transform,
2636                SupportedMetric::SquaredL2,
2637                PreScale::None,
2638                &mut rng,
2639                alloc,
2640            )
2641            .unwrap()
2642        };
2643
2644        type E = DeserializationError;
2645
2646        // Missing norm: 0.0
2647        {
2648            let mut quantizer = make_quantizer();
2649            // SAFETY: We do not do anything with the created value and the compiler
2650            // does not know about the layout of `Positive`, so we don't need to worry
2651            // about violating layout restrictions.
2652            quantizer.mean_norm = unsafe { Positive::new_unchecked(0.0) };
2653
2654            let data = to_flatbuffer(|buf| quantizer.pack(buf));
2655            let proto = flatbuffers::root::<fb::spherical::SphericalQuantizer>(&data).unwrap();
2656            let err = SphericalQuantizer::try_unpack(alloc, proto).unwrap_err();
2657            assert_eq!(err, E::MissingNorm);
2658        }
2659
2660        // Missing norm: negative
2661        {
2662            let mut quantizer = make_quantizer();
2663
2664            // SAFETY: We do not do anything with the created value and the compiler
2665            // does not know about the layout of `Positive`, so we don't need to worry
2666            // about violating layout restrictions.
2667            quantizer.mean_norm = unsafe { Positive::new_unchecked(-1.0) };
2668
2669            let data = to_flatbuffer(|buf| quantizer.pack(buf));
2670            let proto = flatbuffers::root::<fb::spherical::SphericalQuantizer>(&data).unwrap();
2671            let err = SphericalQuantizer::try_unpack(alloc, proto).unwrap_err();
2672            assert_eq!(err, E::MissingNorm);
2673        }
2674
2675        // PreScaleNotPositive
2676        {
2677            let mut quantizer = make_quantizer();
2678
2679            // SAFETY: This really isn't safe, but we are not using the improper value in a
2680            // way that will trigger undefined behavior.
2681            quantizer.pre_scale = unsafe { Positive::new_unchecked(0.0) };
2682
2683            let data = to_flatbuffer(|buf| quantizer.pack(buf));
2684            let proto = flatbuffers::root::<fb::spherical::SphericalQuantizer>(&data).unwrap();
2685            let err = SphericalQuantizer::try_unpack(alloc, proto).unwrap_err();
2686            assert_eq!(err, E::PreScaleNotPositive);
2687        }
2688
2689        // Dim Mismatch.
2690        {
2691            let mut quantizer = make_quantizer();
2692            quantizer.shift = poly!([1.0, 2.0, 3.0], alloc).unwrap();
2693
2694            let data = to_flatbuffer(|buf| quantizer.pack(buf));
2695            let proto = flatbuffers::root::<fb::spherical::SphericalQuantizer>(&data).unwrap();
2696            let err = SphericalQuantizer::try_unpack(alloc, proto).unwrap_err();
2697            assert_eq!(err, E::DimMismatch);
2698        }
2699    }
2700}