1use std::iter;
2use std::iter::Sum;
3
4use log::info;
5use ndarray::{
6 concatenate, s, Array1, Array2, Array3, ArrayBase, ArrayView2, ArrayView3, ArrayViewMut1,
7 ArrayViewMut2, Axis, Data, Ix1, Ix2, NdFloat,
8};
9use num_traits::{AsPrimitive, Bounded, Zero};
10use ordered_float::OrderedFloat;
11use rand::{Rng, RngCore, SeedableRng};
12use rand_xorshift::XorShiftRng;
13use rayon::prelude::*;
14
15use super::primitives;
16use super::{QuantizeVector, Reconstruct, TrainPq};
17use crate::error::ReductiveError;
18use crate::kmeans::{
19 InitialCentroids, KMeansWithCentroids, NIterationsCondition, RandomInstanceCentroids,
20};
21
22#[derive(Clone, Debug, PartialEq)]
29pub struct Pq<A> {
30 pub(crate) projection: Option<Array2<A>>,
31 pub(crate) quantizers: Array3<A>,
32}
33
34impl<A> Pq<A>
35where
36 A: NdFloat,
37{
38 pub fn new(projection: Option<Array2<A>>, quantizers: Array3<A>) -> Self {
39 assert!(
40 !quantizers.is_empty(),
41 "Attempted to construct a product quantizer without quantizers."
42 );
43
44 let reconstructed_len = primitives::reconstructed_len(quantizers.view());
45
46 if let Some(ref projection) = projection {
47 assert_eq!(
48 projection.shape(),
49 [reconstructed_len; 2],
50 "Incorrect projection matrix shape, was: {:?}, should be [{}, {}]",
51 projection.shape(),
52 reconstructed_len,
53 reconstructed_len
54 );
55 }
56
57 Pq {
58 projection,
59 quantizers,
60 }
61 }
62
63 pub(crate) fn check_quantizer_invariants(
64 n_subquantizers: usize,
65 n_subquantizer_bits: u32,
66 n_iterations: usize,
67 n_attempts: usize,
68 instances: ArrayView2<A>,
69 ) -> Result<(), ReductiveError> {
70 if n_subquantizers == 0 || n_subquantizers > instances.ncols() {
71 return Err(ReductiveError::NSubquantizersOutsideRange {
72 n_subquantizers,
73 max_subquantizers: instances.ncols(),
74 });
75 }
76
77 let max_subquantizer_bits = (instances.nrows() as f64).log2().trunc() as u32;
78 if n_subquantizer_bits == 0 || n_subquantizer_bits > max_subquantizer_bits {
79 return Err(ReductiveError::IncorrectNSubquantizerBits {
80 max_subquantizer_bits,
81 });
82 }
83
84 if instances.ncols() % n_subquantizers != 0 {
85 return Err(ReductiveError::IncorrectNumberSubquantizers {
86 n_subquantizers,
87 n_columns: instances.ncols(),
88 });
89 }
90
91 if n_iterations == 0 {
92 return Err(ReductiveError::IncorrectNIterations);
93 }
94
95 if n_attempts == 0 {
96 return Err(ReductiveError::IncorrectNAttempts);
97 }
98
99 Ok(())
100 }
101
102 pub fn n_quantizer_centroids(&self) -> usize {
104 self.quantizers.len_of(Axis(1))
105 }
106
107 pub fn projection(&self) -> Option<ArrayView2<A>> {
109 self.projection.as_ref().map(Array2::view)
110 }
111
112 pub(crate) fn subquantizer_initial_centroids<S>(
118 subquantizer_idx: usize,
119 n_subquantizers: usize,
120 codebook_len: usize,
121 instances: ArrayBase<S, Ix2>,
122 rng: &mut impl Rng,
123 ) -> Array2<A>
124 where
125 S: Data<Elem = A>,
126 {
127 let sq_dims = instances.ncols() / n_subquantizers;
128
129 let mut random_centroids = RandomInstanceCentroids::new(rng);
130
131 let offset = subquantizer_idx * sq_dims;
132 #[allow(clippy::deref_addrof)]
134 let sq_instances = instances.slice(s![.., offset..offset + sq_dims]);
135 random_centroids.initial_centroids(sq_instances, Axis(0), codebook_len)
136 }
137
138 fn train_subquantizer(
145 subquantizer_idx: usize,
146 n_subquantizers: usize,
147 codebook_len: usize,
148 n_iterations: usize,
149 n_attempts: usize,
150 instances: ArrayView2<A>,
151 mut rng: impl Rng,
152 ) -> Array2<A>
153 where
154 A: Sum,
155 usize: AsPrimitive<A>,
156 {
157 assert!(n_attempts > 0, "Cannot train a subquantizer in 0 attempts.");
158
159 info!("Training PQ subquantizer {}", subquantizer_idx);
160
161 let sq_dims = instances.ncols() / n_subquantizers;
162
163 let offset = subquantizer_idx * sq_dims;
164 #[allow(clippy::deref_addrof)]
166 let sq_instances = instances.slice(s![.., offset..offset + sq_dims]);
167
168 iter::repeat_with(|| {
169 let mut quantizer = Pq::subquantizer_initial_centroids(
170 subquantizer_idx,
171 n_subquantizers,
172 codebook_len,
173 instances,
174 &mut rng,
175 );
176 let loss = sq_instances.kmeans_with_centroids(
177 Axis(0),
178 quantizer.view_mut(),
179 NIterationsCondition(n_iterations),
180 );
181 (loss, quantizer)
182 })
183 .take(n_attempts)
184 .map(|(loss, quantizer)| (OrderedFloat(loss), quantizer))
185 .min_by_key(|attempt| attempt.0)
186 .unwrap()
187 .1
188 }
189
190 pub fn subquantizers(&self) -> ArrayView3<A> {
192 self.quantizers.view()
193 }
194}
195
196impl<A> TrainPq<A> for Pq<A>
197where
198 A: NdFloat + Sum,
199 usize: AsPrimitive<A>,
200{
201 fn train_pq_using<S, R>(
202 n_subquantizers: usize,
203 n_subquantizer_bits: u32,
204 n_iterations: usize,
205 n_attempts: usize,
206 instances: ArrayBase<S, Ix2>,
207 mut rng: &mut R,
208 ) -> Result<Pq<A>, ReductiveError>
209 where
210 S: Sync + Data<Elem = A>,
211 R: RngCore + SeedableRng + Send,
212 {
213 Self::check_quantizer_invariants(
214 n_subquantizers,
215 n_subquantizer_bits,
216 n_iterations,
217 n_attempts,
218 instances.view(),
219 )?;
220
221 let rngs = iter::repeat_with(|| XorShiftRng::from_rng(&mut rng))
222 .take(n_subquantizers)
223 .collect::<Result<Vec<_>, _>>()
224 .map_err(ReductiveError::ConstructRng)?;
225
226 let quantizers = rngs
227 .into_par_iter()
228 .enumerate()
229 .map(|(idx, rng)| {
230 Self::train_subquantizer(
231 idx,
232 n_subquantizers,
233 2usize.pow(n_subquantizer_bits),
234 n_iterations,
235 n_attempts,
236 instances.view(),
237 rng,
238 )
239 .insert_axis(Axis(0))
240 })
241 .collect::<Vec<_>>();
242
243 let views = quantizers.iter().map(|a| a.view()).collect::<Vec<_>>();
244
245 Ok(Pq {
246 projection: None,
247 quantizers: concatenate(Axis(0), &views).expect("Cannot concatenate subquantizers"),
248 })
249 }
250}
251
252impl<A> QuantizeVector<A> for Pq<A>
253where
254 A: NdFloat + Sum,
255{
256 fn quantize_batch<I, S>(&self, x: ArrayBase<S, Ix2>) -> Array2<I>
257 where
258 I: AsPrimitive<usize> + Bounded + Zero,
259 S: Data<Elem = A>,
260 usize: AsPrimitive<I>,
261 {
262 let mut quantized = Array2::zeros((x.nrows(), self.quantized_len()));
263 self.quantize_batch_into(x, quantized.view_mut());
264 quantized
265 }
266
267 fn quantize_batch_into<I, S>(&self, x: ArrayBase<S, Ix2>, mut quantized: ArrayViewMut2<I>)
269 where
270 I: AsPrimitive<usize> + Bounded + Zero,
271 S: Data<Elem = A>,
272 usize: AsPrimitive<I>,
273 {
274 match self.projection {
275 Some(ref projection) => {
276 let rx = x.dot(projection);
277 primitives::quantize_batch_into(self.quantizers.view(), rx, quantized.view_mut());
278 }
279 None => {
280 primitives::quantize_batch_into(self.quantizers.view(), x, quantized.view_mut());
281 }
282 }
283 }
284
285 fn quantize_vector<I, S>(&self, x: ArrayBase<S, Ix1>) -> Array1<I>
286 where
287 I: AsPrimitive<usize> + Bounded + Zero,
288 S: Data<Elem = A>,
289 usize: AsPrimitive<I>,
290 {
291 match self.projection {
292 Some(ref projection) => {
293 let rx = x.dot(projection);
294 primitives::quantize(self.quantizers.view(), self.reconstructed_len(), rx)
295 }
296 None => primitives::quantize(self.quantizers.view(), self.reconstructed_len(), x),
297 }
298 }
299
300 fn quantized_len(&self) -> usize {
301 self.quantizers.len_of(Axis(0))
302 }
303}
304
305impl<A> Reconstruct<A> for Pq<A>
306where
307 A: NdFloat + Sum,
308{
309 fn reconstruct_batch_into<I, S>(
310 &self,
311 quantized: ArrayBase<S, Ix2>,
312 mut reconstructions: ArrayViewMut2<A>,
313 ) where
314 I: AsPrimitive<usize>,
315 S: Data<Elem = I>,
316 {
317 primitives::reconstruct_batch_into(
318 self.quantizers.view(),
319 quantized,
320 reconstructions.view_mut(),
321 );
322
323 if let Some(ref projection) = self.projection {
324 let projected_reconstruction = reconstructions.dot(&projection.t());
325 reconstructions.assign(&projected_reconstruction);
326 }
327 }
328
329 fn reconstruct_into<I, S>(
330 &self,
331 quantized: ArrayBase<S, Ix1>,
332 mut reconstruction: ArrayViewMut1<A>,
333 ) where
334 I: AsPrimitive<usize>,
335 S: Data<Elem = I>,
336 {
337 primitives::reconstruct_into(self.quantizers.view(), quantized, reconstruction.view_mut());
338
339 if let Some(ref projection) = self.projection {
340 let projected_reconstruction = reconstruction.dot(&projection.t());
341 reconstruction.assign(&projected_reconstruction);
342 }
343 }
344
345 fn reconstructed_len(&self) -> usize {
346 primitives::reconstructed_len(self.quantizers.view())
347 }
348}
349
350#[cfg(test)]
351mod tests {
352 use ndarray::{array, Array1, Array2, Array3, ArrayView2};
353 use rand::distributions::Uniform;
354 use rand::SeedableRng;
355 use rand_chacha::ChaCha8Rng;
356
357 use super::Pq;
358 use crate::linalg::EuclideanDistance;
359 use crate::ndarray_rand::RandomExt;
360 use crate::pq::{QuantizeVector, Reconstruct, TrainPq};
361
362 fn avg_euclidean_loss(instances: ArrayView2<f32>, quantizer: &Pq<f32>) -> f32 {
366 let mut euclidean_loss = 0f32;
367
368 let quantized: Array2<u8> = quantizer.quantize_batch(instances);
369 let reconstructions = quantizer.reconstruct_batch(quantized);
370
371 for (instance, reconstruction) in instances.outer_iter().zip(reconstructions.outer_iter()) {
372 euclidean_loss += instance.euclidean_distance(reconstruction);
373 }
374
375 euclidean_loss / instances.nrows() as f32
376 }
377
378 fn test_vectors() -> Array2<f32> {
379 array![
380 [0., 2., 0., -0.5, 0., 0.],
381 [1., -0.2, 0., 0.5, 0.5, 0.],
382 [-0.2, 0.2, 0., 0., -2., 0.],
383 [1., 0.2, 0., 0., -2., 0.],
384 ]
385 }
386
387 fn test_quantizations() -> Array2<usize> {
388 array![[1, 1], [0, 1], [1, 0], [0, 0]]
389 }
390
391 fn test_reconstructions() -> Array2<f32> {
392 array![
393 [0., 1., 0., 0., 1., 0.],
394 [1., 0., 0., 0., 1., 0.],
395 [0., 1., 0., 1., -1., 0.],
396 [1., 0., 0., 1., -1., 0.]
397 ]
398 }
399
400 fn test_pq() -> Pq<f32> {
401 let quantizers = array![[[1., 0., 0.], [0., 1., 0.]], [[1., -1., 0.], [0., 1., 0.]],];
402
403 Pq {
404 projection: None,
405 quantizers,
406 }
407 }
408
409 #[test]
410 fn quantize_batch_with_predefined_codebook() {
411 let pq = test_pq();
412
413 assert_eq!(
414 pq.quantize_batch::<usize, _>(test_vectors()),
415 test_quantizations()
416 );
417 }
418
419 #[test]
420 fn quantize_with_predefined_codebook() {
421 let pq = test_pq();
422
423 for (vector, quantization) in test_vectors()
424 .outer_iter()
425 .zip(test_quantizations().outer_iter())
426 {
427 assert_eq!(pq.quantize_vector::<usize, _>(vector), quantization);
428 }
429 }
430
431 #[test]
432 fn quantize_with_pq() {
433 let mut rng = ChaCha8Rng::seed_from_u64(42);
434 let uniform = Uniform::new(0f32, 1f32);
435 let instances = Array2::random_using((256, 20), uniform, &mut rng);
436 let pq = Pq::train_pq_using(10, 7, 10, 1, instances.view(), &mut rng).unwrap();
437 let loss = avg_euclidean_loss(instances.view(), &pq);
438 assert!(loss < 0.08);
440 }
441
442 #[test]
443 fn quantize_with_type() {
444 let uniform = Uniform::new(0f32, 1f32);
445 let pq = Pq {
446 projection: None,
447 quantizers: Array3::random((1, 256, 10), uniform),
448 };
449 pq.quantize_vector::<u8, _>(Array1::random((10,), uniform));
450 }
451
452 #[test]
453 #[should_panic]
454 fn quantize_with_too_narrow_type() {
455 let uniform = Uniform::new(0f32, 1f32);
456 let pq = Pq {
457 projection: None,
458 quantizers: Array3::random((1, 257, 10), uniform),
459 };
460 pq.quantize_vector::<u8, _>(Array1::random((10,), uniform));
461 }
462
463 #[test]
464 fn quantizer_lens() {
465 let quantizer = test_pq();
466
467 assert_eq!(quantizer.quantized_len(), 2);
468 assert_eq!(quantizer.reconstructed_len(), 6);
469 }
470
471 #[test]
472 fn reconstruct_batch_with_predefined_codebook() {
473 let pq = test_pq();
474 assert_eq!(
475 pq.reconstruct_batch(test_quantizations()),
476 test_reconstructions()
477 );
478 }
479
480 #[test]
481 fn reconstruct_with_predefined_codebook() {
482 let pq = test_pq();
483
484 for (quantization, reconstruction) in test_quantizations()
485 .outer_iter()
486 .zip(test_reconstructions().outer_iter())
487 {
488 assert_eq!(pq.reconstruct(quantization), reconstruction);
489 }
490 }
491}