diskann_quantization/scalar/train.rs
1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use super::quantizer::ScalarQuantizer;
7use crate::{
8 num::Positive,
9 utils::{compute_means_and_average_norm, compute_variances},
10};
11use diskann_utils::views;
12
13/// Parameters controlling the generation of the scalar quantization Quantizer.
14///
15/// When performing scalar quantization, the mean of each dimension will be calculated and
16/// the dataset will be shifted around this mean.
17///
18/// Next, the standard deviation of each dimension will be computed and the maximum `m` found.
19///
20/// The dynamic range of the final compressed encoding will then span
21/// `2 * standard_deviations * m` for each dimension symmetrically about the mean for each
22/// dimension. Values outside the spanned dynamic range will be clamped.
23pub struct ScalarQuantizationParameters {
24 standard_deviations: Positive<f64>,
25}
26
27impl ScalarQuantizationParameters {
28 /// Construct a new quantizer with the given parameters.
29 ///
30 /// # Arguments
31 ///
32 /// * `standard_deviations`: The number of maximal standard deviations to use for the
33 /// encoding's dynamic range. This number **must** be positive, and generally should
34 /// be greater than 1.0.
35 ///
36 /// A good starting value is generally 2.0.
37 pub fn new(standard_deviations: Positive<f64>) -> Self {
38 Self {
39 standard_deviations,
40 }
41 }
42
43 /// Return the current number of standard deviations being used to set the dynamic range.
44 pub fn standard_deviations(&self) -> Positive<f64> {
45 self.standard_deviations
46 }
47
48 /// Train a new [`ScalarQuantizer`] on the input training data.
49 ///
50 /// The training algoritm works as follows:
51 ///
52 /// 1. The medoid of the training data is computed.
53 ///
54 /// 2. The standard deviation for each dimension is then calculated across all rows
55 /// of the training set.
56 ///
57 /// 3. The maximum standard deviation `s` is computed and the dynamic range `dyn` of the
58 /// quantizer is computed as `dyn = 2.0 * self.standard_deviations() * s`.
59 ///
60 /// 4. The quantizer is then constructed with `scale = dyn / (2.pow(NBITS) - 1)`.
61 ///
62 /// # Complexity
63 ///
64 /// This method is linear in the number of rows and columns in `data`.
65 ///
66 /// # Allocates
67 ///
68 /// This method allocated memory on the order of `data.ncols()` (the dimensionality of
69 /// the data).
70 ///
71 /// # Parallelism
72 ///
73 /// This function is single threaded.
74 pub fn train<T>(&self, data: views::MatrixView<T>) -> ScalarQuantizer
75 where
76 T: Copy + Into<f64> + Into<f32>,
77 {
78 let (means, mean_norm) = compute_means_and_average_norm(data);
79 let variances = compute_variances(data, &means);
80
81 // Take the maximum variance - that will set our global scaling parameter.
82 let max_std = variances.iter().fold(0.0f64, |max, &x| max.max(x)).sqrt();
83 let p = max_std * self.standard_deviations.into_inner();
84
85 let scale = 2.0 * p;
86 let shift = means.into_iter().map(|i| (i - p) as f32).collect();
87
88 ScalarQuantizer::new(scale as f32, shift, Some(mean_norm as f32))
89 }
90}
91
92// 2.0 seems to be good starting point for scalar quantization.
93//
94// SAFETY: 2.0 is greater than 0.0.
95const DEFAULT_STDEV: Positive<f64> = unsafe { Positive::new_unchecked(2.0) };
96
97impl Default for ScalarQuantizationParameters {
98 fn default() -> Self {
99 Self::new(DEFAULT_STDEV)
100 }
101}
102
103///////////
104// Tests //
105///////////
106
107#[cfg(test)]
108mod tests {
109 use rand::{SeedableRng, rngs::StdRng};
110
111 use super::*;
112 use crate::test_util::create_test_problem;
113
114 fn test_train_impl(nrows: usize, ncols: usize, seed: u64) {
115 // Test Default
116 let default = ScalarQuantizationParameters::default();
117 assert_eq!(default.standard_deviations(), DEFAULT_STDEV);
118
119 let mut rng = StdRng::seed_from_u64(seed);
120 let problem = create_test_problem(nrows, ncols, &mut rng);
121
122 // Compute the maximum standard deviation from the expected variances.
123 let problem_std_max = problem
124 .variances
125 .iter()
126 .copied()
127 .reduce(|a, b| a.max(b))
128 .unwrap()
129 .sqrt();
130
131 // Provide a range of standard deviation requests to the training algoritm.
132 let standard_deviations: [f64; 3] = [1.0, 1.5, 2.0];
133 for std in standard_deviations {
134 let parameters = ScalarQuantizationParameters::new(Positive::new(std).unwrap());
135
136 let quantizer = parameters.train(problem.data.as_view());
137 assert_eq!(quantizer.dim(), ncols);
138
139 let expected_scale = std * 2.0 * problem_std_max;
140 let got_scale = quantizer.scale();
141
142 let relative_diff = (got_scale as f64 - expected_scale) / expected_scale;
143
144 assert!(
145 relative_diff < 1.0e-7,
146 "Relative difference in scaling of {}. Got {}, expected {} \
147 (nrows = {}, ncols = {})",
148 relative_diff,
149 got_scale,
150 expected_scale,
151 nrows,
152 ncols
153 );
154
155 assert_eq!(quantizer.mean_norm().unwrap(), problem.mean_norm as f32);
156
157 // The quantizer shift should be the dataset mean shifted by the appropriate
158 // amount for the unsigned quantization.
159 let shift = std * problem_std_max;
160 let quantizer_shift = quantizer.shift();
161 for (i, (&got, &expected)) in
162 std::iter::zip(quantizer_shift.iter(), problem.means.iter()).enumerate()
163 {
164 let expected = expected - shift;
165
166 assert_eq!(
167 got, expected as f32,
168 "Mismatch in shift amount at index {}, (nrows = {}, ncols = {})",
169 i, nrows, ncols,
170 );
171 }
172 }
173 }
174
175 #[test]
176 fn test_train() {
177 test_train_impl(10, 16, 0x0b1d3ccb952d3079);
178 test_train_impl(7, 8, 0xda9a5c0a672f43cd);
179 }
180}