Skip to main content

diskann_quantization/scalar/
vectors.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use diskann_vector::{DistanceFunction, PureDistanceFunction};
7
8use super::inverse_bit_scale;
9use crate::{
10    bits::{BitSlice, Dense, Representation, Unsigned},
11    distances::{self, InnerProduct, MV, SquaredL2, check_lengths},
12    meta,
13};
14
15/// A per-vector precomputed coefficient to help compute inner products.
16///
17/// To understand the use of the compensation coefficient, assume that we wish to compute
18/// the inner product between two scalar compressed vectors where the quantization has
19/// scale parameter `a` and centroid `B` (note: capital letters represent vectors, lower
20/// case letters represent scalars).
21///
22/// The inner product between a `X = a * (X' + B)` and `Y = a * (Y' + B)` where
23/// `X'` and `Y'` are the scalar encodings for `X` and `Y` respectively is:
24/// ```math
25/// P = <a * X' + B, a * Y' + B>
26///   = a^2 * <X', Y'> + a * <X', B> + a * <Y', B> + <B, B>
27///            ------    -----------   -----------   ------
28///               |           |             |           |
29///          Integer Dot      |        Compensation     |
30///            Product        |           for Y         |
31///                           |                    Constant for
32///                      Compensation               all vectors
33///                         for X
34///
35/// ```
36/// In other words, the inner product can be decomposed into an integer dot-product plus
37/// a bunch of other terms that compensate for the compression.
38///
39/// These compensation terms can be computed as the vectors are compressed. At run time,
40/// we can the return vectors consisting of the quantized encodings (e.g. `X'`) and the
41/// compensation `<X', B>`.
42///
43/// Computation of squared Euclidean distance is more straight forward:
44/// ```math
45/// P = sum( ((a * X' + B) - (a * Y' + B))^2 )
46///   = sum( a^2 * (X' - Y')^2 )
47///   = a^2 * sum( (X' - Y')^2 )
48/// ```
49/// This means the squared Euclidean distance is computed by scaling the squared Euclidean
50/// distance computed directly on the integer codes.
51///
52/// # Distance Implementations
53///
54/// The following distance function types are implemented:
55///
56/// * [`CompensatedSquaredL2`]: For computing squared euclidean distances.
57/// * [`CompensatedIP`]: For computing inner products.
58///
59/// # Examples
60///
61/// The `CompensatedVector` has several named variants that are commonly used:
62/// * [`CompensatedVector`]: An owning, indepndently allocated `CompensatedVector`.
63/// * [`MutCompensatedVectorRef`]: A mutable, reference-like type to a `CompensatedVector`.
64/// * [`CompensatedVectorRef`]: A const, reference-like type to a `CompensatedVector`.
65///
66/// ```
67/// use diskann_quantization::{
68///     scalar::{
69///         self,
70///         CompensatedVector,
71///         MutCompensatedVectorRef,
72///         CompensatedVectorRef
73///     },
74/// };
75///
76/// use diskann_utils::{Reborrow, ReborrowMut};
77///
78/// // Create a new heap-allocated CompensatedVector for 4-bit compressions capable of
79/// // holding 3 elements.
80/// let mut v = CompensatedVector::<4>::new_boxed(3);
81///
82/// // We can inspect the underlying bitslice.
83/// let bitslice = v.vector();
84/// assert_eq!(bitslice.get(0).unwrap(), 0);
85/// assert_eq!(bitslice.get(1).unwrap(), 0);
86/// assert_eq!(v.meta().0, 0.0, "expected default compensation value");
87///
88/// // If we want, we can mutably borrow the bitslice and mutate its components.
89/// let mut bitslice = v.vector_mut();
90/// bitslice.set(0, 1).unwrap();
91/// bitslice.set(1, 2).unwrap();
92/// bitslice.set(2, 3).unwrap();
93///
94/// assert!(bitslice.set(3, 4).is_err(), "out-of-bounds access");
95///
96/// // Get the underlying pointer for comparision.
97/// let ptr = bitslice.as_ptr();
98///
99/// // Vectors can be converted to a generalized reference.
100/// let mut v_ref = v.reborrow_mut();
101///
102/// // The generalized reference preserves the underlying pointer.
103/// assert_eq!(v_ref.vector().as_ptr(), ptr);
104/// let mut bitslice = v_ref.vector_mut();
105/// bitslice.set(0, 10).unwrap();
106///
107/// // Setting the underlying compensation will be visible in the original allocation.
108/// v_ref.set_meta(scalar::Compensation(1.0));
109///
110/// // Check that the changes are visible.
111/// assert_eq!(v.meta().0, 1.0);
112/// assert_eq!(v.vector().get(0).unwrap(), 10);
113///
114/// // Finally, the immutable ref also maintains pointer compatibility.
115/// let v_ref = v.reborrow();
116/// assert_eq!(v_ref.vector().as_ptr(), ptr);
117/// ```
118///
119/// ## Constructing a `MutCompensatedVectorRef` From Components
120///
121/// The following example shows how to assemble a `MutCompensatedVectorRef` from raw memory.
122/// ```
123/// use diskann_quantization::{
124///     bits::{Unsigned, MutBitSlice},
125///     scalar::{self, MutCompensatedVectorRef}
126/// };
127///
128/// // Start with 2 bytes of memory. We will impose a 4-bit scalar quantization on top of
129/// // these 4 bytes.
130/// let mut data = vec![0u8; 2];
131/// let mut compensation = scalar::Compensation(0.0);
132/// {
133///     // First, we need to construct a bit-slice over the data.
134///     // This will check that it is sized properly for 4, 4-bit values.
135///     let mut slice = MutBitSlice::<4, Unsigned>::new(data.as_mut_slice(), 4).unwrap();
136///
137///     // Next, we construct the `MutCompensatedVectorRef`.
138///     let mut v = MutCompensatedVectorRef::new(slice, &mut compensation);
139///
140///     // Through `v`, we can set all the components in `slice` and the compensation.
141///     v.set_meta(scalar::Compensation(1.0));
142///     let mut from_v = v.vector_mut();
143///     from_v.set(0, 1).unwrap();
144///     from_v.set(1, 2).unwrap();
145///     from_v.set(2, 3).unwrap();
146///     from_v.set(3, 4).unwrap();
147/// }
148///
149/// // Now we can check that the changes made internally are visible.
150/// assert_eq!(&data, &[0x21, 0x43]);
151/// assert_eq!(compensation.0, 1.0);
152/// ```
153#[derive(Default, Debug, Clone, Copy, bytemuck::Zeroable, bytemuck::Pod)]
154#[repr(transparent)]
155pub struct Compensation(pub f32);
156
157/// A borrowed `ComptensatedVector`.
158///
159/// See: [`meta::Vector`].
160pub type CompensatedVectorRef<'a, const NBITS: usize, Perm = Dense> =
161    meta::VectorRef<'a, NBITS, Unsigned, Compensation, Perm>;
162
163/// A mutably borrowed `ComptensatedVector`.
164///
165/// See: [`meta::Vector`].
166pub type MutCompensatedVectorRef<'a, const NBITS: usize, Perm = Dense> =
167    meta::VectorMut<'a, NBITS, Unsigned, Compensation, Perm>;
168
169/// An owning `CompensatedVector`.
170///
171/// See: [`meta::Vector`].
172pub type CompensatedVector<const NBITS: usize, Perm = Dense> =
173    meta::Vector<NBITS, Unsigned, Compensation, Perm>;
174
175////////////////////////////
176// Compensated Squared L2 //
177////////////////////////////
178
179/// A `DistanceFunction` containing scaling parameters to enable distance the SquaredL2
180/// distance function over `CompensatedVectors` belonging to the same quantization space.
181#[derive(Debug, Clone, Copy)]
182pub struct CompensatedSquaredL2 {
183    pub(super) scale_squared: f32,
184}
185
186impl CompensatedSquaredL2 {
187    /// Construct a new `CompensatedSquaredL2` with the given scaling factor.
188    pub fn new(scale_squared: f32) -> Self {
189        Self { scale_squared }
190    }
191}
192
193/// Compute the squared euclidean distance between the two compensated vectors.
194///
195/// The value returned by this function is scaled properly, meaning that distances returned
196/// by this method are compatible with full-precision distances.
197///
198/// # Validity
199///
200/// The results of this function are only meaningful if both `x`, `y`, and `Self` belong to
201/// the same quantizer.
202///
203/// # Panics
204///
205/// Panics if `x.len() != y.len()`.
206impl<const NBITS: usize>
207    DistanceFunction<
208        CompensatedVectorRef<'_, NBITS>,
209        CompensatedVectorRef<'_, NBITS>,
210        distances::MathematicalResult<f32>,
211    > for CompensatedSquaredL2
212where
213    Unsigned: Representation<NBITS>,
214    SquaredL2: for<'a, 'b> PureDistanceFunction<
215            BitSlice<'a, NBITS, Unsigned>,
216            BitSlice<'b, NBITS, Unsigned>,
217            distances::MathematicalResult<u32>,
218        >,
219{
220    fn evaluate_similarity(
221        &self,
222        x: CompensatedVectorRef<'_, NBITS>,
223        y: CompensatedVectorRef<'_, NBITS>,
224    ) -> distances::MathematicalResult<f32> {
225        check_lengths!(x, y)?;
226        let squared_l2: distances::MathematicalResult<u32> =
227            SquaredL2::evaluate(x.vector(), y.vector());
228        let squared_l2 = squared_l2?.into_inner() as f32;
229
230        // This should constant-propagate.
231        let bit_scale = inverse_bit_scale::<NBITS>() * inverse_bit_scale::<NBITS>();
232
233        let result = bit_scale * self.scale_squared * squared_l2;
234        Ok(MV::new(result))
235    }
236}
237
238/// Compute the squared euclidean distance between the two compensated vectors.
239///
240/// The value returned by this function is scaled properly, meaning that distances returned
241/// by this method are compatible with full-precision distances.
242///
243/// # Validity
244///
245/// The results of this function are only meaningful if both `x`, `y`, and `Self` belong to
246/// the same quantizer.
247///
248/// # Panics
249///
250/// Panics if `x.len() != y.len()`.
251impl<const NBITS: usize>
252    DistanceFunction<
253        CompensatedVectorRef<'_, NBITS>,
254        CompensatedVectorRef<'_, NBITS>,
255        distances::Result<f32>,
256    > for CompensatedSquaredL2
257where
258    Unsigned: Representation<NBITS>,
259    Self: for<'a, 'b> DistanceFunction<
260            CompensatedVectorRef<'a, NBITS>,
261            CompensatedVectorRef<'b, NBITS>,
262            distances::MathematicalResult<f32>,
263        >,
264{
265    fn evaluate_similarity(
266        &self,
267        x: CompensatedVectorRef<'_, NBITS>,
268        y: CompensatedVectorRef<'_, NBITS>,
269    ) -> distances::Result<f32> {
270        let v: MV<f32> = self.evaluate_similarity(x, y)?;
271        Ok(v.into_inner())
272    }
273}
274
275////////////////////
276// Compensated IP //
277////////////////////
278
279/// A `DistanceFunction` containing scaling parameters to enable distance the SquaredL2
280/// distance function over `CompensatedVectors` belonging to the same quantization space.
281#[derive(Debug, Clone, Copy)]
282pub struct CompensatedIP {
283    pub(super) scale_squared: f32,
284    pub(super) shift_square_norm: f32,
285}
286
287impl CompensatedIP {
288    /// Construct a new `CompensatedIP` with the given scaling factor and shift norm.
289    pub fn new(scale_squared: f32, shift_square_norm: f32) -> Self {
290        Self {
291            scale_squared,
292            shift_square_norm,
293        }
294    }
295}
296
297/// Compute the inner product between the two compensated vectors.
298///
299/// The value returned by this function is scaled properly, meaning that distances returned
300/// by this method are compatible with full-precision computations.
301///
302/// # Validity
303///
304/// The results of this function are only meaningful if both `x`, `y`, and `Self` belong to
305/// the same quantizer.
306///
307/// # Panics
308///
309/// Panics if `x.len() != y.len()`.
310impl<const NBITS: usize>
311    DistanceFunction<
312        CompensatedVectorRef<'_, NBITS>,
313        CompensatedVectorRef<'_, NBITS>,
314        distances::MathematicalResult<f32>,
315    > for CompensatedIP
316where
317    Unsigned: Representation<NBITS>,
318    InnerProduct: for<'a, 'b> PureDistanceFunction<
319            BitSlice<'a, NBITS, Unsigned>,
320            BitSlice<'b, NBITS, Unsigned>,
321            distances::MathematicalResult<u32>,
322        >,
323{
324    fn evaluate_similarity(
325        &self,
326        x: CompensatedVectorRef<'_, NBITS>,
327        y: CompensatedVectorRef<'_, NBITS>,
328    ) -> distances::MathematicalResult<f32> {
329        let product: MV<u32> = InnerProduct::evaluate(x.vector(), y.vector())?;
330
331        // This should constant-propagate.
332        let bit_scale = inverse_bit_scale::<NBITS>() * inverse_bit_scale::<NBITS>();
333
334        let result = (bit_scale * self.scale_squared)
335            .mul_add(product.into_inner() as f32, self.shift_square_norm)
336            + (y.meta().0 + x.meta().0);
337        Ok(MV::new(result))
338    }
339}
340
341/// Compute the inner product between the two compensated vectors.
342///
343/// The value returned by this function is scaled properly, meaning that distances returned
344/// by this method are compatible with full-precision computations.
345///
346/// # Validity
347///
348/// The results of this function are only meaningful if both `x`, `y`, and `Self` belong to
349/// the same quantizer.
350///
351/// # Panics
352///
353/// Panics if `x.len() != y.len()`.
354impl<const NBITS: usize>
355    DistanceFunction<
356        CompensatedVectorRef<'_, NBITS>,
357        CompensatedVectorRef<'_, NBITS>,
358        distances::Result<f32>,
359    > for CompensatedIP
360where
361    Unsigned: Representation<NBITS>,
362    Self: for<'a, 'b> DistanceFunction<
363            CompensatedVectorRef<'a, NBITS>,
364            CompensatedVectorRef<'b, NBITS>,
365            distances::MathematicalResult<f32>,
366        >,
367{
368    fn evaluate_similarity(
369        &self,
370        x: CompensatedVectorRef<'_, NBITS>,
371        y: CompensatedVectorRef<'_, NBITS>,
372    ) -> distances::Result<f32> {
373        let v: MV<f32> = self.evaluate_similarity(x, y)?;
374        Ok(-v.into_inner())
375    }
376}
377
378/// Compensated CosineNormalized distance function.
379#[derive(Debug, Clone, Copy)]
380pub struct CompensatedCosineNormalized {
381    pub(super) scale_squared: f32,
382}
383
384impl CompensatedCosineNormalized {
385    pub fn new(scale_squared: f32) -> Self {
386        Self { scale_squared }
387    }
388}
389
390/// CosineNormalized
391///
392/// This implementation calculates the <x, y> = 1 - L2 / 2 value, which will be further used
393/// to compute the CosineNormalised distance function
394///
395/// # Notes
396///
397/// s = 1 - cosine(X, Y) = 1- <X, Y> / (||X|| * ||Y||)
398///
399/// We can make simply assumption that ||X|| = 1 and ||Y|| = 1.
400/// Then s = 1 - <X, Y>
401///
402/// The squared L2 distance can be computed as follows:
403/// p = ||x||^2 + ||y||^2 - 2<x, y>
404/// When vectors are normalized, this becomes
405/// p = 2 - 2<x, y> = 2 * (1 - <x, y>)
406///
407/// In other words, the similarity score for the squared L2 distance in an ideal world is
408/// 2 times that for cosine similarity. Therefore, squared L2 may serves as a stand-in for
409/// cosine normalized as ordering is preserved.
410impl<const NBITS: usize>
411    DistanceFunction<
412        CompensatedVectorRef<'_, NBITS>,
413        CompensatedVectorRef<'_, NBITS>,
414        distances::MathematicalResult<f32>,
415    > for CompensatedCosineNormalized
416where
417    Unsigned: Representation<NBITS>,
418    SquaredL2: for<'a, 'b> PureDistanceFunction<
419            BitSlice<'a, NBITS, Unsigned>,
420            BitSlice<'b, NBITS, Unsigned>,
421            distances::MathematicalResult<u32>,
422        >,
423{
424    fn evaluate_similarity(
425        &self,
426        x: CompensatedVectorRef<'_, NBITS>,
427        y: CompensatedVectorRef<'_, NBITS>,
428    ) -> distances::MathematicalResult<f32> {
429        let squared_l2: MV<u32> = SquaredL2::evaluate(x.vector(), y.vector())?;
430
431        // This should constant-propagate.
432        let bit_scale = inverse_bit_scale::<NBITS>() * inverse_bit_scale::<NBITS>();
433
434        let l2 = bit_scale * self.scale_squared * squared_l2.into_inner() as f32;
435
436        let result = 1.0 - l2 / 2.0;
437        Ok(MV::new(result))
438    }
439}
440
441impl<const NBITS: usize>
442    DistanceFunction<
443        CompensatedVectorRef<'_, NBITS>,
444        CompensatedVectorRef<'_, NBITS>,
445        distances::Result<f32>,
446    > for CompensatedCosineNormalized
447where
448    Unsigned: Representation<NBITS>,
449    Self: for<'a, 'b> DistanceFunction<
450            CompensatedVectorRef<'a, NBITS>,
451            CompensatedVectorRef<'b, NBITS>,
452            distances::MathematicalResult<f32>,
453        >,
454{
455    fn evaluate_similarity(
456        &self,
457        x: CompensatedVectorRef<'_, NBITS>,
458        y: CompensatedVectorRef<'_, NBITS>,
459    ) -> distances::Result<f32> {
460        let v: MV<f32> = self.evaluate_similarity(x, y)?;
461        Ok(1.0 - v.into_inner())
462    }
463}
464
465///////////
466// Tests //
467///////////
468
469#[cfg(test)]
470mod tests {
471    use diskann_utils::{Reborrow, ReborrowMut};
472    use rand::{
473        Rng, SeedableRng,
474        distr::{Distribution, Uniform},
475        rngs::StdRng,
476    };
477
478    use super::*;
479    use crate::{
480        bits::{Representation, Unsigned},
481        scalar::bit_scale,
482        test_util,
483    };
484
485    ///////////////
486    // Distances //
487    ///////////////
488
489    /// This test works as follows:
490    ///
491    /// First, generate a random value for `a`, `X'` and `B` where:
492    ///
493    /// * `a`: Is the scaling parameters.
494    /// * `X'`: Is the integer compressed codes for a vector.
495    /// * `B`: The floating point vector representing the dataset center.
496    ///
497    /// Next, compute the reconstructed vector using `X = a * X' + B`.
498    /// Repeat this process for another vector `Y` using the same `a` and `B`.
499    ///
500    /// Then, the result of a distance computation can be done on the compressed
501    /// representation and on the reconstructed representation. The results should match
502    /// (modulo floating-point rounding).
503    ///
504    /// To get a handle on floating point issues, we pick "nice" numbers for the values of
505    /// `a` and each component of `B` that are either small integers, or nice binary fractions
506    /// like 1/2 or 3/4.
507    ///
508    /// Even with nice numbers, there is still a small amount of rounding instability.
509    fn test_compensated_distance<const NBITS: usize, R>(
510        dim: usize,
511        ntrials: usize,
512        max_relative_err_l2: f32,
513        max_relative_err_ip: f32,
514        max_relative_err_cos: f32,
515        max_absolute_error: f32,
516        rng: &mut R,
517    ) where
518        Unsigned: Representation<NBITS>,
519        R: Rng,
520        CompensatedSquaredL2: for<'a, 'b> DistanceFunction<
521                CompensatedVectorRef<'a, NBITS>,
522                CompensatedVectorRef<'b, NBITS>,
523                distances::MathematicalResult<f32>,
524            >,
525        CompensatedSquaredL2: for<'a, 'b> DistanceFunction<
526                CompensatedVectorRef<'a, NBITS>,
527                CompensatedVectorRef<'b, NBITS>,
528                distances::Result<f32>,
529            >,
530        CompensatedIP: for<'a, 'b> DistanceFunction<
531                CompensatedVectorRef<'a, NBITS>,
532                CompensatedVectorRef<'b, NBITS>,
533                distances::MathematicalResult<f32>,
534            >,
535        CompensatedIP: for<'a, 'b> DistanceFunction<
536                CompensatedVectorRef<'a, NBITS>,
537                CompensatedVectorRef<'b, NBITS>,
538                distances::Result<f32>,
539            >,
540        CompensatedCosineNormalized: for<'a, 'b> DistanceFunction<
541                CompensatedVectorRef<'a, NBITS>,
542                CompensatedVectorRef<'b, NBITS>,
543                distances::MathematicalResult<f32>,
544            >,
545        CompensatedCosineNormalized: for<'a, 'b> DistanceFunction<
546                CompensatedVectorRef<'a, NBITS>,
547                CompensatedVectorRef<'b, NBITS>,
548                distances::Result<f32>,
549            >,
550    {
551        // The distributions we use for `a` and `B` are taken from integer distributions,
552        // which we then convert to `f32` and divide by a power of 2.
553        //
554        // This helps keep computations exact so we don't also have to worry about tracking
555        // floating rounding issues.
556        //
557        // Here, `alpha` refers to `a` in the function docstring and `beta` refers to `B`.
558        let alpha_distribution = Uniform::new_inclusive(-16, 16).unwrap();
559        let beta_distribution = Uniform::new_inclusive(-32, 32).unwrap();
560
561        // What we divide the results generated by the alpha and beta distributions.
562        let alpha_divisor: f32 = 64.0;
563        let beta_divisor: f32 = 128.0;
564
565        let domain = Unsigned::domain_const::<NBITS>();
566        let code_distribution = Uniform::new_inclusive(*domain.start(), *domain.end()).unwrap();
567
568        // Preallocate buffers.
569        let mut beta: Vec<f32> = vec![0.0; dim];
570        let mut x_prime: Vec<u8> = vec![0; dim];
571        let mut y_prime: Vec<u8> = vec![0; dim];
572        let mut x_reconstructed: Vec<f32> = vec![0.0; dim];
573        let mut y_reconstructed: Vec<f32> = vec![0.0; dim];
574
575        let mut x_compensated = CompensatedVector::<NBITS>::new_boxed(dim);
576        let mut y_compensated = CompensatedVector::<NBITS>::new_boxed(dim);
577
578        // Populate a compensated vector from the codes and `beta`.
579        let populate_compensation = |mut dst: MutCompensatedVectorRef<'_, NBITS>,
580                                     codes: &[u8],
581                                     alpha: f32,
582                                     beta: &[f32]| {
583            assert_eq!(dst.len(), codes.len());
584            assert_eq!(dst.len(), beta.len());
585
586            let mut compensation: f32 = 0.0;
587            let mut vector = dst.vector_mut();
588            for (i, (&c, &b)) in std::iter::zip(codes.iter(), beta.iter()).enumerate() {
589                vector.set(i, c.into()).unwrap();
590
591                let c: f32 = c.into();
592                compensation += c * b;
593            }
594            dst.set_meta(Compensation(alpha * compensation / bit_scale::<NBITS>()));
595        };
596
597        for trial in 0..ntrials {
598            // Generate the problem.
599            let alpha = (alpha_distribution.sample(rng) as f32) / alpha_divisor;
600            beta.iter_mut().for_each(|b| {
601                *b = (beta_distribution.sample(rng) as f32) / beta_divisor;
602            });
603            x_prime
604                .iter_mut()
605                .for_each(|x| *x = code_distribution.sample(rng).try_into().unwrap());
606            y_prime
607                .iter_mut()
608                .for_each(|y| *y = code_distribution.sample(rng).try_into().unwrap());
609
610            // Generate the reconstructed vectors.
611            let bit_scale = inverse_bit_scale::<NBITS>();
612            x_reconstructed
613                .iter_mut()
614                .zip(x_prime.iter())
615                .zip(beta.iter())
616                .for_each(|((x, xp), b)| {
617                    *x = (alpha * *xp as f32) * bit_scale + *b;
618                });
619
620            y_reconstructed
621                .iter_mut()
622                .zip(y_prime.iter())
623                .zip(beta.iter())
624                .for_each(|((y, yp), b)| {
625                    *y = (alpha * *yp as f32) * bit_scale + *b;
626                });
627
628            populate_compensation(x_compensated.reborrow_mut(), &x_prime, alpha, &beta);
629            populate_compensation(y_compensated.reborrow_mut(), &y_prime, alpha, &beta);
630
631            // Squared L2
632            let expected: MV<f32> =
633                diskann_vector::distance::SquaredL2::evaluate(&*x_reconstructed, &*y_reconstructed);
634
635            let distance = CompensatedSquaredL2::new(alpha * alpha);
636            let got: distances::MathematicalResult<f32> =
637                distance.evaluate_similarity(x_compensated.reborrow(), y_compensated.reborrow());
638            let got = got.unwrap();
639
640            let relative_err =
641                test_util::compute_relative_error(got.into_inner(), expected.into_inner());
642            let absolute_err =
643                test_util::compute_absolute_error(got.into_inner(), expected.into_inner());
644
645            assert!(
646                relative_err <= max_relative_err_l2 || absolute_err <= max_absolute_error,
647                "failed SquaredL2 for NBITS = {}, dim = {}, trial = {}. \
648                 Got an error {} (rel) / {} (abs) with tolerance {}/{}. \
649                 Expected {}, got {}",
650                NBITS,
651                dim,
652                trial,
653                relative_err,
654                absolute_err,
655                max_relative_err_l2,
656                max_absolute_error,
657                expected.into_inner(),
658                got.into_inner(),
659            );
660
661            // f32 should match Mathematicalvalue.
662            let got_f32: distances::Result<f32> =
663                distance.evaluate_similarity(x_compensated.reborrow(), y_compensated.reborrow());
664            let got_f32 = got_f32.unwrap();
665            assert_eq!(got.into_inner(), got_f32);
666
667            // Inner Product
668            let expected: MV<f32> = diskann_vector::distance::InnerProduct::evaluate(
669                &*x_reconstructed,
670                &*y_reconstructed,
671            );
672
673            let distance =
674                CompensatedIP::new(alpha * alpha, beta.iter().map(|&i| i * i).sum::<f32>());
675            let got: distances::MathematicalResult<f32> =
676                distance.evaluate_similarity(x_compensated.reborrow(), y_compensated.reborrow());
677            let got = got.unwrap();
678
679            let relative_err =
680                test_util::compute_relative_error(got.into_inner(), expected.into_inner());
681            let absolute_err =
682                test_util::compute_absolute_error(got.into_inner(), expected.into_inner());
683
684            assert!(
685                relative_err <= max_relative_err_ip || absolute_err < max_absolute_error,
686                "failed InnerProduct for NBITS = {}, dim = {}, trial = {}. \
687                 Got an error {} (rel) / {} (abs) with tolerance {}/{}. \
688                 Expected {}, got {}",
689                NBITS,
690                dim,
691                trial,
692                relative_err,
693                absolute_err,
694                max_relative_err_ip,
695                max_absolute_error,
696                expected.into_inner(),
697                got.into_inner(),
698            );
699
700            // f32 should be the negative Mathematicalvalue.
701            let got_f32: distances::Result<f32> =
702                distance.evaluate_similarity(x_compensated.reborrow(), y_compensated.reborrow());
703            let got_f32 = got_f32.unwrap();
704
705            assert_eq!(-got.into_inner(), got_f32);
706
707            // CosineNormalized:
708            // expected value is cosine similarity of reconstructed vectors (no scale/shift)
709            let expected: MV<f32> =
710                diskann_vector::distance::SquaredL2::evaluate(&*x_reconstructed, &*y_reconstructed);
711            let expected = 1.0 - expected.into_inner() / 2.0;
712
713            let distance = CompensatedCosineNormalized::new(alpha * alpha);
714            let got: distances::MathematicalResult<f32> =
715                distance.evaluate_similarity(x_compensated.reborrow(), y_compensated.reborrow());
716            let got = got.unwrap();
717
718            if expected != 0.0 {
719                let relative_err = test_util::compute_relative_error(got.into_inner(), expected);
720                let absolute_err = test_util::compute_absolute_error(got.into_inner(), expected);
721                assert!(
722                    relative_err < max_relative_err_cos || absolute_err < max_absolute_error,
723                    "failed CosineNormalized for NBITS = {}, dim = {}, trial = {}. \
724                     Got an error {} (rel) / {} (abs) with tolerance {}/{}. \
725                     Expected {}, got {}",
726                    NBITS,
727                    dim,
728                    trial,
729                    relative_err,
730                    absolute_err,
731                    max_relative_err_cos,
732                    max_absolute_error,
733                    expected,
734                    got.into_inner(),
735                );
736            } else {
737                let absolute_err = test_util::compute_absolute_error(got.into_inner(), expected);
738                assert!(
739                    absolute_err < max_absolute_error,
740                    "failed CosineNormalized for NBITS = {}, dim = {}, trial = {}. \
741                    Got an absolute error {} with tolerance {}. \
742                    Expected {}, got {}",
743                    NBITS,
744                    dim,
745                    trial,
746                    absolute_err,
747                    max_absolute_error,
748                    expected,
749                    got.into_inner(),
750                );
751            }
752
753            let got_f32: distances::Result<f32> =
754                distance.evaluate_similarity(x_compensated.reborrow(), y_compensated.reborrow());
755            let got_f32 = got_f32.unwrap();
756            assert_eq!(1.0 - got.into_inner(), got_f32);
757        }
758    }
759
760    cfg_if::cfg_if! {
761        if #[cfg(miri)] {
762            // The max dim does not need to be as high for `CompensatedVectors` because they
763            // defer their distance function implementation to `BitSlice`, which is more
764            // heavily tested.
765            const MAX_DIM: usize = 37;
766            const TRIALS_PER_DIM: usize = 1;
767        } else {
768            const MAX_DIM: usize = 256;
769            const TRIALS_PER_DIM: usize = 20;
770        }
771    }
772
773    macro_rules! test_unsigned_compensated {
774        (
775            $name:ident,
776            $nbits:literal,
777            $relative_err_l2:literal,
778            $relative_err_ip:literal,
779            $relative_err_cos:literal,
780            $seed:literal
781        ) => {
782            #[test]
783            fn $name() {
784                let mut rng = StdRng::seed_from_u64($seed);
785                let absolute_error: f32 = 2.0e-7;
786                for dim in 0..MAX_DIM {
787                    test_compensated_distance::<$nbits, _>(
788                        dim,
789                        TRIALS_PER_DIM,
790                        $relative_err_l2,
791                        $relative_err_ip,
792                        $relative_err_cos,
793                        absolute_error,
794                        &mut rng,
795                    );
796                }
797            }
798        };
799    }
800
801    test_unsigned_compensated!(
802        unsigned_compensated_distances_8bit,
803        8,
804        4.0e-4,
805        3.0e-6,
806        1.0e-3,
807        0xa32d5658097a1c35
808    );
809    test_unsigned_compensated!(
810        unsigned_compensated_distances_7bit,
811        7,
812        5.0e-6,
813        3.0e-6,
814        1.0e-3,
815        0x0b65ca44ec7b47d8
816    );
817    test_unsigned_compensated!(
818        unsigned_compensated_distances_6bit,
819        6,
820        5.0e-6,
821        3.0e-6,
822        1.0e-3,
823        0x471b640fba5c520b
824    );
825    test_unsigned_compensated!(
826        unsigned_compensated_distances_5bit,
827        5,
828        5.0e-6,
829        3.0e-6,
830        1.0e-3,
831        0xf60c0c8d1aadc126
832    );
833    test_unsigned_compensated!(
834        unsigned_compensated_distances_4bit,
835        4,
836        3.0e-6,
837        3.0e-6,
838        1.0e-3,
839        0xcc2b897373a143f3
840    );
841    test_unsigned_compensated!(
842        unsigned_compensated_distances_3bit,
843        3,
844        3.0e-6,
845        3.0e-6,
846        1.0e-3,
847        0xaedf3d2a223b7b77
848    );
849    test_unsigned_compensated!(
850        unsigned_compensated_distances_2bit,
851        2,
852        3.0e-6,
853        3.0e-6,
854        1.0e-3,
855        0x2b34015910b34083
856    );
857    test_unsigned_compensated!(
858        unsigned_compensated_distances_1bit,
859        1,
860        0.0,
861        0.0,
862        0.0,
863        0x09fa14c42a9d7d98
864    );
865}