Skip to main content

diskann_quantization/minmax/
quantizer.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use super::vectors::{DataMutRef, FullQueryMut, MinMaxCompensation, MinMaxIP, MinMaxL2Squared};
7use core::f32;
8
9use crate::{
10    AsFunctor, CompressInto,
11    algorithms::Transform,
12    alloc::{GlobalAllocator, ScopedAllocator},
13    bits::{Representation, Unsigned},
14    minmax::{MinMaxCosine, MinMaxCosineNormalized, vectors::FullQueryMeta},
15    num::Positive,
16    scalar::{InputContainsNaN, bit_scale},
17};
18
19/// Recall that from the module-level documentation, MinMaxQuantizer, quantizes X
20/// into `n` bit vectors as follows  -
21/// ```math
22/// X' = round((X - s) * (2^n - 1) / c).clamp(0, 2^n - 1))
23/// ```
24/// where `s` is a shift value and `c` is a scaling parameter computed from the range of values.
25///
26/// For most bit widths (>1), given a positive scaling parameter `grid_scale : f32`,
27/// these are computed as:
28/// ```math
29/// - m = (max_i X[i] + min_i X[i]) / 2.0
30/// - w = max_i X[i] - min_i X[i]
31///
32/// - s = m - w * grid_scale
33/// - c = 2 * w * grid_scale
34///
35/// where `grid_scale` is an input to the quantizer.
36/// ```
37/// For 1-bit quantization, to avoid outliers, `s` and `c` are derived differently:
38/// - Values are first split into two groups: those below and above the mean.
39/// - `s` is the average of values below the mean.
40/// - `c` is the difference between the average of values above the mean and `s`.
41///
42/// See [`MinMaxCompensation`] for notation.
43/// We have then that
44/// ```math
45/// X = X' * (c / (2^n - 1)) + s
46///          --------------    -
47///                 |          |
48///                ax          bx
49/// ```
50pub struct MinMaxQuantizer {
51    /// Support for different strategies of pre-transforming vectors before applying compression.
52    /// See [`Transform`] for more details on supported types. The input dimension of vectors
53    /// to the quantizer is derived from `transform.input_dim()`.
54    transform: Transform<GlobalAllocator>,
55
56    /// Scaling parameter used to scale the range (min, max) in order to avoid outliers.
57    /// The input must be a positive value. In general, any value between [0.8, 1] does well.
58    grid_scale: Positive<f32>,
59}
60
61impl MinMaxQuantizer {
62    /// Instantiates a new quantizer with specific transform.
63    pub fn new(transform: Transform<GlobalAllocator>, grid_scale: Positive<f32>) -> Self {
64        Self {
65            transform,
66            grid_scale,
67        }
68    }
69
70    /// Input dimension of vectors to quantizer.
71    pub fn dim(&self) -> usize {
72        self.transform.input_dim()
73    }
74
75    /// Output dimension of vectors after applying transform.
76    ///
77    /// Output storage vectors should use this dimension instead of `self.dim()` because
78    /// in general, the output dim **may** be different from the input dimension.
79    pub fn output_dim(&self) -> usize {
80        self.transform.output_dim()
81    }
82
83    /// Outputs the minimum and maximum value of the range of values
84    /// for an input vector `vec`. The function cases based on the
85    /// intended number of bits `NBITS` per dimension.
86    ///
87    /// * `1-bit` - In order to avoid outlier values, the range
88    ///   is defined by taking the values larger and smaller than
89    ///   the numeric mean, and then taking the respective means of
90    ///   each of these sets as the `max` and `min`.
91    ///
92    /// * `N-bits` - Computes the `min` and `max` of the vector values.
93    ///
94    /// # Returns
95    ///
96    /// * `(m - w * g, m + w * g)` - the lower and upper end of the range, where,
97    ///   `m = (max + min) / 2.0`, `w = (max - min) / 2.0`, and `g = self.grid_scale`.
98    fn get_range<const NBITS: usize>(&self, vec: &[f32]) -> (f32, f32) {
99        let (min, max) = match NBITS {
100            1 => {
101                let (mut min, mut min_count) = (0.0f32, 0.0f32);
102                let (mut max, mut max_count) = (0.0f32, 0.0f32);
103
104                let mean = vec.iter().sum::<f32>() / (vec.len() as f32);
105
106                vec.iter().for_each(|x| {
107                    let m = f32::from((*x < mean) as u8);
108                    min += m * x;
109                    min_count += m;
110                    max += (1.0 - m) * x;
111                    max_count += 1.0 - m;
112                });
113
114                ((min / min_count).min(mean), (max / max_count).max(mean))
115            }
116            _ => {
117                vec // Using `f32::NAN` since [`core::f32::min`] and `max` output the other value if one of them is NAN .
118                    .iter()
119                    .fold((f32::NAN, f32::NAN), |(cmin, cmax), &e| {
120                        (cmin.min(e), cmax.max(e))
121                    })
122            }
123        };
124
125        let width = (max - min) / 2.0;
126        let mid = min + width;
127
128        (
129            mid - width * self.grid_scale.into_inner(),
130            mid + width * self.grid_scale.into_inner(),
131        )
132    }
133
134    fn compress<const NBITS: usize, T>(
135        &self,
136        from: &[T],
137        mut into: DataMutRef<'_, NBITS>,
138    ) -> Result<L2Loss, InputContainsNaN>
139    where
140        T: Copy + Into<f32>,
141        Unsigned: Representation<NBITS>,
142    {
143        let mut into_vec = into.vector_mut();
144
145        assert_eq!(from.len(), self.dim());
146        assert_eq!(self.output_dim(), into_vec.len());
147
148        let domain = Unsigned::domain_const::<NBITS>();
149        let domain_min = *domain.start() as f32;
150        let domain_max = *domain.end() as f32;
151
152        let mut vec = vec![f32::default(); self.output_dim()];
153
154        // We know vec.len() == self.output_dim() and `from.len() == self.dim`
155        #[allow(clippy::unwrap_used)]
156        self.transform
157            .transform_into(
158                &mut vec,
159                &from.iter().map(|&x| x.into()).collect::<Vec<f32>>(),
160                ScopedAllocator::global(),
161            )
162            .unwrap();
163
164        let (min, max) = self.get_range::<NBITS>(&vec);
165
166        let inverse_scale = (max - min).max(1e-8) / bit_scale::<NBITS>(); // To avoid NaN. This is ONLY possible if the vector is all the same value.
167        let mut norm_squared: f32 = 0.0;
168        let mut code_sum: f32 = 0.0;
169        let mut loss: f32 = 0.0;
170
171        let mut nan_check = false;
172
173        vec.iter().enumerate().for_each(|(i, &v)| {
174            nan_check |= v.is_nan();
175
176            let code = ((v - min) / inverse_scale)
177                .clamp(domain_min, domain_max)
178                .round();
179
180            let v_r = (code * inverse_scale) + min; // reconstructed value for `v`.
181            norm_squared += v_r * v_r;
182            code_sum += code;
183            loss += (v_r - v).powi(2);
184
185            //SAFETY: we checked that the lengths of `from` and `into_vec` are the same.
186            unsafe {
187                into_vec.set_unchecked(i, code as u8);
188            }
189        });
190
191        let meta = MinMaxCompensation {
192            dim: self.output_dim() as u32,
193            b: min,
194            a: inverse_scale,
195            n: inverse_scale * code_sum,
196            norm_squared,
197        };
198
199        into.set_meta(meta);
200
201        if nan_check {
202            Err(InputContainsNaN)
203        } else {
204            Ok(match Positive::new(loss) {
205                Ok(p) => L2Loss::Positive(p),
206                Err(_) => L2Loss::Zero,
207            })
208        }
209    }
210}
211
212/////////////////
213// Compression //
214/////////////////
215
216/// A struct defining euclidean loss from quantization.
217///
218/// For an input vector `x` and its representation `x'`,
219/// this is supposed to store `||x - x'||^2`.
220#[derive(Clone, Copy, Debug)]
221pub enum L2Loss {
222    Zero,
223    Positive(Positive<f32>),
224}
225
226impl L2Loss {
227    /// Euclidean loss as a `f32` value
228    pub fn as_f32(&self) -> f32 {
229        match self {
230            L2Loss::Zero => 0.0,
231            L2Loss::Positive(p) => p.into_inner(),
232        }
233    }
234}
235
236impl<const NBITS: usize, T> CompressInto<&[T], DataMutRef<'_, NBITS>> for MinMaxQuantizer
237where
238    T: Copy + Into<f32>,
239    Unsigned: Representation<NBITS>,
240{
241    type Error = InputContainsNaN;
242
243    type Output = L2Loss;
244
245    /// Compress the input vector `from` into a mut ref of Data `to`.
246    ///
247    /// This method computes and stores the compensation coefficients required for computing
248    /// distances correctly.
249    ///
250    /// # Error
251    ///
252    /// Returns an error if the input contains `NaN`.
253    ///
254    /// # Panics
255    ///
256    /// Panics if:
257    /// * `from.len() != self.dim()`: Vector to be compressed must have the same
258    ///   dimensionality as the quantizer.
259    /// * `to.vector().len() != self.output_dim()`: Compressed vector must have the same dimensionality
260    ///   as the quantizer.
261    fn compress_into(&self, from: &[T], to: DataMutRef<'_, NBITS>) -> Result<L2Loss, Self::Error> {
262        self.compress::<NBITS, T>(from, to)
263    }
264}
265
266impl<'a, T> CompressInto<&[T], FullQueryMut<'a>> for MinMaxQuantizer
267where
268    T: Copy + Into<f32>,
269{
270    type Error = InputContainsNaN;
271
272    type Output = ();
273
274    /// Compress the input vector `from` into a [`FullQueryMut`] `to`.
275    ///
276    /// This method simply applies the transformation to the input without
277    /// any compression.
278    ///
279    /// # Error
280    ///
281    /// Returns an error if the input contains `NaN`.
282    ///
283    /// # Panics
284    ///
285    /// Panics if:
286    /// * `from.len() != self.dim()`: Vector to be compressed must have the same
287    ///   dimensionality as the quantizer.
288    /// * `to.len() != self.output_dim()`: Compressed vector must have the same dimensionality
289    ///   as the quantizer.
290    fn compress_into(&self, from: &[T], mut to: FullQueryMut<'a>) -> Result<(), Self::Error> {
291        assert_eq!(from.len(), self.dim());
292        assert_eq!(self.output_dim(), to.len());
293
294        // Transform the input vector and return error if it contains NaN
295        let from: Vec<f32> = from.iter().map(|&x| x.into()).collect();
296        if from.iter().any(|x| x.is_nan()) {
297            return Err(InputContainsNaN);
298        }
299
300        // We know vec.len() == self.output_dim() and `from.len() == self.dim`
301        #[allow(clippy::unwrap_used)]
302        self.transform
303            .transform_into(to.vector_mut(), &from, ScopedAllocator::global())
304            .unwrap();
305
306        let norm_squared = to.vector().iter().map(|x| *x * *x).sum::<f32>();
307        let sum = to.vector().iter().sum::<f32>();
308
309        *to.meta_mut() = FullQueryMeta { norm_squared, sum };
310
311        Ok(())
312    }
313}
314
315///////////////////////
316// Distance Functors //
317///////////////////////
318
319macro_rules! impl_functor {
320    ($dist:ident) => {
321        impl AsFunctor<$dist> for MinMaxQuantizer {
322            // no need to do any work here.
323            fn as_functor(&self) -> $dist {
324                $dist
325            }
326        }
327    };
328}
329
330impl_functor!(MinMaxIP);
331impl_functor!(MinMaxL2Squared);
332impl_functor!(MinMaxCosine);
333impl_functor!(MinMaxCosineNormalized);
334
335///////////
336// Tests //
337///////////
338#[cfg(test)]
339#[cfg(not(miri))]
340mod minmax_quantizer_tests {
341    use std::num::NonZeroUsize;
342
343    use diskann_utils::{Reborrow, ReborrowMut};
344    use diskann_vector::{PureDistanceFunction, distance::SquaredL2};
345    use rand::{
346        SeedableRng,
347        distr::{Distribution, Uniform},
348        rngs::StdRng,
349    };
350
351    use super::*;
352    use crate::{
353        algorithms::transforms::NullTransform,
354        alloc::GlobalAllocator,
355        minmax::vectors::{Data, DataRef, FullQuery, FullQueryMut},
356    };
357
358    fn reconstruct_minmax<const NBITS: usize>(v: DataRef<'_, NBITS>) -> Vec<f32>
359    where
360        Unsigned: Representation<NBITS>,
361    {
362        (0..v.len())
363            .map(|i| {
364                let m = v.meta();
365                v.vector().get(i).unwrap() as f32 * m.a + m.b
366            })
367            .collect()
368    }
369
370    fn test_quantizer_encoding_random<const NBITS: usize>(
371        dim: usize,
372        rng: &mut StdRng,
373        relative_err: f32,
374        scale: f32,
375    ) where
376        Unsigned: Representation<NBITS>,
377        MinMaxQuantizer: for<'a, 'b> CompressInto<&'a [f32], DataMutRef<'b, NBITS>, Output = L2Loss>
378            + for<'a, 'b> CompressInto<&'a [f32], FullQueryMut<'b>, Output = ()>,
379    {
380        let distribution = Uniform::new_inclusive::<f32, f32>(-1.0, 1.0).unwrap();
381
382        let quantizer = MinMaxQuantizer::new(
383            Transform::Null(NullTransform::new(NonZeroUsize::new(dim).unwrap())),
384            Positive::new(scale).unwrap(),
385        );
386
387        assert_eq!(quantizer.dim(), dim);
388
389        let vector: Vec<f32> = distribution.sample_iter(rng).take(dim).collect();
390
391        let mut encoded = Data::new_boxed(dim);
392        let loss = quantizer
393            .compress_into(&*vector, encoded.reborrow_mut())
394            .unwrap();
395
396        let reconstructed = reconstruct_minmax::<NBITS>(encoded.reborrow());
397        assert_eq!(reconstructed.len(), dim);
398
399        let reconstruction_error: f32 = SquaredL2::evaluate(&*vector, &*reconstructed);
400        let norm = vector.iter().map(|x| x * x).sum::<f32>();
401        assert!(
402            (reconstruction_error / norm) <= relative_err,
403            "Expected vector : {:?} to be reconstructed within error {} but instead got : {:?}, with error {} for dim : {}",
404            &vector,
405            relative_err,
406            &reconstructed,
407            reconstruction_error / norm,
408            dim,
409        );
410
411        assert!((loss.as_f32() - reconstruction_error) <= 1e-4);
412
413        let expected_code_sum = (0..dim)
414            .map(|i| encoded.vector().get(i).unwrap() as f32)
415            .sum::<f32>();
416        let code_sum = encoded.reborrow().meta().n / encoded.reborrow().meta().a;
417        assert!(
418            (code_sum - expected_code_sum).abs() <= 2e-5 * (dim as f32),
419            "Encoded vector with dim : {dim} is {:?}, got error : {} for vector : {:?}",
420            encoded.reborrow(),
421            (code_sum - expected_code_sum).abs(),
422            &vector,
423        );
424        let recon_norm_sq = reconstructed.iter().map(|x| x * x).sum::<f32>();
425        assert!((encoded.reborrow().meta().norm_squared - recon_norm_sq).abs() <= 1e-3);
426
427        // FullQuery
428        let mut f = FullQuery::new_in(dim, GlobalAllocator).unwrap();
429        quantizer
430            .compress_into(vector.as_slice(), f.reborrow_mut())
431            .unwrap();
432
433        f.vector()
434            .iter()
435            .enumerate()
436            .zip(vector.iter())
437            .for_each(|((i, x), y)| {
438                assert!(
439                    (*x - *y).abs() < 1e-10,
440                    "Full Query did not compress dimension {i} with value {} correctly, got {} instead.",
441                    *y,
442                    *x,
443                )
444            });
445
446        assert!(
447            (f.meta().norm_squared - norm).abs() < 1e-10,
448            "Full Query norm in meta should be {norm} but instead got {}",
449            f.meta().norm_squared
450        );
451
452        let sum = vector.iter().sum::<f32>();
453        assert!(
454            (f.meta().sum - sum) < 1e-10,
455            "Full Query norm in meta should be {sum} but instead got {}",
456            f.meta().sum
457        );
458    }
459
460    cfg_if::cfg_if! {
461        if #[cfg(miri)] {
462            // The max dim does not need to be as high for `CompensatedVectors` because they
463            // defer their distance function implementation to `BitSlice`, which is more
464            // heavily tested.
465            const TRIALS: usize = 2;
466        } else {
467            const TRIALS: usize = 10;
468        }
469    }
470
471    macro_rules! test_minmax_quantizer_encoding {
472        ($name:ident, $dim:literal, $nbits:literal, $seed:literal, $err:expr) => {
473            #[test]
474            fn $name() {
475                let mut rng = StdRng::seed_from_u64($seed);
476                let scales = [1.0, 1.1, 0.9];
477                for (s, e) in scales.iter().zip($err) {
478                    for d in 10..$dim {
479                        for _ in 0..TRIALS {
480                            test_quantizer_encoding_random::<$nbits>(d, &mut rng, e, *s);
481                        }
482                    }
483                }
484            }
485        };
486    }
487    test_minmax_quantizer_encoding!(
488        test_minmax_encoding_1bit,
489        100,
490        1,
491        0xa32d5658097a1c35,
492        vec![0.5, 0.5, 0.5]
493    );
494    test_minmax_quantizer_encoding!(
495        test_minmax_encoding_2bit,
496        100,
497        2,
498        0xf60c0c8d1aadc126,
499        vec![0.5, 0.5, 0.5]
500    );
501    test_minmax_quantizer_encoding!(
502        test_minmax_encoding_4bit,
503        100,
504        4,
505        0x09fa14c42a9d7d98,
506        vec![1.0e-2, 1.0e-2, 3.0e-2]
507    );
508    test_minmax_quantizer_encoding!(
509        test_minmax_encoding_8bit,
510        100,
511        8,
512        0xaedf3d2a223b7b77,
513        vec![2.0e-3, 2.0e-3, 7.0e-3]
514    );
515
516    macro_rules! expand_to_bitrates {
517        ($name:ident, $func:ident) => {
518            #[test]
519            fn $name() {
520                $func::<1>();
521                $func::<2>();
522                $func::<4>();
523                $func::<8>();
524            }
525        };
526    }
527
528    /// Tests the edge case where min == max but both are non-zero.
529    fn test_all_same_value_vector<const NBITS: usize>()
530    where
531        Unsigned: Representation<NBITS>,
532        MinMaxQuantizer:
533            for<'a, 'b> CompressInto<&'a [f32], DataMutRef<'b, NBITS>, Output = L2Loss>,
534    {
535        let dim = 30;
536        let quantizer = MinMaxQuantizer::new(
537            Transform::Null(NullTransform::new(NonZeroUsize::new(dim).unwrap())),
538            Positive::new(1.0).unwrap(),
539        );
540        let constant_value = 42.5f32;
541        let vector = vec![constant_value; dim];
542
543        let mut encoded = Data::new_boxed(dim);
544        let result = quantizer.compress_into(&vector, encoded.reborrow_mut());
545
546        assert!(
547            result.is_ok(),
548            "Constant-value vector should compress successfully"
549        );
550
551        assert!(result.unwrap().as_f32().abs() <= 1e-6);
552
553        // Reconstruction should yield the original constant value (approximately)
554        let reconstructed = reconstruct_minmax(encoded.reborrow());
555        for &val in &reconstructed {
556            assert!(
557                (val - constant_value).abs() < 1e-3,
558                "Reconstructed value {} should be close to original {}. Compressed vector is {:?}",
559                val,
560                constant_value,
561                encoded.meta(),
562            );
563        }
564    }
565
566    /// This tests boundary conditions in the quantization logic.
567    fn test_two_distinct_values<const NBITS: usize>()
568    where
569        Unsigned: Representation<NBITS>,
570        MinMaxQuantizer:
571            for<'a, 'b> CompressInto<&'a [f32], DataMutRef<'b, NBITS>, Output = L2Loss>,
572    {
573        let dim = 20;
574        let quantizer = MinMaxQuantizer::new(
575            Transform::Null(NullTransform::new(NonZeroUsize::new(dim).unwrap())),
576            Positive::new(1.0).unwrap(),
577        );
578
579        let val1 = -10.0f32;
580        let val2 = 15.0f32;
581        let mut vector = vec![val1; dim];
582        // Make half the vector the second value
583        for i in vector.iter_mut().skip(dim) {
584            *i = val2;
585        }
586
587        let mut encoded = Data::new_boxed(dim);
588        let result = quantizer.compress_into(&vector, encoded.reborrow_mut());
589
590        assert!(
591            result.is_ok(),
592            "Two-value vector should compress successfully"
593        );
594
595        assert!(result.unwrap().as_f32().abs() <= 1e-6);
596
597        // Verify that only two distinct codes are used
598        let mut codes_used = std::collections::HashSet::new();
599        for i in 0..dim {
600            codes_used.insert(encoded.vector().get(i).unwrap());
601        }
602
603        // For most bit widths, we should see exactly 2 codes (min and max of domain)
604        if NBITS > 1 {
605            assert!(
606                codes_used.len() <= 2,
607                "Should use at most 2 distinct codes for 2-value input, but used: {:?}",
608                codes_used
609            );
610        }
611
612        // Verify reconstruction maintains the two-value structure approximately
613        let reconstructed = reconstruct_minmax(encoded.reborrow());
614        for ((i, val), v) in reconstructed.into_iter().enumerate().zip(&vector) {
615            // Round to nearest 0.1 to account for quantization error
616            assert!(
617                (val - v).abs() < 1e-4,
618                "Reconstructed value in dim : {i} is {val}, when it should be {v}."
619            );
620        }
621    }
622
623    /// Verifies that NaN values in the input cause the expected error but
624    /// dimension in meta is correctly set.
625    fn test_nan_input_error<const NBITS: usize>()
626    where
627        Unsigned: Representation<NBITS>,
628        MinMaxQuantizer:
629            for<'a, 'b> CompressInto<&'a [f32], DataMutRef<'b, NBITS>, Output = L2Loss>,
630    {
631        let dim = 100;
632        let quantizer = MinMaxQuantizer::new(
633            Transform::Null(NullTransform::new(NonZeroUsize::new(dim).unwrap())),
634            Positive::new(1.0).unwrap(),
635        );
636
637        // Test vector with NaN in the middle.
638        let mut vector_nan = vec![1.0f32; dim];
639        vector_nan[33] = f32::NAN;
640        let mut encoded = Data::new_boxed(dim);
641        let result = quantizer.compress_into(&vector_nan, encoded.reborrow_mut());
642        assert!(result.is_err(), "Vector with NaN should cause an error");
643
644        let meta = encoded.meta();
645        assert_eq!(meta.dim as usize, dim);
646    }
647
648    expand_to_bitrates!(all_same_values_vector, test_all_same_value_vector);
649    expand_to_bitrates!(two_distinct_values, test_two_distinct_values);
650    expand_to_bitrates!(nan_input_error, test_nan_input_error);
651
652    /// Verifies that providing a vector with wrong dimensionality causes a panic.
653    #[test]
654    #[should_panic(expected = "assertion `left == right` failed\n  left: 15\n right: 10")]
655    fn test_dimension_mismatch_panic()
656    where
657        Unsigned: Representation<8>,
658        MinMaxQuantizer: for<'a, 'b> CompressInto<&'a [f32], DataMutRef<'b, 8>, Output = L2Loss>,
659    {
660        let expected_dim = 10;
661        let quantizer = MinMaxQuantizer::new(
662            Transform::Null(NullTransform::new(NonZeroUsize::new(expected_dim).unwrap())),
663            Positive::new(1.0).unwrap(),
664        );
665
666        // Provide vector with wrong dimension
667        let wrong_vector = vec![1.0f32; expected_dim + 5]; // Too many dimensions
668        let mut encoded = Data::new_boxed(expected_dim);
669
670        // This should panic due to assertion in compress_into
671        let _ = quantizer.compress_into(&wrong_vector, encoded.reborrow_mut());
672    }
673}