diskann_quantization/minmax/
quantizer.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use super::vectors::{DataMutRef, MinMaxCompensation, MinMaxIP, MinMaxL2Squared};
7use core::f32;
8use diskann_utils::views::MutDenseData;
9
10use crate::{
11    algorithms::Transform,
12    alloc::{GlobalAllocator, ScopedAllocator},
13    bits::{Representation, Unsigned},
14    minmax::{vectors::FullQueryMeta, FullQuery, MinMaxCosine, MinMaxCosineNormalized},
15    num::Positive,
16    scalar::{bit_scale, InputContainsNaN},
17    AsFunctor, CompressInto,
18};
19
20/// Recall that from the module-level documentation, MinMaxQuantizer, quantizes X
21/// into `n` bit vectors as follows  -
22/// ```math
23/// X' = round((X - s) * (2^n - 1) / c).clamp(0, 2^n - 1))
24/// ```
25/// where `s` is a shift value and `c` is a scaling parameter computed from the range of values.
26///
27/// For most bit widths (>1), given a positive scaling parameter `grid_scale : f32`,
28/// these are computed as:
29/// ```math
30/// - m = (max_i X[i] + min_i X[i]) / 2.0
31/// - w = max_i X[i] - min_i X[i]
32///
33/// - s = m - w * grid_scale
34/// - c = 2 * w * grid_scale
35///
36/// where `grid_scale` is an input to the quantizer.
37/// ```
38/// For 1-bit quantization, to avoid outliers, `s` and `c` are derived differently:
39/// - Values are first split into two groups: those below and above the mean.
40/// - `s` is the average of values below the mean.
41/// - `c` is the difference between the average of values above the mean and `s`.
42///
43/// See [`MinMaxCompensation`] for notation.
44/// We have then that
45/// ```math
46/// X = X' * (c / (2^n - 1)) + s
47///          --------------    -
48///                 |          |
49///                ax          bx
50/// ```
51pub struct MinMaxQuantizer {
52    /// Support for different strategies of pre-transforming vectors before applying compression.
53    /// See [`Transform`] for more details on supported types. The input dimension of vectors
54    /// to the quantizer is derived from `transform.input_dim()`.
55    transform: Transform<GlobalAllocator>,
56
57    /// Scaling parameter used to scale the range (min, max) in order to avoid outliers.
58    /// The input must be a positive value. In general, any value between [0.8, 1] does well.
59    grid_scale: Positive<f32>,
60}
61
62impl MinMaxQuantizer {
63    /// Instantiates a new quantizer with specific transform.
64    pub fn new(transform: Transform<GlobalAllocator>, grid_scale: Positive<f32>) -> Self {
65        Self {
66            transform,
67            grid_scale,
68        }
69    }
70
71    /// Input dimension of vectors to quantizer.
72    pub fn dim(&self) -> usize {
73        self.transform.input_dim()
74    }
75
76    /// Output dimension of vectors after applying transform.
77    ///
78    /// Output storage vectors should use this dimension instead of `self.dim()` because
79    /// in general, the output dim **may** be different from the input dimension.
80    pub fn output_dim(&self) -> usize {
81        self.transform.output_dim()
82    }
83
84    /// Outputs the minimum and maximum value of the range of values
85    /// for an input vector `vec`. The function cases based on the
86    /// intended number of bits `NBITS` per dimension.
87    ///
88    /// * `1-bit` - In order to avoid outlier values, the range
89    ///   is defined by taking the values larger and smaller than
90    ///   the numeric mean, and then taking the respective means of
91    ///   each of these sets as the `max` and `min`.
92    ///
93    /// * `N-bits` - Computes the `min` and `max` of the vector values.
94    ///
95    /// # Returns
96    ///
97    /// * `(m - w * g, m + w * g)` - the lower and upper end of the range, where,
98    ///   `m = (max + min) / 2.0`, `w = (max - min) / 2.0`, and `g = self.grid_scale`.
99    fn get_range<const NBITS: usize>(&self, vec: &[f32]) -> (f32, f32) {
100        let (min, max) = match NBITS {
101            1 => {
102                let (mut min, mut min_count) = (0.0f32, 0.0f32);
103                let (mut max, mut max_count) = (0.0f32, 0.0f32);
104
105                let mean = vec.iter().sum::<f32>() / (vec.len() as f32);
106
107                vec.iter().for_each(|x| {
108                    let m = f32::from((*x < mean) as u8);
109                    min += m * x;
110                    min_count += m;
111                    max += (1.0 - m) * x;
112                    max_count += 1.0 - m;
113                });
114
115                ((min / min_count).min(mean), (max / max_count).max(mean))
116            }
117            _ => {
118                vec // Using `f32::NAN` since [`core::f32::min`] and `max` output the other value if one of them is NAN .
119                    .iter()
120                    .fold((f32::NAN, f32::NAN), |(cmin, cmax), &e| {
121                        (cmin.min(e), cmax.max(e))
122                    })
123            }
124        };
125
126        let width = (max - min) / 2.0;
127        let mid = min + width;
128
129        (
130            mid - width * self.grid_scale.into_inner(),
131            mid + width * self.grid_scale.into_inner(),
132        )
133    }
134
135    fn compress<const NBITS: usize, T>(
136        &self,
137        from: &[T],
138        mut into: DataMutRef<'_, NBITS>,
139    ) -> Result<L2Loss, InputContainsNaN>
140    where
141        T: Copy + Into<f32>,
142        Unsigned: Representation<NBITS>,
143    {
144        let mut into_vec = into.vector_mut();
145
146        assert_eq!(from.len(), self.dim());
147        assert_eq!(self.output_dim(), into_vec.len());
148
149        let domain = Unsigned::domain_const::<NBITS>();
150        let domain_min = *domain.start() as f32;
151        let domain_max = *domain.end() as f32;
152
153        let mut vec = vec![f32::default(); self.output_dim()];
154
155        // We know vec.len() == self.output_dim() and `from.len() == self.dim`
156        #[allow(clippy::unwrap_used)]
157        self.transform
158            .transform_into(
159                &mut vec,
160                &from.iter().map(|&x| x.into()).collect::<Vec<f32>>(),
161                ScopedAllocator::global(),
162            )
163            .unwrap();
164
165        let (min, max) = self.get_range::<NBITS>(&vec);
166
167        let inverse_scale = (max - min).max(1e-8) / bit_scale::<NBITS>(); // To avoid NaN. This is ONLY possible if the vector is all the same value.
168        let mut norm_squared: f32 = 0.0;
169        let mut code_sum: f32 = 0.0;
170        let mut loss: f32 = 0.0;
171
172        let mut nan_check = false;
173
174        vec.iter().enumerate().for_each(|(i, &v)| {
175            nan_check |= v.is_nan();
176
177            let code = ((v - min) / inverse_scale)
178                .clamp(domain_min, domain_max)
179                .round();
180
181            let v_r = (code * inverse_scale) + min; // reconstructed value for `v`.
182            norm_squared += v_r * v_r;
183            code_sum += code;
184            loss += (v_r - v).powi(2);
185
186            //SAFETY: we checked that the lengths of `from` and `into_vec` are the same.
187            unsafe {
188                into_vec.set_unchecked(i, code as u8);
189            }
190        });
191
192        let meta = MinMaxCompensation {
193            dim: self.output_dim() as u32,
194            b: min,
195            a: inverse_scale,
196            n: inverse_scale * code_sum,
197            norm_squared,
198        };
199
200        into.set_meta(meta);
201
202        if nan_check {
203            Err(InputContainsNaN)
204        } else {
205            Ok(match Positive::new(loss) {
206                Ok(p) => L2Loss::Positive(p),
207                Err(_) => L2Loss::Zero,
208            })
209        }
210    }
211}
212
213/////////////////
214// Compression //
215/////////////////
216
217/// A struct defining euclidean loss from quantization.
218///
219/// For an input vector `x` and its representation `x'`,
220/// this is supposed to store `||x - x'||^2`.
221#[derive(Clone, Copy, Debug)]
222pub enum L2Loss {
223    Zero,
224    Positive(Positive<f32>),
225}
226
227impl L2Loss {
228    /// Euclidean loss as a `f32` value
229    pub fn as_f32(&self) -> f32 {
230        match self {
231            L2Loss::Zero => 0.0,
232            L2Loss::Positive(p) => p.into_inner(),
233        }
234    }
235}
236
237impl<const NBITS: usize, T> CompressInto<&[T], DataMutRef<'_, NBITS>> for MinMaxQuantizer
238where
239    T: Copy + Into<f32>,
240    Unsigned: Representation<NBITS>,
241{
242    type Error = InputContainsNaN;
243
244    type Output = L2Loss;
245
246    /// Compress the input vector `from` into a mut ref of Data `to`.
247    ///
248    /// This method computes and stores the compensation coefficients required for computing
249    /// distances correctly.
250    ///
251    /// # Error
252    ///
253    /// Returns an error if the input contains `NaN`.
254    ///
255    /// # Panics
256    ///
257    /// Panics if:
258    /// * `from.len() != self.dim()`: Vector to be compressed must have the same
259    ///   dimensionality as the quantizer.
260    /// * `to.vector().len() != self.output_dim()`: Compressed vector must have the same dimensionality
261    ///   as the quantizer.
262    fn compress_into(&self, from: &[T], to: DataMutRef<'_, NBITS>) -> Result<L2Loss, Self::Error> {
263        self.compress::<NBITS, T>(from, to)
264    }
265}
266
267impl<T> CompressInto<&[T], &mut FullQuery> for MinMaxQuantizer
268where
269    T: Copy + Into<f32>,
270{
271    type Error = InputContainsNaN;
272
273    type Output = ();
274
275    /// Compress the input vector `from` into a mutable reference for a [`FullQuery`] `to`.
276    ///
277    /// This method simply applies the transformation to the input without
278    /// any compression.
279    ///
280    /// # Error
281    ///
282    /// Returns an error if the input contains `NaN`.
283    ///
284    /// # Panics
285    ///
286    /// Panics if:
287    /// * `from.len() != self.dim()`: Vector to be compressed must have the same
288    ///   dimensionality as the quantizer.
289    /// * `to.len() != self.output_dim()`: Compressed vector must have the same dimensionality
290    ///   as the quantizer.
291    fn compress_into(&self, from: &[T], to: &mut FullQuery) -> Result<(), Self::Error> {
292        assert_eq!(from.len(), self.dim());
293        assert_eq!(self.output_dim(), to.len());
294
295        // Transform the input vector and return error if it contains NaN
296        let from: Vec<f32> = from.iter().map(|&x| x.into()).collect();
297        if from.iter().any(|x| x.is_nan()) {
298            return Err(InputContainsNaN);
299        }
300
301        // We know vec.len() == self.output_dim() and `from.len() == self.dim`
302        #[allow(clippy::unwrap_used)]
303        self.transform
304            .transform_into(to.data.as_mut_slice(), &from, ScopedAllocator::global())
305            .unwrap();
306
307        let norm_squared = to.data.iter().map(|x| *x * *x).sum::<f32>();
308        let sum = to.data.iter().sum::<f32>();
309
310        to.meta = FullQueryMeta { norm_squared, sum };
311
312        Ok(())
313    }
314}
315
316///////////////////////
317// Distance Functors //
318///////////////////////
319
320macro_rules! impl_functor {
321    ($dist:ident) => {
322        impl AsFunctor<$dist> for MinMaxQuantizer {
323            // no need to do any work here.
324            fn as_functor(&self) -> $dist {
325                $dist
326            }
327        }
328    };
329}
330
331impl_functor!(MinMaxIP);
332impl_functor!(MinMaxL2Squared);
333impl_functor!(MinMaxCosine);
334impl_functor!(MinMaxCosineNormalized);
335
336///////////
337// Tests //
338///////////
339#[cfg(test)]
340mod minmax_quantizer_tests {
341    use std::num::NonZeroUsize;
342
343    use diskann_utils::{Reborrow, ReborrowMut};
344    use diskann_vector::{distance::SquaredL2, PureDistanceFunction};
345    use rand::{
346        distr::{Distribution, Uniform},
347        rngs::StdRng,
348        SeedableRng,
349    };
350
351    use super::*;
352    use crate::{
353        algorithms::transforms::NullTransform,
354        minmax::vectors::{Data, DataRef},
355    };
356
357    fn reconstruct_minmax<const NBITS: usize>(v: DataRef<'_, NBITS>) -> Vec<f32>
358    where
359        Unsigned: Representation<NBITS>,
360    {
361        (0..v.len())
362            .map(|i| {
363                let m = v.meta();
364                v.vector().get(i).unwrap() as f32 * m.a + m.b
365            })
366            .collect()
367    }
368
369    fn test_quantizer_encoding_random<const NBITS: usize>(
370        dim: usize,
371        rng: &mut StdRng,
372        relative_err: f32,
373        scale: f32,
374    ) where
375        Unsigned: Representation<NBITS>,
376        MinMaxQuantizer: for<'a, 'b> CompressInto<&'a [f32], DataMutRef<'b, NBITS>, Output = L2Loss>
377            + for<'a, 'b> CompressInto<&'a [f32], &'b mut FullQuery, Output = ()>,
378    {
379        let distribution = Uniform::new_inclusive::<f32, f32>(-1.0, 1.0).unwrap();
380
381        let quantizer = MinMaxQuantizer::new(
382            Transform::Null(NullTransform::new(NonZeroUsize::new(dim).unwrap())),
383            Positive::new(scale).unwrap(),
384        );
385
386        assert_eq!(quantizer.dim(), dim);
387
388        let vector: Vec<f32> = distribution.sample_iter(rng).take(dim).collect();
389
390        let mut encoded = Data::new_boxed(dim);
391        let loss = quantizer
392            .compress_into(&*vector, encoded.reborrow_mut())
393            .unwrap();
394
395        let reconstructed = reconstruct_minmax::<NBITS>(encoded.reborrow());
396        assert_eq!(reconstructed.len(), dim);
397
398        let reconstruction_error: f32 = SquaredL2::evaluate(&*vector, &*reconstructed);
399        let norm = vector.iter().map(|x| x * x).sum::<f32>();
400        assert!(
401                (reconstruction_error / norm) <= relative_err,
402                "Expected vector : {:?} to be reconstructed within error {} but instead got : {:?}, with error {} for dim : {}",
403                &vector,
404                relative_err,
405                &reconstructed,
406                reconstruction_error / norm,
407                dim,
408            );
409
410        assert!((loss.as_f32() - reconstruction_error) <= 1e-4);
411
412        let expected_code_sum = (0..dim)
413            .map(|i| encoded.vector().get(i).unwrap() as f32)
414            .sum::<f32>();
415        let code_sum = encoded.reborrow().meta().n / encoded.reborrow().meta().a;
416        assert!(
417            (code_sum - expected_code_sum).abs() <= 2e-5 * (dim as f32),
418            "Encoded vector with dim : {dim} is {:?}, got error : {} for vector : {:?}",
419            encoded.reborrow(),
420            (code_sum - expected_code_sum).abs(),
421            &vector,
422        );
423        let recon_norm_sq = reconstructed.iter().map(|x| x * x).sum::<f32>();
424        assert!((encoded.reborrow().meta().norm_squared - recon_norm_sq).abs() <= 1e-3);
425
426        // FullQuery
427        let mut f = FullQuery::empty(dim);
428        quantizer
429            .compress_into(vector.as_slice(), f.reborrow_mut())
430            .unwrap();
431
432        f.data
433            .iter()
434            .enumerate()
435            .zip(vector.iter())
436            .for_each(|((i, x), y)| {
437                assert!(
438                    (*x - *y).abs() < 1e-10,
439                    "Full Query did not compress dimension {i} with value {} correctly, got {} instead.",
440                    *y,
441                    *x,
442                )
443            });
444
445        assert!(
446            (f.meta.norm_squared - norm).abs() < 1e-10,
447            "Full Query norm in meta should be {norm} but instead got {}",
448            f.meta.norm_squared
449        );
450
451        let sum = vector.iter().sum::<f32>();
452        assert!(
453            (f.meta.sum - sum) < 1e-10,
454            "Full Query norm in meta should be {sum} but instead got {}",
455            f.meta.sum
456        );
457    }
458
459    cfg_if::cfg_if! {
460        if #[cfg(miri)] {
461            // The max dim does not need to be as high for `CompensatedVectors` because they
462            // defer their distance function implementation to `BitSlice`, which is more
463            // heavily tested.
464            const TRIALS: usize = 2;
465        } else {
466            const TRIALS: usize = 10;
467        }
468    }
469
470    macro_rules! test_minmax_quantizer_encoding {
471        ($name:ident, $dim:literal, $nbits:literal, $seed:literal, $err:expr) => {
472            #[test]
473            fn $name() {
474                let mut rng = StdRng::seed_from_u64($seed);
475                let scales = [1.0, 1.1, 0.9];
476                for (s, e) in scales.iter().zip($err) {
477                    for d in 10..$dim {
478                        for _ in 0..TRIALS {
479                            test_quantizer_encoding_random::<$nbits>(d, &mut rng, e, *s);
480                        }
481                    }
482                }
483            }
484        };
485    }
486    test_minmax_quantizer_encoding!(
487        test_minmax_encoding_1bit,
488        100,
489        1,
490        0xa32d5658097a1c35,
491        vec![0.5, 0.5, 0.5]
492    );
493    test_minmax_quantizer_encoding!(
494        test_minmax_encoding_2bit,
495        100,
496        2,
497        0xf60c0c8d1aadc126,
498        vec![0.5, 0.5, 0.5]
499    );
500    test_minmax_quantizer_encoding!(
501        test_minmax_encoding_4bit,
502        100,
503        4,
504        0x09fa14c42a9d7d98,
505        vec![1.0e-2, 1.0e-2, 3.0e-2]
506    );
507    test_minmax_quantizer_encoding!(
508        test_minmax_encoding_8bit,
509        100,
510        8,
511        0xaedf3d2a223b7b77,
512        vec![2.0e-3, 2.0e-3, 7.0e-3]
513    );
514
515    macro_rules! expand_to_bitrates {
516        ($name:ident, $func:ident) => {
517            #[test]
518            fn $name() {
519                $func::<1>();
520                $func::<2>();
521                $func::<4>();
522                $func::<8>();
523            }
524        };
525    }
526
527    /// Tests the edge case where min == max but both are non-zero.
528    fn test_all_same_value_vector<const NBITS: usize>()
529    where
530        Unsigned: Representation<NBITS>,
531        MinMaxQuantizer:
532            for<'a, 'b> CompressInto<&'a [f32], DataMutRef<'b, NBITS>, Output = L2Loss>,
533    {
534        let dim = 30;
535        let quantizer = MinMaxQuantizer::new(
536            Transform::Null(NullTransform::new(NonZeroUsize::new(dim).unwrap())),
537            Positive::new(1.0).unwrap(),
538        );
539        let constant_value = 42.5f32;
540        let vector = vec![constant_value; dim];
541
542        let mut encoded = Data::new_boxed(dim);
543        let result = quantizer.compress_into(&vector, encoded.reborrow_mut());
544
545        assert!(
546            result.is_ok(),
547            "Constant-value vector should compress successfully"
548        );
549
550        assert!(result.unwrap().as_f32().abs() <= 1e-6);
551
552        // Reconstruction should yield the original constant value (approximately)
553        let reconstructed = reconstruct_minmax(encoded.reborrow());
554        for &val in &reconstructed {
555            assert!(
556                (val - constant_value).abs() < 1e-3,
557                "Reconstructed value {} should be close to original {}. Compressed vector is {:?}",
558                val,
559                constant_value,
560                encoded.meta(),
561            );
562        }
563    }
564
565    /// This tests boundary conditions in the quantization logic.
566    fn test_two_distinct_values<const NBITS: usize>()
567    where
568        Unsigned: Representation<NBITS>,
569        MinMaxQuantizer:
570            for<'a, 'b> CompressInto<&'a [f32], DataMutRef<'b, NBITS>, Output = L2Loss>,
571    {
572        let dim = 20;
573        let quantizer = MinMaxQuantizer::new(
574            Transform::Null(NullTransform::new(NonZeroUsize::new(dim).unwrap())),
575            Positive::new(1.0).unwrap(),
576        );
577
578        let val1 = -10.0f32;
579        let val2 = 15.0f32;
580        let mut vector = vec![val1; dim];
581        // Make half the vector the second value
582        for i in vector.iter_mut().skip(dim) {
583            *i = val2;
584        }
585
586        let mut encoded = Data::new_boxed(dim);
587        let result = quantizer.compress_into(&vector, encoded.reborrow_mut());
588
589        assert!(
590            result.is_ok(),
591            "Two-value vector should compress successfully"
592        );
593
594        assert!(result.unwrap().as_f32().abs() <= 1e-6);
595
596        // Verify that only two distinct codes are used
597        let mut codes_used = std::collections::HashSet::new();
598        for i in 0..dim {
599            codes_used.insert(encoded.vector().get(i).unwrap());
600        }
601
602        // For most bit widths, we should see exactly 2 codes (min and max of domain)
603        if NBITS > 1 {
604            assert!(
605                codes_used.len() <= 2,
606                "Should use at most 2 distinct codes for 2-value input, but used: {:?}",
607                codes_used
608            );
609        }
610
611        // Verify reconstruction maintains the two-value structure approximately
612        let reconstructed = reconstruct_minmax(encoded.reborrow());
613        for ((i, val), v) in reconstructed.into_iter().enumerate().zip(&vector) {
614            // Round to nearest 0.1 to account for quantization error
615            assert!(
616                (val - v).abs() < 1e-4,
617                "Reconstructed value in dim : {i} is {val}, when it should be {v}."
618            );
619        }
620    }
621
622    /// Verifies that NaN values in the input cause the expected error but
623    /// dimension in meta is correctly set.
624    fn test_nan_input_error<const NBITS: usize>()
625    where
626        Unsigned: Representation<NBITS>,
627        MinMaxQuantizer:
628            for<'a, 'b> CompressInto<&'a [f32], DataMutRef<'b, NBITS>, Output = L2Loss>,
629    {
630        let dim = 100;
631        let quantizer = MinMaxQuantizer::new(
632            Transform::Null(NullTransform::new(NonZeroUsize::new(dim).unwrap())),
633            Positive::new(1.0).unwrap(),
634        );
635
636        // Test vector with NaN in the middle.
637        let mut vector_nan = vec![1.0f32; dim];
638        vector_nan[33] = f32::NAN;
639        let mut encoded = Data::new_boxed(dim);
640        let result = quantizer.compress_into(&vector_nan, encoded.reborrow_mut());
641        assert!(result.is_err(), "Vector with NaN should cause an error");
642
643        let meta = encoded.meta();
644        assert_eq!(meta.dim as usize, dim);
645    }
646
647    expand_to_bitrates!(all_same_values_vector, test_all_same_value_vector);
648    expand_to_bitrates!(two_distinct_values, test_two_distinct_values);
649    expand_to_bitrates!(nan_input_error, test_nan_input_error);
650
651    /// Verifies that providing a vector with wrong dimensionality causes a panic.
652    #[test]
653    #[should_panic(expected = "assertion `left == right` failed\n  left: 15\n right: 10")]
654    fn test_dimension_mismatch_panic()
655    where
656        Unsigned: Representation<8>,
657        MinMaxQuantizer: for<'a, 'b> CompressInto<&'a [f32], DataMutRef<'b, 8>, Output = L2Loss>,
658    {
659        let expected_dim = 10;
660        let quantizer = MinMaxQuantizer::new(
661            Transform::Null(NullTransform::new(NonZeroUsize::new(expected_dim).unwrap())),
662            Positive::new(1.0).unwrap(),
663        );
664
665        // Provide vector with wrong dimension
666        let wrong_vector = vec![1.0f32; expected_dim + 5]; // Too many dimensions
667        let mut encoded = Data::new_boxed(expected_dim);
668
669        // This should panic due to assertion in compress_into
670        let _ = quantizer.compress_into(&wrong_vector, encoded.reborrow_mut());
671    }
672}