diskann_quantization/spherical/
vectors.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6//! # Spherical Quantization Distance Functions
7//!
8//! ## Inner Product - 1-bit Symmetric
9//!
10//! Key:
11//! * `X'` (Upper case with "prime"): The original, full-precision vectors.
12//! * `C`: The dataset centroid.
13//! * `X` (Upper case, no "prime"): The result of `X' - C`. That is, the centered vectors.
14//! * `|X|`: The L2 norm of a vector.
15//! * `x` (Lower case): The normalized version of a vector `X` respectively.
16//! * `x'`: The quantized reconstruction of `x`, computed as `x = T(x!)` where
17//!
18//!   - `x!` is the binary encoded vector in `{-1/sqrt(dim), +1/sqrt(dim)}^dim`.
19//!   - `x -> T(x)` is the distance-preserving transformation.
20//!
21//! ```math
22//! <X', Y'> = <X + C, Y + C>
23//!          = <X, Y> + <X, C> + <Y, C> + <C, C>
24//!          = |X| |Y| <x, y> + <X, C> + <Y, C> + |C|^2
25//!                    ------
26//!                      |
27//!                 Normalized
28//!                 Components
29//! ```
30//!
31//! Now, working with the normalized components:
32//! ```math
33//! <x, y> \approx <x', y'> / (<x', x> * <y', y>)         [From the RabitQ Paper]
34//!                            -------   -------
35//!                               |         |
36//!                            Self Dot  Self Dot
37//!                            Product    Product
38//! ```
39//! Where `x'` and `y'` are transformed vectors in the domain `{-1/sqrt(D), 1/sqrt(D)}^D`.
40//!
41//! This is the result from the RabitQ paper (though modified to work on two symmetrically
42//! compressed vectors).
43//!
44//! NOTE: The symmetric correction factor gives incorrect estimates for estimating the
45//! distance between a vector and itself because the term `<x', x>` is strictly less
46//! than one, bringing the estimate for the inner product `<x, x>` to a value greater than 1.
47//! In practice, this still yields better recall (both exhaustive and via graph build) than
48//! no correction, so we keep it.
49//!
50//! Finally, to compute the inner product `<x', y'>` we use the following general approach:
51//! ```math
52//! <x', y'> = <a * (bx + b), c * (by + d)>
53//!          = (a * b) ( <bx, by> + b*sum(by) + d*sum(bx) + b*d )
54//!            -------   --------   - -------   - -------
55//!               |         |       |    |      |    |
56//!            Scaling      |       | Bit Sum   | Bit Sum
57//!             Terms       |       |           |
58//!                         |       |        y offset
59//!                    Bit Inner    |
60//!                     Product     |
61//!                              x offset
62//! ```
63//!
64//! When the vectors `x` and `y` use the same scaling or offset terms, some of this
65//! computation cam be simplified. However, spherical quantization allows queries to use
66//! a different compression (i.e., scalar quantization) and so this term reflects the
67//! general strategy.
68//!
69//! Thus, for each vector `X`, we need the following compensation values:
70//!
71//! 1. `|X| * a / <x', x>`: The norm of `X'` after it has been shifted by the centroid
72//!    multiplied by the quantization scaling parameter, divided by the correction term.
73//!    This whole expression is multiplied to the the result of the inner product term to
74//!    obtain the full-norm estimate of the shifted inner product.
75//!
76//! 2. `<X, C>`: The inner product between the shifted vector and the centroid.
77//!
78//! 3. `sum(bx)`: The sum of the bits in the binary vector representation of `x'`.
79//!
80//! 4. `|X|`: The norm of the shifted vector - used to computed L2 distances.
81//!
82//! ## Squared L2 - 1-bit Symmetric
83//!
84//! ```math
85//! |X' - Y'| = | (X' - C) - (Y' - C) |
86//!           = | X - Y |
87//!           = |X|^2 + |Y|^2 - 2 <X, Y>
88//!           = |X|^2 + |Y|^2 - 2 |X| |Y| <x, y>
89//!                                       ------
90//!                                         |
91//!                              Reuse from Inner Product
92//! ```
93//!
94//! The compensation terms used here are the same as the same.
95//!
96//! # Full Precision Queries
97//!
98//! When the vector `Y` is full-precision, the expression for the inner product becomes
99//! ```math
100//! <a(X + b), Y> = a (<X, Y> + b * sum(Y))
101//! ```
102//!
103//! # Dev Notes
104//!
105//! The functions implemented here use the [`diskann_wide::arch::Target2`] interface to
106//! propagate micro-architecture defails from the caller.
107//!
108//! When calling implementations in [`crate::bits::distances]`, be sure to use
109//! [`diskann_wide::Architecture::run2`] instead to invoke the distance functions. This will
110//! architecture specific
111//! [target features](https://rust-lang.github.io/rfcs/2045-target-feature.html) are
112//! inhereted properly, even if these functions are not inlined.
113
114use diskann_utils::{Reborrow, ReborrowMut};
115use diskann_vector::{norm::FastL2NormSquared, Norm};
116use diskann_wide::{arch::Target2, Architecture};
117use half::f16;
118use thiserror::Error;
119
120#[cfg(feature = "flatbuffers")]
121use crate::flatbuffers as fb;
122use crate::{
123    alloc::{AllocatorCore, AllocatorError, Poly},
124    bits::{BitSlice, Dense, PermutationStrategy, Representation, Unsigned},
125    distances::{self, InnerProduct, MV},
126    meta,
127};
128
129//////////////////////
130// Supported Metric //
131//////////////////////
132
133/// The metrics that are supported by [`crate::spherical::SphericalQuantizer`].
134#[derive(Debug, Clone, Copy, PartialEq, Eq)]
135pub enum SupportedMetric {
136    SquaredL2,
137    InnerProduct,
138    Cosine,
139}
140
141#[cfg(test)]
142impl SupportedMetric {
143    fn pick(self, shifted_norm: f32, inner_product_with_centroid: f32) -> f32 {
144        match self {
145            Self::SquaredL2 => shifted_norm * shifted_norm,
146            Self::InnerProduct | Self::Cosine => inner_product_with_centroid,
147        }
148    }
149
150    #[cfg(feature = "flatbuffers")]
151    pub(super) fn all() -> [Self; 3] {
152        [Self::SquaredL2, Self::InnerProduct, Self::Cosine]
153    }
154}
155
156impl TryFrom<diskann_vector::distance::Metric> for SupportedMetric {
157    type Error = UnsupportedMetric;
158    fn try_from(metric: diskann_vector::distance::Metric) -> Result<Self, Self::Error> {
159        use diskann_vector::distance::Metric;
160        match metric {
161            Metric::L2 => Ok(Self::SquaredL2),
162            Metric::InnerProduct => Ok(Self::InnerProduct),
163            Metric::Cosine => Ok(Self::Cosine),
164            unsupported => Err(UnsupportedMetric(unsupported)),
165        }
166    }
167}
168
169impl PartialEq<diskann_vector::distance::Metric> for SupportedMetric {
170    fn eq(&self, metric: &diskann_vector::distance::Metric) -> bool {
171        match Self::try_from(*metric) {
172            Ok(m) => *self == m,
173            Err(_) => false,
174        }
175    }
176}
177
178#[derive(Debug, Clone, Copy, Error)]
179#[error("metric {0:?} is not supported for spherical quantization")]
180pub struct UnsupportedMetric(pub(crate) diskann_vector::distance::Metric);
181
182#[cfg(feature = "flatbuffers")]
183#[cfg_attr(docsrs, doc(cfg(feature = "flatbuffers")))]
184#[derive(Debug, Clone, Copy, PartialEq, Error)]
185#[error("the value {0} is not recognized as a supported metric")]
186pub struct InvalidMetric(i8);
187
188#[cfg(feature = "flatbuffers")]
189#[cfg_attr(docsrs, doc(cfg(feature = "flatbuffers")))]
190impl TryFrom<fb::spherical::SupportedMetric> for SupportedMetric {
191    type Error = InvalidMetric;
192    fn try_from(value: fb::spherical::SupportedMetric) -> Result<Self, Self::Error> {
193        match value {
194            fb::spherical::SupportedMetric::SquaredL2 => Ok(Self::SquaredL2),
195            fb::spherical::SupportedMetric::InnerProduct => Ok(Self::InnerProduct),
196            fb::spherical::SupportedMetric::Cosine => Ok(Self::Cosine),
197            unsupported => Err(InvalidMetric(unsupported.0)),
198        }
199    }
200}
201
202#[cfg(feature = "flatbuffers")]
203#[cfg_attr(docsrs, doc(cfg(feature = "flatbuffers")))]
204impl From<SupportedMetric> for fb::spherical::SupportedMetric {
205    fn from(value: SupportedMetric) -> Self {
206        match value {
207            SupportedMetric::SquaredL2 => fb::spherical::SupportedMetric::SquaredL2,
208            SupportedMetric::InnerProduct => fb::spherical::SupportedMetric::InnerProduct,
209            SupportedMetric::Cosine => fb::spherical::SupportedMetric::Cosine,
210        }
211    }
212}
213
214//////////
215// Data //
216//////////
217
218/// Metadata for correcting quantization for computing distances among quant vectors.
219#[derive(Debug, Default, Clone, Copy, PartialEq, bytemuck::Zeroable, bytemuck::Pod)]
220#[repr(C)]
221pub struct DataMeta {
222    /// This is the whole term
223    /// ```math
224    /// |X| * a / <x', x>
225    /// ```
226    /// and represents the entires correction factor for computing inner products on the
227    /// representation
228    /// ```math
229    /// bx + b
230    /// ```
231    /// where `bx` is unsigned binary encoding of the vector and `b` (obtained from
232    /// `Self::offset_term`) is the compression offset.
233    pub inner_product_correction: f16,
234
235    /// A metric-specific correction term. Refer to the module level documentation to
236    /// understand the implication of the terms outlined here.
237    ///
238    /// | Squared L2    |  `|X|^2`  |
239    /// | Inner Product | `<X', C>` |
240    pub metric_specific: f16,
241
242    /// Two times the sum of the ones in the binary representation of the transformed
243    /// unit vector.
244    ///
245    /// This is the term `sum(bx)` in the module level documentation.
246    pub bit_sum: u16,
247}
248
249#[derive(Debug, Error, Clone, Copy, PartialEq)]
250pub enum DataMetaError {
251    #[error("inner product correction {value} cannot fit in a 16-bit floating point number")]
252    InnerProductCorrection { value: f32 },
253
254    #[error("metric specific correction {value} cannot fit in a 16-bit floating point number")]
255    MetricSpecific { value: f32 },
256
257    #[error("bit sum {value} cannot fit in a 16-bit unsigned integer")]
258    BitSum { value: u32 },
259}
260
261impl DataMeta {
262    /// Construct a new metadata from components.
263    ///
264    /// This will internally convert the `f32` values to `f16`.
265    pub fn new(
266        inner_product_correction: f32,
267        metric_specific: f32,
268        bit_sum: u32,
269    ) -> Result<Self, DataMetaError> {
270        let inner_product_correction_f16 = diskann_wide::cast_f32_to_f16(inner_product_correction);
271        if !inner_product_correction_f16.is_finite() {
272            return Err(DataMetaError::InnerProductCorrection {
273                value: inner_product_correction,
274            });
275        }
276
277        let metric_specific_f16 = diskann_wide::cast_f32_to_f16(metric_specific);
278        if !metric_specific_f16.is_finite() {
279            return Err(DataMetaError::MetricSpecific {
280                value: metric_specific,
281            });
282        }
283
284        let bit_sum_u16: u16 = bit_sum
285            .try_into()
286            .map_err(|_| DataMetaError::BitSum { value: bit_sum })?;
287
288        Ok(Self {
289            inner_product_correction: inner_product_correction_f16,
290            metric_specific: metric_specific_f16,
291            bit_sum: bit_sum_u16,
292        })
293    }
294
295    /// Compute the term `b` for a binary compression of a vector so the reconstruction can
296    /// be expressed as
297    /// ```math
298    /// a (bx + b)
299    /// ```
300    /// where
301    ///
302    /// * `a` is the scaling term to achieve the correct dynamic range.
303    /// * `bx` is the unsigned binary encoded vector.
304    ///
305    /// This value is computed as
306    /// ```math
307    /// 2 ^ NBITS - 1
308    /// -------------
309    ///      2
310    /// ```
311    /// and ensures equal coverage above and below 0.
312    const fn offset_term<const NBITS: usize>() -> f32 {
313        ((2usize).pow(NBITS as u32) as f32 - 1.0) / 2.0
314    }
315
316    /// Convert the values in `self` to their full precision representation for computation.
317    #[inline(always)]
318    pub fn to_full<A>(self, arch: A) -> DataMetaF32
319    where
320        A: Architecture,
321    {
322        use diskann_wide::SIMDVector;
323
324        // Relying on `diskann_wide::cast_f16_to_f32` to correctly propagation `target_features`
325        // correction does not seem to completely work.
326        //
327        // We take matters into our own hand and use the architecture's conversion routines
328        // directly.
329        let pre = [
330            self.metric_specific,
331            self.inner_product_correction,
332            half::f16::default(),
333            half::f16::default(),
334            half::f16::default(),
335            half::f16::default(),
336            half::f16::default(),
337            half::f16::default(),
338        ];
339
340        let post: <A as Architecture>::f32x8 =
341            <A as Architecture>::f16x8::from_array(arch, pre).into();
342        let post = post.to_array();
343
344        DataMetaF32 {
345            metric_specific: post[0],
346            inner_product_correction: post[1],
347            bit_sum: self.bit_sum.into(),
348        }
349    }
350}
351
352#[derive(Debug, Default, Clone, Copy, PartialEq, bytemuck::Zeroable, bytemuck::Pod)]
353#[repr(C)]
354pub struct DataMetaF32 {
355    pub inner_product_correction: f32,
356    pub metric_specific: f32,
357    pub bit_sum: f32,
358}
359
360/// A borrowed `ComptensatedVector`.
361pub type DataRef<'a, const NBITS: usize> = meta::VectorRef<'a, NBITS, Unsigned, DataMeta>;
362
363/// A mutably borrowed `ComptensatedVector`.
364pub type DataMut<'a, const NBITS: usize> = meta::VectorMut<'a, NBITS, Unsigned, DataMeta>;
365
366/// An owning data vector.
367pub type Data<const NBITS: usize, A> = meta::PolyVector<NBITS, Unsigned, DataMeta, Dense, A>;
368
369///////////
370// Query //
371///////////
372
373/// Scalar quantization correction factors for computing distances between scalar quantized
374/// queries and spherically quantized data elements.
375///
376/// Computing the distance between a query and a data vector uses the same forumla derived
377/// in the module level documentation.
378///
379/// The one difference is that the query must explicitly carry the "offset" term as it
380/// cannot be derived from the number of bits used for the compression.
381#[derive(Copy, Clone, Default, Debug, PartialEq, bytemuck::Zeroable, bytemuck::Pod)]
382#[repr(C)]
383pub struct QueryMeta {
384    /// The value with which to scale the bit-level inner product with the 1-bit data data
385    /// vectors.
386    pub inner_product_correction: f32,
387
388    /// Scaling factor for the `DataMeta::twice_contraction`. Applied separately to
389    /// still allow 1-bit vectors using `DataMeta` to compute distances with eachother
390    /// efficiently.
391    pub bit_sum: f32,
392
393    /// The query-specific offset, taking into account the scaling factor for the query as
394    /// well as its minimum value. See the struct-level documentation for an explanation.
395    pub offset: f32,
396
397    /// The corresponding metric specific term as [`DataMeta`].
398    pub metric_specific: f32,
399}
400
401/// A specialized type for computing higher-precision inner products with data vectors.
402pub type Query<const NBITS: usize, Perm, A> = meta::PolyVector<NBITS, Unsigned, QueryMeta, Perm, A>;
403
404/// A reference-like version of `Query`.
405pub type QueryRef<'a, const NBITS: usize, Perm> =
406    meta::VectorRef<'a, NBITS, Unsigned, QueryMeta, Perm>;
407
408/// A mutable reference-like version of `Query`.
409pub type QueryMut<'a, const NBITS: usize, Perm> =
410    meta::VectorMut<'a, NBITS, Unsigned, QueryMeta, Perm>;
411
412////////////////////
413// Full Precision //
414////////////////////
415
416#[derive(Debug, Clone, Copy, Default, bytemuck::Zeroable, bytemuck::Pod)]
417#[repr(C)]
418pub struct FullQueryMeta {
419    /// The sum of `data`.
420    pub sum: f32,
421    /// The norm of the shifted vector.
422    pub shifted_norm: f32,
423    /// Metric specific correction term. See [`DataMeta`].
424    pub metric_specific: f32,
425}
426
427/// A full-precision query.
428#[derive(Debug)]
429pub struct FullQuery<A>
430where
431    A: AllocatorCore,
432{
433    /// The data after centering, normalization, and transformation.
434    pub data: Poly<[f32], A>,
435    pub meta: FullQueryMeta,
436}
437
438impl<A> FullQuery<A>
439where
440    A: AllocatorCore,
441{
442    /// Construct an empty `FullQuery` for `dim` dimensional data.
443    pub fn empty(dim: usize, allocator: A) -> Result<Self, AllocatorError> {
444        Ok(Self {
445            data: Poly::broadcast(0.0f32, dim, allocator)?,
446            meta: Default::default(),
447        })
448    }
449}
450
451pub type FullQueryRef<'a> = meta::slice::SliceRef<'a, f32, FullQueryMeta>;
452
453pub type FullQueryMut<'a> = meta::slice::SliceMut<'a, f32, FullQueryMeta>;
454
455impl<'short, A> Reborrow<'short> for FullQuery<A>
456where
457    A: AllocatorCore,
458{
459    type Target = FullQueryRef<'short>;
460    fn reborrow(&'short self) -> Self::Target {
461        FullQueryRef::new(&self.data, &self.meta)
462    }
463}
464
465impl<'short, A> ReborrowMut<'short> for FullQuery<A>
466where
467    A: AllocatorCore,
468{
469    type Target = FullQueryMut<'short>;
470    fn reborrow_mut(&'short mut self) -> Self::Target {
471        FullQueryMut::new(&mut self.data, &mut self.meta)
472    }
473}
474
475/////////////
476// Helpers //
477/////////////
478
479/// This is a workaround to the error `Can't use generic parameters from outer function.` by
480/// forcing constant evaluation of expressions involving offset terms.
481struct ConstOffset<const NBITS: usize>;
482
483impl<const NBITS: usize> ConstOffset<NBITS> {
484    const OFFSET: f32 = DataMeta::offset_term::<NBITS>();
485    const OFFSET_SQUARED: f32 = DataMeta::offset_term::<NBITS>() * DataMeta::offset_term::<NBITS>();
486}
487
488/// This represents the computation
489/// ```math
490/// |X'| |Y'| <x, y>
491/// ```
492/// from the module-level docstring.
493#[inline(always)]
494fn kernel<A, const NBITS: usize>(
495    arch: A,
496    x: DataRef<'_, NBITS>,
497    y: DataRef<'_, NBITS>,
498    dim: f32,
499) -> distances::Result<f32>
500where
501    A: Architecture,
502    Unsigned: Representation<NBITS>,
503    InnerProduct: for<'a> Target2<
504        A,
505        distances::MathematicalResult<u32>,
506        BitSlice<'a, NBITS, Unsigned>,
507        BitSlice<'a, NBITS, Unsigned>,
508    >,
509{
510    // NOTE: `Target2<_, _, _, _>` is used instead of `Architecture::run2` to ensure that
511    // the kernel is inlined into this callsize.
512    //
513    // Even using `Architecture::run2_inline` is not sufficient to guarantee inlining.
514    let ip: distances::MathematicalResult<u32> =
515        <_ as Target2<_, _, _, _>>::run(InnerProduct, arch, x.vector(), y.vector());
516
517    let ip = ip?.into_inner() as f32;
518
519    let offset = ConstOffset::<NBITS>::OFFSET;
520    let offset_squared = ConstOffset::<NBITS>::OFFSET_SQUARED;
521
522    let xc = x.meta().to_full(arch);
523    let yc = y.meta().to_full(arch);
524
525    Ok(xc.inner_product_correction
526        * yc.inner_product_correction
527        * (ip - offset * (xc.bit_sum + yc.bit_sum) + offset_squared * dim))
528}
529
530////////////////////////////
531// Compensated Squared L2 //
532////////////////////////////
533
534/// A `DistanceFunction` containing scaling parameters to enable distance the SquaredL2
535/// distance function over `CompensatedVectors` belonging to the same quantization space.
536#[derive(Debug, Clone, Copy)]
537pub struct CompensatedSquaredL2 {
538    pub(super) dim: f32,
539}
540
541impl CompensatedSquaredL2 {
542    /// Construct a new `CompensatedSquaredL2` with the given scaling factor.
543    pub fn new(dim: usize) -> Self {
544        Self { dim: dim as f32 }
545    }
546}
547
548/// A blanket implementation for applying the identity transformation from
549/// `MathematicalValue` to `f32` for Euclidean distance computations.
550impl<A, T, U> Target2<A, distances::Result<f32>, T, U> for CompensatedSquaredL2
551where
552    A: Architecture,
553    Self: Target2<A, distances::MathematicalResult<f32>, T, U>,
554{
555    #[inline(always)]
556    fn run(self, arch: A, x: T, y: U) -> distances::Result<f32> {
557        self.run(arch, x, y).map(|r| r.into_inner())
558    }
559}
560
561/// Compute the squared euclidean distance between the two compensated vectors.
562///
563/// The value returned by this function is scaled properly, meaning that distances returned
564/// by this method are compatible with full-precision distances.
565///
566/// # Validity
567///
568/// The results of this function are only meaningful if both `x`, `y`, and `Self` belong to
569/// the same quantizer.
570impl<A, const NBITS: usize>
571    Target2<A, distances::MathematicalResult<f32>, DataRef<'_, NBITS>, DataRef<'_, NBITS>>
572    for CompensatedSquaredL2
573where
574    A: Architecture,
575    Unsigned: Representation<NBITS>,
576    InnerProduct: for<'a> Target2<
577        A,
578        distances::MathematicalResult<u32>,
579        BitSlice<'a, NBITS, Unsigned>,
580        BitSlice<'a, NBITS, Unsigned>,
581    >,
582{
583    #[inline(always)]
584    fn run(
585        self,
586        arch: A,
587        x: DataRef<'_, NBITS>,
588        y: DataRef<'_, NBITS>,
589    ) -> distances::MathematicalResult<f32> {
590        let xc = x.meta().to_full(arch);
591        let yc = y.meta().to_full(arch);
592        let result = xc.metric_specific + yc.metric_specific - 2.0 * kernel(arch, x, y, self.dim)?;
593        Ok(MV::new(result))
594    }
595}
596
597impl<A, const Q: usize, const D: usize, Perm>
598    Target2<A, distances::MathematicalResult<f32>, QueryRef<'_, Q, Perm>, DataRef<'_, D>>
599    for CompensatedSquaredL2
600where
601    A: Architecture,
602    Unsigned: Representation<Q>,
603    Unsigned: Representation<D>,
604    Perm: PermutationStrategy<Q>,
605    for<'a> InnerProduct: Target2<
606        A,
607        distances::MathematicalResult<u32>,
608        BitSlice<'a, Q, Unsigned, Perm>,
609        BitSlice<'a, D, Unsigned>,
610    >,
611{
612    #[inline(always)]
613    fn run(
614        self,
615        arch: A,
616        x: QueryRef<'_, Q, Perm>,
617        y: DataRef<'_, D>,
618    ) -> distances::MathematicalResult<f32> {
619        let ip: distances::MathematicalResult<u32> =
620            arch.run2_inline(InnerProduct, x.vector(), y.vector());
621        let ip = ip?.into_inner() as f32;
622
623        let yc = y.meta().to_full(arch);
624        let xc = x.meta();
625
626        let y_offset: f32 = DataMeta::offset_term::<D>();
627
628        let corrected_ip = yc.inner_product_correction
629            * xc.inner_product_correction
630            * (ip - y_offset * xc.bit_sum + xc.offset * yc.bit_sum
631                - y_offset * xc.offset * self.dim);
632
633        Ok(MV::new(
634            yc.metric_specific + xc.metric_specific - 2.0 * corrected_ip,
635        ))
636    }
637}
638
639/// Compute the inner product between a full-precision query and a spherically quantized
640/// data vector.
641///
642/// Returns an error if the arguments have different lengths.
643impl<A, const NBITS: usize>
644    Target2<A, distances::MathematicalResult<f32>, FullQueryRef<'_>, DataRef<'_, NBITS>>
645    for CompensatedSquaredL2
646where
647    A: Architecture,
648    Unsigned: Representation<NBITS>,
649    InnerProduct: for<'a> Target2<
650        A,
651        distances::MathematicalResult<f32>,
652        &'a [f32],
653        BitSlice<'a, NBITS, Unsigned>,
654    >,
655{
656    #[inline(always)]
657    fn run(
658        self,
659        arch: A,
660        x: FullQueryRef<'_>,
661        y: DataRef<'_, NBITS>,
662    ) -> distances::MathematicalResult<f32> {
663        let s = arch
664            .run2(InnerProduct, x.vector(), y.vector())?
665            .into_inner();
666
667        let xc = x.meta();
668        let yc = y.meta().to_full(arch);
669
670        let offset = ConstOffset::<NBITS>::OFFSET;
671        let ip = s - xc.sum * offset;
672
673        // NOTE: `xc.metric_specific` already carries the square norm, so we can save
674        // a multiple by using it directly.
675        let r = xc.metric_specific + yc.metric_specific
676            - 2.0 * xc.shifted_norm * yc.inner_product_correction * ip;
677        Ok(MV::new(r))
678    }
679}
680
681////////////////////
682// Compensated IP //
683////////////////////
684
685/// A `DistanceFunction` containing scaling parameters to enable distance the SquaredL2
686/// distance function over `CompensatedVectors` belonging to the same quantization space.
687#[derive(Debug, Clone, Copy)]
688pub struct CompensatedIP {
689    pub(super) squared_shift_norm: f32,
690    pub(super) dim: f32,
691}
692
693impl CompensatedIP {
694    /// Construct a new `CompensatedIP` with the given scaling factor and shift norm.
695    pub fn new(shift: &[f32], dim: usize) -> Self {
696        Self {
697            squared_shift_norm: FastL2NormSquared.evaluate(shift),
698            dim: dim as f32,
699        }
700    }
701}
702
703/// A blanket implementation for applying the negating transformation
704/// ```text
705/// x -> -x
706/// ```
707/// from `MathematicalValue` to `f32` for inner product distance computations.
708impl<A, T, U> Target2<A, distances::Result<f32>, T, U> for CompensatedIP
709where
710    A: Architecture,
711    Self: Target2<A, distances::MathematicalResult<f32>, T, U>,
712{
713    #[inline(always)]
714    fn run(self, arch: A, x: T, y: U) -> distances::Result<f32> {
715        arch.run2(self, x, y).map(|r| -r.into_inner())
716    }
717}
718
719/// Compute the inner product between the two compensated vectors.
720///
721/// Returns an error if the arguments have different lengths.
722///
723/// The value returned by this function is scaled properly, meaning that distances returned
724/// by this method are compatible with full-precision computations.
725///
726/// # Validity
727///
728/// The results of this function are only meaningful if both `x`, `y`, and `Self` belong to
729/// the same quantizer.
730impl<A, const NBITS: usize>
731    Target2<A, distances::MathematicalResult<f32>, DataRef<'_, NBITS>, DataRef<'_, NBITS>>
732    for CompensatedIP
733where
734    A: Architecture,
735    Unsigned: Representation<NBITS>,
736    InnerProduct: for<'a> Target2<
737        A,
738        distances::MathematicalResult<u32>,
739        BitSlice<'a, NBITS, Unsigned>,
740        BitSlice<'a, NBITS, Unsigned>,
741    >,
742{
743    #[inline(always)]
744    fn run(
745        self,
746        arch: A,
747        x: DataRef<'_, NBITS>,
748        y: DataRef<'_, NBITS>,
749    ) -> distances::MathematicalResult<f32> {
750        let xc = x.meta().to_full(arch);
751        let yc = y.meta().to_full(arch);
752
753        let result = xc.metric_specific
754            + yc.metric_specific
755            + kernel(arch, x, y, self.dim)?
756            + self.squared_shift_norm;
757        Ok(MV::new(result))
758    }
759}
760
761impl<A, const Q: usize, const D: usize, Perm>
762    Target2<A, distances::MathematicalResult<f32>, QueryRef<'_, Q, Perm>, DataRef<'_, D>>
763    for CompensatedIP
764where
765    A: Architecture,
766    Unsigned: Representation<Q>,
767    Unsigned: Representation<D>,
768    Perm: PermutationStrategy<Q>,
769    for<'a> InnerProduct: Target2<
770        A,
771        distances::MathematicalResult<u32>,
772        BitSlice<'a, Q, Unsigned, Perm>,
773        BitSlice<'a, D, Unsigned>,
774    >,
775{
776    #[inline(always)]
777    fn run(
778        self,
779        arch: A,
780        x: QueryRef<'_, Q, Perm>,
781        y: DataRef<'_, D>,
782    ) -> distances::MathematicalResult<f32> {
783        // The inner product of the bit-level data.
784        let ip: MV<u32> = arch.run2_inline(InnerProduct, x.vector(), y.vector())?;
785
786        let yc = y.meta().to_full(arch);
787        let xc = x.meta();
788
789        // Rely on constant propagation to pre-compute these terms.
790        let y_offset: f32 = DataMeta::offset_term::<D>();
791
792        let corrected_ip = xc.inner_product_correction
793            * yc.inner_product_correction
794            * (ip.into_inner() as f32 - y_offset * xc.bit_sum + xc.offset * yc.bit_sum
795                - y_offset * xc.offset * self.dim);
796
797        // Finally, reassemble the remaining compensation terms.
798        Ok(MV::new(
799            corrected_ip + yc.metric_specific + xc.metric_specific + self.squared_shift_norm,
800        ))
801    }
802}
803
804/// Compute the inner product between a full-precision query and a spherically quantized
805/// data vector.
806///
807/// Returns an error if the arguments have different lengths.
808impl<A, const NBITS: usize>
809    Target2<A, distances::MathematicalResult<f32>, FullQueryRef<'_>, DataRef<'_, NBITS>>
810    for CompensatedIP
811where
812    A: Architecture,
813    Unsigned: Representation<NBITS>,
814    InnerProduct: for<'a> Target2<
815        A,
816        distances::MathematicalResult<f32>,
817        &'a [f32],
818        BitSlice<'a, NBITS, Unsigned>,
819    >,
820{
821    #[inline(always)]
822    fn run(
823        self,
824        arch: A,
825        x: FullQueryRef<'_>,
826        y: DataRef<'_, NBITS>,
827    ) -> distances::MathematicalResult<f32> {
828        let s = arch
829            .run2(InnerProduct, x.vector(), y.vector())?
830            .into_inner();
831
832        let yc = y.meta().to_full(arch);
833        let xc = x.meta();
834
835        let offset = ConstOffset::<NBITS>::OFFSET;
836        let ip = xc.shifted_norm * yc.inner_product_correction * (s - xc.sum * offset);
837
838        Ok(MV::new(
839            ip + xc.metric_specific + yc.metric_specific + self.squared_shift_norm,
840        ))
841    }
842}
843
844////////////////////////
845// Compensated Cosine //
846////////////////////////
847
848/// A `DistanceFunction` containing scaling parameters to enable distance the Cosine
849/// distance function over vectors belonging to the same quantization space.
850///
851/// This distance function works by assuming input vectors were normalized **prior** to
852/// compression and therefore cosine may be computed by delegating to inner product
853/// computations. The [`crate::spherical::SphericalQuantizer`] will ensure this
854/// pre-normalization when constructed with [`SupportedMetric::Cosine`].
855#[derive(Debug, Clone, Copy)]
856pub struct CompensatedCosine {
857    pub(super) inner: CompensatedIP,
858}
859
860impl CompensatedCosine {
861    /// Construct a new `CompensatedCosine` around the [`CompensatedIP`].
862    pub fn new(inner: CompensatedIP) -> Self {
863        Self { inner }
864    }
865}
866
867impl<A, T, U> Target2<A, distances::MathematicalResult<f32>, T, U> for CompensatedCosine
868where
869    A: Architecture,
870    CompensatedIP: Target2<A, distances::MathematicalResult<f32>, T, U>,
871{
872    #[inline(always)]
873    fn run(self, arch: A, x: T, y: U) -> distances::MathematicalResult<f32> {
874        self.inner.run(arch, x, y)
875    }
876}
877
878/// A blanket implementation for applying the transformation
879/// ```text
880/// x -> 1-x
881/// ```
882/// from `MathematicalValue` to `f32` for cosine distance computations.
883impl<A, T, U> Target2<A, distances::Result<f32>, T, U> for CompensatedCosine
884where
885    A: Architecture,
886    Self: Target2<A, distances::MathematicalResult<f32>, T, U>,
887{
888    #[inline(always)]
889    fn run(self, arch: A, x: T, y: U) -> distances::Result<f32> {
890        let r: MV<f32> = self.run(arch, x, y)?;
891        Ok(1.0 - r.into_inner())
892    }
893}
894
895///////////
896// Tests //
897///////////
898
899#[cfg(test)]
900mod tests {
901    use diskann_utils::{lazy_format, Reborrow};
902    use diskann_vector::{distance::Metric, norm::FastL2Norm, PureDistanceFunction};
903    use diskann_wide::ARCH;
904    use rand::{
905        distr::{Distribution, Uniform},
906        rngs::StdRng,
907        SeedableRng,
908    };
909    use rand_distr::StandardNormal;
910
911    use super::*;
912    use crate::{
913        alloc::GlobalAllocator,
914        bits::{BitTranspose, Dense},
915    };
916
917    #[derive(Debug, Clone, Copy, PartialEq)]
918    struct Approx {
919        absolute: f32,
920        relative: f32,
921    }
922
923    impl Approx {
924        const fn new(absolute: f32, relative: f32) -> Self {
925            assert!(absolute >= 0.0);
926            assert!(relative >= 0.0);
927            Self { absolute, relative }
928        }
929
930        fn check(&self, got: f32, expected: f32, ctx: Option<&dyn std::fmt::Display>) -> bool {
931            struct Ctx<'a>(Option<&'a dyn std::fmt::Display>);
932
933            impl std::fmt::Display for Ctx<'_> {
934                fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
935                    match self.0 {
936                        None => write!(f, "none"),
937                        Some(d) => write!(f, "{}", d),
938                    }
939                }
940            }
941
942            let absolute = (got - expected).abs();
943            if absolute <= self.absolute {
944                true
945            } else {
946                let relative = absolute / expected.abs();
947                if relative <= self.relative {
948                    true
949                } else {
950                    panic!(
951                        "got {}, expected {}. Abs/Rel = {}/{} with bounds {}/{}: Ctx: {}",
952                        got,
953                        expected,
954                        absolute,
955                        relative,
956                        self.absolute,
957                        self.relative,
958                        Ctx(ctx)
959                    );
960                }
961            }
962        }
963    }
964
965    //////////////
966    // DataMeta //
967    //////////////
968
969    #[test]
970    fn test_data_meta() {
971        // Test constructor happy path.
972        let meta = DataMeta::new(1.0, 2.0, 10).unwrap();
973        let expected = DataMetaF32 {
974            inner_product_correction: 1.0,
975            metric_specific: 2.0,
976            bit_sum: 10.0,
977        };
978        assert_eq!(meta.to_full(ARCH), expected);
979
980        // Test constructor errors.
981        let err = DataMeta::new(65600.0, 2.0, 10).unwrap_err();
982        assert_eq!(
983            err.to_string(),
984            "inner product correction 65600 cannot fit in a 16-bit floating point number"
985        );
986
987        let err = DataMeta::new(2.0, 65600.0, 10).unwrap_err();
988        assert_eq!(
989            err.to_string(),
990            "metric specific correction 65600 cannot fit in a 16-bit floating point number"
991        );
992
993        let err = DataMeta::new(2.0, 2.0, 65536).unwrap_err();
994        assert_eq!(
995            err.to_string(),
996            "bit sum 65536 cannot fit in a 16-bit unsigned integer",
997        );
998    }
999
1000    //////////////////////
1001    // Supported Metric //
1002    //////////////////////
1003
1004    #[test]
1005    fn supported_metric() {
1006        assert_eq!(
1007            SupportedMetric::try_from(Metric::L2).unwrap(),
1008            SupportedMetric::SquaredL2
1009        );
1010        assert_eq!(
1011            SupportedMetric::try_from(Metric::InnerProduct).unwrap(),
1012            SupportedMetric::InnerProduct
1013        );
1014        assert_eq!(
1015            SupportedMetric::try_from(Metric::Cosine).unwrap(),
1016            SupportedMetric::Cosine
1017        );
1018        assert!(matches!(
1019            SupportedMetric::try_from(Metric::CosineNormalized),
1020            Err(UnsupportedMetric(Metric::CosineNormalized))
1021        ));
1022
1023        assert_eq!(SupportedMetric::SquaredL2, Metric::L2);
1024        assert_ne!(SupportedMetric::SquaredL2, Metric::InnerProduct);
1025        assert_ne!(SupportedMetric::SquaredL2, Metric::Cosine);
1026        assert_ne!(SupportedMetric::SquaredL2, Metric::CosineNormalized);
1027
1028        assert_ne!(SupportedMetric::InnerProduct, Metric::L2);
1029        assert_eq!(SupportedMetric::InnerProduct, Metric::InnerProduct);
1030        assert_ne!(SupportedMetric::SquaredL2, Metric::Cosine);
1031        assert_ne!(SupportedMetric::SquaredL2, Metric::CosineNormalized);
1032    }
1033
1034    ///////////////
1035    // Distances //
1036    ///////////////
1037
1038    struct Reference<T> {
1039        compressed: T,
1040        reconstructed: Vec<f32>,
1041        norm: f32,
1042        center_ip: f32,
1043        self_ip: Option<f32>,
1044    }
1045
1046    trait GenerateReference: Sized {
1047        fn generate_reference(
1048            center: &[f32],
1049            metric: SupportedMetric,
1050            rng: &mut StdRng,
1051        ) -> Reference<Self>;
1052    }
1053
1054    impl<const NBITS: usize> GenerateReference for Data<NBITS, GlobalAllocator>
1055    where
1056        Unsigned: Representation<NBITS>,
1057    {
1058        fn generate_reference(
1059            center: &[f32],
1060            metric: SupportedMetric,
1061            rng: &mut StdRng,
1062        ) -> Reference<Self> {
1063            let dim = center.len();
1064
1065            let mut reconstructed = vec![0.0f32; dim];
1066            let mut compressed = Data::<NBITS, _>::new_boxed(dim);
1067
1068            let mut bit_sum = 0;
1069            let dist = Uniform::try_from(Unsigned::domain_const::<NBITS>()).unwrap();
1070            let offset = (2usize.pow(NBITS as u32) as f32 - 1.0) / 2.0;
1071            for (i, r) in reconstructed.iter_mut().enumerate() {
1072                let b: i64 = dist.sample(rng);
1073                bit_sum += b;
1074                compressed.vector_mut().set(i, b).unwrap();
1075                *r = (b as f32) - offset;
1076            }
1077
1078            let r_norm = FastL2Norm.evaluate(reconstructed.as_slice());
1079            reconstructed.iter_mut().for_each(|i| *i /= r_norm);
1080
1081            let norm: f32 = Uniform::new(0.0, 2.0).unwrap().sample(rng);
1082            let center_ip: f32 = Uniform::new(0.5, 2.5).unwrap().sample(rng);
1083            let self_ip: f32 = Uniform::new(0.5, 1.5).unwrap().sample(rng);
1084
1085            compressed.set_meta(
1086                DataMeta::new(
1087                    norm / (self_ip * r_norm),
1088                    metric.pick(norm, center_ip),
1089                    bit_sum.try_into().unwrap(),
1090                )
1091                .unwrap(),
1092            );
1093
1094            Reference {
1095                compressed,
1096                reconstructed,
1097                norm,
1098                center_ip,
1099                self_ip: Some(self_ip),
1100            }
1101        }
1102    }
1103
1104    impl<const NBITS: usize, Perm> GenerateReference for Query<NBITS, Perm, GlobalAllocator>
1105    where
1106        Unsigned: Representation<NBITS>,
1107        Perm: PermutationStrategy<NBITS>,
1108    {
1109        fn generate_reference(
1110            center: &[f32],
1111            metric: SupportedMetric,
1112            rng: &mut StdRng,
1113        ) -> Reference<Self> {
1114            let dim = center.len();
1115
1116            let mut reconstructed = vec![0.0f32; dim];
1117            let mut compressed = Query::<NBITS, Perm, _>::new_boxed(dim);
1118
1119            let distribution = Uniform::try_from(Unsigned::domain_const::<NBITS>()).unwrap();
1120
1121            let base: f32 = StandardNormal {}.sample(rng);
1122            let scale: f32 = {
1123                let scale: f32 = StandardNormal {}.sample(rng);
1124                scale.abs()
1125            };
1126
1127            let mut bit_sum = 0;
1128            for (i, r) in reconstructed.iter_mut().enumerate() {
1129                let b = distribution.sample(rng);
1130                compressed.vector_mut().set(i, b).unwrap();
1131                *r = base + scale * (b as f32);
1132                bit_sum += b;
1133            }
1134
1135            let norm: f32 = Uniform::new(0.0, 2.0).unwrap().sample(rng);
1136            let center_ip: f32 = Uniform::new(-2.0, 2.0).unwrap().sample(rng);
1137
1138            compressed.set_meta(QueryMeta {
1139                inner_product_correction: norm * scale,
1140                bit_sum: bit_sum as f32,
1141                offset: base / scale,
1142                metric_specific: metric.pick(norm, center_ip),
1143            });
1144
1145            Reference {
1146                compressed,
1147                reconstructed,
1148                norm,
1149                center_ip,
1150                self_ip: None,
1151            }
1152        }
1153    }
1154
1155    impl GenerateReference for FullQuery<GlobalAllocator> {
1156        fn generate_reference(
1157            center: &[f32],
1158            metric: SupportedMetric,
1159            rng: &mut StdRng,
1160        ) -> Reference<Self> {
1161            let dim = center.len();
1162
1163            let mut query = FullQuery::empty(dim, GlobalAllocator).unwrap();
1164
1165            let mut sum = 0.0;
1166            let dist = StandardNormal {};
1167            for r in query.data.iter_mut() {
1168                let b: f32 = dist.sample(rng);
1169                sum += b;
1170                *r = b;
1171            }
1172
1173            let r_norm = FastL2Norm.evaluate(&*query.data);
1174            query.data.iter_mut().for_each(|i| *i /= r_norm);
1175
1176            let norm: f32 = Uniform::new(0.0, 2.0).unwrap().sample(rng);
1177            let center_ip: f32 = Uniform::new(-2.0, 2.0).unwrap().sample(rng);
1178
1179            query.meta = FullQueryMeta {
1180                sum: sum / r_norm,
1181                shifted_norm: norm,
1182                metric_specific: metric.pick(norm, center_ip),
1183            };
1184
1185            let reconstructed = query.data.to_vec();
1186            Reference {
1187                compressed: query,
1188                reconstructed,
1189                norm,
1190                center_ip,
1191                self_ip: None,
1192            }
1193        }
1194    }
1195
1196    /// Refer to the module level documentation for some insight into what these components
1197    /// mean.
1198    ///
1199    /// The gist of these tests are that we generate the binary vectors `bx` and `by`
1200    /// (along with their code-book representation), the center, and the shifted versions
1201    /// of the target vectors.
1202    ///
1203    /// From those components, we computed the compensation terms and compute the expected
1204    /// result manually, verifying that the compensated computation works as expected.
1205    fn test_compensated_distance<const NBITS: usize>(
1206        dim: usize,
1207        ntrials: usize,
1208        err_l2: Approx,
1209        err_ip: Approx,
1210        rng: &mut StdRng,
1211    ) where
1212        Unsigned: Representation<NBITS>,
1213        for<'a> CompensatedIP: Target2<
1214                diskann_wide::arch::Current,
1215                distances::Result<f32>,
1216                DataRef<'a, NBITS>,
1217                DataRef<'a, NBITS>,
1218            > + Target2<
1219                diskann_wide::arch::Current,
1220                distances::MathematicalResult<f32>,
1221                DataRef<'a, NBITS>,
1222                DataRef<'a, NBITS>,
1223            >,
1224        for<'a> CompensatedSquaredL2: Target2<
1225                diskann_wide::arch::Current,
1226                distances::Result<f32>,
1227                DataRef<'a, NBITS>,
1228                DataRef<'a, NBITS>,
1229            > + Target2<
1230                diskann_wide::arch::Current,
1231                distances::MathematicalResult<f32>,
1232                DataRef<'a, NBITS>,
1233                DataRef<'a, NBITS>,
1234            >,
1235    {
1236        let mut center = vec![0.0f32; dim];
1237        for trial in 0..ntrials {
1238            // Sample the center.
1239            center
1240                .iter_mut()
1241                .for_each(|c| *c = StandardNormal {}.sample(rng));
1242
1243            let c_square_norm = FastL2NormSquared.evaluate(&*center);
1244
1245            // Inner Product
1246            {
1247                let x = Data::<NBITS, _>::generate_reference(
1248                    &center,
1249                    SupportedMetric::InnerProduct,
1250                    rng,
1251                );
1252                let y = Data::<NBITS, _>::generate_reference(
1253                    &center,
1254                    SupportedMetric::InnerProduct,
1255                    rng,
1256                );
1257
1258                let kernel_result = {
1259                    let xy: MV<f32> = diskann_vector::distance::InnerProduct::evaluate(
1260                        &*x.reconstructed,
1261                        &*y.reconstructed,
1262                    );
1263                    x.norm * y.norm * xy.into_inner() / (x.self_ip.unwrap() * y.self_ip.unwrap())
1264                };
1265
1266                let reference_ip = kernel_result + x.center_ip + y.center_ip + c_square_norm;
1267                let ip = CompensatedIP::new(&center, center.len());
1268                let got_ip: distances::MathematicalResult<f32> =
1269                    ARCH.run2(ip, x.compressed.reborrow(), y.compressed.reborrow());
1270                let got_ip = got_ip.unwrap();
1271
1272                let ctx = &lazy_format!(
1273                    "Inner Product, trial {} of {}, dim = {}",
1274                    trial,
1275                    ntrials,
1276                    dim
1277                );
1278                assert!(err_ip.check(got_ip.into_inner(), reference_ip, Some(ctx)));
1279
1280                let got_ip_f32: distances::Result<f32> =
1281                    ARCH.run2(ip, x.compressed.reborrow(), y.compressed.reborrow());
1282
1283                let got_ip_f32 = got_ip_f32.unwrap();
1284
1285                assert_eq!(got_ip_f32, -got_ip.into_inner());
1286
1287                // Cosine (very similary to inner-product).
1288                let cosine = CompensatedCosine::new(ip);
1289                let got_cosine: distances::MathematicalResult<f32> =
1290                    ARCH.run2(cosine, x.compressed.reborrow(), y.compressed.reborrow());
1291                let got_cosine = got_cosine.unwrap();
1292                assert_eq!(
1293                    got_cosine.into_inner(),
1294                    got_ip.into_inner(),
1295                    "cosine and IP should be the same"
1296                );
1297
1298                let got_cosine_f32: distances::Result<f32> =
1299                    ARCH.run2(cosine, x.compressed.reborrow(), y.compressed.reborrow());
1300
1301                let got_cosine_f32 = got_cosine_f32.unwrap();
1302
1303                assert_eq!(
1304                    got_cosine_f32,
1305                    1.0 - got_cosine.into_inner(),
1306                    "incorrect transform performed"
1307                );
1308            }
1309
1310            // Squared L2
1311            {
1312                let x =
1313                    Data::<NBITS, _>::generate_reference(&center, SupportedMetric::SquaredL2, rng);
1314                let y =
1315                    Data::<NBITS, _>::generate_reference(&center, SupportedMetric::SquaredL2, rng);
1316
1317                // Compute the expected value for the quantity `|X'| |Y'| <x, y>`.
1318                let kernel_result = {
1319                    let xy: MV<f32> = diskann_vector::distance::InnerProduct::evaluate(
1320                        &*x.reconstructed,
1321                        &*y.reconstructed,
1322                    );
1323                    x.norm * y.norm * xy.into_inner() / (x.self_ip.unwrap() * y.self_ip.unwrap())
1324                };
1325
1326                let reference_l2 = x.norm * x.norm + y.norm * y.norm - 2.0 * kernel_result;
1327                let l2 = CompensatedSquaredL2::new(dim);
1328                let got_l2: distances::MathematicalResult<f32> =
1329                    ARCH.run2(l2, x.compressed.reborrow(), y.compressed.reborrow());
1330                let got_l2 = got_l2.unwrap();
1331
1332                let ctx =
1333                    &lazy_format!("Squared L2, trial {} of {}, dim = {}", trial, ntrials, dim);
1334                assert!(err_l2.check(got_l2.into_inner(), reference_l2, Some(ctx)));
1335
1336                let got_l2_f32: distances::Result<f32> =
1337                    ARCH.run2(l2, x.compressed.reborrow(), y.compressed.reborrow());
1338                let got_l2_f32 = got_l2_f32.unwrap();
1339
1340                assert_eq!(got_l2_f32, got_l2.into_inner());
1341            }
1342        }
1343    }
1344
1345    /// This works similarly to the 1-bit compensated distances, but checks the 4-bit query
1346    /// path.
1347    fn test_mixed_compensated_distance<const Q: usize, const D: usize, Perm>(
1348        dim: usize,
1349        ntrials: usize,
1350        err_l2: Approx,
1351        err_ip: Approx,
1352        rng: &mut StdRng,
1353    ) where
1354        Unsigned: Representation<Q>,
1355        Unsigned: Representation<D>,
1356        Perm: PermutationStrategy<Q>,
1357        for<'a> CompensatedIP: Target2<
1358            diskann_wide::arch::Current,
1359            distances::MathematicalResult<f32>,
1360            QueryRef<'a, Q, Perm>,
1361            DataRef<'a, D>,
1362        >,
1363        for<'a> CompensatedSquaredL2: Target2<
1364            diskann_wide::arch::Current,
1365            distances::MathematicalResult<f32>,
1366            QueryRef<'a, Q, Perm>,
1367            DataRef<'a, D>,
1368        >,
1369        for<'a> CompensatedCosine: Target2<
1370            diskann_wide::arch::Current,
1371            distances::MathematicalResult<f32>,
1372            QueryRef<'a, Q, Perm>,
1373            DataRef<'a, D>,
1374        >,
1375        for<'a> CompensatedIP: Target2<
1376            diskann_wide::arch::Current,
1377            distances::Result<f32>,
1378            QueryRef<'a, Q, Perm>,
1379            DataRef<'a, D>,
1380        >,
1381        for<'a> CompensatedSquaredL2: Target2<
1382            diskann_wide::arch::Current,
1383            distances::Result<f32>,
1384            QueryRef<'a, Q, Perm>,
1385            DataRef<'a, D>,
1386        >,
1387        for<'a> CompensatedCosine: Target2<
1388            diskann_wide::arch::Current,
1389            distances::Result<f32>,
1390            QueryRef<'a, Q, Perm>,
1391            DataRef<'a, D>,
1392        >,
1393    {
1394        // The center
1395        let mut center = vec![0.0f32; dim];
1396        for trial in 0..ntrials {
1397            // Sample the center.
1398            center
1399                .iter_mut()
1400                .for_each(|c| *c = StandardNormal {}.sample(rng));
1401
1402            let c_square_norm = FastL2NormSquared.evaluate(&*center);
1403
1404            // Inner Product
1405            {
1406                let x = Query::<Q, Perm, _>::generate_reference(
1407                    &center,
1408                    SupportedMetric::InnerProduct,
1409                    rng,
1410                );
1411                let y =
1412                    Data::<D, _>::generate_reference(&center, SupportedMetric::InnerProduct, rng);
1413
1414                // The expected scaled dot-product between the normalized vectors.
1415                let xy = {
1416                    let xy: MV<f32> = diskann_vector::distance::InnerProduct::evaluate(
1417                        &*x.reconstructed,
1418                        &*y.reconstructed,
1419                    );
1420                    x.norm * y.norm * xy.into_inner() / y.self_ip.unwrap()
1421                };
1422
1423                let reference_ip = -(xy + x.center_ip + y.center_ip + c_square_norm);
1424                let ip = CompensatedIP::new(&center, center.len());
1425                let got_ip: distances::Result<f32> =
1426                    ARCH.run2(ip, x.compressed.reborrow(), y.compressed.reborrow());
1427                let got_ip = got_ip.unwrap();
1428
1429                let ctx = &lazy_format!(
1430                    "Inner Product, trial = {} of {}, dim = {}",
1431                    trial,
1432                    ntrials,
1433                    dim
1434                );
1435
1436                assert!(err_ip.check(got_ip, reference_ip, Some(ctx)));
1437
1438                // Cosine (very similary to inner-product).
1439                let cosine = CompensatedCosine::new(ip);
1440                let got_cosine: distances::MathematicalResult<f32> =
1441                    ARCH.run2(cosine, x.compressed.reborrow(), y.compressed.reborrow());
1442
1443                let got_cosine = got_cosine.unwrap();
1444                assert_eq!(
1445                    got_cosine.into_inner(),
1446                    -got_ip,
1447                    "cosine and IP should be the same"
1448                );
1449
1450                let got_cosine_f32: distances::Result<f32> =
1451                    ARCH.run2(cosine, x.compressed.reborrow(), y.compressed.reborrow());
1452
1453                let got_cosine_f32 = got_cosine_f32.unwrap();
1454                assert_eq!(
1455                    got_cosine_f32,
1456                    1.0 - got_cosine.into_inner(),
1457                    "incorrect transform performed"
1458                );
1459            }
1460
1461            // Squared L2
1462            {
1463                let x = Query::<Q, Perm, _>::generate_reference(
1464                    &center,
1465                    SupportedMetric::SquaredL2,
1466                    rng,
1467                );
1468                let y = Data::<D, _>::generate_reference(&center, SupportedMetric::SquaredL2, rng);
1469
1470                // The expected scaled dot-product between the normalized vectors.
1471                let xy = {
1472                    let xy: MV<f32> = diskann_vector::distance::InnerProduct::evaluate(
1473                        &*x.reconstructed,
1474                        &*y.reconstructed,
1475                    );
1476                    x.norm * y.norm * xy.into_inner() / y.self_ip.unwrap()
1477                };
1478                let reference_l2 = x.norm * x.norm + y.norm * y.norm - 2.0 * xy;
1479                let l2 = CompensatedSquaredL2::new(dim);
1480                let got_l2: distances::Result<f32> =
1481                    ARCH.run2(l2, x.compressed.reborrow(), y.compressed.reborrow());
1482                let got_l2 = got_l2.unwrap();
1483
1484                let ctx = &lazy_format!(
1485                    "Squared L2, trial = {} of {}, dim = {}",
1486                    trial,
1487                    ntrials,
1488                    dim
1489                );
1490
1491                assert!(err_l2.check(got_l2, reference_l2, Some(ctx)));
1492            }
1493        }
1494    }
1495
1496    fn test_full_distances<const NBITS: usize>(
1497        dim: usize,
1498        ntrials: usize,
1499        err_l2: Approx,
1500        err_ip: Approx,
1501        rng: &mut StdRng,
1502    ) where
1503        Unsigned: Representation<NBITS>,
1504        for<'a> CompensatedIP: Target2<
1505            diskann_wide::arch::Current,
1506            distances::MathematicalResult<f32>,
1507            FullQueryRef<'a>,
1508            DataRef<'a, NBITS>,
1509        >,
1510        for<'a> CompensatedSquaredL2: Target2<
1511            diskann_wide::arch::Current,
1512            distances::MathematicalResult<f32>,
1513            FullQueryRef<'a>,
1514            DataRef<'a, NBITS>,
1515        >,
1516        for<'a> CompensatedCosine: Target2<
1517            diskann_wide::arch::Current,
1518            distances::MathematicalResult<f32>,
1519            FullQueryRef<'a>,
1520            DataRef<'a, NBITS>,
1521        >,
1522        for<'a> CompensatedIP: Target2<
1523            diskann_wide::arch::Current,
1524            distances::Result<f32>,
1525            FullQueryRef<'a>,
1526            DataRef<'a, NBITS>,
1527        >,
1528        for<'a> CompensatedSquaredL2: Target2<
1529            diskann_wide::arch::Current,
1530            distances::Result<f32>,
1531            FullQueryRef<'a>,
1532            DataRef<'a, NBITS>,
1533        >,
1534        for<'a> CompensatedCosine: Target2<
1535            diskann_wide::arch::Current,
1536            distances::Result<f32>,
1537            FullQueryRef<'a>,
1538            DataRef<'a, NBITS>,
1539        >,
1540    {
1541        // The center
1542        let mut center = vec![0.0f32; dim];
1543        for trial in 0..ntrials {
1544            // Sample the center.
1545            center
1546                .iter_mut()
1547                .for_each(|c| *c = StandardNormal {}.sample(rng));
1548
1549            let c_square_norm = FastL2NormSquared.evaluate(&*center);
1550
1551            // Inner Product
1552            {
1553                let x = FullQuery::generate_reference(&center, SupportedMetric::InnerProduct, rng);
1554                let y = Data::<NBITS, _>::generate_reference(
1555                    &center,
1556                    SupportedMetric::InnerProduct,
1557                    rng,
1558                );
1559
1560                // The expected scaled dot-product between the normalized vectors.
1561                let xy = {
1562                    let xy: MV<f32> = diskann_vector::distance::InnerProduct::evaluate(
1563                        &*x.reconstructed,
1564                        &*y.reconstructed,
1565                    );
1566                    x.norm * y.norm * xy.into_inner() / y.self_ip.unwrap()
1567                };
1568
1569                let reference_ip = -(xy + x.center_ip + y.center_ip + c_square_norm);
1570                let ip = CompensatedIP::new(&center, center.len());
1571                let got_ip: distances::Result<f32> =
1572                    ARCH.run2(ip, x.compressed.reborrow(), y.compressed.reborrow());
1573                let got_ip = got_ip.unwrap();
1574
1575                let ctx = &lazy_format!(
1576                    "Inner Product, trial = {} of {}, dim = {}",
1577                    trial,
1578                    ntrials,
1579                    dim
1580                );
1581
1582                assert!(err_ip.check(got_ip, reference_ip, Some(ctx)));
1583
1584                // Cosine (very similary to inner-product).
1585                let cosine = CompensatedCosine::new(ip);
1586                let got_cosine: distances::MathematicalResult<f32> =
1587                    ARCH.run2(cosine, x.compressed.reborrow(), y.compressed.reborrow());
1588                let got_cosine = got_cosine.unwrap();
1589                assert_eq!(
1590                    got_cosine.into_inner(),
1591                    -got_ip,
1592                    "cosine and IP should be the same"
1593                );
1594
1595                let got_cosine_f32: distances::Result<f32> =
1596                    ARCH.run2(cosine, x.compressed.reborrow(), y.compressed.reborrow());
1597
1598                let got_cosine_f32 = got_cosine_f32.unwrap();
1599                assert_eq!(
1600                    got_cosine_f32,
1601                    1.0 - got_cosine.into_inner(),
1602                    "incorrect transform performed"
1603                );
1604            }
1605
1606            // Squared L2
1607            {
1608                let x = FullQuery::generate_reference(&center, SupportedMetric::SquaredL2, rng);
1609                let y =
1610                    Data::<NBITS, _>::generate_reference(&center, SupportedMetric::SquaredL2, rng);
1611
1612                // The expected scaled dot-product between the normalized vectors.
1613                let xy = {
1614                    let xy: MV<f32> = diskann_vector::distance::InnerProduct::evaluate(
1615                        &*x.reconstructed,
1616                        &*y.reconstructed,
1617                    );
1618                    x.norm * y.norm * xy.into_inner() / y.self_ip.unwrap()
1619                };
1620
1621                let reference_l2 = x.norm * x.norm + y.norm * y.norm - 2.0 * xy;
1622                let l2 = CompensatedSquaredL2::new(dim);
1623                let got_l2: distances::Result<f32> =
1624                    ARCH.run2(l2, x.compressed.reborrow(), y.compressed.reborrow());
1625                let got_l2 = got_l2.unwrap();
1626
1627                let ctx = &lazy_format!(
1628                    "Squared L2, trial = {} of {}, dim = {}",
1629                    trial,
1630                    ntrials,
1631                    dim
1632                );
1633                assert!(err_l2.check(got_l2, reference_l2, Some(ctx)));
1634            }
1635        }
1636    }
1637
1638    cfg_if::cfg_if! {
1639        if #[cfg(miri)] {
1640            // The max dim does not need to be as high for these vectors because they
1641            // defer their distance function implementation to `BitSlice`, which is more
1642            // heavily tested.
1643            const MAX_DIM: usize = 37;
1644            const TRIALS_PER_DIM: usize = 1;
1645        } else {
1646            const MAX_DIM: usize = 256;
1647            const TRIALS_PER_DIM: usize = 20;
1648        }
1649    }
1650
1651    #[test]
1652    fn test_symmetric_distances_1bit() {
1653        let mut rng = StdRng::seed_from_u64(0x2a5f79a2469218f6);
1654        for dim in 1..MAX_DIM {
1655            test_compensated_distance::<1>(
1656                dim,
1657                TRIALS_PER_DIM,
1658                Approx::new(4.0e-3, 3.0e-3),
1659                Approx::new(1.0e-3, 5.0e-4),
1660                &mut rng,
1661            );
1662        }
1663    }
1664
1665    #[test]
1666    fn test_symmetric_distances_2bit() {
1667        let mut rng = StdRng::seed_from_u64(0x68f8f52057f94399);
1668        for dim in 1..MAX_DIM {
1669            test_compensated_distance::<2>(
1670                dim,
1671                TRIALS_PER_DIM,
1672                Approx::new(3.5e-3, 2.0e-3),
1673                Approx::new(2.0e-3, 5.0e-4),
1674                &mut rng,
1675            );
1676        }
1677    }
1678
1679    #[test]
1680    fn test_symmetric_distances_4bit() {
1681        let mut rng = StdRng::seed_from_u64(0xb88d76ac4c58e923);
1682        for dim in 1..MAX_DIM {
1683            test_compensated_distance::<4>(
1684                dim,
1685                TRIALS_PER_DIM,
1686                Approx::new(2.0e-3, 2.0e-3),
1687                Approx::new(2.0e-3, 5.0e-4),
1688                &mut rng,
1689            );
1690        }
1691    }
1692
1693    #[test]
1694    fn test_symmetric_distances_8bit() {
1695        let mut rng = StdRng::seed_from_u64(0x1c2b79873ee32626);
1696        for dim in 1..MAX_DIM {
1697            test_compensated_distance::<8>(
1698                dim,
1699                TRIALS_PER_DIM,
1700                Approx::new(2.0e-3, 2.0e-3),
1701                Approx::new(2.0e-3, 4.0e-4),
1702                &mut rng,
1703            );
1704        }
1705    }
1706
1707    #[test]
1708    fn test_mixed_distances_4x1() {
1709        let mut rng = StdRng::seed_from_u64(0x1efb4d87ed0a8ada);
1710        for dim in 1..MAX_DIM {
1711            test_mixed_compensated_distance::<4, 1, BitTranspose>(
1712                dim,
1713                TRIALS_PER_DIM,
1714                Approx::new(4.0e-3, 3.0e-3),
1715                Approx::new(1.3e-2, 8.3e-3),
1716                &mut rng,
1717            );
1718        }
1719    }
1720
1721    #[test]
1722    fn test_mixed_distances_4x4() {
1723        let mut rng = StdRng::seed_from_u64(0x508554264eb7a51b);
1724        for dim in 1..MAX_DIM {
1725            test_mixed_compensated_distance::<4, 4, Dense>(
1726                dim,
1727                TRIALS_PER_DIM,
1728                Approx::new(4.0e-3, 3.0e-3),
1729                Approx::new(3.0e-4, 8.3e-2),
1730                &mut rng,
1731            );
1732        }
1733    }
1734
1735    #[test]
1736    fn test_mixed_distances_8x8() {
1737        let mut rng = StdRng::seed_from_u64(0x8acd8e4224c76c43);
1738        for dim in 1..MAX_DIM {
1739            test_mixed_compensated_distance::<8, 8, Dense>(
1740                dim,
1741                TRIALS_PER_DIM,
1742                Approx::new(2.0e-3, 6.0e-3),
1743                Approx::new(1.0e-2, 3.0e-2),
1744                &mut rng,
1745            );
1746        }
1747    }
1748
1749    // Full
1750    #[test]
1751    fn test_full_distances_1bit() {
1752        let mut rng = StdRng::seed_from_u64(0x7f93530559f42d66);
1753        for dim in 1..MAX_DIM {
1754            test_full_distances::<1>(
1755                dim,
1756                TRIALS_PER_DIM,
1757                Approx::new(1.0e-3, 2.0e-3),
1758                Approx::new(0.0, 5.0e-3),
1759                &mut rng,
1760            );
1761        }
1762    }
1763
1764    #[test]
1765    fn test_full_distances_2bit() {
1766        let mut rng = StdRng::seed_from_u64(0xa3ad61d3d03a0c5a);
1767        for dim in 1..MAX_DIM {
1768            test_full_distances::<2>(
1769                dim,
1770                TRIALS_PER_DIM,
1771                Approx::new(2.0e-3, 1.1e-3),
1772                Approx::new(7.0e-4, 1.0e-3),
1773                &mut rng,
1774            );
1775        }
1776    }
1777
1778    #[test]
1779    fn test_full_distances_4bit() {
1780        let mut rng = StdRng::seed_from_u64(0x3e2f50ed7c64f0c2);
1781        for dim in 1..MAX_DIM {
1782            test_full_distances::<4>(
1783                dim,
1784                TRIALS_PER_DIM,
1785                Approx::new(2.0e-3, 1.0e-2),
1786                Approx::new(1.0e-3, 5.0e-4),
1787                &mut rng,
1788            );
1789        }
1790    }
1791
1792    #[test]
1793    fn test_full_distances_8bit() {
1794        let mut rng = StdRng::seed_from_u64(0x95705070e415c6d3);
1795        for dim in 1..MAX_DIM {
1796            test_full_distances::<8>(
1797                dim,
1798                TRIALS_PER_DIM,
1799                Approx::new(1.0e-3, 1.0e-3),
1800                Approx::new(2.0e-3, 1.0e-4),
1801                &mut rng,
1802            );
1803        }
1804    }
1805}