reductive/pq/
pq.rs

1use std::iter;
2use std::iter::Sum;
3
4use log::info;
5use ndarray::{
6    concatenate, s, Array1, Array2, Array3, ArrayBase, ArrayView2, ArrayView3, ArrayViewMut1,
7    ArrayViewMut2, Axis, Data, Ix1, Ix2, NdFloat,
8};
9use num_traits::{AsPrimitive, Bounded, Zero};
10use ordered_float::OrderedFloat;
11use rand::{Rng, RngCore, SeedableRng};
12use rand_xorshift::XorShiftRng;
13use rayon::prelude::*;
14
15use super::primitives;
16use super::{QuantizeVector, Reconstruct, TrainPq};
17use crate::error::ReductiveError;
18use crate::kmeans::{
19    InitialCentroids, KMeansWithCentroids, NIterationsCondition, RandomInstanceCentroids,
20};
21
22/// Product quantizer (Jégou et al., 2011).
23///
24/// A product quantizer is a vector quantizer that slices a vector and
25/// assigns to the *i*-th slice the index of the nearest centroid of the
26/// *i*-th subquantizer. Vector reconstruction consists of concatenating
27/// the centroids that represent the slices.
28#[derive(Clone, Debug, PartialEq)]
29pub struct Pq<A> {
30    pub(crate) projection: Option<Array2<A>>,
31    pub(crate) quantizers: Array3<A>,
32}
33
34impl<A> Pq<A>
35where
36    A: NdFloat,
37{
38    pub fn new(projection: Option<Array2<A>>, quantizers: Array3<A>) -> Self {
39        assert!(
40            !quantizers.is_empty(),
41            "Attempted to construct a product quantizer without quantizers."
42        );
43
44        let reconstructed_len = primitives::reconstructed_len(quantizers.view());
45
46        if let Some(ref projection) = projection {
47            assert_eq!(
48                projection.shape(),
49                [reconstructed_len; 2],
50                "Incorrect projection matrix shape, was: {:?}, should be [{}, {}]",
51                projection.shape(),
52                reconstructed_len,
53                reconstructed_len
54            );
55        }
56
57        Pq {
58            projection,
59            quantizers,
60        }
61    }
62
63    pub(crate) fn check_quantizer_invariants(
64        n_subquantizers: usize,
65        n_subquantizer_bits: u32,
66        n_iterations: usize,
67        n_attempts: usize,
68        instances: ArrayView2<A>,
69    ) -> Result<(), ReductiveError> {
70        if n_subquantizers == 0 || n_subquantizers > instances.ncols() {
71            return Err(ReductiveError::NSubquantizersOutsideRange {
72                n_subquantizers,
73                max_subquantizers: instances.ncols(),
74            });
75        }
76
77        let max_subquantizer_bits = (instances.nrows() as f64).log2().trunc() as u32;
78        if n_subquantizer_bits == 0 || n_subquantizer_bits > max_subquantizer_bits {
79            return Err(ReductiveError::IncorrectNSubquantizerBits {
80                max_subquantizer_bits,
81            });
82        }
83
84        if instances.ncols() % n_subquantizers != 0 {
85            return Err(ReductiveError::IncorrectNumberSubquantizers {
86                n_subquantizers,
87                n_columns: instances.ncols(),
88            });
89        }
90
91        if n_iterations == 0 {
92            return Err(ReductiveError::IncorrectNIterations);
93        }
94
95        if n_attempts == 0 {
96            return Err(ReductiveError::IncorrectNAttempts);
97        }
98
99        Ok(())
100    }
101
102    /// Get the number of centroids per quantizer.
103    pub fn n_quantizer_centroids(&self) -> usize {
104        self.quantizers.len_of(Axis(1))
105    }
106
107    /// Get the projection matrix (if used).
108    pub fn projection(&self) -> Option<ArrayView2<A>> {
109        self.projection.as_ref().map(Array2::view)
110    }
111
112    /// Create initial centroids for a single quantizer.
113    ///
114    /// `subquantizer_idx` is the subquantizer index for which the initial
115    /// centroids should be picked. `subquantizer_idx < n_subquantizers`,
116    /// the total number of subquantizers.
117    pub(crate) fn subquantizer_initial_centroids<S>(
118        subquantizer_idx: usize,
119        n_subquantizers: usize,
120        codebook_len: usize,
121        instances: ArrayBase<S, Ix2>,
122        rng: &mut impl Rng,
123    ) -> Array2<A>
124    where
125        S: Data<Elem = A>,
126    {
127        let sq_dims = instances.ncols() / n_subquantizers;
128
129        let mut random_centroids = RandomInstanceCentroids::new(rng);
130
131        let offset = subquantizer_idx * sq_dims;
132        // ndarray#474
133        #[allow(clippy::deref_addrof)]
134        let sq_instances = instances.slice(s![.., offset..offset + sq_dims]);
135        random_centroids.initial_centroids(sq_instances, Axis(0), codebook_len)
136    }
137
138    /// Train a subquantizer.
139    ///
140    /// `subquantizer_idx` is the index of the subquantizer, where
141    /// `subquantizer_idx < n_subquantizers`, the overall number of
142    /// subquantizers. `codebook_len` is the code book size of the
143    /// quantizer.
144    fn train_subquantizer(
145        subquantizer_idx: usize,
146        n_subquantizers: usize,
147        codebook_len: usize,
148        n_iterations: usize,
149        n_attempts: usize,
150        instances: ArrayView2<A>,
151        mut rng: impl Rng,
152    ) -> Array2<A>
153    where
154        A: Sum,
155        usize: AsPrimitive<A>,
156    {
157        assert!(n_attempts > 0, "Cannot train a subquantizer in 0 attempts.");
158
159        info!("Training PQ subquantizer {}", subquantizer_idx);
160
161        let sq_dims = instances.ncols() / n_subquantizers;
162
163        let offset = subquantizer_idx * sq_dims;
164        // ndarray#474
165        #[allow(clippy::deref_addrof)]
166        let sq_instances = instances.slice(s![.., offset..offset + sq_dims]);
167
168        iter::repeat_with(|| {
169            let mut quantizer = Pq::subquantizer_initial_centroids(
170                subquantizer_idx,
171                n_subquantizers,
172                codebook_len,
173                instances,
174                &mut rng,
175            );
176            let loss = sq_instances.kmeans_with_centroids(
177                Axis(0),
178                quantizer.view_mut(),
179                NIterationsCondition(n_iterations),
180            );
181            (loss, quantizer)
182        })
183        .take(n_attempts)
184        .map(|(loss, quantizer)| (OrderedFloat(loss), quantizer))
185        .min_by_key(|attempt| attempt.0)
186        .unwrap()
187        .1
188    }
189
190    /// Get the subquantizer centroids.
191    pub fn subquantizers(&self) -> ArrayView3<A> {
192        self.quantizers.view()
193    }
194}
195
196impl<A> TrainPq<A> for Pq<A>
197where
198    A: NdFloat + Sum,
199    usize: AsPrimitive<A>,
200{
201    fn train_pq_using<S, R>(
202        n_subquantizers: usize,
203        n_subquantizer_bits: u32,
204        n_iterations: usize,
205        n_attempts: usize,
206        instances: ArrayBase<S, Ix2>,
207        mut rng: &mut R,
208    ) -> Result<Pq<A>, ReductiveError>
209    where
210        S: Sync + Data<Elem = A>,
211        R: RngCore + SeedableRng + Send,
212    {
213        Self::check_quantizer_invariants(
214            n_subquantizers,
215            n_subquantizer_bits,
216            n_iterations,
217            n_attempts,
218            instances.view(),
219        )?;
220
221        let rngs = iter::repeat_with(|| XorShiftRng::from_rng(&mut rng))
222            .take(n_subquantizers)
223            .collect::<Result<Vec<_>, _>>()
224            .map_err(ReductiveError::ConstructRng)?;
225
226        let quantizers = rngs
227            .into_par_iter()
228            .enumerate()
229            .map(|(idx, rng)| {
230                Self::train_subquantizer(
231                    idx,
232                    n_subquantizers,
233                    2usize.pow(n_subquantizer_bits),
234                    n_iterations,
235                    n_attempts,
236                    instances.view(),
237                    rng,
238                )
239                .insert_axis(Axis(0))
240            })
241            .collect::<Vec<_>>();
242
243        let views = quantizers.iter().map(|a| a.view()).collect::<Vec<_>>();
244
245        Ok(Pq {
246            projection: None,
247            quantizers: concatenate(Axis(0), &views).expect("Cannot concatenate subquantizers"),
248        })
249    }
250}
251
252impl<A> QuantizeVector<A> for Pq<A>
253where
254    A: NdFloat + Sum,
255{
256    fn quantize_batch<I, S>(&self, x: ArrayBase<S, Ix2>) -> Array2<I>
257    where
258        I: AsPrimitive<usize> + Bounded + Zero,
259        S: Data<Elem = A>,
260        usize: AsPrimitive<I>,
261    {
262        let mut quantized = Array2::zeros((x.nrows(), self.quantized_len()));
263        self.quantize_batch_into(x, quantized.view_mut());
264        quantized
265    }
266
267    /// Quantize a batch of vectors into an existing matrix.
268    fn quantize_batch_into<I, S>(&self, x: ArrayBase<S, Ix2>, mut quantized: ArrayViewMut2<I>)
269    where
270        I: AsPrimitive<usize> + Bounded + Zero,
271        S: Data<Elem = A>,
272        usize: AsPrimitive<I>,
273    {
274        match self.projection {
275            Some(ref projection) => {
276                let rx = x.dot(projection);
277                primitives::quantize_batch_into(self.quantizers.view(), rx, quantized.view_mut());
278            }
279            None => {
280                primitives::quantize_batch_into(self.quantizers.view(), x, quantized.view_mut());
281            }
282        }
283    }
284
285    fn quantize_vector<I, S>(&self, x: ArrayBase<S, Ix1>) -> Array1<I>
286    where
287        I: AsPrimitive<usize> + Bounded + Zero,
288        S: Data<Elem = A>,
289        usize: AsPrimitive<I>,
290    {
291        match self.projection {
292            Some(ref projection) => {
293                let rx = x.dot(projection);
294                primitives::quantize(self.quantizers.view(), self.reconstructed_len(), rx)
295            }
296            None => primitives::quantize(self.quantizers.view(), self.reconstructed_len(), x),
297        }
298    }
299
300    fn quantized_len(&self) -> usize {
301        self.quantizers.len_of(Axis(0))
302    }
303}
304
305impl<A> Reconstruct<A> for Pq<A>
306where
307    A: NdFloat + Sum,
308{
309    fn reconstruct_batch_into<I, S>(
310        &self,
311        quantized: ArrayBase<S, Ix2>,
312        mut reconstructions: ArrayViewMut2<A>,
313    ) where
314        I: AsPrimitive<usize>,
315        S: Data<Elem = I>,
316    {
317        primitives::reconstruct_batch_into(
318            self.quantizers.view(),
319            quantized,
320            reconstructions.view_mut(),
321        );
322
323        if let Some(ref projection) = self.projection {
324            let projected_reconstruction = reconstructions.dot(&projection.t());
325            reconstructions.assign(&projected_reconstruction);
326        }
327    }
328
329    fn reconstruct_into<I, S>(
330        &self,
331        quantized: ArrayBase<S, Ix1>,
332        mut reconstruction: ArrayViewMut1<A>,
333    ) where
334        I: AsPrimitive<usize>,
335        S: Data<Elem = I>,
336    {
337        primitives::reconstruct_into(self.quantizers.view(), quantized, reconstruction.view_mut());
338
339        if let Some(ref projection) = self.projection {
340            let projected_reconstruction = reconstruction.dot(&projection.t());
341            reconstruction.assign(&projected_reconstruction);
342        }
343    }
344
345    fn reconstructed_len(&self) -> usize {
346        primitives::reconstructed_len(self.quantizers.view())
347    }
348}
349
350#[cfg(test)]
351mod tests {
352    use ndarray::{array, Array1, Array2, Array3, ArrayView2};
353    use rand::distributions::Uniform;
354    use rand::SeedableRng;
355    use rand_chacha::ChaCha8Rng;
356
357    use super::Pq;
358    use crate::linalg::EuclideanDistance;
359    use crate::ndarray_rand::RandomExt;
360    use crate::pq::{QuantizeVector, Reconstruct, TrainPq};
361
362    /// Calculate the average euclidean distances between the the given
363    /// instances and the instances returned by quantizing and then
364    /// reconstructing the instances.
365    fn avg_euclidean_loss(instances: ArrayView2<f32>, quantizer: &Pq<f32>) -> f32 {
366        let mut euclidean_loss = 0f32;
367
368        let quantized: Array2<u8> = quantizer.quantize_batch(instances);
369        let reconstructions = quantizer.reconstruct_batch(quantized);
370
371        for (instance, reconstruction) in instances.outer_iter().zip(reconstructions.outer_iter()) {
372            euclidean_loss += instance.euclidean_distance(reconstruction);
373        }
374
375        euclidean_loss / instances.nrows() as f32
376    }
377
378    fn test_vectors() -> Array2<f32> {
379        array![
380            [0., 2., 0., -0.5, 0., 0.],
381            [1., -0.2, 0., 0.5, 0.5, 0.],
382            [-0.2, 0.2, 0., 0., -2., 0.],
383            [1., 0.2, 0., 0., -2., 0.],
384        ]
385    }
386
387    fn test_quantizations() -> Array2<usize> {
388        array![[1, 1], [0, 1], [1, 0], [0, 0]]
389    }
390
391    fn test_reconstructions() -> Array2<f32> {
392        array![
393            [0., 1., 0., 0., 1., 0.],
394            [1., 0., 0., 0., 1., 0.],
395            [0., 1., 0., 1., -1., 0.],
396            [1., 0., 0., 1., -1., 0.]
397        ]
398    }
399
400    fn test_pq() -> Pq<f32> {
401        let quantizers = array![[[1., 0., 0.], [0., 1., 0.]], [[1., -1., 0.], [0., 1., 0.]],];
402
403        Pq {
404            projection: None,
405            quantizers,
406        }
407    }
408
409    #[test]
410    fn quantize_batch_with_predefined_codebook() {
411        let pq = test_pq();
412
413        assert_eq!(
414            pq.quantize_batch::<usize, _>(test_vectors()),
415            test_quantizations()
416        );
417    }
418
419    #[test]
420    fn quantize_with_predefined_codebook() {
421        let pq = test_pq();
422
423        for (vector, quantization) in test_vectors()
424            .outer_iter()
425            .zip(test_quantizations().outer_iter())
426        {
427            assert_eq!(pq.quantize_vector::<usize, _>(vector), quantization);
428        }
429    }
430
431    #[test]
432    fn quantize_with_pq() {
433        let mut rng = ChaCha8Rng::seed_from_u64(42);
434        let uniform = Uniform::new(0f32, 1f32);
435        let instances = Array2::random_using((256, 20), uniform, &mut rng);
436        let pq = Pq::train_pq_using(10, 7, 10, 1, instances.view(), &mut rng).unwrap();
437        let loss = avg_euclidean_loss(instances.view(), &pq);
438        // Loss is around 0.077.
439        assert!(loss < 0.08);
440    }
441
442    #[test]
443    fn quantize_with_type() {
444        let uniform = Uniform::new(0f32, 1f32);
445        let pq = Pq {
446            projection: None,
447            quantizers: Array3::random((1, 256, 10), uniform),
448        };
449        pq.quantize_vector::<u8, _>(Array1::random((10,), uniform));
450    }
451
452    #[test]
453    #[should_panic]
454    fn quantize_with_too_narrow_type() {
455        let uniform = Uniform::new(0f32, 1f32);
456        let pq = Pq {
457            projection: None,
458            quantizers: Array3::random((1, 257, 10), uniform),
459        };
460        pq.quantize_vector::<u8, _>(Array1::random((10,), uniform));
461    }
462
463    #[test]
464    fn quantizer_lens() {
465        let quantizer = test_pq();
466
467        assert_eq!(quantizer.quantized_len(), 2);
468        assert_eq!(quantizer.reconstructed_len(), 6);
469    }
470
471    #[test]
472    fn reconstruct_batch_with_predefined_codebook() {
473        let pq = test_pq();
474        assert_eq!(
475            pq.reconstruct_batch(test_quantizations()),
476            test_reconstructions()
477        );
478    }
479
480    #[test]
481    fn reconstruct_with_predefined_codebook() {
482        let pq = test_pq();
483
484        for (quantization, reconstruction) in test_quantizations()
485            .outer_iter()
486            .zip(test_reconstructions().outer_iter())
487        {
488            assert_eq!(pq.reconstruct(quantization), reconstruction);
489        }
490    }
491}