Skip to main content

diskann_quantization/scalar/
train.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use super::quantizer::ScalarQuantizer;
7use crate::{
8    num::Positive,
9    utils::{compute_means_and_average_norm, compute_variances},
10};
11use diskann_utils::views;
12
13/// Parameters controlling the generation of the scalar quantization Quantizer.
14///
15/// When performing scalar quantization, the mean of each dimension will be calculated and
16/// the dataset will be shifted around this mean.
17///
18/// Next, the standard deviation of each dimension will be computed and the maximum `m` found.
19///
20/// The dynamic range of the final compressed encoding will then span
21/// `2 * standard_deviations * m` for each dimension symmetrically about the mean for each
22/// dimension. Values outside the spanned dynamic range will be clamped.
23pub struct ScalarQuantizationParameters {
24    standard_deviations: Positive<f64>,
25}
26
27impl ScalarQuantizationParameters {
28    /// Construct a new quantizer with the given parameters.
29    ///
30    /// # Arguments
31    ///
32    /// * `standard_deviations`: The number of maximal standard deviations to use for the
33    ///   encoding's dynamic range. This number **must** be positive, and generally should
34    ///   be greater than 1.0.
35    ///
36    ///   A good starting value is generally 2.0.
37    pub fn new(standard_deviations: Positive<f64>) -> Self {
38        Self {
39            standard_deviations,
40        }
41    }
42
43    /// Return the current number of standard deviations being used to set the dynamic range.
44    pub fn standard_deviations(&self) -> Positive<f64> {
45        self.standard_deviations
46    }
47
48    /// Train a new [`ScalarQuantizer`] on the input training data.
49    ///
50    /// The training algoritm works as follows:
51    ///
52    /// 1. The medoid of the training data is computed.
53    ///
54    /// 2. The standard deviation for each dimension is then calculated across all rows
55    ///    of the training set.
56    ///
57    /// 3. The maximum standard deviation `s` is computed and the dynamic range `dyn` of the
58    ///    quantizer is computed as `dyn = 2.0 * self.standard_deviations() * s`.
59    ///
60    /// 4. The quantizer is then constructed with `scale = dyn / (2.pow(NBITS) - 1)`.
61    ///
62    /// # Complexity
63    ///
64    /// This method is linear in the number of rows and columns in `data`.
65    ///
66    /// # Allocates
67    ///
68    /// This method allocated memory on the order of `data.ncols()` (the dimensionality of
69    /// the data).
70    ///
71    /// # Parallelism
72    ///
73    /// This function is single threaded.
74    pub fn train<T>(&self, data: views::MatrixView<T>) -> ScalarQuantizer
75    where
76        T: Copy + Into<f64> + Into<f32>,
77    {
78        let (means, mean_norm) = compute_means_and_average_norm(data);
79        let variances = compute_variances(data, &means);
80
81        // Take the maximum variance - that will set our global scaling parameter.
82        let max_std = variances.iter().fold(0.0f64, |max, &x| max.max(x)).sqrt();
83        let p = max_std * self.standard_deviations.into_inner();
84
85        let scale = 2.0 * p;
86        let shift = means.into_iter().map(|i| (i - p) as f32).collect();
87
88        ScalarQuantizer::new(scale as f32, shift, Some(mean_norm as f32))
89    }
90}
91
92// 2.0 seems to be good starting point for scalar quantization.
93//
94// SAFETY: 2.0 is greater than 0.0.
95const DEFAULT_STDEV: Positive<f64> = unsafe { Positive::new_unchecked(2.0) };
96
97impl Default for ScalarQuantizationParameters {
98    fn default() -> Self {
99        Self::new(DEFAULT_STDEV)
100    }
101}
102
103///////////
104// Tests //
105///////////
106
107#[cfg(test)]
108mod tests {
109    use rand::{SeedableRng, rngs::StdRng};
110
111    use super::*;
112    use crate::test_util::create_test_problem;
113
114    fn test_train_impl(nrows: usize, ncols: usize, seed: u64) {
115        // Test Default
116        let default = ScalarQuantizationParameters::default();
117        assert_eq!(default.standard_deviations(), DEFAULT_STDEV);
118
119        let mut rng = StdRng::seed_from_u64(seed);
120        let problem = create_test_problem(nrows, ncols, &mut rng);
121
122        // Compute the maximum standard deviation from the expected variances.
123        let problem_std_max = problem
124            .variances
125            .iter()
126            .copied()
127            .reduce(|a, b| a.max(b))
128            .unwrap()
129            .sqrt();
130
131        // Provide a range of standard deviation requests to the training algoritm.
132        let standard_deviations: [f64; 3] = [1.0, 1.5, 2.0];
133        for std in standard_deviations {
134            let parameters = ScalarQuantizationParameters::new(Positive::new(std).unwrap());
135
136            let quantizer = parameters.train(problem.data.as_view());
137            assert_eq!(quantizer.dim(), ncols);
138
139            let expected_scale = std * 2.0 * problem_std_max;
140            let got_scale = quantizer.scale();
141
142            let relative_diff = (got_scale as f64 - expected_scale) / expected_scale;
143
144            assert!(
145                relative_diff < 1.0e-7,
146                "Relative difference in scaling of {}. Got {}, expected {} \
147                 (nrows = {}, ncols = {})",
148                relative_diff,
149                got_scale,
150                expected_scale,
151                nrows,
152                ncols
153            );
154
155            assert_eq!(quantizer.mean_norm().unwrap(), problem.mean_norm as f32);
156
157            // The quantizer shift should be the dataset mean shifted by the appropriate
158            // amount for the unsigned quantization.
159            let shift = std * problem_std_max;
160            let quantizer_shift = quantizer.shift();
161            for (i, (&got, &expected)) in
162                std::iter::zip(quantizer_shift.iter(), problem.means.iter()).enumerate()
163            {
164                let expected = expected - shift;
165
166                assert_eq!(
167                    got, expected as f32,
168                    "Mismatch in shift amount at index {}, (nrows = {}, ncols = {})",
169                    i, nrows, ncols,
170                );
171            }
172        }
173    }
174
175    #[test]
176    fn test_train() {
177        test_train_impl(10, 16, 0x0b1d3ccb952d3079);
178        test_train_impl(7, 8, 0xda9a5c0a672f43cd);
179    }
180}