Skip to main content

diskann_quantization/product/
train.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use diskann_utils::{
7    strided::StridedView,
8    views::{self, Matrix},
9};
10#[cfg(feature = "rayon")]
11use rayon::iter::{IntoParallelIterator, ParallelIterator};
12use thiserror::Error;
13
14use crate::{
15    Parallelism,
16    algorithms::kmeans::{self, common::square_norm},
17    cancel::Cancelation,
18    multi_vector::BlockTransposed,
19    random::{BoxedRngBuilder, RngBuilder},
20};
21
22pub struct LightPQTrainingParameters {
23    /// The number of centers for each partition.
24    ncenters: usize,
25    /// The maximum number of iterations for Lloyd's algorithm.
26    lloyds_reps: usize,
27}
28
29impl LightPQTrainingParameters {
30    /// Construct a new light-weight PQ trainer.
31    pub fn new(ncenters: usize, lloyds_reps: usize) -> Self {
32        Self {
33            ncenters,
34            lloyds_reps,
35        }
36    }
37}
38
39#[derive(Debug)]
40pub struct SimplePivots {
41    dim: usize,
42    ncenters: usize,
43    pivots: Vec<Matrix<f32>>,
44}
45
46fn flatten<T: Copy + Default>(pivots: &[Matrix<T>], ncenters: usize, dim: usize) -> Matrix<T> {
47    let mut flattened = Matrix::new(T::default(), ncenters, dim);
48    let mut col_start = 0;
49    for matrix in pivots {
50        assert_eq!(matrix.nrows(), flattened.nrows());
51        for (row_index, row) in matrix.row_iter().enumerate() {
52            let dst = &mut flattened.row_mut(row_index)[col_start..col_start + row.len()];
53            dst.copy_from_slice(row);
54        }
55        col_start += matrix.ncols();
56    }
57    flattened
58}
59
60impl SimplePivots {
61    /// Return the selected pivots for each chunk.
62    pub fn pivots(&self) -> &[Matrix<f32>] {
63        &self.pivots
64    }
65
66    /// Concatenate the individual pivots into a dense representation.
67    pub fn flatten(&self) -> Vec<f32> {
68        flatten(self.pivots(), self.ncenters, self.dim)
69            .into_inner()
70            .into()
71    }
72}
73
74pub trait TrainQuantizer {
75    type Quantizer;
76    type Error: std::error::Error;
77
78    fn train<R, C>(
79        &self,
80        data: views::MatrixView<f32>,
81        schema: crate::views::ChunkOffsetsView<'_>,
82        parallelism: Parallelism,
83        rng_builder: &R,
84        cancelation: &C,
85    ) -> Result<Self::Quantizer, Self::Error>
86    where
87        R: RngBuilder<usize> + Sync,
88        C: Cancelation + Sync;
89}
90
91impl TrainQuantizer for LightPQTrainingParameters {
92    type Quantizer = SimplePivots;
93    type Error = PQTrainingError;
94
95    /// Perform product quantization training on the provided training set and return a
96    /// `SimplePivots` containing the result of kmeans clustering on each partition.
97    ///
98    /// # Panics
99    ///
100    /// Panics if `data.nrows() != schema.dim()`.
101    ///
102    /// # Errors
103    ///
104    /// An error type is returned under the following circumstances:
105    ///
106    /// * A cancellation request is received. This case can be queried by calling
107    ///   `was_canceled` on the returned `PQTrainingError`.
108    /// * `NaN` or infinities are observed during the training process.
109    fn train<R, C>(
110        &self,
111        data: views::MatrixView<f32>,
112        schema: crate::views::ChunkOffsetsView<'_>,
113        parallelism: Parallelism,
114        rng_builder: &R,
115        cancelation: &C,
116    ) -> Result<Self::Quantizer, Self::Error>
117    where
118        R: RngBuilder<usize> + Sync,
119        C: Cancelation + Sync,
120    {
121        // Inner method where we `dyn` away the cancellation token to reduce compile-times.
122        // Unfortunately, we can't quite do the same with the RngBuilder.
123        #[inline(never)]
124        fn train(
125            trainer: &LightPQTrainingParameters,
126            data: views::MatrixView<f32>,
127            schema: crate::views::ChunkOffsetsView<'_>,
128            parallelism: Parallelism,
129            rng_builder: &(dyn BoxedRngBuilder<usize> + Sync),
130            cancelation: &(dyn Cancelation + Sync),
131        ) -> Result<SimplePivots, PQTrainingError> {
132            // Make sure we're provided sane values for our schema.
133            assert_eq!(data.ncols(), schema.dim());
134
135            let thunk = |i| -> Result<Matrix<f32>, PQTrainingError> {
136                let range = schema.at(i);
137
138                // Check for cancelation.
139                let exit_if_canceled = || -> Result<(), PQTrainingError> {
140                    if cancelation.should_cancel() {
141                        Err(PQTrainingError {
142                            chunk: i,
143                            of: schema.len(),
144                            dim: range.len(),
145                            kind: PQTrainingErrorKind::Canceled,
146                        })
147                    } else {
148                        Ok(())
149                    }
150                };
151
152                // This is an early check - if another task hit cancelation, this allows
153                // the remaining tasks to exit early.
154                exit_if_canceled()?;
155
156                let view = StridedView::try_shrink_from(
157                    &(data.as_slice()[range.start..]),
158                    data.nrows(),
159                    range.len(),
160                    schema.dim(),
161                )
162                .map_err(|err| PQTrainingError {
163                    chunk: i,
164                    of: schema.len(),
165                    dim: range.len(),
166                    kind: PQTrainingErrorKind::InternalError(Box::new(err.as_static())),
167                })?;
168
169                // Allocate scratch data structures.
170                let norms: Vec<f32> = view.row_iter().map(square_norm).collect();
171                let transpose = BlockTransposed::<f32, 16>::from_strided(view);
172                let mut centers = Matrix::new(0.0, trainer.ncenters, range.len());
173
174                // Construct the random number generator seeded by the PQ chunk.
175                let mut rng = rng_builder.build_boxed_rng(i);
176
177                // Initialization
178                kmeans::plusplus::kmeans_plusplus_into_inner(
179                    centers.as_mut_view(),
180                    view,
181                    transpose.as_view(),
182                    &norms,
183                    &mut rng,
184                )
185                .or_else(|err| {
186                    // Suppress recoverable errors.
187                    if !err.is_numerically_recoverable() {
188                        Err(PQTrainingError {
189                            chunk: i,
190                            of: schema.len(),
191                            dim: range.len(),
192                            kind: PQTrainingErrorKind::Initialization(Box::new(err)),
193                        })
194                    } else {
195                        Ok(())
196                    }
197                })?;
198
199                // Did a cancelation request come while runing `kmeans++`?
200                exit_if_canceled()?;
201
202                // Kmeans
203                kmeans::lloyds::lloyds_inner(
204                    view,
205                    &norms,
206                    transpose.as_view(),
207                    centers.as_mut_view(),
208                    trainer.lloyds_reps,
209                );
210                Ok(centers)
211            };
212
213            let pivots: Result<Vec<_>, _> = match parallelism {
214                Parallelism::Sequential => (0..schema.len()).map(thunk).collect(),
215
216                #[cfg(feature = "rayon")]
217                Parallelism::Rayon => (0..schema.len()).into_par_iter().map(thunk).collect(),
218            };
219
220            let dim = data.ncols();
221            let ncenters = trainer.ncenters;
222            Ok(SimplePivots {
223                dim,
224                ncenters,
225                pivots: pivots?,
226            })
227        }
228
229        train(self, data, schema, parallelism, rng_builder, cancelation)
230    }
231}
232
233#[derive(Debug, Error)]
234#[error("pq training failed on chunk {chunk} of {of} (dim {dim})")]
235pub struct PQTrainingError {
236    chunk: usize,
237    of: usize,
238    dim: usize,
239    #[source]
240    kind: PQTrainingErrorKind,
241}
242
243impl PQTrainingError {
244    /// Return whether or not this error originated as a cancelation request.
245    pub fn was_canceled(&self) -> bool {
246        matches!(self.kind, PQTrainingErrorKind::Canceled)
247    }
248}
249
250#[derive(Debug, Error)]
251#[non_exhaustive]
252enum PQTrainingErrorKind {
253    #[error("canceled by request")]
254    Canceled,
255    #[error("initial pivot selection error")]
256    Initialization(#[source] Box<dyn std::error::Error + Send + Sync>),
257    #[error("internal logic error")]
258    InternalError(#[source] Box<dyn std::error::Error + Send + Sync>),
259}
260
261///////////
262// Tests //
263///////////
264
265#[cfg(not(miri))]
266#[cfg(test)]
267mod tests {
268    use std::sync::atomic::{AtomicUsize, Ordering};
269
270    use rand::{
271        Rng, SeedableRng,
272        distr::{Distribution, StandardUniform, Uniform},
273        rngs::StdRng,
274        seq::SliceRandom,
275    };
276
277    use diskann_utils::lazy_format;
278
279    use super::*;
280    use crate::{cancel::DontCancel, error::format, random::StdRngBuilder};
281
282    // With this test - we create sub-matrices that when flattened, will yield the output
283    // sequence `0, 1, 2, 3, 4, ...`.
284    #[test]
285    fn test_flatten() {
286        // The number of rows in the final matrix.
287        let nrows = 5;
288        // The dimensions in each sub-matrix.
289        let sub_dims = [1, 2, 3, 4, 5];
290        // The prefix sum of the sub dimensions.
291        let prefix_sum: Vec<usize> = sub_dims
292            .iter()
293            .scan(0, |state, i| {
294                let this = *state;
295                *state += *i;
296                Some(this)
297            })
298            .collect();
299
300        let dim: usize = sub_dims.iter().sum();
301
302        // Create the sub matrices.
303        let matrices: Vec<Matrix<usize>> = std::iter::zip(sub_dims.iter(), prefix_sum.iter())
304            .map(|(&this_dim, &offset)| {
305                let mut m = Matrix::new(0, nrows, this_dim);
306                for r in 0..nrows {
307                    for c in 0..this_dim {
308                        m[(r, c)] = dim * r + offset + c;
309                    }
310                }
311                m
312            })
313            .collect();
314
315        let flattened = flatten(&matrices, nrows, dim);
316        // Check that the output is correct.
317        for (i, v) in flattened.as_slice().iter().enumerate() {
318            assert_eq!(*v, i, "failed at index {i}");
319        }
320    }
321
322    struct DatasetBuilder {
323        nclusters: usize,
324        cluster_size: usize,
325        step_between_clusters: f32,
326    }
327
328    struct ClusteredDataset {
329        data: Matrix<f32>,
330        // The pre-configured center point for the manufactured clusters.
331        centers: Matrix<f32>,
332    }
333
334    impl DatasetBuilder {
335        fn build<R>(
336            &self,
337            schema: crate::views::ChunkOffsetsView<'_>,
338            rng: &mut R,
339        ) -> ClusteredDataset
340        where
341            R: Rng,
342        {
343            let ndata = self.nclusters * self.cluster_size;
344
345            // Start the clustering points at a different location for each chunk.
346            // The starting offset is chosen from this distribution.
347            let offsets_distribution = Uniform::<f32>::new(-100.0, 100.0).unwrap();
348
349            // The perturbation for vectors within a cluster - all centered around some
350            // mean.
351            let perturbation_distribution = rand_distr::StandardNormal;
352
353            // Indices that we use to shuffle the order of elements in the dataset.
354            let mut indices: Vec<usize> = (0..ndata).collect();
355
356            // Construct the dataset in pieces.
357            let (pieces, centers): (Vec<_>, Vec<_>) = (0..schema.len())
358                .map(|chunk| {
359                    let dim = schema.at(chunk).len();
360
361                    let mut initial = Matrix::new(0.0, ndata, dim);
362                    let mut centers = Matrix::new(0.0, self.nclusters, 1);
363
364                    // The starting offset for clusters.
365                    let offset = offsets_distribution.sample(rng);
366
367                    // Create a dataset with `nclusters`, each cluster
368                    for cluster in 0..self.nclusters {
369                        let this_offset = offset + (cluster as f32 * self.step_between_clusters);
370                        centers[(cluster, 0)] = this_offset;
371
372                        for element in 0..self.cluster_size {
373                            let row = initial.row_mut(cluster * self.cluster_size + element);
374                            for r in row.iter_mut() {
375                                let perturbation: f32 = perturbation_distribution.sample(rng);
376                                *r = this_offset + perturbation;
377                            }
378                        }
379                    }
380
381                    // Shuffle the dataset.
382                    indices.shuffle(rng);
383                    let mut piece = Matrix::new(0.0, ndata, dim);
384                    for (dst, src) in indices.iter().enumerate() {
385                        piece.row_mut(dst).copy_from_slice(initial.row(*src));
386                    }
387                    (piece, centers)
388                })
389                .unzip();
390
391            ClusteredDataset {
392                data: flatten(&pieces, ndata, schema.dim()),
393                centers: flatten(&centers, self.nclusters, schema.len()),
394            }
395        }
396    }
397
398    fn broadcast_distance(x: &[f32], y: f32) -> f32 {
399        x.iter()
400            .map(|i| {
401                let d = *i - y;
402                d * d
403            })
404            .sum()
405    }
406
407    // Happy Path check - varying over parallelism.
408    fn test_pq_training_happy_path(parallelism: Parallelism) {
409        let mut rng = StdRng::seed_from_u64(0x749cb951cf960384);
410        let builder = DatasetBuilder {
411            nclusters: 16,
412            cluster_size: 20,
413            // NOTE: We need to keep the step between clusters fairly large to ensure that
414            // kmeans++ adequately initializes.
415            step_between_clusters: 20.0,
416        };
417
418        let ncenters = builder.nclusters;
419
420        let offsets = [0, 2, 3, 8, 12, 16];
421        let schema = crate::views::ChunkOffsetsView::new(&offsets).unwrap();
422        let dataset = builder.build(schema, &mut rng);
423
424        let trainer = LightPQTrainingParameters::new(ncenters, 6);
425
426        let quantizer = trainer
427            .train(
428                dataset.data.as_view(),
429                schema,
430                parallelism,
431                &StdRngBuilder::new(StandardUniform {}.sample(&mut rng)),
432                &DontCancel,
433            )
434            .unwrap();
435
436        // Now that we have trained the quantizer - we need to double check that the chosen
437        // centroids match what we expect.
438        //
439        // To do this - we loop through the centroids that training picked. We match the
440        // centroids with one of the known centers in our clustering.
441        //
442        // We perform two main checks:
443        //
444        // 1. We ensure that the quantizer's center actually aligns with a cluster (i.e.,
445        //    training did not invent values out of thin air).
446        // 2. Every clustering in the original dataset has a representative in the quantizer.
447        assert_eq!(quantizer.dim, schema.dim());
448        assert_eq!(quantizer.ncenters, ncenters);
449        assert_eq!(quantizer.pivots.len(), schema.len());
450        for (i, pivot) in quantizer.pivots.iter().enumerate() {
451            // Make sure the pivot has the correct dimension.
452            assert_eq!(
453                pivot.ncols(),
454                schema.at(i).len(),
455                "center {i} has the incorrect number of columns"
456            );
457            assert_eq!(pivot.nrows(), ncenters);
458
459            // Start matching pivots to expected centers.
460            let mut seen: Vec<bool> = (0..dataset.centers.nrows()).map(|_| false).collect();
461            for row in pivot.row_iter() {
462                let mut min_distance = f32::MAX;
463                let mut min_index = 0;
464                for c in 0..dataset.centers.nrows() {
465                    let distance = broadcast_distance(row, dataset.centers[(c, i)]);
466                    if distance < min_distance {
467                        min_distance = distance;
468                        min_index = c;
469                    }
470                }
471
472                // Does the minimum distance suggest that we are inside a cluster.
473                assert!(
474                    min_distance < 1.0,
475                    "got a minimum distance of {}, pivot = {}. Row = {:?}",
476                    min_distance,
477                    i,
478                    row,
479                );
480
481                // Mark this index as seen.
482                let seen_before = &mut seen[min_index];
483                assert!(
484                    !*seen_before,
485                    "cluster {} has more than one assignment",
486                    min_index
487                );
488                *seen_before = true;
489            }
490
491            // Make sure that all clusters were seen.
492            assert!(seen.iter().all(|i| *i), "not all clusters were seen");
493        }
494
495        // Check `flatten`.
496        let flattened = quantizer.flatten();
497        assert_eq!(
498            &flattened,
499            flatten(&quantizer.pivots, quantizer.ncenters, quantizer.dim).as_slice()
500        );
501    }
502
503    #[test]
504    fn test_pq_training_happy_path_sequential() {
505        test_pq_training_happy_path(Parallelism::Sequential);
506    }
507
508    #[test]
509    #[cfg(feature = "rayon")]
510    fn test_pq_training_happy_path_parallel() {
511        test_pq_training_happy_path(Parallelism::Rayon);
512    }
513
514    // A canceler that cancels after a set number of invocations.
515    struct CancelAfter {
516        counter: AtomicUsize,
517        after: usize,
518    }
519
520    impl CancelAfter {
521        fn new(after: usize) -> Self {
522            Self {
523                counter: AtomicUsize::new(0),
524                after,
525            }
526        }
527    }
528
529    impl Cancelation for CancelAfter {
530        fn should_cancel(&self) -> bool {
531            let v = self.counter.fetch_add(1, Ordering::Relaxed);
532            v >= self.after
533        }
534    }
535
536    #[test]
537    fn test_cancel() {
538        let mut rng = StdRng::seed_from_u64(0xb85352d38cc5353b);
539        let builder = DatasetBuilder {
540            nclusters: 16,
541            cluster_size: 20,
542            // NOTE: We need to keep the step between clusters fairly large to ensure that
543            // kmeans++ adequately initializes.
544            step_between_clusters: 20.0,
545        };
546
547        let offsets = [0, 2, 3, 8, 12, 16];
548        let schema = crate::views::ChunkOffsetsView::new(&offsets).unwrap();
549        let dataset = builder.build(schema, &mut rng);
550
551        let trainer = LightPQTrainingParameters::new(builder.nclusters, 6);
552
553        for after in 0..10 {
554            let parallelism = [
555                Parallelism::Sequential,
556                #[cfg(feature = "rayon")]
557                Parallelism::Rayon,
558            ];
559
560            for par in parallelism {
561                let result = trainer.train(
562                    dataset.data.as_view(),
563                    schema,
564                    par,
565                    &StdRngBuilder::new(StandardUniform {}.sample(&mut rng)),
566                    &CancelAfter::new(after),
567                );
568                assert!(result.is_err(), "expected the operation to be canceled");
569                let err = result.unwrap_err();
570                assert!(
571                    err.was_canceled(),
572                    "expected the failure reason to be cancellation"
573                );
574            }
575        }
576    }
577
578    // In this test - we ensure that clustering succeeds even if the number of requested
579    // pivots exceeds the number of dataset items.
580    #[test]
581    fn tests_succeeded_with_too_many_pivots() {
582        let data = Matrix::<f32>::new(1.0, 10, 5);
583        let offsets: Vec<usize> = vec![0, 1, 4, 5];
584
585        let trainer = LightPQTrainingParameters::new(2 * data.nrows(), 6);
586
587        let quantizer = trainer
588            .train(
589                data.as_view(),
590                crate::views::ChunkOffsetsView::new(&offsets).unwrap(),
591                Parallelism::Sequential,
592                &StdRngBuilder::new(0),
593                &DontCancel,
594            )
595            .unwrap();
596
597        // We are in the special position to actually know how this will behave.
598        // Since the input dataset lacks diversity, there should only have been a single
599        // pivot actually selected.
600        //
601        // All the rest should be zero.
602        let flat = flatten(&quantizer.pivots, quantizer.ncenters, quantizer.dim);
603
604        assert!(
605            flat.row(0).iter().all(|i| *i == 1.0),
606            "expected pivot 0 to be the non-zero pivot"
607        );
608
609        for (i, row) in flat.row_iter().enumerate() {
610            // skip the first row.
611            if i == 0 {
612                continue;
613            }
614
615            assert!(
616                row.iter().all(|j| *j == 0.0),
617                "expected pivot {i} to be all zeros"
618            );
619        }
620    }
621
622    #[test]
623    fn test_infinity_and_nan_is_not_recoverable() {
624        let num_trials = 10;
625        let nrows = 10;
626        let ncols = 5;
627
628        let offsets: Vec<usize> = vec![0, 1, 4, 5];
629
630        let trainer = LightPQTrainingParameters::new(nrows, 6);
631
632        let row_distribution = Uniform::new(0, nrows).unwrap();
633        let col_distribution = Uniform::new(0, ncols).unwrap();
634        let mut rng = StdRng::seed_from_u64(0xe746cfebba2d7e35);
635
636        for trial in 0..num_trials {
637            let context = lazy_format!("trial {} of {}", trial + 1, num_trials);
638
639            let r = row_distribution.sample(&mut rng);
640            let c = col_distribution.sample(&mut rng);
641
642            let check_result = |r: Result<_, PQTrainingError>| {
643                assert!(
644                    r.is_err(),
645                    "expected error due to infinities/NaN -- {}",
646                    context
647                );
648                let err = r.unwrap_err();
649                assert!(!err.was_canceled());
650                assert!(format(&err).contains("infinity"));
651            };
652
653            let mut data = Matrix::<f32>::new(1.0, nrows, ncols);
654
655            // Positive Infinity
656            data[(r, c)] = f32::INFINITY;
657            let result = trainer.train(
658                data.as_view(),
659                crate::views::ChunkOffsetsView::new(&offsets).unwrap(),
660                Parallelism::Sequential,
661                &StdRngBuilder::new(0),
662                &DontCancel,
663            );
664            check_result(result);
665
666            // Positive Infinity
667            data[(r, c)] = f32::NEG_INFINITY;
668            let result = trainer.train(
669                data.as_view(),
670                crate::views::ChunkOffsetsView::new(&offsets).unwrap(),
671                Parallelism::Sequential,
672                &StdRngBuilder::new(0),
673                &DontCancel,
674            );
675            check_result(result);
676
677            // NaN
678            data[(r, c)] = f32::NAN;
679            let result = trainer.train(
680                data.as_view(),
681                crate::views::ChunkOffsetsView::new(&offsets).unwrap(),
682                Parallelism::Sequential,
683                &StdRngBuilder::new(0),
684                &DontCancel,
685            );
686            check_result(result);
687        }
688    }
689}