Skip to main content

diskann_quantization/minmax/
vectors.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use diskann_vector::{MathematicalValue, PureDistanceFunction};
7use thiserror::Error;
8
9use crate::{
10    alloc::GlobalAllocator,
11    bits::{BitSlice, Dense, Representation, Unsigned},
12    distances,
13    distances::{InnerProduct, MV},
14    meta::{self, slice},
15};
16
17/// A per-vector precomputed coefficients to help compute inner products
18/// and squared L2 distances for the MinMax quantized vectors.
19///
20/// The inner product between `X = ax * X' + bx` and `Y = ay * Y' + by` for d-dimensional
21/// vectors X and Y is:
22/// ```math
23/// <X, Y> = <ax * X' + bx, ay * Y' + by>
24///        = ax * ay * <X', Y'> + ax * <X', by> + ay * <Y', bx> + d * bx * by.
25/// ```
26/// Let us define a grouping of these terms to make it easier to understand:
27/// ```math
28///  Nx = ax * sum_i X'[i],     Ny = ay * sum_i Y'[i],
29/// ```
30/// We can then simplify the dot product calculation as follows:
31/// ```math
32/// <X, Y> = ax * ay * <X', Y'> + Nx * by + Ny * bx +  d * bx * by
33///                    --------
34///                       |
35///               Integer Dot Product
36/// ```
37///
38/// To compute the squared L2 distance,
39/// ```math
40/// |X - Y|^2 = |ax * X' + bx|^2 + |ay * Y' + by|^2 - 2 * <X, Y>
41/// ```
42/// we can re-use the computation for inner-product from above.
43#[derive(Default, Debug, Clone, Copy, PartialEq, bytemuck::Zeroable, bytemuck::Pod)]
44#[repr(C)]
45pub struct MinMaxCompensation {
46    pub dim: u32,          // - dimension
47    pub b: f32,            // - bx
48    pub n: f32,            // - Nx
49    pub a: f32,            // - ax
50    pub norm_squared: f32, // - |ax * X' + bx|^2
51}
52
53const META_BYTES: usize = std::mem::size_of::<MinMaxCompensation>(); // This will be 5 * 4 = 20 bytes.
54
55/// Error type for parsing a slice of bytes as a `DataRef`
56/// and returning corresponding dimension.
57#[derive(Debug, Error, Clone, PartialEq, Eq)]
58pub enum MetaParseError {
59    #[error("Invalid size: {0}, must contain at least {META_BYTES} bytes")]
60    NotCanonical(usize),
61}
62
63impl MinMaxCompensation {
64    /// Reads the dimension from the first 4 bytes of a MinMax quantized vector's metadata.
65    ///
66    /// This function is used to extract the vector dimension from serialized MinMax quantized
67    /// vector data without fully deserializing the entire vector structure.
68    ///
69    /// # Arguments
70    /// * `bytes` - A byte slice containing the serialized MinMax vector data
71    ///
72    /// # Returns
73    /// * `Ok(dimension)` - The dimension of the vector as a `usize`
74    /// * `Err(MetaParseError::NotCanonical(size))` - If the byte slice is shorter than 20 bytes (META_BYTES)
75    ///
76    /// # Usage
77    /// Use this when you need to determine the vector dimension from serialized data before
78    /// creating a `DataRef` or allocating appropriately sized buffers for decompression.
79    #[inline(always)]
80    pub fn read_dimension(bytes: &[u8]) -> Result<usize, MetaParseError> {
81        if bytes.len() < META_BYTES {
82            return Err(MetaParseError::NotCanonical(bytes.len()));
83        }
84
85        // SAFETY: There are at least `META_BYTES` = 20 bytes in the array so this access is within bounds.
86        let dim_bytes: [u8; 4] = bytes.get(..4).map_or_else(
87            || Err(MetaParseError::NotCanonical(bytes.len())),
88            |slice| {
89                slice
90                    .try_into()
91                    .map_err(|_| MetaParseError::NotCanonical(bytes.len()))
92            },
93        )?;
94
95        let dim = u32::from_le_bytes(dim_bytes) as usize;
96
97        Ok(dim)
98    }
99}
100
101/// An owning compressed data vector
102///
103/// See: [`meta::Vector`].
104pub type Data<const NBITS: usize> = meta::Vector<NBITS, Unsigned, MinMaxCompensation, Dense>;
105
106/// A borrowed `Data` vector
107///
108/// See: [`meta::Vector`].
109pub type DataRef<'a, const NBITS: usize> =
110    meta::VectorRef<'a, NBITS, Unsigned, MinMaxCompensation, Dense>;
111
112#[derive(Debug, Error, Clone, Copy, PartialEq, Eq)]
113pub enum DecompressError {
114    #[error("expected src and dst length to be identical, instead src is {0}, and dst is {1}")]
115    LengthMismatch(usize, usize),
116}
117impl<const NBITS: usize> DataRef<'_, NBITS>
118where
119    Unsigned: Representation<NBITS>,
120{
121    /// Decompresses a MinMax quantized vector back into its original floating-point representation.
122    ///
123    /// This method reconstructs the original vector values using the stored quantization parameters
124    /// and the MinMax dequantization formula: `x = x' * a + b` and stores the result in `dst`
125    ///
126    /// # Arguments
127    ///
128    /// * `dst` - A mutable slice of `f32` values where the decompressed data will be written.
129    ///   Must have the same length as the compressed vector.
130    ///
131    /// # Returns
132    ///
133    /// * `Ok(())` - On successful decompression
134    /// * `Err(DecompressError::LengthMismatch(src_len, dst_len))` - If the destination slice
135    ///   length doesn't match the compressed vector length
136    pub fn decompress_into(&self, dst: &mut [f32]) -> Result<(), DecompressError> {
137        if dst.len() != self.len() {
138            return Err(DecompressError::LengthMismatch(self.len(), dst.len()));
139        }
140        let meta = self.meta();
141
142        // SAFETY: We checked that the length of the underlying vector is the same as
143        // as `dst` so we are guaranteed to be within bounds when accessing the vector.
144        dst.iter_mut().enumerate().for_each(|(i, d)| unsafe {
145            *d = self.vector().get_unchecked(i) as f32 * meta.a + meta.b
146        });
147        Ok(())
148    }
149}
150
151/// A mutable borrowed `Data` vector
152///
153/// See: [`meta::Vector`].
154pub type DataMutRef<'a, const NBITS: usize> =
155    meta::VectorMut<'a, NBITS, Unsigned, MinMaxCompensation, Dense>;
156
157////////////////////
158// Full Precision //
159////////////////////
160
161/// A meta struct storing the `sum` and `norm_squared` of a
162/// full query after transformation is applied to it.
163///
164/// The inner product between `X = ax * X' + bx` and `Y` for d-dimensional
165/// vectors X and Y is:
166/// ```math
167/// <X, Y> = <ax * X' + bx, Y>
168///        = ax * <X', Y> + bx * sum(Y).
169///               --------
170///                  |
171///          Integer-Float Dot Product
172/// ```
173///
174/// To compute the squared L2 distance,
175/// ```math
176/// |X - Y|^2 = |ax * X' + bx|^2 + |Y|^2 - 2 * <X', Y>
177/// ```
178#[derive(Debug, Clone, Copy, Default, bytemuck::Zeroable, bytemuck::Pod)]
179#[repr(C)]
180pub struct FullQueryMeta {
181    /// The sum of `data`.
182    pub sum: f32,
183    /// The norm of the 'data'.
184    pub norm_squared: f32,
185}
186
187/// A full precision query.
188///
189/// See: [`slice::Slice`].
190pub type FullQuery<A = GlobalAllocator> = slice::PolySlice<f32, FullQueryMeta, A>;
191
192/// A borrowed full precision query.
193///
194/// See: [`slice::SliceRef`].
195pub type FullQueryRef<'a> = slice::SliceRef<'a, f32, FullQueryMeta>;
196
197/// A mutable borrowed full precision query.
198///
199/// See: [`slice::SliceMut`].
200pub type FullQueryMut<'a> = slice::SliceMut<'a, f32, FullQueryMeta>;
201
202///////////////////////////
203// Compensated Distances //
204///////////////////////////
205#[inline(always)]
206fn kernel<const NBITS: usize, const MBITS: usize, F>(
207    x: DataRef<'_, NBITS>,
208    y: DataRef<'_, MBITS>,
209    f: F,
210) -> distances::MathematicalResult<f32>
211where
212    Unsigned: Representation<NBITS> + Representation<MBITS>,
213    InnerProduct: for<'a, 'b> PureDistanceFunction<
214            BitSlice<'a, NBITS, Unsigned>,
215            BitSlice<'b, MBITS, Unsigned>,
216            distances::MathematicalResult<u32>,
217        >,
218    F: Fn(f32, &MinMaxCompensation, &MinMaxCompensation) -> f32,
219{
220    let raw_product = InnerProduct::evaluate(x.vector(), y.vector())?;
221    let (xm, ym) = (x.meta(), y.meta());
222    let term0 = xm.a * ym.a * raw_product.into_inner() as f32;
223    let term1_x = xm.n * ym.b;
224    let term1_y = ym.n * xm.b;
225    let term2 = xm.b * ym.b * (x.len() as f32);
226
227    let v = term0 + term1_x + term1_y + term2;
228    Ok(MV::new(f(v, &xm, &ym)))
229}
230
231pub struct MinMaxIP;
232
233impl<const NBITS: usize, const MBITS: usize>
234    PureDistanceFunction<DataRef<'_, NBITS>, DataRef<'_, MBITS>, distances::MathematicalResult<f32>>
235    for MinMaxIP
236where
237    Unsigned: Representation<NBITS> + Representation<MBITS>,
238    InnerProduct: for<'a, 'b> PureDistanceFunction<
239            BitSlice<'a, NBITS, Unsigned>,
240            BitSlice<'b, MBITS, Unsigned>,
241            distances::MathematicalResult<u32>,
242        >,
243{
244    #[inline(always)]
245    fn evaluate(
246        x: DataRef<'_, NBITS>,
247        y: DataRef<'_, MBITS>,
248    ) -> distances::MathematicalResult<f32> {
249        kernel(x, y, |v, _, _| v)
250    }
251}
252
253impl<const NBITS: usize, const MBITS: usize>
254    PureDistanceFunction<DataRef<'_, NBITS>, DataRef<'_, MBITS>, distances::Result<f32>>
255    for MinMaxIP
256where
257    Unsigned: Representation<NBITS> + Representation<MBITS>,
258    InnerProduct: for<'a, 'b> PureDistanceFunction<
259            BitSlice<'a, NBITS, Unsigned>,
260            BitSlice<'b, MBITS, Unsigned>,
261            distances::MathematicalResult<u32>,
262        >,
263{
264    #[inline(always)]
265    fn evaluate(x: DataRef<'_, NBITS>, y: DataRef<'_, MBITS>) -> distances::Result<f32> {
266        let v: distances::MathematicalResult<f32> = Self::evaluate(x, y);
267        Ok(-v?.into_inner())
268    }
269}
270
271impl<const NBITS: usize>
272    PureDistanceFunction<FullQueryRef<'_>, DataRef<'_, NBITS>, distances::MathematicalResult<f32>>
273    for MinMaxIP
274where
275    Unsigned: Representation<NBITS>,
276    InnerProduct: for<'a, 'b> PureDistanceFunction<
277            &'a [f32],
278            BitSlice<'b, NBITS, Unsigned>,
279            distances::MathematicalResult<f32>,
280        >,
281{
282    #[inline(always)]
283    fn evaluate(x: FullQueryRef<'_>, y: DataRef<'_, NBITS>) -> distances::MathematicalResult<f32> {
284        let raw_product: f32 = InnerProduct::evaluate(x.vector(), y.vector())?.into_inner();
285        Ok(MathematicalValue::new(
286            raw_product * y.meta().a + x.meta().sum * y.meta().b,
287        ))
288    }
289}
290
291impl<const NBITS: usize>
292    PureDistanceFunction<FullQueryRef<'_>, DataRef<'_, NBITS>, distances::Result<f32>> for MinMaxIP
293where
294    Unsigned: Representation<NBITS>,
295    InnerProduct: for<'a, 'b> PureDistanceFunction<
296            &'a [f32],
297            BitSlice<'b, NBITS, Unsigned>,
298            distances::MathematicalResult<f32>,
299        >,
300{
301    #[inline(always)]
302    fn evaluate(x: FullQueryRef<'_>, y: DataRef<'_, NBITS>) -> distances::Result<f32> {
303        let v: distances::MathematicalResult<f32> = Self::evaluate(x, y);
304        Ok(-v?.into_inner())
305    }
306}
307
308pub struct MinMaxL2Squared;
309
310impl<const NBITS: usize, const MBITS: usize>
311    PureDistanceFunction<DataRef<'_, NBITS>, DataRef<'_, MBITS>, distances::MathematicalResult<f32>>
312    for MinMaxL2Squared
313where
314    Unsigned: Representation<NBITS> + Representation<MBITS>,
315    InnerProduct: for<'a, 'b> PureDistanceFunction<
316            BitSlice<'a, NBITS, Unsigned>,
317            BitSlice<'b, MBITS, Unsigned>,
318            distances::MathematicalResult<u32>,
319        >,
320{
321    #[inline(always)]
322    fn evaluate(
323        x: DataRef<'_, NBITS>,
324        y: DataRef<'_, MBITS>,
325    ) -> distances::MathematicalResult<f32> {
326        kernel(x, y, |v, xm, ym| {
327            -2.0 * v + xm.norm_squared + ym.norm_squared
328        })
329    }
330}
331
332impl<const NBITS: usize, const MBITS: usize>
333    PureDistanceFunction<DataRef<'_, NBITS>, DataRef<'_, MBITS>, distances::Result<f32>>
334    for MinMaxL2Squared
335where
336    Unsigned: Representation<NBITS> + Representation<MBITS>,
337    InnerProduct: for<'a, 'b> PureDistanceFunction<
338            BitSlice<'a, NBITS, Unsigned>,
339            BitSlice<'b, MBITS, Unsigned>,
340            distances::MathematicalResult<u32>,
341        >,
342{
343    #[inline(always)]
344    fn evaluate(x: DataRef<'_, NBITS>, y: DataRef<'_, MBITS>) -> distances::Result<f32> {
345        let v: distances::MathematicalResult<f32> = Self::evaluate(x, y);
346        Ok(v?.into_inner())
347    }
348}
349
350impl<const NBITS: usize>
351    PureDistanceFunction<FullQueryRef<'_>, DataRef<'_, NBITS>, distances::MathematicalResult<f32>>
352    for MinMaxL2Squared
353where
354    Unsigned: Representation<NBITS>,
355    InnerProduct: for<'a, 'b> PureDistanceFunction<
356            &'a [f32],
357            BitSlice<'b, NBITS, Unsigned>,
358            distances::MathematicalResult<f32>,
359        >,
360{
361    #[inline(always)]
362    fn evaluate(x: FullQueryRef<'_>, y: DataRef<'_, NBITS>) -> distances::MathematicalResult<f32> {
363        let raw_product = InnerProduct::evaluate(x.vector(), y.vector())?.into_inner();
364
365        let ym = y.meta();
366        let compensated_ip = raw_product * ym.a + x.meta().sum * ym.b;
367        Ok(MV::new(
368            x.meta().norm_squared + ym.norm_squared - 2.0 * compensated_ip,
369        ))
370    }
371}
372
373impl<const NBITS: usize>
374    PureDistanceFunction<FullQueryRef<'_>, DataRef<'_, NBITS>, distances::Result<f32>>
375    for MinMaxL2Squared
376where
377    Unsigned: Representation<NBITS>,
378    InnerProduct: for<'a, 'b> PureDistanceFunction<
379            &'a [f32],
380            BitSlice<'b, NBITS, Unsigned>,
381            distances::MathematicalResult<f32>,
382        >,
383{
384    #[inline(always)]
385    fn evaluate(x: FullQueryRef<'_>, y: DataRef<'_, NBITS>) -> distances::Result<f32> {
386        let v: distances::MathematicalResult<f32> = Self::evaluate(x, y);
387        Ok(v?.into_inner())
388    }
389}
390
391///////////////////////
392// Cosine Distances //
393///////////////////////
394
395pub struct MinMaxCosine;
396
397impl<const NBITS: usize, const MBITS: usize>
398    PureDistanceFunction<DataRef<'_, NBITS>, DataRef<'_, MBITS>, distances::Result<f32>>
399    for MinMaxCosine
400where
401    Unsigned: Representation<NBITS> + Representation<MBITS>,
402    MinMaxIP: for<'a, 'b> PureDistanceFunction<
403            DataRef<'a, NBITS>,
404            DataRef<'b, MBITS>,
405            distances::MathematicalResult<f32>,
406        >,
407{
408    // 1 - <X, Y> / (|X| * |Y|)
409    #[inline(always)]
410    fn evaluate(x: DataRef<'_, NBITS>, y: DataRef<'_, MBITS>) -> distances::Result<f32> {
411        let ip: MV<f32> = MinMaxIP::evaluate(x, y)?;
412        let (xm, ym) = (x.meta(), y.meta());
413        Ok(1.0 - ip.into_inner() / (xm.norm_squared.sqrt() * ym.norm_squared.sqrt()))
414    }
415}
416
417impl<const NBITS: usize>
418    PureDistanceFunction<FullQueryRef<'_>, DataRef<'_, NBITS>, distances::Result<f32>>
419    for MinMaxCosine
420where
421    Unsigned: Representation<NBITS>,
422    MinMaxIP: for<'a, 'b> PureDistanceFunction<
423            FullQueryRef<'a>,
424            DataRef<'b, NBITS>,
425            distances::MathematicalResult<f32>,
426        >,
427{
428    #[inline(always)]
429    fn evaluate(x: FullQueryRef<'_>, y: DataRef<'_, NBITS>) -> distances::Result<f32> {
430        let ip: MathematicalValue<f32> = MinMaxIP::evaluate(x, y)?;
431        let (xm, ym) = (x.meta().norm_squared, y.meta());
432        Ok(1.0 - ip.into_inner() / (xm.sqrt() * ym.norm_squared.sqrt()))
433        // 1 - <X, Y> / (|X| * |Y|)
434    }
435}
436
437pub struct MinMaxCosineNormalized;
438
439impl<const NBITS: usize, const MBITS: usize>
440    PureDistanceFunction<DataRef<'_, NBITS>, DataRef<'_, MBITS>, distances::Result<f32>>
441    for MinMaxCosineNormalized
442where
443    Unsigned: Representation<NBITS> + Representation<MBITS>,
444    MinMaxIP: for<'a, 'b> PureDistanceFunction<
445            DataRef<'a, NBITS>,
446            DataRef<'b, MBITS>,
447            distances::MathematicalResult<f32>,
448        >,
449{
450    #[inline(always)]
451    fn evaluate(x: DataRef<'_, NBITS>, y: DataRef<'_, MBITS>) -> distances::Result<f32> {
452        let ip: MathematicalValue<f32> = MinMaxIP::evaluate(x, y)?;
453        Ok(1.0 - ip.into_inner()) // 1 - <X, Y>
454    }
455}
456
457impl<const NBITS: usize>
458    PureDistanceFunction<FullQueryRef<'_>, DataRef<'_, NBITS>, distances::Result<f32>>
459    for MinMaxCosineNormalized
460where
461    Unsigned: Representation<NBITS>,
462    MinMaxIP: for<'a, 'b> PureDistanceFunction<
463            FullQueryRef<'a>,
464            DataRef<'b, NBITS>,
465            distances::MathematicalResult<f32>,
466        >,
467{
468    #[inline(always)]
469    fn evaluate(x: FullQueryRef<'_>, y: DataRef<'_, NBITS>) -> distances::Result<f32> {
470        let ip: MathematicalValue<f32> = MinMaxIP::evaluate(x, y)?;
471        Ok(1.0 - ip.into_inner()) // 1 - <X, Y>
472    }
473}
474
475///////////
476// Tests //
477///////////
478
479#[cfg(test)]
480#[cfg(not(miri))]
481mod minmax_vector_tests {
482    use diskann_utils::Reborrow;
483    use rand::{
484        Rng, SeedableRng,
485        distr::{Distribution, Uniform},
486        rngs::StdRng,
487    };
488
489    use super::*;
490    use crate::{alloc::GlobalAllocator, scalar::bit_scale};
491
492    /// Builds a random MinMax quantized vector and its full-precision reconstruction.
493    ///
494    /// Returns `(compressed, original)` where `compressed` has its `MinMaxCompensation`
495    /// metadata fully populated and `original` is the dequantized f32 vector.
496    fn random_minmax_vector<const NBITS: usize>(
497        dim: usize,
498        rng: &mut impl Rng,
499    ) -> (Data<NBITS>, Vec<f32>)
500    where
501        Unsigned: Representation<NBITS>,
502    {
503        let mut v = Data::<NBITS>::new_boxed(dim);
504
505        let domain = Unsigned::domain_const::<NBITS>();
506        let code_dist = Uniform::new_inclusive(*domain.start(), *domain.end()).unwrap();
507
508        {
509            let mut bs = v.vector_mut();
510            for i in 0..dim {
511                bs.set(i, code_dist.sample(rng)).unwrap();
512            }
513        }
514
515        let a: f32 = Uniform::new_inclusive(0.0, 2.0).unwrap().sample(rng);
516        let b: f32 = Uniform::new_inclusive(0.0, 2.0).unwrap().sample(rng);
517
518        let original: Vec<f32> = (0..dim)
519            .map(|i| a * v.vector().get(i).unwrap() as f32 + b)
520            .collect();
521
522        let code_sum: f32 = (0..dim).map(|i| v.vector().get(i).unwrap() as f32).sum();
523        let norm_squared: f32 = original.iter().map(|x| x * x).sum();
524
525        v.set_meta(MinMaxCompensation {
526            a,
527            b,
528            n: a * code_sum,
529            norm_squared,
530            dim: dim as u32,
531        });
532
533        (v, original)
534    }
535
536    fn test_minmax_compensated_vectors<const NBITS: usize, R>(dim: usize, rng: &mut R)
537    where
538        Unsigned: Representation<NBITS>,
539        InnerProduct: for<'a, 'b> PureDistanceFunction<
540                BitSlice<'a, NBITS, Unsigned>,
541                BitSlice<'b, NBITS, Unsigned>,
542                distances::MathematicalResult<u32>,
543            >,
544        InnerProduct: for<'a, 'b> PureDistanceFunction<
545                &'a [f32],
546                BitSlice<'b, NBITS, Unsigned>,
547                distances::MathematicalResult<f32>,
548            >,
549        R: Rng,
550    {
551        assert!(dim <= bit_scale::<NBITS>() as usize);
552
553        let (v1, original1) = random_minmax_vector::<NBITS>(dim, rng);
554        let (v2, original2) = random_minmax_vector::<NBITS>(dim, rng);
555
556        let norm1_squared = v1.meta().norm_squared;
557        let norm2_squared = v2.meta().norm_squared;
558
559        // Calculate raw integer dot product
560        let expected_ip = (0..dim).map(|i| original1[i] * original2[i]).sum::<f32>();
561
562        // Test inner product with f32
563        let computed_ip_f32: distances::Result<f32> =
564            MinMaxIP::evaluate(v1.reborrow(), v2.reborrow());
565        let computed_ip_f32 = computed_ip_f32.unwrap();
566        assert!(
567            (expected_ip - (-computed_ip_f32)).abs() / expected_ip.abs() < 1e-3,
568            "Inner product (f32) failed: expected {}, got {} on dim : {}",
569            -expected_ip,
570            computed_ip_f32,
571            dim
572        );
573
574        // Expected L2 distance = |X|² + |Y|² - 2<X,Y>
575        let expected_l2 = (0..dim)
576            .map(|i| original1[i] - original2[i])
577            .map(|x| x.powf(2.0))
578            .sum::<f32>();
579
580        // Test L2 distance with f32
581        let computed_l2_f32: distances::Result<f32> =
582            MinMaxL2Squared::evaluate(v1.reborrow(), v2.reborrow());
583        let computed_l2_f32 = computed_l2_f32.unwrap();
584        assert!(
585            ((computed_l2_f32 - expected_l2).abs() / expected_l2) < 1e-3,
586            "L2 distance (f32) failed: expected {}, got {} on dim : {}",
587            expected_l2,
588            computed_l2_f32,
589            dim
590        );
591
592        let expected_cosine = 1.0 - expected_ip / (norm1_squared.sqrt() * norm2_squared.sqrt());
593
594        let computed_cosine: distances::Result<f32> =
595            MinMaxCosine::evaluate(v1.reborrow(), v2.reborrow());
596        let computed_cosine = computed_cosine.unwrap();
597
598        {
599            let passed = (computed_cosine - expected_cosine).abs() < 1e-6
600                || ((computed_cosine - expected_cosine).abs() / expected_cosine) < 1e-3;
601
602            assert!(
603                passed,
604                "Cosine distance (f32) failed: expected {}, got {} on dim : {}",
605                expected_cosine, computed_cosine, dim
606            );
607        }
608
609        let cosine_normalized: distances::Result<f32> =
610            MinMaxCosineNormalized::evaluate(v1.reborrow(), v2.reborrow());
611        let cosine_normalized = cosine_normalized.unwrap();
612        let expected_cos_normalized = 1.0 - expected_ip;
613        assert!(
614            ((expected_cos_normalized - cosine_normalized).abs() / expected_cos_normalized.abs())
615                < 1e-6,
616            "CosineNormalized distance (f32) failed: expected {}, got {} on dim : {}",
617            expected_cos_normalized,
618            cosine_normalized,
619            dim
620        );
621
622        //Calculate inner product with full precision vector
623        let mut fp_query = FullQuery::new_in(dim, GlobalAllocator).unwrap();
624        fp_query.vector_mut().copy_from_slice(&original1);
625        *fp_query.meta_mut() = FullQueryMeta {
626            norm_squared: norm1_squared,
627            sum: original1.iter().sum::<f32>(),
628        };
629
630        let fp_ip: distances::Result<f32> = MinMaxIP::evaluate(fp_query.reborrow(), v2.reborrow());
631        let fp_ip = fp_ip.unwrap();
632        assert!(
633            (expected_ip - (-fp_ip)).abs() / expected_ip.abs() < 1e-3,
634            "Inner product (f32) failed: expected {}, got {} on dim : {}",
635            -expected_ip,
636            fp_ip,
637            dim
638        );
639
640        let fp_l2: distances::Result<f32> =
641            MinMaxL2Squared::evaluate(fp_query.reborrow(), v2.reborrow());
642        let fp_l2 = fp_l2.unwrap();
643        assert!(
644            ((fp_l2 - expected_l2).abs() / expected_l2) < 1e-3,
645            "L2 distance (f32) failed: expected {}, got {} on dim : {}",
646            expected_l2,
647            computed_l2_f32,
648            dim
649        );
650
651        let fp_cosine: distances::Result<f32> =
652            MinMaxCosine::evaluate(fp_query.reborrow(), v2.reborrow());
653        let fp_cosine = fp_cosine.unwrap();
654        let diff = (fp_cosine - expected_cosine).abs();
655        assert!(
656            (diff / expected_cosine) < 1e-3 || diff <= 1e-6,
657            "Cosine distance (f32) failed: expected {}, got {} on dim : {}",
658            expected_cosine,
659            fp_cosine,
660            dim
661        );
662
663        let fp_cos_norm: distances::Result<f32> =
664            MinMaxCosineNormalized::evaluate(fp_query.reborrow(), v2.reborrow());
665        let fp_cos_norm = fp_cos_norm.unwrap();
666        assert!(
667            (((1.0 - expected_ip) - fp_cos_norm).abs() / (1.0 - expected_ip)) < 1e-3,
668            "Cosine distance (f32) failed: expected {}, got {} on dim : {}",
669            (1.0 - expected_ip),
670            fp_cos_norm,
671            dim
672        );
673
674        //Test `decompress_into` to make sure it outputs tje full-precision vector correctly.
675        let meta = v1.meta();
676        let v1_ref = DataRef::new(v1.vector(), &meta);
677        let dim = v1_ref.len();
678        let mut boxed = vec![0f32; dim + 1];
679
680        let pre = v1_ref.decompress_into(&mut boxed);
681        assert_eq!(
682            pre.unwrap_err(),
683            DecompressError::LengthMismatch(dim, dim + 1)
684        );
685        let pre = v1_ref.decompress_into(&mut boxed[..dim - 1]);
686        assert_eq!(
687            pre.unwrap_err(),
688            DecompressError::LengthMismatch(dim, dim - 1)
689        );
690        let pre = v1_ref.decompress_into(&mut boxed[..dim]);
691        assert!(pre.is_ok());
692
693        boxed
694            .iter()
695            .zip(original1.iter())
696            .for_each(|(x, y)| assert!((*x - *y).abs() <= 1e-6));
697
698        // Verify `read_dimension` is correct.
699        let mut bytes = vec![0u8; Data::canonical_bytes(dim)];
700        let mut data = DataMutRef::from_canonical_front_mut(bytes.as_mut_slice(), dim).unwrap();
701        data.set_meta(meta);
702
703        let pre = MinMaxCompensation::read_dimension(&bytes);
704        assert!(pre.is_ok());
705        let read_dim = pre.unwrap();
706        assert_eq!(read_dim, dim);
707
708        let pre = MinMaxCompensation::read_dimension(&[0_u8; 2]);
709        assert_eq!(pre.unwrap_err(), MetaParseError::NotCanonical(2));
710    }
711
712    cfg_if::cfg_if! {
713        if #[cfg(miri)] {
714            // The max dim does not need to be as high for `CompensatedVectors` because they
715            // defer their distance function implementation to `BitSlice`, which is more
716            // heavily tested.
717            const TRIALS: usize = 2;
718        } else {
719            const TRIALS: usize = 10;
720        }
721    }
722
723    macro_rules! test_minmax_compensated {
724        ($name:ident, $nbits:literal, $seed:literal) => {
725            #[test]
726            fn $name() {
727                let mut rng = StdRng::seed_from_u64($seed);
728                const MAX_DIM: usize = (bit_scale::<$nbits>() as usize);
729                for dim in 1..=MAX_DIM {
730                    for _ in 0..TRIALS {
731                        test_minmax_compensated_vectors::<$nbits, _>(dim, &mut rng);
732                    }
733                }
734            }
735        };
736    }
737    test_minmax_compensated!(unsigned_minmax_compensated_test_u1, 1, 0xa33d5658097a1c35);
738    test_minmax_compensated!(unsigned_minmax_compensated_test_u2, 2, 0xaedf3d2a223b7b77);
739    test_minmax_compensated!(unsigned_minmax_compensated_test_u4, 4, 0xf60c0c8d1aadc126);
740    test_minmax_compensated!(unsigned_minmax_compensated_test_u8, 8, 0x09fa14c42a9d7d98);
741
742    /// Test the heterogeneous MinMax kernel for N-bit queries × M-bit database vectors.
743    ///
744    /// Verifies that `kernel::<N, M, _>` produces inner-product and squared-L2
745    /// results matching the full-precision reference, for random codes and
746    /// random compensation coefficients.
747    fn test_minmax_heterogeneous_kernel<const NBITS: usize, const MBITS: usize, R>(
748        dim: usize,
749        rng: &mut R,
750    ) where
751        Unsigned: Representation<NBITS> + Representation<MBITS>,
752        InnerProduct: for<'a, 'b> PureDistanceFunction<
753                BitSlice<'a, NBITS, Unsigned>,
754                BitSlice<'b, MBITS, Unsigned>,
755                distances::MathematicalResult<u32>,
756            >,
757        R: Rng,
758    {
759        let (v_query, original1) = random_minmax_vector::<NBITS>(dim, rng);
760        let (v_data, original2) = random_minmax_vector::<MBITS>(dim, rng);
761
762        // ── Inner Product ──
763        let expected_ip: f32 = original1.iter().zip(&original2).map(|(x, y)| x * y).sum();
764        let computed_ip = kernel(v_query.reborrow(), v_data.reborrow(), |v, _, _| v)
765            .unwrap()
766            .into_inner();
767        assert!(
768            (expected_ip - computed_ip).abs() / expected_ip.abs().max(1e-10) < 1e-6,
769            "Heterogeneous IP ({},{}) failed: expected {}, got {} on dim: {}",
770            NBITS,
771            MBITS,
772            expected_ip,
773            computed_ip,
774            dim,
775        );
776    }
777
778    macro_rules! test_minmax_heterogeneous {
779        ($name:ident, $N:literal, $M:literal, $seed:literal) => {
780            #[test]
781            fn $name() {
782                let mut rng = StdRng::seed_from_u64($seed);
783                // Use the smaller bit width's scale as max dimension.
784                const MAX_DIM: usize = bit_scale::<$M>() as usize;
785                for dim in 1..=MAX_DIM {
786                    for _ in 0..TRIALS {
787                        test_minmax_heterogeneous_kernel::<$N, $M, _>(dim, &mut rng);
788                    }
789                }
790            }
791        };
792    }
793
794    test_minmax_heterogeneous!(minmax_heterogeneous_8x4, 8, 4, 0xb7c3d9e5f1a20864);
795    test_minmax_heterogeneous!(minmax_heterogeneous_8x2, 8, 2, 0x4e8f2c6a1d3b5079);
796    test_minmax_heterogeneous!(minmax_heterogeneous_8x1, 8, 1, 0x1b0f2c614d2a7141);
797}