diskann_quantization/minmax/
vectors.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use diskann_utils::{Reborrow, ReborrowMut};
7use diskann_vector::{MathematicalValue, PureDistanceFunction};
8use thiserror::Error;
9
10use crate::{
11    bits::{BitSlice, Dense, Representation, Unsigned},
12    distances,
13    distances::{InnerProduct, MV},
14    meta::{self},
15};
16
17/// A per-vector precomputed coefficients to help compute inner products
18/// and squared L2 distances for the MinMax quantized vectors.
19///
20/// The inner product between `X = ax * X' + bx` and `Y = ay * Y' + by` for d-dimensional
21/// vectors X and Y is:
22/// ```math
23/// <X, Y> = <ax * X' + bx, ay * Y' + by>
24///        = ax * ay * <X', Y'> + ax * <X', by> + ay * <Y', bx> + d * bx * by.
25/// ```
26/// Let us define a grouping of these terms to make it easier to understand:
27/// ```math
28///  Nx = ax * sum_i X'[i],     Ny = ay * sum_i Y'[i],
29/// ```
30/// We can then simplify the dot product calculation as follows:
31/// ```math
32/// <X, Y> = ax * ay * <X', Y'> + Nx * by + Ny * bx +  d * bx * by
33///                    --------
34///                       |
35///               Integer Dot Product
36/// ```
37///
38/// To compute the squared L2 distance,
39/// ```math
40/// |X - Y|^2 = |ax * X' + bx|^2 + |ay * Y' + by|^2 - 2 * <X, Y>
41/// ```
42/// we can re-use the computation for inner-product from above.
43#[derive(Default, Debug, Clone, Copy, PartialEq, bytemuck::Zeroable, bytemuck::Pod)]
44#[repr(C)]
45pub struct MinMaxCompensation {
46    pub dim: u32,          // - dimension
47    pub b: f32,            // - bx
48    pub n: f32,            // - Nx
49    pub a: f32,            // - ax
50    pub norm_squared: f32, // - |ax * X' + bx|^2
51}
52
53const META_BYTES: usize = std::mem::size_of::<MinMaxCompensation>(); // This will be 5 * 4 = 20 bytes.
54
55/// Error type for parsing a slice of bytes as a `DataRef`
56/// and returning corresponding dimension.
57#[derive(Debug, Error, Clone, PartialEq, Eq)]
58pub enum MetaParseError {
59    #[error("Invalid size: {0}, must contain at least {META_BYTES} bytes")]
60    NotCanonical(usize),
61}
62
63impl MinMaxCompensation {
64    /// Reads the dimension from the first 4 bytes of a MinMax quantized vector's metadata.
65    ///
66    /// This function is used to extract the vector dimension from serialized MinMax quantized
67    /// vector data without fully deserializing the entire vector structure.
68    ///
69    /// # Arguments
70    /// * `bytes` - A byte slice containing the serialized MinMax vector data
71    ///
72    /// # Returns
73    /// * `Ok(dimension)` - The dimension of the vector as a `usize`
74    /// * `Err(MetaParseError::NotCanonical(size))` - If the byte slice is shorter than 20 bytes (META_BYTES)
75    ///
76    /// # Usage
77    /// Use this when you need to determine the vector dimension from serialized data before
78    /// creating a `DataRef` or allocating appropriately sized buffers for decompression.
79    #[inline(always)]
80    pub fn read_dimension(bytes: &[u8]) -> Result<usize, MetaParseError> {
81        if bytes.len() < META_BYTES {
82            return Err(MetaParseError::NotCanonical(bytes.len()));
83        }
84
85        // SAFETY: There are at least `META_BYTES` = 20 bytes in the array so this access is within bounds.
86        let dim_bytes: [u8; 4] = bytes.get(..4).map_or_else(
87            || Err(MetaParseError::NotCanonical(bytes.len())),
88            |slice| {
89                slice
90                    .try_into()
91                    .map_err(|_| MetaParseError::NotCanonical(bytes.len()))
92            },
93        )?;
94
95        let dim = u32::from_le_bytes(dim_bytes) as usize;
96
97        Ok(dim)
98    }
99}
100
101/// An owning compressed data vector
102///
103/// See: [`meta::Vector`].
104pub type Data<const NBITS: usize> = meta::Vector<NBITS, Unsigned, MinMaxCompensation, Dense>;
105
106/// A borrowed `Data` vector
107///
108/// See: [`meta::Vector`].
109pub type DataRef<'a, const NBITS: usize> =
110    meta::VectorRef<'a, NBITS, Unsigned, MinMaxCompensation, Dense>;
111
112#[derive(Debug, Error, Clone, Copy, PartialEq, Eq)]
113pub enum DecompressError {
114    #[error("expected src and dst length to be identical, instead src is {0}, and dst is {1}")]
115    LengthMismatch(usize, usize),
116}
117impl<const NBITS: usize> DataRef<'_, NBITS>
118where
119    Unsigned: Representation<NBITS>,
120{
121    /// Decompresses a MinMax quantized vector back into its original floating-point representation.
122    ///
123    /// This method reconstructs the original vector values using the stored quantization parameters
124    /// and the MinMax dequantization formula: `x = x' * a + b` and stores the result in `dst`
125    ///
126    /// # Arguments
127    ///
128    /// * `dst` - A mutable slice of `f32` values where the decompressed data will be written.
129    ///   Must have the same length as the compressed vector.
130    ///
131    /// # Returns
132    ///
133    /// * `Ok(())` - On successful decompression
134    /// * `Err(DecompressError::LengthMismatch(src_len, dst_len))` - If the destination slice
135    ///   length doesn't match the compressed vector length
136    pub fn decompress_into(&self, dst: &mut [f32]) -> Result<(), DecompressError> {
137        if dst.len() != self.len() {
138            return Err(DecompressError::LengthMismatch(self.len(), dst.len()));
139        }
140        let meta = self.meta();
141
142        // SAFETY: We checked that the length of the underlying vector is the same as
143        // as `dst` so we are guaranteed to be within bounds when accessing the vector.
144        dst.iter_mut().enumerate().for_each(|(i, d)| unsafe {
145            *d = self.vector().get_unchecked(i) as f32 * meta.a + meta.b
146        });
147        Ok(())
148    }
149}
150
151/// A mutable borrowed `Data` vector
152///
153/// See: [`meta::Vector`].
154pub type DataMutRef<'a, const NBITS: usize> =
155    meta::VectorMut<'a, NBITS, Unsigned, MinMaxCompensation, Dense>;
156
157////////////////////
158// Full Precision //
159////////////////////
160
161/// The inner product between `X = ax * X' + bx` and `Y` for d-dimensional
162/// vectors X and Y is:
163/// ```math
164/// <X, Y> = <ax * X' + bx, Y>
165///        = ax * <X', Y> + bx * sum(Y).
166///               --------
167///                  |
168///          Integer-Float Dot Product
169/// ```
170///
171/// To compute the squared L2 distance,
172/// ```math
173/// |X - Y|^2 = |ax * X' + bx|^2 + |Y|^2 - 2 * <X', Y>
174/// ```
175///
176/// A Full Precision Query
177#[derive(Debug)]
178pub struct FullQuery {
179    /// The data after transform is applied to it.
180    pub data: Box<[f32]>,
181    pub meta: FullQueryMeta,
182}
183
184/// A meta struct storing the `sum` and `norm_squared` of a
185/// full query after transformation is applied to it.
186#[derive(Debug, Clone, Copy, Default)]
187pub struct FullQueryMeta {
188    /// The sum of `data`.
189    pub sum: f32,
190    /// The norm of the 'data'.
191    pub norm_squared: f32,
192}
193
194impl FullQuery {
195    /// Construct an empty `FullQuery` for `dim` dimensional data.
196    pub fn empty(dim: usize) -> Self {
197        Self {
198            data: vec![0.0f32; dim].into(),
199            meta: Default::default(),
200        }
201    }
202
203    /// Output the length of `data`
204    pub fn len(&self) -> usize {
205        self.data.len()
206    }
207
208    /// Output if `data` is empty.
209    pub fn is_empty(&self) -> bool {
210        self.data.is_empty()
211    }
212}
213
214impl<'short> Reborrow<'short> for FullQuery {
215    type Target = &'short FullQuery;
216    fn reborrow(&'short self) -> Self::Target {
217        self
218    }
219}
220
221impl<'short> ReborrowMut<'short> for FullQuery {
222    type Target = &'short mut FullQuery;
223    fn reborrow_mut(&'short mut self) -> Self::Target {
224        self
225    }
226}
227
228///////////////////////////
229// Compensated Distances //
230///////////////////////////
231
232fn kernel<const NBITS: usize, F>(
233    x: DataRef<'_, NBITS>,
234    y: DataRef<'_, NBITS>,
235    f: F,
236) -> distances::MathematicalResult<f32>
237where
238    Unsigned: Representation<NBITS>,
239    InnerProduct: for<'a, 'b> PureDistanceFunction<
240        BitSlice<'a, NBITS, Unsigned>,
241        BitSlice<'b, NBITS, Unsigned>,
242        distances::MathematicalResult<u32>,
243    >,
244    F: Fn(f32, &MinMaxCompensation, &MinMaxCompensation) -> f32,
245{
246    let raw_product = InnerProduct::evaluate(x.vector(), y.vector())?;
247    let (xm, ym) = (x.meta(), y.meta());
248    let term0 = xm.a * ym.a * raw_product.into_inner() as f32;
249    let term1_x = xm.n * ym.b;
250    let term1_y = ym.n * xm.b;
251    let term2 = xm.b * ym.b * (x.len() as f32);
252
253    let v = term0 + term1_x + term1_y + term2;
254    Ok(MV::new(f(v, &xm, &ym)))
255}
256
257pub struct MinMaxIP;
258
259impl<const NBITS: usize>
260    PureDistanceFunction<DataRef<'_, NBITS>, DataRef<'_, NBITS>, distances::MathematicalResult<f32>>
261    for MinMaxIP
262where
263    Unsigned: Representation<NBITS>,
264    InnerProduct: for<'a, 'b> PureDistanceFunction<
265        BitSlice<'a, NBITS, Unsigned>,
266        BitSlice<'b, NBITS, Unsigned>,
267        distances::MathematicalResult<u32>,
268    >,
269{
270    fn evaluate(
271        x: DataRef<'_, NBITS>,
272        y: DataRef<'_, NBITS>,
273    ) -> distances::MathematicalResult<f32> {
274        kernel(x, y, |v, _, _| v)
275    }
276}
277
278impl<const NBITS: usize>
279    PureDistanceFunction<DataRef<'_, NBITS>, DataRef<'_, NBITS>, distances::Result<f32>>
280    for MinMaxIP
281where
282    Unsigned: Representation<NBITS>,
283    InnerProduct: for<'a, 'b> PureDistanceFunction<
284        BitSlice<'a, NBITS, Unsigned>,
285        BitSlice<'b, NBITS, Unsigned>,
286        distances::MathematicalResult<u32>,
287    >,
288{
289    fn evaluate(x: DataRef<'_, NBITS>, y: DataRef<'_, NBITS>) -> distances::Result<f32> {
290        let v: distances::MathematicalResult<f32> = Self::evaluate(x, y);
291        Ok(-v?.into_inner())
292    }
293}
294
295impl<const NBITS: usize>
296    PureDistanceFunction<&FullQuery, DataRef<'_, NBITS>, distances::MathematicalResult<f32>>
297    for MinMaxIP
298where
299    Unsigned: Representation<NBITS>,
300    InnerProduct: for<'a, 'b> PureDistanceFunction<
301        &'a [f32],
302        BitSlice<'b, NBITS, Unsigned>,
303        distances::MathematicalResult<f32>,
304    >,
305{
306    fn evaluate(x: &FullQuery, y: DataRef<'_, NBITS>) -> distances::MathematicalResult<f32> {
307        let raw_product: f32 = InnerProduct::evaluate(&x.data, y.vector())?.into_inner();
308        Ok(MathematicalValue::new(
309            raw_product * y.meta().a + x.meta.sum * y.meta().b,
310        ))
311    }
312}
313
314impl<const NBITS: usize>
315    PureDistanceFunction<&FullQuery, DataRef<'_, NBITS>, distances::Result<f32>> for MinMaxIP
316where
317    Unsigned: Representation<NBITS>,
318    InnerProduct: for<'a, 'b> PureDistanceFunction<
319        &'a [f32],
320        BitSlice<'b, NBITS, Unsigned>,
321        distances::MathematicalResult<f32>,
322    >,
323{
324    fn evaluate(x: &FullQuery, y: DataRef<'_, NBITS>) -> distances::Result<f32> {
325        let v: distances::MathematicalResult<f32> = Self::evaluate(x, y);
326        Ok(-v?.into_inner())
327    }
328}
329
330pub struct MinMaxL2Squared;
331
332impl<const NBITS: usize>
333    PureDistanceFunction<DataRef<'_, NBITS>, DataRef<'_, NBITS>, distances::MathematicalResult<f32>>
334    for MinMaxL2Squared
335where
336    Unsigned: Representation<NBITS>,
337    InnerProduct: for<'a, 'b> PureDistanceFunction<
338        BitSlice<'a, NBITS, Unsigned>,
339        BitSlice<'b, NBITS, Unsigned>,
340        distances::MathematicalResult<u32>,
341    >,
342{
343    fn evaluate(
344        x: DataRef<'_, NBITS>,
345        y: DataRef<'_, NBITS>,
346    ) -> distances::MathematicalResult<f32> {
347        kernel(x, y, |v, xm, ym| {
348            -2.0 * v + xm.norm_squared + ym.norm_squared
349        })
350    }
351}
352
353impl<const NBITS: usize>
354    PureDistanceFunction<DataRef<'_, NBITS>, DataRef<'_, NBITS>, distances::Result<f32>>
355    for MinMaxL2Squared
356where
357    Unsigned: Representation<NBITS>,
358    InnerProduct: for<'a, 'b> PureDistanceFunction<
359        BitSlice<'a, NBITS, Unsigned>,
360        BitSlice<'b, NBITS, Unsigned>,
361        distances::MathematicalResult<u32>,
362    >,
363{
364    fn evaluate(x: DataRef<'_, NBITS>, y: DataRef<'_, NBITS>) -> distances::Result<f32> {
365        let v: distances::MathematicalResult<f32> = Self::evaluate(x, y);
366        Ok(v?.into_inner())
367    }
368}
369
370impl<const NBITS: usize>
371    PureDistanceFunction<&FullQuery, DataRef<'_, NBITS>, distances::MathematicalResult<f32>>
372    for MinMaxL2Squared
373where
374    Unsigned: Representation<NBITS>,
375    InnerProduct: for<'a, 'b> PureDistanceFunction<
376        &'a [f32],
377        BitSlice<'b, NBITS, Unsigned>,
378        distances::MathematicalResult<f32>,
379    >,
380{
381    fn evaluate(x: &FullQuery, y: DataRef<'_, NBITS>) -> distances::MathematicalResult<f32> {
382        let raw_product = InnerProduct::evaluate(&x.data, y.vector())?.into_inner();
383
384        let ym = y.meta();
385        let compensated_ip = raw_product * ym.a + x.meta.sum * ym.b;
386        Ok(MV::new(
387            x.meta.norm_squared + ym.norm_squared - 2.0 * compensated_ip,
388        ))
389    }
390}
391
392impl<const NBITS: usize>
393    PureDistanceFunction<&FullQuery, DataRef<'_, NBITS>, distances::Result<f32>> for MinMaxL2Squared
394where
395    Unsigned: Representation<NBITS>,
396    InnerProduct: for<'a, 'b> PureDistanceFunction<
397        &'a [f32],
398        BitSlice<'b, NBITS, Unsigned>,
399        distances::MathematicalResult<f32>,
400    >,
401{
402    fn evaluate(x: &FullQuery, y: DataRef<'_, NBITS>) -> distances::Result<f32> {
403        let v: distances::MathematicalResult<f32> = Self::evaluate(x, y);
404        Ok(v?.into_inner())
405    }
406}
407
408///////////////////////
409// Cosine Distances //
410///////////////////////
411
412pub struct MinMaxCosine;
413
414impl<const NBITS: usize>
415    PureDistanceFunction<DataRef<'_, NBITS>, DataRef<'_, NBITS>, distances::Result<f32>>
416    for MinMaxCosine
417where
418    Unsigned: Representation<NBITS>,
419    MinMaxIP: for<'a, 'b> PureDistanceFunction<
420        DataRef<'a, NBITS>,
421        DataRef<'b, NBITS>,
422        distances::MathematicalResult<f32>,
423    >,
424{
425    // 1 - <X, Y> / (|X| * |Y|)
426    fn evaluate(x: DataRef<'_, NBITS>, y: DataRef<'_, NBITS>) -> distances::Result<f32> {
427        let ip: MV<f32> = MinMaxIP::evaluate(x, y)?;
428        let (xm, ym) = (x.meta(), y.meta());
429        Ok(1.0 - ip.into_inner() / (xm.norm_squared.sqrt() * ym.norm_squared.sqrt()))
430    }
431}
432
433impl<const NBITS: usize>
434    PureDistanceFunction<&FullQuery, DataRef<'_, NBITS>, distances::Result<f32>> for MinMaxCosine
435where
436    Unsigned: Representation<NBITS>,
437    MinMaxIP: for<'a, 'b> PureDistanceFunction<
438        &'a FullQuery,
439        DataRef<'b, NBITS>,
440        distances::MathematicalResult<f32>,
441    >,
442{
443    fn evaluate(x: &'_ FullQuery, y: DataRef<'_, NBITS>) -> distances::Result<f32> {
444        let ip: MathematicalValue<f32> = MinMaxIP::evaluate(x, y)?;
445        let (xm, ym) = (x.meta.norm_squared, y.meta());
446        Ok(1.0 - ip.into_inner() / (xm.sqrt() * ym.norm_squared.sqrt()))
447        // 1 - <X, Y> / (|X| * |Y|)
448    }
449}
450
451pub struct MinMaxCosineNormalized;
452
453impl<const NBITS: usize>
454    PureDistanceFunction<DataRef<'_, NBITS>, DataRef<'_, NBITS>, distances::Result<f32>>
455    for MinMaxCosineNormalized
456where
457    Unsigned: Representation<NBITS>,
458    MinMaxIP: for<'a, 'b> PureDistanceFunction<
459        DataRef<'a, NBITS>,
460        DataRef<'b, NBITS>,
461        distances::MathematicalResult<f32>,
462    >,
463{
464    fn evaluate(x: DataRef<'_, NBITS>, y: DataRef<'_, NBITS>) -> distances::Result<f32> {
465        let ip: MathematicalValue<f32> = MinMaxIP::evaluate(x, y)?;
466        Ok(1.0 - ip.into_inner()) // 1 - <X, Y>
467    }
468}
469
470impl<const NBITS: usize>
471    PureDistanceFunction<&FullQuery, DataRef<'_, NBITS>, distances::Result<f32>>
472    for MinMaxCosineNormalized
473where
474    Unsigned: Representation<NBITS>,
475    MinMaxIP: for<'a, 'b> PureDistanceFunction<
476        &'a FullQuery,
477        DataRef<'b, NBITS>,
478        distances::MathematicalResult<f32>,
479    >,
480{
481    fn evaluate(x: &'_ FullQuery, y: DataRef<'_, NBITS>) -> distances::Result<f32> {
482        let ip: MathematicalValue<f32> = MinMaxIP::evaluate(x, y)?;
483        Ok(1.0 - ip.into_inner()) // 1 - <X, Y>
484    }
485}
486
487///////////
488// Tests //
489///////////
490
491#[cfg(test)]
492mod minmax_vector_tests {
493    use diskann_utils::Reborrow;
494    use rand::{
495        distr::{Distribution, Uniform},
496        rngs::StdRng,
497        Rng, SeedableRng,
498    };
499
500    use super::*;
501    use crate::scalar::bit_scale;
502
503    fn test_minmax_compensated_vectors<const NBITS: usize, R>(dim: usize, rng: &mut R)
504    where
505        Unsigned: Representation<NBITS>,
506        InnerProduct: for<'a, 'b> PureDistanceFunction<
507            BitSlice<'a, NBITS, Unsigned>,
508            BitSlice<'b, NBITS, Unsigned>,
509            distances::MathematicalResult<u32>,
510        >,
511        InnerProduct: for<'a, 'b> PureDistanceFunction<
512            &'a [f32],
513            BitSlice<'b, NBITS, Unsigned>,
514            distances::MathematicalResult<f32>,
515        >,
516        R: Rng,
517    {
518        assert!(dim <= bit_scale::<NBITS>() as usize);
519
520        // Create two vectors with known compensation values
521        let mut v1 = Data::<NBITS>::new_boxed(dim);
522        let mut v2 = Data::<NBITS>::new_boxed(dim);
523
524        let domain = Unsigned::domain_const::<NBITS>();
525        let code_distribution = Uniform::new_inclusive(*domain.start(), *domain.end()).unwrap();
526
527        // Set bit values
528        {
529            let mut bitslice1 = v1.vector_mut();
530            let mut bitslice2 = v2.vector_mut();
531
532            for i in 0..dim {
533                bitslice1.set(i, code_distribution.sample(rng)).unwrap();
534                bitslice2.set(i, code_distribution.sample(rng)).unwrap();
535            }
536        }
537        let a_rnd = Uniform::new_inclusive(0.0, 2.0).unwrap();
538        let b_rnd = Uniform::new_inclusive(0.0, 2.0).unwrap();
539
540        // Set compensation coefficients
541        // v1: X = a1 * X' + b1
542        // v2: Y = a2 * Y' + b2
543        let a1 = a_rnd.sample(rng);
544        let b1 = b_rnd.sample(rng);
545        let a2 = a_rnd.sample(rng);
546        let b2 = b_rnd.sample(rng);
547
548        // Calculate sum of vector elements for n values
549        let sum1: f32 = (0..dim).map(|i| v1.vector().get(i).unwrap() as f32).sum();
550        let sum2: f32 = (0..dim).map(|i| v2.vector().get(i).unwrap() as f32).sum();
551
552        // Create original full-precision vectors for reference calculations
553        let mut original1 = Vec::with_capacity(dim);
554        let mut original2 = Vec::with_capacity(dim);
555
556        // Calculate the reconstructed original vectors and their norms
557        for i in 0..dim {
558            let val1 = a1 * v1.vector().get(i).unwrap() as f32 + b1;
559            let val2 = a2 * v2.vector().get(i).unwrap() as f32 + b2;
560            original1.push(val1);
561            original2.push(val2);
562        }
563
564        // Calculate squared norms
565        let norm1_squared: f32 = original1.iter().map(|x| x * x).sum();
566        let norm2_squared: f32 = original2.iter().map(|x| x * x).sum();
567
568        // Set compensation coefficients
569        v1.set_meta(MinMaxCompensation {
570            a: a1,
571            b: b1,
572            n: a1 * sum1,
573            norm_squared: norm1_squared,
574            dim: dim as u32,
575        });
576
577        v2.set_meta(MinMaxCompensation {
578            a: a2,
579            b: b2,
580            n: a2 * sum2,
581            norm_squared: norm2_squared,
582            dim: dim as u32,
583        });
584
585        // Calculate raw integer dot product
586        let expected_ip = (0..dim).map(|i| original1[i] * original2[i]).sum::<f32>();
587
588        // Test inner product with f32
589        let computed_ip_f32: distances::Result<f32> =
590            MinMaxIP::evaluate(v1.reborrow(), v2.reborrow());
591        let computed_ip_f32 = computed_ip_f32.unwrap();
592        assert!(
593            (expected_ip - (-computed_ip_f32)).abs() / expected_ip.abs() < 1e-3,
594            "Inner product (f32) failed: expected {}, got {} on dim : {}",
595            -expected_ip,
596            computed_ip_f32,
597            dim
598        );
599
600        // Expected L2 distance = |X|² + |Y|² - 2<X,Y>
601        let expected_l2 = (0..dim)
602            .map(|i| original1[i] - original2[i])
603            .map(|x| x.powf(2.0))
604            .sum::<f32>();
605
606        // Test L2 distance with f32
607        let computed_l2_f32: distances::Result<f32> =
608            MinMaxL2Squared::evaluate(v1.reborrow(), v2.reborrow());
609        let computed_l2_f32 = computed_l2_f32.unwrap();
610        assert!(
611            ((computed_l2_f32 - expected_l2).abs() / expected_l2) < 1e-3,
612            "L2 distance (f32) failed: expected {}, got {} on dim : {}",
613            expected_l2,
614            computed_l2_f32,
615            dim
616        );
617
618        let expected_cosine = 1.0 - expected_ip / (norm1_squared.sqrt() * norm2_squared.sqrt());
619
620        let computed_cosine: distances::Result<f32> =
621            MinMaxCosine::evaluate(v1.reborrow(), v2.reborrow());
622        let computed_cosine = computed_cosine.unwrap();
623
624        {
625            let passed = (computed_cosine - expected_cosine).abs() < 1e-6
626                || ((computed_cosine - expected_cosine).abs() / expected_cosine) < 1e-3;
627
628            assert!(
629                passed,
630                "Cosine distance (f32) failed: expected {}, got {} on dim : {}",
631                expected_cosine, computed_cosine, dim
632            );
633        }
634
635        let cosine_normalized: distances::Result<f32> =
636            MinMaxCosineNormalized::evaluate(v1.reborrow(), v2.reborrow());
637        let cosine_normalized = cosine_normalized.unwrap();
638        let expected_cos_normalized = 1.0 - expected_ip;
639        assert!(
640            ((expected_cos_normalized - cosine_normalized).abs() / expected_cos_normalized.abs())
641                < 1e-6,
642            "CosineNormalized distance (f32) failed: expected {}, got {} on dim : {}",
643            expected_cos_normalized,
644            cosine_normalized,
645            dim
646        );
647
648        //Calculate inner product with full precision vector
649        let mut fp_query = FullQuery::empty(dim);
650        let fp_meta = FullQueryMeta {
651            norm_squared: norm1_squared,
652            sum: original1.iter().sum::<f32>(),
653        };
654        fp_query.data = original1.clone().into_boxed_slice();
655        fp_query.meta = fp_meta;
656
657        let fp_ip: distances::Result<f32> = MinMaxIP::evaluate(fp_query.reborrow(), v2.reborrow());
658        let fp_ip = fp_ip.unwrap();
659        assert!(
660            (expected_ip - (-fp_ip)).abs() / expected_ip.abs() < 1e-3,
661            "Inner product (f32) failed: expected {}, got {} on dim : {}",
662            -expected_ip,
663            fp_ip,
664            dim
665        );
666
667        let fp_l2: distances::Result<f32> =
668            MinMaxL2Squared::evaluate(fp_query.reborrow(), v2.reborrow());
669        let fp_l2 = fp_l2.unwrap();
670        assert!(
671            ((fp_l2 - expected_l2).abs() / expected_l2) < 1e-3,
672            "L2 distance (f32) failed: expected {}, got {} on dim : {}",
673            expected_l2,
674            computed_l2_f32,
675            dim
676        );
677
678        let fp_cosine: distances::Result<f32> =
679            MinMaxCosine::evaluate(fp_query.reborrow(), v2.reborrow());
680        let fp_cosine = fp_cosine.unwrap();
681        let diff = (fp_cosine - expected_cosine).abs();
682        assert!(
683            (diff / expected_cosine) < 1e-3 || diff <= 1e-6,
684            "Cosine distance (f32) failed: expected {}, got {} on dim : {}",
685            expected_cosine,
686            fp_cosine,
687            dim
688        );
689
690        let fp_cos_norm: distances::Result<f32> =
691            MinMaxCosineNormalized::evaluate(fp_query.reborrow(), v2.reborrow());
692        let fp_cos_norm = fp_cos_norm.unwrap();
693        assert!(
694            (((1.0 - expected_ip) - fp_cos_norm).abs() / (1.0 - expected_ip)) < 1e-3,
695            "Cosine distance (f32) failed: expected {}, got {} on dim : {}",
696            (1.0 - expected_ip),
697            fp_cos_norm,
698            dim
699        );
700
701        //Test `decompress_into` to make sure it outputs tje full-precision vector correctly.
702        let meta = v1.meta();
703        let v1_ref = DataRef::new(v1.vector(), &meta);
704        let dim = v1_ref.len();
705        let mut boxed = vec![0f32; dim + 1];
706
707        let pre = v1_ref.decompress_into(&mut boxed);
708        assert_eq!(
709            pre.unwrap_err(),
710            DecompressError::LengthMismatch(dim, dim + 1)
711        );
712        let pre = v1_ref.decompress_into(&mut boxed[..dim - 1]);
713        assert_eq!(
714            pre.unwrap_err(),
715            DecompressError::LengthMismatch(dim, dim - 1)
716        );
717        let pre = v1_ref.decompress_into(&mut boxed[..dim]);
718        assert!(pre.is_ok());
719
720        boxed
721            .iter()
722            .zip(original1.iter())
723            .for_each(|(x, y)| assert!((*x - *y).abs() <= 1e-6));
724
725        // Verify `read_dimension` is correct.
726        let mut bytes = vec![0u8; Data::canonical_bytes(dim)];
727        let mut data = DataMutRef::from_canonical_front_mut(bytes.as_mut_slice(), dim).unwrap();
728        data.set_meta(meta);
729
730        let pre = MinMaxCompensation::read_dimension(&bytes);
731        assert!(pre.is_ok());
732        let read_dim = pre.unwrap();
733        assert_eq!(read_dim, dim);
734
735        let pre = MinMaxCompensation::read_dimension(&[0_u8; 2]);
736        assert_eq!(pre.unwrap_err(), MetaParseError::NotCanonical(2));
737    }
738
739    cfg_if::cfg_if! {
740        if #[cfg(miri)] {
741            // The max dim does not need to be as high for `CompensatedVectors` because they
742            // defer their distance function implementation to `BitSlice`, which is more
743            // heavily tested.
744            const TRIALS: usize = 2;
745        } else {
746            const TRIALS: usize = 10;
747        }
748    }
749
750    macro_rules! test_minmax_compensated {
751        ($name:ident, $nbits:literal, $seed:literal) => {
752            #[test]
753            fn $name() {
754                let mut rng = StdRng::seed_from_u64($seed);
755                for dim in 1..(bit_scale::<$nbits>() as usize) {
756                    for _ in 0..TRIALS {
757                        test_minmax_compensated_vectors::<$nbits, _>(dim, &mut rng);
758                    }
759                }
760            }
761        };
762    }
763    test_minmax_compensated!(unsigned_minmax_compensated_test_u1, 1, 0xa32d5658097a1c35);
764    test_minmax_compensated!(unsigned_minmax_compensated_test_u2, 2, 0xaedf3d2a223b7b77);
765    test_minmax_compensated!(unsigned_minmax_compensated_test_u4, 4, 0xf60c0c8d1aadc126);
766    test_minmax_compensated!(unsigned_minmax_compensated_test_u8, 8, 0x09fa14c42a9d7d98);
767}