Skip to main content

diskann_quantization/product/
train.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use diskann_utils::{
7    strided::StridedView,
8    views::{self, Matrix},
9};
10#[cfg(feature = "rayon")]
11use rayon::iter::{IntoParallelIterator, ParallelIterator};
12use thiserror::Error;
13
14use crate::{
15    Parallelism,
16    algorithms::kmeans::{
17        self,
18        common::{BlockTranspose, square_norm},
19    },
20    cancel::Cancelation,
21    random::{BoxedRngBuilder, RngBuilder},
22};
23
24pub struct LightPQTrainingParameters {
25    /// The number of centers for each partition.
26    ncenters: usize,
27    /// The maximum number of iterations for Lloyd's algorithm.
28    lloyds_reps: usize,
29}
30
31impl LightPQTrainingParameters {
32    /// Construct a new light-weight PQ trainer.
33    pub fn new(ncenters: usize, lloyds_reps: usize) -> Self {
34        Self {
35            ncenters,
36            lloyds_reps,
37        }
38    }
39}
40
41#[derive(Debug)]
42pub struct SimplePivots {
43    dim: usize,
44    ncenters: usize,
45    pivots: Vec<Matrix<f32>>,
46}
47
48fn flatten<T: Copy + Default>(pivots: &[Matrix<T>], ncenters: usize, dim: usize) -> Matrix<T> {
49    let mut flattened = Matrix::new(T::default(), ncenters, dim);
50    let mut col_start = 0;
51    for matrix in pivots {
52        assert_eq!(matrix.nrows(), flattened.nrows());
53        for (row_index, row) in matrix.row_iter().enumerate() {
54            let dst = &mut flattened.row_mut(row_index)[col_start..col_start + row.len()];
55            dst.copy_from_slice(row);
56        }
57        col_start += matrix.ncols();
58    }
59    flattened
60}
61
62impl SimplePivots {
63    /// Return the selected pivots for each chunk.
64    pub fn pivots(&self) -> &[Matrix<f32>] {
65        &self.pivots
66    }
67
68    /// Concatenate the individual pivots into a dense representation.
69    pub fn flatten(&self) -> Vec<f32> {
70        flatten(self.pivots(), self.ncenters, self.dim)
71            .into_inner()
72            .into()
73    }
74}
75
76pub trait TrainQuantizer {
77    type Quantizer;
78    type Error: std::error::Error;
79
80    fn train<R, C>(
81        &self,
82        data: views::MatrixView<f32>,
83        schema: crate::views::ChunkOffsetsView<'_>,
84        parallelism: Parallelism,
85        rng_builder: &R,
86        cancelation: &C,
87    ) -> Result<Self::Quantizer, Self::Error>
88    where
89        R: RngBuilder<usize> + Sync,
90        C: Cancelation + Sync;
91}
92
93impl TrainQuantizer for LightPQTrainingParameters {
94    type Quantizer = SimplePivots;
95    type Error = PQTrainingError;
96
97    /// Perform product quantization training on the provided training set and return a
98    /// `SimplePivots` containing the result of kmeans clustering on each partition.
99    ///
100    /// # Panics
101    ///
102    /// Panics if `data.nrows() != schema.dim()`.
103    ///
104    /// # Errors
105    ///
106    /// An error type is returned under the following circumstances:
107    ///
108    /// * A cancellation request is received. This case can be queried by calling
109    ///   `was_canceled` on the returned `PQTrainingError`.
110    /// * `NaN` or infinities are observed during the training process.
111    fn train<R, C>(
112        &self,
113        data: views::MatrixView<f32>,
114        schema: crate::views::ChunkOffsetsView<'_>,
115        parallelism: Parallelism,
116        rng_builder: &R,
117        cancelation: &C,
118    ) -> Result<Self::Quantizer, Self::Error>
119    where
120        R: RngBuilder<usize> + Sync,
121        C: Cancelation + Sync,
122    {
123        // Inner method where we `dyn` away the cancellation token to reduce compile-times.
124        // Unfortunately, we can't quite do the same with the RngBuilder.
125        #[inline(never)]
126        fn train(
127            trainer: &LightPQTrainingParameters,
128            data: views::MatrixView<f32>,
129            schema: crate::views::ChunkOffsetsView<'_>,
130            parallelism: Parallelism,
131            rng_builder: &(dyn BoxedRngBuilder<usize> + Sync),
132            cancelation: &(dyn Cancelation + Sync),
133        ) -> Result<SimplePivots, PQTrainingError> {
134            // Make sure we're provided sane values for our schema.
135            assert_eq!(data.ncols(), schema.dim());
136
137            let thunk = |i| -> Result<Matrix<f32>, PQTrainingError> {
138                let range = schema.at(i);
139
140                // Check for cancelation.
141                let exit_if_canceled = || -> Result<(), PQTrainingError> {
142                    if cancelation.should_cancel() {
143                        Err(PQTrainingError {
144                            chunk: i,
145                            of: schema.len(),
146                            dim: range.len(),
147                            kind: PQTrainingErrorKind::Canceled,
148                        })
149                    } else {
150                        Ok(())
151                    }
152                };
153
154                // This is an early check - if another task hit cancelation, this allows
155                // the remaining tasks to exit early.
156                exit_if_canceled()?;
157
158                let view = StridedView::try_shrink_from(
159                    &(data.as_slice()[range.start..]),
160                    data.nrows(),
161                    range.len(),
162                    schema.dim(),
163                )
164                .map_err(|err| PQTrainingError {
165                    chunk: i,
166                    of: schema.len(),
167                    dim: range.len(),
168                    kind: PQTrainingErrorKind::InternalError(Box::new(err.as_static())),
169                })?;
170
171                // Allocate scratch data structures.
172                let norms: Vec<f32> = view.row_iter().map(square_norm).collect();
173                let transpose = BlockTranspose::<16>::from_strided(view);
174                let mut centers = Matrix::new(0.0, trainer.ncenters, range.len());
175
176                // Construct the random number generator seeded by the PQ chunk.
177                let mut rng = rng_builder.build_boxed_rng(i);
178
179                // Initialization
180                kmeans::plusplus::kmeans_plusplus_into_inner(
181                    centers.as_mut_view(),
182                    view,
183                    &transpose,
184                    &norms,
185                    &mut rng,
186                )
187                .or_else(|err| {
188                    // Suppress recoverable errors.
189                    if !err.is_numerically_recoverable() {
190                        Err(PQTrainingError {
191                            chunk: i,
192                            of: schema.len(),
193                            dim: range.len(),
194                            kind: PQTrainingErrorKind::Initialization(Box::new(err)),
195                        })
196                    } else {
197                        Ok(())
198                    }
199                })?;
200
201                // Did a cancelation request come while runing `kmeans++`?
202                exit_if_canceled()?;
203
204                // Kmeans
205                kmeans::lloyds::lloyds_inner(
206                    view,
207                    &norms,
208                    &transpose,
209                    centers.as_mut_view(),
210                    trainer.lloyds_reps,
211                );
212                Ok(centers)
213            };
214
215            let pivots: Result<Vec<_>, _> = match parallelism {
216                Parallelism::Sequential => (0..schema.len()).map(thunk).collect(),
217
218                #[cfg(feature = "rayon")]
219                Parallelism::Rayon => (0..schema.len()).into_par_iter().map(thunk).collect(),
220            };
221
222            let dim = data.ncols();
223            let ncenters = trainer.ncenters;
224            Ok(SimplePivots {
225                dim,
226                ncenters,
227                pivots: pivots?,
228            })
229        }
230
231        train(self, data, schema, parallelism, rng_builder, cancelation)
232    }
233}
234
235#[derive(Debug, Error)]
236#[error("pq training failed on chunk {chunk} of {of} (dim {dim})")]
237pub struct PQTrainingError {
238    chunk: usize,
239    of: usize,
240    dim: usize,
241    #[source]
242    kind: PQTrainingErrorKind,
243}
244
245impl PQTrainingError {
246    /// Return whether or not this error originated as a cancelation request.
247    pub fn was_canceled(&self) -> bool {
248        matches!(self.kind, PQTrainingErrorKind::Canceled)
249    }
250}
251
252#[derive(Debug, Error)]
253#[non_exhaustive]
254enum PQTrainingErrorKind {
255    #[error("canceled by request")]
256    Canceled,
257    #[error("initial pivot selection error")]
258    Initialization(#[source] Box<dyn std::error::Error + Send + Sync>),
259    #[error("internal logic error")]
260    InternalError(#[source] Box<dyn std::error::Error + Send + Sync>),
261}
262
263///////////
264// Tests //
265///////////
266
267#[cfg(not(miri))]
268#[cfg(test)]
269mod tests {
270    use std::sync::atomic::{AtomicUsize, Ordering};
271
272    use rand::{
273        Rng, SeedableRng,
274        distr::{Distribution, StandardUniform, Uniform},
275        rngs::StdRng,
276        seq::SliceRandom,
277    };
278
279    use diskann_utils::lazy_format;
280
281    use super::*;
282    use crate::{cancel::DontCancel, error::format, random::StdRngBuilder};
283
284    // With this test - we create sub-matrices that when flattened, will yield the output
285    // sequence `0, 1, 2, 3, 4, ...`.
286    #[test]
287    fn test_flatten() {
288        // The number of rows in the final matrix.
289        let nrows = 5;
290        // The dimensions in each sub-matrix.
291        let sub_dims = [1, 2, 3, 4, 5];
292        // The prefix sum of the sub dimensions.
293        let prefix_sum: Vec<usize> = sub_dims
294            .iter()
295            .scan(0, |state, i| {
296                let this = *state;
297                *state += *i;
298                Some(this)
299            })
300            .collect();
301
302        let dim: usize = sub_dims.iter().sum();
303
304        // Create the sub matrices.
305        let matrices: Vec<Matrix<usize>> = std::iter::zip(sub_dims.iter(), prefix_sum.iter())
306            .map(|(&this_dim, &offset)| {
307                let mut m = Matrix::new(0, nrows, this_dim);
308                for r in 0..nrows {
309                    for c in 0..this_dim {
310                        m[(r, c)] = dim * r + offset + c;
311                    }
312                }
313                m
314            })
315            .collect();
316
317        let flattened = flatten(&matrices, nrows, dim);
318        // Check that the output is correct.
319        for (i, v) in flattened.as_slice().iter().enumerate() {
320            assert_eq!(*v, i, "failed at index {i}");
321        }
322    }
323
324    struct DatasetBuilder {
325        nclusters: usize,
326        cluster_size: usize,
327        step_between_clusters: f32,
328    }
329
330    struct ClusteredDataset {
331        data: Matrix<f32>,
332        // The pre-configured center point for the manufactured clusters.
333        centers: Matrix<f32>,
334    }
335
336    impl DatasetBuilder {
337        fn build<R>(
338            &self,
339            schema: crate::views::ChunkOffsetsView<'_>,
340            rng: &mut R,
341        ) -> ClusteredDataset
342        where
343            R: Rng,
344        {
345            let ndata = self.nclusters * self.cluster_size;
346
347            // Start the clustering points at a different location for each chunk.
348            // The starting offset is chosen from this distribution.
349            let offsets_distribution = Uniform::<f32>::new(-100.0, 100.0).unwrap();
350
351            // The perturbation for vectors within a cluster - all centered around some
352            // mean.
353            let perturbation_distribution = rand_distr::StandardNormal;
354
355            // Indices that we use to shuffle the order of elements in the dataset.
356            let mut indices: Vec<usize> = (0..ndata).collect();
357
358            // Construct the dataset in pieces.
359            let (pieces, centers): (Vec<_>, Vec<_>) = (0..schema.len())
360                .map(|chunk| {
361                    let dim = schema.at(chunk).len();
362
363                    let mut initial = Matrix::new(0.0, ndata, dim);
364                    let mut centers = Matrix::new(0.0, self.nclusters, 1);
365
366                    // The starting offset for clusters.
367                    let offset = offsets_distribution.sample(rng);
368
369                    // Create a dataset with `nclusters`, each cluster
370                    for cluster in 0..self.nclusters {
371                        let this_offset = offset + (cluster as f32 * self.step_between_clusters);
372                        centers[(cluster, 0)] = this_offset;
373
374                        for element in 0..self.cluster_size {
375                            let row = initial.row_mut(cluster * self.cluster_size + element);
376                            for r in row.iter_mut() {
377                                let perturbation: f32 = perturbation_distribution.sample(rng);
378                                *r = this_offset + perturbation;
379                            }
380                        }
381                    }
382
383                    // Shuffle the dataset.
384                    indices.shuffle(rng);
385                    let mut piece = Matrix::new(0.0, ndata, dim);
386                    for (dst, src) in indices.iter().enumerate() {
387                        piece.row_mut(dst).copy_from_slice(initial.row(*src));
388                    }
389                    (piece, centers)
390                })
391                .unzip();
392
393            ClusteredDataset {
394                data: flatten(&pieces, ndata, schema.dim()),
395                centers: flatten(&centers, self.nclusters, schema.len()),
396            }
397        }
398    }
399
400    fn broadcast_distance(x: &[f32], y: f32) -> f32 {
401        x.iter()
402            .map(|i| {
403                let d = *i - y;
404                d * d
405            })
406            .sum()
407    }
408
409    // Happy Path check - varying over parallelism.
410    fn test_pq_training_happy_path(parallelism: Parallelism) {
411        let mut rng = StdRng::seed_from_u64(0x749cb951cf960384);
412        let builder = DatasetBuilder {
413            nclusters: 16,
414            cluster_size: 20,
415            // NOTE: We need to keep the step between clusters fairly large to ensure that
416            // kmeans++ adequately initializes.
417            step_between_clusters: 20.0,
418        };
419
420        let ncenters = builder.nclusters;
421
422        let offsets = [0, 2, 3, 8, 12, 16];
423        let schema = crate::views::ChunkOffsetsView::new(&offsets).unwrap();
424        let dataset = builder.build(schema, &mut rng);
425
426        let trainer = LightPQTrainingParameters::new(ncenters, 6);
427
428        let quantizer = trainer
429            .train(
430                dataset.data.as_view(),
431                schema,
432                parallelism,
433                &StdRngBuilder::new(StandardUniform {}.sample(&mut rng)),
434                &DontCancel,
435            )
436            .unwrap();
437
438        // Now that we have trained the quantizer - we need to double check that the chosen
439        // centroids match what we expect.
440        //
441        // To do this - we loop through the centroids that training picked. We match the
442        // centroids with one of the known centers in our clustering.
443        //
444        // We perform two main checks:
445        //
446        // 1. We ensure that the quantizer's center actually aligns with a cluster (i.e.,
447        //    training did not invent values out of thin air).
448        // 2. Every clustering in the original dataset has a representative in the quantizer.
449        assert_eq!(quantizer.dim, schema.dim());
450        assert_eq!(quantizer.ncenters, ncenters);
451        assert_eq!(quantizer.pivots.len(), schema.len());
452        for (i, pivot) in quantizer.pivots.iter().enumerate() {
453            // Make sure the pivot has the correct dimension.
454            assert_eq!(
455                pivot.ncols(),
456                schema.at(i).len(),
457                "center {i} has the incorrect number of columns"
458            );
459            assert_eq!(pivot.nrows(), ncenters);
460
461            // Start matching pivots to expected centers.
462            let mut seen: Vec<bool> = (0..dataset.centers.nrows()).map(|_| false).collect();
463            for row in pivot.row_iter() {
464                let mut min_distance = f32::MAX;
465                let mut min_index = 0;
466                for c in 0..dataset.centers.nrows() {
467                    let distance = broadcast_distance(row, dataset.centers[(c, i)]);
468                    if distance < min_distance {
469                        min_distance = distance;
470                        min_index = c;
471                    }
472                }
473
474                // Does the minimum distance suggest that we are inside a cluster.
475                assert!(
476                    min_distance < 1.0,
477                    "got a minimum distance of {}, pivot = {}. Row = {:?}",
478                    min_distance,
479                    i,
480                    row,
481                );
482
483                // Mark this index as seen.
484                let seen_before = &mut seen[min_index];
485                assert!(
486                    !*seen_before,
487                    "cluster {} has more than one assignment",
488                    min_index
489                );
490                *seen_before = true;
491            }
492
493            // Make sure that all clusters were seen.
494            assert!(seen.iter().all(|i| *i), "not all clusters were seen");
495        }
496
497        // Check `flatten`.
498        let flattened = quantizer.flatten();
499        assert_eq!(
500            &flattened,
501            flatten(&quantizer.pivots, quantizer.ncenters, quantizer.dim).as_slice()
502        );
503    }
504
505    #[test]
506    fn test_pq_training_happy_path_sequential() {
507        test_pq_training_happy_path(Parallelism::Sequential);
508    }
509
510    #[test]
511    #[cfg(feature = "rayon")]
512    fn test_pq_training_happy_path_parallel() {
513        test_pq_training_happy_path(Parallelism::Rayon);
514    }
515
516    // A canceler that cancels after a set number of invocations.
517    struct CancelAfter {
518        counter: AtomicUsize,
519        after: usize,
520    }
521
522    impl CancelAfter {
523        fn new(after: usize) -> Self {
524            Self {
525                counter: AtomicUsize::new(0),
526                after,
527            }
528        }
529    }
530
531    impl Cancelation for CancelAfter {
532        fn should_cancel(&self) -> bool {
533            let v = self.counter.fetch_add(1, Ordering::Relaxed);
534            v >= self.after
535        }
536    }
537
538    #[test]
539    fn test_cancel() {
540        let mut rng = StdRng::seed_from_u64(0xb85352d38cc5353b);
541        let builder = DatasetBuilder {
542            nclusters: 16,
543            cluster_size: 20,
544            // NOTE: We need to keep the step between clusters fairly large to ensure that
545            // kmeans++ adequately initializes.
546            step_between_clusters: 20.0,
547        };
548
549        let offsets = [0, 2, 3, 8, 12, 16];
550        let schema = crate::views::ChunkOffsetsView::new(&offsets).unwrap();
551        let dataset = builder.build(schema, &mut rng);
552
553        let trainer = LightPQTrainingParameters::new(builder.nclusters, 6);
554
555        for after in 0..10 {
556            let parallelism = [
557                Parallelism::Sequential,
558                #[cfg(feature = "rayon")]
559                Parallelism::Rayon,
560            ];
561
562            for par in parallelism {
563                let result = trainer.train(
564                    dataset.data.as_view(),
565                    schema,
566                    par,
567                    &StdRngBuilder::new(StandardUniform {}.sample(&mut rng)),
568                    &CancelAfter::new(after),
569                );
570                assert!(result.is_err(), "expected the operation to be canceled");
571                let err = result.unwrap_err();
572                assert!(
573                    err.was_canceled(),
574                    "expected the failure reason to be cancellation"
575                );
576            }
577        }
578    }
579
580    // In this test - we ensure that clustering succeeds even if the number of requested
581    // pivots exceeds the number of dataset items.
582    #[test]
583    fn tests_succeeded_with_too_many_pivots() {
584        let data = Matrix::<f32>::new(1.0, 10, 5);
585        let offsets: Vec<usize> = vec![0, 1, 4, 5];
586
587        let trainer = LightPQTrainingParameters::new(2 * data.nrows(), 6);
588
589        let quantizer = trainer
590            .train(
591                data.as_view(),
592                crate::views::ChunkOffsetsView::new(&offsets).unwrap(),
593                Parallelism::Sequential,
594                &StdRngBuilder::new(0),
595                &DontCancel,
596            )
597            .unwrap();
598
599        // We are in the special position to actually know how this will behave.
600        // Since the input dataset lacks diversity, there should only have been a single
601        // pivot actually selected.
602        //
603        // All the rest should be zero.
604        let flat = flatten(&quantizer.pivots, quantizer.ncenters, quantizer.dim);
605
606        assert!(
607            flat.row(0).iter().all(|i| *i == 1.0),
608            "expected pivot 0 to be the non-zero pivot"
609        );
610
611        for (i, row) in flat.row_iter().enumerate() {
612            // skip the first row.
613            if i == 0 {
614                continue;
615            }
616
617            assert!(
618                row.iter().all(|j| *j == 0.0),
619                "expected pivot {i} to be all zeros"
620            );
621        }
622    }
623
624    #[test]
625    fn test_infinity_and_nan_is_not_recoverable() {
626        let num_trials = 10;
627        let nrows = 10;
628        let ncols = 5;
629
630        let offsets: Vec<usize> = vec![0, 1, 4, 5];
631
632        let trainer = LightPQTrainingParameters::new(nrows, 6);
633
634        let row_distribution = Uniform::new(0, nrows).unwrap();
635        let col_distribution = Uniform::new(0, ncols).unwrap();
636        let mut rng = StdRng::seed_from_u64(0xe746cfebba2d7e35);
637
638        for trial in 0..num_trials {
639            let context = lazy_format!("trial {} of {}", trial + 1, num_trials);
640
641            let r = row_distribution.sample(&mut rng);
642            let c = col_distribution.sample(&mut rng);
643
644            let check_result = |r: Result<_, PQTrainingError>| {
645                assert!(
646                    r.is_err(),
647                    "expected error due to infinities/NaN -- {}",
648                    context
649                );
650                let err = r.unwrap_err();
651                assert!(!err.was_canceled());
652                assert!(format(&err).contains("infinity"));
653            };
654
655            let mut data = Matrix::<f32>::new(1.0, nrows, ncols);
656
657            // Positive Infinity
658            data[(r, c)] = f32::INFINITY;
659            let result = trainer.train(
660                data.as_view(),
661                crate::views::ChunkOffsetsView::new(&offsets).unwrap(),
662                Parallelism::Sequential,
663                &StdRngBuilder::new(0),
664                &DontCancel,
665            );
666            check_result(result);
667
668            // Positive Infinity
669            data[(r, c)] = f32::NEG_INFINITY;
670            let result = trainer.train(
671                data.as_view(),
672                crate::views::ChunkOffsetsView::new(&offsets).unwrap(),
673                Parallelism::Sequential,
674                &StdRngBuilder::new(0),
675                &DontCancel,
676            );
677            check_result(result);
678
679            // NaN
680            data[(r, c)] = f32::NAN;
681            let result = trainer.train(
682                data.as_view(),
683                crate::views::ChunkOffsetsView::new(&offsets).unwrap(),
684                Parallelism::Sequential,
685                &StdRngBuilder::new(0),
686                &DontCancel,
687            );
688            check_result(result);
689        }
690    }
691}